transform izip_longest #11424

This commit is contained in:
Benjamin Peterson 2011-03-07 22:50:37 -06:00
parent 923e6d29e5
commit 1af19d1ffd
3 changed files with 32 additions and 16 deletions

View File

@ -13,7 +13,7 @@ from ..fixer_util import Name
class FixItertools(fixer_base.BaseFix): class FixItertools(fixer_base.BaseFix):
BM_compatible = True BM_compatible = True
it_funcs = "('imap'|'ifilter'|'izip'|'ifilterfalse')" it_funcs = "('imap'|'ifilter'|'izip'|'izip_longest'|'ifilterfalse')"
PATTERN = """ PATTERN = """
power< it='itertools' power< it='itertools'
trailer< trailer<
@ -28,7 +28,8 @@ class FixItertools(fixer_base.BaseFix):
def transform(self, node, results): def transform(self, node, results):
prefix = None prefix = None
func = results['func'][0] func = results['func'][0]
if 'it' in results and func.value != u'ifilterfalse': if ('it' in results and
func.value not in (u'ifilterfalse', u'izip_longest')):
dot, it = (results['dot'], results['it']) dot, it = (results['dot'], results['it'])
# Remove the 'itertools' # Remove the 'itertools'
prefix = it.prefix prefix = it.prefix

View File

@ -31,9 +31,10 @@ class FixItertoolsImports(fixer_base.BaseFix):
if member_name in (u'imap', u'izip', u'ifilter'): if member_name in (u'imap', u'izip', u'ifilter'):
child.value = None child.value = None
child.remove() child.remove()
elif member_name == u'ifilterfalse': elif member_name in (u'ifilterfalse', u'izip_longest'):
node.changed() node.changed()
name_node.value = u'filterfalse' name_node.value = (u'filterfalse' if member_name[1] == u'f'
else u'zip_longest')
# Make sure the import statement is still sane # Make sure the import statement is still sane
children = imports.children[:] or [imports] children = imports.children[:] or [imports]

View File

@ -3623,16 +3623,24 @@ class Test_itertools(FixerTestCase):
a = """%s(f, a)""" a = """%s(f, a)"""
self.checkall(b, a) self.checkall(b, a)
def test_2(self): def test_qualified(self):
b = """itertools.ifilterfalse(a, b)""" b = """itertools.ifilterfalse(a, b)"""
a = """itertools.filterfalse(a, b)""" a = """itertools.filterfalse(a, b)"""
self.check(b, a) self.check(b, a)
def test_4(self): b = """itertools.izip_longest(a, b)"""
a = """itertools.zip_longest(a, b)"""
self.check(b, a)
def test_2(self):
b = """ifilterfalse(a, b)""" b = """ifilterfalse(a, b)"""
a = """filterfalse(a, b)""" a = """filterfalse(a, b)"""
self.check(b, a) self.check(b, a)
b = """izip_longest(a, b)"""
a = """zip_longest(a, b)"""
self.check(b, a)
def test_space_1(self): def test_space_1(self):
b = """ %s(f, a)""" b = """ %s(f, a)"""
a = """ %s(f, a)""" a = """ %s(f, a)"""
@ -3643,9 +3651,14 @@ class Test_itertools(FixerTestCase):
a = """ itertools.filterfalse(a, b)""" a = """ itertools.filterfalse(a, b)"""
self.check(b, a) self.check(b, a)
b = """ itertools.izip_longest(a, b)"""
a = """ itertools.zip_longest(a, b)"""
self.check(b, a)
def test_run_order(self): def test_run_order(self):
self.assert_runs_after('map', 'zip', 'filter') self.assert_runs_after('map', 'zip', 'filter')
class Test_itertools_imports(FixerTestCase): class Test_itertools_imports(FixerTestCase):
fixer = 'itertools_imports' fixer = 'itertools_imports'
@ -3696,17 +3709,18 @@ class Test_itertools_imports(FixerTestCase):
s = "from itertools import bar as bang" s = "from itertools import bar as bang"
self.unchanged(s) self.unchanged(s)
def test_ifilter(self): def test_ifilter_and_zip_longest(self):
b = "from itertools import ifilterfalse" for name in "filterfalse", "zip_longest":
a = "from itertools import filterfalse" b = "from itertools import i%s" % (name,)
a = "from itertools import %s" % (name,)
self.check(b, a) self.check(b, a)
b = "from itertools import imap, ifilterfalse, foo" b = "from itertools import imap, i%s, foo" % (name,)
a = "from itertools import filterfalse, foo" a = "from itertools import %s, foo" % (name,)
self.check(b, a) self.check(b, a)
b = "from itertools import bar, ifilterfalse, foo" b = "from itertools import bar, i%s, foo" % (name,)
a = "from itertools import bar, filterfalse, foo" a = "from itertools import bar, %s, foo" % (name,)
self.check(b, a) self.check(b, a)
def test_import_star(self): def test_import_star(self):