最小全域木問題(クラスカル法とプリム法)

最小全域木問題を解くためのアルゴリズム「クラスカル法」「プリム法」を使ってみた.

  • 最小全域木について
  • クラスカル法
  • プリム法
  • PKUの問題
  • クラスカル法による解答
  • プリム法による解答
  • メモリ使用量と実行時間の比較

最小全域木について

まず,全域木(Spanning tree)とは連結グラフの全ての頂点とそのグラフを構成する辺の一部分のみで構成される木のこと.つまり,連結グラフから適当な辺を取り除いていき,閉路をもたない木の形にしたものが全域木となる.ここで,グラフの各辺に重みがある場合,重みの総和が最小になるように辺を選んで作った全域木のことを最小全域木(Minimum spanning tree)という.

最小全域木を求めるアルゴリズムとしては以下の二つが有名である.

  • クラスカル法 (Kruskal's algorithm)
  • プリム法 (Prim's algorithm)

いずれも貪欲法と呼ばれるアルゴリズムの一種である.

クラスカル法

クラスカル法の流れは以下の通り.

  1. グラフの各頂点がそれぞれの木に属するように,森(木の集合)Fを生成する(つまり頂点1個だけからなる木が頂点の個数だけ存在する).
  2. グラフの全ての辺を含む集合Eを生成する.
  3. Eが空集合になるまで,以下を繰り返す.
    • Eから重みが最小の辺eを取り出し,削除する.その辺eと繋がっている二つの頂点u, vが別々の木に属しているならば,辺eを森Fに加え,二つの木を連結しひとつの木にまとめる.
  4. 最終的に森Fが最小全域木となる.

二つの頂点u, vが別の木に属しているかどうかの判定にはUnion Findのアルゴリズムを使うことができる.

クラスカル法の実行イメージはこんな感じ.このアニメーションはRを使って作成した

クラスカル法の詳細については以下を参照.
クラスカル法 - Wikipedia

プリム法

プリム法の流れは以下の通り.

  1. VとEを空集合とする.
  2. グラフから任意の頂点をひとつ選び,Vに加える.
  3. Vがグラフのすべての頂点を含むまで,以下を繰り返す.
    • Vに含まれる頂点uと含まれない頂点vを結ぶ重みが最小の辺(u,v)をグラフから選び,Eに加える.そしてvをVに加える.
  4. 最終的にグラフ(V,E)が最小全域木となる.

詳しくは以下を参照.
プリム法 - Wikipedia

PKUの問題

PKU JudgeOnlineの問題集から最小全域木を求める問題を選び,クラスカル法とプリム法で解いてみた.今回解いた問題はこちら.

Problem 2031 - Building a Space Station


大雑把に言うと以下のような問題.

球の形をした各宇宙ステーション(cell)の座標と半径が与えられる.これらの宇宙ステーション間に通路を建設し,全宇宙ステーションを連結させたい.必要な通路長の総和の最小値を求めよ.

宇宙ステーションがグラフの頂点,通路がグラフの辺にあたる.ちなみに二つの宇宙ステーションA,Bについて「AB間の距離 < Aの半径 + Bの半径」が満たされる場合は,ABは既に繋がっているものとみなされる.

クラスカル法による解答

Union Find処理用のクラスDisjointSetを定義したり,通路(グラフの辺)を表す構造体corridorを作ったりといろいろ頑張った結果,ソースコードがまあまあ長くなってしまった.通路を長さ順でソートするのにvectorではなくmultiset >を使ったりすれば,構造体corridorを定義する必要がなくなるのでもう少し行数を減らせるかもしれない.その代わり実行時間は遅くなるだろうけど….

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <cmath>
using namespace std;

// Union Find用
class DisjointSet {
private:
  vector<int> p;
  
public:
  DisjointSet(int size) : p(size, -1) {}

  bool unite_set(int x, int y) {
    x = root(x);
    y = root(y);
    
    if (x != y) {
      if (x > y) { swap(x, y); }
      p[x] += p[y];
      p[y] = x;
    }   

    return x != y;
  }

private:
  int root(int x) {
    return p[x] < 0 ? x : p[x] = root(p[x]);
  }
};

// 通路
struct corridor {
  int a, b;     // 繋がっているcell
  double len;   // 通路の長さ
  corridor(int a, int b, double len) : a(a), b(b), len(len) {}

  // ソート用
  bool operator <(const corridor& rhs) const {
    if (len != rhs.len) { return len < rhs.len; }
    if (a != rhs.a) { return a < rhs.a; }
    return b < rhs.b;
  }
};

double dist(double x1, double y1, double z1, double x2, double y2, double z2) {
  return sqrt((x1-x2)*(x1-x2) + (y1-y2)*(y1-y2) + (z1-z2)*(z1-z2));
}

int main(int argc, char const* argv[]) {
  int n;
  while (cin >> n && n) {
    double x[100], y[100], z[100], r[100];
    for (int i = 0; i < n; i++) {
      cin >> x[i] >> y[i] >> z[i] >> r[i];
    }

    // 全ての通路を長さでソートする
    vector<corridor> cor;
    for (int i = 0; i < n; i++) {
      for (int j = 0; j < i; j++) {
        cor.push_back(corridor(i, j, max(0.0, dist(x[i], y[i], z[i], x[j], y[j], z[j]) - r[i] - r[j])));
      }
    }
    sort(cor.begin(), cor.end());

    DisjointSet ds(n);
    double ans = 0;   // 通路の長さの総和
    int count = 0;    // 選んだ通路の数

    for (vector<corridor>::iterator itr = cor.begin(); count < n - 1; itr++) {
      // aとbが別々の集合に属しているならばその通路を採用する
      if (ds.unite_set(itr->a, itr->b)) {
        ans += itr->len;
        count++;
      }
    }

    printf("%.3f\n", ans);
  }
  return 0;
}

プリム法による解答

通路を長さ順にソートするのにmultisetを使ったりしたので,クラスカル法と比べてシンプルにまとまっている.

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <map>
#include <set>
#include <cmath>
using namespace std;

double dist(double x1, double y1, double z1, double x2, double y2, double z2) {
  return sqrt((x1-x2)*(x1-x2) + (y1-y2)*(y1-y2) + (z1-z2)*(z1-z2));
}

int main(int argc, char const* argv[]) {
  int n;
  while (cin >> n && n) {
    double x[100], y[100], z[100], r[100];
    for (int i = 0; i < n; i++) {
      cin >> x[i] >> y[i] >> z[i] >> r[i];
    }

    set<int> cell;   // 今まで選んだcellの集合
    double ans = 0;  // 通路の長さの総和

    multimap<double, int> corridor;   // cellから繋がっている通路
    corridor.insert(make_pair(0.0, 0));

    while (cell.size() < n) {
      // 今まで選んだcellに繋がっている通路のうち一番短いものを選ぶ
      double len = corridor.begin()->first;
      int c = corridor.begin()->second;
      corridor.erase(corridor.begin());

      // 選んだ通路に繋がるcellが既に集合の中にある場合は非採用
      // そうでない場合は通路を採用しcellを集合に加える
      if (!cell.insert(c).second) { continue; }
      ans += len;

      // 新たに集合に加えたcellに繋がる通路を探索候補に追加
      for (int i = 0; i < n; i++) {
        if (cell.find(i) == cell.end()) {
          corridor.insert(make_pair(max(0.0, dist(x[i], y[i], z[i], x[c], y[c], z[c]) - r[i] - r[c]), i));
        }
      }
    }

    printf("%.3f\n", ans);
  }
  return 0;
}

メモリ使用量と実行時間の比較

各コードをOnline Judgeにsubmitして走らせた時の実行時間とメモリ使用量は以下のようになった.Online Judgeの際にどういう入力が与えられるのかは非公開なので不明だが,submitしたそれぞれのプログラムに対して毎回同じ入力が与えられるため,実行時間の比較は一応可能となっている.ちなみにこの問題での実行時間の上限は1000ms,メモリ使用量の上限は30000KB.

アルゴリズム名実行時間メモリ使用量
クラスカル法188ms1044KB
プリム法454ms1072KB
クラスカル法では長さについてあらかじめソートしておいた辺の配列を端から順に探索していくだけなのに対して,プリム法では辺のmultisetから繰り返し要素を追加したり取り出したりしているため,この処理がボトルネックになって実行時間が長くなってしまったものと思われる.もう少し書き方を変更すればまた違った結果が得られるかもしれない.あくまでも参考程度に.