bpo-39576: Prevent memory error for overly optimistic precisions (GH-18581)

This commit is contained in:
Stefan Krah 2020-02-21 01:52:47 +01:00 committed by GitHub
parent 6c444d0dab
commit 90930e6545
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 245 additions and 6 deletions

View File

@ -5476,6 +5476,41 @@ class CWhitebox(unittest.TestCase):
self.assertEqual(Decimal.from_float(cls(101.1)), self.assertEqual(Decimal.from_float(cls(101.1)),
Decimal.from_float(101.1)) Decimal.from_float(101.1))
def test_maxcontext_exact_arith(self):
# Make sure that exact operations do not raise MemoryError due
# to huge intermediate values when the context precision is very
# large.
# The following functions fill the available precision and are
# therefore not suitable for large precisions (by design of the
# specification).
MaxContextSkip = ['logical_invert', 'next_minus', 'next_plus',
'logical_and', 'logical_or', 'logical_xor',
'next_toward', 'rotate', 'shift']
Decimal = C.Decimal
Context = C.Context
localcontext = C.localcontext
# Here only some functions that are likely candidates for triggering a
# MemoryError are tested. deccheck.py has an exhaustive test.
maxcontext = Context(prec=C.MAX_PREC, Emin=C.MIN_EMIN, Emax=C.MAX_EMAX)
with localcontext(maxcontext):
self.assertEqual(Decimal(0).exp(), 1)
self.assertEqual(Decimal(1).ln(), 0)
self.assertEqual(Decimal(1).log10(), 0)
self.assertEqual(Decimal(10**2).log10(), 2)
self.assertEqual(Decimal(10**223).log10(), 223)
self.assertEqual(Decimal(10**19).logb(), 19)
self.assertEqual(Decimal(4).sqrt(), 2)
self.assertEqual(Decimal("40E9").sqrt(), Decimal('2.0E+5'))
self.assertEqual(divmod(Decimal(10), 3), (3, 1))
self.assertEqual(Decimal(10) // 3, 3)
self.assertEqual(Decimal(4) / 2, 2)
self.assertEqual(Decimal(400) ** -1, Decimal('0.0025'))
@requires_docstrings @requires_docstrings
@unittest.skipUnless(C, "test requires C version") @unittest.skipUnless(C, "test requires C version")
class SignatureTest(unittest.TestCase): class SignatureTest(unittest.TestCase):

View File

@ -3781,6 +3781,43 @@ mpd_qdiv(mpd_t *q, const mpd_t *a, const mpd_t *b,
const mpd_context_t *ctx, uint32_t *status) const mpd_context_t *ctx, uint32_t *status)
{ {
_mpd_qdiv(SET_IDEAL_EXP, q, a, b, ctx, status); _mpd_qdiv(SET_IDEAL_EXP, q, a, b, ctx, status);
if (*status & MPD_Malloc_error) {
/* Inexact quotients (the usual case) fill the entire context precision,
* which can lead to malloc() failures for very high precisions. Retry
* the operation with a lower precision in case the result is exact.
*
* We need an upper bound for the number of digits of a_coeff / b_coeff
* when the result is exact. If a_coeff' * 1 / b_coeff' is in lowest
* terms, then maxdigits(a_coeff') + maxdigits(1 / b_coeff') is a suitable
* bound.
*
* 1 / b_coeff' is exact iff b_coeff' exclusively has prime factors 2 or 5.
* The largest amount of digits is generated if b_coeff' is a power of 2 or
* a power of 5 and is less than or equal to log5(b_coeff') <= log2(b_coeff').
*
* We arrive at a total upper bound:
*
* maxdigits(a_coeff') + maxdigits(1 / b_coeff') <=
* a->digits + log2(b_coeff) =
* a->digits + log10(b_coeff) / log10(2) <=
* a->digits + b->digits * 4;
*/
uint32_t workstatus = 0;
mpd_context_t workctx = *ctx;
workctx.prec = a->digits + b->digits * 4;
if (workctx.prec >= ctx->prec) {
return; /* No point in retrying, keep the original error. */
}
_mpd_qdiv(SET_IDEAL_EXP, q, a, b, &workctx, &workstatus);
if (workstatus == 0) { /* The result is exact, unrounded, normal etc. */
*status = 0;
return;
}
mpd_seterror(q, *status, status);
}
} }
/* Internal function. */ /* Internal function. */
@ -7702,8 +7739,8 @@ mpd_qinvroot(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
/* END LIBMPDEC_ONLY */ /* END LIBMPDEC_ONLY */
/* Algorithm from decimal.py */ /* Algorithm from decimal.py */
void static void
mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx, _mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
uint32_t *status) uint32_t *status)
{ {
mpd_context_t maxcontext; mpd_context_t maxcontext;
@ -7836,6 +7873,40 @@ malloc_error:
goto out; goto out;
} }
void
mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
uint32_t *status)
{
_mpd_qsqrt(result, a, ctx, status);
if (*status & (MPD_Malloc_error|MPD_Division_impossible)) {
/* The above conditions can occur at very high context precisions
* if intermediate values get too large. Retry the operation with
* a lower context precision in case the result is exact.
*
* If the result is exact, an upper bound for the number of digits
* is the number of digits in the input.
*
* NOTE: sqrt(40e9) = 2.0e+5 /\ digits(40e9) = digits(2.0e+5) = 2
*/
uint32_t workstatus = 0;
mpd_context_t workctx = *ctx;
workctx.prec = a->digits;
if (workctx.prec >= ctx->prec) {
return; /* No point in repeating this, keep the original error. */
}
_mpd_qsqrt(result, a, &workctx, &workstatus);
if (workstatus == 0) {
*status = 0;
return;
}
mpd_seterror(result, *status, status);
}
}
/******************************************************************************/ /******************************************************************************/
/* Base conversions */ /* Base conversions */

