题面:
假设有
n
n
n堆石子,你需要合并成
1
1
1堆,可以选择相邻的两堆石子合并,每次合并消耗两堆石子的总重量。问合并完所有石子的总代价最小为多少。
数据范围: 1 ⩽ n ⩽ 500 1 \leqslant n \leqslant 500 1⩽n⩽500
暴力递归写法
#include
using namespace std;
const int N = 510;
int n;
int a[N], sum[N];
int dfs(int l, int r){
if(l == r) return 0;
int ans = 1 << 30;
for(int i = l; i <= r - 1; i ++ ){//枚举中间的分界点
ans = min(ans, dfs(l, i) + dfs(i + 1, r));
}
return ans + sum[r] - sum[l - 1];
}
int main(){
scanf("%d",&n);
for(int i = 1; i <= n; i ++ ) scanf("%d", &a[i]);
for(int i = 1; i <= n; i ++ ){
sum[i] = sum[i - 1] + a[i];
}
printf("%d\n", dfs(1, n));
return 0;
}
记忆化搜索写法 O ( n 3 ) O(n^3) O(n3)
#include
using namespace std;
const int N = 510;
int n;
int a[N], sum[N], f[N][N];
int dfs(int l, int r){//表示合并l-r所有的石子的代价
if(f[l][r] != -1) return f[l][r];
if(l == r){
f[l][r] = 0;
return 0;
}
int ans = 1 << 30;
for(int i = l; i <= r - 1; i ++ ){//枚举中间的分界点
ans = min(ans, dfs(l, i) + dfs(i + 1, r));
}
f[l][r] = ans + sum[r] - sum[l - 1];
return f[l][r];
}
int main(){
scanf("%d",&n);
for(int i = 1; i <= n; i ++ ) scanf("%d", &a[i]);
for(int i = 1; i <= n; i ++ ){
sum[i] = sum[i - 1] + a[i];
}
memset(f, -1, sizeof f);
printf("%d\n", dfs(1, n));
return 0;
}
区间dp写法
#include
using namespace std;
const int N = 510;
int n;
int a[N], sum[N], f[N][N];
int main(){
scanf("%d",&n);
for(int i = 1; i <= n; i ++ ) scanf("%d", &a[i]);
for(int i = 1; i <= n; i ++ ){
sum[i] = sum[i - 1] + a[i];
}
memset(f, 127, sizeof f);
for(int i = 1; i <= n; i ++ ) f[i][i] = 0;
for(int i = 1; i <= n; i ++ )
for(int j = 1; j + i - 1 <= n; j ++ )
for(int k = j; k <= j + i - 1; k ++ )
f[j][j + i - 1] = min(f[j][j + i - 1], f[j][k] + f[k + 1][j + i - 1] + sum[j + i - 1] - sum[j - 1]);
printf("%d\n", f[1][n]);
return 0;
}
区间dp比记忆化搜索快3倍左右
题面:
假设有
n
n
n堆石子,
n
n
n堆石子围城一个环,你可以选择相邻的两堆石子合并,你需要合并成
1
1
1堆,每次合并消耗两堆石子的总重量。问合并完所有石子的总代价最小为多少。
数据范围: 1 ⩽ n ⩽ 250 1 \leqslant n \leqslant 250 1⩽n⩽250
思路:
根据环形问题,我们通常会把环形问题线性化。
#include
using namespace std;
const int N = 510;
int a[N], f[N][N], n, sum[N];
int ans;
int main(){
scanf("%d",&n);
for(int i = 1; i <= n; i ++ ){
scanf("%d",&a[i]);
a[i + n] = a[i];
}
n *= 2;
for(int i = 1; i <= n; i ++ ) sum[i] = sum[i - 1] + a[i];
memset(f, 127, sizeof f);
for(int i = 1; i <= n; i ++ ) f[i][i] = 0;
for(int i = 1; i <= n; i ++ ){
for(int j = 1; j + i - 1 <= n; j ++ ){
for(int k = j; k - j + 1 < i; k ++ ){
f[j][j + i - 1] = min(f[j][j + i - 1], f[j][k] + f[k + 1][j + i - 1] + sum[j + i - 1] - sum[j - 1]);
}
}
}
int ans = 2e9;
for(int i = 1; i <= n / 2; i ++ ){
ans = min(ans, f[i][i + n / 2 - 1]);
}
printf("%d\n",ans);
return 0;
}