Merged revisions 62080-62262 via svnmerge from

svn+ssh://pythondev@svn.python.org/sandbox/trunk/2to3/lib2to3

........
  r62092 | collin.winter | 2008-04-01 18:27:10 +0200 (Di, 01 Apr 2008) | 1 line

  Add get_prev_sibling() to complement pytree's get_next_sibling().
........
  r62226 | collin.winter | 2008-04-08 21:07:56 +0200 (Di, 08 Apr 2008) | 1 line

  Add min() and max() to the list of special contexts that don't require adding list() calls around dict methods.
........
  r62232 | collin.winter | 2008-04-09 00:12:38 +0200 (Mi, 09 Apr 2008) | 4 lines

  Fix for http://bugs.python.org/issue2596

  This extends fix_xrange to know about the (mostly) same special contexts as fix_dict (where a special context is something that is guaranteed to fully consume the iterable), adding list() calls where appropriate. It also special-cases "x in range(y)".
........
This commit is contained in:
Martin v. Löwis 2008-04-10 02:48:01 +00:00
parent c00eb73a30
commit 60a819d681
6 changed files with 119 additions and 13 deletions

View File

@ -29,10 +29,10 @@ from .. import patcomp
from ..pgen2 import token from ..pgen2 import token
from . import basefix from . import basefix
from .util import Name, Call, LParen, RParen, ArgList, Dot, set from .util import Name, Call, LParen, RParen, ArgList, Dot, set
from . import util
exempt = set(["sorted", "list", "set", "any", "all", "tuple", "sum"]) iter_exempt = util.consuming_calls | set(["iter"])
iter_exempt = exempt | set(["iter"])
class FixDict(basefix.BaseFix): class FixDict(basefix.BaseFix):
@ -92,7 +92,7 @@ class FixDict(basefix.BaseFix):
return results["func"].value in iter_exempt return results["func"].value in iter_exempt
else: else:
# list(d.keys()) -> list(d.keys()), etc. # list(d.keys()) -> list(d.keys()), etc.
return results["func"].value in exempt return results["func"].value in util.consuming_calls
if not isiter: if not isiter:
return False return False
# for ... in d.iterkeys() -> for ... in d.keys(), etc. # for ... in d.iterkeys() -> for ... in d.keys(), etc.

View File

@ -5,14 +5,55 @@
# Local imports # Local imports
from .import basefix from .import basefix
from .util import Name from .util import Name, Call, consuming_calls
from .. import patcomp
class FixXrange(basefix.BaseFix): class FixXrange(basefix.BaseFix):
PATTERN = """ PATTERN = """
power< name='xrange' trailer< '(' [any] ')' > > power< (name='range'|name='xrange') trailer< '(' [any] ')' > any* >
""" """
def transform(self, node, results): def transform(self, node, results):
name = results["name"]
if name.value == "xrange":
return self.transform_xrange(node, results)
elif name.value == "range":
return self.transform_range(node, results)
else:
raise ValueError(repr(name))
def transform_xrange(self, node, results):
name = results["name"] name = results["name"]
name.replace(Name("range", prefix=name.get_prefix())) name.replace(Name("range", prefix=name.get_prefix()))
def transform_range(self, node, results):
if not self.in_special_context(node):
arg = node.clone()
arg.set_prefix("")
call = Call(Name("list"), [arg])
call.set_prefix(node.get_prefix())
return call
return node
P1 = "power< func=NAME trailer< '(' node=any ')' > any* >"
p1 = patcomp.compile_pattern(P1)
P2 = """for_stmt< 'for' any 'in' node=any ':' any* >
| comp_for< 'for' any 'in' node=any any* >
| comparison< any 'in' node=any any*>
"""
p2 = patcomp.compile_pattern(P2)
def in_special_context(self, node):
if node.parent is None:
return False
results = {}
if (node.parent.parent is not None and
self.p1.match(node.parent.parent, results) and
results["node"] is node):
# list(d.keys()) -> list(d.keys()), etc.
return results["func"].value in consuming_calls
# for ... in d.iterkeys() -> for ... in d.keys(), etc.
return self.p2.match(node.parent, results) and results["node"] is node

View File

@ -182,6 +182,10 @@ except NameError:
### Misc ### Misc
########################################################### ###########################################################
consuming_calls = set(["sorted", "list", "set", "any", "all", "tuple", "sum",
"min", "max"])
def attr_chain(obj, attr): def attr_chain(obj, attr):
"""Follow an attribute chain. """Follow an attribute chain.

View File

