Issue #14814: In the spirit of TOOWTDI, ditch the redundant version parameter to the factory functions by using the appropriate direct class references instead

This commit is contained in:
Nick Coghlan 2012-05-27 00:25:58 +10:00
parent 072b1e1485
commit 51c3067551
2 changed files with 48 additions and 94 deletions

View File

@ -36,34 +36,22 @@ class NetmaskValueError(ValueError):
"""A Value Error related to the netmask."""
def ip_address(address, version=None):
def ip_address(address):
"""Take an IP string/int and return an object of the correct type.
Args:
address: A string or integer, the IP address. Either IPv4 or
IPv6 addresses may be supplied; integers less than 2**32 will
be considered to be IPv4 by default.
version: An integer, 4 or 6. If set, don't try to automatically
determine what the IP address type is. Important for things
like ip_address(1), which could be IPv4, '192.0.2.1', or IPv6,
'2001:db8::1'.
Returns:
An IPv4Address or IPv6Address object.
Raises:
ValueError: if the *address* passed isn't either a v4 or a v6
address, or if the version is not None, 4, or 6.
address
"""
if version is not None:
if version == 4:
return IPv4Address(address)
elif version == 6:
return IPv6Address(address)
else:
raise ValueError()
try:
return IPv4Address(address)
except (AddressValueError, NetmaskValueError):
@ -78,35 +66,22 @@ def ip_address(address, version=None):
address)
def ip_network(address, version=None, strict=True):
def ip_network(address, strict=True):
"""Take an IP string/int and return an object of the correct type.
Args:
address: A string or integer, the IP network. Either IPv4 or
IPv6 networks may be supplied; integers less than 2**32 will
be considered to be IPv4 by default.
version: An integer, 4 or 6. If set, don't try to automatically
determine what the IP address type is. Important for things
like ip_network(1), which could be IPv4, '192.0.2.1/32', or IPv6,
'2001:db8::1/128'.
Returns:
An IPv4Network or IPv6Network object.
Raises:
ValueError: if the string passed isn't either a v4 or a v6
address. Or if the network has host bits set. Or if the version
is not None, 4, or 6.
address. Or if the network has host bits set.
"""
if version is not None:
if version == 4:
return IPv4Network(address, strict)
elif version == 6:
return IPv6Network(address, strict)
else:
raise ValueError()
try:
return IPv4Network(address, strict)
except (AddressValueError, NetmaskValueError):
@ -121,24 +96,20 @@ def ip_network(address, version=None, strict=True):
address)
def ip_interface(address, version=None):
def ip_interface(address):
"""Take an IP string/int and return an object of the correct type.
Args:
address: A string or integer, the IP address. Either IPv4 or
IPv6 addresses may be supplied; integers less than 2**32 will
be considered to be IPv4 by default.
version: An integer, 4 or 6. If set, don't try to automatically
determine what the IP address type is. Important for things
like ip_interface(1), which could be IPv4, '192.0.2.1/32', or IPv6,
'2001:db8::1/128'.
Returns:
An IPv4Interface or IPv6Interface object.
Raises:
ValueError: if the string passed isn't either a v4 or a v6
address. Or if the version is not None, 4, or 6.
address.
Notes:
The IPv?Interface classes describe an Address on a particular
@ -146,14 +117,6 @@ def ip_interface(address, version=None):
and Network classes.
"""
if version is not None:
if version == 4:
return IPv4Interface(address)
elif version == 6:
return IPv6Interface(address)
else:
raise ValueError()
try:
return IPv4Interface(address)
except (AddressValueError, NetmaskValueError):
@ -281,7 +244,7 @@ def summarize_address_range(first, last):
If the first and last objects are not the same version.
ValueError:
If the last object is not greater than the first.
If the version is not 4 or 6.
If the version of the first address is not 4 or 6.
"""
if (not (isinstance(first, _BaseAddress) and
@ -318,7 +281,7 @@ def summarize_address_range(first, last):
if current == ip._ALL_ONES:
break
first_int = current + 1
first = ip_address(first_int, version=first._version)
first = first.__class__(first_int)
def _collapse_addresses_recursive(addresses):
@ -586,12 +549,12 @@ class _BaseAddress(_IPAddressBase):
def __add__(self, other):
if not isinstance(other, int):
return NotImplemented
return ip_address(int(self) + other, version=self._version)
return self.__class__(int(self) + other)
def __sub__(self, other):
if not isinstance(other, int):
return NotImplemented
return ip_address(int(self) - other, version=self._version)
return self.__class__(int(self) - other)
def __repr__(self):
return '%s(%r)' % (self.__class__.__name__, str(self))
@ -612,13 +575,12 @@ class _BaseAddress(_IPAddressBase):
class _BaseNetwork(_IPAddressBase):
"""A generic IP object.
"""A generic IP network object.
This IP class contains the version independent methods which are
used by networks.
"""
def __init__(self, address):
self._cache = {}
@ -642,14 +604,14 @@ class _BaseNetwork(_IPAddressBase):
bcast = int(self.broadcast_address) - 1
while cur <= bcast:
cur += 1
yield ip_address(cur - 1, version=self._version)
yield self._address_class(cur - 1)
def __iter__(self):
cur = int(self.network_address)
bcast = int(self.broadcast_address)
while cur <= bcast:
cur += 1
yield ip_address(cur - 1, version=self._version)
yield self._address_class(cur - 1)
def __getitem__(self, n):
network = int(self.network_address)
@ -657,12 +619,12 @@ class _BaseNetwork(_IPAddressBase):
if n >= 0:
if network + n > broadcast:
raise IndexError
return ip_address(network + n, version=self._version)
return self._address_class(network + n)
else:
n += 1
if broadcast + n < network:
raise IndexError
return ip_address(broadcast + n, version=self._version)
return self._address_class(broadcast + n)
def __lt__(self, other):
if self._version != other._version:
@ -746,8 +708,8 @@ class _BaseNetwork(_IPAddressBase):
def broadcast_address(self):
x = self._cache.get('broadcast_address')
if x is None:
x = ip_address(int(self.network_address) | int(self.hostmask),
version=self._version)
x = self._address_class(int(self.network_address) |
int(self.hostmask))
self._cache['broadcast_address'] = x
return x
@ -755,15 +717,15 @@ class _BaseNetwork(_IPAddressBase):
def hostmask(self):
x = self._cache.get('hostmask')
if x is None:
x = ip_address(int(self.netmask) ^ self._ALL_ONES,
version=self._version)
x = self._address_class(int(self.netmask) ^ self._ALL_ONES)
self._cache['hostmask'] = x
return x
@property
def network(self):
return ip_network('%s/%d' % (str(self.network_address),
self.prefixlen))
# XXX (ncoghlan): This is redundant now and will likely be removed
return self.__class__('%s/%d' % (str(self.network_address),
self.prefixlen))
@property
def with_prefixlen(self):
@ -786,6 +748,10 @@ class _BaseNetwork(_IPAddressBase):
def version(self):
raise NotImplementedError('BaseNet has no version')
@property
def _address_class(self):
raise NotImplementedError('BaseNet has no associated address class')
@property
def prefixlen(self):
return self._prefixlen
@ -840,9 +806,8 @@ class _BaseNetwork(_IPAddressBase):
raise StopIteration
# Make sure we're comparing the network of other.
other = ip_network('%s/%s' % (str(other.network_address),
str(other.prefixlen)),
version=other._version)
other = other.__class__('%s/%s' % (str(other.network_address),
str(other.prefixlen)))
s1, s2 = self.subnets()
while s1 != other and s2 != other:
@ -973,9 +938,9 @@ class _BaseNetwork(_IPAddressBase):
'prefix length diff %d is invalid for netblock %s' % (
new_prefixlen, str(self)))
first = ip_network('%s/%s' % (str(self.network_address),
str(self._prefixlen + prefixlen_diff)),
version=self._version)
first = self.__class__('%s/%s' %
(str(self.network_address),
str(self._prefixlen + prefixlen_diff)))
yield first
current = first
@ -983,16 +948,17 @@ class _BaseNetwork(_IPAddressBase):
broadcast = current.broadcast_address
if broadcast == self.broadcast_address:
return
new_addr = ip_address(int(broadcast) + 1, version=self._version)
current = ip_network('%s/%s' % (str(new_addr), str(new_prefixlen)),
version=self._version)
new_addr = self._address_class(int(broadcast) + 1)
current = self.__class__('%s/%s' % (str(new_addr),
str(new_prefixlen)))
yield current
def masked(self):
"""Return the network object with the host bits masked out."""
return ip_network('%s/%d' % (self.network_address, self._prefixlen),
version=self._version)
# XXX (ncoghlan): This is redundant now and will likely be removed
return self.__class__('%s/%d' % (self.network_address,
self._prefixlen))
def supernet(self, prefixlen_diff=1, new_prefix=None):
"""The supernet containing the current network.
@ -1030,11 +996,10 @@ class _BaseNetwork(_IPAddressBase):
'current prefixlen is %d, cannot have a prefixlen_diff of %d' %
(self.prefixlen, prefixlen_diff))
# TODO (pmoody): optimize this.
t = ip_network('%s/%d' % (str(self.network_address),
self.prefixlen - prefixlen_diff),
version=self._version, strict=False)
return ip_network('%s/%d' % (str(t.network_address), t.prefixlen),
version=t._version)
t = self.__class__('%s/%d' % (str(self.network_address),
self.prefixlen - prefixlen_diff),
strict=False)
return t.__class__('%s/%d' % (str(t.network_address), t.prefixlen))
class _BaseV4(object):
@ -1391,6 +1356,9 @@ class IPv4Network(_BaseV4, _BaseNetwork):
.prefixlen: 27
"""
# Class to use when creating address objects
# TODO (ncoghlan): Investigate using IPv4Interface instead
_address_class = IPv4Address
# the valid octets for host and netmasks. only useful for IPv4.
_valid_mask_octets = set((255, 254, 252, 248, 240, 224, 192, 128, 0))
@ -2071,6 +2039,10 @@ class IPv6Network(_BaseV6, _BaseNetwork):
"""
# Class to use when creating address objects
# TODO (ncoghlan): Investigate using IPv6Interface instead
_address_class = IPv6Address
def __init__(self, address, strict=True):
"""Instantiate a new IPv6 Network object.

