diff --git a/Lib/lib2to3/fixer_base.py b/Lib/lib2to3/fixer_base.py index 34d07280596..4b536a91432 100644 --- a/Lib/lib2to3/fixer_base.py +++ b/Lib/lib2to3/fixer_base.py @@ -94,10 +94,6 @@ class BaseFix(object): """ raise NotImplementedError() - def parenthesize(self, node): - """Wrapper around pygram.parenthesize().""" - return pygram.parenthesize(node) - def new_name(self, template="xxx_todo_changeme"): """Return a string suitable for use as an identifier diff --git a/Lib/lib2to3/fixer_util.py b/Lib/lib2to3/fixer_util.py index ea394e82dad..0c485f08d32 100644 --- a/Lib/lib2to3/fixer_util.py +++ b/Lib/lib2to3/fixer_util.py @@ -158,6 +158,9 @@ def is_list(node): ### Misc ########################################################### +def parenthesize(node): + return Node(syms.atom, [LParen(), node, RParen()]) + consuming_calls = set(["sorted", "list", "set", "any", "all", "tuple", "sum", "min", "max"]) @@ -232,20 +235,77 @@ def make_suite(node): suite.parent = parent return suite -def does_tree_import(package, name, node): - """ Returns true if name is imported from package at the - top level of the tree which node belongs to. - To cover the case of an import like 'import foo', use - Null for the package and 'foo' for the name. """ +def find_root(node): + """Find the top level namespace.""" # Scamper up to the top level namespace while node.type != syms.file_input: assert node.parent, "Tree is insane! root found before "\ "file_input node was found." node = node.parent + return node - binding = find_binding(name, node, package) +def does_tree_import(package, name, node): + """ Returns true if name is imported from package at the + top level of the tree which node belongs to. + To cover the case of an import like 'import foo', use + None for the package and 'foo' for the name. """ + binding = find_binding(name, find_root(node), package) return bool(binding) +def is_import(node): + """Returns true if the node is an import statement.""" + return node.type in (syms.import_name, syms.import_from) + +def touch_import(package, name, node): + """ Works like `does_tree_import` but adds an import statement + if it was not imported. """ + def is_import_stmt(node): + return node.type == syms.simple_stmt and node.children and \ + is_import(node.children[0]) + + root = find_root(node) + + if does_tree_import(package, name, root): + return + + add_newline_before = False + + # figure out where to insert the new import. First try to find + # the first import and then skip to the last one. + insert_pos = offset = 0 + for idx, node in enumerate(root.children): + if not is_import_stmt(node): + continue + for offset, node2 in enumerate(root.children[idx:]): + if not is_import_stmt(node2): + break + insert_pos = idx + offset + break + + # if there are no imports where we can insert, find the docstring. + # if that also fails, we stick to the beginning of the file + if insert_pos == 0: + for idx, node in enumerate(root.children): + if node.type == syms.simple_stmt and node.children and \ + node.children[0].type == token.STRING: + insert_pos = idx + 1 + add_newline_before + break + + if package is None: + import_ = Node(syms.import_name, [ + Leaf(token.NAME, 'import'), + Leaf(token.NAME, name, prefix=' ') + ]) + else: + import_ = FromImport(package, [Leaf(token.NAME, name, prefix=' ')]) + + children = [import_, Newline()] + if add_newline_before: + children.insert(0, Newline()) + root.insert_child(insert_pos, Node(syms.simple_stmt, children)) + + _def_syms = set([syms.classdef, syms.funcdef]) def find_binding(name, node, package=None): """ Returns the node which binds variable name, otherwise None. @@ -285,7 +345,7 @@ def find_binding(name, node, package=None): if ret: if not package: return ret - if ret.type in (syms.import_name, syms.import_from): + if is_import(ret): return ret return None diff --git a/Lib/lib2to3/fixes/fix_apply.py b/Lib/lib2to3/fixes/fix_apply.py index faede685c3f..5af13b716d8 100644 --- a/Lib/lib2to3/fixes/fix_apply.py +++ b/Lib/lib2to3/fixes/fix_apply.py @@ -9,7 +9,7 @@ This converts apply(func, v, k) into (func)(*v, **k).""" from .. import pytree from ..pgen2 import token from .. import fixer_base -from ..fixer_util import Call, Comma +from ..fixer_util import Call, Comma, parenthesize class FixApply(fixer_base.BaseFix): @@ -39,7 +39,7 @@ class FixApply(fixer_base.BaseFix): (func.type != syms.power or func.children[-2].type == token.DOUBLESTAR)): # Need to parenthesize - func = self.parenthesize(func) + func = parenthesize(func) func.set_prefix("") args = args.clone() args.set_prefix("") diff --git a/Lib/lib2to3/fixes/fix_has_key.py b/Lib/lib2to3/fixes/fix_has_key.py index fb7b07b3611..482f27da18b 100644 --- a/Lib/lib2to3/fixes/fix_has_key.py +++ b/Lib/lib2to3/fixes/fix_has_key.py @@ -33,7 +33,7 @@ CAVEATS: from .. import pytree from ..pgen2 import token from .. import fixer_base -from ..fixer_util import Name +from ..fixer_util import Name, parenthesize class FixHasKey(fixer_base.BaseFix): @@ -86,7 +86,7 @@ class FixHasKey(fixer_base.BaseFix): after = [n.clone() for n in after] if arg.type in (syms.comparison, syms.not_test, syms.and_test, syms.or_test, syms.test, syms.lambdef, syms.argument): - arg = self.parenthesize(arg) + arg = parenthesize(arg) if len(before) == 1: before = before[0] else: @@ -98,12 +98,12 @@ class FixHasKey(fixer_base.BaseFix): n_op = pytree.Node(syms.comp_op, (n_not, n_op)) new = pytree.Node(syms.comparison, (arg, n_op, before)) if after: - new = self.parenthesize(new) + new = parenthesize(new) new = pytree.Node(syms.power, (new,) + tuple(after)) if node.parent.type in (syms.comparison, syms.expr, syms.xor_expr, syms.and_expr, syms.shift_expr, syms.arith_expr, syms.term, syms.factor, syms.power): - new = self.parenthesize(new) + new = parenthesize(new) new.set_prefix(prefix) return new diff --git a/Lib/lib2to3/fixes/fix_imports.py b/Lib/lib2to3/fixes/fix_imports.py index e48c4f0904d..75770c915f9 100644 --- a/Lib/lib2to3/fixes/fix_imports.py +++ b/Lib/lib2to3/fixes/fix_imports.py @@ -42,6 +42,8 @@ MAPPING = {'StringIO': 'io', 'DocXMLRPCServer': 'xmlrpc.server', 'SimpleXMLRPCServer': 'xmlrpc.server', 'httplib': 'http.client', + 'htmlentitydefs' : 'html.entities', + 'HTMLParser' : 'html.parser', 'Cookie': 'http.cookies', 'cookielib': 'http.cookiejar', 'BaseHTTPServer': 'http.server', @@ -64,16 +66,17 @@ def build_pattern(mapping=MAPPING): mod_list = ' | '.join(["module_name='%s'" % key for key in mapping]) bare_names = alternates(mapping.keys()) - yield """name_import=import_name< 'import' ((%s) - | dotted_as_names< any* (%s) any* >) > + yield """name_import=import_name< 'import' ((%s) | + multiple_imports=dotted_as_names< any* (%s) any* >) > """ % (mod_list, mod_list) yield """import_from< 'from' (%s) 'import' ['('] ( any | import_as_name< any 'as' any > | import_as_names< any* >) [')'] > """ % mod_list - yield """import_name< 'import' - dotted_as_name< (%s) 'as' any > > - """ % mod_list + yield """import_name< 'import' (dotted_as_name< (%s) 'as' any > | + multiple_imports=dotted_as_names< + any* dotted_as_name< (%s) 'as' any > any* >) > + """ % (mod_list, mod_list) # Find usages of module members in code e.g. thread.foo(bar) yield "power< bare_with_attr=(%s) trailer<'.' any > any* >" % bare_names @@ -100,8 +103,8 @@ class FixImports(fixer_base.BaseFix): match = super(FixImports, self).match results = match(node) if results: - # Module usage could be in the trailier of an attribute lookup, so - # we might have nested matches when "bare_with_attr" is present. + # Module usage could be in the trailer of an attribute lookup, so we + # might have nested matches when "bare_with_attr" is present. if "bare_with_attr" not in results and \ any([match(obj) for obj in attr_chain(node, "parent")]): return False @@ -116,11 +119,21 @@ class FixImports(fixer_base.BaseFix): import_mod = results.get("module_name") if import_mod: new_name = self.mapping[(import_mod or mod_name).value] + import_mod.replace(Name(new_name, prefix=import_mod.get_prefix())) if "name_import" in results: # If it's not a "from x import x, y" or "import x as y" import, # marked its usage to be replaced. self.replace[import_mod.value] = new_name - import_mod.replace(Name(new_name, prefix=import_mod.get_prefix())) + if "multiple_imports" in results: + # This is a nasty hack to fix multiple imports on a + # line (e.g., "import StringIO, urlparse"). The problem is that I + # can't figure out an easy way to make a pattern recognize the + # keys of MAPPING randomly sprinkled in an import statement. + while True: + results = self.match(node) + if not results: + break + self.transform(node, results) else: # Replace usage of the module. bare_name = results["bare_with_attr"][0] diff --git a/Lib/lib2to3/fixes/fix_imports2.py b/Lib/lib2to3/fixes/fix_imports2.py index bcd7aa6db24..477bafd4a3a 100644 --- a/Lib/lib2to3/fixes/fix_imports2.py +++ b/Lib/lib2to3/fixes/fix_imports2.py @@ -11,6 +11,6 @@ MAPPING = { class FixImports2(fix_imports.FixImports): - order = "post" + run_order = 6 mapping = MAPPING diff --git a/Lib/lib2to3/fixes/fix_intern.py b/Lib/lib2to3/fixes/fix_intern.py index 921ba597818..66c616e666c 100644 --- a/Lib/lib2to3/fixes/fix_intern.py +++ b/Lib/lib2to3/fixes/fix_intern.py @@ -8,7 +8,7 @@ intern(s) -> sys.intern(s)""" # Local imports from .. import pytree from .. import fixer_base -from ..fixer_util import Name, Attr +from ..fixer_util import Name, Attr, touch_import class FixIntern(fixer_base.BaseFix): @@ -40,4 +40,5 @@ class FixIntern(fixer_base.BaseFix): newarglist, results["rpar"].clone()])] + after) new.set_prefix(node.get_prefix()) + touch_import(None, 'sys', node) return new diff --git a/Lib/lib2to3/fixes/fix_isinstance.py b/Lib/lib2to3/fixes/fix_isinstance.py new file mode 100644 index 00000000000..295577ae2b2 --- /dev/null +++ b/Lib/lib2to3/fixes/fix_isinstance.py @@ -0,0 +1,52 @@ +# Copyright 2008 Armin Ronacher. +# Licensed to PSF under a Contributor Agreement. + +"""Fixer that cleans up a tuple argument to isinstance after the tokens +in it were fixed. This is mainly used to remove double occurrences of +tokens as a leftover of the long -> int / unicode -> str conversion. + +eg. isinstance(x, (int, long)) -> isinstance(x, (int, int)) + -> isinstance(x, int) +""" + +from .. import fixer_base +from ..fixer_util import token + + +class FixIsinstance(fixer_base.BaseFix): + + PATTERN = """ + power< + 'isinstance' + trailer< '(' arglist< any ',' atom< '(' + args=testlist_gexp< any+ > + ')' > > ')' > + > + """ + + run_order = 6 + + def transform(self, node, results): + names_inserted = set() + testlist = results["args"] + args = testlist.children + new_args = [] + iterator = enumerate(args) + for idx, arg in iterator: + if arg.type == token.NAME and arg.value in names_inserted: + if idx < len(args) - 1 and args[idx + 1].type == token.COMMA: + iterator.next() + continue + else: + new_args.append(arg) + if arg.type == token.NAME: + names_inserted.add(arg.value) + if new_args and new_args[-1].type == token.COMMA: + del new_args[-1] + if len(new_args) == 1: + atom = testlist.parent + new_args[0].set_prefix(atom.get_prefix()) + atom.replace(new_args[0]) + else: + args[:] = new_args + node.changed() diff --git a/Lib/lib2to3/fixes/fix_long.py b/Lib/lib2to3/fixes/fix_long.py index f67f0261c12..5fd6af55145 100644 --- a/Lib/lib2to3/fixes/fix_long.py +++ b/Lib/lib2to3/fixes/fix_long.py @@ -2,8 +2,6 @@ # Licensed to PSF under a Contributor Agreement. """Fixer that turns 'long' into 'int' everywhere. - -This also strips the trailing 'L' or 'l' from long loterals. """ # Local imports @@ -14,22 +12,13 @@ from ..fixer_util import Name, Number class FixLong(fixer_base.BaseFix): - PATTERN = """ - (long_type = 'long' | number = NUMBER) - """ + PATTERN = "'long'" static_long = Name("long") static_int = Name("int") def transform(self, node, results): - long_type = results.get("long_type") - number = results.get("number") - new = None - if long_type: - assert node == self.static_long, node - new = self.static_int.clone() - if number and node.value[-1] in ("l", "L"): - new = Number(node.value[:-1]) - if new is not None: - new.set_prefix(node.get_prefix()) - return new + assert node == self.static_long, node + new = self.static_int.clone() + new.set_prefix(node.get_prefix()) + return new diff --git a/Lib/lib2to3/fixes/fix_reduce.py b/Lib/lib2to3/fixes/fix_reduce.py new file mode 100644 index 00000000000..89fa2b431e9 --- /dev/null +++ b/Lib/lib2to3/fixes/fix_reduce.py @@ -0,0 +1,33 @@ +# Copyright 2008 Armin Ronacher. +# Licensed to PSF under a Contributor Agreement. + +"""Fixer for reduce(). + +Makes sure reduce() is imported from the functools module if reduce is +used in that module. +""" + +from .. import pytree +from .. import fixer_base +from ..fixer_util import Name, Attr, touch_import + + + +class FixReduce(fixer_base.BaseFix): + + PATTERN = """ + power< 'reduce' + trailer< '(' + arglist< ( + (not(argument) any ',' + not(argument + > + """ + + def transform(self, node, results): + touch_import('functools', 'reduce', node) diff --git a/Lib/lib2to3/fixes/fix_repr.py b/Lib/lib2to3/fixes/fix_repr.py index 99e772272d9..0bc6ba6f635 100644 --- a/Lib/lib2to3/fixes/fix_repr.py +++ b/Lib/lib2to3/fixes/fix_repr.py @@ -5,7 +5,7 @@ # Local imports from .. import fixer_base -from ..fixer_util import Call, Name +from ..fixer_util import Call, Name, parenthesize class FixRepr(fixer_base.BaseFix): @@ -18,5 +18,5 @@ class FixRepr(fixer_base.BaseFix): expr = results["expr"].clone() if expr.type == self.syms.testlist1: - expr = self.parenthesize(expr) + expr = parenthesize(expr) return Call(Name("repr"), [expr], prefix=node.get_prefix()) diff --git a/Lib/lib2to3/fixes/fix_urllib.py b/Lib/lib2to3/fixes/fix_urllib.py index ea7e9ca0ce6..1e74bfff974 100644 --- a/Lib/lib2to3/fixes/fix_urllib.py +++ b/Lib/lib2to3/fixes/fix_urllib.py @@ -29,7 +29,7 @@ MAPPING = {'urllib': [ 'AbstractBasicAuthHandler', 'HTTPBasicAuthHandler', 'ProxyBasicAuthHandler', 'AbstractDigestAuthHandler', - 'HTTPDigestAuthHander', 'ProxyDigestAuthHandler', + 'HTTPDigestAuthHandler', 'ProxyDigestAuthHandler', 'HTTPHandler', 'HTTPSHandler', 'FileHandler', 'FTPHandler', 'CacheFTPHandler', 'UnknownHandler']), diff --git a/Lib/lib2to3/fixes/fix_xrange.py b/Lib/lib2to3/fixes/fix_xrange.py index 85efcd0a056..ca8f21ad2e1 100644 --- a/Lib/lib2to3/fixes/fix_xrange.py +++ b/Lib/lib2to3/fixes/fix_xrange.py @@ -12,7 +12,9 @@ from .. import patcomp class FixXrange(fixer_base.BaseFix): PATTERN = """ - power< (name='range'|name='xrange') trailer< '(' [any] ')' > any* > + power< + (name='range'|name='xrange') trailer< '(' args=any ')' > + rest=any* > """ def transform(self, node, results): @@ -30,11 +32,14 @@ class FixXrange(fixer_base.BaseFix): def transform_range(self, node, results): if not self.in_special_context(node): - arg = node.clone() - arg.set_prefix("") - call = Call(Name("list"), [arg]) - call.set_prefix(node.get_prefix()) - return call + range_call = Call(Name("range"), [results["args"].clone()]) + # Encase the range call in list(). + list_call = Call(Name("list"), [range_call], + prefix=node.get_prefix()) + # Put things that were after the range() call after the list call. + for n in results["rest"]: + list_call.append_child(n) + return list_call return node P1 = "power< func=NAME trailer< '(' node=any ')' > any* >" diff --git a/Lib/lib2to3/main.py b/Lib/lib2to3/main.py index 84f0f840a4c..a4d148d4558 100644 --- a/Lib/lib2to3/main.py +++ b/Lib/lib2to3/main.py @@ -5,6 +5,7 @@ Main program for 2to3. import sys import os import logging +import shutil import optparse from . import refactor @@ -39,6 +40,7 @@ class StdoutRefactoringTool(refactor.RefactoringTool): # Actually write the new file super(StdoutRefactoringTool, self).write_file(new_text, filename, old_text) + shutil.copymode(filename, backup) def print_output(self, lines): for line in lines: @@ -56,7 +58,7 @@ def main(fixer_pkg, args=None): Returns a suggested exit status (0, 1, 2). """ # Set up option parser - parser = optparse.OptionParser(usage="refactor.py [options] file|dir ...") + parser = optparse.OptionParser(usage="2to3 [options] file|dir ...") parser.add_option("-d", "--doctests_only", action="store_true", help="Fix up doctests only") parser.add_option("-f", "--fix", action="append", default=[], diff --git a/Lib/lib2to3/pygram.py b/Lib/lib2to3/pygram.py index 24f1bcf33fd..d63978d1e94 100644 --- a/Lib/lib2to3/pygram.py +++ b/Lib/lib2to3/pygram.py @@ -29,10 +29,3 @@ class Symbols(object): python_grammar = driver.load_grammar(_GRAMMAR_FILE) python_symbols = Symbols(python_grammar) - - -def parenthesize(node): - return pytree.Node(python_symbols.atom, - (pytree.Leaf(token.LPAR, "("), - node, - pytree.Leaf(token.RPAR, ")"))) diff --git a/Lib/lib2to3/pytree.py b/Lib/lib2to3/pytree.py index d3457335964..6de7cf138f1 100644 --- a/Lib/lib2to3/pytree.py +++ b/Lib/lib2to3/pytree.py @@ -279,18 +279,21 @@ class Node(Base): child.parent = self self.children[i].parent = None self.children[i] = child + self.changed() def insert_child(self, i, child): """Equivalent to 'node.children.insert(i, child)'. This method also sets the child's parent attribute appropriately.""" child.parent = self self.children.insert(i, child) + self.changed() def append_child(self, child): """Equivalent to 'node.children.append(child)'. This method also sets the child's parent attribute appropriately.""" child.parent = self self.children.append(child) + self.changed() class Leaf(Base): diff --git a/Lib/lib2to3/refactor.py b/Lib/lib2to3/refactor.py index e7dd66efb82..2e7afeb30ee 100755 --- a/Lib/lib2to3/refactor.py +++ b/Lib/lib2to3/refactor.py @@ -123,8 +123,8 @@ class RefactoringTool(object): logger=self.logger) self.pre_order, self.post_order = self.get_fixers() - self.pre_order_mapping = get_headnode_dict(self.pre_order) - self.post_order_mapping = get_headnode_dict(self.post_order) + self.pre_order_heads = get_headnode_dict(self.pre_order) + self.post_order_heads = get_headnode_dict(self.post_order) self.files = [] # List of files that were or should be modified @@ -294,8 +294,8 @@ class RefactoringTool(object): for fixer in all_fixers: fixer.start_tree(tree, name) - self.traverse_by(self.pre_order_mapping, tree.pre_order()) - self.traverse_by(self.post_order_mapping, tree.post_order()) + self.traverse_by(self.pre_order_heads, tree.pre_order()) + self.traverse_by(self.post_order_heads, tree.post_order()) for fixer in all_fixers: fixer.finish_tree(tree, name) diff --git a/Lib/lib2to3/tests/benchmark.py b/Lib/lib2to3/tests/benchmark.py deleted file mode 100644 index e0f3e68f646..00000000000 --- a/Lib/lib2to3/tests/benchmark.py +++ /dev/null @@ -1,58 +0,0 @@ -#!/usr/bin/env python2.5 -""" -This is a benchmarking script to test the speed of 2to3's pattern matching -system. It's equivalent to "refactor.py -f all" for every Python module -in sys.modules, but without engaging the actual transformations. -""" - -__author__ = "Collin Winter " - -# Python imports -import os.path -import sys -from time import time - -# Test imports -from .support import adjust_path -adjust_path() - -# Local imports -from .. import refactor - -### Mock code for refactor.py and the fixers -############################################################################### -class Options: - def __init__(self, **kwargs): - for k, v in kwargs.items(): - setattr(self, k, v) - - self.verbose = False - -def dummy_transform(*args, **kwargs): - pass - -### Collect list of modules to match against -############################################################################### -files = [] -for mod in sys.modules.values(): - if mod is None or not hasattr(mod, '__file__'): - continue - f = mod.__file__ - if f.endswith('.pyc'): - f = f[:-1] - if f.endswith('.py'): - files.append(f) - -### Set up refactor and run the benchmark -############################################################################### -options = Options(fix=["all"], print_function=False, doctests_only=False) -refactor = refactor.RefactoringTool(options) -for fixer in refactor.fixers: - # We don't want them to actually fix the tree, just match against it. - fixer.transform = dummy_transform - -t = time() -for f in files: - print "Matching", f - refactor.refactor_file(f) -print "%d seconds to match %d files" % (time() - t, len(sys.modules)) diff --git a/Lib/lib2to3/tests/test_fixers.py b/Lib/lib2to3/tests/test_fixers.py index 209d5d76592..ca556a8947e 100755 --- a/Lib/lib2to3/tests/test_fixers.py +++ b/Lib/lib2to3/tests/test_fixers.py @@ -293,30 +293,30 @@ class Test_intern(FixerTestCase): def test_prefix_preservation(self): b = """x = intern( a )""" - a = """x = sys.intern( a )""" + a = """import sys\nx = sys.intern( a )""" self.check(b, a) b = """y = intern("b" # test )""" - a = """y = sys.intern("b" # test + a = """import sys\ny = sys.intern("b" # test )""" self.check(b, a) b = """z = intern(a+b+c.d, )""" - a = """z = sys.intern(a+b+c.d, )""" + a = """import sys\nz = sys.intern(a+b+c.d, )""" self.check(b, a) def test(self): b = """x = intern(a)""" - a = """x = sys.intern(a)""" + a = """import sys\nx = sys.intern(a)""" self.check(b, a) b = """z = intern(a+b+c.d,)""" - a = """z = sys.intern(a+b+c.d,)""" + a = """import sys\nz = sys.intern(a+b+c.d,)""" self.check(b, a) b = """intern("y%s" % 5).replace("y", "")""" - a = """sys.intern("y%s" % 5).replace("y", "")""" + a = """import sys\nsys.intern("y%s" % 5).replace("y", "")""" self.check(b, a) # These should not be refactored @@ -337,6 +337,35 @@ class Test_intern(FixerTestCase): s = """intern()""" self.unchanged(s) +class Test_reduce(FixerTestCase): + fixer = "reduce" + + def test_simple_call(self): + b = "reduce(a, b, c)" + a = "from functools import reduce\nreduce(a, b, c)" + self.check(b, a) + + def test_call_with_lambda(self): + b = "reduce(lambda x, y: x + y, seq)" + a = "from functools import reduce\nreduce(lambda x, y: x + y, seq)" + self.check(b, a) + + def test_unchanged(self): + s = "reduce(a)" + self.unchanged(s) + + s = "reduce(a, b=42)" + self.unchanged(s) + + s = "reduce(a, b, c, d)" + self.unchanged(s) + + s = "reduce(**c)" + self.unchanged(s) + + s = "reduce()" + self.unchanged(s) + class Test_print(FixerTestCase): fixer = "print" @@ -1044,33 +1073,39 @@ class Test_long(FixerTestCase): a = """z = type(x) in (int, int)""" self.check(b, a) - def test_4(self): - b = """a = 12L""" - a = """a = 12""" - self.check(b, a) - - def test_5(self): - b = """b = 0x12l""" - a = """b = 0x12""" - self.check(b, a) - - def test_unchanged_1(self): - s = """a = 12""" - self.unchanged(s) - - def test_unchanged_2(self): - s = """b = 0x12""" - self.unchanged(s) - - def test_unchanged_3(self): - s = """c = 3.14""" - self.unchanged(s) - def test_prefix_preservation(self): b = """x = long( x )""" a = """x = int( x )""" self.check(b, a) +class Test_isinstance(FixerTestCase): + fixer = "isinstance" + + def test_remove_multiple_items(self): + b = """isinstance(x, (int, int, int))""" + a = """isinstance(x, int)""" + self.check(b, a) + + b = """isinstance(x, (int, float, int, int, float))""" + a = """isinstance(x, (int, float))""" + self.check(b, a) + + b = """isinstance(x, (int, float, int, int, float, str))""" + a = """isinstance(x, (int, float, str))""" + self.check(b, a) + + b = """isinstance(foo() + bar(), (x(), y(), x(), int, int))""" + a = """isinstance(foo() + bar(), (x(), y(), x(), int))""" + self.check(b, a) + + def test_prefix_preservation(self): + b = """if isinstance( foo(), ( bar, bar, baz )) : pass""" + a = """if isinstance( foo(), ( bar, baz )) : pass""" + self.check(b, a) + + def test_unchanged(self): + self.unchanged("isinstance(x, (str, int))") + class Test_dict(FixerTestCase): fixer = "dict" @@ -1287,6 +1322,14 @@ class Test_xrange(FixerTestCase): a = """x = list(range(10, 3, 9)) + [4]""" self.check(b, a) + b = """x = range(10)[::-1]""" + a = """x = list(range(10))[::-1]""" + self.check(b, a) + + b = """x = range(10) [3]""" + a = """x = list(range(10)) [3]""" + self.check(b, a) + def test_xrange_in_for(self): b = """for i in xrange(10):\n j=i""" a = """for i in range(10):\n j=i""" @@ -1422,9 +1465,8 @@ class Test_xreadlines(FixerTestCase): s = "foo(xreadlines)" self.unchanged(s) -class Test_imports(FixerTestCase): - fixer = "imports" - from ..fixes.fix_imports import MAPPING as modules + +class ImportsFixerTests: def test_import_module(self): for old, new in self.modules.items(): @@ -1522,18 +1564,36 @@ class Test_imports(FixerTestCase): self.check(b, a) +class Test_imports(FixerTestCase, ImportsFixerTests): + fixer = "imports" + from ..fixes.fix_imports import MAPPING as modules -class Test_imports2(Test_imports): + def test_multiple_imports(self): + b = """import urlparse, cStringIO""" + a = """import urllib.parse, io""" + self.check(b, a) + + def test_multiple_imports_as(self): + b = """ + import copy_reg as bar, HTMLParser as foo, urlparse + s = urlparse.spam(bar.foo()) + """ + a = """ + import copyreg as bar, html.parser as foo, urllib.parse + s = urllib.parse.spam(bar.foo()) + """ + self.check(b, a) + + +class Test_imports2(FixerTestCase, ImportsFixerTests): fixer = "imports2" from ..fixes.fix_imports2 import MAPPING as modules -class Test_imports_fixer_order(Test_imports): - - fixer = None +class Test_imports_fixer_order(FixerTestCase, ImportsFixerTests): def setUp(self): - Test_imports.setUp(self, ['imports', 'imports2']) + super(Test_imports_fixer_order, self).setUp(['imports', 'imports2']) from ..fixes.fix_imports2 import MAPPING as mapping2 self.modules = mapping2.copy() from ..fixes.fix_imports import MAPPING as mapping1 diff --git a/Lib/lib2to3/tests/test_util.py b/Lib/lib2to3/tests/test_util.py index 5d021502fdd..95b566ab390 100644 --- a/Lib/lib2to3/tests/test_util.py +++ b/Lib/lib2to3/tests/test_util.py @@ -526,6 +526,33 @@ class Test_find_binding(support.TestCase): b = 7""" self.failIf(self.find_binding("a", s)) +class Test_touch_import(support.TestCase): + + def test_after_docstring(self): + node = parse('"""foo"""\nbar()') + fixer_util.touch_import(None, "foo", node) + self.assertEqual(str(node), '"""foo"""\nimport foo\nbar()\n\n') + + def test_after_imports(self): + node = parse('"""foo"""\nimport bar\nbar()') + fixer_util.touch_import(None, "foo", node) + self.assertEqual(str(node), '"""foo"""\nimport bar\nimport foo\nbar()\n\n') + + def test_beginning(self): + node = parse('bar()') + fixer_util.touch_import(None, "foo", node) + self.assertEqual(str(node), 'import foo\nbar()\n\n') + + def test_from_import(self): + node = parse('bar()') + fixer_util.touch_import("cgi", "escape", node) + self.assertEqual(str(node), 'from cgi import escape\nbar()\n\n') + + def test_name_import(self): + node = parse('bar()') + fixer_util.touch_import(None, "cgi", node) + self.assertEqual(str(node), 'import cgi\nbar()\n\n') + if __name__ == "__main__": import __main__