Single Round Match 408 Div2 1000pt

練習がてらに解いてみる。問題は大まかに説明すると「赤玉と青玉が合計で奇数個入った袋があり、あなたがランダムに1個とり、敵が青玉を1個とる。最終的にあなたが最後に取る玉が青ならあなたの勝ち。」というもの。厳密な問題文はアカウントがあれば下のリンクで見に行ける。
http://www.topcoder.com/stat?c=problem_statement&pm=9754&rd=12180

さて、とりあえず問題文に書いてあることを簡潔に実装するとこうなった。

  double getProbability(int redCount, int blueCount) {
    if(redCount == 0) return 1.0;
    if(blueCount < 2) return 0.0;
    double p = 1.0 * redCount / (redCount + blueCount);
    return p * getProbability(redCount - 1, blueCount - 1) \
           + (1 - p) * getProbability(redCount, blueCount - 2);
  }

これで手元でのテストがだいたい通る。1個failするけど、それはテストに使っているTZTesterってプラグインのバグで、浮動小数点数で返した答えが正しいかをイコールで判断しているのが原因。サーバに送ってテストすると通る。submit!


で、Practice OptionからRun System Testを選んでもっとたくさんのテストケースを食わせてみる。案の定パフォーマンスを何も考えていないコードなので{131, 3000}が来たときに2秒の時間制限に引っかかってfailした。手元のコードのテストケースを書き換えて実行してみる。うは、60秒経っても終わらないwww


まぁ、それじゃあとりあえずこのコードでは何度も同じ値で関数が呼ばれるので、計算した値をキャッシュしておくことにしよう。Pythonでパフォーマンスを考えずに書くならここはタプルをキーにした辞書を作るところだけど、pairとかやるのもばからしいし、今回引数の範囲は上限が4000だと問題文に明記してあるのでこれでいいよね。

    int key = redCount * 4001 + blueCount;

さて、そんなわけでmapの説明を見つつ書いてみる。cacheに既に値があればそれを返し、なければ計算してcacheに入れてから返す。こんな感じかな。

typedef map<int, double> MID;
class MarblesInABag {
public:
  MID cache;
  double getProbability(int redCount, int blueCount) {
    int key = redCount * 4001 + blueCount;
    MID::iterator i = cache.find(key);
    if(i != cache.end()){
      return i->second;
    }
    if(redCount == 0) return 1.0;
    if(blueCount < 2) return 0.0;
    double p = 1.0 * redCount / (redCount + blueCount);
    double result = p * getProbability(redCount - 1, blueCount - 1)
                    + (1 - p) * getProbability(redCount, blueCount - 2);
    pair<int, double> value(key, result);
    cache.insert(cache.begin(), value);
    return result;
  }
}

さて、これで60秒以上掛かっていたのが2〜3秒で済むようになった。しかし制限時間は2秒だ。これではまだ遅い。まぁ、これでも遅いってのは実は想定の範囲内なんだけど。関数を何度も呼ぶのは効率が悪そうなので、ループで値の小さいところから埋めていくように変える。

  double getProbability(int redCount, int blueCount) {
    typedef pair<int, double> PID;
    const int MAX_PLUS_1 = 4001;
    int key = redCount * MAX_PLUS_1 + blueCount;
    // when r = 0
    for(int b(1); b <= blueCount; b += 2){
      cache.insert(cache.begin(), PID(b, 1.0));
    }

    for(int r(1); r <= redCount; ++r){
      int b(1 - (r % 2));
      cache.insert(cache.begin(), PID(r * MAX_PLUS_1 + b, 0.0));
      for(b += 2; b <= blueCount; b += 2){
	double p = 1.0 * r / (r + b);
	double result = p * cache.find((r - 1) * MAX_PLUS_1 + b - 1)->second \
	  + (1 - p) * cache.find(r * MAX_PLUS_1 + b - 2)->second;
	cache.insert(cache.begin(), PID(r * MAX_PLUS_1 + b, result));
      }
    }
    return cache.find(key)->second;
  }

なに、これでも2秒超える?関数呼び出しがボトルネックだと思っていたが違うのか?!えーと、ここから何が削れるんだ?やっぱりmapに全部突っ込んだのが間違いかな。。。必要なところだけをvectorに入れる形に変えてみる。

  double getProbability(int redCount, int blueCount) {
    const int MAX_PLUS_1 = 4001;
    vector<double> cache(MAX_PLUS_1);
    vector<double> new_cache(MAX_PLUS_1);

    // when r = 0
    for(int b(1); b <= blueCount; b += 2){
      cache[b] = 1.0;
    }

    for(int r(1); r <= redCount; ++r){
      int b(1 - (r % 2));
      new_cache[b] = 0.0;
      for(b += 2; b <= blueCount; b += 2){
	double p = 1.0 * r / (r + b);
	new_cache[b] = p * cache[b - 1] + (1 - p) * new_cache[b - 2];
      }
      cache.swap(new_cache);
    }
    return cache[blueCount];
  }

ふう。だいぶ縮んだ。さっきの「2〜3秒」ってのは実はコンパイル時間も含まれていたのでそれを除くと、1.2秒から0.077秒に縮んだ。15倍。submit。うん、全部のシステムテストに通った。


今回の教訓は「mapはそんなに速くない」ということか。この程度の規模の問題だったら余裕だと思ったんだけど、想像より遅かった。まぁ、最大4000x4000個の値がつっこまれるわけだからなぁ。vectorに変えて15倍速くなったことに関してはあんまり意外ではない。完全にサイズ固定で確保してswapしつつ書き込んでいるだけだから配列を使っているも同然だし、書き込みも読み出しもO(1)だし。もっと速くても驚かない。


さて、解けたのでめでたしめでたしと言うことで帰ろうと思ったらもう終電がなくなっていたorz