• Strassen矩阵乘法问题(Java)


    Strassen矩阵乘法问题(Java)


    在这里插入图片描述


    1、前置介绍

    矩阵乘法是线性代数中最常见的问题之一 ,它在数值计算中有广泛的应用。 设AB是2个nXn矩阵,
    它们的乘积AB同样是一个nXn矩阵。 AB的乘积矩阵C中元素C[i][j]定义为:
    C [ i ] [ j ] = ∑ k = 1 n A [ i ] [ k ] B [ k ] [ j ] C[i][j] = \sum_{k=1}^{n}A[i][k]B[k][j] C[i][j]=k=1nA[i][k]B[k][j]

    在这里插入图片描述

    采用传统方法,时间复杂度为:O(n3)

    因为按照上述的定义来计算A和 B的乘积矩阵c,则每计算C的一个元素C[i][j],需要做n次乘法运算和n-1次加法运算。 因此,得到矩阵C的n2 个元素所需的计算时间为 O(n3) 。

    为解决计算计算效率问题,Strassen算法由此出现,该算法基本思想是分治,将计算2个n阶矩阵乘积所需的计算时间改进到0(nlog7) = 0(n2.81)

    我们知道,C11=A11*B11+A12*B21

    在这里插入图片描述

    矩阵A和B的示意图如下:

    在这里插入图片描述

    传统方法:

    在这里插入图片描述

    2个n阶方阵的乘积转换为8个n/2 阶方阵的乘积和4个n/2阶方阵的加法。

    由此可得:

    C11 = A11B11 + A12B21

    C12 = A11B12 + A12B22

    C21 = A21B11 + A22B21

    C22 = A21B12 + A22B22

    分治法:

    为了降低时间复杂度,必须减少乘法的次数。

    使用与上例类似的技术,将矩阵A,B和C中每一矩阵都分块成4个大小相等的子矩阵。由此可将方程C=AB重写为:

    在这里插入图片描述

    2个n阶方阵的乘积转换为7个n/2 阶方阵的乘积和18个n/2阶方阵的加减法。

    伪代码如下:

    // 递归维度分半算法:
    public void STRASSEN(n,A,B,C);
    {  
    if n=2 then MATRIX-MULTIPLY(A,B,C)
    / /结束循环,计算 两个2阶方阵的乘法         
    else{
      将矩阵A和B分块;
      STRASSEN(n/2,A11,B12-B22,M1);
      STRASSEN(n/2,A11+A12,B22,M2); 
      STRASSEN(n/2,A21+A22,B11,M3);
      STRASSEN(n/2,A22,B21-B11,M4);
      STRASSEN(n/2,A11+A22,B11+B22,M5);
      STRASSEN(n/2,A12-A22,B21+B22,M6);
      STRASSEN(n/2,A11-A21,B11+B12,M7);}
    }                
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    算法导论伪代码:

    在这里插入图片描述

    3、代码实现

    public class StrassenMatrixMultiply
    {
        public static void main(String[] args)
        {
            int[] a = new int[]
            {
                1, 1, 1, 1,
                2, 2, 2, 2,
                3, 3, 3, 3,
                4, 4, 4, 4
            };
    
            int[] b = new int[]
            {
                1, 2, 3, 4,
                1, 2, 3, 4,
                1, 2, 3, 4,
                1, 2, 3, 4
            };
    
            int length = 4;
    
            int[] c = sMM(a, b, length);
    
            for(int i = 0; i < c.length; i++)
            {
                System.out.print(c[i] + " ");
    
                if((i + 1) % length == 0) //换行
                    System.out.println();
            }
        }
    
        public static int[] sMM(int[] a, int[] b, int length) {
            if(length == 2) {
                return getResult(a, b);
            }
            else {
                int tlength = length / 2;
                // 把a数组分为四部分,进行分治递归
                int[] aa = new int[tlength * tlength];
                int[] ab = new int[tlength * tlength];
                int[] ac = new int[tlength * tlength];
                int[] ad = new int[tlength * tlength];
                // 把b数组分为四部分,进行分治递归
                int[] ba = new int[tlength * tlength];
                int[] bb = new int[tlength * tlength];
                int[] bc = new int[tlength * tlength];
                int[] bd = new int[tlength * tlength];
    
                // TODO 划分子矩阵
                for(int i = 0; i < length; i++) {
                    for(int j = 0; j < length; j++) {
                        /*
                         * 划分矩阵:
                         * 例子:将 4 * 4 的矩阵,变为 2 * 2 的矩阵,
                         * 那么原矩阵左上、右上、左下、右下的四个元素分别归为新矩阵
                        */
                        if(i < tlength) {
                            if(j < tlength) {
                                aa[i * tlength + j] = a[i * length + j];
                                ba[i * tlength + j] = b[i * length + j];
                            } else {
                                ab[i * tlength + (j - tlength)] = a[i * length + j];
                                bb[i * tlength + (j - tlength)] = b[i * length + j];
                            }
                        } else {
                            if(j < tlength) {
                                //i 大于 tlength 时,需要减去 tlength,j同理
                                //因为 b,c,d三个子矩阵有对应了父矩阵的后半部分
                                ac[(i - tlength) * tlength + j] = a[i * length + j];
                                bc[(i - tlength) * tlength + j] = b[i * length + j];
                            } else {
                                ad[(i - tlength) * tlength + (j - tlength)] = a[i * length + j];
                                bd[(i - tlength) * tlength + (j - tlength)] = b[i * length + j];
                            }
                        }
                    }
                }
    
                // TODO 分治递归
                int[] result = new int[length * length];
    
                // temp:4个临时矩阵
                int[] t1 = add(sMM(aa, ba, tlength), sMM(ab, bc, tlength));
                int[] t2 = add(sMM(aa, bb, tlength), sMM(ab, bd, tlength));
                int[] t3 = add(sMM(ac, ba, tlength), sMM(ad, bc, tlength));
                int[] t4 = add(sMM(ac, bb, tlength), sMM(ad, bd, tlength));
    
                // TODO 归并结果
                for(int i = 0; i < length; i++) {
                    for(int j = 0; j < length; j++) {
                        if (i < tlength){
                            if(j < tlength) {
                                result[i * length + j] = t1[i * tlength + j];
                            } else {
                                result[i * length + j] = t2[i * tlength + (j - tlength)];
                            }
                        } else {
                            if(j < tlength) {
                                result[i * length + j] = t3[(i - tlength) * tlength + j];
                            } else {
                                result[i * length + j] = t4[(i - tlength) * tlength + (j - tlength)];
                            }
                        }
                    }
                }
                return result;
            }
        }
    
        public static int[] getResult(int[] a, int[] b) {
            int p1 = a[0] * (b[1] - b[3]);
            int p2 = (a[0] + a[1]) * b[3];
            int p3 = (a[2] + a[3]) * b[0];
            int p4 = a[3] * (b[2] - b[0]);
            int p5 = (a[0] + a[3]) * (b[0] + b[3]);
            int p6 = (a[1] - a[3]) * (b[2] + b[3]);
            int p7 = (a[0] - a[2]) * (b[0] + b[1]);
    
            int c00 = p5 + p4 - p2 + p6;
            int c01 = p1 + p2;
            int c10 = p3 + p4;
            int c11 = p5 + p1 -p3 - p7;
    
            return new int[] {c00, c01, c10, c11};
        }
    
        public static int[] add(int[] a, int[] b) {
            int[] c = new int[a.length];
            for(int i = 0; i < a.length; i++) {
                c[i] = a[i] + b[i];
    	    }
            return c;
        }
    
        // TODO 返回一个数是不是2的幂次方
        public static boolean adjust(int x) {
            return (x & (x - 1)) == 0;
        }
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141

    4、复杂度分析

    传统方法和分治法的复杂度比较,如下图所示;

    在这里插入图片描述

    T ( n ) = { O ( 1 ) , n = 2 7 T ( n / 2 ) + O ( n 2 ) , n > 2 T(n) = \left\{

    O(1),n=27T(n/2)+O(n2),n>2" role="presentation">O(1),n=27T(n/2)+O(n2),n>2
    \right. T(n)={O(1),n=27T(n/2)+O(n2),n>2

    T(n) = 0(nlog7 ) = 0(n2.81)

    5、参考资料

    • 算法分析与设计(第四版)
    • 算法导论第三版
    • 博客园
  • 相关阅读:
    线程的创建和两种线程实现方式的区别
    MTX-Ovalbumin 甲氨蝶呤修饰卵清蛋白 Methotrexate-Ovalbumin
    hive表向es集群同步数据20230830
    glibc 里的线程 id
    css让图片的某些区域拉伸,其他部分保持比例,起到类似于安卓中点九.9图的效果
    《C++避坑神器·二十》C++智能指针简单使用
    公众号搜题查题接口 搜题公众号制作教程 附赠题库接口
    【知识管理】总纲
    漏洞复现--时空智友企业流程化管控系统敏感信息泄露(POC)
    【接口】HTTP(3) |GET和POST两种基本请求方法有什么区别
  • 原文地址:https://blog.csdn.net/m0_52735414/article/details/128006532