Mister雑記

競プロやります。

ARC066-D 「Xor Sum」 (600)

atcoder.jp

概要

正の整数 Nが与えられる。このとき、

  •  a + b \leq u
  •  a  xor  b \leq v

なる非負整数の組 (a, b)が存在し、かつ 0 \leq u, v \leq Nを満たすような (u, v)の組の個数を 10^9+7で割った余りを求めよ。

制約

  •  1 \leq N \leq 10^{18}

考察

まず、 + xorの間には以下のような性質がある。

  •  a  xor  b \geq 0
  •  (a + b) - (a  xor  b) =  (a  and  b) \times 2 \geq 0

ここから 0 \leq u \leq vという関係がわかる。よって v 0から Nの間で固定して、それぞれ uが何通りの値を取りうるかを考えることにする。


ここで、 a  xor  bを最下位bitから順に決めていくことを考える。 v = a + bの最下位bitについて場合分けすると以下の通り。

  •  vの最下位bitが1(つまり奇数)の場合

最下位bitにだけ注目すると、 (a, b) = (0, 1), (1, 0)となる。 a  xor  bの最下位bitはどちらも1で、繰り上がりは常になし。

  •  vの最下位bitが0(つまり偶数)の場合

最下位bitにだけ注目すると、 (a, b) = (0, 0), (1, 1)となる。 a  xor  bの最下位bitはどちらも0だが、繰り上がりの有無によりそこから先が異なる。


以上を踏まえると、 vを固定したときに a  xor  bの取りうる値の数は以下のような再帰関数によって求まる。

ll rec(ll v) {
    if (v == 0) return 1;

    if (v & 1) {
        // 最下位は1で確定
        // 繰り上がりはないので、それより上についてそのまま考えればいい
        return rec(v >> 1);
    } else {
        // 最下位は0で確定
        // 繰り上がりが起こるか否かで場合分け
        return rec(v >> 1) + rec((v >> 1) - 1);
    }
}


このままではメモ化しても O(N \log N)でTLEなので、 v \in \{0, 1, \cdots, M\}を一気に処理することで高速化する。

上の遷移が奇数と偶数で分かれているので、こちらも vを奇数と偶数のグループに分ける。


  • 奇数グループ

奇数グループは \{1, 3, \cdots, 2 \lfloor (M - 1) / 2 \rfloor + 1\}となる。これらは \{0, 1, \cdots, \lfloor (M - 1) / 2 \rfloor\}へと遷移する。

  • 偶数グループ

偶数グループは \{0, 2, \cdots, 2 \lfloor M / 2 \rfloor\}となる。

繰り上がり無の場合、これらは \{0, 1, \cdots, \lfloor M / 2 \rfloor\}へと遷移する。

繰り上がり有の場合は注意が必要で、 0からは遷移ができない。よって  M \lt 2のケースを弾く必要がある。 そうでない場合、 \{2, 4, \cdots, 2\lfloor M / 2 \rfloor\} \{0, 1, \cdots, \lfloor M / 2 \rfloor - 1\}へと遷移する。


以上をメモ化再帰で実装すれば、計算量 O(\log N)で解が求まる。

実装例

解法によって実装量は大きく変わると思われる。私の解法は結果的に実装量が少なくなったが、遷移の \pm 1で大きく躓いてしまった。

#include <iostream>
#include <map>

using namespace std;
using ll = long long;

const ll MOD = 1000000007;

map<ll, ll> dp;

// v=0,1,...,Mにおける解の合計
ll rec(ll M) {
    if (M == 0) return 1;
    if (dp.count(M)) return dp[M];

    dp[M] = rec((M - 1) / 2);            // 奇数
    dp[M] += rec(M / 2);                 // 偶数 繰り上がり無
    if (M > 1) dp[M] += rec(M / 2 - 1);  // 偶数 繰り上がり有
    return dp[M] %= MOD;
}

int main() {
    ll N;
    cin >> N;
    cout << rec(N) << endl;
    return 0;
}

感想

  • はぁしんどい。
  • よく昔の自分は「最下位bitから決める」という方針に舵を切れたもんだと思う。
    • 裏では a  and  bがどんな値を取りうるか実験していたりする。

追記

よくよく考えたら xor andに変える必要がなかったので、そこ全体を修正しました。