torus711 のアレ

主に競技プログラミングの問題について書きます

TopCoder SRM 595, Division 2, Level 3 : LittleElephantAndXor

概要

整数 A, B, C が与えられる。
次の cond を満たすタプル ( x, y ) の数を求めよ

  • x <= A
  • y <= B
  • x XOR y <= C

解法

先頭から順に x, y のビットを決めていくと考えます。
このとき、x について次のことが言えます。

  • x が A 未満であることが確定していればそれ以降の bit に任意の値を入れられる
  • そうでないとき、A 以下になるような値しか入れられない

y と B,、x XOR y と C についても同様です。
欲しいのは x, y への valid な bit の割り当ての総数なので、これらの性質が一致する状態は同一視できます。
従って、次のような DP を考えることができます。
 dp[ i 桁考慮した ][ x が A 未満になった ][ y が B 未満になった ][ x XOR y が C 未満になった ] := 総数
dp[ i ][ j ][ k ][ l ] からの遷移は、x, y のある位置( bit )への値の割り当て方を全通り試すので 4 通りです。

x と A, y と B, x XOR y と C の各組について設定しようとしている bit が valid かどうかを考えます。
x の対応する bit に割り当てようとしている値を na とし、A の対応する bit を bit_a とすると、
 j == 1 || na <= bit_a
を満たすとき、その bit を割り当ては valid です。
他の値の組についても同様にして判定できます。
( x XOR y の対応する bit は y への割り当てを nb とすると na != nb (を評価した値)です )
全ての値の組について valid なとき、その割り当ては valid なので、j, k, l の次の値を nj, nk, nl とすると dp[ i + 1 ][ nj ][ nk ][ nl ] を更新します。

j, k, l から nj, nk, nl を求めるには、

nj = j || na < bit_a; // 過去に性質を満たしたか、この割り当てで満たすようになる

のようにします。

この計算が終わったあと、dp[ 全体の桁数 ][ j ][ k ][ l ] の総和が答えとなります。
入力が符号付き 32 bit 整数なので、31 桁やればよいです。

コード

typedef long long LL;

#define REP( i, m, n ) for ( int i = (int)( m ); i < (int)( n ); ++i )

LL dp[33][2][2][2];

class LittleElephantAndXor
{
public:
	long long getNumber( int A, int B, int C )
	{
		memset( dp, 0, sizeof( dp ) );

		// dp[ i 桁考慮 ][ A 未満確定 ][ B 未満確定 ][ C 未満確定 ] := 数
		dp[0][0][0][0] = 1;

		REP( i, 0, 31 )
		{
			REP( j, 0, 2 )
			{
				REP( k, 0, 2 )
				{
					REP( l, 0, 2 )
					{
						REP( s, 0, 1 << 2 )
						{
							// 0 bit 目:A の bit / 1 bit 目:B の bit
							const int na = !!( s & 1 << 0 );
							const int nb = !!( s & 1 << 1 );
							const int nc = na != nb;

							const int bit_a = !!( A & 1 << ( 30 - i ) );
							const int bit_b = !!( B & 1 << ( 30 - i ) );
							const int bit_c = !!( C & 1 << ( 30 - i ) );

							if ( !j && bit_a < na || !k && bit_b < nb || !l && bit_c < nc )
							{
								continue;
							}

							const int nj = j || na < bit_a;
							const int nk = k || nb < bit_b;
							const int nl = l || nc < bit_c;

							dp[ i + 1 ][ nj ][ nk ][ nl ] += dp[i][j][k][l];
						}
					}
				}
			}
		}
		
		LL res = 0;
		REP( j, 0, 2 )
		{
			REP( k, 0, 2 )
			{
				REP( l, 0, 2 )
				{
					res += dp[31][j][k][l];
				}
			}
		}

		return res;
	}
};