MLIR-AIE
AIEVecToLLVM.cpp
Go to the documentation of this file.
1//===- AIEVecToLLVM.cpp - AIEVec to LLVM dialect conversion ---------------===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2022 Xilinx Inc.
8// (c) Copyright 2024 Advanced Micro Devices Inc.
9//
10//===----------------------------------------------------------------------===//
11
12#include "../PassDetail.h"
13
20#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
21#include "mlir/Conversion/LLVMCommon/Pattern.h"
22#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
23#include "mlir/IR/TypeUtilities.h"
24#include <sstream>
25
26using namespace mlir;
27
28namespace xilinx::aievec {
29
30inline static Value bitcastValueToType(OpBuilder &builder, Location loc,
31 Value val, Type dstTy) {
32 return builder.create<LLVM::BitcastOp>(loc, dstTy, val).getResult();
33}
34
35// This function emits the instructions required to widen a 128b input vector
36// into a 512b encoded as a vector<16xi32>. It first bitcasts it to a
37// vector<4xi32> to respect the intrinsic signature.
38inline static Value widen128bVectorValueTo512b(OpBuilder &builder, Location loc,
39 Value val) {
40 return builder
41 .create<xllvm::VectorSetI512I128IntrOp>(
42 loc, VectorType::get({16}, builder.getI32Type()),
43 bitcastValueToType(builder, loc, val,
44 VectorType::get({4}, builder.getI32Type())))
45 .getResult();
46}
47
48// This function emits the instructions required to widen a 256b input vector
49// into a 512b encoded as a vector<16xi32>. It first bitcasts it to a
50// vector<8xi32> to respect the intrinsic signature. It will also materialize
51// a constant 0, used as an insertion index.
52inline static Value widen256bVectorValueTo512b(OpBuilder &builder, Location loc,
53 Value val) {
54 auto cst0 =
55 builder.create<LLVM::ConstantOp>(loc, builder.getI32Type(), (int32_t)0);
56 return builder
57 .create<xllvm::VectorSetI512I256IntrOp>(
58 loc, VectorType::get({16}, builder.getI32Type()),
59 bitcastValueToType(builder, loc, val,
60 VectorType::get({8}, builder.getI32Type())),
61 cst0)
62 .getResult();
63}
64
65// This function emits the sequence of operations that forces a value into a
66// specific type. This may include widening vectors to match a specific bit
67// length.
68static Value forceCastValueToType(OpBuilder &builder, Location loc, Value val,
69 Type type) {
70 auto valTy = val.getType();
71 if (valTy == type)
72 return val;
73 auto srcVecTy = dyn_cast<VectorType>(valTy);
74 if (srcVecTy) {
75 auto dstVecTy = dyn_cast<VectorType>(type);
76 assert(dstVecTy && "vector values cannot be forced into a non-vector type");
77 assert(srcVecTy.getRank() == 1 && dstVecTy.getRank() == 1 &&
78 "only flat 1D vectors can be force casted");
79 int64_t dstVecLength =
80 dstVecTy.getElementTypeBitWidth() * dstVecTy.getShape()[0];
81 int64_t srcVecLength =
82 srcVecTy.getElementTypeBitWidth() * srcVecTy.getShape()[0];
83 if (srcVecLength != dstVecLength) {
84 assert(srcVecLength < dstVecLength &&
85 "only widening forced casts are supported");
86 assert(dstVecLength == 512 &&
87 (srcVecLength == 128 || srcVecLength == 256) &&
88 "only 128b to 512b and 256b to 512b forced casts are supported");
89 if (srcVecLength == 128)
90 val = widen128bVectorValueTo512b(builder, loc, val);
91 else
92 val = widen256bVectorValueTo512b(builder, loc, val);
93 }
94 }
95 return bitcastValueToType(builder, loc, val, type);
96}
97
98// This function emits the sequence of operations that forces a range of values
99// to match the signature specified by the TypeRange. It can be used to convert
100// the parameters of an op being converted to the types accepted by an
101// intrinsic with a fixed signature that treats its inputs as "bags of bits".
102static SmallVector<Value> forceCastOperandsToSignature(OpBuilder &builder,
103 Location loc,
104 ValueRange operands,
105 TypeRange signature) {
106 return llvm::to_vector(llvm::map_range(
107 llvm::zip_equal(operands, signature), [&](auto &&vt) -> Value {
108 return forceCastValueToType(builder, loc, std::get<0>(vt),
109 std::get<1>(vt));
110 }));
111}
112
114 uint32_t start;
115 uint32_t offsets;
116 uint32_t offsets_hi;
117 uint32_t step;
118 uint32_t square;
119};
120
121// sgn_x: Sign mask of matrix X. If it is one matrix X is interpreted as
122// signed, else it treated as unsigned.
123// sgn_y: Sign mask of matrix Y. If it is one matrix Y is interpreted as
124// signed, else it treated as unsigned.
125// amode/bmode/variant: config acc width, mul precision, and mul mode
126// zero_acc: Zeroing of acc1. If it is one then acc1 is zeroed.
127// shift16: Shift mask of acc1. If a bit is set the <<16 operation will be
128// executed on acc1.
129// sub_mul: Negation mask of the matrix multiplication result. If it is
130// one the result of the operation will be negated.
131// sub_acc1: Negation mask of acc1. If it is one acc1 will be negated.
132// sub_acc2: Negation mask of acc2. If it is one acc2 will be negated.
133// sub_mask: Negation mask of complex multiplications. Negates a term of a
134// complex multiplication.
135static inline int aiev2_vmac_compute_control(int sgn_x, int sgn_y, int amode,
136 int bmode, int variant,
137 int zero_acc, int shift16,
138 int sub_mul, int sub_acc1,
139 int sub_acc2, int sub_mask) {
140 return ((unsigned)sub_mask << 16) | ((unsigned)shift16 << 10) |
141 ((unsigned)sub_mul << 11) | ((unsigned)sub_acc1 << 12) |
142 ((unsigned)sub_acc2 << 13) | ((unsigned)amode << 1) |
143 ((unsigned)bmode << 3) | ((unsigned)variant << 5) |
144 (((unsigned)sgn_x << 9) | ((unsigned)sgn_y << 8)) |
145 ((unsigned)zero_acc << 0);
146}
147
148std::string getVectorTypeString(VectorType type, bool abbrev = false,
149 bool acc = false) {
150 std::stringstream ss;
151 auto size = getVectorLaneSize(type);
152 ss << "v" << size;
153 if (auto intType = dyn_cast<IntegerType>(type.getElementType())) {
154 ss << (acc ? "acc" : abbrev ? "i" : "int") << intType.getWidth();
155 } else if (dyn_cast<FloatType>(type.getElementType())) {
156 ss << (abbrev ? "f" : "float");
157 }
158 return ss.str();
159}
160
161std::string getMulOrFMAIntrinsicName(Operation *op) {
162 std::string baseName;
163 Value lhs, result;
164 if (auto mulOp = dyn_cast<aievec::aie1::MulOp>(op)) {
165 baseName = "mul";
166 lhs = mulOp.getLhs();
167 result = mulOp.getResult();
168 } else if (auto fmaOp = dyn_cast<aievec::aie1::FMAOp>(op)) {
169 baseName = "mac";
170 lhs = fmaOp.getLhs();
171 result = fmaOp.getResult();
172 }
173 VectorType resultType = cast<VectorType>(result.getType());
174 int resultSize = getVectorLaneSize(resultType);
175 std::stringstream ss;
176 ss << "llvm.aie.";
177 if (dyn_cast<IntegerType>(resultType.getElementType())) {
178 ss << baseName;
179 ss << resultSize << "."
180 << getVectorTypeString(cast<VectorType>(lhs.getType()));
181 } else if (dyn_cast<FloatType>(resultType.getElementType())) {
182 ss << "vfp" << baseName;
183 }
184 return ss.str();
185}
186
187// Squashes the easy-to-read 16-bit square encoding into
188// the 8-bit encoding the configuration register uses
189uint32_t encodeSquare(uint32_t square) {
190 uint32_t out = 0;
191 out |= ((square >> 0) & 0x3) << 0;
192 out |= ((square >> 4) & 0x3) << 2;
193 out |= ((square >> 8) & 0x3) << 4;
194 out |= ((square >> 12) & 0x3) << 6;
195 return out & 0xFF;
196}
197
198// Encode the configuration register with buffer parameters and options
199// TODO: struct to handle this?
200void encodeConf(uint32_t conf[2], const BufferParams &x, const BufferParams &z,
201 bool sub) {
202 conf[0] |= ((x.step & 0x3F) << 0) | ((z.step & 0x3F) << 8);
203 conf[1] |= (encodeSquare(x.square) << 0) | (encodeSquare(z.square) << 8);
204 conf[1] |= sub << 17;
205}
206
208 : public mlir::ConvertOpToLLVMPattern<aievec::aie1::AddOp> {
209public:
210 using ConvertOpToLLVMPattern<aievec::aie1::AddOp>::ConvertOpToLLVMPattern;
211
212 LogicalResult
213 matchAndRewrite(aievec::aie1::AddOp op, OpAdaptor adaptor,
214 ConversionPatternRewriter &rewriter) const override {
215 op.emitWarning() << "aie.add conversion is not implemented\n";
216 return failure();
217 }
218};
219
221 : public mlir::ConvertOpToLLVMPattern<aievec::aie1::SubOp> {
222public:
223 using ConvertOpToLLVMPattern<aievec::aie1::SubOp>::ConvertOpToLLVMPattern;
224
225 LogicalResult
226 matchAndRewrite(aievec::aie1::SubOp op, OpAdaptor adaptor,
227 ConversionPatternRewriter &rewriter) const override {
228 op.emitWarning() << "aie.sub conversion is not implemented\n";
229 return failure();
230 }
231};
232
234 : public mlir::ConvertOpToLLVMPattern<aievec::aie1::FMAOp> {
235public:
236 using ConvertOpToLLVMPattern<aievec::aie1::FMAOp>::ConvertOpToLLVMPattern;
237
238 LogicalResult
239 matchAndRewrite(aievec::aie1::FMAOp op, OpAdaptor adaptor,
240 ConversionPatternRewriter &rewriter) const override {
241 auto module = op->getParentOfType<ModuleOp>();
242 MLIRContext *context = rewriter.getContext();
243
244 auto startType = IntegerType::get(context, 32);
245 auto offsetsType = VectorType::get({2}, IntegerType::get(context, 32));
246 auto confType = VectorType::get({2}, IntegerType::get(context, 32));
247
248 // If the intrinsic declaration doesn't exist, create it
249 std::string intrinsicName = getMulOrFMAIntrinsicName(op);
250 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(
251 StringAttr::get(context, intrinsicName));
252
253 if (!func) {
254 OpBuilder::InsertionGuard guard(rewriter);
255 rewriter.setInsertionPointToStart(module.getBody());
256 func = rewriter.create<LLVM::LLVMFuncOp>(
257 rewriter.getUnknownLoc(), intrinsicName,
258 LLVM::LLVMFunctionType::get(
259 op.getResult().getType(),
260 {op.getLhs().getType(), op.getRhs().getType(),
261 op.getAcc().getType(), startType, /* xstart */
262 startType, /* ystart */
263 startType, /* zstart */
264 offsetsType, /* xoffsets */
265 offsetsType, /* zoffsets */
266 confType}));
267 }
268
269 // Parse the string attribute values
270 BufferParams x = {};
271 BufferParams z = {};
272 op.getXstart().getAsInteger(0, x.start);
273 op.getXoffsets().getAsInteger(0, x.offsets);
274 op.getXoffsetsHi().getAsInteger(0, x.offsets_hi);
275 op.getXstep().getAsInteger(0, x.step);
276 op.getXsquare().getAsInteger(0, x.square);
277 op.getZstart().getAsInteger(0, z.start);
278 op.getZoffsets().getAsInteger(0, z.offsets);
279 op.getZoffsetsHi().getAsInteger(0, z.offsets_hi);
280 op.getZstep().getAsInteger(0, z.step);
281 op.getZsquare().getAsInteger(0, z.square);
282
283 // Encode the configuration register
284 uint32_t conf[2] = {0, 0};
285 encodeConf(conf, x, z, op.getFmsub());
286
287 // Create the constants and replace the op
288 auto xstartVal = rewriter.create<LLVM::ConstantOp>(
289 op->getLoc(), startType, rewriter.getI32IntegerAttr(x.start));
290 auto ystartVal = rewriter.create<LLVM::ConstantOp>(
291 op->getLoc(), startType, rewriter.getI32IntegerAttr(0));
292 auto zstartVal = rewriter.create<LLVM::ConstantOp>(
293 op->getLoc(), startType, rewriter.getI32IntegerAttr(z.start));
294 auto xoffsetsVal = rewriter.create<LLVM::ConstantOp>(
295 op->getLoc(), offsetsType,
296 rewriter.getI32VectorAttr({(int32_t)x.offsets, (int32_t)x.offsets_hi}));
297 auto zoffsetsVal = rewriter.create<LLVM::ConstantOp>(
298 op->getLoc(), offsetsType,
299 rewriter.getI32VectorAttr({(int32_t)z.offsets, (int32_t)z.offsets_hi}));
300 auto confVal = rewriter.create<LLVM::ConstantOp>(
301 op->getLoc(), confType,
302 rewriter.getI32VectorAttr({(int32_t)conf[0], (int32_t)conf[1]}));
303 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
304 op, func,
305 ValueRange{op.getLhs(), op.getRhs(), op.getAcc(), xstartVal, ystartVal,
306 zstartVal, xoffsetsVal, zoffsetsVal, confVal});
307 return success();
308 }
309};
310
312 : public mlir::ConvertOpToLLVMPattern<aievec::aie1::MulOp> {
313public:
314 using ConvertOpToLLVMPattern<aievec::aie1::MulOp>::ConvertOpToLLVMPattern;
315
316 LogicalResult
317 matchAndRewrite(aievec::aie1::MulOp op, OpAdaptor adaptor,
318 ConversionPatternRewriter &rewriter) const override {
319 auto module = op->getParentOfType<ModuleOp>();
320 MLIRContext *context = rewriter.getContext();
321
322 auto startType = IntegerType::get(context, 32);
323 auto offsetsType = VectorType::get({2}, IntegerType::get(context, 32));
324 auto confType = VectorType::get({2}, IntegerType::get(context, 32));
325
326 // If the intrinsic declaration doesn't exist, create it
327 std::string intrinsicName = getMulOrFMAIntrinsicName(op);
328 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(
329 StringAttr::get(context, intrinsicName));
330
331 if (!func) {
332 OpBuilder::InsertionGuard guard(rewriter);
333 rewriter.setInsertionPointToStart(module.getBody());
334 func = rewriter.create<LLVM::LLVMFuncOp>(
335 rewriter.getUnknownLoc(), intrinsicName,
336 LLVM::LLVMFunctionType::get(op.getResult().getType(),
337 {op.getLhs().getType(),
338 op.getRhs().getType(),
339 startType, /* xstart */
340 startType, /* ystart */
341 startType, /* zstart */
342 offsetsType, /* xoffsets */
343 offsetsType, /* zoffsets */
344 confType}));
345 }
346
347 // Parse the string attribute values
348 BufferParams x = {};
349 BufferParams z = {};
350 op.getXstart().getAsInteger(0, x.start);
351 op.getXoffsets().getAsInteger(0, x.offsets);
352 op.getXoffsetsHi().getAsInteger(0, x.offsets_hi);
353 op.getXstep().getAsInteger(0, x.step);
354 op.getXsquare().getAsInteger(0, x.square);
355 op.getZstart().getAsInteger(0, z.start);
356 op.getZoffsets().getAsInteger(0, z.offsets);
357 op.getZoffsetsHi().getAsInteger(0, z.offsets_hi);
358 op.getZstep().getAsInteger(0, z.step);
359 op.getZsquare().getAsInteger(0, z.square);
360
361 // Encode the configuration register
362 uint32_t conf[2] = {0, 0};
363 encodeConf(conf, x, z, false);
364
365 // Create the constants and replace the op
366 auto xstartVal = rewriter.create<LLVM::ConstantOp>(
367 op->getLoc(), startType, rewriter.getI32IntegerAttr(x.start));
368 auto ystartVal = rewriter.create<LLVM::ConstantOp>(
369 op->getLoc(), startType, rewriter.getI32IntegerAttr(0));
370 auto zstartVal = rewriter.create<LLVM::ConstantOp>(
371 op->getLoc(), startType, rewriter.getI32IntegerAttr(z.start));
372 auto xoffsetsVal = rewriter.create<LLVM::ConstantOp>(
373 op->getLoc(), offsetsType,
374 rewriter.getI32VectorAttr({(int32_t)x.offsets, (int32_t)x.offsets_hi}));
375 auto zoffsetsVal = rewriter.create<LLVM::ConstantOp>(
376 op->getLoc(), offsetsType,
377 rewriter.getI32VectorAttr({(int32_t)z.offsets, (int32_t)z.offsets_hi}));
378 auto confVal = rewriter.create<LLVM::ConstantOp>(
379 op->getLoc(), confType,
380 rewriter.getI32VectorAttr({(int32_t)conf[0], (int32_t)conf[1]}));
381 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
382 op, func,
383 ValueRange{op.getLhs(), op.getRhs(), xstartVal, ystartVal, zstartVal,
384 xoffsetsVal, zoffsetsVal, confVal});
385 return success();
386 }
387};
388
390 : public mlir::ConvertOpToLLVMPattern<aievec::MulElemOp> {
391public:
392 using ConvertOpToLLVMPattern<aievec::MulElemOp>::ConvertOpToLLVMPattern;
393
394 MulElemOpConversion(const LLVMTypeConverter &typeConverter,
395 Aie2Fp32Emulation aie2Fp32EmulationOption)
396 : ConvertOpToLLVMPattern(typeConverter),
398
399 Aie2Fp32Emulation aie2Fp32EmulationOption;
400
402 enum class Kind {
403 // DtIn0_DtIn1_DtRes_CxMxKxN
410 // TODO: I16_I16_I64_16x1x2x1
411 };
412
414 int conf;
415 };
416
417 static DecodedMulElemOp decodeMulElemOp(OpAdaptor op) {
418 auto lhs = op.getLhs();
419 auto lhsVecTy = cast<VectorType>(lhs.getType());
420 auto lhsScaTy = lhsVecTy.getElementType();
421 unsigned lhsBitWidth = lhsScaTy.getIntOrFloatBitWidth();
422
423 // Integer types
424 if (llvm::isa<IntegerType>(lhsScaTy)) {
425 if (lhsBitWidth == 8) {
427 aiev2_vmac_compute_control(
428 /*sgn_x=*/1, /*sgn_y=*/1, /*amode=*/0, /*bmode=*/1,
429 /*variant=*/1, /*zero_acc=*/0, /*shift16=*/0,
430 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
431 /*sub_mask=*/0)};
432 } else if (lhsBitWidth == 16) {
434 aiev2_vmac_compute_control(
435 /*sgn_x=*/1, /*sgn_y=*/1, /*amode=*/0, /*bmode=*/3,
436 /*variant=*/1, /*zero_acc=*/0, /*shift16=*/0,
437 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
438 /*sub_mask=*/0)};
439 } else if (lhsBitWidth == 32) {
440 // emulated I32 mul_elem
442 }
443 } else {
444 // Float types
445 if (lhsBitWidth == 16) {
447 aiev2_vmac_compute_control(
448 /*sgn_x=*/0, /*sgn_y=*/0, /*amode=*/2, /*bmode=*/3,
449 /*variant=*/1, /*zero_acc=*/0, /*shift16=*/0,
450 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
451 /*sub_mask=*/0)};
452 } else if (lhsBitWidth == 32) {
453 // emulated FP32 mul_elem
455 }
456 }
457
459 }
460
461 // This conversion pattern implements the below CPP emulated I32 mul_elem.
462 // INTRINSIC(v16acc64)
463 // mul_elem_16_2(v16int32 a0, v16int32 a1, v16int32 b0, v16int32 b1) {
464 // v32uint16 a_lo = (v32uint16)shuffle(a0, a1, 2);
465 // v32int16 a_hi = (v32int16)shuffle(a0, a1, 3);
466 // v32uint16 b_lo = (v32uint16)shuffle(b0, b1, 2);
467 // v32int16 b_hi = (v32int16)shuffle(b0, b1, 3);
468 // v16acc64 acc = ::mul_elem_16_2(a_hi, b_hi);
469 // acc = mac_elem_16_2_conf(a_hi, 1, b_lo, false, acc, 0, 1, 0, 0);
470 // acc = mac_elem_16_2_conf(a_lo, false, b_hi, 1, acc, 0, 0, 0, 0);
471 // acc = mac_elem_16_2_conf(a_lo, false, b_lo, false, acc, 0, 1, 0, 0);
472 // return acc;
473 // }
474 // Caller example when handling the elementwise mul of two v16int32 vectors.
475 // v16int32 v1 = LHS();
476 // v16int32 v2 = RHS();
477 // v16acc64 v3 = mul_elem_16_2(v1, broadcast_zero_s32(), v2,
478 // undef_v16int32());
479 // Explantion:
480 // a_lo = low_part(a0[0]--a0[15], a1[0]--a1[15])
481 // a_hi = high_part(a0[0]--a0[15], a1[0]--a1[15])
482 // b_lo = low_part(b0[0]--b0[15], b1[0]--b1[15])
483 // b_hi = high_part(b0[0]--b0[15], b1[0]--b1[15])
484 // The firt `acc` is from mul_elem_16_2(a_hi, b_hi), which performs 16 channel
485 // of 1x2x1 matmul, acc[0] = a_hi[0]*b_hi[0]+a_hi[16]*b_hi[16], ... , acc[15]
486 // = a_hi[15]*b_hi[15]+a_hi[31]*b_hi[31]. Then, the first MAC performs `acc`
487 // left shift 16bit, and then 16 channel of 1x2x1 matmul (a_hi, b_lo)
488 // accumulating to `acc`. The second MAC performs 16 channel of 1x2x1 matmul
489 // (a_lo, b_hi) accumulating to `acc`. Finally, the third MAC performs 16
490 // channel of 1x2x1 matmul (a_lo, b_hi) accumulating to `acc`.
491 LogicalResult
492 convertToEmulatedI32MulElem(aievec::MulElemOp op, OpAdaptor adaptor,
493 ConversionPatternRewriter &rewriter) const {
494
495 Location loc = op.getLoc();
496 auto zeroCst = rewriter.create<LLVM::ConstantOp>(
497 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
498 auto a0 = adaptor.getLhs();
499 auto a1 = rewriter.create<xllvm::VectorBroadcast32I512IntrOp>(
500 loc, VectorType::get({16}, rewriter.getI32Type()), zeroCst);
501 auto b0 = adaptor.getRhs();
502 auto b1 = rewriter.create<xllvm::UndefV16I32IntrOp>(
503 loc, VectorType::get({16}, rewriter.getI32Type()));
504
505 // 4* Shuffle
506 auto a_lo = rewriter.create<xllvm::VectorShuffleIntrOp>(
507 loc, VectorType::get({16}, rewriter.getI32Type()), a0, a1,
508 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
509 rewriter.getI32IntegerAttr(2)));
510 auto a_hi = rewriter.create<xllvm::VectorShuffleIntrOp>(
511 loc, VectorType::get({16}, rewriter.getI32Type()), a0, a1,
512 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
513 rewriter.getI32IntegerAttr(3)));
514 auto b_lo = rewriter.create<xllvm::VectorShuffleIntrOp>(
515 loc, VectorType::get({16}, rewriter.getI32Type()), b0, b1,
516 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
517 rewriter.getI32IntegerAttr(2)));
518 auto b_hi = rewriter.create<xllvm::VectorShuffleIntrOp>(
519 loc, VectorType::get({16}, rewriter.getI32Type()), b0, b1,
520 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
521 rewriter.getI32IntegerAttr(3)));
522 // MUL + 3 * MAC
523 auto mulConfCst = rewriter.create<LLVM::ConstantOp>(
524 loc, rewriter.getI32Type(),
525 rewriter.getI32IntegerAttr(aiev2_vmac_compute_control(
526 /*sgn_x=*/1, /*sgn_y=*/1, /*amode=*/1, /*bmode=*/3,
527 /*variant=*/2, /*zero_acc=*/0, /*shift16=*/0,
528 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, /*sub_mask=*/0)));
529 auto mulConfOp = rewriter.create<xllvm::MulConfAcc64IntrOp>(
530 loc, VectorType::get({16}, rewriter.getI64Type()),
531 forceCastOperandsToSignature(
532 rewriter, loc,
533 /*operands=*/{a_hi, b_hi, mulConfCst},
534 /*signature=*/
535 {VectorType::get({64}, rewriter.getI8Type()),
536 VectorType::get({16}, rewriter.getI32Type()),
537 rewriter.getI32Type()}));
538
539 auto createMacConfOp = [&](SmallVector<Value> operands,
540 int macConf) -> Value {
541 operands.push_back(rewriter.create<LLVM::ConstantOp>(
542 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(macConf)));
543 return rewriter
544 .create<xllvm::MacConfAcc64IntrOp>(
545 loc, VectorType::get({16}, rewriter.getI64Type()),
546 forceCastOperandsToSignature(
547 rewriter, loc,
548 /*operands=*/operands,
549 /*signature=*/
550 {VectorType::get({64}, rewriter.getI8Type()),
551 VectorType::get({16}, rewriter.getI32Type()),
552 VectorType::get({16}, rewriter.getI64Type()),
553 rewriter.getI32Type()}))
554 .getResult();
555 };
556 auto acc64Val = mulConfOp.getResult();
557 acc64Val = createMacConfOp(
558 SmallVector<Value>{a_hi, b_lo, acc64Val},
559 aiev2_vmac_compute_control(
560 /*sgn_x=*/1, /*sgn_y=*/0, /*amode=*/1, /*bmode=*/3,
561 /*variant=*/2, /*zero_acc=*/0, /*shift16=*/1,
562 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, /*sub_mask=*/0));
563 acc64Val = createMacConfOp(
564 SmallVector<Value>{a_lo, b_hi, acc64Val},
565 aiev2_vmac_compute_control(
566 /*sgn_x=*/0, /*sgn_y=*/1, /*amode=*/1, /*bmode=*/3,
567 /*variant=*/2, /*zero_acc=*/0, /*shift16=*/0,
568 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, /*sub_mask=*/0));
569 acc64Val = createMacConfOp(
570 SmallVector<Value>{a_lo, b_lo, acc64Val},
571 aiev2_vmac_compute_control(
572 /*sgn_x=*/0, /*sgn_y=*/0, /*amode=*/1, /*bmode=*/3,
573 /*variant=*/2, /*zero_acc=*/0, /*shift16=*/1,
574 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, /*sub_mask=*/0));
575
576 // create bitcast for result
577 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
578 acc64Val);
579 return success();
580 }
581
582 // This conversion pattern implements the below CPP emulated FP32 mul_elem.
583 // inline v16accfloat mul_elem_16_accuracy_safe(v16float v1, v16float v2) {
584 // v32bfloat16 a = broadcast_zero_to_v32bfloat16();
585 // v32bfloat16 b = broadcast_zero_to_v32bfloat16();
586 // v32bfloat16 c = broadcast_zero_to_v32bfloat16();
587 // v32bfloat16 d = broadcast_zero_to_v32bfloat16();
588 // v32bfloat16 e = broadcast_zero_to_v32bfloat16();
589 // v32bfloat16 f = broadcast_zero_to_v32bfloat16();
590 // v32bfloat16 dummy0 = broadcast_one_to_v32bfloat16();
591 // a = insert(a,0,to_v16bfloat16((v16accfloat)v1));
592 // v16accfloat acc0 = msc_elem_16_2(a, dummy0, (v16accfloat)v1);
593 // b = insert(b,0,to_v16bfloat16(acc0));
594 // c = insert(c,0,to_v16bfloat16(msc_elem_16_2(b, dummy0, acc0)));
595 // d = insert(d,0,to_v16bfloat16((v16accfloat)v2));
596 // v16accfloat acc1 = msc_elem_16_2(d, dummy0, (v16accfloat)v2);
597 // e = insert(e,0,to_v16bfloat16(acc1));
598 // f = insert(f,0,to_v16bfloat16(msc_elem_16_2(e, dummy0, acc1)));
599 // return
600 // mac_elem_16_2(a,d,mac_elem_16_2(a,e,mac_elem_16_2(b,d,mac_elem_16_2(
601 // d,c,mac_elem_16_2(b,e,mac_elem_16_2(a,f,mac_elem_16_2(
602 // b,f,mac_elem_16_2(c,e,mul_elem_16_2(c,f)))))))));
603 // }
604 // Caller example when handling the elementwise mul of two v16float vectors.
605 // v16float v1 = LHS(); v16float v2 = RHS();
606 // v16accfloat v3 = mul_elem_16(v1, v2);
607 // Explantion: For v32bfloat16 `a`, the first half v16bf16 contains `most
608 // significant 7 bits of mantissa` from v1, and the second half v16bf16 are
609 // zeros. For v16accfloat `acc0`, the MSC equals to "(original `v1` with 23
610 // bits of mantissa) - (`a` with MSB 7 bits of mantissa from v1)". For
611 // v32bfloat16 `b`, the first half v16bf16 contains `[7:13] bits of mantissa
612 // from v1` from v1, and the second half v16bf16 are zeros. For v32bfloat16
613 // `c`, the first half v16bf16 contains `[14:20] bits of mantissa from v1`
614 // from v1, and the second half v16bf16 are zeros. Hence, we can represent
615 // v16float in three v32bfloat16 and then perform 9 MUL/MAC in v32bfloat16 to
616 // get the final elementwise multiplication result.
617
618 LogicalResult
619 convertToEmulatedFP32MulElem(aievec::MulElemOp op, OpAdaptor adaptor,
620 ConversionPatternRewriter &rewriter) const {
621 Location loc = op.getLoc();
622 auto zeroCst = rewriter.create<LLVM::ConstantOp>(
623 loc, rewriter.getBF16Type(),
624 rewriter.getZeroAttr(rewriter.getBF16Type()));
625 auto aZeros = rewriter.create<xllvm::VectorBroadcast16BF512IntrOp>(
626 loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
627 auto bZeros = rewriter.create<xllvm::VectorBroadcast16BF512IntrOp>(
628 loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
629 auto cZeros = rewriter.create<xllvm::VectorBroadcast16BF512IntrOp>(
630 loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
631 auto dZeros = rewriter.create<xllvm::VectorBroadcast16BF512IntrOp>(
632 loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
633 auto eZeros = rewriter.create<xllvm::VectorBroadcast16BF512IntrOp>(
634 loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
635 auto fZeros = rewriter.create<xllvm::VectorBroadcast16BF512IntrOp>(
636 loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
637 auto oneCst = rewriter.create<LLVM::ConstantOp>(
638 loc, rewriter.getBF16Type(),
639 rewriter.getOneAttr(rewriter.getBF16Type()));
640 auto dummy0 = rewriter.create<xllvm::VectorBroadcast16BF512IntrOp>(
641 loc, VectorType::get({32}, rewriter.getBF16Type()), oneCst);
642 auto zeroCstI32 = rewriter.create<LLVM::ConstantOp>(
643 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
644 auto mscMacMulConfCst = rewriter.create<LLVM::ConstantOp>(
645 loc, rewriter.getI32Type(),
646 rewriter.getI32IntegerAttr(aiev2_vmac_compute_control(
647 /*sgn_x=*/0, /*sgn_y=*/0, /*amode=*/2, /*bmode=*/3,
648 /*variant=*/1, /*zero_acc=*/0, /*shift16=*/0,
649 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, /*sub_mask=*/0)));
650
651 auto extractV16FP32ToThreeV16BF16 =
652 [&](Value inputV16FP32, Value aZeros, Value bZeros,
653 Value cZeros) -> std::tuple<Value, Value, Value> {
654 // a = insert(a,0,to_v16bfloat16((v16accfloat)v1));
655 auto inputBitCasted =
656 forceCastValueToType(rewriter, loc, inputV16FP32,
657 VectorType::get({8}, rewriter.getI64Type()));
658 auto v1ToBF16 = rewriter.create<xllvm::Vector16AccFloatToV16BF16IntrOp>(
659 loc, VectorType::get({16}, rewriter.getBF16Type()), inputBitCasted);
660 auto a = rewriter.create<xllvm::UpdBF512BF256IntrOp>(
661 loc, VectorType::get({32}, rewriter.getBF16Type()), aZeros, v1ToBF16,
662 zeroCstI32);
663
664 // v16accfloat acc0 = msc_elem_16_2(a, dummy0, (v16accfloat)v1);
665 auto acc0 = rewriter.create<xllvm::MscConfBF16IntrOp>(
666 loc, VectorType::get({8}, rewriter.getI64Type()), a, dummy0,
667 inputBitCasted, mscMacMulConfCst);
668
669 // b = insert(b,0,to_v16bfloat16(acc0));
670 auto acc0ToBF16 = rewriter.create<xllvm::Vector16AccFloatToV16BF16IntrOp>(
671 loc, VectorType::get({16}, rewriter.getBF16Type()), acc0);
672 auto b = rewriter.create<xllvm::UpdBF512BF256IntrOp>(
673 loc, VectorType::get({32}, rewriter.getBF16Type()), bZeros,
674 acc0ToBF16, zeroCstI32);
675
676 // c = insert(c,0,to_v16bfloat16(msc_elem_16_2(b, dummy0, acc0)));
677 auto acc0Mscb = rewriter.create<xllvm::MscConfBF16IntrOp>(
678 loc, VectorType::get({8}, rewriter.getI64Type()), b, dummy0, acc0,
679 mscMacMulConfCst);
680 auto acc0MscbToBF16 =
681 rewriter.create<xllvm::Vector16AccFloatToV16BF16IntrOp>(
682 loc, VectorType::get({16}, rewriter.getBF16Type()), acc0Mscb);
683 auto c = rewriter.create<xllvm::UpdBF512BF256IntrOp>(
684 loc, VectorType::get({32}, rewriter.getBF16Type()), cZeros,
685 acc0MscbToBF16, zeroCstI32);
686 return std::make_tuple(a.getResult(), b.getResult(), c.getResult());
687 };
688
689 // Get v16vfloat16 a, b, c for representing v16float v1
690 auto [a, b, c] =
691 extractV16FP32ToThreeV16BF16(adaptor.getLhs(), aZeros, bZeros, cZeros);
692 // Get v16vfloat16 d, e, f for representing v16float v2
693 auto [d, e, f] =
694 extractV16FP32ToThreeV16BF16(adaptor.getRhs(), dZeros, eZeros, fZeros);
695
696 // Create 1 MUL and 2/5/8 MACs depending on the Aie2Fp32EmulationOption
697 auto createMacOps = [&](Value lhs, Value rhs, Value acc) -> Value {
698 return rewriter
699 .create<xllvm::MacConfBF16IntrOp>(
700 loc, VectorType::get({8}, rewriter.getI64Type()), lhs, rhs, acc,
701 mscMacMulConfCst)
702 .getResult();
703 };
704
705 Value finalMacVal;
706 if (aie2Fp32EmulationOption == Aie2Fp32Emulation::AccuracyFast) {
707 // Fast and Accurate option. float a*b would require 6 mac operations.
708 // Input fp32 number is split in to 3 bfloat16 numbers to extract all the
709 // bits of the mantissa. float a,b; both a and b are split in to 3
710 // bfloat16 numbers each. Hence there would be 9 mac operations in
711 // multiplication of a and b. In the 9 mac operations to emulate fp32 mul,
712 // mac operations with LSBs are ignored. (3 last terms). This helps
713 // improve cycle count of mul and has least impact on accuracy of result.
714 // This is the default option to the aiecompiler
715 auto afMul = rewriter.create<xllvm::MulConfBF16IntrOp>(
716 loc, VectorType::get({8}, rewriter.getI64Type()), a, f,
717 mscMacMulConfCst);
718 finalMacVal = createMacOps(
719 a, d,
720 createMacOps(
721 a, e,
722 createMacOps(b, d,
723 createMacOps(d, c, createMacOps(b, e, afMul)))));
724 } else if (aie2Fp32EmulationOption == Aie2Fp32Emulation::AccuracyLow) {
725 // Fast and least accurate option. float a*b would require 3 mac
726 // operations.
727 // Input fp32 number is split in to 2 bfloat16 numbers. Hence not all the
728 // bits from mantissa can be used. float a,b; Both a and b are split in to
729 // 2 bfloat16 numbers each. Hence there would be 4 mac operations in
730 // multiplication of a and b. In the 4 mac operations to emulate fp32 mul,
731 // mac operations with LSBs are ignored. (1 last term). This helps improve
732 // cycle count of mul float a, b;
733 auto bdMul = rewriter.create<xllvm::MulConfBF16IntrOp>(
734 loc, VectorType::get({8}, rewriter.getI64Type()), b, d,
735 mscMacMulConfCst);
736 finalMacVal = createMacOps(a, d, createMacOps(a, e, bdMul));
737 } else {
738 // aie2Fp32EmulationOption == Aie2Fp32Emulation::AccuracySafe
739 // Most accurate option since input fp32 number is split in to 3 bfloat16
740 // numbers to extract all the bits of the mantissa. float a*b would
741 // require 9 mac operations due to 3 bfloat16 splits each.
742 auto cfMul = rewriter.create<xllvm::MulConfBF16IntrOp>(
743 loc, VectorType::get({8}, rewriter.getI64Type()), c, f,
744 mscMacMulConfCst);
745 finalMacVal = createMacOps(
746 a, d,
747 createMacOps(
748 a, e,
749 createMacOps(
750 b, d,
751 createMacOps(
752 d, c,
753 createMacOps(
754 b, e,
755 createMacOps(
756 a, f,
757 createMacOps(b, f,
758 createMacOps(c, e, cfMul))))))));
759 }
760
761 // create bitcast for result
762 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
763 finalMacVal);
764 return success();
765 }
766
767 LogicalResult
768 matchAndRewrite(aievec::MulElemOp op, OpAdaptor adaptor,
769 ConversionPatternRewriter &rewriter) const override {
770 Location loc = op.getLoc();
771 auto decodedMulElemOp = decodeMulElemOp(adaptor);
772
773 if (decodedMulElemOp.kind == DecodedMulElemOp::Kind::UNSUPPORTED) {
774 op.emitWarning() << "aievec.mul_elem conversion is not supported.\n";
775 return failure();
776 }
777
778 // Handle the emulated I32/FP32 mul_elem
779 if (decodedMulElemOp.kind == DecodedMulElemOp::Kind::I32_I32_I64_32x1x2x1) {
780 return convertToEmulatedI32MulElem(op, adaptor, rewriter);
781 } else if (decodedMulElemOp.kind ==
783 return convertToEmulatedFP32MulElem(op, adaptor, rewriter);
784 }
785
786 // create constant for config
787 auto confCst = rewriter.create<LLVM::ConstantOp>(
788 loc, rewriter.getI32Type(),
789 rewriter.getI32IntegerAttr(decodedMulElemOp.conf));
790 Value mulElemOp = nullptr;
791 SmallVector<Value> operands({adaptor.getLhs(), adaptor.getRhs(), confCst});
792
793 // create xllvm intrinsic
794 if (decodedMulElemOp.kind == DecodedMulElemOp::Kind::I16_I16_I32_32x1x1x1 ||
795 decodedMulElemOp.kind == DecodedMulElemOp::Kind::I8_I8_I32_32x1x2x1) {
796 mulElemOp = rewriter.create<xllvm::MulConfAcc32IntrOp>(
797 loc, VectorType::get({16}, rewriter.getI64Type()),
798 forceCastOperandsToSignature(
799 rewriter, loc, operands,
800 {VectorType::get({64}, rewriter.getI8Type()),
801 VectorType::get({16}, rewriter.getI32Type()),
802 rewriter.getI32Type()}));
803 } else if (decodedMulElemOp.kind ==
805 mulElemOp = rewriter.create<xllvm::MulConfBF16IntrOp>(
806 loc, VectorType::get({8}, rewriter.getI64Type()),
807 forceCastOperandsToSignature(
808 rewriter, loc, operands,
809 {VectorType::get({32}, rewriter.getBF16Type()),
810 VectorType::get({32}, rewriter.getBF16Type()),
811 rewriter.getI32Type()}));
812 }
813
814 // create bitcast for result
815 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
816 mulElemOp);
817 return success();
818 }
819};
820
821class UPSOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::UPSOp> {
822public:
823 using ConvertOpToLLVMPattern<aievec::UPSOp>::ConvertOpToLLVMPattern;
824
825 LogicalResult
826 matchAndRewrite(aievec::UPSOp op, OpAdaptor adaptor,
827 ConversionPatternRewriter &rewriter) const override {
828 Location loc = op.getLoc();
829
830 Value result = op.getResult();
831 VectorType resultType = cast<VectorType>(result.getType());
832 VectorType flatResTy = getFlattenedVectorType(resultType);
833 Type resultScaTy = resultType.getElementType();
834 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
835 int resultLanes = getVectorLaneSize(resultType);
836 int resultVectorSize = resultBitWidth * resultLanes;
837
838 Value opSrcVal = adaptor.getSource();
839 auto srcVecTy = cast<VectorType>(opSrcVal.getType());
840 auto fltSrcVecTy = getFlattenedVectorType(srcVecTy);
841 if (srcVecTy != fltSrcVecTy)
842 opSrcVal =
843 rewriter
844 .create<vector::ShapeCastOp>(op.getLoc(), fltSrcVecTy, opSrcVal)
845 .getResult();
846
847 // create xllvm intrinsic
848 // Integer types
849 Value upsIntrOp = nullptr;
850 if (llvm::isa<IntegerType>(resultScaTy)) {
851 // create constant for sign
852 auto signCst = rewriter.create<LLVM::ConstantOp>(
853 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
854 auto shiftCst = rewriter.create<LLVM::ConstantOp>(
855 loc, rewriter.getI32Type(),
856 rewriter.getI32IntegerAttr(op.getShift()));
857
858 SmallVector<Value> operands({opSrcVal, shiftCst, signCst});
859 if (resultVectorSize == 512) {
860 if (resultBitWidth == 32) {
861 // v16int16 -> v16acc32
862 upsIntrOp = rewriter.create<xllvm::Acc32V16I256UpsIntrOp>(
863 loc, VectorType::get({8}, rewriter.getI64Type()),
864 forceCastOperandsToSignature(
865 rewriter, loc, operands,
866 {VectorType::get({16}, rewriter.getI16Type()),
867 rewriter.getI32Type(), rewriter.getI32Type()}));
868 } else if (resultBitWidth == 64) {
869 // v8int32 -> v8acc64
870 upsIntrOp = rewriter.create<xllvm::Acc64V8I256UpsIntrOp>(
871 loc, VectorType::get({8}, rewriter.getI64Type()),
872 forceCastOperandsToSignature(
873 rewriter, loc, operands,
874 {VectorType::get({8}, rewriter.getI32Type()),
875 rewriter.getI32Type(), rewriter.getI32Type()}));
876 }
877 } else if (resultVectorSize == 1024) {
878 Value src = opSrcVal;
879 VectorType srcType = cast<VectorType>(src.getType());
880 Type srcScaType = srcType.getElementType();
881 unsigned srcBitWidth = srcScaType.getIntOrFloatBitWidth();
882
883 if (resultBitWidth == 32 && srcBitWidth == 16) {
884 // v32int16 -> v32acc32
885 upsIntrOp = rewriter.create<xllvm::Acc32V32I512UpsIntrOp>(
886 loc, VectorType::get({16}, rewriter.getI64Type()),
887 forceCastOperandsToSignature(
888 rewriter, loc, operands,
889 {VectorType::get({32}, rewriter.getI16Type()),
890 rewriter.getI32Type(), rewriter.getI32Type()}));
891 } else if (resultBitWidth == 64 && srcBitWidth == 32) {
892 // v16int32 -> v16acc64
893 upsIntrOp = rewriter.create<xllvm::Acc64V16I512UpsIntrOp>(
894 loc, VectorType::get({16}, rewriter.getI64Type()),
895 forceCastOperandsToSignature(
896 rewriter, loc, operands,
897 {VectorType::get({16}, rewriter.getI32Type()),
898 rewriter.getI32Type(), rewriter.getI32Type()}));
899 } else if (resultBitWidth == 64 && srcBitWidth == 16) {
900 // v16int16 -> v16acc64
901 upsIntrOp = rewriter.create<xllvm::Acc64V16I256UpsIntrOp>(
902 loc, VectorType::get({16}, rewriter.getI64Type()),
903 forceCastOperandsToSignature(
904 rewriter, loc, operands,
905 {VectorType::get({16}, rewriter.getI16Type()),
906 rewriter.getI32Type(), rewriter.getI32Type()}));
907 } else if (resultBitWidth == 32 && srcBitWidth == 8) {
908 // v32int8 -> v32acc32
909 upsIntrOp = rewriter.create<xllvm::Acc32V32I256UpsIntrOp>(
910 loc, VectorType::get({16}, rewriter.getI64Type()),
911 forceCastOperandsToSignature(
912 rewriter, loc, operands,
913 {VectorType::get({32}, rewriter.getI8Type()),
914 rewriter.getI32Type(), rewriter.getI32Type()}));
915 }
916 }
917 } else {
918 // Float types
919 if (resultVectorSize == 512) {
920 // v16bfloat16 -> v16accfloat
921 upsIntrOp = rewriter.create<xllvm::Vector16BF16ToV16AccFloatIntrOp>(
922 loc, VectorType::get({8}, rewriter.getI64Type()),
923 forceCastOperandsToSignature(
924 rewriter, loc, {opSrcVal},
925 {VectorType::get({16}, rewriter.getBF16Type())}));
926 } else if (resultVectorSize == 1024) {
927 // v32bfloat16 -> v32accfloat
928 // The CPP example of the implementation is below:
929 // INTRINSIC(v32accfloat) ups_to_v32accfloat(v32bfloat16 a) {
930 // v16accfloat x0 = ups_to_v16accfloat(extract_v16bfloat16(a, 0));
931 // v16accfloat x1 = ups_to_v16accfloat(extract_v16bfloat16(a, 1));
932 // return concat(x0, x1);
933 // }
934 auto indexZeroCst = rewriter.create<LLVM::ConstantOp>(
935 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
936 auto indexOneCst = rewriter.create<LLVM::ConstantOp>(
937 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
938 auto extractUps = [&](Value source, Value index) -> Value {
939 auto extOp = rewriter.create<xllvm::ExtI256I512IntrOp>(
940 loc, VectorType::get({8}, rewriter.getI32Type()),
941 forceCastOperandsToSignature(
942 rewriter, loc, {source, index},
943 {VectorType::get({16}, rewriter.getI32Type()),
944 rewriter.getI32Type()}));
945 return rewriter.create<xllvm::Vector16BF16ToV16AccFloatIntrOp>(
946 loc, VectorType::get({8}, rewriter.getI64Type()),
947 forceCastOperandsToSignature(
948 rewriter, loc, {extOp},
949 {VectorType::get({16}, rewriter.getBF16Type())}));
950 };
951 auto resLo = extractUps(opSrcVal, indexZeroCst);
952 auto resHi = extractUps(opSrcVal, indexOneCst);
953 // Concat the two 512-bit vector to a 1024-bit vector.
954 // Note that given sources a0 and a1, the result is [a1; a0].
955 upsIntrOp = rewriter.create<xllvm::ConcatI1024I512IntrOp>(
956 loc, VectorType::get({32}, rewriter.getI32Type()),
957 forceCastOperandsToSignature(
958 rewriter, loc, {resLo, resHi},
959 {VectorType::get({16}, rewriter.getI32Type()),
960 VectorType::get({16}, rewriter.getI32Type())}));
961 }
962 }
963
964 if (!upsIntrOp) {
965 op.emitWarning() << "aievec.ups is not supported.\n";
966 return failure();
967 }
968
969 // create bitcast for result if needed
970 if (flatResTy != upsIntrOp.getType())
971 upsIntrOp = rewriter.create<LLVM::BitcastOp>(loc, flatResTy, upsIntrOp);
972
973 if (flatResTy != resultType)
974 upsIntrOp =
975 rewriter.create<vector::ShapeCastOp>(loc, resultType, upsIntrOp);
976
977 rewriter.replaceOp(op, upsIntrOp);
978
979 return success();
980 }
981};
982
983class SRSOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::SRSOp> {
984public:
985 using ConvertOpToLLVMPattern<aievec::SRSOp>::ConvertOpToLLVMPattern;
986
987 LogicalResult
988 matchAndRewrite(aievec::SRSOp op, OpAdaptor adaptor,
989 ConversionPatternRewriter &rewriter) const override {
990 Location loc = op.getLoc();
991
992 Value result = op.getResult();
993 VectorType resultType = cast<VectorType>(result.getType());
994 Type resultScaTy = resultType.getElementType();
995 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
996 int resultLanes = getVectorLaneSize(resultType);
997 int resultVectorSize = resultBitWidth * resultLanes;
998
999 // Integer types
1000 Value srsIntrOp = nullptr;
1001 if (llvm::isa<IntegerType>(resultScaTy)) {
1002 // create constant for sign
1003 auto signCst = rewriter.create<LLVM::ConstantOp>(
1004 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
1005
1006 // create xllvm intrinsic
1007 SmallVector<Value> operands(
1008 {adaptor.getSource(), adaptor.getShift(), signCst});
1009 if (resultVectorSize == 512) {
1010 if (resultBitWidth == 16) {
1011 // v32acc32 -> v32int16
1012 srsIntrOp = rewriter.create<xllvm::I512V32Acc32SrsIntrOp>(
1013 loc, VectorType::get({32}, rewriter.getI16Type()),
1014 forceCastOperandsToSignature(
1015 rewriter, loc, operands,
1016 {VectorType::get({16}, rewriter.getI64Type()),
1017 rewriter.getI32Type(), rewriter.getI32Type()}));
1018 } else if (resultBitWidth == 32) {
1019 // v16acc64 -> v16int32
1020 srsIntrOp = rewriter.create<xllvm::I512V16Acc64SrsIntrOp>(
1021 loc, VectorType::get({16}, rewriter.getI32Type()),
1022 forceCastOperandsToSignature(
1023 rewriter, loc, operands,
1024 {VectorType::get({16}, rewriter.getI64Type()),
1025 rewriter.getI32Type(), rewriter.getI32Type()}));
1026 }
1027 } else if (resultVectorSize == 256) {
1028 Value src = adaptor.getSource();
1029 VectorType srcType = cast<VectorType>(src.getType());
1030 Type srcScaType = srcType.getElementType();
1031 unsigned srcBitWidth = srcScaType.getIntOrFloatBitWidth();
1032
1033 if (resultBitWidth == 16 && srcBitWidth == 32) {
1034 // v16acc32 -> v16int16
1035 srsIntrOp = rewriter.create<xllvm::I256V16Acc32SrsIntrOp>(
1036 loc, VectorType::get({16}, rewriter.getI16Type()),
1037 forceCastOperandsToSignature(
1038 rewriter, loc, operands,
1039 {VectorType::get({8}, rewriter.getI64Type()),
1040 rewriter.getI32Type(), rewriter.getI32Type()}));
1041 } else if (resultBitWidth == 8 && srcBitWidth == 32) {
1042 // v32acc32 -> v32int8
1043 srsIntrOp = rewriter.create<xllvm::I256V32Acc32SrsIntrOp>(
1044 loc, VectorType::get({32}, rewriter.getI8Type()),
1045 forceCastOperandsToSignature(
1046 rewriter, loc, operands,
1047 {VectorType::get({16}, rewriter.getI64Type()),
1048 rewriter.getI32Type(), rewriter.getI32Type()}));
1049 } else if (resultBitWidth == 16 && srcBitWidth == 64) {
1050 // v16acc64 -> v16int16
1051 srsIntrOp = rewriter.create<xllvm::I256V16Acc64SrsIntrOp>(
1052 loc, VectorType::get({16}, rewriter.getI16Type()),
1053 forceCastOperandsToSignature(
1054 rewriter, loc, operands,
1055 {VectorType::get({16}, rewriter.getI64Type()),
1056 rewriter.getI32Type(), rewriter.getI32Type()}));
1057 } else if (resultBitWidth == 32 && srcBitWidth == 64) {
1058 // v8acc64 -> v8int32
1059 srsIntrOp = rewriter.create<xllvm::I256V8Acc64SrsIntrOp>(
1060 loc, VectorType::get({8}, rewriter.getI32Type()),
1061 forceCastOperandsToSignature(
1062 rewriter, loc, operands,
1063 {VectorType::get({8}, rewriter.getI64Type()),
1064 rewriter.getI32Type(), rewriter.getI32Type()}));
1065 }
1066 }
1067 } else {
1068 // Float types
1069 if (resultVectorSize == 256) {
1070 // v16accfloat -> v16bfloat16
1071 srsIntrOp = rewriter.create<xllvm::Vector16AccFloatToV16BF16IntrOp>(
1072 loc, VectorType::get({16}, rewriter.getBF16Type()),
1073 forceCastOperandsToSignature(
1074 rewriter, loc, {adaptor.getSource()},
1075 {VectorType::get({8}, rewriter.getI64Type())}));
1076 } else if (resultVectorSize == 512) {
1077 // v32accfloat -> v32bfloat16
1078 // The CPP example of the implementation is below:
1079 // v32bfloat16 to_v32bfloat16(v32accfloat acc) {
1080 // v16bfloat16 x0 = to_v16bfloat16(extract_v16accfloat(acc, 0));
1081 // v16bfloat16 x1 = to_v16bfloat16(extract_v16accfloat(acc, 1));
1082 // return concat(x0, x1);
1083 // }
1084 auto indexZeroCst = rewriter.create<LLVM::ConstantOp>(
1085 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
1086 auto indexOneCst = rewriter.create<LLVM::ConstantOp>(
1087 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
1088 auto extractSrs = [&](Value source, Value index) -> Value {
1089 auto extOp = rewriter.create<xllvm::ExtI512I1024IntrOp>(
1090 loc, VectorType::get({16}, rewriter.getI32Type()),
1091 forceCastOperandsToSignature(
1092 rewriter, loc, {source, index},
1093 {VectorType::get({32}, rewriter.getI32Type()),
1094 rewriter.getI32Type()}));
1095 return rewriter.create<xllvm::Vector16AccFloatToV16BF16IntrOp>(
1096 loc, VectorType::get({16}, rewriter.getBF16Type()),
1097 forceCastOperandsToSignature(
1098 rewriter, loc, {extOp},
1099 {VectorType::get({8}, rewriter.getI64Type())}));
1100 };
1101 auto resLo = extractSrs(adaptor.getSource(), indexZeroCst);
1102 auto resHi = extractSrs(adaptor.getSource(), indexOneCst);
1103 // Concat the two 256-bit vector to a 512-bit vector.
1104 // Note that given sources a0 and a1, the result is [a1; a0].
1105 srsIntrOp = rewriter.create<xllvm::ConcatI512I256IntrOp>(
1106 loc, VectorType::get({16}, rewriter.getI32Type()),
1107 forceCastOperandsToSignature(
1108 rewriter, loc, {resLo, resHi},
1109 {VectorType::get({8}, rewriter.getI32Type()),
1110 VectorType::get({8}, rewriter.getI32Type())}));
1111 }
1112 }
1113
1114 if (!srsIntrOp) {
1115 op.emitWarning() << "aievec.srs is not supported.\n";
1116 return failure();
1117 }
1118
1119 // create bitcast for result if needed
1120 if (op.getResult().getType() != srsIntrOp.getType()) {
1121 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
1122 srsIntrOp);
1123 } else {
1124 rewriter.replaceOp(op, srsIntrOp);
1125 }
1126
1127 return success();
1128 }
1129};
1130
1131class UPDOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::UPDOp> {
1132public:
1133 using ConvertOpToLLVMPattern<aievec::UPDOp>::ConvertOpToLLVMPattern;
1134
1135 static std::string getIntrinsicName(aievec::UPDOp op, int loadSize) {
1136 auto resultType = cast<VectorType>(op.getResult().getType());
1137 std::stringstream ss;
1138 ss << "llvm.aie.upd.";
1139 ss << (loadSize == 128 ? 'v' : loadSize == 256 ? 'w' : 'x') << ".";
1140 ss << getVectorTypeString(resultType) << ".";
1141 // The index affects which intrinsic to call
1142 ss << (op.getIndex() == 0 ? "lo" : "hi");
1143 return ss.str();
1144 }
1145
1146 LogicalResult
1147 matchAndRewrite(aievec::UPDOp op, OpAdaptor adaptor,
1148 ConversionPatternRewriter &rewriter) const override {
1149 auto module = op->getParentOfType<ModuleOp>();
1150 MLIRContext *context = rewriter.getContext();
1151
1152 // A bit more complicated: load the vector, then update result vector
1153 // AIE1 is capable of 128-bit on one bank and 256-bit loads on even-odd
1154 // banks Identify size of update
1155 int vecSizeInBits =
1156 getVectorSizeInBits(cast<VectorType>(op.getResult().getType()));
1157
1158 auto ptr = this->getStridedElementPtr(
1159 op->getLoc(), cast<MemRefType>(op.getSource().getType()),
1160 adaptor.getSource(), adaptor.getIndices(), rewriter);
1161
1162 // TODO: handle the offset field
1163
1164 if (vecSizeInBits <= 256) {
1165 // Total <=256-bit updates are much simpler:
1166 // we can do a direct load into the vector register
1167 // look at the indices to calculate the address
1168 auto vectorPtrType = LLVM::LLVMPointerType::get(
1169 getContext(),
1170 cast<MemRefType>(op.getSource().getType()).getMemorySpaceAsInt());
1171 auto castedPtr =
1172 rewriter.create<LLVM::BitcastOp>(op->getLoc(), vectorPtrType, ptr);
1173 auto vecType = cast<VectorType>(op.getResult().getType());
1174 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, vecType, castedPtr, 1);
1175 } else {
1176 // Total >256-bit updates will require upd ops to fill the whole vector
1177 // each UDP op represents one of these 256-bit loads and updates
1178
1179 // Determine the load size
1180 // TODO: no examples of 1024-bit output vectors: doesn't feel right
1181 // to attempt a 512-bit load to do an update like this
1182 int loadSize = vecSizeInBits == 256 ? 128
1183 : vecSizeInBits == 512 ? 256
1184 : 512;
1185
1186 // Create a vectorType for the load proper
1187 // Load half of the final result vector
1188 auto resultType = cast<VectorType>(op.getResult().getType());
1189 int lanes = getVectorLaneSize(resultType);
1190 auto loadType =
1191 VectorType::get({(int64_t)lanes / 2}, resultType.getElementType());
1192
1193 // Load the vector
1194 auto vectorPtrType = LLVM::LLVMPointerType::get(
1195 getContext(),
1196 cast<MemRefType>(op.getSource().getType()).getMemorySpaceAsInt());
1197 auto castedPtr =
1198 rewriter.create<LLVM::BitcastOp>(op->getLoc(), vectorPtrType, ptr);
1199 auto loadValue =
1200 rewriter.create<LLVM::LoadOp>(op->getLoc(), loadType, castedPtr, 1);
1201
1202 // Get set up for the intrinsic
1203 std::string intrinsicName = getIntrinsicName(op, loadSize);
1204
1205 // If the intrinsic declaration doesn't exist, create it
1206 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(
1207 StringAttr::get(context, intrinsicName));
1208
1209 if (!func) {
1210 OpBuilder::InsertionGuard guard(rewriter);
1211 rewriter.setInsertionPointToStart(module.getBody());
1212 func = rewriter.create<LLVM::LLVMFuncOp>(
1213 rewriter.getUnknownLoc(), intrinsicName,
1214 LLVM::LLVMFunctionType::get(resultType, {resultType, loadType}));
1215 }
1216
1217 // Determine what the destination is
1218 Value destValue;
1219 if (adaptor.getVector()) {
1220 // This UPD is using an existing destination vector
1221 destValue = adaptor.getVector();
1222 } else {
1223 // If this UPD is not working off of an existing destination vector,
1224 // create an undefined vector as the destination
1225
1226 // TODO: determine if the undef intrinsic is needed or if an LLVM
1227 // undef suffices destValue =
1228 // rewriter.create<LLVM::UndefOp>(op->getLoc(), resultType);
1229
1230 std::stringstream ss;
1231 ss << "llvm.aie." << getVectorTypeString(resultType) << ".undef";
1232 std::string intrinsicName = ss.str();
1233
1234 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(
1235 StringAttr::get(rewriter.getContext(), intrinsicName));
1236
1237 if (!func) {
1238 OpBuilder::InsertionGuard guard(rewriter);
1239 rewriter.setInsertionPointToStart(module.getBody());
1240 func = rewriter.create<LLVM::LLVMFuncOp>(
1241 rewriter.getUnknownLoc(), intrinsicName,
1242 LLVM::LLVMFunctionType::get(resultType, {}));
1243 }
1244 destValue =
1245 rewriter.create<LLVM::CallOp>(op->getLoc(), func, ValueRange{})
1246 ->getOpResult(0);
1247 }
1248
1249 // Create our call
1250 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
1251 op, func, ValueRange{destValue, loadValue});
1252 }
1253
1254 return success();
1255 }
1256};
1257
1259 : public mlir::ConvertOpToLLVMPattern<aievec::ConcatOp> {
1260public:
1261 using ConvertOpToLLVMPattern<aievec::ConcatOp>::ConvertOpToLLVMPattern;
1262
1263 LogicalResult
1264 matchAndRewrite(aievec::ConcatOp op, OpAdaptor adaptor,
1265 ConversionPatternRewriter &rewriter) const override {
1266 Location loc = op.getLoc();
1267
1268 SmallVector<Value> sources = adaptor.getSources();
1269 Value src = sources.front();
1270 VectorType srcType = cast<VectorType>(src.getType());
1271 Type srcScalarType = srcType.getElementType();
1272 unsigned srcBitWidth = srcScalarType.getIntOrFloatBitWidth();
1273 int srcLanes = getVectorLaneSize(srcType);
1274 int srcVectorSize = srcBitWidth * srcLanes;
1275
1276 Value result = op.getResult();
1277 VectorType resultType = cast<VectorType>(result.getType());
1278 Type resultScaTy = resultType.getElementType();
1279 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
1280 int resultLanes = getVectorLaneSize(resultType);
1281 int resultVectorSize = resultBitWidth * resultLanes;
1282
1283 if (sources.size() != 2 && sources.size() != 4) {
1284 op.emitWarning() << "aievec.concat with " << sources.size()
1285 << " operands is not supported.\n";
1286 return failure();
1287 }
1288
1289 // create xllvm intrinsic
1290 Value concatOp = nullptr;
1291 if (srcVectorSize == 256 && resultVectorSize == 512) {
1292 concatOp = rewriter.create<xllvm::ConcatI512I256IntrOp>(
1293 loc, VectorType::get({16}, rewriter.getI32Type()),
1294 forceCastOperandsToSignature(
1295 rewriter, loc, adaptor.getSources(),
1296 {VectorType::get({8}, rewriter.getI32Type()),
1297 VectorType::get({8}, rewriter.getI32Type())}));
1298 } else if (srcVectorSize == 256 && resultVectorSize == 1024) {
1299 concatOp = rewriter.create<xllvm::ConcatI1024I256IntrOp>(
1300 loc, VectorType::get({32}, rewriter.getI32Type()),
1301 forceCastOperandsToSignature(
1302 rewriter, loc, adaptor.getSources(),
1303 {VectorType::get({8}, rewriter.getI32Type()),
1304 VectorType::get({8}, rewriter.getI32Type()),
1305 VectorType::get({8}, rewriter.getI32Type()),
1306 VectorType::get({8}, rewriter.getI32Type())}));
1307 } else if (srcVectorSize == 512 && resultVectorSize == 1024) {
1308 concatOp = rewriter.create<xllvm::ConcatI1024I512IntrOp>(
1309 loc, VectorType::get({32}, rewriter.getI32Type()),
1310 forceCastOperandsToSignature(
1311 rewriter, loc, adaptor.getSources(),
1312 {VectorType::get({16}, rewriter.getI32Type()),
1313 VectorType::get({16}, rewriter.getI32Type())}));
1314 } else {
1315 op.emitWarning() << "aievec.concat with " << srcVectorSize
1316 << "-bit operands, and " << resultVectorSize
1317 << "-bit result is not supported.\n";
1318 return failure();
1319 }
1320
1321 // create bitcast for result
1322 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
1323 concatOp);
1324
1325 return success();
1326 }
1327};
1328
1329class ExtOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::ExtOp> {
1330public:
1331 using ConvertOpToLLVMPattern<aievec::ExtOp>::ConvertOpToLLVMPattern;
1332
1333 LogicalResult
1334 matchAndRewrite(aievec::ExtOp op, OpAdaptor adaptor,
1335 ConversionPatternRewriter &rewriter) const override {
1336 Location loc = op.getLoc();
1337
1338 Value src = adaptor.getSource();
1339 VectorType srcType = cast<VectorType>(src.getType());
1340 Type srcScalarType = srcType.getElementType();
1341 unsigned srcBitWidth = srcScalarType.getIntOrFloatBitWidth();
1342 int srcLanes = getVectorLaneSize(srcType);
1343 int srcVectorSize = srcBitWidth * srcLanes;
1344
1345 Value result = op.getResult();
1346 VectorType resultType = cast<VectorType>(result.getType());
1347 Type resultScaTy = resultType.getElementType();
1348 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
1349 int resultLanes = getVectorLaneSize(resultType);
1350 int resultVectorSize = resultBitWidth * resultLanes;
1351
1352 // create constant for index
1353 auto indexCst = rewriter.create<LLVM::ConstantOp>(
1354 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(op.getIndex()));
1355
1356 // create xllvm intrinsic
1357 SmallVector<Value> operands({adaptor.getSource(), indexCst});
1358 Value extOp = nullptr;
1359 // Integer types
1360 if (resultVectorSize == 256 && srcVectorSize == 512) {
1361 extOp = rewriter.create<xllvm::ExtI256I512IntrOp>(
1362 loc, VectorType::get({8}, rewriter.getI32Type()),
1363 forceCastOperandsToSignature(
1364 rewriter, loc, operands,
1365 {VectorType::get({16}, rewriter.getI32Type()),
1366 rewriter.getI32Type()}));
1367 } else if (resultVectorSize == 512 && srcVectorSize == 1024) {
1368 extOp = rewriter.create<xllvm::ExtI512I1024IntrOp>(
1369 loc, VectorType::get({16}, rewriter.getI32Type()),
1370 forceCastOperandsToSignature(
1371 rewriter, loc, operands,
1372 {VectorType::get({32}, rewriter.getI32Type()),
1373 rewriter.getI32Type()}));
1374 } else if (resultVectorSize == 256 && srcVectorSize == 1024) {
1375 extOp = rewriter.create<xllvm::ExtI256I1024IntrOp>(
1376 loc, VectorType::get({8}, rewriter.getI32Type()),
1377 forceCastOperandsToSignature(
1378 rewriter, loc, operands,
1379 {VectorType::get({32}, rewriter.getI32Type()),
1380 rewriter.getI32Type()}));
1381 } else if (resultVectorSize == 128 && srcVectorSize == 512) {
1382 auto shiftOp = adaptor.getSource();
1383 if (op.getIndex() > 0) {
1384 auto undefOp = rewriter.create<xllvm::UndefV16I32IntrOp>(
1385 loc, VectorType::get({16}, rewriter.getI32Type()));
1386 auto stepCst = rewriter.create<LLVM::ConstantOp>(
1387 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
1388 auto shiftCst = rewriter.create<LLVM::ConstantOp>(
1389 loc, rewriter.getI32Type(),
1390 rewriter.getI32IntegerAttr(op.getIndex() * 16));
1391 SmallVector<Value> shiftOperands{adaptor.getSource(), undefOp, stepCst,
1392 shiftCst};
1393 // Right shift the source vector in index * 16 bytes (i.e. in index *
1394 // 128 bits). The integer index is expected to be 0 to 3.
1395 shiftOp = rewriter.create<xllvm::VectorShiftI512I512IntrOp>(
1396 loc, VectorType::get({16}, rewriter.getI32Type()),
1397 forceCastOperandsToSignature(
1398 rewriter, loc, shiftOperands,
1399 {VectorType::get({16}, rewriter.getI32Type()),
1400 VectorType::get({16}, rewriter.getI32Type()),
1401 rewriter.getI32Type(), rewriter.getI32Type()}));
1402 }
1403 // The underlying intrinsic takes a source vector and extract the lowest
1404 // 128-bit. i.e. it always extracts the input vector with index = 0.
1405 extOp = rewriter.create<xllvm::ExtI128I512IntrOp>(
1406 loc, VectorType::get({4}, rewriter.getI32Type()),
1407 forceCastOperandsToSignature(
1408 rewriter, loc, /*operands=*/{shiftOp},
1409 {VectorType::get({16}, rewriter.getI32Type())}));
1410 } else {
1411 op.emitWarning() << "aievec.ext with " << srcVectorSize
1412 << "-bit source, and " << resultVectorSize
1413 << "-bit result is not supported.\n";
1414 return failure();
1415 }
1416
1417 // create bitcast for result
1418 if (op.getResult().getType() != extOp.getType()) {
1419 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
1420 extOp);
1421 } else {
1422 rewriter.replaceOp(op, extOp);
1423 }
1424
1425 return success();
1426 }
1427};
1428
1430 : public mlir::ConvertOpToLLVMPattern<aievec::aie1::SelectOp> {
1431public:
1432 using ConvertOpToLLVMPattern<aievec::aie1::SelectOp>::ConvertOpToLLVMPattern;
1433
1434 static std::string getIntrinsicName(aievec::aie1::SelectOp op) {
1435 auto xbuffType = cast<VectorType>(op.getXbuff().getType());
1436 std::stringstream ss;
1437 ss << "llvm.aie.prim." << getVectorTypeString(xbuffType) << ".select";
1438 return ss.str();
1439 }
1440
1441 LogicalResult
1442 matchAndRewrite(aievec::aie1::SelectOp op, OpAdaptor adaptor,
1443 ConversionPatternRewriter &rewriter) const override {
1444 auto module = op->getParentOfType<ModuleOp>();
1445 MLIRContext *context = rewriter.getContext();
1446
1447 auto selectType = IntegerType::get(context, 32);
1448 auto startType = IntegerType::get(context, 32);
1449 auto offsetsType = VectorType::get({2}, IntegerType::get(context, 32));
1450 auto confType = VectorType::get({2}, IntegerType::get(context, 32));
1451
1452 // If the intrinsic declaration doesn't exist, create it
1453 std::string intrinsicName = getIntrinsicName(op);
1454 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(
1455 StringAttr::get(context, intrinsicName));
1456
1457 if (!func) {
1458 OpBuilder::InsertionGuard guard(rewriter);
1459 rewriter.setInsertionPointToStart(module.getBody());
1460 func = rewriter.create<LLVM::LLVMFuncOp>(
1461 rewriter.getUnknownLoc(), intrinsicName,
1462 LLVM::LLVMFunctionType::get(op.getResult().getType(),
1463 {op.getXbuff().getType(), selectType,
1464 startType, /* xstart */
1465 startType, /* ystart */
1466 offsetsType, /* xoffsets */
1467 offsetsType, /* yoffsets */
1468 confType}));
1469 }
1470
1471 // Parse the string attribute values
1472 uint32_t select = 0;
1473 BufferParams x = {};
1474 BufferParams y = {};
1475 BufferParams z = {};
1476
1477 op.getSelect().getAsInteger(0, select);
1478 op.getXstart().getAsInteger(0, x.start);
1479 op.getXoffsets().getAsInteger(0, x.offsets);
1480 op.getXoffsetsHi().getAsInteger(0, x.offsets_hi);
1481 op.getXsquare().getAsInteger(0, x.square);
1482 op.getYstart().getAsInteger(0, y.start);
1483 op.getYoffsets().getAsInteger(0, y.offsets);
1484 op.getYoffsetsHi().getAsInteger(0, y.offsets_hi);
1485 op.getYsquare().getAsInteger(0, y.square);
1486
1487 // Encode the configuration register
1488 uint32_t conf[2] = {0, 0};
1489 encodeConf(conf, x, z, false);
1490 conf[1] |= encodeSquare(y.square) << 21;
1491
1492 // Create the constants and replace the op
1493 auto selectVal = rewriter.create<LLVM::ConstantOp>(
1494 op->getLoc(), selectType, rewriter.getI32IntegerAttr(select));
1495 auto xstartVal = rewriter.create<LLVM::ConstantOp>(
1496 op->getLoc(), startType, rewriter.getI32IntegerAttr(x.start));
1497 auto ystartVal = rewriter.create<LLVM::ConstantOp>(
1498 op->getLoc(), startType, rewriter.getI32IntegerAttr(y.start));
1499 auto xoffsetsVal = rewriter.create<LLVM::ConstantOp>(
1500 op->getLoc(), offsetsType,
1501 rewriter.getI32VectorAttr({(int32_t)x.offsets, (int32_t)x.offsets_hi}));
1502 auto yoffsetsVal = rewriter.create<LLVM::ConstantOp>(
1503 op->getLoc(), offsetsType,
1504 rewriter.getI32VectorAttr({(int32_t)y.offsets, (int32_t)y.offsets_hi}));
1505 auto confVal = rewriter.create<LLVM::ConstantOp>(
1506 op->getLoc(), confType,
1507 rewriter.getI32VectorAttr({(int32_t)conf[0], (int32_t)conf[1]}));
1508 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
1509 op, func,
1510 ValueRange{op.getXbuff(), selectVal, xstartVal, ystartVal, xoffsetsVal,
1511 yoffsetsVal, confVal});
1512 return success();
1513 }
1514};
1515
1516class PackOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::PackOp> {
1517public:
1518 using ConvertOpToLLVMPattern<aievec::PackOp>::ConvertOpToLLVMPattern;
1519
1520 static std::string getIntrinsicName(aievec::PackOp op) {
1521 auto sourceType = cast<VectorType>(op.getSource().getType());
1522 std::stringstream ss;
1523 ss << "llvm.aie.pack." << getVectorTypeString(sourceType);
1524 return ss.str();
1525 }
1526
1527 LogicalResult
1528 matchAndRewrite(aievec::PackOp op, OpAdaptor adaptor,
1529 ConversionPatternRewriter &rewriter) const override {
1530 auto module = op->getParentOfType<ModuleOp>();
1531 MLIRContext *context = rewriter.getContext();
1532
1533 // If the intrinsic declaration doesn't exist, create it
1534 std::string intrinsicName = getIntrinsicName(op);
1535 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(
1536 StringAttr::get(context, intrinsicName));
1537
1538 if (!func) {
1539 OpBuilder::InsertionGuard guard(rewriter);
1540 rewriter.setInsertionPointToStart(module.getBody());
1541 func = rewriter.create<LLVM::LLVMFuncOp>(
1542 rewriter.getUnknownLoc(), intrinsicName,
1543 LLVM::LLVMFunctionType::get(op.getResult().getType(),
1544 {op.getSource().getType()}));
1545 }
1546
1547 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, func,
1548 ValueRange{op.getSource()});
1549 return success();
1550 }
1551};
1552
1554 : public mlir::ConvertOpToLLVMPattern<aievec::UnpackOp> {
1555public:
1556 using ConvertOpToLLVMPattern<aievec::UnpackOp>::ConvertOpToLLVMPattern;
1557
1558 LogicalResult
1559 matchAndRewrite(aievec::UnpackOp op, OpAdaptor adaptor,
1560 ConversionPatternRewriter &rewriter) const override {
1561 op.emitWarning() << "aie.unpack conversion is not implemented\n";
1562 return failure();
1563 }
1564};
1565
1567 : public mlir::ConvertOpToLLVMPattern<aievec::BroadcastOp> {
1568public:
1569 using ConvertOpToLLVMPattern<aievec::BroadcastOp>::ConvertOpToLLVMPattern;
1570
1571 LogicalResult
1572 matchAndRewrite(aievec::BroadcastOp op, OpAdaptor adaptor,
1573 ConversionPatternRewriter &rewriter) const override {
1574 op.emitWarning() << "aie.broadcast conversion is not implemented\n";
1575 return failure();
1576 }
1577};
1578
1579class MaxOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::MaxOp> {
1580public:
1581 using ConvertOpToLLVMPattern<aievec::MaxOp>::ConvertOpToLLVMPattern;
1582
1583 LogicalResult
1584 matchAndRewrite(aievec::MaxOp op, OpAdaptor adaptor,
1585 ConversionPatternRewriter &rewriter) const override {
1586 Location loc = op.getLoc();
1587
1588 VectorType resultType = cast<VectorType>(op.getResult().getType());
1589 Type resultScaTy = resultType.getElementType();
1590 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
1591 int resultLanes = getVectorLaneSize(resultType);
1592 int resultVectorSize = resultBitWidth * resultLanes;
1593
1594 // aievec.max op has the AllTypesMatch constraint on lhs/rhs/res
1595 if (resultVectorSize != 512) {
1596 op.emitWarning() << "aievec.max conversion with " << resultVectorSize
1597 << "-bit result is not supported.\n";
1598 return failure();
1599 }
1600
1601 // create xllvm intrinsic
1602 Value maxOp = nullptr;
1603 if (llvm::isa<IntegerType>(resultScaTy)) {
1604 // create constant for third operand `cmp`
1605 // Note: `cmp` is implicitly treated as `sign` to the vmax intrinsic
1606 auto cmpCst = rewriter.create<LLVM::ConstantOp>(
1607 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
1608 SmallVector<Value> operands{adaptor.getLhs(), adaptor.getRhs(), cmpCst};
1609 if (resultBitWidth == 8) {
1610 maxOp = rewriter.create<xllvm::VectorMaxLt8IntrOp>(
1611 loc,
1612 mlir::LLVM::LLVMStructType::getLiteral(
1613 rewriter.getContext(),
1614 {VectorType::get({64}, rewriter.getI8Type()),
1615 VectorType::get({2}, rewriter.getI32Type())}),
1616 forceCastOperandsToSignature(
1617 rewriter, loc, operands,
1618 {VectorType::get({64}, rewriter.getI8Type()),
1619 VectorType::get({64}, rewriter.getI8Type()),
1620 rewriter.getI32Type()}));
1621 } else if (resultBitWidth == 16) {
1622 maxOp = rewriter.create<xllvm::VectorMaxLt16IntrOp>(
1623 loc,
1624 mlir::LLVM::LLVMStructType::getLiteral(
1625 rewriter.getContext(),
1626 {VectorType::get({32}, rewriter.getI16Type()),
1627 rewriter.getI32Type()}),
1628 forceCastOperandsToSignature(
1629 rewriter, loc, operands,
1630 {VectorType::get({32}, rewriter.getI16Type()),
1631 VectorType::get({32}, rewriter.getI16Type()),
1632 rewriter.getI32Type()}));
1633 } else if (resultBitWidth == 32) {
1634 maxOp = rewriter.create<xllvm::VectorMaxLt32IntrOp>(
1635 loc,
1636 mlir::LLVM::LLVMStructType::getLiteral(
1637 rewriter.getContext(),
1638 {VectorType::get({16}, rewriter.getI32Type()),
1639 rewriter.getI32Type()}),
1640 forceCastOperandsToSignature(
1641 rewriter, loc, operands,
1642 {VectorType::get({16}, rewriter.getI32Type()),
1643 VectorType::get({16}, rewriter.getI32Type()),
1644 rewriter.getI32Type()}));
1645 }
1646 } else {
1647 if (resultBitWidth == 16) {
1648 maxOp = rewriter.create<xllvm::VectorMaxLtBf16IntrOp>(
1649 loc,
1650 mlir::LLVM::LLVMStructType::getLiteral(
1651 rewriter.getContext(),
1652 {VectorType::get({32}, rewriter.getBF16Type()),
1653 rewriter.getI32Type()}),
1654 forceCastOperandsToSignature(
1655 rewriter, loc, {adaptor.getLhs(), adaptor.getRhs()},
1656 {VectorType::get({32}, rewriter.getBF16Type()),
1657 VectorType::get({32}, rewriter.getBF16Type())}));
1658 }
1659 }
1660
1661 if (!maxOp) {
1662 // We have checked the lhs/rhs/res to be 512-bit vectors. Hence, a
1663 // possible failure here is due to unsupported element datatype.
1664 op.emitWarning() << "aievec.max conversion fails due to unsupported "
1665 "element data type.\n";
1666 return failure();
1667 }
1668
1669 // create llvm.extractvalue for the first element in the LLVMStruct
1670 rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(op, maxOp,
1671 /*position=*/0);
1672
1673 return success();
1674 }
1675};
1676
1677class MinOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::MinOp> {
1678public:
1679 using ConvertOpToLLVMPattern<aievec::MinOp>::ConvertOpToLLVMPattern;
1680
1681 LogicalResult
1682 matchAndRewrite(aievec::MinOp op, OpAdaptor adaptor,
1683 ConversionPatternRewriter &rewriter) const override {
1684 Location loc = op.getLoc();
1685
1686 VectorType resultType = cast<VectorType>(op.getResult().getType());
1687 Type resultScaTy = resultType.getElementType();
1688 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
1689 int resultLanes = getVectorLaneSize(resultType);
1690 int resultVectorSize = resultBitWidth * resultLanes;
1691
1692 // aievec.min op has the AllTypesMatch constraint on lhs/rhs/res
1693 if (resultVectorSize != 512) {
1694 op.emitWarning() << "aievec.min conversion with " << resultVectorSize
1695 << "-bit result is not supported.\n";
1696 return failure();
1697 }
1698
1699 // create xllvm intrinsic
1700 Value minOp = nullptr;
1701 if (llvm::isa<IntegerType>(resultScaTy)) {
1702 // create constant for third operand `cmp`
1703 // Note: `cmp` is implicitly treated as `sign` to the vmin intrinsic
1704 auto cmpCst = rewriter.create<LLVM::ConstantOp>(
1705 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
1706 SmallVector<Value> operands{adaptor.getLhs(), adaptor.getRhs(), cmpCst};
1707 if (resultBitWidth == 8) {
1708 minOp = rewriter.create<xllvm::VectorMinGe8IntrOp>(
1709 loc,
1710 mlir::LLVM::LLVMStructType::getLiteral(
1711 rewriter.getContext(),
1712 {VectorType::get({64}, rewriter.getI8Type()),
1713 VectorType::get({2}, rewriter.getI32Type())}),
1714 forceCastOperandsToSignature(
1715 rewriter, loc, operands,
1716 {VectorType::get({64}, rewriter.getI8Type()),
1717 VectorType::get({64}, rewriter.getI8Type()),
1718 rewriter.getI32Type()}));
1719 } else if (resultBitWidth == 16) {
1720 minOp = rewriter.create<xllvm::VectorMinGe16IntrOp>(
1721 loc,
1722 mlir::LLVM::LLVMStructType::getLiteral(
1723 rewriter.getContext(),
1724 {VectorType::get({32}, rewriter.getI16Type()),
1725 rewriter.getI32Type()}),
1726 forceCastOperandsToSignature(
1727 rewriter, loc, operands,
1728 {VectorType::get({32}, rewriter.getI16Type()),
1729 VectorType::get({32}, rewriter.getI16Type()),
1730 rewriter.getI32Type()}));
1731 } else if (resultBitWidth == 32) {
1732 minOp = rewriter.create<xllvm::VectorMinGe32IntrOp>(
1733 loc,
1734 mlir::LLVM::LLVMStructType::getLiteral(
1735 rewriter.getContext(),
1736 {VectorType::get({16}, rewriter.getI32Type()),
1737 rewriter.getI32Type()}),
1738 forceCastOperandsToSignature(
1739 rewriter, loc, operands,
1740 {VectorType::get({16}, rewriter.getI32Type()),
1741 VectorType::get({16}, rewriter.getI32Type()),
1742 rewriter.getI32Type()}));
1743 }
1744 } else {
1745 if (resultBitWidth == 16) {
1746 minOp = rewriter.create<xllvm::VectorMinGeBf16IntrOp>(
1747 loc,
1748 mlir::LLVM::LLVMStructType::getLiteral(
1749 rewriter.getContext(),
1750 {VectorType::get({32}, rewriter.getBF16Type()),
1751 rewriter.getI32Type()}),
1752 forceCastOperandsToSignature(
1753 rewriter, loc, {adaptor.getLhs(), adaptor.getRhs()},
1754 {VectorType::get({32}, rewriter.getBF16Type()),
1755 VectorType::get({32}, rewriter.getBF16Type())}));
1756 }
1757 }
1758
1759 if (!minOp) {
1760 // We have checked the lhs/rhs/res to be 512-bit vectors. Hence, a
1761 // possible failure here is due to unsupported element datatype.
1762 op.emitWarning() << "aievec.min conversion fails due to unsupported "
1763 "element data type.\n";
1764 return failure();
1765 }
1766
1767 // create llvm.extractvalue for the first element in the LLVMStruct
1768 rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(op, minOp,
1769 /*position=*/0);
1770
1771 return success();
1772 }
1773};
1774
1776 : public mlir::ConvertOpToLLVMPattern<aievec::BroadcastScalarOp> {
1777public:
1778 using ConvertOpToLLVMPattern<
1779 aievec::BroadcastScalarOp>::ConvertOpToLLVMPattern;
1780
1781 LogicalResult
1782 matchAndRewrite(aievec::BroadcastScalarOp op, OpAdaptor adaptor,
1783 ConversionPatternRewriter &rewriter) const override {
1784 Location loc = op.getLoc();
1785
1786 Value result = op.getResult();
1787 VectorType resultType = cast<VectorType>(result.getType());
1788 Type resultScaTy = resultType.getElementType();
1789 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
1790 int resultLanes = getVectorLaneSize(resultType);
1791 int resultVectorSize = resultBitWidth * resultLanes;
1792
1793 if (resultVectorSize != 512) {
1794 op.emitWarning()
1795 << "aievec.broadcast_scalar conversion with result vector size "
1796 << resultVectorSize << " is not implemented.\n";
1797 return failure();
1798 }
1799
1800 // Integer types
1801 if (llvm::isa<IntegerType>(resultScaTy)) {
1802 Value src = adaptor.getSource();
1803 Type srcType = src.getType();
1804 unsigned srcBitWidth = srcType.getIntOrFloatBitWidth();
1805
1806 if (srcBitWidth < 32) {
1807 src = rewriter.create<LLVM::SExtOp>(loc, rewriter.getI32Type(), src);
1808 }
1809
1810 if (resultBitWidth == 8) {
1811 rewriter.replaceOpWithNewOp<xllvm::VectorBroadcast8I512IntrOp>(
1812 op, VectorType::get({64}, rewriter.getI8Type()), src);
1813 } else if (resultBitWidth == 16) {
1814 rewriter.replaceOpWithNewOp<xllvm::VectorBroadcast16I512IntrOp>(
1815 op, VectorType::get({32}, rewriter.getI16Type()), src);
1816 } else if (resultBitWidth == 32) {
1817 rewriter.replaceOpWithNewOp<xllvm::VectorBroadcast32I512IntrOp>(
1818 op, VectorType::get({16}, rewriter.getI32Type()), src);
1819 } else {
1820 op.emitWarning()
1821 << "aievec.broadcast_scalar conversion with result bitwidth "
1822 << resultBitWidth << " is not implemented.\n";
1823 return failure();
1824 }
1825 } else {
1826 // Float types
1827 if (resultBitWidth == 16) {
1828 rewriter.replaceOpWithNewOp<xllvm::VectorBroadcast16BF512IntrOp>(
1829 op, VectorType::get({32}, rewriter.getBF16Type()),
1830 adaptor.getSource());
1831 } else if (resultBitWidth == 32) {
1832 rewriter.replaceOpWithNewOp<xllvm::VectorBroadcastfloatI512IntrOp>(
1833 op, VectorType::get({16}, rewriter.getF32Type()),
1834 adaptor.getSource());
1835 } else {
1836 op.emitWarning()
1837 << "aievec.broadcast_scalar conversion with result bitwidth "
1838 << resultBitWidth << " is not implemented.\n";
1839 return failure();
1840 }
1841 }
1842
1843 return success();
1844 }
1845};
1846
1847class ShiftOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::ShiftOp> {
1848public:
1849 using ConvertOpToLLVMPattern<aievec::ShiftOp>::ConvertOpToLLVMPattern;
1850
1851 LogicalResult
1852 matchAndRewrite(aievec::ShiftOp op, OpAdaptor adaptor,
1853 ConversionPatternRewriter &rewriter) const override {
1854 Location loc = op.getLoc();
1855
1856 Value result = op.getResult();
1857 VectorType resultType = cast<VectorType>(result.getType());
1858 Type resultScaTy = resultType.getElementType();
1859 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
1860 int resultLanes = getVectorLaneSize(resultType);
1861 int resultVectorSize = resultBitWidth * resultLanes;
1862
1863 if (resultVectorSize != 512) {
1864 op.emitWarning() << "aievec.shift conversion with result vector size "
1865 << resultVectorSize << " is not implemented.\n";
1866 return failure();
1867 }
1868
1869 // assume step is always zero
1870 auto stepCst = rewriter.create<LLVM::ConstantOp>(
1871 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
1872
1873 // create xllvm intrinsic
1874 Value shiftOp = nullptr;
1875 SmallVector<Value> operands(
1876 {adaptor.getLhs(), adaptor.getRhs(), stepCst, adaptor.getShift()});
1877 if (llvm::isa<IntegerType>(resultScaTy)) {
1878 // Integer types
1879 shiftOp = rewriter.create<xllvm::VectorShiftI512I512IntrOp>(
1880 loc, VectorType::get({16}, rewriter.getI32Type()),
1881 forceCastOperandsToSignature(
1882 rewriter, loc, operands,
1883 {VectorType::get({16}, rewriter.getI32Type()),
1884 VectorType::get({16}, rewriter.getI32Type()),
1885 rewriter.getI32Type(), rewriter.getI32Type()}));
1886 } else {
1887 // Float types
1888 shiftOp = rewriter.create<xllvm::VectorShiftBF512BF512IntrOp>(
1889 loc, VectorType::get({32}, rewriter.getBF16Type()),
1890 forceCastOperandsToSignature(
1891 rewriter, loc, operands,
1892 {VectorType::get({32}, rewriter.getBF16Type()),
1893 VectorType::get({32}, rewriter.getBF16Type()),
1894 rewriter.getI32Type(), rewriter.getI32Type()}));
1895 }
1896
1897 // create bitcast for result
1898 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
1899 shiftOp);
1900
1901 return success();
1902 }
1903};
1904
1906 : public mlir::ConvertOpToLLVMPattern<aievec::ExtElemOp> {
1907public:
1908 using ConvertOpToLLVMPattern<aievec::ExtElemOp>::ConvertOpToLLVMPattern;
1909
1910 LogicalResult
1911 matchAndRewrite(aievec::ExtElemOp op, OpAdaptor adaptor,
1912 ConversionPatternRewriter &rewriter) const override {
1913 Location loc = op.getLoc();
1914
1915 Type resultType = op.getResult().getType();
1916 unsigned resultBitWidth = resultType.getIntOrFloatBitWidth();
1917
1918 Value src = adaptor.getSource();
1919 VectorType srcType = cast<VectorType>(src.getType());
1920 Type srcScalarType = srcType.getElementType();
1921 unsigned srcBitWidth = srcScalarType.getIntOrFloatBitWidth();
1922 int srcLanes = getVectorLaneSize(srcType);
1923 int srcVectorSize = srcBitWidth * srcLanes;
1924
1925 if (srcVectorSize != 512) {
1926 op.emitWarning() << "aievec.ext_elem conversion with source vector size "
1927 << srcVectorSize << " is not supported.\n";
1928 return failure();
1929 }
1930
1931 // create constant for sign
1932 auto signCst = rewriter.create<LLVM::ConstantOp>(
1933 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
1934
1935 // create xllvm intrinsic
1936 Value extElemOp = nullptr;
1937 SmallVector<Value> operands(
1938 {adaptor.getSource(), adaptor.getIndex(), signCst});
1939 if (resultBitWidth == 8) {
1940 extElemOp = rewriter.create<xllvm::VectorExtractElem8I512IntrOp>(
1941 loc, rewriter.getI32Type(),
1942 forceCastOperandsToSignature(
1943 rewriter, loc, operands,
1944 {VectorType::get({64}, rewriter.getI8Type()),
1945 rewriter.getI32Type(), rewriter.getI32Type()}));
1946 } else if (resultBitWidth == 16) {
1947 extElemOp = rewriter.create<xllvm::VectorExtractElem16I512IntrOp>(
1948 loc, rewriter.getI32Type(),
1949 forceCastOperandsToSignature(
1950 rewriter, loc, operands,
1951 {VectorType::get({32}, rewriter.getI16Type()),
1952 rewriter.getI32Type(), rewriter.getI32Type()}));
1953 } else if (resultBitWidth == 32) {
1954 extElemOp = rewriter.create<xllvm::VectorExtractElem32I512IntrOp>(
1955 loc, rewriter.getI32Type(),
1956 forceCastOperandsToSignature(
1957 rewriter, loc, operands,
1958 {VectorType::get({16}, rewriter.getI32Type()),
1959 rewriter.getI32Type(), rewriter.getI32Type()}));
1960 } else {
1961 op.emitWarning() << "aievec.ext_elem conversion with result bit width "
1962 << resultBitWidth << " is not implemented.\n";
1963 return failure();
1964 }
1965
1966 // create truncation op (and bitcast op)
1967 if (llvm::isa<IntegerType>(resultType)) {
1968 if (resultBitWidth < 32) {
1969 rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, resultType, extElemOp);
1970 } else {
1971 rewriter.replaceOp(op, extElemOp);
1972 }
1973 } else {
1974 // Float types
1975 if (resultBitWidth == 16) {
1976 extElemOp = rewriter.create<LLVM::TruncOp>(loc, rewriter.getI16Type(),
1977 extElemOp);
1978 }
1979 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, resultType, extElemOp);
1980 }
1981
1982 return success();
1983 }
1984};
1985
1987 : public mlir::ConvertOpToLLVMPattern<aievec::FMAElemOp> {
1988public:
1989 using ConvertOpToLLVMPattern<aievec::FMAElemOp>::ConvertOpToLLVMPattern;
1990
1991 LogicalResult
1992 matchAndRewrite(aievec::FMAElemOp fmaOp, OpAdaptor adaptor,
1993 ConversionPatternRewriter &rewriter) const override {
1994 auto loc = fmaOp.getLoc();
1995 auto lhs = adaptor.getLhs();
1996 auto rhs = adaptor.getRhs();
1997 auto acc = adaptor.getAcc();
1998 auto lhsTy = cast<VectorType>(lhs.getType());
1999 auto rhsTy = cast<VectorType>(rhs.getType());
2000 auto accTy = cast<VectorType>(acc.getType());
2001 auto flatLhsTy = getFlattenedVectorType(lhsTy);
2002 auto flatRhsTy = getFlattenedVectorType(rhsTy);
2003 auto flatAccTy = getFlattenedVectorType(accTy);
2004
2005 // Flatten operands, if needed
2006 if (lhsTy != flatLhsTy)
2007 lhs = rewriter.create<vector::ShapeCastOp>(loc, flatLhsTy, lhs);
2008 if (rhsTy != flatRhsTy)
2009 rhs = rewriter.create<vector::ShapeCastOp>(loc, flatRhsTy, rhs);
2010 if (accTy != flatAccTy)
2011 acc = rewriter.create<vector::ShapeCastOp>(loc, flatAccTy, acc);
2012
2013 // Build vmac configuration constant
2014 Type i32ty = rewriter.getI32Type();
2015 auto confCst = rewriter.create<LLVM::ConstantOp>(
2016 loc, i32ty,
2017 rewriter.getI32IntegerAttr(aiev2_vmac_compute_control(
2018 /*sgn_x=*/0, /*sgn_y=*/0, /*amode=*/2, /*bmode=*/3,
2019 /*variant=*/1, /*zero_acc=*/0, /*shift16=*/0,
2020 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
2021 /*sub_mask=*/0)));
2022
2023 // Insert vmac intrinsic
2024 auto v32bf16Ty = VectorType::get({32}, rewriter.getBF16Type());
2025 auto v8i64Ty = VectorType::get({8}, rewriter.getI64Type());
2026 auto macIntrOp = rewriter.create<xllvm::MacConfBF16IntrOp>(
2027 loc, v8i64Ty,
2028 forceCastOperandsToSignature(rewriter, loc, {lhs, rhs, acc, confCst},
2029 {v32bf16Ty, v32bf16Ty, v8i64Ty, i32ty}));
2030
2031 // Recast/Reshape result
2032 auto resVal =
2033 forceCastValueToType(rewriter, loc, macIntrOp.getResult(), flatAccTy);
2034 if (flatAccTy != accTy)
2035 resVal = rewriter.create<vector::ShapeCastOp>(loc, accTy, resVal);
2036
2037 rewriter.replaceOp(fmaOp, resVal);
2038 return success();
2039 }
2040};
2041
2043 : public mlir::ConvertOpToLLVMPattern<aievec::MatMulOp> {
2044 using ConvertOpToLLVMPattern<aievec::MatMulOp>::ConvertOpToLLVMPattern;
2045
2046 struct DecodedMatMulOp {
2047 typedef enum { I32, I64, BF16 } Kind;
2048
2049 Kind kind;
2050 Value lhs;
2051 Value rhs;
2052 Value acc;
2053 int conf;
2054 };
2055
2056 static DecodedMatMulOp decodeMatMulOp(OpAdaptor op) {
2057 Value lhs = op.getLhs();
2058 Value rhs = op.getRhs();
2059 Value acc = op.getAcc();
2060 auto accVecTy = cast<VectorType>(acc.getType());
2061 if (isa<Float32Type>(accVecTy.getElementType()))
2062 // <4x8xbf16> x <8x4xbf16> + <4x4xf32>
2063 return {DecodedMatMulOp::Kind::BF16, lhs, rhs, acc,
2064 aiev2_vmac_compute_control(
2065 /*sgn_x=*/0, /*sgn_y=*/0, /*amode=*/2, /*bmode=*/3,
2066 /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0,
2067 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
2068 /*sub_mask=*/0)};
2069
2070 int signX = 0, signY = 0;
2071 auto lhsVecTy = cast<VectorType>(lhs.getType());
2072 auto lhsScaTy = cast<IntegerType>(lhsVecTy.getElementType());
2073 if (auto extSIOp = lhs.getDefiningOp<arith::ExtSIOp>()) {
2074 lhs = extSIOp.getIn();
2075 lhsVecTy = cast<VectorType>(lhs.getType());
2076 lhsScaTy = cast<IntegerType>(lhsVecTy.getElementType());
2077 signX = 1;
2078 } else if (auto extUIOp = lhs.getDefiningOp<arith::ExtUIOp>()) {
2079 lhs = extUIOp.getIn();
2080 lhsVecTy = cast<VectorType>(lhs.getType());
2081 lhsScaTy = cast<IntegerType>(lhsVecTy.getElementType());
2082 } else {
2083 // NOTE: We're choosing 'signed' by default
2084 if (!lhsScaTy.isUnsigned())
2085 signX = 1;
2086 }
2087 auto lhsShape = lhsVecTy.getShape();
2088
2089 auto rhsVecTy = cast<VectorType>(rhs.getType());
2090 auto rhsScaTy = cast<IntegerType>(rhsVecTy.getElementType());
2091 if (auto extSIOp = rhs.getDefiningOp<arith::ExtSIOp>()) {
2092 rhs = extSIOp.getIn();
2093 rhsVecTy = cast<VectorType>(rhs.getType());
2094 rhsScaTy = cast<IntegerType>(rhsVecTy.getElementType());
2095 signY = 1;
2096 } else if (auto extUIOp = rhs.getDefiningOp<arith::ExtUIOp>()) {
2097 rhs = extUIOp.getIn();
2098 rhsVecTy = cast<VectorType>(rhs.getType());
2099 rhsScaTy = cast<IntegerType>(rhsVecTy.getElementType());
2100 } else {
2101 // NOTE: We're choosing 'signed' by default
2102 if (!rhsScaTy.isUnsigned())
2103 signY = 1;
2104 }
2105
2106 unsigned lhsBitWidth = lhsScaTy.getWidth();
2107 unsigned rhsBitWidth = rhsScaTy.getWidth();
2108 auto accScaTy = cast<IntegerType>(accVecTy.getElementType());
2109 unsigned accBitWidth = accScaTy.getWidth();
2110 if (accBitWidth == 32) {
2111 if (lhsBitWidth == 8) {
2112 if (rhsBitWidth == 4) {
2113 // <4x16xi8> x <16x8xi4> + <4x8xi32>
2114 return {DecodedMatMulOp::Kind::I32, lhs, rhs, acc,
2115 aiev2_vmac_compute_control(
2116 /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/0,
2117 /*bmode=*/0,
2118 /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0,
2119 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
2120 /*sub_mask=*/0)};
2121 } else {
2122 // <4x8xi8> x <8x8xi8> + <4x8xi32>
2123 return {DecodedMatMulOp::Kind::I32, lhs, rhs, acc,
2124 aiev2_vmac_compute_control(
2125 /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/0,
2126 /*bmode=*/1,
2127 /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0,
2128 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
2129 /*sub_mask=*/0)};
2130 }
2131 } else {
2132 if (rhsBitWidth == 8) {
2133 // <4x4xi16> x <4x8xi8> + <4x8xi32>
2134 return {DecodedMatMulOp::Kind::I32, lhs, rhs, acc,
2135 aiev2_vmac_compute_control(
2136 /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/0,
2137 /*bmode=*/2,
2138 /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0,
2139 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
2140 /*sub_mask=*/0)};
2141 } else {
2142 // <4x2xi16> x <2x8xi16> + <4x8xi32>
2143 return {DecodedMatMulOp::Kind::I32, lhs, rhs, acc,
2144 aiev2_vmac_compute_control(
2145 /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/0,
2146 /*bmode=*/3,
2147 /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0,
2148 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
2149 /*sub_mask=*/0)};
2150 }
2151 }
2152 }
2153
2154 if (lhsBitWidth == 16) {
2155 if (rhsBitWidth == 8) {
2156 if (lhsShape == ArrayRef<int64_t>({2, 8})) {
2157 // <2x8xi16> x <8x8xi8> + <2x8xi64>
2158 return {DecodedMatMulOp::Kind::I64, lhs, rhs, acc,
2159 aiev2_vmac_compute_control(
2160 /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/1,
2161 /*bmode=*/2,
2162 /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0,
2163 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
2164 /*sub_mask=*/0)};
2165 }
2166 // <4x8xi16> x <8x4xi8> + <4x4xi64>
2167 return {DecodedMatMulOp::Kind::I64, lhs, rhs, acc,
2168 aiev2_vmac_compute_control(
2169 /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/1, /*bmode=*/2,
2170 /*variant=*/1, /*zero_acc=*/0, /*shift16=*/0,
2171 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
2172 /*sub_mask=*/0)};
2173 }
2174 if (lhsShape == ArrayRef<int64_t>({2, 4})) {
2175 // <2x4xi16> x <4x8xi16> + <2x8xi64>
2176 return {DecodedMatMulOp::Kind::I64, lhs, rhs, acc,
2177 aiev2_vmac_compute_control(
2178 /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/1, /*bmode=*/3,
2179 /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0,
2180 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
2181 /*sub_mask=*/0)};
2182 }
2183 // <4x4xi16> x <4x4xi16> + <4x4xi64>
2184 return {DecodedMatMulOp::Kind::I64, lhs, rhs, acc,
2185 aiev2_vmac_compute_control(
2186 /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/1, /*bmode=*/3,
2187 /*variant=*/1, /*zero_acc=*/0, /*shift16=*/0,
2188 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
2189 /*sub_mask=*/0)};
2190 }
2191 // <4x2xi32> x <2x4xi16> + <4x4xi64>
2192 return {DecodedMatMulOp::Kind::I64, lhs, rhs, acc,
2193 aiev2_vmac_compute_control(
2194 /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/1, /*bmode=*/0,
2195 /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0,
2196 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
2197 /*sub_mask=*/0)};
2198 }
2199
2200 LogicalResult
2201 matchAndRewrite(aievec::MatMulOp op, OpAdaptor adaptor,
2202 ConversionPatternRewriter &rewriter) const override {
2203 auto decodedMatMulOp = decodeMatMulOp(adaptor);
2204
2205 Location loc = op.getLoc();
2206 // Flatten the inputs
2207 auto lhsFlattenedVecTy =
2208 getFlattenedVectorType(cast<VectorType>(decodedMatMulOp.lhs.getType()));
2209 decodedMatMulOp.lhs = rewriter.create<vector::ShapeCastOp>(
2210 loc, lhsFlattenedVecTy, decodedMatMulOp.lhs);
2211 auto rhsFlattenedVecTy =
2212 getFlattenedVectorType(cast<VectorType>(decodedMatMulOp.rhs.getType()));
2213 decodedMatMulOp.rhs = rewriter.create<vector::ShapeCastOp>(
2214 loc, rhsFlattenedVecTy, decodedMatMulOp.rhs);
2215 auto accFlattenedVecTy =
2216 getFlattenedVectorType(cast<VectorType>(decodedMatMulOp.acc.getType()));
2217 decodedMatMulOp.acc = rewriter.create<vector::ShapeCastOp>(
2218 loc, accFlattenedVecTy, decodedMatMulOp.acc);
2219
2220 Type i32ty = rewriter.getI32Type();
2221 auto confCst = rewriter.create<LLVM::ConstantOp>(
2222 loc, i32ty, rewriter.getI32IntegerAttr(decodedMatMulOp.conf));
2223 SmallVector<Value> operands({decodedMatMulOp.lhs, decodedMatMulOp.rhs,
2224 decodedMatMulOp.acc, confCst});
2225 Value matMulResVal;
2226 if (decodedMatMulOp.kind == DecodedMatMulOp::Kind::BF16)
2227 matMulResVal =
2228 rewriter
2229 .create<xllvm::MacConfBF16IntrOp>(
2230 loc, VectorType::get({8}, rewriter.getI64Type()),
2231 forceCastOperandsToSignature(
2232 rewriter, loc, operands,
2233 {VectorType::get({32}, rewriter.getBF16Type()),
2234 VectorType::get({32}, rewriter.getBF16Type()),
2235 VectorType::get({8}, rewriter.getI64Type()), i32ty}))
2236 .getResult();
2237 else {
2238 SmallVector<Type> intrFuncSig(
2239 {VectorType::get({64}, rewriter.getI8Type()),
2240 VectorType::get({16}, i32ty),
2241 VectorType::get({16}, rewriter.getI64Type()), i32ty});
2242 VectorType v16xi64ty = VectorType::get({16}, rewriter.getI64Type());
2243 if (decodedMatMulOp.kind == DecodedMatMulOp::Kind::I32)
2244 matMulResVal = rewriter
2245 .create<xllvm::MacConfAcc32IntrOp>(
2246 loc, v16xi64ty,
2247 forceCastOperandsToSignature(
2248 rewriter, loc, operands, intrFuncSig))
2249 .getResult();
2250 else
2251 matMulResVal = rewriter
2252 .create<xllvm::MacConfAcc64IntrOp>(
2253 loc, v16xi64ty,
2254 forceCastOperandsToSignature(
2255 rewriter, loc, operands, intrFuncSig))
2256 .getResult();
2257 }
2258
2259 auto castFromAcc =
2260 bitcastValueToType(rewriter, loc, matMulResVal, accFlattenedVecTy);
2261
2262 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(),
2263 castFromAcc);
2264
2265 return success();
2266 }
2267};
2268
2269// This pattern folds aievec.cast op. For AIE2, the accumulators are in 32/64
2270// bits, and the vectors are in 4/8/16/32 bits. Hence, we don't have to
2271// explicitly express the casting between accumulators and vectors at the LLVM
2272// dialect level. The backend LLVM compiler will decide the correct accumulator
2273// or vector registers given the ops and intrinsics.
2274class FoldAIECastOps : public mlir::ConvertOpToLLVMPattern<aievec::CastOp> {
2275 using ConvertOpToLLVMPattern<aievec::CastOp>::ConvertOpToLLVMPattern;
2276
2277 LogicalResult
2278 matchAndRewrite(aievec::CastOp castOp, OpAdaptor adaptor,
2279 ConversionPatternRewriter &rewriter) const override {
2280 rewriter.replaceOp(castOp, adaptor.getSource());
2281 return success();
2282 }
2283};
2284
2286 : public mlir::ConvertOpToLLVMPattern<aievec::ShuffleOp> {
2287 using ConvertOpToLLVMPattern<aievec::ShuffleOp>::ConvertOpToLLVMPattern;
2288
2289 LogicalResult
2290 matchAndRewrite(aievec::ShuffleOp shuffleOp, OpAdaptor adaptor,
2291 ConversionPatternRewriter &rewriter) const override {
2292 auto loc = shuffleOp.getLoc();
2293 auto lhs = adaptor.getLhs();
2294 auto rhs = adaptor.getRhs();
2295 auto i32ty = rewriter.getI32Type();
2296 auto v16xi32ty = VectorType::get({16}, i32ty);
2297 if (!rhs)
2298 rhs = rewriter.create<xllvm::UndefV16I32IntrOp>(loc, v16xi32ty);
2299
2300 auto modeAttrVal =
2301 rewriter
2302 .create<LLVM::ConstantOp>(loc, i32ty,
2303 static_cast<int32_t>(shuffleOp.getMode()))
2304 .getResult();
2305 auto vShuffleVal = rewriter
2306 .create<xllvm::VectorShuffleIntrOp>(
2307 loc, v16xi32ty,
2308 forceCastOperandsToSignature(
2309 rewriter, loc,
2310 /*operands=*/{lhs, rhs, modeAttrVal},
2311 /*signature=*/{v16xi32ty, v16xi32ty, i32ty}))
2312 .getResult();
2313
2314 vShuffleVal = forceCastValueToType(rewriter, loc, vShuffleVal,
2315 shuffleOp.getResult().getType());
2316
2317 rewriter.replaceOp(shuffleOp, vShuffleVal);
2318
2319 return success();
2320 }
2321};
2322
2324 mlir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns,
2325 Aie2Fp32Emulation aie2Fp32EmulationOption) {
2326 // clang-format off
2327 patterns.add<AddOpConversion,
2348 ShuffleOpConversion>(converter);
2349 patterns.add<MulElemOpConversion>(converter, aie2Fp32EmulationOption);
2350 // clang-format on
2351}
2352
2354 : ConvertAIEVecToLLVMBase<ConvertAIEVecToLLVMPass> {
2355 void runOnOperation() override {
2356 RewritePatternSet patterns(&getContext());
2357 LLVMTypeConverter converter(&getContext());
2358
2359 // Don't convert vector types, we want to handle multi-dimensional
2360 // vector on our own.
2361 converter.addConversion(
2362 [&](VectorType type) -> std::optional<Type> { return type; });
2363
2364 populateAIEVecToLLVMConversionPatterns(converter, patterns,
2365 aie2Fp32Emulation);
2366
2367 LLVMConversionTarget target(getContext());
2368 target.addIllegalDialect<xilinx::aievec::AIEVecDialect,
2369 xilinx::aievec::aie1::AIEVecAIE1Dialect>();
2370 target.addLegalDialect<arith::ArithDialect, vector::VectorDialect,
2371 xilinx::xllvm::XLLVMDialect>();
2372 if (failed(applyPartialConversion(getOperation(), target,
2373 std::move(patterns))))
2374 signalPassFailure();
2375 }
2376};
2377
2378std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
2380 return std::make_unique<ConvertAIEVecToLLVMPass>();
2381}
2382
2383} // namespace xilinx::aievec
LogicalResult matchAndRewrite(aievec::aie1::AddOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::BroadcastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::BroadcastScalarOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ConcatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ExtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ExtElemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::FMAElemOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::aie1::FMAOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::MaxOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::MinOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult convertToEmulatedFP32MulElem(aievec::MulElemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
LogicalResult convertToEmulatedI32MulElem(aievec::MulElemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
MulElemOpConversion(const LLVMTypeConverter &typeConverter, Aie2Fp32Emulation aie2Fp32EmulationOption)
LogicalResult matchAndRewrite(aievec::MulElemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static DecodedMulElemOp decodeMulElemOp(OpAdaptor op)
LogicalResult matchAndRewrite(aievec::aie1::MulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::PackOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static std::string getIntrinsicName(aievec::PackOp op)
LogicalResult matchAndRewrite(aievec::SRSOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static std::string getIntrinsicName(aievec::aie1::SelectOp op)
LogicalResult matchAndRewrite(aievec::aie1::SelectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ShiftOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::aie1::SubOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static std::string getIntrinsicName(aievec::UPDOp op, int loadSize)
LogicalResult matchAndRewrite(aievec::UPDOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::UPSOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::UnpackOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
mlir::VectorType getFlattenedVectorType(mlir::VectorType vecTy)
int32_t getVectorSizeInBits(mlir::VectorType type)
Definition AIEVecUtils.h:66
unsigned getVectorLaneSize(mlir::VectorType type)
Definition AIEVecUtils.h:55
uint32_t encodeSquare(uint32_t square)
void encodeConf(uint32_t conf[2], const BufferParams &x, const BufferParams &z, bool sub)
void populateAIEVecToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns)
std::string getMulOrFMAIntrinsicName(Operation *op)
std::unique_ptr< mlir::OperationPass< mlir::ModuleOp > > createConvertAIEVecToLLVMPass()
std::string getVectorTypeString(VectorType type, bool abbrev=false, bool acc=false)