Python 标准库中非常有用的装饰器

开发 后端
PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields root = [] # root of the circular doubly linked list root[:] = [root, root, None, None] # initialize by pointing to self

[[432227]]

众所周知,Python 语言灵活、简洁,对程序员友好,但在性能上有点不太令人满意,这一点通过一个递归的求斐波那契额函数就可以说明:

def fib(n): 
    if n <= 1: 
        return n 
    return fib(n - 1) + fib(n - 2) 
  • 1.
  • 2.
  • 3.
  • 4.

在我的 MBP 上计算 fib(40) 花费了 33 秒:

import time 
 
def main(): 
    start = time.time() 
    result = fib(40) 
    end = time.time() 
    cost = end - start 
    print(f"{result = } {cost = :.4f}"
 
if __name__ == '__main__'
    main() 
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.

但是,假如使用标准库中的这个装饰器,那结果完全不一样

from functools import lru_cache 
 
@lru_cache 
def fib(n): 
    if n <= 1: 
        return n 
    return fib(n - 1) + fib(n - 2) 
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.

这次的结果是 0 秒,你没看错,我保留了 4 位小数,后面的忽略了。

提升了多少倍?我已经计算不出来了。

为什么 lru_cache 装饰器这么牛逼,它到底做了什么事情?今天就来聊一聊这个最有用的装饰器。

如果看过计算机操作系统的话,你对 LRU 一定不会陌生,这就是著名的最近最久未使用缓存淘汰算法。

而 lru_cache 就是这个算法的具体实现。(这个算法可是面试经常考的哦,有的面试官要求现场手写代码)

现在,我们来看一个 lru_cache 的源代码,其中的英文注释,我已经为你翻译为中文:

def lru_cache(maxsize=128, typed=False): 
    """LRU 缓存装饰器 
 
    如果 *maxsize* 是 None, 将不会淘汰缓存,缓存大小也不做限制 
 
    如果 *typed* 是 True, 不同类型的参数将独立做缓存,比如 f(3.0) and f(3) 将认为是不同的函数调用而缓存在两个缓存节点上。 
 
    函数的参数必须可以被 hash 
 
    查看缓存信息使用的是命名元组 (hits, misses, maxsize, currsize) 
    查看缓存信息:user_func.cache_info().  清理缓存信息:user_func.cache_clear(). 
 
    LRU 算法:  http://en.wikipedia.org/wiki/Cache_replacement_policies#Least_recently_used_(LRU) 
 
    ""
 
    # lru_cache 的内部实现是线程安全的 
 
    if isinstance(maxsize, int): 
        # 负数转换为 0  
        if maxsize < 0: 
            maxsize = 0 
    elif callable(maxsize) and isinstance(typed, bool): 
        #如果被装饰的函数(user_function)直接通过 maxsize 参数传入  
        user_function, maxsize = maxsize, 128 
        wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo) 
        return update_wrapper(wrapper, user_function) 
    elif maxsize is not None: 
        raise TypeError( 
            'Expected first argument to be an integer, a callable, or None'
 
    def decorating_function(user_function): 
        wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo) 
        return update_wrapper(wrapper, user_function) 
 
    return decorating_function 
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.
  • 32.
  • 33.
  • 34.
  • 35.
  • 36.

这里面有两个参数,一个是 maxsize,表示缓存的大小,当传入负数时,自动设置为 0,如果不传入 maxsize,或者设置为 None,表示缓存没有大小限制,此时没有缓存淘汰。还有一个是 type,当 type 传入 True 时,不同的参数类型会当作不同的 key 存到缓存当中。

接下来,lru_cache 的核心在这个函数上 _lru_cache_wrapper,建议有感情的阅读、背诵并默写。我们来看下它的源代码

def _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo): 
    # 所有 lru cache 实例共享的常量: 
    sentinel = object()          # 用来表示缓存未命中的唯一对象 
    make_key = _make_key         # build a key from the function arguments 
    PREV, NEXTKEY, RESULT = 0, 1, 2, 3   # names for the link fields 
 
    cache = {} 
    hits = misses = 0 
    full = False 
    cache_get = cache.get    # 绑定函数来获取缓存中 key 的值 
    cache_len = cache.__len__  # 绑定函数获取缓存大小 
    lock = RLock()           # 因为链表上的更新是线程不安全的 
    root = []                # 循环双向链表的根节点 
    root[:] = [root, root, None, None]     # 初始化根节点的前后指针都指向它自己 
 
    if maxsize == 0: 
 
        def wrapper(*args, **kwds): 
            # 没有缓存,仅更新统计信息 
            nonlocal misses 
            misses += 1 
            result = user_function(*args, **kwds) 
            return result 
 
    elif maxsize is None: 
 
        def wrapper(*args, **kwds): 
            # 仅仅排序,不考虑排序和缓存大小限制 
            nonlocal hits, misses 
            key = make_key(args, kwds, typed) 
            result = cache_get(key, sentinel) 
            if result is not sentinel: 
                hits += 1 
                return result 
            misses += 1 
            result = user_function(*args, **kwds) 
            cache[key] = result 
            return result 
 
    else
 
        def wrapper(*args, **kwds): 
            # 大小有限制,并跟踪最近使用的缓存 
            nonlocal root, hits, misses, full 
            key = make_key(args, kwds, typed) 
            with lock: 
                link = cache_get(key
                if link is not None: 
                    # 缓存命中,将命中的缓存移动到循环双向链表的头部 
                    link_prev, link_next, _key, result = link 
                    link_prev[NEXT] = link_next 
                    link_next[PREV] = link_prev 
                    last = root[PREV] 
                    last[NEXT] = root[PREV] = link 
                    link[PREV] = last 
                    link[NEXT] = root 
                    hits += 1 
                    return result 
                misses += 1 
            result = user_function(*args, **kwds) 
            with lock: 
                if key in cache: 
                    # 走到这里说明 key 已经放在了缓存,且锁已经释放了,链表已经更新了,这里什么也不需要做了,最后只需要返回计算的结果就可以了。 
                    pass 
                elif full
                    # 如果缓存满了, 使用最老的根节点来存储新节点就可以了,链表上不需要删除(是不是很聪明) 
                    oldroot = root 
                    oldroot[KEY] = key 
                    oldroot[RESULT] = result 
                    root = oldroot[NEXT
                    oldkey = root[KEY
                    oldresult = root[RESULT] 
                    root[KEY] = root[RESULT] = None 
                     
                    # 最后,我们需要从缓存中清除这个 key,因为它已经无效了。 
                    del cache[oldkey] 
                    # 新值放入缓存 
                    cache[key] = oldroot 
                else
                    # 如果没有满,将新的结果放入循环双向链表的头部 
                    last = root[PREV] 
                    link = [last, root, key, result] 
                    last[NEXT] = root[PREV] = cache[key] = link 
                    # 使用 cache_len 绑定方法而不是 len() 函数,后者可能会被包装在 lru_cache 本身中 
                    full = (cache_len() >= maxsize) 
            return result 
 
    def cache_info(): 
        """报告缓存统计信息""" 
        with lock: 
            return _CacheInfo(hits, misses, maxsize, cache_len()) 
 
    def cache_clear(): 
        """清理缓存信息""" 
        nonlocal hits, misses, full 
        with lock: 
            cache.clear() 
            root[:] = [root, root, None, None] 
            hits = misses = 0 
            full = False 
 
    wrapper.cache_info = cache_info 
    wrapper.cache_clear = cache_clear 
    return wrapper 
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.
  • 32.
  • 33.
  • 34.
  • 35.
  • 36.
  • 37.
  • 38.
  • 39.
  • 40.
  • 41.
  • 42.
  • 43.
  • 44.
  • 45.
  • 46.
  • 47.
  • 48.
  • 49.
  • 50.
  • 51.
  • 52.
  • 53.
  • 54.
  • 55.
  • 56.
  • 57.
  • 58.
  • 59.
  • 60.
  • 61.
  • 62.
  • 63.
  • 64.
  • 65.
  • 66.
  • 67.
  • 68.
  • 69.
  • 70.
  • 71.
  • 72.
  • 73.
  • 74.
  • 75.
  • 76.
  • 77.
  • 78.
  • 79.
  • 80.
  • 81.
  • 82.
  • 83.
  • 84.
  • 85.
  • 86.
  • 87.
  • 88.
  • 89.
  • 90.
  • 91.
  • 92.
  • 93.
  • 94.
  • 95.
  • 96.
  • 97.
  • 98.
  • 99.
  • 100.
  • 101.
  • 102.
  • 103.
  • 104.

如果我写的注释你都看明白了,那也不用看我下面的废话了,如果还有点不太明白,我啰嗦几句,也许你就明白了。

第一、所谓缓存,用的仍然是内存,为了快速存取,用的就是一个 hash 表,也就是 Python 的字典,都是在内存里的操作。

cache = {} 
  • 1.

第二、如果 maxsize == 0,就相当于没有使用缓存,每调用一次,未命中数就 + 1,代码逻辑是这样的:

def wrapper(*args, **kwds): 
    nonlocal misses 
    misses += 1 # 未命中数 
    result = user_function(*args, **kwds) 
    return result 
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.

第三、如果 maxsize == None,相当于缓存无限制,也就不需要考虑淘汰,这个实现非常简单,我们直接在函数中用一个字典就可以实现,比如说:

cache = {} 
def fib(n): 
 
    if n in cache: 
        return cache[n] 
 
    if n <= 1: 
        return n 
    result = fib(n - 1) + fib(n - 2) 
    cache[n] = result 
    return result 
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.

运行时间:

理解了这一点,在装饰器中,这段逻辑就不难看懂:

def wrapper(*args, **kwds): 
    nonlocal hits, misses 
    key = make_key(args, kwds, typed) 
    result = cache_get(key, sentinel) 
    if result is not sentinel: 
        hits += 1 
        return result 
    misses += 1 
    result = user_function(*args, **kwds) 
    cache[key] = result 
    return result 
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.

第四、真正的缓存淘汰算法。

为了实现缓存(键值对)的淘汰,我们需要对缓存按时间进行排序,这就需要用到链表,链表的头部是最新插入的,尾部是最老插入的,当缓存数量已经达到最大值时,我们删除最久未使用的链尾节点,为了不删除链尾,我们可以使用循环链表,当缓存满了,直接更新链尾节点赋值为新节点,并把它做为新的链头就可以了。

当缓存命中时,我们需要把这个节点移动到链表的头部,保证链表的头部是最近经常使用的,为了移动方便,我们需要双向链表。

双向循环链表在 Python 中实现,可以简单的这么写:

PREV, NEXTKEY, RESULT = 0, 1, 2, 3   # names for the link fields 
root = []                # root of the circular doubly linked list 
root[:] = [root, root, None, None]     # initialize by pointing to self 
  • 1.
  • 2.
  • 3.

可能有些朋友看不懂最后那行代码:root[:] = [root, root, None, None],画个图你就理解了:

这些箭头指向的都是节点的内存地址,随着节点的增多,就是这个样子的:

对比这个图,再看源代码,就很容易看懂了。尤其是这块的代码逻辑,是面试常考的重点,如果你能手写出这样线程安全的 LRU 缓存淘汰算法,那无疑是非常优秀的。

其他 LRU 算法的实现

其他关于 LRU 算法的实现,我自己写了两个,可以看这里:

LRU 缓存淘汰算法-双链表+hash 表[1]

LRU 缓存淘汰算法-Python 有序字典[2]

最后的话

装饰器 lru_cache 的作用就是把函数的计算机结果保存下来,下次用的时候可以直接从 hash 表中取出,避免重复计算从而提升效率,简单点的,直接在函数中使用个字典就搞定了,复杂点的,请看 lru_cache 的代码实现。另一方面,递归函数慢的一个主要原因就是重复计算。

Python 标准库的源码,是学习编程最有营养的原料,当你有好奇心时,不妨去窥探一下源码,相信你有定会有新的收获。今天的分享就到这里,如果有收获的话,请点赞、在看、转发、关注,感谢你的支持。

参考资料

[1]

LRU 缓存淘汰算法-双链表+hash 表: https://github.com/somenzz/geekbang/blob/master/algorthms/lru_use_link_table.py

[2]

 

LRU 缓存淘汰算法-Python 有序字典: https://github.com/somenzz/geekbang/blob/master/algorthms/lru_use_ordered_dict.py

 

责任编辑:武晓燕 来源: Python七号
相关推荐

2020-10-29 10:00:55

Python函数文件

2021-11-30 08:26:59

C++设计模式

2009-03-24 14:23:59

PHP类库PHP开发PHP

2012-05-25 14:20:08

JavaScript

2025-02-26 11:05:03

2013-06-14 14:57:09

Java基础代码

2012-04-17 09:44:08

JavaScript

2022-06-27 19:01:04

Python应用程序数据

2010-07-30 09:07:12

PHP函数

2011-07-07 17:16:43

PHP

2017-08-02 13:32:18

编程Java程序片段

2023-02-19 15:22:22

React技巧

2011-04-06 14:08:14

jQuery

2020-08-27 10:10:46

NodeJs 软件Express

2022-09-02 23:08:04

JavaScript技巧开发

2023-06-13 15:15:02

JavaScript前端编程语言

2018-08-03 10:02:05

Linux命令

2021-08-17 10:34:19

Python数据科学机器学习

2013-11-05 10:03:22

Eclipse功能

2013-08-21 10:31:22

HTML5工具
点赞
收藏

51CTO技术栈公众号