@ -167,13 +167,27 @@ class Base(object):
return None return None
# Can't use index(); we need to test by identity # Can't use index(); we need to test by identity
for i, sibling in enumerate(self.parent.children): for i, child in enumerate(self.parent.children):
if sibling is self: if child is self:
try: try:
return self.parent.children[i+1] return self.parent.children[i+1]
except IndexError: except IndexError:
return None return None
def get_prev_sibling(self):
"""Return the node immediately preceding the invocant in their
parent's children list. If the invocant does not have a previous
sibling, return None."""
if self.parent is None:
return None
# Can't use index(); we need to test by identity
for i, child in enumerate(self.parent.children):
if child is self:
if i == 0:
return None
return self.parent.children[i-1]
def get_suffix(self): def get_suffix(self):
"""Return the string immediately following the invocant node. This """Return the string immediately following the invocant node. This
is effectively equivalent to node.get_next_sibling().get_prefix()""" is effectively equivalent to node.get_next_sibling().get_prefix()"""

View File

@ -16,6 +16,8 @@ from os.path import dirname, pathsep
from .. import pygram from .. import pygram
from .. import pytree from .. import pytree
from .. import refactor from .. import refactor
from ..fixes import util
class Options: class Options:
def __init__(self, **kwargs): def __init__(self, **kwargs):
@ -1105,8 +1107,7 @@ class Test_dict(FixerTestCase):
self.check(b, a) self.check(b, a)
def test_unchanged(self): def test_unchanged(self):
wrappers = ["set", "sorted", "any", "all", "tuple", "sum"] for wrapper in util.consuming_calls:
for wrapper in wrappers:
s = "s = %s(d.keys())" % wrapper s = "s = %s(d.keys())" % wrapper
self.unchanged(s) self.unchanged(s)
@ -1254,26 +1255,54 @@ class Test_xrange(FixerTestCase):
a = """x = range( 0 , 10 , 2 )""" a = """x = range( 0 , 10 , 2 )"""
self.check(b, a) self.check(b, a)
def test_1(self): def test_single_arg(self):
b = """x = xrange(10)""" b = """x = xrange(10)"""
a = """x = range(10)""" a = """x = range(10)"""
self.check(b, a) self.check(b, a)
def test_2(self): def test_two_args(self):
b = """x = xrange(1, 10)""" b = """x = xrange(1, 10)"""
a = """x = range(1, 10)""" a = """x = range(1, 10)"""
self.check(b, a) self.check(b, a)
def test_3(self): def test_three_args(self):
b = """x = xrange(0, 10, 2)""" b = """x = xrange(0, 10, 2)"""
a = """x = range(0, 10, 2)""" a = """x = range(0, 10, 2)"""
self.check(b, a) self.check(b, a)
def test_4(self): def test_wrap_in_list(self):
b = """x = range(10, 3, 9)"""
a = """x = list(range(10, 3, 9))"""
self.check(b, a)
b = """x = foo(range(10, 3, 9))"""
a = """x = foo(list(range(10, 3, 9)))"""
self.check(b, a)
b = """x = range(10, 3, 9) + [4]"""
a = """x = list(range(10, 3, 9)) + [4]"""
self.check(b, a)
def test_xrange_in_for(self):
b = """for i in xrange(10):\n j=i""" b = """for i in xrange(10):\n j=i"""
a = """for i in range(10):\n j=i""" a = """for i in range(10):\n j=i"""
self.check(b, a) self.check(b, a)
b = """[i for i in xrange(10)]"""
a = """[i for i in range(10)]"""
self.check(b, a)
def test_range_in_for(self):
self.unchanged("for i in range(10): pass")
self.unchanged("[i for i in range(10)]")
def test_in_contains_test(self):
self.unchanged("x in range(10, 3, 9)")
def test_in_consuming_context(self):
for call in util.consuming_calls:
self.unchanged("a = %s(range(10))" % call)
class Test_raw_input(FixerTestCase): class Test_raw_input(FixerTestCase):
fixer = "raw_input" fixer = "raw_input"

View File

@ -319,6 +319,24 @@ class TestNodes(support.TestCase):
self.assertEqual(l2.get_next_sibling(), None) self.assertEqual(l2.get_next_sibling(), None)
self.assertEqual(p1.get_next_sibling(), None) self.assertEqual(p1.get_next_sibling(), None)
def testNodePrevSibling(self):
n1 = pytree.Node(1000, [])
n2 = pytree.Node(1000, [])
p1 = pytree.Node(1000, [n1, n2])
self.failUnless(n2.get_prev_sibling() is n1)
self.assertEqual(n1.get_prev_sibling(), None)
self.assertEqual(p1.get_prev_sibling(), None)
def testLeafPrevSibling(self):
l1 = pytree.Leaf(100, "a")
l2 = pytree.Leaf(100, "b")
p1 = pytree.Node(1000, [l1, l2])
self.failUnless(l2.get_prev_sibling() is l1)
self.assertEqual(l1.get_prev_sibling(), None)
self.assertEqual(p1.get_prev_sibling(), None)
class TestPatterns(support.TestCase): class TestPatterns(support.TestCase):