diff --git a/Lib/pstats.py b/Lib/pstats.py index 4bb5cf17064..8744235ad00 100644 --- a/Lib/pstats.py +++ b/Lib/pstats.py @@ -104,7 +104,9 @@ class Stats: print(file=self.stream) def load_stats(self, arg): - if not arg: self.stats = {} + if arg is None: + self.stats = {} + return elif isinstance(arg, str): f = open(arg, 'rb') self.stats = marshal.load(f) @@ -114,13 +116,13 @@ class Stats: arg = time.ctime(file_stats.st_mtime) + " " + arg except: # in case this is not unix pass - self.files = [ arg ] + self.files = [arg] elif hasattr(arg, 'create_stats'): arg.create_stats() self.stats = arg.stats arg.stats = {} if not self.stats: - raise TypeError("Cannot create or construct a %r object from '%r''" + raise TypeError("Cannot create or construct a %r object from %r" % (self.__class__, arg)) return @@ -135,29 +137,29 @@ class Stats: self.max_name_len = len(func_std_string(func)) def add(self, *arg_list): - if not arg_list: return self - if len(arg_list) > 1: self.add(*arg_list[1:]) - other = arg_list[0] - if type(self) != type(other): - other = Stats(other) - self.files += other.files - self.total_calls += other.total_calls - self.prim_calls += other.prim_calls - self.total_tt += other.total_tt - for func in other.top_level: - self.top_level[func] = None + if not arg_list: + return self + for item in reversed(arg_list): + if type(self) != type(item): + item = Stats(item) + self.files += item.files + self.total_calls += item.total_calls + self.prim_calls += item.prim_calls + self.total_tt += item.total_tt + for func in item.top_level: + self.top_level[func] = None - if self.max_name_len < other.max_name_len: - self.max_name_len = other.max_name_len + if self.max_name_len < item.max_name_len: + self.max_name_len = item.max_name_len - self.fcn_list = None + self.fcn_list = None - for func, stat in other.stats.items(): - if func in self.stats: - old_func_stat = self.stats[func] - else: - old_func_stat = (0, 0, 0, 0, {},) - self.stats[func] = add_func_stats(old_func_stat, stat) + for func, stat in item.stats.items(): + if func in self.stats: + old_func_stat = self.stats[func] + else: + old_func_stat = (0, 0, 0, 0, {},) + self.stats[func] = add_func_stats(old_func_stat, stat) return self def dump_stats(self, filename): diff --git a/Lib/test/pstats.pck b/Lib/test/pstats.pck new file mode 100644 index 00000000000..c48ccb73a9f Binary files /dev/null and b/Lib/test/pstats.pck differ diff --git a/Lib/test/test_pstats.py b/Lib/test/test_pstats.py index 7c28465c3b7..9ebeebbfee6 100644 --- a/Lib/test/test_pstats.py +++ b/Lib/test/test_pstats.py @@ -1,5 +1,6 @@ import unittest from test import support +from io import StringIO import pstats @@ -8,8 +9,8 @@ class AddCallersTestCase(unittest.TestCase): """Tests for pstats.add_callers helper.""" def test_combine_results(self): - """pstats.add_callers should combine the call results of both target - and source by adding the call time. See issue1269.""" + # pstats.add_callers should combine the call results of both target + # and source by adding the call time. See issue1269. # new format: used by the cProfile module target = {"a": (1, 2, 3, 4)} source = {"a": (1, 2, 3, 4), "b": (5, 6, 7, 8)} @@ -22,9 +23,21 @@ class AddCallersTestCase(unittest.TestCase): self.assertEqual(new_callers, {'a': 2, 'b': 5}) +class StatsTestCase(unittest.TestCase): + def setUp(self): + stats_file = support.findfile('pstats.pck') + self.stats = pstats.Stats(stats_file) + + def test_add(self): + stream = StringIO() + stats = pstats.Stats(stream=stream) + stats.add(self.stats, self.stats) + + def test_main(): support.run_unittest( - AddCallersTestCase + AddCallersTestCase, + StatsTestCase, ) diff --git a/Misc/NEWS b/Misc/NEWS index a93cad94bc2..f5a7b0285f1 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -43,6 +43,8 @@ Core and Builtins Library ------- +- Issue #10166: Avoid recursion in pstats Stats.add() for many stats items. + - Issue #10163: Skip unreadable registry keys during mimetypes initialization.