Heron의 방법으로 제곱근을 구하는 수치해석 로직을 통해 Iterator와 Generator의 쓰임에 대해 설명하는 예제를 준비해보았다. 아래 코드는 [1] 에서 참조했다 (책에 존재하던 코드의 오류는 필자가 본 페이지와 같이 수정해서 원 저자의 레포에 커밋하고 머지되었다).
아래 sqrt1 함수는 제곱근을 구할 대상이 될 숫자인 a를 입력받아, a의 제곱근에 대한 추정치를 Heron의 알고리즘으로 반복해서 구하게 된다. 동 알고리즘은 매 횟수마다 이전보다 더 정밀한 추정치를 구하게 되며, 이전 단계에서 구한 추정치와 이번 단계에서 구한 추정치의 차이가 0.001보다 적게 되면 반복을 멈추고 결과를 돌려주며 종료된다.
def sqrt1(a: float) -> float:
x = a / 2 # initial guess
x_n = a
while abs(x_n - x) > 0.001:
x = x_n
x_n = (x + (a / x)) / 2
return x_n
아래 sqrt2는 위에서 0.001이라고 정했던 threshold를 함수 사용자로부터 입력받을 수 있는 점만 다르다.
def sqrt2(a: float, threshold: float) -> float:
x = a / 2 # initial guess
x_n = a
while abs(x_n - x) > threshold:
x = x_n
x_n = (x + (a / x)) / 2
return x_n
아래 함수는 특정 threshold에 구애받지 않고 더 정밀한 추정치를 구하는 과정이 무한히 반복되며, 함수의 상태가 저장되었다가 함수를 외부에서 호출할 때 마다 추정치를 뱉어내게 된다. 함수의 끝에 return x 대신 위치한 yield x 가 바로 Heron의 알고리즘에 의해 업데이트 된 추정치를 생성(generate)해 함수 외부에 전달하는 역할을 한다고 볼 수 있다. sqrt3 함수의 첫 줄에 나타난 type hints가 그런 내용을 암시하고 있다. def sqrt3(a: float)
는 sqrt3
이라는 함수를 정의하되, 입력값(argument)으로 float type의 a 변수를 받는다는 뜻이며, 그 뒷 부분인 -> Iterator[float]
는 typing 패키지를 통해 정의되는 이터레이터로서 float 타입을 출력값(output)으로 낸다는 뜻이다. 이 type hints는 파이썬과 같은 동적 프로그래밍 언어에서는 필수적인 부분이 아니지만 코드를 읽는 사람을 위해서 코드의 일부분으로 포함되어 타입에 대한 가이드를 해주는 부분이다.
코드의 실행부에서는 itertools.islice(sqrt3(n), 10)
를 통해 sqrt3(n=25) 라는 호출 가능한 객체(Callables)와 그 객체를 몇 번 호출할지(여기서는 10번)를 지정해 실행한 결과를 얻는다. 바로 이것이 sqrt3을 25에 대해 실행해서 제곱근을 얻기 위해 sqrt3(25)를 10번 호출한 결과이다. 이 결과는 itertools라는 이터레이터 생성자(generator)를 호출해 얻은 것이므로 우리가 list, tuple 처럼 쉽게 읽을 수 있는 값의 형태는 아니다. 이를 list로 변환해 보면 아래 코드 맨 아래의 주석부분에 나타난 것처럼 해당 함수를 10회 실행한 결과를 차례로 볼 수 있다.
import itertools
from typing import Generic, TypeVar, Iterator, Tuple
A = TypeVar("A")
def sqrt3(a: float) -> Iterator[float]:
x = a / 2 # initial guess
while True:
x = (x + (a / x)) / 2
yield x
n = 25
iterations = list(itertools.islice(sqrt3(n), 10))
'''
[7.25,
5.349137931034482,
5.011394106532552,
5.000012953048684,
5.000000000016778,
5.0,
5.0,
5.0,
5.0,
5.0]
'''
만일, 위에 작성한 sqrt3 함수에 대해 threshold를 주고 싶으면 해당 함수를 변경하지 않고 그 밖의 함수를 통해 제어할 수 있다. 아래의 converge
함수를 쓰면 된다. 이터레이터인 values와 threshold를 해당 함수에 주면, values를 차례로 만들어 이동하면서 앞에서부터 두 값을 비교해서 그 값이 threshold보다 작으면 yield를 더 이상 하지 않는다 (pairwise는 python 최근 버전이 아니어서 itertools.pairwise가 없는 경우 아래처럼 구현해서 쓸 수 있다). 그러면 results에 해당 함수의 실행 결과를 담은 다음, 그 이터레이터를 islice로 잘라서(최대 한도 10000개 까지만) 쓴다. 함수가 각각의 기능을 구현하면, 이를 필요한 기능을 덧붙여 재사용 할 수 있는 것이다.
def converge(values: Iterator[float], threshold: float) -> Iterator[float]:
for a, b in itertools.pairwise(values):
yield a
if abs(a - b) < threshold:
break
def pairwise(values: Iterator[A]) -> Iterator[Tuple[A, A]]:
a = next(values, None)
if a is None:
return
for b in values:
yield (a, b)
a = b
results = converge(sqrt3(n), 0.001)
capped_results = itertools.islice(results, 10000)
list(capped_results)
# [7.25, 5.349137931034482, 5.011394106532552, 5.000012953048684]
참고문헌: [1] Ashwin Rao, Tikhon Jelvis, "Foundations of Reinforcement Learning with Applications in Finance", 2023, Ch2
'IT in General > Python' 카테고리의 다른 글
Immutable vs mutable (0) | 2023.10.22 |
---|---|
any와 all, 그리고 return (0) | 2022.03.25 |
데코레이터와 피보나치 (0) | 2022.03.03 |
itertools.islice, tee, groupby, product (0) | 2022.03.03 |
functools.reduce (0) | 2022.03.02 |