torus711 のアレ

主に競技プログラミングの問題について書きます.PC 以外だと数式が表示されないかもしれないです

AtCoder Beginner Contest 372, E : K-th Largest Connected Components

問題概要

 $V = \{ 1, 2, \dots, n \}$ として,空グラフ $G = ( V, E = \{\} )$ を考える.$G$ に対し,以下の 2 種類からなるクエリを $q$ 個処理せよ:

  • クエリ 1: $u, v \in V$ が与えられる.$E \leftarrow E \cup \{ \{ u, v \} \}$ と更新する.
  • クエリ 2: $v \in V, k \in \mathbb Z_{ > 0 }$ が与えられる.$u$ と同じ連結成分に属する頂点の内,頂点番号が $k$ 番目に大きいものを出力する.存在しない場合はそのことを報告する.

制約

  • $1 \leq n, q \leq 2 \times 10^5$
  • $1 \leq k \leq 10$

解法

 $k$ の制約を見落としたまま不必要に難しい問題を解いてしまったので,(せっかくなので(?))そちらについて書きます.
 グラフの連結成分を管理したいということなので,この部分についてはセオリー通り(?)素集合データ構造(いわゆる Union-Find)を利用します.加えて,クエリ 2 に答えるために,連結成分ごとにそこに属する頂点を何らかのデータ構造に入れて管理します.
 このデータ構造を一旦 $\mathcal S$ とし,頂点 $v$ が属する連結成分に紐付けられるものを $\mathcal S_v$ で表します.初期状態では各 $v$ について $\mathcal S_v = \{ v \}$ です.クエリ 1 において異なる連結成分に属する 2 頂点 $u, v$ が指定されたとき,$\mathcal S_u \cup \mathcal S_v$ (データ構造のマージ)を計算する必要があります.これを工夫せずに処理すると最悪ケースで $\Omega( n^2 )$ 時間がかかり TLE しますが,界隈で「データ構造をマージする一般的なテク(マージテク)」と呼ばれる工夫 (Weighted Union Heiristic) をすることで,要素がマージによって移動させられる回数を $O( n \log n )$ 回に抑えることができます.
 以上により連結成分内の要素を列挙することまではできましたが,まだ $k$-th largest を求める部分が残っています.C++ の std::set はそのようなインターフェースを提供していないので,$\mathcal S$ の実装として用いることは難しいです.ということで何か別のものを用意する必要があるのですが,今回は Randomized Binary Search Tree (RBST) (see プログラミングコンテストでのデータ構造 2 ~平衡二分探索木編~) が手元にあったので,やや加筆して用いました.
 これで一応,$O( q \alpha( n ) + n \log^2 n )$ 時間になったはずです.

コード

#include <random>
#include <memory>
 
namespace RBST
{
	template < typename T >
	class set
	{
	private:
		// 乱数生成器
		std::mt19937 rng;
		std::uniform_int_distribution< int > rng_n;
		int rand( const int a, const int b )
		{
			rng_n.param( std::uniform_int_distribution< int >::param_type( a, b ) );
			return rng_n( rng );
		}
 
		// ノード
		struct Node
		{
			T value;
			Node *left = nullptr, *right = nullptr;
 
			unsigned int size_ = 1;
 
			Node( const T &v ) : value( v )
			{
				return;
			}
 
			~Node()
			{
				delete left;
				delete right;
				return;
			};
		};
 
		// 本体
		Node *root = nullptr;
 
	public:
		~set()
		{
			delete root;
		}
 
		unsigned int size() const
		{
			return size( root );
		}
 
		unsigned int size( Node * const t ) const
		{
			return t == nullptr ? 0 : t->size_;
		}
 
		Node *update( Node * const t )
		{
			t->size_ = size( t->left ) + size( t->right ) + 1;
			return t;
		}
 
		bool find( const T &v )
		{
			return find( root, v );
		}
 
		bool find( Node * const t, const T &v )
		{
			if ( t == nullptr )
			{
				return false;
			}
			if ( t->value == v )
			{
				return true;
			}
			return find( ( v < t->value ? t->left : t->right ), v );
		}
 
		Node *merge( Node *l, Node *r )
		{
			if ( l == nullptr || r == nullptr )
			{
				return l == nullptr ? r : l;
			}
 
			const int n = size( l ), m = size( r );
			if ( rand( 0, n + m ) < n )
			{
				l->right = merge( l->right, r );
				return update( l );
			}
			else
			{
				r->left = merge( l, r->left );
				return update( r );
			}
		}
 
