Issue #23133: Pickling of ipaddress objects now produces more compact and

portable representation.
This commit is contained in:
Serhiy Storchaka 2015-01-18 22:36:33 +02:00
parent 3b225d8bfb
commit 5f38f5c502
3 changed files with 41 additions and 14 deletions

View File

@ -567,6 +567,9 @@ class _IPAddressBase(_TotalOrderingMixin):
except ValueError: except ValueError:
cls._report_invalid_netmask(ip_str) cls._report_invalid_netmask(ip_str)
def __reduce__(self):
return self.__class__, (str(self),)
class _BaseAddress(_IPAddressBase): class _BaseAddress(_IPAddressBase):
@ -576,11 +579,6 @@ class _BaseAddress(_IPAddressBase):
used by single IP addresses. used by single IP addresses.
""" """
def __init__(self, address):
if (not isinstance(address, bytes)
and '/' in str(address)):
raise AddressValueError("Unexpected '/' in %r" % address)
def __int__(self): def __int__(self):
return self._ip return self._ip
@ -626,6 +624,9 @@ class _BaseAddress(_IPAddressBase):
def _get_address_key(self): def _get_address_key(self):
return (self._version, self) return (self._version, self)
def __reduce__(self):
return self.__class__, (self._ip,)
class _BaseNetwork(_IPAddressBase): class _BaseNetwork(_IPAddressBase):
@ -1295,7 +1296,6 @@ class IPv4Address(_BaseV4, _BaseAddress):
AddressValueError: If ipaddress isn't a valid IPv4 address. AddressValueError: If ipaddress isn't a valid IPv4 address.
""" """
_BaseAddress.__init__(self, address)
_BaseV4.__init__(self, address) _BaseV4.__init__(self, address)
# Efficient constructor from integer. # Efficient constructor from integer.
@ -1313,6 +1313,8 @@ class IPv4Address(_BaseV4, _BaseAddress):
# Assume input argument to be string or any object representation # Assume input argument to be string or any object representation
# which converts into a formatted IP string. # which converts into a formatted IP string.
addr_str = str(address) addr_str = str(address)
if '/' in addr_str:
raise AddressValueError("Unexpected '/' in %r" % address)
self._ip = self._ip_int_from_string(addr_str) self._ip = self._ip_int_from_string(addr_str)
@property @property
@ -1446,6 +1448,8 @@ class IPv4Interface(IPv4Address):
def __hash__(self): def __hash__(self):
return self._ip ^ self._prefixlen ^ int(self.network.network_address) return self._ip ^ self._prefixlen ^ int(self.network.network_address)
__reduce__ = _IPAddressBase.__reduce__
@property @property
def ip(self): def ip(self):
return IPv4Address(self._ip) return IPv4Address(self._ip)
@ -1920,7 +1924,6 @@ class IPv6Address(_BaseV6, _BaseAddress):
AddressValueError: If address isn't a valid IPv6 address. AddressValueError: If address isn't a valid IPv6 address.
""" """
_BaseAddress.__init__(self, address)
_BaseV6.__init__(self, address) _BaseV6.__init__(self, address)
# Efficient constructor from integer. # Efficient constructor from integer.
@ -1938,6 +1941,8 @@ class IPv6Address(_BaseV6, _BaseAddress):
# Assume input argument to be string or any object representation # Assume input argument to be string or any object representation
# which converts into a formatted IP string. # which converts into a formatted IP string.
addr_str = str(address) addr_str = str(address)
if '/' in addr_str:
raise AddressValueError("Unexpected '/' in %r" % address)
self._ip = self._ip_int_from_string(addr_str) self._ip = self._ip_int_from_string(addr_str)
@property @property
@ -2134,6 +2139,8 @@ class IPv6Interface(IPv6Address):
def __hash__(self): def __hash__(self):
return self._ip ^ self._prefixlen ^ int(self.network.network_address) return self._ip ^ self._prefixlen ^ int(self.network.network_address)
__reduce__ = _IPAddressBase.__reduce__
@property @property
def ip(self): def ip(self):
return IPv6Address(self._ip) return IPv6Address(self._ip)

View File

