MLIR-AIE
AIEVecToLLVM.cpp
Go to the documentation of this file.
1//===- AIEVecToLLVM.cpp - AIEVec to LLVM dialect conversion ---------------===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2022 Xilinx Inc.
8// (c) Copyright 2024 Advanced Micro Devices Inc.
9//
10//===----------------------------------------------------------------------===//
11
12#include "../PassDetail.h"
13
20#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
21#include "mlir/Conversion/LLVMCommon/Pattern.h"
22#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
23#include "mlir/Dialect/UB/IR/UBOps.h"
24#include "mlir/IR/TypeUtilities.h"
25#include <sstream>
26
27using namespace mlir;
28
29namespace xilinx::aievec {
30
31inline static Value bitcastValueToType(OpBuilder &builder, Location loc,
32 Value val, Type dstTy) {
33 return builder.create<LLVM::BitcastOp>(loc, dstTy, val).getResult();
34}
35
36// This function emits the instructions required to widen a 128b input vector
37// into a 512b encoded as a vector<16xi32>. It first bitcasts it to a
38// vector<4xi32> to respect the intrinsic signature.
39inline static Value widen128bVectorValueTo512b(OpBuilder &builder, Location loc,
40 Value val) {
41 return builder
42 .create<xllvm::VectorSetI512I128IntrOp>(
43 loc, VectorType::get({16}, builder.getI32Type()),
44 bitcastValueToType(builder, loc, val,
45 VectorType::get({4}, builder.getI32Type())))
46 .getResult();
47}
48
49// This function emits the instructions required to widen a 256b input vector
50// into a 512b encoded as a vector<16xi32>. It first bitcasts it to a
51// vector<8xi32> to respect the intrinsic signature. It will also materialize
52// a constant 0, used as an insertion index.
53inline static Value widen256bVectorValueTo512b(OpBuilder &builder, Location loc,
54 Value val) {
55 auto cst0 =
56 builder.create<LLVM::ConstantOp>(loc, builder.getI32Type(), (int32_t)0);
57 return builder
58 .create<xllvm::VectorSetI512I256IntrOp>(
59 loc, VectorType::get({16}, builder.getI32Type()),
60 bitcastValueToType(builder, loc, val,
61 VectorType::get({8}, builder.getI32Type())),
62 cst0)
63 .getResult();
64}
65
66// This function emits the sequence of operations that forces a value into a
67// specific type. This may include widening vectors to match a specific bit
68// length.
69static Value forceCastValueToType(OpBuilder &builder, Location loc, Value val,
70 Type type) {
71 auto valTy = val.getType();
72 if (valTy == type)
73 return val;
74 auto srcVecTy = dyn_cast<VectorType>(valTy);
75 if (srcVecTy) {
76 auto dstVecTy = dyn_cast<VectorType>(type);
77 assert(dstVecTy && "vector values cannot be forced into a non-vector type");
78 assert(srcVecTy.getRank() == 1 && dstVecTy.getRank() == 1 &&
79 "only flat 1D vectors can be force casted");
80 int64_t dstVecLength =
81 dstVecTy.getElementTypeBitWidth() * dstVecTy.getShape()[0];
82 int64_t srcVecLength =
83 srcVecTy.getElementTypeBitWidth() * srcVecTy.getShape()[0];
84 if (srcVecLength != dstVecLength) {
85 assert(srcVecLength < dstVecLength &&
86 "only widening forced casts are supported");
87 assert(dstVecLength == 512 &&
88 (srcVecLength == 128 || srcVecLength == 256) &&
89 "only 128b to 512b and 256b to 512b forced casts are supported");
90 if (srcVecLength == 128)
91 val = widen128bVectorValueTo512b(builder, loc, val);
92 else
93 val = widen256bVectorValueTo512b(builder, loc, val);
94 }
95 }
96 return bitcastValueToType(builder, loc, val, type);
97}
98
99// This function emits the sequence of operations that forces a range of values
100// to match the signature specified by the TypeRange. It can be used to convert
101// the parameters of an op being converted to the types accepted by an
102// intrinsic with a fixed signature that treats its inputs as "bags of bits".
103static SmallVector<Value> forceCastOperandsToSignature(OpBuilder &builder,
104 Location loc,
105 ValueRange operands,
106 TypeRange signature) {
107 return llvm::to_vector(llvm::map_range(
108 llvm::zip_equal(operands, signature), [&](auto &&vt) -> Value {
109 return forceCastValueToType(builder, loc, std::get<0>(vt),
110 std::get<1>(vt));
111 }));
112}
113
115 uint32_t start;
116 uint32_t offsets;
117 uint32_t offsets_hi;
118 uint32_t step;
119 uint32_t square;
120};
121
122// sgn_x: Sign mask of matrix X. If it is one matrix X is interpreted as
123// signed, else it treated as unsigned.
124// sgn_y: Sign mask of matrix Y. If it is one matrix Y is interpreted as
125// signed, else it treated as unsigned.
126// amode/bmode/variant: config acc width, mul precision, and mul mode
127// zero_acc: Zeroing of acc1. If it is one then acc1 is zeroed.
128// shift16: Shift mask of acc1. If a bit is set the <<16 operation will be
129// executed on acc1.
130// sub_mul: Negation mask of the matrix multiplication result. If it is
131// one the result of the operation will be negated.
132// sub_acc1: Negation mask of acc1. If it is one acc1 will be negated.
133// sub_acc2: Negation mask of acc2. If it is one acc2 will be negated.
134// sub_mask: Negation mask of complex multiplications. Negates a term of a
135// complex multiplication.
136static inline int aiev2_vmac_compute_control(int sgn_x, int sgn_y, int amode,
137 int bmode, int variant,
138 int zero_acc, int shift16,
139 int sub_mul, int sub_acc1,
140 int sub_acc2, int sub_mask) {
141 return ((unsigned)sub_mask << 16) | ((unsigned)shift16 << 10) |
142 ((unsigned)sub_mul << 11) | ((unsigned)sub_acc1 << 12) |
143 ((unsigned)sub_acc2 << 13) | ((unsigned)amode << 1) |
144 ((unsigned)bmode << 3) | ((unsigned)variant << 5) |
145 (((unsigned)sgn_x << 9) | ((unsigned)sgn_y << 8)) |
146 ((unsigned)zero_acc << 0);
147}
148
149std::string getVectorTypeString(VectorType type, bool abbrev = false,
150 bool acc = false) {
151 std::stringstream ss;
152 auto size = getVectorLaneSize(type);
153 ss << "v" << size;
154 if (auto intType = dyn_cast<IntegerType>(type.getElementType())) {
155 ss << (acc ? "acc" : abbrev ? "i" : "int") << intType.getWidth();
156 } else if (dyn_cast<FloatType>(type.getElementType())) {
157 ss << (abbrev ? "f" : "float");
158 }
159 return ss.str();
160}
161
162std::string getMulOrFMAIntrinsicName(Operation *op) {
163 std::string baseName;
164 Value lhs, result;
165 if (auto mulOp = dyn_cast<aievec::aie1::MulOp>(op)) {
166 baseName = "mul";
167 lhs = mulOp.getLhs();
168 result = mulOp.getResult();
169 } else if (auto fmaOp = dyn_cast<aievec::aie1::FMAOp>(op)) {
170 baseName = "mac";
171 lhs = fmaOp.getLhs();
172 result = fmaOp.getResult();
173 }
174 VectorType resultType = cast<VectorType>(result.getType());
175 int resultSize = getVectorLaneSize(resultType);
176 std::stringstream ss;
177 ss << "llvm.aie.";
178 if (dyn_cast<IntegerType>(resultType.getElementType())) {
179 ss << baseName;
180 ss << resultSize << "."
181 << getVectorTypeString(cast<VectorType>(lhs.getType()));
182 } else if (dyn_cast<FloatType>(resultType.getElementType())) {
183 ss << "vfp" << baseName;
184 }
185 return ss.str();
186}
187
188// Squashes the easy-to-read 16-bit square encoding into
189// the 8-bit encoding the configuration register uses
190uint32_t encodeSquare(uint32_t square) {
191 uint32_t out = 0;
192 out |= ((square >> 0) & 0x3) << 0;
193 out |= ((square >> 4) & 0x3) << 2;
194 out |= ((square >> 8) & 0x3) << 4;
195 out |= ((square >> 12) & 0x3) << 6;
196 return out & 0xFF;
197}
198
199// Encode the configuration register with buffer parameters and options
200// TODO: struct to handle this?
201void encodeConf(uint32_t conf[2], const BufferParams &x, const BufferParams &z,
202 bool sub) {
203 conf[0] |= ((x.step & 0x3F) << 0) | ((z.step & 0x3F) << 8);
204 conf[1] |= (encodeSquare(x.square) << 0) | (encodeSquare(z.square) << 8);
205 conf[1] |= sub << 17;
206}
207
209 : public mlir::ConvertOpToLLVMPattern<aievec::aie1::AddOp> {
210public:
211 using ConvertOpToLLVMPattern<aievec::aie1::AddOp>::ConvertOpToLLVMPattern;
212
213 LogicalResult
214 matchAndRewrite(aievec::aie1::AddOp op, OpAdaptor adaptor,
215 ConversionPatternRewriter &rewriter) const override {
216 op.emitWarning() << "aie.add conversion is not implemented\n";
217 return failure();
218 }
219};
220
222 : public mlir::ConvertOpToLLVMPattern<aievec::aie1::SubOp> {
223public:
224 using ConvertOpToLLVMPattern<aievec::aie1::SubOp>::ConvertOpToLLVMPattern;
225
226 LogicalResult
227 matchAndRewrite(aievec::aie1::SubOp op, OpAdaptor adaptor,
228 ConversionPatternRewriter &rewriter) const override {
229 op.emitWarning() << "aie.sub conversion is not implemented\n";
230 return failure();
231 }
232};
233
235 : public mlir::ConvertOpToLLVMPattern<aievec::aie1::FMAOp> {
236public:
237 using ConvertOpToLLVMPattern<aievec::aie1::FMAOp>::ConvertOpToLLVMPattern;
238
239 LogicalResult
240 matchAndRewrite(aievec::aie1::FMAOp op, OpAdaptor adaptor,
241 ConversionPatternRewriter &rewriter) const override {
242 auto module = op->getParentOfType<ModuleOp>();
243 MLIRContext *context = rewriter.getContext();
244
245 auto startType = IntegerType::get(context, 32);
246 auto offsetsType = VectorType::get({2}, IntegerType::get(context, 32));
247 auto confType = VectorType::get({2}, IntegerType::get(context, 32));
248
249 // If the intrinsic declaration doesn't exist, create it
250 std::string intrinsicName = getMulOrFMAIntrinsicName(op);
251 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(
252 StringAttr::get(context, intrinsicName));
253
254 if (!func) {
255 OpBuilder::InsertionGuard guard(rewriter);
256 rewriter.setInsertionPointToStart(module.getBody());
257 func = rewriter.create<LLVM::LLVMFuncOp>(
258 rewriter.getUnknownLoc(), intrinsicName,
259 LLVM::LLVMFunctionType::get(
260 op.getResult().getType(),
261 {op.getLhs().getType(), op.getRhs().getType(),
262 op.getAcc().getType(), startType, /* xstart */
263 startType, /* ystart */
264 startType, /* zstart */
265 offsetsType, /* xoffsets */
266 offsetsType, /* zoffsets */
267 confType}));
268 }
269
270 // Parse the string attribute values
271 BufferParams x = {};
272 BufferParams z = {};
273 op.getXstart().getAsInteger(0, x.start);
274 op.getXoffsets().getAsInteger(0, x.offsets);
275 op.getXoffsetsHi().getAsInteger(0, x.offsets_hi);
276 op.getXstep().getAsInteger(0, x.step);
277 op.getXsquare().getAsInteger(0, x.square);
278 op.getZstart().getAsInteger(0, z.start);
279 op.getZoffsets().getAsInteger(0, z.offsets);
280 op.getZoffsetsHi().getAsInteger(0, z.offsets_hi);
281 op.getZstep().getAsInteger(0, z.step);
282 op.getZsquare().getAsInteger(0, z.square);
283
284 // Encode the configuration register
285 uint32_t conf[2] = {0, 0};
286 encodeConf(conf, x, z, op.getFmsub());
287
288 // Create the constants and replace the op
289 auto xstartVal = rewriter.create<LLVM::ConstantOp>(
290 op->getLoc(), startType, rewriter.getI32IntegerAttr(x.start));
291 auto ystartVal = rewriter.create<LLVM::ConstantOp>(
292 op->getLoc(), startType, rewriter.getI32IntegerAttr(0));
293 auto zstartVal = rewriter.create<LLVM::ConstantOp>(
294 op->getLoc(), startType, rewriter.getI32IntegerAttr(z.start));
295 auto xoffsetsVal = rewriter.create<LLVM::ConstantOp>(
296 op->getLoc(), offsetsType,
297 rewriter.getI32VectorAttr({(int32_t)x.offsets, (int32_t)x.offsets_hi}));
298 auto zoffsetsVal = rewriter.create<LLVM::ConstantOp>(
299 op->getLoc(), offsetsType,
300 rewriter.getI32VectorAttr({(int32_t)z.offsets, (int32_t)z.offsets_hi}));
301 auto confVal = rewriter.create<LLVM::ConstantOp>(
302 op->getLoc(), confType,
303 rewriter.getI32VectorAttr({(int32_t)conf[0], (int32_t)conf[1]}));
304 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
305 op, func,
306 ValueRange{op.getLhs(), op.getRhs(), op.getAcc(), xstartVal, ystartVal,
307 zstartVal, xoffsetsVal, zoffsetsVal, confVal});
308 return success();
309 }
310};
311
313 : public mlir::ConvertOpToLLVMPattern<aievec::aie1::MulOp> {
314public:
315 using ConvertOpToLLVMPattern<aievec::aie1::MulOp>::ConvertOpToLLVMPattern;
316
317 LogicalResult
318 matchAndRewrite(aievec::aie1::MulOp op, OpAdaptor adaptor,
319 ConversionPatternRewriter &rewriter) const override {
320 auto module = op->getParentOfType<ModuleOp>();
321 MLIRContext *context = rewriter.getContext();
322
323 auto startType = IntegerType::get(context, 32);
324 auto offsetsType = VectorType::get({2}, IntegerType::get(context, 32));
325 auto confType = VectorType::get({2}, IntegerType::get(context, 32));
326
327 // If the intrinsic declaration doesn't exist, create it
328 std::string intrinsicName = getMulOrFMAIntrinsicName(op);
329 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(
330 StringAttr::get(context, intrinsicName));
331
332 if (!func) {
333 OpBuilder::InsertionGuard guard(rewriter);
334 rewriter.setInsertionPointToStart(module.getBody());
335 func = rewriter.create<LLVM::LLVMFuncOp>(
336 rewriter.getUnknownLoc(), intrinsicName,
337 LLVM::LLVMFunctionType::get(op.getResult().getType(),
338 {op.getLhs().getType(),
339 op.getRhs().getType(),
340 startType, /* xstart */
341 startType, /* ystart */
342 startType, /* zstart */
343 offsetsType, /* xoffsets */
344 offsetsType, /* zoffsets */
345 confType}));
346 }
347
348 // Parse the string attribute values
349 BufferParams x = {};
350 BufferParams z = {};
351 op.getXstart().getAsInteger(0, x.start);
352 op.getXoffsets().getAsInteger(0, x.offsets);
353 op.getXoffsetsHi().getAsInteger(0, x.offsets_hi);
354 op.getXstep().getAsInteger(0, x.step);
355 op.getXsquare().getAsInteger(0, x.square);
356 op.getZstart().getAsInteger(0, z.start);
357 op.getZoffsets().getAsInteger(0, z.offsets);
358 op.getZoffsetsHi().getAsInteger(0, z.offsets_hi);
359 op.getZstep().getAsInteger(0, z.step);
360 op.getZsquare().getAsInteger(0, z.square);
361
362 // Encode the configuration register
363 uint32_t conf[2] = {0, 0};
364 encodeConf(conf, x, z, false);
365
366 // Create the constants and replace the op
367 auto xstartVal = rewriter.create<LLVM::ConstantOp>(
368 op->getLoc(), startType, rewriter.getI32IntegerAttr(x.start));
369 auto ystartVal = rewriter.create<LLVM::ConstantOp>(
370 op->getLoc(), startType, rewriter.getI32IntegerAttr(0));
371 auto zstartVal = rewriter.create<LLVM::ConstantOp>(
372 op->getLoc(), startType, rewriter.getI32IntegerAttr(z.start));
373 auto xoffsetsVal = rewriter.create<LLVM::ConstantOp>(
374 op->getLoc(), offsetsType,
375 rewriter.getI32VectorAttr({(int32_t)x.offsets, (int32_t)x.offsets_hi}));
376 auto zoffsetsVal = rewriter.create<LLVM::ConstantOp>(
377 op->getLoc(), offsetsType,
378 rewriter.getI32VectorAttr({(int32_t)z.offsets, (int32_t)z.offsets_hi}));
379 auto confVal = rewriter.create<LLVM::ConstantOp>(
380 op->getLoc(), confType,
381 rewriter.getI32VectorAttr({(int32_t)conf[0], (int32_t)conf[1]}));
382 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
383 op, func,
384 ValueRange{op.getLhs(), op.getRhs(), xstartVal, ystartVal, zstartVal,
385 xoffsetsVal, zoffsetsVal, confVal});
386 return success();
387 }
388};
389
391 : public mlir::ConvertOpToLLVMPattern<aievec::MulElemOp> {
392public:
393 using ConvertOpToLLVMPattern<aievec::MulElemOp>::ConvertOpToLLVMPattern;
394
395 MulElemOpConversion(const LLVMTypeConverter &typeConverter,
396 Aie2Fp32Emulation aie2Fp32EmulationOption)
397 : ConvertOpToLLVMPattern(typeConverter),
399
400 Aie2Fp32Emulation aie2Fp32EmulationOption;
401
403 enum class Kind {
404 // DtIn0_DtIn1_DtRes_CxMxKxN
411 // TODO: I16_I16_I64_16x1x2x1
412 };
413
415 int conf;
416 };
417
418 static DecodedMulElemOp decodeMulElemOp(OpAdaptor op) {
419 auto lhs = op.getLhs();
420 auto lhsVecTy = cast<VectorType>(lhs.getType());
421 auto lhsScaTy = lhsVecTy.getElementType();
422 unsigned lhsBitWidth = lhsScaTy.getIntOrFloatBitWidth();
423
424 // Integer types
425 if (llvm::isa<IntegerType>(lhsScaTy)) {
426 if (lhsBitWidth == 8) {
428 aiev2_vmac_compute_control(
429 /*sgn_x=*/1, /*sgn_y=*/1, /*amode=*/0, /*bmode=*/1,
430 /*variant=*/1, /*zero_acc=*/0, /*shift16=*/0,
431 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
432 /*sub_mask=*/0)};
433 } else if (lhsBitWidth == 16) {
435 aiev2_vmac_compute_control(
436 /*sgn_x=*/1, /*sgn_y=*/1, /*amode=*/0, /*bmode=*/3,
437 /*variant=*/1, /*zero_acc=*/0, /*shift16=*/0,
438 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
439 /*sub_mask=*/0)};
440 } else if (lhsBitWidth == 32) {
441 // emulated I32 mul_elem
443 }
444 } else {
445 // Float types
446 if (lhsBitWidth == 16) {
448 aiev2_vmac_compute_control(
449 /*sgn_x=*/0, /*sgn_y=*/0, /*amode=*/2, /*bmode=*/3,
450 /*variant=*/1, /*zero_acc=*/0, /*shift16=*/0,
451 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
452 /*sub_mask=*/0)};
453 } else if (lhsBitWidth == 32) {
454 // emulated FP32 mul_elem
456 }
457 }
458
460 }
461
462 // This conversion pattern implements the below CPP emulated I32 mul_elem.
463 // INTRINSIC(v16acc64)
464 // mul_elem_16_2(v16int32 a0, v16int32 a1, v16int32 b0, v16int32 b1) {
465 // v32uint16 a_lo = (v32uint16)shuffle(a0, a1, 2);
466 // v32int16 a_hi = (v32int16)shuffle(a0, a1, 3);
467 // v32uint16 b_lo = (v32uint16)shuffle(b0, b1, 2);
468 // v32int16 b_hi = (v32int16)shuffle(b0, b1, 3);
469 // v16acc64 acc = ::mul_elem_16_2(a_hi, b_hi);
470 // acc = mac_elem_16_2_conf(a_hi, 1, b_lo, false, acc, 0, 1, 0, 0);
471 // acc = mac_elem_16_2_conf(a_lo, false, b_hi, 1, acc, 0, 0, 0, 0);
472 // acc = mac_elem_16_2_conf(a_lo, false, b_lo, false, acc, 0, 1, 0, 0);
473 // return acc;
474 // }
475 // Caller example when handling the elementwise mul of two v16int32 vectors.
476 // v16int32 v1 = LHS();
477 // v16int32 v2 = RHS();
478 // v16acc64 v3 = mul_elem_16_2(v1, broadcast_zero_s32(), v2,
479 // undef_v16int32());
480 // Explantion:
481 // a_lo = low_part(a0[0]--a0[15], a1[0]--a1[15])
482 // a_hi = high_part(a0[0]--a0[15], a1[0]--a1[15])
483 // b_lo = low_part(b0[0]--b0[15], b1[0]--b1[15])
484 // b_hi = high_part(b0[0]--b0[15], b1[0]--b1[15])
485 // The firt `acc` is from mul_elem_16_2(a_hi, b_hi), which performs 16 channel
486 // of 1x2x1 matmul, acc[0] = a_hi[0]*b_hi[0]+a_hi[16]*b_hi[16], ... , acc[15]
487 // = a_hi[15]*b_hi[15]+a_hi[31]*b_hi[31]. Then, the first MAC performs `acc`
488 // left shift 16bit, and then 16 channel of 1x2x1 matmul (a_hi, b_lo)
489 // accumulating to `acc`. The second MAC performs 16 channel of 1x2x1 matmul
490 // (a_lo, b_hi) accumulating to `acc`. Finally, the third MAC performs 16
491 // channel of 1x2x1 matmul (a_lo, b_hi) accumulating to `acc`.
492 LogicalResult
493 convertToEmulatedI32MulElem(aievec::MulElemOp op, OpAdaptor adaptor,
494 ConversionPatternRewriter &rewriter) const {
495
496 Location loc = op.getLoc();
497 auto zeroCst = rewriter.create<LLVM::ConstantOp>(
498 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
499 auto a0 = adaptor.getLhs();
500 auto a1 = rewriter.create<xllvm::VectorBroadcast32I512IntrOp>(
501 loc, VectorType::get({16}, rewriter.getI32Type()), zeroCst);
502 auto b0 = adaptor.getRhs();
503 auto b1 = rewriter.create<xllvm::UndefV16I32IntrOp>(
504 loc, VectorType::get({16}, rewriter.getI32Type()));
505
506 // 4* Shuffle
507 auto a_lo = rewriter.create<xllvm::VectorShuffleIntrOp>(
508 loc, VectorType::get({16}, rewriter.getI32Type()), a0, a1,
509 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
510 rewriter.getI32IntegerAttr(2)));
511 auto a_hi = rewriter.create<xllvm::VectorShuffleIntrOp>(
512 loc, VectorType::get({16}, rewriter.getI32Type()), a0, a1,
513 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
514 rewriter.getI32IntegerAttr(3)));
515 auto b_lo = rewriter.create<xllvm::VectorShuffleIntrOp>(
516 loc, VectorType::get({16}, rewriter.getI32Type()), b0, b1,
517 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
518 rewriter.getI32IntegerAttr(2)));
519 auto b_hi = rewriter.create<xllvm::VectorShuffleIntrOp>(
520 loc, VectorType::get({16}, rewriter.getI32Type()), b0, b1,
521 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
522 rewriter.getI32IntegerAttr(3)));
523 // MUL + 3 * MAC
524 auto mulConfCst = rewriter.create<LLVM::ConstantOp>(
525 loc, rewriter.getI32Type(),
526 rewriter.getI32IntegerAttr(aiev2_vmac_compute_control(
527 /*sgn_x=*/1, /*sgn_y=*/1, /*amode=*/1, /*bmode=*/3,
528 /*variant=*/2, /*zero_acc=*/0, /*shift16=*/0,
529 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, /*sub_mask=*/0)));
530 auto mulConfOp = rewriter.create<xllvm::MulConfAcc64IntrOp>(
531 loc, VectorType::get({16}, rewriter.getI64Type()),
532 forceCastOperandsToSignature(
533 rewriter, loc,
534 /*operands=*/{a_hi, b_hi, mulConfCst},
535 /*signature=*/
536 {VectorType::get({64}, rewriter.getI8Type()),
537 VectorType::get({16}, rewriter.getI32Type()),
538 rewriter.getI32Type()}));
539
540 auto createMacConfOp = [&](SmallVector<Value> operands,
541 int macConf) -> Value {
542 operands.push_back(rewriter.create<LLVM::ConstantOp>(
543 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(macConf)));
544 return rewriter
545 .create<xllvm::MacConfAcc64IntrOp>(
546 loc, VectorType::get({16}, rewriter.getI64Type()),
547 forceCastOperandsToSignature(
548 rewriter, loc,
549 /*operands=*/operands,
550 /*signature=*/
551 {VectorType::get({64}, rewriter.getI8Type()),
552 VectorType::get({16}, rewriter.getI32Type()),
553 VectorType::get({16}, rewriter.getI64Type()),
554 rewriter.getI32Type()}))
555 .getResult();
556 };
557 auto acc64Val = mulConfOp.getResult();
558 acc64Val = createMacConfOp(
559 SmallVector<Value>{a_hi, b_lo, acc64Val},
560 aiev2_vmac_compute_control(
561 /*sgn_x=*/1, /*sgn_y=*/0, /*amode=*/1, /*bmode=*/3,
562 /*variant=*/2, /*zero_acc=*/0, /*shift16=*/1,
563 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, /*sub_mask=*/0));
564 acc64Val = createMacConfOp(
565 SmallVector<Value>{a_lo, b_hi, acc64Val},
566 aiev2_vmac_compute_control(
567 /*sgn_x=*/0, /*sgn_y=*/1, /*amode=*/1, /*bmode=*/3,
568 /*variant=*/2, /*zero_acc=*/0, /*shift16=*/0,
569 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, /*sub_mask=*/0));
570 acc64Val = createMacConfOp(
571 SmallVector<Value>{a_lo, b_lo, acc64Val},
572 aiev2_vmac_compute_control(
573 /*sgn_x=*/0, /*sgn_y=*/0, /*amode=*/1, /*bmode=*/3,
574 /*variant=*/2, /*zero_acc=*/0, /*shift16=*/1,
575 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, /*sub_mask=*/0));
576
577 // create bitcast for result
578 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
579 acc64Val);
580 return success();
581 }
582
583 // This conversion pattern implements the below CPP emulated FP32 mul_elem.
584 // inline v16accfloat mul_elem_16_accuracy_safe(v16float v1, v16float v2) {
585 // v32bfloat16 a = broadcast_zero_to_v32bfloat16();
586 // v32bfloat16 b = broadcast_zero_to_v32bfloat16();
587 // v32bfloat16 c = broadcast_zero_to_v32bfloat16();
588 // v32bfloat16 d = broadcast_zero_to_v32bfloat16();
589 // v32bfloat16 e = broadcast_zero_to_v32bfloat16();
590 // v32bfloat16 f = broadcast_zero_to_v32bfloat16();
591 // v32bfloat16 dummy0 = broadcast_one_to_v32bfloat16();
592 // a = insert(a,0,to_v16bfloat16((v16accfloat)v1));
593 // v16accfloat acc0 = msc_elem_16_2(a, dummy0, (v16accfloat)v1);
594 // b = insert(b,0,to_v16bfloat16(acc0));
595 // c = insert(c,0,to_v16bfloat16(msc_elem_16_2(b, dummy0, acc0)));
596 // d = insert(d,0,to_v16bfloat16((v16accfloat)v2));
597 // v16accfloat acc1 = msc_elem_16_2(d, dummy0, (v16accfloat)v2);
598 // e = insert(e,0,to_v16bfloat16(acc1));
599 // f = insert(f,0,to_v16bfloat16(msc_elem_16_2(e, dummy0, acc1)));
600 // return
601 // mac_elem_16_2(a,d,mac_elem_16_2(a,e,mac_elem_16_2(b,d,mac_elem_16_2(
602 // d,c,mac_elem_16_2(b,e,mac_elem_16_2(a,f,mac_elem_16_2(
603 // b,f,mac_elem_16_2(c,e,mul_elem_16_2(c,f)))))))));
604 // }
605 // Caller example when handling the elementwise mul of two v16float vectors.
606 // v16float v1 = LHS(); v16float v2 = RHS();
607 // v16accfloat v3 = mul_elem_16(v1, v2);
608 // Explantion: For v32bfloat16 `a`, the first half v16bf16 contains `most
609 // significant 7 bits of mantissa` from v1, and the second half v16bf16 are
610 // zeros. For v16accfloat `acc0`, the MSC equals to "(original `v1` with 23
611 // bits of mantissa) - (`a` with MSB 7 bits of mantissa from v1)". For
612 // v32bfloat16 `b`, the first half v16bf16 contains `[7:13] bits of mantissa
613 // from v1` from v1, and the second half v16bf16 are zeros. For v32bfloat16
614 // `c`, the first half v16bf16 contains `[14:20] bits of mantissa from v1`
615 // from v1, and the second half v16bf16 are zeros. Hence, we can represent
616 // v16float in three v32bfloat16 and then perform 9 MUL/MAC in v32bfloat16 to
617 // get the final elementwise multiplication result.
618
619 LogicalResult
620 convertToEmulatedFP32MulElem(aievec::MulElemOp op, OpAdaptor adaptor,
621 ConversionPatternRewriter &rewriter) const {
622 Location loc = op.getLoc();
623 auto zeroCst = rewriter.create<LLVM::ConstantOp>(
624 loc, rewriter.getBF16Type(),
625 rewriter.getZeroAttr(rewriter.getBF16Type()));
626 auto aZeros = rewriter.create<xllvm::VectorBroadcast16BF512IntrOp>(
627 loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
628 auto bZeros = rewriter.create<xllvm::VectorBroadcast16BF512IntrOp>(
629 loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
630 auto cZeros = rewriter.create<xllvm::VectorBroadcast16BF512IntrOp>(
631 loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
632 auto dZeros = rewriter.create<xllvm::VectorBroadcast16BF512IntrOp>(
633 loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
634 auto eZeros = rewriter.create<xllvm::VectorBroadcast16BF512IntrOp>(
635 loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
636 auto fZeros = rewriter.create<xllvm::VectorBroadcast16BF512IntrOp>(
637 loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
638 auto oneCst = rewriter.create<LLVM::ConstantOp>(
639 loc, rewriter.getBF16Type(),
640 rewriter.getOneAttr(rewriter.getBF16Type()));
641 auto dummy0 = rewriter.create<xllvm::VectorBroadcast16BF512IntrOp>(
642 loc, VectorType::get({32}, rewriter.getBF16Type()), oneCst);
643 auto zeroCstI32 = rewriter.create<LLVM::ConstantOp>(
644 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
645 auto mscMacMulConfCst = rewriter.create<LLVM::ConstantOp>(
646 loc, rewriter.getI32Type(),
647 rewriter.getI32IntegerAttr(aiev2_vmac_compute_control(
648 /*sgn_x=*/0, /*sgn_y=*/0, /*amode=*/2, /*bmode=*/3,
649 /*variant=*/1, /*zero_acc=*/0, /*shift16=*/0,
650 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, /*sub_mask=*/0)));
651
652 auto extractV16FP32ToThreeV16BF16 =
653 [&](Value inputV16FP32, Value aZeros, Value bZeros,
654 Value cZeros) -> std::tuple<Value, Value, Value> {
655 // a = insert(a,0,to_v16bfloat16((v16accfloat)v1));
656 auto inputBitCasted =
657 forceCastValueToType(rewriter, loc, inputV16FP32,
658 VectorType::get({8}, rewriter.getI64Type()));
659 auto v1ToBF16 = rewriter.create<xllvm::Vector16AccFloatToV16BF16IntrOp>(
660 loc, VectorType::get({16}, rewriter.getBF16Type()), inputBitCasted);
661 auto a = rewriter.create<xllvm::UpdBF512BF256IntrOp>(
662 loc, VectorType::get({32}, rewriter.getBF16Type()), aZeros, v1ToBF16,
663 zeroCstI32);
664
665 // v16accfloat acc0 = msc_elem_16_2(a, dummy0, (v16accfloat)v1);
666 auto acc0 = rewriter.create<xllvm::MscConfBF16IntrOp>(
667 loc, VectorType::get({8}, rewriter.getI64Type()), a, dummy0,
668 inputBitCasted, mscMacMulConfCst);
669
670 // b = insert(b,0,to_v16bfloat16(acc0));
671 auto acc0ToBF16 = rewriter.create<xllvm::Vector16AccFloatToV16BF16IntrOp>(
672 loc, VectorType::get({16}, rewriter.getBF16Type()), acc0);
673 auto b = rewriter.create<xllvm::UpdBF512BF256IntrOp>(
674 loc, VectorType::get({32}, rewriter.getBF16Type()), bZeros,
675 acc0ToBF16, zeroCstI32);
676
677 // c = insert(c,0,to_v16bfloat16(msc_elem_16_2(b, dummy0, acc0)));
678 auto acc0Mscb = rewriter.create<xllvm::MscConfBF16IntrOp>(
679 loc, VectorType::get({8}, rewriter.getI64Type()), b, dummy0, acc0,
680 mscMacMulConfCst);
681 auto acc0MscbToBF16 =
682 rewriter.create<xllvm::Vector16AccFloatToV16BF16IntrOp>(
683 loc, VectorType::get({16}, rewriter.getBF16Type()), acc0Mscb);
684 auto c = rewriter.create<xllvm::UpdBF512BF256IntrOp>(
685 loc, VectorType::get({32}, rewriter.getBF16Type()), cZeros,
686 acc0MscbToBF16, zeroCstI32);
687 return std::make_tuple(a.getResult(), b.getResult(), c.getResult());
688 };
689
690 // Get v16vfloat16 a, b, c for representing v16float v1
691 auto [a, b, c] =
692 extractV16FP32ToThreeV16BF16(adaptor.getLhs(), aZeros, bZeros, cZeros);
693 // Get v16vfloat16 d, e, f for representing v16float v2
694 auto [d, e, f] =
695 extractV16FP32ToThreeV16BF16(adaptor.getRhs(), dZeros, eZeros, fZeros);
696
697 // Create 1 MUL and 2/5/8 MACs depending on the Aie2Fp32EmulationOption
698 auto createMacOps = [&](Value lhs, Value rhs, Value acc) -> Value {
699 return rewriter
700 .create<xllvm::MacConfBF16IntrOp>(
701 loc, VectorType::get({8}, rewriter.getI64Type()), lhs, rhs, acc,
702 mscMacMulConfCst)
703 .getResult();
704 };
705
706 Value finalMacVal;
707 if (aie2Fp32EmulationOption == Aie2Fp32Emulation::AccuracyFast) {
708 // Fast and Accurate option. float a*b would require 6 mac operations.
709 // Input fp32 number is split in to 3 bfloat16 numbers to extract all the
710 // bits of the mantissa. float a,b; both a and b are split in to 3
711 // bfloat16 numbers each. Hence there would be 9 mac operations in
712 // multiplication of a and b. In the 9 mac operations to emulate fp32 mul,
713 // mac operations with LSBs are ignored. (3 last terms). This helps
714 // improve cycle count of mul and has least impact on accuracy of result.
715 // This is the default option to the aiecompiler
716 auto afMul = rewriter.create<xllvm::MulConfBF16IntrOp>(
717 loc, VectorType::get({8}, rewriter.getI64Type()), a, f,
718 mscMacMulConfCst);
719 finalMacVal = createMacOps(
720 a, d,
721 createMacOps(
722 a, e,
723 createMacOps(b, d,
724 createMacOps(d, c, createMacOps(b, e, afMul)))));
725 } else if (aie2Fp32EmulationOption == Aie2Fp32Emulation::AccuracyLow) {
726 // Fast and least accurate option. float a*b would require 3 mac
727 // operations.
728 // Input fp32 number is split in to 2 bfloat16 numbers. Hence not all the
729 // bits from mantissa can be used. float a,b; Both a and b are split in to
730 // 2 bfloat16 numbers each. Hence there would be 4 mac operations in
731 // multiplication of a and b. In the 4 mac operations to emulate fp32 mul,
732 // mac operations with LSBs are ignored. (1 last term). This helps improve
733 // cycle count of mul float a, b;
734 auto bdMul = rewriter.create<xllvm::MulConfBF16IntrOp>(
735 loc, VectorType::get({8}, rewriter.getI64Type()), b, d,
736 mscMacMulConfCst);
737 finalMacVal = createMacOps(a, d, createMacOps(a, e, bdMul));
738 } else {
739 // aie2Fp32EmulationOption == Aie2Fp32Emulation::AccuracySafe
740 // Most accurate option since input fp32 number is split in to 3 bfloat16
741 // numbers to extract all the bits of the mantissa. float a*b would
742 // require 9 mac operations due to 3 bfloat16 splits each.
743 auto cfMul = rewriter.create<xllvm::MulConfBF16IntrOp>(
744 loc, VectorType::get({8}, rewriter.getI64Type()), c, f,
745 mscMacMulConfCst);
746 finalMacVal = createMacOps(
747 a, d,
748 createMacOps(
749 a, e,
750 createMacOps(
751 b, d,
752 createMacOps(
753 d, c,
754 createMacOps(
755 b, e,
756 createMacOps(
757 a, f,
758 createMacOps(b, f,
759 createMacOps(c, e, cfMul))))))));
760 }
761
762 // create bitcast for result
763 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
764 finalMacVal);
765 return success();
766 }
767
768 LogicalResult
769 matchAndRewrite(aievec::MulElemOp op, OpAdaptor adaptor,
770 ConversionPatternRewriter &rewriter) const override {
771 Location loc = op.getLoc();
772 auto decodedMulElemOp = decodeMulElemOp(adaptor);
773
774 if (decodedMulElemOp.kind == DecodedMulElemOp::Kind::UNSUPPORTED) {
775 op.emitWarning() << "aievec.mul_elem conversion is not supported.\n";
776 return failure();
777 }
778
779 // Handle the emulated I32/FP32 mul_elem
780 if (decodedMulElemOp.kind == DecodedMulElemOp::Kind::I32_I32_I64_32x1x2x1) {
781 return convertToEmulatedI32MulElem(op, adaptor, rewriter);
782 } else if (decodedMulElemOp.kind ==
784 return convertToEmulatedFP32MulElem(op, adaptor, rewriter);
785 }
786
787 // create constant for config
788 auto confCst = rewriter.create<LLVM::ConstantOp>(
789 loc, rewriter.getI32Type(),
790 rewriter.getI32IntegerAttr(decodedMulElemOp.conf));
791 Value mulElemOp = nullptr;
792 SmallVector<Value> operands({adaptor.getLhs(), adaptor.getRhs(), confCst});
793
794 // create xllvm intrinsic
795 if (decodedMulElemOp.kind == DecodedMulElemOp::Kind::I16_I16_I32_32x1x1x1 ||
796 decodedMulElemOp.kind == DecodedMulElemOp::Kind::I8_I8_I32_32x1x2x1) {
797 mulElemOp = rewriter.create<xllvm::MulConfAcc32IntrOp>(
798 loc, VectorType::get({16}, rewriter.getI64Type()),
799 forceCastOperandsToSignature(
800 rewriter, loc, operands,
801 {VectorType::get({64}, rewriter.getI8Type()),
802 VectorType::get({16}, rewriter.getI32Type()),
803 rewriter.getI32Type()}));
804 } else if (decodedMulElemOp.kind ==
806 mulElemOp = rewriter.create<xllvm::MulConfBF16IntrOp>(
807 loc, VectorType::get({8}, rewriter.getI64Type()),
808 forceCastOperandsToSignature(
809 rewriter, loc, operands,
810 {VectorType::get({32}, rewriter.getBF16Type()),
811 VectorType::get({32}, rewriter.getBF16Type()),
812 rewriter.getI32Type()}));
813 }
814
815 // create bitcast for result
816 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
817 mulElemOp);
818 return success();
819 }
820};
821
822class UPSOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::UPSOp> {
823public:
824 using ConvertOpToLLVMPattern<aievec::UPSOp>::ConvertOpToLLVMPattern;
825
826 LogicalResult
827 matchAndRewrite(aievec::UPSOp op, OpAdaptor adaptor,
828 ConversionPatternRewriter &rewriter) const override {
829 Location loc = op.getLoc();
830
831 Value result = op.getResult();
832 VectorType resultType = cast<VectorType>(result.getType());
833 VectorType flatResTy = getFlattenedVectorType(resultType);
834 Type resultScaTy = resultType.getElementType();
835 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
836 int resultLanes = getVectorLaneSize(resultType);
837 int resultVectorSize = resultBitWidth * resultLanes;
838
839 Value opSrcVal = adaptor.getSource();
840 auto srcVecTy = cast<VectorType>(opSrcVal.getType());
841 auto fltSrcVecTy = getFlattenedVectorType(srcVecTy);
842 if (srcVecTy != fltSrcVecTy)
843 opSrcVal =
844 rewriter
845 .create<vector::ShapeCastOp>(op.getLoc(), fltSrcVecTy, opSrcVal)
846 .getResult();
847
848 // create xllvm intrinsic
849 // Integer types
850 Value upsIntrOp = nullptr;
851 if (llvm::isa<IntegerType>(resultScaTy)) {
852 // create constant for sign
853 auto signCst = rewriter.create<LLVM::ConstantOp>(
854 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
855 auto shiftCst = rewriter.create<LLVM::ConstantOp>(
856 loc, rewriter.getI32Type(),
857 rewriter.getI32IntegerAttr(op.getShift()));
858
859 SmallVector<Value> operands({opSrcVal, shiftCst, signCst});
860 if (resultVectorSize == 512) {
861 if (resultBitWidth == 32) {
862 // v16int16 -> v16acc32
863 upsIntrOp = rewriter.create<xllvm::Acc32V16I256UpsIntrOp>(
864 loc, VectorType::get({8}, rewriter.getI64Type()),
865 forceCastOperandsToSignature(
866 rewriter, loc, operands,
867 {VectorType::get({16}, rewriter.getI16Type()),
868 rewriter.getI32Type(), rewriter.getI32Type()}));
869 } else if (resultBitWidth == 64) {
870 // v8int32 -> v8acc64
871 upsIntrOp = rewriter.create<xllvm::Acc64V8I256UpsIntrOp>(
872 loc, VectorType::get({8}, rewriter.getI64Type()),
873 forceCastOperandsToSignature(
874 rewriter, loc, operands,
875 {VectorType::get({8}, rewriter.getI32Type()),
876 rewriter.getI32Type(), rewriter.getI32Type()}));
877 }
878 } else if (resultVectorSize == 1024) {
879 Value src = opSrcVal;
880 VectorType srcType = cast<VectorType>(src.getType());
881 Type srcScaType = srcType.getElementType();
882 unsigned srcBitWidth = srcScaType.getIntOrFloatBitWidth();
883
884 if (resultBitWidth == 32 && srcBitWidth == 16) {
885 // v32int16 -> v32acc32
886 upsIntrOp = rewriter.create<xllvm::Acc32V32I512UpsIntrOp>(
887 loc, VectorType::get({16}, rewriter.getI64Type()),
888 forceCastOperandsToSignature(
889 rewriter, loc, operands,
890 {VectorType::get({32}, rewriter.getI16Type()),
891 rewriter.getI32Type(), rewriter.getI32Type()}));
892 } else if (resultBitWidth == 64 && srcBitWidth == 32) {
893 // v16int32 -> v16acc64
894 upsIntrOp = rewriter.create<xllvm::Acc64V16I512UpsIntrOp>(
895 loc, VectorType::get({16}, rewriter.getI64Type()),
896 forceCastOperandsToSignature(
897 rewriter, loc, operands,
898 {VectorType::get({16}, rewriter.getI32Type()),
899 rewriter.getI32Type(), rewriter.getI32Type()}));
900 } else if (resultBitWidth == 64 && srcBitWidth == 16) {
901 // v16int16 -> v16acc64
902 upsIntrOp = rewriter.create<xllvm::Acc64V16I256UpsIntrOp>(
903 loc, VectorType::get({16}, rewriter.getI64Type()),
904 forceCastOperandsToSignature(
905 rewriter, loc, operands,
906 {VectorType::get({16}, rewriter.getI16Type()),
907 rewriter.getI32Type(), rewriter.getI32Type()}));
908 } else if (resultBitWidth == 32 && srcBitWidth == 8) {
909 // v32int8 -> v32acc32
910 upsIntrOp = rewriter.create<xllvm::Acc32V32I256UpsIntrOp>(
911 loc, VectorType::get({16}, rewriter.getI64Type()),
912 forceCastOperandsToSignature(
913 rewriter, loc, operands,
914 {VectorType::get({32}, rewriter.getI8Type()),
915 rewriter.getI32Type(), rewriter.getI32Type()}));
916 }
917 }
918 } else {
919 // Float types
920 if (resultVectorSize == 512) {
921 // v16bfloat16 -> v16accfloat
922 upsIntrOp = rewriter.create<xllvm::Vector16BF16ToV16AccFloatIntrOp>(
923 loc, VectorType::get({8}, rewriter.getI64Type()),
924 forceCastOperandsToSignature(
925 rewriter, loc, {opSrcVal},
926 {VectorType::get({16}, rewriter.getBF16Type())}));
927 } else if (resultVectorSize == 1024) {
928 // v32bfloat16 -> v32accfloat
929 // The CPP example of the implementation is below:
930 // INTRINSIC(v32accfloat) ups_to_v32accfloat(v32bfloat16 a) {
931 // v16accfloat x0 = ups_to_v16accfloat(extract_v16bfloat16(a, 0));
932 // v16accfloat x1 = ups_to_v16accfloat(extract_v16bfloat16(a, 1));
933 // return concat(x0, x1);
934 // }
935 auto indexZeroCst = rewriter.create<LLVM::ConstantOp>(
936 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
937 auto indexOneCst = rewriter.create<LLVM::ConstantOp>(
938 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
939 auto extractUps = [&](Value source, Value index) -> Value {
940 auto extOp = rewriter.create<xllvm::ExtI256I512IntrOp>(
941 loc, VectorType::get({8}, rewriter.getI32Type()),
942 forceCastOperandsToSignature(
943 rewriter, loc, {source, index},
944 {VectorType::get({16}, rewriter.getI32Type()),
945 rewriter.getI32Type()}));
946 return rewriter.create<xllvm::Vector16BF16ToV16AccFloatIntrOp>(
947 loc, VectorType::get({8}, rewriter.getI64Type()),
948 forceCastOperandsToSignature(
949 rewriter, loc, {extOp},
950 {VectorType::get({16}, rewriter.getBF16Type())}));
951 };
952 auto resLo = extractUps(opSrcVal, indexZeroCst);
953 auto resHi = extractUps(opSrcVal, indexOneCst);
954 // Concat the two 512-bit vector to a 1024-bit vector.
955 // Note that given sources a0 and a1, the result is [a1; a0].
956 upsIntrOp = rewriter.create<xllvm::ConcatI1024I512IntrOp>(
957 loc, VectorType::get({32}, rewriter.getI32Type()),
958 forceCastOperandsToSignature(
959 rewriter, loc, {resLo, resHi},
960 {VectorType::get({16}, rewriter.getI32Type()),
961 VectorType::get({16}, rewriter.getI32Type())}));
962 }
963 }
964
965 if (!upsIntrOp) {
966 op.emitWarning() << "aievec.ups is not supported.\n";
967 return failure();
968 }
969
970 // create bitcast for result if needed
971 if (flatResTy != upsIntrOp.getType())
972 upsIntrOp = rewriter.create<LLVM::BitcastOp>(loc, flatResTy, upsIntrOp);
973
974 if (flatResTy != resultType)
975 upsIntrOp =
976 rewriter.create<vector::ShapeCastOp>(loc, resultType, upsIntrOp);
977
978 rewriter.replaceOp(op, upsIntrOp);
979
980 return success();
981 }
982};
983
984class SRSOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::SRSOp> {
985public:
986 using ConvertOpToLLVMPattern<aievec::SRSOp>::ConvertOpToLLVMPattern;
987
988 LogicalResult
989 matchAndRewrite(aievec::SRSOp op, OpAdaptor adaptor,
990 ConversionPatternRewriter &rewriter) const override {
991 Location loc = op.getLoc();
992
993 Value result = op.getResult();
994 VectorType resultType = cast<VectorType>(result.getType());
995 Type resultScaTy = resultType.getElementType();
996 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
997 int resultLanes = getVectorLaneSize(resultType);
998 int resultVectorSize = resultBitWidth * resultLanes;
999
1000 // Integer types
1001 Value srsIntrOp = nullptr;
1002 if (llvm::isa<IntegerType>(resultScaTy)) {
1003 // create constant for sign
1004 auto signCst = rewriter.create<LLVM::ConstantOp>(
1005 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
1006
1007 // create xllvm intrinsic
1008 SmallVector<Value> operands(
1009 {adaptor.getSource(), adaptor.getShift(), signCst});
1010 if (resultVectorSize == 512) {
1011 if (resultBitWidth == 16) {
1012 // v32acc32 -> v32int16
1013 srsIntrOp = rewriter.create<xllvm::I512V32Acc32SrsIntrOp>(
1014 loc, VectorType::get({32}, rewriter.getI16Type()),
1015 forceCastOperandsToSignature(
1016 rewriter, loc, operands,
1017 {VectorType::get({16}, rewriter.getI64Type()),
1018 rewriter.getI32Type(), rewriter.getI32Type()}));
1019 } else if (resultBitWidth == 32) {
1020 // v16acc64 -> v16int32
1021 srsIntrOp = rewriter.create<xllvm::I512V16Acc64SrsIntrOp>(
1022 loc, VectorType::get({16}, rewriter.getI32Type()),
1023 forceCastOperandsToSignature(
1024 rewriter, loc, operands,
1025 {VectorType::get({16}, rewriter.getI64Type()),
1026 rewriter.getI32Type(), rewriter.getI32Type()}));
1027 }
1028 } else if (resultVectorSize == 256) {
1029 Value src = adaptor.getSource();
1030 VectorType srcType = cast<VectorType>(src.getType());
1031 Type srcScaType = srcType.getElementType();
1032 unsigned srcBitWidth = srcScaType.getIntOrFloatBitWidth();
1033
1034 if (resultBitWidth == 16 && srcBitWidth == 32) {
1035 // v16acc32 -> v16int16
1036 srsIntrOp = rewriter.create<xllvm::I256V16Acc32SrsIntrOp>(
1037 loc, VectorType::get({16}, rewriter.getI16Type()),
1038 forceCastOperandsToSignature(
1039 rewriter, loc, operands,
1040 {VectorType::get({8}, rewriter.getI64Type()),
1041 rewriter.getI32Type(), rewriter.getI32Type()}));
1042 } else if (resultBitWidth == 8 && srcBitWidth == 32) {
1043 // v32acc32 -> v32int8
1044 srsIntrOp = rewriter.create<xllvm::I256V32Acc32SrsIntrOp>(
1045 loc, VectorType::get({32}, rewriter.getI8Type()),
1046 forceCastOperandsToSignature(
1047 rewriter, loc, operands,
1048 {VectorType::get({16}, rewriter.getI64Type()),
1049 rewriter.getI32Type(), rewriter.getI32Type()}));
1050 } else if (resultBitWidth == 16 && srcBitWidth == 64) {
1051 // v16acc64 -> v16int16
1052 srsIntrOp = rewriter.create<xllvm::I256V16Acc64SrsIntrOp>(
1053 loc, VectorType::get({16}, rewriter.getI16Type()),
1054 forceCastOperandsToSignature(
1055 rewriter, loc, operands,
1056 {VectorType::get({16}, rewriter.getI64Type()),
1057 rewriter.getI32Type(), rewriter.getI32Type()}));
1058 } else if (resultBitWidth == 32 && srcBitWidth == 64) {
1059 // v8acc64 -> v8int32
1060 srsIntrOp = rewriter.create<xllvm::I256V8Acc64SrsIntrOp>(
1061 loc, VectorType::get({8}, rewriter.getI32Type()),
1062 forceCastOperandsToSignature(
1063 rewriter, loc, operands,
1064 {VectorType::get({8}, rewriter.getI64Type()),
1065 rewriter.getI32Type(), rewriter.getI32Type()}));
1066 }
1067 }
1068 } else {
1069 // Float types
1070 if (resultVectorSize == 256) {
1071 // v16accfloat -> v16bfloat16
1072 srsIntrOp = rewriter.create<xllvm::Vector16AccFloatToV16BF16IntrOp>(
1073 loc, VectorType::get({16}, rewriter.getBF16Type()),
1074 forceCastOperandsToSignature(
1075 rewriter, loc, {adaptor.getSource()},
1076 {VectorType::get({8}, rewriter.getI64Type())}));
1077 } else if (resultVectorSize == 512) {
1078 // v32accfloat -> v32bfloat16
1079 // The CPP example of the implementation is below:
1080 // v32bfloat16 to_v32bfloat16(v32accfloat acc) {
1081 // v16bfloat16 x0 = to_v16bfloat16(extract_v16accfloat(acc, 0));
1082 // v16bfloat16 x1 = to_v16bfloat16(extract_v16accfloat(acc, 1));
1083 // return concat(x0, x1);
1084 // }
1085 auto indexZeroCst = rewriter.create<LLVM::ConstantOp>(
1086 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
1087 auto indexOneCst = rewriter.create<LLVM::ConstantOp>(
1088 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
1089 auto extractSrs = [&](Value source, Value index) -> Value {
1090 auto extOp = rewriter.create<xllvm::ExtI512I1024IntrOp>(
1091 loc, VectorType::get({16}, rewriter.getI32Type()),
1092 forceCastOperandsToSignature(
1093 rewriter, loc, {source, index},
1094 {VectorType::get({32}, rewriter.getI32Type()),
1095 rewriter.getI32Type()}));
1096 return rewriter.create<xllvm::Vector16AccFloatToV16BF16IntrOp>(
1097 loc, VectorType::get({16}, rewriter.getBF16Type()),
1098 forceCastOperandsToSignature(
1099 rewriter, loc, {extOp},
1100 {VectorType::get({8}, rewriter.getI64Type())}));
1101 };
1102 auto resLo = extractSrs(adaptor.getSource(), indexZeroCst);
1103 auto resHi = extractSrs(adaptor.getSource(), indexOneCst);
1104 // Concat the two 256-bit vector to a 512-bit vector.
1105 // Note that given sources a0 and a1, the result is [a1; a0].
1106 srsIntrOp = rewriter.create<xllvm::ConcatI512I256IntrOp>(
1107 loc, VectorType::get({16}, rewriter.getI32Type()),
1108 forceCastOperandsToSignature(
1109 rewriter, loc, {resLo, resHi},
1110 {VectorType::get({8}, rewriter.getI32Type()),
1111 VectorType::get({8}, rewriter.getI32Type())}));
1112 }
1113 }
1114
1115 if (!srsIntrOp) {
1116 op.emitWarning() << "aievec.srs is not supported.\n";
1117 return failure();
1118 }
1119
1120 // create bitcast for result if needed
1121 if (op.getResult().getType() != srsIntrOp.getType()) {
1122 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
1123 srsIntrOp);
1124 } else {
1125 rewriter.replaceOp(op, srsIntrOp);
1126 }
1127
1128 return success();
1129 }
1130};
1131
1132class UPDOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::UPDOp> {
1133public:
1134 using ConvertOpToLLVMPattern<aievec::UPDOp>::ConvertOpToLLVMPattern;
1135
1136 static std::string getIntrinsicName(aievec::UPDOp op, int loadSize) {
1137 auto resultType = cast<VectorType>(op.getResult().getType());
1138 std::stringstream ss;
1139 ss << "llvm.aie.upd.";
1140 ss << (loadSize == 128 ? 'v' : loadSize == 256 ? 'w' : 'x') << ".";
1141 ss << getVectorTypeString(resultType) << ".";
1142 // The index affects which intrinsic to call
1143 ss << (op.getIndex() == 0 ? "lo" : "hi");
1144 return ss.str();
1145 }
1146
1147 LogicalResult
1148 matchAndRewrite(aievec::UPDOp op, OpAdaptor adaptor,
1149 ConversionPatternRewriter &rewriter) const override {
1150 auto module = op->getParentOfType<ModuleOp>();
1151 MLIRContext *context = rewriter.getContext();
1152
1153 // A bit more complicated: load the vector, then update result vector
1154 // AIE1 is capable of 128-bit on one bank and 256-bit loads on even-odd
1155 // banks Identify size of update
1156 int vecSizeInBits =
1157 getVectorSizeInBits(cast<VectorType>(op.getResult().getType()));
1158
1159 auto ptr = this->getStridedElementPtr(
1160 rewriter, op->getLoc(), cast<MemRefType>(op.getSource().getType()),
1161 adaptor.getSource(), adaptor.getIndices());
1162
1163 // TODO: handle the offset field
1164
1165 if (vecSizeInBits <= 256) {
1166 // Total <=256-bit updates are much simpler:
1167 // we can do a direct load into the vector register
1168 // look at the indices to calculate the address
1169 auto vectorPtrType = LLVM::LLVMPointerType::get(
1170 getContext(),
1171 cast<MemRefType>(op.getSource().getType()).getMemorySpaceAsInt());
1172 auto castedPtr =
1173 rewriter.create<LLVM::BitcastOp>(op->getLoc(), vectorPtrType, ptr);
1174 auto vecType = cast<VectorType>(op.getResult().getType());
1175 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, vecType, castedPtr, 1);
1176 } else {
1177 // Total >256-bit updates will require upd ops to fill the whole vector
1178 // each UDP op represents one of these 256-bit loads and updates
1179
1180 // Determine the load size
1181 // TODO: no examples of 1024-bit output vectors: doesn't feel right
1182 // to attempt a 512-bit load to do an update like this
1183 int loadSize = vecSizeInBits == 256 ? 128
1184 : vecSizeInBits == 512 ? 256
1185 : 512;
1186
1187 // Create a vectorType for the load proper
1188 // Load half of the final result vector
1189 auto resultType = cast<VectorType>(op.getResult().getType());
1190 int lanes = getVectorLaneSize(resultType);
1191 auto loadType =
1192 VectorType::get({(int64_t)lanes / 2}, resultType.getElementType());
1193
1194 // Load the vector
1195 auto vectorPtrType = LLVM::LLVMPointerType::get(
1196 getContext(),
1197 cast<MemRefType>(op.getSource().getType()).getMemorySpaceAsInt());
1198 auto castedPtr =
1199 rewriter.create<LLVM::BitcastOp>(op->getLoc(), vectorPtrType, ptr);
1200 auto loadValue =
1201 rewriter.create<LLVM::LoadOp>(op->getLoc(), loadType, castedPtr, 1);
1202
1203 // Get set up for the intrinsic
1204 std::string intrinsicName = getIntrinsicName(op, loadSize);
1205
1206 // If the intrinsic declaration doesn't exist, create it
1207 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(
1208 StringAttr::get(context, intrinsicName));
1209
1210 if (!func) {
1211 OpBuilder::InsertionGuard guard(rewriter);
1212 rewriter.setInsertionPointToStart(module.getBody());
1213 func = rewriter.create<LLVM::LLVMFuncOp>(
1214 rewriter.getUnknownLoc(), intrinsicName,
1215 LLVM::LLVMFunctionType::get(resultType, {resultType, loadType}));
1216 }
1217
1218 // Determine what the destination is
1219 Value destValue;
1220 if (adaptor.getVector()) {
1221 // This UPD is using an existing destination vector
1222 destValue = adaptor.getVector();
1223 } else {
1224 // If this UPD is not working off of an existing destination vector,
1225 // create an undefined vector as the destination
1226
1227 // TODO: determine if the undef intrinsic is needed or if an LLVM
1228 // undef suffices destValue =
1229 // rewriter.create<LLVM::UndefOp>(op->getLoc(), resultType);
1230
1231 std::stringstream ss;
1232 ss << "llvm.aie." << getVectorTypeString(resultType) << ".undef";
1233 std::string intrinsicName = ss.str();
1234
1235 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(
1236 StringAttr::get(rewriter.getContext(), intrinsicName));
1237
1238 if (!func) {
1239 OpBuilder::InsertionGuard guard(rewriter);
1240 rewriter.setInsertionPointToStart(module.getBody());
1241 func = rewriter.create<LLVM::LLVMFuncOp>(
1242 rewriter.getUnknownLoc(), intrinsicName,
1243 LLVM::LLVMFunctionType::get(resultType, {}));
1244 }
1245 destValue =
1246 rewriter.create<LLVM::CallOp>(op->getLoc(), func, ValueRange{})
1247 ->getOpResult(0);
1248 }
1249
1250 // Create our call
1251 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
1252 op, func, ValueRange{destValue, loadValue});
1253 }
1254
1255 return success();
1256 }
1257};
1258
1260 : public mlir::ConvertOpToLLVMPattern<aievec::ConcatOp> {
1261public:
1262 using ConvertOpToLLVMPattern<aievec::ConcatOp>::ConvertOpToLLVMPattern;
1263
1264 LogicalResult
1265 matchAndRewrite(aievec::ConcatOp op, OpAdaptor adaptor,
1266 ConversionPatternRewriter &rewriter) const override {
1267 Location loc = op.getLoc();
1268
1269 SmallVector<Value> sources = adaptor.getSources();
1270 Value src = sources.front();
1271 VectorType srcType = cast<VectorType>(src.getType());
1272 Type srcScalarType = srcType.getElementType();
1273 unsigned srcBitWidth = srcScalarType.getIntOrFloatBitWidth();
1274 int srcLanes = getVectorLaneSize(srcType);
1275 int srcVectorSize = srcBitWidth * srcLanes;
1276
1277 Value result = op.getResult();
1278 VectorType resultType = cast<VectorType>(result.getType());
1279 Type resultScaTy = resultType.getElementType();
1280 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
1281 int resultLanes = getVectorLaneSize(resultType);
1282 int resultVectorSize = resultBitWidth * resultLanes;
1283
1284 if (sources.size() != 2 && sources.size() != 4) {
1285 op.emitWarning() << "aievec.concat with " << sources.size()
1286 << " operands is not supported.\n";
1287 return failure();
1288 }
1289
1290 // create xllvm intrinsic
1291 Value concatOp = nullptr;
1292 if (srcVectorSize == 256 && resultVectorSize == 512) {
1293 concatOp = rewriter.create<xllvm::ConcatI512I256IntrOp>(
1294 loc, VectorType::get({16}, rewriter.getI32Type()),
1295 forceCastOperandsToSignature(
1296 rewriter, loc, adaptor.getSources(),
1297 {VectorType::get({8}, rewriter.getI32Type()),
1298 VectorType::get({8}, rewriter.getI32Type())}));
1299 } else if (srcVectorSize == 256 && resultVectorSize == 1024) {
1300 concatOp = rewriter.create<xllvm::ConcatI1024I256IntrOp>(
1301 loc, VectorType::get({32}, rewriter.getI32Type()),
1302 forceCastOperandsToSignature(
1303 rewriter, loc, adaptor.getSources(),
1304 {VectorType::get({8}, rewriter.getI32Type()),
1305 VectorType::get({8}, rewriter.getI32Type()),
1306 VectorType::get({8}, rewriter.getI32Type()),
1307 VectorType::get({8}, rewriter.getI32Type())}));
1308 } else if (srcVectorSize == 512 && resultVectorSize == 1024) {
1309 concatOp = rewriter.create<xllvm::ConcatI1024I512IntrOp>(
1310 loc, VectorType::get({32}, rewriter.getI32Type()),
1311 forceCastOperandsToSignature(
1312 rewriter, loc, adaptor.getSources(),
1313 {VectorType::get({16}, rewriter.getI32Type()),
1314 VectorType::get({16}, rewriter.getI32Type())}));
1315 } else {
1316 op.emitWarning() << "aievec.concat with " << srcVectorSize
1317 << "-bit operands, and " << resultVectorSize
1318 << "-bit result is not supported.\n";
1319 return failure();
1320 }
1321
1322 // create bitcast for result
1323 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
1324 concatOp);
1325
1326 return success();
1327 }
1328};
1329
1330class ExtOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::ExtOp> {
1331public:
1332 using ConvertOpToLLVMPattern<aievec::ExtOp>::ConvertOpToLLVMPattern;
1333
1334 LogicalResult
1335 matchAndRewrite(aievec::ExtOp op, OpAdaptor adaptor,
1336 ConversionPatternRewriter &rewriter) const override {
1337 Location loc = op.getLoc();
1338
1339 Value src = adaptor.getSource();
1340 VectorType srcType = cast<VectorType>(src.getType());
1341 Type srcScalarType = srcType.getElementType();
1342 unsigned srcBitWidth = srcScalarType.getIntOrFloatBitWidth();
1343 int srcLanes = getVectorLaneSize(srcType);
1344 int srcVectorSize = srcBitWidth * srcLanes;
1345
1346 Value result = op.getResult();
1347 VectorType resultType = cast<VectorType>(result.getType());
1348 Type resultScaTy = resultType.getElementType();
1349 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
1350 int resultLanes = getVectorLaneSize(resultType);
1351 int resultVectorSize = resultBitWidth * resultLanes;
1352
1353 // create constant for index
1354 auto indexCst = rewriter.create<LLVM::ConstantOp>(
1355 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(op.getIndex()));
1356
1357 // create xllvm intrinsic
1358 SmallVector<Value> operands({adaptor.getSource(), indexCst});
1359 Value extOp = nullptr;
1360 // Integer types
1361 if (resultVectorSize == 256 && srcVectorSize == 512) {
1362 extOp = rewriter.create<xllvm::ExtI256I512IntrOp>(
1363 loc, VectorType::get({8}, rewriter.getI32Type()),
1364 forceCastOperandsToSignature(
1365 rewriter, loc, operands,
1366 {VectorType::get({16}, rewriter.getI32Type()),
1367 rewriter.getI32Type()}));
1368 } else if (resultVectorSize == 512 && srcVectorSize == 1024) {
1369 extOp = rewriter.create<xllvm::ExtI512I1024IntrOp>(
1370 loc, VectorType::get({16}, rewriter.getI32Type()),
1371 forceCastOperandsToSignature(
1372 rewriter, loc, operands,
1373 {VectorType::get({32}, rewriter.getI32Type()),
1374 rewriter.getI32Type()}));
1375 } else if (resultVectorSize == 256 && srcVectorSize == 1024) {
1376 extOp = rewriter.create<xllvm::ExtI256I1024IntrOp>(
1377 loc, VectorType::get({8}, rewriter.getI32Type()),
1378 forceCastOperandsToSignature(
1379 rewriter, loc, operands,
1380 {VectorType::get({32}, rewriter.getI32Type()),
1381 rewriter.getI32Type()}));
1382 } else if (resultVectorSize == 128 && srcVectorSize == 512) {
1383 auto shiftOp = adaptor.getSource();
1384 if (op.getIndex() > 0) {
1385 auto undefOp = rewriter.create<xllvm::UndefV16I32IntrOp>(
1386 loc, VectorType::get({16}, rewriter.getI32Type()));
1387 auto stepCst = rewriter.create<LLVM::ConstantOp>(
1388 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
1389 auto shiftCst = rewriter.create<LLVM::ConstantOp>(
1390 loc, rewriter.getI32Type(),
1391 rewriter.getI32IntegerAttr(op.getIndex() * 16));
1392 SmallVector<Value> shiftOperands{adaptor.getSource(), undefOp, stepCst,
1393 shiftCst};
1394 // Right shift the source vector in index * 16 bytes (i.e. in index *
1395 // 128 bits). The integer index is expected to be 0 to 3.
1396 shiftOp = rewriter.create<xllvm::VectorShiftI512I512IntrOp>(
1397 loc, VectorType::get({16}, rewriter.getI32Type()),
1398 forceCastOperandsToSignature(
1399 rewriter, loc, shiftOperands,
1400 {VectorType::get({16}, rewriter.getI32Type()),
1401 VectorType::get({16}, rewriter.getI32Type()),
1402 rewriter.getI32Type(), rewriter.getI32Type()}));
1403 }
1404 // The underlying intrinsic takes a source vector and extract the lowest
1405 // 128-bit. i.e. it always extracts the input vector with index = 0.
1406 extOp = rewriter.create<xllvm::ExtI128I512IntrOp>(
1407 loc, VectorType::get({4}, rewriter.getI32Type()),
1408 forceCastOperandsToSignature(
1409 rewriter, loc, /*operands=*/{shiftOp},
1410 {VectorType::get({16}, rewriter.getI32Type())}));
1411 } else {
1412 op.emitWarning() << "aievec.ext with " << srcVectorSize
1413 << "-bit source, and " << resultVectorSize
1414 << "-bit result is not supported.\n";
1415 return failure();
1416 }
1417
1418 // create bitcast for result
1419 if (op.getResult().getType() != extOp.getType()) {
1420 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
1421 extOp);
1422 } else {
1423 rewriter.replaceOp(op, extOp);
1424 }
1425
1426 return success();
1427 }
1428};
1429
1431 : public mlir::ConvertOpToLLVMPattern<aievec::aie1::SelectOp> {
1432public:
1433 using ConvertOpToLLVMPattern<aievec::aie1::SelectOp>::ConvertOpToLLVMPattern;
1434
1435 static std::string getIntrinsicName(aievec::aie1::SelectOp op) {
1436 auto xbuffType = cast<VectorType>(op.getXbuff().getType());
1437 std::stringstream ss;
1438 ss << "llvm.aie.prim." << getVectorTypeString(xbuffType) << ".select";
1439 return ss.str();
1440 }
1441
1442 LogicalResult
1443 matchAndRewrite(aievec::aie1::SelectOp op, OpAdaptor adaptor,
1444 ConversionPatternRewriter &rewriter) const override {
1445 auto module = op->getParentOfType<ModuleOp>();
1446 MLIRContext *context = rewriter.getContext();
1447
1448 auto selectType = IntegerType::get(context, 32);
1449 auto startType = IntegerType::get(context, 32);
1450 auto offsetsType = VectorType::get({2}, IntegerType::get(context, 32));
1451 auto confType = VectorType::get({2}, IntegerType::get(context, 32));
1452
1453 // If the intrinsic declaration doesn't exist, create it
1454 std::string intrinsicName = getIntrinsicName(op);
1455 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(
1456 StringAttr::get(context, intrinsicName));
1457
1458 if (!func) {
1459 OpBuilder::InsertionGuard guard(rewriter);
1460 rewriter.setInsertionPointToStart(module.getBody());
1461 func = rewriter.create<LLVM::LLVMFuncOp>(
1462 rewriter.getUnknownLoc(), intrinsicName,
1463 LLVM::LLVMFunctionType::get(op.getResult().getType(),
1464 {op.getXbuff().getType(), selectType,
1465 startType, /* xstart */
1466 startType, /* ystart */
1467 offsetsType, /* xoffsets */
1468 offsetsType, /* yoffsets */
1469 confType}));
1470 }
1471
1472 // Parse the string attribute values
1473 uint32_t select = 0;
1474 BufferParams x = {};
1475 BufferParams y = {};
1476 BufferParams z = {};
1477
1478 op.getSelect().getAsInteger(0, select);
1479 op.getXstart().getAsInteger(0, x.start);
1480 op.getXoffsets().getAsInteger(0, x.offsets);
1481 op.getXoffsetsHi().getAsInteger(0, x.offsets_hi);
1482 op.getXsquare().getAsInteger(0, x.square);
1483 op.getYstart().getAsInteger(0, y.start);
1484 op.getYoffsets().getAsInteger(0, y.offsets);
1485 op.getYoffsetsHi().getAsInteger(0, y.offsets_hi);
1486 op.getYsquare().getAsInteger(0, y.square);
1487
1488 // Encode the configuration register
1489 uint32_t conf[2] = {0, 0};
1490 encodeConf(conf, x, z, false);
1491 conf[1] |= encodeSquare(y.square) << 21;
1492
1493 // Create the constants and replace the op
1494 auto selectVal = rewriter.create<LLVM::ConstantOp>(
1495 op->getLoc(), selectType, rewriter.getI32IntegerAttr(select));
1496 auto xstartVal = rewriter.create<LLVM::ConstantOp>(
1497 op->getLoc(), startType, rewriter.getI32IntegerAttr(x.start));
1498 auto ystartVal = rewriter.create<LLVM::ConstantOp>(
1499 op->getLoc(), startType, rewriter.getI32IntegerAttr(y.start));
1500 auto xoffsetsVal = rewriter.create<LLVM::ConstantOp>(
1501 op->getLoc(), offsetsType,
1502 rewriter.getI32VectorAttr({(int32_t)x.offsets, (int32_t)x.offsets_hi}));
1503 auto yoffsetsVal = rewriter.create<LLVM::ConstantOp>(
1504 op->getLoc(), offsetsType,
1505 rewriter.getI32VectorAttr({(int32_t)y.offsets, (int32_t)y.offsets_hi}));
1506 auto confVal = rewriter.create<LLVM::ConstantOp>(
1507 op->getLoc(), confType,
1508 rewriter.getI32VectorAttr({(int32_t)conf[0], (int32_t)conf[1]}));
1509 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
1510 op, func,
1511 ValueRange{op.getXbuff(), selectVal, xstartVal, ystartVal, xoffsetsVal,
1512 yoffsetsVal, confVal});
1513 return success();
1514 }
1515};
1516
1517class PackOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::PackOp> {
1518public:
1519 using ConvertOpToLLVMPattern<aievec::PackOp>::ConvertOpToLLVMPattern;
1520
1521 static std::string getIntrinsicName(aievec::PackOp op) {
1522 auto sourceType = cast<VectorType>(op.getSource().getType());
1523 std::stringstream ss;
1524 ss << "llvm.aie.pack." << getVectorTypeString(sourceType);
1525 return ss.str();
1526 }
1527
1528 LogicalResult
1529 matchAndRewrite(aievec::PackOp op, OpAdaptor adaptor,
1530 ConversionPatternRewriter &rewriter) const override {
1531 auto module = op->getParentOfType<ModuleOp>();
1532 MLIRContext *context = rewriter.getContext();
1533
1534 // If the intrinsic declaration doesn't exist, create it
1535 std::string intrinsicName = getIntrinsicName(op);
1536 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(
1537 StringAttr::get(context, intrinsicName));
1538
1539 if (!func) {
1540 OpBuilder::InsertionGuard guard(rewriter);
1541 rewriter.setInsertionPointToStart(module.getBody());
1542 func = rewriter.create<LLVM::LLVMFuncOp>(
1543 rewriter.getUnknownLoc(), intrinsicName,
1544 LLVM::LLVMFunctionType::get(op.getResult().getType(),
1545 {op.getSource().getType()}));
1546 }
1547
1548 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, func,
1549 ValueRange{op.getSource()});
1550 return success();
1551 }
1552};
1553
1555 : public mlir::ConvertOpToLLVMPattern<aievec::UnpackOp> {
1556public:
1557 using ConvertOpToLLVMPattern<aievec::UnpackOp>::ConvertOpToLLVMPattern;
1558
1559 LogicalResult
1560 matchAndRewrite(aievec::UnpackOp op, OpAdaptor adaptor,
1561 ConversionPatternRewriter &rewriter) const override {
1562 op.emitWarning() << "aie.unpack conversion is not implemented\n";
1563 return failure();
1564 }
1565};
1566
1568 : public mlir::ConvertOpToLLVMPattern<aievec::BroadcastOp> {
1569public:
1570 using ConvertOpToLLVMPattern<aievec::BroadcastOp>::ConvertOpToLLVMPattern;
1571
1572 LogicalResult
1573 matchAndRewrite(aievec::BroadcastOp op, OpAdaptor adaptor,
1574 ConversionPatternRewriter &rewriter) const override {
1575 op.emitWarning() << "aie.broadcast conversion is not implemented\n";
1576 return failure();
1577 }
1578};
1579
1580class MaxOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::MaxOp> {
1581public:
1582 using ConvertOpToLLVMPattern<aievec::MaxOp>::ConvertOpToLLVMPattern;
1583
1584 LogicalResult
1585 matchAndRewrite(aievec::MaxOp op, OpAdaptor adaptor,
1586 ConversionPatternRewriter &rewriter) const override {
1587 Location loc = op.getLoc();
1588
1589 VectorType resultType = cast<VectorType>(op.getResult().getType());
1590 Type resultScaTy = resultType.getElementType();
1591 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
1592 int resultLanes = getVectorLaneSize(resultType);
1593 int resultVectorSize = resultBitWidth * resultLanes;
1594
1595 // aievec.max op has the AllTypesMatch constraint on lhs/rhs/res
1596 if (resultVectorSize != 512) {
1597 op.emitWarning() << "aievec.max conversion with " << resultVectorSize
1598 << "-bit result is not supported.\n";
1599 return failure();
1600 }
1601
1602 // create xllvm intrinsic
1603 Value maxOp = nullptr;
1604 if (llvm::isa<IntegerType>(resultScaTy)) {
1605 // create constant for third operand `cmp`
1606 // Note: `cmp` is implicitly treated as `sign` to the vmax intrinsic
1607 auto cmpCst = rewriter.create<LLVM::ConstantOp>(
1608 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
1609 SmallVector<Value> operands{adaptor.getLhs(), adaptor.getRhs(), cmpCst};
1610 if (resultBitWidth == 8) {
1611 maxOp = rewriter.create<xllvm::VectorMaxLt8IntrOp>(
1612 loc,
1613 mlir::LLVM::LLVMStructType::getLiteral(
1614 rewriter.getContext(),
1615 {VectorType::get({64}, rewriter.getI8Type()),
1616 VectorType::get({2}, rewriter.getI32Type())}),
1617 forceCastOperandsToSignature(
1618 rewriter, loc, operands,
1619 {VectorType::get({64}, rewriter.getI8Type()),
1620 VectorType::get({64}, rewriter.getI8Type()),
1621 rewriter.getI32Type()}));
1622 } else if (resultBitWidth == 16) {
1623 maxOp = rewriter.create<xllvm::VectorMaxLt16IntrOp>(
1624 loc,
1625 mlir::LLVM::LLVMStructType::getLiteral(
1626 rewriter.getContext(),
1627 {VectorType::get({32}, rewriter.getI16Type()),
1628 rewriter.getI32Type()}),
1629 forceCastOperandsToSignature(
1630 rewriter, loc, operands,
1631 {VectorType::get({32}, rewriter.getI16Type()),
1632 VectorType::get({32}, rewriter.getI16Type()),
1633 rewriter.getI32Type()}));
1634 } else if (resultBitWidth == 32) {
1635 maxOp = rewriter.create<xllvm::VectorMaxLt32IntrOp>(
1636 loc,
1637 mlir::LLVM::LLVMStructType::getLiteral(
1638 rewriter.getContext(),
1639 {VectorType::get({16}, rewriter.getI32Type()),
1640 rewriter.getI32Type()}),
1641 forceCastOperandsToSignature(
1642 rewriter, loc, operands,
1643 {VectorType::get({16}, rewriter.getI32Type()),
1644 VectorType::get({16}, rewriter.getI32Type()),
1645 rewriter.getI32Type()}));
1646 }
1647 } else {
1648 if (resultBitWidth == 16) {
1649 maxOp = rewriter.create<xllvm::VectorMaxLtBf16IntrOp>(
1650 loc,
1651 mlir::LLVM::LLVMStructType::getLiteral(
1652 rewriter.getContext(),
1653 {VectorType::get({32}, rewriter.getBF16Type()),
1654 rewriter.getI32Type()}),
1655 forceCastOperandsToSignature(
1656 rewriter, loc, {adaptor.getLhs(), adaptor.getRhs()},
1657 {VectorType::get({32}, rewriter.getBF16Type()),
1658 VectorType::get({32}, rewriter.getBF16Type())}));
1659 }
1660 }
1661
1662 if (!maxOp) {
1663 // We have checked the lhs/rhs/res to be 512-bit vectors. Hence, a
1664 // possible failure here is due to unsupported element datatype.
1665 op.emitWarning() << "aievec.max conversion fails due to unsupported "
1666 "element data type.\n";
1667 return failure();
1668 }
1669
1670 // create llvm.extractvalue for the first element in the LLVMStruct
1671 rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(op, maxOp,
1672 /*position=*/0);
1673
1674 return success();
1675 }
1676};
1677
1678class MinOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::MinOp> {
1679public:
1680 using ConvertOpToLLVMPattern<aievec::MinOp>::ConvertOpToLLVMPattern;
1681
1682 LogicalResult
1683 matchAndRewrite(aievec::MinOp op, OpAdaptor adaptor,
1684 ConversionPatternRewriter &rewriter) const override {
1685 Location loc = op.getLoc();
1686
1687 VectorType resultType = cast<VectorType>(op.getResult().getType());
1688 Type resultScaTy = resultType.getElementType();
1689 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
1690 int resultLanes = getVectorLaneSize(resultType);
1691 int resultVectorSize = resultBitWidth * resultLanes;
1692
1693 // aievec.min op has the AllTypesMatch constraint on lhs/rhs/res
1694 if (resultVectorSize != 512) {
1695 op.emitWarning() << "aievec.min conversion with " << resultVectorSize
1696 << "-bit result is not supported.\n";
1697 return failure();
1698 }
1699
1700 // create xllvm intrinsic
1701 Value minOp = nullptr;
1702 if (llvm::isa<IntegerType>(resultScaTy)) {
1703 // create constant for third operand `cmp`
1704 // Note: `cmp` is implicitly treated as `sign` to the vmin intrinsic
1705 auto cmpCst = rewriter.create<LLVM::ConstantOp>(
1706 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
1707 SmallVector<Value> operands{adaptor.getLhs(), adaptor.getRhs(), cmpCst};
1708 if (resultBitWidth == 8) {
1709 minOp = rewriter.create<xllvm::VectorMinGe8IntrOp>(
1710 loc,
1711 mlir::LLVM::LLVMStructType::getLiteral(
1712 rewriter.getContext(),
1713 {VectorType::get({64}, rewriter.getI8Type()),
1714 VectorType::get({2}, rewriter.getI32Type())}),
1715 forceCastOperandsToSignature(
1716 rewriter, loc, operands,
1717 {VectorType::get({64}, rewriter.getI8Type()),
1718 VectorType::get({64}, rewriter.getI8Type()),
1719 rewriter.getI32Type()}));
1720 } else if (resultBitWidth == 16) {
1721 minOp = rewriter.create<xllvm::VectorMinGe16IntrOp>(
1722 loc,
1723 mlir::LLVM::LLVMStructType::getLiteral(
1724 rewriter.getContext(),
1725 {VectorType::get({32}, rewriter.getI16Type()),
1726 rewriter.getI32Type()}),
1727 forceCastOperandsToSignature(
1728 rewriter, loc, operands,
1729 {VectorType::get({32}, rewriter.getI16Type()),
1730 VectorType::get({32}, rewriter.getI16Type()),
1731 rewriter.getI32Type()}));
1732 } else if (resultBitWidth == 32) {
1733 minOp = rewriter.create<xllvm::VectorMinGe32IntrOp>(
1734 loc,
1735 mlir::LLVM::LLVMStructType::getLiteral(
1736 rewriter.getContext(),
1737 {VectorType::get({16}, rewriter.getI32Type()),
1738 rewriter.getI32Type()}),
1739 forceCastOperandsToSignature(
1740 rewriter, loc, operands,
1741 {VectorType::get({16}, rewriter.getI32Type()),
1742 VectorType::get({16}, rewriter.getI32Type()),
1743 rewriter.getI32Type()}));
1744 }
1745 } else {
1746 if (resultBitWidth == 16) {
1747 minOp = rewriter.create<xllvm::VectorMinGeBf16IntrOp>(
1748 loc,
1749 mlir::LLVM::LLVMStructType::getLiteral(
1750 rewriter.getContext(),
1751 {VectorType::get({32}, rewriter.getBF16Type()),
1752 rewriter.getI32Type()}),
1753 forceCastOperandsToSignature(
1754 rewriter, loc, {adaptor.getLhs(), adaptor.getRhs()},
1755 {VectorType::get({32}, rewriter.getBF16Type()),
1756 VectorType::get({32}, rewriter.getBF16Type())}));
1757 }
1758 }
1759
1760 if (!minOp) {
1761 // We have checked the lhs/rhs/res to be 512-bit vectors. Hence, a
1762 // possible failure here is due to unsupported element datatype.
1763 op.emitWarning() << "aievec.min conversion fails due to unsupported "
1764 "element data type.\n";
1765 return failure();
1766 }
1767
1768 // create llvm.extractvalue for the first element in the LLVMStruct
1769 rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(op, minOp,
1770 /*position=*/0);
1771
1772 return success();
1773 }
1774};
1775
1777 : public mlir::ConvertOpToLLVMPattern<aievec::BroadcastScalarOp> {
1778public:
1779 using ConvertOpToLLVMPattern<
1780 aievec::BroadcastScalarOp>::ConvertOpToLLVMPattern;
1781
1782 LogicalResult
1783 matchAndRewrite(aievec::BroadcastScalarOp op, OpAdaptor adaptor,
1784 ConversionPatternRewriter &rewriter) const override {
1785 Location loc = op.getLoc();
1786
1787 Value result = op.getResult();
1788 VectorType resultType = cast<VectorType>(result.getType());
1789 Type resultScaTy = resultType.getElementType();
1790 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
1791 int resultLanes = getVectorLaneSize(resultType);
1792 int resultVectorSize = resultBitWidth * resultLanes;
1793
1794 if (resultVectorSize != 512) {
1795 op.emitWarning()
1796 << "aievec.broadcast_scalar conversion with result vector size "
1797 << resultVectorSize << " is not implemented.\n";
1798 return failure();
1799 }
1800
1801 // Integer types
1802 if (llvm::isa<IntegerType>(resultScaTy)) {
1803 Value src = adaptor.getSource();
1804 Type srcType = src.getType();
1805 unsigned srcBitWidth = srcType.getIntOrFloatBitWidth();
1806
1807 if (srcBitWidth < 32) {
1808 src = rewriter.create<LLVM::SExtOp>(loc, rewriter.getI32Type(), src);
1809 }
1810
1811 if (resultBitWidth == 8) {
1812 rewriter.replaceOpWithNewOp<xllvm::VectorBroadcast8I512IntrOp>(
1813 op, VectorType::get({64}, rewriter.getI8Type()), src);
1814 } else if (resultBitWidth == 16) {
1815 rewriter.replaceOpWithNewOp<xllvm::VectorBroadcast16I512IntrOp>(
1816 op, VectorType::get({32}, rewriter.getI16Type()), src);
1817 } else if (resultBitWidth == 32) {
1818 rewriter.replaceOpWithNewOp<xllvm::VectorBroadcast32I512IntrOp>(
1819 op, VectorType::get({16}, rewriter.getI32Type()), src);
1820 } else {
1821 op.emitWarning()
1822 << "aievec.broadcast_scalar conversion with result bitwidth "
1823 << resultBitWidth << " is not implemented.\n";
1824 return failure();
1825 }
1826 } else {
1827 // Float types
1828 if (resultBitWidth == 16) {
1829 rewriter.replaceOpWithNewOp<xllvm::VectorBroadcast16BF512IntrOp>(
1830 op, VectorType::get({32}, rewriter.getBF16Type()),
1831 adaptor.getSource());
1832 } else if (resultBitWidth == 32) {
1833 rewriter.replaceOpWithNewOp<xllvm::VectorBroadcastfloatI512IntrOp>(
1834 op, VectorType::get({16}, rewriter.getF32Type()),
1835 adaptor.getSource());
1836 } else {
1837 op.emitWarning()
1838 << "aievec.broadcast_scalar conversion with result bitwidth "
1839 << resultBitWidth << " is not implemented.\n";
1840 return failure();
1841 }
1842 }
1843
1844 return success();
1845 }
1846};
1847
1848class ShiftOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::ShiftOp> {
1849public:
1850 using ConvertOpToLLVMPattern<aievec::ShiftOp>::ConvertOpToLLVMPattern;
1851
1852 LogicalResult
1853 matchAndRewrite(aievec::ShiftOp op, OpAdaptor adaptor,
1854 ConversionPatternRewriter &rewriter) const override {
1855 Location loc = op.getLoc();
1856
1857 Value result = op.getResult();
1858 VectorType resultType = cast<VectorType>(result.getType());
1859 Type resultScaTy = resultType.getElementType();
1860 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
1861 int resultLanes = getVectorLaneSize(resultType);
1862 int resultVectorSize = resultBitWidth * resultLanes;
1863
1864 if (resultVectorSize != 512) {
1865 op.emitWarning() << "aievec.shift conversion with result vector size "
1866 << resultVectorSize << " is not implemented.\n";
1867 return failure();
1868 }
1869
1870 // assume step is always zero
1871 auto stepCst = rewriter.create<LLVM::ConstantOp>(
1872 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
1873
1874 // create xllvm intrinsic
1875 Value shiftOp = nullptr;
1876 SmallVector<Value> operands(
1877 {adaptor.getLhs(), adaptor.getRhs(), stepCst, adaptor.getShift()});
1878 if (llvm::isa<IntegerType>(resultScaTy)) {
1879 // Integer types
1880 shiftOp = rewriter.create<xllvm::VectorShiftI512I512IntrOp>(
1881 loc, VectorType::get({16}, rewriter.getI32Type()),
1882 forceCastOperandsToSignature(
1883 rewriter, loc, operands,
1884 {VectorType::get({16}, rewriter.getI32Type()),
1885 VectorType::get({16}, rewriter.getI32Type()),
1886 rewriter.getI32Type(), rewriter.getI32Type()}));
1887 } else {
1888 // Float types
1889 shiftOp = rewriter.create<xllvm::VectorShiftBF512BF512IntrOp>(
1890 loc, VectorType::get({32}, rewriter.getBF16Type()),
1891 forceCastOperandsToSignature(
1892 rewriter, loc, operands,
1893 {VectorType::get({32}, rewriter.getBF16Type()),
1894 VectorType::get({32}, rewriter.getBF16Type()),
1895 rewriter.getI32Type(), rewriter.getI32Type()}));
1896 }
1897
1898 // create bitcast for result
1899 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
1900 shiftOp);
1901
1902 return success();
1903 }
1904};
1905
1907 : public mlir::ConvertOpToLLVMPattern<aievec::ExtElemOp> {
1908public:
1909 using ConvertOpToLLVMPattern<aievec::ExtElemOp>::ConvertOpToLLVMPattern;
1910
1911 LogicalResult
1912 matchAndRewrite(aievec::ExtElemOp op, OpAdaptor adaptor,
1913 ConversionPatternRewriter &rewriter) const override {
1914 Location loc = op.getLoc();
1915
1916 Type resultType = op.getResult().getType();
1917 unsigned resultBitWidth = resultType.getIntOrFloatBitWidth();
1918
1919 Value src = adaptor.getSource();
1920 VectorType srcType = cast<VectorType>(src.getType());
1921 Type srcScalarType = srcType.getElementType();
1922 unsigned srcBitWidth = srcScalarType.getIntOrFloatBitWidth();
1923 int srcLanes = getVectorLaneSize(srcType);
1924 int srcVectorSize = srcBitWidth * srcLanes;
1925
1926 if (srcVectorSize != 512) {
1927 op.emitWarning() << "aievec.ext_elem conversion with source vector size "
1928 << srcVectorSize << " is not supported.\n";
1929 return failure();
1930 }
1931
1932 // create constant for sign
1933 auto signCst = rewriter.create<LLVM::ConstantOp>(
1934 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
1935
1936 // create xllvm intrinsic
1937 Value extElemOp = nullptr;
1938 SmallVector<Value> operands(
1939 {adaptor.getSource(), adaptor.getIndex(), signCst});
1940 if (resultBitWidth == 8) {
1941 extElemOp = rewriter.create<xllvm::VectorExtractElem8I512IntrOp>(
1942 loc, rewriter.getI32Type(),
1943 forceCastOperandsToSignature(
1944 rewriter, loc, operands,
1945 {VectorType::get({64}, rewriter.getI8Type()),
1946 rewriter.getI32Type(), rewriter.getI32Type()}));
1947 } else if (resultBitWidth == 16) {
1948 extElemOp = rewriter.create<xllvm::VectorExtractElem16I512IntrOp>(
1949 loc, rewriter.getI32Type(),
1950 forceCastOperandsToSignature(
1951 rewriter, loc, operands,
1952 {VectorType::get({32}, rewriter.getI16Type()),
1953 rewriter.getI32Type(), rewriter.getI32Type()}));
1954 } else if (resultBitWidth == 32) {
1955 extElemOp = rewriter.create<xllvm::VectorExtractElem32I512IntrOp>(
1956 loc, rewriter.getI32Type(),
1957 forceCastOperandsToSignature(
1958 rewriter, loc, operands,
1959 {VectorType::get({16}, rewriter.getI32Type()),
1960 rewriter.getI32Type(), rewriter.getI32Type()}));
1961 } else {
1962 op.emitWarning() << "aievec.ext_elem conversion with result bit width "
1963 << resultBitWidth << " is not implemented.\n";
1964 return failure();
1965 }
1966
1967 // create truncation op (and bitcast op)
1968 if (llvm::isa<IntegerType>(resultType)) {
1969 if (resultBitWidth < 32) {
1970 rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, resultType, extElemOp);
1971 } else {
1972 rewriter.replaceOp(op, extElemOp);
1973 }
1974 } else {
1975 // Float types
1976 if (resultBitWidth == 16) {
1977 extElemOp = rewriter.create<LLVM::TruncOp>(loc, rewriter.getI16Type(),
1978 extElemOp);
1979 }
1980 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, resultType, extElemOp);
1981 }
1982
1983 return success();
1984 }
1985};
1986
1988 : public mlir::ConvertOpToLLVMPattern<aievec::FMAElemOp> {
1989public:
1990 using ConvertOpToLLVMPattern<aievec::FMAElemOp>::ConvertOpToLLVMPattern;
1991
1992 LogicalResult
1993 matchAndRewrite(aievec::FMAElemOp fmaOp, OpAdaptor adaptor,
1994 ConversionPatternRewriter &rewriter) const override {
1995 auto loc = fmaOp.getLoc();
1996 auto lhs = adaptor.getLhs();
1997 auto rhs = adaptor.getRhs();
1998 auto acc = adaptor.getAcc();
1999 auto lhsTy = cast<VectorType>(lhs.getType());
2000 auto rhsTy = cast<VectorType>(rhs.getType());
2001 auto accTy = cast<VectorType>(acc.getType());
2002 auto flatLhsTy = getFlattenedVectorType(lhsTy);
2003 auto flatRhsTy = getFlattenedVectorType(rhsTy);
2004 auto flatAccTy = getFlattenedVectorType(accTy);
2005
2006 // Flatten operands, if needed
2007 if (lhsTy != flatLhsTy)
2008 lhs = rewriter.create<vector::ShapeCastOp>(loc, flatLhsTy, lhs);
2009 if (rhsTy != flatRhsTy)
2010 rhs = rewriter.create<vector::ShapeCastOp>(loc, flatRhsTy, rhs);
2011 if (accTy != flatAccTy)
2012 acc = rewriter.create<vector::ShapeCastOp>(loc, flatAccTy, acc);
2013
2014 // Build vmac configuration constant
2015 Type i32ty = rewriter.getI32Type();
2016 auto confCst = rewriter.create<LLVM::ConstantOp>(
2017 loc, i32ty,
2018 rewriter.getI32IntegerAttr(aiev2_vmac_compute_control(
2019 /*sgn_x=*/0, /*sgn_y=*/0, /*amode=*/2, /*bmode=*/3,
2020 /*variant=*/1, /*zero_acc=*/0, /*shift16=*/0,
2021 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
2022 /*sub_mask=*/0)));
2023
2024 // Insert vmac intrinsic
2025 auto v32bf16Ty = VectorType::get({32}, rewriter.getBF16Type());
2026 auto v8i64Ty = VectorType::get({8}, rewriter.getI64Type());
2027 auto macIntrOp = rewriter.create<xllvm::MacConfBF16IntrOp>(
2028 loc, v8i64Ty,
2029 forceCastOperandsToSignature(rewriter, loc, {lhs, rhs, acc, confCst},
2030 {v32bf16Ty, v32bf16Ty, v8i64Ty, i32ty}));
2031
2032 // Recast/Reshape result
2033 auto resVal =
2034 forceCastValueToType(rewriter, loc, macIntrOp.getResult(), flatAccTy);
2035 if (flatAccTy != accTy)
2036 resVal = rewriter.create<vector::ShapeCastOp>(loc, accTy, resVal);
2037
2038 rewriter.replaceOp(fmaOp, resVal);
2039 return success();
2040 }
2041};
2042
2044 : public mlir::ConvertOpToLLVMPattern<aievec::MatMulOp> {
2045 using ConvertOpToLLVMPattern<aievec::MatMulOp>::ConvertOpToLLVMPattern;
2046
2047 struct DecodedMatMulOp {
2048 typedef enum { I32, I64, BF16 } Kind;
2049
2050 Kind kind;
2051 Value lhs;
2052 Value rhs;
2053 Value acc;
2054 int conf;
2055 };
2056
2057 static DecodedMatMulOp decodeMatMulOp(OpAdaptor op) {
2058 Value lhs = op.getLhs();
2059 Value rhs = op.getRhs();
2060 Value acc = op.getAcc();
2061 auto accVecTy = cast<VectorType>(acc.getType());
2062 if (isa<Float32Type>(accVecTy.getElementType()))
2063 // <4x8xbf16> x <8x4xbf16> + <4x4xf32>
2064 return {DecodedMatMulOp::Kind::BF16, lhs, rhs, acc,
2065 aiev2_vmac_compute_control(
2066 /*sgn_x=*/0, /*sgn_y=*/0, /*amode=*/2, /*bmode=*/3,
2067 /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0,
2068 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
2069 /*sub_mask=*/0)};
2070
2071 int signX = 0, signY = 0;
2072 auto lhsVecTy = cast<VectorType>(lhs.getType());
2073 auto lhsScaTy = cast<IntegerType>(lhsVecTy.getElementType());
2074 if (auto extSIOp = lhs.getDefiningOp<arith::ExtSIOp>()) {
2075 lhs = extSIOp.getIn();
2076 lhsVecTy = cast<VectorType>(lhs.getType());
2077 lhsScaTy = cast<IntegerType>(lhsVecTy.getElementType());
2078 signX = 1;
2079 } else if (auto extUIOp = lhs.getDefiningOp<arith::ExtUIOp>()) {
2080 lhs = extUIOp.getIn();
2081 lhsVecTy = cast<VectorType>(lhs.getType());
2082 lhsScaTy = cast<IntegerType>(lhsVecTy.getElementType());
2083 } else {
2084 // NOTE: We're choosing 'signed' by default
2085 if (!lhsScaTy.isUnsigned())
2086 signX = 1;
2087 }
2088 auto lhsShape = lhsVecTy.getShape();
2089
2090 auto rhsVecTy = cast<VectorType>(rhs.getType());
2091 auto rhsScaTy = cast<IntegerType>(rhsVecTy.getElementType());
2092 if (auto extSIOp = rhs.getDefiningOp<arith::ExtSIOp>()) {
2093 rhs = extSIOp.getIn();
2094 rhsVecTy = cast<VectorType>(rhs.getType());
2095 rhsScaTy = cast<IntegerType>(rhsVecTy.getElementType());
2096 signY = 1;
2097 } else if (auto extUIOp = rhs.getDefiningOp<arith::ExtUIOp>()) {
2098 rhs = extUIOp.getIn();
2099 rhsVecTy = cast<VectorType>(rhs.getType());
2100 rhsScaTy = cast<IntegerType>(rhsVecTy.getElementType());
2101 } else {
2102 // NOTE: We're choosing 'signed' by default
2103 if (!rhsScaTy.isUnsigned())
2104 signY = 1;
2105 }
2106
2107 unsigned lhsBitWidth = lhsScaTy.getWidth();
2108 unsigned rhsBitWidth = rhsScaTy.getWidth();
2109 auto accScaTy = cast<IntegerType>(accVecTy.getElementType());
2110 unsigned accBitWidth = accScaTy.getWidth();
2111 if (accBitWidth == 32) {
2112 if (lhsBitWidth == 8) {
2113 if (rhsBitWidth == 4) {
2114 // <4x16xi8> x <16x8xi4> + <4x8xi32>
2115 return {DecodedMatMulOp::Kind::I32, lhs, rhs, acc,
2116 aiev2_vmac_compute_control(
2117 /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/0,
2118 /*bmode=*/0,
2119 /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0,
2120 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
2121 /*sub_mask=*/0)};
2122 } else {
2123 // <4x8xi8> x <8x8xi8> + <4x8xi32>
2124 return {DecodedMatMulOp::Kind::I32, lhs, rhs, acc,
2125 aiev2_vmac_compute_control(
2126 /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/0,
2127 /*bmode=*/1,
2128 /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0,
2129 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
2130 /*sub_mask=*/0)};
2131 }
2132 } else {
2133 if (rhsBitWidth == 8) {
2134 // <4x4xi16> x <4x8xi8> + <4x8xi32>
2135 return {DecodedMatMulOp::Kind::I32, lhs, rhs, acc,
2136 aiev2_vmac_compute_control(
2137 /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/0,
2138 /*bmode=*/2,
2139 /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0,
2140 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
2141 /*sub_mask=*/0)};
2142 } else {
2143 // <4x2xi16> x <2x8xi16> + <4x8xi32>
2144 return {DecodedMatMulOp::Kind::I32, lhs, rhs, acc,
2145 aiev2_vmac_compute_control(
2146 /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/0,
2147 /*bmode=*/3,
2148 /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0,
2149 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
2150 /*sub_mask=*/0)};
2151 }
2152 }
2153 }
2154
2155 if (lhsBitWidth == 16) {
2156 if (rhsBitWidth == 8) {
2157 if (lhsShape == ArrayRef<int64_t>({2, 8})) {
2158 // <2x8xi16> x <8x8xi8> + <2x8xi64>
2159 return {DecodedMatMulOp::Kind::I64, lhs, rhs, acc,
2160 aiev2_vmac_compute_control(
2161 /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/1,
2162 /*bmode=*/2,
2163 /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0,
2164 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
2165 /*sub_mask=*/0)};
2166 }
2167 // <4x8xi16> x <8x4xi8> + <4x4xi64>
2168 return {DecodedMatMulOp::Kind::I64, lhs, rhs, acc,
2169 aiev2_vmac_compute_control(
2170 /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/1, /*bmode=*/2,
2171 /*variant=*/1, /*zero_acc=*/0, /*shift16=*/0,
2172 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
2173 /*sub_mask=*/0)};
2174 }
2175 if (lhsShape == ArrayRef<int64_t>({2, 4})) {
2176 // <2x4xi16> x <4x8xi16> + <2x8xi64>
2177 return {DecodedMatMulOp::Kind::I64, lhs, rhs, acc,
2178 aiev2_vmac_compute_control(
2179 /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/1, /*bmode=*/3,
2180 /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0,
2181 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
2182 /*sub_mask=*/0)};
2183 }
2184 // <4x4xi16> x <4x4xi16> + <4x4xi64>
2185 return {DecodedMatMulOp::Kind::I64, lhs, rhs, acc,
2186 aiev2_vmac_compute_control(
2187 /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/1, /*bmode=*/3,
2188 /*variant=*/1, /*zero_acc=*/0, /*shift16=*/0,
2189 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
2190 /*sub_mask=*/0)};
2191 }
2192 // <4x2xi32> x <2x4xi16> + <4x4xi64>
2193 return {DecodedMatMulOp::Kind::I64, lhs, rhs, acc,
2194 aiev2_vmac_compute_control(
2195 /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/1, /*bmode=*/0,
2196 /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0,
2197 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
2198 /*sub_mask=*/0)};
2199 }
2200
2201 LogicalResult
2202 matchAndRewrite(aievec::MatMulOp op, OpAdaptor adaptor,
2203 ConversionPatternRewriter &rewriter) const override {
2204 auto decodedMatMulOp = decodeMatMulOp(adaptor);
2205
2206 Location loc = op.getLoc();
2207 // Flatten the inputs
2208 auto lhsFlattenedVecTy =
2209 getFlattenedVectorType(cast<VectorType>(decodedMatMulOp.lhs.getType()));
2210 decodedMatMulOp.lhs = rewriter.create<vector::ShapeCastOp>(
2211 loc, lhsFlattenedVecTy, decodedMatMulOp.lhs);
2212 auto rhsFlattenedVecTy =
2213 getFlattenedVectorType(cast<VectorType>(decodedMatMulOp.rhs.getType()));
2214 decodedMatMulOp.rhs = rewriter.create<vector::ShapeCastOp>(
2215 loc, rhsFlattenedVecTy, decodedMatMulOp.rhs);
2216 auto accFlattenedVecTy =
2217 getFlattenedVectorType(cast<VectorType>(decodedMatMulOp.acc.getType()));
2218 decodedMatMulOp.acc = rewriter.create<vector::ShapeCastOp>(
2219 loc, accFlattenedVecTy, decodedMatMulOp.acc);
2220
2221 Type i32ty = rewriter.getI32Type();
2222 auto confCst = rewriter.create<LLVM::ConstantOp>(
2223 loc, i32ty, rewriter.getI32IntegerAttr(decodedMatMulOp.conf));
2224 SmallVector<Value> operands({decodedMatMulOp.lhs, decodedMatMulOp.rhs,
2225 decodedMatMulOp.acc, confCst});
2226 Value matMulResVal;
2227 if (decodedMatMulOp.kind == DecodedMatMulOp::Kind::BF16)
2228 matMulResVal =
2229 rewriter
2230 .create<xllvm::MacConfBF16IntrOp>(
2231 loc, VectorType::get({8}, rewriter.getI64Type()),
2232 forceCastOperandsToSignature(
2233 rewriter, loc, operands,
2234 {VectorType::get({32}, rewriter.getBF16Type()),
2235 VectorType::get({32}, rewriter.getBF16Type()),
2236 VectorType::get({8}, rewriter.getI64Type()), i32ty}))
2237 .getResult();
2238 else {
2239 SmallVector<Type> intrFuncSig(
2240 {VectorType::get({64}, rewriter.getI8Type()),
2241 VectorType::get({16}, i32ty),
2242 VectorType::get({16}, rewriter.getI64Type()), i32ty});
2243 VectorType v16xi64ty = VectorType::get({16}, rewriter.getI64Type());
2244 if (decodedMatMulOp.kind == DecodedMatMulOp::Kind::I32)
2245 matMulResVal = rewriter
2246 .create<xllvm::MacConfAcc32IntrOp>(
2247 loc, v16xi64ty,
2248 forceCastOperandsToSignature(
2249 rewriter, loc, operands, intrFuncSig))
2250 .getResult();
2251 else
2252 matMulResVal = rewriter
2253 .create<xllvm::MacConfAcc64IntrOp>(
2254 loc, v16xi64ty,
2255 forceCastOperandsToSignature(
2256 rewriter, loc, operands, intrFuncSig))
2257 .getResult();
2258 }
2259
2260 auto castFromAcc =
2261 bitcastValueToType(rewriter, loc, matMulResVal, accFlattenedVecTy);
2262
2263 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(),
2264 castFromAcc);
2265
2266 return success();
2267 }
2268};
2269
2270// This pattern folds aievec.cast op. For AIE2, the accumulators are in 32/64
2271// bits, and the vectors are in 4/8/16/32 bits. Hence, we don't have to
2272// explicitly express the casting between accumulators and vectors at the LLVM
2273// dialect level. The backend LLVM compiler will decide the correct accumulator
2274// or vector registers given the ops and intrinsics.
2275class FoldAIECastOps : public mlir::ConvertOpToLLVMPattern<aievec::CastOp> {
2276 using ConvertOpToLLVMPattern<aievec::CastOp>::ConvertOpToLLVMPattern;
2277
2278 LogicalResult
2279 matchAndRewrite(aievec::CastOp castOp, OpAdaptor adaptor,
2280 ConversionPatternRewriter &rewriter) const override {
2281 rewriter.replaceOp(castOp, adaptor.getSource());
2282 return success();
2283 }
2284};
2285
2287 : public mlir::ConvertOpToLLVMPattern<aievec::ShuffleOp> {
2288 using ConvertOpToLLVMPattern<aievec::ShuffleOp>::ConvertOpToLLVMPattern;
2289
2290 LogicalResult
2291 matchAndRewrite(aievec::ShuffleOp shuffleOp, OpAdaptor adaptor,
2292 ConversionPatternRewriter &rewriter) const override {
2293 auto loc = shuffleOp.getLoc();
2294 auto lhs = adaptor.getLhs();
2295 auto rhs = adaptor.getRhs();
2296 auto i32ty = rewriter.getI32Type();
2297 auto v16xi32ty = VectorType::get({16}, i32ty);
2298 if (!rhs)
2299 rhs = rewriter.create<xllvm::UndefV16I32IntrOp>(loc, v16xi32ty);
2300
2301 auto modeAttrVal =
2302 rewriter
2303 .create<LLVM::ConstantOp>(loc, i32ty,
2304 static_cast<int32_t>(shuffleOp.getMode()))
2305 .getResult();
2306 auto vShuffleVal = rewriter
2307 .create<xllvm::VectorShuffleIntrOp>(
2308 loc, v16xi32ty,
2309 forceCastOperandsToSignature(
2310 rewriter, loc,
2311 /*operands=*/{lhs, rhs, modeAttrVal},
2312 /*signature=*/{v16xi32ty, v16xi32ty, i32ty}))
2313 .getResult();
2314
2315 vShuffleVal = forceCastValueToType(rewriter, loc, vShuffleVal,
2316 shuffleOp.getResult().getType());
2317
2318 rewriter.replaceOp(shuffleOp, vShuffleVal);
2319
2320 return success();
2321 }
2322};
2323
2325 mlir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns,
2326 Aie2Fp32Emulation aie2Fp32EmulationOption) {
2327 // clang-format off
2328 patterns.add<AddOpConversion,
2349 ShuffleOpConversion>(converter);
2350 patterns.add<MulElemOpConversion>(converter, aie2Fp32EmulationOption);
2351 // clang-format on
2352}
2353
2355 : ConvertAIEVecToLLVMBase<ConvertAIEVecToLLVMPass> {
2356 void runOnOperation() override {
2357 RewritePatternSet patterns(&getContext());
2358 LLVMTypeConverter converter(&getContext());
2359
2360 // Don't convert vector types, we want to handle multi-dimensional
2361 // vector on our own.
2362 converter.addConversion(
2363 [&](VectorType type) -> std::optional<Type> { return type; });
2364
2365 populateAIEVecToLLVMConversionPatterns(converter, patterns,
2366 aie2Fp32Emulation);
2367
2368 LLVMConversionTarget target(getContext());
2369 target.addIllegalDialect<xilinx::aievec::AIEVecDialect,
2370 xilinx::aievec::aie1::AIEVecAIE1Dialect>();
2371 target.addLegalDialect<arith::ArithDialect, vector::VectorDialect,
2372 xilinx::xllvm::XLLVMDialect, ub::UBDialect>();
2373 if (failed(applyPartialConversion(getOperation(), target,
2374 std::move(patterns))))
2375 signalPassFailure();
2376 }
2377};
2378
2379std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
2381 return std::make_unique<ConvertAIEVecToLLVMPass>();
2382}
2383
2384} // namespace xilinx::aievec
LogicalResult matchAndRewrite(aievec::aie1::AddOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::BroadcastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::BroadcastScalarOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ConcatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ExtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ExtElemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::FMAElemOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::aie1::FMAOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::MaxOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::MinOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult convertToEmulatedFP32MulElem(aievec::MulElemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
LogicalResult convertToEmulatedI32MulElem(aievec::MulElemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
MulElemOpConversion(const LLVMTypeConverter &typeConverter, Aie2Fp32Emulation aie2Fp32EmulationOption)
LogicalResult matchAndRewrite(aievec::MulElemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static DecodedMulElemOp decodeMulElemOp(OpAdaptor op)
LogicalResult matchAndRewrite(aievec::aie1::MulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::PackOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static std::string getIntrinsicName(aievec::PackOp op)
LogicalResult matchAndRewrite(aievec::SRSOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static std::string getIntrinsicName(aievec::aie1::SelectOp op)
LogicalResult matchAndRewrite(aievec::aie1::SelectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ShiftOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::aie1::SubOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static std::string getIntrinsicName(aievec::UPDOp op, int loadSize)
LogicalResult matchAndRewrite(aievec::UPDOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::UPSOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::UnpackOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
mlir::VectorType getFlattenedVectorType(mlir::VectorType vecTy)
int32_t getVectorSizeInBits(mlir::VectorType type)
Definition AIEVecUtils.h:66
unsigned getVectorLaneSize(mlir::VectorType type)
Definition AIEVecUtils.h:55
uint32_t encodeSquare(uint32_t square)
void encodeConf(uint32_t conf[2], const BufferParams &x, const BufferParams &z, bool sub)
void populateAIEVecToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns)
std::string getMulOrFMAIntrinsicName(Operation *op)
std::unique_ptr< mlir::OperationPass< mlir::ModuleOp > > createConvertAIEVecToLLVMPass()
std::string getVectorTypeString(VectorType type, bool abbrev=false, bool acc=false)