알고리즘/Graph

그래프 알고리즘 코드 정리 [3: 최소신장트리(MST)]

Algorithmus 2022. 11. 28. 01:01

최소신장트리(MST)

각 간선(edge)의 weight가 1이나 0이 아닌 경우에 모든 노드를 한 트리로 묶은 것(spanning tree)의 weight의 합이 최소가 되는 것을 MST (minimum spanning tree) 라 한다. 이를 구하는 방법은 union find를 쓰는 Kruskal과, Dijkstra의 변형인 Prim이 있는데, 통상 CP에서는 전자를 많이 쓴다고 한다.

UnionFind

Kruskal을 쓰려면 union find 자료구조를 정의해서 사용해야 한다. 다음 코드는 path compression과 union-by-rank가 모두 적용된 클래스이다.

class UnionFind:
	def __init__(self, N: int) -> None:  # N: number of vertices
		self.p = [-1 for i in range(N)]  # parent of vertex i
		self.rank = [0 for _ in range(N)]
		self.setsize = [1 for _ in range(N)]
		self.numsets = N

	# finds the "representative" node in a's component
	def find(self, i: int) -> int:
		if self.p[i] == -1: return i  # was set -1 to avoid self-loop
		self.p[i] = self.find(self.p[i])  # path compression
		return self.p[i]

	# returns whether two nodes are in the same connected component
	def is_same_set(self, x: int, y: int) -> bool:
		return self.find(x) == self.find(y)

	def num_disjoint_sets(self):
		return self.numsets

	def size_of_set(self, i):
		return self.setsize[self.find(i)]

	# returns whether the merge changed connectivity
	def union(self, i: int, j: int) -> bool:
		if i == j or self.is_same_set(i, j): return False
		x, y = self.find(i), self.find(j)
		# to ensure size(x) < size(y)
		if self.rank[x] >= self.rank[y]: x, y = y, x
		self.p[x] = y  # union-by-rank: maintain max height of trees
		self.rank[y] += 1  # was: += rank[x]
		self.setsize[y] += self.setsize[x]
		self.numsets -= 1
		return True

위 클래스를 이용하여 시험해보는 예제이다.

>>> uf = UnionFind(5)
>>> uf
<__main__.UnionFind object at 0x000001F78C018310>
>>> uf.num_disjoint_sets()
5
>>> uf.union(0, 1)
True
>>> uf.num_disjoint_sets()
4
>>> uf.union(2, 3)
True
>>> uf.num_disjoint_sets()
3
>>> uf.union(4, 3)
True
>>> uf.num_disjoint_sets()
2
>>> uf.is_same_set(0, 3)
False
>>> uf.is_same_set(4, 3)
True
>>> for i in range(5):
...     print(f"node {i}'s set is {uf.find(i)} with size {uf.size_of_set(i)}")
...
node 0's set is 0 with size 2
node 1's set is 0 with size 2
node 2's set is 2 with size 3
node 3's set is 2 with size 3
node 4's set is 2 with size 3

위의 경우에, union을 할 때 마다 합병하는 tree에 대해 몇 개의 tree가 지금까지 합병되었는지 그 숫자를 나타내는 rank를 사용했지만, 이를 달리 설정하면 height 등으로 변형할 수 있다.

Kruskal

Edge list {w, u, v} 형태로 자료를 저장해, w로 정렬 후, tree에 추가시 사이클이 안생기면서 가장 w가 작은 edge부터 v-1개를 추가한다(v개 node에 대해 cycle이 안생기면서 모든 점을 연결하는 spanning tree의 edge 수는 v-1개임). 여기서 사이클이 안생기는지 여부를 검증하기 위해 union find 클래스의 is_same_set을 활용한다. 이를 설명하기 위해 아래 예를 보자.

0 - 1 - 2           0 - 1 - 2
     \        =>         \ /
4     3             4     3
<no cycle>           <cycle>

