diff --git a/Lib/lib2to3/fixer_base.py b/Lib/lib2to3/fixer_base.py index 2b8a31a84fa..16887aabc2e 100644 --- a/Lib/lib2to3/fixer_base.py +++ b/Lib/lib2to3/fixer_base.py @@ -33,6 +33,8 @@ class BaseFix(object): explicit = False # Is this ignored by refactor.py -f all? run_order = 5 # Fixers will be sorted by run order before execution # Lower numbers will be run first. + _accept_type = None # [Advanced and not public] This tells RefactoringTool + # which node type to accept when there's not a pattern. # Shortcut for access to Python grammar symbols syms = pygram.python_symbols diff --git a/Lib/lib2to3/fixes/fix_import.py b/Lib/lib2to3/fixes/fix_import.py index 0a98cc3d828..e2bd8e7192f 100644 --- a/Lib/lib2to3/fixes/fix_import.py +++ b/Lib/lib2to3/fixes/fix_import.py @@ -12,7 +12,7 @@ Becomes: # Local imports from .. import fixer_base -from os.path import dirname, join, exists, pathsep +from os.path import dirname, join, exists, sep from ..fixer_util import FromImport, syms, token @@ -84,7 +84,7 @@ class FixImport(fixer_base.BaseFix): # so can't be a relative import. if not exists(join(dirname(base_path), '__init__.py')): return False - for ext in ['.py', pathsep, '.pyc', '.so', '.sl', '.pyd']: + for ext in ['.py', sep, '.pyc', '.so', '.sl', '.pyd']: if exists(base_path + ext): return True return False diff --git a/Lib/lib2to3/fixes/fix_imports.py b/Lib/lib2to3/fixes/fix_imports.py index 00f0a75d178..9c815e14dae 100644 --- a/Lib/lib2to3/fixes/fix_imports.py +++ b/Lib/lib2to3/fixes/fix_imports.py @@ -84,8 +84,6 @@ def build_pattern(mapping=MAPPING): class FixImports(fixer_base.BaseFix): - order = "pre" # Pre-order tree traversal - # This is overridden in fix_imports2. mapping = MAPPING diff --git a/Lib/lib2to3/fixes/fix_long.py b/Lib/lib2to3/fixes/fix_long.py index 73aba485eca..6f3661cbab5 100644 --- a/Lib/lib2to3/fixes/fix_long.py +++ b/Lib/lib2to3/fixes/fix_long.py @@ -5,18 +5,15 @@ """ # Local imports -from .. import fixer_base -from ..fixer_util import Name, Number, is_probably_builtin +from lib2to3 import fixer_base +from lib2to3.fixer_util import is_probably_builtin class FixLong(fixer_base.BaseFix): PATTERN = "'long'" - static_int = Name(u"int") - def transform(self, node, results): if is_probably_builtin(node): - new = self.static_int.clone() - new.prefix = node.prefix - return new + node.value = u"int" + node.changed() diff --git a/Lib/lib2to3/fixes/fix_ne.py b/Lib/lib2to3/fixes/fix_ne.py index 23fe86983ee..7025980b48e 100644 --- a/Lib/lib2to3/fixes/fix_ne.py +++ b/Lib/lib2to3/fixes/fix_ne.py @@ -12,9 +12,11 @@ from .. import fixer_base class FixNe(fixer_base.BaseFix): # This is so simple that we don't need the pattern compiler. + _accept_type = token.NOTEQUAL + def match(self, node): # Override - return node.type == token.NOTEQUAL and node.value == u"<>" + return node.value == u"<>" def transform(self, node, results): new = pytree.Leaf(token.NOTEQUAL, u"!=", prefix=node.prefix) diff --git a/Lib/lib2to3/fixes/fix_numliterals.py b/Lib/lib2to3/fixes/fix_numliterals.py index a049aed76e7..b0c23f8041a 100644 --- a/Lib/lib2to3/fixes/fix_numliterals.py +++ b/Lib/lib2to3/fixes/fix_numliterals.py @@ -12,10 +12,11 @@ from ..fixer_util import Number class FixNumliterals(fixer_base.BaseFix): # This is so simple that we don't need the pattern compiler. + _accept_type = token.NUMBER + def match(self, node): # Override - return (node.type == token.NUMBER and - (node.value.startswith(u"0") or node.value[-1] in u"Ll")) + return (node.value.startswith(u"0") or node.value[-1] in u"Ll") def transform(self, node, results): val = node.value diff --git a/Lib/lib2to3/fixes/fix_operator.py b/Lib/lib2to3/fixes/fix_operator.py new file mode 100644 index 00000000000..9591667214f --- /dev/null +++ b/Lib/lib2to3/fixes/fix_operator.py @@ -0,0 +1,40 @@ +"""Fixer for operator.{isCallable,sequenceIncludes} + +operator.isCallable(obj) -> hasattr(obj, '__call__') +operator.sequenceIncludes(obj) -> operator.contains(obj) +""" + +# Local imports +from .. import fixer_base +from ..fixer_util import Call, Name, String + +class FixOperator(fixer_base.BaseFix): + + methods = "method=('isCallable'|'sequenceIncludes')" + func = "'(' func=any ')'" + PATTERN = """ + power< module='operator' + trailer< '.' {methods} > trailer< {func} > > + | + power< {methods} trailer< {func} > > + """.format(methods=methods, func=func) + + def transform(self, node, results): + method = results["method"][0] + + if method.value == u"sequenceIncludes": + if "module" not in results: + # operator may not be in scope, so we can't make a change. + self.warning(node, "You should use operator.contains here.") + else: + method.value = u"contains" + method.changed() + elif method.value == u"isCallable": + if "module" not in results: + self.warning(node, + "You should use hasattr(%s, '__call__') here." % + results["func"].value) + else: + func = results["func"] + args = [func.clone(), String(u", "), String(u"'__call__'")] + return Call(Name(u"hasattr"), args, prefix=node.prefix) diff --git a/Lib/lib2to3/fixes/fix_print.py b/Lib/lib2to3/fixes/fix_print.py index be29dce1324..6cad8ce59aa 100644 --- a/Lib/lib2to3/fixes/fix_print.py +++ b/Lib/lib2to3/fixes/fix_print.py @@ -26,20 +26,15 @@ parend_expr = patcomp.compile_pattern( ) -class FixPrint(fixer_base.ConditionalFix): +class FixPrint(fixer_base.BaseFix): PATTERN = """ simple_stmt< any* bare='print' any* > | print_stmt """ - skip_on = '__future__.print_function' - def transform(self, node, results): assert results - if self.should_skip(node): - return - bare_print = results.get("bare") if bare_print: diff --git a/Lib/lib2to3/fixes/fix_urllib.py b/Lib/lib2to3/fixes/fix_urllib.py index a89266effda..d11220c5e18 100644 --- a/Lib/lib2to3/fixes/fix_urllib.py +++ b/Lib/lib2to3/fixes/fix_urllib.py @@ -12,13 +12,13 @@ from ..fixer_util import Name, Comma, FromImport, Newline, attr_chain MAPPING = {'urllib': [ ('urllib.request', ['URLOpener', 'FancyURLOpener', 'urlretrieve', - '_urlopener', 'urlopen', 'urlcleanup']), + '_urlopener', 'urlopen', 'urlcleanup', + 'pathname2url', 'url2pathname']), ('urllib.parse', ['quote', 'quote_plus', 'unquote', 'unquote_plus', - 'urlencode', 'pathname2url', 'url2pathname', 'splitattr', - 'splithost', 'splitnport', 'splitpasswd', 'splitport', - 'splitquery', 'splittag', 'splittype', 'splituser', - 'splitvalue', ]), + 'urlencode', 'splitattr', 'splithost', 'splitnport', + 'splitpasswd', 'splitport', 'splitquery', 'splittag', + 'splittype', 'splituser', 'splitvalue', ]), ('urllib.error', ['ContentTooShortError'])], 'urllib2' : [ diff --git a/Lib/lib2to3/main.py b/Lib/lib2to3/main.py index cf4adf7c029..92388079b2c 100644 --- a/Lib/lib2to3/main.py +++ b/Lib/lib2to3/main.py @@ -4,19 +4,31 @@ Main program for 2to3. import sys import os +import difflib import logging import shutil import optparse from . import refactor + +def diff_texts(a, b, filename): + """Return a unified diff of two strings.""" + a = a.splitlines() + b = b.splitlines() + return difflib.unified_diff(a, b, filename, filename, + "(original)", "(refactored)", + lineterm="") + + class StdoutRefactoringTool(refactor.MultiprocessRefactoringTool): """ Prints output to stdout. """ - def __init__(self, fixers, options, explicit, nobackups): + def __init__(self, fixers, options, explicit, nobackups, show_diffs): self.nobackups = nobackups + self.show_diffs = show_diffs super(StdoutRefactoringTool, self).__init__(fixers, options, explicit) def log_error(self, msg, *args, **kwargs): @@ -42,9 +54,18 @@ class StdoutRefactoringTool(refactor.MultiprocessRefactoringTool): if not self.nobackups: shutil.copymode(backup, filename) - def print_output(self, lines): - for line in lines: - print line + def print_output(self, old, new, filename, equal): + if equal: + self.log_message("No changes to %s", filename) + else: + self.log_message("Refactored %s", filename) + if self.show_diffs: + for line in diff_texts(old, new, filename): + print line + + +def warn(msg): + print >> sys.stderr, "WARNING: %s" % (msg,) def main(fixer_pkg, args=None): @@ -70,9 +91,12 @@ def main(fixer_pkg, args=None): parser.add_option("-l", "--list-fixes", action="store_true", help="List available transformations (fixes/fix_*.py)") parser.add_option("-p", "--print-function", action="store_true", - help="Modify the grammar so that print() is a function") + help="DEPRECATED Modify the grammar so that print() is " + "a function") parser.add_option("-v", "--verbose", action="store_true", help="More verbose logging") + parser.add_option("--no-diffs", action="store_true", + help="Don't show diffs of the refactoring") parser.add_option("-w", "--write", action="store_true", help="Write back modified files") parser.add_option("-n", "--nobackups", action="store_true", default=False, @@ -81,6 +105,11 @@ def main(fixer_pkg, args=None): # Parse command line arguments refactor_stdin = False options, args = parser.parse_args(args) + if not options.write and options.no_diffs: + warn("not writing files and not printing diffs; that's not very useful") + if options.print_function: + warn("-p is deprecated; " + "detection of from __future__ import print_function is automatic") if not options.write and options.nobackups: parser.error("Can't use -n without -w") if options.list_fixes: @@ -90,13 +119,13 @@ def main(fixer_pkg, args=None): if not args: return 0 if not args: - print >>sys.stderr, "At least one file or directory argument required." - print >>sys.stderr, "Use --help to show usage." + print >> sys.stderr, "At least one file or directory argument required." + print >> sys.stderr, "Use --help to show usage." return 2 if "-" in args: refactor_stdin = True if options.write: - print >>sys.stderr, "Can't write to stdin." + print >> sys.stderr, "Can't write to stdin." return 2 # Set up logging handler @@ -104,7 +133,6 @@ def main(fixer_pkg, args=None): logging.basicConfig(format='%(name)s: %(message)s', level=level) # Initialize the refactoring tool - rt_opts = {"print_function" : options.print_function} avail_fixes = set(refactor.get_fixers_from_package(fixer_pkg)) unwanted_fixes = set(fixer_pkg + ".fix_" + fix for fix in options.nofix) explicit = set() @@ -119,8 +147,8 @@ def main(fixer_pkg, args=None): else: requested = avail_fixes.union(explicit) fixer_names = requested.difference(unwanted_fixes) - rt = StdoutRefactoringTool(sorted(fixer_names), rt_opts, sorted(explicit), - options.nobackups) + rt = StdoutRefactoringTool(sorted(fixer_names), None, sorted(explicit), + options.nobackups, not options.no_diffs) # Refactor all files and directories passed as arguments if not rt.errors: diff --git a/Lib/lib2to3/patcomp.py b/Lib/lib2to3/patcomp.py index e5e114500af..5c62a157807 100644 --- a/Lib/lib2to3/patcomp.py +++ b/Lib/lib2to3/patcomp.py @@ -14,7 +14,7 @@ __author__ = "Guido van Rossum " import os # Fairly local imports -from .pgen2 import driver, literals, token, tokenize, parse +from .pgen2 import driver, literals, token, tokenize, parse, grammar # Really local imports from . import pytree @@ -138,7 +138,7 @@ class PatternCompiler(object): node = nodes[0] if node.type == token.STRING: value = unicode(literals.evalString(node.value)) - return pytree.LeafPattern(content=value) + return pytree.LeafPattern(_type_of_literal(value), value) elif node.type == token.NAME: value = node.value if value.isupper(): @@ -179,6 +179,15 @@ TOKEN_MAP = {"NAME": token.NAME, "TOKEN": None} +def _type_of_literal(value): + if value[0].isalpha(): + return token.NAME + elif value in grammar.opmap: + return grammar.opmap[value] + else: + return None + + def pattern_convert(grammar, raw_node_info): """Converts raw node information to a Node or Leaf instance.""" type, value, context, children = raw_node_info diff --git a/Lib/lib2to3/pgen2/grammar.py b/Lib/lib2to3/pgen2/grammar.py index 7818c24d26b..0483424dc42 100644 --- a/Lib/lib2to3/pgen2/grammar.py +++ b/Lib/lib2to3/pgen2/grammar.py @@ -97,6 +97,19 @@ class Grammar(object): f.close() self.__dict__.update(d) + def copy(self): + """ + Copy the grammar. + """ + new = self.__class__() + for dict_attr in ("symbol2number", "number2symbol", "dfas", "keywords", + "tokens", "symbol2label"): + setattr(new, dict_attr, getattr(self, dict_attr).copy()) + new.labels = self.labels[:] + new.states = self.states[:] + new.start = self.start + return new + def report(self): """Dump the grammar tables to standard output, for debugging.""" from pprint import pprint diff --git a/Lib/lib2to3/pygram.py b/Lib/lib2to3/pygram.py index d63978d1e94..6cdb3a44e85 100644 --- a/Lib/lib2to3/pygram.py +++ b/Lib/lib2to3/pygram.py @@ -28,4 +28,8 @@ class Symbols(object): python_grammar = driver.load_grammar(_GRAMMAR_FILE) + python_symbols = Symbols(python_grammar) + +python_grammar_no_print_statement = python_grammar.copy() +del python_grammar_no_print_statement.keywords["print"] diff --git a/Lib/lib2to3/refactor.py b/Lib/lib2to3/refactor.py index d1cdb5d8caf..97a540d7c65 100644 --- a/Lib/lib2to3/refactor.py +++ b/Lib/lib2to3/refactor.py @@ -14,14 +14,15 @@ __author__ = "Guido van Rossum " # Python imports import os import sys -import difflib import logging import operator -from collections import defaultdict +import collections +import StringIO +import warnings from itertools import chain # Local imports -from .pgen2 import driver, tokenize +from .pgen2 import driver, tokenize, token from . import pytree, pygram @@ -37,7 +38,12 @@ def get_all_fix_names(fixer_pkg, remove_prefix=True): fix_names.append(name[:-3]) return fix_names -def get_head_types(pat): + +class _EveryNode(Exception): + pass + + +def _get_head_types(pat): """ Accepts a pytree Pattern Node and returns a set of the pattern types which will match first. """ @@ -45,34 +51,50 @@ def get_head_types(pat): # NodePatters must either have no type and no content # or a type and content -- so they don't get any farther # Always return leafs + if pat.type is None: + raise _EveryNode return set([pat.type]) if isinstance(pat, pytree.NegatedPattern): if pat.content: - return get_head_types(pat.content) - return set([None]) # Negated Patterns don't have a type + return _get_head_types(pat.content) + raise _EveryNode # Negated Patterns don't have a type if isinstance(pat, pytree.WildcardPattern): # Recurse on each node in content r = set() for p in pat.content: for x in p: - r.update(get_head_types(x)) + r.update(_get_head_types(x)) return r raise Exception("Oh no! I don't understand pattern %s" %(pat)) -def get_headnode_dict(fixer_list): + +def _get_headnode_dict(fixer_list): """ Accepts a list of fixers and returns a dictionary of head node type --> fixer list. """ - head_nodes = defaultdict(list) + head_nodes = collections.defaultdict(list) + every = [] for fixer in fixer_list: - if not fixer.pattern: - head_nodes[None].append(fixer) - continue - for t in get_head_types(fixer.pattern): - head_nodes[t].append(fixer) - return head_nodes + if fixer.pattern: + try: + heads = _get_head_types(fixer.pattern) + except _EveryNode: + every.append(fixer) + else: + for node_type in heads: + head_nodes[node_type].append(fixer) + else: + if fixer._accept_type is not None: + head_nodes[fixer._accept_type].append(fixer) + else: + every.append(fixer) + for node_type in chain(pygram.python_grammar.symbol2number.itervalues(), + pygram.python_grammar.tokens): + head_nodes[node_type].extend(every) + return dict(head_nodes) + def get_fixers_from_package(pkg_name): """ @@ -101,13 +123,56 @@ else: _to_system_newlines = _identity +def _detect_future_print(source): + have_docstring = False + gen = tokenize.generate_tokens(StringIO.StringIO(source).readline) + def advance(): + tok = next(gen) + return tok[0], tok[1] + ignore = frozenset((token.NEWLINE, tokenize.NL, token.COMMENT)) + try: + while True: + tp, value = advance() + if tp in ignore: + continue + elif tp == token.STRING: + if have_docstring: + break + have_docstring = True + elif tp == token.NAME: + if value == u"from": + tp, value = advance() + if tp != token.NAME and value != u"__future__": + break + tp, value = advance() + if tp != token.NAME and value != u"import": + break + tp, value = advance() + if tp == token.OP and value == u"(": + tp, value = advance() + while tp == token.NAME: + if value == u"print_function": + return True + tp, value = advance() + if tp != token.OP and value != u",": + break + tp, value = advance() + else: + break + else: + break + except StopIteration: + pass + return False + + class FixerError(Exception): """A fixer could not be loaded.""" class RefactoringTool(object): - _default_options = {"print_function": False} + _default_options = {} CLASS_PREFIX = "Fix" # The prefix for fixer classes FILE_PREFIX = "fix_" # The prefix for modules with a fixer within @@ -124,20 +189,21 @@ class RefactoringTool(object): self.explicit = explicit or [] self.options = self._default_options.copy() if options is not None: + if "print_function" in options: + warnings.warn("the 'print_function' option is deprecated", + DeprecationWarning) self.options.update(options) self.errors = [] self.logger = logging.getLogger("RefactoringTool") self.fixer_log = [] self.wrote = False - if self.options["print_function"]: - del pygram.python_grammar.keywords["print"] self.driver = driver.Driver(pygram.python_grammar, convert=pytree.convert, logger=self.logger) self.pre_order, self.post_order = self.get_fixers() - self.pre_order_heads = get_headnode_dict(self.pre_order) - self.post_order_heads = get_headnode_dict(self.post_order) + self.pre_order_heads = _get_headnode_dict(self.pre_order) + self.post_order_heads = _get_headnode_dict(self.post_order) self.files = [] # List of files that were or should be modified @@ -196,8 +262,9 @@ class RefactoringTool(object): msg = msg % args self.logger.debug(msg) - def print_output(self, lines): - """Called with lines of output to give to the user.""" + def print_output(self, old_text, new_text, filename, equal): + """Called with the old version, new version, and filename of a + refactored file.""" pass def refactor(self, items, write=False, doctests_only=False): @@ -220,7 +287,8 @@ class RefactoringTool(object): dirnames.sort() filenames.sort() for name in filenames: - if not name.startswith(".") and name.endswith("py"): + if not name.startswith(".") and \ + os.path.splitext(name)[1].endswith("py"): fullname = os.path.join(dirpath, name) self.refactor_file(fullname, write, doctests_only) # Modify dirnames in-place to remove subdirs with leading dots @@ -276,12 +344,16 @@ class RefactoringTool(object): An AST corresponding to the refactored input stream; None if there were errors during the parse. """ + if _detect_future_print(data): + self.driver.grammar = pygram.python_grammar_no_print_statement try: tree = self.driver.parse_string(data) except Exception, err: self.log_error("Can't parse %s: %s: %s", name, err.__class__.__name__, err) return + finally: + self.driver.grammar = pygram.python_grammar self.log_debug("Refactoring %s", name) self.refactor_tree(tree, name) return tree @@ -298,7 +370,7 @@ class RefactoringTool(object): else: tree = self.refactor_string(input, "") if tree and tree.was_changed: - self.processed_file(str(tree), "", input) + self.processed_file(unicode(tree), "", input) else: self.log_debug("No changes in stdin") @@ -338,12 +410,11 @@ class RefactoringTool(object): if not fixers: return for node in traversal: - for fixer in fixers[node.type] + fixers[None]: + for fixer in fixers[node.type]: results = fixer.match(node) if results: new = fixer.transform(node, results) - if new is not None and (new != node or - str(new) != str(node)): + if new is not None: node.replace(new) node = new @@ -357,10 +428,11 @@ class RefactoringTool(object): old_text = self._read_python_source(filename)[0] if old_text is None: return - if old_text == new_text: + equal = old_text == new_text + self.print_output(old_text, new_text, filename, equal) + if equal: self.log_debug("No changes to %s", filename) return - self.print_output(diff_texts(old_text, new_text, filename)) if write: self.write_file(new_text, filename, old_text, encoding) else: @@ -451,7 +523,7 @@ class RefactoringTool(object): filename, lineno, err.__class__.__name__, err) return block if self.refactor_tree(tree, filename): - new = str(tree).splitlines(True) + new = unicode(tree).splitlines(True) # Undo the adjustment of the line numbers in wrap_toks() below. clipped, new = new[:lineno-1], new[lineno-1:] assert clipped == [u"\n"] * (lineno-1), clipped @@ -582,12 +654,3 @@ class MultiprocessRefactoringTool(RefactoringTool): else: return super(MultiprocessRefactoringTool, self).refactor_file( *args, **kwargs) - - -def diff_texts(a, b, filename): - """Return a unified diff of two strings.""" - a = a.splitlines() - b = b.splitlines() - return difflib.unified_diff(a, b, filename, filename, - "(original)", "(refactored)", - lineterm="") diff --git a/Lib/lib2to3/tests/data/different_encoding.py b/Lib/lib2to3/tests/data/different_encoding.py index 888f51f65cc..9f32bd04dc2 100644 --- a/Lib/lib2to3/tests/data/different_encoding.py +++ b/Lib/lib2to3/tests/data/different_encoding.py @@ -1,3 +1,6 @@ #!/usr/bin/env python -# -*- coding: iso-8859-1 -*- -print u'' +# -*- coding: utf-8 -*- +print u'ßàáâãäåæçèéêëìíîïðñòóôõöøùúûüýþÿÀÁÂÃÄÅÆÇÈÉÊËÌÍÎÏÐÑÒÓÔÕÖØÙÚÛÜÝÞ' + +def f(x): + print '%s\t-> α(%2i):%s β(%s)' diff --git a/Lib/lib2to3/tests/test_fixers.py b/Lib/lib2to3/tests/test_fixers.py index 189b4d2b585..a19de9541d5 100755 --- a/Lib/lib2to3/tests/test_fixers.py +++ b/Lib/lib2to3/tests/test_fixers.py @@ -18,8 +18,6 @@ class FixerTestCase(support.TestCase): def setUp(self, fix_list=None, fixer_pkg="lib2to3", options=None): if fix_list is None: fix_list = [self.fixer] - if options is None: - options = {"print_function" : False} self.refactor = support.get_refactorer(fixer_pkg, fix_list, options) self.fixer_log = [] self.filename = u"" @@ -58,8 +56,7 @@ class FixerTestCase(support.TestCase): def assert_runs_after(self, *names): fixes = [self.fixer] fixes.extend(names) - options = {"print_function" : False} - r = support.get_refactorer("lib2to3", fixes, options) + r = support.get_refactorer("lib2to3", fixes) (pre, post) = r.get_fixers() n = "fix_" + self.fixer if post and post[-1].__class__.__module__.endswith(n): @@ -379,18 +376,15 @@ class Test_print(FixerTestCase): self.unchanged(s) def test_idempotency_print_as_function(self): - print_stmt = pygram.python_grammar.keywords.pop("print") - try: - s = """print(1, 1+1, 1+1+1)""" - self.unchanged(s) + self.refactor.driver.grammar = pygram.python_grammar_no_print_statement + s = """print(1, 1+1, 1+1+1)""" + self.unchanged(s) - s = """print()""" - self.unchanged(s) + s = """print()""" + self.unchanged(s) - s = """print('')""" - self.unchanged(s) - finally: - pygram.python_grammar.keywords["print"] = print_stmt + s = """print('')""" + self.unchanged(s) def test_1(self): b = """print 1, 1+1, 1+1+1""" @@ -462,31 +456,15 @@ class Test_print(FixerTestCase): a = """print(file=sys.stderr)""" self.check(b, a) - # With from __future__ import print_function def test_with_future_print_function(self): - # XXX: These tests won't actually do anything until the parser - # is fixed so it won't crash when it sees print(x=y). - # When #2412 is fixed, the try/except block can be taken - # out and the tests can be run like normal. - # MvL: disable entirely for now, so that it doesn't print to stdout - return - try: - s = "from __future__ import print_function\n"\ - "print('Hai!', end=' ')" - self.unchanged(s) + s = "from __future__ import print_function\n" \ + "print('Hai!', end=' ')" + self.unchanged(s) - b = "print 'Hello, world!'" - a = "print('Hello, world!')" - self.check(b, a) + b = "print 'Hello, world!'" + a = "print('Hello, world!')" + self.check(b, a) - s = "from __future__ import *\n"\ - "print('Hai!', end=' ')" - self.unchanged(s) - except: - return - else: - self.assertFalse(True, "#2421 has been fixed -- printing tests "\ - "need to be updated!") class Test_exec(FixerTestCase): fixer = "exec" @@ -1705,6 +1683,11 @@ class Test_imports_fixer_order(FixerTestCase, ImportsFixerTests): for key in ('dbhash', 'dumbdbm', 'dbm', 'gdbm'): self.modules[key] = mapping1[key] + def test_after_local_imports_refactoring(self): + for fix in ("imports", "imports2"): + self.fixer = fix + self.assert_runs_after("import") + class Test_urllib(FixerTestCase): fixer = "urllib" @@ -3504,6 +3487,7 @@ class Test_itertools_imports(FixerTestCase): s = "from itertools import foo" self.unchanged(s) + class Test_import(FixerTestCase): fixer = "import" @@ -3538,8 +3522,7 @@ class Test_import(FixerTestCase): self.always_exists = False self.present_files = set(['__init__.py']) - expected_extensions = ('.py', os.path.pathsep, '.pyc', '.so', - '.sl', '.pyd') + expected_extensions = ('.py', os.path.sep, '.pyc', '.so', '.sl', '.pyd') names_to_test = (p("/spam/eggs.py"), "ni.py", p("../../shrubbery.py")) for name in names_to_test: @@ -3569,6 +3552,13 @@ class Test_import(FixerTestCase): self.present_files = set(["__init__.py", "bar.py"]) self.check(b, a) + def test_import_from_package(self): + b = "import bar" + a = "from . import bar" + self.always_exists = False + self.present_files = set(["__init__.py", "bar/"]) + self.check(b, a) + def test_comments_and_indent(self): b = "import bar # Foo" a = "from . import bar # Foo" @@ -4095,3 +4085,26 @@ class Test_getcwdu(FixerTestCase): b = """os.getcwdu ( )""" a = """os.getcwd ( )""" self.check(b, a) + + +class Test_operator(FixerTestCase): + + fixer = "operator" + + def test_operator_isCallable(self): + b = "operator.isCallable(x)" + a = "hasattr(x, '__call__')" + self.check(b, a) + + def test_operator_sequenceIncludes(self): + b = "operator.sequenceIncludes(x, y)" + a = "operator.contains(x, y)" + self.check(b, a) + + def test_bare_isCallable(self): + s = "isCallable(x)" + self.warns_unchanged(s, "You should use hasattr(x, '__call__') here.") + + def test_bare_sequenceIncludes(self): + s = "sequenceIncludes(x, y)" + self.warns_unchanged(s, "You should use operator.contains here.") diff --git a/Lib/lib2to3/tests/test_refactor.py b/Lib/lib2to3/tests/test_refactor.py index f545c7cf7b8..a8d209a9b07 100644 --- a/Lib/lib2to3/tests/test_refactor.py +++ b/Lib/lib2to3/tests/test_refactor.py @@ -7,9 +7,12 @@ import os import operator import StringIO import tempfile +import shutil import unittest +import warnings from lib2to3 import refactor, pygram, fixer_base +from lib2to3.pgen2 import token from . import support @@ -42,14 +45,11 @@ class TestRefactoringTool(unittest.TestCase): return refactor.RefactoringTool(fixers, options, explicit) def test_print_function_option(self): - gram = pygram.python_grammar - save = gram.keywords["print"] - try: - rt = self.rt({"print_function" : True}) - self.assertRaises(KeyError, operator.itemgetter("print"), - gram.keywords) - finally: - gram.keywords["print"] = save + with warnings.catch_warnings(record=True) as w: + refactor.RefactoringTool(_DEFAULT_FIXERS, {"print_function" : True}) + self.assertEqual(len(w), 1) + msg, = w + self.assertTrue(msg.category is DeprecationWarning) def test_fixer_loading_helpers(self): contents = ["explicit", "first", "last", "parrot", "preorder"] @@ -61,19 +61,63 @@ class TestRefactoringTool(unittest.TestCase): self.assertEqual(full_names, ["myfixes.fix_" + name for name in contents]) + def test_detect_future_print(self): + run = refactor._detect_future_print + self.assertFalse(run("")) + self.assertTrue(run("from __future__ import print_function")) + self.assertFalse(run("from __future__ import generators")) + self.assertFalse(run("from __future__ import generators, feature")) + input = "from __future__ import generators, print_function" + self.assertTrue(run(input)) + input ="from __future__ import print_function, generators" + self.assertTrue(run(input)) + input = "from __future__ import (print_function,)" + self.assertTrue(run(input)) + input = "from __future__ import (generators, print_function)" + self.assertTrue(run(input)) + input = "from __future__ import (generators, nested_scopes)" + self.assertFalse(run(input)) + input = """from __future__ import generators +from __future__ import print_function""" + self.assertTrue(run(input)) + self.assertFalse(run("from")) + self.assertFalse(run("from 4")) + self.assertFalse(run("from x")) + self.assertFalse(run("from x 5")) + self.assertFalse(run("from x im")) + self.assertFalse(run("from x import")) + self.assertFalse(run("from x import 4")) + input = "'docstring'\nfrom __future__ import print_function" + self.assertTrue(run(input)) + input = "'docstring'\n'somng'\nfrom __future__ import print_function" + self.assertFalse(run(input)) + input = "# comment\nfrom __future__ import print_function" + self.assertTrue(run(input)) + input = "# comment\n'doc'\nfrom __future__ import print_function" + self.assertTrue(run(input)) + input = "class x: pass\nfrom __future__ import print_function" + self.assertFalse(run(input)) + def test_get_headnode_dict(self): class NoneFix(fixer_base.BaseFix): - PATTERN = None + pass class FileInputFix(fixer_base.BaseFix): PATTERN = "file_input< any * >" + class SimpleFix(fixer_base.BaseFix): + PATTERN = "'name'" + no_head = NoneFix({}, []) with_head = FileInputFix({}, []) - d = refactor.get_headnode_dict([no_head, with_head]) - expected = {None: [no_head], - pygram.python_symbols.file_input : [with_head]} - self.assertEqual(d, expected) + simple = SimpleFix({}, []) + d = refactor._get_headnode_dict([no_head, with_head, simple]) + top_fixes = d.pop(pygram.python_symbols.file_input) + self.assertEqual(top_fixes, [with_head, no_head]) + name_fixes = d.pop(token.NAME) + self.assertEqual(name_fixes, [simple, no_head]) + for fixes in d.itervalues(): + self.assertEqual(fixes, [no_head]) def test_fixer_loading(self): from myfixes.fix_first import FixFirst @@ -106,10 +150,10 @@ class TestRefactoringTool(unittest.TestCase): class MyRT(refactor.RefactoringTool): - def print_output(self, lines): - diff_lines.extend(lines) + def print_output(self, old_text, new_text, filename, equal): + results.extend([old_text, new_text, filename, equal]) - diff_lines = [] + results = [] rt = MyRT(_DEFAULT_FIXERS) save = sys.stdin sys.stdin = StringIO.StringIO("def parrot(): pass\n\n") @@ -117,12 +161,10 @@ class TestRefactoringTool(unittest.TestCase): rt.refactor_stdin() finally: sys.stdin = save - expected = """--- (original) -+++ (refactored) -@@ -1,2 +1,2 @@ --def parrot(): pass -+def cheese(): pass""".splitlines() - self.assertEqual(diff_lines[:-1], expected) + expected = ["def parrot(): pass\n\n", + "def cheese(): pass\n\n", + "", False] + self.assertEqual(results, expected) def check_file_refactoring(self, test_file, fixers=_2TO3_FIXERS): def read_file(): @@ -145,6 +187,37 @@ class TestRefactoringTool(unittest.TestCase): test_file = os.path.join(FIXER_DIR, "parrot_example.py") self.check_file_refactoring(test_file, _DEFAULT_FIXERS) + def test_refactor_dir(self): + def check(structure, expected): + def mock_refactor_file(self, f, *args): + got.append(f) + save_func = refactor.RefactoringTool.refactor_file + refactor.RefactoringTool.refactor_file = mock_refactor_file + rt = self.rt() + got = [] + dir = tempfile.mkdtemp(prefix="2to3-test_refactor") + try: + os.mkdir(os.path.join(dir, "a_dir")) + for fn in structure: + open(os.path.join(dir, fn), "wb").close() + rt.refactor_dir(dir) + finally: + refactor.RefactoringTool.refactor_file = save_func + shutil.rmtree(dir) + self.assertEqual(got, + [os.path.join(dir, path) for path in expected]) + check([], []) + tree = ["nothing", + "hi.py", + ".dumb", + ".after.py", + "sappy"] + expected = ["hi.py"] + check(tree, expected) + tree = ["hi.py", + "a_dir/stuff.py"] + check(tree, tree) + def test_file_encoding(self): fn = os.path.join(TEST_DATA_DIR, "different_encoding.py") self.check_file_refactoring(fn) diff --git a/Lib/lib2to3/tests/test_util.py b/Lib/lib2to3/tests/test_util.py index ef3256d5a9a..6186b4ff74a 100644 --- a/Lib/lib2to3/tests/test_util.py +++ b/Lib/lib2to3/tests/test_util.py @@ -1,4 +1,4 @@ -""" Test suite for the code in fixes.util """ +""" Test suite for the code in fixer_util """ # Testing imports from . import support @@ -7,10 +7,10 @@ from . import support import os.path # Local imports -from .. import pytree -from .. import fixer_util -from ..fixer_util import Attr, Name - +from lib2to3.pytree import Node, Leaf +from lib2to3 import fixer_util +from lib2to3.fixer_util import Attr, Name, Call, Comma +from lib2to3.pgen2 import token def parse(code, strip_levels=0): # The topmost node is file_input, which we don't care about. @@ -24,7 +24,7 @@ def parse(code, strip_levels=0): class MacroTestCase(support.TestCase): def assertStr(self, node, string): if isinstance(node, (tuple, list)): - node = pytree.Node(fixer_util.syms.simple_stmt, node) + node = Node(fixer_util.syms.simple_stmt, node) self.assertEqual(str(node), string) @@ -78,6 +78,31 @@ class Test_Name(MacroTestCase): self.assertStr(Name("a", prefix="b"), "ba") +class Test_Call(MacroTestCase): + def _Call(self, name, args=None, prefix=None): + """Help the next test""" + children = [] + if isinstance(args, list): + for arg in args: + children.append(arg) + children.append(Comma()) + children.pop() + return Call(Name(name), children, prefix) + + def test(self): + kids = [None, + [Leaf(token.NUMBER, 1), Leaf(token.NUMBER, 2), + Leaf(token.NUMBER, 3)], + [Leaf(token.NUMBER, 1), Leaf(token.NUMBER, 3), + Leaf(token.NUMBER, 2), Leaf(token.NUMBER, 4)], + [Leaf(token.STRING, "b"), Leaf(token.STRING, "j", prefix=" ")] + ] + self.assertStr(self._Call("A"), "A()") + self.assertStr(self._Call("b", kids[1]), "b(1,2,3)") + self.assertStr(self._Call("a.b().c", kids[2]), "a.b().c(1,3,2,4)") + self.assertStr(self._Call("d", kids[3], prefix=" "), " d(b, j)") + + class Test_does_tree_import(support.TestCase): def _find_bind_rec(self, name, node): # Search a tree for a binding -- used to find the starting