• 【精选】矩阵加速


    大家好,我是Weekoder!

    今天要讲的内容是矩阵加速!

    这时候就有人说了:

    Weekoder 这么蒻,怎么会矩阵啊。还给我们讲,真是十恶不赦!

    不不不,容我解释。在经过我的研究后,我发现基本的矩阵运算和矩阵加速都并没有那么难。只要继续往下看,相信你也能学会!

    注意:以下内容的学习难度将会用颜色表示,与洛谷题目难度顺序一致,即 <<<绿。(并不对标洛谷题目难度,只作为学习难易度参考)

    Part 1 Definition

    定义

    矩阵和二维数组很像,是由 m×n 个数排列成 mn 列的一张表,由于排列出来的表是一个矩形,故称其为矩阵。矩阵长这个样子:

    (a11a12a13a1na21a22a23a2na31a32a33a3nam1am2am3amn)

    可以看到,矩阵中的每个元素都有着对应的行和列,我们把一个矩阵记作 A,第 ij 列的元素即为 aij。更形式化的,写作:

    A=(aij)Fm×n

    其中 F数域,一般取为实数域 R 或复数域 C。(看不懂没事,蒟蒻自行走开QWQ)

    Part 2 Special matrices

    特殊矩阵

    1.零矩阵

    元素全部为 0 的矩阵称为零矩阵。像这样:

    (000000000)

    零矩阵记作 0m×n,就是在 0 下面加上矩阵的大小 m×n。你可以把零矩阵看做数字 0,任何数乘以 0 都得 0

    2.对角矩阵

    只有主对角线上的元素有值,其余元素为 0 的矩阵称为对角矩阵

    注:主对角线为矩阵中从左上角到右下角的一条对角线。

    (a1000a2000an)

    对角矩阵根据主对角线上的值,记作 diag(a1,a2,,an)

    3.单位矩阵

    主对角线上的元素均为 1,其余元素为 0 的矩阵称为单位矩阵

    (100010001)

    单位矩阵记作 I

    记得分数中的概念分数单位吗?矩阵单位和分数单位的“地位”差不多,代表的都是最基础的,最小的独立个体。你可以把单位矩阵看做数字 1,任何数乘以 1 都等于它本身。

    最基础的,常见的特殊矩阵就是这些了。当然,还有很多的特殊矩阵,不过我们暂时用不到。

    Part 3 Matrix operations

    矩阵运算

    1.相等

    若对于矩阵 A,B,所有的 i,j 都有 aij=bij 且矩阵的行和列相等,则称矩阵 A,B 相等。

    其实就是两个矩阵长得一模一样

    (a11a12a13a1na21a22a23a2na31a32a33a3nam1am2am3amn)=(b11b12b13b1nb21b22b23b2nb31b32b33b3nbm1bm2bm3bmn)

    2.矩阵加(减)法

    若要求 A,B 两个矩阵之和,即 C=A+B,则对于任意 i,j,满足 cij=aij+bij。要求矩阵行列相等。

    总结一句话:对应位置相加。

    (a11a12a13a1na21a22a23a2na31a32a33a3nam1am2am3amn)+(b11b12b13b1nb21b22b23b2nb31b32b33b3nbm1bm2bm3bmn)=(a11+b11a12+b12a13+b13a1n+b1na21+b21a22+b22a23+b23a2n+b2na31+b31a32+b32a33+b33a3n+b3nam1+bm1am2+bm2am3+bm3amn+bmn)

    矩阵加法满足交换律和结合律:

    A+B=B+A

    (A+B)+C=A+(B+C)

    减法同理,对应位置相减。

    (a11a12a13a1na21a22a23a2na31a32a33a3nam1am2am3amn)(b11b12b13b1nb21b22b23b2nb31b32b33b3nbm1bm2bm3bmn)=(a11b11a12b12a13b13a1nb1na21b21a22b22a23b23a2nb2na31b31a32b32a33b33a3nb3nam1bm1am2bm2am3bm3amnbmn)

    3.矩阵数乘

    λ(一个数字) 乘以矩阵 A,记作 λA,即为矩阵数乘运算。若有 B=λA,则对于任意 i,j 都满足 bij=λaij

    还是一句话:对应位置相乘。

    λ(a11a12a13a1na21a22a23a2na31a32a33a3nam1am2am3amn)=(λa11λa12λa13λa1nλa21λa22λa23λa2nλa31λa32λa33λa3nλam1λam2λam3λamn)

    Part 3.5 Matrix multiplication

    矩阵乘法

    虽然矩阵乘法也属于矩阵运算,但难度比前面的都高,而且是今天的重点内容,所以单独放出来讲,故记为 Part 3.5。(话说你们没有发现难度变成黄了吗)

    例题!(虽然难度是橙)

    先看矩阵乘法的定义:若有 nm 列矩阵 Amk 列的矩阵 BA 的行与 B 的列相等),则 nk 列的矩阵 C=A×B 满足

    cij=l=1mail×blj

    只要枚举 i,j(范围是 n,k),并套用公式就能用 O(n3) 的时间复杂度解决这个问题。

    我知道,这看起来根本不是新手蒟蒻能看懂的。那我就用人话来讲讲矩阵乘法。

    矩阵乘法并不是一个一个乘,而是行对应列乘。怎么个乘法呢?我们来看看下面两个矩阵相乘的例子。

    (523794)×(268109132441)

    第一个矩阵为 A,第二个矩阵为 B

    我们先取出 A第一行。像这样:

    (523)

    再取出 B第一列。像这样:

    (202)

    不对,你给我转过来。

    (202)

    现在终于可以相乘了。逐位相乘得出结果:

    (523)×(202)=(5×22×03×2)=(1006)

    得出了结果 (1006)。再将每一位相加:

    (1006)10+0+6=16

    还记得我们之前是怎么取的吗?我们取了 A第一行B第一列(注意加粗部分),所以答案就存储在 C第一行第一列。还没搞懂?更通用一点:我们取了 AxBy(注意加粗部分),所以答案就存储在 Cx 行第 y。也就是说,当我们想要获取矩阵 C 的第 xy 列的时候,就需要取 A 的第 x 行和 B 的第 y 列,相乘再相加。由于 A 的行数与 B 的列数相等,取出来的数列才可以逐位相乘(不然元素个数不一样)。而取出来的数列长度就是 m,所以可以用 O(m) 求和,总时间复杂度 O(nmk)=O(n3)

    最后,可以看看代码辅助理解。

    #include 
    using namespace std;
    
    const int N = 105;
    
    int n, m, k, a[N][N], b[N][N]; // 用二维数组存矩阵 A,B
    
    int main() {
        cin >> n >> m >> k;
        for (int i = 1; i <= n; i++)
            for (int j = 1; j <= m; j++)
                cin >> a[i][j]; // 输入矩阵 A
        for (int i = 1; i <= m; i++)
            for (int j = 1; j <= k; j++)
                cin >> b[i][j]; // 输入矩阵 B
        for (int i = 1; i <= n; i++) {
            for (int j = 1; j <= k; j++) { // 枚举 C 矩阵 n 行 k 列的每个元素
      	        // 以下部分为模拟刚刚讲的矩阵乘法
                int sum = 0; // 求和,sum 即为 C_ij
                for (int l = 1; l <= m; l++)
                    sum += a[i][l] * b[l][j]; // 求和,A 的行和 B 的列,建议模拟一下过程加强理解
                cout << sum << " "; // 输出 sum(C_ij)
            }
            cout << "\n"; // 记得换行!
        }
        return 0; // 完美的结束
    } 
    

    这样就能愉快地切掉这道题了。请完成这道题再继续!

    矩阵乘法满足以下性质:

    结合律:(AB)C=A(BC)

    分配律:(A+B)C=AC+BC

    矩阵乘法不满足交换律。(这是重点!)

    有了矩阵乘法,我们还可以结合上面的特殊矩阵得到一些性质:

    A×I=A

    A×0m×n=0m×n

    Part 4 Matrix fast power

    矩阵封装 & 矩阵快速幂

    快到今天的主题了!上例题

    点开题目后的你 be like:

    这是啥呀?

    我来让题目描述“缩点水”:

    给定一个 nn 列的矩阵 A,求 Ak,即 A×A×A××A×Ak 次

    第一思路:暴力!直接做 k 次矩阵乘法,时间复杂度 O(kn3)。看看数据范围:

    0k1012

    考虑放弃做题。

    那我们该怎么优化呢?看到需要计算 Ak,我突然想到了一个算法:快速幂!但是矩阵快速幂该怎么写呢?答案是:和正常的快速幂一样,矩阵也能使用快速幂,只不过快速幂中的乘法变成了矩阵乘法。但是矩阵乘法太难写,有没有什么办法能让矩阵乘法也像普通的乘法一样,只要写一个 * 乘号就行了呢?

    注意:不会快速幂的话可以先简单看看我写的文章

    回到主题,有没有什么办法能只要写一个 * 乘号就能进行矩阵乘法呢?其实我们可以用结构体把矩阵封装起来,再用重载运算符就行了。关于重载运算符,可以参考这些资料

    定义一个矩阵类型的结构体可以写成这样:

    struct Matrix {
    	
    };
    

    我们需要在里面用一个二维数组存储矩阵。我们还可以写一个结构体初始化函数,只要定义了一个矩阵,就自动清零,免去清零的麻烦。

    struct Matrix {
    	int a[N][N]; // N 为矩阵大小
    	Matrix() {
    		memset(a, 0, sizeof a);
    	}
    };
    

    最后,把矩阵乘法写进去。

    struct Matrix {
        ll a[N][N];
        Matrix() {
            memset(a, 0, sizeof a);
        }
        Matrix operator*(const Matrix &x)const {
            Matrix res;
            for (int i = 1; i <= n; i++)
                for (int j = 1; j <= n; j++)
                    for (int k = 1; k <= n; k++)
                        res.a[i][j] = (res.a[i][j] % MOD + a[i][k] % MOD * x.a[k][j] % MOD) % MOD;
            return res;
        }
    }; 
    

    注意,这里一定要写成 a[i][k] * x.a[k][j],不能写成 x.a[i][k] * a[k][j],因为矩阵乘法不满足交换律!

    这样,结构体封装部分就完成了。

    我们要定义两个矩阵:abasea 是输入的矩阵,base 是答案矩阵,所以 base 需要初始化成 I(单位矩阵),写一个初始化函数 init,如下:

    void init() {
        for (int i = 1; i <= n; i++) base.a[i][i] =1;
    }
    

    初始化完以后,就可以执行快速幂了,计算 Ak 了,让 baseA。矩阵快速幂核心代码如下:

    void expow(ll b) {
        while (b) {
            if (b & 1) base = base * a;
            a = a * a, b >>= 1;
        }  
    }
    

    有一点需要注意的就是,不能写成 base *= a 等形式,因为重载运算符定义的是 *,没有定义 *=,所以需要将 *= 展开。

    最后,就可以输出 base 了。展示全部代码:

    #include 
    using namespace std;
    
    typedef long long ll;
    
    const int N = 105, MOD = 1e9 + 7;
    
    int n;
    ll k;
    
    struct Matrix {
        ll a[N][N];
        Matrix() {
            memset(a, 0, sizeof a);
        }
        Matrix operator*(const Matrix &x)const {
            Matrix res;
            for (int i = 1; i <= n; i++)
                for (int j = 1; j <= n; j++)
                    for (int k = 1; k <= n; k++)
                        res.a[i][j] = (res.a[i][j] % MOD + a[i][k] % MOD * x.a[k][j] % MOD) % MOD;
            return res;
        }
    }a, base; 
    
    void init() {
        for (int i = 1; i <= n; i++) base.a[i][i] =1;
    }
    
    void expow(ll b) {
        while (b) {
            if (b & 1) base = base * a;
            a = a * a, b >>= 1;
        }  
    }
    
    int main() {
        cin >> n >> k;
        for (int i = 1; i <= n; i++)
            for (int j = 1; j <= n; j++)
                cin >> a.a[i][j];
        init();
        expow(k);
        for (int i = 1; i <= n; putchar('\n'), i++)
            for (int j = 1; j <= n; j++)
                cout << base.a[i][j] << " ";
        return 0;
    } 
    

    Part 5 Matrix acceleration

    矩阵加速

    终于到了最后的 BOSS 关卡 了!你们有信心吗?加油!

    点击此处进入 BOSS 关卡 ......

    点开题目 BOSS 关卡 后的你 be like(梅开二度):

    这和矩阵有什么关系吗???

    我直接一个递推!

    • 对于 100% 的数据 1T1001n2×109

    O(Tn)2×1011 的复杂度实在无法接受。

    (呜呜呜我再也不学 c艹 了)

    没关系,先看看思路!

    因为发现当 x3 时答案为 1,所以这是最基础的情况。我们可以构造一个只有一列的矩阵:

    (a3a2a1)=(111)

    显然,这三个元素都是 1

    那么,假设我想要得到 a4,该怎么办呢?所以,我们需要进行一种运算,让上面的矩阵变化一下,像这样:

    (a3a2a1)(a4a3a2)

    更加通用一点:

    (axax1ax2)(ax+1axax1)

    可以发现,矩阵中的每个元素的项数都向前推进了 1。那么,我们大概可以写出伪代码:


    如果 x3

    输出 1

    否则

    执行运算 n3 次(重要!)

    并输出答案矩阵 11


    特判(对于特殊情况的判断)和输出应该没什么问题,主要是为什么运算恰好要执行 n3 次呢?稍微画个图模拟一下就好了。

    还是假设要获取 a4,则执行运算 43=1 次。在执行 1 次运算后,

    (a3a2a1)

    变为

    (a4a3a2)

    这样就刚好在第 11 列得到 a4 啦!

    那么,说了这么久,这个神秘的运算是什么呢?当当当当~,他就是我们的——矩阵乘法!

    没错,所谓的变换,其实就是乘上了一个特殊的矩阵!那么,这个矩阵长什么样呢?让我们一起来推理吧。

    (此处应配上推理の小曲)

    我们可以先列一个表格,表格的行代表矩阵 (a3a2a1) 的元素,列代表递推时与这些元素相关的元素。像这样:(表格可能在博客里渲染不出来,凑合着看吧,抱歉)

    ax ax1 ax2
    ax1
    ax2
    ax3

    好了,对于 ax,我们该怎么填他那一列呢?我们可以观察到递推式 ax=ax1+ax3,所以有:

    ax=ax1×1+ax2×0+ax3×1

    观察系数 1,0,1,把这些系数填入表格中:

    ax ax1 ax2
    ax1 1
    ax2 0
    ax3 1

    后面的也以此类推:

    ax1=ax1×1+ax2×0+ax3×0

    ax2=ax1×0+ax2×1+ax3×0

    ax ax1 ax2
    ax1 1 1 0
    ax2 0 0 1
    ax3 1 0 0

    这样,我们就可以推出这个神秘的矩阵了:

    (110001100)

    好了,现在我们终于知道了,一次神秘操作,就是将让 (a3a2a1) 这个矩阵乘上(110001100)。这时候就有人问了:

    一次矩阵乘法的时间复杂度还没有递推快,这根本就没有优化嘛。

    等等!我们把这个式子展开:

    (111)×(110001100)×(110001100)×(110001100)=(111)×(110001100)n3

    不是吧!这居然变成了一个矩阵快速幂?!!

    也就是说,我们可以用快速幂计算 (110001100)n3,并乘上初始矩阵 (111)。这样,我们成功地把时间复杂度从 O(Tn) 优化到了 O(Tlogn)!(矩阵快速幂是 O(logn),因为矩阵很小,矩阵乘法只计算 9 次,是一个很小的常数)

    下面奉上代码:(标准的矩阵加速思想)

    #include 
    using namespace std;
    
    typedef long long ll;
    
    const int MOD = 1e9 + 7;
    
    int T, n;
    
    struct Matrix {
        ll a[5][5];
        Matrix() {
            memset(a, 0, sizeof a);
        }
        Matrix operator*(const Matrix &x)const { // 矩阵乘法
            Matrix res;
            for (int i = 1; i <= 3; i++)
                for (int j = 1; j <= 3; j++)
                	for (int k = 1; k <= 3; k++)
                        res.a[i][j] = (res.a[i][j] % MOD + a[i][k] % MOD * x.a[k][j] % MOD) % MOD;
            return res;
        }
        void mems() {
        	memset(a, 0, sizeof a);
    	}
    }ans, base; 
    
    void init() { // 初始化两个矩阵
    	ans.mems(), base.mems(); // 记得清空!
    	ans.a[1][1] = ans.a[1][2] = ans.a[1][3] = 1;
    	base.a[1][1] = base.a[1][2] = base.a[2][3] = base.a[3][1] = 1;
    }
    
    void expow(int b) { // 矩阵快速幂,是在 ans 矩阵的基础上乘的
        while (b) {
            if (b & 1) ans = ans * base;
            base = base * base, b >>= 1;
        }  
    }
    
    int main() {    
        cin >> T;
        while (T --) {
            cin >> n;
            init(); // 初始化不能忘
            if (n <= 3) { // 特判
                cout << "1\n";
                continue;
            } 
            expow(n - 3); // 计算特殊矩阵的 n - 3 次方,已经乘到了 ans 里
            cout << ans.a[1][1] << "\n"; // 输出答案!芜湖!
        }
        return 0; // 快乐结束
    } 
    

    就这样,我们完成了矩阵加速递推。

    再次声明矩阵快速幂(矩阵加速)时间复杂度:O(N3logn),其中 N 为矩阵的行数(列数),n 为快速幂的规模 an

    小提示:关于 base 矩阵的构造

    就是这个 (110001100) 矩阵。

    可以这样:我们要推导出 (axax1ax2),那么这个矩阵从哪里来?当然是从 (ax1ax2ax3) 来。所以,表格才长这样:

    ax ax1 ax2
    ax1
    ax2
    ax3

    那么,能不能构造一个行列数各不相同的矩阵,而不是一个 n×n 的矩阵呢?答案是不可以,因为我们要计算 (110001100) 这种矩阵的幂,那如果行和列不相等,相乘的两个矩阵的行列也不相等,就无法进行矩阵乘法。比如这个:

    (110001)×(110001)

    可以看到,左边 2 行,右边 3 列,显然不相等,无法进行矩阵乘法。

    Part 6 Thank you!

    这篇文章花费了我很多时间,希望你喜欢!

    对了,你学会了吗?是不是,矩阵也并没有那么难?

    这应该是我的【精选】文章中的第一篇,没想到写的是矩阵方面的。

    总之,很感谢你的阅读!希望你能从我这学到点东西!

    再见!

本文作者:Weekoder

本文链接:https://www.cnblogs.com/weekoder/p/18237764

版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。

  • 相关阅读:
    web3再牛 也没能逃出这几个老巨头的手掌心
    R3F(React Three Fiber)经验篇
    《入门级-Cocos2dx4.0 塔防游戏开发》---第十课:游戏中餐单设置
    Day712. 封闭类-Java8后最重要新特性
    [go]golang中“var“与“:=“的区别
    使用grpcui测试ASP.NET core gRPC服务
    mac 编译问题记录
    高速公路测量计算CASIO程序全套
    C++ 多态语法点
    Qt/QML编程之路:基于QWidget编程及各种2D/3D/PIC绘制的示例(45)
  • 原文地址:https://www.cnblogs.com/Weekoder/p/18237764