본문 바로가기

Study/BOJ

BOJ - 최소 스패닝 트리(1197)

MST를 찾는 알고리즘 중 union-find를 이용하여 O(E*logE)에 찾을 수 있는 Kruskal algorithm을 이용하였다.

 

코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include<iostream>
#include<vector>
#include<algorithm>
using namespace std;
int V, E;
typedef struct graph
{
    int v1;
    int v2;
    int cost;
}graph;
vector<graph> v;
 
int getRoot(vector<int> parent, int idx)
{
    int i;
    for(i = idx; i != parent[i]; i = parent[i]);
    return i;
/*
    if (parent[idx] == idx) return idx;
    else return getRoot(parent, parent[idx]);
    */
}
 
bool find(vector<int> parent, int a, int b)
{
    a = getRoot(parent, a);
    b = getRoot(parent, b);
    if (a == b) return true;
    else return false;
}
 
void my_union(vector<int> &parent, int a, int b)
{
    a = getRoot(parent, a);
    b = getRoot(parent, b);
    if (a < b) parent[b] = a;
    else parent[a] = b;
}
 
bool comp(graph a, graph b)
{
    return a.cost < b.cost;
}
 
int main()
{
    cin.tie(NULL);
    ios_base::sync_with_stdio(false);
    cin >> V >> E;
    vector<int> parent(V + 1);
    for(int i=1; i<=V; i++)
        parent[i] = i;
    for(int i=0; i<E; i++)
    {
        graph g;
        cin >> g.v1 >> g.v2 >> g.cost;
        v.push_back(g);
    }
    sort(v.begin(), v.end(), comp);
    int sum = 0;
    for(int i=0; i<v.size(); i++)
    {
        if (!find(parent, v[i].v1, v[i].v2))
        {
            my_union(parent, v[i].v1, v[i].v2);
            sum += v[i].cost;
        }
    }
    printf("%d\n", sum);
    return (0);
}
cs

위와 같이 cost가 작은 순으로 sorting한 뒤, cost가 작은 edge부터 연결(즉, union)해주면 된다. 이 때, find 함수를 계속 호출하면서 현재 같은 set에 묶여있는지 확인해주어야 한다. (cycle check)

1
2
3
4
5
6
7
8
9
10
int getRoot(vector<int> parent, int idx)
{
    int i;
    for(i = idx; i != parent[i]; i = parent[i]);
    return i;
/*
    if (parent[idx] == idx) return idx;
    else return getRoot(parent, parent[idx]);
    */
}
cs

그런데, 주의해야 할 점이 있었다. 기존에는 getRoot를 찾을 때 재귀로 찾는 방법을 선택하였는데 시간초과가 떠버린 것이다. 

그래서 그냥 return하는 것이 아니라 parent[idx]의 값을 수정하면서 return하게 하는 path-compression을 이용하였는데도 시간초과가 떠서, 위와 같이 for문으로 구현하였더니 시간초과를 피할 수 있었다. 

 

추가)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
int getRoot(vector<int> &parent, int idx)
{
    /*
    int i;
    for(i = idx; i != parent[i]; i = parent[i]);
    return i;*/
    if (parent[idx] == idx) return idx;
    else
    {
        int root = getRoot(parent, parent[idx]);
        parent[idx] = root;
        return parent[idx];
    }
}
cs

path-compression을 사용하면 훨씬 빠른데 시간이 줄지 않았던 이유는 parameter로 받을 때

vector<int> &parent로 받지 않아 함수 내에서 값을 수정하더라도 update가 안됐던 것이 문제였다.

그래서 이 방법을 사용하는 것이 훨씬 좋다! 

 

'Study > BOJ' 카테고리의 다른 글

BOJ - 네트워크 연결(1922)  (0) 2021.04.15
BOJ - 거짓말(1043)  (0) 2021.04.14
BOJ - 미네랄(2933)  (0) 2021.04.12
BOJ - 청소년 상어(19236)  (0) 2021.04.10
BOJ - 마법사 상어와 파이어스톰(20058)  (0) 2021.04.10