From c2480deddecd964e3557092cd151a599bef1b278 Mon Sep 17 00:00:00 2001 From: Tolik Zinovyev Date: Fri, 24 May 2024 12:22:06 +0000 Subject: [PATCH] core: simplify code, add asserts, remove undefined behavior around int64x64-128. --- src/core/model/int64x64-128.cc | 124 +++++++++++++++------------ src/core/model/int64x64-128.h | 5 +- src/core/model/int64x64.cc | 16 +++- src/core/test/int64x64-test-suite.cc | 29 +++++++ 4 files changed, 115 insertions(+), 59 deletions(-) diff --git a/src/core/model/int64x64-128.cc b/src/core/model/int64x64-128.cc index daf105741..c5fc43255 100644 --- a/src/core/model/int64x64-128.cc +++ b/src/core/model/int64x64-128.cc @@ -52,8 +52,8 @@ output_sign(const int128_t sa, const int128_t sb, uint128_t& ua, uint128_t& ub) { bool negA = sa < 0; bool negB = sb < 0; - ua = negA ? -sa : sa; - ub = negB ? -sb : sb; + ua = negA ? -static_cast(sa) : sa; + ub = negB ? -static_cast(sb) : sb; return negA != negB; } @@ -64,50 +64,64 @@ int64x64_t::Mul(const int64x64_t& o) uint128_t b; bool negative = output_sign(_v, o._v, a, b); uint128_t result = Umul(a, b); - _v = negative ? -result : result; + if (negative) + { + NS_ASSERT_MSG(result <= HP128_MASK_HI_BIT, "overflow detected"); + _v = -result; + } + else + { + NS_ASSERT_MSG(result < HP128_MASK_HI_BIT, "overflow detected"); + _v = result; + } } uint128_t int64x64_t::Umul(const uint128_t a, const uint128_t b) { - uint128_t aL = a & HP_MASK_LO; - uint128_t bL = b & HP_MASK_LO; - uint128_t aH = (a >> 64) & HP_MASK_LO; - uint128_t bH = (b >> 64) & HP_MASK_LO; + uint128_t al = a & HP_MASK_LO; + uint128_t bl = b & HP_MASK_LO; + uint128_t ah = a >> 64; + uint128_t bh = b >> 64; - uint128_t result; - uint128_t hiPart; - uint128_t loPart; - uint128_t midPart; - uint128_t res1; - uint128_t res2; + // Let Q(x) be the unsigned Q64.64 fixed point value represented by the uint128_t x: + // Q(x) * 2^64 = x = xh * 2^64 + xl. + // (Defining x this way avoids ambiguity about the meaning of the division operators in + // Q(x) = x / 2^64 = xh + xl / 2^64.) + // Then + // Q(a) = ah + al / 2^64 + // and + // Q(b) = bh + bl / 2^64. + // We need to find uint128_t c such that + // Q(c) = Q(a) * Q(b). + // Then + // c = Q(c) * 2^64 + // = (ah + al / 2^64) * (bh + bl / 2^64) * 2^64 + // = (ah * 2^64 + al) * (bh * 2^64 + bl) / 2^64 + // = ah * bh * 2^64 + (ah * bl + al * bh) + al * bl / 2^64. + // We compute the last part of c by (al * bl) >> 64 which truncates (instead of rounds) + // the LSB. If c exceeds 2^127, we might assert. This is because our caller + // (Mul function) will not be able to represent the result. - // Multiplying (a.h 2^64 + a.l) x (b.h 2^64 + b.l) = - // 2^128 a.h b.h + 2^64*(a.h b.l+b.h a.l) + a.l b.l - // get the low part a.l b.l - // multiply the fractional part - loPart = aL * bL; - // compute the middle part 2^64*(a.h b.l+b.h a.l) - midPart = aL * bH + aH * bL; - // compute the high part 2^128 a.h b.h - hiPart = aH * bH; - // if the high part is not zero, put a warning - NS_ABORT_MSG_IF((hiPart & HP_MASK_HI) != 0, - "High precision 128 bits multiplication error: multiplication overflow."); + uint128_t res = (al * bl) >> 64; + { + // ah, bh <= 2^63 and al, bl <= 2^64 - 1, so mid < 2^128 - 2^64 and there is no + // integer overflow. + uint128_t mid = ah * bl + al * bh; + // res < 2^64, so there is no integer overflow. + res += mid; + } + { + uint128_t high = ah * bh; + // If high > 2^63, then the result will overflow. + NS_ASSERT_MSG(high <= (static_cast(1) << 63), "overflow detected"); + high <<= 64; + NS_ASSERT_MSG(res + high >= res, "overflow detected"); + // No overflow since res, high <= 2^127 and one of res, high is < 2^127. + res += high; + } - // Adding 64-bit terms to get 128-bit results, with carries - res1 = loPart >> 64; - res2 = midPart & HP_MASK_LO; - result = res1 + res2; - - res1 = midPart >> 64; - res2 = hiPart & HP_MASK_LO; - res1 += res2; - res1 <<= 64; - - result += res1; - - return result; + return res; } void @@ -189,7 +203,7 @@ void int64x64_t::MulByInvert(const int64x64_t& o) { bool negResult = _v < 0; - uint128_t a = negResult ? -_v : _v; + uint128_t a = negResult ? -static_cast(_v) : _v; uint128_t result = UmulByInvert(a, o._v); _v = negResult ? -result : result; @@ -198,21 +212,25 @@ int64x64_t::MulByInvert(const int64x64_t& o) uint128_t int64x64_t::UmulByInvert(const uint128_t a, const uint128_t b) { - uint128_t result; - uint128_t ah; - uint128_t bh; - uint128_t al; - uint128_t bl; - uint128_t hi; - uint128_t mid; - ah = a >> 64; - bh = b >> 64; - al = a & HP_MASK_LO; - bl = b & HP_MASK_LO; - hi = ah * bh; - mid = ah * bl + al * bh; + // Since b is assumed to be the output of Invert(), b <= 2^127. + NS_ASSERT(b <= HP128_MASK_HI_BIT); + + uint128_t al = a & HP_MASK_LO; + uint128_t bl = b & HP_MASK_LO; + uint128_t ah = a >> 64; + uint128_t bh = b >> 64; + + // Since ah, bh <= 2^63, high <= 2^126 and there is no overflow. + uint128_t high = ah * bh; + + // Since ah, bh <= 2^63 and al, bl < 2^64, mid < 2^128 and there is + // no overflow. + uint128_t mid = ah * bl + al * bh; mid >>= 64; - result = hi + mid; + + // Since high <= 2^126 and mid < 2^64, result < 2^127 and there is no overflow. + uint128_t result = high + mid; + return result; } diff --git a/src/core/model/int64x64-128.h b/src/core/model/int64x64-128.h index 599c76e09..0c2a068b3 100644 --- a/src/core/model/int64x64-128.h +++ b/src/core/model/int64x64-128.h @@ -58,8 +58,6 @@ class int64x64_t static const uint128_t HP128_MASK_HI_BIT = (((int128_t)1) << 127); /// Mask for fraction part. static const uint64_t HP_MASK_LO = 0xffffffffffffffffULL; - /// Mask for sign + integer part. - static const uint64_t HP_MASK_HI = ~HP_MASK_LO; /** * Floating point value of HP_MASK_LO + 1. * We really want: @@ -412,6 +410,7 @@ class int64x64_t /** * Implement `*=`. + * We assert if the product cannot be encoded in int64x64_t. * * \param [in] o The other factor. */ @@ -427,7 +426,7 @@ class int64x64_t * * Mathematically this should produce a Q128.128 value; * we keep the central 128 bits, representing the Q64.64 result. - * We assert on integer overflow beyond the 64-bit integer portion. + * We might assert if the result, in uint128_t format, exceeds 2^127. * * \param [in] a First factor. * \param [in] b Second factor. diff --git a/src/core/model/int64x64.cc b/src/core/model/int64x64.cc index a4155328d..2c390b050 100644 --- a/src/core/model/int64x64.cc +++ b/src/core/model/int64x64.cc @@ -23,6 +23,7 @@ #include // showpos #include +#include #include #include @@ -69,9 +70,19 @@ std::ostream& operator<<(std::ostream& os, const int64x64_t& value) { const bool negative = (value < 0); - const int64x64_t absVal = (negative ? -value : value); - int64_t hi = absVal.GetHigh(); + uint64_t hi; + int64x64_t low; + if (value != int64x64_t(std::numeric_limits::min(), 0)) + { + const int64x64_t absVal = (negative ? -value : value); + hi = absVal.GetHigh(); + low = int64x64_t(0, absVal.GetLow()); + } + else + { + hi = static_cast(1) << 63; + } // Save stream format flags auto precision = static_cast(os.precision()); @@ -84,7 +95,6 @@ operator<<(std::ostream& os, const int64x64_t& value) std::ostringstream oss; oss << hi << "."; // collect the digits here so we can round properly - int64x64_t low(0, absVal.GetLow()); std::size_t places = 0; // Number of decimal places printed so far bool more = true; // Should we print more digits? diff --git a/src/core/test/int64x64-test-suite.cc b/src/core/test/int64x64-test-suite.cc index aa444531f..877a965e8 100644 --- a/src/core/test/int64x64-test-suite.cc +++ b/src/core/test/int64x64-test-suite.cc @@ -575,6 +575,35 @@ Int64x64ArithmeticTestCase::DoRun() // Check special values Check(51, int64x64_t(0, 0x159fa87f8aeaad21ULL) * 10, int64x64_t(0, 0xd83c94fb6d2ac34aULL)); + { + auto x = int64x64_t(std::numeric_limits::min(), 0); + Check(52, x * 1, x); + Check(53, 1 * x, x); + } + { + int64x64_t x(1 << 30, (static_cast(1) << 63) + 1); + auto ret = x * x; + int64x64_t expected(1152921505680588800, 4611686020574871553); + // The real difference between ret and expected is 2^-128. + int64x64_t tolerance = 0; + if (int64x64_t::implementation == int64x64_t::ld_impl) + { + tolerance = tol1; + } + Check(54, ret, expected, tolerance); + } + + // The following triggers an assert in int64x64-128.cc:Umul():117 + /* + { + auto x = int64x64_t(1LL << 31); // 2^31 + auto y = 2 * x; // 2^32 + Check(55, x, x); + Check(56, y, y); + auto z [[maybe_unused]] = x * y; // 2^63 < 0, triggers assert + Check(57, z, z); + } + */ } /**