2.3. Matrix Multiply

The DSPLib contains one Matrix Multiply/GEMM (GEneral Matrix Multiply) solution. The gemm has two input ports connected to two windows of data. The inputs are denoted as Matrix A (inA) and Matrix B (inB). Matrix A has a template parameter TP_DIM_A to describe the number of rows of A. The number of columns of inA must be equal to the number of rows of inB. This is denoted with the template parameter TP_DIM_AB. The number of columns of B is denoted by TP_DIM_B.

An output port connects to a window, where the data for the output matrix will be stored. The output matrix will have rows = inA rows (TP_DIM_A) and columns = inB (TP_DIM_B) columns. The data type of both input matrices can be configured and the data type of the output is derived from the inputs.

2.3.1. Entry Point

The graph entry point is the following:

xf::dsp::aie::blas::matrix_mult::matrix_mult_graph

2.3.2. Supported Types

The Matrix Multiply supports a matrix of elements of integer type (int16, cint16, int32 or cint32) multiplied by a matrix of elements of integer type. It also supports a matrix of elements of float type (float, cfloat) multiplied by a matrix of elements of float type. However, a mix of integer types and float types is not supported.

2.3.3. Template Parameters

To see details on the template parameters for the Matrix Multiply, see API Reference Overview.

2.3.4. Access functions

To see details on the access functions for the Matrix Multiply, see API Reference Overview.

2.3.5. Ports

To see details on the ports for the Matrix Multiply, see API Reference Overview.

2.3.6. Design Notes

2.3.6.1. Tiling

Input matrices are processed in distinct blocks. Matrix elements must be rearranged into a specific pattern.

The following table demonstrates how a 16x16 input matrix should be rearranged into a 4x4 tiling pattern.

Note

Indices are quoted assuming a row major matrix. A column major matrix would be the transpose of the table below.

Table 7 : Matrix Multiply 4x4 tiling pattern
  Tile Col 0 Tile Col 1 Tile Col 2 Tile Col 3
Tile Row 0 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
Tile Row 1 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
Tile Row 2 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
Tile Row 3 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255

This is stored contiguously in memory like:

0, 1, 2, 3, 16, 17, 18, 19, 32, 33, 34, 35, 48, 49, 50, 51, 4, 5, 6, 7, 20, 21, 22, 23, 36, 37, 38, 39, 52, 53, 54, 55, 8, 9, 10, 11, 24, 25, 26, 27, 40, 41, 42, 43, 56, 57, 58, 59, 12, 13, 14, 15, 28, 29, 30, 31, 44, 45, 46, 47, 60, 61, 62, 63, 64, 65, 66, 67, 80, 81, 82, 83, 96, 97, 98, 99, 112, 113, 114, 115, … , 204, 205, 206, 207, 220, 221, 222, 223, 236, 237, 238, 239, 252, 253, 254, 255

The following table demonstrates how a 16x16 input matrix should be rearranged into a 4x2 tiling pattern.

Table 8 : Matrix Multiply 4x2 tiling pattern
  Tile Col 0 Tile Col 1 Tile Col 2 Tile Col 3 Tile Col 4 Tile Col 5 Tile Col 6 Tile Col 7
Tile Row 0 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
Tile Row 1 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
Tile Row 2 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
Tile Row 3 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255

This is stored contiguously in memory like:

0, 1, 16, 17, 32, 33, 48, 49, 2, 3, 18, 19, 34, 35, 50, 51, …, 206, 207, 222, 223, 238, 239, 254, 255

Multiplying a 16x16 matrix (with 4x4 tiling) with a 16x16 matrix (with 4x2 tiling) will result in a 16x16 matrix with 4x2 tiling.

The following table specifies the tiling scheme used for a given data type combination and the corresponding output data type:

Table 9 : Matrix Multiply tiling pattern combination
Input Type Combination Tiling Scheme Output Type
A B A B  
int16 int16 4x4 4x4 int16
int16 cint16 4x2 2x2 cint16
int16 int32 4x2 2x2 int32
int16 cint32 2x4 4x2 cint32
cint16 int16 4x4 4x2 cint16
cint16 cint16 4x4 4x2 cint16
cint16 int32 4x4 4x2 cint32
cint16 cint32 2x2 2x2 cint32
int32 int16 4x4 4x2 int32
int32 int32 4x4 4x2 int32
int32 cint16 4x4 4x2 cint32
int32 cint32 2x2 2x2 cint32
cint32 int16 2x4 4x2 cint32
cint32 cint16 2x2 2x2 cint32
cint32 int32 2x2 2x2 cint32
cint32 cint32 2x2 2x2 cint32
float float 4x4 4x2 float
float cfloat 2x4 4x2 cfloat
cfloat float 2x4 4x2 cfloat
cfloat cfloat 4x2 2x2 cfloat

The parameters TP_ADD_TILING_A, TP_ADD_TILING_B, and TP_ADD_DETILING_OUT control the inclusion of an additional pre-processing / post-processing kernel to perform the required data data storage re-ordering. When used with TP_DIM_A_LEADING, TP_DIM_B_LEADING, or TP_DIM_OUT_LEADING, the matrix is also transposed in the tiling kernel.

