• 《算法导论》:矩阵乘法的Strassen算法代码实现


    一、原文及伪代码

    第四章-矩阵乘法的Strassen算法

    SQUARE-MATRIX-MULTIPLY-RECURSIVE(A,B)
    1 n = A.rows                                   //A的行数
    2 let C be a new n*n matrix                    //让C变成新的n*n矩阵
    3 if n == 1
    4     c11 = a11 * b11
    5 else partition A,B,and C as in equations     //将三个矩阵各自分成4个部分
          //分别求出四个元素
    6     C11 = SQUARE-MATRIX-MULTIPLY-RECURSIVE(A11,B11) 
             + SQUARE-MATRIX-MULTIPLY-RECURSIVE(A12,B21)
    7     C12 = SQUARE-MATRIX-MULTIPLY-RECURSIVE(A11,B12) 
             + SQUARE-MATRIX-MULTIPLY-RECURSIVE(A12,B22)
    8     C21 = SQUARE-MATRIX-MULTIPLY-RECURSIVE(A21,B11) 
             + SQUARE-MATRIX-MULTIPLY-RECURSIVE(A22,B21)
    9     C22 = SQUARE-MATRIX-MULTIPLY-RECURSIVE(A21,B12) 
             + SQUARE-MATRIX-MULTIPLY-RECURSIVE(A22,B22)
    10 return C
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    Strassen()
    let C be a new n*n matrix
    if A.row == 1:
        C = A * B
    else partition A,B,and C //步骤1:将四个矩阵各自分为四部分
        //步骤2:计算10个S
        S1=B12-B22
        S2=A11-A12
        S3=A21+A22
        S4=B21-B11
        S5=A11+A22
        S6=B11+B22
        S7=A12-A22
        S8=B21+B22
        S9=A11-A21
        S10=B11+B12
        //步骤3:递归计算7个矩阵积
        P1=Strassen(A11,S1)
        P2=Strassen(A11,B22)
        P3=Strassen(S3,B11)
        P4=Strassen(A22,S4)
        P5=Strassen(S5,S6)
        P6=Strassen(S7,S8)
        P7=Strassen(S9,S10)
        //步骤4:不同Pi的加减运算
        C11=P5+P4-P2+P6
        C12=P1+P2
        C21=P3+P4
        C22=P5+P1-P3-P7
        return C
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30

    二、C++代码

    #include 
    #include 
    using namespace std;
    template<typename T>
    class Strassen_class {
    public:
        void ADD(T** MatrixA, T** MatrixB, T** MatrixResult, int MatrixSize);
        void SUB(T** MatrixA, T** MatrixB, T** MatrixResult, int MatrixSize);
        void MUL(T** MatrixA, T** MatrixB, T** MatrixResult, int MatrixSize);//朴素算法实现
        void FillMatrix(T** MatrixA, T** MatrixB, int length);//A,B矩阵赋值
        void PrintMatrix(T** MatrixA, int MatrixSize);//打印矩阵
        void Strassen(int N, T** MatrixA, T** MatrixB, T** MatrixC);//Strassen算法实现
    };
    //矩阵相加
    template<typename T>
    void Strassen_class<T>::ADD(T** MatrixA, T** MatrixB, T** MatrixResult, int MatrixSize)
    {
        for (int i = 0; i < MatrixSize; i++)
        {
            for (int j = 0; j < MatrixSize; j++)
            {
                MatrixResult[i][j] = MatrixA[i][j] + MatrixB[i][j];
            }
        }
    }
    //矩阵相减
    template<typename T>
    void Strassen_class<T>::SUB(T** MatrixA, T** MatrixB, T** MatrixResult, int MatrixSize)
    {
        for (int i = 0; i < MatrixSize; i++)
        {
            for (int j = 0; j < MatrixSize; j++)
            {
                MatrixResult[i][j] = MatrixA[i][j] - MatrixB[i][j];
            }
        }
    }
    //普通的矩阵乘法
    template<typename T>
    void Strassen_class<T>::MUL(T** MatrixA, T** MatrixB, T** MatrixResult, int MatrixSize)
    {
        for (int i = 0; i < MatrixSize; i++)
        {
            for (int j = 0; j < MatrixSize; j++)
            {
                MatrixResult[i][j] = 0;
                for (int k = 0; k < MatrixSize; k++)
                {
                    MatrixResult[i][j] = MatrixResult[i][j] + MatrixA[i][k] * MatrixB[k][j];
                }
            }
        }
    }
    //A、B矩阵赋值
    template<typename T>
    void Strassen_class<T>::FillMatrix(T** MatrixA, T** MatrixB, int length)
    {
        for (int row = 0; row < length; row++)
        {
            for (int column = 0; column < length; column++)
            {
                //给矩阵里赋值0到4的随机数
                MatrixB[row][column] = (MatrixA[row][column] = rand() % 5);
            }
        }
    }
    //打印矩阵
    template<typename T>
    void Strassen_class<T>::PrintMatrix(T** MatrixA, int MatrixSize)
    {
        cout << endl;
        for (int row = 0; row < MatrixSize; row++)
        {
            for (int column = 0; column < MatrixSize; column++)
            {
                cout << MatrixA[row][column] << "\t";
                if ((column + 1) % ((MatrixSize)) == 0)
                    cout << endl;
            }
        }
        cout << endl;
    }
    
    //Strassen算法
    template<typename T>
    void Strassen_class<T>::Strassen(int N, T * *MatrixA, T * *MatrixB, T * *MatrixC)
    {
    
        int HalfSize = N / 2;
        int newSize = N / 2;
        //当不能分成4个4*4的数组时,我们就采用正常的办法
        if (N <= 64)    
        {
            MUL(MatrixA, MatrixB, MatrixC, N);
        }
        else
        {
            //创建多个二维数组
            T** A11; T** A12; T** A21; T** A22;
            T** B11; T** B12; T** B21; T** B22;
            T** C11; T** C12; T** C21; T** C22;
            T** M1; T** M2; T** M3; T** M4;
            T** M5; T** M6; T** M7;
            T** AResult; T** BResult;
            //创建一个一维数组的指针,用于寻找首地址
            A11 = new T * [newSize];
            A12 = new T * [newSize];
            A21 = new T * [newSize];
            A22 = new T * [newSize];
    
            B11 = new T * [newSize];
            B12 = new T * [newSize];
            B21 = new T * [newSize];
            B22 = new T * [newSize];
    
            C11 = new T * [newSize];
            C12 = new T * [newSize];
            C21 = new T * [newSize];
            C22 = new T * [newSize];
    
            M1 = new T * [newSize];
            M2 = new T * [newSize];
            M3 = new T * [newSize];
            M4 = new T * [newSize];
            M5 = new T * [newSize];
            M6 = new T * [newSize];
            M7 = new T * [newSize];
    
            AResult = new T * [newSize];
            BResult = new T * [newSize];
    
            int newLength = newSize;    //N/2长度
    
            //在上面一维数组的基础上,分别在每一行再创建一个一维数组的指针,从而实现一个二维数组
            for (int i = 0; i < newSize; i++)
            {
                A11[i] = new T[newLength];
                A12[i] = new T[newLength];
                A21[i] = new T[newLength];
                A22[i] = new T[newLength];
    
                B11[i] = new T[newLength];
                B12[i] = new T[newLength];
                B21[i] = new T[newLength];
                B22[i] = new T[newLength];
    
                C11[i] = new T[newLength];
                C12[i] = new T[newLength];
                C21[i] = new T[newLength];
                C22[i] = new T[newLength];
    
                M1[i] = new T[newLength];
                M2[i] = new T[newLength];
                M3[i] = new T[newLength];
                M4[i] = new T[newLength];
                M5[i] = new T[newLength];
                M6[i] = new T[newLength];
                M7[i] = new T[newLength];
    
                AResult[i] = new T[newLength];
                BResult[i] = new T[newLength];
            }
            //将输入的数组四等分成N/2*N/2的数组,将A和B中的数组各自赋值给自己的四个分支数组
            for (int i = 0; i < N / 2; i++)
            {
                for (int j = 0; j < N / 2; j++)
                {
                    A11[i][j] = MatrixA[i][j];
                    A12[i][j] = MatrixA[i][j + N / 2];
                    A21[i][j] = MatrixA[i + N / 2][j];
                    A22[i][j] = MatrixA[i + N / 2][j + N / 2];
    
                    B11[i][j] = MatrixB[i][j];
                    B12[i][j] = MatrixB[i][j + N / 2];
                    B21[i][j] = MatrixB[i + N / 2][j];
                    B22[i][j] = MatrixB[i + N / 2][j + N / 2];
                }
            }
    
            //计算7个矩阵
            //M1=A11(B12-B22)  
            SUB(B12, B22, BResult, HalfSize);     
            Strassen(HalfSize, A11, BResult, M1);
    
            //M2=(A11+A12)B22 
            ADD(A11, A12, AResult, HalfSize);    
            Strassen(HalfSize, AResult, B22, M2);
            
            //M3=(A21+A22)B11  
            ADD(A21, A22, AResult, HalfSize); 
            Strassen(HalfSize, AResult, B11, M3);
    
            //M4=A22(B21-B11)    
            SUB(B21, B11, BResult, HalfSize); 
            Strassen(HalfSize, A22, BResult, M4);
    
            //M5=(A11+A22)(B11+B22)
            ADD(A11, A22, AResult, HalfSize);
            ADD(B11, B22, BResult, HalfSize);    
            Strassen(HalfSize, AResult, BResult, M5); 
           
            //M6=(A12-A22)(B21+B22) 
            SUB(A12, A22, AResult, HalfSize);
            ADD(B21, B22, BResult, HalfSize);     
            Strassen(HalfSize, AResult, BResult, M6);
    
            //M7=(A11-A21)(B11+B12)
            SUB(A11, A21, AResult, HalfSize);
            ADD(B11, B12, BResult, HalfSize);    
            Strassen(HalfSize, AResult, BResult, M6);    
     
    
            //C11 = M5 + M4 - M2 + M6;
            ADD(M5, M4, AResult, HalfSize);
            SUB(M6, M2, BResult, HalfSize);
            ADD(AResult, BResult, C11, HalfSize);
    
            //C12 = M1 + M1;
            ADD(M1, M2, C12, HalfSize);
    
            //C21 = M3 + M4;
            ADD(M3, M4, C21, HalfSize);
    
            //C22 = M5 + M1 - M3 - M7;
            ADD(M5, M1, AResult, HalfSize);
            ADD(M7, M3, BResult, HalfSize);
            SUB(AResult, BResult, C22, HalfSize);
    
            //组合小矩阵到一个大矩阵
            for (int i = 0; i < N / 2; i++)
            {
                for (int j = 0; j < N / 2; j++)
                {
                    MatrixC[i][j] = C11[i][j];
                    MatrixC[i][j + N / 2] = C12[i][j];
                    MatrixC[i + N / 2][j] = C21[i][j];
                    MatrixC[i + N / 2][j + N / 2] = C22[i][j];
                }
            }
    
            // 释放矩阵内存空间
            for (int i = 0; i < newLength; i++)
            {
                delete[] A11[i]; delete[] A12[i]; delete[] A21[i];delete[] A22[i];
                delete[] B11[i]; delete[] B12[i]; delete[] B21[i];delete[] B22[i];
                delete[] C11[i]; delete[] C12[i]; delete[] C21[i];delete[] C22[i];
                delete[] M1[i]; delete[] M2[i]; delete[] M3[i]; delete[] M4[i];
                delete[] M5[i]; delete[] M6[i]; delete[] M7[i];
                delete[] AResult[i]; delete[] BResult[i];
             }
            delete[] A11; delete[] A12; delete[] A21; delete[] A22;
            delete[] B11; delete[] B12; delete[] B21; delete[] B22;
            delete[] C11; delete[] C12; delete[] C21; delete[] C22;
            delete[] M1; delete[] M2; delete[] M3; delete[] M4; 
            delete[] M5;delete[] M6; delete[] M7;
            delete[] AResult;delete[] BResult;
        }
    }
    
    int main()
    {
        Strassen_class<int> stra;//定义Strassen_class类对象
        int MatrixSize = 0;
    
        int** MatrixA;    //存放矩阵A
        int** MatrixB;    //存放矩阵B
        int** MatrixC;    //存放结果矩阵
        cout << "\n请输入矩阵大小(必须是2的幂指数值(例如:32,64,512,..): ";
        cin >> MatrixSize;
    
        int N = MatrixSize;//for readiblity.
    
        //申请内存
        MatrixA = new int* [MatrixSize];
        MatrixB = new int* [MatrixSize];
        MatrixC = new int* [MatrixSize];
        //申请空间
        for (int i = 0; i < MatrixSize; i++)
        {
            MatrixA[i] = new int[MatrixSize];
            MatrixB[i] = new int[MatrixSize];
            MatrixC[i] = new int[MatrixSize];
        }
        stra.FillMatrix(MatrixA, MatrixB, MatrixSize);  //矩阵赋值
        stra.Strassen(N, MatrixA, MatrixB, MatrixC); //strassen矩阵相乘算法
        cout << "\n矩阵运算结果... \n";
        stra.PrintMatrix(MatrixC, MatrixSize);
        return 0;
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229
    • 230
    • 231
    • 232
    • 233
    • 234
    • 235
    • 236
    • 237
    • 238
    • 239
    • 240
    • 241
    • 242
    • 243
    • 244
    • 245
    • 246
    • 247
    • 248
    • 249
    • 250
    • 251
    • 252
    • 253
    • 254
    • 255
    • 256
    • 257
    • 258
    • 259
    • 260
    • 261
    • 262
    • 263
    • 264
    • 265
    • 266
    • 267
    • 268
    • 269
    • 270
    • 271
    • 272
    • 273
    • 274
    • 275
    • 276
    • 277
    • 278
    • 279
    • 280
    • 281
    • 282
    • 283
    • 284
    • 285
    • 286
    • 287
    • 288
    • 289
  • 相关阅读:
    算法-贪心-112. 雷达设备
    Harmony Next 文件命令操作(发送、读取、媒体文件查询)
    低代码开发浅析
    springboot+火车票预订系统 毕业设计-附源码091029
    【JVM笔记】方法调用与返回字节码指令
    每日一问:Java中抽象类与抽象方法
    关于算子mindspore.nn.Conv2dTranspose没有output_padding参数
    go版本1.16.5 运行项目出现undefined: math.MaxInt报错
    hive表字段跟字段对应的值转为json数组
    一文速学-Pandas多文件批次聚合处理详解+实例代码
  • 原文地址:https://blog.csdn.net/m0_61843614/article/details/126692420