2.3. Matrix Multiply¶
The DSPLib contains one Matrix Multiply/GEMM (GEneral Matrix Multiply) solution. The gemm has two input ports connected to two windows of data. The inputs are denoted as Matrix A (inA) and Matrix B (inB). Matrix A has a template parameter TP_DIM_A to describe the number of rows of A. The number of columns of inA must be equal to the number of rows of inB. This is denoted with the template parameter TP_DIM_AB. The number of columns of B is denoted by TP_DIM_B.
An output port connects to a window, where the data for the output matrix will be stored. The output matrix will have rows = inA rows (TP_DIM_A) and columns = inB (TP_DIM_B) columns. The data type of both input matrices can be configured and the data type of the output is derived from the inputs.
2.3.1. Entry Point¶
The graph entry point is the following:
xf::dsp::aie::blas::matrix_mult::matrix_mult_graph
2.3.2. Supported Types¶
The Matrix Multiply supports a matrix of elements of integer type (int16, cint16, int32 or cint32) multiplied by a matrix of elements of integer type. It also supports a matrix of elements of float type (float, cfloat) multiplied by a matrix of elements of float type. However, a mix of integer types and float types is not supported.
2.3.3. Template Parameters¶
To see details on the template parameters for the Matrix Multiply, see API Reference Overview.
2.3.4. Access functions¶
To see details on the access functions for the Matrix Multiply, see API Reference Overview.
2.3.5. Ports¶
To see details on the ports for the Matrix Multiply, see API Reference Overview.
2.3.6. Design Notes¶
2.3.6.1. Tiling¶
Input matrices are processed in distinct blocks. Matrix elements must be rearranged into a specific pattern.
The following table demonstrates how a 16x16 input matrix should be rearranged into a 4x4 tiling pattern.
Note
Indices are quoted assuming a row major matrix. A column major matrix would be the transpose of the table below.
Tile Col 0 | Tile Col 1 | Tile Col 2 | Tile Col 3 | |||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Tile Row 0 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 |
16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | |
32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | |
48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | |
Tile Row 1 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 |
80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | |
96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | |
112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | |
Tile Row 2 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 |
144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | |
160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | |
176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | |
Tile Row 3 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 |
208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | |
224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | |
240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 |
This is stored contiguously in memory like:
0, 1, 2, 3, 16, 17, 18, 19, 32, 33, 34, 35, 48, 49, 50, 51, 4, 5, 6, 7, 20, 21, 22, 23, 36, 37, 38, 39, 52, 53, 54, 55, 8, 9, 10, 11, 24, 25, 26, 27, 40, 41, 42, 43, 56, 57, 58, 59, 12, 13, 14, 15, 28, 29, 30, 31, 44, 45, 46, 47, 60, 61, 62, 63, 64, 65, 66, 67, 80, 81, 82, 83, 96, 97, 98, 99, 112, 113, 114, 115, … , 204, 205, 206, 207, 220, 221, 222, 223, 236, 237, 238, 239, 252, 253, 254, 255
The following table demonstrates how a 16x16 input matrix should be rearranged into a 4x2 tiling pattern.
Tile Col 0 | Tile Col 1 | Tile Col 2 | Tile Col 3 | Tile Col 4 | Tile Col 5 | Tile Col 6 | Tile Col 7 | |||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Tile Row 0 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 |
16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | |
32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | |
48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | |
Tile Row 1 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 |
80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | |
96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | |
112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | |
Tile Row 2 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 |
144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | |
160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | |
176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | |
Tile Row 3 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 |
208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | |
224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | |
240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 |
This is stored contiguously in memory like:
0, 1, 16, 17, 32, 33, 48, 49, 2, 3, 18, 19, 34, 35, 50, 51, …, 206, 207, 222, 223, 238, 239, 254, 255
Multiplying a 16x16 matrix (with 4x4 tiling) with a 16x16 matrix (with 4x2 tiling) will result in a 16x16 matrix with 4x2 tiling.
The following table specifies the tiling scheme used for a given data type combination and the corresponding output data type:
Input Type Combination | Tiling Scheme | Output Type | ||
---|---|---|---|---|
A | B | A | B | |
int16 | int16 | 4x4 | 4x4 | int16 |
int16 | cint16 | 4x2 | 2x2 | cint16 |
int16 | int32 | 4x2 | 2x2 | int32 |
int16 | cint32 | 2x4 | 4x2 | cint32 |
cint16 | int16 | 4x4 | 4x2 | cint16 |
cint16 | cint16 | 4x4 | 4x2 | cint16 |
cint16 | int32 | 4x4 | 4x2 | cint32 |
cint16 | cint32 | 2x2 | 2x2 | cint32 |
int32 | int16 | 4x4 | 4x2 | int32 |
int32 | int32 | 4x4 | 4x2 | int32 |
int32 | cint16 | 4x4 | 4x2 | cint32 |
int32 | cint32 | 2x2 | 2x2 | cint32 |
cint32 | int16 | 2x4 | 4x2 | cint32 |
cint32 | cint16 | 2x2 | 2x2 | cint32 |
cint32 | int32 | 2x2 | 2x2 | cint32 |
cint32 | cint32 | 2x2 | 2x2 | cint32 |
float | float | 4x4 | 4x2 | float |
float | cfloat | 2x4 | 4x2 | cfloat |
cfloat | float | 2x4 | 4x2 | cfloat |
cfloat | cfloat | 4x2 | 2x2 | cfloat |
The parameters TP_ADD_TILING_A, TP_ADD_TILING_B, and TP_ADD_DETILING_OUT control the inclusion of an additional pre-processing / post-processing kernel to perform the required data data storage re-ordering. When used with TP_DIM_A_LEADING, TP_DIM_B_LEADING, or TP_DIM_OUT_LEADING, the matrix is also transposed in the tiling kernel.
If the additional kernels are not selected, then the matrix multiply kernels assume incoming data is in the correct format, as specified above. When using the TP_CASC_LEN parameter, the matrix multiply operation is split across TP_DIM_AB and processed in a TP_CASC_LEN number of kernels. The accumulated partial results of each kernel are passed down the cascade port to the next kernel in the cascade chain until the final kernel provides the expected output. Cascade connections are made internally to the matrix multiply graph.
Each AI Engine kernel in the array is given a sub-matrix, so the interface to the graph is an array of ports for both A and B.
Input Matrix A (16x16 - 4x4 Tile - Cascade Length 2):
AIE 0 | AIE 1 | |||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Tile Col 0 | Tile Col 1 | Tile Col 2 | Tile Col 3 | |||||||||||||
Tile Row 0 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 |
16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | |
32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | |
48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | |
Tile Row 1 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 |
80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | |
96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | |
112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | |
Tile Row 2 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 |
144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | |
160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | |
176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | |
Tile Row 3 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 |
208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | |
224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | |
240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 |
Input Matrix B (16x16 - 4x2 Tile - Cascade Length 2):
Tile Col 0 | Tile Col 1 | Tile Col 2 | Tile Col 3 | Tile Col 4 | Tile Col 5 | Tile Col 6 | Tile Col 7 | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
AIE 0 | Tile Row 0 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 |
16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | ||
32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | ||
48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | ||
Tile Row 1 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | |
80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | ||
96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | ||
112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | ||
AIE 1 | Tile Row 2 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 |
144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | ||
160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | ||
176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | ||
Tile Row 3 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | |
208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | ||
224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | ||
240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 |
Find a full list of descriptions and parameters in the API Reference Overview.
Connections to the cascade ports can be made as follows:
for (int i = 0 ; i < P_CASC_LEN; i++) { connect<>(inA[i], mmultGraph.inA[i]); connect<>(inB[i], mmultGraph.inB[i]); } connect<>(mmultGraph.out, out);
2.3.6.2. Constraints¶
A Matrix Multiply solution can consist of a cascade of kernels for the multiply operations themselved, but also tiling kernels on each input to each member of that cascade, and a tiling kernel on the output. The tiling kernels’ function is to convert between the arrangement of matrix elements in memory to a form of arrangement optimized for vector multiply, or vice versa. In the entry level graph, the following names are used to identify the various kernels as follows:
‘m_MatmultKernels’ - This is the array of kernel pointers returned by getKernels which point to the cascade TP_CASC_LEN of matrix multiply kernels. These kernels perform the matrix multiply operations.
‘untiler’ - This is a single kernel on on the output of the matrix multiply kernel or cascade of kernels. It performs the transformation from a tiled format to the output format.
‘tilerA’ - This is an array of TP_CASC_LEN kernels which connect 1:1 with the A input port of the matrix multiply kernels.
‘tilerB’ - This is an array of TP_CASC_LEN kernels which connect 1:1 with the B input port of the matrix multiply kernels.
2.3.7. Code Example including constraints¶
The following code example shows how the matrix_multiply_graph class may be used within a user super-graph, including how to set the runtime<ratio> of internal kernels. This example shows the matric multiplier configured to multiply a 32x16 matrix by a 16x32 matrix giving a 32x32 matrix.
#include <adf.h> #include "matrix_mult_graph.hpp" #define T_DATA_A cint16 #define T_DATA_B cint16 #define P_DIM_A 32 #define P_DIM_AB 16 #define P_DIM_B 32 #define P_SHIFT 16 #define P_ROUND_MODE 0 #define P_DIM_A_LEADING 0 #define P_DIM_B_LEADING 1 #define P_DIM_OUT_LEADING 0 #define P_ADD_TILING_A 0 #define P_ADD_TILING_B 0 #define P_ADD_DETILING_OUT 0 #define P_INPUT_WINDOW_VSIZE_A 512 #define P_INPUT_WINDOW_VSIZE_B 512 #define P_CASC_LEN 1 class myMM : public adf::graph { public: adf::port<input> inA; adf::port<input> inB; adf::port<output> out; xf::dsp::aie::blas::matrix_mult::matrix_mult_graph<T_DATA_A, T_DATA_B, P_DIM_A, P_DIM_AB, P_DIM_B, P_SHIFT, P_ROUND_MODE, P_DIM_A_LEADING, P_DIM_B_LEADING, P_DIM_OUT_LEADING, P_ADD_TILING_A, P_ADD_TILING_B, P_ADD_DETILING_OUT, P_INPUT_WINDOW_VSIZE_A, P_INPUT_WINDOW_VSIZE_B, P_CASC_LEN> matrixMult; myMM() { adf::connect<> net0(inA , matrixMult.inA); adf::connect<> net1(inB , matrixMult.inB); adf::connect<> net2(matrixMult.out , out); adf::kernel *kernels = matrixMult.getKernels(); for(int i=0; i<TP_CASC_LEN; i++) { adf::runtime<ratio>(kernels[i]) = 0.7; adf::runtime<ratio>(&matrixMult.tilerA[i]) = 0.5; adf::runtime<ratio>(&matrixMult.tilerA[i]) = 0.5; } adf::runtime<ratio>(&matrixMult.untiler) = 0.5; };