노드 0, 1, 2, 3, 4가 있을때, 0, 1, 2, 3은 이미 한 set에 속한 상태이다. 이 때 edge (2, 3) 을 추가해도 될지 여부를 고민해 보면, 안되는 이유를 두 가지 설명할 수 있다. 첫째, edge (2, 3)을 추가하면 graph에 cycle이 생겨버려서 더 이상 tree가 아니게 된다 (tree는 DAG인 graph). 그리고 이와 동치인 둘째 이유로, node 2와 3은 이미 같은 set에 속하였다는 설명이 가능하다. 왜 둘은 동치인가. cycle이 생긴다는 것은 예를 들어 1, 2, 3을 보면 1에서 2로 접근하는 방법이 적어도 두 가지 있다는 것이다 (12, 132가 가능). 같은 set에 속한 노드끼리는 이미 서로 트리를 통해 접근이 가능한 상태이며 이 말은 그 set 내의 원소들 간에 연결을 추가로 만들 경우 다른 접근경로가 생긴다는 것으로서, 두 경로를 합치면 cycle과 같다. 이를 전형적 MST 문제에 적용해 구현해보면 아래와 같다 (문제: https://onlinejudge.org/external/116/11631.pdf)

import sys
while True:
    m, n = map(int, sys.stdin.readline().strip().split())
    if m == 0 and n == 0:
        break
    EL = []
    total_w = 0
    for _ in range(n):  # roads
        x, y, w = map(int, sys.stdin.readline().strip().split())
        total_w += w
        EL.append((w, x, y))
    EL.sort()
    uf = UnionFind(m)
    mst_weight = 0
    num_taken = 0
    for w, i, j in EL:
        if uf.is_same_set(i, j): continue
        uf.union(i, j)
        mst_weight += w
        num_taken += 1
        if num_taken == m-1: break
    ans = total_w - mst_weight
    print(ans)

Prim

Kruskal이나 Prim은 아무거나 무차별적으로 선택하면 되는 것처럼 알고리즘 책에는 나타나 있다. 왜냐면 복잡도가 O(ElogV) 로 같기 때문이다. 하지만 그것은 upper-bound일 뿐, 만일 Python < 3.11 처럼 느린 언어를 사용하여 채점되는 코테사이트의 경우에는 특정 알고리즘으로 할 경우 TLE가 나게 될 수 있다. 틀린 게 없는데 시간은 이미 많이 갔고, TLE가 나면 정신이 아득해진다. 그때 되어 c++로 다시 짜기도 어려우니 그렇다면 애초에 진입시부터 뭘 선택할지 고민을 해야 하는 상황이 발생한다. 문제의 본질과 알고리즘의 특성을 고려하면 시행착오를 거치지 않고도 더 나은 알고리즘을 선택 가능하다. "모두 다 해보고 더 좋은 걸 고르자"는 여유를 부릴 수 없는 순간에는 선험적 지식이 유용하다.
Kruskal은 edge를 weight 오름차순으로 정렬해서 가장 작은 weight를 갖는 것부터 tree에 추가하되, 추가하더라도 tree의 속성인 no cycle을 깨지 않는 edge만 추가하기 위해 검증과정을 거친다. 같은 그룹인지를 O(lg)만에 알기 위해, edge 추가시마다 각 node의 parent를 모두 업데이트해 두는 path compression을 한다. 그런데 dense graph의 경우는 cycle 위험을 검증해야 할 edge가 많아서 TLE가 날 수 있는 것 같다.
Prim을 보면 임의의 시작점으로부터 인접한 edge중에서 가장 weight가 작은 것(min)을 바로 선택하도록 heap 구조를 활용해 edge를 저장해둔다. 인접한 edge 중에서 사이클이 발생하지 않는지 여부는 이미 방문했는지를 나타내는 array를 통해 검증하기 때문에 Kruskal 대비 비용이 적게 든다. 코드를 보면 다음과 같다. (문제: http://uva.onlinejudge.org/external/112/11228.pdf)

import sys, heapq

cases = int(input())
for t in range(cases):
    n, r = map(int, sys.stdin.readline().strip().split())
    # n: the number of cities that comprise Graphland
    # r: threshold value to determine if two cities are in the same state
    nodes = []
    for i in range(n):
        x, y = map(int, sys.stdin.readline().strip().split())
        nodes.append((x, y))
    AL = [[] for _ in range(n)]
    for i in range(n):
        x1, y1 = nodes[i]
        for j in range(n):
            if i == j: continue
            x2, y2 = nodes[j]
            dist = ((x1-x2)**2 + (y1-y2)**2)**.5  # pythagorian
            AL[i].append((dist, j))
            #AL[j].append((dist, i))   already i, j has all reversed case
    pq = []
    taken = [False for _ in range(n)]
    
    #heapq.heappush(pq, (0, 0))  # must not include this, as process(0) will add first adjacent edge
    def process(u):
        taken[u] = True
        for w, v in AL[u]:
            if not taken[v]:
                heapq.heappush(pq, (w, v))
    
    process(0)
    road, rail = 0, 0
    num_taken, rail_taken = 0, 0
    while pq:
        w, u = heapq.heappop(pq)
        if taken[u]: continue
        if w <= r:
            road += w
        else:
            rail += w
            rail_taken += 1
        process(u)
        num_taken += 1
        if num_taken== n-1:
            break
    print(f"Case #{t+1}: {rail_taken+1} {round(road)} {round(rail)}")

그렇다면, Kruskal에서도 Prim처럼 visited[]를 통해 최적화 할 수는 없을까? 통상 visited는 노드에 대한 것이다. Prim은 인접 노드에 대한 탐색인 반면, Kruskal은 edge를 찾아 추가하는 것이므로 만일 기존에 추가한 edge의 양 끝 노드에 대해 방문했다고 표시하고 이를 나중에 추가시 제외해 버리면 틀리게 된다. 예를 들어, Kruskal에서 한 개의 edge를 추가한 상태에서, 다음으로 가벼운 weight을 갖는 edge를 추가하려는데 그게 기존에 추가했던 edge의 끝점 중 하나와 연결된 edge라면, node 방문 여부만을 가지고 이를 추가해도 될지를 결정하는 것을 부적절하다. edge 추가시는 꼭 사이클 여부를 점검해야 한다. 이를 UnionFind에서는 각 노드가 속한 그룹 관련 정보를 union 시마다 최신 상태로 업데이트해서 가지고 있기 때문에 신속히(O(lg), Path Compression시) 그 정보를 알 수 있기 때문에 쓰는 것이다.
참고로 사이클 탐지 알고리즘 중에서 UnionFind는 무방향 그래프의 경우에, floyd cycle이나 DFS/BFS는 방향 그래프에 사용할 수 있다.

반응형