View File

@ -780,12 +780,6 @@ class IpaddrUnitTest(unittest.TestCase):
self.assertEqual(self.ipv4_address.version, 4)
self.assertEqual(self.ipv6_address.version, 6)
with self.assertRaises(ValueError):
ipaddress.ip_address('1', version=[])
with self.assertRaises(ValueError):
ipaddress.ip_address('1', version=5)
def testMaxPrefixLength(self):
self.assertEqual(self.ipv4_interface.max_prefixlen, 32)
self.assertEqual(self.ipv6_interface.max_prefixlen, 128)
@ -1052,12 +1046,7 @@ class IpaddrUnitTest(unittest.TestCase):
def testForceVersion(self):
self.assertEqual(ipaddress.ip_network(1).version, 4)
self.assertEqual(ipaddress.ip_network(1, version=6).version, 6)
with self.assertRaises(ValueError):
ipaddress.ip_network(1, version='l')
with self.assertRaises(ValueError):
ipaddress.ip_network(1, version=3)
self.assertEqual(ipaddress.IPv6Network(1).version, 6)
def testWithStar(self):
self.assertEqual(str(self.ipv4_interface.with_prefixlen), "1.2.3.4/24")
@ -1148,13 +1137,6 @@ class IpaddrUnitTest(unittest.TestCase):
sixtofouraddr.sixtofour)
self.assertFalse(bad_addr.sixtofour)
def testIpInterfaceVersion(self):
with self.assertRaises(ValueError):
ipaddress.ip_interface(1, version=123)
with self.assertRaises(ValueError):
ipaddress.ip_interface(1, version='')
if __name__ == '__main__':
unittest.main()