Covert pickle tests to use unittest.

Extend tests to cover a few more cases.  For cPickle, test several of
the undocumented features.
This commit is contained in:
Jeremy Hylton 2001-10-15 21:38:56 +00:00
parent 499ab6a653
commit 6642653875
5 changed files with 285 additions and 184 deletions

View File

@ -1,13 +0,0 @@
test_cpickle
dumps()
loads()
ok
loads() DATA
ok
dumps() binary
loads() binary
ok
loads() BINDATA
ok
dumps() RECURSIVE
ok

View File

@ -1,13 +0,0 @@
test_pickle
dumps()
loads()
ok
loads() DATA
ok
dumps() binary
loads() binary
ok
loads() BINDATA
ok
dumps() RECURSIVE
ok

View File

@ -1,9 +1,27 @@
# test_pickle and test_cpickle both use this.
import unittest
from test_support import TestFailed, have_unicode
import sys
# break into multiple strings to please font-lock-mode
class C:
def __cmp__(self, other):
return cmp(self.__dict__, other.__dict__)
import __main__
__main__.C = C
C.__module__ = "__main__"
class myint(int):
def __init__(self, x):
self.str = str(x)
class initarg(C):
def __init__(self, a, b):
self.a = a
self.b = b
def __getinitargs__(self):
return self.a, self.b
# break into multiple strings to avoid confusing font-lock-mode
DATA = """(lp1
I0
aL1L
@ -58,17 +76,7 @@ BINDATA = ']q\x01(K\x00L1L\nG@\x00\x00\x00\x00\x00\x00\x00' + \
'C\nq\x05oq\x06}q\x07(U\x03fooq\x08K\x01U\x03barq\tK\x02ubh' + \
'\x06tq\nh\nK\x05e.'
class C:
def __cmp__(self, other):
return cmp(self.__dict__, other.__dict__)
import __main__
__main__.C = C
C.__module__ = "__main__"
# Call this with the module to be tested (pickle or cPickle).
def dotest(pickle):
def create_data():
c = C()
c.foo = 1
c.bar = 2
@ -86,153 +94,159 @@ def dotest(pickle):
x.append(y)
x.append(y)
x.append(5)
r = []
r.append(r)
return x
print "dumps()"
s = pickle.dumps(x)
class AbstractPickleTests(unittest.TestCase):
print "loads()"
x2 = pickle.loads(s)
if x2 == x:
print "ok"
else:
print "bad"
_testdata = create_data()
print "loads() DATA"
x2 = pickle.loads(DATA)
if x2 == x:
print "ok"
else:
print "bad"
print "dumps() binary"
s = pickle.dumps(x, 1)
print "loads() binary"
x2 = pickle.loads(s)
if x2 == x:
print "ok"
else:
print "bad"
print "loads() BINDATA"
x2 = pickle.loads(BINDATA)
if x2 == x:
print "ok"
else:
print "bad"
print "dumps() RECURSIVE"
s = pickle.dumps(r)
x2 = pickle.loads(s)
if x2 == r:
print "ok"
else:
print "bad"
# don't create cyclic garbage
del x2[0]
del r[0]
# Test protection against closed files
import tempfile, os
fn = tempfile.mktemp()
f = open(fn, "w")
f.close()
try:
pickle.dump(123, f)
except ValueError:
def setUp(self):
# subclass must define self.dumps, self.loads, self.error
pass
else:
print "dump to closed file should raise ValueError"
f = open(fn, "r")
f.close()
try:
pickle.load(f)
except ValueError:
pass
else:
print "load from closed file should raise ValueError"
os.remove(fn)
def test_misc(self):
# test various datatypes not tested by testdata
x = myint(4)
s = self.dumps(x)
y = self.loads(s)
self.assertEqual(x, y)
# Test specific bad cases
for i in range(10):
try:
x = pickle.loads('garyp')
except KeyError, y:
# pickle
del y
except pickle.BadPickleGet, y:
# cPickle
del y
else:
print "unexpected success!"
break
x = (1, ())
s = self.dumps(x)
y = self.loads(s)
self.assertEqual(x, y)
# Test insecure strings
insecure = ["abc", "2 + 2", # not quoted
"'abc' + 'def'", # not a single quoted string
"'abc", # quote is not closed
"'abc\"", # open quote and close quote don't match
"'abc' ?", # junk after close quote
# some tests of the quoting rules
"'abc\"\''",
"'\\\\a\'\'\'\\\'\\\\\''",
]
for s in insecure:
buf = "S" + s + "\012p0\012."
try:
x = pickle.loads(buf)
except ValueError:
pass
else:
print "accepted insecure string: %s" % repr(buf)
x = initarg(1, x)
s = self.dumps(x)
y = self.loads(s)
self.assertEqual(x, y)
# XXX test __reduce__ protocol?
def test_identity(self):
s = self.dumps(self._testdata)
x = self.loads(s)
self.assertEqual(x, self._testdata)
def test_constant(self):
x = self.loads(DATA)
self.assertEqual(x, self._testdata)
x = self.loads(BINDATA)
self.assertEqual(x, self._testdata)
def test_binary(self):
s = self.dumps(self._testdata, 1)
x = self.loads(s)
self.assertEqual(x, self._testdata)
def test_recursive_list(self):
l = []
l.append(l)
s = self.dumps(l)
x = self.loads(s)
self.assertEqual(x, l)
self.assertEqual(x, x[0])
self.assertEqual(id(x), id(x[0]))
def test_recursive_dict(self):
d = {}
d[1] = d
s = self.dumps(d)
x = self.loads(s)
self.assertEqual(x, d)
self.assertEqual(x[1], x)
self.assertEqual(id(x[1]), id(x))
def test_recursive_inst(self):
i = C()
i.attr = i
s = self.dumps(i)
x = self.loads(s)
self.assertEqual(x, i)
self.assertEqual(x.attr, x)
self.assertEqual(id(x.attr), id(x))
def test_recursive_multi(self):
l = []
d = {1:l}
i = C()
i.attr = d
l.append(i)
s = self.dumps(l)
x = self.loads(s)
self.assertEqual(x, l)
self.assertEqual(x[0], i)
self.assertEqual(x[0].attr, d)
self.assertEqual(x[0].attr[1], x)
self.assertEqual(x[0].attr[1][0], i)
self.assertEqual(x[0].attr[1][0].attr, d)
def test_garyp(self):
self.assertRaises(self.error, self.loads, 'garyp')
def test_insecure_strings(self):
insecure = ["abc", "2 + 2", # not quoted
"'abc' + 'def'", # not a single quoted string
"'abc", # quote is not closed
"'abc\"", # open quote and close quote don't match
"'abc' ?", # junk after close quote
# some tests of the quoting rules
"'abc\"\''",
"'\\\\a\'\'\'\\\'\\\\\''",
]
for s in insecure:
buf = "S" + s + "\012p0\012."
self.assertRaises(ValueError, self.loads, buf)
# Test some Unicode end cases
if have_unicode:
endcases = [unicode(''), unicode('<\\u>'), unicode('<\\\u1234>'),
unicode('<\n>'), unicode('<\\>')]
else:
endcases = []
for u in endcases:
try:
u2 = pickle.loads(pickle.dumps(u))
except Exception, msg:
print "Endcase exception: %s => %s(%s)" % \
(`u`, msg.__class__.__name__, str(msg))
else:
if u2 != u:
print "Endcase failure: %s => %s" % (`u`, `u2`)
def test_unicode(self):
endcases = [unicode(''), unicode('<\\u>'), unicode('<\\\u1234>'),
unicode('<\n>'), unicode('<\\>')]
for u in endcases:
p = self.dumps(u)
u2 = self.loads(p)
self.assertEqual(u2, u)
# Test the full range of Python ints.
n = sys.maxint
while n:
for expected in (-n, n):
for binary_mode in (0, 1):
s = pickle.dumps(expected, binary_mode)
got = pickle.loads(s)
if expected != got:
raise TestFailed("for %s-mode pickle of %d, pickle "
"string is %s, loaded back as %s" % (
binary_mode and "binary" or "text",
expected,
repr(s),
got))
n = n >> 1
def test_ints(self):
import sys
n = sys.maxint
while n:
for expected in (-n, n):
s = self.dumps(expected)
n2 = self.loads(s)
self.assertEqual(expected, n2)
n = n >> 1
# Fake a pickle from a sizeof(long)==8 box.
maxint64 = (1L << 63) - 1
data = 'I' + str(maxint64) + '\n.'
got = pickle.loads(data)
if maxint64 != got:
raise TestFailed("maxint64 test failed %r %r" % (maxint64, got))
# Try too with a bogus literal.
data = 'I' + str(maxint64) + 'JUNK\n.'
try:
got = pickle.loads(data)
except ValueError:
def test_maxint64(self):
maxint64 = (1L << 63) - 1
data = 'I' + str(maxint64) + '\n.'
got = self.loads(data)
self.assertEqual(got, maxint64)
# Try too with a bogus literal.
data = 'I' + str(maxint64) + 'JUNK\n.'
self.assertRaises(ValueError, self.loads, data)
def test_reduce(self):
pass
else:
raise TestFailed("should have raised error on bogus INT literal")
def test_getinitargs(self):
pass
class AbstractPickleModuleTests(unittest.TestCase):
def test_dump_closed_file(self):
import tempfile, os
fn = tempfile.mktemp()
f = open(fn, "w")
f.close()
self.assertRaises(ValueError, self.module.dump, 123, f)
os.remove(fn)
def test_load_closed_file(self):
import tempfile, os
fn = tempfile.mktemp()
f = open(fn, "w")
f.close()
self.assertRaises(ValueError, self.module.dump, 123, f)
os.remove(fn)

View File

@ -1,3 +1,86 @@
import cPickle
import pickletester
pickletester.dotest(cPickle)
from cStringIO import StringIO
from pickletester import AbstractPickleTests, AbstractPickleModuleTests
from test_support import run_unittest
class cPickleTests(AbstractPickleTests, AbstractPickleModuleTests):
def setUp(self):
self.dumps = cPickle.dumps
self.loads = cPickle.loads
error = cPickle.BadPickleGet
module = cPickle
class cPicklePicklerTests(AbstractPickleTests):
def dumps(self, arg, bin=0):
f = StringIO()
p = cPickle.Pickler(f, bin)
p.dump(arg)
f.seek(0)
return f.read()
def loads(self, buf):
f = StringIO(buf)
p = cPickle.Unpickler(f)
return p.load()
error = cPickle.BadPickleGet
class cPickleListPicklerTests(AbstractPickleTests):
def dumps(self, arg, bin=0):
p = cPickle.Pickler(bin)
p.dump(arg)
return p.getvalue()
def loads(self, *args):
f = StringIO(args[0])
p = cPickle.Unpickler(f)
return p.load()
error = cPickle.BadPickleGet
class cPickleFastPicklerTests(AbstractPickleTests):
def dumps(self, arg, bin=0):
f = StringIO()
p = cPickle.Pickler(f, bin)
p.fast = 1
p.dump(arg)
f.seek(0)
return f.read()
def loads(self, *args):
f = StringIO(args[0])
p = cPickle.Unpickler(f)
return p.load()
error = cPickle.BadPickleGet
def test_recursive_list(self):
self.assertRaises(ValueError,
AbstractPickleTests.test_recursive_list,
self)
def test_recursive_inst(self):
self.assertRaises(ValueError,
AbstractPickleTests.test_recursive_inst,
self)
def test_recursive_dict(self):
self.assertRaises(ValueError,
AbstractPickleTests.test_recursive_dict,
self)
def test_recursive_multi(self):
self.assertRaises(ValueError,
AbstractPickleTests.test_recursive_multi,
self)
if __name__ == "__main__":
run_unittest(cPickleTests)
run_unittest(cPicklePicklerTests)
run_unittest(cPickleListPicklerTests)
run_unittest(cPickleFastPicklerTests)

View File

@ -1,3 +1,33 @@
import pickle
import pickletester
pickletester.dotest(pickle)
from cStringIO import StringIO
from pickletester import AbstractPickleTests, AbstractPickleModuleTests
from test_support import run_unittest
class PickleTests(AbstractPickleTests, AbstractPickleModuleTests):
def setUp(self):
self.dumps = pickle.dumps
self.loads = pickle.loads
module = pickle
error = KeyError
class PicklerTests(AbstractPickleTests):
error = KeyError
def dumps(self, arg, bin=0):
f = StringIO()
p = pickle.Pickler(f, bin)
p.dump(arg)
f.seek(0)
return f.read()
def loads(self, buf):
f = StringIO(buf)
u = pickle.Unpickler(f)
return u.load()
if __name__ == "__main__":
run_unittest(PickleTests)
run_unittest(PicklerTests)