MLIR-AIE
AIEVecToLLVM.cpp
Go to the documentation of this file.
1//===- AIEVecToLLVM.cpp - AIEVec to LLVM dialect conversion ---------------===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2022 Xilinx Inc.
8// (c) Copyright 2024 Advanced Micro Devices Inc.
9//
10//===----------------------------------------------------------------------===//
11
12#include "../PassDetail.h"
13
20#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
21#include "mlir/Conversion/LLVMCommon/Pattern.h"
22#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
23#include "mlir/Dialect/Math/IR/Math.h"
24#include "mlir/Dialect/UB/IR/UBOps.h"
25#include "mlir/IR/TypeUtilities.h"
26#include <sstream>
27
28namespace xilinx {
29using namespace mlir; // For LLVM::LLVMDialect in generated getDependentDialects
30#define GEN_PASS_DEF_CONVERTAIEVECTOLLVM
31#include "aie/Conversion/Passes.h.inc"
32} // namespace xilinx
33
34using namespace mlir;
35
36namespace xilinx::aievec {
37
38inline static Value bitcastValueToType(OpBuilder &builder, Location loc,
39 Value val, Type dstTy) {
40 return LLVM::BitcastOp::create(builder, loc, dstTy, val).getResult();
41}
42
43// This function emits the instructions required to widen a 128b input vector
44// into a 512b encoded as a vector<16xi32>. It first bitcasts it to a
45// vector<4xi32> to respect the intrinsic signature.
46inline static Value widen128bVectorValueTo512b(OpBuilder &builder, Location loc,
47 Value val) {
48 return xllvm::VectorSetI512I128IntrOp::create(
49 builder, loc, VectorType::get({16}, builder.getI32Type()),
50 bitcastValueToType(builder, loc, val,
51 VectorType::get({4}, builder.getI32Type())))
52 .getResult();
53}
54
55// This function emits the instructions required to widen a 256b input vector
56// into a 512b encoded as a vector<16xi32>. It first bitcasts it to a
57// vector<8xi32> to respect the intrinsic signature. It will also materialize
58// a constant 0, used as an insertion index.
59inline static Value widen256bVectorValueTo512b(OpBuilder &builder, Location loc,
60 Value val) {
61 auto cst0 =
62 LLVM::ConstantOp::create(builder, loc, builder.getI32Type(), (int32_t)0);
63 return xllvm::VectorSetI512I256IntrOp::create(
64 builder, loc, VectorType::get({16}, builder.getI32Type()),
65 bitcastValueToType(builder, loc, val,
66 VectorType::get({8}, builder.getI32Type())),
67 cst0)
68 .getResult();
69}
70
71// This function emits the sequence of operations that forces a value into a
72// specific type. This may include widening vectors to match a specific bit
73// length.
74static Value forceCastValueToType(OpBuilder &builder, Location loc, Value val,
75 Type type) {
76 auto valTy = val.getType();
77 if (valTy == type)
78 return val;
79 auto srcVecTy = dyn_cast<VectorType>(valTy);
80 auto dstVecTy = dyn_cast<VectorType>(type);
81
82 if (srcVecTy) {
83 assert(dstVecTy && "vector values cannot be forced into a non-vector type");
84
85 // Flatten source vector if it's not rank-1
86 auto flatSrcVecTy = getFlattenedVectorType(srcVecTy);
87 if (srcVecTy != flatSrcVecTy)
88 val = vector::ShapeCastOp::create(builder, loc, flatSrcVecTy, val);
89
90 // Flatten destination type if it's not rank-1
91 auto flatDstVecTy = getFlattenedVectorType(dstVecTy);
92
93 int64_t dstVecLength =
94 flatDstVecTy.getElementTypeBitWidth() * flatDstVecTy.getShape()[0];
95 int64_t srcVecLength =
96 flatSrcVecTy.getElementTypeBitWidth() * flatSrcVecTy.getShape()[0];
97 if (srcVecLength != dstVecLength) {
98 assert(srcVecLength < dstVecLength &&
99 "only widening forced casts are supported");
100 assert(dstVecLength == 512 &&
101 (srcVecLength == 128 || srcVecLength == 256) &&
102 "only 128b to 512b and 256b to 512b forced casts are supported");
103 if (srcVecLength == 128)
104 val = widen128bVectorValueTo512b(builder, loc, val);
105 else
106 val = widen256bVectorValueTo512b(builder, loc, val);
107 }
108
109 // Bitcast to flat destination type (bitcast only supports flat vectors)
110 val = bitcastValueToType(builder, loc, val, flatDstVecTy);
111
112 // Reshape back to original destination shape if needed
113 if (flatDstVecTy != dstVecTy)
114 val = vector::ShapeCastOp::create(builder, loc, dstVecTy, val);
115
116 return val;
117 }
118
119 // Non-vector types can be bitcast directly
120 assert(!dstVecTy && "cannot force cast scalar to vector type");
121 return bitcastValueToType(builder, loc, val, type);
122}
123
124// This function emits the sequence of operations that forces a range of values
125// to match the signature specified by the TypeRange. It can be used to convert
126// the parameters of an op being converted to the types accepted by an
127// intrinsic with a fixed signature that treats its inputs as "bags of bits".
128static SmallVector<Value> forceCastOperandsToSignature(OpBuilder &builder,
129 Location loc,
130 ValueRange operands,
131 TypeRange signature) {
132 return llvm::to_vector(llvm::map_range(
133 llvm::zip_equal(operands, signature), [&](auto &&vt) -> Value {
134 return forceCastValueToType(builder, loc, std::get<0>(vt),
135 std::get<1>(vt));
136 }));
137}
138
139// Utility function to get or create a noinline scalar helper function.
140// This is used to create optimization barriers that prevent LLVM from
141// re-vectorizing unrolled scalar operations.
142//
143// Parameters:
144// - module: The parent module to insert the function into
145// - rewriter: The pattern rewriter
146// - opName: Base name of the operation (e.g., "fdiv", "addf", "mulf")
147// - device: Target device ("aie2", "aie2p", etc.)
148// - argTypes: Input argument types
149// - resultType: Return type
150// - bodyBuilder: Lambda that builds the function body given (builder, loc,
151// args)
152//
153// Returns: The helper function (created or existing)
154//
155// Function naming convention: __<device>_scalar_<opName>
156// Example: __aie2p_scalar_fdiv, __aie2_scalar_addf
157static LLVM::LLVMFuncOp getOrCreateScalarHelperFunc(
158 ModuleOp module, OpBuilder &rewriter, StringRef opName, StringRef device,
159 TypeRange argTypes, Type resultType,
160 std::function<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
161
162 // Build function name: __<device>_scalar_<opName>
163 std::string funcName = "__" + device.str() + "_scalar_" + opName.str();
164
165 // Check if function already exists
166 auto helperFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(funcName);
167 if (helperFunc)
168 return helperFunc;
169
170 // Create new function
171 OpBuilder::InsertionGuard guard(rewriter);
172 rewriter.setInsertionPointToStart(module.getBody());
173
174 // Convert TypeRange to SmallVector<Type> for LLVM::LLVMFunctionType::get
175 SmallVector<Type> argTypesVec(argTypes.begin(), argTypes.end());
176
177 helperFunc = LLVM::LLVMFuncOp::create(
178 rewriter, rewriter.getUnknownLoc(), funcName,
179 LLVM::LLVMFunctionType::get(resultType, argTypesVec));
180
181 // Mark as noinline to act as optimization barrier
182 helperFunc->setAttr("passthrough", rewriter.getArrayAttr(
183 {rewriter.getStringAttr("noinline")}));
184
185 // Add function body
186 auto *entryBlock = helperFunc.addEntryBlock(rewriter);
187 OpBuilder::InsertionGuard bodyGuard(rewriter);
188 rewriter.setInsertionPointToStart(entryBlock);
189
190 // Collect function arguments
191 SmallVector<Value> args;
192 for (unsigned i = 0; i < argTypes.size(); ++i)
193 args.push_back(entryBlock->getArgument(i));
194
195 // Call the body builder with the function arguments
196 bodyBuilder(rewriter, rewriter.getUnknownLoc(), args);
197
198 return helperFunc;
199}
200
202 uint32_t start;
203 uint32_t offsets;
204 uint32_t offsets_hi;
205 uint32_t step;
206 uint32_t square;
207};
208
209// sgn_x: Sign mask of matrix X. If it is one matrix X is interpreted as
210// signed, else it treated as unsigned.
211// sgn_y: Sign mask of matrix Y. If it is one matrix Y is interpreted as
212// signed, else it treated as unsigned.
213// amode/bmode/variant: config acc width, mul precision, and mul mode
214// zero_acc: Zeroing of acc1. If it is one then acc1 is zeroed.
215// shift16: Shift mask of acc1. If a bit is set the <<16 operation will be
216// executed on acc1.
217// sub_mul: Negation mask of the matrix multiplication result. If it is
218// one the result of the operation will be negated.
219// sub_acc1: Negation mask of acc1. If it is one acc1 will be negated.
220// sub_acc2: Negation mask of acc2. If it is one acc2 will be negated.
221// sub_mask: Negation mask of complex multiplications. Negates a term of a
222// complex multiplication.
223static inline int aiev2_vmac_compute_control(int sgn_x, int sgn_y, int amode,
224 int bmode, int variant,
225 int zero_acc, int shift16,
226 int sub_mul, int sub_acc1,
227 int sub_acc2, int sub_mask) {
228 return ((unsigned)sub_mask << 16) | ((unsigned)shift16 << 10) |
229 ((unsigned)sub_mul << 11) | ((unsigned)sub_acc1 << 12) |
230 ((unsigned)sub_acc2 << 13) | ((unsigned)amode << 1) |
231 ((unsigned)bmode << 3) | ((unsigned)variant << 5) |
232 (((unsigned)sgn_x << 9) | ((unsigned)sgn_y << 8)) |
233 ((unsigned)zero_acc << 0);
234}
235
236std::string getVectorTypeString(VectorType type, bool abbrev = false,
237 bool acc = false) {
238 std::stringstream ss;
239 auto size = getVectorLaneSize(type);
240 ss << "v" << size;
241 if (auto intType = dyn_cast<IntegerType>(type.getElementType())) {
242 ss << (acc ? "acc" : abbrev ? "i" : "int") << intType.getWidth();
243 } else if (dyn_cast<FloatType>(type.getElementType())) {
244 ss << (abbrev ? "f" : "float");
245 }
246 return ss.str();
247}
248
249std::string getMulOrFMAIntrinsicName(Operation *op) {
250 std::string baseName;
251 Value lhs, result;
252 if (auto mulOp = dyn_cast<aievec::aie1::MulOp>(op)) {
253 baseName = "mul";
254 lhs = mulOp.getLhs();
255 result = mulOp.getResult();
256 } else if (auto fmaOp = dyn_cast<aievec::aie1::FMAOp>(op)) {
257 baseName = "mac";
258 lhs = fmaOp.getLhs();
259 result = fmaOp.getResult();
260 }
261 VectorType resultType = cast<VectorType>(result.getType());
262 int resultSize = getVectorLaneSize(resultType);
263 std::stringstream ss;
264 ss << "llvm.aie.";
265 if (dyn_cast<IntegerType>(resultType.getElementType())) {
266 ss << baseName;
267 ss << resultSize << "."
268 << getVectorTypeString(cast<VectorType>(lhs.getType()));
269 } else if (dyn_cast<FloatType>(resultType.getElementType())) {
270 ss << "vfp" << baseName;
271 }
272 return ss.str();
273}
274
275// Squashes the easy-to-read 16-bit square encoding into
276// the 8-bit encoding the configuration register uses
277uint32_t encodeSquare(uint32_t square) {
278 uint32_t out = 0;
279 out |= ((square >> 0) & 0x3) << 0;
280 out |= ((square >> 4) & 0x3) << 2;
281 out |= ((square >> 8) & 0x3) << 4;
282 out |= ((square >> 12) & 0x3) << 6;
283 return out & 0xFF;
284}
285
286// Encode the configuration register with buffer parameters and options
287// TODO: struct to handle this?
288void encodeConf(uint32_t conf[2], const BufferParams &x, const BufferParams &z,
289 bool sub) {
290 conf[0] |= ((x.step & 0x3F) << 0) | ((z.step & 0x3F) << 8);
291 conf[1] |= (encodeSquare(x.square) << 0) | (encodeSquare(z.square) << 8);
292 conf[1] |= sub << 17;
293}
294
296 : public mlir::ConvertOpToLLVMPattern<aievec::aie1::AddOp> {
297public:
298 using ConvertOpToLLVMPattern<aievec::aie1::AddOp>::ConvertOpToLLVMPattern;
299
300 LogicalResult
301 matchAndRewrite(aievec::aie1::AddOp op, OpAdaptor adaptor,
302 ConversionPatternRewriter &rewriter) const override {
303 op.emitWarning() << "aie.add conversion is not implemented\n";
304 return failure();
305 }
306};
307
308// AIE2 version of AddElemOp conversion
310 : public mlir::ConvertOpToLLVMPattern<aievec::AddElemOp> {
311public:
312 using ConvertOpToLLVMPattern<aievec::AddElemOp>::ConvertOpToLLVMPattern;
313
319
320 static DecodedAddElemOp decodeAddElemOp(OpAdaptor op) {
321 auto lhs = op.getLhs();
322 auto lhsVecTy = cast<VectorType>(lhs.getType());
323 auto lhsScaTy = lhsVecTy.getElementType();
324 unsigned lhsBitWidth = lhsScaTy.getIntOrFloatBitWidth();
325
326 // Integer types
327 if (llvm::isa<IntegerType>(lhsScaTy)) {
329 } else {
330 // Float types
331 if (lhsBitWidth == 32) {
332 // FP32 add_elem
334 }
335 }
337 }
338
339 LogicalResult
340 matchAndRewrite(aievec::AddElemOp op, OpAdaptor adaptor,
341 ConversionPatternRewriter &rewriter) const override {
342 Location loc = op.getLoc();
343 auto decodedAddElemOp = decodeAddElemOp(adaptor);
344
345 if (decodedAddElemOp.kind == DecodedAddElemOp::Kind::UNSUPPORTED) {
346 op.emitWarning() << "aievec.add_elem conversion is not supported.\n";
347 return failure();
348 }
349
350 // Handle the FP32 add_elem for AIE2 - uses packed I64 representation
351 if (decodedAddElemOp.kind ==
353 auto confCst = LLVM::ConstantOp::create(
354 rewriter, loc, rewriter.getI32Type(),
355 rewriter.getI32IntegerAttr(decodedAddElemOp.conf));
356 SmallVector<Value> operands(
357 {adaptor.getLhs(), adaptor.getRhs(), confCst});
358
359 auto addElemOp = xllvm::AddAccFloatAIE2IntrOp::create(
360 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()),
361 forceCastOperandsToSignature(
362 rewriter, loc, operands,
363 {VectorType::get({8}, rewriter.getI64Type()),
364 VectorType::get({8}, rewriter.getI64Type()),
365 rewriter.getI32Type()}));
366
367 // create bitcast/shape_cast for result
368 auto resultVal = forceCastValueToType(rewriter, loc, addElemOp,
369 op.getResult().getType());
370 rewriter.replaceOp(op, resultVal);
371 return success();
372 }
373
374 op.emitWarning() << "aievec.add_elem conversion is not supported.\n";
375 return failure();
376 }
377};
378
379// AIE2 version of SubElemOp conversion
381 : public mlir::ConvertOpToLLVMPattern<aievec::SubElemOp> {
382public:
383 using ConvertOpToLLVMPattern<aievec::SubElemOp>::ConvertOpToLLVMPattern;
384
390
391 static DecodedSubElemOp decodeSubElemOp(OpAdaptor op) {
392 auto lhs = op.getLhs();
393 auto lhsVecTy = cast<VectorType>(lhs.getType());
394 auto lhsScaTy = lhsVecTy.getElementType();
395 unsigned lhsBitWidth = lhsScaTy.getIntOrFloatBitWidth();
396
397 // Integer types
398 if (llvm::isa<IntegerType>(lhsScaTy)) {
400 } else {
401 // Float types
402 if (lhsBitWidth == 32) {
403 // FP32 sub_elem
405 }
406 }
408 }
409
410 LogicalResult
411 matchAndRewrite(aievec::SubElemOp op, OpAdaptor adaptor,
412 ConversionPatternRewriter &rewriter) const override {
413 Location loc = op.getLoc();
414 auto decodedSubElemOp = decodeSubElemOp(adaptor);
415
416 if (decodedSubElemOp.kind == DecodedSubElemOp::Kind::UNSUPPORTED) {
417 op.emitWarning() << "aievec.sub_elem conversion is not supported.\n";
418 return failure();
419 }
420
421 // Handle the FP32 sub_elem for AIE2 - uses packed I64 representation
422 if (decodedSubElemOp.kind ==
424 auto confCst = LLVM::ConstantOp::create(
425 rewriter, loc, rewriter.getI32Type(),
426 rewriter.getI32IntegerAttr(decodedSubElemOp.conf));
427 SmallVector<Value> operands(
428 {adaptor.getLhs(), adaptor.getRhs(), confCst});
429
430 auto subElemOp = xllvm::SubAccFloatAIE2IntrOp::create(
431 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()),
432 forceCastOperandsToSignature(
433 rewriter, loc, operands,
434 {VectorType::get({8}, rewriter.getI64Type()),
435 VectorType::get({8}, rewriter.getI64Type()),
436 rewriter.getI32Type()}));
437
438 // create bitcast/shape_cast for result
439 auto resultVal = forceCastValueToType(rewriter, loc, subElemOp,
440 op.getResult().getType());
441 rewriter.replaceOp(op, resultVal);
442 return success();
443 }
444
445 op.emitWarning() << "aievec.sub_elem conversion is not supported.\n";
446 return failure();
447 }
448};
449
450// AIE2p version of AddElemOp conversion
452 : public mlir::ConvertOpToLLVMPattern<aievec::AddElemOp> {
453public:
454 using ConvertOpToLLVMPattern<aievec::AddElemOp>::ConvertOpToLLVMPattern;
455
465
466 static DecodedAddElemOp decodeAddElemOp(OpAdaptor op) {
467 auto lhs = op.getLhs();
468 auto lhsVecTy = cast<VectorType>(lhs.getType());
469 auto lhsScaTy = lhsVecTy.getElementType();
470 unsigned lhsBitWidth = lhsScaTy.getIntOrFloatBitWidth();
471 int laneSize = getVectorLaneSize(lhsVecTy);
472
473 // Integer types
474 if (llvm::isa<IntegerType>(lhsScaTy)) {
476 } else {
477 // Float types
478 if (lhsBitWidth == 32) {
479 // FP32 add_elem
480 if (laneSize == 16) {
482 } else if (laneSize == 32) {
484 }
485 }
486 }
488 }
489
490 LogicalResult
491 matchAndRewrite(aievec::AddElemOp op, OpAdaptor adaptor,
492 ConversionPatternRewriter &rewriter) const override {
493 Location loc = op.getLoc();
494 auto decodedAddElemOp = decodeAddElemOp(adaptor);
495
496 if (decodedAddElemOp.kind == DecodedAddElemOp::Kind::UNSUPPORTED) {
497 op.emitWarning() << "aievec.add_elem conversion is not supported.\n";
498 return failure();
499 }
500
501 // Handle the FP32 add_elem for AIE2p (16-lane)
502 // We need to expand <16xf32> to <64xf32> for the ACC2048 intrinsic
503 if (decodedAddElemOp.kind ==
505 // Step 1: Bitcast <16 x float> to <8 x i64>
506 auto v8i64Ty = VectorType::get({8}, rewriter.getI64Type());
507 auto lhsI64 =
508 LLVM::BitcastOp::create(rewriter, loc, v8i64Ty, adaptor.getLhs());
509 auto rhsI64 =
510 LLVM::BitcastOp::create(rewriter, loc, v8i64Ty, adaptor.getRhs());
511
512 // Step 2: Shuffle <8 x i64> to <32 x i64> (expand with poison values)
513 auto v32i64Ty = VectorType::get({32}, rewriter.getI64Type());
514 SmallVector<int64_t> expandMask = {0, 1, 2, 3, 4, 5, 6, 7};
515 for (int i = 8; i < 32; ++i)
516 expandMask.push_back(-1); // poison values
517
518 auto lhsExpanded =
519 vector::ShuffleOp::create(rewriter, loc, lhsI64, lhsI64, expandMask);
520 auto rhsExpanded =
521 vector::ShuffleOp::create(rewriter, loc, rhsI64, rhsI64, expandMask);
522
523 // Step 3: Bitcast <32 x i64> to <64 x float>
524 auto v64f32Ty = VectorType::get({64}, rewriter.getF32Type());
525 auto lhsF32 =
526 LLVM::BitcastOp::create(rewriter, loc, v64f32Ty, lhsExpanded);
527 auto rhsF32 =
528 LLVM::BitcastOp::create(rewriter, loc, v64f32Ty, rhsExpanded);
529
530 // Step 4: Call the ACC2048 intrinsic with conf=60
531 auto confCst = LLVM::ConstantOp::create(
532 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(60));
533
534 // Create the intrinsic call
535 auto addResult = xllvm::AddACC2048AccFloatAIE2pIntrOp::create(
536 rewriter, loc, v64f32Ty, lhsF32, rhsF32, confCst);
537
538 // Step 5: Bitcast <64 x float> back to <32 x i64>
539 auto resultI64 =
540 LLVM::BitcastOp::create(rewriter, loc, v32i64Ty, addResult);
541
542 // Step 6: Shuffle to extract first 8 elements <32 x i64> -> <8 x i64>
543 SmallVector<int64_t> extractMask = {0, 1, 2, 3, 4, 5, 6, 7};
544 auto resultExtracted = vector::ShuffleOp::create(rewriter, loc, resultI64,
545 resultI64, extractMask);
546
547 // Step 7: Bitcast <8 x i64> back to <16 x float>
548 auto v16f32Ty = VectorType::get({16}, rewriter.getF32Type());
549 auto finalResult =
550 LLVM::BitcastOp::create(rewriter, loc, v16f32Ty, resultExtracted);
551
552 rewriter.replaceOp(op, finalResult);
553 return success();
554 }
555
556 // Handle the FP32 add_elem for AIE2p (32-lane)
557 // Use ACC2048 intrinsic by padding to 64 lanes
558 if (decodedAddElemOp.kind ==
560 // Pad from <32 x float> to <64 x float> using shuffle
561 SmallVector<int64_t> padMask;
562 for (int i = 0; i < 32; ++i)
563 padMask.push_back(i);
564 for (int i = 32; i < 64; ++i)
565 padMask.push_back(-1); // poison/undef
566
567 auto v64f32Ty = VectorType::get({64}, rewriter.getF32Type());
568 auto lhsPadded = vector::ShuffleOp::create(
569 rewriter, loc, adaptor.getLhs(), adaptor.getLhs(), padMask);
570 auto rhsPadded = vector::ShuffleOp::create(
571 rewriter, loc, adaptor.getRhs(), adaptor.getRhs(), padMask);
572
573 // Call ACC2048 intrinsic
574 auto confCst = LLVM::ConstantOp::create(
575 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(60));
576 auto addResult = xllvm::AddACC2048AccFloatAIE2pIntrOp::create(
577 rewriter, loc, v64f32Ty, lhsPadded, rhsPadded, confCst);
578
579 // Extract first 32 elements from 64-element result
580 SmallVector<int64_t> extractMask;
581 for (int i = 0; i < 32; ++i)
582 extractMask.push_back(i);
583 auto finalResult = vector::ShuffleOp::create(rewriter, loc, addResult,
584 addResult, extractMask);
585
586 rewriter.replaceOp(op, finalResult);
587 return success();
588 }
589
590 op.emitWarning() << "aievec.add_elem conversion is not supported.\n";
591 return failure();
592 }
593};
594
595// AIE2p version of SubElemOp conversion
597 : public mlir::ConvertOpToLLVMPattern<aievec::SubElemOp> {
598public:
599 using ConvertOpToLLVMPattern<aievec::SubElemOp>::ConvertOpToLLVMPattern;
600
610
611 static DecodedSubElemOp decodeSubElemOp(OpAdaptor op) {
612 auto lhs = op.getLhs();
613 auto lhsVecTy = cast<VectorType>(lhs.getType());
614 auto lhsScaTy = lhsVecTy.getElementType();
615 unsigned lhsBitWidth = lhsScaTy.getIntOrFloatBitWidth();
616 int laneSize = getVectorLaneSize(lhsVecTy);
617
618 // Integer types
619 if (llvm::isa<IntegerType>(lhsScaTy)) {
621 } else {
622 // Float types
623 if (lhsBitWidth == 32) {
624 // FP32 sub_elem
625 if (laneSize == 16) {
627 } else if (laneSize == 32) {
629 }
630 }
631 }
633 }
634
635 LogicalResult
636 matchAndRewrite(aievec::SubElemOp op, OpAdaptor adaptor,
637 ConversionPatternRewriter &rewriter) const override {
638 Location loc = op.getLoc();
639 auto decodedSubElemOp = decodeSubElemOp(adaptor);
640
641 if (decodedSubElemOp.kind == DecodedSubElemOp::Kind::UNSUPPORTED) {
642 op.emitWarning() << "aievec.sub_elem conversion is not supported.\n";
643 return failure();
644 }
645
646 // Handle the FP32 sub_elem for AIE2p (16-lane)
647 // We need to expand <16xf32> to <64xf32> for the ACC2048 intrinsic
648 if (decodedSubElemOp.kind ==
650 // Step 1: Bitcast <16 x float> to <8 x i64>
651 auto v8i64Ty = VectorType::get({8}, rewriter.getI64Type());
652 auto lhsI64 =
653 LLVM::BitcastOp::create(rewriter, loc, v8i64Ty, adaptor.getLhs());
654 auto rhsI64 =
655 LLVM::BitcastOp::create(rewriter, loc, v8i64Ty, adaptor.getRhs());
656
657 // Step 2: Shuffle <8 x i64> to <32 x i64> (expand with poison values)
658 auto v32i64Ty = VectorType::get({32}, rewriter.getI64Type());
659 SmallVector<int64_t> expandMask = {0, 1, 2, 3, 4, 5, 6, 7};
660 for (int i = 8; i < 32; ++i)
661 expandMask.push_back(-1); // poison values
662
663 auto lhsExpanded =
664 vector::ShuffleOp::create(rewriter, loc, lhsI64, lhsI64, expandMask);
665 auto rhsExpanded =
666 vector::ShuffleOp::create(rewriter, loc, rhsI64, rhsI64, expandMask);
667
668 // Step 3: Bitcast <32 x i64> to <64 x float>
669 auto v64f32Ty = VectorType::get({64}, rewriter.getF32Type());
670 auto lhsF32 =
671 LLVM::BitcastOp::create(rewriter, loc, v64f32Ty, lhsExpanded);
672 auto rhsF32 =
673 LLVM::BitcastOp::create(rewriter, loc, v64f32Ty, rhsExpanded);
674
675 // Step 4: Call the ACC2048 intrinsic with conf=60
676 auto confCst = LLVM::ConstantOp::create(
677 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(60));
678
679 // Create the intrinsic call
680 auto subResult = xllvm::SubACC2048AccFloatAIE2pIntrOp::create(
681 rewriter, loc, v64f32Ty, lhsF32, rhsF32, confCst);
682
683 // Step 5: Bitcast <64 x float> back to <32 x i64>
684 auto resultI64 =
685 LLVM::BitcastOp::create(rewriter, loc, v32i64Ty, subResult);
686
687 // Step 6: Shuffle to extract first 8 elements <32 x i64> -> <8 x i64>
688 SmallVector<int64_t> extractMask = {0, 1, 2, 3, 4, 5, 6, 7};
689 auto resultExtracted = vector::ShuffleOp::create(rewriter, loc, resultI64,
690 resultI64, extractMask);
691
692 // Step 7: Bitcast <8 x i64> back to <16 x float>
693 auto v16f32Ty = VectorType::get({16}, rewriter.getF32Type());
694 auto finalResult =
695 LLVM::BitcastOp::create(rewriter, loc, v16f32Ty, resultExtracted);
696
697 rewriter.replaceOp(op, finalResult);
698 return success();
699 }
700
701 // Handle the FP32 sub_elem for AIE2p (32-lane)
702 // Use ACC2048 intrinsic by padding to 64 lanes
703 if (decodedSubElemOp.kind ==
705 // Pad from <32 x float> to <64 x float> using shuffle
706 SmallVector<int64_t> padMask;
707 for (int i = 0; i < 32; ++i)
708 padMask.push_back(i);
709 for (int i = 32; i < 64; ++i)
710 padMask.push_back(-1); // poison/undef
711
712 auto v64f32Ty = VectorType::get({64}, rewriter.getF32Type());
713 auto lhsPadded = vector::ShuffleOp::create(
714 rewriter, loc, adaptor.getLhs(), adaptor.getLhs(), padMask);
715 auto rhsPadded = vector::ShuffleOp::create(
716 rewriter, loc, adaptor.getRhs(), adaptor.getRhs(), padMask);
717
718 // Call ACC2048 intrinsic
719 auto confCst = LLVM::ConstantOp::create(
720 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(60));
721 auto subResult = xllvm::SubACC2048AccFloatAIE2pIntrOp::create(
722 rewriter, loc, v64f32Ty, lhsPadded, rhsPadded, confCst);
723
724 // Extract first 32 elements from 64-element result
725 SmallVector<int64_t> extractMask;
726 for (int i = 0; i < 32; ++i)
727 extractMask.push_back(i);
728 auto finalResult = vector::ShuffleOp::create(rewriter, loc, subResult,
729 subResult, extractMask);
730
731 rewriter.replaceOp(op, finalResult);
732 return success();
733 }
734
735 op.emitWarning() << "aievec.sub_elem conversion is not supported.\n";
736 return failure();
737 }
738};
739
741 : public mlir::ConvertOpToLLVMPattern<aievec::aie1::SubOp> {
742public:
743 using ConvertOpToLLVMPattern<aievec::aie1::SubOp>::ConvertOpToLLVMPattern;
744
745 LogicalResult
746 matchAndRewrite(aievec::aie1::SubOp op, OpAdaptor adaptor,
747 ConversionPatternRewriter &rewriter) const override {
748 op.emitWarning() << "aie.sub conversion is not implemented\n";
749 return failure();
750 }
751};
752
754 : public mlir::ConvertOpToLLVMPattern<aievec::aie1::FMAOp> {
755public:
756 using ConvertOpToLLVMPattern<aievec::aie1::FMAOp>::ConvertOpToLLVMPattern;
757
758 LogicalResult
759 matchAndRewrite(aievec::aie1::FMAOp op, OpAdaptor adaptor,
760 ConversionPatternRewriter &rewriter) const override {
761 auto module = op->getParentOfType<ModuleOp>();
762 MLIRContext *context = rewriter.getContext();
763
764 auto startType = IntegerType::get(context, 32);
765 auto offsetsType = VectorType::get({2}, IntegerType::get(context, 32));
766 auto confType = VectorType::get({2}, IntegerType::get(context, 32));
767
768 // If the intrinsic declaration doesn't exist, create it
769 std::string intrinsicName = getMulOrFMAIntrinsicName(op);
770 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(
771 StringAttr::get(context, intrinsicName));
772
773 if (!func) {
774 OpBuilder::InsertionGuard guard(rewriter);
775 rewriter.setInsertionPointToStart(module.getBody());
776 func = LLVM::LLVMFuncOp::create(
777 rewriter, rewriter.getUnknownLoc(), intrinsicName,
778 LLVM::LLVMFunctionType::get(
779 op.getResult().getType(),
780 {op.getLhs().getType(), op.getRhs().getType(),
781 op.getAcc().getType(), startType, /* xstart */
782 startType, /* ystart */
783 startType, /* zstart */
784 offsetsType, /* xoffsets */
785 offsetsType, /* zoffsets */
786 confType}));
787 }
788
789 // Parse the string attribute values
790 BufferParams x = {};
791 BufferParams z = {};
792 op.getXstart().getAsInteger(0, x.start);
793 op.getXoffsets().getAsInteger(0, x.offsets);
794 op.getXoffsetsHi().getAsInteger(0, x.offsets_hi);
795 op.getXstep().getAsInteger(0, x.step);
796 op.getXsquare().getAsInteger(0, x.square);
797 op.getZstart().getAsInteger(0, z.start);
798 op.getZoffsets().getAsInteger(0, z.offsets);
799 op.getZoffsetsHi().getAsInteger(0, z.offsets_hi);
800 op.getZstep().getAsInteger(0, z.step);
801 op.getZsquare().getAsInteger(0, z.square);
802
803 // Encode the configuration register
804 uint32_t conf[2] = {0, 0};
805 encodeConf(conf, x, z, op.getFmsub());
806
807 // Create the constants and replace the op
808 auto xstartVal = LLVM::ConstantOp::create(
809 rewriter, op->getLoc(), startType, rewriter.getI32IntegerAttr(x.start));
810 auto ystartVal = LLVM::ConstantOp::create(rewriter, op->getLoc(), startType,
811 rewriter.getI32IntegerAttr(0));
812 auto zstartVal = LLVM::ConstantOp::create(
813 rewriter, op->getLoc(), startType, rewriter.getI32IntegerAttr(z.start));
814 auto xoffsetsVal = LLVM::ConstantOp::create(
815 rewriter, op->getLoc(), offsetsType,
816 rewriter.getI32VectorAttr({(int32_t)x.offsets, (int32_t)x.offsets_hi}));
817 auto zoffsetsVal = LLVM::ConstantOp::create(
818 rewriter, op->getLoc(), offsetsType,
819 rewriter.getI32VectorAttr({(int32_t)z.offsets, (int32_t)z.offsets_hi}));
820 auto confVal = LLVM::ConstantOp::create(
821 rewriter, op->getLoc(), confType,
822 rewriter.getI32VectorAttr({(int32_t)conf[0], (int32_t)conf[1]}));
823 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
824 op, func,
825 ValueRange{op.getLhs(), op.getRhs(), op.getAcc(), xstartVal, ystartVal,
826 zstartVal, xoffsetsVal, zoffsetsVal, confVal});
827 return success();
828 }
829};
830
832 : public mlir::ConvertOpToLLVMPattern<aievec::aie1::MulOp> {
833public:
834 using ConvertOpToLLVMPattern<aievec::aie1::MulOp>::ConvertOpToLLVMPattern;
835
836 LogicalResult
837 matchAndRewrite(aievec::aie1::MulOp op, OpAdaptor adaptor,
838 ConversionPatternRewriter &rewriter) const override {
839 auto module = op->getParentOfType<ModuleOp>();
840 MLIRContext *context = rewriter.getContext();
841
842 auto startType = IntegerType::get(context, 32);
843 auto offsetsType = VectorType::get({2}, IntegerType::get(context, 32));
844 auto confType = VectorType::get({2}, IntegerType::get(context, 32));
845
846 // If the intrinsic declaration doesn't exist, create it
847 std::string intrinsicName = getMulOrFMAIntrinsicName(op);
848 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(
849 StringAttr::get(context, intrinsicName));
850
851 if (!func) {
852 OpBuilder::InsertionGuard guard(rewriter);
853 rewriter.setInsertionPointToStart(module.getBody());
854 func = LLVM::LLVMFuncOp::create(
855 rewriter, rewriter.getUnknownLoc(), intrinsicName,
856 LLVM::LLVMFunctionType::get(op.getResult().getType(),
857 {op.getLhs().getType(),
858 op.getRhs().getType(),
859 startType, /* xstart */
860 startType, /* ystart */
861 startType, /* zstart */
862 offsetsType, /* xoffsets */
863 offsetsType, /* zoffsets */
864 confType}));
865 }
866
867 // Parse the string attribute values
868 BufferParams x = {};
869 BufferParams z = {};
870 op.getXstart().getAsInteger(0, x.start);
871 op.getXoffsets().getAsInteger(0, x.offsets);
872 op.getXoffsetsHi().getAsInteger(0, x.offsets_hi);
873 op.getXstep().getAsInteger(0, x.step);
874 op.getXsquare().getAsInteger(0, x.square);
875 op.getZstart().getAsInteger(0, z.start);
876 op.getZoffsets().getAsInteger(0, z.offsets);
877 op.getZoffsetsHi().getAsInteger(0, z.offsets_hi);
878 op.getZstep().getAsInteger(0, z.step);
879 op.getZsquare().getAsInteger(0, z.square);
880
881 // Encode the configuration register
882 uint32_t conf[2] = {0, 0};
883 encodeConf(conf, x, z, false);
884
885 // Create the constants and replace the op
886 auto xstartVal = LLVM::ConstantOp::create(
887 rewriter, op->getLoc(), startType, rewriter.getI32IntegerAttr(x.start));
888 auto ystartVal = LLVM::ConstantOp::create(rewriter, op->getLoc(), startType,
889 rewriter.getI32IntegerAttr(0));
890 auto zstartVal = LLVM::ConstantOp::create(
891 rewriter, op->getLoc(), startType, rewriter.getI32IntegerAttr(z.start));
892 auto xoffsetsVal = LLVM::ConstantOp::create(
893 rewriter, op->getLoc(), offsetsType,
894 rewriter.getI32VectorAttr({(int32_t)x.offsets, (int32_t)x.offsets_hi}));
895 auto zoffsetsVal = LLVM::ConstantOp::create(
896 rewriter, op->getLoc(), offsetsType,
897 rewriter.getI32VectorAttr({(int32_t)z.offsets, (int32_t)z.offsets_hi}));
898 auto confVal = LLVM::ConstantOp::create(
899 rewriter, op->getLoc(), confType,
900 rewriter.getI32VectorAttr({(int32_t)conf[0], (int32_t)conf[1]}));
901 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
902 op, func,
903 ValueRange{op.getLhs(), op.getRhs(), xstartVal, ystartVal, zstartVal,
904 xoffsetsVal, zoffsetsVal, confVal});
905 return success();
906 }
907};
908
910 : public mlir::ConvertOpToLLVMPattern<aievec::MulElemOp> {
911public:
912 using ConvertOpToLLVMPattern<aievec::MulElemOp>::ConvertOpToLLVMPattern;
913
914 MulElemOpConversion(const LLVMTypeConverter &typeConverter,
915 Aie2Fp32Emulation aie2Fp32EmulationOption)
916 : ConvertOpToLLVMPattern(typeConverter),
918
919 Aie2Fp32Emulation aie2Fp32EmulationOption;
920
922 enum class Kind {
923 // DtIn0_DtIn1_DtRes_CxMxKxN
930 // TODO: I16_I16_I64_16x1x2x1
931 };
932
934 int conf;
935 };
936
937 static DecodedMulElemOp decodeMulElemOp(OpAdaptor op) {
938 auto lhs = op.getLhs();
939 auto lhsVecTy = cast<VectorType>(lhs.getType());
940 auto lhsScaTy = lhsVecTy.getElementType();
941 unsigned lhsBitWidth = lhsScaTy.getIntOrFloatBitWidth();
942
943 // Integer types
944 if (llvm::isa<IntegerType>(lhsScaTy)) {
945 if (lhsBitWidth == 8) {
947 aiev2_vmac_compute_control(
948 /*sgn_x=*/1, /*sgn_y=*/1, /*amode=*/0, /*bmode=*/1,
949 /*variant=*/1, /*zero_acc=*/0, /*shift16=*/0,
950 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
951 /*sub_mask=*/0)};
952 } else if (lhsBitWidth == 16) {
954 aiev2_vmac_compute_control(
955 /*sgn_x=*/1, /*sgn_y=*/1, /*amode=*/0, /*bmode=*/3,
956 /*variant=*/1, /*zero_acc=*/0, /*shift16=*/0,
957 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
958 /*sub_mask=*/0)};
959 } else if (lhsBitWidth == 32) {
960 // emulated I32 mul_elem
962 }
963 } else {
964 // Float types
965 if (lhsBitWidth == 16) {
967 aiev2_vmac_compute_control(
968 /*sgn_x=*/0, /*sgn_y=*/0, /*amode=*/2, /*bmode=*/3,
969 /*variant=*/1, /*zero_acc=*/0, /*shift16=*/0,
970 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
971 /*sub_mask=*/0)};
972 } else if (lhsBitWidth == 32) {
973 // emulated FP32 mul_elem
975 }
976 }
977
979 }
980
981 // This conversion pattern implements the below CPP emulated I32 mul_elem.
982 // INTRINSIC(v16acc64)
983 // mul_elem_16_2(v16int32 a0, v16int32 a1, v16int32 b0, v16int32 b1) {
984 // v32uint16 a_lo = (v32uint16)shuffle(a0, a1, 2);
985 // v32int16 a_hi = (v32int16)shuffle(a0, a1, 3);
986 // v32uint16 b_lo = (v32uint16)shuffle(b0, b1, 2);
987 // v32int16 b_hi = (v32int16)shuffle(b0, b1, 3);
988 // v16acc64 acc = ::mul_elem_16_2(a_hi, b_hi);
989 // acc = mac_elem_16_2_conf(a_hi, 1, b_lo, false, acc, 0, 1, 0, 0);
990 // acc = mac_elem_16_2_conf(a_lo, false, b_hi, 1, acc, 0, 0, 0, 0);
991 // acc = mac_elem_16_2_conf(a_lo, false, b_lo, false, acc, 0, 1, 0, 0);
992 // return acc;
993 // }
994 // Caller example when handling the elementwise mul of two v16int32 vectors.
995 // v16int32 v1 = LHS();
996 // v16int32 v2 = RHS();
997 // v16acc64 v3 = mul_elem_16_2(v1, broadcast_zero_s32(), v2,
998 // undef_v16int32());
999 // Explantion:
1000 // a_lo = low_part(a0[0]--a0[15], a1[0]--a1[15])
1001 // a_hi = high_part(a0[0]--a0[15], a1[0]--a1[15])
1002 // b_lo = low_part(b0[0]--b0[15], b1[0]--b1[15])
1003 // b_hi = high_part(b0[0]--b0[15], b1[0]--b1[15])
1004 // The firt `acc` is from mul_elem_16_2(a_hi, b_hi), which performs 16 channel
1005 // of 1x2x1 matmul, acc[0] = a_hi[0]*b_hi[0]+a_hi[16]*b_hi[16], ... , acc[15]
1006 // = a_hi[15]*b_hi[15]+a_hi[31]*b_hi[31]. Then, the first MAC performs `acc`
1007 // left shift 16bit, and then 16 channel of 1x2x1 matmul (a_hi, b_lo)
1008 // accumulating to `acc`. The second MAC performs 16 channel of 1x2x1 matmul
1009 // (a_lo, b_hi) accumulating to `acc`. Finally, the third MAC performs 16
1010 // channel of 1x2x1 matmul (a_lo, b_hi) accumulating to `acc`.
1011 LogicalResult
1012 convertToEmulatedI32MulElem(aievec::MulElemOp op, OpAdaptor adaptor,
1013 ConversionPatternRewriter &rewriter) const {
1014
1015 Location loc = op.getLoc();
1016 auto zeroCst = LLVM::ConstantOp::create(
1017 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
1018 auto a0 = adaptor.getLhs();
1019 auto a1 = xllvm::VectorBroadcast32I512IntrOp::create(
1020 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()), zeroCst);
1021 auto b0 = adaptor.getRhs();
1022 auto b1 = xllvm::UndefV16I32IntrOp::create(
1023 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()));
1024
1025 // 4* Shuffle
1026 auto a_lo = xllvm::VectorShuffleIntrOp::create(
1027 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()), a0, a1,
1028 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1029 rewriter.getI32IntegerAttr(2)));
1030 auto a_hi = xllvm::VectorShuffleIntrOp::create(
1031 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()), a0, a1,
1032 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1033 rewriter.getI32IntegerAttr(3)));
1034 auto b_lo = xllvm::VectorShuffleIntrOp::create(
1035 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()), b0, b1,
1036 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1037 rewriter.getI32IntegerAttr(2)));
1038 auto b_hi = xllvm::VectorShuffleIntrOp::create(
1039 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()), b0, b1,
1040 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1041 rewriter.getI32IntegerAttr(3)));
1042 // MUL + 3 * MAC
1043 auto mulConfCst = LLVM::ConstantOp::create(
1044 rewriter, loc, rewriter.getI32Type(),
1045 rewriter.getI32IntegerAttr(aiev2_vmac_compute_control(
1046 /*sgn_x=*/1, /*sgn_y=*/1, /*amode=*/1, /*bmode=*/3,
1047 /*variant=*/2, /*zero_acc=*/0, /*shift16=*/0,
1048 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, /*sub_mask=*/0)));
1049 auto mulConfOp = xllvm::MulConfAcc64IntrOp::create(
1050 rewriter, loc, VectorType::get({16}, rewriter.getI64Type()),
1051 forceCastOperandsToSignature(
1052 rewriter, loc,
1053 /*operands=*/{a_hi, b_hi, mulConfCst},
1054 /*signature=*/
1055 {VectorType::get({64}, rewriter.getI8Type()),
1056 VectorType::get({16}, rewriter.getI32Type()),
1057 rewriter.getI32Type()}));
1058
1059 auto createMacConfOp = [&](SmallVector<Value> operands,
1060 int macConf) -> Value {
1061 operands.push_back(
1062 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1063 rewriter.getI32IntegerAttr(macConf)));
1064 return xllvm::MacConfAcc64IntrOp::create(
1065 rewriter, loc, VectorType::get({16}, rewriter.getI64Type()),
1066 forceCastOperandsToSignature(
1067 rewriter, loc,
1068 /*operands=*/operands,
1069 /*signature=*/
1070 {VectorType::get({64}, rewriter.getI8Type()),
1071 VectorType::get({16}, rewriter.getI32Type()),
1072 VectorType::get({16}, rewriter.getI64Type()),
1073 rewriter.getI32Type()}))
1074 .getResult();
1075 };
1076 auto acc64Val = mulConfOp.getResult();
1077 acc64Val = createMacConfOp(
1078 SmallVector<Value>{a_hi, b_lo, acc64Val},
1079 aiev2_vmac_compute_control(
1080 /*sgn_x=*/1, /*sgn_y=*/0, /*amode=*/1, /*bmode=*/3,
1081 /*variant=*/2, /*zero_acc=*/0, /*shift16=*/1,
1082 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, /*sub_mask=*/0));
1083 acc64Val = createMacConfOp(
1084 SmallVector<Value>{a_lo, b_hi, acc64Val},
1085 aiev2_vmac_compute_control(
1086 /*sgn_x=*/0, /*sgn_y=*/1, /*amode=*/1, /*bmode=*/3,
1087 /*variant=*/2, /*zero_acc=*/0, /*shift16=*/0,
1088 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, /*sub_mask=*/0));
1089 acc64Val = createMacConfOp(
1090 SmallVector<Value>{a_lo, b_lo, acc64Val},
1091 aiev2_vmac_compute_control(
1092 /*sgn_x=*/0, /*sgn_y=*/0, /*amode=*/1, /*bmode=*/3,
1093 /*variant=*/2, /*zero_acc=*/0, /*shift16=*/1,
1094 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, /*sub_mask=*/0));
1095
1096 // create bitcast/shape_cast for result
1097 auto resultVal =
1098 forceCastValueToType(rewriter, loc, acc64Val, op.getResult().getType());
1099 rewriter.replaceOp(op, resultVal);
1100 return success();
1101 }
1102
1103 // This conversion pattern implements the below CPP emulated FP32 mul_elem.
1104 // inline v16accfloat mul_elem_16_accuracy_safe(v16float v1, v16float v2) {
1105 // v32bfloat16 a = broadcast_zero_to_v32bfloat16();
1106 // v32bfloat16 b = broadcast_zero_to_v32bfloat16();
1107 // v32bfloat16 c = broadcast_zero_to_v32bfloat16();
1108 // v32bfloat16 d = broadcast_zero_to_v32bfloat16();
1109 // v32bfloat16 e = broadcast_zero_to_v32bfloat16();
1110 // v32bfloat16 f = broadcast_zero_to_v32bfloat16();
1111 // v32bfloat16 dummy0 = broadcast_one_to_v32bfloat16();
1112 // a = insert(a,0,to_v16bfloat16((v16accfloat)v1));
1113 // v16accfloat acc0 = msc_elem_16_2(a, dummy0, (v16accfloat)v1);
1114 // b = insert(b,0,to_v16bfloat16(acc0));
1115 // c = insert(c,0,to_v16bfloat16(msc_elem_16_2(b, dummy0, acc0)));
1116 // d = insert(d,0,to_v16bfloat16((v16accfloat)v2));
1117 // v16accfloat acc1 = msc_elem_16_2(d, dummy0, (v16accfloat)v2);
1118 // e = insert(e,0,to_v16bfloat16(acc1));
1119 // f = insert(f,0,to_v16bfloat16(msc_elem_16_2(e, dummy0, acc1)));
1120 // return
1121 // mac_elem_16_2(a,d,mac_elem_16_2(a,e,mac_elem_16_2(b,d,mac_elem_16_2(
1122 // d,c,mac_elem_16_2(b,e,mac_elem_16_2(a,f,mac_elem_16_2(
1123 // b,f,mac_elem_16_2(c,e,mul_elem_16_2(c,f)))))))));
1124 // }
1125 // Caller example when handling the elementwise mul of two v16float vectors.
1126 // v16float v1 = LHS(); v16float v2 = RHS();
1127 // v16accfloat v3 = mul_elem_16(v1, v2);
1128 // Explantion: For v32bfloat16 `a`, the first half v16bf16 contains `most
1129 // significant 7 bits of mantissa` from v1, and the second half v16bf16 are
1130 // zeros. For v16accfloat `acc0`, the MSC equals to "(original `v1` with 23
1131 // bits of mantissa) - (`a` with MSB 7 bits of mantissa from v1)". For
1132 // v32bfloat16 `b`, the first half v16bf16 contains `[7:13] bits of mantissa
1133 // from v1` from v1, and the second half v16bf16 are zeros. For v32bfloat16
1134 // `c`, the first half v16bf16 contains `[14:20] bits of mantissa from v1`
1135 // from v1, and the second half v16bf16 are zeros. Hence, we can represent
1136 // v16float in three v32bfloat16 and then perform 9 MUL/MAC in v32bfloat16 to
1137 // get the final elementwise multiplication result.
1138
1139 LogicalResult
1140 convertToEmulatedFP32MulElem(aievec::MulElemOp op, OpAdaptor adaptor,
1141 ConversionPatternRewriter &rewriter) const {
1142 Location loc = op.getLoc();
1143 auto zeroCst =
1144 LLVM::ConstantOp::create(rewriter, loc, rewriter.getBF16Type(),
1145 rewriter.getZeroAttr(rewriter.getBF16Type()));
1146 auto aZeros = xllvm::VectorBroadcast16BF512IntrOp::create(
1147 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
1148 auto bZeros = xllvm::VectorBroadcast16BF512IntrOp::create(
1149 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
1150 auto cZeros = xllvm::VectorBroadcast16BF512IntrOp::create(
1151 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
1152 auto dZeros = xllvm::VectorBroadcast16BF512IntrOp::create(
1153 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
1154 auto eZeros = xllvm::VectorBroadcast16BF512IntrOp::create(
1155 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
1156 auto fZeros = xllvm::VectorBroadcast16BF512IntrOp::create(
1157 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
1158 auto oneCst =
1159 LLVM::ConstantOp::create(rewriter, loc, rewriter.getBF16Type(),
1160 rewriter.getOneAttr(rewriter.getBF16Type()));
1161 auto dummy0 = xllvm::VectorBroadcast16BF512IntrOp::create(
1162 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()), oneCst);
1163 auto zeroCstI32 = LLVM::ConstantOp::create(
1164 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
1165 auto mscMacMulConfCst = LLVM::ConstantOp::create(
1166 rewriter, loc, rewriter.getI32Type(),
1167 rewriter.getI32IntegerAttr(aiev2_vmac_compute_control(
1168 /*sgn_x=*/0, /*sgn_y=*/0, /*amode=*/2, /*bmode=*/3,
1169 /*variant=*/1, /*zero_acc=*/0, /*shift16=*/0,
1170 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, /*sub_mask=*/0)));
1171
1172 auto extractV16FP32ToThreeV16BF16 =
1173 [&](Value inputV16FP32, Value aZeros, Value bZeros,
1174 Value cZeros) -> std::tuple<Value, Value, Value> {
1175 // a = insert(a,0,to_v16bfloat16((v16accfloat)v1));
1176 auto inputBitCasted =
1177 forceCastValueToType(rewriter, loc, inputV16FP32,
1178 VectorType::get({8}, rewriter.getI64Type()));
1179 auto v1ToBF16 = xllvm::Vector16AccFloatToV16BF16AIE2IntrOp::create(
1180 rewriter, loc, VectorType::get({16}, rewriter.getBF16Type()),
1181 inputBitCasted);
1182 auto a = xllvm::UpdBF512BF256IntrOp::create(
1183 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()), aZeros,
1184 v1ToBF16, zeroCstI32);
1185
1186 // v16accfloat acc0 = msc_elem_16_2(a, dummy0, (v16accfloat)v1);
1187 auto acc0 = xllvm::MscConfBF16IntrOp::create(
1188 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()), a, dummy0,
1189 inputBitCasted, mscMacMulConfCst);
1190
1191 // b = insert(b,0,to_v16bfloat16(acc0));
1192 auto acc0ToBF16 = xllvm::Vector16AccFloatToV16BF16AIE2IntrOp::create(
1193 rewriter, loc, VectorType::get({16}, rewriter.getBF16Type()), acc0);
1194 auto b = xllvm::UpdBF512BF256IntrOp::create(
1195 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()), bZeros,
1196 acc0ToBF16, zeroCstI32);
1197
1198 // c = insert(c,0,to_v16bfloat16(msc_elem_16_2(b, dummy0, acc0)));
1199 auto acc0Mscb = xllvm::MscConfBF16IntrOp::create(
1200 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()), b, dummy0,
1201 acc0, mscMacMulConfCst);
1202 auto acc0MscbToBF16 = xllvm::Vector16AccFloatToV16BF16AIE2IntrOp::create(
1203 rewriter, loc, VectorType::get({16}, rewriter.getBF16Type()),
1204 acc0Mscb);
1205 auto c = xllvm::UpdBF512BF256IntrOp::create(
1206 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()), cZeros,
1207 acc0MscbToBF16, zeroCstI32);
1208 return std::make_tuple(a.getResult(), b.getResult(), c.getResult());
1209 };
1210
1211 // Get v16vfloat16 a, b, c for representing v16float v1
1212 auto [a, b, c] =
1213 extractV16FP32ToThreeV16BF16(adaptor.getLhs(), aZeros, bZeros, cZeros);
1214 // Get v16vfloat16 d, e, f for representing v16float v2
1215 auto [d, e, f] =
1216 extractV16FP32ToThreeV16BF16(adaptor.getRhs(), dZeros, eZeros, fZeros);
1217
1218 // Create 1 MUL and 2/5/8 MACs depending on the Aie2Fp32EmulationOption
1219 auto createMacOps = [&](Value lhs, Value rhs, Value acc) -> Value {
1220 return xllvm::MacConfBF16IntrOp::create(
1221 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()),
1222 lhs, rhs, acc, mscMacMulConfCst)
1223 .getResult();
1224 };
1225
1226 Value finalMacVal;
1227 if (aie2Fp32EmulationOption == Aie2Fp32Emulation::AccuracyFast) {
1228 // Fast and Accurate option. float a*b would require 6 mac operations.
1229 // Input fp32 number is split in to 3 bfloat16 numbers to extract all the
1230 // bits of the mantissa. float a,b; both a and b are split in to 3
1231 // bfloat16 numbers each. Hence there would be 9 mac operations in
1232 // multiplication of a and b. In the 9 mac operations to emulate fp32 mul,
1233 // mac operations with LSBs are ignored. (3 last terms). This helps
1234 // improve cycle count of mul and has least impact on accuracy of result.
1235 // This is the default option to the aiecompiler
1236 auto afMul = xllvm::MulConfBF16IntrOp::create(
1237 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()), a, f,
1238 mscMacMulConfCst);
1239 finalMacVal = createMacOps(
1240 a, d,
1241 createMacOps(
1242 a, e,
1243 createMacOps(b, d,
1244 createMacOps(d, c, createMacOps(b, e, afMul)))));
1245 } else if (aie2Fp32EmulationOption == Aie2Fp32Emulation::AccuracyLow) {
1246 // Fast and least accurate option. float a*b would require 3 mac
1247 // operations.
1248 // Input fp32 number is split in to 2 bfloat16 numbers. Hence not all the
1249 // bits from mantissa can be used. float a,b; Both a and b are split in to
1250 // 2 bfloat16 numbers each. Hence there would be 4 mac operations in
1251 // multiplication of a and b. In the 4 mac operations to emulate fp32 mul,
1252 // mac operations with LSBs are ignored. (1 last term). This helps improve
1253 // cycle count of mul float a, b;
1254 auto bdMul = xllvm::MulConfBF16IntrOp::create(
1255 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()), b, d,
1256 mscMacMulConfCst);
1257 finalMacVal = createMacOps(a, d, createMacOps(a, e, bdMul));
1258 } else {
1259 // aie2Fp32EmulationOption == Aie2Fp32Emulation::AccuracySafe
1260 // Most accurate option since input fp32 number is split in to 3 bfloat16
1261 // numbers to extract all the bits of the mantissa. float a*b would
1262 // require 9 mac operations due to 3 bfloat16 splits each.
1263 auto cfMul = xllvm::MulConfBF16IntrOp::create(
1264 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()), c, f,
1265 mscMacMulConfCst);
1266 finalMacVal = createMacOps(
1267 a, d,
1268 createMacOps(
1269 a, e,
1270 createMacOps(
1271 b, d,
1272 createMacOps(
1273 d, c,
1274 createMacOps(
1275 b, e,
1276 createMacOps(
1277 a, f,
1278 createMacOps(b, f,
1279 createMacOps(c, e, cfMul))))))));
1280 }
1281
1282 // create bitcast/shape_cast for result
1283 auto resultVal = forceCastValueToType(rewriter, loc, finalMacVal,
1284 op.getResult().getType());
1285 rewriter.replaceOp(op, resultVal);
1286 return success();
1287 }
1288
1289 LogicalResult
1290 matchAndRewrite(aievec::MulElemOp op, OpAdaptor adaptor,
1291 ConversionPatternRewriter &rewriter) const override {
1292 Location loc = op.getLoc();
1293 auto decodedMulElemOp = decodeMulElemOp(adaptor);
1294
1295 if (decodedMulElemOp.kind == DecodedMulElemOp::Kind::UNSUPPORTED) {
1296 op.emitWarning() << "aievec.mul_elem conversion is not supported.\n";
1297 return failure();
1298 }
1299
1300 // Handle the emulated I32/FP32 mul_elem
1301 if (decodedMulElemOp.kind == DecodedMulElemOp::Kind::I32_I32_I64_32x1x2x1) {
1302 return convertToEmulatedI32MulElem(op, adaptor, rewriter);
1303 } else if (decodedMulElemOp.kind ==
1305 return convertToEmulatedFP32MulElem(op, adaptor, rewriter);
1306 }
1307
1308 // create constant for config
1309 auto confCst = LLVM::ConstantOp::create(
1310 rewriter, loc, rewriter.getI32Type(),
1311 rewriter.getI32IntegerAttr(decodedMulElemOp.conf));
1312 Value mulElemOp = nullptr;
1313 SmallVector<Value> operands({adaptor.getLhs(), adaptor.getRhs(), confCst});
1314
1315 // create xllvm intrinsic
1316 if (decodedMulElemOp.kind == DecodedMulElemOp::Kind::I16_I16_I32_32x1x1x1 ||
1317 decodedMulElemOp.kind == DecodedMulElemOp::Kind::I8_I8_I32_32x1x2x1) {
1318 mulElemOp = xllvm::MulConfAcc32IntrOp::create(
1319 rewriter, loc, VectorType::get({16}, rewriter.getI64Type()),
1320 forceCastOperandsToSignature(
1321 rewriter, loc, operands,
1322 {VectorType::get({64}, rewriter.getI8Type()),
1323 VectorType::get({16}, rewriter.getI32Type()),
1324 rewriter.getI32Type()}));
1325 } else if (decodedMulElemOp.kind ==
1327 // Create zero vector using the exact pattern from working reference:
1328 // vbroadcast16.I512(0) -> bitcast to bf16 -> extract lower 256 bits
1329 auto zero32 = LLVM::ConstantOp::create(
1330 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
1331 auto zeros_i16 = xllvm::VectorBroadcast16I512IntrOp::create(
1332 rewriter, loc, VectorType::get({32}, rewriter.getI16Type()), zero32);
1333 auto zeros_bf16 = LLVM::BitcastOp::create(
1334 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()),
1335 zeros_i16);
1336 auto zeroVec = xllvm::ExtBF256BF512IntrOp::create(
1337 rewriter, loc, VectorType::get({16}, rewriter.getBF16Type()),
1338 zeros_bf16, zero32);
1339
1340 // Use set+upd pattern to match working reference
1341 auto idx1 = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1342 rewriter.getI32IntegerAttr(1));
1343
1344 // Set lhs at lower 256 bits, then update upper 256 bits with zeros
1345 auto lhsSet = xllvm::VectorSetBF512BF256IntrOp::create(
1346 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()),
1347 adaptor.getLhs(), zero32);
1348 auto lhsConcat = xllvm::UpdBF512BF256IntrOp::create(
1349 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()), lhsSet,
1350 zeroVec, idx1);
1351
1352 // Set rhs at lower 256 bits, then update upper 256 bits with zeros
1353 auto rhsSet = xllvm::VectorSetBF512BF256IntrOp::create(
1354 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()),
1355 adaptor.getRhs(), zero32);
1356 auto rhsConcat = xllvm::UpdBF512BF256IntrOp::create(
1357 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()), rhsSet,
1358 zeroVec, idx1);
1359
1360 // Call bf.mul16.conf with padded vectors
1361 mulElemOp = xllvm::MulConfBF16IntrOp::create(
1362 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()), lhsConcat,
1363 rhsConcat, confCst);
1364 }
1365
1366 // create bitcast/shape_cast for result
1367 auto resultVal = forceCastValueToType(rewriter, loc, mulElemOp,
1368 op.getResult().getType());
1369 rewriter.replaceOp(op, resultVal);
1370 return success();
1371 }
1372};
1373
1374// AIE2p version of MulElemOp conversion
1376 : public mlir::ConvertOpToLLVMPattern<aievec::MulElemOp> {
1377public:
1378 using ConvertOpToLLVMPattern<aievec::MulElemOp>::ConvertOpToLLVMPattern;
1379
1381 enum class Kind {
1382 BF16_BF16_FP32_16x1x1x1, // 16-lane bf16 -> 16-lane f32
1383 BF16_BF16_FP32_32x1x2x1, // 32-lane bf16 -> 32-lane f32
1384 BF16_BF16_FP32_64x1x2x1, // 64-lane bf16 -> 64-lane f32
1386 };
1388 int conf;
1389 };
1390
1391 static DecodedMulElemOp decodeMulElemOp(OpAdaptor op) {
1392 auto lhs = op.getLhs();
1393 auto lhsVecTy = cast<VectorType>(lhs.getType());
1394 auto lhsScaTy = lhsVecTy.getElementType();
1395 unsigned lhsBitWidth = lhsScaTy.getIntOrFloatBitWidth();
1396 int lhsLanes = getVectorLaneSize(lhsVecTy);
1397
1398 // Integer types - not supported for AIE2p elementwise mul
1399 if (llvm::isa<IntegerType>(lhsScaTy)) {
1401 } else {
1402 // Float types
1403 if (lhsBitWidth == 16) {
1404 // BF16 mul_elem
1405 if (lhsLanes == 16) {
1406 // 16-lane bfloat16 uses I512.I512.ACC512 intrinsic
1408 } else if (lhsLanes == 32) {
1409 // 32-lane bfloat16 uses I512.I512.ACC1024 intrinsic
1411 } else if (lhsLanes == 64) {
1412 // 64-lane bfloat16 uses I1024.I1024.ACC2048 intrinsic
1414 }
1415 }
1416 }
1418 }
1419
1420 LogicalResult
1421 matchAndRewrite(aievec::MulElemOp op, OpAdaptor adaptor,
1422 ConversionPatternRewriter &rewriter) const override {
1423 Location loc = op.getLoc();
1424 auto decodedMulElemOp = decodeMulElemOp(adaptor);
1425
1426 if (decodedMulElemOp.kind == DecodedMulElemOp::Kind::UNSUPPORTED) {
1427 op.emitWarning() << "aievec.mul_elem conversion is not supported for "
1428 "AIE2p.\n";
1429 return failure();
1430 }
1431
1432 // Create constant for config
1433 auto confCst = LLVM::ConstantOp::create(
1434 rewriter, loc, rewriter.getI32Type(),
1435 rewriter.getI32IntegerAttr(decodedMulElemOp.conf));
1436
1437 Value mulElemOp = nullptr;
1438
1439 // Handle BF16 mul_elem for AIE2p
1440 if (decodedMulElemOp.kind ==
1442 // 16-lane bfloat16: <16 x bfloat> x <16 x bfloat> -> <16 x float>
1443 // The intrinsic requires <32 x bfloat> inputs, so we need to pad
1444
1445 // Pad LHS from 16 to 32 bfloat16 using shuffle
1446 SmallVector<int64_t> padMask;
1447 for (int i = 0; i < 16; ++i)
1448 padMask.push_back(i);
1449 for (int i = 16; i < 32; ++i)
1450 padMask.push_back(-1); // poison/undef
1451
1452 auto lhsPadded = vector::ShuffleOp::create(
1453 rewriter, loc, adaptor.getLhs(), adaptor.getLhs(), padMask);
1454 auto rhsPadded = vector::ShuffleOp::create(
1455 rewriter, loc, adaptor.getRhs(), adaptor.getRhs(), padMask);
1456
1457 SmallVector<Value> operands({lhsPadded, rhsPadded, confCst});
1458
1459 // Call I512.I512.ACC512 intrinsic
1460 mulElemOp = xllvm::MulConfBF16I512ACC512AIE2pIntrOp::create(
1461 rewriter, loc, VectorType::get({16}, rewriter.getF32Type()),
1462 forceCastOperandsToSignature(
1463 rewriter, loc, operands,
1464 {VectorType::get({32}, rewriter.getBF16Type()),
1465 VectorType::get({32}, rewriter.getBF16Type()),
1466 rewriter.getI32Type()}));
1467 } else if (decodedMulElemOp.kind ==
1469 // 32-lane bfloat16: <32 x bfloat> x <32 x bfloat> -> <32 x float>
1470 SmallVector<Value> operands(
1471 {adaptor.getLhs(), adaptor.getRhs(), confCst});
1472 mulElemOp = xllvm::MulConfBF16I512ACC1024AIE2pIntrOp::create(
1473 rewriter, loc, VectorType::get({32}, rewriter.getF32Type()),
1474 forceCastOperandsToSignature(
1475 rewriter, loc, operands,
1476 {VectorType::get({32}, rewriter.getBF16Type()),
1477 VectorType::get({32}, rewriter.getBF16Type()),
1478 rewriter.getI32Type()}));
1479 } else if (decodedMulElemOp.kind ==
1481 // 64-lane bfloat16: <64 x bfloat> x <64 x bfloat> -> <64 x float>
1482 SmallVector<Value> operands(
1483 {adaptor.getLhs(), adaptor.getRhs(), confCst});
1484 mulElemOp = xllvm::MulConfBF16I1024ACC2048AIE2pIntrOp::create(
1485 rewriter, loc, VectorType::get({64}, rewriter.getF32Type()),
1486 forceCastOperandsToSignature(
1487 rewriter, loc, operands,
1488 {VectorType::get({64}, rewriter.getBF16Type()),
1489 VectorType::get({64}, rewriter.getBF16Type()),
1490 rewriter.getI32Type()}));
1491 }
1492
1493 // create bitcast/shape_cast for result
1494 auto resultVal = forceCastValueToType(rewriter, loc, mulElemOp,
1495 op.getResult().getType());
1496 rewriter.replaceOp(op, resultVal);
1497 return success();
1498 }
1499};
1500
1501// AIE2p version of FMAElemOp conversion. Uses native F32 accumulators
1502// and AIE2p-specific MAC intrinsics.
1504 : public mlir::ConvertOpToLLVMPattern<aievec::FMAElemOp> {
1505public:
1506 using ConvertOpToLLVMPattern<aievec::FMAElemOp>::ConvertOpToLLVMPattern;
1507
1508 LogicalResult
1509 matchAndRewrite(aievec::FMAElemOp fmaOp, OpAdaptor adaptor,
1510 ConversionPatternRewriter &rewriter) const override {
1511 auto loc = fmaOp.getLoc();
1512 auto lhs = adaptor.getLhs();
1513 auto rhs = adaptor.getRhs();
1514 auto acc = adaptor.getAcc();
1515 auto lhsTy = cast<VectorType>(lhs.getType());
1516 auto accTy = cast<VectorType>(acc.getType());
1517 auto flatLhsTy = getFlattenedVectorType(lhsTy);
1518 auto flatAccTy = getFlattenedVectorType(accTy);
1519
1520 // Flatten operands, if needed
1521 if (lhsTy != flatLhsTy)
1522 lhs = vector::ShapeCastOp::create(rewriter, loc, flatLhsTy, lhs);
1523 if (cast<VectorType>(rhs.getType()) != flatLhsTy)
1524 rhs = vector::ShapeCastOp::create(rewriter, loc, flatLhsTy, rhs);
1525 if (accTy != flatAccTy)
1526 acc = vector::ShapeCastOp::create(rewriter, loc, flatAccTy, acc);
1527
1528 if (!flatLhsTy.getElementType().isBF16()) {
1529 fmaOp.emitWarning()
1530 << "aievec.mac_elem AIE2p conversion only supports bf16 inputs.\n";
1531 return failure();
1532 }
1533
1534 Type i32ty = rewriter.getI32Type();
1535 auto confCst = LLVM::ConstantOp::create(
1536 rewriter, loc, i32ty,
1537 rewriter.getI32IntegerAttr(aiev2_vmac_compute_control(
1538 /*sgn_x=*/0, /*sgn_y=*/0, /*amode=*/2, /*bmode=*/3,
1539 /*variant=*/1, /*zero_acc=*/0, /*shift16=*/0,
1540 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
1541 /*sub_mask=*/0)));
1542
1543 unsigned lhsLanes = flatLhsTy.getNumElements();
1544 Value macIntrOp = nullptr;
1545
1546 if (lhsLanes == 16) {
1547 // 16-lane bf16: pad to v32bf16, use I512.I512.ACC512 intrinsic
1548 SmallVector<int64_t> padMask;
1549 for (int i = 0; i < 16; ++i)
1550 padMask.push_back(i);
1551 for (int i = 16; i < 32; ++i)
1552 padMask.push_back(-1); // poison
1553
1554 auto lhsPadded =
1555 vector::ShuffleOp::create(rewriter, loc, lhs, lhs, padMask);
1556 auto rhsPadded =
1557 vector::ShuffleOp::create(rewriter, loc, rhs, rhs, padMask);
1558
1559 auto v32bf16Ty = VectorType::get({32}, rewriter.getBF16Type());
1560 auto v16f32Ty = VectorType::get({16}, rewriter.getF32Type());
1561 macIntrOp = xllvm::MacConfBF16I512ACC512AIE2pIntrOp::create(
1562 rewriter, loc, v16f32Ty,
1563 forceCastOperandsToSignature(
1564 rewriter, loc, {lhsPadded, rhsPadded, acc, confCst},
1565 {v32bf16Ty, v32bf16Ty, v16f32Ty, i32ty}));
1566 } else if (lhsLanes == 32) {
1567 // 32-lane bf16: direct, use I512.I512.ACC1024 intrinsic
1568 auto v32bf16Ty = VectorType::get({32}, rewriter.getBF16Type());
1569 auto v32f32Ty = VectorType::get({32}, rewriter.getF32Type());
1570 macIntrOp = xllvm::MacConfBF16I512ACC1024AIE2pIntrOp::create(
1571 rewriter, loc, v32f32Ty,
1572 forceCastOperandsToSignature(
1573 rewriter, loc, {lhs, rhs, acc, confCst},
1574 {v32bf16Ty, v32bf16Ty, v32f32Ty, i32ty}));
1575 } else {
1576 fmaOp.emitWarning()
1577 << "aievec.mac_elem AIE2p conversion: unsupported lane count "
1578 << lhsLanes << ".\n";
1579 return failure();
1580 }
1581
1582 // Recast/Reshape result
1583 auto resVal = forceCastValueToType(rewriter, loc, macIntrOp, flatAccTy);
1584 if (flatAccTy != accTy)
1585 resVal = vector::ShapeCastOp::create(rewriter, loc, accTy, resVal);
1586
1587 rewriter.replaceOp(fmaOp, resVal);
1588 return success();
1589 }
1590};
1591
1592// Enum to represent different AIE target architectures
1593enum class AIEArch {
1594 AIE2,
1595 AIE2p,
1596};
1597
1598class UPSOpAIE2Conversion : public mlir::ConvertOpToLLVMPattern<aievec::UPSOp> {
1599public:
1600 using ConvertOpToLLVMPattern<aievec::UPSOp>::ConvertOpToLLVMPattern;
1601
1602 LogicalResult
1603 matchAndRewrite(aievec::UPSOp op, OpAdaptor adaptor,
1604 ConversionPatternRewriter &rewriter) const override {
1605 Location loc = op.getLoc();
1606
1607 Value result = op.getResult();
1608 VectorType resultType = cast<VectorType>(result.getType());
1609 VectorType flatResTy = getFlattenedVectorType(resultType);
1610 Type resultScaTy = resultType.getElementType();
1611 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
1612 int resultLanes = getVectorLaneSize(resultType);
1613 int resultVectorSize = resultBitWidth * resultLanes;
1614
1615 Value opSrcVal = adaptor.getSource();
1616 auto srcVecTy = cast<VectorType>(opSrcVal.getType());
1617 auto fltSrcVecTy = getFlattenedVectorType(srcVecTy);
1618 if (srcVecTy != fltSrcVecTy)
1619 opSrcVal = vector::ShapeCastOp::create(rewriter, op.getLoc(), fltSrcVecTy,
1620 opSrcVal)
1621 .getResult();
1622
1623 // create xllvm intrinsic
1624 // Integer types
1625 Value upsIntrOp = nullptr;
1626 if (llvm::isa<IntegerType>(resultScaTy)) {
1627 // create constant for sign
1628 auto signCst = LLVM::ConstantOp::create(
1629 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
1630 auto shiftCst =
1631 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1632 rewriter.getI32IntegerAttr(op.getShift()));
1633
1634 SmallVector<Value> operands({opSrcVal, shiftCst, signCst});
1635 if (resultVectorSize == 512) {
1636 if (resultBitWidth == 32) {
1637 // v16int16 -> v16acc32
1638 upsIntrOp = xllvm::Acc32V16I256UpsAIE2IntrOp::create(
1639 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()),
1640 forceCastOperandsToSignature(
1641 rewriter, loc, operands,
1642 {VectorType::get({16}, rewriter.getI16Type()),
1643 rewriter.getI32Type(), rewriter.getI32Type()}));
1644 } else if (resultBitWidth == 64) {
1645 // v8int32 -> v8acc64
1646 upsIntrOp = xllvm::Acc64V8I256UpsAIE2IntrOp::create(
1647 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()),
1648 forceCastOperandsToSignature(
1649 rewriter, loc, operands,
1650 {VectorType::get({8}, rewriter.getI32Type()),
1651 rewriter.getI32Type(), rewriter.getI32Type()}));
1652 }
1653 } else if (resultVectorSize == 1024) {
1654 Value src = opSrcVal;
1655 VectorType srcType = cast<VectorType>(src.getType());
1656 Type srcScaType = srcType.getElementType();
1657 unsigned srcBitWidth = srcScaType.getIntOrFloatBitWidth();
1658
1659 if (resultBitWidth == 32 && srcBitWidth == 16) {
1660 // v32int16 -> v32acc32
1661 upsIntrOp = xllvm::Acc32V32I512UpsAIE2IntrOp::create(
1662 rewriter, loc, VectorType::get({16}, rewriter.getI64Type()),
1663 forceCastOperandsToSignature(
1664 rewriter, loc, operands,
1665 {VectorType::get({32}, rewriter.getI16Type()),
1666 rewriter.getI32Type(), rewriter.getI32Type()}));
1667 } else if (resultBitWidth == 64 && srcBitWidth == 32) {
1668 // v16int32 -> v16acc64
1669 upsIntrOp = xllvm::Acc64V16I512UpsAIE2IntrOp::create(
1670 rewriter, loc, VectorType::get({16}, rewriter.getI64Type()),
1671 forceCastOperandsToSignature(
1672 rewriter, loc, operands,
1673 {VectorType::get({16}, rewriter.getI32Type()),
1674 rewriter.getI32Type(), rewriter.getI32Type()}));
1675 } else if (resultBitWidth == 64 && srcBitWidth == 16) {
1676 // v16int16 -> v16acc64
1677 upsIntrOp = xllvm::Acc64V16I256UpsAIE2IntrOp::create(
1678 rewriter, loc, VectorType::get({16}, rewriter.getI64Type()),
1679 forceCastOperandsToSignature(
1680 rewriter, loc, operands,
1681 {VectorType::get({16}, rewriter.getI16Type()),
1682 rewriter.getI32Type(), rewriter.getI32Type()}));
1683 } else if (resultBitWidth == 32 && srcBitWidth == 8) {
1684 // v32int8 -> v32acc32
1685 upsIntrOp = xllvm::Acc32V32I256UpsAIE2IntrOp::create(
1686 rewriter, loc, VectorType::get({16}, rewriter.getI64Type()),
1687 forceCastOperandsToSignature(
1688 rewriter, loc, operands,
1689 {VectorType::get({32}, rewriter.getI8Type()),
1690 rewriter.getI32Type(), rewriter.getI32Type()}));
1691 }
1692 }
1693 } else {
1694 // Float types
1695 // AIE2p uses native F32 types, AIE2 uses packed I64 types
1696 if (resultVectorSize == 512) {
1697 // v16bfloat16 -> v16accfloat
1698 upsIntrOp = xllvm::Vector16BF16ToV16AccFloatAIE2IntrOp::create(
1699 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()),
1700 forceCastOperandsToSignature(
1701 rewriter, loc, {opSrcVal},
1702 {VectorType::get({16}, rewriter.getBF16Type())}));
1703 } else if (resultVectorSize == 1024) {
1704 // v32bfloat16 -> v32accfloat
1705 // The CPP example of the implementation is below:
1706 // INTRINSIC(v32accfloat) ups_to_v32accfloat(v32bfloat16 a) {
1707 // v16accfloat x0 = ups_to_v16accfloat(extract_v16bfloat16(a, 0));
1708 // v16accfloat x1 = ups_to_v16accfloat(extract_v16bfloat16(a, 1));
1709 // return concat(x0, x1);
1710 // }
1711 auto indexZeroCst =
1712 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1713 rewriter.getI32IntegerAttr(0));
1714 auto indexOneCst =
1715 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1716 rewriter.getI32IntegerAttr(1));
1717 auto extractUps = [&](Value source, Value index) -> Value {
1718 auto extOp = xllvm::ExtI256I512IntrOp::create(
1719 rewriter, loc, VectorType::get({8}, rewriter.getI32Type()),
1720 forceCastOperandsToSignature(
1721 rewriter, loc, {source, index},
1722 {VectorType::get({16}, rewriter.getI32Type()),
1723 rewriter.getI32Type()}));
1724 return xllvm::Vector16BF16ToV16AccFloatAIE2IntrOp::create(
1725 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()),
1726 forceCastOperandsToSignature(
1727 rewriter, loc, {extOp},
1728 {VectorType::get({16}, rewriter.getBF16Type())}));
1729 };
1730 auto resLo = extractUps(opSrcVal, indexZeroCst);
1731 auto resHi = extractUps(opSrcVal, indexOneCst);
1732 // Concat the two 512-bit vector to a 1024-bit vector.
1733 // Note that given sources a0 and a1, the result is [a1; a0].
1734 upsIntrOp = xllvm::ConcatI1024I512IntrOp::create(
1735 rewriter, loc, VectorType::get({32}, rewriter.getI32Type()),
1736 forceCastOperandsToSignature(
1737 rewriter, loc, {resLo, resHi},
1738 {VectorType::get({16}, rewriter.getI32Type()),
1739 VectorType::get({16}, rewriter.getI32Type())}));
1740 }
1741 }
1742
1743 if (!upsIntrOp) {
1744 op.emitWarning() << "aievec.ups is not supported.\n";
1745 return failure();
1746 }
1747
1748 // create bitcast for result if needed
1749 if (flatResTy != upsIntrOp.getType())
1750 upsIntrOp = LLVM::BitcastOp::create(rewriter, loc, flatResTy, upsIntrOp);
1751
1752 if (flatResTy != resultType)
1753 upsIntrOp =
1754 vector::ShapeCastOp::create(rewriter, loc, resultType, upsIntrOp);
1755
1756 rewriter.replaceOp(op, upsIntrOp);
1757
1758 return success();
1759 }
1760};
1761
1762// TODO: Split the op at AIEVec dialect level
1764 : public mlir::ConvertOpToLLVMPattern<aievec::UPSOp> {
1765public:
1766 using ConvertOpToLLVMPattern<aievec::UPSOp>::ConvertOpToLLVMPattern;
1767
1768 LogicalResult
1769 matchAndRewrite(aievec::UPSOp op, OpAdaptor adaptor,
1770 ConversionPatternRewriter &rewriter) const override {
1771 Location loc = op.getLoc();
1772
1773 Value result = op.getResult();
1774 VectorType resultType = cast<VectorType>(result.getType());
1775 VectorType flatResTy = getFlattenedVectorType(resultType);
1776 Type resultScaTy = resultType.getElementType();
1777 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
1778 int resultLanes = getVectorLaneSize(resultType);
1779 int resultVectorSize = resultBitWidth * resultLanes;
1780
1781 Value opSrcVal = adaptor.getSource();
1782 auto srcVecTy = cast<VectorType>(opSrcVal.getType());
1783 auto fltSrcVecTy = getFlattenedVectorType(srcVecTy);
1784 if (srcVecTy != fltSrcVecTy)
1785 opSrcVal = vector::ShapeCastOp::create(rewriter, op.getLoc(), fltSrcVecTy,
1786 opSrcVal)
1787 .getResult();
1788
1789 // create xllvm intrinsic
1790 // Integer types
1791 Value upsIntrOp = nullptr;
1792 if (llvm::isa<IntegerType>(resultScaTy)) {
1793 // create constant for sign
1794 auto signCst = LLVM::ConstantOp::create(
1795 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
1796 auto shiftCst =
1797 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1798 rewriter.getI32IntegerAttr(op.getShift()));
1799
1800 SmallVector<Value> operands({opSrcVal, shiftCst, signCst});
1801 if (resultVectorSize == 512) {
1802 if (resultBitWidth == 32) {
1803 // v16int16 -> v16acc32
1804 upsIntrOp = xllvm::Acc32V16I256UpsAIE2pIntrOp::create(
1805 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()),
1806 forceCastOperandsToSignature(
1807 rewriter, loc, operands,
1808 {VectorType::get({16}, rewriter.getI16Type()),
1809 rewriter.getI32Type(), rewriter.getI32Type()}));
1810 } else if (resultBitWidth == 64) {
1811 // v8int32 -> v8acc64
1812 upsIntrOp = xllvm::Acc64V8I256UpsAIE2pIntrOp::create(
1813 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()),
1814 forceCastOperandsToSignature(
1815 rewriter, loc, operands,
1816 {VectorType::get({8}, rewriter.getI32Type()),
1817 rewriter.getI32Type(), rewriter.getI32Type()}));
1818 }
1819 } else if (resultVectorSize == 1024) {
1820 Value src = opSrcVal;
1821 VectorType srcType = cast<VectorType>(src.getType());
1822 Type srcScaType = srcType.getElementType();
1823 unsigned srcBitWidth = srcScaType.getIntOrFloatBitWidth();
1824 int srcLanes = getVectorLaneSize(srcType);
1825 int srcVectorSize = srcBitWidth * srcLanes;
1826
1827 if (resultBitWidth == 32 && srcBitWidth == 16 && srcVectorSize == 512) {
1828 // v32int16 -> v32acc32
1829 upsIntrOp = xllvm::Acc32V32I512UpsAIE2pIntrOp::create(
1830 rewriter, loc, VectorType::get({32}, rewriter.getI32Type()),
1831 forceCastOperandsToSignature(
1832 rewriter, loc, operands,
1833 {VectorType::get({32}, rewriter.getI16Type()),
1834 rewriter.getI32Type(), rewriter.getI32Type()}));
1835 } else if (resultBitWidth == 64 && srcBitWidth == 32 &&
1836 srcVectorSize == 512) {
1837 // v16int32 -> v16acc64
1838 upsIntrOp = xllvm::Acc64V16I512UpsAIE2pIntrOp::create(
1839 rewriter, loc, VectorType::get({16}, rewriter.getI64Type()),
1840 forceCastOperandsToSignature(
1841 rewriter, loc, operands,
1842 {VectorType::get({16}, rewriter.getI32Type()),
1843 rewriter.getI32Type(), rewriter.getI32Type()}));
1844 } else if (resultBitWidth == 64 && srcBitWidth == 16 &&
1845 srcVectorSize == 256) {
1846 // v16int16 -> v16acc64
1847 upsIntrOp = xllvm::Acc64V16I256UpsAIE2pIntrOp::create(
1848 rewriter, loc, VectorType::get({16}, rewriter.getI64Type()),
1849 forceCastOperandsToSignature(
1850 rewriter, loc, operands,
1851 {VectorType::get({16}, rewriter.getI16Type()),
1852 rewriter.getI32Type(), rewriter.getI32Type()}));
1853 } else if (resultBitWidth == 32 && srcBitWidth == 8 &&
1854 srcVectorSize == 256) {
1855 // v32int8 -> v32acc32
1856 upsIntrOp = xllvm::Acc32V32I256UpsAIE2pIntrOp::create(
1857 rewriter, loc, VectorType::get({32}, rewriter.getI32Type()),
1858 forceCastOperandsToSignature(
1859 rewriter, loc, operands,
1860 {VectorType::get({32}, rewriter.getI8Type()),
1861 rewriter.getI32Type(), rewriter.getI32Type()}));
1862 }
1863 } else if (resultVectorSize == 2048) {
1864 Value src = opSrcVal;
1865 VectorType srcType = cast<VectorType>(src.getType());
1866 Type srcScaType = srcType.getElementType();
1867 unsigned srcBitWidth = srcScaType.getIntOrFloatBitWidth();
1868 int srcLanes = getVectorLaneSize(srcType);
1869 int srcVectorSize = srcBitWidth * srcLanes;
1870
1871 if (resultBitWidth == 32 && srcBitWidth == 8 && srcVectorSize == 512) {
1872 // v64int8 -> v64acc32
1873 upsIntrOp = xllvm::Acc32V64I512UpsAIE2pIntrOp::create(
1874 rewriter, loc, VectorType::get({64}, rewriter.getI32Type()),
1875 forceCastOperandsToSignature(
1876 rewriter, loc, operands,
1877 {VectorType::get({64}, rewriter.getI8Type()),
1878 rewriter.getI32Type(), rewriter.getI32Type()}));
1879 } else if (resultBitWidth == 64 && srcBitWidth == 16 &&
1880 srcVectorSize == 512) {
1881 // v32int16 -> v32acc64
1882 upsIntrOp = xllvm::Acc64V32I512UpsAIE2pIntrOp::create(
1883 rewriter, loc, VectorType::get({32}, rewriter.getI64Type()),
1884 forceCastOperandsToSignature(
1885 rewriter, loc, operands,
1886 {VectorType::get({32}, rewriter.getI16Type()),
1887 rewriter.getI32Type(), rewriter.getI32Type()}));
1888 } else if (resultBitWidth == 32 && srcBitWidth == 16 &&
1889 srcVectorSize == 1024) {
1890 // v64int16 -> v64acc32
1891 // Extract 2 chunks of v32int16 and convert each to v32acc32
1892 auto index0Cst =
1893 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1894 rewriter.getI32IntegerAttr(0));
1895 auto index1Cst =
1896 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1897 rewriter.getI32IntegerAttr(1));
1898
1899 auto extractUps2048 = [&](Value source, Value index, Value shiftCst,
1900 Value signCst) -> Value {
1901 // Use vector::ShuffleOp to extract 512-bit from 1024-bit
1902 // Cast source to v32xi32 for shuffling
1903 auto v32i32Source = forceCastValueToType(
1904 rewriter, loc, source,
1905 VectorType::get({32}, rewriter.getI32Type()));
1906
1907 // Determine shuffle mask based on index
1908 // index 0: elements [0-15]
1909 // index 1: elements [16-31]
1910 SmallVector<int64_t> shuffleMask;
1911 if (auto constIndex = index.getDefiningOp<LLVM::ConstantOp>()) {
1912 auto indexAttr = cast<IntegerAttr>(constIndex.getValue());
1913 int64_t idxVal = indexAttr.getInt();
1914 int startIdx = idxVal * 16;
1915 for (int i = 0; i < 16; ++i) {
1916 shuffleMask.push_back(startIdx + i);
1917 }
1918 } else {
1919 // Default to index 0 if not constant
1920 for (int i = 0; i < 16; ++i) {
1921 shuffleMask.push_back(i);
1922 }
1923 }
1924
1925 auto extOp = vector::ShuffleOp::create(rewriter, loc, v32i32Source,
1926 v32i32Source, shuffleMask);
1927
1928 return xllvm::Acc32V32I512UpsAIE2pIntrOp::create(
1929 rewriter, loc, VectorType::get({32}, rewriter.getI32Type()),
1930 forceCastOperandsToSignature(
1931 rewriter, loc, {extOp, shiftCst, signCst},
1932 {VectorType::get({32}, rewriter.getI16Type()),
1933 rewriter.getI32Type(), rewriter.getI32Type()}));
1934 };
1935
1936 auto res0 = extractUps2048(opSrcVal, index0Cst, shiftCst, signCst);
1937 auto res1 = extractUps2048(opSrcVal, index1Cst, shiftCst, signCst);
1938
1939 // Concat two 1024-bit vectors to a 2048-bit vector using
1940 // vector::ShuffleOp
1941 SmallVector<int64_t> concatMask;
1942 for (int i = 0; i < 64; ++i) {
1943 concatMask.push_back(i);
1944 }
1945 upsIntrOp =
1946 vector::ShuffleOp::create(rewriter, loc, res0, res1, concatMask);
1947 }
1948 }
1949 } else {
1950 // Float types
1951 // AIE2p uses native F32 types, AIE2 uses packed I64 types
1952 if (resultVectorSize == 512) {
1953 // v16bfloat16 -> v16accfloat
1954 upsIntrOp = xllvm::Vector16BF16ToV16AccFloatAIE2pIntrOp::create(
1955 rewriter, loc, VectorType::get({16}, rewriter.getF32Type()),
1956 forceCastOperandsToSignature(
1957 rewriter, loc, {opSrcVal},
1958 {VectorType::get({16}, rewriter.getBF16Type())}));
1959 } else if (resultVectorSize == 1024) {
1960 // v32bfloat16 -> v32accfloat
1961 upsIntrOp = xllvm::Vector32BF16ToV32AccFloatAIE2pIntrOp::create(
1962 rewriter, loc, VectorType::get({32}, rewriter.getF32Type()),
1963 forceCastOperandsToSignature(
1964 rewriter, loc, {opSrcVal},
1965 {VectorType::get({32}, rewriter.getBF16Type())}));
1966 } else if (resultVectorSize == 2048) {
1967 // v64bfloat16 -> v64accfloat
1968 // Extract 2 chunks of v32bfloat16 and convert each to v32accfloat
1969 auto index0Cst =
1970 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1971 rewriter.getI32IntegerAttr(0));
1972 auto index1Cst =
1973 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1974 rewriter.getI32IntegerAttr(1));
1975
1976 auto extractUps2048 = [&](Value source, Value index) -> Value {
1977 // Use vector::ShuffleOp to extract 512-bit from 1024-bit
1978 // Cast source to v32xi32 for shuffling
1979 auto v32i32Source = forceCastValueToType(
1980 rewriter, loc, source,
1981 VectorType::get({32}, rewriter.getI32Type()));
1982
1983 // Determine shuffle mask based on index
1984 // index 0: elements [0-15]
1985 // index 1: elements [16-31]
1986 SmallVector<int64_t> shuffleMask;
1987 if (auto constIndex = index.getDefiningOp<LLVM::ConstantOp>()) {
1988 auto indexAttr = cast<IntegerAttr>(constIndex.getValue());
1989 int64_t idxVal = indexAttr.getInt();
1990 int startIdx = idxVal * 16;
1991 for (int i = 0; i < 16; ++i) {
1992 shuffleMask.push_back(startIdx + i);
1993 }
1994 } else {
1995 // Default to index 0 if not constant
1996 for (int i = 0; i < 16; ++i) {
1997 shuffleMask.push_back(i);
1998 }
1999 }
2000
2001 auto extOp = vector::ShuffleOp::create(rewriter, loc, v32i32Source,
2002 v32i32Source, shuffleMask);
2003
2004 return xllvm::Vector32BF16ToV32AccFloatAIE2pIntrOp::create(
2005 rewriter, loc, VectorType::get({32}, rewriter.getF32Type()),
2006 forceCastOperandsToSignature(
2007 rewriter, loc, {extOp},
2008 {VectorType::get({32}, rewriter.getBF16Type())}));
2009 };
2010
2011 auto res0 = extractUps2048(opSrcVal, index0Cst);
2012 auto res1 = extractUps2048(opSrcVal, index1Cst);
2013
2014 // Concat two 1024-bit vectors to a 2048-bit vector using
2015 // vector::ShuffleOp
2016 auto v32i32Res0 = forceCastValueToType(
2017 rewriter, loc, res0, VectorType::get({32}, rewriter.getI32Type()));
2018 auto v32i32Res1 = forceCastValueToType(
2019 rewriter, loc, res1, VectorType::get({32}, rewriter.getI32Type()));
2020
2021 SmallVector<int64_t> concatMask;
2022 for (int i = 0; i < 64; ++i) {
2023 concatMask.push_back(i);
2024 }
2025 upsIntrOp = vector::ShuffleOp::create(rewriter, loc, v32i32Res0,
2026 v32i32Res1, concatMask);
2027 }
2028 }
2029
2030 if (!upsIntrOp) {
2031 op.emitWarning() << "aievec.ups is not supported.\n";
2032 return failure();
2033 }
2034
2035 // create bitcast for result if needed
2036 if (flatResTy != upsIntrOp.getType())
2037 upsIntrOp = LLVM::BitcastOp::create(rewriter, loc, flatResTy, upsIntrOp);
2038
2039 if (flatResTy != resultType)
2040 upsIntrOp =
2041 vector::ShapeCastOp::create(rewriter, loc, resultType, upsIntrOp);
2042
2043 rewriter.replaceOp(op, upsIntrOp);
2044
2045 return success();
2046 }
2047};
2048
2049class SRSOpAIE2Conversion : public mlir::ConvertOpToLLVMPattern<aievec::SRSOp> {
2050public:
2051 using ConvertOpToLLVMPattern<aievec::SRSOp>::ConvertOpToLLVMPattern;
2052
2053 LogicalResult
2054 matchAndRewrite(aievec::SRSOp op, OpAdaptor adaptor,
2055 ConversionPatternRewriter &rewriter) const override {
2056 Location loc = op.getLoc();
2057
2058 Value result = op.getResult();
2059 VectorType resultType = cast<VectorType>(result.getType());
2060 Type resultScaTy = resultType.getElementType();
2061 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
2062 int resultLanes = getVectorLaneSize(resultType);
2063 int resultVectorSize = resultBitWidth * resultLanes;
2064
2065 // Integer types
2066 Value srsIntrOp = nullptr;
2067 if (llvm::isa<IntegerType>(resultScaTy)) {
2068 // create constant for sign from the op's sign attribute
2069 auto signCst =
2070 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
2071 rewriter.getI32IntegerAttr(op.getSign()));
2072
2073 // create xllvm intrinsic
2074 SmallVector<Value> operands(
2075 {adaptor.getSource(), adaptor.getShift(), signCst});
2076 if (resultVectorSize == 512) {
2077 if (resultBitWidth == 16) {
2078 srsIntrOp = xllvm::I512V32Acc32SrsAIE2IntrOp::create(
2079 rewriter, loc, VectorType::get({32}, rewriter.getI16Type()),
2080 forceCastOperandsToSignature(
2081 rewriter, loc, operands,
2082 {VectorType::get({16}, rewriter.getI64Type()),
2083 rewriter.getI32Type(), rewriter.getI32Type()}));
2084 } else if (resultBitWidth == 32) {
2085 srsIntrOp = xllvm::I512V16Acc64SrsAIE2IntrOp::create(
2086 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()),
2087 forceCastOperandsToSignature(
2088 rewriter, loc, operands,
2089 {VectorType::get({16}, rewriter.getI64Type()),
2090 rewriter.getI32Type(), rewriter.getI32Type()}));
2091 }
2092 } else if (resultVectorSize == 256) {
2093 Value src = adaptor.getSource();
2094 VectorType srcType = cast<VectorType>(src.getType());
2095 Type srcScaType = srcType.getElementType();
2096 unsigned srcBitWidth = srcScaType.getIntOrFloatBitWidth();
2097
2098 if (resultBitWidth == 16 && srcBitWidth == 32) {
2099 srsIntrOp = xllvm::I256V16Acc32SrsAIE2IntrOp::create(
2100 rewriter, loc, VectorType::get({16}, rewriter.getI16Type()),
2101 forceCastOperandsToSignature(
2102 rewriter, loc, operands,
2103 {VectorType::get({8}, rewriter.getI64Type()),
2104 rewriter.getI32Type(), rewriter.getI32Type()}));
2105 } else if (resultBitWidth == 8 && srcBitWidth == 32) {
2106 srsIntrOp = xllvm::I256V32Acc32SrsAIE2IntrOp::create(
2107 rewriter, loc, VectorType::get({32}, rewriter.getI8Type()),
2108 forceCastOperandsToSignature(
2109 rewriter, loc, operands,
2110 {VectorType::get({16}, rewriter.getI64Type()),
2111 rewriter.getI32Type(), rewriter.getI32Type()}));
2112 } else if (resultBitWidth == 16 && srcBitWidth == 64) {
2113 srsIntrOp = xllvm::I256V16Acc64SrsAIE2IntrOp::create(
2114 rewriter, loc, VectorType::get({16}, rewriter.getI16Type()),
2115 forceCastOperandsToSignature(
2116 rewriter, loc, operands,
2117 {VectorType::get({16}, rewriter.getI64Type()),
2118 rewriter.getI32Type(), rewriter.getI32Type()}));
2119 } else if (resultBitWidth == 32 && srcBitWidth == 64) {
2120 srsIntrOp = xllvm::I256V8Acc64SrsAIE2IntrOp::create(
2121 rewriter, loc, VectorType::get({8}, rewriter.getI32Type()),
2122 forceCastOperandsToSignature(
2123 rewriter, loc, operands,
2124 {VectorType::get({8}, rewriter.getI64Type()),
2125 rewriter.getI32Type(), rewriter.getI32Type()}));
2126 }
2127 }
2128 } else {
2129 // Float types
2130 if (resultVectorSize == 256) {
2131 srsIntrOp = xllvm::Vector16AccFloatToV16BF16AIE2IntrOp::create(
2132 rewriter, loc, VectorType::get({16}, rewriter.getBF16Type()),
2133 forceCastOperandsToSignature(
2134 rewriter, loc, {adaptor.getSource()},
2135 {VectorType::get({8}, rewriter.getI64Type())}));
2136 } else if (resultVectorSize == 512) {
2137 // v32accfloat -> v32bfloat16
2138 // The CPP example of the implementation is below:
2139 // v32bfloat16 to_v32bfloat16(v32accfloat acc) {
2140 // v16bfloat16 x0 = to_v16bfloat16(extract_v16accfloat(acc, 0));
2141 // v16bfloat16 x1 = to_v16bfloat16(extract_v16accfloat(acc, 1));
2142 // return concat(x0, x1);
2143 // }
2144 auto indexZeroCst =
2145 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
2146 rewriter.getI32IntegerAttr(0));
2147 auto indexOneCst =
2148 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
2149 rewriter.getI32IntegerAttr(1));
2150 auto extractSrs = [&](Value source, Value index) -> Value {
2151 auto extOp = xllvm::ExtI512I1024IntrOp::create(
2152 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()),
2153 forceCastOperandsToSignature(
2154 rewriter, loc, {source, index},
2155 {VectorType::get({32}, rewriter.getI32Type()),
2156 rewriter.getI32Type()}));
2157 return xllvm::Vector16AccFloatToV16BF16AIE2IntrOp::create(
2158 rewriter, loc, VectorType::get({16}, rewriter.getBF16Type()),
2159 forceCastOperandsToSignature(
2160 rewriter, loc, {extOp},
2161 {VectorType::get({8}, rewriter.getI64Type())}));
2162 };
2163 auto resLo = extractSrs(adaptor.getSource(), indexZeroCst);
2164 auto resHi = extractSrs(adaptor.getSource(), indexOneCst);
2165 // Concat the two 256-bit vector to a 512-bit vector.
2166 // Note that given sources a0 and a1, the result is [a1; a0].
2167 srsIntrOp = xllvm::ConcatI512I256IntrOp::create(
2168 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()),
2169 forceCastOperandsToSignature(
2170 rewriter, loc, {resLo, resHi},
2171 {VectorType::get({8}, rewriter.getI32Type()),
2172 VectorType::get({8}, rewriter.getI32Type())}));
2173 }
2174 }
2175
2176 if (!srsIntrOp) {
2177 op.emitWarning() << "aievec.srs is not supported.\n";
2178 return failure();
2179 }
2180
2181 // create bitcast/shape_cast for result if needed
2182 auto resultVal = forceCastValueToType(rewriter, loc, srsIntrOp,
2183 op.getResult().getType());
2184 rewriter.replaceOp(op, resultVal);
2185
2186 return success();
2187 }
2188};
2189
2190// TODO: Split the op at AIEVec dialect level
2192 : public mlir::ConvertOpToLLVMPattern<aievec::SRSOp> {
2193public:
2194 using ConvertOpToLLVMPattern<aievec::SRSOp>::ConvertOpToLLVMPattern;
2195
2196 LogicalResult
2197 matchAndRewrite(aievec::SRSOp op, OpAdaptor adaptor,
2198 ConversionPatternRewriter &rewriter) const override {
2199 Location loc = op.getLoc();
2200
2201 Value result = op.getResult();
2202 VectorType resultType = cast<VectorType>(result.getType());
2203 Type resultScaTy = resultType.getElementType();
2204 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
2205 int resultLanes = getVectorLaneSize(resultType);
2206 int resultVectorSize = resultBitWidth * resultLanes;
2207
2208 // Integer types
2209 Value srsIntrOp = nullptr;
2210 if (llvm::isa<IntegerType>(resultScaTy)) {
2211 // create constant for sign from the op's sign attribute
2212 auto signCst =
2213 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
2214 rewriter.getI32IntegerAttr(op.getSign()));
2215
2216 // create xllvm intrinsic
2217 SmallVector<Value> operands(
2218 {adaptor.getSource(), adaptor.getShift(), signCst});
2219 if (resultVectorSize == 512) {
2220 Value src = adaptor.getSource();
2221 VectorType srcType = cast<VectorType>(src.getType());
2222 Type srcScaType = srcType.getElementType();
2223 unsigned srcBitWidth = srcScaType.getIntOrFloatBitWidth();
2224
2225 if (resultBitWidth == 16 && srcBitWidth == 32) {
2226 // v32acc32 -> v32int16
2227 srsIntrOp = xllvm::I512V32Acc32SrsAIE2pIntrOp::create(
2228 rewriter, loc, VectorType::get({32}, rewriter.getI16Type()),
2229 forceCastOperandsToSignature(
2230 rewriter, loc, operands,
2231 {VectorType::get({32}, rewriter.getI32Type()),
2232 rewriter.getI32Type(), rewriter.getI32Type()}));
2233 } else if (resultBitWidth == 16 && srcBitWidth == 64) {
2234 // v32acc64 -> v32int16
2235 srsIntrOp = xllvm::I512V32Acc64SrsAIE2pIntrOp::create(
2236 rewriter, loc, VectorType::get({32}, rewriter.getI16Type()),
2237 forceCastOperandsToSignature(
2238 rewriter, loc, operands,
2239 {VectorType::get({32}, rewriter.getI64Type()),
2240 rewriter.getI32Type(), rewriter.getI32Type()}));
2241 } else if (resultBitWidth == 32 && srcBitWidth == 64) {
2242 // v16acc64 -> v16int32
2243 srsIntrOp = xllvm::I512V16Acc64SrsAIE2pIntrOp::create(
2244 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()),
2245 forceCastOperandsToSignature(
2246 rewriter, loc, operands,
2247 {VectorType::get({16}, rewriter.getI64Type()),
2248 rewriter.getI32Type(), rewriter.getI32Type()}));
2249 } else if (resultBitWidth == 8 && srcBitWidth == 32) {
2250 // v64acc32 -> v64int8
2251 srsIntrOp = xllvm::I512V64Acc32SrsAIE2pIntrOp::create(
2252 rewriter, loc, VectorType::get({64}, rewriter.getI8Type()),
2253 forceCastOperandsToSignature(
2254 rewriter, loc, operands,
2255 {VectorType::get({64}, rewriter.getI32Type()),
2256 rewriter.getI32Type(), rewriter.getI32Type()}));
2257 }
2258 } else if (resultVectorSize == 256) {
2259 Value src = adaptor.getSource();
2260 VectorType srcType = cast<VectorType>(src.getType());
2261 Type srcScaType = srcType.getElementType();
2262 unsigned srcBitWidth = srcScaType.getIntOrFloatBitWidth();
2263
2264 if (resultBitWidth == 16 && srcBitWidth == 32) {
2265 // v16acc32 -> v16int16
2266 srsIntrOp = xllvm::I256V16Acc32SrsAIE2pIntrOp::create(
2267 rewriter, loc, VectorType::get({16}, rewriter.getI16Type()),
2268 forceCastOperandsToSignature(
2269 rewriter, loc, operands,
2270 {VectorType::get({16}, rewriter.getI32Type()),
2271 rewriter.getI32Type(), rewriter.getI32Type()}));
2272 } else if (resultBitWidth == 8 && srcBitWidth == 32) {
2273 // v32acc32 -> v32int8
2274 srsIntrOp = xllvm::I256V32Acc32SrsAIE2pIntrOp::create(
2275 rewriter, loc, VectorType::get({32}, rewriter.getI8Type()),
2276 forceCastOperandsToSignature(
2277 rewriter, loc, operands,
2278 {VectorType::get({32}, rewriter.getI32Type()),
2279 rewriter.getI32Type(), rewriter.getI32Type()}));
2280 } else if (resultBitWidth == 16 && srcBitWidth == 64) {
2281 // v16acc64 -> v16int16
2282 srsIntrOp = xllvm::I256V16Acc64SrsAIE2pIntrOp::create(
2283 rewriter, loc, VectorType::get({16}, rewriter.getI16Type()),
2284 forceCastOperandsToSignature(
2285 rewriter, loc, operands,
2286 {VectorType::get({16}, rewriter.getI64Type()),
2287 rewriter.getI32Type(), rewriter.getI32Type()}));
2288 } else if (resultBitWidth == 32 && srcBitWidth == 64) {
2289 // v8acc64 -> v8int32
2290 srsIntrOp = xllvm::I256V8Acc64SrsAIE2pIntrOp::create(
2291 rewriter, loc, VectorType::get({8}, rewriter.getI32Type()),
2292 forceCastOperandsToSignature(
2293 rewriter, loc, operands,
2294 {VectorType::get({8}, rewriter.getI64Type()),
2295 rewriter.getI32Type(), rewriter.getI32Type()}));
2296 }
2297 } else if (resultVectorSize == 1024) {
2298 Value src = adaptor.getSource();
2299 VectorType srcType = cast<VectorType>(src.getType());
2300 Type srcScaType = srcType.getElementType();
2301 unsigned srcBitWidth = srcScaType.getIntOrFloatBitWidth();
2302
2303 if (resultBitWidth == 16 && srcBitWidth == 32) {
2304 // v64acc32 -> v64int16
2305 // Extract 2 chunks of v32acc32 and convert each to v32int16
2306 auto index0Cst =
2307 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
2308 rewriter.getI32IntegerAttr(0));
2309 auto index1Cst =
2310 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
2311 rewriter.getI32IntegerAttr(1));
2312
2313 auto extractSrs1024 = [&](Value source, Value index, Value shiftCst,
2314 Value signCst) -> Value {
2315 // Use vector::ShuffleOp to extract 1024-bit from 2048-bit
2316 // Cast source to v64xi32 for shuffling
2317 auto v64i32Source = forceCastValueToType(
2318 rewriter, loc, source,
2319 VectorType::get({64}, rewriter.getI32Type()));
2320
2321 // Determine shuffle mask based on index
2322 // index 0: elements [0-31]
2323 // index 1: elements [32-63]
2324 SmallVector<int64_t> shuffleMask;
2325 if (auto constIndex = index.getDefiningOp<LLVM::ConstantOp>()) {
2326 auto indexAttr = cast<IntegerAttr>(constIndex.getValue());
2327 int64_t idxVal = indexAttr.getInt();
2328 int startIdx = idxVal * 32;
2329 for (int i = 0; i < 32; ++i) {
2330 shuffleMask.push_back(startIdx + i);
2331 }
2332 } else {
2333 // Default to index 0 if not constant
2334 for (int i = 0; i < 32; ++i) {
2335 shuffleMask.push_back(i);
2336 }
2337 }
2338
2339 auto extOp = vector::ShuffleOp::create(rewriter, loc, v64i32Source,
2340 v64i32Source, shuffleMask);
2341
2342 return xllvm::I512V32Acc32SrsAIE2pIntrOp::create(
2343 rewriter, loc, VectorType::get({32}, rewriter.getI16Type()),
2344 forceCastOperandsToSignature(
2345 rewriter, loc, {extOp, shiftCst, signCst},
2346 {VectorType::get({32}, rewriter.getI32Type()),
2347 rewriter.getI32Type(), rewriter.getI32Type()}));
2348 };
2349
2350 auto res0 =
2351 extractSrs1024(src, index0Cst, adaptor.getShift(), signCst);
2352 auto res1 =
2353 extractSrs1024(src, index1Cst, adaptor.getShift(), signCst);
2354
2355 // Concat two 512-bit vectors to a 1024-bit vector using
2356 // vector::ShuffleOp
2357 auto v16i32Res0 = forceCastValueToType(
2358 rewriter, loc, res0,
2359 VectorType::get({16}, rewriter.getI32Type()));
2360 auto v16i32Res1 = forceCastValueToType(
2361 rewriter, loc, res1,
2362 VectorType::get({16}, rewriter.getI32Type()));
2363
2364 SmallVector<int64_t> concatMask;
2365 for (int i = 0; i < 32; ++i) {
2366 concatMask.push_back(i);
2367 }
2368 srsIntrOp = vector::ShuffleOp::create(rewriter, loc, v16i32Res0,
2369 v16i32Res1, concatMask);
2370 }
2371 }
2372 } else {
2373 // Float types
2374 // AIE2p uses native F32 types, AIE2 uses packed I64 types
2375 if (resultVectorSize == 256) {
2376 // v16accfloat -> v16bfloat16
2377 srsIntrOp = xllvm::Vector16AccFloatToV16BF16AIE2pIntrOp::create(
2378 rewriter, loc, VectorType::get({16}, rewriter.getBF16Type()),
2379 forceCastOperandsToSignature(
2380 rewriter, loc, {adaptor.getSource()},
2381 {VectorType::get({16}, rewriter.getF32Type())}));
2382 } else if (resultVectorSize == 512) {
2383 // v32accfloat -> v32bfloat16
2384 srsIntrOp = xllvm::Vector32AccFloatToV32BF16AIE2pIntrOp::create(
2385 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()),
2386 forceCastOperandsToSignature(
2387 rewriter, loc, {adaptor.getSource()},
2388 {VectorType::get({32}, rewriter.getF32Type())}));
2389 } else if (resultVectorSize == 1024) {
2390 // v64accfloat -> v64bfloat16
2391 // Extract 2 chunks of v32accfloat and convert each to v32bfloat16
2392 auto index0Cst =
2393 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
2394 rewriter.getI32IntegerAttr(0));
2395 auto index1Cst =
2396 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
2397 rewriter.getI32IntegerAttr(1));
2398
2399 auto extractSrs1024 = [&](Value source, Value index) -> Value {
2400 // Use vector::ShuffleOp to extract 1024-bit from 2048-bit
2401 // Cast source to v64xi32 for shuffling
2402 auto v64i32Source = forceCastValueToType(
2403 rewriter, loc, source,
2404 VectorType::get({64}, rewriter.getI32Type()));
2405
2406 // Determine shuffle mask based on index
2407 // index 0: elements [0-31]
2408 // index 1: elements [32-63]
2409 SmallVector<int64_t> shuffleMask;
2410 if (auto constIndex = index.getDefiningOp<LLVM::ConstantOp>()) {
2411 auto indexAttr = cast<IntegerAttr>(constIndex.getValue());
2412 int64_t idxVal = indexAttr.getInt();
2413 int startIdx = idxVal * 32;
2414 for (int i = 0; i < 32; ++i) {
2415 shuffleMask.push_back(startIdx + i);
2416 }
2417 } else {
2418 // Default to index 0 if not constant
2419 for (int i = 0; i < 32; ++i) {
2420 shuffleMask.push_back(i);
2421 }
2422 }
2423
2424 auto extOp = vector::ShuffleOp::create(rewriter, loc, v64i32Source,
2425 v64i32Source, shuffleMask);
2426
2427 return xllvm::Vector32AccFloatToV32BF16AIE2pIntrOp::create(
2428 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()),
2429 forceCastOperandsToSignature(
2430 rewriter, loc, {extOp},
2431 {VectorType::get({32}, rewriter.getF32Type())}));
2432 };
2433
2434 auto res0 = extractSrs1024(adaptor.getSource(), index0Cst);
2435 auto res1 = extractSrs1024(adaptor.getSource(), index1Cst);
2436
2437 // Concat two 512-bit vectors to a 1024-bit vector using
2438 // vector::ShuffleOp
2439 auto v16i32Res0 = forceCastValueToType(
2440 rewriter, loc, res0, VectorType::get({16}, rewriter.getI32Type()));
2441 auto v16i32Res1 = forceCastValueToType(
2442 rewriter, loc, res1, VectorType::get({16}, rewriter.getI32Type()));
2443
2444 SmallVector<int64_t> concatMask;
2445 for (int i = 0; i < 32; ++i) {
2446 concatMask.push_back(i);
2447 }
2448 srsIntrOp = vector::ShuffleOp::create(rewriter, loc, v16i32Res0,
2449 v16i32Res1, concatMask);
2450 }
2451 }
2452
2453 if (!srsIntrOp) {
2454 op.emitWarning() << "aievec.srs is not supported.\n";
2455 return failure();
2456 }
2457
2458 // create bitcast/shape_cast for result if needed
2459 auto resultVal = forceCastValueToType(rewriter, loc, srsIntrOp,
2460 op.getResult().getType());
2461 rewriter.replaceOp(op, resultVal);
2462
2463 return success();
2464 }
2465};
2466
2467class UPDOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::UPDOp> {
2468public:
2469 using ConvertOpToLLVMPattern<aievec::UPDOp>::ConvertOpToLLVMPattern;
2470
2471 static std::string getIntrinsicName(aievec::UPDOp op, int loadSize) {
2472 auto resultType = cast<VectorType>(op.getResult().getType());
2473 std::stringstream ss;
2474 ss << "llvm.aie.upd.";
2475 ss << (loadSize == 128 ? 'v' : loadSize == 256 ? 'w' : 'x') << ".";
2476 ss << getVectorTypeString(resultType) << ".";
2477 // The index affects which intrinsic to call
2478 ss << (op.getIndex() == 0 ? "lo" : "hi");
2479 return ss.str();
2480 }
2481
2482 LogicalResult
2483 matchAndRewrite(aievec::UPDOp op, OpAdaptor adaptor,
2484 ConversionPatternRewriter &rewriter) const override {
2485 auto module = op->getParentOfType<ModuleOp>();
2486 MLIRContext *context = rewriter.getContext();
2487
2488 // A bit more complicated: load the vector, then update result vector
2489 // AIE1 is capable of 128-bit on one bank and 256-bit loads on even-odd
2490 // banks Identify size of update
2491 int vecSizeInBits =
2492 getVectorSizeInBits(cast<VectorType>(op.getResult().getType()));
2493
2494 auto ptr = this->getStridedElementPtr(
2495 rewriter, op->getLoc(), cast<MemRefType>(op.getSource().getType()),
2496 adaptor.getSource(), adaptor.getIndices());
2497
2498 // TODO: handle the offset field
2499
2500 if (vecSizeInBits <= 256) {
2501 // Total <=256-bit updates are much simpler:
2502 // we can do a direct load into the vector register
2503 // look at the indices to calculate the address
2504 auto vectorPtrType = LLVM::LLVMPointerType::get(
2505 getContext(),
2506 cast<MemRefType>(op.getSource().getType()).getMemorySpaceAsInt());
2507 auto castedPtr =
2508 LLVM::BitcastOp::create(rewriter, op->getLoc(), vectorPtrType, ptr);
2509 auto vecType = cast<VectorType>(op.getResult().getType());
2510 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, vecType, castedPtr, 1);
2511 } else {
2512 // Total >256-bit updates will require upd ops to fill the whole vector
2513 // each UDP op represents one of these 256-bit loads and updates
2514
2515 // Determine the load size
2516 // TODO: no examples of 1024-bit output vectors: doesn't feel right
2517 // to attempt a 512-bit load to do an update like this
2518 int loadSize = vecSizeInBits == 256 ? 128
2519 : vecSizeInBits == 512 ? 256
2520 : 512;
2521
2522 // Create a vectorType for the load proper
2523 // Load half of the final result vector
2524 auto resultType = cast<VectorType>(op.getResult().getType());
2525 int lanes = getVectorLaneSize(resultType);
2526 auto loadType =
2527 VectorType::get({(int64_t)lanes / 2}, resultType.getElementType());
2528
2529 // Load the vector
2530 auto vectorPtrType = LLVM::LLVMPointerType::get(
2531 getContext(),
2532 cast<MemRefType>(op.getSource().getType()).getMemorySpaceAsInt());
2533 auto castedPtr =
2534 LLVM::BitcastOp::create(rewriter, op->getLoc(), vectorPtrType, ptr);
2535 auto loadValue =
2536 LLVM::LoadOp::create(rewriter, op->getLoc(), loadType, castedPtr, 1);
2537
2538 // Get set up for the intrinsic
2539 std::string intrinsicName = getIntrinsicName(op, loadSize);
2540
2541 // If the intrinsic declaration doesn't exist, create it
2542 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(
2543 StringAttr::get(context, intrinsicName));
2544
2545 if (!func) {
2546 OpBuilder::InsertionGuard guard(rewriter);
2547 rewriter.setInsertionPointToStart(module.getBody());
2548 func = LLVM::LLVMFuncOp::create(
2549 rewriter, rewriter.getUnknownLoc(), intrinsicName,
2550 LLVM::LLVMFunctionType::get(resultType, {resultType, loadType}));
2551 }
2552
2553 // Determine what the destination is
2554 Value destValue;
2555 if (adaptor.getVector()) {
2556 // This UPD is using an existing destination vector
2557 destValue = adaptor.getVector();
2558 } else {
2559 // If this UPD is not working off of an existing destination vector,
2560 // create an undefined vector as the destination
2561
2562 // TODO: determine if the undef intrinsic is needed or if an LLVM
2563 // undef suffices destValue =
2564 // LLVM::UndefOp::create(rewriter, op->getLoc(), resultType);
2565
2566 std::stringstream ss;
2567 ss << "llvm.aie." << getVectorTypeString(resultType) << ".undef";
2568 std::string intrinsicName = ss.str();
2569
2570 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(
2571 StringAttr::get(rewriter.getContext(), intrinsicName));
2572
2573 if (!func) {
2574 OpBuilder::InsertionGuard guard(rewriter);
2575 rewriter.setInsertionPointToStart(module.getBody());
2576 func = LLVM::LLVMFuncOp::create(
2577 rewriter, rewriter.getUnknownLoc(), intrinsicName,
2578 LLVM::LLVMFunctionType::get(resultType, {}));
2579 }
2580 destValue =
2581 LLVM::CallOp::create(rewriter, op->getLoc(), func, ValueRange{})
2582 ->getOpResult(0);
2583 }
2584
2585 // Create our call
2586 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
2587 op, func, ValueRange{destValue, loadValue});
2588 }
2589
2590 return success();
2591 }
2592};
2593
2595 : public mlir::ConvertOpToLLVMPattern<aievec::ConcatOp> {
2596public:
2597 using ConvertOpToLLVMPattern<aievec::ConcatOp>::ConvertOpToLLVMPattern;
2598
2599 LogicalResult
2600 matchAndRewrite(aievec::ConcatOp op, OpAdaptor adaptor,
2601 ConversionPatternRewriter &rewriter) const override {
2602 Location loc = op.getLoc();
2603
2604 SmallVector<Value> sources = adaptor.getSources();
2605 Value src = sources.front();
2606 VectorType srcType = cast<VectorType>(src.getType());
2607 Type srcScalarType = srcType.getElementType();
2608 unsigned srcBitWidth = srcScalarType.getIntOrFloatBitWidth();
2609 int srcLanes = getVectorLaneSize(srcType);
2610 int srcVectorSize = srcBitWidth * srcLanes;
2611
2612 Value result = op.getResult();
2613 VectorType resultType = cast<VectorType>(result.getType());
2614 Type resultScaTy = resultType.getElementType();
2615 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
2616 int resultLanes = getVectorLaneSize(resultType);
2617 int resultVectorSize = resultBitWidth * resultLanes;
2618
2619 if (sources.size() != 2 && sources.size() != 4) {
2620 op.emitWarning() << "aievec.concat with " << sources.size()
2621 << " operands is not supported.\n";
2622 return failure();
2623 }
2624
2625 // create xllvm intrinsic
2626 Value concatOp = nullptr;
2627 if (srcVectorSize == 256 && resultVectorSize == 512) {
2628 concatOp = xllvm::ConcatI512I256IntrOp::create(
2629 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()),
2630 forceCastOperandsToSignature(
2631 rewriter, loc, adaptor.getSources(),
2632 {VectorType::get({8}, rewriter.getI32Type()),
2633 VectorType::get({8}, rewriter.getI32Type())}));
2634 } else if (srcVectorSize == 256 && resultVectorSize == 1024) {
2635 concatOp = xllvm::ConcatI1024I256IntrOp::create(
2636 rewriter, loc, VectorType::get({32}, rewriter.getI32Type()),
2637 forceCastOperandsToSignature(
2638 rewriter, loc, adaptor.getSources(),
2639 {VectorType::get({8}, rewriter.getI32Type()),
2640 VectorType::get({8}, rewriter.getI32Type()),
2641 VectorType::get({8}, rewriter.getI32Type()),
2642 VectorType::get({8}, rewriter.getI32Type())}));
2643 } else if (srcVectorSize == 512 && resultVectorSize == 1024) {
2644 concatOp = xllvm::ConcatI1024I512IntrOp::create(
2645 rewriter, loc, VectorType::get({32}, rewriter.getI32Type()),
2646 forceCastOperandsToSignature(
2647 rewriter, loc, adaptor.getSources(),
2648 {VectorType::get({16}, rewriter.getI32Type()),
2649 VectorType::get({16}, rewriter.getI32Type())}));
2650 } else {
2651 op.emitWarning() << "aievec.concat with " << srcVectorSize
2652 << "-bit operands, and " << resultVectorSize
2653 << "-bit result is not supported.\n";
2654 return failure();
2655 }
2656
2657 // create bitcast/shape_cast for result
2658 auto resultVal =
2659 forceCastValueToType(rewriter, loc, concatOp, op.getResult().getType());
2660 rewriter.replaceOp(op, resultVal);
2661
2662 return success();
2663 }
2664};
2665
2666class ExtOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::ExtOp> {
2667public:
2668 using ConvertOpToLLVMPattern<aievec::ExtOp>::ConvertOpToLLVMPattern;
2669
2670 LogicalResult
2671 matchAndRewrite(aievec::ExtOp op, OpAdaptor adaptor,
2672 ConversionPatternRewriter &rewriter) const override {
2673 Location loc = op.getLoc();
2674
2675 Value src = adaptor.getSource();
2676 VectorType srcType = cast<VectorType>(src.getType());
2677 Type srcScalarType = srcType.getElementType();
2678 unsigned srcBitWidth = srcScalarType.getIntOrFloatBitWidth();
2679 int srcLanes = getVectorLaneSize(srcType);
2680 int srcVectorSize = srcBitWidth * srcLanes;
2681
2682 Value result = op.getResult();
2683 VectorType resultType = cast<VectorType>(result.getType());
2684 Type resultScaTy = resultType.getElementType();
2685 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
2686 int resultLanes = getVectorLaneSize(resultType);
2687 int resultVectorSize = resultBitWidth * resultLanes;
2688
2689 // create constant for index
2690 auto indexCst =
2691 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
2692 rewriter.getI32IntegerAttr(op.getIndex()));
2693
2694 // create xllvm intrinsic
2695 SmallVector<Value> operands({adaptor.getSource(), indexCst});
2696 Value extOp = nullptr;
2697 // Integer types
2698 if (resultVectorSize == 256 && srcVectorSize == 512) {
2699 extOp = xllvm::ExtI256I512IntrOp::create(
2700 rewriter, loc, VectorType::get({8}, rewriter.getI32Type()),
2701 forceCastOperandsToSignature(
2702 rewriter, loc, operands,
2703 {VectorType::get({16}, rewriter.getI32Type()),
2704 rewriter.getI32Type()}));
2705 } else if (resultVectorSize == 512 && srcVectorSize == 1024) {
2706 extOp = xllvm::ExtI512I1024IntrOp::create(
2707 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()),
2708 forceCastOperandsToSignature(
2709 rewriter, loc, operands,
2710 {VectorType::get({32}, rewriter.getI32Type()),
2711 rewriter.getI32Type()}));
2712 } else if (resultVectorSize == 256 && srcVectorSize == 1024) {
2713 extOp = xllvm::ExtI256I1024IntrOp::create(
2714 rewriter, loc, VectorType::get({8}, rewriter.getI32Type()),
2715 forceCastOperandsToSignature(
2716 rewriter, loc, operands,
2717 {VectorType::get({32}, rewriter.getI32Type()),
2718 rewriter.getI32Type()}));
2719 } else if (resultVectorSize == 128 && srcVectorSize == 512) {
2720 auto shiftOp = adaptor.getSource();
2721 if (op.getIndex() > 0) {
2722 auto undefOp = xllvm::UndefV16I32IntrOp::create(
2723 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()));
2724 auto stepCst =
2725 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
2726 rewriter.getI32IntegerAttr(0));
2727 auto shiftCst = LLVM::ConstantOp::create(
2728 rewriter, loc, rewriter.getI32Type(),
2729 rewriter.getI32IntegerAttr(op.getIndex() * 16));
2730 SmallVector<Value> shiftOperands{adaptor.getSource(), undefOp, stepCst,
2731 shiftCst};
2732 // Right shift the source vector in index * 16 bytes (i.e. in index *
2733 // 128 bits). The integer index is expected to be 0 to 3.
2734 shiftOp = xllvm::VectorShiftI512I512IntrOp::create(
2735 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()),
2736 forceCastOperandsToSignature(
2737 rewriter, loc, shiftOperands,
2738 {VectorType::get({16}, rewriter.getI32Type()),
2739 VectorType::get({16}, rewriter.getI32Type()),
2740 rewriter.getI32Type(), rewriter.getI32Type()}));
2741 }
2742 // The underlying intrinsic takes a source vector and extract the lowest
2743 // 128-bit. i.e. it always extracts the input vector with index = 0.
2744 extOp = xllvm::ExtI128I512IntrOp::create(
2745 rewriter, loc, VectorType::get({4}, rewriter.getI32Type()),
2746 forceCastOperandsToSignature(
2747 rewriter, loc, /*operands=*/{shiftOp},
2748 {VectorType::get({16}, rewriter.getI32Type())}));
2749 } else {
2750 op.emitWarning() << "aievec.ext with " << srcVectorSize
2751 << "-bit source, and " << resultVectorSize
2752 << "-bit result is not supported.\n";
2753 return failure();
2754 }
2755
2756 // create bitcast/shape_cast for result
2757 auto resultVal =
2758 forceCastValueToType(rewriter, loc, extOp, op.getResult().getType());
2759 rewriter.replaceOp(op, resultVal);
2760
2761 return success();
2762 }
2763};
2764
2765// AIE2p version of ExtOp conversion using vector.shuffle
2767 : public mlir::ConvertOpToLLVMPattern<aievec::ExtOp> {
2768public:
2769 using ConvertOpToLLVMPattern<aievec::ExtOp>::ConvertOpToLLVMPattern;
2770
2771 LogicalResult
2772 matchAndRewrite(aievec::ExtOp op, OpAdaptor adaptor,
2773 ConversionPatternRewriter &rewriter) const override {
2774 Location loc = op.getLoc();
2775
2776 Value src = adaptor.getSource();
2777 VectorType srcType = cast<VectorType>(src.getType());
2778 VectorType resultType = cast<VectorType>(op.getResult().getType());
2779
2780 int srcLanes = getVectorLaneSize(srcType);
2781 int resultLanes = getVectorLaneSize(resultType);
2782
2783 // Verify this is extracting half the vector
2784 if (srcLanes != 2 * resultLanes) {
2785 op.emitWarning() << "aievec.ext with non-half extraction is not "
2786 "supported for AIE2p.\n";
2787 return failure();
2788 }
2789
2790 // Build shuffle mask based on index
2791 // index 0: extract lower half [0, 1, ..., resultLanes-1]
2792 // index 1: extract upper half [resultLanes, ..., srcLanes-1]
2793 SmallVector<int64_t> shuffleMask;
2794 int startIdx = op.getIndex() * resultLanes;
2795 for (int i = 0; i < resultLanes; ++i) {
2796 shuffleMask.push_back(startIdx + i);
2797 }
2798
2799 // Use vector.shuffle to extract the half
2800 auto extracted =
2801 vector::ShuffleOp::create(rewriter, loc, src, src, shuffleMask);
2802
2803 rewriter.replaceOp(op, extracted);
2804 return success();
2805 }
2806};
2807
2809 : public mlir::ConvertOpToLLVMPattern<aievec::aie1::SelectOp> {
2810public:
2811 using ConvertOpToLLVMPattern<aievec::aie1::SelectOp>::ConvertOpToLLVMPattern;
2812
2813 static std::string getIntrinsicName(aievec::aie1::SelectOp op) {
2814 auto xbuffType = cast<VectorType>(op.getXbuff().getType());
2815 std::stringstream ss;
2816 ss << "llvm.aie.prim." << getVectorTypeString(xbuffType) << ".select";
2817 return ss.str();
2818 }
2819
2820 LogicalResult
2821 matchAndRewrite(aievec::aie1::SelectOp op, OpAdaptor adaptor,
2822 ConversionPatternRewriter &rewriter) const override {
2823 auto module = op->getParentOfType<ModuleOp>();
2824 MLIRContext *context = rewriter.getContext();
2825
2826 auto selectType = IntegerType::get(context, 32);
2827 auto startType = IntegerType::get(context, 32);
2828 auto offsetsType = VectorType::get({2}, IntegerType::get(context, 32));
2829 auto confType = VectorType::get({2}, IntegerType::get(context, 32));
2830
2831 // If the intrinsic declaration doesn't exist, create it
2832 std::string intrinsicName = getIntrinsicName(op);
2833 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(
2834 StringAttr::get(context, intrinsicName));
2835
2836 if (!func) {
2837 OpBuilder::InsertionGuard guard(rewriter);
2838 rewriter.setInsertionPointToStart(module.getBody());
2839 func = LLVM::LLVMFuncOp::create(
2840 rewriter, rewriter.getUnknownLoc(), intrinsicName,
2841 LLVM::LLVMFunctionType::get(op.getResult().getType(),
2842 {op.getXbuff().getType(), selectType,
2843 startType, /* xstart */
2844 startType, /* ystart */
2845 offsetsType, /* xoffsets */
2846 offsetsType, /* yoffsets */
2847 confType}));
2848 }
2849
2850 // Parse the string attribute values
2851 uint32_t select = 0;
2852 BufferParams x = {};
2853 BufferParams y = {};
2854 BufferParams z = {};
2855
2856 op.getSelect().getAsInteger(0, select);
2857 op.getXstart().getAsInteger(0, x.start);
2858 op.getXoffsets().getAsInteger(0, x.offsets);
2859 op.getXoffsetsHi().getAsInteger(0, x.offsets_hi);
2860 op.getXsquare().getAsInteger(0, x.square);
2861 op.getYstart().getAsInteger(0, y.start);
2862 op.getYoffsets().getAsInteger(0, y.offsets);
2863 op.getYoffsetsHi().getAsInteger(0, y.offsets_hi);
2864 op.getYsquare().getAsInteger(0, y.square);
2865
2866 // Encode the configuration register
2867 uint32_t conf[2] = {0, 0};
2868 encodeConf(conf, x, z, false);
2869 conf[1] |= encodeSquare(y.square) << 21;
2870
2871 // Create the constants and replace the op
2872 auto selectVal = LLVM::ConstantOp::create(
2873 rewriter, op->getLoc(), selectType, rewriter.getI32IntegerAttr(select));
2874 auto xstartVal = LLVM::ConstantOp::create(
2875 rewriter, op->getLoc(), startType, rewriter.getI32IntegerAttr(x.start));
2876 auto ystartVal = LLVM::ConstantOp::create(
2877 rewriter, op->getLoc(), startType, rewriter.getI32IntegerAttr(y.start));
2878 auto xoffsetsVal = LLVM::ConstantOp::create(
2879 rewriter, op->getLoc(), offsetsType,
2880 rewriter.getI32VectorAttr({(int32_t)x.offsets, (int32_t)x.offsets_hi}));
2881 auto yoffsetsVal = LLVM::ConstantOp::create(
2882 rewriter, op->getLoc(), offsetsType,
2883 rewriter.getI32VectorAttr({(int32_t)y.offsets, (int32_t)y.offsets_hi}));
2884 auto confVal = LLVM::ConstantOp::create(
2885 rewriter, op->getLoc(), confType,
2886 rewriter.getI32VectorAttr({(int32_t)conf[0], (int32_t)conf[1]}));
2887 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
2888 op, func,
2889 ValueRange{op.getXbuff(), selectVal, xstartVal, ystartVal, xoffsetsVal,
2890 yoffsetsVal, confVal});
2891 return success();
2892 }
2893};
2894
2895class PackOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::PackOp> {
2896public:
2897 using ConvertOpToLLVMPattern<aievec::PackOp>::ConvertOpToLLVMPattern;
2898
2899 static std::string getIntrinsicName(aievec::PackOp op) {
2900 auto sourceType = cast<VectorType>(op.getSource().getType());
2901 std::stringstream ss;
2902 ss << "llvm.aie.pack." << getVectorTypeString(sourceType);
2903 return ss.str();
2904 }
2905
2906 LogicalResult
2907 matchAndRewrite(aievec::PackOp op, OpAdaptor adaptor,
2908 ConversionPatternRewriter &rewriter) const override {
2909 auto module = op->getParentOfType<ModuleOp>();
2910 MLIRContext *context = rewriter.getContext();
2911
2912 // If the intrinsic declaration doesn't exist, create it
2913 std::string intrinsicName = getIntrinsicName(op);
2914 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(
2915 StringAttr::get(context, intrinsicName));
2916
2917 if (!func) {
2918 OpBuilder::InsertionGuard guard(rewriter);
2919 rewriter.setInsertionPointToStart(module.getBody());
2920 func = LLVM::LLVMFuncOp::create(
2921 rewriter, rewriter.getUnknownLoc(), intrinsicName,
2922 LLVM::LLVMFunctionType::get(op.getResult().getType(),
2923 {op.getSource().getType()}));
2924 }
2925
2926 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, func,
2927 ValueRange{op.getSource()});
2928 return success();
2929 }
2930};
2931
2933 : public mlir::ConvertOpToLLVMPattern<aievec::UnpackOp> {
2934public:
2935 using ConvertOpToLLVMPattern<aievec::UnpackOp>::ConvertOpToLLVMPattern;
2936
2937 LogicalResult
2938 matchAndRewrite(aievec::UnpackOp op, OpAdaptor adaptor,
2939 ConversionPatternRewriter &rewriter) const override {
2940 op.emitWarning() << "aie.unpack conversion is not implemented\n";
2941 return failure();
2942 }
2943};
2944
2946 : public mlir::ConvertOpToLLVMPattern<aievec::BroadcastOp> {
2947public:
2948 using ConvertOpToLLVMPattern<aievec::BroadcastOp>::ConvertOpToLLVMPattern;
2949
2950 LogicalResult
2951 matchAndRewrite(aievec::BroadcastOp op, OpAdaptor adaptor,
2952 ConversionPatternRewriter &rewriter) const override {
2953 op.emitWarning() << "aie.broadcast conversion is not implemented\n";
2954 return failure();
2955 }
2956};
2957
2958// Helper to pad a vector from N lanes to dstLanes using vector.shuffle
2959// with poison (-1) for the upper portion.
2960static Value padVectorWithPoison(ConversionPatternRewriter &rewriter,
2961 Location loc, Value vec, int srcLanes,
2962 int dstLanes) {
2963 SmallVector<int64_t> padMask;
2964 for (int i = 0; i < srcLanes; ++i)
2965 padMask.push_back(i);
2966 for (int i = srcLanes; i < dstLanes; ++i)
2967 padMask.push_back(-1);
2968 return vector::ShuffleOp::create(rewriter, loc, vec, vec, padMask);
2969}
2970
2971// Helper to extract the first N lanes from a wider vector using
2972// vector.shuffle.
2973static Value extractLowerLanes(ConversionPatternRewriter &rewriter,
2974 Location loc, Value vec, int lanes) {
2975 SmallVector<int64_t> extractMask;
2976 for (int i = 0; i < lanes; ++i)
2977 extractMask.push_back(i);
2978 return vector::ShuffleOp::create(rewriter, loc, vec, vec, extractMask);
2979}
2980
2981class MaxOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::MaxOp> {
2982public:
2983 using ConvertOpToLLVMPattern<aievec::MaxOp>::ConvertOpToLLVMPattern;
2984
2985 LogicalResult
2986 matchAndRewrite(aievec::MaxOp op, OpAdaptor adaptor,
2987 ConversionPatternRewriter &rewriter) const override {
2988 Location loc = op.getLoc();
2989
2990 VectorType resultType = cast<VectorType>(op.getResult().getType());
2991 Type resultScaTy = resultType.getElementType();
2992 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
2993 int resultLanes = getVectorLaneSize(resultType);
2994 int resultVectorSize = resultBitWidth * resultLanes;
2995
2996 // aievec.max op has the AllTypesMatch constraint on lhs/rhs/res
2997 if (resultVectorSize != 512 && resultVectorSize != 256) {
2998 op.emitWarning() << "aievec.max conversion with " << resultVectorSize
2999 << "-bit result is not supported.\n";
3000 return failure();
3001 }
3002
3003 // create xllvm intrinsic
3004 Value maxOp = nullptr;
3005 if (llvm::isa<IntegerType>(resultScaTy)) {
3006 // create constant for third operand `cmp`
3007 // Note: `cmp` is implicitly treated as `sign` to the vmax intrinsic
3008 auto cmpCst = LLVM::ConstantOp::create(
3009 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
3010 SmallVector<Value> operands{adaptor.getLhs(), adaptor.getRhs(), cmpCst};
3011 if (resultBitWidth == 8) {
3012 maxOp = xllvm::VectorMaxLt8IntrOp::create(
3013 rewriter, loc,
3014 mlir::LLVM::LLVMStructType::getLiteral(
3015 rewriter.getContext(),
3016 {VectorType::get({64}, rewriter.getI8Type()),
3017 VectorType::get({2}, rewriter.getI32Type())}),
3018 forceCastOperandsToSignature(
3019 rewriter, loc, operands,
3020 {VectorType::get({64}, rewriter.getI8Type()),
3021 VectorType::get({64}, rewriter.getI8Type()),
3022 rewriter.getI32Type()}));
3023 } else if (resultBitWidth == 16) {
3024 maxOp = xllvm::VectorMaxLt16IntrOp::create(
3025 rewriter, loc,
3026 mlir::LLVM::LLVMStructType::getLiteral(
3027 rewriter.getContext(),
3028 {VectorType::get({32}, rewriter.getI16Type()),
3029 rewriter.getI32Type()}),
3030 forceCastOperandsToSignature(
3031 rewriter, loc, operands,
3032 {VectorType::get({32}, rewriter.getI16Type()),
3033 VectorType::get({32}, rewriter.getI16Type()),
3034 rewriter.getI32Type()}));
3035 } else if (resultBitWidth == 32) {
3036 maxOp = xllvm::VectorMaxLt32IntrOp::create(
3037 rewriter, loc,
3038 mlir::LLVM::LLVMStructType::getLiteral(
3039 rewriter.getContext(),
3040 {VectorType::get({16}, rewriter.getI32Type()),
3041 rewriter.getI32Type()}),
3042 forceCastOperandsToSignature(
3043 rewriter, loc, operands,
3044 {VectorType::get({16}, rewriter.getI32Type()),
3045 VectorType::get({16}, rewriter.getI32Type()),
3046 rewriter.getI32Type()}));
3047 }
3048 } else {
3049 if (resultBitWidth == 16) {
3050 auto v32bf16Ty = VectorType::get({32}, rewriter.getBF16Type());
3051 Value lhs = adaptor.getLhs(), rhs = adaptor.getRhs();
3052
3053 // Pad 16-lane to 32-lane if needed
3054 if (resultLanes == 16) {
3055 lhs = padVectorWithPoison(rewriter, loc, lhs, 16, 32);
3056 rhs = padVectorWithPoison(rewriter, loc, rhs, 16, 32);
3057 }
3058
3059 maxOp = xllvm::VectorMaxLtBf16IntrOp::create(
3060 rewriter, loc,
3061 mlir::LLVM::LLVMStructType::getLiteral(
3062 rewriter.getContext(), {v32bf16Ty, rewriter.getI32Type()}),
3063 forceCastOperandsToSignature(rewriter, loc, {lhs, rhs},
3064 {v32bf16Ty, v32bf16Ty}));
3065 }
3066 }
3067
3068 if (!maxOp) {
3069 op.emitWarning() << "aievec.max conversion fails due to unsupported "
3070 "element data type.\n";
3071 return failure();
3072 }
3073
3074 // Extract the vector result from the struct
3075 Value resultVec = LLVM::ExtractValueOp::create(rewriter, loc, maxOp,
3076 /*position=*/0);
3077 // Truncate back to 16 lanes if padded
3078 if (resultLanes == 16 && !llvm::isa<IntegerType>(resultScaTy))
3079 resultVec = extractLowerLanes(rewriter, loc, resultVec, 16);
3080
3081 rewriter.replaceOp(op, resultVec);
3082
3083 return success();
3084 }
3085};
3086
3087class MinOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::MinOp> {
3088public:
3089 using ConvertOpToLLVMPattern<aievec::MinOp>::ConvertOpToLLVMPattern;
3090
3091 LogicalResult
3092 matchAndRewrite(aievec::MinOp op, OpAdaptor adaptor,
3093 ConversionPatternRewriter &rewriter) const override {
3094 Location loc = op.getLoc();
3095
3096 VectorType resultType = cast<VectorType>(op.getResult().getType());
3097 Type resultScaTy = resultType.getElementType();
3098 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
3099 int resultLanes = getVectorLaneSize(resultType);
3100 int resultVectorSize = resultBitWidth * resultLanes;
3101
3102 // aievec.min op has the AllTypesMatch constraint on lhs/rhs/res
3103 if (resultVectorSize != 512 && resultVectorSize != 256) {
3104 op.emitWarning() << "aievec.min conversion with " << resultVectorSize
3105 << "-bit result is not supported.\n";
3106 return failure();
3107 }
3108
3109 // create xllvm intrinsic
3110 Value minOp = nullptr;
3111 if (llvm::isa<IntegerType>(resultScaTy)) {
3112 // create constant for third operand `cmp`
3113 // Note: `cmp` is implicitly treated as `sign` to the vmin intrinsic
3114 auto cmpCst = LLVM::ConstantOp::create(
3115 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
3116 SmallVector<Value> operands{adaptor.getLhs(), adaptor.getRhs(), cmpCst};
3117 if (resultBitWidth == 8) {
3118 minOp = xllvm::VectorMinGe8IntrOp::create(
3119 rewriter, loc,
3120 mlir::LLVM::LLVMStructType::getLiteral(
3121 rewriter.getContext(),
3122 {VectorType::get({64}, rewriter.getI8Type()),
3123 VectorType::get({2}, rewriter.getI32Type())}),
3124 forceCastOperandsToSignature(
3125 rewriter, loc, operands,
3126 {VectorType::get({64}, rewriter.getI8Type()),
3127 VectorType::get({64}, rewriter.getI8Type()),
3128 rewriter.getI32Type()}));
3129 } else if (resultBitWidth == 16) {
3130 minOp = xllvm::VectorMinGe16IntrOp::create(
3131 rewriter, loc,
3132 mlir::LLVM::LLVMStructType::getLiteral(
3133 rewriter.getContext(),
3134 {VectorType::get({32}, rewriter.getI16Type()),
3135 rewriter.getI32Type()}),
3136 forceCastOperandsToSignature(
3137 rewriter, loc, operands,
3138 {VectorType::get({32}, rewriter.getI16Type()),
3139 VectorType::get({32}, rewriter.getI16Type()),
3140 rewriter.getI32Type()}));
3141 } else if (resultBitWidth == 32) {
3142 minOp = xllvm::VectorMinGe32IntrOp::create(
3143 rewriter, loc,
3144 mlir::LLVM::LLVMStructType::getLiteral(
3145 rewriter.getContext(),
3146 {VectorType::get({16}, rewriter.getI32Type()),
3147 rewriter.getI32Type()}),
3148 forceCastOperandsToSignature(
3149 rewriter, loc, operands,
3150 {VectorType::get({16}, rewriter.getI32Type()),
3151 VectorType::get({16}, rewriter.getI32Type()),
3152 rewriter.getI32Type()}));
3153 }
3154 } else {
3155 if (resultBitWidth == 16) {
3156 auto v32bf16Ty = VectorType::get({32}, rewriter.getBF16Type());
3157 Value lhs = adaptor.getLhs(), rhs = adaptor.getRhs();
3158
3159 // Pad 16-lane to 32-lane if needed
3160 if (resultLanes == 16) {
3161 lhs = padVectorWithPoison(rewriter, loc, lhs, 16, 32);
3162 rhs = padVectorWithPoison(rewriter, loc, rhs, 16, 32);
3163 }
3164
3165 minOp = xllvm::VectorMinGeBf16IntrOp::create(
3166 rewriter, loc,
3167 mlir::LLVM::LLVMStructType::getLiteral(
3168 rewriter.getContext(), {v32bf16Ty, rewriter.getI32Type()}),
3169 forceCastOperandsToSignature(rewriter, loc, {lhs, rhs},
3170 {v32bf16Ty, v32bf16Ty}));
3171 }
3172 }
3173
3174 if (!minOp) {
3175 op.emitWarning() << "aievec.min conversion fails due to unsupported "
3176 "element data type.\n";
3177 return failure();
3178 }
3179
3180 // Extract the vector result from the struct
3181 Value resultVec = LLVM::ExtractValueOp::create(rewriter, loc, minOp,
3182 /*position=*/0);
3183 // Truncate back to 16 lanes if padded
3184 if (resultLanes == 16 && !llvm::isa<IntegerType>(resultScaTy))
3185 resultVec = extractLowerLanes(rewriter, loc, resultVec, 16);
3186
3187 rewriter.replaceOp(op, resultVec);
3188
3189 return success();
3190 }
3191};
3192
3193// AIE2p version of MaxOp conversion
3195 : public mlir::ConvertOpToLLVMPattern<aievec::MaxOp> {
3196public:
3197 using ConvertOpToLLVMPattern<aievec::MaxOp>::ConvertOpToLLVMPattern;
3198
3199 LogicalResult
3200 matchAndRewrite(aievec::MaxOp op, OpAdaptor adaptor,
3201 ConversionPatternRewriter &rewriter) const override {
3202 Location loc = op.getLoc();
3203
3204 VectorType resultType = cast<VectorType>(op.getResult().getType());
3205 Type resultScaTy = resultType.getElementType();
3206 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
3207 int resultLanes = getVectorLaneSize(resultType);
3208 int resultVectorSize = resultBitWidth * resultLanes;
3209
3210 // aievec.max op has the AllTypesMatch constraint on lhs/rhs/res
3211 if (resultVectorSize != 512 && resultVectorSize != 256) {
3212 op.emitWarning() << "aievec.max conversion with " << resultVectorSize
3213 << "-bit result is not supported.\n";
3214 return failure();
3215 }
3216
3217 // create xllvm intrinsic
3218 Value maxOp = nullptr;
3219 if (llvm::isa<IntegerType>(resultScaTy)) {
3220 // create constant for third operand `cmp`
3221 // Note: `cmp` is implicitly treated as `sign` to the vmax intrinsic
3222 auto cmpCst = LLVM::ConstantOp::create(
3223 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
3224 SmallVector<Value> operands{adaptor.getLhs(), adaptor.getRhs(), cmpCst};
3225 if (resultBitWidth == 8) {
3226 maxOp = xllvm::VectorMaxLt8AIE2pIntrOp::create(
3227 rewriter, loc,
3228 mlir::LLVM::LLVMStructType::getLiteral(
3229 rewriter.getContext(),
3230 {VectorType::get({64}, rewriter.getI8Type()),
3231 VectorType::get({2}, rewriter.getI32Type())}),
3232 forceCastOperandsToSignature(
3233 rewriter, loc, operands,
3234 {VectorType::get({64}, rewriter.getI8Type()),
3235 VectorType::get({64}, rewriter.getI8Type()),
3236 rewriter.getI32Type()}));
3237 } else if (resultBitWidth == 16) {
3238 maxOp = xllvm::VectorMaxLt16AIE2pIntrOp::create(
3239 rewriter, loc,
3240 mlir::LLVM::LLVMStructType::getLiteral(
3241 rewriter.getContext(),
3242 {VectorType::get({32}, rewriter.getI16Type()),
3243 rewriter.getI32Type()}),
3244 forceCastOperandsToSignature(
3245 rewriter, loc, operands,
3246 {VectorType::get({32}, rewriter.getI16Type()),
3247 VectorType::get({32}, rewriter.getI16Type()),
3248 rewriter.getI32Type()}));
3249 } else if (resultBitWidth == 32) {
3250 maxOp = xllvm::VectorMaxLt32AIE2pIntrOp::create(
3251 rewriter, loc,
3252 mlir::LLVM::LLVMStructType::getLiteral(
3253 rewriter.getContext(),
3254 {VectorType::get({16}, rewriter.getI32Type()),
3255 rewriter.getI32Type()}),
3256 forceCastOperandsToSignature(
3257 rewriter, loc, operands,
3258 {VectorType::get({16}, rewriter.getI32Type()),
3259 VectorType::get({16}, rewriter.getI32Type()),
3260 rewriter.getI32Type()}));
3261 }
3262 } else {
3263 if (resultBitWidth == 16) {
3264 auto v32bf16Ty = VectorType::get({32}, rewriter.getBF16Type());
3265 Value lhs = adaptor.getLhs(), rhs = adaptor.getRhs();
3266
3267 if (resultLanes == 16) {
3268 lhs = padVectorWithPoison(rewriter, loc, lhs, 16, 32);
3269 rhs = padVectorWithPoison(rewriter, loc, rhs, 16, 32);
3270 }
3271
3272 maxOp = xllvm::VectorMaxLtBf16AIE2pIntrOp::create(
3273 rewriter, loc,
3274 mlir::LLVM::LLVMStructType::getLiteral(
3275 rewriter.getContext(), {v32bf16Ty, rewriter.getI32Type()}),
3276 forceCastOperandsToSignature(rewriter, loc, {lhs, rhs},
3277 {v32bf16Ty, v32bf16Ty}));
3278 }
3279 }
3280
3281 if (!maxOp) {
3282 op.emitWarning() << "aievec.max conversion fails due to unsupported "
3283 "element data type.\n";
3284 return failure();
3285 }
3286
3287 Value resultVec = LLVM::ExtractValueOp::create(rewriter, loc, maxOp,
3288 /*position=*/0);
3289 if (resultLanes == 16 && !llvm::isa<IntegerType>(resultScaTy))
3290 resultVec = extractLowerLanes(rewriter, loc, resultVec, 16);
3291
3292 rewriter.replaceOp(op, resultVec);
3293
3294 return success();
3295 }
3296};
3297
3298// AIE2p version of MinOp conversion
3300 : public mlir::ConvertOpToLLVMPattern<aievec::MinOp> {
3301public:
3302 using ConvertOpToLLVMPattern<aievec::MinOp>::ConvertOpToLLVMPattern;
3303
3304 LogicalResult
3305 matchAndRewrite(aievec::MinOp op, OpAdaptor adaptor,
3306 ConversionPatternRewriter &rewriter) const override {
3307 Location loc = op.getLoc();
3308
3309 VectorType resultType = cast<VectorType>(op.getResult().getType());
3310 Type resultScaTy = resultType.getElementType();
3311 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
3312 int resultLanes = getVectorLaneSize(resultType);
3313 int resultVectorSize = resultBitWidth * resultLanes;
3314
3315 // aievec.min op has the AllTypesMatch constraint on lhs/rhs/res
3316 if (resultVectorSize != 512 && resultVectorSize != 256) {
3317 op.emitWarning() << "aievec.min conversion with " << resultVectorSize
3318 << "-bit result is not supported.\n";
3319 return failure();
3320 }
3321
3322 // create xllvm intrinsic
3323 Value minOp = nullptr;
3324 if (llvm::isa<IntegerType>(resultScaTy)) {
3325 // create constant for third operand `cmp`
3326 // Note: `cmp` is implicitly treated as `sign` to the vmin intrinsic
3327 auto cmpCst = LLVM::ConstantOp::create(
3328 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
3329 SmallVector<Value> operands{adaptor.getLhs(), adaptor.getRhs(), cmpCst};
3330 if (resultBitWidth == 8) {
3331 minOp = xllvm::VectorMinGe8AIE2pIntrOp::create(
3332 rewriter, loc,
3333 mlir::LLVM::LLVMStructType::getLiteral(
3334 rewriter.getContext(),
3335 {VectorType::get({64}, rewriter.getI8Type()),
3336 VectorType::get({2}, rewriter.getI32Type())}),
3337 forceCastOperandsToSignature(
3338 rewriter, loc, operands,
3339 {VectorType::get({64}, rewriter.getI8Type()),
3340 VectorType::get({64}, rewriter.getI8Type()),
3341 rewriter.getI32Type()}));
3342 } else if (resultBitWidth == 16) {
3343 minOp = xllvm::VectorMinGe16AIE2pIntrOp::create(
3344 rewriter, loc,
3345 mlir::LLVM::LLVMStructType::getLiteral(
3346 rewriter.getContext(),
3347 {VectorType::get({32}, rewriter.getI16Type()),
3348 rewriter.getI32Type()}),
3349 forceCastOperandsToSignature(
3350 rewriter, loc, operands,
3351 {VectorType::get({32}, rewriter.getI16Type()),
3352 VectorType::get({32}, rewriter.getI16Type()),
3353 rewriter.getI32Type()}));
3354 } else if (resultBitWidth == 32) {
3355 minOp = xllvm::VectorMinGe32AIE2pIntrOp::create(
3356 rewriter, loc,
3357 mlir::LLVM::LLVMStructType::getLiteral(
3358 rewriter.getContext(),
3359 {VectorType::get({16}, rewriter.getI32Type()),
3360 rewriter.getI32Type()}),
3361 forceCastOperandsToSignature(
3362 rewriter, loc, operands,
3363 {VectorType::get({16}, rewriter.getI32Type()),
3364 VectorType::get({16}, rewriter.getI32Type()),
3365 rewriter.getI32Type()}));
3366 }
3367 } else {
3368 if (resultBitWidth == 16) {
3369 auto v32bf16Ty = VectorType::get({32}, rewriter.getBF16Type());
3370 Value lhs = adaptor.getLhs(), rhs = adaptor.getRhs();
3371
3372 if (resultLanes == 16) {
3373 lhs = padVectorWithPoison(rewriter, loc, lhs, 16, 32);
3374 rhs = padVectorWithPoison(rewriter, loc, rhs, 16, 32);
3375 }
3376
3377 minOp = xllvm::VectorMinGeBf16AIE2pIntrOp::create(
3378 rewriter, loc,
3379 mlir::LLVM::LLVMStructType::getLiteral(
3380 rewriter.getContext(), {v32bf16Ty, rewriter.getI32Type()}),
3381 forceCastOperandsToSignature(rewriter, loc, {lhs, rhs},
3382 {v32bf16Ty, v32bf16Ty}));
3383 }
3384 }
3385
3386 if (!minOp) {
3387 op.emitWarning() << "aievec.min conversion fails due to unsupported "
3388 "element data type.\n";
3389 return failure();
3390 }
3391
3392 Value resultVec = LLVM::ExtractValueOp::create(rewriter, loc, minOp,
3393 /*position=*/0);
3394 if (resultLanes == 16 && !llvm::isa<IntegerType>(resultScaTy))
3395 resultVec = extractLowerLanes(rewriter, loc, resultVec, 16);
3396
3397 rewriter.replaceOp(op, resultVec);
3398
3399 return success();
3400 }
3401};
3402
3403// ----- CmpOp conversion for AIE2 -----
3404// Implements aievec.cmp using vmax.lt / vmin.ge intrinsics, extracting the
3405// comparison bitmask (field 1 of the returned struct).
3406template <typename MaxLtBf16IntrOp, typename MinGeBf16IntrOp,
3407 typename MaxLt32IntrOp, typename MinGe32IntrOp,
3408 typename MaxLt16IntrOp, typename MinGe16IntrOp>
3409class CmpOpConversionBase : public mlir::ConvertOpToLLVMPattern<aievec::CmpOp> {
3410public:
3411 using ConvertOpToLLVMPattern<aievec::CmpOp>::ConvertOpToLLVMPattern;
3412
3413 LogicalResult
3414 matchAndRewrite(aievec::CmpOp op, OpAdaptor adaptor,
3415 ConversionPatternRewriter &rewriter) const override {
3416 Location loc = op.getLoc();
3417 auto vecTy = cast<VectorType>(op.getLhs().getType());
3418 auto elTy = vecTy.getElementType();
3419 unsigned elWidth = elTy.getIntOrFloatBitWidth();
3420 unsigned lanes = getVectorLaneSize(vecTy);
3421 auto pred = op.getPred();
3422
3423 // Handle bf16 vectors (16 or 32 lanes)
3424 if (elWidth == 16 && isa<FloatType>(elTy)) {
3425 auto v32bf16Ty = VectorType::get({32}, rewriter.getBF16Type());
3426 auto structTy = LLVM::LLVMStructType::getLiteral(
3427 rewriter.getContext(), {v32bf16Ty, rewriter.getI32Type()});
3428
3429 Value lhs = adaptor.getLhs(), rhs = adaptor.getRhs();
3430 if (lanes == 16) {
3431 lhs = padVectorWithPoison(rewriter, loc, lhs, 16, 32);
3432 rhs = padVectorWithPoison(rewriter, loc, rhs, 16, 32);
3433 } else if (lanes != 32) {
3434 return failure();
3435 }
3436
3437 Value bitmask;
3438 auto castedLhs =
3439 forceCastOperandsToSignature(rewriter, loc, {lhs}, {v32bf16Ty});
3440 auto castedRhs =
3441 forceCastOperandsToSignature(rewriter, loc, {rhs}, {v32bf16Ty});
3442
3443 if (pred == "slt" || pred == "ult") {
3444 auto intrOp = MaxLtBf16IntrOp::create(
3445 rewriter, loc, structTy, ValueRange{castedLhs[0], castedRhs[0]});
3446 bitmask = LLVM::ExtractValueOp::create(rewriter, loc, intrOp, 1);
3447 } else if (pred == "sge" || pred == "uge") {
3448 auto intrOp = MinGeBf16IntrOp::create(
3449 rewriter, loc, structTy, ValueRange{castedLhs[0], castedRhs[0]});
3450 bitmask = LLVM::ExtractValueOp::create(rewriter, loc, intrOp, 1);
3451 } else if (pred == "sgt" || pred == "ugt") {
3452 // gt(a,b) = lt(b,a): swap operands
3453 auto intrOp = MaxLtBf16IntrOp::create(
3454 rewriter, loc, structTy, ValueRange{castedRhs[0], castedLhs[0]});
3455 bitmask = LLVM::ExtractValueOp::create(rewriter, loc, intrOp, 1);
3456 } else if (pred == "sle" || pred == "ule") {
3457 // le(a,b) = ge(b,a): swap operands
3458 auto intrOp = MinGeBf16IntrOp::create(
3459 rewriter, loc, structTy, ValueRange{castedRhs[0], castedLhs[0]});
3460 bitmask = LLVM::ExtractValueOp::create(rewriter, loc, intrOp, 1);
3461 } else if (pred == "eq") {
3462 // eq(a,b) = ge(a,b) AND ge(b,a)
3463 auto geAB = MinGeBf16IntrOp::create(
3464 rewriter, loc, structTy, ValueRange{castedLhs[0], castedRhs[0]});
3465 auto geBA = MinGeBf16IntrOp::create(
3466 rewriter, loc, structTy, ValueRange{castedRhs[0], castedLhs[0]});
3467 auto maskAB = LLVM::ExtractValueOp::create(rewriter, loc, geAB, 1);
3468 auto maskBA = LLVM::ExtractValueOp::create(rewriter, loc, geBA, 1);
3469 bitmask = LLVM::AndOp::create(rewriter, loc, maskAB, maskBA);
3470 } else if (pred == "ne") {
3471 // ne(a,b) = lt(a,b) OR lt(b,a)
3472 auto ltAB = MaxLtBf16IntrOp::create(
3473 rewriter, loc, structTy, ValueRange{castedLhs[0], castedRhs[0]});
3474 auto ltBA = MaxLtBf16IntrOp::create(
3475 rewriter, loc, structTy, ValueRange{castedRhs[0], castedLhs[0]});
3476 auto maskAB = LLVM::ExtractValueOp::create(rewriter, loc, ltAB, 1);
3477 auto maskBA = LLVM::ExtractValueOp::create(rewriter, loc, ltBA, 1);
3478 bitmask = LLVM::OrOp::create(rewriter, loc, maskAB, maskBA);
3479 } else {
3480 return failure();
3481 }
3482
3483 // Mask off upper bits for 16-lane inputs
3484 if (lanes == 16) {
3485 auto mask =
3486 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
3487 rewriter.getI32IntegerAttr(0xFFFF));
3488 bitmask = LLVM::AndOp::create(rewriter, loc, bitmask, mask);
3489 }
3490
3491 // Cast i32 bitmask to the op's unsigned integer result type
3492 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
3493 op, op.getResult().getType(), bitmask);
3494 return success();
3495 }
3496
3497 // Handle i32 vectors
3498 if (elWidth == 32 && isa<IntegerType>(elTy) && lanes == 16) {
3499 auto v16i32Ty = VectorType::get({16}, rewriter.getI32Type());
3500 auto structTy = LLVM::LLVMStructType::getLiteral(
3501 rewriter.getContext(), {v16i32Ty, rewriter.getI32Type()});
3502 auto cmpCst = LLVM::ConstantOp::create(
3503 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
3504 Value bitmask;
3505
3506 if (pred == "slt" || pred == "ult") {
3507 auto intrOp = MaxLt32IntrOp::create(
3508 rewriter, loc, structTy,
3509 forceCastOperandsToSignature(
3510 rewriter, loc, {adaptor.getLhs(), adaptor.getRhs(), cmpCst},
3511 {v16i32Ty, v16i32Ty, rewriter.getI32Type()}));
3512 bitmask = LLVM::ExtractValueOp::create(rewriter, loc, intrOp, 1);
3513 } else if (pred == "sge" || pred == "uge") {
3514 auto intrOp = MinGe32IntrOp::create(
3515 rewriter, loc, structTy,
3516 forceCastOperandsToSignature(
3517 rewriter, loc, {adaptor.getLhs(), adaptor.getRhs(), cmpCst},
3518 {v16i32Ty, v16i32Ty, rewriter.getI32Type()}));
3519 bitmask = LLVM::ExtractValueOp::create(rewriter, loc, intrOp, 1);
3520 } else if (pred == "sgt" || pred == "ugt") {
3521 auto intrOp = MaxLt32IntrOp::create(
3522 rewriter, loc, structTy,
3523 forceCastOperandsToSignature(
3524 rewriter, loc, {adaptor.getRhs(), adaptor.getLhs(), cmpCst},
3525 {v16i32Ty, v16i32Ty, rewriter.getI32Type()}));
3526 bitmask = LLVM::ExtractValueOp::create(rewriter, loc, intrOp, 1);
3527 } else if (pred == "sle" || pred == "ule") {
3528 auto intrOp = MinGe32IntrOp::create(
3529 rewriter, loc, structTy,
3530 forceCastOperandsToSignature(
3531 rewriter, loc, {adaptor.getRhs(), adaptor.getLhs(), cmpCst},
3532 {v16i32Ty, v16i32Ty, rewriter.getI32Type()}));
3533 bitmask = LLVM::ExtractValueOp::create(rewriter, loc, intrOp, 1);
3534 } else if (pred == "eq") {
3535 auto geAB = MinGe32IntrOp::create(
3536 rewriter, loc, structTy,
3537 forceCastOperandsToSignature(
3538 rewriter, loc, {adaptor.getLhs(), adaptor.getRhs(), cmpCst},
3539 {v16i32Ty, v16i32Ty, rewriter.getI32Type()}));
3540 auto geBA = MinGe32IntrOp::create(
3541 rewriter, loc, structTy,
3542 forceCastOperandsToSignature(
3543 rewriter, loc, {adaptor.getRhs(), adaptor.getLhs(), cmpCst},
3544 {v16i32Ty, v16i32Ty, rewriter.getI32Type()}));
3545 auto maskAB = LLVM::ExtractValueOp::create(rewriter, loc, geAB, 1);
3546 auto maskBA = LLVM::ExtractValueOp::create(rewriter, loc, geBA, 1);
3547 bitmask = LLVM::AndOp::create(rewriter, loc, maskAB, maskBA);
3548 } else if (pred == "ne") {
3549 auto ltAB = MaxLt32IntrOp::create(
3550 rewriter, loc, structTy,
3551 forceCastOperandsToSignature(
3552 rewriter, loc, {adaptor.getLhs(), adaptor.getRhs(), cmpCst},
3553 {v16i32Ty, v16i32Ty, rewriter.getI32Type()}));
3554 auto ltBA = MaxLt32IntrOp::create(
3555 rewriter, loc, structTy,
3556 forceCastOperandsToSignature(
3557 rewriter, loc, {adaptor.getRhs(), adaptor.getLhs(), cmpCst},
3558 {v16i32Ty, v16i32Ty, rewriter.getI32Type()}));
3559 auto maskAB = LLVM::ExtractValueOp::create(rewriter, loc, ltAB, 1);
3560 auto maskBA = LLVM::ExtractValueOp::create(rewriter, loc, ltBA, 1);
3561 bitmask = LLVM::OrOp::create(rewriter, loc, maskAB, maskBA);
3562 } else {
3563 return failure();
3564 }
3565
3566 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
3567 op, op.getResult().getType(), bitmask);
3568 return success();
3569 }
3570
3571 // Handle i16 vectors
3572 if (elWidth == 16 && isa<IntegerType>(elTy) && lanes == 32) {
3573 auto v32i16Ty = VectorType::get({32}, rewriter.getI16Type());
3574 auto structTy = LLVM::LLVMStructType::getLiteral(
3575 rewriter.getContext(), {v32i16Ty, rewriter.getI32Type()});
3576 auto cmpCst = LLVM::ConstantOp::create(
3577 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
3578 Value bitmask;
3579
3580 if (pred == "slt" || pred == "ult") {
3581 auto intrOp = MaxLt16IntrOp::create(
3582 rewriter, loc, structTy,
3583 forceCastOperandsToSignature(
3584 rewriter, loc, {adaptor.getLhs(), adaptor.getRhs(), cmpCst},
3585 {v32i16Ty, v32i16Ty, rewriter.getI32Type()}));
3586 bitmask = LLVM::ExtractValueOp::create(rewriter, loc, intrOp, 1);
3587 } else if (pred == "sge" || pred == "uge") {
3588 auto intrOp = MinGe16IntrOp::create(
3589 rewriter, loc, structTy,
3590 forceCastOperandsToSignature(
3591 rewriter, loc, {adaptor.getLhs(), adaptor.getRhs(), cmpCst},
3592 {v32i16Ty, v32i16Ty, rewriter.getI32Type()}));
3593 bitmask = LLVM::ExtractValueOp::create(rewriter, loc, intrOp, 1);
3594 } else if (pred == "sgt" || pred == "ugt") {
3595 auto intrOp = MaxLt16IntrOp::create(
3596 rewriter, loc, structTy,
3597 forceCastOperandsToSignature(
3598 rewriter, loc, {adaptor.getRhs(), adaptor.getLhs(), cmpCst},
3599 {v32i16Ty, v32i16Ty, rewriter.getI32Type()}));
3600 bitmask = LLVM::ExtractValueOp::create(rewriter, loc, intrOp, 1);
3601 } else if (pred == "sle" || pred == "ule") {
3602 auto intrOp = MinGe16IntrOp::create(
3603 rewriter, loc, structTy,
3604 forceCastOperandsToSignature(
3605 rewriter, loc, {adaptor.getRhs(), adaptor.getLhs(), cmpCst},
3606 {v32i16Ty, v32i16Ty, rewriter.getI32Type()}));
3607 bitmask = LLVM::ExtractValueOp::create(rewriter, loc, intrOp, 1);
3608 } else if (pred == "eq") {
3609 auto geAB = MinGe16IntrOp::create(
3610 rewriter, loc, structTy,
3611 forceCastOperandsToSignature(
3612 rewriter, loc, {adaptor.getLhs(), adaptor.getRhs(), cmpCst},
3613 {v32i16Ty, v32i16Ty, rewriter.getI32Type()}));
3614 auto geBA = MinGe16IntrOp::create(
3615 rewriter, loc, structTy,
3616 forceCastOperandsToSignature(
3617 rewriter, loc, {adaptor.getRhs(), adaptor.getLhs(), cmpCst},
3618 {v32i16Ty, v32i16Ty, rewriter.getI32Type()}));
3619 auto maskAB = LLVM::ExtractValueOp::create(rewriter, loc, geAB, 1);
3620 auto maskBA = LLVM::ExtractValueOp::create(rewriter, loc, geBA, 1);
3621 bitmask = LLVM::AndOp::create(rewriter, loc, maskAB, maskBA);
3622 } else if (pred == "ne") {
3623 auto ltAB = MaxLt16IntrOp::create(
3624 rewriter, loc, structTy,
3625 forceCastOperandsToSignature(
3626 rewriter, loc, {adaptor.getLhs(), adaptor.getRhs(), cmpCst},
3627 {v32i16Ty, v32i16Ty, rewriter.getI32Type()}));
3628 auto ltBA = MaxLt16IntrOp::create(
3629 rewriter, loc, structTy,
3630 forceCastOperandsToSignature(
3631 rewriter, loc, {adaptor.getRhs(), adaptor.getLhs(), cmpCst},
3632 {v32i16Ty, v32i16Ty, rewriter.getI32Type()}));
3633 auto maskAB = LLVM::ExtractValueOp::create(rewriter, loc, ltAB, 1);
3634 auto maskBA = LLVM::ExtractValueOp::create(rewriter, loc, ltBA, 1);
3635 bitmask = LLVM::OrOp::create(rewriter, loc, maskAB, maskBA);
3636 } else {
3637 return failure();
3638 }
3639
3640 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
3641 op, op.getResult().getType(), bitmask);
3642 return success();
3643 }
3644
3645 return failure();
3646 }
3647};
3648
3650 CmpOpConversionBase<xllvm::VectorMaxLtBf16IntrOp,
3651 xllvm::VectorMinGeBf16IntrOp,
3652 xllvm::VectorMaxLt32IntrOp, xllvm::VectorMinGe32IntrOp,
3653 xllvm::VectorMaxLt16IntrOp, xllvm::VectorMinGe16IntrOp>;
3654
3656 xllvm::VectorMaxLtBf16AIE2pIntrOp, xllvm::VectorMinGeBf16AIE2pIntrOp,
3657 xllvm::VectorMaxLt32AIE2pIntrOp, xllvm::VectorMinGe32AIE2pIntrOp,
3658 xllvm::VectorMaxLt16AIE2pIntrOp, xllvm::VectorMinGe16AIE2pIntrOp>;
3659
3660// ----- SelOp conversion for AIE2/AIE2p -----
3661// Implements aievec.sel using vsel16/vsel32 intrinsics.
3662// For bf16 vectors, bitcasts to i16, calls vsel16, bitcasts back.
3663template <typename Sel16IntrOp, typename Sel32IntrOp>
3664class SelOpConversionBase : public mlir::ConvertOpToLLVMPattern<aievec::SelOp> {
3665public:
3666 using ConvertOpToLLVMPattern<aievec::SelOp>::ConvertOpToLLVMPattern;
3667
3668 LogicalResult
3669 matchAndRewrite(aievec::SelOp op, OpAdaptor adaptor,
3670 ConversionPatternRewriter &rewriter) const override {
3671 Location loc = op.getLoc();
3672 auto resultType = cast<VectorType>(op.getResult().getType());
3673 auto elTy = resultType.getElementType();
3674 unsigned elWidth = elTy.getIntOrFloatBitWidth();
3675 unsigned lanes = getVectorLaneSize(resultType);
3676 auto i32Ty = rewriter.getI32Type();
3677
3678 // Cast the sel bitmask from unsigned integer to i32
3679 Value selMask = adaptor.getSel();
3680 if (selMask.getType() != i32Ty)
3681 selMask =
3682 UnrealizedConversionCastOp::create(rewriter, loc, i32Ty, selMask)
3683 .getResult(0);
3684
3685 // Handle bf16 vectors (16 or 32 lanes) via bitcast to i16
3686 if (elWidth == 16 && isa<FloatType>(elTy)) {
3687 auto v32i16Ty = VectorType::get({32}, rewriter.getI16Type());
3688 auto v32bf16Ty = VectorType::get({32}, rewriter.getBF16Type());
3689
3690 Value lhs = adaptor.getLhs(), rhs = adaptor.getRhs();
3691 bool needExtract = false;
3692
3693 if (lanes == 16) {
3694 lhs = padVectorWithPoison(rewriter, loc, lhs, 16, 32);
3695 rhs = padVectorWithPoison(rewriter, loc, rhs, 16, 32);
3696 needExtract = true;
3697 } else if (lanes != 32) {
3698 return failure();
3699 }
3700
3701 // Bitcast bf16 -> i16
3702 auto lhsCast = forceCastValueToType(rewriter, loc, lhs, v32bf16Ty);
3703 auto lhsI16 = LLVM::BitcastOp::create(rewriter, loc, v32i16Ty, lhsCast);
3704 auto rhsCast = forceCastValueToType(rewriter, loc, rhs, v32bf16Ty);
3705 auto rhsI16 = LLVM::BitcastOp::create(rewriter, loc, v32i16Ty, rhsCast);
3706
3707 auto selResult = Sel16IntrOp::create(
3708 rewriter, loc, v32i16Ty,
3709 forceCastOperandsToSignature(rewriter, loc, {lhsI16, rhsI16, selMask},
3710 {v32i16Ty, v32i16Ty, i32Ty}));
3711
3712 // Bitcast i16 -> bf16
3713 Value result =
3714 LLVM::BitcastOp::create(rewriter, loc, v32bf16Ty, selResult);
3715
3716 if (needExtract)
3717 result = extractLowerLanes(rewriter, loc, result, 16);
3718
3719 rewriter.replaceOp(op, result);
3720 return success();
3721 }
3722
3723 // Handle i32 vectors (16 lanes)
3724 if (elWidth == 32 && isa<IntegerType>(elTy) && lanes == 16) {
3725 auto v16i32Ty = VectorType::get({16}, rewriter.getI32Type());
3726 auto selResult = Sel32IntrOp::create(
3727 rewriter, loc, v16i32Ty,
3728 forceCastOperandsToSignature(
3729 rewriter, loc, {adaptor.getLhs(), adaptor.getRhs(), selMask},
3730 {v16i32Ty, v16i32Ty, i32Ty}));
3731 rewriter.replaceOp(op, selResult->getResult(0));
3732 return success();
3733 }
3734
3735 // Handle i16 vectors (32 lanes)
3736 if (elWidth == 16 && isa<IntegerType>(elTy) && lanes == 32) {
3737 auto v32i16Ty = VectorType::get({32}, rewriter.getI16Type());
3738 auto selResult = Sel16IntrOp::create(
3739 rewriter, loc, v32i16Ty,
3740 forceCastOperandsToSignature(
3741 rewriter, loc, {adaptor.getLhs(), adaptor.getRhs(), selMask},
3742 {v32i16Ty, v32i16Ty, i32Ty}));
3743 rewriter.replaceOp(op, selResult->getResult(0));
3744 return success();
3745 }
3746
3747 return failure();
3748 }
3749};
3750
3753
3754using SelOpAIE2pConversion = SelOpConversionBase<xllvm::VectorSel16AIE2pIntrOp,
3755 xllvm::VectorSel32AIE2pIntrOp>;
3756
3758 : public mlir::ConvertOpToLLVMPattern<aievec::BroadcastScalarOp> {
3759public:
3760 using ConvertOpToLLVMPattern<
3761 aievec::BroadcastScalarOp>::ConvertOpToLLVMPattern;
3762
3763 LogicalResult
3764 matchAndRewrite(aievec::BroadcastScalarOp op, OpAdaptor adaptor,
3765 ConversionPatternRewriter &rewriter) const override {
3766 Location loc = op.getLoc();
3767
3768 Value result = op.getResult();
3769 VectorType resultType = cast<VectorType>(result.getType());
3770 Type resultScaTy = resultType.getElementType();
3771 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
3772 int resultLanes = getVectorLaneSize(resultType);
3773 int resultVectorSize = resultBitWidth * resultLanes;
3774
3775 if (resultVectorSize != 512) {
3776 op.emitWarning()
3777 << "aievec.broadcast_scalar conversion with result vector size "
3778 << resultVectorSize << " is not implemented.\n";
3779 return failure();
3780 }
3781
3782 // Integer types
3783 if (llvm::isa<IntegerType>(resultScaTy)) {
3784 Value src = adaptor.getSource();
3785 Type srcType = src.getType();
3786 unsigned srcBitWidth = srcType.getIntOrFloatBitWidth();
3787
3788 if (srcBitWidth < 32) {
3789 src = LLVM::SExtOp::create(rewriter, loc, rewriter.getI32Type(), src);
3790 }
3791
3792 if (resultBitWidth == 8) {
3793 rewriter.replaceOpWithNewOp<xllvm::VectorBroadcast8I512IntrOp>(
3794 op, VectorType::get({64}, rewriter.getI8Type()), src);
3795 } else if (resultBitWidth == 16) {
3796 rewriter.replaceOpWithNewOp<xllvm::VectorBroadcast16I512IntrOp>(
3797 op, VectorType::get({32}, rewriter.getI16Type()), src);
3798 } else if (resultBitWidth == 32) {
3799 rewriter.replaceOpWithNewOp<xllvm::VectorBroadcast32I512IntrOp>(
3800 op, VectorType::get({16}, rewriter.getI32Type()), src);
3801 } else {
3802 op.emitWarning()
3803 << "aievec.broadcast_scalar conversion with result bitwidth "
3804 << resultBitWidth << " is not implemented.\n";
3805 return failure();
3806 }
3807 } else {
3808 // Float types
3809 if (resultBitWidth == 16) {
3810 rewriter.replaceOpWithNewOp<xllvm::VectorBroadcast16BF512IntrOp>(
3811 op, VectorType::get({32}, rewriter.getBF16Type()),
3812 adaptor.getSource());
3813 } else if (resultBitWidth == 32) {
3814 // Use llvm.aie2.vbroadcast32.I512 with bitcasts (float -> i32 -> float)
3815 // Following the pattern: %0 = bitcast float %b to i32
3816 // %1 = tail call <16 x i32>
3817 // @llvm.aie2.vbroadcast32.I512(i32 %0) %2 =
3818 // bitcast <16 x i32> %1 to <16 x float>
3819 auto srcAsI32 = bitcastValueToType(rewriter, loc, adaptor.getSource(),
3820 rewriter.getI32Type());
3821 auto broadcastI32 = xllvm::VectorBroadcast32I512IntrOp::create(
3822 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()),
3823 srcAsI32);
3824 auto resultF32 =
3825 bitcastValueToType(rewriter, loc, broadcastI32,
3826 VectorType::get({16}, rewriter.getF32Type()));
3827 rewriter.replaceOp(op, resultF32);
3828 } else {
3829 op.emitWarning()
3830 << "aievec.broadcast_scalar conversion with result bitwidth "
3831 << resultBitWidth << " is not implemented.\n";
3832 return failure();
3833 }
3834 }
3835
3836 return success();
3837 }
3838};
3839
3840// AIE2p version of BroadcastScalarOp conversion using insertelement +
3841// shufflevector
3843 : public mlir::ConvertOpToLLVMPattern<aievec::BroadcastScalarOp> {
3844public:
3845 using ConvertOpToLLVMPattern<
3846 aievec::BroadcastScalarOp>::ConvertOpToLLVMPattern;
3847
3848 LogicalResult
3849 matchAndRewrite(aievec::BroadcastScalarOp op, OpAdaptor adaptor,
3850 ConversionPatternRewriter &rewriter) const override {
3851 Location loc = op.getLoc();
3852
3853 Value result = op.getResult();
3854 VectorType resultType = cast<VectorType>(result.getType());
3855 Type resultScaTy = resultType.getElementType();
3856 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
3857 int resultLanes = getVectorLaneSize(resultType);
3858 int resultVectorSize = resultBitWidth * resultLanes;
3859
3860 // Support both 256-bit and 512-bit vectors for AIE2p
3861 if (resultVectorSize != 256 && resultVectorSize != 512) {
3862 op.emitWarning()
3863 << "aievec.broadcast_scalar conversion with result vector size "
3864 << resultVectorSize << " is not implemented for AIE2p.\n";
3865 return failure();
3866 }
3867
3868 Value src = adaptor.getSource();
3869 Type srcType = src.getType();
3870
3871 // For integer types, extend or truncate to match result element type
3872 if (llvm::isa<IntegerType>(resultScaTy)) {
3873 unsigned srcBitWidth = srcType.getIntOrFloatBitWidth();
3874 if (srcBitWidth < resultBitWidth) {
3875 src = LLVM::SExtOp::create(rewriter, loc, resultScaTy, src);
3876 } else if (srcBitWidth > resultBitWidth) {
3877 src = LLVM::TruncOp::create(rewriter, loc, resultScaTy, src);
3878 }
3879 }
3880
3881 // Create poison vector of the result type
3882 auto poisonVec = LLVM::PoisonOp::create(rewriter, loc, resultType);
3883
3884 // Insert scalar at position 0
3885 auto idx0 = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
3886 rewriter.getI64IntegerAttr(0));
3887 auto insertedVec = LLVM::InsertElementOp::create(rewriter, loc, resultType,
3888 poisonVec, src, idx0);
3889
3890 // Create shufflevector mask with all zeros (broadcast position 0 to all
3891 // lanes)
3892 SmallVector<int64_t> broadcastMask(resultLanes, 0);
3893 auto broadcastVec = vector::ShuffleOp::create(rewriter, loc, insertedVec,
3894 insertedVec, broadcastMask);
3895
3896 rewriter.replaceOp(op, broadcastVec);
3897 return success();
3898 }
3899};
3900
3901class ShiftOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::ShiftOp> {
3902public:
3903 using ConvertOpToLLVMPattern<aievec::ShiftOp>::ConvertOpToLLVMPattern;
3904
3905 LogicalResult
3906 matchAndRewrite(aievec::ShiftOp op, OpAdaptor adaptor,
3907 ConversionPatternRewriter &rewriter) const override {
3908 Location loc = op.getLoc();
3909
3910 Value result = op.getResult();
3911 VectorType resultType = cast<VectorType>(result.getType());
3912 Type resultScaTy = resultType.getElementType();
3913 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
3914 int resultLanes = getVectorLaneSize(resultType);
3915 int resultVectorSize = resultBitWidth * resultLanes;
3916
3917 if (resultVectorSize != 512) {
3918 op.emitWarning() << "aievec.shift conversion with result vector size "
3919 << resultVectorSize << " is not implemented.\n";
3920 return failure();
3921 }
3922
3923 // assume step is always zero
3924 auto stepCst = LLVM::ConstantOp::create(
3925 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
3926
3927 // create xllvm intrinsic
3928 Value shiftOp = nullptr;
3929 SmallVector<Value> operands(
3930 {adaptor.getLhs(), adaptor.getRhs(), stepCst, adaptor.getShift()});
3931 if (llvm::isa<IntegerType>(resultScaTy)) {
3932 // Integer types
3933 shiftOp = xllvm::VectorShiftI512I512IntrOp::create(
3934 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()),
3935 forceCastOperandsToSignature(
3936 rewriter, loc, operands,
3937 {VectorType::get({16}, rewriter.getI32Type()),
3938 VectorType::get({16}, rewriter.getI32Type()),
3939 rewriter.getI32Type(), rewriter.getI32Type()}));
3940 } else {
3941 // Float types
3942 shiftOp = xllvm::VectorShiftBF512BF512IntrOp::create(
3943 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()),
3944 forceCastOperandsToSignature(
3945 rewriter, loc, operands,
3946 {VectorType::get({32}, rewriter.getBF16Type()),
3947 VectorType::get({32}, rewriter.getBF16Type()),
3948 rewriter.getI32Type(), rewriter.getI32Type()}));
3949 }
3950
3951 // create bitcast/shape_cast for result
3952 auto resultVal =
3953 forceCastValueToType(rewriter, loc, shiftOp, op.getResult().getType());
3954 rewriter.replaceOp(op, resultVal);
3955
3956 return success();
3957 }
3958};
3959
3960// AIE2p version of ShiftOp conversion
3962 : public mlir::ConvertOpToLLVMPattern<aievec::ShiftOp> {
3963public:
3964 using ConvertOpToLLVMPattern<aievec::ShiftOp>::ConvertOpToLLVMPattern;
3965
3966 LogicalResult
3967 matchAndRewrite(aievec::ShiftOp op, OpAdaptor adaptor,
3968 ConversionPatternRewriter &rewriter) const override {
3969 Location loc = op.getLoc();
3970
3971 Value result = op.getResult();
3972 VectorType resultType = cast<VectorType>(result.getType());
3973 Type resultScaTy = resultType.getElementType();
3974 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
3975 int resultLanes = getVectorLaneSize(resultType);
3976 int resultVectorSize = resultBitWidth * resultLanes;
3977
3978 if (resultVectorSize != 512) {
3979 op.emitWarning() << "aievec.shift conversion with result vector size "
3980 << resultVectorSize << " is not implemented.\n";
3981 return failure();
3982 }
3983
3984 // assume step is always zero
3985 auto stepCst = LLVM::ConstantOp::create(
3986 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
3987
3988 // create xllvm intrinsic
3989 Value shiftOp = nullptr;
3990 SmallVector<Value> operands(
3991 {adaptor.getLhs(), adaptor.getRhs(), stepCst, adaptor.getShift()});
3992 if (llvm::isa<IntegerType>(resultScaTy)) {
3993 // Integer types - use AIE2p intrinsic
3994 shiftOp = xllvm::VectorShiftI512I512AIE2pIntrOp::create(
3995 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()),
3996 forceCastOperandsToSignature(
3997 rewriter, loc, operands,
3998 {VectorType::get({16}, rewriter.getI32Type()),
3999 VectorType::get({16}, rewriter.getI32Type()),
4000 rewriter.getI32Type(), rewriter.getI32Type()}));
4001 } else {
4002 // Float types - use AIE2p intrinsic
4003 shiftOp = xllvm::VectorShiftBF512BF512AIE2pIntrOp::create(
4004 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()),
4005 forceCastOperandsToSignature(
4006 rewriter, loc, operands,
4007 {VectorType::get({32}, rewriter.getBF16Type()),
4008 VectorType::get({32}, rewriter.getBF16Type()),
4009 rewriter.getI32Type(), rewriter.getI32Type()}));
4010 }
4011
4012 // create bitcast/shape_cast for result
4013 auto resultVal =
4014 forceCastValueToType(rewriter, loc, shiftOp, op.getResult().getType());
4015 rewriter.replaceOp(op, resultVal);
4016
4017 return success();
4018 }
4019};
4020
4022 : public mlir::ConvertOpToLLVMPattern<aievec::ExtElemOp> {
4023public:
4024 using ConvertOpToLLVMPattern<aievec::ExtElemOp>::ConvertOpToLLVMPattern;
4025
4026 LogicalResult
4027 matchAndRewrite(aievec::ExtElemOp op, OpAdaptor adaptor,
4028 ConversionPatternRewriter &rewriter) const override {
4029 Location loc = op.getLoc();
4030
4031 Type resultType = op.getResult().getType();
4032 unsigned resultBitWidth = resultType.getIntOrFloatBitWidth();
4033
4034 Value src = adaptor.getSource();
4035 VectorType srcType = cast<VectorType>(src.getType());
4036 Type srcScalarType = srcType.getElementType();
4037 unsigned srcBitWidth = srcScalarType.getIntOrFloatBitWidth();
4038 int srcLanes = getVectorLaneSize(srcType);
4039 int srcVectorSize = srcBitWidth * srcLanes;
4040
4041 if (srcVectorSize != 512) {
4042 op.emitWarning() << "aievec.ext_elem conversion with source vector size "
4043 << srcVectorSize << " is not supported.\n";
4044 return failure();
4045 }
4046
4047 // create constant for sign
4048 auto signCst = LLVM::ConstantOp::create(
4049 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
4050
4051 // create xllvm intrinsic
4052 Value extElemOp = nullptr;
4053 SmallVector<Value> operands(
4054 {adaptor.getSource(), adaptor.getIndex(), signCst});
4055 if (resultBitWidth == 8) {
4056 extElemOp = xllvm::VectorExtractElem8I512IntrOp::create(
4057 rewriter, loc, rewriter.getI32Type(),
4058 forceCastOperandsToSignature(
4059 rewriter, loc, operands,
4060 {VectorType::get({64}, rewriter.getI8Type()),
4061 rewriter.getI32Type(), rewriter.getI32Type()}));
4062 } else if (resultBitWidth == 16) {
4063 extElemOp = xllvm::VectorExtractElem16I512IntrOp::create(
4064 rewriter, loc, rewriter.getI32Type(),
4065 forceCastOperandsToSignature(
4066 rewriter, loc, operands,
4067 {VectorType::get({32}, rewriter.getI16Type()),
4068 rewriter.getI32Type(), rewriter.getI32Type()}));
4069 } else if (resultBitWidth == 32) {
4070 extElemOp = xllvm::VectorExtractElem32I512IntrOp::create(
4071 rewriter, loc, rewriter.getI32Type(),
4072 forceCastOperandsToSignature(
4073 rewriter, loc, operands,
4074 {VectorType::get({16}, rewriter.getI32Type()),
4075 rewriter.getI32Type(), rewriter.getI32Type()}));
4076 } else {
4077 op.emitWarning() << "aievec.ext_elem conversion with result bit width "
4078 << resultBitWidth << " is not implemented.\n";
4079 return failure();
4080 }
4081
4082 // create truncation op (and bitcast op)
4083 if (llvm::isa<IntegerType>(resultType)) {
4084 if (resultBitWidth < 32) {
4085 // Two-step truncation to avoid direct i32→i8 which the AIE2
4086 // backend cannot legalize after SLP vectorization.
4087 if (resultBitWidth < 16) {
4088 auto i16Ty = rewriter.getI16Type();
4089 auto trunc16 = LLVM::TruncOp::create(rewriter, loc, i16Ty, extElemOp);
4090 rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, resultType,
4091 trunc16.getResult());
4092 } else {
4093 rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, resultType, extElemOp);
4094 }
4095 } else {
4096 rewriter.replaceOp(op, extElemOp);
4097 }
4098 } else {
4099 // Float types
4100 if (resultBitWidth == 16) {
4101 extElemOp = LLVM::TruncOp::create(rewriter, loc, rewriter.getI16Type(),
4102 extElemOp);
4103 }
4104 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, resultType, extElemOp);
4105 }
4106
4107 return success();
4108 }
4109};
4110
4112 : public mlir::ConvertOpToLLVMPattern<aievec::FMAElemOp> {
4113public:
4114 using ConvertOpToLLVMPattern<aievec::FMAElemOp>::ConvertOpToLLVMPattern;
4115
4116 LogicalResult
4117 matchAndRewrite(aievec::FMAElemOp fmaOp, OpAdaptor adaptor,
4118 ConversionPatternRewriter &rewriter) const override {
4119 auto loc = fmaOp.getLoc();
4120 auto lhs = adaptor.getLhs();
4121 auto rhs = adaptor.getRhs();
4122 auto acc = adaptor.getAcc();
4123 auto lhsTy = cast<VectorType>(lhs.getType());
4124 auto rhsTy = cast<VectorType>(rhs.getType());
4125 auto accTy = cast<VectorType>(acc.getType());
4126 auto flatLhsTy = getFlattenedVectorType(lhsTy);
4127 auto flatRhsTy = getFlattenedVectorType(rhsTy);
4128 auto flatAccTy = getFlattenedVectorType(accTy);
4129
4130 // Flatten operands, if needed
4131 if (lhsTy != flatLhsTy)
4132 lhs = vector::ShapeCastOp::create(rewriter, loc, flatLhsTy, lhs);
4133 if (rhsTy != flatRhsTy)
4134 rhs = vector::ShapeCastOp::create(rewriter, loc, flatRhsTy, rhs);
4135 if (accTy != flatAccTy)
4136 acc = vector::ShapeCastOp::create(rewriter, loc, flatAccTy, acc);
4137
4138 // Build vmac configuration constant
4139 Type i32ty = rewriter.getI32Type();
4140 auto confCst = LLVM::ConstantOp::create(
4141 rewriter, loc, i32ty,
4142 rewriter.getI32IntegerAttr(aiev2_vmac_compute_control(
4143 /*sgn_x=*/0, /*sgn_y=*/0, /*amode=*/2, /*bmode=*/3,
4144 /*variant=*/1, /*zero_acc=*/0, /*shift16=*/0,
4145 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
4146 /*sub_mask=*/0)));
4147
4148 // Pad 16-lane bf16 operands to 32-lane using set+upd intrinsics.
4149 // forceCastOperandsToSignature only does bitwise reinterpretation, which
4150 // leaves garbage in the upper lanes. The MAC intrinsic requires properly
4151 // zero-padded v32bf16 inputs.
4152 auto v32bf16Ty = VectorType::get({32}, rewriter.getBF16Type());
4153 if (flatLhsTy.getElementType().isBF16() &&
4154 flatLhsTy.getNumElements() < 32) {
4155 auto zero32 = LLVM::ConstantOp::create(rewriter, loc, i32ty,
4156 rewriter.getI32IntegerAttr(0));
4157 auto zeros_i16 = xllvm::VectorBroadcast16I512IntrOp::create(
4158 rewriter, loc, VectorType::get({32}, rewriter.getI16Type()), zero32);
4159 auto zeros_bf16 =
4160 LLVM::BitcastOp::create(rewriter, loc, v32bf16Ty, zeros_i16);
4161 auto zeroVec = xllvm::ExtBF256BF512IntrOp::create(
4162 rewriter, loc, VectorType::get({16}, rewriter.getBF16Type()),
4163 zeros_bf16, zero32);
4164
4165 auto idx1 = LLVM::ConstantOp::create(rewriter, loc, i32ty,
4166 rewriter.getI32IntegerAttr(1));
4167
4168 auto lhsSet = xllvm::VectorSetBF512BF256IntrOp::create(
4169 rewriter, loc, v32bf16Ty, lhs, zero32);
4170 lhs = xllvm::UpdBF512BF256IntrOp::create(rewriter, loc, v32bf16Ty, lhsSet,
4171 zeroVec, idx1);
4172
4173 auto rhsSet = xllvm::VectorSetBF512BF256IntrOp::create(
4174 rewriter, loc, v32bf16Ty, rhs, zero32);
4175 rhs = xllvm::UpdBF512BF256IntrOp::create(rewriter, loc, v32bf16Ty, rhsSet,
4176 zeroVec, idx1);
4177 }
4178
4179 // Insert vmac intrinsic
4180 auto v8i64Ty = VectorType::get({8}, rewriter.getI64Type());
4181 auto macIntrOp = xllvm::MacConfBF16IntrOp::create(
4182 rewriter, loc, v8i64Ty,
4183 forceCastOperandsToSignature(rewriter, loc, {lhs, rhs, acc, confCst},
4184 {v32bf16Ty, v32bf16Ty, v8i64Ty, i32ty}));
4185
4186 // Recast/Reshape result
4187 auto resVal =
4188 forceCastValueToType(rewriter, loc, macIntrOp.getResult(), flatAccTy);
4189 if (flatAccTy != accTy)
4190 resVal = vector::ShapeCastOp::create(rewriter, loc, accTy, resVal);
4191
4192 rewriter.replaceOp(fmaOp, resVal);
4193 return success();
4194 }
4195};
4196
4198 : public mlir::ConvertOpToLLVMPattern<aievec::MatMulOp> {
4199 using ConvertOpToLLVMPattern<aievec::MatMulOp>::ConvertOpToLLVMPattern;
4200
4201 struct DecodedMatMulOp {
4202 typedef enum { I32, I64, BF16 } Kind;
4203
4204 Kind kind;
4205 Value lhs;
4206 Value rhs;
4207 Value acc;
4208 int conf;
4209 };
4210
4211 static DecodedMatMulOp decodeMatMulOp(OpAdaptor op) {
4212 Value lhs = op.getLhs();
4213 Value rhs = op.getRhs();
4214 Value acc = op.getAcc();
4215 auto accVecTy = cast<VectorType>(acc.getType());
4216 if (isa<Float32Type>(accVecTy.getElementType()))
4217 // <4x8xbf16> x <8x4xbf16> + <4x4xf32>
4218 return {DecodedMatMulOp::Kind::BF16, lhs, rhs, acc,
4219 aiev2_vmac_compute_control(
4220 /*sgn_x=*/0, /*sgn_y=*/0, /*amode=*/2, /*bmode=*/3,
4221 /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0,
4222 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
4223 /*sub_mask=*/0)};
4224
4225 // Helper: look through vector.shape_cast ops to find the defining op.
4226 // The VectorToAIEVec pass inserts shape_casts (via reshapeLeadingUnitDims)
4227 // between the extension ops and the matmul, which hides the signedness.
4228 auto lookThroughShapeCasts = [](Value v) -> Value {
4229 while (auto castOp = v.getDefiningOp<vector::ShapeCastOp>())
4230 v = castOp.getSource();
4231 return v;
4232 };
4233
4234 int signX = 0, signY = 0;
4235 auto lhsVecTy = cast<VectorType>(lhs.getType());
4236 auto lhsScaTy = cast<IntegerType>(lhsVecTy.getElementType());
4237 Value lhsOrig = lookThroughShapeCasts(lhs);
4238 if (auto extSIOp = lhsOrig.getDefiningOp<arith::ExtSIOp>()) {
4239 lhs = lookThroughShapeCasts(extSIOp.getIn());
4240 lhsVecTy = cast<VectorType>(lhs.getType());
4241 lhsScaTy = cast<IntegerType>(lhsVecTy.getElementType());
4242 signX = 1;
4243 } else if (auto extUIOp = lhsOrig.getDefiningOp<arith::ExtUIOp>()) {
4244 lhs = lookThroughShapeCasts(extUIOp.getIn());
4245 lhsVecTy = cast<VectorType>(lhs.getType());
4246 lhsScaTy = cast<IntegerType>(lhsVecTy.getElementType());
4247 } else {
4248 // Default to unsigned for lhs (activation input is typically uint8).
4249 // The VectorToAIEVec pass strips extsi/extui before creating
4250 // aievec.matmul, so sign info is not available here. Using unsigned
4251 // for A matches the common use case of uint8 activations × int8 weights.
4252 if (lhsScaTy.isUnsigned())
4253 signX = 0;
4254 }
4255 auto lhsShape = lhsVecTy.getShape();
4256
4257 auto rhsVecTy = cast<VectorType>(rhs.getType());
4258 auto rhsScaTy = cast<IntegerType>(rhsVecTy.getElementType());
4259 Value rhsOrig = lookThroughShapeCasts(rhs);
4260 if (auto extSIOp = rhsOrig.getDefiningOp<arith::ExtSIOp>()) {
4261 rhs = lookThroughShapeCasts(extSIOp.getIn());
4262 rhsVecTy = cast<VectorType>(rhs.getType());
4263 rhsScaTy = cast<IntegerType>(rhsVecTy.getElementType());
4264 signY = 1;
4265 } else if (auto extUIOp = rhsOrig.getDefiningOp<arith::ExtUIOp>()) {
4266 rhs = lookThroughShapeCasts(extUIOp.getIn());
4267 rhsVecTy = cast<VectorType>(rhs.getType());
4268 rhsScaTy = cast<IntegerType>(rhsVecTy.getElementType());
4269 } else {
4270 // NOTE: We're choosing 'signed' by default
4271 if (!rhsScaTy.isUnsigned())
4272 signY = 1;
4273 }
4274
4275 unsigned lhsBitWidth = lhsScaTy.getWidth();
4276 unsigned rhsBitWidth = rhsScaTy.getWidth();
4277 auto accScaTy = cast<IntegerType>(accVecTy.getElementType());
4278 unsigned accBitWidth = accScaTy.getWidth();
4279 if (accBitWidth == 32) {
4280 if (lhsBitWidth == 8) {
4281 if (rhsBitWidth == 4) {
4282 // <4x16xi8> x <16x8xi4> + <4x8xi32>
4283 return {DecodedMatMulOp::Kind::I32, lhs, rhs, acc,
4284 aiev2_vmac_compute_control(
4285 /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/0,
4286 /*bmode=*/0,
4287 /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0,
4288 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
4289 /*sub_mask=*/0)};
4290 } else {
4291 // <4x8xi8> x <8x8xi8> + <4x8xi32>
4292 return {DecodedMatMulOp::Kind::I32, lhs, rhs, acc,
4293 aiev2_vmac_compute_control(
4294 /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/0,
4295 /*bmode=*/1,
4296 /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0,
4297 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
4298 /*sub_mask=*/0)};
4299 }
4300 } else {
4301 if (rhsBitWidth == 8) {
4302 // <4x4xi16> x <4x8xi8> + <4x8xi32>
4303 return {DecodedMatMulOp::Kind::I32, lhs, rhs, acc,
4304 aiev2_vmac_compute_control(
4305 /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/0,
4306 /*bmode=*/2,
4307 /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0,
4308 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
4309 /*sub_mask=*/0)};
4310 } else {
4311 // <4x2xi16> x <2x8xi16> + <4x8xi32>
4312 return {DecodedMatMulOp::Kind::I32, lhs, rhs, acc,
4313 aiev2_vmac_compute_control(
4314 /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/0,
4315 /*bmode=*/3,
4316 /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0,
4317 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
4318 /*sub_mask=*/0)};
4319 }
4320 }
4321 }
4322
4323 if (lhsBitWidth == 16) {
4324 if (rhsBitWidth == 8) {
4325 if (lhsShape == ArrayRef<int64_t>({2, 8})) {
4326 // <2x8xi16> x <8x8xi8> + <2x8xi64>
4327 return {DecodedMatMulOp::Kind::I64, lhs, rhs, acc,
4328 aiev2_vmac_compute_control(
4329 /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/1,
4330 /*bmode=*/2,
4331 /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0,
4332 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
4333 /*sub_mask=*/0)};
4334 }
4335 // <4x8xi16> x <8x4xi8> + <4x4xi64>
4336 return {DecodedMatMulOp::Kind::I64, lhs, rhs, acc,
4337 aiev2_vmac_compute_control(
4338 /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/1, /*bmode=*/2,
4339 /*variant=*/1, /*zero_acc=*/0, /*shift16=*/0,
4340 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
4341 /*sub_mask=*/0)};
4342 }
4343 if (lhsShape == ArrayRef<int64_t>({2, 4})) {
4344 // <2x4xi16> x <4x8xi16> + <2x8xi64>
4345 return {DecodedMatMulOp::Kind::I64, lhs, rhs, acc,
4346 aiev2_vmac_compute_control(
4347 /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/1, /*bmode=*/3,
4348 /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0,
4349 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
4350 /*sub_mask=*/0)};
4351 }
4352 // <4x4xi16> x <4x4xi16> + <4x4xi64>
4353 return {DecodedMatMulOp::Kind::I64, lhs, rhs, acc,
4354 aiev2_vmac_compute_control(
4355 /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/1, /*bmode=*/3,
4356 /*variant=*/1, /*zero_acc=*/0, /*shift16=*/0,
4357 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
4358 /*sub_mask=*/0)};
4359 }
4360 // <4x2xi32> x <2x4xi16> + <4x4xi64>
4361 return {DecodedMatMulOp::Kind::I64, lhs, rhs, acc,
4362 aiev2_vmac_compute_control(
4363 /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/1, /*bmode=*/0,
4364 /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0,
4365 /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
4366 /*sub_mask=*/0)};
4367 }
4368
4369 LogicalResult
4370 matchAndRewrite(aievec::MatMulOp op, OpAdaptor adaptor,
4371 ConversionPatternRewriter &rewriter) const override {
4372 auto decodedMatMulOp = decodeMatMulOp(adaptor);
4373
4374 Location loc = op.getLoc();
4375 // Flatten the inputs
4376 auto lhsFlattenedVecTy =
4377 getFlattenedVectorType(cast<VectorType>(decodedMatMulOp.lhs.getType()));
4378 decodedMatMulOp.lhs = vector::ShapeCastOp::create(
4379 rewriter, loc, lhsFlattenedVecTy, decodedMatMulOp.lhs);
4380 auto rhsFlattenedVecTy =
4381 getFlattenedVectorType(cast<VectorType>(decodedMatMulOp.rhs.getType()));
4382 decodedMatMulOp.rhs = vector::ShapeCastOp::create(
4383 rewriter, loc, rhsFlattenedVecTy, decodedMatMulOp.rhs);
4384 auto accFlattenedVecTy =
4385 getFlattenedVectorType(cast<VectorType>(decodedMatMulOp.acc.getType()));
4386 decodedMatMulOp.acc = vector::ShapeCastOp::create(
4387 rewriter, loc, accFlattenedVecTy, decodedMatMulOp.acc);
4388
4389 Type i32ty = rewriter.getI32Type();
4390 auto confCst = LLVM::ConstantOp::create(
4391 rewriter, loc, i32ty, rewriter.getI32IntegerAttr(decodedMatMulOp.conf));
4392 SmallVector<Value> operands({decodedMatMulOp.lhs, decodedMatMulOp.rhs,
4393 decodedMatMulOp.acc, confCst});
4394 Value matMulResVal;
4395 if (decodedMatMulOp.kind == DecodedMatMulOp::Kind::BF16)
4396 matMulResVal =
4397 xllvm::MacConfBF16IntrOp::create(
4398 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()),
4399 forceCastOperandsToSignature(
4400 rewriter, loc, operands,
4401 {VectorType::get({32}, rewriter.getBF16Type()),
4402 VectorType::get({32}, rewriter.getBF16Type()),
4403 VectorType::get({8}, rewriter.getI64Type()), i32ty}))
4404 .getResult();
4405 else {
4406 SmallVector<Type> intrFuncSig(
4407 {VectorType::get({64}, rewriter.getI8Type()),
4408 VectorType::get({16}, i32ty),
4409 VectorType::get({16}, rewriter.getI64Type()), i32ty});
4410 VectorType v16xi64ty = VectorType::get({16}, rewriter.getI64Type());
4411 if (decodedMatMulOp.kind == DecodedMatMulOp::Kind::I32)
4412 matMulResVal = xllvm::MacConfAcc32IntrOp::create(
4413 rewriter, loc, v16xi64ty,
4414 forceCastOperandsToSignature(rewriter, loc, operands,
4415 intrFuncSig))
4416 .getResult();
4417 else
4418 matMulResVal = xllvm::MacConfAcc64IntrOp::create(
4419 rewriter, loc, v16xi64ty,
4420 forceCastOperandsToSignature(rewriter, loc, operands,
4421 intrFuncSig))
4422 .getResult();
4423 }
4424
4425 auto castFromAcc =
4426 bitcastValueToType(rewriter, loc, matMulResVal, accFlattenedVecTy);
4427
4428 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(),
4429 castFromAcc);
4430
4431 return success();
4432 }
4433};
4434
4435// Helper function to transpose RHS in bf16 format and convert to accfloat
4436// Input: vXxbf16 (X must be 64), Output: v64accfloat
4437static Value transposeAndConvertRHS(OpBuilder &rewriter, Location loc,
4438 Type i32ty, Value rhs64bf16) {
4439 auto v32f32Ty = VectorType::get({32}, rewriter.getF32Type());
4440
4441 // Transpose RHS 8x8 matrix in bf16 format (more efficient)
4442 // Cast v64bf16 to v32i32 for transpose operations
4443 auto rhs64i32 = forceCastValueToType(
4444 rewriter, loc, rhs64bf16, VectorType::get({32}, rewriter.getI32Type()));
4445
4446 // Extract two <16 x i32> chunks
4447 SmallVector<int64_t> chunk0Mask, chunk1Mask;
4448 for (int i = 0; i < 16; ++i) {
4449 chunk0Mask.push_back(i);
4450 chunk1Mask.push_back(16 + i);
4451 }
4452 auto rhs16i32_0 =
4453 vector::ShuffleOp::create(rewriter, loc, rhs64i32, rhs64i32, chunk0Mask);
4454 auto rhs16i32_1 =
4455 vector::ShuffleOp::create(rewriter, loc, rhs64i32, rhs64i32, chunk1Mask);
4456
4457 // Apply vshuffle with modes 52 and 53
4458 auto shuffleMode52 = LLVM::ConstantOp::create(rewriter, loc, i32ty,
4459 rewriter.getI32IntegerAttr(52));
4460 auto shuffleMode53 = LLVM::ConstantOp::create(rewriter, loc, i32ty,
4461 rewriter.getI32IntegerAttr(53));
4462
4463 auto shuffled52 = xllvm::VectorShuffleAIE2pIntrOp::create(
4464 rewriter, loc, VectorType::get({16}, i32ty), rhs16i32_0, rhs16i32_1,
4465 shuffleMode52);
4466 auto shuffled53 = xllvm::VectorShuffleAIE2pIntrOp::create(
4467 rewriter, loc, VectorType::get({16}, i32ty), rhs16i32_0, rhs16i32_1,
4468 shuffleMode53);
4469
4470 // Concatenate to get transposed v32i32
4471 SmallVector<int64_t> transposeConcatMask;
4472 for (int i = 0; i < 32; ++i)
4473 transposeConcatMask.push_back(i);
4474 auto rhsTransposedI32 = vector::ShuffleOp::create(
4475 rewriter, loc, shuffled52, shuffled53, transposeConcatMask);
4476 auto rhsTransposedBF16 =
4477 forceCastValueToType(rewriter, loc, rhsTransposedI32,
4478 VectorType::get({64}, rewriter.getBF16Type()));
4479
4480 // Convert transposed RHS v64bfloat16 to v64accfloat (in two v32 chunks)
4481 SmallVector<int64_t> firstHalfMask, secondHalfMask;
4482 for (int i = 0; i < 32; ++i) {
4483 firstHalfMask.push_back(i);
4484 secondHalfMask.push_back(32 + i);
4485 }
4486
4487 auto rhsT32bf16_lo = vector::ShuffleOp::create(
4488 rewriter, loc, rhsTransposedBF16, rhsTransposedBF16, firstHalfMask);
4489 auto rhsT32bf16_hi = vector::ShuffleOp::create(
4490 rewriter, loc, rhsTransposedBF16, rhsTransposedBF16, secondHalfMask);
4491
4492 auto rhsT32f32_lo = xllvm::Vector32BF16ToV32AccFloatAIE2pIntrOp::create(
4493 rewriter, loc, v32f32Ty, rhsT32bf16_lo);
4494 auto rhsT32f32_hi = xllvm::Vector32BF16ToV32AccFloatAIE2pIntrOp::create(
4495 rewriter, loc, v32f32Ty, rhsT32bf16_hi);
4496
4497 // Concat to v64accfloat
4498 SmallVector<int64_t> concatMask;
4499 for (int i = 0; i < 64; ++i)
4500 concatMask.push_back(i);
4501 return vector::ShuffleOp::create(rewriter, loc, rhsT32f32_lo, rhsT32f32_hi,
4502 concatMask);
4503}
4504
4505// Helper function to perform BFP16-based 8×8 matmul via mac_8x8_8x8T_conf
4506// LHS: v64accfloat, RHS: v64accfloat (transposed 8×8), ACC: v64i32
4507// Returns: v64i32 result
4508static Value performBFP16_8x8MatMul(OpBuilder &rewriter, Location loc,
4509 Type i32ty, Value lhs64f32,
4510 Value rhs64f32Transposed, Value acc64i32,
4511 Value confCst) {
4512 auto v64i32Ty = VectorType::get({64}, rewriter.getI32Type());
4513
4514 // Convert both to BFP16 format
4515 auto bfpStructTy = mlir::LLVM::LLVMStructType::getLiteral(
4516 rewriter.getContext(), {VectorType::get({64}, rewriter.getI8Type()),
4517 VectorType::get({8}, rewriter.getI8Type())});
4518
4519 auto lhsBFP = xllvm::Vector64AccFloatToV64BFP16EBS8AIE2pIntrOp::create(
4520 rewriter, loc, bfpStructTy, lhs64f32);
4521 auto rhsBFP = xllvm::Vector64AccFloatToV64BFP16EBS8AIE2pIntrOp::create(
4522 rewriter, loc, bfpStructTy, rhs64f32Transposed);
4523
4524 // Extract mantissa and exponent
4525 auto lhsData = LLVM::ExtractValueOp::create(rewriter, loc, lhsBFP, 0);
4526 auto lhsExp = LLVM::ExtractValueOp::create(rewriter, loc, lhsBFP, 1);
4527 auto rhsData = LLVM::ExtractValueOp::create(rewriter, loc, rhsBFP, 0);
4528 auto rhsExp = LLVM::ExtractValueOp::create(rewriter, loc, rhsBFP, 1);
4529
4530 // Perform BFP16 matmul
4531 return xllvm::MacConfBFP576ACC2048AIE2pIntrOp::create(
4532 rewriter, loc, v64i32Ty, lhsData, lhsExp, rhsData, rhsExp, acc64i32,
4533 confCst);
4534}
4535
4536// Helper function to perform 8×8×4 BF16 matmul following mac_8x8_8x4_bf16
4537// LHS: 64 bfloat16 (8×8 matrix), RHS: 32 bfloat16 (8×4 matrix)
4538// ACC: 32 float (8×4 result)
4539static Value perform8x8x4MatMul(OpBuilder &rewriter, Location loc, Type i32ty,
4540 Value lhs64bf16, Value rhs32bf16,
4541 Value acc32f32) {
4542 auto v32bf16Ty = VectorType::get({32}, rewriter.getBF16Type());
4543 auto v32f32Ty = VectorType::get({32}, rewriter.getF32Type());
4544
4545 // Extract lower and upper halves of LHS (64 bfloat16 -> 2x 32 bfloat16)
4546 SmallVector<int64_t> lowerMask, upperMask;
4547 for (int i = 0; i < 32; ++i) {
4548 lowerMask.push_back(i);
4549 upperMask.push_back(32 + i);
4550 }
4551
4552 auto xl =
4553 vector::ShuffleOp::create(rewriter, loc, lhs64bf16, lhs64bf16, lowerMask);
4554 auto xh =
4555 vector::ShuffleOp::create(rewriter, loc, lhs64bf16, lhs64bf16, upperMask);
4556
4557 // Cast to v16xi32 for shuffle intrinsic
4558 auto xlI32 =
4559 forceCastValueToType(rewriter, loc, xl, VectorType::get({16}, i32ty));
4560 auto xhI32 =
4561 forceCastValueToType(rewriter, loc, xh, VectorType::get({16}, i32ty));
4562
4563 // Shuffle with T16_8x8_lo (mode 52) and T16_8x8_hi (mode 53)
4564 auto shuffleModeLo = LLVM::ConstantOp::create(rewriter, loc, i32ty,
4565 rewriter.getI32IntegerAttr(52));
4566 auto xa = xllvm::VectorShuffleAIE2pIntrOp::create(
4567 rewriter, loc, VectorType::get({16}, i32ty), xlI32, xhI32, shuffleModeLo);
4568
4569 auto shuffleModeHi = LLVM::ConstantOp::create(rewriter, loc, i32ty,
4570 rewriter.getI32IntegerAttr(53));
4571 auto xb = xllvm::VectorShuffleAIE2pIntrOp::create(
4572 rewriter, loc, VectorType::get({16}, i32ty), xlI32, xhI32, shuffleModeHi);
4573
4574 // Convert back to bfloat16
4575 auto xaBF16 = forceCastValueToType(rewriter, loc, xa, v32bf16Ty);
4576 auto xbBF16 = forceCastValueToType(rewriter, loc, xb, v32bf16Ty);
4577
4578 // Helper to extract and broadcast 8 elements to 32, then shuffle with T16_4x8
4579 auto extractBroadcastShuffle = [&](Value src, int idx) -> Value {
4580 SmallVector<int64_t> extractMask;
4581 int startIdx = idx * 8;
4582 for (int i = 0; i < 8; ++i)
4583 extractMask.push_back(startIdx + i);
4584 // Broadcast by repeating 4 times to get 32 elements
4585 for (int rep = 0; rep < 3; ++rep) {
4586 for (int i = 0; i < 8; ++i)
4587 extractMask.push_back(startIdx + i);
4588 }
4589 auto broadcasted =
4590 vector::ShuffleOp::create(rewriter, loc, src, src, extractMask);
4591
4592 // Apply T16_4x8 shuffle pattern (mode 29)
4593 auto broadI32 = forceCastValueToType(rewriter, loc, broadcasted,
4594 VectorType::get({16}, i32ty));
4595 auto shuffleMode4x8 = LLVM::ConstantOp::create(
4596 rewriter, loc, i32ty, rewriter.getI32IntegerAttr(29));
4597 auto shuffled = xllvm::VectorShuffleAIE2pIntrOp::create(
4598 rewriter, loc, VectorType::get({16}, i32ty), broadI32, broadI32,
4599 shuffleMode4x8);
4600
4601 return forceCastValueToType(rewriter, loc, shuffled, v32bf16Ty);
4602 };
4603
4604 // Prepare 8 row vectors from xa and xb
4605 SmallVector<Value> rowVectors;
4606 for (int i = 0; i < 4; ++i)
4607 rowVectors.push_back(extractBroadcastShuffle(xaBF16, i));
4608 for (int i = 0; i < 4; ++i)
4609 rowVectors.push_back(extractBroadcastShuffle(xbBF16, i));
4610
4611 // Helper to extract and broadcast 4 elements to 32 (for RHS columns)
4612 auto extractBroadcast4 = [&](Value src, int idx) -> Value {
4613 SmallVector<int64_t> mask;
4614 int startIdx = idx * 4;
4615 // Repeat the 4 elements 8 times to get 32 elements
4616 for (int rep = 0; rep < 8; ++rep) {
4617 for (int i = 0; i < 4; ++i)
4618 mask.push_back(startIdx + i);
4619 }
4620 return vector::ShuffleOp::create(rewriter, loc, src, src, mask);
4621 };
4622
4623 // Prepare 8 column vectors from RHS
4624 SmallVector<Value> colVectors;
4625 for (int i = 0; i < 8; ++i)
4626 colVectors.push_back(extractBroadcast4(rhs32bf16, i));
4627
4628 // Perform 8 MAC operations with conf=60 (no zero_acc)
4629 auto conf60 = LLVM::ConstantOp::create(rewriter, loc, i32ty,
4630 rewriter.getI32IntegerAttr(60));
4631
4632 Value acc = acc32f32;
4633 for (int i = 0; i < 8; ++i) {
4634 acc = xllvm::MacConfBF16I512ACC1024AIE2pIntrOp::create(
4635 rewriter, loc, v32f32Ty, rowVectors[i], colVectors[i], acc, conf60);
4636 }
4637
4638 return acc;
4639}
4640
4642 : public mlir::ConvertOpToLLVMPattern<aievec::MatMulOp_AIE2P> {
4643 using ConvertOpToLLVMPattern<aievec::MatMulOp_AIE2P>::ConvertOpToLLVMPattern;
4644 struct DecodedMatMulOp {
4645 typedef enum {
4646 BF16_8x8x8_I1024_ACC2048,
4647 BF16_4x8x8_I1024_ACC1024,
4648 BF16_8x1x8_I512_ACC2048,
4649 BF16_4x8x4_I512_ACC512,
4650 BF16_8x8x4_I512_ACC1024,
4651 I8_8x8x8_I512_ACC2048,
4652 I16_8x2x8_I1024_ACC2048,
4653 UNSUPPORTED
4654 } Kind;
4655 Kind kind;
4656 Value lhs;
4657 Value rhs;
4658 Value acc;
4659 int conf;
4660 };
4661 static DecodedMatMulOp decodeMatMulOp(OpAdaptor op) {
4662 Value lhs = op.getLhs();
4663 Value rhs = op.getRhs();
4664 Value acc = op.getAcc();
4665
4666 auto lhsVecTy = cast<VectorType>(lhs.getType());
4667 auto rhsVecTy = cast<VectorType>(rhs.getType());
4668 auto accVecTy = cast<VectorType>(acc.getType());
4669
4670 // Check for AIE2p integer matmul
4671 if (isa<IntegerType>(lhsVecTy.getElementType()) &&
4672 isa<IntegerType>(rhsVecTy.getElementType()) &&
4673 isa<IntegerType>(accVecTy.getElementType())) {
4674
4675 auto lhsIntTy = cast<IntegerType>(lhsVecTy.getElementType());
4676 auto rhsIntTy = cast<IntegerType>(rhsVecTy.getElementType());
4677 auto accIntTy = cast<IntegerType>(accVecTy.getElementType());
4678
4679 int lhsLanes = getVectorLaneSize(lhsVecTy);
4680 int rhsLanes = getVectorLaneSize(rhsVecTy);
4681 int accLanes = getVectorLaneSize(accVecTy);
4682
4683 // Check for <8x8xi8> x <8x8xi8> + <8x8xi32>
4684 if (lhsIntTy.getWidth() == 8 && rhsIntTy.getWidth() == 8 &&
4685 accIntTy.getWidth() == 32 && lhsLanes == 64 && rhsLanes == 64 &&
4686 accLanes == 64) {
4687 // Uses I512.I512.ACC2048 (64 lanes of i8 -> 64 lanes of i32)
4688 return {DecodedMatMulOp::Kind::I8_8x8x8_I512_ACC2048, lhs, rhs, acc,
4689 776};
4690 }
4691
4692 // Check for <8x2xi16> x <2x8xi16> + <8x8xi32>
4693 // Note: Vectors are <8x8xi16> shape, but only lower <8x2xi16> and
4694 // <2x8xi16> contain data
4695 if (lhsIntTy.getWidth() == 16 && rhsIntTy.getWidth() == 16 &&
4696 accIntTy.getWidth() == 32 && lhsLanes == 16 && rhsLanes == 16 &&
4697 accLanes == 64) {
4698 // Uses I1024.I1024.ACC2048 (64 lanes of i16 -> 64 lanes of i32)
4699 return {DecodedMatMulOp::Kind::I16_8x2x8_I1024_ACC2048, lhs, rhs, acc,
4700 24};
4701 }
4702 }
4703
4704 // Check for AIE2p bf16 matmul
4705 if (isa<BFloat16Type>(lhsVecTy.getElementType()) &&
4706 isa<BFloat16Type>(rhsVecTy.getElementType()) &&
4707 isa<Float32Type>(accVecTy.getElementType())) {
4708
4709 // Determine input size and accumulator size to select the right variant
4710 int lhsLanes = getVectorLaneSize(lhsVecTy);
4711 int rhsLanes = getVectorLaneSize(rhsVecTy);
4712 int accLanes = getVectorLaneSize(accVecTy);
4713
4714 // I512 inputs (32 lanes each) with ACC512 (16 lanes)
4715 if (lhsLanes == 32 && rhsLanes == 32 && accLanes == 16) {
4716 // Uses I512.I512.ACC512 (16 lanes of f32)
4717 return {DecodedMatMulOp::Kind::BF16_4x8x4_I512_ACC512, lhs, rhs, acc,
4718 60};
4719 }
4720 // Special case for 8x8x4 matmul: <8x8xbf16> x <8x4xbf16> + <8x4xf32>
4721 else if (lhsLanes == 64 && rhsLanes == 32 && accLanes == 32) {
4722 // Uses I512.I512.ACC1024 for each MAC operation
4723 return {DecodedMatMulOp::Kind::BF16_8x8x4_I512_ACC1024, lhs, rhs, acc,
4724 60};
4725 }
4726 // Special case for 4x8x8 matmul: <4x8xbf16> x <8x8xbf16> + <4x8xf32>
4727 else if (lhsLanes == 32 && rhsLanes == 64 && accLanes == 32) {
4728 // Uses BFP16 format via mac_8x8_8x8T_conf
4729 return {DecodedMatMulOp::Kind::BF16_4x8x8_I1024_ACC1024, lhs, rhs, acc,
4730 780};
4731 }
4732 // Special case for 8x1x8 matmul: <8x1xbf16> x <1x8xbf16> + <8x8xf32>
4733 else if (lhsLanes == 8 && rhsLanes == 8 && accLanes == 64) {
4734 // Outer product: transpose+replicate LHS, replicate RHS, use
4735 // mac_elem_64_conf
4736 return {DecodedMatMulOp::Kind::BF16_8x1x8_I512_ACC2048, lhs, rhs, acc,
4737 60};
4738 }
4739 // I1024 inputs (64 lanes each)
4740 else if (lhsLanes == 64 && rhsLanes == 64 && accLanes == 64) {
4741 // Uses I1024.I1024.ACC2048 (64 lanes of f32)
4742 return {DecodedMatMulOp::Kind::BF16_8x8x8_I1024_ACC2048, lhs, rhs, acc,
4743 60};
4744 }
4745 }
4746
4747 return {DecodedMatMulOp::Kind::UNSUPPORTED, lhs, rhs, acc, -1};
4748 }
4749 LogicalResult
4750 matchAndRewrite(aievec::MatMulOp_AIE2P op, OpAdaptor adaptor,
4751 ConversionPatternRewriter &rewriter) const override {
4752 auto decodedMatMulOp = decodeMatMulOp(adaptor);
4753 if (decodedMatMulOp.kind == DecodedMatMulOp::Kind::UNSUPPORTED) {
4754 op.emitWarning() << "aievec.matmul_aie2p conversion is not supported for "
4755 "this type combination.\n";
4756 return failure();
4757 }
4758 Location loc = op.getLoc();
4759
4760 // Flatten the inputs
4761 auto lhsFlattenedVecTy =
4762 getFlattenedVectorType(cast<VectorType>(decodedMatMulOp.lhs.getType()));
4763 decodedMatMulOp.lhs = vector::ShapeCastOp::create(
4764 rewriter, loc, lhsFlattenedVecTy, decodedMatMulOp.lhs);
4765 auto rhsFlattenedVecTy =
4766 getFlattenedVectorType(cast<VectorType>(decodedMatMulOp.rhs.getType()));
4767 decodedMatMulOp.rhs = vector::ShapeCastOp::create(
4768 rewriter, loc, rhsFlattenedVecTy, decodedMatMulOp.rhs);
4769 auto accFlattenedVecTy =
4770 getFlattenedVectorType(cast<VectorType>(decodedMatMulOp.acc.getType()));
4771 decodedMatMulOp.acc = vector::ShapeCastOp::create(
4772 rewriter, loc, accFlattenedVecTy, decodedMatMulOp.acc);
4773 Type i32ty = rewriter.getI32Type();
4774 auto confCst = LLVM::ConstantOp::create(
4775 rewriter, loc, i32ty, rewriter.getI32IntegerAttr(decodedMatMulOp.conf));
4776
4777 SmallVector<Value> operands({decodedMatMulOp.lhs, decodedMatMulOp.rhs,
4778 decodedMatMulOp.acc, confCst});
4779
4780 Value matMulResVal;
4781
4782 if (decodedMatMulOp.kind == DecodedMatMulOp::Kind::I8_8x8x8_I512_ACC2048) {
4783 // <8x8xi8> x <8x8xi8> + <8x8xi32>
4784 // Signature: <32 x i64> @llvm.aie2p.I512.I512.ACC2048.mac.conf(
4785 // <16 x i32>, <32 x i16>, <32 x i64>, i32)
4786 // Bitcast LHS <64 x i8> -> <16 x i32>
4787 // Bitcast RHS <64 x i8> -> <32 x i16>
4788 // Bitcast ACC <64 x i32> -> <32 x i64>
4789 matMulResVal =
4790 xllvm::MacConfI512ACC2048AIE2pIntrOp::create(
4791 rewriter, loc, VectorType::get({32}, rewriter.getI64Type()),
4792 forceCastOperandsToSignature(
4793 rewriter, loc, operands,
4794 {VectorType::get({16}, rewriter.getI32Type()),
4795 VectorType::get({32}, rewriter.getI16Type()),
4796 VectorType::get({32}, rewriter.getI64Type()), i32ty}))
4797 .getResult();
4798 } else if (decodedMatMulOp.kind ==
4799 DecodedMatMulOp::Kind::I16_8x2x8_I1024_ACC2048) {
4800 // <8x2xi16> x <2x8xi16> + <8x8xi32>
4801 // Input vectors are 16 lanes each, need to pad to 64 lanes for intrinsic
4802 // Signature: <32 x i64> @llvm.aie2p.I1024.I1024.ACC2048.mac.conf(
4803 // <32 x i32>, <64 x i16>, <32 x i64>, i32)
4804
4805 // Pad LHS from <16 x i16> to <64 x i16> using shuffle
4806 SmallVector<int64_t> lhsPadMask;
4807 for (int i = 0; i < 16; ++i)
4808 lhsPadMask.push_back(i);
4809 for (int i = 16; i < 64; ++i)
4810 lhsPadMask.push_back(-1); // undef/poison
4811 auto lhsPadded = vector::ShuffleOp::create(
4812 rewriter, loc, decodedMatMulOp.lhs, decodedMatMulOp.lhs, lhsPadMask);
4813
4814 // Pad RHS from <16 x i16> to <64 x i16> using shuffle
4815 SmallVector<int64_t> rhsPadMask;
4816 for (int i = 0; i < 16; ++i)
4817 rhsPadMask.push_back(i);
4818 for (int i = 16; i < 64; ++i)
4819 rhsPadMask.push_back(-1); // undef/poison
4820 auto rhsPadded = vector::ShuffleOp::create(
4821 rewriter, loc, decodedMatMulOp.rhs, decodedMatMulOp.rhs, rhsPadMask);
4822
4823 // Update operands with padded vectors
4824 SmallVector<Value> paddedOperands(
4825 {lhsPadded, rhsPadded, decodedMatMulOp.acc, confCst});
4826
4827 // Bitcast LHS <64 x i16> -> <32 x i32>
4828 // Keep RHS as <64 x i16>
4829 // Bitcast ACC <64 x i32> -> <32 x i64>
4830 matMulResVal =
4831 xllvm::MacConfI1024ACC2048AIE2pIntrOp::create(
4832 rewriter, loc, VectorType::get({32}, rewriter.getI64Type()),
4833 forceCastOperandsToSignature(
4834 rewriter, loc, paddedOperands,
4835 {VectorType::get({32}, rewriter.getI32Type()),
4836 VectorType::get({64}, rewriter.getI16Type()),
4837 VectorType::get({32}, rewriter.getI64Type()), i32ty}))
4838 .getResult();
4839 } else if (decodedMatMulOp.kind ==
4840 DecodedMatMulOp::Kind::BF16_8x8x8_I1024_ACC2048) {
4841 // <8x8xbf16> x <8x8xbf16> + <8x8xf32>
4842 // This implements the 8×8×8 BF16 matmul using BFP16 format
4843 // Following the aie_api reference implementation that converts to BFP16
4844
4845 auto v32f32Ty = VectorType::get({32}, rewriter.getF32Type());
4846
4847 // Step 1: Convert LHS v64bfloat16 to v64accfloat (in two v32 chunks)
4848 SmallVector<int64_t> firstHalfMask, secondHalfMask;
4849 for (int i = 0; i < 32; ++i) {
4850 firstHalfMask.push_back(i);
4851 secondHalfMask.push_back(32 + i);
4852 }
4853
4854 auto lhs32bf16_lo =
4855 vector::ShuffleOp::create(rewriter, loc, decodedMatMulOp.lhs,
4856 decodedMatMulOp.lhs, firstHalfMask);
4857 auto lhs32bf16_hi =
4858 vector::ShuffleOp::create(rewriter, loc, decodedMatMulOp.lhs,
4859 decodedMatMulOp.lhs, secondHalfMask);
4860
4861 auto lhs32f32_lo = xllvm::Vector32BF16ToV32AccFloatAIE2pIntrOp::create(
4862 rewriter, loc, v32f32Ty, lhs32bf16_lo);
4863 auto lhs32f32_hi = xllvm::Vector32BF16ToV32AccFloatAIE2pIntrOp::create(
4864 rewriter, loc, v32f32Ty, lhs32bf16_hi);
4865
4866 // Concat to v64accfloat
4867 SmallVector<int64_t> concatMask;
4868 for (int i = 0; i < 64; ++i)
4869 concatMask.push_back(i);
4870 auto lhs64f32 = vector::ShuffleOp::create(rewriter, loc, lhs32f32_lo,
4871 lhs32f32_hi, concatMask);
4872
4873 // Step 2: Transpose RHS and convert to accfloat using shared helper
4874 auto rhsTransposed =
4875 transposeAndConvertRHS(rewriter, loc, i32ty, decodedMatMulOp.rhs);
4876
4877 // Step 4: Use shared BFP16 8×8 matmul helper
4878 auto conf780 = LLVM::ConstantOp::create(rewriter, loc, i32ty,
4879 rewriter.getI32IntegerAttr(780));
4880
4881 matMulResVal = performBFP16_8x8MatMul(
4882 rewriter, loc, i32ty, lhs64f32, rhsTransposed,
4883 forceCastValueToType(rewriter, loc, decodedMatMulOp.acc,
4884 VectorType::get({64}, rewriter.getI32Type())),
4885 conf780);
4886 } else if (decodedMatMulOp.kind ==
4887 DecodedMatMulOp::Kind::BF16_4x8x8_I1024_ACC1024) {
4888 // <4x8xbf16> x <8x8xbf16> + <4x8xf32>
4889 // LHS: 32 lanes, RHS: 64 lanes, ACC: 32 lanes
4890 // Similar to 8×8×8 but only use first 32 lanes of LHS
4891
4892 auto v32f32Ty = VectorType::get({32}, rewriter.getF32Type());
4893
4894 // Step 1: Convert LHS v32bfloat16 to v32accfloat, then pad to v64accfloat
4895 auto lhs32f32 = xllvm::Vector32BF16ToV32AccFloatAIE2pIntrOp::create(
4896 rewriter, loc, v32f32Ty, decodedMatMulOp.lhs);
4897
4898 // Pad v32accfloat to v64accfloat using shuffle
4899 SmallVector<int64_t> lhsPadMask;
4900 for (int i = 0; i < 32; ++i)
4901 lhsPadMask.push_back(i);
4902 for (int i = 32; i < 64; ++i)
4903 lhsPadMask.push_back(-1); // poison
4904 auto lhs64f32 = vector::ShuffleOp::create(rewriter, loc, lhs32f32,
4905 lhs32f32, lhsPadMask);
4906
4907 // Step 2: Transpose RHS and convert to accfloat using shared helper
4908 auto rhsTransposed =
4909 transposeAndConvertRHS(rewriter, loc, i32ty, decodedMatMulOp.rhs);
4910
4911 // Step 4: Pad ACC from 32 to 64 i32
4912 SmallVector<int64_t> accPadMask;
4913 for (int i = 0; i < 32; ++i)
4914 accPadMask.push_back(i);
4915 for (int i = 32; i < 64; ++i)
4916 accPadMask.push_back(-1); // poison
4917 auto acc64i32 = vector::ShuffleOp::create(
4918 rewriter, loc,
4919 forceCastValueToType(rewriter, loc, decodedMatMulOp.acc,
4920 VectorType::get({32}, rewriter.getI32Type())),
4921 forceCastValueToType(rewriter, loc, decodedMatMulOp.acc,
4922 VectorType::get({32}, rewriter.getI32Type())),
4923 accPadMask);
4924
4925 // Step 5: Use shared BFP16 8×8 matmul helper
4926 auto result64i32 = performBFP16_8x8MatMul(
4927 rewriter, loc, i32ty, lhs64f32, rhsTransposed, acc64i32, confCst);
4928
4929 // Step 6: Extract first 32 elements
4930 SmallVector<int64_t> extractMask;
4931 for (int i = 0; i < 32; ++i)
4932 extractMask.push_back(i);
4933 matMulResVal = vector::ShuffleOp::create(rewriter, loc, result64i32,
4934 result64i32, extractMask);
4935 } else if (decodedMatMulOp.kind ==
4936 DecodedMatMulOp::Kind::BF16_8x1x8_I512_ACC2048) {
4937 // <8x1xbf16> x <1x8xbf16> + <8x8xf32>
4938 // Outer product: grow_replicate both to 64, transpose LHS, use
4939 // mac_elem_64_conf
4940
4941 auto v64f32Ty = VectorType::get({64}, rewriter.getF32Type());
4942
4943 // Step 1: Replicate LHS from 8 to 64 elements (replicate 8 times)
4944 SmallVector<int64_t> lhsReplicateMask;
4945 for (int rep = 0; rep < 8; ++rep) {
4946 for (int i = 0; i < 8; ++i)
4947 lhsReplicateMask.push_back(i);
4948 }
4949 auto lhs64bf16 =
4950 vector::ShuffleOp::create(rewriter, loc, decodedMatMulOp.lhs,
4951 decodedMatMulOp.lhs, lhsReplicateMask);
4952
4953 // Step 2: Transpose LHS as 8×8 matrix
4954 SmallVector<int64_t> transposeMask;
4955 for (int c = 0; c < 8; ++c) {
4956 for (int r = 0; r < 8; ++r) {
4957 transposeMask.push_back(r * 8 + c);
4958 }
4959 }
4960 auto lhs64bf16Transposed = vector::ShuffleOp::create(
4961 rewriter, loc, lhs64bf16, lhs64bf16, transposeMask);
4962
4963 // Step 3: Replicate RHS from 8 to 64 elements (replicate 8 times)
4964 SmallVector<int64_t> rhsReplicateMask;
4965 for (int rep = 0; rep < 8; ++rep) {
4966 for (int i = 0; i < 8; ++i)
4967 rhsReplicateMask.push_back(i);
4968 }
4969 auto rhs64bf16 =
4970 vector::ShuffleOp::create(rewriter, loc, decodedMatMulOp.rhs,
4971 decodedMatMulOp.rhs, rhsReplicateMask);
4972
4973 // Step 4: Use mac_elem_64_conf (which is
4974 // MacConfBF16I512ACC2048AIE2pIntrOp)
4975 matMulResVal = xllvm::MacConfBF16I512ACC2048AIE2pIntrOp::create(
4976 rewriter, loc, v64f32Ty, lhs64bf16Transposed, rhs64bf16,
4977 decodedMatMulOp.acc, confCst);
4978 } else if (decodedMatMulOp.kind ==
4979 DecodedMatMulOp::Kind::BF16_4x8x4_I512_ACC512) {
4980 // 4×8×4 matmul: <4x8xbf16> x <8x4xbf16> + <4x4xf32>
4981 // Following the reference pattern: a.grow<64>(), b,
4982 // acc.grow<32>().extract<16>(0) We pad LHS 32→64, pad ACC 16→32, call
4983 // 8×8×4 impl, then extract 32→16
4984
4985 // Pad LHS from 32 to 64 bfloat16 using shuffle
4986 SmallVector<int64_t> lhsPadMask;
4987 for (int i = 0; i < 32; ++i)
4988 lhsPadMask.push_back(i);
4989 for (int i = 32; i < 64; ++i)
4990 lhsPadMask.push_back(-1); // poison/undef
4991 auto lhsPadded = vector::ShuffleOp::create(
4992 rewriter, loc, decodedMatMulOp.lhs, decodedMatMulOp.lhs, lhsPadMask);
4993
4994 // Pad ACC from 16 to 32 float using shuffle
4995 SmallVector<int64_t> accPadMask;
4996 for (int i = 0; i < 16; ++i)
4997 accPadMask.push_back(i);
4998 for (int i = 16; i < 32; ++i)
4999 accPadMask.push_back(-1); // poison/undef
5000 auto accPadded = vector::ShuffleOp::create(
5001 rewriter, loc, decodedMatMulOp.acc, decodedMatMulOp.acc, accPadMask);
5002
5003 // Call the shared 8×8×4 helper with padded inputs
5004 Value acc32 = perform8x8x4MatMul(rewriter, loc, i32ty, lhsPadded,
5005 decodedMatMulOp.rhs, accPadded);
5006
5007 // Extract first 16 elements from 32-element result
5008 SmallVector<int64_t> extractMask;
5009 for (int i = 0; i < 16; ++i)
5010 extractMask.push_back(i);
5011 matMulResVal =
5012 vector::ShuffleOp::create(rewriter, loc, acc32, acc32, extractMask);
5013 } else if (decodedMatMulOp.kind ==
5014 DecodedMatMulOp::Kind::BF16_8x8x4_I512_ACC1024) {
5015 // Special 8×8×4 matmul: <8x8xbf16> x <8x4xbf16> + <8x4xf32>
5016 // Uses shared helper function
5017 matMulResVal =
5018 perform8x8x4MatMul(rewriter, loc, i32ty, decodedMatMulOp.lhs,
5019 decodedMatMulOp.rhs, decodedMatMulOp.acc);
5020 }
5021
5022 // Cast from flattened result back to original accumulator shape
5023 auto castFromAcc =
5024 forceCastValueToType(rewriter, loc, matMulResVal, accFlattenedVecTy);
5025 // Reshape back to original shape
5026 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(),
5027 castFromAcc);
5028 return success();
5029 }
5030};
5031
5032// This pattern folds aievec.cast op. For AIE2, the accumulators are in 32/64
5033// bits, and the vectors are in 4/8/16/32 bits. Hence, we don't have to
5034// explicitly express the casting between accumulators and vectors at the LLVM
5035// dialect level. The backend LLVM compiler will decide the correct accumulator
5036// or vector registers given the ops and intrinsics.
5037class FoldAIECastOps : public mlir::ConvertOpToLLVMPattern<aievec::CastOp> {
5038 using ConvertOpToLLVMPattern<aievec::CastOp>::ConvertOpToLLVMPattern;
5039
5040 // Helper to check if a value is a constant zero
5041 static bool isConstantZero(Value val) {
5042 DenseElementsAttr denseAttr;
5043
5044 // Check for both arith.constant and llvm.mlir.constant
5045 if (auto arithConstOp = val.getDefiningOp<arith::ConstantOp>()) {
5046 denseAttr = dyn_cast<DenseElementsAttr>(arithConstOp.getValue());
5047 } else if (auto llvmConstOp = val.getDefiningOp<LLVM::ConstantOp>()) {
5048 denseAttr = dyn_cast<DenseElementsAttr>(llvmConstOp.getValue());
5049 }
5050
5051 if (!denseAttr || !denseAttr.isSplat())
5052 return false;
5053
5054 auto splatAttr = denseAttr.getSplatValue<Attribute>();
5055 if (auto floatAttr = dyn_cast<FloatAttr>(splatAttr))
5056 return floatAttr.getValue().isZero();
5057 if (auto intAttr = dyn_cast<IntegerAttr>(splatAttr))
5058 return intAttr.getValue().isZero();
5059
5060 return false;
5061 }
5062
5063 LogicalResult
5064 matchAndRewrite(aievec::CastOp castOp, OpAdaptor adaptor,
5065 ConversionPatternRewriter &rewriter) const override {
5066 // Special handling for isResAcc=true with zero constant source
5067 // The backend cannot handle zeroinitializer for accumulator types,
5068 // so we must use the vbroadcast.zero.acc1024 intrinsic instead
5069 if (!castOp.getIsResAcc() || !isConstantZero(adaptor.getSource())) {
5070 // Default behavior: fold the cast
5071 rewriter.replaceOp(castOp, adaptor.getSource());
5072 return success();
5073 }
5074
5075 Location loc = castOp.getLoc();
5076 auto srcVecType = cast<VectorType>(castOp.getSource().getType());
5077 Type srcElemType = srcVecType.getElementType();
5078 int lanes = getVectorLaneSize(srcVecType);
5079
5080 // For f32 vectors (accfloat), use vbroadcast.zero.acc1024
5081 if (srcElemType.isF32() && lanes == 16) {
5082 // Call vbroadcast.zero.acc1024 to get vector<16xi64>
5083 auto zeroAcc1024 = xllvm::VectorBroadcastZeroAcc1024IntrOp::create(
5084 rewriter, loc, VectorType::get({16}, rewriter.getI64Type()));
5085
5086 // Extract lower 8 elements to get vector<8xi64> (512-bit accumulator)
5087 SmallVector<int64_t> extractMask = {0, 1, 2, 3, 4, 5, 6, 7};
5088 auto zeroAcc512 = vector::ShuffleOp::create(rewriter, loc, zeroAcc1024,
5089 zeroAcc1024, extractMask);
5090
5091 // Bitcast back to vector<16xf32> to match the cast result type
5092 auto result = LLVM::BitcastOp::create(
5093 rewriter, loc, VectorType::get({16}, rewriter.getF32Type()),
5094 zeroAcc512);
5095
5096 rewriter.replaceOp(castOp, result);
5097 return success();
5098 }
5099
5100 // Fallback: fold the cast (should not reach here for supported cases)
5101 rewriter.replaceOp(castOp, adaptor.getSource());
5102 return success();
5103 }
5104};
5105
5106// AIE2p version of FoldAIECastOps
5108 : public mlir::ConvertOpToLLVMPattern<aievec::CastOp> {
5109 using ConvertOpToLLVMPattern<aievec::CastOp>::ConvertOpToLLVMPattern;
5110
5111 LogicalResult
5112 matchAndRewrite(aievec::CastOp castOp, OpAdaptor adaptor,
5113 ConversionPatternRewriter &rewriter) const override {
5114 // Fold the cast.
5115 rewriter.replaceOp(castOp, adaptor.getSource());
5116 return success();
5117 }
5118};
5119
5121 : public mlir::ConvertOpToLLVMPattern<aievec::ShuffleOp> {
5122 using ConvertOpToLLVMPattern<aievec::ShuffleOp>::ConvertOpToLLVMPattern;
5123
5124 LogicalResult
5125 matchAndRewrite(aievec::ShuffleOp shuffleOp, OpAdaptor adaptor,
5126 ConversionPatternRewriter &rewriter) const override {
5127 auto loc = shuffleOp.getLoc();
5128 auto lhs = adaptor.getLhs();
5129 auto rhs = adaptor.getRhs();
5130 auto i32ty = rewriter.getI32Type();
5131 auto v16xi32ty = VectorType::get({16}, i32ty);
5132 if (!rhs)
5133 rhs = xllvm::UndefV16I32IntrOp::create(rewriter, loc, v16xi32ty);
5134
5135 auto modeAttrVal =
5136 LLVM::ConstantOp::create(rewriter, loc, i32ty,
5137 static_cast<int32_t>(shuffleOp.getMode()))
5138 .getResult();
5139 auto vShuffleVal = xllvm::VectorShuffleIntrOp::create(
5140 rewriter, loc, v16xi32ty,
5141 forceCastOperandsToSignature(
5142 rewriter, loc,
5143 /*operands=*/{lhs, rhs, modeAttrVal},
5144 /*signature=*/{v16xi32ty, v16xi32ty, i32ty}))
5145 .getResult();
5146
5147 vShuffleVal = forceCastValueToType(rewriter, loc, vShuffleVal,
5148 shuffleOp.getResult().getType());
5149
5150 rewriter.replaceOp(shuffleOp, vShuffleVal);
5151
5152 return success();
5153 }
5154};
5155
5156// Convert aievec.inv to xllvm.intr.aie2p.inv intrinsic for AIE2P
5157// Scalar f32: direct conversion to xllvm.intr.aie2p.inv
5158// Vector f32: unroll into scalar xllvm.intr.aie2p.inv operations
5160 : public mlir::ConvertOpToLLVMPattern<aievec::InvOp> {
5161public:
5162 using ConvertOpToLLVMPattern<aievec::InvOp>::ConvertOpToLLVMPattern;
5163
5164 LogicalResult
5165 matchAndRewrite(aievec::InvOp invOp, OpAdaptor adaptor,
5166 ConversionPatternRewriter &rewriter) const override {
5167 auto loc = invOp.getLoc();
5168 auto operandType = adaptor.getSource().getType();
5169
5170 // Handle scalar f32 inverse
5171 if (operandType.isF32()) {
5172 auto invResult = xllvm::InvAIE2pIntrOp::create(
5173 rewriter, loc, rewriter.getF32Type(), adaptor.getSource());
5174 rewriter.replaceOp(invOp, invResult);
5175 return success();
5176 }
5177
5178 // Handle vector<N x f32> inverse
5179 auto vecType = dyn_cast<VectorType>(operandType);
5180 if (!vecType || !vecType.getElementType().isF32())
5181 return failure();
5182
5183 // Unroll vector inverse into scalar operations
5184 int numElements = getVectorLaneSize(vecType);
5185 Value result = LLVM::PoisonOp::create(rewriter, loc, vecType);
5186
5187 for (int i = 0; i < numElements; ++i) {
5188 // Extract element i
5189 auto indexCst = LLVM::ConstantOp::create(
5190 rewriter, loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(i));
5191 auto extractedElem = LLVM::ExtractElementOp::create(
5192 rewriter, loc, adaptor.getSource(), indexCst);
5193
5194 // Call xllvm.intr.aie2p.inv on the scalar
5195 auto invResult = xllvm::InvAIE2pIntrOp::create(
5196 rewriter, loc, rewriter.getF32Type(), extractedElem);
5197
5198 // Insert result back into vector
5199 result = LLVM::InsertElementOp::create(rewriter, loc, vecType, result,
5200 invResult, indexCst);
5201 }
5202
5203 rewriter.replaceOp(invOp, result);
5204 return success();
5205 }
5206};
5207
5208// Convert aievec.exp to xllvm.exp2 intrinsic for AIE2P
5209// Uses the identity: exp(x) = exp2(x * log2(e))
5210// Supports both lane-16 and lane-32 bf16 vectors
5212 : public mlir::ConvertOpToLLVMPattern<aievec::ExpOp> {
5213public:
5214 using ConvertOpToLLVMPattern<aievec::ExpOp>::ConvertOpToLLVMPattern;
5215
5216 LogicalResult
5217 matchAndRewrite(aievec::ExpOp expOp, OpAdaptor adaptor,
5218 ConversionPatternRewriter &rewriter) const override {
5219 auto loc = expOp.getLoc();
5220 auto srcType = cast<VectorType>(adaptor.getSource().getType());
5221 auto srcElemType = srcType.getElementType();
5222 unsigned laneSize = getVectorLaneSize(srcType);
5223
5224 // Support v16bfloat16 and v32bfloat16
5225 if ((laneSize != 16 && laneSize != 32) || !srcElemType.isBF16())
5226 return expOp.emitWarning()
5227 << "aievec.exp conversion only supports v16bfloat16 and "
5228 "v32bfloat16.\n";
5229
5230 // Step 1: Create bf16 constant for log2(e) ≈ 1.442695
5231 auto log2eBF16Const = LLVM::ConstantOp::create(
5232 rewriter, loc, rewriter.getBF16Type(),
5233 rewriter.getFloatAttr(rewriter.getBF16Type(), 1.442695));
5234
5235 // Broadcast log2(e) to match input lane size
5236 SmallVector<int64_t> broadcastMask;
5237 for (unsigned i = 0; i < laneSize; ++i)
5238 broadcastMask.push_back(0);
5239
5240 auto v1bf16 = LLVM::UndefOp::create(
5241 rewriter, loc, VectorType::get({1}, rewriter.getBF16Type()));
5242 auto v1bf16Inserted = LLVM::InsertElementOp::create(
5243 rewriter, loc, v1bf16, log2eBF16Const,
5244 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), 0));
5245
5246 auto log2eVec = vector::ShuffleOp::create(rewriter, loc, v1bf16Inserted,
5247 v1bf16Inserted, broadcastMask);
5248
5249 // Step 2: Multiply input by log2(e) in bf16 domain using MulElemOp
5250 // For lane-16: uses I512.I512.ACC512
5251 // For lane-32: uses I512.I512.ACC1024
5252 auto resultF32Ty =
5253 VectorType::get({(int64_t)laneSize}, rewriter.getF32Type());
5254 auto mulResult = aievec::MulElemOp::create(rewriter, loc, resultF32Ty,
5255 adaptor.getSource(), log2eVec);
5256
5257 // Step 3: Call exp2 intrinsic based on lane size
5258 Value exp2Result;
5259 auto v16bf16Ty = VectorType::get({16}, rewriter.getBF16Type());
5260
5261 if (laneSize == 16) {
5262 // Lane-16: Single exp2 call
5263 // exp2 takes v16float and returns v16bfloat16
5264 exp2Result =
5265 xllvm::Exp2AIE2pIntrOp::create(rewriter, loc, v16bf16Ty, mulResult);
5266 } else {
5267 // Lane-32: Split-and-recombine pattern
5268 // Split v32float into two v16float halves
5269 SmallVector<int64_t> lowerMask, upperMask;
5270 for (int i = 0; i < 16; ++i) {
5271 lowerMask.push_back(i); // indices 0-15
5272 upperMask.push_back(16 + i); // indices 16-31
5273 }
5274
5275 auto lowerHalf = vector::ShuffleOp::create(rewriter, loc, mulResult,
5276 mulResult, lowerMask);
5277 auto upperHalf = vector::ShuffleOp::create(rewriter, loc, mulResult,
5278 mulResult, upperMask);
5279
5280 // Call exp2 on each half separately
5281 auto exp2Lower =
5282 xllvm::Exp2AIE2pIntrOp::create(rewriter, loc, v16bf16Ty, lowerHalf);
5283 auto exp2Upper =
5284 xllvm::Exp2AIE2pIntrOp::create(rewriter, loc, v16bf16Ty, upperHalf);
5285
5286 // Recombine the two v16bfloat16 results into v32bfloat16
5287 SmallVector<int64_t> combineMask;
5288 for (int i = 0; i < 32; ++i)
5289 combineMask.push_back(i);
5290
5291 exp2Result = vector::ShuffleOp::create(rewriter, loc, exp2Lower,
5292 exp2Upper, combineMask);
5293 }
5294
5295 rewriter.replaceOp(expOp, exp2Result);
5296
5297 return success();
5298 }
5299};
5300
5301// Convert aievec.tanh to xllvm.tanh intrinsic for AIE2P
5302// Supports both lane-16 and lane-32 bf16 vectors
5304 : public mlir::ConvertOpToLLVMPattern<aievec::TanhOp> {
5305public:
5306 using ConvertOpToLLVMPattern<aievec::TanhOp>::ConvertOpToLLVMPattern;
5307
5308 LogicalResult
5309 matchAndRewrite(aievec::TanhOp tanhOp, OpAdaptor adaptor,
5310 ConversionPatternRewriter &rewriter) const override {
5311 auto loc = tanhOp.getLoc();
5312 auto srcType = cast<VectorType>(adaptor.getSource().getType());
5313 auto srcElemType = srcType.getElementType();
5314 unsigned laneSize = getVectorLaneSize(srcType);
5315
5316 // Support v16bfloat16 and v32bfloat16
5317 if ((laneSize != 16 && laneSize != 32) || !srcElemType.isBF16())
5318 return tanhOp.emitWarning()
5319 << "aievec.tanh conversion only supports v16bfloat16 and "
5320 "v32bfloat16.\n";
5321
5322 // Step 1: Convert bf16 input to f32 using the dedicated UPS intrinsic
5323 auto v16bf16Ty = VectorType::get({16}, rewriter.getBF16Type());
5324 auto v16f32Ty = VectorType::get({16}, rewriter.getF32Type());
5325
5326 // Step 2: Call tanh intrinsic based on lane size
5327 Value tanhResult;
5328
5329 if (laneSize == 16) {
5330 // Lane-16: Convert bf16->f32 then call tanh
5331 auto inputF32 = xllvm::Vector16BF16ToV16AccFloatAIE2pIntrOp::create(
5332 rewriter, loc, v16f32Ty, adaptor.getSource());
5333 tanhResult =
5334 xllvm::TanhAIE2pIntrOp::create(rewriter, loc, v16bf16Ty, inputF32);
5335 } else {
5336 // Lane-32: Split into two v16bf16, convert each, tanh each, recombine
5337 SmallVector<int64_t> lowerMask, upperMask;
5338 for (int i = 0; i < 16; ++i) {
5339 lowerMask.push_back(i);
5340 upperMask.push_back(16 + i);
5341 }
5342
5343 auto lowerBf16 = vector::ShuffleOp::create(
5344 rewriter, loc, adaptor.getSource(), adaptor.getSource(), lowerMask);
5345 auto upperBf16 = vector::ShuffleOp::create(
5346 rewriter, loc, adaptor.getSource(), adaptor.getSource(), upperMask);
5347
5348 auto lowerF32 = xllvm::Vector16BF16ToV16AccFloatAIE2pIntrOp::create(
5349 rewriter, loc, v16f32Ty, lowerBf16);
5350 auto upperF32 = xllvm::Vector16BF16ToV16AccFloatAIE2pIntrOp::create(
5351 rewriter, loc, v16f32Ty, upperBf16);
5352
5353 auto tanhLower =
5354 xllvm::TanhAIE2pIntrOp::create(rewriter, loc, v16bf16Ty, lowerF32);
5355 auto tanhUpper =
5356 xllvm::TanhAIE2pIntrOp::create(rewriter, loc, v16bf16Ty, upperF32);
5357
5358 SmallVector<int64_t> combineMask;
5359 for (int i = 0; i < 32; ++i)
5360 combineMask.push_back(i);
5361
5362 tanhResult = vector::ShuffleOp::create(rewriter, loc, tanhLower,
5363 tanhUpper, combineMask);
5364 }
5365
5366 rewriter.replaceOp(tanhOp, tanhResult);
5367
5368 return success();
5369 }
5370};
5371
5372// Convert math.rsqrt (scalar f32 or vector f32) to xllvm.intr.aie2p.invsqrt
5373// Scalar f32: direct conversion to xllvm.intr.aie2p.invsqrt
5374// Vector f32: unroll into scalar xllvm.intr.aie2p.invsqrt operations
5376 : public mlir::ConvertOpToLLVMPattern<math::RsqrtOp> {
5377public:
5378 using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
5379
5380 LogicalResult
5381 matchAndRewrite(math::RsqrtOp rsqrtOp, OpAdaptor adaptor,
5382 ConversionPatternRewriter &rewriter) const override {
5383 auto loc = rsqrtOp.getLoc();
5384 auto operandType = adaptor.getOperand().getType();
5385
5386 // Handle scalar f32 rsqrt
5387 if (operandType.isF32()) {
5388 auto rsqrtResult = xllvm::InvsqrtAIE2pIntrOp::create(
5389 rewriter, loc, rewriter.getF32Type(), adaptor.getOperand());
5390 rewriter.replaceOp(rsqrtOp, rsqrtResult);
5391 return success();
5392 }
5393
5394 // Handle vector<N x f32> rsqrt
5395 auto vecType = dyn_cast<VectorType>(operandType);
5396 if (!vecType || !vecType.getElementType().isF32())
5397 return failure();
5398
5399 // Unroll vector rsqrt into scalar operations
5400 int numElements = getVectorLaneSize(vecType);
5401 Value result = LLVM::PoisonOp::create(rewriter, loc, vecType);
5402
5403 for (int i = 0; i < numElements; ++i) {
5404 // Extract element i
5405 auto indexCst = LLVM::ConstantOp::create(
5406 rewriter, loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(i));
5407 auto extractedElem = LLVM::ExtractElementOp::create(
5408 rewriter, loc, adaptor.getOperand(), indexCst);
5409
5410 // Call xllvm.intr.aie2p.invsqrt on the scalar
5411 auto rsqrtResult = xllvm::InvsqrtAIE2pIntrOp::create(
5412 rewriter, loc, rewriter.getF32Type(), extractedElem);
5413
5414 // Insert result back into vector
5415 result = LLVM::InsertElementOp::create(rewriter, loc, vecType, result,
5416 rsqrtResult, indexCst);
5417 }
5418
5419 rewriter.replaceOp(rsqrtOp, result);
5420 return success();
5421 }
5422};
5423
5424// Convert arith.divf for vector<N x f32> to unrolled scalar divisions
5425// Uses a noinline helper function call as a barrier to prevent LLVM
5426// re-vectorization. Scalar f32 divisions are handled by downstream passes.
5427class FdivOpConversion : public mlir::ConvertOpToLLVMPattern<arith::DivFOp> {
5428public:
5429 using ConvertOpToLLVMPattern<arith::DivFOp>::ConvertOpToLLVMPattern;
5430
5431 FdivOpConversion(const LLVMTypeConverter &typeConverter, StringRef device)
5432 : ConvertOpToLLVMPattern(typeConverter), deviceName(device.str()) {}
5433
5434 std::string deviceName;
5435
5436 LogicalResult
5437 matchAndRewrite(arith::DivFOp divOp, OpAdaptor adaptor,
5438 ConversionPatternRewriter &rewriter) const override {
5439 auto loc = divOp.getLoc();
5440 auto lhsType = adaptor.getLhs().getType();
5441
5442 // Only handle vector<N x f32> fdiv
5443 // Scalar f32 fdiv is handled by downstream passes
5444 auto vecType = dyn_cast<VectorType>(lhsType);
5445 if (!vecType || !vecType.getElementType().isF32())
5446 return failure();
5447
5448 auto rhsType = adaptor.getRhs().getType();
5449 auto rhsVecType = dyn_cast<VectorType>(rhsType);
5450 if (!rhsVecType || rhsVecType != vecType)
5451 return failure();
5452
5453 // For AIE2P, implement a/b as a * inv(b) using the hardware reciprocal
5454 // intrinsic. For AIE2, use scalar fdiv helper function.
5455 auto module = divOp->getParentOfType<ModuleOp>();
5456 auto f32Ty = rewriter.getF32Type();
5457
5458 // Select device-specific body for the scalar helper function.
5459 // AIE2P: use inv(b) * a (hardware reciprocal intrinsic is reliable).
5460 // AIE2: use scalar fdiv directly.
5461 std::function<void(OpBuilder &, Location, ValueRange)> bodyBuilder;
5462 if (deviceName == "aie2p") {
5463 bodyBuilder = [](OpBuilder &builder, Location loc, ValueRange args) {
5464 auto invResult = xllvm::InvAIE2pIntrOp::create(
5465 builder, loc, builder.getF32Type(), args[1]);
5466 auto mulResult =
5467 arith::MulFOp::create(builder, loc, args[0], invResult);
5468 LLVM::ReturnOp::create(builder, loc, ValueRange{mulResult});
5469 };
5470 } else {
5471 bodyBuilder = [](OpBuilder &builder, Location loc, ValueRange args) {
5472 auto divResult = arith::DivFOp::create(builder, loc, args[0], args[1]);
5473 LLVM::ReturnOp::create(builder, loc, ValueRange{divResult});
5474 };
5475 }
5476
5477 // Get or create the noinline scalar helper (acts as barrier to prevent
5478 // LLVM from re-vectorizing the scalar ops).
5479 auto helperFunc =
5480 getOrCreateScalarHelperFunc(module, rewriter, "fdiv", deviceName,
5481 /*argTypes=*/{f32Ty, f32Ty},
5482 /*resultType=*/f32Ty, bodyBuilder);
5483
5484 // Unroll vector fdiv into scalar helper function calls.
5485 int numElements = getVectorLaneSize(vecType);
5486 Value result = LLVM::PoisonOp::create(rewriter, loc, vecType);
5487
5488 for (int i = 0; i < numElements; ++i) {
5489 auto indexCst = LLVM::ConstantOp::create(
5490 rewriter, loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(i));
5491 auto lhsElem = LLVM::ExtractElementOp::create(rewriter, loc,
5492 adaptor.getLhs(), indexCst);
5493 auto rhsElem = LLVM::ExtractElementOp::create(rewriter, loc,
5494 adaptor.getRhs(), indexCst);
5495
5496 auto divResult = LLVM::CallOp::create(rewriter, loc, helperFunc,
5497 ValueRange{lhsElem, rhsElem})
5498 ->getResult(0);
5499
5500 result = LLVM::InsertElementOp::create(rewriter, loc, vecType, result,
5501 divResult, indexCst);
5502 }
5503
5504 rewriter.replaceOp(divOp, result);
5505 return success();
5506 }
5507};
5508
5510 mlir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) {
5511 // clang-format off
5512 // Patterns that work for all backends (AIE1, AIE2, AIE2p)
5513 patterns.add<AddOpConversion,
5524 ShuffleOpConversion>(converter);
5525 // clang-format on
5526}
5527
5529 mlir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns,
5530 Aie2Fp32Emulation aie2Fp32EmulationOption) {
5531 // Patterns specific to AIE2 backend
5532 patterns.add<AddElemOpAIE2Conversion, SubElemOpAIE2Conversion>(converter);
5533 patterns.add<MulElemOpConversion>(converter, aie2Fp32EmulationOption);
5534 patterns.add<UPSOpAIE2Conversion, SRSOpAIE2Conversion>(converter);
5535 patterns.add<ShiftOpConversion>(converter);
5536 patterns.add<MaxOpConversion, MinOpConversion>(converter);
5537 patterns.add<CmpOpConversion, SelOpConversion>(converter);
5538 patterns.add<ExtOpConversion>(converter);
5539 patterns.add<ExtractElemOpConversion>(converter);
5540 patterns.add<ConcatOpConversion>(converter);
5541 patterns.add<FMAElemOpConversion>(converter);
5542 patterns.add<FoldAIECastOps>(converter);
5543 patterns.add<FdivOpConversion>(converter, "aie2");
5544}
5545
5546// AIE2p version of ExtractElemOp conversion using LLVM extractelement
5548 : public mlir::ConvertOpToLLVMPattern<aievec::ExtElemOp> {
5549public:
5550 using ConvertOpToLLVMPattern<aievec::ExtElemOp>::ConvertOpToLLVMPattern;
5551
5552 LogicalResult
5553 matchAndRewrite(aievec::ExtElemOp op, OpAdaptor adaptor,
5554 ConversionPatternRewriter &rewriter) const override {
5555 Location loc = op.getLoc();
5556
5557 // AIE2p doesn't have dedicated vextract intrinsics, so use LLVM
5558 // extractelement
5559 Value extracted = LLVM::ExtractElementOp::create(
5560 rewriter, loc, adaptor.getSource(), adaptor.getIndex());
5561
5562 rewriter.replaceOp(op, extracted);
5563 return success();
5564 }
5565};
5566
5567// AIE2p version of ConcatOp conversion using vector.shuffle
5569 : public mlir::ConvertOpToLLVMPattern<aievec::ConcatOp> {
5570public:
5571 using ConvertOpToLLVMPattern<aievec::ConcatOp>::ConvertOpToLLVMPattern;
5572
5573 LogicalResult
5574 matchAndRewrite(aievec::ConcatOp op, OpAdaptor adaptor,
5575 ConversionPatternRewriter &rewriter) const override {
5576 Location loc = op.getLoc();
5577
5578 SmallVector<Value> sources = adaptor.getSources();
5579
5580 if (sources.empty()) {
5581 op.emitWarning() << "aievec.concat with no sources is not supported.\n";
5582 return failure();
5583 }
5584
5585 // AIE2p doesn't have dedicated concat intrinsics, use vector.shuffle
5586 Value result = sources[0];
5587
5588 // Build shuffle mask that concatenates all sources
5589 auto srcType = cast<VectorType>(sources[0].getType());
5590 int64_t srcLanes = getVectorLaneSize(srcType);
5591
5592 if (sources.size() == 2) {
5593 // Concatenate two vectors using shuffle
5594 SmallVector<int64_t> mask;
5595 for (int64_t i = 0; i < srcLanes * 2; ++i)
5596 mask.push_back(i);
5597
5598 result = vector::ShuffleOp::create(rewriter, loc, sources[0], sources[1],
5599 mask);
5600 } else if (sources.size() == 4) {
5601 // Concatenate four vectors: first concat pairs, then concat results
5602 SmallVector<int64_t> pairMask;
5603 for (int64_t i = 0; i < srcLanes * 2; ++i)
5604 pairMask.push_back(i);
5605
5606 auto pair0 = vector::ShuffleOp::create(rewriter, loc, sources[0],
5607 sources[1], pairMask);
5608 auto pair1 = vector::ShuffleOp::create(rewriter, loc, sources[2],
5609 sources[3], pairMask);
5610
5611 SmallVector<int64_t> finalMask;
5612 for (int64_t i = 0; i < srcLanes * 4; ++i)
5613 finalMask.push_back(i);
5614
5615 result =
5616 vector::ShuffleOp::create(rewriter, loc, pair0, pair1, finalMask);
5617 } else {
5618 op.emitWarning() << "aievec.concat with " << sources.size()
5619 << " operands is not supported for AIE2p.\n";
5620 return failure();
5621 }
5622
5623 rewriter.replaceOp(op, result);
5624 return success();
5625 }
5626};
5627
5629 mlir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) {
5630 // Patterns specific to AIE2p backend
5631 patterns.add<AddElemOpAIE2pConversion, SubElemOpAIE2pConversion>(converter);
5632 patterns.add<MulElemOpAIE2pConversion>(converter);
5633 patterns.add<FMAElemOpAIE2pConversion>(converter);
5634 patterns.add<UPSOpAIE2pConversion, SRSOpAIE2pConversion>(converter);
5635 patterns.add<MatMulOpAIE2pConversion>(converter);
5636 patterns.add<ShiftOpAIE2pConversion>(converter);
5637 patterns.add<MaxOpAIE2pConversion, MinOpAIE2pConversion>(converter);
5638 patterns.add<CmpOpAIE2pConversion, SelOpAIE2pConversion>(converter);
5639 patterns.add<ExtOpAIE2pConversion>(converter);
5640 patterns.add<ExtractElemOpAIE2pConversion>(converter);
5641 patterns.add<ConcatOpAIE2pConversion>(converter);
5642 patterns.add<ExpOpAIE2pConversion>(converter);
5643 patterns.add<TanhOpAIE2pConversion>(converter);
5644 patterns.add<InvOpAIE2pConversion>(converter);
5645 patterns.add<BroadcastScalarOpAIE2pConversion>(converter);
5646 patterns.add<RsqrtOpAIE2pConversion>(converter);
5647 patterns.add<FdivOpConversion>(converter, "aie2p");
5648 patterns.add<FoldAIECastOpsAIE2p>(converter);
5649}
5650
5652 mlir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns,
5653 Aie2Fp32Emulation aie2Fp32EmulationOption, StringRef aieTarget) {
5655 if (aieTarget == "aie2p")
5657 else
5659 aie2Fp32EmulationOption);
5660}
5661
5662// Configure legalization rules shared by AIE2 and AIE2p
5663static void configureAIEVecToLLVMLegalizations(LLVMConversionTarget &target) {
5664 // Vector f32 divf is illegal (needs unrolling to scalar divf)
5665 // Scalar f32 divf is legal (handled by downstream passes)
5666 target.addDynamicallyLegalOp<arith::DivFOp>([](arith::DivFOp divOp) {
5667 auto resultType = divOp.getType();
5668 if (auto vecType = dyn_cast<VectorType>(resultType)) {
5669 // Vector f32 divf is illegal and needs conversion
5670 return !vecType.getElementType().isF32();
5671 }
5672 // Scalar divf is legal
5673 return true;
5674 });
5675}
5676
5678 : xilinx::impl::ConvertAIEVecToLLVMBase<ConvertAIEVecToLLVMPass> {
5680 ConvertAIEVecToLLVMPass(const xilinx::ConvertAIEVecToLLVMOptions &options) {
5681 aieTarget = options.aieTarget;
5682 aie2Fp32Emulation = options.aie2Fp32Emulation;
5683 }
5684
5685 void runOnOperation() override {
5686 RewritePatternSet patterns(&getContext());
5687 LLVMTypeConverter converter(&getContext());
5688
5689 // Don't convert vector types, we want to handle multi-dimensional
5690 // vector on our own.
5691 converter.addConversion(
5692 [&](VectorType type) -> std::optional<Type> { return type; });
5693
5694 populateAIEVecToLLVMConversionPatterns(converter, patterns,
5695 aie2Fp32Emulation, aieTarget);
5696
5697 LLVMConversionTarget target(getContext());
5698 target.addIllegalDialect<xilinx::aievec::AIEVecDialect,
5699 xilinx::aievec::aie1::AIEVecAIE1Dialect>();
5700 target.addLegalDialect<arith::ArithDialect, vector::VectorDialect,
5701 xilinx::xllvm::XLLVMDialect, ub::UBDialect>();
5702
5703 // Configure legalizations for AIE2/AIE2p
5704 configureAIEVecToLLVMLegalizations(target);
5705
5706 if (failed(applyPartialConversion(getOperation(), target,
5707 std::move(patterns))))
5708 signalPassFailure();
5709 }
5710};
5711
5712std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
5714 return std::make_unique<ConvertAIEVecToLLVMPass>();
5715}
5716
5717std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
5719 const xilinx::ConvertAIEVecToLLVMOptions &options) {
5720 return std::make_unique<ConvertAIEVecToLLVMPass>(options);
5721}
5722
5723} // namespace xilinx::aievec
static DecodedAddElemOp decodeAddElemOp(OpAdaptor op)
LogicalResult matchAndRewrite(aievec::AddElemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::AddElemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static DecodedAddElemOp decodeAddElemOp(OpAdaptor op)
LogicalResult matchAndRewrite(aievec::aie1::AddOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::BroadcastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::BroadcastScalarOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::BroadcastScalarOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::CmpOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ConcatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ConcatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ExpOp expOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ExtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ExtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ExtElemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ExtElemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::FMAElemOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::FMAElemOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::aie1::FMAOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::DivFOp divOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
FdivOpConversion(const LLVMTypeConverter &typeConverter, StringRef device)
LogicalResult matchAndRewrite(aievec::InvOp invOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::MaxOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::MaxOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::MinOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::MinOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static DecodedMulElemOp decodeMulElemOp(OpAdaptor op)
LogicalResult matchAndRewrite(aievec::MulElemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult convertToEmulatedFP32MulElem(aievec::MulElemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
LogicalResult convertToEmulatedI32MulElem(aievec::MulElemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
MulElemOpConversion(const LLVMTypeConverter &typeConverter, Aie2Fp32Emulation aie2Fp32EmulationOption)
LogicalResult matchAndRewrite(aievec::MulElemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static DecodedMulElemOp decodeMulElemOp(OpAdaptor op)
LogicalResult matchAndRewrite(aievec::aie1::MulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::PackOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static std::string getIntrinsicName(aievec::PackOp op)
LogicalResult matchAndRewrite(math::RsqrtOp rsqrtOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::SRSOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::SRSOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::SelOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static std::string getIntrinsicName(aievec::aie1::SelectOp op)
LogicalResult matchAndRewrite(aievec::aie1::SelectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ShiftOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ShiftOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::SubElemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static DecodedSubElemOp decodeSubElemOp(OpAdaptor op)
LogicalResult matchAndRewrite(aievec::SubElemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static DecodedSubElemOp decodeSubElemOp(OpAdaptor op)
LogicalResult matchAndRewrite(aievec::aie1::SubOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::TanhOp tanhOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static std::string getIntrinsicName(aievec::UPDOp op, int loadSize)
LogicalResult matchAndRewrite(aievec::UPDOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::UPSOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::UPSOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::UnpackOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
PathEndPoint src
mlir::VectorType getFlattenedVectorType(mlir::VectorType vecTy)
int32_t getVectorSizeInBits(mlir::VectorType type)
Definition AIEVecUtils.h:66
unsigned getVectorLaneSize(mlir::VectorType type)
Definition AIEVecUtils.h:55
uint32_t encodeSquare(uint32_t square)
void populateAIEVecToLLVMCommonConversionPatterns(mlir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns)
void encodeConf(uint32_t conf[2], const BufferParams &x, const BufferParams &z, bool sub)
void populateAIEVecToLLVMAIE2ConversionPatterns(mlir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns)
void populateAIEVecToLLVMAIE2pConversionPatterns(mlir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns)
std::string getMulOrFMAIntrinsicName(Operation *op)
std::unique_ptr< mlir::OperationPass< mlir::ModuleOp > > createConvertAIEVecToLLVMPass()
std::string getVectorTypeString(VectorType type, bool abbrev=false, bool acc=false)
void populateAIEVecToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns, Aie2Fp32Emulation aie2Fp32EmulationOption, llvm::StringRef aieTarget)
ConvertAIEVecToLLVMPass(const xilinx::ConvertAIEVecToLLVMOptions &options)