Switch the AF_* and SOCK_* constants in the socket module to IntEnum.
Closes #18720.
This commit is contained in:
parent
7e7cf8bc51
commit
b2ff3cf0e9
|
@ -48,6 +48,7 @@ import _socket
|
|||
from _socket import *
|
||||
|
||||
import os, sys, io
|
||||
from enum import IntEnum
|
||||
|
||||
try:
|
||||
import errno
|
||||
|
@ -60,6 +61,30 @@ EWOULDBLOCK = getattr(errno, 'EWOULDBLOCK', 11)
|
|||
__all__ = ["getfqdn", "create_connection"]
|
||||
__all__.extend(os._get_exports_list(_socket))
|
||||
|
||||
# Set up the socket.AF_* socket.SOCK_* constants as members of IntEnums for
|
||||
# nicer string representations.
|
||||
# Note that _socket only knows about the integer values. The public interface
|
||||
# in this module understands the enums and translates them back from integers
|
||||
# where needed (e.g. .family property of a socket object).
|
||||
AddressFamily = IntEnum('AddressFamily',
|
||||
{name: value for name, value in globals().items()
|
||||
if name.isupper() and name.startswith('AF_')})
|
||||
globals().update(AddressFamily.__members__)
|
||||
|
||||
SocketType = IntEnum('SocketType',
|
||||
{name: value for name, value in globals().items()
|
||||
if name.isupper() and name.startswith('SOCK_')})
|
||||
globals().update(SocketType.__members__)
|
||||
|
||||
def _intenum_converter(value, enum_klass):
|
||||
"""Convert a numeric family value to an IntEnum member.
|
||||
|
||||
If it's not a known member, return the numeric value itself.
|
||||
"""
|
||||
try:
|
||||
return enum_klass(value)
|
||||
except ValueError:
|
||||
return value
|
||||
|
||||
_realsocket = socket
|
||||
|
||||
|
@ -91,6 +116,10 @@ class socket(_socket.socket):
|
|||
__slots__ = ["__weakref__", "_io_refs", "_closed"]
|
||||
|
||||
def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None):
|
||||
# For user code address family and type values are IntEnum members, but
|
||||
# for the underlying _socket.socket they're just integers. The
|
||||
# constructor of _socket.socket converts the given argument to an
|
||||
# integer automatically.
|
||||
_socket.socket.__init__(self, family, type, proto, fileno)
|
||||
self._io_refs = 0
|
||||
self._closed = False
|
||||
|
@ -230,6 +259,18 @@ class socket(_socket.socket):
|
|||
self._closed = True
|
||||
return super().detach()
|
||||
|
||||
@property
|
||||
def family(self):
|
||||
"""Read-only access to the address family for this socket.
|
||||
"""
|
||||
return _intenum_converter(super().family, AddressFamily)
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
"""Read-only access to the socket type.
|
||||
"""
|
||||
return _intenum_converter(super().type, SocketType)
|
||||
|
||||
if os.name == 'nt':
|
||||
def get_inheritable(self):
|
||||
return os.get_handle_inheritable(self.fileno())
|
||||
|
@ -243,7 +284,6 @@ class socket(_socket.socket):
|
|||
get_inheritable.__doc__ = "Get the inheritable flag of the socket"
|
||||
set_inheritable.__doc__ = "Set the inheritable flag of the socket"
|
||||
|
||||
|
||||
def fromfd(fd, family, type, proto=0):
|
||||
""" fromfd(fd, family, type[, proto]) -> socket object
|
||||
|
||||
|
@ -469,3 +509,27 @@ def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT,
|
|||
raise err
|
||||
else:
|
||||
raise error("getaddrinfo returns an empty list")
|
||||
|
||||
def getaddrinfo(host, port, family=0, type=0, proto=0, flags=0):
|
||||
"""Resolve host and port into list of address info entries.
|
||||
|
||||
Translate the host/port argument into a sequence of 5-tuples that contain
|
||||
all the necessary arguments for creating a socket connected to that service.
|
||||
host is a domain name, a string representation of an IPv4/v6 address or
|
||||
None. port is a string service name such as 'http', a numeric port number or
|
||||
None. By passing None as the value of host and port, you can pass NULL to
|
||||
the underlying C API.
|
||||
|
||||
The family, type and proto arguments can be optionally specified in order to
|
||||
narrow the list of addresses returned. Passing zero as a value for each of
|
||||
these arguments selects the full range of results.
|
||||
"""
|
||||
# We override this function since we want to translate the numeric family
|
||||
# and socket type values to enum constants.
|
||||
addrlist = []
|
||||
for res in _socket.getaddrinfo(host, port, family, type, proto, flags):
|
||||
af, socktype, proto, canonname, sa = res
|
||||
addrlist.append((_intenum_converter(af, AddressFamily),
|
||||
_intenum_converter(socktype, SocketType),
|
||||
proto, canonname, sa))
|
||||
return addrlist
|
||||
|
|
|
@ -1161,9 +1161,12 @@ class GeneralModuleTests(unittest.TestCase):
|
|||
socket.getaddrinfo(HOST, 80)
|
||||
socket.getaddrinfo(HOST, None)
|
||||
# test family and socktype filters
|
||||
infos = socket.getaddrinfo(HOST, None, socket.AF_INET)
|
||||
for family, _, _, _, _ in infos:
|
||||
infos = socket.getaddrinfo(HOST, 80, socket.AF_INET, socket.SOCK_STREAM)
|
||||
for family, type, _, _, _ in infos:
|
||||
self.assertEqual(family, socket.AF_INET)
|
||||
self.assertEqual(str(family), 'AddressFamily.AF_INET')
|
||||
self.assertEqual(type, socket.SOCK_STREAM)
|
||||
self.assertEqual(str(type), 'SocketType.SOCK_STREAM')
|
||||
infos = socket.getaddrinfo(HOST, None, 0, socket.SOCK_STREAM)
|
||||
for _, socktype, _, _, _ in infos:
|
||||
self.assertEqual(socktype, socket.SOCK_STREAM)
|
||||
|
@ -1321,6 +1324,27 @@ class GeneralModuleTests(unittest.TestCase):
|
|||
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
||||
self.assertRaises(OverflowError, s.bind, (support.HOSTv6, 0, -10))
|
||||
|
||||
def test_str_for_enums(self):
|
||||
# Make sure that the AF_* and SOCK_* constants have enum-like string
|
||||
# reprs.
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
self.assertEqual(str(s.family), 'AddressFamily.AF_INET')
|
||||
self.assertEqual(str(s.type), 'SocketType.SOCK_STREAM')
|
||||
|
||||
@unittest.skipIf(os.name == 'nt', 'Will not work on Windows')
|
||||
def test_uknown_socket_family_repr(self):
|
||||
# Test that when created with a family that's not one of the known
|
||||
# AF_*/SOCK_* constants, socket.family just returns the number.
|
||||
#
|
||||
# To do this we fool socket.socket into believing it already has an
|
||||
# open fd because on this path it doesn't actually verify the family and
|
||||
# type and populates the socket object.
|
||||
#
|
||||
# On Windows this trick won't work, so the test is skipped.
|
||||
fd, _ = tempfile.mkstemp()
|
||||
with socket.socket(family=42424, type=13331, fileno=fd) as s:
|
||||
self.assertEqual(s.family, 42424)
|
||||
self.assertEqual(s.type, 13331)
|
||||
|
||||
@unittest.skipUnless(HAVE_SOCKET_CAN, 'SocketCan required for this test.')
|
||||
class BasicCANTest(unittest.TestCase):
|
||||
|
|
Loading…
Reference in New Issue