矩阵快速幂可以将 O ( n ) O(n) O(n) 的 DP 优化成 O ( log n ) O(\log{n}) O(logn) 的时间复杂度。
前置知识——快速幂,可见:【算法基础:数学知识】4.4 快速幂
我们拿递推公式
d
p
[
i
]
=
d
p
[
i
−
1
]
+
[
i
−
2
]
dp[i] = dp[i - 1] + [i - 2]
dp[i]=dp[i−1]+[i−2] 举例,
为了将其表示成矩阵乘法,添加一个式子
d
p
[
i
−
1
]
=
d
p
[
i
−
1
]
dp[i - 1] = dp[i - 1]
dp[i−1]=dp[i−1]
两个式子合并可以得——
这是因为
d
p
[
i
]
=
1
∗
d
p
[
i
−
1
]
+
1
∗
[
i
−
2
]
dp[i] = 1 * dp[i - 1] + 1 * [i - 2]
dp[i]=1∗dp[i−1]+1∗[i−2] ,
d
p
[
i
−
1
]
=
1
∗
d
p
[
i
−
1
]
+
0
∗
[
i
−
2
]
dp[i - 1] = 1 * dp[i - 1] + 0 * [i - 2]
dp[i−1]=1∗dp[i−1]+0∗[i−2]
这样就可以对递推矩阵使用快速幂来将 n 次递推的时间复杂度简化到 logn。
https://leetcode.cn/problems/climbing-stairs/
提示:
1 <= n <= 45
class Solution {
public int climbStairs(int n) {
if (n <= 1) return n;
int[] dp = new int[n];
dp[0] = 1;
dp[1] = 2;
for (int i = 2; i < n; ++i) {
dp[i] = dp[i - 1] + dp[i - 2];
}
return dp[n - 1];
}
}
根据递推公式,可以得出矩阵快速幂的矩阵是什么。
class Solution {
public int climbStairs(int n) {
if (n <= 2) return n;
// m是根据递推公式来的
int[][] m = {
{1, 1},
{1, 0}
};
return pow(m, n - 2)[0][0];
}
public int[][] pow(int[][] m, int k) {
// dp[0] = 1,dp[1] = 2
int[][] res = {
{2, 0},
{1, 0}
};
for (; k != 0; k /= 2) {
if (k % 2 == 1) res = mul(m, res);
m = mul(m, m);
}
return res;
}
public int[][] mul(int[][] x, int[][] y) {
int[][] res = new int[2][2];
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 2; ++j) {
res[i][j] = x[i][0] * y[0][j] + x[i][1] * y[1][j];
}
}
return res;
}
}
也可以更简洁一些,从 dp[0] 开始写,代码如下:
class Solution {
public int climbStairs(int n) {
// m是根据递推公式来的
int[][] m = {
{1, 1},
{1, 0}
};
return pow(m, n)[0][0];
}
public int[][] pow(int[][] m, int k) {
// dp[0] = 1
int[][] res = {
{1, 0},
{0, 0}
};
for (; k != 0; k /= 2) {
if (k % 2 == 1) res = mul(m, res);
m = mul(m, m);
}
return res;
}
public int[][] mul(int[][] x, int[][] y) {
int[][] res = new int[2][2];
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 2; ++j) {
res[i][j] = x[i][0] * y[0][j] + x[i][1] * y[1][j];
}
}
return res;
}
}
https://leetcode.cn/problems/fibonacci-number/
提示:
0 <= n <= 30
跟上一题差不多,注意初始值变了。
class Solution {
public int fib(int n) {
if (n == 0) return n;
// m是根据递推公式来的
int[][] m = {
{1, 1},
{1, 0}
};
return pow(m, n - 1)[0][0];
}
public int[][] pow(int[][] m, int k) {
// dp[0] = 1
int[][] res = {
{1, 0},
{0, 0}
};
for (; k != 0; k /= 2) {
if (k % 2 == 1) res = mul(m, res);
m = mul(m, m);
}
return res;
}
public int[][] mul(int[][] x, int[][] y) {
int[][] res = new int[2][2];
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 2; ++j) {
res[i][j] = x[i][0] * y[0][j] + x[i][1] * y[1][j];
}
}
return res;
}
}
https://leetcode.cn/problems/n-th-tribonacci-number/
提示:
0 <= n <= 37
答案保证是一个 32 位整数,即 answer <= 2^31 - 1。
对矩阵稍作修改即可。
class Solution {
public int tribonacci(int n) {
if (n <= 1) return n;
// m是根据递推公式来的
int[][] m = {
{1, 1, 1},
{1, 0, 0},
{0, 1, 0}
};
return pow(m, n - 2)[0][0];
}
public int[][] pow(int[][] m, int k) {
// dp[0] = 0, dp[1] = 1, dp[2] = 1
int[][] res = {
{1, 0, 0},
{1, 0, 0},
{0, 0, 0}
};
for (; k != 0; k /= 2) {
if (k % 2 == 1) res = mul(m, res);
m = mul(m, m);
}
return res;
}
public int[][] mul(int[][] x, int[][] y) {
int[][] res = new int[3][3];
for (int i = 0; i < 3; ++i) {
for (int j = 0; j < 3; ++j) {
res[i][j] = x[i][0] * y[0][j] + x[i][1] * y[1][j] + x[i][2] * y[2][j];
}
}
return res;
}
}
https://leetcode.cn/problems/count-vowels-permutation/
提示:
1 <= n <= 2 * 10^4
class Solution {
final long MOD = (int)1e9 + 7;
public int countVowelPermutation(int n) {
long[][] dp = new long[n][5];
Arrays.fill(dp[0], 1);
for (int i = 1; i < n; ++i) {
dp[i][0] = dp[i - 1][1];
dp[i][1] = (dp[i - 1][0] + dp[i - 1][2]) % MOD;
dp[i][2] = (dp[i - 1][0] + dp[i - 1][1] + dp[i - 1][3] + dp[i - 1][4]) % MOD;
dp[i][3] = (dp[i - 1][2] + dp[i - 1][4]) % MOD;
dp[i][4] = dp[i - 1][0];
}
long ans = 0;
for (long x: dp[n - 1]) ans = (ans + x) % MOD;
return (int)ans;
}
}
class Solution {
final long MOD = (int)1e9 + 7;
public int countVowelPermutation(int n) {
long[][] m = {
{0, 1, 0, 0, 0},
{1, 0, 1, 0, 0},
{1, 1, 0, 1, 1},
{0, 0, 1, 0, 1},
{1, 0, 0, 0, 0}
};
long[][] res = pow(m, n - 1);
long ans = 0;
for (int i = 0; i < 5; ++i) ans = (ans + res[i][0]) % MOD;
return (int)ans;
}
public long[][] pow(long[][] m, int k) {
long[][] res = {
{1, 0, 0, 0, 0},
{1, 0, 0, 0, 0},
{1, 0, 0, 0, 0},
{1, 0, 0, 0, 0},
{1, 0, 0, 0, 0}
};
for (; k != 0; k /= 2) {
if (k % 2 == 1) res = mul(m, res);
m = mul(m, m);
}
return res;
}
public long[][] mul(long[][] x, long[][] y) {
long[][] res = new long[5][5];
for (int i = 0; i < 5; ++i) {
for (int j = 0; j < 5; ++j) {
res[i][j] = (x[i][0] * y[0][j] + x[i][1] * y[1][j] + x[i][2] * y[2][j] + x[i][3] * y[3][j] + x[i][4] * y[4][j]) % MOD;
}
}
return res;
}
}
https://leetcode.cn/problems/student-attendance-record-ii/
提示:
1 <= n <= 10^5
class Solution {
final int MOD = (int)1e9 + 7;
public int checkRecord(int n) {
// 长度,A 的数量,结尾连续 L 的数量
int[][][] dp = new int[n + 1][2][3];
dp[0][0][0] = 1;
for (int i = 1; i <= n; ++i) {
// 以P结尾
for (int j = 0; j < 2; ++j) {
for (int k = 0; k < 3; ++k) {
dp[i][j][0] = (dp[i][j][0] + dp[i - 1][j][k]) % MOD;
}
}
// 以A结尾
for (int k = 0; k < 3; ++k) {
dp[i][1][0] = (dp[i][1][0] + dp[i - 1][0][k]) % MOD;
}
// 以L结尾
for (int j = 0; j < 2; ++j) {
for (int k = 1; k < 3; ++k) {
dp[i][j][k] = (dp[i][j][k] + dp[i - 1][j][k - 1]) % MOD;
}
}
}
int ans = 0;
for (int j = 0; j < 2; ++j) {
for (int k = 0; k < 3; ++k) {
ans = (ans + dp[n][j][k]) % MOD;
}
}
return ans;
}
}
在这里插入代码片
https://leetcode.cn/problems/domino-and-tromino-tiling/
提示:
1 <= n <= 1000
class Solution {
final int MOD = (int)1e9 + 7;
public int numTilings(int n) {
// 0空,1上,2下,3满
int[][] dp = new int[n][4];
dp[0][0] = dp[0][3] = 1;
for (int i = 1; i < n; ++i) {
dp[i][0] = dp[i - 1][3];
dp[i][1] = (dp[i - 1][0] + dp[i - 1][2]) % MOD;
dp[i][2] = (dp[i - 1][0] + dp[i - 1][1]) % MOD;
dp[i][3] = (((dp[i - 1][0] + dp[i - 1][1]) % MOD + dp[i - 1][2]) % MOD + dp[i - 1][3]) % MOD;
}
return dp[n - 1][3];
}
}
https://leetcode.cn/problems/domino-and-tromino-tiling/solutions/1968516/by-endlesscheng-umpp/
class Solution {
final long MOD = (long)1e9 + 7;
public int numTilings(int n) {
if (n <= 2) return n;
long[] dp = new long[n + 1];
dp[0] = 1;
dp[1] = 1;
dp[2] = 2;
for (int i = 3; i <= n; ++i) {
dp[i] = (dp[i - 1] * 2 + dp[i - 3]) % MOD;
}
return (int)dp[n];
}
}
class Solution {
final int MOD = (int)1e9 + 7;
public int numTilings(int n) {
// 0空,1上,2下,3满
long[][] m = {
{0, 0, 0, 1},
{1, 0, 1, 0},
{1, 1, 0, 0},
{1, 1, 1, 1}
};
return (int)pow(m, n - 1)[3][0];
}
public long[][] pow(long[][] m, int k) {
long[][] res = {
{1, 0, 0, 0},
{0, 0, 0, 0},
{0, 0, 0, 0},
{1, 0, 0, 0}
};
for (; k != 0; k >>= 1) {
if ((k & 1) == 1) res = mul(m, res);
m = mul(m, m);
}
return res;
}
public long[][] mul(long[][] a, long[][] b) {
long[][] c = new long[4][4];
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 4; ++j) {
c[i][j] = (a[i][0] * b[0][j] + a[i][1] * b[1][j] + a[i][2] * b[2][j] + a[i][3] * b[3][j]) % MOD;
}
}
return c;
}
}
# 相关链接
[【力扣周赛】第 362 场周赛(⭐差分&匹配&状态压缩DP&矩阵快速幂优化DP&KMP)](https://blog.csdn.net/qq_43406895/article/details/132824604)