bpo-39576: Prevent memory error for overly optimistic precisions (GH-18581) (#18584)
(cherry picked from commit 90930e6545
)
Authored-by: Stefan Krah <skrah@bytereef.org>
This commit is contained in:
parent
d0a464e31a
commit
b6271025c6
|
@ -5476,6 +5476,41 @@ class CWhitebox(unittest.TestCase):
|
|||
self.assertEqual(Decimal.from_float(cls(101.1)),
|
||||
Decimal.from_float(101.1))
|
||||
|
||||
def test_maxcontext_exact_arith(self):
|
||||
|
||||
# Make sure that exact operations do not raise MemoryError due
|
||||
# to huge intermediate values when the context precision is very
|
||||
# large.
|
||||
|
||||
# The following functions fill the available precision and are
|
||||
# therefore not suitable for large precisions (by design of the
|
||||
# specification).
|
||||
MaxContextSkip = ['logical_invert', 'next_minus', 'next_plus',
|
||||
'logical_and', 'logical_or', 'logical_xor',
|
||||
'next_toward', 'rotate', 'shift']
|
||||
|
||||
Decimal = C.Decimal
|
||||
Context = C.Context
|
||||
localcontext = C.localcontext
|
||||
|
||||
# Here only some functions that are likely candidates for triggering a
|
||||
# MemoryError are tested. deccheck.py has an exhaustive test.
|
||||
maxcontext = Context(prec=C.MAX_PREC, Emin=C.MIN_EMIN, Emax=C.MAX_EMAX)
|
||||
with localcontext(maxcontext):
|
||||
self.assertEqual(Decimal(0).exp(), 1)
|
||||
self.assertEqual(Decimal(1).ln(), 0)
|
||||
self.assertEqual(Decimal(1).log10(), 0)
|
||||
self.assertEqual(Decimal(10**2).log10(), 2)
|
||||
self.assertEqual(Decimal(10**223).log10(), 223)
|
||||
self.assertEqual(Decimal(10**19).logb(), 19)
|
||||
self.assertEqual(Decimal(4).sqrt(), 2)
|
||||
self.assertEqual(Decimal("40E9").sqrt(), Decimal('2.0E+5'))
|
||||
self.assertEqual(divmod(Decimal(10), 3), (3, 1))
|
||||
self.assertEqual(Decimal(10) // 3, 3)
|
||||
self.assertEqual(Decimal(4) / 2, 2)
|
||||
self.assertEqual(Decimal(400) ** -1, Decimal('0.0025'))
|
||||
|
||||
|
||||
@requires_docstrings
|
||||
@unittest.skipUnless(C, "test requires C version")
|
||||
class SignatureTest(unittest.TestCase):
|
||||
|
|
|
@ -3781,6 +3781,43 @@ mpd_qdiv(mpd_t *q, const mpd_t *a, const mpd_t *b,
|
|||
const mpd_context_t *ctx, uint32_t *status)
|
||||
{
|
||||
_mpd_qdiv(SET_IDEAL_EXP, q, a, b, ctx, status);
|
||||
|
||||
if (*status & MPD_Malloc_error) {
|
||||
/* Inexact quotients (the usual case) fill the entire context precision,
|
||||
* which can lead to malloc() failures for very high precisions. Retry
|
||||
* the operation with a lower precision in case the result is exact.
|
||||
*
|
||||
* We need an upper bound for the number of digits of a_coeff / b_coeff
|
||||
* when the result is exact. If a_coeff' * 1 / b_coeff' is in lowest
|
||||
* terms, then maxdigits(a_coeff') + maxdigits(1 / b_coeff') is a suitable
|
||||
* bound.
|
||||
*
|
||||
* 1 / b_coeff' is exact iff b_coeff' exclusively has prime factors 2 or 5.
|
||||
* The largest amount of digits is generated if b_coeff' is a power of 2 or
|
||||
* a power of 5 and is less than or equal to log5(b_coeff') <= log2(b_coeff').
|
||||
*
|
||||
* We arrive at a total upper bound:
|
||||
*
|
||||
* maxdigits(a_coeff') + maxdigits(1 / b_coeff') <=
|
||||
* a->digits + log2(b_coeff) =
|
||||
* a->digits + log10(b_coeff) / log10(2) <=
|
||||
* a->digits + b->digits * 4;
|
||||
*/
|
||||
uint32_t workstatus = 0;
|
||||
mpd_context_t workctx = *ctx;
|
||||
workctx.prec = a->digits + b->digits * 4;
|
||||
if (workctx.prec >= ctx->prec) {
|
||||
return; /* No point in retrying, keep the original error. */
|
||||
}
|
||||
|
||||
_mpd_qdiv(SET_IDEAL_EXP, q, a, b, &workctx, &workstatus);
|
||||
if (workstatus == 0) { /* The result is exact, unrounded, normal etc. */
|
||||
*status = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
mpd_seterror(q, *status, status);
|
||||
}
|
||||
}
|
||||
|
||||
/* Internal function. */
|
||||
|
@ -7702,9 +7739,9 @@ mpd_qinvroot(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
|
|||
/* END LIBMPDEC_ONLY */
|
||||
|
||||
/* Algorithm from decimal.py */
|
||||
void
|
||||
mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
|
||||
uint32_t *status)
|
||||
static void
|
||||
_mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
|
||||
uint32_t *status)
|
||||
{
|
||||
mpd_context_t maxcontext;
|
||||
MPD_NEW_STATIC(c,0,0,0,0);
|
||||
|
@ -7836,6 +7873,40 @@ malloc_error:
|
|||
goto out;
|
||||
}
|
||||
|
||||
void
|
||||
mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
|
||||
uint32_t *status)
|
||||
{
|
||||
_mpd_qsqrt(result, a, ctx, status);
|
||||
|
||||
if (*status & (MPD_Malloc_error|MPD_Division_impossible)) {
|
||||
/* The above conditions can occur at very high context precisions
|
||||
* if intermediate values get too large. Retry the operation with
|
||||
* a lower context precision in case the result is exact.
|
||||
*
|
||||
* If the result is exact, an upper bound for the number of digits
|
||||
* is the number of digits in the input.
|
||||
*
|
||||
* NOTE: sqrt(40e9) = 2.0e+5 /\ digits(40e9) = digits(2.0e+5) = 2
|
||||
*/
|
||||
uint32_t workstatus = 0;
|
||||
mpd_context_t workctx = *ctx;
|
||||
workctx.prec = a->digits;
|
||||
|
||||
if (workctx.prec >= ctx->prec) {
|
||||
return; /* No point in repeating this, keep the original error. */
|
||||
}
|
||||
|
||||
_mpd_qsqrt(result, a, &workctx, &workstatus);
|
||||
if (workstatus == 0) {
|
||||
*status = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
mpd_seterror(result, *status, status);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/******************************************************************************/
|
||||
/* Base conversions */
|
||||
|
|
|
@ -125,6 +125,12 @@ ContextFunctions = {
|
|||
'special': ('context.__reduce_ex__', 'context.create_decimal_from_float')
|
||||
}
|
||||
|
||||
# Functions that set no context flags but whose result can differ depending
|
||||
# on prec, Emin and Emax.
|
||||
MaxContextSkip = ['is_normal', 'is_subnormal', 'logical_invert', 'next_minus',
|
||||
'next_plus', 'number_class', 'logical_and', 'logical_or',
|
||||
'logical_xor', 'next_toward', 'rotate', 'shift']
|
||||
|
||||
# Functions that require a restricted exponent range for reasonable runtimes.
|
||||
UnaryRestricted = [
|
||||
'__ceil__', '__floor__', '__int__', '__trunc__',
|
||||
|
@ -344,6 +350,20 @@ class TestSet(object):
|
|||
self.pex = RestrictedList() # Python exceptions for P.Decimal
|
||||
self.presults = RestrictedList() # P.Decimal results
|
||||
|
||||
# If the above results are exact, unrounded and not clamped, repeat
|
||||
# the operation with a maxcontext to ensure that huge intermediate
|
||||
# values do not cause a MemoryError.
|
||||
self.with_maxcontext = False
|
||||
self.maxcontext = context.c.copy()
|
||||
self.maxcontext.prec = C.MAX_PREC
|
||||
self.maxcontext.Emax = C.MAX_EMAX
|
||||
self.maxcontext.Emin = C.MIN_EMIN
|
||||
self.maxcontext.clear_flags()
|
||||
|
||||
self.maxop = RestrictedList() # converted C.Decimal operands
|
||||
self.maxex = RestrictedList() # Python exceptions for C.Decimal
|
||||
self.maxresults = RestrictedList() # C.Decimal results
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# SkipHandler: skip known discrepancies
|
||||
|
@ -545,13 +565,17 @@ def function_as_string(t):
|
|||
if t.contextfunc:
|
||||
cargs = t.cop
|
||||
pargs = t.pop
|
||||
maxargs = t.maxop
|
||||
cfunc = "c_func: %s(" % t.funcname
|
||||
pfunc = "p_func: %s(" % t.funcname
|
||||
maxfunc = "max_func: %s(" % t.funcname
|
||||
else:
|
||||
cself, cargs = t.cop[0], t.cop[1:]
|
||||
pself, pargs = t.pop[0], t.pop[1:]
|
||||
maxself, maxargs = t.maxop[0], t.maxop[1:]
|
||||
cfunc = "c_func: %s.%s(" % (repr(cself), t.funcname)
|
||||
pfunc = "p_func: %s.%s(" % (repr(pself), t.funcname)
|
||||
maxfunc = "max_func: %s.%s(" % (repr(maxself), t.funcname)
|
||||
|
||||
err = cfunc
|
||||
for arg in cargs:
|
||||
|
@ -565,6 +589,14 @@ def function_as_string(t):
|
|||
err = err.rstrip(", ")
|
||||
err += ")"
|
||||
|
||||
if t.with_maxcontext:
|
||||
err += "\n"
|
||||
err += maxfunc
|
||||
for arg in maxargs:
|
||||
err += "%s, " % repr(arg)
|
||||
err = err.rstrip(", ")
|
||||
err += ")"
|
||||
|
||||
return err
|
||||
|
||||
def raise_error(t):
|
||||
|
@ -577,9 +609,24 @@ def raise_error(t):
|
|||
err = "Error in %s:\n\n" % t.funcname
|
||||
err += "input operands: %s\n\n" % (t.op,)
|
||||
err += function_as_string(t)
|
||||
err += "\n\nc_result: %s\np_result: %s\n\n" % (t.cresults, t.presults)
|
||||
err += "c_exceptions: %s\np_exceptions: %s\n\n" % (t.cex, t.pex)
|
||||
err += "%s\n\n" % str(t.context)
|
||||
|
||||
err += "\n\nc_result: %s\np_result: %s\n" % (t.cresults, t.presults)
|
||||
if t.with_maxcontext:
|
||||
err += "max_result: %s\n\n" % (t.maxresults)
|
||||
else:
|
||||
err += "\n"
|
||||
|
||||
err += "c_exceptions: %s\np_exceptions: %s\n" % (t.cex, t.pex)
|
||||
if t.with_maxcontext:
|
||||
err += "max_exceptions: %s\n\n" % t.maxex
|
||||
else:
|
||||
err += "\n"
|
||||
|
||||
err += "%s\n" % str(t.context)
|
||||
if t.with_maxcontext:
|
||||
err += "%s\n" % str(t.maxcontext)
|
||||
else:
|
||||
err += "\n"
|
||||
|
||||
raise VerifyError(err)
|
||||
|
||||
|
@ -603,6 +650,13 @@ def raise_error(t):
|
|||
# are printed to stdout.
|
||||
# ======================================================================
|
||||
|
||||
def all_nan(a):
|
||||
if isinstance(a, C.Decimal):
|
||||
return a.is_nan()
|
||||
elif isinstance(a, tuple):
|
||||
return all(all_nan(v) for v in a)
|
||||
return False
|
||||
|
||||
def convert(t, convstr=True):
|
||||
""" t is the testset. At this stage the testset contains a tuple of
|
||||
operands t.op of various types. For decimal methods the first
|
||||
|
@ -617,10 +671,12 @@ def convert(t, convstr=True):
|
|||
for i, op in enumerate(t.op):
|
||||
|
||||
context.clear_status()
|
||||
t.maxcontext.clear_flags()
|
||||
|
||||
if op in RoundModes:
|
||||
t.cop.append(op)
|
||||
t.pop.append(op)
|
||||
t.maxop.append(op)
|
||||
|
||||
elif not t.contextfunc and i == 0 or \
|
||||
convstr and isinstance(op, str):
|
||||
|
@ -638,11 +694,25 @@ def convert(t, convstr=True):
|
|||
p = None
|
||||
pex = e.__class__
|
||||
|
||||
try:
|
||||
C.setcontext(t.maxcontext)
|
||||
maxop = C.Decimal(op)
|
||||
maxex = None
|
||||
except (TypeError, ValueError, OverflowError) as e:
|
||||
maxop = None
|
||||
maxex = e.__class__
|
||||
finally:
|
||||
C.setcontext(context.c)
|
||||
|
||||
t.cop.append(c)
|
||||
t.cex.append(cex)
|
||||
|
||||
t.pop.append(p)
|
||||
t.pex.append(pex)
|
||||
|
||||
t.maxop.append(maxop)
|
||||
t.maxex.append(maxex)
|
||||
|
||||
if cex is pex:
|
||||
if str(c) != str(p) or not context.assert_eq_status():
|
||||
raise_error(t)
|
||||
|
@ -652,14 +722,21 @@ def convert(t, convstr=True):
|
|||
else:
|
||||
raise_error(t)
|
||||
|
||||
# The exceptions in the maxcontext operation can legitimately
|
||||
# differ, only test that maxex implies cex:
|
||||
if maxex is not None and cex is not maxex:
|
||||
raise_error(t)
|
||||
|
||||
elif isinstance(op, Context):
|
||||
t.context = op
|
||||
t.cop.append(op.c)
|
||||
t.pop.append(op.p)
|
||||
t.maxop.append(t.maxcontext)
|
||||
|
||||
else:
|
||||
t.cop.append(op)
|
||||
t.pop.append(op)
|
||||
t.maxop.append(op)
|
||||
|
||||
return 1
|
||||
|
||||
|
@ -673,6 +750,7 @@ def callfuncs(t):
|
|||
t.rc and t.rp are the results of the operation.
|
||||
"""
|
||||
context.clear_status()
|
||||
t.maxcontext.clear_flags()
|
||||
|
||||
try:
|
||||
if t.contextfunc:
|
||||
|
@ -700,6 +778,35 @@ def callfuncs(t):
|
|||
t.rp = None
|
||||
t.pex.append(e.__class__)
|
||||
|
||||
# If the above results are exact, unrounded, normal etc., repeat the
|
||||
# operation with a maxcontext to ensure that huge intermediate values
|
||||
# do not cause a MemoryError.
|
||||
if (t.funcname not in MaxContextSkip and
|
||||
not context.c.flags[C.InvalidOperation] and
|
||||
not context.c.flags[C.Inexact] and
|
||||
not context.c.flags[C.Rounded] and
|
||||
not context.c.flags[C.Subnormal] and
|
||||
not context.c.flags[C.Clamped] and
|
||||
not context.clamp and # results are padded to context.prec if context.clamp==1.
|
||||
not any(isinstance(v, C.Context) for v in t.cop)): # another context is used.
|
||||
t.with_maxcontext = True
|
||||
try:
|
||||
if t.contextfunc:
|
||||
maxargs = t.maxop
|
||||
t.rmax = getattr(t.maxcontext, t.funcname)(*maxargs)
|
||||
else:
|
||||
maxself = t.maxop[0]
|
||||
maxargs = t.maxop[1:]
|
||||
try:
|
||||
C.setcontext(t.maxcontext)
|
||||
t.rmax = getattr(maxself, t.funcname)(*maxargs)
|
||||
finally:
|
||||
C.setcontext(context.c)
|
||||
t.maxex.append(None)
|
||||
except (TypeError, ValueError, OverflowError, MemoryError) as e:
|
||||
t.rmax = None
|
||||
t.maxex.append(e.__class__)
|
||||
|
||||
def verify(t, stat):
|
||||
""" t is the testset. At this stage the testset contains the following
|
||||
tuples:
|
||||
|
@ -714,6 +821,9 @@ def verify(t, stat):
|
|||
"""
|
||||
t.cresults.append(str(t.rc))
|
||||
t.presults.append(str(t.rp))
|
||||
if t.with_maxcontext:
|
||||
t.maxresults.append(str(t.rmax))
|
||||
|
||||
if isinstance(t.rc, C.Decimal) and isinstance(t.rp, P.Decimal):
|
||||
# General case: both results are Decimals.
|
||||
t.cresults.append(t.rc.to_eng_string())
|
||||
|
@ -725,6 +835,12 @@ def verify(t, stat):
|
|||
t.presults.append(str(t.rp.imag))
|
||||
t.presults.append(str(t.rp.real))
|
||||
|
||||
if t.with_maxcontext and isinstance(t.rmax, C.Decimal):
|
||||
t.maxresults.append(t.rmax.to_eng_string())
|
||||
t.maxresults.append(t.rmax.as_tuple())
|
||||
t.maxresults.append(str(t.rmax.imag))
|
||||
t.maxresults.append(str(t.rmax.real))
|
||||
|
||||
nc = t.rc.number_class().lstrip('+-s')
|
||||
stat[nc] += 1
|
||||
else:
|
||||
|
@ -732,6 +848,9 @@ def verify(t, stat):
|
|||
if not isinstance(t.rc, tuple) and not isinstance(t.rp, tuple):
|
||||
if t.rc != t.rp:
|
||||
raise_error(t)
|
||||
if t.with_maxcontext and not isinstance(t.rmax, tuple):
|
||||
if t.rmax != t.rc:
|
||||
raise_error(t)
|
||||
stat[type(t.rc).__name__] += 1
|
||||
|
||||
# The return value lists must be equal.
|
||||
|
@ -744,6 +863,20 @@ def verify(t, stat):
|
|||
if not t.context.assert_eq_status():
|
||||
raise_error(t)
|
||||
|
||||
if t.with_maxcontext:
|
||||
# NaN payloads etc. depend on precision and clamp.
|
||||
if all_nan(t.rc) and all_nan(t.rmax):
|
||||
return
|
||||
# The return value lists must be equal.
|
||||
if t.maxresults != t.cresults:
|
||||
raise_error(t)
|
||||
# The Python exception lists (TypeError, etc.) must be equal.
|
||||
if t.maxex != t.cex:
|
||||
raise_error(t)
|
||||
# The context flags must be equal.
|
||||
if t.maxcontext.flags != t.context.c.flags:
|
||||
raise_error(t)
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# Main test loops
|
||||
|
|
Loading…
Reference in New Issue