学習の栞

学びたいことと、学ぶべきことと、学べることの区別がついてない人間の進捗管理

SRM501 div2 hard 1000pts : FoxAverageSequence 解説

はじめに

SRM501 div2 1000pts を解いたら計算量の落とし方がいい感じだった && 多くの人が自分よりひとつ悪いオーダーで通してて悔しかった ので解説記事を書いてみる

問題概要

美しい整数列Aは以下の制約を満たす。

  • 素数40以下
  • 0 \leq a_{i} \leq 40
  • a_{i} \leq \sum_{k=0}^{k=i-1}a_{i} / i
  • a_{i} \gt a_{i+1} \gt a_{i+2}となるiは存在しない

整数列seqが与えられる。seqの各項は-1以上40以下である。seqの-1の項を0以上40以下の任意の数に変えることが出来るとき、seqから得られる美しい整数列の総数を求めよ。

解法

dp(i,j,k,l)を、i個目までの和がj、i個目の要素がkで、i個目の要素がi-1個目の要素から{ l=0 : 減少していない, l=1 : 減少している}として、

  •  dp(i+1, j, k, 0) = \sum_{0 \leq k' \leq k}( dp(i,j-k,k',0) + dp(i,j-k,k',1) )
  •  dp(i+1, j, k, 1) = \sum_{k \lt k' \leq 40} dp(i,j-k,k',0)

に従って更新すればとりあえず答えは求まる。
但し、dp(i,j,k,l)はseqのi個目の要素が-1で無くseqのi個目の要素がkでない場合0である。また、k * i ≤ j - kを満たさない場合もdp(i,j,k,l)=0とする。

これを愚直に計算しようとすると、時間計算量はseqの要素数nとseqの各要素が取りうる値の最大値mを用いてO(n^{2}m^{3})となり微妙に怪しい*1

ということで、よく使う例の方法で計算量を落としにかかかる。
dp(i+1, j, k, 0) \\ = \sum_{0 \leq k' \leq k} ( dp(i, j-k, k', 0) + dp(i, j-k, k', 1) )
 = dp(i, j-k, k, 0) + dp(i, j-k, k, 1) \\ + \sum_{0 \leq k' \leq k-1} ( dp(i, j-k, k', 0) + dp(i, j-k, k', 1) )
 = dp(i, j-k, k, 0) + dp(i, j-k, k, 1) + dp(i+1, j-1, k-1, 0)

同じ要領で、2番目の式
 dp(i+1, j, k, 1) = \sum_{k \lt k' \leq 40} dp(i,j-k,k',0)
を変形すると
 dp(i+1,j,k,1) = dp(i,j-k,k+1,0) + dp(i+1, j+1, k+1, 1)
となる。

これでdp表の更新にかかるオーダーが減り、全体のオーダーは無事O(n^{2}m^{2})に落ちたのだが、seqのi個目の要素が-1で無い場合のdp表の更新を上の漸化式に従って行うと間違う。なぜならば、dp(i+1, j, k, 0)を計算する場合を例にとれば、 \sum_{0 \leq k' \leq k-1} ( dp(i, j-k, k', 0) + dp(i, j-k, k', 1) ) のつもりで書いている dp(i+1, j-1, k-1, 0) が、k-1とseqのi+1番目の要素が異なる場合、 dp(i+1, j-1, k-1, 0) = 0 になるためである。

seqのi番目の要素が-1で無い場合、dp(i,j,k,l)はseqのi番目の要素とkが一致した場合には

  •  dp(i+1, j, k, 0) = \sum_{0 \leq k' \leq k}( dp(i,j-k,k',0) + dp(i,j-k,k',1) )
  •  dp(i+1, j, k, 1) = \sum_{k \lt k' \leq 40} dp(i,j-k,k',0)

を使って更新し、そうでない場合は0にすればよい。各(i,j,l)の組に対してO(1)の更新がO(m)回と、O(m)の更新がO(1)回しか呼ばれないのだから、DP全体の計算量はO(n^{2}m^{2})に収まる。

実装

バグの箇所がなかなか見つからずにつらかった。書き直す気はない。ぺたり。

#include <string>
#include <vector>
#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <map>
#include <set>
#include <iostream>
#include <sstream>
#include <unordered_map>
#include <unordered_set>
#include <stack>
#include <queue>
#include <deque>
#include <list>
#include <utility>
#include <cassert>
using namespace std;

#define all(x) ((x).begin()),((x).end())
#define pb push_back
#define mkp make_pair
#define fi first
#define se second

typedef long long ll;
typedef pair<int,int> pii;
typedef pair<double,double> pdd;

const int mod = 1e9+7;
int dp[50][50*50][50][2];
vector<int> S;

int solve(int i, int j, int k, int l) {
	if(i < 0 || j < 0 || j > 40*40|| k < 0 || k > 40)
		return 0;
	if(dp[i][j][k][l] + 1)
		return dp[i][j][k][l];
	if(k * i > j - k)
		return dp[i][j][k][l] = 0;

	if(S[i] + 1) {
		if(S[i] != k)
			return dp[i][j][k][l] = 0;
		dp[i][j][k][l] = 0;
		for(int k_ = k + l; 0 <= k_ && k_ <= 40; k_ = k_ + (l == 0 ? -1 : 1)) {
			dp[i][j][k][l] = (dp[i][j][k][l] + solve(i-1,j-k,k_,0)) % mod;
			if(l == 0)
				dp[i][j][k][l] = (dp[i][j][k][l] + solve(i-1,j-k,k_,1)) % mod;
		}
	}
	else {
		dp[i][j][k][l] = 0;
		if(l == 0) {
			for(int l_ = 0; l_ < 2; l_++)
				dp[i][j][k][l] = (dp[i][j][k][l] + solve(i-1,j-k,k,l_)) % mod;
			dp[i][j][k][l] = (dp[i][j][k][l] + solve(i,j-1,k-1,0)) % mod;
		}
		else {
			dp[i][j][k][l] = (dp[i][j][k][l] + solve(i-1,j-k,k+1,0)) % mod;
			dp[i][j][k][l] = (dp[i][j][k][l] + solve(i,j+1,k+1,1)) % mod;
		}
	}
	return dp[i][j][k][l];
}

class FoxAverageSequence {
	public:
	int theCount(vector <int> seq) {
		S = seq;
		for(int i = 0; i < 50; i++)
			for(int j = 0; j < 50*50; j++)
				for(int k = 0; k < 50; k++)
					dp[i][j][k][0] = dp[i][j][k][1] = -1;
		for(int i = 0; i < 50*50; i++)
			for(int j = 0; j < 50; j++)
				dp[0][i][j][0] = dp[0][i][j][1] = 0;
		if(S[0] == -1) {
			for(int i = 0; i <= 40; i++) {
				dp[0][i][i][0] = 1;
			}
		}
		else
			dp[0][S[0]][S[0]][0] = 1;

		ll res = 0;
		for(int i = 0; i <= 40 * 40; i++) {
			for(int j = 0; j <= 40; j++) {
				for(int k = 0; k < 2; k++) {
					res = (res + solve(seq.size()-1,i,j,k)) % mod;
				}
			}
		}
		/*
		cout << endl;
		for(int i = 0; i < 5; i++) {
			for(int j = 0; j < 20; j++) {
				for(int k = 0; k < 50; k++) {
					cout << dp[i][j][k][0] << "," << dp[i][j][k][1] << " ";
				}
				cout << endl;
			}
			cout << endl;
		}
		// */
		return res;
	}
};

感想

O(n^{2}m^{3})解で解いた人がTLEせずにACしてるのが悔しかった*2。計算量の落とし方としては良くある方法*3だが、seqのi番目の要素が-1でないときにループを回してdp表を更新しても全体的な計算量が悪くならないという点がお洒落だと思った。

*1:通るらしいんだなぁ。。。これが。

*2:40^5 が大体 1e8 なので落ちても良さそう

*3:和をdp表のどっかから持ってきて計算量を落とす方法