前言
Python以其语法简单、生态好而著名,但同时,Python也是出了名的慢。
这个帖子主要记录了某些场景下优化Python性能的小技巧,虽然说加速Python给人一种“矮子里面拔将军”的感觉。。
从一道题目说起
笔者在做最大二叉树时发现,令笔者郁闷的是两种不同的写法的性能相差1倍之多:
from typing import List
class TreeNode:
    def __init__(self, x):
        self.val = x
        self.left = None
        self.right = None
class Solution:
    # 416ms
    def constructMaximumBinaryTree2(self, nums: List[int]) -> TreeNode:
        def arr2node(nums:List[int], left:int, right:int):
            if left < right:
                p = max(nums[left:right])
                index = nums.index(p)
                r = TreeNode(p)
                r.left = arr2node(nums, left, index) if left < index else None
                r.right = arr2node(nums, index+1, right) if index + 1 < right else None
                return r
            else:
                return None
        return arr2node(nums, 0, len(nums))
    # 208ms
    def constructMaximumBinaryTree(self, nums: List[int]) -> TreeNode:
        if nums == []:
            return None
        p = max(nums)
        index = nums.index(p)
        r = TreeNode(p)
        r.left = self.constructMaximumBinaryTree(nums[0:index]) if index > 0 else None
        r.right = self.constructMaximumBinaryTree(nums[index+1:]) if index + 1 < len(nums) else None
        return r
在扇面的代码中,前者是笔者写的,后者是他人的解法, 两者的时间复杂度均为O(n^2)。究竟是什么原因导致两者差距巨大?
用cProfile来检查性能
起初作者猜测: leetcode的检测有数十个测试样例,可能是微小的差别累积导致了明显的差距。笔者脑子闪过一些偏门知识:例如,List在元素少于32个时会直接传递值而不是引用(大概),有可能就是这个复制过程导致的。
笔者也多次执行同样的程序, 评测机给出的结果都稳定在400ms左右。
尽管笔者有诸多猜测,但是猜测只能提供思路,不能知道真正的瓶颈所在。
于是笔者用了cProfile这个库来进行性能测试:
from cProfile import Profile
LENS = 10000
def test_1():
    lis = get_arr(lens=LENS)
    s = Solution()
    _ = s.constructMaximumBinaryTree(lis)
def test_2():
    lis = get_arr(lens=LENS)
    s = Solution()
    _ = s.constructMaximumBinaryTree2(lis)
def main():
    prof = Profile()
    prof.runcall(test_1)
    prof.print_stats()
    print("-" * 20)
    prof2 = Profile()
    prof2.runcall(test_2)
    prof2.print_stats()
main()
为了让两者差异更加明显, 特意选了LENS = 10000。
运行, 有如下输出:
         84488 function calls (74489 primitive calls) in 0.044 seconds
   Ordered by: standard name
   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
  10000/1    0.016    0.000    0.033    0.033 654.py:31(constructMaximumBinaryTree)
        1    0.000    0.000    0.011    0.011 654.py:41(get_arr)
        1    0.001    0.001    0.001    0.001 654.py:42(<listcomp>)
        1    0.000    0.000    0.044    0.044 654.py:50(test_1)
    10000    0.006    0.000    0.006    0.000 654.py:8(__init__)
     9999    0.005    0.000    0.007    0.000 random.py:222(_randbelow)
        1    0.003    0.003    0.011    0.011 random.py:260(shuffle)
    10001    0.001    0.000    0.001    0.000 {built-in method builtins.len}
    10000    0.007    0.000    0.007    0.000 {built-in method builtins.max}
     9999    0.001    0.000    0.001    0.000 {method 'bit_length' of 'int' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
    14484    0.002    0.000    0.002    0.000 {method 'getrandbits' of '_random.Random' objects}
    10000    0.003    0.000    0.003    0.000 {method 'index' of 'list' objects}
--------------------
         74693 function calls (64694 primitive calls) in 1.012 seconds
   Ordered by: standard name
   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    1.000    1.000 654.py:16(constructMaximumBinaryTree2)
  10000/1    0.021    0.000    1.000    1.000 654.py:17(arr2node)
        1    0.000    0.000    0.012    0.012 654.py:41(get_arr)
        1    0.001    0.001    0.001    0.001 654.py:42(<listcomp>)
        1    0.000    0.000    1.012    1.012 654.py:55(test_2)
    10000    0.005    0.000    0.005    0.000 654.py:8(__init__)
     9999    0.005    0.000    0.008    0.000 random.py:222(_randbelow)
        1    0.004    0.004    0.011    0.011 random.py:260(shuffle)
        1    0.000    0.000    0.000    0.000 typing.py:1095(__hash__)
        1    0.000    0.000    0.000    0.000 typing.py:676(inner)
        2    0.000    0.000    0.000    0.000 {built-in method builtins.len}
    10000    0.008    0.000    0.008    0.000 {built-in method builtins.max}
     9999    0.001    0.000    0.001    0.000 {method 'bit_length' of 'int' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
    14685    0.002    0.000    0.002    0.000 {method 'getrandbits' of '_random.Random' objects}
    10000    0.967    0.000    0.967    0.000 {method 'index' of 'list' objects}
如上图,笔者将目光聚焦到最后一行,惊讶地发现差别居然是.index()这个方法导致的!在第一个实现中,调用10000次index总用时0.003s, 而在后者这一数值达到0.967s.
如何解释这个差距? 我们返回源码,发现前者的实现中,传递的是完整的nums的引用,那么它的查询范围为一直为[0, n),也就是从头扫到尾。
而后者的实现中,每次都传递一个切片nums[a:b], 它的查询范围限制在[a, b), 随着递归车层次越深,这个范围会越小。
没想到这样一个小小的不同会导致性能产生如此大的差别。
参考资料