diff --git a/Lib/ipaddress.py b/Lib/ipaddress.py index 8d369e3986f..ce6bf0c8775 100644 --- a/Lib/ipaddress.py +++ b/Lib/ipaddress.py @@ -567,6 +567,9 @@ class _IPAddressBase(_TotalOrderingMixin): except ValueError: cls._report_invalid_netmask(ip_str) + def __reduce__(self): + return self.__class__, (str(self),) + class _BaseAddress(_IPAddressBase): @@ -576,11 +579,6 @@ class _BaseAddress(_IPAddressBase): used by single IP addresses. """ - def __init__(self, address): - if (not isinstance(address, bytes) - and '/' in str(address)): - raise AddressValueError("Unexpected '/' in %r" % address) - def __int__(self): return self._ip @@ -626,6 +624,9 @@ class _BaseAddress(_IPAddressBase): def _get_address_key(self): return (self._version, self) + def __reduce__(self): + return self.__class__, (self._ip,) + class _BaseNetwork(_IPAddressBase): @@ -1295,7 +1296,6 @@ class IPv4Address(_BaseV4, _BaseAddress): AddressValueError: If ipaddress isn't a valid IPv4 address. """ - _BaseAddress.__init__(self, address) _BaseV4.__init__(self, address) # Efficient constructor from integer. @@ -1313,6 +1313,8 @@ class IPv4Address(_BaseV4, _BaseAddress): # Assume input argument to be string or any object representation # which converts into a formatted IP string. addr_str = str(address) + if '/' in addr_str: + raise AddressValueError("Unexpected '/' in %r" % address) self._ip = self._ip_int_from_string(addr_str) @property @@ -1446,6 +1448,8 @@ class IPv4Interface(IPv4Address): def __hash__(self): return self._ip ^ self._prefixlen ^ int(self.network.network_address) + __reduce__ = _IPAddressBase.__reduce__ + @property def ip(self): return IPv4Address(self._ip) @@ -1920,7 +1924,6 @@ class IPv6Address(_BaseV6, _BaseAddress): AddressValueError: If address isn't a valid IPv6 address. """ - _BaseAddress.__init__(self, address) _BaseV6.__init__(self, address) # Efficient constructor from integer. @@ -1938,6 +1941,8 @@ class IPv6Address(_BaseV6, _BaseAddress): # Assume input argument to be string or any object representation # which converts into a formatted IP string. addr_str = str(address) + if '/' in addr_str: + raise AddressValueError("Unexpected '/' in %r" % address) self._ip = self._ip_int_from_string(addr_str) @property @@ -2134,6 +2139,8 @@ class IPv6Interface(IPv6Address): def __hash__(self): return self._ip ^ self._prefixlen ^ int(self.network.network_address) + __reduce__ = _IPAddressBase.__reduce__ + @property def ip(self): return IPv6Address(self._ip) diff --git a/Lib/test/test_ipaddress.py b/Lib/test/test_ipaddress.py index 1fbd2d5323f..02028bfe65b 100644 --- a/Lib/test/test_ipaddress.py +++ b/Lib/test/test_ipaddress.py @@ -8,6 +8,7 @@ import unittest import re import contextlib import operator +import pickle import ipaddress @@ -82,6 +83,13 @@ class CommonTestMixin: self.assertRaises(TypeError, hex, self.factory(1)) self.assertRaises(TypeError, bytes, self.factory(1)) + def pickle_test(self, addr): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + x = self.factory(addr) + y = pickle.loads(pickle.dumps(x, proto)) + self.assertEqual(y, x) + class CommonTestMixin_v4(CommonTestMixin): @@ -247,6 +255,9 @@ class AddressTestCase_v4(BaseTestCase, CommonTestMixin_v4): assertBadOctet("257.0.0.0", 257) assertBadOctet("192.168.0.999", 999) + def test_pickle(self): + self.pickle_test('192.0.2.1') + class AddressTestCase_v6(BaseTestCase, CommonTestMixin_v6): factory = ipaddress.IPv6Address @@ -379,6 +390,9 @@ class AddressTestCase_v6(BaseTestCase, CommonTestMixin_v6): assertBadPart("02001:db8::", "02001") assertBadPart('2001:888888::1', "888888") + def test_pickle(self): + self.pickle_test('2001:db8::') + class NetmaskTestMixin_v4(CommonTestMixin_v4): """Input validation on interfaces and networks is very similar""" @@ -446,6 +460,11 @@ class NetmaskTestMixin_v4(CommonTestMixin_v4): class InterfaceTestCase_v4(BaseTestCase, NetmaskTestMixin_v4): factory = ipaddress.IPv4Interface + def test_pickle(self): + self.pickle_test('192.0.2.0/27') + self.pickle_test('192.0.2.0/31') # IPV4LENGTH - 1 + self.pickle_test('192.0.2.0') # IPV4LENGTH + class NetworkTestCase_v4(BaseTestCase, NetmaskTestMixin_v4): factory = ipaddress.IPv4Network @@ -500,6 +519,11 @@ class NetmaskTestMixin_v6(CommonTestMixin_v6): assertBadNetmask("::1", "pudding") assertBadNetmask("::", "::") + def test_pickle(self): + self.pickle_test('2001:db8::1000/124') + self.pickle_test('2001:db8::1000/127') # IPV6LENGTH - 1 + self.pickle_test('2001:db8::1000') # IPV6LENGTH + class InterfaceTestCase_v6(BaseTestCase, NetmaskTestMixin_v6): factory = ipaddress.IPv6Interface @@ -774,13 +798,6 @@ class IpaddrUnitTest(unittest.TestCase): self.assertEqual(128, ipaddress._count_righthand_zero_bits(0, 128)) self.assertEqual("IPv4Network('1.2.3.0/24')", repr(self.ipv4_network)) - def testMissingAddressVersion(self): - class Broken(ipaddress._BaseAddress): - pass - broken = Broken('127.0.0.1') - with self.assertRaisesRegex(NotImplementedError, "Broken.*version"): - broken.version - def testMissingNetworkVersion(self): class Broken(ipaddress._BaseNetwork): pass diff --git a/Misc/NEWS b/Misc/NEWS index 245cfaf4edd..5373f5bca8b 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -203,6 +203,9 @@ Core and Builtins Library ------- +- Issue #23133: Pickling of ipaddress objects now produces more compact and + portable representation. + - Issue #23248: Update ssl error codes from latest OpenSSL git master. - Issue #23266: Much faster implementation of ipaddress.collapse_addresses()