12#include "../PassDetail.h"
20#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
21#include "mlir/Conversion/LLVMCommon/Pattern.h"
22#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
23#include "mlir/Dialect/Math/IR/Math.h"
24#include "mlir/Dialect/UB/IR/UBOps.h"
25#include "mlir/IR/TypeUtilities.h"
30#define GEN_PASS_DEF_CONVERTAIEVECTOLLVM
31#include "aie/Conversion/Passes.h.inc"
38inline static Value bitcastValueToType(OpBuilder &builder, Location loc,
39 Value val, Type dstTy) {
40 return LLVM::BitcastOp::create(builder, loc, dstTy, val).getResult();
46inline static Value widen128bVectorValueTo512b(OpBuilder &builder, Location loc,
48 return xllvm::VectorSetI512I128IntrOp::create(
49 builder, loc, VectorType::get({16}, builder.getI32Type()),
50 bitcastValueToType(builder, loc, val,
51 VectorType::get({4}, builder.getI32Type())))
59inline static Value widen256bVectorValueTo512b(OpBuilder &builder, Location loc,
62 LLVM::ConstantOp::create(builder, loc, builder.getI32Type(), (int32_t)0);
63 return xllvm::VectorSetI512I256IntrOp::create(
64 builder, loc, VectorType::get({16}, builder.getI32Type()),
65 bitcastValueToType(builder, loc, val,
66 VectorType::get({8}, builder.getI32Type())),
74static Value forceCastValueToType(OpBuilder &builder, Location loc, Value val,
76 auto valTy = val.getType();
79 auto srcVecTy = dyn_cast<VectorType>(valTy);
80 auto dstVecTy = dyn_cast<VectorType>(type);
83 assert(dstVecTy &&
"vector values cannot be forced into a non-vector type");
87 if (srcVecTy != flatSrcVecTy)
88 val = vector::ShapeCastOp::create(builder, loc, flatSrcVecTy, val);
93 int64_t dstVecLength =
94 flatDstVecTy.getElementTypeBitWidth() * flatDstVecTy.getShape()[0];
95 int64_t srcVecLength =
96 flatSrcVecTy.getElementTypeBitWidth() * flatSrcVecTy.getShape()[0];
97 if (srcVecLength != dstVecLength) {
98 assert(srcVecLength < dstVecLength &&
99 "only widening forced casts are supported");
100 assert(dstVecLength == 512 &&
101 (srcVecLength == 128 || srcVecLength == 256) &&
102 "only 128b to 512b and 256b to 512b forced casts are supported");
103 if (srcVecLength == 128)
104 val = widen128bVectorValueTo512b(builder, loc, val);
106 val = widen256bVectorValueTo512b(builder, loc, val);
110 val = bitcastValueToType(builder, loc, val, flatDstVecTy);
113 if (flatDstVecTy != dstVecTy)
114 val = vector::ShapeCastOp::create(builder, loc, dstVecTy, val);
120 assert(!dstVecTy &&
"cannot force cast scalar to vector type");
121 return bitcastValueToType(builder, loc, val, type);
128static SmallVector<Value> forceCastOperandsToSignature(OpBuilder &builder,
131 TypeRange signature) {
132 return llvm::to_vector(llvm::map_range(
133 llvm::zip_equal(operands, signature), [&](
auto &&vt) -> Value {
134 return forceCastValueToType(builder, loc, std::get<0>(vt),
157static LLVM::LLVMFuncOp getOrCreateScalarHelperFunc(
158 ModuleOp module, OpBuilder &rewriter, StringRef opName, StringRef device,
159 TypeRange argTypes, Type resultType,
160 std::function<
void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
163 std::string funcName =
"__" + device.str() +
"_scalar_" + opName.str();
166 auto helperFunc =
module.lookupSymbol<LLVM::LLVMFuncOp>(funcName);
171 OpBuilder::InsertionGuard guard(rewriter);
172 rewriter.setInsertionPointToStart(module.getBody());
175 SmallVector<Type> argTypesVec(argTypes.begin(), argTypes.end());
177 helperFunc = LLVM::LLVMFuncOp::create(
178 rewriter, rewriter.getUnknownLoc(), funcName,
179 LLVM::LLVMFunctionType::get(resultType, argTypesVec));
182 helperFunc->setAttr(
"passthrough", rewriter.getArrayAttr(
183 {rewriter.getStringAttr(
"noinline")}));
186 auto *entryBlock = helperFunc.addEntryBlock(rewriter);
187 OpBuilder::InsertionGuard bodyGuard(rewriter);
188 rewriter.setInsertionPointToStart(entryBlock);
191 SmallVector<Value>
args;
192 for (
unsigned i = 0; i < argTypes.size(); ++i)
193 args.push_back(entryBlock->getArgument(i));
196 bodyBuilder(rewriter, rewriter.getUnknownLoc(),
args);
223static inline int aiev2_vmac_compute_control(
int sgn_x,
int sgn_y,
int amode,
224 int bmode,
int variant,
225 int zero_acc,
int shift16,
226 int sub_mul,
int sub_acc1,
227 int sub_acc2,
int sub_mask) {
228 return ((
unsigned)sub_mask << 16) | ((unsigned)shift16 << 10) |
229 ((unsigned)sub_mul << 11) | ((unsigned)sub_acc1 << 12) |
230 ((unsigned)sub_acc2 << 13) | ((unsigned)amode << 1) |
231 ((unsigned)bmode << 3) | ((unsigned)variant << 5) |
232 (((unsigned)sgn_x << 9) | ((unsigned)sgn_y << 8)) |
233 ((
unsigned)zero_acc << 0);
238 std::stringstream ss;
241 if (
auto intType = dyn_cast<IntegerType>(type.getElementType())) {
242 ss << (acc ?
"acc" : abbrev ?
"i" :
"int") << intType.getWidth();
243 }
else if (dyn_cast<FloatType>(type.getElementType())) {
244 ss << (abbrev ?
"f" :
"float");
250 std::string baseName;
252 if (
auto mulOp = dyn_cast<aievec::aie1::MulOp>(op)) {
254 lhs = mulOp.getLhs();
255 result = mulOp.getResult();
256 }
else if (
auto fmaOp = dyn_cast<aievec::aie1::FMAOp>(op)) {
258 lhs = fmaOp.getLhs();
259 result = fmaOp.getResult();
261 VectorType resultType = cast<VectorType>(result.getType());
263 std::stringstream ss;
265 if (dyn_cast<IntegerType>(resultType.getElementType())) {
267 ss << resultSize <<
"."
269 }
else if (dyn_cast<FloatType>(resultType.getElementType())) {
270 ss <<
"vfp" << baseName;
279 out |= ((square >> 0) & 0x3) << 0;
280 out |= ((square >> 4) & 0x3) << 2;
281 out |= ((square >> 8) & 0x3) << 4;
282 out |= ((square >> 12) & 0x3) << 6;
290 conf[0] |= ((x.
step & 0x3F) << 0) | ((z.
step & 0x3F) << 8);
292 conf[1] |= sub << 17;
296 :
public mlir::ConvertOpToLLVMPattern<aievec::aie1::AddOp> {
298 using ConvertOpToLLVMPattern<aievec::aie1::AddOp>::ConvertOpToLLVMPattern;
302 ConversionPatternRewriter &rewriter)
const override {
303 op.emitWarning() <<
"aie.add conversion is not implemented\n";
310 :
public mlir::ConvertOpToLLVMPattern<aievec::AddElemOp> {
312 using ConvertOpToLLVMPattern<aievec::AddElemOp>::ConvertOpToLLVMPattern;
321 auto lhs = op.getLhs();
322 auto lhsVecTy = cast<VectorType>(lhs.getType());
323 auto lhsScaTy = lhsVecTy.getElementType();
324 unsigned lhsBitWidth = lhsScaTy.getIntOrFloatBitWidth();
327 if (llvm::isa<IntegerType>(lhsScaTy)) {
331 if (lhsBitWidth == 32) {
341 ConversionPatternRewriter &rewriter)
const override {
342 Location loc = op.getLoc();
346 op.emitWarning() <<
"aievec.add_elem conversion is not supported.\n";
351 if (decodedAddElemOp.kind ==
353 auto confCst = LLVM::ConstantOp::create(
354 rewriter, loc, rewriter.getI32Type(),
355 rewriter.getI32IntegerAttr(decodedAddElemOp.conf));
356 SmallVector<Value> operands(
357 {adaptor.getLhs(), adaptor.getRhs(), confCst});
359 auto addElemOp = xllvm::AddAccFloatAIE2IntrOp::create(
360 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()),
361 forceCastOperandsToSignature(
362 rewriter, loc, operands,
363 {VectorType::get({8}, rewriter.getI64Type()),
364 VectorType::get({8}, rewriter.getI64Type()),
365 rewriter.getI32Type()}));
368 auto resultVal = forceCastValueToType(rewriter, loc, addElemOp,
369 op.getResult().getType());
370 rewriter.replaceOp(op, resultVal);
374 op.emitWarning() <<
"aievec.add_elem conversion is not supported.\n";
381 :
public mlir::ConvertOpToLLVMPattern<aievec::SubElemOp> {
383 using ConvertOpToLLVMPattern<aievec::SubElemOp>::ConvertOpToLLVMPattern;
392 auto lhs = op.getLhs();
393 auto lhsVecTy = cast<VectorType>(lhs.getType());
394 auto lhsScaTy = lhsVecTy.getElementType();
395 unsigned lhsBitWidth = lhsScaTy.getIntOrFloatBitWidth();
398 if (llvm::isa<IntegerType>(lhsScaTy)) {
402 if (lhsBitWidth == 32) {
412 ConversionPatternRewriter &rewriter)
const override {
413 Location loc = op.getLoc();
417 op.emitWarning() <<
"aievec.sub_elem conversion is not supported.\n";
422 if (decodedSubElemOp.kind ==
424 auto confCst = LLVM::ConstantOp::create(
425 rewriter, loc, rewriter.getI32Type(),
426 rewriter.getI32IntegerAttr(decodedSubElemOp.conf));
427 SmallVector<Value> operands(
428 {adaptor.getLhs(), adaptor.getRhs(), confCst});
430 auto subElemOp = xllvm::SubAccFloatAIE2IntrOp::create(
431 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()),
432 forceCastOperandsToSignature(
433 rewriter, loc, operands,
434 {VectorType::get({8}, rewriter.getI64Type()),
435 VectorType::get({8}, rewriter.getI64Type()),
436 rewriter.getI32Type()}));
439 auto resultVal = forceCastValueToType(rewriter, loc, subElemOp,
440 op.getResult().getType());
441 rewriter.replaceOp(op, resultVal);
445 op.emitWarning() <<
"aievec.sub_elem conversion is not supported.\n";
452 :
public mlir::ConvertOpToLLVMPattern<aievec::AddElemOp> {
454 using ConvertOpToLLVMPattern<aievec::AddElemOp>::ConvertOpToLLVMPattern;
467 auto lhs = op.getLhs();
468 auto lhsVecTy = cast<VectorType>(lhs.getType());
469 auto lhsScaTy = lhsVecTy.getElementType();
470 unsigned lhsBitWidth = lhsScaTy.getIntOrFloatBitWidth();
474 if (llvm::isa<IntegerType>(lhsScaTy)) {
478 if (lhsBitWidth == 32) {
480 if (laneSize == 16) {
482 }
else if (laneSize == 32) {
492 ConversionPatternRewriter &rewriter)
const override {
493 Location loc = op.getLoc();
497 op.emitWarning() <<
"aievec.add_elem conversion is not supported.\n";
503 if (decodedAddElemOp.kind ==
506 auto v8i64Ty = VectorType::get({8}, rewriter.getI64Type());
508 LLVM::BitcastOp::create(rewriter, loc, v8i64Ty, adaptor.getLhs());
510 LLVM::BitcastOp::create(rewriter, loc, v8i64Ty, adaptor.getRhs());
513 auto v32i64Ty = VectorType::get({32}, rewriter.getI64Type());
514 SmallVector<int64_t> expandMask = {0, 1, 2, 3, 4, 5, 6, 7};
515 for (
int i = 8; i < 32; ++i)
516 expandMask.push_back(-1);
519 vector::ShuffleOp::create(rewriter, loc, lhsI64, lhsI64, expandMask);
521 vector::ShuffleOp::create(rewriter, loc, rhsI64, rhsI64, expandMask);
524 auto v64f32Ty = VectorType::get({64}, rewriter.getF32Type());
526 LLVM::BitcastOp::create(rewriter, loc, v64f32Ty, lhsExpanded);
528 LLVM::BitcastOp::create(rewriter, loc, v64f32Ty, rhsExpanded);
531 auto confCst = LLVM::ConstantOp::create(
532 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(60));
535 auto addResult = xllvm::AddACC2048AccFloatAIE2pIntrOp::create(
536 rewriter, loc, v64f32Ty, lhsF32, rhsF32, confCst);
540 LLVM::BitcastOp::create(rewriter, loc, v32i64Ty, addResult);
543 SmallVector<int64_t> extractMask = {0, 1, 2, 3, 4, 5, 6, 7};
544 auto resultExtracted = vector::ShuffleOp::create(rewriter, loc, resultI64,
545 resultI64, extractMask);
548 auto v16f32Ty = VectorType::get({16}, rewriter.getF32Type());
550 LLVM::BitcastOp::create(rewriter, loc, v16f32Ty, resultExtracted);
552 rewriter.replaceOp(op, finalResult);
558 if (decodedAddElemOp.kind ==
561 SmallVector<int64_t> padMask;
562 for (
int i = 0; i < 32; ++i)
563 padMask.push_back(i);
564 for (
int i = 32; i < 64; ++i)
565 padMask.push_back(-1);
567 auto v64f32Ty = VectorType::get({64}, rewriter.getF32Type());
568 auto lhsPadded = vector::ShuffleOp::create(
569 rewriter, loc, adaptor.getLhs(), adaptor.getLhs(), padMask);
570 auto rhsPadded = vector::ShuffleOp::create(
571 rewriter, loc, adaptor.getRhs(), adaptor.getRhs(), padMask);
574 auto confCst = LLVM::ConstantOp::create(
575 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(60));
576 auto addResult = xllvm::AddACC2048AccFloatAIE2pIntrOp::create(
577 rewriter, loc, v64f32Ty, lhsPadded, rhsPadded, confCst);
580 SmallVector<int64_t> extractMask;
581 for (
int i = 0; i < 32; ++i)
582 extractMask.push_back(i);
583 auto finalResult = vector::ShuffleOp::create(rewriter, loc, addResult,
584 addResult, extractMask);
586 rewriter.replaceOp(op, finalResult);
590 op.emitWarning() <<
"aievec.add_elem conversion is not supported.\n";
597 :
public mlir::ConvertOpToLLVMPattern<aievec::SubElemOp> {
599 using ConvertOpToLLVMPattern<aievec::SubElemOp>::ConvertOpToLLVMPattern;
612 auto lhs = op.getLhs();
613 auto lhsVecTy = cast<VectorType>(lhs.getType());
614 auto lhsScaTy = lhsVecTy.getElementType();
615 unsigned lhsBitWidth = lhsScaTy.getIntOrFloatBitWidth();
619 if (llvm::isa<IntegerType>(lhsScaTy)) {
623 if (lhsBitWidth == 32) {
625 if (laneSize == 16) {
627 }
else if (laneSize == 32) {
637 ConversionPatternRewriter &rewriter)
const override {
638 Location loc = op.getLoc();
642 op.emitWarning() <<
"aievec.sub_elem conversion is not supported.\n";
648 if (decodedSubElemOp.kind ==
651 auto v8i64Ty = VectorType::get({8}, rewriter.getI64Type());
653 LLVM::BitcastOp::create(rewriter, loc, v8i64Ty, adaptor.getLhs());
655 LLVM::BitcastOp::create(rewriter, loc, v8i64Ty, adaptor.getRhs());
658 auto v32i64Ty = VectorType::get({32}, rewriter.getI64Type());
659 SmallVector<int64_t> expandMask = {0, 1, 2, 3, 4, 5, 6, 7};
660 for (
int i = 8; i < 32; ++i)
661 expandMask.push_back(-1);
664 vector::ShuffleOp::create(rewriter, loc, lhsI64, lhsI64, expandMask);
666 vector::ShuffleOp::create(rewriter, loc, rhsI64, rhsI64, expandMask);
669 auto v64f32Ty = VectorType::get({64}, rewriter.getF32Type());
671 LLVM::BitcastOp::create(rewriter, loc, v64f32Ty, lhsExpanded);
673 LLVM::BitcastOp::create(rewriter, loc, v64f32Ty, rhsExpanded);
676 auto confCst = LLVM::ConstantOp::create(
677 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(60));
680 auto subResult = xllvm::SubACC2048AccFloatAIE2pIntrOp::create(
681 rewriter, loc, v64f32Ty, lhsF32, rhsF32, confCst);
685 LLVM::BitcastOp::create(rewriter, loc, v32i64Ty, subResult);
688 SmallVector<int64_t> extractMask = {0, 1, 2, 3, 4, 5, 6, 7};
689 auto resultExtracted = vector::ShuffleOp::create(rewriter, loc, resultI64,
690 resultI64, extractMask);
693 auto v16f32Ty = VectorType::get({16}, rewriter.getF32Type());
695 LLVM::BitcastOp::create(rewriter, loc, v16f32Ty, resultExtracted);
697 rewriter.replaceOp(op, finalResult);
703 if (decodedSubElemOp.kind ==
706 SmallVector<int64_t> padMask;
707 for (
int i = 0; i < 32; ++i)
708 padMask.push_back(i);
709 for (
int i = 32; i < 64; ++i)
710 padMask.push_back(-1);
712 auto v64f32Ty = VectorType::get({64}, rewriter.getF32Type());
713 auto lhsPadded = vector::ShuffleOp::create(
714 rewriter, loc, adaptor.getLhs(), adaptor.getLhs(), padMask);
715 auto rhsPadded = vector::ShuffleOp::create(
716 rewriter, loc, adaptor.getRhs(), adaptor.getRhs(), padMask);
719 auto confCst = LLVM::ConstantOp::create(
720 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(60));
721 auto subResult = xllvm::SubACC2048AccFloatAIE2pIntrOp::create(
722 rewriter, loc, v64f32Ty, lhsPadded, rhsPadded, confCst);
725 SmallVector<int64_t> extractMask;
726 for (
int i = 0; i < 32; ++i)
727 extractMask.push_back(i);
728 auto finalResult = vector::ShuffleOp::create(rewriter, loc, subResult,
729 subResult, extractMask);
731 rewriter.replaceOp(op, finalResult);
735 op.emitWarning() <<
"aievec.sub_elem conversion is not supported.\n";
741 :
public mlir::ConvertOpToLLVMPattern<aievec::aie1::SubOp> {
743 using ConvertOpToLLVMPattern<aievec::aie1::SubOp>::ConvertOpToLLVMPattern;
747 ConversionPatternRewriter &rewriter)
const override {
748 op.emitWarning() <<
"aie.sub conversion is not implemented\n";
754 :
public mlir::ConvertOpToLLVMPattern<aievec::aie1::FMAOp> {
756 using ConvertOpToLLVMPattern<aievec::aie1::FMAOp>::ConvertOpToLLVMPattern;
760 ConversionPatternRewriter &rewriter)
const override {
761 auto module = op->getParentOfType<ModuleOp>();
762 MLIRContext *context = rewriter.getContext();
764 auto startType = IntegerType::get(context, 32);
765 auto offsetsType = VectorType::get({2}, IntegerType::get(context, 32));
766 auto confType = VectorType::get({2}, IntegerType::get(context, 32));
770 auto func =
module.lookupSymbol<LLVM::LLVMFuncOp>(
771 StringAttr::get(context, intrinsicName));
774 OpBuilder::InsertionGuard guard(rewriter);
775 rewriter.setInsertionPointToStart(module.getBody());
776 func = LLVM::LLVMFuncOp::create(
777 rewriter, rewriter.getUnknownLoc(), intrinsicName,
778 LLVM::LLVMFunctionType::get(
779 op.getResult().getType(),
780 {op.getLhs().getType(), op.getRhs().getType(),
781 op.getAcc().getType(), startType,
792 op.getXstart().getAsInteger(0, x.
start);
793 op.getXoffsets().getAsInteger(0, x.
offsets);
794 op.getXoffsetsHi().getAsInteger(0, x.
offsets_hi);
795 op.getXstep().getAsInteger(0, x.
step);
796 op.getXsquare().getAsInteger(0, x.
square);
797 op.getZstart().getAsInteger(0, z.
start);
798 op.getZoffsets().getAsInteger(0, z.
offsets);
799 op.getZoffsetsHi().getAsInteger(0, z.
offsets_hi);
800 op.getZstep().getAsInteger(0, z.
step);
801 op.getZsquare().getAsInteger(0, z.
square);
804 uint32_t conf[2] = {0, 0};
808 auto xstartVal = LLVM::ConstantOp::create(
809 rewriter, op->getLoc(), startType, rewriter.getI32IntegerAttr(x.
start));
810 auto ystartVal = LLVM::ConstantOp::create(rewriter, op->getLoc(), startType,
811 rewriter.getI32IntegerAttr(0));
812 auto zstartVal = LLVM::ConstantOp::create(
813 rewriter, op->getLoc(), startType, rewriter.getI32IntegerAttr(z.
start));
814 auto xoffsetsVal = LLVM::ConstantOp::create(
815 rewriter, op->getLoc(), offsetsType,
816 rewriter.getI32VectorAttr({(int32_t)x.offsets, (int32_t)x.offsets_hi}));
817 auto zoffsetsVal = LLVM::ConstantOp::create(
818 rewriter, op->getLoc(), offsetsType,
819 rewriter.getI32VectorAttr({(int32_t)z.offsets, (int32_t)z.offsets_hi}));
820 auto confVal = LLVM::ConstantOp::create(
821 rewriter, op->getLoc(), confType,
822 rewriter.getI32VectorAttr({(int32_t)conf[0], (int32_t)conf[1]}));
823 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
825 ValueRange{op.getLhs(), op.getRhs(), op.getAcc(), xstartVal, ystartVal,
826 zstartVal, xoffsetsVal, zoffsetsVal, confVal});
832 :
public mlir::ConvertOpToLLVMPattern<aievec::aie1::MulOp> {
834 using ConvertOpToLLVMPattern<aievec::aie1::MulOp>::ConvertOpToLLVMPattern;
838 ConversionPatternRewriter &rewriter)
const override {
839 auto module = op->getParentOfType<ModuleOp>();
840 MLIRContext *context = rewriter.getContext();
842 auto startType = IntegerType::get(context, 32);
843 auto offsetsType = VectorType::get({2}, IntegerType::get(context, 32));
844 auto confType = VectorType::get({2}, IntegerType::get(context, 32));
848 auto func =
module.lookupSymbol<LLVM::LLVMFuncOp>(
849 StringAttr::get(context, intrinsicName));
852 OpBuilder::InsertionGuard guard(rewriter);
853 rewriter.setInsertionPointToStart(module.getBody());
854 func = LLVM::LLVMFuncOp::create(
855 rewriter, rewriter.getUnknownLoc(), intrinsicName,
856 LLVM::LLVMFunctionType::get(op.getResult().getType(),
857 {op.getLhs().getType(),
858 op.getRhs().getType(),
870 op.getXstart().getAsInteger(0, x.
start);
871 op.getXoffsets().getAsInteger(0, x.
offsets);
872 op.getXoffsetsHi().getAsInteger(0, x.
offsets_hi);
873 op.getXstep().getAsInteger(0, x.
step);
874 op.getXsquare().getAsInteger(0, x.
square);
875 op.getZstart().getAsInteger(0, z.
start);
876 op.getZoffsets().getAsInteger(0, z.
offsets);
877 op.getZoffsetsHi().getAsInteger(0, z.
offsets_hi);
878 op.getZstep().getAsInteger(0, z.
step);
879 op.getZsquare().getAsInteger(0, z.
square);
882 uint32_t conf[2] = {0, 0};
886 auto xstartVal = LLVM::ConstantOp::create(
887 rewriter, op->getLoc(), startType, rewriter.getI32IntegerAttr(x.
start));
888 auto ystartVal = LLVM::ConstantOp::create(rewriter, op->getLoc(), startType,
889 rewriter.getI32IntegerAttr(0));
890 auto zstartVal = LLVM::ConstantOp::create(
891 rewriter, op->getLoc(), startType, rewriter.getI32IntegerAttr(z.
start));
892 auto xoffsetsVal = LLVM::ConstantOp::create(
893 rewriter, op->getLoc(), offsetsType,
894 rewriter.getI32VectorAttr({(int32_t)x.offsets, (int32_t)x.offsets_hi}));
895 auto zoffsetsVal = LLVM::ConstantOp::create(
896 rewriter, op->getLoc(), offsetsType,
897 rewriter.getI32VectorAttr({(int32_t)z.offsets, (int32_t)z.offsets_hi}));
898 auto confVal = LLVM::ConstantOp::create(
899 rewriter, op->getLoc(), confType,
900 rewriter.getI32VectorAttr({(int32_t)conf[0], (int32_t)conf[1]}));
901 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
903 ValueRange{op.getLhs(), op.getRhs(), xstartVal, ystartVal, zstartVal,
904 xoffsetsVal, zoffsetsVal, confVal});
910 :
public mlir::ConvertOpToLLVMPattern<aievec::MulElemOp> {
912 using ConvertOpToLLVMPattern<aievec::MulElemOp>::ConvertOpToLLVMPattern;
916 : ConvertOpToLLVMPattern(typeConverter),
938 auto lhs = op.getLhs();
939 auto lhsVecTy = cast<VectorType>(lhs.getType());
940 auto lhsScaTy = lhsVecTy.getElementType();
941 unsigned lhsBitWidth = lhsScaTy.getIntOrFloatBitWidth();
944 if (llvm::isa<IntegerType>(lhsScaTy)) {
945 if (lhsBitWidth == 8) {
947 aiev2_vmac_compute_control(
952 }
else if (lhsBitWidth == 16) {
954 aiev2_vmac_compute_control(
959 }
else if (lhsBitWidth == 32) {
965 if (lhsBitWidth == 16) {
967 aiev2_vmac_compute_control(
972 }
else if (lhsBitWidth == 32) {
1013 ConversionPatternRewriter &rewriter)
const {
1015 Location loc = op.getLoc();
1016 auto zeroCst = LLVM::ConstantOp::create(
1017 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
1018 auto a0 = adaptor.getLhs();
1019 auto a1 = xllvm::VectorBroadcast32I512IntrOp::create(
1020 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()), zeroCst);
1021 auto b0 = adaptor.getRhs();
1022 auto b1 = xllvm::UndefV16I32IntrOp::create(
1023 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()));
1026 auto a_lo = xllvm::VectorShuffleIntrOp::create(
1027 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()), a0, a1,
1028 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1029 rewriter.getI32IntegerAttr(2)));
1030 auto a_hi = xllvm::VectorShuffleIntrOp::create(
1031 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()), a0, a1,
1032 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1033 rewriter.getI32IntegerAttr(3)));
1034 auto b_lo = xllvm::VectorShuffleIntrOp::create(
1035 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()), b0, b1,
1036 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1037 rewriter.getI32IntegerAttr(2)));
1038 auto b_hi = xllvm::VectorShuffleIntrOp::create(
1039 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()), b0, b1,
1040 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1041 rewriter.getI32IntegerAttr(3)));
1043 auto mulConfCst = LLVM::ConstantOp::create(
1044 rewriter, loc, rewriter.getI32Type(),
1045 rewriter.getI32IntegerAttr(aiev2_vmac_compute_control(
1049 auto mulConfOp = xllvm::MulConfAcc64IntrOp::create(
1050 rewriter, loc, VectorType::get({16}, rewriter.getI64Type()),
1051 forceCastOperandsToSignature(
1053 {a_hi, b_hi, mulConfCst},
1055 {VectorType::get({64}, rewriter.getI8Type()),
1056 VectorType::get({16}, rewriter.getI32Type()),
1057 rewriter.getI32Type()}));
1059 auto createMacConfOp = [&](SmallVector<Value> operands,
1060 int macConf) -> Value {
1062 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1063 rewriter.getI32IntegerAttr(macConf)));
1064 return xllvm::MacConfAcc64IntrOp::create(
1065 rewriter, loc, VectorType::get({16}, rewriter.getI64Type()),
1066 forceCastOperandsToSignature(
1070 {VectorType::get({64}, rewriter.getI8Type()),
1071 VectorType::get({16}, rewriter.getI32Type()),
1072 VectorType::get({16}, rewriter.getI64Type()),
1073 rewriter.getI32Type()}))
1076 auto acc64Val = mulConfOp.getResult();
1077 acc64Val = createMacConfOp(
1078 SmallVector<Value>{a_hi, b_lo, acc64Val},
1079 aiev2_vmac_compute_control(
1083 acc64Val = createMacConfOp(
1084 SmallVector<Value>{a_lo, b_hi, acc64Val},
1085 aiev2_vmac_compute_control(
1089 acc64Val = createMacConfOp(
1090 SmallVector<Value>{a_lo, b_lo, acc64Val},
1091 aiev2_vmac_compute_control(
1098 forceCastValueToType(rewriter, loc, acc64Val, op.getResult().getType());
1099 rewriter.replaceOp(op, resultVal);
1141 ConversionPatternRewriter &rewriter)
const {
1142 Location loc = op.getLoc();
1144 LLVM::ConstantOp::create(rewriter, loc, rewriter.getBF16Type(),
1145 rewriter.getZeroAttr(rewriter.getBF16Type()));
1146 auto aZeros = xllvm::VectorBroadcast16BF512IntrOp::create(
1147 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
1148 auto bZeros = xllvm::VectorBroadcast16BF512IntrOp::create(
1149 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
1150 auto cZeros = xllvm::VectorBroadcast16BF512IntrOp::create(
1151 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
1152 auto dZeros = xllvm::VectorBroadcast16BF512IntrOp::create(
1153 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
1154 auto eZeros = xllvm::VectorBroadcast16BF512IntrOp::create(
1155 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
1156 auto fZeros = xllvm::VectorBroadcast16BF512IntrOp::create(
1157 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
1159 LLVM::ConstantOp::create(rewriter, loc, rewriter.getBF16Type(),
1160 rewriter.getOneAttr(rewriter.getBF16Type()));
1161 auto dummy0 = xllvm::VectorBroadcast16BF512IntrOp::create(
1162 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()), oneCst);
1163 auto zeroCstI32 = LLVM::ConstantOp::create(
1164 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
1165 auto mscMacMulConfCst = LLVM::ConstantOp::create(
1166 rewriter, loc, rewriter.getI32Type(),
1167 rewriter.getI32IntegerAttr(aiev2_vmac_compute_control(
1172 auto extractV16FP32ToThreeV16BF16 =
1173 [&](Value inputV16FP32, Value aZeros, Value bZeros,
1174 Value cZeros) -> std::tuple<Value, Value, Value> {
1176 auto inputBitCasted =
1177 forceCastValueToType(rewriter, loc, inputV16FP32,
1178 VectorType::get({8}, rewriter.getI64Type()));
1179 auto v1ToBF16 = xllvm::Vector16AccFloatToV16BF16AIE2IntrOp::create(
1180 rewriter, loc, VectorType::get({16}, rewriter.getBF16Type()),
1182 auto a = xllvm::UpdBF512BF256IntrOp::create(
1183 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()), aZeros,
1184 v1ToBF16, zeroCstI32);
1187 auto acc0 = xllvm::MscConfBF16IntrOp::create(
1188 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()), a, dummy0,
1189 inputBitCasted, mscMacMulConfCst);
1192 auto acc0ToBF16 = xllvm::Vector16AccFloatToV16BF16AIE2IntrOp::create(
1193 rewriter, loc, VectorType::get({16}, rewriter.getBF16Type()), acc0);
1194 auto b = xllvm::UpdBF512BF256IntrOp::create(
1195 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()), bZeros,
1196 acc0ToBF16, zeroCstI32);
1199 auto acc0Mscb = xllvm::MscConfBF16IntrOp::create(
1200 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()), b, dummy0,
1201 acc0, mscMacMulConfCst);
1202 auto acc0MscbToBF16 = xllvm::Vector16AccFloatToV16BF16AIE2IntrOp::create(
1203 rewriter, loc, VectorType::get({16}, rewriter.getBF16Type()),
1205 auto c = xllvm::UpdBF512BF256IntrOp::create(
1206 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()), cZeros,
1207 acc0MscbToBF16, zeroCstI32);
1208 return std::make_tuple(a.getResult(), b.getResult(), c.getResult());
1213 extractV16FP32ToThreeV16BF16(adaptor.getLhs(), aZeros, bZeros, cZeros);
1216 extractV16FP32ToThreeV16BF16(adaptor.getRhs(), dZeros, eZeros, fZeros);
1219 auto createMacOps = [&](Value lhs, Value rhs, Value acc) -> Value {
1220 return xllvm::MacConfBF16IntrOp::create(
1221 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()),
1222 lhs, rhs, acc, mscMacMulConfCst)
1236 auto afMul = xllvm::MulConfBF16IntrOp::create(
1237 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()), a, f,
1239 finalMacVal = createMacOps(
1244 createMacOps(d, c, createMacOps(b, e, afMul)))));
1254 auto bdMul = xllvm::MulConfBF16IntrOp::create(
1255 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()), b, d,
1257 finalMacVal = createMacOps(a, d, createMacOps(a, e, bdMul));
1263 auto cfMul = xllvm::MulConfBF16IntrOp::create(
1264 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()), c, f,
1266 finalMacVal = createMacOps(
1279 createMacOps(c, e, cfMul))))))));
1283 auto resultVal = forceCastValueToType(rewriter, loc, finalMacVal,
1284 op.getResult().getType());
1285 rewriter.replaceOp(op, resultVal);
1291 ConversionPatternRewriter &rewriter)
const override {
1292 Location loc = op.getLoc();
1296 op.emitWarning() <<
"aievec.mul_elem conversion is not supported.\n";
1303 }
else if (decodedMulElemOp.kind ==
1309 auto confCst = LLVM::ConstantOp::create(
1310 rewriter, loc, rewriter.getI32Type(),
1311 rewriter.getI32IntegerAttr(decodedMulElemOp.conf));
1312 Value mulElemOp =
nullptr;
1313 SmallVector<Value> operands({adaptor.getLhs(), adaptor.getRhs(), confCst});
1318 mulElemOp = xllvm::MulConfAcc32IntrOp::create(
1319 rewriter, loc, VectorType::get({16}, rewriter.getI64Type()),
1320 forceCastOperandsToSignature(
1321 rewriter, loc, operands,
1322 {VectorType::get({64}, rewriter.getI8Type()),
1323 VectorType::get({16}, rewriter.getI32Type()),
1324 rewriter.getI32Type()}));
1325 }
else if (decodedMulElemOp.kind ==
1329 auto zero32 = LLVM::ConstantOp::create(
1330 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
1331 auto zeros_i16 = xllvm::VectorBroadcast16I512IntrOp::create(
1332 rewriter, loc, VectorType::get({32}, rewriter.getI16Type()), zero32);
1333 auto zeros_bf16 = LLVM::BitcastOp::create(
1334 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()),
1336 auto zeroVec = xllvm::ExtBF256BF512IntrOp::create(
1337 rewriter, loc, VectorType::get({16}, rewriter.getBF16Type()),
1338 zeros_bf16, zero32);
1341 auto idx1 = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1342 rewriter.getI32IntegerAttr(1));
1345 auto lhsSet = xllvm::VectorSetBF512BF256IntrOp::create(
1346 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()),
1347 adaptor.getLhs(), zero32);
1348 auto lhsConcat = xllvm::UpdBF512BF256IntrOp::create(
1349 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()), lhsSet,
1353 auto rhsSet = xllvm::VectorSetBF512BF256IntrOp::create(
1354 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()),
1355 adaptor.getRhs(), zero32);
1356 auto rhsConcat = xllvm::UpdBF512BF256IntrOp::create(
1357 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()), rhsSet,
1361 mulElemOp = xllvm::MulConfBF16IntrOp::create(
1362 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()), lhsConcat,
1363 rhsConcat, confCst);
1367 auto resultVal = forceCastValueToType(rewriter, loc, mulElemOp,
1368 op.getResult().getType());
1369 rewriter.replaceOp(op, resultVal);
1376 :
public mlir::ConvertOpToLLVMPattern<aievec::MulElemOp> {
1378 using ConvertOpToLLVMPattern<aievec::MulElemOp>::ConvertOpToLLVMPattern;
1392 auto lhs = op.getLhs();
1393 auto lhsVecTy = cast<VectorType>(lhs.getType());
1394 auto lhsScaTy = lhsVecTy.getElementType();
1395 unsigned lhsBitWidth = lhsScaTy.getIntOrFloatBitWidth();
1399 if (llvm::isa<IntegerType>(lhsScaTy)) {
1403 if (lhsBitWidth == 16) {
1405 if (lhsLanes == 16) {
1408 }
else if (lhsLanes == 32) {
1411 }
else if (lhsLanes == 64) {
1422 ConversionPatternRewriter &rewriter)
const override {
1423 Location loc = op.getLoc();
1427 op.emitWarning() <<
"aievec.mul_elem conversion is not supported for "
1433 auto confCst = LLVM::ConstantOp::create(
1434 rewriter, loc, rewriter.getI32Type(),
1435 rewriter.getI32IntegerAttr(decodedMulElemOp.conf));
1437 Value mulElemOp =
nullptr;
1440 if (decodedMulElemOp.kind ==
1446 SmallVector<int64_t> padMask;
1447 for (
int i = 0; i < 16; ++i)
1448 padMask.push_back(i);
1449 for (
int i = 16; i < 32; ++i)
1450 padMask.push_back(-1);
1452 auto lhsPadded = vector::ShuffleOp::create(
1453 rewriter, loc, adaptor.getLhs(), adaptor.getLhs(), padMask);
1454 auto rhsPadded = vector::ShuffleOp::create(
1455 rewriter, loc, adaptor.getRhs(), adaptor.getRhs(), padMask);
1457 SmallVector<Value> operands({lhsPadded, rhsPadded, confCst});
1460 mulElemOp = xllvm::MulConfBF16I512ACC512AIE2pIntrOp::create(
1461 rewriter, loc, VectorType::get({16}, rewriter.getF32Type()),
1462 forceCastOperandsToSignature(
1463 rewriter, loc, operands,
1464 {VectorType::get({32}, rewriter.getBF16Type()),
1465 VectorType::get({32}, rewriter.getBF16Type()),
1466 rewriter.getI32Type()}));
1467 }
else if (decodedMulElemOp.kind ==
1470 SmallVector<Value> operands(
1471 {adaptor.getLhs(), adaptor.getRhs(), confCst});
1472 mulElemOp = xllvm::MulConfBF16I512ACC1024AIE2pIntrOp::create(
1473 rewriter, loc, VectorType::get({32}, rewriter.getF32Type()),
1474 forceCastOperandsToSignature(
1475 rewriter, loc, operands,
1476 {VectorType::get({32}, rewriter.getBF16Type()),
1477 VectorType::get({32}, rewriter.getBF16Type()),
1478 rewriter.getI32Type()}));
1479 }
else if (decodedMulElemOp.kind ==
1482 SmallVector<Value> operands(
1483 {adaptor.getLhs(), adaptor.getRhs(), confCst});
1484 mulElemOp = xllvm::MulConfBF16I1024ACC2048AIE2pIntrOp::create(
1485 rewriter, loc, VectorType::get({64}, rewriter.getF32Type()),
1486 forceCastOperandsToSignature(
1487 rewriter, loc, operands,
1488 {VectorType::get({64}, rewriter.getBF16Type()),
1489 VectorType::get({64}, rewriter.getBF16Type()),
1490 rewriter.getI32Type()}));
1494 auto resultVal = forceCastValueToType(rewriter, loc, mulElemOp,
1495 op.getResult().getType());
1496 rewriter.replaceOp(op, resultVal);
1504 :
public mlir::ConvertOpToLLVMPattern<aievec::FMAElemOp> {
1506 using ConvertOpToLLVMPattern<aievec::FMAElemOp>::ConvertOpToLLVMPattern;
1510 ConversionPatternRewriter &rewriter)
const override {
1511 auto loc = fmaOp.getLoc();
1512 auto lhs = adaptor.getLhs();
1513 auto rhs = adaptor.getRhs();
1514 auto acc = adaptor.getAcc();
1515 auto lhsTy = cast<VectorType>(lhs.getType());
1516 auto accTy = cast<VectorType>(acc.getType());
1521 if (lhsTy != flatLhsTy)
1522 lhs = vector::ShapeCastOp::create(rewriter, loc, flatLhsTy, lhs);
1523 if (cast<VectorType>(rhs.getType()) != flatLhsTy)
1524 rhs = vector::ShapeCastOp::create(rewriter, loc, flatLhsTy, rhs);
1525 if (accTy != flatAccTy)
1526 acc = vector::ShapeCastOp::create(rewriter, loc, flatAccTy, acc);
1528 if (!flatLhsTy.getElementType().isBF16()) {
1530 <<
"aievec.mac_elem AIE2p conversion only supports bf16 inputs.\n";
1534 Type i32ty = rewriter.getI32Type();
1535 auto confCst = LLVM::ConstantOp::create(
1536 rewriter, loc, i32ty,
1537 rewriter.getI32IntegerAttr(aiev2_vmac_compute_control(
1543 unsigned lhsLanes = flatLhsTy.getNumElements();
1544 Value macIntrOp =
nullptr;
1546 if (lhsLanes == 16) {
1548 SmallVector<int64_t> padMask;
1549 for (
int i = 0; i < 16; ++i)
1550 padMask.push_back(i);
1551 for (
int i = 16; i < 32; ++i)
1552 padMask.push_back(-1);
1555 vector::ShuffleOp::create(rewriter, loc, lhs, lhs, padMask);
1557 vector::ShuffleOp::create(rewriter, loc, rhs, rhs, padMask);
1559 auto v32bf16Ty = VectorType::get({32}, rewriter.getBF16Type());
1560 auto v16f32Ty = VectorType::get({16}, rewriter.getF32Type());
1561 macIntrOp = xllvm::MacConfBF16I512ACC512AIE2pIntrOp::create(
1562 rewriter, loc, v16f32Ty,
1563 forceCastOperandsToSignature(
1564 rewriter, loc, {lhsPadded, rhsPadded, acc, confCst},
1565 {v32bf16Ty, v32bf16Ty, v16f32Ty, i32ty}));
1566 }
else if (lhsLanes == 32) {
1568 auto v32bf16Ty = VectorType::get({32}, rewriter.getBF16Type());
1569 auto v32f32Ty = VectorType::get({32}, rewriter.getF32Type());
1570 macIntrOp = xllvm::MacConfBF16I512ACC1024AIE2pIntrOp::create(
1571 rewriter, loc, v32f32Ty,
1572 forceCastOperandsToSignature(
1573 rewriter, loc, {lhs, rhs, acc, confCst},
1574 {v32bf16Ty, v32bf16Ty, v32f32Ty, i32ty}));
1577 <<
"aievec.mac_elem AIE2p conversion: unsupported lane count "
1578 << lhsLanes <<
".\n";
1583 auto resVal = forceCastValueToType(rewriter, loc, macIntrOp, flatAccTy);
1584 if (flatAccTy != accTy)
1585 resVal = vector::ShapeCastOp::create(rewriter, loc, accTy, resVal);
1587 rewriter.replaceOp(fmaOp, resVal);
1600 using ConvertOpToLLVMPattern<aievec::UPSOp>::ConvertOpToLLVMPattern;
1604 ConversionPatternRewriter &rewriter)
const override {
1605 Location loc = op.getLoc();
1607 Value result = op.getResult();
1608 VectorType resultType = cast<VectorType>(result.getType());
1610 Type resultScaTy = resultType.getElementType();
1611 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
1613 int resultVectorSize = resultBitWidth * resultLanes;
1615 Value opSrcVal = adaptor.getSource();
1616 auto srcVecTy = cast<VectorType>(opSrcVal.getType());
1618 if (srcVecTy != fltSrcVecTy)
1619 opSrcVal = vector::ShapeCastOp::create(rewriter, op.getLoc(), fltSrcVecTy,
1625 Value upsIntrOp =
nullptr;
1626 if (llvm::isa<IntegerType>(resultScaTy)) {
1628 auto signCst = LLVM::ConstantOp::create(
1629 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
1631 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1632 rewriter.getI32IntegerAttr(op.getShift()));
1634 SmallVector<Value> operands({opSrcVal, shiftCst, signCst});
1635 if (resultVectorSize == 512) {
1636 if (resultBitWidth == 32) {
1638 upsIntrOp = xllvm::Acc32V16I256UpsAIE2IntrOp::create(
1639 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()),
1640 forceCastOperandsToSignature(
1641 rewriter, loc, operands,
1642 {VectorType::get({16}, rewriter.getI16Type()),
1643 rewriter.getI32Type(), rewriter.getI32Type()}));
1644 }
else if (resultBitWidth == 64) {
1646 upsIntrOp = xllvm::Acc64V8I256UpsAIE2IntrOp::create(
1647 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()),
1648 forceCastOperandsToSignature(
1649 rewriter, loc, operands,
1650 {VectorType::get({8}, rewriter.getI32Type()),
1651 rewriter.getI32Type(), rewriter.getI32Type()}));
1653 }
else if (resultVectorSize == 1024) {
1654 Value src = opSrcVal;
1655 VectorType srcType = cast<VectorType>(src.getType());
1656 Type srcScaType = srcType.getElementType();
1657 unsigned srcBitWidth = srcScaType.getIntOrFloatBitWidth();
1659 if (resultBitWidth == 32 && srcBitWidth == 16) {
1661 upsIntrOp = xllvm::Acc32V32I512UpsAIE2IntrOp::create(
1662 rewriter, loc, VectorType::get({16}, rewriter.getI64Type()),
1663 forceCastOperandsToSignature(
1664 rewriter, loc, operands,
1665 {VectorType::get({32}, rewriter.getI16Type()),
1666 rewriter.getI32Type(), rewriter.getI32Type()}));
1667 }
else if (resultBitWidth == 64 && srcBitWidth == 32) {
1669 upsIntrOp = xllvm::Acc64V16I512UpsAIE2IntrOp::create(
1670 rewriter, loc, VectorType::get({16}, rewriter.getI64Type()),
1671 forceCastOperandsToSignature(
1672 rewriter, loc, operands,
1673 {VectorType::get({16}, rewriter.getI32Type()),
1674 rewriter.getI32Type(), rewriter.getI32Type()}));
1675 }
else if (resultBitWidth == 64 && srcBitWidth == 16) {
1677 upsIntrOp = xllvm::Acc64V16I256UpsAIE2IntrOp::create(
1678 rewriter, loc, VectorType::get({16}, rewriter.getI64Type()),
1679 forceCastOperandsToSignature(
1680 rewriter, loc, operands,
1681 {VectorType::get({16}, rewriter.getI16Type()),
1682 rewriter.getI32Type(), rewriter.getI32Type()}));
1683 }
else if (resultBitWidth == 32 && srcBitWidth == 8) {
1685 upsIntrOp = xllvm::Acc32V32I256UpsAIE2IntrOp::create(
1686 rewriter, loc, VectorType::get({16}, rewriter.getI64Type()),
1687 forceCastOperandsToSignature(
1688 rewriter, loc, operands,
1689 {VectorType::get({32}, rewriter.getI8Type()),
1690 rewriter.getI32Type(), rewriter.getI32Type()}));
1696 if (resultVectorSize == 512) {
1698 upsIntrOp = xllvm::Vector16BF16ToV16AccFloatAIE2IntrOp::create(
1699 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()),
1700 forceCastOperandsToSignature(
1701 rewriter, loc, {opSrcVal},
1702 {VectorType::get({16}, rewriter.getBF16Type())}));
1703 }
else if (resultVectorSize == 1024) {
1712 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1713 rewriter.getI32IntegerAttr(0));
1715 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1716 rewriter.getI32IntegerAttr(1));
1717 auto extractUps = [&](Value source, Value index) -> Value {
1718 auto extOp = xllvm::ExtI256I512IntrOp::create(
1719 rewriter, loc, VectorType::get({8}, rewriter.getI32Type()),
1720 forceCastOperandsToSignature(
1721 rewriter, loc, {source, index},
1722 {VectorType::get({16}, rewriter.getI32Type()),
1723 rewriter.getI32Type()}));
1724 return xllvm::Vector16BF16ToV16AccFloatAIE2IntrOp::create(
1725 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()),
1726 forceCastOperandsToSignature(
1727 rewriter, loc, {extOp},
1728 {VectorType::get({16}, rewriter.getBF16Type())}));
1730 auto resLo = extractUps(opSrcVal, indexZeroCst);
1731 auto resHi = extractUps(opSrcVal, indexOneCst);
1734 upsIntrOp = xllvm::ConcatI1024I512IntrOp::create(
1735 rewriter, loc, VectorType::get({32}, rewriter.getI32Type()),
1736 forceCastOperandsToSignature(
1737 rewriter, loc, {resLo, resHi},
1738 {VectorType::get({16}, rewriter.getI32Type()),
1739 VectorType::get({16}, rewriter.getI32Type())}));
1744 op.emitWarning() <<
"aievec.ups is not supported.\n";
1749 if (flatResTy != upsIntrOp.getType())
1750 upsIntrOp = LLVM::BitcastOp::create(rewriter, loc, flatResTy, upsIntrOp);
1752 if (flatResTy != resultType)
1754 vector::ShapeCastOp::create(rewriter, loc, resultType, upsIntrOp);
1756 rewriter.replaceOp(op, upsIntrOp);
1764 :
public mlir::ConvertOpToLLVMPattern<aievec::UPSOp> {
1766 using ConvertOpToLLVMPattern<aievec::UPSOp>::ConvertOpToLLVMPattern;
1770 ConversionPatternRewriter &rewriter)
const override {
1771 Location loc = op.getLoc();
1773 Value result = op.getResult();
1774 VectorType resultType = cast<VectorType>(result.getType());
1776 Type resultScaTy = resultType.getElementType();
1777 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
1779 int resultVectorSize = resultBitWidth * resultLanes;
1781 Value opSrcVal = adaptor.getSource();
1782 auto srcVecTy = cast<VectorType>(opSrcVal.getType());
1784 if (srcVecTy != fltSrcVecTy)
1785 opSrcVal = vector::ShapeCastOp::create(rewriter, op.getLoc(), fltSrcVecTy,
1791 Value upsIntrOp =
nullptr;
1792 if (llvm::isa<IntegerType>(resultScaTy)) {
1794 auto signCst = LLVM::ConstantOp::create(
1795 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
1797 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1798 rewriter.getI32IntegerAttr(op.getShift()));
1800 SmallVector<Value> operands({opSrcVal, shiftCst, signCst});
1801 if (resultVectorSize == 512) {
1802 if (resultBitWidth == 32) {
1804 upsIntrOp = xllvm::Acc32V16I256UpsAIE2pIntrOp::create(
1805 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()),
1806 forceCastOperandsToSignature(
1807 rewriter, loc, operands,
1808 {VectorType::get({16}, rewriter.getI16Type()),
1809 rewriter.getI32Type(), rewriter.getI32Type()}));
1810 }
else if (resultBitWidth == 64) {
1812 upsIntrOp = xllvm::Acc64V8I256UpsAIE2pIntrOp::create(
1813 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()),
1814 forceCastOperandsToSignature(
1815 rewriter, loc, operands,
1816 {VectorType::get({8}, rewriter.getI32Type()),
1817 rewriter.getI32Type(), rewriter.getI32Type()}));
1819 }
else if (resultVectorSize == 1024) {
1820 Value src = opSrcVal;
1821 VectorType srcType = cast<VectorType>(src.getType());
1822 Type srcScaType = srcType.getElementType();
1823 unsigned srcBitWidth = srcScaType.getIntOrFloatBitWidth();
1825 int srcVectorSize = srcBitWidth * srcLanes;
1827 if (resultBitWidth == 32 && srcBitWidth == 16 && srcVectorSize == 512) {
1829 upsIntrOp = xllvm::Acc32V32I512UpsAIE2pIntrOp::create(
1830 rewriter, loc, VectorType::get({32}, rewriter.getI32Type()),
1831 forceCastOperandsToSignature(
1832 rewriter, loc, operands,
1833 {VectorType::get({32}, rewriter.getI16Type()),
1834 rewriter.getI32Type(), rewriter.getI32Type()}));
1835 }
else if (resultBitWidth == 64 && srcBitWidth == 32 &&
1836 srcVectorSize == 512) {
1838 upsIntrOp = xllvm::Acc64V16I512UpsAIE2pIntrOp::create(
1839 rewriter, loc, VectorType::get({16}, rewriter.getI64Type()),
1840 forceCastOperandsToSignature(
1841 rewriter, loc, operands,
1842 {VectorType::get({16}, rewriter.getI32Type()),
1843 rewriter.getI32Type(), rewriter.getI32Type()}));
1844 }
else if (resultBitWidth == 64 && srcBitWidth == 16 &&
1845 srcVectorSize == 256) {
1847 upsIntrOp = xllvm::Acc64V16I256UpsAIE2pIntrOp::create(
1848 rewriter, loc, VectorType::get({16}, rewriter.getI64Type()),
1849 forceCastOperandsToSignature(
1850 rewriter, loc, operands,
1851 {VectorType::get({16}, rewriter.getI16Type()),
1852 rewriter.getI32Type(), rewriter.getI32Type()}));
1853 }
else if (resultBitWidth == 32 && srcBitWidth == 8 &&
1854 srcVectorSize == 256) {
1856 upsIntrOp = xllvm::Acc32V32I256UpsAIE2pIntrOp::create(
1857 rewriter, loc, VectorType::get({32}, rewriter.getI32Type()),
1858 forceCastOperandsToSignature(
1859 rewriter, loc, operands,
1860 {VectorType::get({32}, rewriter.getI8Type()),
1861 rewriter.getI32Type(), rewriter.getI32Type()}));
1863 }
else if (resultVectorSize == 2048) {
1864 Value src = opSrcVal;
1865 VectorType srcType = cast<VectorType>(src.getType());
1866 Type srcScaType = srcType.getElementType();
1867 unsigned srcBitWidth = srcScaType.getIntOrFloatBitWidth();
1869 int srcVectorSize = srcBitWidth * srcLanes;
1871 if (resultBitWidth == 32 && srcBitWidth == 8 && srcVectorSize == 512) {
1873 upsIntrOp = xllvm::Acc32V64I512UpsAIE2pIntrOp::create(
1874 rewriter, loc, VectorType::get({64}, rewriter.getI32Type()),
1875 forceCastOperandsToSignature(
1876 rewriter, loc, operands,
1877 {VectorType::get({64}, rewriter.getI8Type()),
1878 rewriter.getI32Type(), rewriter.getI32Type()}));
1879 }
else if (resultBitWidth == 64 && srcBitWidth == 16 &&
1880 srcVectorSize == 512) {
1882 upsIntrOp = xllvm::Acc64V32I512UpsAIE2pIntrOp::create(
1883 rewriter, loc, VectorType::get({32}, rewriter.getI64Type()),
1884 forceCastOperandsToSignature(
1885 rewriter, loc, operands,
1886 {VectorType::get({32}, rewriter.getI16Type()),
1887 rewriter.getI32Type(), rewriter.getI32Type()}));
1888 }
else if (resultBitWidth == 32 && srcBitWidth == 16 &&
1889 srcVectorSize == 1024) {
1893 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1894 rewriter.getI32IntegerAttr(0));
1896 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1897 rewriter.getI32IntegerAttr(1));
1899 auto extractUps2048 = [&](Value source, Value index, Value shiftCst,
1900 Value signCst) -> Value {
1903 auto v32i32Source = forceCastValueToType(
1904 rewriter, loc, source,
1905 VectorType::get({32}, rewriter.getI32Type()));
1910 SmallVector<int64_t> shuffleMask;
1911 if (
auto constIndex = index.getDefiningOp<LLVM::ConstantOp>()) {
1912 auto indexAttr = cast<IntegerAttr>(constIndex.getValue());
1913 int64_t idxVal = indexAttr.getInt();
1914 int startIdx = idxVal * 16;
1915 for (
int i = 0; i < 16; ++i) {
1916 shuffleMask.push_back(startIdx + i);
1920 for (
int i = 0; i < 16; ++i) {
1921 shuffleMask.push_back(i);
1925 auto extOp = vector::ShuffleOp::create(rewriter, loc, v32i32Source,
1926 v32i32Source, shuffleMask);
1928 return xllvm::Acc32V32I512UpsAIE2pIntrOp::create(
1929 rewriter, loc, VectorType::get({32}, rewriter.getI32Type()),
1930 forceCastOperandsToSignature(
1931 rewriter, loc, {extOp, shiftCst, signCst},
1932 {VectorType::get({32}, rewriter.getI16Type()),
1933 rewriter.getI32Type(), rewriter.getI32Type()}));
1936 auto res0 = extractUps2048(opSrcVal, index0Cst, shiftCst, signCst);
1937 auto res1 = extractUps2048(opSrcVal, index1Cst, shiftCst, signCst);
1941 SmallVector<int64_t> concatMask;
1942 for (
int i = 0; i < 64; ++i) {
1943 concatMask.push_back(i);
1946 vector::ShuffleOp::create(rewriter, loc, res0, res1, concatMask);
1952 if (resultVectorSize == 512) {
1954 upsIntrOp = xllvm::Vector16BF16ToV16AccFloatAIE2pIntrOp::create(
1955 rewriter, loc, VectorType::get({16}, rewriter.getF32Type()),
1956 forceCastOperandsToSignature(
1957 rewriter, loc, {opSrcVal},
1958 {VectorType::get({16}, rewriter.getBF16Type())}));
1959 }
else if (resultVectorSize == 1024) {
1961 upsIntrOp = xllvm::Vector32BF16ToV32AccFloatAIE2pIntrOp::create(
1962 rewriter, loc, VectorType::get({32}, rewriter.getF32Type()),
1963 forceCastOperandsToSignature(
1964 rewriter, loc, {opSrcVal},
1965 {VectorType::get({32}, rewriter.getBF16Type())}));
1966 }
else if (resultVectorSize == 2048) {
1970 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1971 rewriter.getI32IntegerAttr(0));
1973 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1974 rewriter.getI32IntegerAttr(1));
1976 auto extractUps2048 = [&](Value source, Value index) -> Value {
1979 auto v32i32Source = forceCastValueToType(
1980 rewriter, loc, source,
1981 VectorType::get({32}, rewriter.getI32Type()));
1986 SmallVector<int64_t> shuffleMask;
1987 if (
auto constIndex = index.getDefiningOp<LLVM::ConstantOp>()) {
1988 auto indexAttr = cast<IntegerAttr>(constIndex.getValue());
1989 int64_t idxVal = indexAttr.getInt();
1990 int startIdx = idxVal * 16;
1991 for (
int i = 0; i < 16; ++i) {
1992 shuffleMask.push_back(startIdx + i);
1996 for (
int i = 0; i < 16; ++i) {
1997 shuffleMask.push_back(i);
2001 auto extOp = vector::ShuffleOp::create(rewriter, loc, v32i32Source,
2002 v32i32Source, shuffleMask);
2004 return xllvm::Vector32BF16ToV32AccFloatAIE2pIntrOp::create(
2005 rewriter, loc, VectorType::get({32}, rewriter.getF32Type()),
2006 forceCastOperandsToSignature(
2007 rewriter, loc, {extOp},
2008 {VectorType::get({32}, rewriter.getBF16Type())}));
2011 auto res0 = extractUps2048(opSrcVal, index0Cst);
2012 auto res1 = extractUps2048(opSrcVal, index1Cst);
2016 auto v32i32Res0 = forceCastValueToType(
2017 rewriter, loc, res0, VectorType::get({32}, rewriter.getI32Type()));
2018 auto v32i32Res1 = forceCastValueToType(
2019 rewriter, loc, res1, VectorType::get({32}, rewriter.getI32Type()));
2021 SmallVector<int64_t> concatMask;
2022 for (
int i = 0; i < 64; ++i) {
2023 concatMask.push_back(i);
2025 upsIntrOp = vector::ShuffleOp::create(rewriter, loc, v32i32Res0,
2026 v32i32Res1, concatMask);
2031 op.emitWarning() <<
"aievec.ups is not supported.\n";
2036 if (flatResTy != upsIntrOp.getType())
2037 upsIntrOp = LLVM::BitcastOp::create(rewriter, loc, flatResTy, upsIntrOp);
2039 if (flatResTy != resultType)
2041 vector::ShapeCastOp::create(rewriter, loc, resultType, upsIntrOp);
2043 rewriter.replaceOp(op, upsIntrOp);
2051 using ConvertOpToLLVMPattern<aievec::SRSOp>::ConvertOpToLLVMPattern;
2055 ConversionPatternRewriter &rewriter)
const override {
2056 Location loc = op.getLoc();
2058 Value result = op.getResult();
2059 VectorType resultType = cast<VectorType>(result.getType());
2060 Type resultScaTy = resultType.getElementType();
2061 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
2063 int resultVectorSize = resultBitWidth * resultLanes;
2066 Value srsIntrOp =
nullptr;
2067 if (llvm::isa<IntegerType>(resultScaTy)) {
2070 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
2071 rewriter.getI32IntegerAttr(op.getSign()));
2074 SmallVector<Value> operands(
2075 {adaptor.getSource(), adaptor.getShift(), signCst});
2076 if (resultVectorSize == 512) {
2077 if (resultBitWidth == 16) {
2078 srsIntrOp = xllvm::I512V32Acc32SrsAIE2IntrOp::create(
2079 rewriter, loc, VectorType::get({32}, rewriter.getI16Type()),
2080 forceCastOperandsToSignature(
2081 rewriter, loc, operands,
2082 {VectorType::get({16}, rewriter.getI64Type()),
2083 rewriter.getI32Type(), rewriter.getI32Type()}));
2084 }
else if (resultBitWidth == 32) {
2085 srsIntrOp = xllvm::I512V16Acc64SrsAIE2IntrOp::create(
2086 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()),
2087 forceCastOperandsToSignature(
2088 rewriter, loc, operands,
2089 {VectorType::get({16}, rewriter.getI64Type()),
2090 rewriter.getI32Type(), rewriter.getI32Type()}));
2092 }
else if (resultVectorSize == 256) {
2093 Value src = adaptor.getSource();
2094 VectorType srcType = cast<VectorType>(src.getType());
2095 Type srcScaType = srcType.getElementType();
2096 unsigned srcBitWidth = srcScaType.getIntOrFloatBitWidth();
2098 if (resultBitWidth == 16 && srcBitWidth == 32) {
2099 srsIntrOp = xllvm::I256V16Acc32SrsAIE2IntrOp::create(
2100 rewriter, loc, VectorType::get({16}, rewriter.getI16Type()),
2101 forceCastOperandsToSignature(
2102 rewriter, loc, operands,
2103 {VectorType::get({8}, rewriter.getI64Type()),
2104 rewriter.getI32Type(), rewriter.getI32Type()}));
2105 }
else if (resultBitWidth == 8 && srcBitWidth == 32) {
2106 srsIntrOp = xllvm::I256V32Acc32SrsAIE2IntrOp::create(
2107 rewriter, loc, VectorType::get({32}, rewriter.getI8Type()),
2108 forceCastOperandsToSignature(
2109 rewriter, loc, operands,
2110 {VectorType::get({16}, rewriter.getI64Type()),
2111 rewriter.getI32Type(), rewriter.getI32Type()}));
2112 }
else if (resultBitWidth == 16 && srcBitWidth == 64) {
2113 srsIntrOp = xllvm::I256V16Acc64SrsAIE2IntrOp::create(
2114 rewriter, loc, VectorType::get({16}, rewriter.getI16Type()),
2115 forceCastOperandsToSignature(
2116 rewriter, loc, operands,
2117 {VectorType::get({16}, rewriter.getI64Type()),
2118 rewriter.getI32Type(), rewriter.getI32Type()}));
2119 }
else if (resultBitWidth == 32 && srcBitWidth == 64) {
2120 srsIntrOp = xllvm::I256V8Acc64SrsAIE2IntrOp::create(
2121 rewriter, loc, VectorType::get({8}, rewriter.getI32Type()),
2122 forceCastOperandsToSignature(
2123 rewriter, loc, operands,
2124 {VectorType::get({8}, rewriter.getI64Type()),
2125 rewriter.getI32Type(), rewriter.getI32Type()}));
2130 if (resultVectorSize == 256) {
2131 srsIntrOp = xllvm::Vector16AccFloatToV16BF16AIE2IntrOp::create(
2132 rewriter, loc, VectorType::get({16}, rewriter.getBF16Type()),
2133 forceCastOperandsToSignature(
2134 rewriter, loc, {adaptor.getSource()},
2135 {VectorType::get({8}, rewriter.getI64Type())}));
2136 }
else if (resultVectorSize == 512) {
2145 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
2146 rewriter.getI32IntegerAttr(0));
2148 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
2149 rewriter.getI32IntegerAttr(1));
2150 auto extractSrs = [&](Value source, Value index) -> Value {
2151 auto extOp = xllvm::ExtI512I1024IntrOp::create(
2152 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()),
2153 forceCastOperandsToSignature(
2154 rewriter, loc, {source, index},
2155 {VectorType::get({32}, rewriter.getI32Type()),
2156 rewriter.getI32Type()}));
2157 return xllvm::Vector16AccFloatToV16BF16AIE2IntrOp::create(
2158 rewriter, loc, VectorType::get({16}, rewriter.getBF16Type()),
2159 forceCastOperandsToSignature(
2160 rewriter, loc, {extOp},
2161 {VectorType::get({8}, rewriter.getI64Type())}));
2163 auto resLo = extractSrs(adaptor.getSource(), indexZeroCst);
2164 auto resHi = extractSrs(adaptor.getSource(), indexOneCst);
2167 srsIntrOp = xllvm::ConcatI512I256IntrOp::create(
2168 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()),
2169 forceCastOperandsToSignature(
2170 rewriter, loc, {resLo, resHi},
2171 {VectorType::get({8}, rewriter.getI32Type()),
2172 VectorType::get({8}, rewriter.getI32Type())}));
2177 op.emitWarning() <<
"aievec.srs is not supported.\n";
2182 auto resultVal = forceCastValueToType(rewriter, loc, srsIntrOp,
2183 op.getResult().getType());
2184 rewriter.replaceOp(op, resultVal);
2192 :
public mlir::ConvertOpToLLVMPattern<aievec::SRSOp> {
2194 using ConvertOpToLLVMPattern<aievec::SRSOp>::ConvertOpToLLVMPattern;
2198 ConversionPatternRewriter &rewriter)
const override {
2199 Location loc = op.getLoc();
2201 Value result = op.getResult();
2202 VectorType resultType = cast<VectorType>(result.getType());
2203 Type resultScaTy = resultType.getElementType();
2204 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
2206 int resultVectorSize = resultBitWidth * resultLanes;
2209 Value srsIntrOp =
nullptr;
2210 if (llvm::isa<IntegerType>(resultScaTy)) {
2213 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
2214 rewriter.getI32IntegerAttr(op.getSign()));
2217 SmallVector<Value> operands(
2218 {adaptor.getSource(), adaptor.getShift(), signCst});
2219 if (resultVectorSize == 512) {
2220 Value src = adaptor.getSource();
2221 VectorType srcType = cast<VectorType>(src.getType());
2222 Type srcScaType = srcType.getElementType();
2223 unsigned srcBitWidth = srcScaType.getIntOrFloatBitWidth();
2225 if (resultBitWidth == 16 && srcBitWidth == 32) {
2227 srsIntrOp = xllvm::I512V32Acc32SrsAIE2pIntrOp::create(
2228 rewriter, loc, VectorType::get({32}, rewriter.getI16Type()),
2229 forceCastOperandsToSignature(
2230 rewriter, loc, operands,
2231 {VectorType::get({32}, rewriter.getI32Type()),
2232 rewriter.getI32Type(), rewriter.getI32Type()}));
2233 }
else if (resultBitWidth == 16 && srcBitWidth == 64) {
2235 srsIntrOp = xllvm::I512V32Acc64SrsAIE2pIntrOp::create(
2236 rewriter, loc, VectorType::get({32}, rewriter.getI16Type()),
2237 forceCastOperandsToSignature(
2238 rewriter, loc, operands,
2239 {VectorType::get({32}, rewriter.getI64Type()),
2240 rewriter.getI32Type(), rewriter.getI32Type()}));
2241 }
else if (resultBitWidth == 32 && srcBitWidth == 64) {
2243 srsIntrOp = xllvm::I512V16Acc64SrsAIE2pIntrOp::create(
2244 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()),
2245 forceCastOperandsToSignature(
2246 rewriter, loc, operands,
2247 {VectorType::get({16}, rewriter.getI64Type()),
2248 rewriter.getI32Type(), rewriter.getI32Type()}));
2249 }
else if (resultBitWidth == 8 && srcBitWidth == 32) {
2251 srsIntrOp = xllvm::I512V64Acc32SrsAIE2pIntrOp::create(
2252 rewriter, loc, VectorType::get({64}, rewriter.getI8Type()),
2253 forceCastOperandsToSignature(
2254 rewriter, loc, operands,
2255 {VectorType::get({64}, rewriter.getI32Type()),
2256 rewriter.getI32Type(), rewriter.getI32Type()}));
2258 }
else if (resultVectorSize == 256) {
2259 Value src = adaptor.getSource();
2260 VectorType srcType = cast<VectorType>(src.getType());
2261 Type srcScaType = srcType.getElementType();
2262 unsigned srcBitWidth = srcScaType.getIntOrFloatBitWidth();
2264 if (resultBitWidth == 16 && srcBitWidth == 32) {
2266 srsIntrOp = xllvm::I256V16Acc32SrsAIE2pIntrOp::create(
2267 rewriter, loc, VectorType::get({16}, rewriter.getI16Type()),
2268 forceCastOperandsToSignature(
2269 rewriter, loc, operands,
2270 {VectorType::get({16}, rewriter.getI32Type()),
2271 rewriter.getI32Type(), rewriter.getI32Type()}));
2272 }
else if (resultBitWidth == 8 && srcBitWidth == 32) {
2274 srsIntrOp = xllvm::I256V32Acc32SrsAIE2pIntrOp::create(
2275 rewriter, loc, VectorType::get({32}, rewriter.getI8Type()),
2276 forceCastOperandsToSignature(
2277 rewriter, loc, operands,
2278 {VectorType::get({32}, rewriter.getI32Type()),
2279 rewriter.getI32Type(), rewriter.getI32Type()}));
2280 }
else if (resultBitWidth == 16 && srcBitWidth == 64) {
2282 srsIntrOp = xllvm::I256V16Acc64SrsAIE2pIntrOp::create(
2283 rewriter, loc, VectorType::get({16}, rewriter.getI16Type()),
2284 forceCastOperandsToSignature(
2285 rewriter, loc, operands,
2286 {VectorType::get({16}, rewriter.getI64Type()),
2287 rewriter.getI32Type(), rewriter.getI32Type()}));
2288 }
else if (resultBitWidth == 32 && srcBitWidth == 64) {
2290 srsIntrOp = xllvm::I256V8Acc64SrsAIE2pIntrOp::create(
2291 rewriter, loc, VectorType::get({8}, rewriter.getI32Type()),
2292 forceCastOperandsToSignature(
2293 rewriter, loc, operands,
2294 {VectorType::get({8}, rewriter.getI64Type()),
2295 rewriter.getI32Type(), rewriter.getI32Type()}));
2297 }
else if (resultVectorSize == 1024) {
2298 Value src = adaptor.getSource();
2299 VectorType srcType = cast<VectorType>(src.getType());
2300 Type srcScaType = srcType.getElementType();
2301 unsigned srcBitWidth = srcScaType.getIntOrFloatBitWidth();
2303 if (resultBitWidth == 16 && srcBitWidth == 32) {
2307 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
2308 rewriter.getI32IntegerAttr(0));
2310 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
2311 rewriter.getI32IntegerAttr(1));
2313 auto extractSrs1024 = [&](Value source, Value index, Value shiftCst,
2314 Value signCst) -> Value {
2317 auto v64i32Source = forceCastValueToType(
2318 rewriter, loc, source,
2319 VectorType::get({64}, rewriter.getI32Type()));
2324 SmallVector<int64_t> shuffleMask;
2325 if (
auto constIndex = index.getDefiningOp<LLVM::ConstantOp>()) {
2326 auto indexAttr = cast<IntegerAttr>(constIndex.getValue());
2327 int64_t idxVal = indexAttr.getInt();
2328 int startIdx = idxVal * 32;
2329 for (
int i = 0; i < 32; ++i) {
2330 shuffleMask.push_back(startIdx + i);
2334 for (
int i = 0; i < 32; ++i) {
2335 shuffleMask.push_back(i);
2339 auto extOp = vector::ShuffleOp::create(rewriter, loc, v64i32Source,
2340 v64i32Source, shuffleMask);
2342 return xllvm::I512V32Acc32SrsAIE2pIntrOp::create(
2343 rewriter, loc, VectorType::get({32}, rewriter.getI16Type()),
2344 forceCastOperandsToSignature(
2345 rewriter, loc, {extOp, shiftCst, signCst},
2346 {VectorType::get({32}, rewriter.getI32Type()),
2347 rewriter.getI32Type(), rewriter.getI32Type()}));
2351 extractSrs1024(src, index0Cst, adaptor.getShift(), signCst);
2353 extractSrs1024(src, index1Cst, adaptor.getShift(), signCst);
2357 auto v16i32Res0 = forceCastValueToType(
2358 rewriter, loc, res0,
2359 VectorType::get({16}, rewriter.getI32Type()));
2360 auto v16i32Res1 = forceCastValueToType(
2361 rewriter, loc, res1,
2362 VectorType::get({16}, rewriter.getI32Type()));
2364 SmallVector<int64_t> concatMask;
2365 for (
int i = 0; i < 32; ++i) {
2366 concatMask.push_back(i);
2368 srsIntrOp = vector::ShuffleOp::create(rewriter, loc, v16i32Res0,
2369 v16i32Res1, concatMask);
2375 if (resultVectorSize == 256) {
2377 srsIntrOp = xllvm::Vector16AccFloatToV16BF16AIE2pIntrOp::create(
2378 rewriter, loc, VectorType::get({16}, rewriter.getBF16Type()),
2379 forceCastOperandsToSignature(
2380 rewriter, loc, {adaptor.getSource()},
2381 {VectorType::get({16}, rewriter.getF32Type())}));
2382 }
else if (resultVectorSize == 512) {
2384 srsIntrOp = xllvm::Vector32AccFloatToV32BF16AIE2pIntrOp::create(
2385 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()),
2386 forceCastOperandsToSignature(
2387 rewriter, loc, {adaptor.getSource()},
2388 {VectorType::get({32}, rewriter.getF32Type())}));
2389 }
else if (resultVectorSize == 1024) {
2393 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
2394 rewriter.getI32IntegerAttr(0));
2396 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
2397 rewriter.getI32IntegerAttr(1));
2399 auto extractSrs1024 = [&](Value source, Value index) -> Value {
2402 auto v64i32Source = forceCastValueToType(
2403 rewriter, loc, source,
2404 VectorType::get({64}, rewriter.getI32Type()));
2409 SmallVector<int64_t> shuffleMask;
2410 if (
auto constIndex = index.getDefiningOp<LLVM::ConstantOp>()) {
2411 auto indexAttr = cast<IntegerAttr>(constIndex.getValue());
2412 int64_t idxVal = indexAttr.getInt();
2413 int startIdx = idxVal * 32;
2414 for (
int i = 0; i < 32; ++i) {
2415 shuffleMask.push_back(startIdx + i);
2419 for (
int i = 0; i < 32; ++i) {
2420 shuffleMask.push_back(i);
2424 auto extOp = vector::ShuffleOp::create(rewriter, loc, v64i32Source,
2425 v64i32Source, shuffleMask);
2427 return xllvm::Vector32AccFloatToV32BF16AIE2pIntrOp::create(
2428 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()),
2429 forceCastOperandsToSignature(
2430 rewriter, loc, {extOp},
2431 {VectorType::get({32}, rewriter.getF32Type())}));
2434 auto res0 = extractSrs1024(adaptor.getSource(), index0Cst);
2435 auto res1 = extractSrs1024(adaptor.getSource(), index1Cst);
2439 auto v16i32Res0 = forceCastValueToType(
2440 rewriter, loc, res0, VectorType::get({16}, rewriter.getI32Type()));
2441 auto v16i32Res1 = forceCastValueToType(
2442 rewriter, loc, res1, VectorType::get({16}, rewriter.getI32Type()));
2444 SmallVector<int64_t> concatMask;
2445 for (
int i = 0; i < 32; ++i) {
2446 concatMask.push_back(i);
2448 srsIntrOp = vector::ShuffleOp::create(rewriter, loc, v16i32Res0,
2449 v16i32Res1, concatMask);
2454 op.emitWarning() <<
"aievec.srs is not supported.\n";
2459 auto resultVal = forceCastValueToType(rewriter, loc, srsIntrOp,
2460 op.getResult().getType());
2461 rewriter.replaceOp(op, resultVal);
2469 using ConvertOpToLLVMPattern<aievec::UPDOp>::ConvertOpToLLVMPattern;
2472 auto resultType = cast<VectorType>(op.getResult().getType());
2473 std::stringstream ss;
2474 ss <<
"llvm.aie.upd.";
2475 ss << (loadSize == 128 ?
'v' : loadSize == 256 ?
'w' :
'x') <<
".";
2478 ss << (op.getIndex() == 0 ?
"lo" :
"hi");
2484 ConversionPatternRewriter &rewriter)
const override {
2485 auto module = op->getParentOfType<ModuleOp>();
2486 MLIRContext *context = rewriter.getContext();
2494 auto ptr = this->getStridedElementPtr(
2495 rewriter, op->getLoc(), cast<MemRefType>(op.getSource().getType()),
2496 adaptor.getSource(), adaptor.getIndices());
2500 if (vecSizeInBits <= 256) {
2504 auto vectorPtrType = LLVM::LLVMPointerType::get(
2506 cast<MemRefType>(op.getSource().getType()).getMemorySpaceAsInt());
2508 LLVM::BitcastOp::create(rewriter, op->getLoc(), vectorPtrType, ptr);
2509 auto vecType = cast<VectorType>(op.getResult().getType());
2510 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, vecType, castedPtr, 1);
2518 int loadSize = vecSizeInBits == 256 ? 128
2519 : vecSizeInBits == 512 ? 256
2524 auto resultType = cast<VectorType>(op.getResult().getType());
2527 VectorType::get({(int64_t)lanes / 2}, resultType.getElementType());
2530 auto vectorPtrType = LLVM::LLVMPointerType::get(
2532 cast<MemRefType>(op.getSource().getType()).getMemorySpaceAsInt());
2534 LLVM::BitcastOp::create(rewriter, op->getLoc(), vectorPtrType, ptr);
2536 LLVM::LoadOp::create(rewriter, op->getLoc(), loadType, castedPtr, 1);
2542 auto func =
module.lookupSymbol<LLVM::LLVMFuncOp>(
2543 StringAttr::get(context, intrinsicName));
2546 OpBuilder::InsertionGuard guard(rewriter);
2547 rewriter.setInsertionPointToStart(module.getBody());
2548 func = LLVM::LLVMFuncOp::create(
2549 rewriter, rewriter.getUnknownLoc(), intrinsicName,
2550 LLVM::LLVMFunctionType::get(resultType, {resultType, loadType}));
2555 if (adaptor.getVector()) {
2557 destValue = adaptor.getVector();
2566 std::stringstream ss;
2568 std::string intrinsicName = ss.str();
2570 auto func =
module.lookupSymbol<LLVM::LLVMFuncOp>(
2571 StringAttr::get(rewriter.getContext(), intrinsicName));
2574 OpBuilder::InsertionGuard guard(rewriter);
2575 rewriter.setInsertionPointToStart(module.getBody());
2576 func = LLVM::LLVMFuncOp::create(
2577 rewriter, rewriter.getUnknownLoc(), intrinsicName,
2578 LLVM::LLVMFunctionType::get(resultType, {}));
2581 LLVM::CallOp::create(rewriter, op->getLoc(), func, ValueRange{})
2586 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
2587 op, func, ValueRange{destValue, loadValue});
2595 :
public mlir::ConvertOpToLLVMPattern<aievec::ConcatOp> {
2597 using ConvertOpToLLVMPattern<aievec::ConcatOp>::ConvertOpToLLVMPattern;
2601 ConversionPatternRewriter &rewriter)
const override {
2602 Location loc = op.getLoc();
2604 SmallVector<Value> sources = adaptor.getSources();
2605 Value src = sources.front();
2606 VectorType srcType = cast<VectorType>(src.getType());
2607 Type srcScalarType = srcType.getElementType();
2608 unsigned srcBitWidth = srcScalarType.getIntOrFloatBitWidth();
2610 int srcVectorSize = srcBitWidth * srcLanes;
2612 Value result = op.getResult();
2613 VectorType resultType = cast<VectorType>(result.getType());
2614 Type resultScaTy = resultType.getElementType();
2615 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
2617 int resultVectorSize = resultBitWidth * resultLanes;
2619 if (sources.size() != 2 && sources.size() != 4) {
2620 op.emitWarning() <<
"aievec.concat with " << sources.size()
2621 <<
" operands is not supported.\n";
2626 Value concatOp =
nullptr;
2627 if (srcVectorSize == 256 && resultVectorSize == 512) {
2628 concatOp = xllvm::ConcatI512I256IntrOp::create(
2629 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()),
2630 forceCastOperandsToSignature(
2631 rewriter, loc, adaptor.getSources(),
2632 {VectorType::get({8}, rewriter.getI32Type()),
2633 VectorType::get({8}, rewriter.getI32Type())}));
2634 }
else if (srcVectorSize == 256 && resultVectorSize == 1024) {
2635 concatOp = xllvm::ConcatI1024I256IntrOp::create(
2636 rewriter, loc, VectorType::get({32}, rewriter.getI32Type()),
2637 forceCastOperandsToSignature(
2638 rewriter, loc, adaptor.getSources(),
2639 {VectorType::get({8}, rewriter.getI32Type()),
2640 VectorType::get({8}, rewriter.getI32Type()),
2641 VectorType::get({8}, rewriter.getI32Type()),
2642 VectorType::get({8}, rewriter.getI32Type())}));
2643 }
else if (srcVectorSize == 512 && resultVectorSize == 1024) {
2644 concatOp = xllvm::ConcatI1024I512IntrOp::create(
2645 rewriter, loc, VectorType::get({32}, rewriter.getI32Type()),
2646 forceCastOperandsToSignature(
2647 rewriter, loc, adaptor.getSources(),
2648 {VectorType::get({16}, rewriter.getI32Type()),
2649 VectorType::get({16}, rewriter.getI32Type())}));
2651 op.emitWarning() <<
"aievec.concat with " << srcVectorSize
2652 <<
"-bit operands, and " << resultVectorSize
2653 <<
"-bit result is not supported.\n";
2659 forceCastValueToType(rewriter, loc, concatOp, op.getResult().getType());
2660 rewriter.replaceOp(op, resultVal);
2668 using ConvertOpToLLVMPattern<aievec::ExtOp>::ConvertOpToLLVMPattern;
2672 ConversionPatternRewriter &rewriter)
const override {
2673 Location loc = op.getLoc();
2675 Value src = adaptor.getSource();
2676 VectorType srcType = cast<VectorType>(src.getType());
2677 Type srcScalarType = srcType.getElementType();
2678 unsigned srcBitWidth = srcScalarType.getIntOrFloatBitWidth();
2680 int srcVectorSize = srcBitWidth * srcLanes;
2682 Value result = op.getResult();
2683 VectorType resultType = cast<VectorType>(result.getType());
2684 Type resultScaTy = resultType.getElementType();
2685 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
2687 int resultVectorSize = resultBitWidth * resultLanes;
2691 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
2692 rewriter.getI32IntegerAttr(op.getIndex()));
2695 SmallVector<Value> operands({adaptor.getSource(), indexCst});
2696 Value extOp =
nullptr;
2698 if (resultVectorSize == 256 && srcVectorSize == 512) {
2699 extOp = xllvm::ExtI256I512IntrOp::create(
2700 rewriter, loc, VectorType::get({8}, rewriter.getI32Type()),
2701 forceCastOperandsToSignature(
2702 rewriter, loc, operands,
2703 {VectorType::get({16}, rewriter.getI32Type()),
2704 rewriter.getI32Type()}));
2705 }
else if (resultVectorSize == 512 && srcVectorSize == 1024) {
2706 extOp = xllvm::ExtI512I1024IntrOp::create(
2707 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()),
2708 forceCastOperandsToSignature(
2709 rewriter, loc, operands,
2710 {VectorType::get({32}, rewriter.getI32Type()),
2711 rewriter.getI32Type()}));
2712 }
else if (resultVectorSize == 256 && srcVectorSize == 1024) {
2713 extOp = xllvm::ExtI256I1024IntrOp::create(
2714 rewriter, loc, VectorType::get({8}, rewriter.getI32Type()),
2715 forceCastOperandsToSignature(
2716 rewriter, loc, operands,
2717 {VectorType::get({32}, rewriter.getI32Type()),
2718 rewriter.getI32Type()}));
2719 }
else if (resultVectorSize == 128 && srcVectorSize == 512) {
2720 auto shiftOp = adaptor.getSource();
2721 if (op.getIndex() > 0) {
2722 auto undefOp = xllvm::UndefV16I32IntrOp::create(
2723 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()));
2725 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
2726 rewriter.getI32IntegerAttr(0));
2727 auto shiftCst = LLVM::ConstantOp::create(
2728 rewriter, loc, rewriter.getI32Type(),
2729 rewriter.getI32IntegerAttr(op.getIndex() * 16));
2730 SmallVector<Value> shiftOperands{adaptor.getSource(), undefOp, stepCst,
2734 shiftOp = xllvm::VectorShiftI512I512IntrOp::create(
2735 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()),
2736 forceCastOperandsToSignature(
2737 rewriter, loc, shiftOperands,
2738 {VectorType::get({16}, rewriter.getI32Type()),
2739 VectorType::get({16}, rewriter.getI32Type()),
2740 rewriter.getI32Type(), rewriter.getI32Type()}));
2744 extOp = xllvm::ExtI128I512IntrOp::create(
2745 rewriter, loc, VectorType::get({4}, rewriter.getI32Type()),
2746 forceCastOperandsToSignature(
2747 rewriter, loc, {shiftOp},
2748 {VectorType::get({16}, rewriter.getI32Type())}));
2750 op.emitWarning() <<
"aievec.ext with " << srcVectorSize
2751 <<
"-bit source, and " << resultVectorSize
2752 <<
"-bit result is not supported.\n";
2758 forceCastValueToType(rewriter, loc, extOp, op.getResult().getType());
2759 rewriter.replaceOp(op, resultVal);
2767 :
public mlir::ConvertOpToLLVMPattern<aievec::ExtOp> {
2769 using ConvertOpToLLVMPattern<aievec::ExtOp>::ConvertOpToLLVMPattern;
2773 ConversionPatternRewriter &rewriter)
const override {
2774 Location loc = op.getLoc();
2776 Value src = adaptor.getSource();
2777 VectorType srcType = cast<VectorType>(src.getType());
2778 VectorType resultType = cast<VectorType>(op.getResult().getType());
2784 if (srcLanes != 2 * resultLanes) {
2785 op.emitWarning() <<
"aievec.ext with non-half extraction is not "
2786 "supported for AIE2p.\n";
2793 SmallVector<int64_t> shuffleMask;
2794 int startIdx = op.getIndex() * resultLanes;
2795 for (
int i = 0; i < resultLanes; ++i) {
2796 shuffleMask.push_back(startIdx + i);
2801 vector::ShuffleOp::create(rewriter, loc, src, src, shuffleMask);
2803 rewriter.replaceOp(op, extracted);
2809 :
public mlir::ConvertOpToLLVMPattern<aievec::aie1::SelectOp> {
2811 using ConvertOpToLLVMPattern<aievec::aie1::SelectOp>::ConvertOpToLLVMPattern;
2814 auto xbuffType = cast<VectorType>(op.getXbuff().getType());
2815 std::stringstream ss;
2822 ConversionPatternRewriter &rewriter)
const override {
2823 auto module = op->getParentOfType<ModuleOp>();
2824 MLIRContext *context = rewriter.getContext();
2826 auto selectType = IntegerType::get(context, 32);
2827 auto startType = IntegerType::get(context, 32);
2828 auto offsetsType = VectorType::get({2}, IntegerType::get(context, 32));
2829 auto confType = VectorType::get({2}, IntegerType::get(context, 32));
2832 std::string intrinsicName = getIntrinsicName(op);
2833 auto func =
module.lookupSymbol<LLVM::LLVMFuncOp>(
2834 StringAttr::get(context, intrinsicName));
2837 OpBuilder::InsertionGuard guard(rewriter);
2838 rewriter.setInsertionPointToStart(module.getBody());
2839 func = LLVM::LLVMFuncOp::create(
2840 rewriter, rewriter.getUnknownLoc(), intrinsicName,
2841 LLVM::LLVMFunctionType::get(op.getResult().getType(),
2842 {op.getXbuff().getType(), selectType,
2851 uint32_t select = 0;
2856 op.getSelect().getAsInteger(0, select);
2857 op.getXstart().getAsInteger(0, x.
start);
2858 op.getXoffsets().getAsInteger(0, x.
offsets);
2859 op.getXoffsetsHi().getAsInteger(0, x.
offsets_hi);
2860 op.getXsquare().getAsInteger(0, x.
square);
2861 op.getYstart().getAsInteger(0, y.
start);
2862 op.getYoffsets().getAsInteger(0, y.
offsets);
2863 op.getYoffsetsHi().getAsInteger(0, y.
offsets_hi);
2864 op.getYsquare().getAsInteger(0, y.
square);
2867 uint32_t conf[2] = {0, 0};
2872 auto selectVal = LLVM::ConstantOp::create(
2873 rewriter, op->getLoc(), selectType, rewriter.getI32IntegerAttr(select));
2874 auto xstartVal = LLVM::ConstantOp::create(
2875 rewriter, op->getLoc(), startType, rewriter.getI32IntegerAttr(x.
start));
2876 auto ystartVal = LLVM::ConstantOp::create(
2877 rewriter, op->getLoc(), startType, rewriter.getI32IntegerAttr(y.
start));
2878 auto xoffsetsVal = LLVM::ConstantOp::create(
2879 rewriter, op->getLoc(), offsetsType,
2880 rewriter.getI32VectorAttr({(int32_t)x.offsets, (int32_t)x.offsets_hi}));
2881 auto yoffsetsVal = LLVM::ConstantOp::create(
2882 rewriter, op->getLoc(), offsetsType,
2883 rewriter.getI32VectorAttr({(int32_t)y.offsets, (int32_t)y.offsets_hi}));
2884 auto confVal = LLVM::ConstantOp::create(
2885 rewriter, op->getLoc(), confType,
2886 rewriter.getI32VectorAttr({(int32_t)conf[0], (int32_t)conf[1]}));
2887 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
2889 ValueRange{op.getXbuff(), selectVal, xstartVal, ystartVal, xoffsetsVal,
2890 yoffsetsVal, confVal});
2897 using ConvertOpToLLVMPattern<aievec::PackOp>::ConvertOpToLLVMPattern;
2900 auto sourceType = cast<VectorType>(op.getSource().getType());
2901 std::stringstream ss;
2908 ConversionPatternRewriter &rewriter)
const override {
2909 auto module = op->getParentOfType<ModuleOp>();
2910 MLIRContext *context = rewriter.getContext();
2913 std::string intrinsicName = getIntrinsicName(op);
2914 auto func =
module.lookupSymbol<LLVM::LLVMFuncOp>(
2915 StringAttr::get(context, intrinsicName));
2918 OpBuilder::InsertionGuard guard(rewriter);
2919 rewriter.setInsertionPointToStart(module.getBody());
2920 func = LLVM::LLVMFuncOp::create(
2921 rewriter, rewriter.getUnknownLoc(), intrinsicName,
2922 LLVM::LLVMFunctionType::get(op.getResult().getType(),
2923 {op.getSource().getType()}));
2926 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, func,
2927 ValueRange{op.getSource()});
2933 :
public mlir::ConvertOpToLLVMPattern<aievec::UnpackOp> {
2935 using ConvertOpToLLVMPattern<aievec::UnpackOp>::ConvertOpToLLVMPattern;
2939 ConversionPatternRewriter &rewriter)
const override {
2940 op.emitWarning() <<
"aie.unpack conversion is not implemented\n";
2946 :
public mlir::ConvertOpToLLVMPattern<aievec::BroadcastOp> {
2948 using ConvertOpToLLVMPattern<aievec::BroadcastOp>::ConvertOpToLLVMPattern;
2952 ConversionPatternRewriter &rewriter)
const override {
2953 op.emitWarning() <<
"aie.broadcast conversion is not implemented\n";
2960static Value padVectorWithPoison(ConversionPatternRewriter &rewriter,
2961 Location loc, Value vec,
int srcLanes,
2963 SmallVector<int64_t> padMask;
2964 for (
int i = 0; i < srcLanes; ++i)
2965 padMask.push_back(i);
2966 for (
int i = srcLanes; i < dstLanes; ++i)
2967 padMask.push_back(-1);
2968 return vector::ShuffleOp::create(rewriter, loc, vec, vec, padMask);
2973static Value extractLowerLanes(ConversionPatternRewriter &rewriter,
2974 Location loc, Value vec,
int lanes) {
2975 SmallVector<int64_t> extractMask;
2976 for (
int i = 0; i < lanes; ++i)
2977 extractMask.push_back(i);
2978 return vector::ShuffleOp::create(rewriter, loc, vec, vec, extractMask);
2983 using ConvertOpToLLVMPattern<aievec::MaxOp>::ConvertOpToLLVMPattern;
2987 ConversionPatternRewriter &rewriter)
const override {
2988 Location loc = op.getLoc();
2990 VectorType resultType = cast<VectorType>(op.getResult().getType());
2991 Type resultScaTy = resultType.getElementType();
2992 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
2994 int resultVectorSize = resultBitWidth * resultLanes;
2997 if (resultVectorSize != 512 && resultVectorSize != 256) {
2998 op.emitWarning() <<
"aievec.max conversion with " << resultVectorSize
2999 <<
"-bit result is not supported.\n";
3004 Value maxOp =
nullptr;
3005 if (llvm::isa<IntegerType>(resultScaTy)) {
3008 auto cmpCst = LLVM::ConstantOp::create(
3009 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
3010 SmallVector<Value> operands{adaptor.getLhs(), adaptor.getRhs(), cmpCst};
3011 if (resultBitWidth == 8) {
3012 maxOp = xllvm::VectorMaxLt8IntrOp::create(
3014 mlir::LLVM::LLVMStructType::getLiteral(
3015 rewriter.getContext(),
3016 {VectorType::get({64}, rewriter.getI8Type()),
3017 VectorType::get({2}, rewriter.getI32Type())}),
3018 forceCastOperandsToSignature(
3019 rewriter, loc, operands,
3020 {VectorType::get({64}, rewriter.getI8Type()),
3021 VectorType::get({64}, rewriter.getI8Type()),
3022 rewriter.getI32Type()}));
3023 }
else if (resultBitWidth == 16) {
3024 maxOp = xllvm::VectorMaxLt16IntrOp::create(
3026 mlir::LLVM::LLVMStructType::getLiteral(
3027 rewriter.getContext(),
3028 {VectorType::get({32}, rewriter.getI16Type()),
3029 rewriter.getI32Type()}),
3030 forceCastOperandsToSignature(
3031 rewriter, loc, operands,
3032 {VectorType::get({32}, rewriter.getI16Type()),
3033 VectorType::get({32}, rewriter.getI16Type()),
3034 rewriter.getI32Type()}));
3035 }
else if (resultBitWidth == 32) {
3036 maxOp = xllvm::VectorMaxLt32IntrOp::create(
3038 mlir::LLVM::LLVMStructType::getLiteral(
3039 rewriter.getContext(),
3040 {VectorType::get({16}, rewriter.getI32Type()),
3041 rewriter.getI32Type()}),
3042 forceCastOperandsToSignature(
3043 rewriter, loc, operands,
3044 {VectorType::get({16}, rewriter.getI32Type()),
3045 VectorType::get({16}, rewriter.getI32Type()),
3046 rewriter.getI32Type()}));
3049 if (resultBitWidth == 16) {
3050 auto v32bf16Ty = VectorType::get({32}, rewriter.getBF16Type());
3051 Value lhs = adaptor.getLhs(), rhs = adaptor.getRhs();
3054 if (resultLanes == 16) {
3055 lhs = padVectorWithPoison(rewriter, loc, lhs, 16, 32);
3056 rhs = padVectorWithPoison(rewriter, loc, rhs, 16, 32);
3059 maxOp = xllvm::VectorMaxLtBf16IntrOp::create(
3061 mlir::LLVM::LLVMStructType::getLiteral(
3062 rewriter.getContext(), {v32bf16Ty, rewriter.getI32Type()}),
3063 forceCastOperandsToSignature(rewriter, loc, {lhs, rhs},
3064 {v32bf16Ty, v32bf16Ty}));
3069 op.emitWarning() <<
"aievec.max conversion fails due to unsupported "
3070 "element data type.\n";
3075 Value resultVec = LLVM::ExtractValueOp::create(rewriter, loc, maxOp,
3078 if (resultLanes == 16 && !llvm::isa<IntegerType>(resultScaTy))
3079 resultVec = extractLowerLanes(rewriter, loc, resultVec, 16);
3081 rewriter.replaceOp(op, resultVec);
3089 using ConvertOpToLLVMPattern<aievec::MinOp>::ConvertOpToLLVMPattern;
3093 ConversionPatternRewriter &rewriter)
const override {
3094 Location loc = op.getLoc();
3096 VectorType resultType = cast<VectorType>(op.getResult().getType());
3097 Type resultScaTy = resultType.getElementType();
3098 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
3100 int resultVectorSize = resultBitWidth * resultLanes;
3103 if (resultVectorSize != 512 && resultVectorSize != 256) {
3104 op.emitWarning() <<
"aievec.min conversion with " << resultVectorSize
3105 <<
"-bit result is not supported.\n";
3110 Value minOp =
nullptr;
3111 if (llvm::isa<IntegerType>(resultScaTy)) {
3114 auto cmpCst = LLVM::ConstantOp::create(
3115 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
3116 SmallVector<Value> operands{adaptor.getLhs(), adaptor.getRhs(), cmpCst};
3117 if (resultBitWidth == 8) {
3118 minOp = xllvm::VectorMinGe8IntrOp::create(
3120 mlir::LLVM::LLVMStructType::getLiteral(
3121 rewriter.getContext(),
3122 {VectorType::get({64}, rewriter.getI8Type()),
3123 VectorType::get({2}, rewriter.getI32Type())}),
3124 forceCastOperandsToSignature(
3125 rewriter, loc, operands,
3126 {VectorType::get({64}, rewriter.getI8Type()),
3127 VectorType::get({64}, rewriter.getI8Type()),
3128 rewriter.getI32Type()}));
3129 }
else if (resultBitWidth == 16) {
3130 minOp = xllvm::VectorMinGe16IntrOp::create(
3132 mlir::LLVM::LLVMStructType::getLiteral(
3133 rewriter.getContext(),
3134 {VectorType::get({32}, rewriter.getI16Type()),
3135 rewriter.getI32Type()}),
3136 forceCastOperandsToSignature(
3137 rewriter, loc, operands,
3138 {VectorType::get({32}, rewriter.getI16Type()),
3139 VectorType::get({32}, rewriter.getI16Type()),
3140 rewriter.getI32Type()}));
3141 }
else if (resultBitWidth == 32) {
3142 minOp = xllvm::VectorMinGe32IntrOp::create(
3144 mlir::LLVM::LLVMStructType::getLiteral(
3145 rewriter.getContext(),
3146 {VectorType::get({16}, rewriter.getI32Type()),
3147 rewriter.getI32Type()}),
3148 forceCastOperandsToSignature(
3149 rewriter, loc, operands,
3150 {VectorType::get({16}, rewriter.getI32Type()),
3151 VectorType::get({16}, rewriter.getI32Type()),
3152 rewriter.getI32Type()}));
3155 if (resultBitWidth == 16) {
3156 auto v32bf16Ty = VectorType::get({32}, rewriter.getBF16Type());
3157 Value lhs = adaptor.getLhs(), rhs = adaptor.getRhs();
3160 if (resultLanes == 16) {
3161 lhs = padVectorWithPoison(rewriter, loc, lhs, 16, 32);
3162 rhs = padVectorWithPoison(rewriter, loc, rhs, 16, 32);
3165 minOp = xllvm::VectorMinGeBf16IntrOp::create(
3167 mlir::LLVM::LLVMStructType::getLiteral(
3168 rewriter.getContext(), {v32bf16Ty, rewriter.getI32Type()}),
3169 forceCastOperandsToSignature(rewriter, loc, {lhs, rhs},
3170 {v32bf16Ty, v32bf16Ty}));
3175 op.emitWarning() <<
"aievec.min conversion fails due to unsupported "
3176 "element data type.\n";
3181 Value resultVec = LLVM::ExtractValueOp::create(rewriter, loc, minOp,
3184 if (resultLanes == 16 && !llvm::isa<IntegerType>(resultScaTy))
3185 resultVec = extractLowerLanes(rewriter, loc, resultVec, 16);
3187 rewriter.replaceOp(op, resultVec);
3195 :
public mlir::ConvertOpToLLVMPattern<aievec::MaxOp> {
3197 using ConvertOpToLLVMPattern<aievec::MaxOp>::ConvertOpToLLVMPattern;
3201 ConversionPatternRewriter &rewriter)
const override {
3202 Location loc = op.getLoc();
3204 VectorType resultType = cast<VectorType>(op.getResult().getType());
3205 Type resultScaTy = resultType.getElementType();
3206 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
3208 int resultVectorSize = resultBitWidth * resultLanes;
3211 if (resultVectorSize != 512 && resultVectorSize != 256) {
3212 op.emitWarning() <<
"aievec.max conversion with " << resultVectorSize
3213 <<
"-bit result is not supported.\n";
3218 Value maxOp =
nullptr;
3219 if (llvm::isa<IntegerType>(resultScaTy)) {
3222 auto cmpCst = LLVM::ConstantOp::create(
3223 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
3224 SmallVector<Value> operands{adaptor.getLhs(), adaptor.getRhs(), cmpCst};
3225 if (resultBitWidth == 8) {
3226 maxOp = xllvm::VectorMaxLt8AIE2pIntrOp::create(
3228 mlir::LLVM::LLVMStructType::getLiteral(
3229 rewriter.getContext(),
3230 {VectorType::get({64}, rewriter.getI8Type()),
3231 VectorType::get({2}, rewriter.getI32Type())}),
3232 forceCastOperandsToSignature(
3233 rewriter, loc, operands,
3234 {VectorType::get({64}, rewriter.getI8Type()),
3235 VectorType::get({64}, rewriter.getI8Type()),
3236 rewriter.getI32Type()}));
3237 }
else if (resultBitWidth == 16) {
3238 maxOp = xllvm::VectorMaxLt16AIE2pIntrOp::create(
3240 mlir::LLVM::LLVMStructType::getLiteral(
3241 rewriter.getContext(),
3242 {VectorType::get({32}, rewriter.getI16Type()),
3243 rewriter.getI32Type()}),
3244 forceCastOperandsToSignature(
3245 rewriter, loc, operands,
3246 {VectorType::get({32}, rewriter.getI16Type()),
3247 VectorType::get({32}, rewriter.getI16Type()),
3248 rewriter.getI32Type()}));
3249 }
else if (resultBitWidth == 32) {
3250 maxOp = xllvm::VectorMaxLt32AIE2pIntrOp::create(
3252 mlir::LLVM::LLVMStructType::getLiteral(
3253 rewriter.getContext(),
3254 {VectorType::get({16}, rewriter.getI32Type()),
3255 rewriter.getI32Type()}),
3256 forceCastOperandsToSignature(
3257 rewriter, loc, operands,
3258 {VectorType::get({16}, rewriter.getI32Type()),
3259 VectorType::get({16}, rewriter.getI32Type()),
3260 rewriter.getI32Type()}));
3263 if (resultBitWidth == 16) {
3264 auto v32bf16Ty = VectorType::get({32}, rewriter.getBF16Type());
3265 Value lhs = adaptor.getLhs(), rhs = adaptor.getRhs();
3267 if (resultLanes == 16) {
3268 lhs = padVectorWithPoison(rewriter, loc, lhs, 16, 32);
3269 rhs = padVectorWithPoison(rewriter, loc, rhs, 16, 32);
3272 maxOp = xllvm::VectorMaxLtBf16AIE2pIntrOp::create(
3274 mlir::LLVM::LLVMStructType::getLiteral(
3275 rewriter.getContext(), {v32bf16Ty, rewriter.getI32Type()}),
3276 forceCastOperandsToSignature(rewriter, loc, {lhs, rhs},
3277 {v32bf16Ty, v32bf16Ty}));
3282 op.emitWarning() <<
"aievec.max conversion fails due to unsupported "
3283 "element data type.\n";
3287 Value resultVec = LLVM::ExtractValueOp::create(rewriter, loc, maxOp,
3289 if (resultLanes == 16 && !llvm::isa<IntegerType>(resultScaTy))
3290 resultVec = extractLowerLanes(rewriter, loc, resultVec, 16);
3292 rewriter.replaceOp(op, resultVec);
3300 :
public mlir::ConvertOpToLLVMPattern<aievec::MinOp> {
3302 using ConvertOpToLLVMPattern<aievec::MinOp>::ConvertOpToLLVMPattern;
3306 ConversionPatternRewriter &rewriter)
const override {
3307 Location loc = op.getLoc();
3309 VectorType resultType = cast<VectorType>(op.getResult().getType());
3310 Type resultScaTy = resultType.getElementType();
3311 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
3313 int resultVectorSize = resultBitWidth * resultLanes;
3316 if (resultVectorSize != 512 && resultVectorSize != 256) {
3317 op.emitWarning() <<
"aievec.min conversion with " << resultVectorSize
3318 <<
"-bit result is not supported.\n";
3323 Value minOp =
nullptr;
3324 if (llvm::isa<IntegerType>(resultScaTy)) {
3327 auto cmpCst = LLVM::ConstantOp::create(
3328 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
3329 SmallVector<Value> operands{adaptor.getLhs(), adaptor.getRhs(), cmpCst};
3330 if (resultBitWidth == 8) {
3331 minOp = xllvm::VectorMinGe8AIE2pIntrOp::create(
3333 mlir::LLVM::LLVMStructType::getLiteral(
3334 rewriter.getContext(),
3335 {VectorType::get({64}, rewriter.getI8Type()),
3336 VectorType::get({2}, rewriter.getI32Type())}),
3337 forceCastOperandsToSignature(
3338 rewriter, loc, operands,
3339 {VectorType::get({64}, rewriter.getI8Type()),
3340 VectorType::get({64}, rewriter.getI8Type()),
3341 rewriter.getI32Type()}));
3342 }
else if (resultBitWidth == 16) {
3343 minOp = xllvm::VectorMinGe16AIE2pIntrOp::create(
3345 mlir::LLVM::LLVMStructType::getLiteral(
3346 rewriter.getContext(),
3347 {VectorType::get({32}, rewriter.getI16Type()),
3348 rewriter.getI32Type()}),
3349 forceCastOperandsToSignature(
3350 rewriter, loc, operands,
3351 {VectorType::get({32}, rewriter.getI16Type()),
3352 VectorType::get({32}, rewriter.getI16Type()),
3353 rewriter.getI32Type()}));
3354 }
else if (resultBitWidth == 32) {
3355 minOp = xllvm::VectorMinGe32AIE2pIntrOp::create(
3357 mlir::LLVM::LLVMStructType::getLiteral(
3358 rewriter.getContext(),
3359 {VectorType::get({16}, rewriter.getI32Type()),
3360 rewriter.getI32Type()}),
3361 forceCastOperandsToSignature(
3362 rewriter, loc, operands,
3363 {VectorType::get({16}, rewriter.getI32Type()),
3364 VectorType::get({16}, rewriter.getI32Type()),
3365 rewriter.getI32Type()}));
3368 if (resultBitWidth == 16) {
3369 auto v32bf16Ty = VectorType::get({32}, rewriter.getBF16Type());
3370 Value lhs = adaptor.getLhs(), rhs = adaptor.getRhs();
3372 if (resultLanes == 16) {
3373 lhs = padVectorWithPoison(rewriter, loc, lhs, 16, 32);
3374 rhs = padVectorWithPoison(rewriter, loc, rhs, 16, 32);
3377 minOp = xllvm::VectorMinGeBf16AIE2pIntrOp::create(
3379 mlir::LLVM::LLVMStructType::getLiteral(
3380 rewriter.getContext(), {v32bf16Ty, rewriter.getI32Type()}),
3381 forceCastOperandsToSignature(rewriter, loc, {lhs, rhs},
3382 {v32bf16Ty, v32bf16Ty}));
3387 op.emitWarning() <<
"aievec.min conversion fails due to unsupported "
3388 "element data type.\n";
3392 Value resultVec = LLVM::ExtractValueOp::create(rewriter, loc, minOp,
3394 if (resultLanes == 16 && !llvm::isa<IntegerType>(resultScaTy))
3395 resultVec = extractLowerLanes(rewriter, loc, resultVec, 16);
3397 rewriter.replaceOp(op, resultVec);
3406template <
typename MaxLtBf16IntrOp,
typename MinGeBf16IntrOp,
3407 typename MaxLt32IntrOp,
typename MinGe32IntrOp,
3408 typename MaxLt16IntrOp,
typename MinGe16IntrOp>
3411 using ConvertOpToLLVMPattern<aievec::CmpOp>::ConvertOpToLLVMPattern;
3415 ConversionPatternRewriter &rewriter)
const override {
3416 Location loc = op.getLoc();
3417 auto vecTy = cast<VectorType>(op.getLhs().getType());
3418 auto elTy = vecTy.getElementType();
3419 unsigned elWidth = elTy.getIntOrFloatBitWidth();
3421 auto pred = op.getPred();
3424 if (elWidth == 16 && isa<FloatType>(elTy)) {
3425 auto v32bf16Ty = VectorType::get({32}, rewriter.getBF16Type());
3426 auto structTy = LLVM::LLVMStructType::getLiteral(
3427 rewriter.getContext(), {v32bf16Ty, rewriter.getI32Type()});
3429 Value lhs = adaptor.getLhs(), rhs = adaptor.getRhs();
3431 lhs = padVectorWithPoison(rewriter, loc, lhs, 16, 32);
3432 rhs = padVectorWithPoison(rewriter, loc, rhs, 16, 32);
3433 }
else if (lanes != 32) {
3439 forceCastOperandsToSignature(rewriter, loc, {lhs}, {v32bf16Ty});
3441 forceCastOperandsToSignature(rewriter, loc, {rhs}, {v32bf16Ty});
3443 if (pred ==
"slt" || pred ==
"ult") {
3444 auto intrOp = MaxLtBf16IntrOp::create(
3445 rewriter, loc, structTy, ValueRange{castedLhs[0], castedRhs[0]});
3446 bitmask = LLVM::ExtractValueOp::create(rewriter, loc, intrOp, 1);
3447 }
else if (pred ==
"sge" || pred ==
"uge") {
3448 auto intrOp = MinGeBf16IntrOp::create(
3449 rewriter, loc, structTy, ValueRange{castedLhs[0], castedRhs[0]});
3450 bitmask = LLVM::ExtractValueOp::create(rewriter, loc, intrOp, 1);
3451 }
else if (pred ==
"sgt" || pred ==
"ugt") {
3453 auto intrOp = MaxLtBf16IntrOp::create(
3454 rewriter, loc, structTy, ValueRange{castedRhs[0], castedLhs[0]});
3455 bitmask = LLVM::ExtractValueOp::create(rewriter, loc, intrOp, 1);
3456 }
else if (pred ==
"sle" || pred ==
"ule") {
3458 auto intrOp = MinGeBf16IntrOp::create(
3459 rewriter, loc, structTy, ValueRange{castedRhs[0], castedLhs[0]});
3460 bitmask = LLVM::ExtractValueOp::create(rewriter, loc, intrOp, 1);
3461 }
else if (pred ==
"eq") {
3463 auto geAB = MinGeBf16IntrOp::create(
3464 rewriter, loc, structTy, ValueRange{castedLhs[0], castedRhs[0]});
3465 auto geBA = MinGeBf16IntrOp::create(
3466 rewriter, loc, structTy, ValueRange{castedRhs[0], castedLhs[0]});
3467 auto maskAB = LLVM::ExtractValueOp::create(rewriter, loc, geAB, 1);
3468 auto maskBA = LLVM::ExtractValueOp::create(rewriter, loc, geBA, 1);
3469 bitmask = LLVM::AndOp::create(rewriter, loc, maskAB, maskBA);
3470 }
else if (pred ==
"ne") {
3472 auto ltAB = MaxLtBf16IntrOp::create(
3473 rewriter, loc, structTy, ValueRange{castedLhs[0], castedRhs[0]});
3474 auto ltBA = MaxLtBf16IntrOp::create(
3475 rewriter, loc, structTy, ValueRange{castedRhs[0], castedLhs[0]});
3476 auto maskAB = LLVM::ExtractValueOp::create(rewriter, loc, ltAB, 1);
3477 auto maskBA = LLVM::ExtractValueOp::create(rewriter, loc, ltBA, 1);
3478 bitmask = LLVM::OrOp::create(rewriter, loc, maskAB, maskBA);
3486 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
3487 rewriter.getI32IntegerAttr(0xFFFF));
3488 bitmask = LLVM::AndOp::create(rewriter, loc, bitmask, mask);
3492 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
3493 op, op.getResult().getType(), bitmask);
3498 if (elWidth == 32 && isa<IntegerType>(elTy) && lanes == 16) {
3499 auto v16i32Ty = VectorType::get({16}, rewriter.getI32Type());
3500 auto structTy = LLVM::LLVMStructType::getLiteral(
3501 rewriter.getContext(), {v16i32Ty, rewriter.getI32Type()});
3502 auto cmpCst = LLVM::ConstantOp::create(
3503 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
3506 if (pred ==
"slt" || pred ==
"ult") {
3507 auto intrOp = MaxLt32IntrOp::create(
3508 rewriter, loc, structTy,
3509 forceCastOperandsToSignature(
3510 rewriter, loc, {adaptor.getLhs(), adaptor.getRhs(), cmpCst},
3511 {v16i32Ty, v16i32Ty, rewriter.getI32Type()}));
3512 bitmask = LLVM::ExtractValueOp::create(rewriter, loc, intrOp, 1);
3513 }
else if (pred ==
"sge" || pred ==
"uge") {
3514 auto intrOp = MinGe32IntrOp::create(
3515 rewriter, loc, structTy,
3516 forceCastOperandsToSignature(
3517 rewriter, loc, {adaptor.getLhs(), adaptor.getRhs(), cmpCst},
3518 {v16i32Ty, v16i32Ty, rewriter.getI32Type()}));
3519 bitmask = LLVM::ExtractValueOp::create(rewriter, loc, intrOp, 1);
3520 }
else if (pred ==
"sgt" || pred ==
"ugt") {
3521 auto intrOp = MaxLt32IntrOp::create(
3522 rewriter, loc, structTy,
3523 forceCastOperandsToSignature(
3524 rewriter, loc, {adaptor.getRhs(), adaptor.getLhs(), cmpCst},
3525 {v16i32Ty, v16i32Ty, rewriter.getI32Type()}));
3526 bitmask = LLVM::ExtractValueOp::create(rewriter, loc, intrOp, 1);
3527 }
else if (pred ==
"sle" || pred ==
"ule") {
3528 auto intrOp = MinGe32IntrOp::create(
3529 rewriter, loc, structTy,
3530 forceCastOperandsToSignature(
3531 rewriter, loc, {adaptor.getRhs(), adaptor.getLhs(), cmpCst},
3532 {v16i32Ty, v16i32Ty, rewriter.getI32Type()}));
3533 bitmask = LLVM::ExtractValueOp::create(rewriter, loc, intrOp, 1);
3534 }
else if (pred ==
"eq") {
3535 auto geAB = MinGe32IntrOp::create(
3536 rewriter, loc, structTy,
3537 forceCastOperandsToSignature(
3538 rewriter, loc, {adaptor.getLhs(), adaptor.getRhs(), cmpCst},
3539 {v16i32Ty, v16i32Ty, rewriter.getI32Type()}));
3540 auto geBA = MinGe32IntrOp::create(
3541 rewriter, loc, structTy,
3542 forceCastOperandsToSignature(
3543 rewriter, loc, {adaptor.getRhs(), adaptor.getLhs(), cmpCst},
3544 {v16i32Ty, v16i32Ty, rewriter.getI32Type()}));
3545 auto maskAB = LLVM::ExtractValueOp::create(rewriter, loc, geAB, 1);
3546 auto maskBA = LLVM::ExtractValueOp::create(rewriter, loc, geBA, 1);
3547 bitmask = LLVM::AndOp::create(rewriter, loc, maskAB, maskBA);
3548 }
else if (pred ==
"ne") {
3549 auto ltAB = MaxLt32IntrOp::create(
3550 rewriter, loc, structTy,
3551 forceCastOperandsToSignature(
3552 rewriter, loc, {adaptor.getLhs(), adaptor.getRhs(), cmpCst},
3553 {v16i32Ty, v16i32Ty, rewriter.getI32Type()}));
3554 auto ltBA = MaxLt32IntrOp::create(
3555 rewriter, loc, structTy,
3556 forceCastOperandsToSignature(
3557 rewriter, loc, {adaptor.getRhs(), adaptor.getLhs(), cmpCst},
3558 {v16i32Ty, v16i32Ty, rewriter.getI32Type()}));
3559 auto maskAB = LLVM::ExtractValueOp::create(rewriter, loc, ltAB, 1);
3560 auto maskBA = LLVM::ExtractValueOp::create(rewriter, loc, ltBA, 1);
3561 bitmask = LLVM::OrOp::create(rewriter, loc, maskAB, maskBA);
3566 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
3567 op, op.getResult().getType(), bitmask);
3572 if (elWidth == 16 && isa<IntegerType>(elTy) && lanes == 32) {
3573 auto v32i16Ty = VectorType::get({32}, rewriter.getI16Type());
3574 auto structTy = LLVM::LLVMStructType::getLiteral(
3575 rewriter.getContext(), {v32i16Ty, rewriter.getI32Type()});
3576 auto cmpCst = LLVM::ConstantOp::create(
3577 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
3580 if (pred ==
"slt" || pred ==
"ult") {
3581 auto intrOp = MaxLt16IntrOp::create(
3582 rewriter, loc, structTy,
3583 forceCastOperandsToSignature(
3584 rewriter, loc, {adaptor.getLhs(), adaptor.getRhs(), cmpCst},
3585 {v32i16Ty, v32i16Ty, rewriter.getI32Type()}));
3586 bitmask = LLVM::ExtractValueOp::create(rewriter, loc, intrOp, 1);
3587 }
else if (pred ==
"sge" || pred ==
"uge") {
3588 auto intrOp = MinGe16IntrOp::create(
3589 rewriter, loc, structTy,
3590 forceCastOperandsToSignature(
3591 rewriter, loc, {adaptor.getLhs(), adaptor.getRhs(), cmpCst},
3592 {v32i16Ty, v32i16Ty, rewriter.getI32Type()}));
3593 bitmask = LLVM::ExtractValueOp::create(rewriter, loc, intrOp, 1);
3594 }
else if (pred ==
"sgt" || pred ==
"ugt") {
3595 auto intrOp = MaxLt16IntrOp::create(
3596 rewriter, loc, structTy,
3597 forceCastOperandsToSignature(
3598 rewriter, loc, {adaptor.getRhs(), adaptor.getLhs(), cmpCst},
3599 {v32i16Ty, v32i16Ty, rewriter.getI32Type()}));
3600 bitmask = LLVM::ExtractValueOp::create(rewriter, loc, intrOp, 1);
3601 }
else if (pred ==
"sle" || pred ==
"ule") {
3602 auto intrOp = MinGe16IntrOp::create(
3603 rewriter, loc, structTy,
3604 forceCastOperandsToSignature(
3605 rewriter, loc, {adaptor.getRhs(), adaptor.getLhs(), cmpCst},
3606 {v32i16Ty, v32i16Ty, rewriter.getI32Type()}));
3607 bitmask = LLVM::ExtractValueOp::create(rewriter, loc, intrOp, 1);
3608 }
else if (pred ==
"eq") {
3609 auto geAB = MinGe16IntrOp::create(
3610 rewriter, loc, structTy,
3611 forceCastOperandsToSignature(
3612 rewriter, loc, {adaptor.getLhs(), adaptor.getRhs(), cmpCst},
3613 {v32i16Ty, v32i16Ty, rewriter.getI32Type()}));
3614 auto geBA = MinGe16IntrOp::create(
3615 rewriter, loc, structTy,
3616 forceCastOperandsToSignature(
3617 rewriter, loc, {adaptor.getRhs(), adaptor.getLhs(), cmpCst},
3618 {v32i16Ty, v32i16Ty, rewriter.getI32Type()}));
3619 auto maskAB = LLVM::ExtractValueOp::create(rewriter, loc, geAB, 1);
3620 auto maskBA = LLVM::ExtractValueOp::create(rewriter, loc, geBA, 1);
3621 bitmask = LLVM::AndOp::create(rewriter, loc, maskAB, maskBA);
3622 }
else if (pred ==
"ne") {
3623 auto ltAB = MaxLt16IntrOp::create(
3624 rewriter, loc, structTy,
3625 forceCastOperandsToSignature(
3626 rewriter, loc, {adaptor.getLhs(), adaptor.getRhs(), cmpCst},
3627 {v32i16Ty, v32i16Ty, rewriter.getI32Type()}));
3628 auto ltBA = MaxLt16IntrOp::create(
3629 rewriter, loc, structTy,
3630 forceCastOperandsToSignature(
3631 rewriter, loc, {adaptor.getRhs(), adaptor.getLhs(), cmpCst},
3632 {v32i16Ty, v32i16Ty, rewriter.getI32Type()}));
3633 auto maskAB = LLVM::ExtractValueOp::create(rewriter, loc, ltAB, 1);
3634 auto maskBA = LLVM::ExtractValueOp::create(rewriter, loc, ltBA, 1);
3635 bitmask = LLVM::OrOp::create(rewriter, loc, maskAB, maskBA);
3640 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
3641 op, op.getResult().getType(), bitmask);
3651 xllvm::VectorMinGeBf16IntrOp,
3652 xllvm::VectorMaxLt32IntrOp, xllvm::VectorMinGe32IntrOp,
3653 xllvm::VectorMaxLt16IntrOp, xllvm::VectorMinGe16IntrOp>;
3656 xllvm::VectorMaxLtBf16AIE2pIntrOp, xllvm::VectorMinGeBf16AIE2pIntrOp,
3657 xllvm::VectorMaxLt32AIE2pIntrOp, xllvm::VectorMinGe32AIE2pIntrOp,
3658 xllvm::VectorMaxLt16AIE2pIntrOp, xllvm::VectorMinGe16AIE2pIntrOp>;
3663template <
typename Sel16IntrOp,
typename Sel32IntrOp>
3666 using ConvertOpToLLVMPattern<aievec::SelOp>::ConvertOpToLLVMPattern;
3670 ConversionPatternRewriter &rewriter)
const override {
3671 Location loc = op.getLoc();
3672 auto resultType = cast<VectorType>(op.getResult().getType());
3673 auto elTy = resultType.getElementType();
3674 unsigned elWidth = elTy.getIntOrFloatBitWidth();
3676 auto i32Ty = rewriter.getI32Type();
3679 Value selMask = adaptor.getSel();
3680 if (selMask.getType() != i32Ty)
3682 UnrealizedConversionCastOp::create(rewriter, loc, i32Ty, selMask)
3686 if (elWidth == 16 && isa<FloatType>(elTy)) {
3687 auto v32i16Ty = VectorType::get({32}, rewriter.getI16Type());
3688 auto v32bf16Ty = VectorType::get({32}, rewriter.getBF16Type());
3690 Value lhs = adaptor.getLhs(), rhs = adaptor.getRhs();
3691 bool needExtract =
false;
3694 lhs = padVectorWithPoison(rewriter, loc, lhs, 16, 32);
3695 rhs = padVectorWithPoison(rewriter, loc, rhs, 16, 32);
3697 }
else if (lanes != 32) {
3702 auto lhsCast = forceCastValueToType(rewriter, loc, lhs, v32bf16Ty);
3703 auto lhsI16 = LLVM::BitcastOp::create(rewriter, loc, v32i16Ty, lhsCast);
3704 auto rhsCast = forceCastValueToType(rewriter, loc, rhs, v32bf16Ty);
3705 auto rhsI16 = LLVM::BitcastOp::create(rewriter, loc, v32i16Ty, rhsCast);
3707 auto selResult = Sel16IntrOp::create(
3708 rewriter, loc, v32i16Ty,
3709 forceCastOperandsToSignature(rewriter, loc, {lhsI16, rhsI16, selMask},
3710 {v32i16Ty, v32i16Ty, i32Ty}));
3714 LLVM::BitcastOp::create(rewriter, loc, v32bf16Ty, selResult);
3717 result = extractLowerLanes(rewriter, loc, result, 16);
3719 rewriter.replaceOp(op, result);
3724 if (elWidth == 32 && isa<IntegerType>(elTy) && lanes == 16) {
3725 auto v16i32Ty = VectorType::get({16}, rewriter.getI32Type());
3726 auto selResult = Sel32IntrOp::create(
3727 rewriter, loc, v16i32Ty,
3728 forceCastOperandsToSignature(
3729 rewriter, loc, {adaptor.getLhs(), adaptor.getRhs(), selMask},
3730 {v16i32Ty, v16i32Ty, i32Ty}));
3731 rewriter.replaceOp(op, selResult->getResult(0));
3736 if (elWidth == 16 && isa<IntegerType>(elTy) && lanes == 32) {
3737 auto v32i16Ty = VectorType::get({32}, rewriter.getI16Type());
3738 auto selResult = Sel16IntrOp::create(
3739 rewriter, loc, v32i16Ty,
3740 forceCastOperandsToSignature(
3741 rewriter, loc, {adaptor.getLhs(), adaptor.getRhs(), selMask},
3742 {v32i16Ty, v32i16Ty, i32Ty}));
3743 rewriter.replaceOp(op, selResult->getResult(0));
3755 xllvm::VectorSel32AIE2pIntrOp>;
3758 :
public mlir::ConvertOpToLLVMPattern<aievec::BroadcastScalarOp> {
3760 using ConvertOpToLLVMPattern<
3761 aievec::BroadcastScalarOp>::ConvertOpToLLVMPattern;
3765 ConversionPatternRewriter &rewriter)
const override {
3766 Location loc = op.getLoc();
3768 Value result = op.getResult();
3769 VectorType resultType = cast<VectorType>(result.getType());
3770 Type resultScaTy = resultType.getElementType();
3771 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
3773 int resultVectorSize = resultBitWidth * resultLanes;
3775 if (resultVectorSize != 512) {
3777 <<
"aievec.broadcast_scalar conversion with result vector size "
3778 << resultVectorSize <<
" is not implemented.\n";
3783 if (llvm::isa<IntegerType>(resultScaTy)) {
3784 Value src = adaptor.getSource();
3785 Type srcType = src.getType();
3786 unsigned srcBitWidth = srcType.getIntOrFloatBitWidth();
3788 if (srcBitWidth < 32) {
3789 src = LLVM::SExtOp::create(rewriter, loc, rewriter.getI32Type(), src);
3792 if (resultBitWidth == 8) {
3793 rewriter.replaceOpWithNewOp<xllvm::VectorBroadcast8I512IntrOp>(
3794 op, VectorType::get({64}, rewriter.getI8Type()), src);
3795 }
else if (resultBitWidth == 16) {
3796 rewriter.replaceOpWithNewOp<xllvm::VectorBroadcast16I512IntrOp>(
3797 op, VectorType::get({32}, rewriter.getI16Type()), src);
3798 }
else if (resultBitWidth == 32) {
3799 rewriter.replaceOpWithNewOp<xllvm::VectorBroadcast32I512IntrOp>(
3800 op, VectorType::get({16}, rewriter.getI32Type()), src);
3803 <<
"aievec.broadcast_scalar conversion with result bitwidth "
3804 << resultBitWidth <<
" is not implemented.\n";
3809 if (resultBitWidth == 16) {
3810 rewriter.replaceOpWithNewOp<xllvm::VectorBroadcast16BF512IntrOp>(
3811 op, VectorType::get({32}, rewriter.getBF16Type()),
3812 adaptor.getSource());
3813 }
else if (resultBitWidth == 32) {
3819 auto srcAsI32 = bitcastValueToType(rewriter, loc, adaptor.getSource(),
3820 rewriter.getI32Type());
3821 auto broadcastI32 = xllvm::VectorBroadcast32I512IntrOp::create(
3822 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()),
3825 bitcastValueToType(rewriter, loc, broadcastI32,
3826 VectorType::get({16}, rewriter.getF32Type()));
3827 rewriter.replaceOp(op, resultF32);
3830 <<
"aievec.broadcast_scalar conversion with result bitwidth "
3831 << resultBitWidth <<
" is not implemented.\n";
3843 :
public mlir::ConvertOpToLLVMPattern<aievec::BroadcastScalarOp> {
3845 using ConvertOpToLLVMPattern<
3846 aievec::BroadcastScalarOp>::ConvertOpToLLVMPattern;
3850 ConversionPatternRewriter &rewriter)
const override {
3851 Location loc = op.getLoc();
3853 Value result = op.getResult();
3854 VectorType resultType = cast<VectorType>(result.getType());
3855 Type resultScaTy = resultType.getElementType();
3856 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
3858 int resultVectorSize = resultBitWidth * resultLanes;
3861 if (resultVectorSize != 256 && resultVectorSize != 512) {
3863 <<
"aievec.broadcast_scalar conversion with result vector size "
3864 << resultVectorSize <<
" is not implemented for AIE2p.\n";
3868 Value src = adaptor.getSource();
3869 Type srcType = src.getType();
3872 if (llvm::isa<IntegerType>(resultScaTy)) {
3873 unsigned srcBitWidth = srcType.getIntOrFloatBitWidth();
3874 if (srcBitWidth < resultBitWidth) {
3875 src = LLVM::SExtOp::create(rewriter, loc, resultScaTy, src);
3876 }
else if (srcBitWidth > resultBitWidth) {
3877 src = LLVM::TruncOp::create(rewriter, loc, resultScaTy, src);
3882 auto poisonVec = LLVM::PoisonOp::create(rewriter, loc, resultType);
3885 auto idx0 = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
3886 rewriter.getI64IntegerAttr(0));
3887 auto insertedVec = LLVM::InsertElementOp::create(rewriter, loc, resultType,
3888 poisonVec, src, idx0);
3892 SmallVector<int64_t> broadcastMask(resultLanes, 0);
3893 auto broadcastVec = vector::ShuffleOp::create(rewriter, loc, insertedVec,
3894 insertedVec, broadcastMask);
3896 rewriter.replaceOp(op, broadcastVec);
3903 using ConvertOpToLLVMPattern<aievec::ShiftOp>::ConvertOpToLLVMPattern;
3907 ConversionPatternRewriter &rewriter)
const override {
3908 Location loc = op.getLoc();
3910 Value result = op.getResult();
3911 VectorType resultType = cast<VectorType>(result.getType());
3912 Type resultScaTy = resultType.getElementType();
3913 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
3915 int resultVectorSize = resultBitWidth * resultLanes;
3917 if (resultVectorSize != 512) {
3918 op.emitWarning() <<
"aievec.shift conversion with result vector size "
3919 << resultVectorSize <<
" is not implemented.\n";
3924 auto stepCst = LLVM::ConstantOp::create(
3925 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
3928 Value shiftOp =
nullptr;
3929 SmallVector<Value> operands(
3930 {adaptor.getLhs(), adaptor.getRhs(), stepCst, adaptor.getShift()});
3931 if (llvm::isa<IntegerType>(resultScaTy)) {
3933 shiftOp = xllvm::VectorShiftI512I512IntrOp::create(
3934 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()),
3935 forceCastOperandsToSignature(
3936 rewriter, loc, operands,
3937 {VectorType::get({16}, rewriter.getI32Type()),
3938 VectorType::get({16}, rewriter.getI32Type()),
3939 rewriter.getI32Type(), rewriter.getI32Type()}));
3942 shiftOp = xllvm::VectorShiftBF512BF512IntrOp::create(
3943 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()),
3944 forceCastOperandsToSignature(
3945 rewriter, loc, operands,
3946 {VectorType::get({32}, rewriter.getBF16Type()),
3947 VectorType::get({32}, rewriter.getBF16Type()),
3948 rewriter.getI32Type(), rewriter.getI32Type()}));
3953 forceCastValueToType(rewriter, loc, shiftOp, op.getResult().getType());
3954 rewriter.replaceOp(op, resultVal);
3962 :
public mlir::ConvertOpToLLVMPattern<aievec::ShiftOp> {
3964 using ConvertOpToLLVMPattern<aievec::ShiftOp>::ConvertOpToLLVMPattern;
3968 ConversionPatternRewriter &rewriter)
const override {
3969 Location loc = op.getLoc();
3971 Value result = op.getResult();
3972 VectorType resultType = cast<VectorType>(result.getType());
3973 Type resultScaTy = resultType.getElementType();
3974 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
3976 int resultVectorSize = resultBitWidth * resultLanes;
3978 if (resultVectorSize != 512) {
3979 op.emitWarning() <<
"aievec.shift conversion with result vector size "
3980 << resultVectorSize <<
" is not implemented.\n";
3985 auto stepCst = LLVM::ConstantOp::create(
3986 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
3989 Value shiftOp =
nullptr;
3990 SmallVector<Value> operands(
3991 {adaptor.getLhs(), adaptor.getRhs(), stepCst, adaptor.getShift()});
3992 if (llvm::isa<IntegerType>(resultScaTy)) {
3994 shiftOp = xllvm::VectorShiftI512I512AIE2pIntrOp::create(
3995 rewriter, loc, VectorType::get({16}, rewriter.getI32Type()),
3996 forceCastOperandsToSignature(
3997 rewriter, loc, operands,
3998 {VectorType::get({16}, rewriter.getI32Type()),
3999 VectorType::get({16}, rewriter.getI32Type()),
4000 rewriter.getI32Type(), rewriter.getI32Type()}));
4003 shiftOp = xllvm::VectorShiftBF512BF512AIE2pIntrOp::create(
4004 rewriter, loc, VectorType::get({32}, rewriter.getBF16Type()),
4005 forceCastOperandsToSignature(
4006 rewriter, loc, operands,
4007 {VectorType::get({32}, rewriter.getBF16Type()),
4008 VectorType::get({32}, rewriter.getBF16Type()),
4009 rewriter.getI32Type(), rewriter.getI32Type()}));
4014 forceCastValueToType(rewriter, loc, shiftOp, op.getResult().getType());
4015 rewriter.replaceOp(op, resultVal);
4022 :
public mlir::ConvertOpToLLVMPattern<aievec::ExtElemOp> {
4024 using ConvertOpToLLVMPattern<aievec::ExtElemOp>::ConvertOpToLLVMPattern;
4028 ConversionPatternRewriter &rewriter)
const override {
4029 Location loc = op.getLoc();
4031 Type resultType = op.getResult().getType();
4032 unsigned resultBitWidth = resultType.getIntOrFloatBitWidth();
4034 Value src = adaptor.getSource();
4035 VectorType srcType = cast<VectorType>(src.getType());
4036 Type srcScalarType = srcType.getElementType();
4037 unsigned srcBitWidth = srcScalarType.getIntOrFloatBitWidth();
4039 int srcVectorSize = srcBitWidth * srcLanes;
4041 if (srcVectorSize != 512) {
4042 op.emitWarning() <<
"aievec.ext_elem conversion with source vector size "
4043 << srcVectorSize <<
" is not supported.\n";
4048 auto signCst = LLVM::ConstantOp::create(
4049 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
4052 Value extElemOp =
nullptr;
4053 SmallVector<Value> operands(
4054 {adaptor.getSource(), adaptor.getIndex(), signCst});
4055 if (resultBitWidth == 8) {
4056 extElemOp = xllvm::VectorExtractElem8I512IntrOp::create(
4057 rewriter, loc, rewriter.getI32Type(),
4058 forceCastOperandsToSignature(
4059 rewriter, loc, operands,
4060 {VectorType::get({64}, rewriter.getI8Type()),
4061 rewriter.getI32Type(), rewriter.getI32Type()}));
4062 }
else if (resultBitWidth == 16) {
4063 extElemOp = xllvm::VectorExtractElem16I512IntrOp::create(
4064 rewriter, loc, rewriter.getI32Type(),
4065 forceCastOperandsToSignature(
4066 rewriter, loc, operands,
4067 {VectorType::get({32}, rewriter.getI16Type()),
4068 rewriter.getI32Type(), rewriter.getI32Type()}));
4069 }
else if (resultBitWidth == 32) {
4070 extElemOp = xllvm::VectorExtractElem32I512IntrOp::create(
4071 rewriter, loc, rewriter.getI32Type(),
4072 forceCastOperandsToSignature(
4073 rewriter, loc, operands,
4074 {VectorType::get({16}, rewriter.getI32Type()),
4075 rewriter.getI32Type(), rewriter.getI32Type()}));
4077 op.emitWarning() <<
"aievec.ext_elem conversion with result bit width "
4078 << resultBitWidth <<
" is not implemented.\n";
4083 if (llvm::isa<IntegerType>(resultType)) {
4084 if (resultBitWidth < 32) {
4087 if (resultBitWidth < 16) {
4088 auto i16Ty = rewriter.getI16Type();
4089 auto trunc16 = LLVM::TruncOp::create(rewriter, loc, i16Ty, extElemOp);
4090 rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, resultType,
4091 trunc16.getResult());
4093 rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, resultType, extElemOp);
4096 rewriter.replaceOp(op, extElemOp);
4100 if (resultBitWidth == 16) {
4101 extElemOp = LLVM::TruncOp::create(rewriter, loc, rewriter.getI16Type(),
4104 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, resultType, extElemOp);
4112 :
public mlir::ConvertOpToLLVMPattern<aievec::FMAElemOp> {
4114 using ConvertOpToLLVMPattern<aievec::FMAElemOp>::ConvertOpToLLVMPattern;
4118 ConversionPatternRewriter &rewriter)
const override {
4119 auto loc = fmaOp.getLoc();
4120 auto lhs = adaptor.getLhs();
4121 auto rhs = adaptor.getRhs();
4122 auto acc = adaptor.getAcc();
4123 auto lhsTy = cast<VectorType>(lhs.getType());
4124 auto rhsTy = cast<VectorType>(rhs.getType());
4125 auto accTy = cast<VectorType>(acc.getType());
4131 if (lhsTy != flatLhsTy)
4132 lhs = vector::ShapeCastOp::create(rewriter, loc, flatLhsTy, lhs);
4133 if (rhsTy != flatRhsTy)
4134 rhs = vector::ShapeCastOp::create(rewriter, loc, flatRhsTy, rhs);
4135 if (accTy != flatAccTy)
4136 acc = vector::ShapeCastOp::create(rewriter, loc, flatAccTy, acc);
4139 Type i32ty = rewriter.getI32Type();
4140 auto confCst = LLVM::ConstantOp::create(
4141 rewriter, loc, i32ty,
4142 rewriter.getI32IntegerAttr(aiev2_vmac_compute_control(
4152 auto v32bf16Ty = VectorType::get({32}, rewriter.getBF16Type());
4153 if (flatLhsTy.getElementType().isBF16() &&
4154 flatLhsTy.getNumElements() < 32) {
4155 auto zero32 = LLVM::ConstantOp::create(rewriter, loc, i32ty,
4156 rewriter.getI32IntegerAttr(0));
4157 auto zeros_i16 = xllvm::VectorBroadcast16I512IntrOp::create(
4158 rewriter, loc, VectorType::get({32}, rewriter.getI16Type()), zero32);
4160 LLVM::BitcastOp::create(rewriter, loc, v32bf16Ty, zeros_i16);
4161 auto zeroVec = xllvm::ExtBF256BF512IntrOp::create(
4162 rewriter, loc, VectorType::get({16}, rewriter.getBF16Type()),
4163 zeros_bf16, zero32);
4165 auto idx1 = LLVM::ConstantOp::create(rewriter, loc, i32ty,
4166 rewriter.getI32IntegerAttr(1));
4168 auto lhsSet = xllvm::VectorSetBF512BF256IntrOp::create(
4169 rewriter, loc, v32bf16Ty, lhs, zero32);
4170 lhs = xllvm::UpdBF512BF256IntrOp::create(rewriter, loc, v32bf16Ty, lhsSet,
4173 auto rhsSet = xllvm::VectorSetBF512BF256IntrOp::create(
4174 rewriter, loc, v32bf16Ty, rhs, zero32);
4175 rhs = xllvm::UpdBF512BF256IntrOp::create(rewriter, loc, v32bf16Ty, rhsSet,
4180 auto v8i64Ty = VectorType::get({8}, rewriter.getI64Type());
4181 auto macIntrOp = xllvm::MacConfBF16IntrOp::create(
4182 rewriter, loc, v8i64Ty,
4183 forceCastOperandsToSignature(rewriter, loc, {lhs, rhs, acc, confCst},
4184 {v32bf16Ty, v32bf16Ty, v8i64Ty, i32ty}));
4188 forceCastValueToType(rewriter, loc, macIntrOp.getResult(), flatAccTy);
4189 if (flatAccTy != accTy)
4190 resVal = vector::ShapeCastOp::create(rewriter, loc, accTy, resVal);
4192 rewriter.replaceOp(fmaOp, resVal);
4198 :
public mlir::ConvertOpToLLVMPattern<aievec::MatMulOp> {
4199 using ConvertOpToLLVMPattern<aievec::MatMulOp>::ConvertOpToLLVMPattern;
4201 struct DecodedMatMulOp {
4202 typedef enum { I32, I64, BF16 } Kind;
4211 static DecodedMatMulOp decodeMatMulOp(OpAdaptor op) {
4212 Value lhs = op.getLhs();
4213 Value rhs = op.getRhs();
4214 Value acc = op.getAcc();
4215 auto accVecTy = cast<VectorType>(acc.getType());
4216 if (isa<Float32Type>(accVecTy.getElementType()))
4218 return {DecodedMatMulOp::Kind::BF16, lhs, rhs, acc,
4219 aiev2_vmac_compute_control(
4228 auto lookThroughShapeCasts = [](Value v) -> Value {
4229 while (
auto castOp = v.getDefiningOp<vector::ShapeCastOp>())
4230 v = castOp.getSource();
4234 int signX = 0, signY = 0;
4235 auto lhsVecTy = cast<VectorType>(lhs.getType());
4236 auto lhsScaTy = cast<IntegerType>(lhsVecTy.getElementType());
4237 Value lhsOrig = lookThroughShapeCasts(lhs);
4238 if (
auto extSIOp = lhsOrig.getDefiningOp<arith::ExtSIOp>()) {
4239 lhs = lookThroughShapeCasts(extSIOp.getIn());
4240 lhsVecTy = cast<VectorType>(lhs.getType());
4241 lhsScaTy = cast<IntegerType>(lhsVecTy.getElementType());
4243 }
else if (
auto extUIOp = lhsOrig.getDefiningOp<arith::ExtUIOp>()) {
4244 lhs = lookThroughShapeCasts(extUIOp.getIn());
4245 lhsVecTy = cast<VectorType>(lhs.getType());
4246 lhsScaTy = cast<IntegerType>(lhsVecTy.getElementType());
4252 if (lhsScaTy.isUnsigned())
4255 auto lhsShape = lhsVecTy.getShape();
4257 auto rhsVecTy = cast<VectorType>(rhs.getType());
4258 auto rhsScaTy = cast<IntegerType>(rhsVecTy.getElementType());
4259 Value rhsOrig = lookThroughShapeCasts(rhs);
4260 if (
auto extSIOp = rhsOrig.getDefiningOp<arith::ExtSIOp>()) {
4261 rhs = lookThroughShapeCasts(extSIOp.getIn());
4262 rhsVecTy = cast<VectorType>(rhs.getType());
4263 rhsScaTy = cast<IntegerType>(rhsVecTy.getElementType());
4265 }
else if (
auto extUIOp = rhsOrig.getDefiningOp<arith::ExtUIOp>()) {
4266 rhs = lookThroughShapeCasts(extUIOp.getIn());
4267 rhsVecTy = cast<VectorType>(rhs.getType());
4268 rhsScaTy = cast<IntegerType>(rhsVecTy.getElementType());
4271 if (!rhsScaTy.isUnsigned())
4275 unsigned lhsBitWidth = lhsScaTy.getWidth();
4276 unsigned rhsBitWidth = rhsScaTy.getWidth();
4277 auto accScaTy = cast<IntegerType>(accVecTy.getElementType());
4278 unsigned accBitWidth = accScaTy.getWidth();
4279 if (accBitWidth == 32) {
4280 if (lhsBitWidth == 8) {
4281 if (rhsBitWidth == 4) {
4283 return {DecodedMatMulOp::Kind::I32, lhs, rhs, acc,
4284 aiev2_vmac_compute_control(
4292 return {DecodedMatMulOp::Kind::I32, lhs, rhs, acc,
4293 aiev2_vmac_compute_control(
4301 if (rhsBitWidth == 8) {
4303 return {DecodedMatMulOp::Kind::I32, lhs, rhs, acc,
4304 aiev2_vmac_compute_control(
4312 return {DecodedMatMulOp::Kind::I32, lhs, rhs, acc,
4313 aiev2_vmac_compute_control(
4323 if (lhsBitWidth == 16) {
4324 if (rhsBitWidth == 8) {
4325 if (lhsShape == ArrayRef<int64_t>({2, 8})) {
4327 return {DecodedMatMulOp::Kind::I64, lhs, rhs, acc,
4328 aiev2_vmac_compute_control(
4336 return {DecodedMatMulOp::Kind::I64, lhs, rhs, acc,
4337 aiev2_vmac_compute_control(
4343 if (lhsShape == ArrayRef<int64_t>({2, 4})) {
4345 return {DecodedMatMulOp::Kind::I64, lhs, rhs, acc,
4346 aiev2_vmac_compute_control(
4353 return {DecodedMatMulOp::Kind::I64, lhs, rhs, acc,
4354 aiev2_vmac_compute_control(
4361 return {DecodedMatMulOp::Kind::I64, lhs, rhs, acc,
4362 aiev2_vmac_compute_control(
4370 matchAndRewrite(aievec::MatMulOp op, OpAdaptor adaptor,
4371 ConversionPatternRewriter &rewriter)
const override {
4372 auto decodedMatMulOp = decodeMatMulOp(adaptor);
4374 Location loc = op.getLoc();
4376 auto lhsFlattenedVecTy =
4378 decodedMatMulOp.lhs = vector::ShapeCastOp::create(
4379 rewriter, loc, lhsFlattenedVecTy, decodedMatMulOp.lhs);
4380 auto rhsFlattenedVecTy =
4382 decodedMatMulOp.rhs = vector::ShapeCastOp::create(
4383 rewriter, loc, rhsFlattenedVecTy, decodedMatMulOp.rhs);
4384 auto accFlattenedVecTy =
4386 decodedMatMulOp.acc = vector::ShapeCastOp::create(
4387 rewriter, loc, accFlattenedVecTy, decodedMatMulOp.acc);
4389 Type i32ty = rewriter.getI32Type();
4390 auto confCst = LLVM::ConstantOp::create(
4391 rewriter, loc, i32ty, rewriter.getI32IntegerAttr(decodedMatMulOp.conf));
4392 SmallVector<Value> operands({decodedMatMulOp.lhs, decodedMatMulOp.rhs,
4393 decodedMatMulOp.acc, confCst});
4395 if (decodedMatMulOp.kind == DecodedMatMulOp::Kind::BF16)
4397 xllvm::MacConfBF16IntrOp::create(
4398 rewriter, loc, VectorType::get({8}, rewriter.getI64Type()),
4399 forceCastOperandsToSignature(
4400 rewriter, loc, operands,
4401 {VectorType::get({32}, rewriter.getBF16Type()),
4402 VectorType::get({32}, rewriter.getBF16Type()),
4403 VectorType::get({8}, rewriter.getI64Type()), i32ty}))
4406 SmallVector<Type> intrFuncSig(
4407 {VectorType::get({64}, rewriter.getI8Type()),
4408 VectorType::get({16}, i32ty),
4409 VectorType::get({16}, rewriter.getI64Type()), i32ty});
4410 VectorType v16xi64ty = VectorType::get({16}, rewriter.getI64Type());
4411 if (decodedMatMulOp.kind == DecodedMatMulOp::Kind::I32)
4412 matMulResVal = xllvm::MacConfAcc32IntrOp::create(
4413 rewriter, loc, v16xi64ty,
4414 forceCastOperandsToSignature(rewriter, loc, operands,
4418 matMulResVal = xllvm::MacConfAcc64IntrOp::create(
4419 rewriter, loc, v16xi64ty,
4420 forceCastOperandsToSignature(rewriter, loc, operands,
4426 bitcastValueToType(rewriter, loc, matMulResVal, accFlattenedVecTy);
4428 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(),
4437static Value transposeAndConvertRHS(OpBuilder &rewriter, Location loc,
4438 Type i32ty, Value rhs64bf16) {
4439 auto v32f32Ty = VectorType::get({32}, rewriter.getF32Type());
4443 auto rhs64i32 = forceCastValueToType(
4444 rewriter, loc, rhs64bf16, VectorType::get({32}, rewriter.getI32Type()));
4447 SmallVector<int64_t> chunk0Mask, chunk1Mask;
4448 for (
int i = 0; i < 16; ++i) {
4449 chunk0Mask.push_back(i);
4450 chunk1Mask.push_back(16 + i);
4453 vector::ShuffleOp::create(rewriter, loc, rhs64i32, rhs64i32, chunk0Mask);
4455 vector::ShuffleOp::create(rewriter, loc, rhs64i32, rhs64i32, chunk1Mask);
4458 auto shuffleMode52 = LLVM::ConstantOp::create(rewriter, loc, i32ty,
4459 rewriter.getI32IntegerAttr(52));
4460 auto shuffleMode53 = LLVM::ConstantOp::create(rewriter, loc, i32ty,
4461 rewriter.getI32IntegerAttr(53));
4463 auto shuffled52 = xllvm::VectorShuffleAIE2pIntrOp::create(
4464 rewriter, loc, VectorType::get({16}, i32ty), rhs16i32_0, rhs16i32_1,
4466 auto shuffled53 = xllvm::VectorShuffleAIE2pIntrOp::create(
4467 rewriter, loc, VectorType::get({16}, i32ty), rhs16i32_0, rhs16i32_1,
4471 SmallVector<int64_t> transposeConcatMask;
4472 for (
int i = 0; i < 32; ++i)
4473 transposeConcatMask.push_back(i);
4474 auto rhsTransposedI32 = vector::ShuffleOp::create(
4475 rewriter, loc, shuffled52, shuffled53, transposeConcatMask);
4476 auto rhsTransposedBF16 =
4477 forceCastValueToType(rewriter, loc, rhsTransposedI32,
4478 VectorType::get({64}, rewriter.getBF16Type()));
4481 SmallVector<int64_t> firstHalfMask, secondHalfMask;
4482 for (
int i = 0; i < 32; ++i) {
4483 firstHalfMask.push_back(i);
4484 secondHalfMask.push_back(32 + i);
4487 auto rhsT32bf16_lo = vector::ShuffleOp::create(
4488 rewriter, loc, rhsTransposedBF16, rhsTransposedBF16, firstHalfMask);
4489 auto rhsT32bf16_hi = vector::ShuffleOp::create(
4490 rewriter, loc, rhsTransposedBF16, rhsTransposedBF16, secondHalfMask);
4492 auto rhsT32f32_lo = xllvm::Vector32BF16ToV32AccFloatAIE2pIntrOp::create(
4493 rewriter, loc, v32f32Ty, rhsT32bf16_lo);
4494 auto rhsT32f32_hi = xllvm::Vector32BF16ToV32AccFloatAIE2pIntrOp::create(
4495 rewriter, loc, v32f32Ty, rhsT32bf16_hi);
4498 SmallVector<int64_t> concatMask;
4499 for (
int i = 0; i < 64; ++i)
4500 concatMask.push_back(i);
4501 return vector::ShuffleOp::create(rewriter, loc, rhsT32f32_lo, rhsT32f32_hi,
4508static Value performBFP16_8x8MatMul(OpBuilder &rewriter, Location loc,
4509 Type i32ty, Value lhs64f32,
4510 Value rhs64f32Transposed, Value acc64i32,
4512 auto v64i32Ty = VectorType::get({64}, rewriter.getI32Type());
4515 auto bfpStructTy = mlir::LLVM::LLVMStructType::getLiteral(
4516 rewriter.getContext(), {VectorType::get({64}, rewriter.getI8Type()),
4517 VectorType::get({8}, rewriter.getI8Type())});
4519 auto lhsBFP = xllvm::Vector64AccFloatToV64BFP16EBS8AIE2pIntrOp::create(
4520 rewriter, loc, bfpStructTy, lhs64f32);
4521 auto rhsBFP = xllvm::Vector64AccFloatToV64BFP16EBS8AIE2pIntrOp::create(
4522 rewriter, loc, bfpStructTy, rhs64f32Transposed);
4525 auto lhsData = LLVM::ExtractValueOp::create(rewriter, loc, lhsBFP, 0);
4526 auto lhsExp = LLVM::ExtractValueOp::create(rewriter, loc, lhsBFP, 1);
4527 auto rhsData = LLVM::ExtractValueOp::create(rewriter, loc, rhsBFP, 0);
4528 auto rhsExp = LLVM::ExtractValueOp::create(rewriter, loc, rhsBFP, 1);
4531 return xllvm::MacConfBFP576ACC2048AIE2pIntrOp::create(
4532 rewriter, loc, v64i32Ty, lhsData, lhsExp, rhsData, rhsExp, acc64i32,
4539static Value perform8x8x4MatMul(OpBuilder &rewriter, Location loc, Type i32ty,
4540 Value lhs64bf16, Value rhs32bf16,
4542 auto v32bf16Ty = VectorType::get({32}, rewriter.getBF16Type());
4543 auto v32f32Ty = VectorType::get({32}, rewriter.getF32Type());
4546 SmallVector<int64_t> lowerMask, upperMask;
4547 for (
int i = 0; i < 32; ++i) {
4548 lowerMask.push_back(i);
4549 upperMask.push_back(32 + i);
4553 vector::ShuffleOp::create(rewriter, loc, lhs64bf16, lhs64bf16, lowerMask);
4555 vector::ShuffleOp::create(rewriter, loc, lhs64bf16, lhs64bf16, upperMask);
4559 forceCastValueToType(rewriter, loc, xl, VectorType::get({16}, i32ty));
4561 forceCastValueToType(rewriter, loc, xh, VectorType::get({16}, i32ty));
4564 auto shuffleModeLo = LLVM::ConstantOp::create(rewriter, loc, i32ty,
4565 rewriter.getI32IntegerAttr(52));
4566 auto xa = xllvm::VectorShuffleAIE2pIntrOp::create(
4567 rewriter, loc, VectorType::get({16}, i32ty), xlI32, xhI32, shuffleModeLo);
4569 auto shuffleModeHi = LLVM::ConstantOp::create(rewriter, loc, i32ty,
4570 rewriter.getI32IntegerAttr(53));
4571 auto xb = xllvm::VectorShuffleAIE2pIntrOp::create(
4572 rewriter, loc, VectorType::get({16}, i32ty), xlI32, xhI32, shuffleModeHi);
4575 auto xaBF16 = forceCastValueToType(rewriter, loc, xa, v32bf16Ty);
4576 auto xbBF16 = forceCastValueToType(rewriter, loc, xb, v32bf16Ty);
4579 auto extractBroadcastShuffle = [&](Value
src,
int idx) -> Value {
4580 SmallVector<int64_t> extractMask;
4581 int startIdx = idx * 8;
4582 for (
int i = 0; i < 8; ++i)
4583 extractMask.push_back(startIdx + i);
4585 for (
int rep = 0; rep < 3; ++rep) {
4586 for (
int i = 0; i < 8; ++i)
4587 extractMask.push_back(startIdx + i);
4590 vector::ShuffleOp::create(rewriter, loc, src, src, extractMask);
4593 auto broadI32 = forceCastValueToType(rewriter, loc, broadcasted,
4594 VectorType::get({16}, i32ty));
4595 auto shuffleMode4x8 = LLVM::ConstantOp::create(
4596 rewriter, loc, i32ty, rewriter.getI32IntegerAttr(29));
4597 auto shuffled = xllvm::VectorShuffleAIE2pIntrOp::create(
4598 rewriter, loc, VectorType::get({16}, i32ty), broadI32, broadI32,
4601 return forceCastValueToType(rewriter, loc, shuffled, v32bf16Ty);
4605 SmallVector<Value> rowVectors;
4606 for (
int i = 0; i < 4; ++i)
4607 rowVectors.push_back(extractBroadcastShuffle(xaBF16, i));
4608 for (
int i = 0; i < 4; ++i)
4609 rowVectors.push_back(extractBroadcastShuffle(xbBF16, i));
4612 auto extractBroadcast4 = [&](Value
src,
int idx) -> Value {
4613 SmallVector<int64_t> mask;
4614 int startIdx = idx * 4;
4616 for (
int rep = 0; rep < 8; ++rep) {
4617 for (
int i = 0; i < 4; ++i)
4618 mask.push_back(startIdx + i);
4620 return vector::ShuffleOp::create(rewriter, loc, src, src, mask);
4624 SmallVector<Value> colVectors;
4625 for (
int i = 0; i < 8; ++i)
4626 colVectors.push_back(extractBroadcast4(rhs32bf16, i));
4629 auto conf60 = LLVM::ConstantOp::create(rewriter, loc, i32ty,
4630 rewriter.getI32IntegerAttr(60));
4632 Value acc = acc32f32;
4633 for (
int i = 0; i < 8; ++i) {
4634 acc = xllvm::MacConfBF16I512ACC1024AIE2pIntrOp::create(
4635 rewriter, loc, v32f32Ty, rowVectors[i], colVectors[i], acc, conf60);
4642 :
public mlir::ConvertOpToLLVMPattern<aievec::MatMulOp_AIE2P> {
4643 using ConvertOpToLLVMPattern<aievec::MatMulOp_AIE2P>::ConvertOpToLLVMPattern;
4644 struct DecodedMatMulOp {
4646 BF16_8x8x8_I1024_ACC2048,
4647 BF16_4x8x8_I1024_ACC1024,
4648 BF16_8x1x8_I512_ACC2048,
4649 BF16_4x8x4_I512_ACC512,
4650 BF16_8x8x4_I512_ACC1024,
4651 I8_8x8x8_I512_ACC2048,
4652 I16_8x2x8_I1024_ACC2048,
4661 static DecodedMatMulOp decodeMatMulOp(OpAdaptor op) {
4662 Value lhs = op.getLhs();
4663 Value rhs = op.getRhs();
4664 Value acc = op.getAcc();
4666 auto lhsVecTy = cast<VectorType>(lhs.getType());
4667 auto rhsVecTy = cast<VectorType>(rhs.getType());
4668 auto accVecTy = cast<VectorType>(acc.getType());
4671 if (isa<IntegerType>(lhsVecTy.getElementType()) &&
4672 isa<IntegerType>(rhsVecTy.getElementType()) &&
4673 isa<IntegerType>(accVecTy.getElementType())) {
4675 auto lhsIntTy = cast<IntegerType>(lhsVecTy.getElementType());
4676 auto rhsIntTy = cast<IntegerType>(rhsVecTy.getElementType());
4677 auto accIntTy = cast<IntegerType>(accVecTy.getElementType());
4684 if (lhsIntTy.getWidth() == 8 && rhsIntTy.getWidth() == 8 &&
4685 accIntTy.getWidth() == 32 && lhsLanes == 64 && rhsLanes == 64 &&
4688 return {DecodedMatMulOp::Kind::I8_8x8x8_I512_ACC2048, lhs, rhs, acc,
4695 if (lhsIntTy.getWidth() == 16 && rhsIntTy.getWidth() == 16 &&
4696 accIntTy.getWidth() == 32 && lhsLanes == 16 && rhsLanes == 16 &&
4699 return {DecodedMatMulOp::Kind::I16_8x2x8_I1024_ACC2048, lhs, rhs, acc,
4705 if (isa<BFloat16Type>(lhsVecTy.getElementType()) &&
4706 isa<BFloat16Type>(rhsVecTy.getElementType()) &&
4707 isa<Float32Type>(accVecTy.getElementType())) {
4715 if (lhsLanes == 32 && rhsLanes == 32 && accLanes == 16) {
4717 return {DecodedMatMulOp::Kind::BF16_4x8x4_I512_ACC512, lhs, rhs, acc,
4721 else if (lhsLanes == 64 && rhsLanes == 32 && accLanes == 32) {
4723 return {DecodedMatMulOp::Kind::BF16_8x8x4_I512_ACC1024, lhs, rhs, acc,
4727 else if (lhsLanes == 32 && rhsLanes == 64 && accLanes == 32) {
4729 return {DecodedMatMulOp::Kind::BF16_4x8x8_I1024_ACC1024, lhs, rhs, acc,
4733 else if (lhsLanes == 8 && rhsLanes == 8 && accLanes == 64) {
4736 return {DecodedMatMulOp::Kind::BF16_8x1x8_I512_ACC2048, lhs, rhs, acc,
4740 else if (lhsLanes == 64 && rhsLanes == 64 && accLanes == 64) {
4742 return {DecodedMatMulOp::Kind::BF16_8x8x8_I1024_ACC2048, lhs, rhs, acc,
4747 return {DecodedMatMulOp::Kind::UNSUPPORTED, lhs, rhs, acc, -1};
4750 matchAndRewrite(aievec::MatMulOp_AIE2P op, OpAdaptor adaptor,
4751 ConversionPatternRewriter &rewriter)
const override {
4752 auto decodedMatMulOp = decodeMatMulOp(adaptor);
4753 if (decodedMatMulOp.kind == DecodedMatMulOp::Kind::UNSUPPORTED) {
4754 op.emitWarning() <<
"aievec.matmul_aie2p conversion is not supported for "
4755 "this type combination.\n";
4758 Location loc = op.getLoc();
4761 auto lhsFlattenedVecTy =
4763 decodedMatMulOp.lhs = vector::ShapeCastOp::create(
4764 rewriter, loc, lhsFlattenedVecTy, decodedMatMulOp.lhs);
4765 auto rhsFlattenedVecTy =
4767 decodedMatMulOp.rhs = vector::ShapeCastOp::create(
4768 rewriter, loc, rhsFlattenedVecTy, decodedMatMulOp.rhs);
4769 auto accFlattenedVecTy =
4771 decodedMatMulOp.acc = vector::ShapeCastOp::create(
4772 rewriter, loc, accFlattenedVecTy, decodedMatMulOp.acc);
4773 Type i32ty = rewriter.getI32Type();
4774 auto confCst = LLVM::ConstantOp::create(
4775 rewriter, loc, i32ty, rewriter.getI32IntegerAttr(decodedMatMulOp.conf));
4777 SmallVector<Value> operands({decodedMatMulOp.lhs, decodedMatMulOp.rhs,
4778 decodedMatMulOp.acc, confCst});
4782 if (decodedMatMulOp.kind == DecodedMatMulOp::Kind::I8_8x8x8_I512_ACC2048) {
4790 xllvm::MacConfI512ACC2048AIE2pIntrOp::create(
4791 rewriter, loc, VectorType::get({32}, rewriter.getI64Type()),
4792 forceCastOperandsToSignature(
4793 rewriter, loc, operands,
4794 {VectorType::get({16}, rewriter.getI32Type()),
4795 VectorType::get({32}, rewriter.getI16Type()),
4796 VectorType::get({32}, rewriter.getI64Type()), i32ty}))
4798 }
else if (decodedMatMulOp.kind ==
4799 DecodedMatMulOp::Kind::I16_8x2x8_I1024_ACC2048) {
4806 SmallVector<int64_t> lhsPadMask;
4807 for (
int i = 0; i < 16; ++i)
4808 lhsPadMask.push_back(i);
4809 for (
int i = 16; i < 64; ++i)
4810 lhsPadMask.push_back(-1);
4811 auto lhsPadded = vector::ShuffleOp::create(
4812 rewriter, loc, decodedMatMulOp.lhs, decodedMatMulOp.lhs, lhsPadMask);
4815 SmallVector<int64_t> rhsPadMask;
4816 for (
int i = 0; i < 16; ++i)
4817 rhsPadMask.push_back(i);
4818 for (
int i = 16; i < 64; ++i)
4819 rhsPadMask.push_back(-1);
4820 auto rhsPadded = vector::ShuffleOp::create(
4821 rewriter, loc, decodedMatMulOp.rhs, decodedMatMulOp.rhs, rhsPadMask);
4824 SmallVector<Value> paddedOperands(
4825 {lhsPadded, rhsPadded, decodedMatMulOp.acc, confCst});
4831 xllvm::MacConfI1024ACC2048AIE2pIntrOp::create(
4832 rewriter, loc, VectorType::get({32}, rewriter.getI64Type()),
4833 forceCastOperandsToSignature(
4834 rewriter, loc, paddedOperands,
4835 {VectorType::get({32}, rewriter.getI32Type()),
4836 VectorType::get({64}, rewriter.getI16Type()),
4837 VectorType::get({32}, rewriter.getI64Type()), i32ty}))
4839 }
else if (decodedMatMulOp.kind ==
4840 DecodedMatMulOp::Kind::BF16_8x8x8_I1024_ACC2048) {
4845 auto v32f32Ty = VectorType::get({32}, rewriter.getF32Type());
4848 SmallVector<int64_t> firstHalfMask, secondHalfMask;
4849 for (
int i = 0; i < 32; ++i) {
4850 firstHalfMask.push_back(i);
4851 secondHalfMask.push_back(32 + i);
4855 vector::ShuffleOp::create(rewriter, loc, decodedMatMulOp.lhs,
4856 decodedMatMulOp.lhs, firstHalfMask);
4858 vector::ShuffleOp::create(rewriter, loc, decodedMatMulOp.lhs,
4859 decodedMatMulOp.lhs, secondHalfMask);
4861 auto lhs32f32_lo = xllvm::Vector32BF16ToV32AccFloatAIE2pIntrOp::create(
4862 rewriter, loc, v32f32Ty, lhs32bf16_lo);
4863 auto lhs32f32_hi = xllvm::Vector32BF16ToV32AccFloatAIE2pIntrOp::create(
4864 rewriter, loc, v32f32Ty, lhs32bf16_hi);
4867 SmallVector<int64_t> concatMask;
4868 for (
int i = 0; i < 64; ++i)
4869 concatMask.push_back(i);
4870 auto lhs64f32 = vector::ShuffleOp::create(rewriter, loc, lhs32f32_lo,
4871 lhs32f32_hi, concatMask);
4874 auto rhsTransposed =
4875 transposeAndConvertRHS(rewriter, loc, i32ty, decodedMatMulOp.rhs);
4878 auto conf780 = LLVM::ConstantOp::create(rewriter, loc, i32ty,
4879 rewriter.getI32IntegerAttr(780));
4881 matMulResVal = performBFP16_8x8MatMul(
4882 rewriter, loc, i32ty, lhs64f32, rhsTransposed,
4883 forceCastValueToType(rewriter, loc, decodedMatMulOp.acc,
4884 VectorType::get({64}, rewriter.getI32Type())),
4886 }
else if (decodedMatMulOp.kind ==
4887 DecodedMatMulOp::Kind::BF16_4x8x8_I1024_ACC1024) {
4892 auto v32f32Ty = VectorType::get({32}, rewriter.getF32Type());
4895 auto lhs32f32 = xllvm::Vector32BF16ToV32AccFloatAIE2pIntrOp::create(
4896 rewriter, loc, v32f32Ty, decodedMatMulOp.lhs);
4899 SmallVector<int64_t> lhsPadMask;
4900 for (
int i = 0; i < 32; ++i)
4901 lhsPadMask.push_back(i);
4902 for (
int i = 32; i < 64; ++i)
4903 lhsPadMask.push_back(-1);
4904 auto lhs64f32 = vector::ShuffleOp::create(rewriter, loc, lhs32f32,
4905 lhs32f32, lhsPadMask);
4908 auto rhsTransposed =
4909 transposeAndConvertRHS(rewriter, loc, i32ty, decodedMatMulOp.rhs);
4912 SmallVector<int64_t> accPadMask;
4913 for (
int i = 0; i < 32; ++i)
4914 accPadMask.push_back(i);
4915 for (
int i = 32; i < 64; ++i)
4916 accPadMask.push_back(-1);
4917 auto acc64i32 = vector::ShuffleOp::create(
4919 forceCastValueToType(rewriter, loc, decodedMatMulOp.acc,
4920 VectorType::get({32}, rewriter.getI32Type())),
4921 forceCastValueToType(rewriter, loc, decodedMatMulOp.acc,
4922 VectorType::get({32}, rewriter.getI32Type())),
4926 auto result64i32 = performBFP16_8x8MatMul(
4927 rewriter, loc, i32ty, lhs64f32, rhsTransposed, acc64i32, confCst);
4930 SmallVector<int64_t> extractMask;
4931 for (
int i = 0; i < 32; ++i)
4932 extractMask.push_back(i);
4933 matMulResVal = vector::ShuffleOp::create(rewriter, loc, result64i32,
4934 result64i32, extractMask);
4935 }
else if (decodedMatMulOp.kind ==
4936 DecodedMatMulOp::Kind::BF16_8x1x8_I512_ACC2048) {
4941 auto v64f32Ty = VectorType::get({64}, rewriter.getF32Type());
4944 SmallVector<int64_t> lhsReplicateMask;
4945 for (
int rep = 0; rep < 8; ++rep) {
4946 for (
int i = 0; i < 8; ++i)
4947 lhsReplicateMask.push_back(i);
4950 vector::ShuffleOp::create(rewriter, loc, decodedMatMulOp.lhs,
4951 decodedMatMulOp.lhs, lhsReplicateMask);
4954 SmallVector<int64_t> transposeMask;
4955 for (
int c = 0; c < 8; ++c) {
4956 for (
int r = 0; r < 8; ++r) {
4957 transposeMask.push_back(r * 8 + c);
4960 auto lhs64bf16Transposed = vector::ShuffleOp::create(
4961 rewriter, loc, lhs64bf16, lhs64bf16, transposeMask);
4964 SmallVector<int64_t> rhsReplicateMask;
4965 for (
int rep = 0; rep < 8; ++rep) {
4966 for (
int i = 0; i < 8; ++i)
4967 rhsReplicateMask.push_back(i);
4970 vector::ShuffleOp::create(rewriter, loc, decodedMatMulOp.rhs,
4971 decodedMatMulOp.rhs, rhsReplicateMask);
4975 matMulResVal = xllvm::MacConfBF16I512ACC2048AIE2pIntrOp::create(
4976 rewriter, loc, v64f32Ty, lhs64bf16Transposed, rhs64bf16,
4977 decodedMatMulOp.acc, confCst);
4978 }
else if (decodedMatMulOp.kind ==
4979 DecodedMatMulOp::Kind::BF16_4x8x4_I512_ACC512) {
4986 SmallVector<int64_t> lhsPadMask;
4987 for (
int i = 0; i < 32; ++i)
4988 lhsPadMask.push_back(i);
4989 for (
int i = 32; i < 64; ++i)
4990 lhsPadMask.push_back(-1);
4991 auto lhsPadded = vector::ShuffleOp::create(
4992 rewriter, loc, decodedMatMulOp.lhs, decodedMatMulOp.lhs, lhsPadMask);
4995 SmallVector<int64_t> accPadMask;
4996 for (
int i = 0; i < 16; ++i)
4997 accPadMask.push_back(i);
4998 for (
int i = 16; i < 32; ++i)
4999 accPadMask.push_back(-1);
5000 auto accPadded = vector::ShuffleOp::create(
5001 rewriter, loc, decodedMatMulOp.acc, decodedMatMulOp.acc, accPadMask);
5004 Value acc32 = perform8x8x4MatMul(rewriter, loc, i32ty, lhsPadded,
5005 decodedMatMulOp.rhs, accPadded);
5008 SmallVector<int64_t> extractMask;
5009 for (
int i = 0; i < 16; ++i)
5010 extractMask.push_back(i);
5012 vector::ShuffleOp::create(rewriter, loc, acc32, acc32, extractMask);
5013 }
else if (decodedMatMulOp.kind ==
5014 DecodedMatMulOp::Kind::BF16_8x8x4_I512_ACC1024) {
5018 perform8x8x4MatMul(rewriter, loc, i32ty, decodedMatMulOp.lhs,
5019 decodedMatMulOp.rhs, decodedMatMulOp.acc);
5024 forceCastValueToType(rewriter, loc, matMulResVal, accFlattenedVecTy);
5026 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(),
5038 using ConvertOpToLLVMPattern<aievec::CastOp>::ConvertOpToLLVMPattern;
5041 static bool isConstantZero(Value val) {
5042 DenseElementsAttr denseAttr;
5045 if (
auto arithConstOp = val.getDefiningOp<arith::ConstantOp>()) {
5046 denseAttr = dyn_cast<DenseElementsAttr>(arithConstOp.getValue());
5047 }
else if (
auto llvmConstOp = val.getDefiningOp<LLVM::ConstantOp>()) {
5048 denseAttr = dyn_cast<DenseElementsAttr>(llvmConstOp.getValue());
5051 if (!denseAttr || !denseAttr.isSplat())
5054 auto splatAttr = denseAttr.getSplatValue<Attribute>();
5055 if (
auto floatAttr = dyn_cast<FloatAttr>(splatAttr))
5056 return floatAttr.getValue().isZero();
5057 if (
auto intAttr = dyn_cast<IntegerAttr>(splatAttr))
5058 return intAttr.getValue().isZero();
5064 matchAndRewrite(aievec::CastOp castOp, OpAdaptor adaptor,
5065 ConversionPatternRewriter &rewriter)
const override {
5069 if (!castOp.getIsResAcc() || !isConstantZero(adaptor.getSource())) {
5071 rewriter.replaceOp(castOp, adaptor.getSource());
5075 Location loc = castOp.getLoc();
5076 auto srcVecType = cast<VectorType>(castOp.getSource().getType());
5077 Type srcElemType = srcVecType.getElementType();
5081 if (srcElemType.isF32() && lanes == 16) {
5083 auto zeroAcc1024 = xllvm::VectorBroadcastZeroAcc1024IntrOp::create(
5084 rewriter, loc, VectorType::get({16}, rewriter.getI64Type()));
5087 SmallVector<int64_t> extractMask = {0, 1, 2, 3, 4, 5, 6, 7};
5088 auto zeroAcc512 = vector::ShuffleOp::create(rewriter, loc, zeroAcc1024,
5089 zeroAcc1024, extractMask);
5092 auto result = LLVM::BitcastOp::create(
5093 rewriter, loc, VectorType::get({16}, rewriter.getF32Type()),
5096 rewriter.replaceOp(castOp, result);
5101 rewriter.replaceOp(castOp, adaptor.getSource());
5108 :
public mlir::ConvertOpToLLVMPattern<aievec::CastOp> {
5109 using ConvertOpToLLVMPattern<aievec::CastOp>::ConvertOpToLLVMPattern;
5112 matchAndRewrite(aievec::CastOp castOp, OpAdaptor adaptor,
5113 ConversionPatternRewriter &rewriter)
const override {
5115 rewriter.replaceOp(castOp, adaptor.getSource());
5121 :
public mlir::ConvertOpToLLVMPattern<aievec::ShuffleOp> {
5122 using ConvertOpToLLVMPattern<aievec::ShuffleOp>::ConvertOpToLLVMPattern;
5125 matchAndRewrite(aievec::ShuffleOp shuffleOp, OpAdaptor adaptor,
5126 ConversionPatternRewriter &rewriter)
const override {
5127 auto loc = shuffleOp.getLoc();
5128 auto lhs = adaptor.getLhs();
5129 auto rhs = adaptor.getRhs();
5130 auto i32ty = rewriter.getI32Type();
5131 auto v16xi32ty = VectorType::get({16}, i32ty);
5133 rhs = xllvm::UndefV16I32IntrOp::create(rewriter, loc, v16xi32ty);
5136 LLVM::ConstantOp::create(rewriter, loc, i32ty,
5137 static_cast<int32_t
>(shuffleOp.getMode()))
5139 auto vShuffleVal = xllvm::VectorShuffleIntrOp::create(
5140 rewriter, loc, v16xi32ty,
5141 forceCastOperandsToSignature(
5143 {lhs, rhs, modeAttrVal},
5144 {v16xi32ty, v16xi32ty, i32ty}))
5147 vShuffleVal = forceCastValueToType(rewriter, loc, vShuffleVal,
5148 shuffleOp.getResult().getType());
5150 rewriter.replaceOp(shuffleOp, vShuffleVal);
5160 :
public mlir::ConvertOpToLLVMPattern<aievec::InvOp> {
5162 using ConvertOpToLLVMPattern<aievec::InvOp>::ConvertOpToLLVMPattern;
5166 ConversionPatternRewriter &rewriter)
const override {
5167 auto loc = invOp.getLoc();
5168 auto operandType = adaptor.getSource().getType();
5171 if (operandType.isF32()) {
5172 auto invResult = xllvm::InvAIE2pIntrOp::create(
5173 rewriter, loc, rewriter.getF32Type(), adaptor.getSource());
5174 rewriter.replaceOp(invOp, invResult);
5179 auto vecType = dyn_cast<VectorType>(operandType);
5180 if (!vecType || !vecType.getElementType().isF32())
5185 Value result = LLVM::PoisonOp::create(rewriter, loc, vecType);
5187 for (
int i = 0; i < numElements; ++i) {
5189 auto indexCst = LLVM::ConstantOp::create(
5190 rewriter, loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(i));
5191 auto extractedElem = LLVM::ExtractElementOp::create(
5192 rewriter, loc, adaptor.getSource(), indexCst);
5195 auto invResult = xllvm::InvAIE2pIntrOp::create(
5196 rewriter, loc, rewriter.getF32Type(), extractedElem);
5199 result = LLVM::InsertElementOp::create(rewriter, loc, vecType, result,
5200 invResult, indexCst);
5203 rewriter.replaceOp(invOp, result);
5212 :
public mlir::ConvertOpToLLVMPattern<aievec::ExpOp> {
5214 using ConvertOpToLLVMPattern<aievec::ExpOp>::ConvertOpToLLVMPattern;
5218 ConversionPatternRewriter &rewriter)
const override {
5219 auto loc = expOp.getLoc();
5220 auto srcType = cast<VectorType>(adaptor.getSource().getType());
5221 auto srcElemType = srcType.getElementType();
5225 if ((laneSize != 16 && laneSize != 32) || !srcElemType.isBF16())
5226 return expOp.emitWarning()
5227 <<
"aievec.exp conversion only supports v16bfloat16 and "
5231 auto log2eBF16Const = LLVM::ConstantOp::create(
5232 rewriter, loc, rewriter.getBF16Type(),
5233 rewriter.getFloatAttr(rewriter.getBF16Type(), 1.442695));
5236 SmallVector<int64_t> broadcastMask;
5237 for (
unsigned i = 0; i < laneSize; ++i)
5238 broadcastMask.push_back(0);
5240 auto v1bf16 = LLVM::UndefOp::create(
5241 rewriter, loc, VectorType::get({1}, rewriter.getBF16Type()));
5242 auto v1bf16Inserted = LLVM::InsertElementOp::create(
5243 rewriter, loc, v1bf16, log2eBF16Const,
5244 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), 0));
5246 auto log2eVec = vector::ShuffleOp::create(rewriter, loc, v1bf16Inserted,
5247 v1bf16Inserted, broadcastMask);
5253 VectorType::get({(int64_t)laneSize}, rewriter.getF32Type());
5254 auto mulResult = aievec::MulElemOp::create(rewriter, loc, resultF32Ty,
5255 adaptor.getSource(), log2eVec);
5259 auto v16bf16Ty = VectorType::get({16}, rewriter.getBF16Type());
5261 if (laneSize == 16) {
5265 xllvm::Exp2AIE2pIntrOp::create(rewriter, loc, v16bf16Ty, mulResult);
5269 SmallVector<int64_t> lowerMask, upperMask;
5270 for (
int i = 0; i < 16; ++i) {
5271 lowerMask.push_back(i);
5272 upperMask.push_back(16 + i);
5275 auto lowerHalf = vector::ShuffleOp::create(rewriter, loc, mulResult,
5276 mulResult, lowerMask);
5277 auto upperHalf = vector::ShuffleOp::create(rewriter, loc, mulResult,
5278 mulResult, upperMask);
5282 xllvm::Exp2AIE2pIntrOp::create(rewriter, loc, v16bf16Ty, lowerHalf);
5284 xllvm::Exp2AIE2pIntrOp::create(rewriter, loc, v16bf16Ty, upperHalf);
5287 SmallVector<int64_t> combineMask;
5288 for (
int i = 0; i < 32; ++i)
5289 combineMask.push_back(i);
5291 exp2Result = vector::ShuffleOp::create(rewriter, loc, exp2Lower,
5292 exp2Upper, combineMask);
5295 rewriter.replaceOp(expOp, exp2Result);
5304 :
public mlir::ConvertOpToLLVMPattern<aievec::TanhOp> {
5306 using ConvertOpToLLVMPattern<aievec::TanhOp>::ConvertOpToLLVMPattern;
5310 ConversionPatternRewriter &rewriter)
const override {
5311 auto loc = tanhOp.getLoc();
5312 auto srcType = cast<VectorType>(adaptor.getSource().getType());
5313 auto srcElemType = srcType.getElementType();
5317 if ((laneSize != 16 && laneSize != 32) || !srcElemType.isBF16())
5318 return tanhOp.emitWarning()
5319 <<
"aievec.tanh conversion only supports v16bfloat16 and "
5323 auto v16bf16Ty = VectorType::get({16}, rewriter.getBF16Type());
5324 auto v16f32Ty = VectorType::get({16}, rewriter.getF32Type());
5329 if (laneSize == 16) {
5331 auto inputF32 = xllvm::Vector16BF16ToV16AccFloatAIE2pIntrOp::create(
5332 rewriter, loc, v16f32Ty, adaptor.getSource());
5334 xllvm::TanhAIE2pIntrOp::create(rewriter, loc, v16bf16Ty, inputF32);
5337 SmallVector<int64_t> lowerMask, upperMask;
5338 for (
int i = 0; i < 16; ++i) {
5339 lowerMask.push_back(i);
5340 upperMask.push_back(16 + i);
5343 auto lowerBf16 = vector::ShuffleOp::create(
5344 rewriter, loc, adaptor.getSource(), adaptor.getSource(), lowerMask);
5345 auto upperBf16 = vector::ShuffleOp::create(
5346 rewriter, loc, adaptor.getSource(), adaptor.getSource(), upperMask);
5348 auto lowerF32 = xllvm::Vector16BF16ToV16AccFloatAIE2pIntrOp::create(
5349 rewriter, loc, v16f32Ty, lowerBf16);
5350 auto upperF32 = xllvm::Vector16BF16ToV16AccFloatAIE2pIntrOp::create(
5351 rewriter, loc, v16f32Ty, upperBf16);
5354 xllvm::TanhAIE2pIntrOp::create(rewriter, loc, v16bf16Ty, lowerF32);
5356 xllvm::TanhAIE2pIntrOp::create(rewriter, loc, v16bf16Ty, upperF32);
5358 SmallVector<int64_t> combineMask;
5359 for (
int i = 0; i < 32; ++i)
5360 combineMask.push_back(i);
5362 tanhResult = vector::ShuffleOp::create(rewriter, loc, tanhLower,
5363 tanhUpper, combineMask);
5366 rewriter.replaceOp(tanhOp, tanhResult);
5376 :
public mlir::ConvertOpToLLVMPattern<math::RsqrtOp> {
5378 using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
5382 ConversionPatternRewriter &rewriter)
const override {
5383 auto loc = rsqrtOp.getLoc();
5384 auto operandType = adaptor.getOperand().getType();
5387 if (operandType.isF32()) {
5388 auto rsqrtResult = xllvm::InvsqrtAIE2pIntrOp::create(
5389 rewriter, loc, rewriter.getF32Type(), adaptor.getOperand());
5390 rewriter.replaceOp(rsqrtOp, rsqrtResult);
5395 auto vecType = dyn_cast<VectorType>(operandType);
5396 if (!vecType || !vecType.getElementType().isF32())
5401 Value result = LLVM::PoisonOp::create(rewriter, loc, vecType);
5403 for (
int i = 0; i < numElements; ++i) {
5405 auto indexCst = LLVM::ConstantOp::create(
5406 rewriter, loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(i));
5407 auto extractedElem = LLVM::ExtractElementOp::create(
5408 rewriter, loc, adaptor.getOperand(), indexCst);
5411 auto rsqrtResult = xllvm::InvsqrtAIE2pIntrOp::create(
5412 rewriter, loc, rewriter.getF32Type(), extractedElem);
5415 result = LLVM::InsertElementOp::create(rewriter, loc, vecType, result,
5416 rsqrtResult, indexCst);
5419 rewriter.replaceOp(rsqrtOp, result);
5429 using ConvertOpToLLVMPattern<arith::DivFOp>::ConvertOpToLLVMPattern;
5432 : ConvertOpToLLVMPattern(typeConverter), deviceName(device.str()) {}
5438 ConversionPatternRewriter &rewriter)
const override {
5439 auto loc = divOp.getLoc();
5440 auto lhsType = adaptor.getLhs().getType();
5444 auto vecType = dyn_cast<VectorType>(lhsType);
5445 if (!vecType || !vecType.getElementType().isF32())
5448 auto rhsType = adaptor.getRhs().getType();
5449 auto rhsVecType = dyn_cast<VectorType>(rhsType);
5450 if (!rhsVecType || rhsVecType != vecType)
5455 auto module = divOp->getParentOfType<ModuleOp>();
5456 auto f32Ty = rewriter.getF32Type();
5461 std::function<void(OpBuilder &, Location, ValueRange)> bodyBuilder;
5462 if (deviceName ==
"aie2p") {
5463 bodyBuilder = [](OpBuilder &builder, Location loc, ValueRange
args) {
5464 auto invResult = xllvm::InvAIE2pIntrOp::create(
5465 builder, loc, builder.getF32Type(),
args[1]);
5467 arith::MulFOp::create(builder, loc,
args[0], invResult);
5468 LLVM::ReturnOp::create(builder, loc, ValueRange{mulResult});
5471 bodyBuilder = [](OpBuilder &builder, Location loc, ValueRange
args) {
5472 auto divResult = arith::DivFOp::create(builder, loc,
args[0],
args[1]);
5473 LLVM::ReturnOp::create(builder, loc, ValueRange{divResult});
5480 getOrCreateScalarHelperFunc(module, rewriter,
"fdiv", deviceName,
5482 f32Ty, bodyBuilder);
5486 Value result = LLVM::PoisonOp::create(rewriter, loc, vecType);
5488 for (
int i = 0; i < numElements; ++i) {
5489 auto indexCst = LLVM::ConstantOp::create(
5490 rewriter, loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(i));
5491 auto lhsElem = LLVM::ExtractElementOp::create(rewriter, loc,
5492 adaptor.getLhs(), indexCst);
5493 auto rhsElem = LLVM::ExtractElementOp::create(rewriter, loc,
5494 adaptor.getRhs(), indexCst);
5496 auto divResult = LLVM::CallOp::create(rewriter, loc, helperFunc,
5497 ValueRange{lhsElem, rhsElem})
5500 result = LLVM::InsertElementOp::create(rewriter, loc, vecType, result,
5501 divResult, indexCst);
5504 rewriter.replaceOp(divOp, result);
5510 mlir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) {
5529 mlir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns,
5530 Aie2Fp32Emulation aie2Fp32EmulationOption) {
5548 :
public mlir::ConvertOpToLLVMPattern<aievec::ExtElemOp> {
5550 using ConvertOpToLLVMPattern<aievec::ExtElemOp>::ConvertOpToLLVMPattern;
5554 ConversionPatternRewriter &rewriter)
const override {
5555 Location loc = op.getLoc();
5559 Value extracted = LLVM::ExtractElementOp::create(
5560 rewriter, loc, adaptor.getSource(), adaptor.getIndex());
5562 rewriter.replaceOp(op, extracted);
5569 :
public mlir::ConvertOpToLLVMPattern<aievec::ConcatOp> {
5571 using ConvertOpToLLVMPattern<aievec::ConcatOp>::ConvertOpToLLVMPattern;
5575 ConversionPatternRewriter &rewriter)
const override {
5576 Location loc = op.getLoc();
5578 SmallVector<Value> sources = adaptor.getSources();
5580 if (sources.empty()) {
5581 op.emitWarning() <<
"aievec.concat with no sources is not supported.\n";
5586 Value result = sources[0];
5589 auto srcType = cast<VectorType>(sources[0].getType());
5592 if (sources.size() == 2) {
5594 SmallVector<int64_t> mask;
5595 for (int64_t i = 0; i < srcLanes * 2; ++i)
5598 result = vector::ShuffleOp::create(rewriter, loc, sources[0], sources[1],
5600 }
else if (sources.size() == 4) {
5602 SmallVector<int64_t> pairMask;
5603 for (int64_t i = 0; i < srcLanes * 2; ++i)
5604 pairMask.push_back(i);
5606 auto pair0 = vector::ShuffleOp::create(rewriter, loc, sources[0],
5607 sources[1], pairMask);
5608 auto pair1 = vector::ShuffleOp::create(rewriter, loc, sources[2],
5609 sources[3], pairMask);
5611 SmallVector<int64_t> finalMask;
5612 for (int64_t i = 0; i < srcLanes * 4; ++i)
5613 finalMask.push_back(i);
5616 vector::ShuffleOp::create(rewriter, loc, pair0, pair1, finalMask);
5618 op.emitWarning() <<
"aievec.concat with " << sources.size()
5619 <<
" operands is not supported for AIE2p.\n";
5623 rewriter.replaceOp(op, result);
5629 mlir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) {
5652 mlir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns,
5653 Aie2Fp32Emulation aie2Fp32EmulationOption, StringRef aieTarget) {
5655 if (aieTarget ==
"aie2p")
5659 aie2Fp32EmulationOption);
5663static void configureAIEVecToLLVMLegalizations(LLVMConversionTarget &target) {
5666 target.addDynamicallyLegalOp<arith::DivFOp>([](arith::DivFOp divOp) {
5667 auto resultType = divOp.getType();
5668 if (
auto vecType = dyn_cast<VectorType>(resultType)) {
5670 return !vecType.getElementType().isF32();
5681 aieTarget = options.aieTarget;
5682 aie2Fp32Emulation = options.aie2Fp32Emulation;
5686 RewritePatternSet patterns(&getContext());
5687 LLVMTypeConverter converter(&getContext());
5691 converter.addConversion(
5692 [&](VectorType type) -> std::optional<Type> {
return type; });
5695 aie2Fp32Emulation, aieTarget);
5697 LLVMConversionTarget target(getContext());
5698 target.addIllegalDialect<xilinx::aievec::AIEVecDialect,
5699 xilinx::aievec::aie1::AIEVecAIE1Dialect>();
5700 target.addLegalDialect<arith::ArithDialect, vector::VectorDialect,
5701 xilinx::xllvm::XLLVMDialect, ub::UBDialect>();
5704 configureAIEVecToLLVMLegalizations(target);
5706 if (failed(applyPartialConversion(getOperation(), target,
5707 std::move(patterns))))
5708 signalPassFailure();
5712std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
5714 return std::make_unique<ConvertAIEVecToLLVMPass>();
5717std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
5719 const xilinx::ConvertAIEVecToLLVMOptions &options) {
5720 return std::make_unique<ConvertAIEVecToLLVMPass>(options);
static DecodedAddElemOp decodeAddElemOp(OpAdaptor op)
LogicalResult matchAndRewrite(aievec::AddElemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::AddElemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static DecodedAddElemOp decodeAddElemOp(OpAdaptor op)
LogicalResult matchAndRewrite(aievec::aie1::AddOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::BroadcastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::BroadcastScalarOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::BroadcastScalarOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::CmpOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ConcatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ConcatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ExpOp expOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ExtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ExtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::FMAElemOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::FMAElemOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::aie1::FMAOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::DivFOp divOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
FdivOpConversion(const LLVMTypeConverter &typeConverter, StringRef device)
LogicalResult matchAndRewrite(aievec::InvOp invOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::MaxOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::MaxOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::MinOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::MinOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static DecodedMulElemOp decodeMulElemOp(OpAdaptor op)
LogicalResult matchAndRewrite(aievec::MulElemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult convertToEmulatedFP32MulElem(aievec::MulElemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
LogicalResult convertToEmulatedI32MulElem(aievec::MulElemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
MulElemOpConversion(const LLVMTypeConverter &typeConverter, Aie2Fp32Emulation aie2Fp32EmulationOption)
Aie2Fp32Emulation aie2Fp32EmulationOption
LogicalResult matchAndRewrite(aievec::MulElemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static DecodedMulElemOp decodeMulElemOp(OpAdaptor op)
LogicalResult matchAndRewrite(aievec::aie1::MulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::PackOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static std::string getIntrinsicName(aievec::PackOp op)
LogicalResult matchAndRewrite(math::RsqrtOp rsqrtOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::SRSOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::SRSOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::SelOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static std::string getIntrinsicName(aievec::aie1::SelectOp op)
LogicalResult matchAndRewrite(aievec::aie1::SelectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ShiftOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ShiftOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::SubElemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static DecodedSubElemOp decodeSubElemOp(OpAdaptor op)
LogicalResult matchAndRewrite(aievec::SubElemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static DecodedSubElemOp decodeSubElemOp(OpAdaptor op)
LogicalResult matchAndRewrite(aievec::aie1::SubOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::TanhOp tanhOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static std::string getIntrinsicName(aievec::UPDOp op, int loadSize)
LogicalResult matchAndRewrite(aievec::UPDOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::UPSOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::UPSOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::UnpackOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
mlir::VectorType getFlattenedVectorType(mlir::VectorType vecTy)
int32_t getVectorSizeInBits(mlir::VectorType type)
unsigned getVectorLaneSize(mlir::VectorType type)
uint32_t encodeSquare(uint32_t square)
void populateAIEVecToLLVMCommonConversionPatterns(mlir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns)
void encodeConf(uint32_t conf[2], const BufferParams &x, const BufferParams &z, bool sub)
void populateAIEVecToLLVMAIE2ConversionPatterns(mlir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns)
void populateAIEVecToLLVMAIE2pConversionPatterns(mlir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns)
std::string getMulOrFMAIntrinsicName(Operation *op)
std::unique_ptr< mlir::OperationPass< mlir::ModuleOp > > createConvertAIEVecToLLVMPass()
std::string getVectorTypeString(VectorType type, bool abbrev=false, bool acc=false)
void populateAIEVecToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns, Aie2Fp32Emulation aie2Fp32EmulationOption, llvm::StringRef aieTarget)
@ FP32_FP32_FP32_16x1x1x1
@ FP32_FP32_FP32_16x1x1x1
@ FP32_FP32_FP32_32x1x1x1
ConvertAIEVecToLLVMPass(const xilinx::ConvertAIEVecToLLVMOptions &options)
void runOnOperation() override
ConvertAIEVecToLLVMPass()=default
@ BF16_BF16_FP32_32x1x2x1
@ BF16_BF16_FP32_64x1x2x1
@ BF16_BF16_FP32_16x1x1x1
@ FP32_FP32_FP32_16x1x1x1
@ BF16_BF16_FP32_16x1x2x1
@ FP32_FP32_FP32_16x1x1x1
@ FP32_FP32_FP32_16x1x1x1
@ FP32_FP32_FP32_32x1x1x1