Collections module reformatting and minor code refactoring (GH-20772)

This commit is contained in:
Raymond Hettinger 2020-06-10 23:17:58 -07:00 committed by GitHub
parent 896f4cf63f
commit 31d17798d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 271 additions and 91 deletions

View File

@ -14,17 +14,30 @@ list, set, and tuple.
'''
__all__ = ['deque', 'defaultdict', 'namedtuple', 'UserDict', 'UserList',
'UserString', 'Counter', 'OrderedDict', 'ChainMap']
__all__ = [
'ChainMap',
'Counter',
'OrderedDict',
'UserDict',
'UserList',
'UserString',
'defaultdict',
'deque',
'namedtuple',
]
import _collections_abc
from operator import itemgetter as _itemgetter, eq as _eq
from keyword import iskeyword as _iskeyword
import sys as _sys
import heapq as _heapq
from _weakref import proxy as _proxy
from itertools import repeat as _repeat, chain as _chain, starmap as _starmap
import sys as _sys
from itertools import chain as _chain
from itertools import repeat as _repeat
from itertools import starmap as _starmap
from keyword import iskeyword as _iskeyword
from operator import eq as _eq
from operator import itemgetter as _itemgetter
from reprlib import recursive_repr as _recursive_repr
from _weakref import proxy as _proxy
try:
from _collections import deque
@ -54,6 +67,7 @@ def __getattr__(name):
return obj
raise AttributeError(f'module {__name__!r} has no attribute {name!r}')
################################################################################
### OrderedDict
################################################################################
@ -408,10 +422,13 @@ def namedtuple(typename, field_names, *, rename=False, defaults=None, module=Non
# Create all the named tuple methods to be added to the class namespace
s = f'lambda _cls, {arg_list}: _tuple_new(_cls, ({arg_list}))'
namespace = {'_tuple_new': tuple_new, '__builtins__': None,
'__name__': f'namedtuple_{typename}'}
__new__ = eval(s, namespace)
namespace = {
'_tuple_new': tuple_new,
'__builtins__': None,
'__name__': f'namedtuple_{typename}',
}
code = f'lambda _cls, {arg_list}: _tuple_new(_cls, ({arg_list}))'
__new__ = eval(code, namespace)
__new__.__name__ = '__new__'
__new__.__doc__ = f'Create new instance of {typename}({arg_list})'
if defaults is not None:
@ -449,8 +466,14 @@ def namedtuple(typename, field_names, *, rename=False, defaults=None, module=Non
return _tuple(self)
# Modify function metadata to help with introspection and debugging
for method in (__new__, _make.__func__, _replace,
__repr__, _asdict, __getnewargs__):
for method in (
__new__,
_make.__func__,
_replace,
__repr__,
_asdict,
__getnewargs__,
):
method.__qualname__ = f'{typename}.{method.__name__}'
# Build-up the class namespace dictionary
@ -566,7 +589,7 @@ class Counter(dict):
>>> c = Counter(a=4, b=2) # a new counter from keyword args
'''
super(Counter, self).__init__()
super().__init__()
self.update(iterable, **kwds)
def __missing__(self, key):
@ -650,7 +673,8 @@ class Counter(dict):
for elem, count in iterable.items():
self[elem] = count + self_get(elem, 0)
else:
super(Counter, self).update(iterable) # fast path when counter is empty
# fast path when counter is empty
super().update(iterable)
else:
_count_elements(self, iterable)
if kwds:
@ -733,13 +757,14 @@ class Counter(dict):
def __repr__(self):
if not self:
return '%s()' % self.__class__.__name__
return f'{self.__class__.__name__}()'
try:
items = ', '.join(map('%r: %r'.__mod__, self.most_common()))
return '%s({%s})' % (self.__class__.__name__, items)
# dict() preserves the ordering returned by most_common()
d = dict(self.most_common())
except TypeError:
# handle case where values are not orderable
return '{0}({1!r})'.format(self.__class__.__name__, dict(self))
d = dict(self)
return f'{self.__class__.__name__}({d!r})'
# Multiset-style mathematical operations discussed in:
# Knuth TAOCP Volume II section 4.6.3 exercise 19
@ -1018,7 +1043,7 @@ class ChainMap(_collections_abc.MutableMapping):
try:
del self.maps[0][key]
except KeyError:
raise KeyError('Key not found in the first mapping: {!r}'.format(key))
raise KeyError(f'Key not found in the first mapping: {key!r}')
def popitem(self):
'Remove and return an item pair from maps[0]. Raise KeyError is maps[0] is empty.'
@ -1032,30 +1057,30 @@ class ChainMap(_collections_abc.MutableMapping):
try:
return self.maps[0].pop(key, *args)
except KeyError:
raise KeyError('Key not found in the first mapping: {!r}'.format(key))
raise KeyError(f'Key not found in the first mapping: {key!r}')
def clear(self):
'Clear maps[0], leaving maps[1:] intact.'
self.maps[0].clear()
def __ior__(self, other):
self.maps[0] |= other
self.maps[0].update(other)
return self
def __or__(self, other):
if isinstance(other, _collections_abc.Mapping):
m = self.maps[0].copy()
m.update(other)
return self.__class__(m, *self.maps[1:])
return NotImplemented
if not isinstance(other, _collections_abc.Mapping):
return NotImplemented
m = self.copy()
m.maps[0].update(other)
return m
def __ror__(self, other):
if isinstance(other, _collections_abc.Mapping):
m = dict(other)
for child in reversed(self.maps):
m.update(child)
return self.__class__(m)
return NotImplemented
if not isinstance(other, _collections_abc.Mapping):
return NotImplemented
m = dict(other)
for child in reversed(self.maps):
m.update(child)
return self.__class__(m)
################################################################################
@ -1072,15 +1097,22 @@ class UserDict(_collections_abc.MutableMapping):
if kwargs:
self.update(kwargs)
def __len__(self): return len(self.data)
def __len__(self):
return len(self.data)
def __getitem__(self, key):
if key in self.data:
return self.data[key]
if hasattr(self.__class__, "__missing__"):
return self.__class__.__missing__(self, key)
raise KeyError(key)
def __setitem__(self, key, item): self.data[key] = item
def __delitem__(self, key): del self.data[key]
def __setitem__(self, key, item):
self.data[key] = item
def __delitem__(self, key):
del self.data[key]
def __iter__(self):
return iter(self.data)
@ -1089,7 +1121,8 @@ class UserDict(_collections_abc.MutableMapping):
return key in self.data
# Now, add the methods in dicts but not in MutableMapping
def __repr__(self): return repr(self.data)
def __repr__(self):
return repr(self.data)
def __or__(self, other):
if isinstance(other, UserDict):
@ -1097,12 +1130,14 @@ class UserDict(_collections_abc.MutableMapping):
if isinstance(other, dict):
return self.__class__(self.data | other)
return NotImplemented
def __ror__(self, other):
if isinstance(other, UserDict):
return self.__class__(other.data | self.data)
if isinstance(other, dict):
return self.__class__(other | self.data)
return NotImplemented
def __ior__(self, other):
if isinstance(other, UserDict):
self.data |= other.data
@ -1138,13 +1173,13 @@ class UserDict(_collections_abc.MutableMapping):
return d
################################################################################
### UserList
################################################################################
class UserList(_collections_abc.MutableSequence):
"""A more or less complete user-defined wrapper around list objects."""
def __init__(self, initlist=None):
self.data = []
if initlist is not None:
@ -1155,35 +1190,60 @@ class UserList(_collections_abc.MutableSequence):
self.data[:] = initlist.data[:]
else:
self.data = list(initlist)
def __repr__(self): return repr(self.data)
def __lt__(self, other): return self.data < self.__cast(other)
def __le__(self, other): return self.data <= self.__cast(other)
def __eq__(self, other): return self.data == self.__cast(other)
def __gt__(self, other): return self.data > self.__cast(other)
def __ge__(self, other): return self.data >= self.__cast(other)
def __repr__(self):
return repr(self.data)
def __lt__(self, other):
return self.data < self.__cast(other)
def __le__(self, other):
return self.data <= self.__cast(other)
def __eq__(self, other):
return self.data == self.__cast(other)
def __gt__(self, other):
return self.data > self.__cast(other)
def __ge__(self, other):
return self.data >= self.__cast(other)
def __cast(self, other):
return other.data if isinstance(other, UserList) else other
def __contains__(self, item): return item in self.data
def __len__(self): return len(self.data)
def __contains__(self, item):
return item in self.data
def __len__(self):
return len(self.data)
def __getitem__(self, i):
if isinstance(i, slice):
return self.__class__(self.data[i])
else:
return self.data[i]
def __setitem__(self, i, item): self.data[i] = item
def __delitem__(self, i): del self.data[i]
def __setitem__(self, i, item):
self.data[i] = item
def __delitem__(self, i):
del self.data[i]
def __add__(self, other):
if isinstance(other, UserList):
return self.__class__(self.data + other.data)
elif isinstance(other, type(self.data)):
return self.__class__(self.data + other)
return self.__class__(self.data + list(other))
def __radd__(self, other):
if isinstance(other, UserList):
return self.__class__(other.data + self.data)
elif isinstance(other, type(self.data)):
return self.__class__(other + self.data)
return self.__class__(list(other) + self.data)
def __iadd__(self, other):
if isinstance(other, UserList):
self.data += other.data
@ -1192,28 +1252,53 @@ class UserList(_collections_abc.MutableSequence):
else:
self.data += list(other)
return self
def __mul__(self, n):
return self.__class__(self.data*n)
return self.__class__(self.data * n)
__rmul__ = __mul__
def __imul__(self, n):
self.data *= n
return self
def __copy__(self):
inst = self.__class__.__new__(self.__class__)
inst.__dict__.update(self.__dict__)
# Create a copy and avoid triggering descriptors
inst.__dict__["data"] = self.__dict__["data"][:]
return inst
def append(self, item): self.data.append(item)
def insert(self, i, item): self.data.insert(i, item)
def pop(self, i=-1): return self.data.pop(i)
def remove(self, item): self.data.remove(item)
def clear(self): self.data.clear()
def copy(self): return self.__class__(self)
def count(self, item): return self.data.count(item)
def index(self, item, *args): return self.data.index(item, *args)
def reverse(self): self.data.reverse()
def sort(self, /, *args, **kwds): self.data.sort(*args, **kwds)
def append(self, item):
self.data.append(item)
def insert(self, i, item):
self.data.insert(i, item)
def pop(self, i=-1):
return self.data.pop(i)
def remove(self, item):
self.data.remove(item)
def clear(self):
self.data.clear()
def copy(self):
return self.__class__(self)
def count(self, item):
return self.data.count(item)
def index(self, item, *args):
return self.data.index(item, *args)
def reverse(self):
self.data.reverse()
def sort(self, /, *args, **kwds):
self.data.sort(*args, **kwds)
def extend(self, other):
if isinstance(other, UserList):
self.data.extend(other.data)
@ -1221,12 +1306,12 @@ class UserList(_collections_abc.MutableSequence):
self.data.extend(other)
################################################################################
### UserString
################################################################################
class UserString(_collections_abc.Sequence):
def __init__(self, seq):
if isinstance(seq, str):
self.data = seq
@ -1234,12 +1319,25 @@ class UserString(_collections_abc.Sequence):
self.data = seq.data[:]
else:
self.data = str(seq)
def __str__(self): return str(self.data)
def __repr__(self): return repr(self.data)
def __int__(self): return int(self.data)
def __float__(self): return float(self.data)
def __complex__(self): return complex(self.data)
def __hash__(self): return hash(self.data)
def __str__(self):
return str(self.data)
def __repr__(self):
return repr(self.data)
def __int__(self):
return int(self.data)
def __float__(self):
return float(self.data)
def __complex__(self):
return complex(self.data)
def __hash__(self):
return hash(self.data)
def __getnewargs__(self):
return (self.data[:],)
@ -1247,18 +1345,22 @@ class UserString(_collections_abc.Sequence):
if isinstance(string, UserString):
return self.data == string.data
return self.data == string
def __lt__(self, string):
if isinstance(string, UserString):
return self.data < string.data
return self.data < string
def __le__(self, string):
if isinstance(string, UserString):
return self.data <= string.data
return self.data <= string
def __gt__(self, string):
if isinstance(string, UserString):
return self.data > string.data
return self.data > string
def __ge__(self, string):
if isinstance(string, UserString):
return self.data >= string.data
@ -1269,110 +1371,188 @@ class UserString(_collections_abc.Sequence):
char = char.data
return char in self.data
def __len__(self): return len(self.data)
def __getitem__(self, index): return self.__class__(self.data[index])
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.__class__(self.data[index])
def __add__(self, other):
if isinstance(other, UserString):
return self.__class__(self.data + other.data)
elif isinstance(other, str):
return self.__class__(self.data + other)
return self.__class__(self.data + str(other))
def __radd__(self, other):
if isinstance(other, str):
return self.__class__(other + self.data)
return self.__class__(str(other) + self.data)
def __mul__(self, n):
return self.__class__(self.data*n)
return self.__class__(self.data * n)
__rmul__ = __mul__
def __mod__(self, args):
return self.__class__(self.data % args)
def __rmod__(self, template):
return self.__class__(str(template) % self)
# the following methods are defined in alphabetical order:
def capitalize(self): return self.__class__(self.data.capitalize())
def capitalize(self):
return self.__class__(self.data.capitalize())
def casefold(self):
return self.__class__(self.data.casefold())
def center(self, width, *args):
return self.__class__(self.data.center(width, *args))
def count(self, sub, start=0, end=_sys.maxsize):
if isinstance(sub, UserString):
sub = sub.data
return self.data.count(sub, start, end)
def removeprefix(self, prefix, /):
if isinstance(prefix, UserString):
prefix = prefix.data
return self.__class__(self.data.removeprefix(prefix))
def removesuffix(self, suffix, /):
if isinstance(suffix, UserString):
suffix = suffix.data
return self.__class__(self.data.removesuffix(suffix))
def encode(self, encoding='utf-8', errors='strict'):
encoding = 'utf-8' if encoding is None else encoding
errors = 'strict' if errors is None else errors
return self.data.encode(encoding, errors)
def endswith(self, suffix, start=0, end=_sys.maxsize):
return self.data.endswith(suffix, start, end)
def expandtabs(self, tabsize=8):
return self.__class__(self.data.expandtabs(tabsize))
def find(self, sub, start=0, end=_sys.maxsize):
if isinstance(sub, UserString):
sub = sub.data
return self.data.find(sub, start, end)
def format(self, /, *args, **kwds):
return self.data.format(*args, **kwds)
def format_map(self, mapping):
return self.data.format_map(mapping)
def index(self, sub, start=0, end=_sys.maxsize):
return self.data.index(sub, start, end)
def isalpha(self): return self.data.isalpha()
def isalnum(self): return self.data.isalnum()
def isascii(self): return self.data.isascii()
def isdecimal(self): return self.data.isdecimal()
def isdigit(self): return self.data.isdigit()
def isidentifier(self): return self.data.isidentifier()
def islower(self): return self.data.islower()
def isnumeric(self): return self.data.isnumeric()
def isprintable(self): return self.data.isprintable()
def isspace(self): return self.data.isspace()
def istitle(self): return self.data.istitle()
def isupper(self): return self.data.isupper()
def join(self, seq): return self.data.join(seq)
def isalpha(self):
return self.data.isalpha()
def isalnum(self):
return self.data.isalnum()
def isascii(self):
return self.data.isascii()
def isdecimal(self):
return self.data.isdecimal()
def isdigit(self):
return self.data.isdigit()
def isidentifier(self):
return self.data.isidentifier()
def islower(self):
return self.data.islower()
def isnumeric(self):
return self.data.isnumeric()
def isprintable(self):
return self.data.isprintable()
def isspace(self):
return self.data.isspace()
def istitle(self):
return self.data.istitle()
def isupper(self):
return self.data.isupper()
def join(self, seq):
return self.data.join(seq)
def ljust(self, width, *args):
return self.__class__(self.data.ljust(width, *args))
def lower(self): return self.__class__(self.data.lower())
def lstrip(self, chars=None): return self.__class__(self.data.lstrip(chars))
def lower(self):
return self.__class__(self.data.lower())
def lstrip(self, chars=None):
return self.__class__(self.data.lstrip(chars))
maketrans = str.maketrans
def partition(self, sep):
return self.data.partition(sep)
def replace(self, old, new, maxsplit=-1):
if isinstance(old, UserString):
old = old.data
if isinstance(new, UserString):
new = new.data
return self.__class__(self.data.replace(old, new, maxsplit))
def rfind(self, sub, start=0, end=_sys.maxsize):
if isinstance(sub, UserString):
sub = sub.data
return self.data.rfind(sub, start, end)
def rindex(self, sub, start=0, end=_sys.maxsize):
return self.data.rindex(sub, start, end)
def rjust(self, width, *args):
return self.__class__(self.data.rjust(width, *args))
def rpartition(self, sep):
return self.data.rpartition(sep)
def rstrip(self, chars=None):
return self.__class__(self.data.rstrip(chars))
def split(self, sep=None, maxsplit=-1):
return self.data.split(sep, maxsplit)
def rsplit(self, sep=None, maxsplit=-1):
return self.data.rsplit(sep, maxsplit)
def splitlines(self, keepends=False): return self.data.splitlines(keepends)
def splitlines(self, keepends=False):
return self.data.splitlines(keepends)
def startswith(self, prefix, start=0, end=_sys.maxsize):
return self.data.startswith(prefix, start, end)
def strip(self, chars=None): return self.__class__(self.data.strip(chars))
def swapcase(self): return self.__class__(self.data.swapcase())
def title(self): return self.__class__(self.data.title())
def strip(self, chars=None):
return self.__class__(self.data.strip(chars))
def swapcase(self):
return self.__class__(self.data.swapcase())
def title(self):
return self.__class__(self.data.title())
def translate(self, *args):
return self.__class__(self.data.translate(*args))
def upper(self): return self.__class__(self.data.upper())
def zfill(self, width): return self.__class__(self.data.zfill(width))
def upper(self):
return self.__class__(self.data.upper())
def zfill(self, width):
return self.__class__(self.data.zfill(width))