@ -8,6 +8,7 @@ import unittest
import re import re
import contextlib import contextlib
import operator import operator
import pickle
import ipaddress import ipaddress
@ -82,6 +83,13 @@ class CommonTestMixin:
self.assertRaises(TypeError, hex, self.factory(1)) self.assertRaises(TypeError, hex, self.factory(1))
self.assertRaises(TypeError, bytes, self.factory(1)) self.assertRaises(TypeError, bytes, self.factory(1))
def pickle_test(self, addr):
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
with self.subTest(proto=proto):
x = self.factory(addr)
y = pickle.loads(pickle.dumps(x, proto))
self.assertEqual(y, x)
class CommonTestMixin_v4(CommonTestMixin): class CommonTestMixin_v4(CommonTestMixin):
@ -247,6 +255,9 @@ class AddressTestCase_v4(BaseTestCase, CommonTestMixin_v4):
assertBadOctet("257.0.0.0", 257) assertBadOctet("257.0.0.0", 257)
assertBadOctet("192.168.0.999", 999) assertBadOctet("192.168.0.999", 999)
def test_pickle(self):
self.pickle_test('192.0.2.1')
class AddressTestCase_v6(BaseTestCase, CommonTestMixin_v6): class AddressTestCase_v6(BaseTestCase, CommonTestMixin_v6):
factory = ipaddress.IPv6Address factory = ipaddress.IPv6Address
@ -379,6 +390,9 @@ class AddressTestCase_v6(BaseTestCase, CommonTestMixin_v6):
assertBadPart("02001:db8::", "02001") assertBadPart("02001:db8::", "02001")
assertBadPart('2001:888888::1', "888888") assertBadPart('2001:888888::1', "888888")
def test_pickle(self):
self.pickle_test('2001:db8::')
class NetmaskTestMixin_v4(CommonTestMixin_v4): class NetmaskTestMixin_v4(CommonTestMixin_v4):
"""Input validation on interfaces and networks is very similar""" """Input validation on interfaces and networks is very similar"""
@ -446,6 +460,11 @@ class NetmaskTestMixin_v4(CommonTestMixin_v4):
class InterfaceTestCase_v4(BaseTestCase, NetmaskTestMixin_v4): class InterfaceTestCase_v4(BaseTestCase, NetmaskTestMixin_v4):
factory = ipaddress.IPv4Interface factory = ipaddress.IPv4Interface
def test_pickle(self):
self.pickle_test('192.0.2.0/27')
self.pickle_test('192.0.2.0/31') # IPV4LENGTH - 1
self.pickle_test('192.0.2.0') # IPV4LENGTH
class NetworkTestCase_v4(BaseTestCase, NetmaskTestMixin_v4): class NetworkTestCase_v4(BaseTestCase, NetmaskTestMixin_v4):
factory = ipaddress.IPv4Network factory = ipaddress.IPv4Network
@ -500,6 +519,11 @@ class NetmaskTestMixin_v6(CommonTestMixin_v6):
assertBadNetmask("::1", "pudding") assertBadNetmask("::1", "pudding")
assertBadNetmask("::", "::") assertBadNetmask("::", "::")
def test_pickle(self):
self.pickle_test('2001:db8::1000/124')
self.pickle_test('2001:db8::1000/127') # IPV6LENGTH - 1
self.pickle_test('2001:db8::1000') # IPV6LENGTH
class InterfaceTestCase_v6(BaseTestCase, NetmaskTestMixin_v6): class InterfaceTestCase_v6(BaseTestCase, NetmaskTestMixin_v6):
factory = ipaddress.IPv6Interface factory = ipaddress.IPv6Interface
@ -774,13 +798,6 @@ class IpaddrUnitTest(unittest.TestCase):
self.assertEqual(128, ipaddress._count_righthand_zero_bits(0, 128)) self.assertEqual(128, ipaddress._count_righthand_zero_bits(0, 128))
self.assertEqual("IPv4Network('1.2.3.0/24')", repr(self.ipv4_network)) self.assertEqual("IPv4Network('1.2.3.0/24')", repr(self.ipv4_network))
def testMissingAddressVersion(self):
class Broken(ipaddress._BaseAddress):
pass
broken = Broken('127.0.0.1')
with self.assertRaisesRegex(NotImplementedError, "Broken.*version"):
broken.version
def testMissingNetworkVersion(self): def testMissingNetworkVersion(self):
class Broken(ipaddress._BaseNetwork): class Broken(ipaddress._BaseNetwork):
pass pass

View File

@ -203,6 +203,9 @@ Core and Builtins
Library Library
------- -------
- Issue #23133: Pickling of ipaddress objects now produces more compact and
portable representation.
- Issue #23248: Update ssl error codes from latest OpenSSL git master. - Issue #23248: Update ssl error codes from latest OpenSSL git master.
- Issue #23266: Much faster implementation of ipaddress.collapse_addresses() - Issue #23266: Much faster implementation of ipaddress.collapse_addresses()