본문 바로가기

알고리즘/<4> The Greedy Approach

[알고리즘] 크루스칼 알고리즘, Kruskal's Algorithm | 설계 및 구현코드

  크루스칼 알고리즘은 프림 알고리즘과 달리 한 정점이 locally optimal한지 결정할 때 단순히 edge의 가중치를 기준으로 결정한다. 즉, edge의 가중치가 작으면 작을수록 locally optimal한 것이다. 가중치가 작은 것 부터, 즉 유리한 것 부터 먼저 담으려는 크루스칼 알고리즘의 의도가 참 말그대로 'Greedy' 해보인다. 하지만 크루스칼 알고리즘이 대책없이 edge들을 마구마구 담아대는 것은 아니다. 순환 경로가 발생하지 않도록 정점들의 관계를 검사하는 작업을 치른다. 이 과정에 대해 더 자세히 알아보도록 하자.

 

수도코드(high level)

 

F := ∅;  //edge의 집합 초기화
  
create disjoint subsets of V, one for each
vertex and containing only that vertex;

sort the edges in E in nondecreasing order;

While (the instance is not solved) {
  select next edge;  //selection procedure
  if(the edge connects 2 vertices in disjoint subsets) {  //feasibility check
    merge the subsets;
    add the edge to F;
  }
  
  if(all the subsets are merged)  //solution check
    the instance is solved;
}

 

문제 해결 과정

  위 수도코드를 보면 알 수 있듯 맨 처음 edge들을 가중치가 작은 것 부터 순서대로 정렬해놓은 후 edge들을 하나씩 가져간다. 즉 크기가 작은 edge부터 뽑아간다. 그 뒤 뽑아간 edge를 잇는 두 정점이 서로 다른 집합(disjoint set)에 있는지 검사한다. 이것이 바로 크루스칼 알고리즘이 edge들을 greedy하게 선택하되 순환경로가 만들어지는 것을 막기위해 매번 수행하는 검사이다. 여기서 집합은 두 정점이 연결되는 순간 그 두 정점은 같은 집합에 속하게 되는 것인데 수도코드의 맨 처음을 보면, 알고리즘 시작 전 그래프의 모든 정점들을 원소가 자기 자신 뿐인 서로소 집합에 담은 채로 시작함을 알 수 있다. 여기서 서로 다른 집합의 정점과 합치면서 집합이 커져나가는 것이다.  그 과정을 아래 그림으로 이해해보자.

 

 

 

해결 과정 예시

  위 그림을 통해 정점들이 서로 집합을 어떻게 만들어나가는지 감을 잡을 수 있었을 것이다. 그렇다면 이 집합을 크루스칼 알고리즘에서는 어떻게 구현하고, 또 정점들이 서로 같은 집합에 있는지, 다른 집합에 있는지는 어떤 방법으로 검사할 수 있을까? 이를 위해 크루스칼 알고리즘은 disjoint set 자체를 하나의 자료구조로 만들어두었다. 이 disjoint set에 대한 다양한 자료구조가 존재하는데 그 중 성능이 가장 좋은 자료구조를 하나 골라 이에 대해 설명하겠다.

 

<크루스칼 알고리즘에서 사용하는 자료구조>

  disjoint set은 initial(), find(), merge()등의 함수를 제공하는 자료구조이며 집합을 트리로 나타내고, 이를 다시 배열로 표현한다. 트리에서 자기 자신을 가리키는 것이 root가 된다.

 

① typedef index set_pointer;

  특정 키가 들어있는 배열의 인덱스를 set_pointer라고 정의한다.

typedef index set_pointer;

 

② struct nodetype { index parent; int depth; }

  트리를 이루는 노드 구조체를 정의한다. 부모의 인덱스를 가리키는 parent와 노드의 depth 값이 들어있다.

struct nodetype { index parent; int depth; }

 

③ void makeset (index i)

  자기 자신을 가리키는, 원소가 자기 하나뿐인 트리를 만든다.

void makeset (index i) { U[i].parent=i; U[i].depth=0; }

 

④ set_pointer find (index i)

  해당 노드가 속한 집합의 parent를 찾아 인덱스를 리턴한다.

set_pointer find (index i) {
  index j;
  j=i;
  while(U[j].parent!=j)
    j=U[j].parent;
  return j;
}

 

⑤ void merge (set_pointer p, set_pointer q)

  서로 다른 두 집합을 합친다. 이 때, 트리가 선형구조를 이루지 않도록 하기 위해 depth가 작은 것이 큰 것을 가리키도록 한다.

 void merge (set_pointer p, set_pointer q) {
      if ( U[p].depth == U[q].depth ) {
          U[p].depth +=1;
          U[q].parent = p;
      } else if ( Q[p].depth < U[q].depth) Q[p].parent = q;
     else U[q].parent = p;
     }

 

⑥ bool equal (set_pointer p, set_pointer q)

  부모의 인덱스가 서로 같은지 검사한다.

bool equal (set_pointer p, set_pointer q) {
    if(p==q) return TRUE;
    else return FALSE;
}

 

