IT in General/Python

데코레이터와 피보나치

Algorithmus 2022. 3. 3. 18:02

Python의 decorator에 대한 설명으로 대표적인 것이 이름을 출력하는 예제일 것이다. 그러나 실제로 그런 일을 할 일이 별로 없기 때문에 마음에 와닿지가 않아 이해도 잘 되지 않았다. 가장 실용적인 사례로서 DP(dynamic programming)을 위한 memoization의 경우에 활용되는 @cache()라는 decorator를 피보나치 수를 구하는 예제에 적용하여 그 내부 동작을 살펴본다. 먼저 팩토리얼을 구하는 코드를 소개한다. 아래 데코레이터는 @cache로서, DP의 top-down시 많이 활용하는 lru_cache(maxsize=None)과 동일한 결과를 돌려주는 것으로 3.9에서 추가된 함수이다.

@cache
def factorial(n):
    return n * factorial(n-1) if n else 1

>>> factorial(10)      # no previously cached result, makes 11 recursive calls
3628800
>>> factorial(5)       # just looks up cached value result
120
>>> factorial(12)      # makes two new recursive calls, the other 10 are cached
479001600

# [Source] https://docs.python.org/3/library/functools.html

 

먼저 factorial(10)을 호출하면 캐시된 것이 없기 때문에 재귀호출을 하여 값을 구한다. 그러나 factorial(5)와 같이 기존에 계산된 것이 저장된(cache or memoize) 경우에는 이 값을 불러온다. 이런 과정을 자동으로 담당해 주는 것이 @cache라는 데코레이터이다.

 

다음은 3.2부터 추가된 전통적 lru_cache를 이용한 피보나치 수열 예제이다. lru_cache(None)이라고 입력해도 동일하다.

 

@lru_cache(maxsize=None)
def fib(n):
    if n < 2:
        return n
    return fib(n-1) + fib(n-2)

>>> [fib(n) for n in range(16)]
[0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]

>>> fib.cache_info()
CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)

# [Source] Same as above.

 

lru_cache 및 cache 데코레이터에 대한 설명을 직접 Python 공식 코드를 보며 하겠다. 아래 코드가 해당 decorator의 줄기이다. 일단 데코레이터는 user_function을 받아 wrapper로 감싼 다음 그 결과를 돌려준다. 여기서 wrapper는 밑에 정의된 _lru_cache_wrapper라는 은닉된 함수이다.

def lru_cache(maxsize=128, typed=False):
    if isinstance(maxsize, int):
        maxsize = 0 if maxsize < 0
    elif callable(maxsize) and isinstance(typed, bool):
        # The user_function was passed in directly via the maxsize argument
        user_function, maxsize = maxsize, 128
        wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo)
        wrapper.cache_parameters = lambda : {'maxsize': maxsize, 'typed': typed}
        return update_wrapper(wrapper, user_function)
    elif maxsize is not None:
        raise TypeError(
            'Expected first argument to be an integer, a callable, or None')
            
    def decorating_function(user_function):
        wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo)
        wrapper.cache_parameters = lambda : {'maxsize': maxsize, 'typed': typed}
        return update_wrapper(wrapper, user_function)

    return decorating_function

 

위 함수의 바로 밑에 정의되었으며 _로 시작하므로 캡슐화된 내용이다 (줄기를 제외하고는 pass로 생략했다). 사용자는 lru_cache를 통해서만 접근가능하므로 실제 동작을 소스를 보기전에는 알 수 없다. 여기서 핵심은 cache = {} 이다. 여기에 예를 들어 maxsize is None인 경우에, wrapper를 정의하고, 이것이 user_function에 인수가 들어가서 연산된 결과를 cache[key]로 저장한 뒤, 그 결과를 돌려주는 것이다.

def _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo):
    # Constants shared by all lru cache instances:
    sentinel = object()          # unique object used to signal cache misses
    make_key = _make_key         # build a key from the function arguments
    PREV, NEXT, KEY, RESULT = 0, 1, 2, 3   # names for the link fields

    cache = {}
    hits = misses = 0
    full = False
    cache_get = cache.get    # bound method to lookup a key or return None
    cache_len = cache.__len__  # get cache size without calling len()
    lock = RLock()           # because linkedlist updates aren't threadsafe
    root = []                # root of the circular doubly linked list
    root[:] = [root, root, None, None]     # initialize by pointing to self

    if maxsize == 0:

        def wrapper(*args, **kwds):
            # No caching -- just a statistics update
            nonlocal misses
            misses += 1
            result = user_function(*args, **kwds)
            return result

    elif maxsize is None:

        def wrapper(*args, **kwds):
            # Simple caching without ordering or size limit
            nonlocal hits, misses
            key = make_key(args, kwds, typed)
            result = cache_get(key, sentinel)
            if result is not sentinel:
                hits += 1
                return result
            misses += 1
            result = user_function(*args, **kwds)
            cache[key] = result
            return result

    else: pass

    def cache_info():
        """Report cache statistics"""
        with lock:
            return _CacheInfo(hits, misses, maxsize, cache_len())

    def cache_clear():
        """Clear the cache and cache statistics"""
        nonlocal hits, misses, full
        with lock:
            cache.clear()
            root[:] = [root, root, None, None]
            hits = misses = 0
            full = False

    wrapper.cache_info = cache_info
    wrapper.cache_clear = cache_clear
    return wrapper

try:
    from _functools import _lru_cache_wrapper
except ImportError:
    pass


################################################################################
### cache -- simplified access to the infinity cache
################################################################################

def cache(user_function, /):
    'Simple lightweight unbounded cache.  Sometimes called "memoize".'
    return lru_cache(maxsize=None)(user_function)

 

마지막에 보이는 함수 정의인 cache는 Python 3.9에서 새로 생겼는데 결국 lru_cache(None)에 다름아니다.

반응형

'IT in General > Python' 카테고리의 다른 글

Immutable vs mutable  (0) 2023.10.22
Iterator와 Generator 사용방법  (0) 2023.03.15
any와 all, 그리고 return  (0) 2022.03.25
itertools.islice, tee, groupby, product  (0) 2022.03.03
functools.reduce  (0) 2022.03.02