If the additional kernels are not selected, then the matrix multiply kernels assume incoming data is in the correct format, as specified above. When using the TP_CASC_LEN parameter, the matrix multiply operation is split across TP_DIM_AB and processed in a TP_CASC_LEN number of kernels. The accumulated partial results of each kernel are passed down the cascade port to the next kernel in the cascade chain until the final kernel provides the expected output. Cascade connections are made internally to the matrix multiply graph.

Each AI Engine kernel in the array is given a sub-matrix, so the interface to the graph is an array of ports for both A and B.

Input Matrix A (16x16 - 4x4 Tile - Cascade Length 2):

Table 10 : Input Matrix A (16x16 - 4x4 Tile - Cascade Length 2)
  AIE 0 AIE 1
  Tile Col 0 Tile Col 1 Tile Col 2 Tile Col 3
Tile Row 0 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
Tile Row 1 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
Tile Row 2 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
Tile Row 3 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255

Input Matrix B (16x16 - 4x2 Tile - Cascade Length 2):

Table 11 : Input Matrix B (16x16 - 4x2 Tile - Cascade Length 2)
    Tile Col 0 Tile Col 1 Tile Col 2 Tile Col 3 Tile Col 4 Tile Col 5 Tile Col 6 Tile Col 7
AIE 0 Tile Row 0 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
Tile Row 1 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
AIE 1 Tile Row 2 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
Tile Row 3 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255

Find a full list of descriptions and parameters in the API Reference Overview.

Connections to the cascade ports can be made as follows:

for (int i = 0 ; i < P_CASC_LEN; i++) {
    connect<>(inA[i], mmultGraph.inA[i]);
    connect<>(inB[i], mmultGraph.inB[i]);
}
connect<>(mmultGraph.out, out);

2.3.6.2. Constraints

A Matrix Multiply solution can consist of a cascade of kernels for the multiply operations themselved, but also tiling kernels on each input to each member of that cascade, and a tiling kernel on the output. The tiling kernels’ function is to convert between the arrangement of matrix elements in memory to a form of arrangement optimized for vector multiply, or vice versa. In the entry level graph, the following names are used to identify the various kernels as follows:

‘m_MatmultKernels’ - This is the array of kernel pointers returned by getKernels which point to the cascade TP_CASC_LEN of matrix multiply kernels. These kernels perform the matrix multiply operations.

‘untiler’ - This is a single kernel on on the output of the matrix multiply kernel or cascade of kernels. It performs the transformation from a tiled format to the output format.

‘tilerA’ - This is an array of TP_CASC_LEN kernels which connect 1:1 with the A input port of the matrix multiply kernels.

‘tilerB’ - This is an array of TP_CASC_LEN kernels which connect 1:1 with the B input port of the matrix multiply kernels.

2.3.7. Code Example including constraints

The following code example shows how the matrix_multiply_graph class may be used within a user super-graph, including how to set the runtime<ratio> of internal kernels. This example shows the matric multiplier configured to multiply a 32x16 matrix by a 16x32 matrix giving a 32x32 matrix.

#include <adf.h>
#include "matrix_mult_graph.hpp"
#define T_DATA_A cint16
#define T_DATA_B cint16
#define P_DIM_A 32
#define P_DIM_AB 16
#define P_DIM_B 32
#define P_SHIFT 16
#define P_ROUND_MODE 0
#define P_DIM_A_LEADING 0
#define P_DIM_B_LEADING 1
#define P_DIM_OUT_LEADING 0
#define P_ADD_TILING_A 0
#define P_ADD_TILING_B 0
#define P_ADD_DETILING_OUT 0
#define P_INPUT_WINDOW_VSIZE_A 512
#define P_INPUT_WINDOW_VSIZE_B 512
#define P_CASC_LEN 1

class myMM : public adf::graph
{
public:
  adf::port<input> inA;
  adf::port<input> inB;
  adf::port<output> out;
  xf::dsp::aie::blas::matrix_mult::matrix_mult_graph<T_DATA_A,
                                              T_DATA_B,
                                              P_DIM_A,
                                              P_DIM_AB,
                                              P_DIM_B,
                                              P_SHIFT,
                                              P_ROUND_MODE,
                                              P_DIM_A_LEADING,
                                              P_DIM_B_LEADING,
                                              P_DIM_OUT_LEADING,
                                              P_ADD_TILING_A,
                                              P_ADD_TILING_B,
                                              P_ADD_DETILING_OUT,
                                              P_INPUT_WINDOW_VSIZE_A,
                                              P_INPUT_WINDOW_VSIZE_B,
                                              P_CASC_LEN> matrixMult;
  myMM()
  {
    adf::connect<> net0(inA , matrixMult.inA);
    adf::connect<> net1(inB , matrixMult.inB);
    adf::connect<> net2(matrixMult.out , out);
    adf::kernel *kernels = matrixMult.getKernels();
    for(int i=0; i<TP_CASC_LEN; i++)
    {
      adf::runtime<ratio>(kernels[i]) = 0.7;
      adf::runtime<ratio>(&matrixMult.tilerA[i]) = 0.5;
      adf::runtime<ratio>(&matrixMult.tilerA[i]) = 0.5;
  }
  adf::runtime<ratio>(&matrixMult.untiler) = 0.5;
};