⑦ void initial (int n)

  makeset() 함수를 이용하여 자기 자신을 가리키며 원소가 자신 뿐인 집합들로 초기화한다.

void initial (int n) {
    index  i;
    for(i=1; i<=n; i++)
      makeset(i);
}

 

 '집합' 과 '트리' 구조에 대한 이해를 돕기위해 그림으로 예를들면 아래와 같다.

 

▷ 트리 초기화

▷ 트리 병합

      v1과 v2를 병합하고, 그 뒤 v3과 병합하는 경우라고 가정

 

 

  이 모든 과정이 아래의 수도코드에 자세히 담겨있다.

 

수도코드(low level)

 

void kruskal(int n, int m, set_of_edges E, set_of_edges& F) {
  index i, j;
  set_pointer p, q;
  edge e;
  
  sort the m edges in E by weight in nondecreasing order;
  F=;
  initial(n); //서로소로 초기화
  while(number of edges in F is less than n-1) {
    e = edges with least weight not yet considered;
    i, j= indices of verticies connected by e;
    p = find(i); //vi가 속한 집합을 찾음
    q = find(j); //vj가 속한 집합을 찾음
    if(!equal(p,q)) {  //p와 q, 즉 두 정점의 루트가 같은지 비교
      merge(p,q);  //다르면 합침
      add e to F;
    }
  }
}

 

  위 수도코드를 바탕으로 구현한 알고리즘은 다음과 같다.

 

구현 코드

 

#include <stdio.h>

typedef struct
{
	int weight;
	int v1;
	int v2;
}edge;

typedef struct
{
	int parent;
	int depth;
}universe;

#define VERTEX_NUM 5
#define EDGE_NUM 7
#define TRUE 1
#define FALSE 0


void input(int vertex1, int vertex2, int w);
void kruskal(int n, int m, edge* E, edge* F);
void initial(universe* U, int n);
void makeset(universe* U, int i);
int find(universe* U, int i);
void merge(universe* U, int p, int q);
int equal(int p, int q);
int check(int* array);

edge set_of_edges[EDGE_NUM];
int index = -1;
int f_index = -1;

int main()
{
	edge result[EDGE_NUM];

	input(1, 2, 1);
	input(1, 3, 3);
	input(2, 3, 3);
	input(2, 4, 6);
	input(3, 4, 4);
	input(3, 5, 2);
	input(4, 5, 5);

	kruskal(VERTEX_NUM, EDGE_NUM, set_of_edges, result);
	
	for(int i=0;i<=f_index;i++)
	{
		printf("%d - %d , weight: %d\n", result[i].v1, result[i].v2, result[i].weight);
	}
}

void input(int vertex1, int vertex2, int w)
{
	set_of_edges[++index].v1 = vertex1;
	set_of_edges[index].v2 = vertex2;
	set_of_edges[index].weight = w;
}

void kruskal(int n, int m, edge *E, edge *F)
{
	universe U[VERTEX_NUM+1];
	int add[VERTEX_NUM+1] = { 0, };

	//sort
	for (int i = 0; i < m - 1; i++)
	{
		for (int j = i + 1; j < m; j++)
		{
			if (E[i].weight > E[j].weight)
			{
				edge temp = E[i];
				E[i] = E[j];
				E[j] = temp;
			}
		}
	}
	//init array U
	for (int i = 1; i <= VERTEX_NUM; i++)
	{
		U[i].parent = 0;
		U[i].depth = 0;
	}

	//init array F
	for (int i = 0; i < m; i++)
	{
		F[i].v1 = 0;
		F[i].v2 = 0;
		F[i].weight = 0;
	}

	//initial
	initial(U,VERTEX_NUM);


	int index = 0;
	while (1)
	{
		int i, j, p, q;
		edge e = E[index];

		i = E[index].v1;
		j = E[index].v2;
		
		p = find(U,i);
		q = find(U,j);
		
		if (!equal(p, q))
		{
			merge(U, p, q);
			F[++f_index] = E[index];
			
		}

		index++;
		
		int cnt = 0;
		for (int i = 1; i <= EDGE_NUM; i++)
		{
			if (U[i].parent == i)
				cnt++;
		}

		if (cnt == 1)
			break;
	}

}

void initial(universe* U,int n)
{
	for (int i = 1; i <= n; i++)
		makeset(U,i);
}

void makeset(universe* U,int i)
{
	U[i].parent = i;
	U[i].depth = 0;
}

int find(universe* U,int i)
{
	int j;
	j = i;
	
	while (U[j].parent != j)
	{
		j = U[j].parent;
	}
	return j;
}

void merge(universe* U,int p, int q)
{
	if (U[p].depth == U[q].depth)
	{
		U[p].depth += 1;
		U[q].parent = p;
	}
	else if (U[p].depth < U[q].depth)
		U[p].parent = q;
	else
		U[q].parent = p;
}

int equal(int p, int q)
{
	if (p == q)
		return TRUE;
	else
		return FALSE;
}

int check(int* array)
{
	int finish = TRUE;
	for (int i = 1; i <= VERTEX_NUM; i++)
	{
		if (array[i] == FALSE)
			finish = FALSE;
	}
	
	return finish;
}