		std::pair< Node*, Node* > split_by_index( Node *t, const unsigned int k )
		{
			if ( t == nullptr )
			{
				return std::make_pair( nullptr, nullptr );
			}
 
			if ( k <= size( t->left ) )
			{
				auto ts = split( t->left, k );
				t->left = ts.second;
				return std::make_pair( ts.first, update( t ) );
			}
			else
			{
				auto ts = split( t->right, k - size( t->left ) - 1 );
				t->right = ts.first;
				return std::make_pair( update( t ), ts.second );
			}
		}
 
		std::pair< Node*, Node* > split_by_value( Node *t, const T &v )
		{
			// v 以下の値をもつノードからなる部分木と,それ以外に分割する
 
			if ( t == nullptr )
			{
				return std::make_pair( nullptr, nullptr );
			}
 
			if ( t->value <= v )
			{
				auto ts = split_by_value( t->right, v );
				t->right = ts.first;
				return std::make_pair( update( t ), ts.second );
			}
			else
			{
				auto ts = split_by_value( t->left, v );
				t->left = ts.second;
				return std::make_pair( ts.first, update( t ) );
			}
		}
 
		bool insert( const T &v )
		{
			if ( find( v ) )
			{
				return false;
			}
 
			auto ts = split_by_value( root, v );
			root = merge( merge( ts.first, new Node( v ) ), ts.second );
 
			return true;
		}
 
		bool erase( const T &v )
		{
			if ( !find( v ) )
			{
				return false;
			}
 
			auto l = split_by_value( root, v - 1 );
			auto r = split_by_value( l.second, v );
 
			delete r.first;
 
			root = merge( l.first, r.second );
 
			return true;
		}
 
		T operator[]( const unsigned int k ) const
		{
			return nth_value( root , k );
		}
 
		vector< T > elems()
		{
			vector< T > res;
			traverse( res, root );
			return res;
		}

		void traverse( auto &res, Node * const t )
		{
			if ( !t )
			{
				return;
			}

			traverse( res, t->left );
			res.PB( t->value );
			traverse( res, t->right );

			return;
		}

	private:
		T nth_value( Node * const t, const unsigned k ) const
		{
			if ( size( t->left ) == k )
			{
				return t->value;
			}
			else if ( k <= size( t->left ) )
			{
				return nth_value( t->left, k );
			}
			else
			{
				return nth_value( t->right, k - size( t->left ) - 1 );
			}
		}
 
	public:
		// テスト用
		void dump() const
		{
			return dump( root );
		}
 
		void dump( Node * const t ) const
		{
			if ( t == nullptr )
			{
				return;
			}
 
			dump( t->left );
			cerr << "set element : " << t->value << endl;
			dump( t->right );
 
			return;
		}
	};
}

class DisjointSetForest; // 中身省略
// DisjointSetForest( N )
// find( x )
// same( x, y )
// unite( x, y )
// groups()
// groupSize( x )

int main()
{
	IN( int, N, Q );

	DisjointSetForest dsf( N );
	vector< unique_ptr< RBST::set< int > > > rbsts( N );
	REP( i, N )
	{
		rbsts[i] = make_unique< RBST::set< int > >();
		rbsts[i]->insert( -( i + 1 ) );
	}
	VI indices( N );
	iota( ALL( indices ), 0 );

	const auto merge = [&]( int i, int j )
	{
		if ( rbsts[i]->size() < rbsts[j]->size() )
		{
			swap( i, j );
		}

		const auto elems = rbsts[j]->elems();
		rbsts[j] = nullptr;
		FOR( a, elems )
		{
			rbsts[i]->insert( a );
		}

		return i;
	};

	REP( Q )
	{
		IN( int, T );
		if ( T == 1 )
		{
			IN( int, u, v );
			--u, --v;

			if ( dsf.same( u, v ) )
			{
				continue;
			}

			const int x = dsf.find( u );
			const int y = dsf.find( v );

			const int z = merge( indices[x], indices[y] );
			dsf.unite( u, v );
			indices[ dsf.find( u ) ] = z;
		}
		else
		{
			IN( int, u, k );
			--u;

			const auto &s = *rbsts[ indices[ dsf.find( u ) ] ];
			cout << ( SZ( s ) < k ? -1 : -s[ k - 1 ] ) << '\n';
		}
	}

	cout << flush;

	return 0;
}