View File

@ -125,6 +125,12 @@ ContextFunctions = {
'special': ('context.__reduce_ex__', 'context.create_decimal_from_float') 'special': ('context.__reduce_ex__', 'context.create_decimal_from_float')
} }
# Functions that set no context flags but whose result can differ depending
# on prec, Emin and Emax.
MaxContextSkip = ['is_normal', 'is_subnormal', 'logical_invert', 'next_minus',
'next_plus', 'number_class', 'logical_and', 'logical_or',
'logical_xor', 'next_toward', 'rotate', 'shift']
# Functions that require a restricted exponent range for reasonable runtimes. # Functions that require a restricted exponent range for reasonable runtimes.
UnaryRestricted = [ UnaryRestricted = [
'__ceil__', '__floor__', '__int__', '__trunc__', '__ceil__', '__floor__', '__int__', '__trunc__',
@ -344,6 +350,20 @@ class TestSet(object):
self.pex = RestrictedList() # Python exceptions for P.Decimal self.pex = RestrictedList() # Python exceptions for P.Decimal
self.presults = RestrictedList() # P.Decimal results self.presults = RestrictedList() # P.Decimal results
# If the above results are exact, unrounded and not clamped, repeat
# the operation with a maxcontext to ensure that huge intermediate
# values do not cause a MemoryError.
self.with_maxcontext = False
self.maxcontext = context.c.copy()
self.maxcontext.prec = C.MAX_PREC
self.maxcontext.Emax = C.MAX_EMAX
self.maxcontext.Emin = C.MIN_EMIN
self.maxcontext.clear_flags()
self.maxop = RestrictedList() # converted C.Decimal operands
self.maxex = RestrictedList() # Python exceptions for C.Decimal
self.maxresults = RestrictedList() # C.Decimal results
# ====================================================================== # ======================================================================
# SkipHandler: skip known discrepancies # SkipHandler: skip known discrepancies
@ -545,13 +565,17 @@ def function_as_string(t):
if t.contextfunc: if t.contextfunc:
cargs = t.cop cargs = t.cop
pargs = t.pop pargs = t.pop
maxargs = t.maxop
cfunc = "c_func: %s(" % t.funcname cfunc = "c_func: %s(" % t.funcname
pfunc = "p_func: %s(" % t.funcname pfunc = "p_func: %s(" % t.funcname
maxfunc = "max_func: %s(" % t.funcname
else: else:
cself, cargs = t.cop[0], t.cop[1:] cself, cargs = t.cop[0], t.cop[1:]
pself, pargs = t.pop[0], t.pop[1:] pself, pargs = t.pop[0], t.pop[1:]
maxself, maxargs = t.maxop[0], t.maxop[1:]
cfunc = "c_func: %s.%s(" % (repr(cself), t.funcname) cfunc = "c_func: %s.%s(" % (repr(cself), t.funcname)
pfunc = "p_func: %s.%s(" % (repr(pself), t.funcname) pfunc = "p_func: %s.%s(" % (repr(pself), t.funcname)
maxfunc = "max_func: %s.%s(" % (repr(maxself), t.funcname)
err = cfunc err = cfunc
for arg in cargs: for arg in cargs:
@ -565,6 +589,14 @@ def function_as_string(t):
err = err.rstrip(", ") err = err.rstrip(", ")
err += ")" err += ")"
if t.with_maxcontext:
err += "\n"
err += maxfunc
for arg in maxargs:
err += "%s, " % repr(arg)
err = err.rstrip(", ")
err += ")"
return err return err
def raise_error(t): def raise_error(t):
@ -577,9 +609,24 @@ def raise_error(t):
err = "Error in %s:\n\n" % t.funcname err = "Error in %s:\n\n" % t.funcname
err += "input operands: %s\n\n" % (t.op,) err += "input operands: %s\n\n" % (t.op,)
err += function_as_string(t) err += function_as_string(t)
err += "\n\nc_result: %s\np_result: %s\n\n" % (t.cresults, t.presults)
err += "c_exceptions: %s\np_exceptions: %s\n\n" % (t.cex, t.pex) err += "\n\nc_result: %s\np_result: %s\n" % (t.cresults, t.presults)
err += "%s\n\n" % str(t.context) if t.with_maxcontext:
err += "max_result: %s\n\n" % (t.maxresults)
else:
err += "\n"
err += "c_exceptions: %s\np_exceptions: %s\n" % (t.cex, t.pex)
if t.with_maxcontext:
err += "max_exceptions: %s\n\n" % t.maxex
else:
err += "\n"
err += "%s\n" % str(t.context)
if t.with_maxcontext:
err += "%s\n" % str(t.maxcontext)
else:
err += "\n"
raise VerifyError(err) raise VerifyError(err)
@ -603,6 +650,13 @@ def raise_error(t):
# are printed to stdout. # are printed to stdout.
# ====================================================================== # ======================================================================
def all_nan(a):
if isinstance(a, C.Decimal):
return a.is_nan()
elif isinstance(a, tuple):
return all(all_nan(v) for v in a)
return False
def convert(t, convstr=True): def convert(t, convstr=True):
""" t is the testset. At this stage the testset contains a tuple of """ t is the testset. At this stage the testset contains a tuple of
operands t.op of various types. For decimal methods the first operands t.op of various types. For decimal methods the first
@ -617,10 +671,12 @@ def convert(t, convstr=True):
for i, op in enumerate(t.op): for i, op in enumerate(t.op):
context.clear_status() context.clear_status()
t.maxcontext.clear_flags()
if op in RoundModes: if op in RoundModes:
t.cop.append(op) t.cop.append(op)
t.pop.append(op) t.pop.append(op)
t.maxop.append(op)
elif not t.contextfunc and i == 0 or \ elif not t.contextfunc and i == 0 or \
convstr and isinstance(op, str): convstr and isinstance(op, str):
@ -638,11 +694,25 @@ def convert(t, convstr=True):
p = None p = None
pex = e.__class__ pex = e.__class__
try:
C.setcontext(t.maxcontext)
maxop = C.Decimal(op)
maxex = None
except (TypeError, ValueError, OverflowError) as e:
maxop = None
maxex = e.__class__
finally:
C.setcontext(context.c)
t.cop.append(c) t.cop.append(c)
t.cex.append(cex) t.cex.append(cex)
t.pop.append(p) t.pop.append(p)
t.pex.append(pex) t.pex.append(pex)
t.maxop.append(maxop)
t.maxex.append(maxex)
if cex is pex: if cex is pex:
if str(c) != str(p) or not context.assert_eq_status(): if str(c) != str(p) or not context.assert_eq_status():
raise_error(t) raise_error(t)
@ -652,14 +722,21 @@ def convert(t, convstr=True):
else: else:
raise_error(t) raise_error(t)
# The exceptions in the maxcontext operation can legitimately
# differ, only test that maxex implies cex:
if maxex is not None and cex is not maxex:
raise_error(t)
elif isinstance(op, Context): elif isinstance(op, Context):
t.context = op t.context = op
t.cop.append(op.c) t.cop.append(op.c)
t.pop.append(op.p) t.pop.append(op.p)
t.maxop.append(t.maxcontext)
else: else:
t.cop.append(op) t.cop.append(op)
t.pop.append(op) t.pop.append(op)
t.maxop.append(op)
return 1 return 1
@ -673,6 +750,7 @@ def callfuncs(t):
t.rc and t.rp are the results of the operation. t.rc and t.rp are the results of the operation.
""" """
context.clear_status() context.clear_status()
t.maxcontext.clear_flags()
try: try:
if t.contextfunc: if t.contextfunc:
@ -700,6 +778,35 @@ def callfuncs(t):
t.rp = None t.rp = None
t.pex.append(e.__class__) t.pex.append(e.__class__)
# If the above results are exact, unrounded, normal etc., repeat the
# operation with a maxcontext to ensure that huge intermediate values
# do not cause a MemoryError.
if (t.funcname not in MaxContextSkip and
not context.c.flags[C.InvalidOperation] and
not context.c.flags[C.Inexact] and
not context.c.flags[C.Rounded] and
not context.c.flags[C.Subnormal] and
not context.c.flags[C.Clamped] and
not context.clamp and # results are padded to context.prec if context.clamp==1.
not any(isinstance(v, C.Context) for v in t.cop)): # another context is used.
t.with_maxcontext = True
try:
if t.contextfunc:
maxargs = t.maxop
t.rmax = getattr(t.maxcontext, t.funcname)(*maxargs)
else:
maxself = t.maxop[0]
maxargs = t.maxop[1:]
try:
C.setcontext(t.maxcontext)
t.rmax = getattr(maxself, t.funcname)(*maxargs)
finally:
C.setcontext(context.c)
t.maxex.append(None)
except (TypeError, ValueError, OverflowError, MemoryError) as e:
t.rmax = None
t.maxex.append(e.__class__)
def verify(t, stat): def verify(t, stat):
""" t is the testset. At this stage the testset contains the following """ t is the testset. At this stage the testset contains the following
tuples: tuples:
@ -714,6 +821,9 @@ def verify(t, stat):
""" """
t.cresults.append(str(t.rc)) t.cresults.append(str(t.rc))
t.presults.append(str(t.rp)) t.presults.append(str(t.rp))
if t.with_maxcontext:
t.maxresults.append(str(t.rmax))
if isinstance(t.rc, C.Decimal) and isinstance(t.rp, P.Decimal): if isinstance(t.rc, C.Decimal) and isinstance(t.rp, P.Decimal):
# General case: both results are Decimals. # General case: both results are Decimals.
t.cresults.append(t.rc.to_eng_string()) t.cresults.append(t.rc.to_eng_string())
@ -725,6 +835,12 @@ def verify(t, stat):
t.presults.append(str(t.rp.imag)) t.presults.append(str(t.rp.imag))
t.presults.append(str(t.rp.real)) t.presults.append(str(t.rp.real))
if t.with_maxcontext and isinstance(t.rmax, C.Decimal):
t.maxresults.append(t.rmax.to_eng_string())
t.maxresults.append(t.rmax.as_tuple())
t.maxresults.append(str(t.rmax.imag))
t.maxresults.append(str(t.rmax.real))
nc = t.rc.number_class().lstrip('+-s') nc = t.rc.number_class().lstrip('+-s')
stat[nc] += 1 stat[nc] += 1
else: else:
@ -732,6 +848,9 @@ def verify(t, stat):
if not isinstance(t.rc, tuple) and not isinstance(t.rp, tuple): if not isinstance(t.rc, tuple) and not isinstance(t.rp, tuple):
if t.rc != t.rp: if t.rc != t.rp:
raise_error(t) raise_error(t)
if t.with_maxcontext and not isinstance(t.rmax, tuple):
if t.rmax != t.rc:
raise_error(t)
stat[type(t.rc).__name__] += 1 stat[type(t.rc).__name__] += 1
# The return value lists must be equal. # The return value lists must be equal.
@ -744,6 +863,20 @@ def verify(t, stat):
if not t.context.assert_eq_status(): if not t.context.assert_eq_status():
raise_error(t) raise_error(t)
if t.with_maxcontext:
# NaN payloads etc. depend on precision and clamp.
if all_nan(t.rc) and all_nan(t.rmax):
return
# The return value lists must be equal.
if t.maxresults != t.cresults:
raise_error(t)
# The Python exception lists (TypeError, etc.) must be equal.
if t.maxex != t.cex:
raise_error(t)
# The context flags must be equal.
if t.maxcontext.flags != t.context.c.flags:
raise_error(t)
# ====================================================================== # ======================================================================
# Main test loops # Main test loops