21#include "mlir/Dialect/Affine/IR/AffineOps.h"
22#include "mlir/Dialect/EmitC/IR/EmitC.h"
23#include "mlir/Dialect/Func/IR/FuncOps.h"
24#include "mlir/Dialect/Math/IR/Math.h"
25#include "mlir/Dialect/MemRef/IR/MemRef.h"
26#include "mlir/Dialect/SCF/IR/SCF.h"
27#include "mlir/Dialect/UB/IR/UBOps.h"
28#include "mlir/IR/PatternMatch.h"
29#include "mlir/IR/SymbolTable.h"
30#include "mlir/IR/TypeUtilities.h"
31#include "mlir/Pass/PassManager.h"
32#include "mlir/Transforms/DialectConversion.h"
33#include "mlir/Transforms/Passes.h"
34#include "llvm/ADT/SmallSet.h"
39#define DEBUG_TYPE "lower-vector-to-aievec"
44using namespace vector;
52static bool isNarrowingOp(Operation *op) {
53 if (isa<arith::TruncFOp>(op) || isa<arith::TruncIOp>(op))
56 if (
auto srsOp = dyn_cast<aievec::SRSOp>(op)) {
57 auto *srsOpSrcOp = srsOp.getSource().getDefiningOp();
58 if (isa<aievec::UPSOp>(srsOpSrcOp) || isa<aievec::CastOp>(srsOpSrcOp))
72static bool isSRSCompoundCandidate(arith::TruncIOp trunciOp) {
73 Value source = trunciOp.getIn();
76 if (source.getDefiningOp<arith::ShRSIOp>())
80 if (
auto minsiOp = source.getDefiningOp<arith::MinSIOp>()) {
81 if (
auto maxsiOp = minsiOp.getLhs().getDefiningOp<arith::MaxSIOp>()) {
82 if (maxsiOp.getLhs().getDefiningOp<arith::ShRSIOp>())
88 if (
auto maxsiOp = source.getDefiningOp<arith::MaxSIOp>()) {
89 if (
auto minsiOp = maxsiOp.getLhs().getDefiningOp<arith::MinSIOp>()) {
90 if (minsiOp.getLhs().getDefiningOp<arith::ShRSIOp>())
97 if (
auto minsiOp = source.getDefiningOp<arith::MinSIOp>()) {
98 if (minsiOp.getLhs().getDefiningOp<arith::MaxSIOp>())
103 if (
auto maxsiOp = source.getDefiningOp<arith::MaxSIOp>()) {
104 if (maxsiOp.getLhs().getDefiningOp<arith::MinSIOp>())
113static bool shrsiUsedByCompoundSRS(arith::ShRSIOp rsOp) {
114 for (Operation *user : rsOp->getUsers()) {
116 if (
auto truncOp = dyn_cast<arith::TruncIOp>(user))
117 if (isSRSCompoundCandidate(truncOp))
121 if (isa<arith::MaxSIOp, arith::MinSIOp>(user)) {
122 for (Operation *user2 : user->getUsers()) {
123 if (
auto truncOp2 = dyn_cast<arith::TruncIOp>(user2))
124 if (isSRSCompoundCandidate(truncOp2))
126 if (isa<arith::MaxSIOp, arith::MinSIOp>(user2)) {
127 for (Operation *user3 : user2->getUsers()) {
128 if (
auto truncOp3 = dyn_cast<arith::TruncIOp>(user3))
129 if (isSRSCompoundCandidate(truncOp3))
142static bool scalarClampInCompoundSRS(Operation *op) {
143 if (!isa<arith::MaxSIOp, arith::MinSIOp>(op))
146 if (isa<VectorType>(op->getResult(0).getType()))
148 for (Operation *user : op->getUsers()) {
149 if (
auto truncOp = dyn_cast<arith::TruncIOp>(user)) {
150 if (isSRSCompoundCandidate(truncOp))
153 if (isa<arith::MaxSIOp, arith::MinSIOp>(user)) {
154 for (Operation *user2 : user->getUsers()) {
155 if (
auto truncOp2 = dyn_cast<arith::TruncIOp>(user2)) {
156 if (isSRSCompoundCandidate(truncOp2))
168static std::optional<Value> getSourceOfWideningOp(Value src) {
169 if (
auto extSIOp =
src.getDefiningOp<arith::ExtSIOp>())
170 return extSIOp.getIn();
171 if (
auto extUIOp =
src.getDefiningOp<arith::ExtUIOp>())
172 return extUIOp.getIn();
173 if (
auto extFOp =
src.getDefiningOp<arith::ExtFOp>())
174 return extFOp.getIn();
175 if (
auto srsOp =
src.getDefiningOp<aievec::SRSOp>()) {
179 auto srsSource = srsOp.getSource();
181 if (
auto upsOp = srsSource.getDefiningOp<aievec::UPSOp>())
182 return upsOp.getSource();
184 if (
auto castOp =
src.getDefiningOp<aievec::CastOp>()) {
188 auto castSource = castOp.getSource();
190 if (
auto upsOp = castSource.getDefiningOp<aievec::UPSOp>())
191 return upsOp.getSource();
193 return std::optional<Value>();
198static std::optional<Value> getSourceOfNarrowingOp(Value src) {
199 if (
auto truncFOp =
src.getDefiningOp<arith::TruncFOp>())
200 return truncFOp.getIn();
201 if (
auto truncIOp =
src.getDefiningOp<arith::TruncIOp>())
202 return truncIOp.getIn();
203 return std::optional<Value>();
212static Value widenValueWithNarrowingCheck(Value val, Type targetType,
214 ConversionPatternRewriter &rewriter) {
216 if (
auto narrowedSrc = getSourceOfNarrowingOp(val)) {
217 if (narrowedSrc->getType() == targetType)
222 if (val.getType() == targetType)
225 return arith::ExtFOp::create(rewriter, loc, targetType, val);
238narrowValueWithWideningCheck(Operation *srcOp, Value val, Type targetType,
240 ConversionPatternRewriter &rewriter) {
247 if (srcOp->hasOneUse()) {
248 Operation *user = *srcOp->getUsers().begin();
249 if (
auto extfOp = dyn_cast<arith::ExtFOp>(user)) {
259 arith::TruncFOp::create(rewriter, loc, targetType, val);
269performBF16BinaryOpInF32(Value lhs, Value rhs, Operation *srcOp, Location loc,
270 ConversionPatternRewriter &rewriter,
271 std::function<Value(Value, Value)> opBuilder) {
272 Type f32Type = rewriter.getF32Type();
275 Value lhsF32 = widenValueWithNarrowingCheck(lhs, f32Type, loc, rewriter);
276 Value rhsF32 = widenValueWithNarrowingCheck(rhs, f32Type, loc, rewriter);
279 Value resultF32 = opBuilder(lhsF32, rhsF32);
282 auto narrowResult = narrowValueWithWideningCheck(
283 srcOp, resultF32, lhs.getType(), loc, rewriter);
285 if (narrowResult.skipNarrowing) {
287 rewriter.replaceOp(narrowResult.wideningUser, resultF32);
288 rewriter.eraseOp(srcOp);
290 rewriter.replaceOp(srcOp, narrowResult.narrowedValue);
297static std::optional<std::tuple<Value, Value, Value>>
298extractMACOperandsFromAddOperands(Value addLhs, Value addRhs) {
299 auto *lhsDefOp = addLhs.getDefiningOp();
300 auto *rhsDefOp = addRhs.getDefiningOp();
301 arith::MulIOp mulOp =
nullptr;
304 mulOp = dyn_cast<arith::MulIOp>(lhsDefOp);
307 if (!mulOp && rhsDefOp) {
308 mulOp = dyn_cast<arith::MulIOp>(rhsDefOp);
312 return std::make_tuple(mulOp.getLhs(), mulOp.getRhs(), acc);
315 auto lhsSrsOp = addLhs.getDefiningOp<aievec::SRSOp>();
316 auto rhsSrsOp = addRhs.getDefiningOp<aievec::SRSOp>();
317 aievec::aie1::MulOp aieMulOp =
nullptr;
319 aieMulOp = lhsSrsOp.getSource().getDefiningOp<aievec::aie1::MulOp>();
322 if (!aieMulOp && rhsSrsOp) {
323 aieMulOp = rhsSrsOp.getSource().getDefiningOp<aievec::aie1::MulOp>();
327 return std::make_tuple(aieMulOp.getLhs(), aieMulOp.getRhs(), acc);
334static std::optional<std::tuple<Value, Value, Value>>
335extractFMACOperandsFromAddOperands(Value addLhs, Value addRhs) {
336 auto *lhsDefOp = addLhs.getDefiningOp();
337 auto *rhsDefOp = addRhs.getDefiningOp();
338 arith::MulFOp mulOp =
nullptr;
341 mulOp = dyn_cast<arith::MulFOp>(lhsDefOp);
344 if (!mulOp && rhsDefOp) {
345 mulOp = dyn_cast<arith::MulFOp>(rhsDefOp);
349 return std::make_tuple(mulOp.getLhs(), mulOp.getRhs(), acc);
356static std::optional<Value>
357convertValueToTargetTypeAIE2(ConversionPatternRewriter &rewriter, Location loc,
358 Value inputVal, VectorType tgtType) {
359 auto srcType = cast<VectorType>(inputVal.getType());
360 auto srcElemType = srcType.getElementType();
361 unsigned srcBitWidth = srcElemType.getIntOrFloatBitWidth();
364 auto tgtElemType = tgtType.getElementType();
365 unsigned tgtBitWidth = tgtElemType.getIntOrFloatBitWidth();
368 if (srcType == tgtType)
371 if ((srcElemType == tgtElemType) && (srcLaneSize != tgtLaneSize)) {
373 if ((srcLaneSize == 16 && tgtLaneSize == 32 &&
374 isa<FloatType>(srcElemType)) ||
375 (srcLaneSize == 32 && tgtLaneSize == 64 &&
376 isa<IntegerType>(srcElemType))) {
377 auto zeroConstOp = arith::ConstantOp::create(
378 rewriter, loc, srcType.getElementType(),
379 rewriter.getZeroAttr(srcType.getElementType()));
380 auto broadcastZeroOp = aievec::BroadcastScalarOp::create(
381 rewriter, loc, tgtType, zeroConstOp->getResult(0));
382 auto extOp = aievec::ExtOp::create(rewriter, loc, srcType,
383 broadcastZeroOp->getResult(0), 0);
385 SmallVector<Value> inputSources = {inputVal, extOp->getResult(0)};
387 aievec::ConcatOp::create(rewriter, loc, tgtType, inputSources);
389 return concatOp.getResult();
391 }
else if ((srcElemType != tgtElemType) && (srcLaneSize == tgtLaneSize) &&
392 isa<IntegerType>(srcElemType) && isa<IntegerType>(tgtElemType)) {
393 if (srcBitWidth == 16 && tgtBitWidth == 32 && srcLaneSize == 16) {
397 auto upsOp = aievec::UPSOp::create(rewriter, loc, accType, inputVal);
398 auto castOp = aievec::CastOp::create(
399 rewriter, loc, tgtType, upsOp.getResult(),
false);
400 return castOp.getResult();
403 if (srcBitWidth == 8 && tgtBitWidth == 32 && srcLaneSize == 16) {
408 aievec::ConcatOp::create(rewriter, loc, concatOutType,
409 SmallVector<Value>({inputVal, inputVal}));
412 aievec::UPSOp::create(rewriter, loc, accType, concatOp.getResult());
414 auto castOp = aievec::CastOp::create(
415 rewriter, loc, castType, upsOp.getResult(),
false);
417 aievec::ExtOp::create(rewriter, loc, tgtType, castOp.getResult(), 0);
418 return extOp.getResult();
421 if (srcBitWidth == 8 && tgtBitWidth == 16 && srcLaneSize == 32) {
424 aievec::UnpackOp::create(rewriter, loc, tgtType, inputVal);
425 return unpackOp.getResult();
435static SmallVector<NamedAttribute>
436buildAttributeListForRotationSelectOp(PatternRewriter &rewriter, VectorType vTy,
439 auto elemTy = vTy.getElementType();
440 if (
auto intTy = dyn_cast<IntegerType>(elemTy))
441 width = intTy.getWidth();
442 StringAttr attr0 = rewriter.getStringAttr(
"0");
443 StringAttr attr0x06040200 = rewriter.getStringAttr(
"0x06040200");
444 StringAttr attr0x0e0c0a08 = rewriter.getStringAttr(
"0x0e0c0a08");
445 StringAttr attr0x2103 = rewriter.getStringAttr(
"0x2103");
446 StringAttr attr0x3210 = rewriter.getStringAttr(
"0x3210");
447 StringAttr selectAttrName = rewriter.getStringAttr(
"select");
448 StringAttr xoffsetsAttrName = rewriter.getStringAttr(
"xoffsets");
449 StringAttr xoffsetsHiAttrName = rewriter.getStringAttr(
"xoffsets_hi");
450 StringAttr xsquareAttrName = rewriter.getStringAttr(
"xsquare");
451 StringAttr xstartAttrName = rewriter.getStringAttr(
"xstart");
452 StringAttr yoffsetsAttrName = rewriter.getStringAttr(
"yoffsets");
453 StringAttr yoffsetsHiAttrName = rewriter.getStringAttr(
"yoffsets_hi");
454 StringAttr ysquareAttrName = rewriter.getStringAttr(
"ysquare");
455 StringAttr ystartAttrName = rewriter.getStringAttr(
"ystart");
460 int64_t xstart = rotation + 1;
461 int64_t ystart = rotation - 1;
462 return SmallVector<NamedAttribute, 9>(
463 {{selectAttrName, rewriter.getStringAttr(
"0x11111111")},
464 {xoffsetsAttrName, attr0x06040200},
465 {xoffsetsHiAttrName, attr0x0e0c0a08},
466 {xsquareAttrName, attr0x2103},
467 {xstartAttrName, rewriter.getStringAttr(std::to_string(xstart))},
468 {yoffsetsAttrName, rewriter.getStringAttr(
"0x0503010f")},
469 {yoffsetsHiAttrName, rewriter.getStringAttr(
"0x0d0b0907")},
470 {ysquareAttrName, attr0x2103},
471 {ystartAttrName, rewriter.getStringAttr(std::to_string(ystart))}});
473 return SmallVector<NamedAttribute, 9>(
474 {{selectAttrName, attr0},
475 {xoffsetsAttrName, attr0x06040200},
476 {xoffsetsHiAttrName, attr0x0e0c0a08},
477 {xsquareAttrName, attr0x3210},
478 {xstartAttrName, rewriter.getStringAttr(std::to_string(rotation))},
479 {yoffsetsAttrName, attr0},
480 {yoffsetsHiAttrName, attr0},
481 {ysquareAttrName, attr0},
482 {ystartAttrName, attr0}});
485 return SmallVector<NamedAttribute, 7>(
486 {{selectAttrName, attr0},
487 {xoffsetsAttrName, rewriter.getStringAttr(
"0x76543210")},
488 {xsquareAttrName, attr0x3210},
489 {xstartAttrName, rewriter.getStringAttr(std::to_string(rotation))},
490 {yoffsetsAttrName, attr0},
491 {ysquareAttrName, attr0},
492 {ystartAttrName, attr0}});
494 llvm::report_fatal_error(
"Unexpected width!");
502SmallVector<NamedAttribute>
506 auto elemTy = fmaOp.getLhs().getType().getElementType();
507 if (
auto intTy = dyn_cast<IntegerType>(elemTy))
508 width = intTy.getWidth();
509 auto *ctx = fmaOp.getContext();
529 return SmallVector<NamedAttribute, 11>(
530 {{fmaOp.getXstartAttrName(), StringAttr::get(ctx,
"0")},
531 {fmaOp.getXoffsetsAttrName(), StringAttr::get(ctx,
"0x73727170")},
532 {fmaOp.getXoffsetsHiAttrName(), StringAttr::get(ctx,
"0x77767574")},
533 {fmaOp.getXstepAttrName(), fmaOp.getXstepAttr()},
534 {fmaOp.getXsquareAttrName(), StringAttr::get(ctx,
"0x3120")},
535 {fmaOp.getZstartAttrName(),
536 StringAttr::get(ctx, std::to_string(bcastPos))},
537 {fmaOp.getZoffsetsAttrName(), StringAttr::get(ctx,
"0")},
538 {fmaOp.getZoffsetsHiAttrName(), StringAttr::get(ctx,
"0")},
539 {fmaOp.getZstepAttrName(), StringAttr::get(ctx, std::to_string(step))},
540 {fmaOp.getZsquareAttrName(), fmaOp.getZsquareAttr()},
541 {fmaOp.getFmsubAttrName(), fmaOp.getFmsubAttr()}});
543 return SmallVector<NamedAttribute, 11>(
544 {{fmaOp.getXstartAttrName(), StringAttr::get(ctx,
"0")},
545 {fmaOp.getXoffsetsAttrName(), StringAttr::get(ctx,
"0x76543210")},
546 {fmaOp.getXoffsetsHiAttrName(), fmaOp.getXoffsetsHiAttr()},
547 {fmaOp.getXstepAttrName(), fmaOp.getXstepAttr()},
548 {fmaOp.getXsquareAttrName(), fmaOp.getXsquareAttr()},
549 {fmaOp.getZstartAttrName(),
550 StringAttr::get(ctx, std::to_string(bcastPos))},
551 {fmaOp.getZoffsetsAttrName(), StringAttr::get(ctx,
"0x00000000")},
552 {fmaOp.getZoffsetsHiAttrName(), fmaOp.getZoffsetsHiAttr()},
553 {fmaOp.getZstepAttrName(), fmaOp.getZstepAttr()},
554 {fmaOp.getZsquareAttrName(), fmaOp.getZsquareAttr()},
555 {fmaOp.getFmsubAttrName(), fmaOp.getFmsubAttr()}});
557 llvm::report_fatal_error(
"Unexpected width!");
565template <
typename SrcOpTy,
typename AIEv2ElemOp>
566static LogicalResult genAddElemAIE2(ConversionPatternRewriter &rewriter,
567 Value lval, Value rval, VectorType srcType,
569 auto lCastOp = aievec::CastOp::create(rewriter, srcOp.getLoc(), srcType, lval,
571 auto rCastOp = aievec::CastOp::create(rewriter, srcOp.getLoc(), srcType, rval,
573 auto elemOp = AIEv2ElemOp::create(
574 rewriter, srcOp.getLoc(), lCastOp->getResult(0).getType(),
575 lCastOp->getResult(0), rCastOp->getResult(0));
576 rewriter.replaceOpWithNewOp<aievec::CastOp>(
577 srcOp, srcOp.getType(), elemOp.getResult(),
false);
581static arith::CmpIPredicate
582convertToIntegerPredicate(arith::CmpFPredicate pred) {
584 case CmpFPredicate::UEQ:
585 case CmpFPredicate::OEQ:
586 return CmpIPredicate::eq;
587 case CmpFPredicate::UGT:
588 return CmpIPredicate::ugt;
589 case CmpFPredicate::OGT:
590 return CmpIPredicate::sgt;
591 case CmpFPredicate::UGE:
592 return CmpIPredicate::uge;
593 case CmpFPredicate::OGE:
594 return CmpIPredicate::sge;
595 case CmpFPredicate::ULT:
596 return CmpIPredicate::ult;
597 case CmpFPredicate::OLT:
598 return CmpIPredicate::slt;
599 case CmpFPredicate::ULE:
600 return CmpIPredicate::ule;
601 case CmpFPredicate::OLE:
602 return CmpIPredicate::sle;
603 case CmpFPredicate::UNE:
604 case CmpFPredicate::ONE:
605 return CmpIPredicate::ne;
607 llvm::report_fatal_error(
"Unexpected predicate!");
611static arith::CmpIPredicate
612convertToIntegerPredicate(arith::CmpIPredicate pred) {
616static aievec::CmpOp createCmpOpAIE2(ConversionPatternRewriter &rewriter,
617 CmpIPredicate pred, Location loc,
618 Type type, Value lhs, Value rhs) {
620 case CmpIPredicate::eq:
621 return aievec::CmpOp::create(rewriter, loc, type, lhs, rhs,
"eq");
622 case CmpIPredicate::ne:
623 return aievec::CmpOp::create(rewriter, loc, type, lhs, rhs,
"ne");
624 case CmpIPredicate::slt:
625 return aievec::CmpOp::create(rewriter, loc, type, lhs, rhs,
"slt");
626 case CmpIPredicate::ult:
627 return aievec::CmpOp::create(rewriter, loc, type, lhs, rhs,
"ult");
628 case CmpIPredicate::sle:
629 return aievec::CmpOp::create(rewriter, loc, type, lhs, rhs,
"sle");
630 case CmpIPredicate::ule:
631 return aievec::CmpOp::create(rewriter, loc, type, lhs, rhs,
"ule");
632 case CmpIPredicate::sgt:
633 return aievec::CmpOp::create(rewriter, loc, type, lhs, rhs,
"sgt");
634 case CmpIPredicate::ugt:
635 return aievec::CmpOp::create(rewriter, loc, type, lhs, rhs,
"ugt");
636 case CmpIPredicate::sge:
637 return aievec::CmpOp::create(rewriter, loc, type, lhs, rhs,
"sge");
638 case CmpIPredicate::uge:
639 return aievec::CmpOp::create(rewriter, loc, type, lhs, rhs,
"uge");
644template <
typename DstOpTy>
645static aievec::ExtElemOp
646generateAIEVecOpsForReductionOp(ConversionPatternRewriter &rewriter,
647 vector::ReductionOp srcOp,
int shiftIndex,
649 assert(shiftIndex > 0 && (shiftIndex & (shiftIndex - 1)) == 0 &&
650 "shiftIndex must be power of 2");
652 Location loc = srcOp.getLoc();
653 auto vType = dyn_cast<VectorType>(curValue.getType());
654 Type scalarType = vType.getElementType();
655 Type vecType = curValue.getType();
656 DstOpTy curOp =
nullptr;
657 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
659 for (
int id = shiftIndex;
id > 0;
id /= 2) {
660 auto constOp = arith::ConstantOp::create(
661 rewriter, loc, rewriter.getI32IntegerAttr(
id * elWidth / 8));
663 auto shiftBytesOp = aievec::ShiftOp::create(
664 rewriter, loc, vecType, curValue, curValue, constOp.getResult());
666 curOp = DstOpTy::create(rewriter, loc, vecType, curValue,
667 shiftBytesOp.getResult());
669 curValue = curOp.getResult();
673 arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0));
674 return aievec::ExtElemOp::create(rewriter, loc, scalarType, curOp,
675 zeroConstOp.getResult());
681static std::pair<Value, unsigned>
682padV16ToV32WithInfinity(ConversionPatternRewriter &rewriter, Location loc,
683 Value inputVec, Type scalarType,
bool negativeInf) {
688 auto infAttr = rewriter.getFloatAttr(
690 APFloat::getInf(cast<FloatType>(scalarType).getFloatSemantics(),
692 auto splatInf = arith::ConstantOp::create(rewriter, loc, infAttr).getResult();
696 aievec::BroadcastScalarOp::create(rewriter, loc, v32bf16Type, splatInf);
698 aievec::ExtOp::create(rewriter, loc, v16bf16Type, infVec, 1);
702 aievec::ConcatOp::create(rewriter, loc, v32bf16Type,
703 ValueRange{inputVec, infUpperHalf.getResult()});
705 return {paddedVec, 32};
710static Value padV16ToV32WithZeros(ConversionPatternRewriter &rewriter,
711 Location loc, Value inputVec,
715 auto zeroAttr = rewriter.getZeroAttr(v16Type);
716 auto zeroVec = arith::ConstantOp::create(rewriter, loc, zeroAttr);
717 return aievec::ConcatOp::create(rewriter, loc, v32Type,
718 ValueRange{inputVec, zeroVec.getResult()});
721static func::FuncOp getOrInsertFuncDecl(ConversionPatternRewriter &rewriter,
722 Operation *parentSymbolTableOp,
723 StringRef funcName, TypeRange inTypes,
724 TypeRange outTypes) {
726 mlir::OpBuilder::InsertionGuard insertGuard(rewriter);
727 rewriter.setInsertionPointToStart(
728 &parentSymbolTableOp->getRegions().front().getBlocks().front());
729 SymbolTable st = SymbolTable(parentSymbolTableOp);
730 func::FuncOp fnOpLookup = st.lookup<func::FuncOp>(funcName);
734 if (fnOpLookup != NULL) {
737 StringAttr t1 = rewriter.getStringAttr(
"sym_visibility");
738 StringAttr t2 = rewriter.getStringAttr(
"private");
739 NamedAttribute funcAccess = NamedAttribute(t1, t2);
740 FunctionType fnType =
741 mlir::FunctionType::get(rewriter.getContext(), inTypes, outTypes);
742 fnOp = func::FuncOp::create(rewriter, parentSymbolTableOp->getLoc(),
743 funcName, fnType, funcAccess);
773template <
typename SrcOpTy,
typename Func>
774static void splitWideVectorOp(SrcOpTy srcOp, ArrayRef<Value> wideInputs,
775 VectorType halfType, VectorType wideType,
776 ConversionPatternRewriter &rewriter,
777 Func &&processHalves) {
779 Location loc = srcOp.getLoc();
782 SmallVector<std::pair<Value, Value>> halfInputs;
783 halfInputs.reserve(wideInputs.size());
784 for (Value wideInput : wideInputs) {
786 aievec::ExtOp::create(rewriter, loc, halfType, wideInput, 0);
788 aievec::ExtOp::create(rewriter, loc, halfType, wideInput, 1);
789 halfInputs.emplace_back(lowerHalf.getResult(), upperHalf.getResult());
793 auto [lowResult, highResult] = processHalves(halfInputs, loc, rewriter);
796 SmallVector<Value> concatSources = {lowResult, highResult};
797 rewriter.replaceOpWithNewOp<aievec::ConcatOp>(srcOp, wideType, concatSources);
801template <
typename SrcOpTy>
802static void splitWideUnaryVectorOp(
803 SrcOpTy srcOp, Value wideInput, VectorType halfType, VectorType wideType,
804 ConversionPatternRewriter &rewriter,
805 std::function<Value(Value, Location, ConversionPatternRewriter &)>
808 splitWideVectorOp<SrcOpTy>(
809 srcOp, {wideInput}, halfType, wideType, rewriter,
810 [&processHalf](ArrayRef<std::pair<Value, Value>> halfInputs, Location loc,
811 ConversionPatternRewriter &rewriter) {
812 auto [lowerHalf, upperHalf] = halfInputs[0];
813 Value lowResult = processHalf(lowerHalf, loc, rewriter);
814 Value highResult = processHalf(upperHalf, loc, rewriter);
815 return std::make_pair(lowResult, highResult);
820template <
typename SrcOpTy>
821static void splitWideBinaryVectorOp(
822 SrcOpTy srcOp, Value lhs, Value rhs, VectorType halfType,
823 VectorType wideType, ConversionPatternRewriter &rewriter,
824 std::function<Value(Value, Value, Location, ConversionPatternRewriter &)>
827 splitWideVectorOp<SrcOpTy>(
828 srcOp, {lhs, rhs}, halfType, wideType, rewriter,
829 [&processHalf](ArrayRef<std::pair<Value, Value>> halfInputs, Location loc,
830 ConversionPatternRewriter &rewriter) {
831 auto [lhsLow, lhsHigh] = halfInputs[0];
832 auto [rhsLow, rhsHigh] = halfInputs[1];
833 Value lowResult = processHalf(lhsLow, rhsLow, loc, rewriter);
834 Value highResult = processHalf(lhsHigh, rhsHigh, loc, rewriter);
835 return std::make_pair(lowResult, highResult);
844static bool matchExpOpForAIE2LUT(math::ExpOp::Adaptor adaptor) {
845 auto srcType = dyn_cast<VectorType>(adaptor.getOperand().getType());
850 Type scalarType = srcType.getElementType();
851 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
854 return isa<FloatType>(scalarType) && (laneSize == 16 || laneSize == 32) &&
859static bool matchExpOpForAIE2P(math::ExpOp::Adaptor adaptor) {
860 auto srcType = dyn_cast<VectorType>(adaptor.getOperand().getType());
865 Type scalarType = srcType.getElementType();
866 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
869 return scalarType.isBF16() && (laneSize == 16 || laneSize == 32) &&
881 using OpConversionPattern::OpConversionPattern;
885 ConversionPatternRewriter &rewriter)
const override {
887 auto extOp = adaptor.getSource().getDefiningOp<vector::ExtractOp>();
892 auto src = extOp.getSource();
893 auto pos = extOp.getStaticPosition();
894 int64_t posVal = pos[0];
895 auto srcVecType = cast<VectorType>(src.getType());
896 auto resultType = cast<VectorType>(bcastOp.getResult().getType());
897 if (srcVecType != resultType) {
898 if (srcVecType.getNumElements() != 2 * resultType.getNumElements())
900 auto half =
static_cast<int8_t
>(posVal / resultType.getNumElements());
901 posVal -= half * resultType.getNumElements();
902 src = aievec::ExtOp::create(rewriter, extOp.getLoc(), resultType, src,
903 rewriter.getI8IntegerAttr(half))
907 unsigned elWidth = resultType.getElementType().getIntOrFloatBitWidth();
910 laneSize * elWidth == 512) {
912 rewriter.replaceOpWithNewOp<aievec::BroadcastOp>(bcastOp, resultType, src,
914 }
else if (laneSize * elWidth == 256) {
916 VectorType aievecBcastType =
919 aievec::ConcatOp::create(rewriter, bcastOp.getLoc(), aievecBcastType,
920 SmallVector<Value>({src, src}));
921 auto aieBcastOp = aievec::BroadcastOp::create(
922 rewriter, bcastOp.getLoc(), aievecBcastType, concatOp.getResult(),
924 rewriter.replaceOpWithNewOp<aievec::ExtOp>(bcastOp, resultType,
925 aieBcastOp.getResult(), 0);
926 }
else if (laneSize * elWidth == 1024) {
928 VectorType aievecBcastType =
930 auto half =
static_cast<int8_t
>(posVal / resultType.getNumElements());
931 posVal -= half * resultType.getNumElements();
933 aievec::ExtOp::create(rewriter, bcastOp.getLoc(), aievecBcastType,
934 src, rewriter.getI8IntegerAttr(half));
935 auto aieBcastOp = aievec::BroadcastOp::create(rewriter, bcastOp.getLoc(),
937 extOp.getResult(), posVal);
938 rewriter.replaceOpWithNewOp<aievec::ConcatOp>(
940 SmallVector<Value>({aieBcastOp.getResult(), aieBcastOp.getResult()}));
950 using OpConversionPattern::OpConversionPattern;
954 ConversionPatternRewriter &rewriter)
const override {
956 if (adaptor.getSource().getDefiningOp<vector::ExtractOp>())
959 auto resultType = cast<VectorType>(bcastOp.getResult().getType());
961 Type scalarType = resultType.getElementType();
962 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
964 auto src = bcastOp.getSource();
966 if (laneSize * elWidth == 512) {
967 Value newOp = aievec::BroadcastScalarOp::create(
968 rewriter, bcastOp.getLoc(), flatResultType, src);
969 if (resultType != flatResultType)
970 newOp = vector::ShapeCastOp::create(rewriter, bcastOp.getLoc(),
972 rewriter.replaceOp(bcastOp, newOp);
976 if (laneSize * elWidth == 256) {
978 auto aieBcastOp = aievec::BroadcastScalarOp::create(
979 rewriter, bcastOp.getLoc(), vecType, src);
981 aievec::ExtOp::create(rewriter, bcastOp.getLoc(), flatResultType,
982 aieBcastOp.getResult(), 0);
983 if (resultType != flatResultType)
984 newOp = vector::ShapeCastOp::create(rewriter, bcastOp.getLoc(),
986 rewriter.replaceOp(bcastOp, newOp);
990 if (laneSize * elWidth == 1024) {
992 auto aieBcastOp = aievec::BroadcastScalarOp::create(
993 rewriter, bcastOp.getLoc(), vecType, src);
994 Value newOp = aievec::ConcatOp::create(
995 rewriter, bcastOp.getLoc(), flatResultType,
996 SmallVector<Value>({aieBcastOp.getResult(), aieBcastOp.getResult()}));
997 if (resultType != flatResultType)
998 newOp = vector::ShapeCastOp::create(rewriter, bcastOp.getLoc(),
1000 rewriter.replaceOp(bcastOp, newOp);
1012 using OpConversionPattern::OpConversionPattern;
1020 ConversionPatternRewriter &rewriter)
const override {
1022 auto resultType = dyn_cast<VectorType>(addOp.getType());
1028 extractMACOperandsFromAddOperands(adaptor.getLhs(), adaptor.getRhs());
1031 auto [lhs, rhs, acc] = *res;
1034 unsigned resultElWidth =
1035 resultType.getElementType().getIntOrFloatBitWidth();
1038 if ((laneSize != 32 || resultElWidth != 16) &&
1039 (laneSize != 16 || resultElWidth != 32))
1044 auto upsOp = aievec::UPSOp::create(rewriter, addOp.getLoc(), accType, acc,
1046 auto fmaElemOp = aievec::FMAElemOp::create(
1047 rewriter, addOp.getLoc(), accType, lhs, rhs, upsOp.getResult(),
1050 auto shiftParamOp = arith::ConstantOp::create(
1051 rewriter, addOp.getLoc(), rewriter.getI32IntegerAttr(
shiftParam));
1052 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1053 addOp, resultType, fmaElemOp.getResult(), shiftParamOp.getResult());
1062static Value lowerBF16FMAHalf(Value lhs, Value rhs, Value acc,
1063 unsigned shiftParam, Location loc,
1064 ConversionPatternRewriter &rewriter) {
1065 auto f32AccType = VectorType::get({16}, rewriter.getF32Type());
1067 aievec::UPSOp::create(rewriter, loc, f32AccType, acc, shiftParam);
1068 auto fmaElemOp = aievec::FMAElemOp::create(rewriter, loc, f32AccType, lhs,
1069 rhs, upsOp.getResult(),
1071 auto shiftParamOp = arith::ConstantOp::create(
1072 rewriter, loc, rewriter.getI32IntegerAttr(shiftParam));
1074 aievec::SRSOp::create(rewriter, loc, cast<VectorType>(lhs.getType()),
1075 fmaElemOp.getResult(), shiftParamOp.getResult());
1076 return srsOp.getResult();
1088 using OpConversionPattern::OpConversionPattern;
1096 ConversionPatternRewriter &rewriter)
const override {
1098 auto resVecTy = cast<VectorType>(fmaOp.getType());
1099 auto resElemTy = resVecTy.getElementType();
1103 if ((!resElemTy.isF32() && !resElemTy.isBF16()) ||
1104 (numElems != 16 && !(resElemTy.isBF16() && numElems == 32)))
1105 return rewriter.notifyMatchFailure(
1106 fmaOp,
"Unsupported operand types in vector.fma lowering.");
1108 Value lhs = adaptor.getLhs();
1109 Value rhs = adaptor.getRhs();
1110 Value acc = adaptor.getAcc();
1113 if (numElems == 32 && resElemTy.isBF16()) {
1117 splitWideVectorOp<vector::FMAOp>(
1118 fmaOp, {lhs, rhs, acc}, halfType, resVecTy, rewriter,
1119 [localShiftParam](ArrayRef<std::pair<Value, Value>> halfInputs,
1120 Location loc, ConversionPatternRewriter &rewriter) {
1121 auto [lhsLow, lhsHigh] = halfInputs[0];
1122 auto [rhsLow, rhsHigh] = halfInputs[1];
1123 auto [accLow, accHigh] = halfInputs[2];
1125 Value lowResult = lowerBF16FMAHalf(lhsLow, rhsLow, accLow,
1126 localShiftParam, loc, rewriter);
1127 Value highResult = lowerBF16FMAHalf(lhsHigh, rhsHigh, accHigh,
1128 localShiftParam, loc, rewriter);
1129 return std::make_pair(lowResult, highResult);
1134 if (resElemTy.isBF16())
1135 acc = aievec::UPSOp::create(rewriter, fmaOp.getLoc(),
1136 VectorType::get({16}, rewriter.getF32Type()),
1139 lhs = getSourceOfWideningOp(lhs).value_or(
nullptr);
1140 rhs = getSourceOfWideningOp(rhs).value_or(
nullptr);
1142 return rewriter.notifyMatchFailure(
1143 fmaOp,
"vector.fma operands are f32, and they don't come from "
1144 "arith.extf on bf16; can't lower to aievec.");
1145 if (!cast<VectorType>(lhs.getType()).getElementType().isBF16() ||
1146 !cast<VectorType>(rhs.getType()).getElementType().isBF16())
1147 return rewriter.notifyMatchFailure(
1148 fmaOp,
"vector.fma operands come from arith.extf, but the source "
1149 "of the widening op is not bf16; can't lower to aievec.");
1152 aievec::FMAElemOp::create(rewriter, fmaOp.getLoc(), acc.getType(), lhs,
1155 if (resElemTy.isBF16()) {
1156 auto shiftParamOp = arith::ConstantOp::create(
1157 rewriter, fmaOp.getLoc(), rewriter.getI32IntegerAttr(
shiftParam));
1158 newOp = aievec::SRSOp::create(rewriter, fmaOp.getLoc(), resVecTy, newOp,
1162 rewriter.replaceOp(fmaOp, newOp);
1174 using OpConversionPattern::OpConversionPattern;
1182 ConversionPatternRewriter &rewriter)
const override {
1184 auto resultType = dyn_cast<VectorType>(addOp.getType());
1189 auto elemType = resultType.getElementType();
1190 if (!elemType.isBF16())
1195 extractFMACOperandsFromAddOperands(adaptor.getLhs(), adaptor.getRhs());
1198 auto [lhs, rhs, acc] = *res;
1203 if (laneSize == 32) {
1207 splitWideVectorOp<arith::AddFOp>(
1208 addOp, {lhs, rhs, acc}, halfType, resultType, rewriter,
1209 [localShiftParam](ArrayRef<std::pair<Value, Value>> halfInputs,
1210 Location loc, ConversionPatternRewriter &rewriter) {
1211 auto [lhsLow, lhsHigh] = halfInputs[0];
1212 auto [rhsLow, rhsHigh] = halfInputs[1];
1213 auto [accLow, accHigh] = halfInputs[2];
1215 Value lowResult = lowerBF16FMAHalf(lhsLow, rhsLow, accLow,
1216 localShiftParam, loc, rewriter);
1217 Value highResult = lowerBF16FMAHalf(lhsHigh, rhsHigh, accHigh,
1218 localShiftParam, loc, rewriter);
1219 return std::make_pair(lowResult, highResult);
1228 auto f32AccType = VectorType::get({16}, rewriter.getF32Type());
1229 auto upsOp = aievec::UPSOp::create(rewriter, addOp.getLoc(), f32AccType,
1231 auto fmaElemOp = aievec::FMAElemOp::create(
1232 rewriter, addOp.getLoc(), f32AccType, lhs, rhs, upsOp.getResult(),
1235 auto shiftParamOp = arith::ConstantOp::create(
1236 rewriter, addOp.getLoc(), rewriter.getI32IntegerAttr(
shiftParam));
1237 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1238 addOp, resultType, fmaElemOp.getResult(), shiftParamOp.getResult());
1250 using OpConversionPattern::OpConversionPattern;
1258 ConversionPatternRewriter &rewriter)
const override {
1260 auto resultType = dyn_cast<VectorType>(mulOp.getType());
1268 auto isAddOp = [&](Operation *op) {
return isa<arith::AddFOp>(op); };
1269 if (resultType.getElementType().isBF16() && mulOp->hasOneUse() &&
1270 llvm::any_of(mulOp->getUsers(), isAddOp))
1273 unsigned resultElWidth =
1274 resultType.getElementType().getIntOrFloatBitWidth();
1279 if (laneSize == 32 && resultElWidth == 16) {
1283 splitWideBinaryVectorOp<arith::MulFOp>(
1284 mulOp, adaptor.getLhs(), adaptor.getRhs(), halfType, resultType,
1286 [localShiftParam](Value lhsHalf, Value rhsHalf, Location loc,
1287 ConversionPatternRewriter &rewriter) -> Value {
1288 Type accType = getVectorOpDestType(
1289 cast<VectorType>(lhsHalf.getType()), true);
1290 auto mulElemOp = aievec::MulElemOp::create(rewriter, loc, accType,
1292 auto shiftParamOp = arith::ConstantOp::create(
1293 rewriter, loc, rewriter.getI32IntegerAttr(localShiftParam));
1294 auto srsOp = aievec::SRSOp::create(
1295 rewriter, loc, cast<VectorType>(lhsHalf.getType()),
1296 mulElemOp.getResult(), shiftParamOp.getResult());
1297 return srsOp.getResult();
1303 if (laneSize != 16 || (resultElWidth != 16 && resultElWidth != 32))
1307 auto lval = adaptor.getLhs();
1308 auto rval = adaptor.getRhs();
1309 lval = getSourceOfWideningOp(lval).value_or(lval);
1310 rval = getSourceOfWideningOp(rval).value_or(rval);
1311 auto lSrcType = cast<VectorType>(lval.getType());
1312 auto rSrcType = cast<VectorType>(rval.getType());
1313 unsigned lBitWidth = lSrcType.getElementType().getIntOrFloatBitWidth();
1314 unsigned rBitWidth = rSrcType.getElementType().getIntOrFloatBitWidth();
1316 if (rBitWidth > lBitWidth) {
1320 if (lSrcType != rSrcType) {
1325 unsigned bitWidth = (rBitWidth > lBitWidth) ? rBitWidth : lBitWidth;
1326 Type srcElemType = (rBitWidth > lBitWidth) ? rSrcType.getElementType()
1327 : lSrcType.getElementType();
1328 unsigned numLanes = 0;
1329 if (isa<FloatType>(srcElemType) && (bitWidth == 16 || bitWidth == 32)) {
1331 }
else if (isa<IntegerType>(srcElemType) &&
1332 (bitWidth == 8 || bitWidth == 16)) {
1334 }
else if (isa<IntegerType>(srcElemType) && (bitWidth == 32)) {
1340 if (targetInputType != lSrcType) {
1341 lval = convertValueToTargetTypeAIE2(rewriter, mulOp.getLoc(), lval,
1345 if (targetInputType != rSrcType) {
1346 rval = convertValueToTargetTypeAIE2(rewriter, mulOp.getLoc(), rval,
1354 auto mulElemOp = aievec::MulElemOp::create(rewriter, mulOp.getLoc(),
1355 accType, lval, rval);
1358 auto mulElemResultType = mulElemOp.getType();
1359 auto mulElemResultElWidth =
1360 mulElemResultType.getElementType().getIntOrFloatBitWidth();
1362 if (mulElemResultElWidth == resultElWidth) {
1363 rewriter.replaceOpWithNewOp<aievec::CastOp>(
1364 mulOp, resultType, mulElemOp.getResult(),
false);
1365 }
else if (mulElemResultElWidth > resultElWidth) {
1366 auto shiftParamOp = arith::ConstantOp::create(
1367 rewriter, mulOp.getLoc(), rewriter.getI32IntegerAttr(
shiftParam));
1368 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1369 mulOp, resultType, mulElemOp.getResult(), shiftParamOp.getResult());
1384 using OpConversionPattern::OpConversionPattern;
1392 ConversionPatternRewriter &rewriter)
const override {
1394 auto resultType = dyn_cast<VectorType>(mulOp.getType());
1399 auto isAddOp = [&](Operation *op) {
return isa<arith::AddIOp>(op); };
1400 if (mulOp->hasOneUse() && llvm::any_of(mulOp->getUsers(), isAddOp))
1404 unsigned resultElWidth =
1405 resultType.getElementType().getIntOrFloatBitWidth();
1408 if ((laneSize != 32 || (resultElWidth != 16 && resultElWidth != 8)) &&
1409 ((laneSize != 16 && laneSize != 32) || resultElWidth != 32))
1413 auto lval = adaptor.getLhs();
1414 auto rval = adaptor.getRhs();
1416 lval = getSourceOfWideningOp(lval).value_or(lval);
1417 rval = getSourceOfWideningOp(rval).value_or(rval);
1419 auto lSrcType = cast<VectorType>(lval.getType());
1420 auto rSrcType = cast<VectorType>(rval.getType());
1421 unsigned lBitWidth = lSrcType.getElementType().getIntOrFloatBitWidth();
1422 unsigned rBitWidth = rSrcType.getElementType().getIntOrFloatBitWidth();
1424 if (rBitWidth > lBitWidth) {
1429 unsigned bitWidth = (rBitWidth > lBitWidth) ? rBitWidth : lBitWidth;
1430 Type srcElemType = (rBitWidth > lBitWidth) ? rSrcType.getElementType()
1431 : lSrcType.getElementType();
1432 unsigned numLanes = 0;
1433 if (isa<FloatType>(srcElemType) && (bitWidth == 16 || bitWidth == 32)) {
1435 }
else if (isa<IntegerType>(srcElemType) &&
1436 (bitWidth == 8 || bitWidth == 16)) {
1438 }
else if (isa<IntegerType>(srcElemType) && (bitWidth == 32)) {
1444 if (targetInputType != lSrcType) {
1445 lval = convertValueToTargetTypeAIE2(rewriter, mulOp.getLoc(), lval,
1449 if (targetInputType != rSrcType) {
1450 rval = convertValueToTargetTypeAIE2(rewriter, mulOp.getLoc(), rval,
1458 auto mulElemOp = aievec::MulElemOp::create(rewriter, mulOp.getLoc(),
1459 accType, lval, rval);
1462 auto mulElemResultType = mulElemOp.getType();
1463 auto mulElemResultElWidth =
1464 mulElemResultType.getElementType().getIntOrFloatBitWidth();
1466 if (mulElemResultElWidth == resultElWidth) {
1467 rewriter.replaceOpWithNewOp<aievec::CastOp>(
1468 mulOp, resultType, mulElemOp.getResult(),
false);
1469 }
else if (mulElemResultElWidth > resultElWidth) {
1470 auto shiftParamOp = arith::ConstantOp::create(
1471 rewriter, mulOp.getLoc(), rewriter.getI32IntegerAttr(
shiftParam));
1472 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1473 mulOp, resultType, mulElemOp.getResult(), shiftParamOp.getResult());
1487 using OpConversionPattern::OpConversionPattern;
1491 ConversionPatternRewriter &rewriter)
const override {
1493 dyn_cast<aievec::ConcatOp>(adaptor.getLhs().getDefiningOp());
1496 vector::BroadcastOp bcastOp =
nullptr;
1497 auto *concatDefOp = concatOp.getSources()[0].getDefiningOp();
1499 bcastOp = dyn_cast<vector::BroadcastOp>(concatDefOp);
1500 Value lhs = adaptor.getRhs();
1502 bcastOp = dyn_cast<vector::BroadcastOp>(adaptor.getRhs().getDefiningOp());
1505 lhs = concatOp.getSources()[0];
1508 dyn_cast<vector::ExtractOp>(bcastOp.getSource().getDefiningOp());
1512 auto rhs = extOp.getSource();
1513 auto concatVecType = cast<VectorType>(concatOp.getResult().getType());
1515 arith::ConstantOp::create(rewriter, concatOp.getLoc(), lhs.getType(),
1516 rewriter.getZeroAttr(lhs.getType()));
1518 aievec::ConcatOp::create(rewriter, concatOp.getLoc(), concatVecType,
1519 SmallVector<Value, 2>({lhs, zvec}))
1522 auto pos = extOp.getStaticPosition();
1523 int64_t zstart = pos[0];
1525 rewriter.replaceOpWithNewOp<aievec::aie1::FMAOp>(
1526 fmaOp, TypeRange({fmaOp.getResult().getType()}),
1527 ValueRange({lhsX2, rhs, adaptor.getAcc()}), fmaOpAttr);
1535 using OpConversionPattern::OpConversionPattern;
1539 ConversionPatternRewriter &rewriter)
const override {
1540 auto vecType = cast<VectorType>(addOp.getType());
1543 extractMACOperandsFromAddOperands(adaptor.getLhs(), adaptor.getRhs());
1546 auto [lhs, rhs, acc] = *res;
1548 SmallVector<int64_t, 4> concatVecShape(vecType.getShape().begin(),
1549 vecType.getShape().end());
1550 concatVecShape[vecType.getRank() - 1] *= 2;
1551 auto concatVecType =
1552 VectorType::get(concatVecShape, vecType.getElementType());
1556 aievec::ConcatOp::create(rewriter, addOp.getLoc(), concatVecType,
1557 SmallVector<Value, 2>(2, lhs))
1559 auto upsOp = aievec::UPSOp::create(rewriter, addOp.getLoc(), accType, acc);
1560 auto fmaOp = aievec::aie1::FMAOp::create(
1561 rewriter, addOp.getLoc(), accType, lhsX2, rhs, upsOp.getResult(),
1565 auto shiftParamOp = arith::ConstantOp::create(
1566 rewriter, addOp.getLoc(), rewriter.getI32IntegerAttr(0));
1567 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1568 addOp, vecType, fmaOp.getResult(), shiftParamOp.getResult());
1578 using OpConversionPattern::OpConversionPattern;
1589 ConversionPatternRewriter &rewriter)
const override {
1591 if (readOp.getMask())
1592 return readOp.emitError() <<
"AIE doesn't support masked loads.";
1595 AffineMap map = readOp.getPermutationMap();
1596 if (!map.isMinorIdentity())
1600 if (map.isConstant())
1604 auto vType = readOp.getVectorType();
1614 int64_t vSize = vType.getNumElements() * vType.getElementTypeBitWidth();
1623 if ((vSize >
minVectorSize) && std::bitset<8>(multiplicity).count() != 1)
1626 auto updOp = xilinx::aievec::UPDOp::create(
1627 rewriter, readOp.getLoc(), vType, adaptor.getBase(),
1628 adaptor.getIndices(), 0, 0, TypedValue<VectorType>(
nullptr));
1630 updOp = xilinx::aievec::UPDOp::create(
1631 rewriter, readOp.getLoc(), vType, adaptor.getBase(),
1632 adaptor.getIndices(),
maxLoadSize, 1, updOp.getResult());
1634 rewriter.replaceOp(readOp, updOp.getResult());
1644template <
typename SrcOpTy,
typename DstOpTy>
1651 ConversionPatternRewriter &rewriter)
const override {
1652 rewriter.replaceOpWithNewOp<DstOpTy>(
1653 srcOp, srcOp.getResult().getType(), adaptor.getLhs(), adaptor.getRhs(),
1661 using OpConversionPattern::OpConversionPattern;
1665 ConversionPatternRewriter &rewriter)
const override {
1666 auto resType = addOp.getType();
1667 if (!isa<VectorType>(resType))
1670 auto lhs = adaptor.getLhs();
1671 auto rhs = adaptor.getRhs();
1672 auto *lhsDefOp = lhs.getDefiningOp();
1673 auto *rhsDefOp = rhs.getDefiningOp();
1674 if ((isa_and_nonnull<arith::MulIOp>(lhsDefOp)) ||
1675 (isa_and_nonnull<arith::MulIOp>(rhsDefOp)))
1678 rewriter.replaceOpWithNewOp<aievec::aie1::AddOp>(
1679 addOp, resType, lhs, rhs,
1696 using OpConversionPattern::OpConversionPattern;
1699 ConversionPatternRewriter &rewriter)
const override {
1700 auto resTy = dyn_cast<VectorType>(mulOp.getType());
1704 auto newMulOp = aievec::aie1::MulOp::create(
1705 rewriter, mulOp.getLoc(), accTy, adaptor.getLhs(), adaptor.getRhs());
1706 auto shiftParamOp = arith::ConstantOp::create(
1707 rewriter, mulOp.getLoc(), rewriter.getI32IntegerAttr(0));
1708 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1709 mulOp, resTy, newMulOp.getResult(), shiftParamOp.getResult());
1714template <
typename SrcOpTy,
typename DstOpTy>
1722 ConversionPatternRewriter &rewriter)
const override {
1723 VectorType resultType = dyn_cast<VectorType>(srcOp.getType());
1729 llvm::SmallSet<std::pair<unsigned, signed>, 16> laneSizeElWidthPairSet;
1730 laneSizeElWidthPairSet.insert({64, 8});
1731 laneSizeElWidthPairSet.insert({32, 16});
1732 laneSizeElWidthPairSet.insert({16, 32});
1733 laneSizeElWidthPairSet.insert({32, 32});
1735 auto lhs = adaptor.getLhs();
1736 auto rhs = adaptor.getRhs();
1737 auto lhsDefOp = lhs.getDefiningOp();
1738 auto rhsDefOp = rhs.getDefiningOp();
1743 bool lhsIsMul = lhsDefOp && (isa<arith::MulIOp>(lhsDefOp) ||
1744 isa<arith::MulFOp>(lhsDefOp));
1745 bool rhsIsMul = rhsDefOp && (isa<arith::MulIOp>(rhsDefOp) ||
1746 isa<arith::MulFOp>(rhsDefOp));
1747 bool lhsIsConst = lhsDefOp && isa<arith::ConstantOp>(lhsDefOp);
1748 bool rhsIsConst = rhsDefOp && isa<arith::ConstantOp>(rhsDefOp);
1753 if (!resultType.getElementType().isF32() &&
1754 ((lhsIsMul && !rhsIsConst) || (rhsIsMul && !lhsIsConst)))
1757 Type scalarType = resultType.getElementType();
1758 unsigned resultElWidth = scalarType.getIntOrFloatBitWidth();
1762 if (isa<IntegerType>(scalarType)) {
1763 if (!laneSizeElWidthPairSet.count(
1764 std::make_pair(laneSize, resultElWidth)))
1770 if (!lhsDefOp && !rhsDefOp) {
1771 if (laneSize * resultElWidth == 512) {
1772 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getType(), lhs,
1776 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs, resultType,
1781 if (resultElWidth == 32) {
1782 auto lhsExt = getSourceOfWideningOp(lhs).value_or(
nullptr);
1783 auto rhsExt = getSourceOfWideningOp(rhs).value_or(
nullptr);
1785 if (!lhsExt && !rhsExt) {
1786 if (laneSize * resultElWidth == 512) {
1787 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getType(), lhs,
1791 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1795 if (lhsExt && rhsExt) {
1798 VectorType lSrcType = cast<VectorType>(lval.getType());
1802 aievec::UPSOp::create(rewriter, srcOp.getLoc(), accType, lval);
1804 aievec::UPSOp::create(rewriter, srcOp.getLoc(), accType, rval);
1805 auto elemOp = DstOpTy::create(
1806 rewriter, srcOp.getLoc(), lUpsOp->getResult(0).getType(),
1807 lUpsOp->getResult(0), rUpsOp->getResult(0));
1808 rewriter.replaceOpWithNewOp<aievec::CastOp>(
1809 srcOp, srcOp.getType(), elemOp.getResult(),
false);
1813 if (!lhsExt || !rhsExt) {
1814 auto lval = lhsExt ? lhsExt : lhs;
1815 auto rval = rhsExt ? rhsExt : rhs;
1816 auto extVal = lhsExt ? lval : rval;
1817 VectorType vType = cast<VectorType>(extVal.getType());
1818 unsigned bitWidth = vType.getElementType().getIntOrFloatBitWidth();
1820 if (bitWidth != 8 && bitWidth != 16) {
1821 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1825 if (bitWidth * laneSize != 256) {
1826 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1830 Type accType =
nullptr;
1832 if (bitWidth == 8) {
1834 Value valToUps = lhsExt ? lval : rval;
1835 Value valToCast = lhsExt ? rval : lval;
1836 auto upsOp = aievec::UPSOp::create(rewriter, srcOp.getLoc(),
1839 aievec::CastOp::create(rewriter, srcOp.getLoc(), resultType,
1842 lhsExt ? upsOp->getResult(0) : castOp->getResult(0);
1844 lhsExt ? castOp->getResult(0) : upsOp->getResult(0);
1845 auto elemOp = DstOpTy::create(rewriter, srcOp.getLoc(),
1846 upsOp->getResult(0).getType(),
1847 lhsToElemOp, rhsToElemOp);
1848 rewriter.replaceOpWithNewOp<aievec::CastOp>(
1849 srcOp, srcOp.getType(), elemOp.getResult(),
false);
1853 if (bitWidth == 16) {
1856 aievec::UPSOp::create(rewriter, srcOp.getLoc(), accType, lval);
1858 aievec::UPSOp::create(rewriter, srcOp.getLoc(), accType, rval);
1860 auto elemOp = DstOpTy::create(
1861 rewriter, srcOp.getLoc(), lUpsOp->getResult(0).getType(),
1862 lUpsOp->getResult(0), rUpsOp->getResult(0));
1864 auto shiftParamOp = arith::ConstantOp::create(
1865 rewriter, srcOp.getLoc(), rewriter.getI32IntegerAttr(0));
1866 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1867 srcOp, srcOp.getType(), elemOp.getResult(),
1868 shiftParamOp.getResult());
1873 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getType(), lhs, rhs);
1879 if (laneSize != 16 && laneSize != 32)
1883 if (laneSize == 32 && resultElWidth == 32) {
1886 splitWideBinaryVectorOp<SrcOpTy>(
1887 srcOp, lhs, rhs, halfType, resultType, rewriter,
1888 [](Value lhsHalf, Value rhsHalf, Location loc,
1889 ConversionPatternRewriter &rewriter) -> Value {
1890 VectorType halfVecType = cast<VectorType>(lhsHalf.getType());
1892 auto lCastOp = aievec::CastOp::create(rewriter, loc, halfVecType,
1894 auto rCastOp = aievec::CastOp::create(rewriter, loc, halfVecType,
1896 auto elemOp = DstOpTy::create(
1897 rewriter, loc, lCastOp->getResult(0).getType(),
1898 lCastOp->getResult(0), rCastOp->getResult(0));
1899 auto resCastOp = aievec::CastOp::create(
1900 rewriter, loc, halfVecType, elemOp.getResult(),
1902 return resCastOp.getResult();
1908 if (laneSize == 32 && resultElWidth == 16) {
1911 splitWideBinaryVectorOp<SrcOpTy>(
1912 srcOp, lhs, rhs, halfType, resultType, rewriter,
1913 [](Value lhsHalf, Value rhsHalf, Location loc,
1914 ConversionPatternRewriter &rewriter) -> Value {
1915 VectorType halfVecType = cast<VectorType>(lhsHalf.getType());
1918 aievec::UPSOp::create(rewriter, loc, accType, lhsHalf);
1920 aievec::UPSOp::create(rewriter, loc, accType, rhsHalf);
1922 DstOpTy::create(rewriter, loc, lUpsOp->getResult(0).getType(),
1923 lUpsOp->getResult(0), rUpsOp->getResult(0));
1924 auto shiftParamOp = arith::ConstantOp::create(
1925 rewriter, loc, rewriter.getI32IntegerAttr(0));
1926 auto srsOp = aievec::SRSOp::create(rewriter, loc, halfVecType,
1928 shiftParamOp.getResult());
1929 return srsOp.getResult();
1936 if (resultElWidth == 32) {
1937 if (!lhsDefOp && !rhsDefOp) {
1938 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1942 auto lhsExt = getSourceOfWideningOp(lhs).value_or(
nullptr);
1943 auto rhsExt = getSourceOfWideningOp(rhs).value_or(
nullptr);
1945 if (!lhsExt && !rhsExt) {
1946 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1951 if (lhsExt && rhsExt) {
1954 VectorType vType = cast<VectorType>(lval.getType());
1958 aievec::UPSOp::create(rewriter, srcOp.getLoc(), accType, lval);
1960 aievec::UPSOp::create(rewriter, srcOp.getLoc(), accType, rval);
1961 auto elemOp = DstOpTy::create(
1962 rewriter, srcOp.getLoc(), lUpsOp->getResult(0).getType(),
1963 lUpsOp->getResult(0), rUpsOp->getResult(0));
1964 rewriter.replaceOpWithNewOp<aievec::CastOp>(srcOp, srcOp.getType(),
1965 elemOp.getResult());
1970 if (!lhsExt || !rhsExt) {
1971 auto lval = lhsExt ? lhsExt : lhs;
1972 auto rval = rhsExt ? rhsExt : rhs;
1973 auto extVal = lhsExt ? lval : rval;
1974 VectorType vType = cast<VectorType>(extVal.getType());
1977 aievec::UPSOp upsOp;
1978 aievec::CastOp castOp;
1981 aievec::UPSOp::create(rewriter, srcOp.getLoc(), accType, lval);
1982 castOp = aievec::CastOp::create(rewriter, srcOp.getLoc(),
1987 aievec::UPSOp::create(rewriter, srcOp.getLoc(), accType, rval);
1988 castOp = aievec::CastOp::create(rewriter, srcOp.getLoc(),
1993 auto elemOp = DstOpTy::create(
1994 rewriter, srcOp.getLoc(), upsOp->getResult(0).getType(),
1995 upsOp->getResult(0), castOp->getResult(0));
1997 rewriter.replaceOpWithNewOp<aievec::CastOp>(
1998 srcOp, srcOp.getType(), elemOp.getResult(),
false);
2007 aievec::UPSOp::create(rewriter, srcOp.getLoc(), accType, lhs);
2009 aievec::UPSOp::create(rewriter, srcOp.getLoc(), accType, rhs);
2010 auto elemOp = DstOpTy::create(rewriter, srcOp.getLoc(),
2011 lUpsOp->getResult(0).getType(),
2012 lUpsOp->getResult(0), rUpsOp->getResult(0));
2013 auto shiftParamOp = arith::ConstantOp::create(
2014 rewriter, srcOp.getLoc(), rewriter.getI32IntegerAttr(0));
2015 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
2016 srcOp, srcOp.getType(), elemOp.getResult(), shiftParamOp.getResult());
2038template <
typename SrcOpTy,
typename DstOpTy>
2045 ConversionPatternRewriter &rewriter)
const override {
2046 VectorType resultType = dyn_cast<VectorType>(srcOp.getType());
2051 llvm::SmallSet<unsigned, 16> elWidthSet;
2052 elWidthSet.insert(8);
2053 elWidthSet.insert(16);
2054 elWidthSet.insert(32);
2056 Type scalarType = resultType.getElementType();
2057 unsigned resultElWidth = scalarType.getIntOrFloatBitWidth();
2060 unsigned totalBits = laneSize * resultElWidth;
2061 if (!elWidthSet.count(resultElWidth) ||
2062 (totalBits != 512 && !(totalBits == 256 && resultElWidth == 16)))
2065 if (totalBits == 256 && resultElWidth == 16) {
2067 Location loc = srcOp.getLoc();
2070 padV16ToV32WithZeros(rewriter, loc, adaptor.getLhs(), scalarType);
2072 padV16ToV32WithZeros(rewriter, loc, adaptor.getRhs(), scalarType);
2073 auto wideOp = DstOpTy::create(rewriter, loc, wideType, lhsPad, rhsPad);
2074 rewriter.replaceOpWithNewOp<aievec::ExtOp>(srcOp, resultType,
2075 wideOp.getResult(), 0);
2079 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getType(),
2080 adaptor.getLhs(), adaptor.getRhs());
2091template <
typename SrcOpTy,
typename DstOpTy>
2098 ConversionPatternRewriter &rewriter)
const override {
2100 Type resultType = srcOp.getType();
2101 if (isa<VectorType>(resultType))
2104 auto intType = dyn_cast<IntegerType>(resultType);
2108 unsigned elWidth = intType.getWidth();
2109 if (elWidth != 8 && elWidth != 16 && elWidth != 32)
2112 unsigned numLanes = 512 / elWidth;
2114 Location loc = srcOp.getLoc();
2117 auto lhsBcast = aievec::BroadcastScalarOp::create(rewriter, loc, vecType,
2119 auto rhsBcast = aievec::BroadcastScalarOp::create(rewriter, loc, vecType,
2123 auto vecOp = DstOpTy::create(rewriter, loc, vecType, lhsBcast.getResult(),
2124 rhsBcast.getResult());
2128 arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0));
2129 rewriter.replaceOpWithNewOp<aievec::ExtElemOp>(
2130 srcOp, intType, vecOp.getResult(), zeroIdx.getResult());
2147template <
typename SrcOpTy,
typename CmpTy>
2154 ConversionPatternRewriter &rewriter)
const override {
2155 VectorType lhsType = dyn_cast<VectorType>(srcOp.getLhs().getType());
2159 llvm::SmallSet<unsigned, 16> elWidthSet;
2160 elWidthSet.insert(8);
2161 elWidthSet.insert(16);
2162 elWidthSet.insert(32);
2164 Type scalarType = lhsType.getElementType();
2165 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2168 unsigned totalBits = laneSize * elWidth;
2169 if (!elWidthSet.count(elWidth) ||
2170 (totalBits != 512 && !(totalBits == 256 && elWidth == 16)))
2173 Location loc = srcOp.getLoc();
2174 Value lhs = srcOp.getLhs();
2175 Value rhs = srcOp.getRhs();
2176 unsigned effectiveLaneSize = laneSize;
2178 if (totalBits == 256 && elWidth == 16) {
2179 lhs = padV16ToV32WithZeros(rewriter, loc, lhs, scalarType);
2180 rhs = padV16ToV32WithZeros(rewriter, loc, rhs, scalarType);
2181 effectiveLaneSize = 32;
2185 Type type = mlir::IntegerType::get(srcOp.getContext(),
2186 effectiveLaneSize <= 32 ? 32 : 64,
2187 mlir::IntegerType::Unsigned);
2189 CmpTy pred = srcOp.getPredicate();
2191 arith::CmpIPredicate ipred = convertToIntegerPredicate(pred);
2193 aievec::CmpOp aieCmpOp =
2194 createCmpOpAIE2(rewriter, ipred, loc, type, lhs, rhs);
2199 VectorType resultType = dyn_cast<VectorType>(srcOp.getResult().getType());
2202 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2203 srcOp, resultType, aieCmpOp.getResult());
2215 using OpConversionPattern::OpConversionPattern;
2219 ConversionPatternRewriter &rewriter)
const override {
2220 auto resultType = dyn_cast<VectorType>(srcOp.getType());
2224 llvm::SmallSet<unsigned, 16> elWidthSet;
2225 elWidthSet.insert(8);
2226 elWidthSet.insert(16);
2227 elWidthSet.insert(32);
2229 Type scalarType = resultType.getElementType();
2230 unsigned resultElWidth = scalarType.getIntOrFloatBitWidth();
2233 unsigned totalBits = laneSize * resultElWidth;
2234 if (!elWidthSet.count(resultElWidth) ||
2235 (totalBits != 512 && !(totalBits == 256 && resultElWidth == 16)))
2238 if (totalBits == 256 && resultElWidth == 16) {
2240 Location loc = srcOp.getLoc();
2247 Value falsePad = padV16ToV32WithZeros(rewriter, loc,
2248 srcOp.getFalseValue(), scalarType);
2250 padV16ToV32WithZeros(rewriter, loc, srcOp.getTrueValue(), scalarType);
2254 Type type = mlir::IntegerType::get(srcOp.getContext(), 32,
2255 mlir::IntegerType::Unsigned);
2256 auto convertOp = UnrealizedConversionCastOp::create(
2257 rewriter, loc, type, adaptor.getCondition());
2259 auto wideSelOp = aievec::SelOp::create(rewriter, loc, wideType, falsePad,
2260 truePad, convertOp.getResult(0));
2262 rewriter.replaceOpWithNewOp<aievec::ExtOp>(srcOp, resultType,
2263 wideSelOp.getResult(), 0);
2268 mlir::IntegerType::get(srcOp.getContext(), laneSize <= 32 ? 32 : 64,
2269 mlir::IntegerType::Unsigned);
2271 auto convertOp = UnrealizedConversionCastOp::create(
2272 rewriter, srcOp.getLoc(), type, adaptor.getCondition());
2278 rewriter.replaceOpWithNewOp<aievec::SelOp>(
2279 srcOp, srcOp.getResult().getType(), srcOp.getFalseValue(),
2280 srcOp.getTrueValue(), convertOp.getResult(0));
2287 using OpConversionPattern::OpConversionPattern;
2291 ConversionPatternRewriter &rewriter)
const override {
2292 if (
auto kind = srcOp.getKind(); kind != vector::CombiningKind::MINSI &&
2293 kind != vector::CombiningKind::MINUI &&
2294 kind != vector::CombiningKind::MINIMUMF &&
2295 kind != vector::CombiningKind::MINNUMF)
2298 auto vType = cast<VectorType>(srcOp.getVector().getType());
2299 Type scalarType = vType.getElementType();
2300 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2302 unsigned vectorSize = laneSize * elWidth;
2305 if (vectorSize != 512 && !(vectorSize == 256 && scalarType.isBF16()))
2308 Location loc = srcOp.getLoc();
2309 Value inputVec = srcOp.getVector();
2312 if (vectorSize == 256) {
2313 std::tie(inputVec, laneSize) = padV16ToV32WithInfinity(
2314 rewriter, loc, srcOp.getVector(), scalarType,
false);
2317 int shiftIndex = laneSize / 2;
2318 auto reduceResultOp = generateAIEVecOpsForReductionOp<aievec::MinOp>(
2319 rewriter, srcOp, shiftIndex, inputVec);
2321 if (srcOp.getAcc()) {
2322 Value reduceResult = reduceResultOp.getResult();
2323 Value acc = srcOp.getAcc();
2326 if (acc.getType().isBF16()) {
2328 auto minOpBuilder = [&](Value lhs, Value rhs) -> Value {
2329 auto cmpOp = arith::CmpFOp::create(
2330 rewriter, srcOp.getLoc(), arith::CmpFPredicate::OLT, lhs, rhs);
2331 return arith::SelectOp::create(rewriter, srcOp.getLoc(), cmpOp, lhs,
2337 performBF16BinaryOpInF32(reduceResult, acc, srcOp, srcOp.getLoc(),
2338 rewriter, minOpBuilder);
2342 arith::CmpFOp::create(rewriter, srcOp.getLoc(),
2343 arith::CmpFPredicate::OLT, reduceResult, acc);
2344 rewriter.replaceOpWithNewOp<arith::SelectOp>(srcOp, cmpOp, reduceResult,
2348 rewriter.replaceOp(srcOp, reduceResultOp);
2355 using OpConversionPattern::OpConversionPattern;
2359 ConversionPatternRewriter &rewriter)
const override {
2360 if (
auto kind = srcOp.getKind(); kind != vector::CombiningKind::MAXSI &&
2361 kind != vector::CombiningKind::MAXUI &&
2362 kind != vector::CombiningKind::MAXIMUMF &&
2363 kind != vector::CombiningKind::MAXNUMF)
2366 auto vType = cast<VectorType>(srcOp.getVector().getType());
2367 Type scalarType = vType.getElementType();
2368 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2370 unsigned vectorSize = laneSize * elWidth;
2374 if (vectorSize != 512 && !(vectorSize == 256 && scalarType.isBF16()))
2377 Location loc = srcOp.getLoc();
2378 Value inputVec = srcOp.getVector();
2381 if (vectorSize == 256) {
2382 std::tie(inputVec, laneSize) = padV16ToV32WithInfinity(
2383 rewriter, loc, srcOp.getVector(), scalarType,
true);
2386 int shiftIndex = laneSize / 2;
2387 auto reduceResultOp = generateAIEVecOpsForReductionOp<aievec::MaxOp>(
2388 rewriter, srcOp, shiftIndex, inputVec);
2390 if (srcOp.getAcc()) {
2391 Value reduceResult = reduceResultOp.getResult();
2392 Value acc = srcOp.getAcc();
2395 if (acc.getType().isBF16()) {
2397 auto maxOpBuilder = [&](Value lhs, Value rhs) -> Value {
2398 auto cmpOp = arith::CmpFOp::create(
2399 rewriter, srcOp.getLoc(), arith::CmpFPredicate::OGT, lhs, rhs);
2400 return arith::SelectOp::create(rewriter, srcOp.getLoc(), cmpOp, lhs,
2406 performBF16BinaryOpInF32(reduceResult, acc, srcOp, srcOp.getLoc(),
2407 rewriter, maxOpBuilder);
2411 arith::CmpFOp::create(rewriter, srcOp.getLoc(),
2412 arith::CmpFPredicate::OGT, reduceResult, acc);
2413 rewriter.replaceOpWithNewOp<arith::SelectOp>(srcOp, cmpOp, reduceResult,
2417 rewriter.replaceOp(srcOp, reduceResultOp);
2424 using OpConversionPattern::OpConversionPattern;
2428 ConversionPatternRewriter &rewriter)
const override {
2429 if (
auto kind = srcOp.getKind(); kind != vector::CombiningKind::ADD)
2432 auto vType = cast<VectorType>(srcOp.getVector().getType());
2433 Type scalarType = vType.getElementType();
2434 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2436 llvm::SmallSet<std::pair<unsigned, signed>, 16> laneSizeElWidthPairSet;
2437 laneSizeElWidthPairSet.insert({64, 8});
2438 laneSizeElWidthPairSet.insert({32, 16});
2439 laneSizeElWidthPairSet.insert({32, 32});
2440 laneSizeElWidthPairSet.insert({16, 32});
2442 if (!isa<IntegerType>(scalarType) ||
2443 !laneSizeElWidthPairSet.count(std::make_pair(laneSize, elWidth)))
2446 int shiftIndex = laneSize / 2;
2447 if (laneSize == 32 && elWidth == 32) {
2448 Location loc = srcOp.getLoc();
2452 aievec::ExtOp::create(rewriter, loc, vecType, srcOp.getVector(), 0);
2454 aievec::ExtOp::create(rewriter, loc, vecType, srcOp.getVector(), 1);
2456 aievec::AddElemOp::create(rewriter, loc, lExtOp.getResult().getType(),
2457 lExtOp.getResult(), rExtOp.getResult());
2459 auto reduceResultOp = generateAIEVecOpsForReductionOp<aievec::AddElemOp>(
2460 rewriter, srcOp, shiftIndex, addElemOp.getResult());
2462 rewriter.replaceOpWithNewOp<arith::AddIOp>(
2463 srcOp, reduceResultOp.getResult(), srcOp.getAcc());
2465 rewriter.replaceOp(srcOp, reduceResultOp);
2467 auto reduceResultOp = generateAIEVecOpsForReductionOp<aievec::AddElemOp>(
2468 rewriter, srcOp, shiftIndex, srcOp.getVector());
2470 rewriter.replaceOpWithNewOp<arith::AddIOp>(
2471 srcOp, reduceResultOp.getResult(), srcOp.getAcc());
2473 rewriter.replaceOp(srcOp, reduceResultOp);
2482 using OpConversionPattern::OpConversionPattern;
2486 ConversionPatternRewriter &rewriter)
const override {
2487 if (
auto kind = srcOp.getKind(); kind != vector::CombiningKind::ADD)
2490 auto vType = cast<VectorType>(srcOp.getVector().getType());
2491 Type scalarType = vType.getElementType();
2492 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2495 if (!isa<FloatType>(scalarType) || laneSize != 16 || elWidth != 32)
2498 int shiftIndex = laneSize / 2;
2499 assert(shiftIndex > 0 && (shiftIndex & (shiftIndex - 1)) == 0 &&
2500 "shiftIndex must be power of 2");
2502 Location loc = srcOp.getLoc();
2503 Value curValue = srcOp.getVector();
2504 aievec::CastOp curOp =
nullptr;
2506 for (
int id = shiftIndex;
id > 0;
id /= 2) {
2507 auto constOp = arith::ConstantOp::create(
2508 rewriter, loc, rewriter.getI32IntegerAttr(
id * elWidth / 8));
2510 auto shiftBytesOp = aievec::ShiftOp::create(
2511 rewriter, loc, vType, curValue, curValue, constOp.getResult());
2513 auto lCastOp = aievec::CastOp::create(rewriter, loc, vType, curValue,
2516 aievec::CastOp::create(rewriter, loc, vType, shiftBytesOp.getResult(),
2518 auto elemOp = aievec::AddElemOp::create(
2519 rewriter, loc, lCastOp.getResult().getType(), lCastOp.getResult(),
2520 rCastOp.getResult());
2521 curOp = aievec::CastOp::create(rewriter, loc, vType, elemOp.getResult(),
2523 curValue = curOp.getResult();
2527 arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0));
2528 auto reduceResultOp = aievec::ExtElemOp::create(
2529 rewriter, srcOp.getLoc(), scalarType, curOp, zeroConstOp.getResult());
2532 rewriter.replaceOpWithNewOp<arith::AddFOp>(
2533 srcOp, reduceResultOp.getResult(), srcOp.getAcc());
2535 rewriter.replaceOp(srcOp, reduceResultOp);
2544 using OpConversionPattern::OpConversionPattern;
2548 ConversionPatternRewriter &rewriter)
const override {
2550 if (
auto kind = srcOp.getKind(); kind != vector::CombiningKind::ADD) {
2554 auto vType = cast<VectorType>(srcOp.getVector().getType());
2555 Type scalarType = vType.getElementType();
2556 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2560 if (!isa<FloatType>(scalarType) || (laneSize != 16 && laneSize != 32) ||
2565 Location loc = srcOp.getLoc();
2566 Value curValue = srcOp.getVector();
2567 VectorType currentVType = vType;
2570 if (laneSize == 32) {
2573 aievec::ExtOp::create(rewriter, loc, halfType, srcOp.getVector(), 0);
2575 aievec::ExtOp::create(rewriter, loc, halfType, srcOp.getVector(), 1);
2579 aievec::UPSOp::create(rewriter, loc, accType, lowerHalf.getResult());
2581 aievec::UPSOp::create(rewriter, loc, accType, upperHalf.getResult());
2582 auto addElemOp = aievec::AddElemOp::create(
2583 rewriter, loc, accType, lUpsOp.getResult(), rUpsOp.getResult());
2584 auto shiftParamOp = arith::ConstantOp::create(
2585 rewriter, loc, rewriter.getI32IntegerAttr(0));
2587 aievec::SRSOp::create(rewriter, loc, halfType, addElemOp.getResult(),
2588 shiftParamOp.getResult());
2589 curValue = srsOp.getResult();
2590 currentVType = halfType;
2597 dyn_cast<VectorType>(accType).getElementType().getIntOrFloatBitWidth();
2599 auto upsOp = aievec::UPSOp::create(rewriter, loc, accType, curValue);
2600 curValue = upsOp.getResult();
2602 aievec::AddElemOp curOp =
nullptr;
2604 for (
int id = shiftIndex;
id > 0;
id /= 2) {
2605 auto constOp = arith::ConstantOp::create(
2606 rewriter, loc, rewriter.getI32IntegerAttr(
id * accWidth / 8));
2607 auto shiftBytesOp = aievec::ShiftOp::create(
2608 rewriter, loc, accType, curValue, curValue, constOp,
true);
2609 curOp = aievec::AddElemOp::create(rewriter, loc, accType, curValue,
2610 shiftBytesOp.getResult());
2611 curValue = curOp.getResult();
2614 auto shiftParamOp = arith::ConstantOp::create(
2615 rewriter, srcOp.getLoc(), rewriter.getI32IntegerAttr(0));
2618 aievec::SRSOp::create(rewriter, loc, currentVType, curOp.getResult(),
2619 shiftParamOp.getResult());
2623 SmallVector<Value> concatSources = {srsOp.getResult(), srsOp.getResult()};
2625 aievec::ConcatOp::create(rewriter, loc, vecType, concatSources);
2628 arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0));
2629 auto reduceResultOp =
2630 aievec::ExtElemOp::create(rewriter, srcOp.getLoc(), scalarType,
2631 concatOp, zeroConstOp.getResult());
2634 rewriter.replaceOpWithNewOp<arith::AddFOp>(
2635 srcOp, reduceResultOp.getResult(), srcOp.getAcc());
2637 rewriter.replaceOp(srcOp, reduceResultOp);
2645 using OpConversionPattern::OpConversionPattern;
2649 ConversionPatternRewriter &rewriter)
const override {
2650 if (
auto kind = srcOp.getKind(); kind != vector::CombiningKind::ADD)
2653 auto vType = cast<VectorType>(srcOp.getVector().getType());
2654 Type scalarType = vType.getElementType();
2655 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2659 if (!isa<FloatType>(scalarType) || (laneSize != 16 && laneSize != 32) ||
2663 Location loc = srcOp.getLoc();
2664 int shiftIndex = laneSize / 2;
2665 Value inputToReduce = srcOp.getVector();
2668 if (laneSize == 32) {
2673 aievec::ExtOp::create(rewriter, loc, halfType, srcOp.getVector(), 0);
2675 aievec::ExtOp::create(rewriter, loc, halfType, srcOp.getVector(), 1);
2680 aievec::UPSOp::create(rewriter, loc, accType, lowerHalf.getResult());
2682 aievec::UPSOp::create(rewriter, loc, accType, upperHalf.getResult());
2683 auto addElemOp = aievec::AddElemOp::create(
2684 rewriter, loc, accType, lUpsOp.getResult(), rUpsOp.getResult());
2685 auto shiftParamOp = arith::ConstantOp::create(
2686 rewriter, loc, rewriter.getI32IntegerAttr(0));
2688 aievec::SRSOp::create(rewriter, loc, halfType, addElemOp.getResult(),
2689 shiftParamOp.getResult());
2691 inputToReduce = srsOp.getResult();
2697 cast<VectorType>(inputToReduce.getType()),
true);
2699 dyn_cast<VectorType>(accType).getElementType().getIntOrFloatBitWidth();
2701 auto upsOp = aievec::UPSOp::create(rewriter, loc, accType, inputToReduce);
2702 Value curValue = upsOp.getResult();
2704 aievec::AddElemOp curOp =
nullptr;
2705 for (
int id = shiftIndex;
id > 0;
id /= 2) {
2706 auto constOp = arith::ConstantOp::create(
2707 rewriter, loc, rewriter.getI32IntegerAttr(
id * accWidth / 8));
2708 auto shiftBytesOp = aievec::ShiftOp::create(
2709 rewriter, loc, accType, curValue, curValue, constOp,
true);
2710 curOp = aievec::AddElemOp::create(rewriter, loc, accType, curValue,
2711 shiftBytesOp.getResult());
2712 curValue = curOp.getResult();
2719 arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0));
2720 auto extractedF32 = aievec::ExtElemOp::create(
2721 rewriter, srcOp.getLoc(), rewriter.getF32Type(), curOp.getResult(),
2722 zeroConstOp.getResult());
2725 auto reduceResultBF16 = arith::TruncFOp::create(
2726 rewriter, srcOp.getLoc(), scalarType, extractedF32.getResult());
2729 rewriter.replaceOpWithNewOp<arith::AddFOp>(srcOp, reduceResultBF16,
2732 rewriter.replaceOp(srcOp, reduceResultBF16);
2741 using OpConversionPattern::OpConversionPattern;
2745 ConversionPatternRewriter &rewriter)
const override {
2746 auto vType = extractOp.getSourceVectorType();
2747 if (vType.getRank() != 1)
2750 int64_t stride = cast<IntegerAttr>(adaptor.getStrides()[0]).getInt();
2756 return extractOp.emitError()
2757 <<
"AIEv1 doesn't support select ops on int8 types";
2761 int64_t size = cast<IntegerAttr>(adaptor.getSizes()[0]).getInt();
2762 if (vType.getNumElements() != 2 * size)
2765 int64_t offset = cast<IntegerAttr>(adaptor.getOffsets()[0]).getInt();
2766 auto selectOp = aievec::aie1::SelectOp::create(
2767 rewriter, extractOp.getLoc(), vType, adaptor.getSource(),
2768 buildAttributeListForRotationSelectOp(rewriter, vType, offset));
2769 rewriter.replaceOpWithNewOp<aievec::aie1::ExtOp>(
2770 extractOp, extractOp.getType(), selectOp.getResult(),
2771 rewriter.getI8IntegerAttr(0));
2780 using OpConversionPattern::OpConversionPattern;
2784 ConversionPatternRewriter &rewriter)
const override {
2785 auto vType = cast<VectorType>(adaptor.getSource().getType());
2786 if (vType.getRank() != 1)
2789 int64_t stride = cast<IntegerAttr>(adaptor.getStrides()[0]).getInt();
2795 int64_t size = cast<IntegerAttr>(adaptor.getSizes()[0]).getInt();
2796 if (vType.getNumElements() != 2 * size)
2799 auto shortVecType = cast<VectorType>(extractOp.getResult().getType());
2801 aievec::ExtOp::create(rewriter, extractOp.getLoc(), shortVecType,
2802 adaptor.getSource(), rewriter.getI8IntegerAttr(0))
2805 aievec::ExtOp::create(rewriter, extractOp.getLoc(), shortVecType,
2806 adaptor.getSource(), rewriter.getI8IntegerAttr(1))
2808 int64_t offset = cast<IntegerAttr>(adaptor.getOffsets()[0]).getInt();
2810 auto shiftBytesConstOp = arith::ConstantOp::create(
2811 rewriter, extractOp.getLoc(), rewriter.getIntegerType(32),
2812 rewriter.getI32IntegerAttr(shiftBytes));
2813 rewriter.replaceOpWithNewOp<aievec::ShiftOp>(
2814 extractOp, shortVecType, bottomHalf, topHalf, shiftBytesConstOp);
2823 using OpConversionPattern::OpConversionPattern;
2830 ConversionPatternRewriter &rewriter)
const override {
2832 if (updOp->hasOneUse() && isa<aievec::ExtOp>(*updOp->getUsers().begin()))
2835 auto vecType = cast<VectorType>(updOp.getType());
2836 SmallVector<int64_t, 4> vecShape(vecType.getShape().begin(),
2837 vecType.getShape().end());
2838 vecShape[vecType.getRank() - 1] *= 2;
2839 auto longVecType = VectorType::get(vecShape, vecType.getElementType());
2840 auto newUpdOp = aievec::UPDOp::create(
2841 rewriter, updOp.getLoc(), longVecType, adaptor.getSource(),
2842 adaptor.getIndices(), adaptor.getOffset(), adaptor.getIndex(),
2843 adaptor.getVector());
2844 rewriter.replaceOpWithNewOp<aievec::ExtOp>(
2845 updOp, vecType, newUpdOp.getResult(), rewriter.getI8IntegerAttr(0));
2854 using OpConversionPattern::OpConversionPattern;
2860 ConversionPatternRewriter &rewriter)
const override {
2862 if (extOp.getIndex() != 0)
2865 auto updOp = dyn_cast<aievec::UPDOp>(extOp.getSource().getDefiningOp());
2870 if (!updOp->hasOneUse())
2873 rewriter.replaceOpWithNewOp<aievec::UPDOp>(
2874 extOp, extOp.getType(), updOp.getSource(), updOp.getIndices(),
2875 updOp.getOffset(), updOp.getIndex(), updOp.getVector());
2884 using OpConversionPattern::OpConversionPattern;
2888 ConversionPatternRewriter &rewriter)
const override {
2889 if (!matchExpOpForAIE2P(adaptor))
2892 auto srcType = dyn_cast<VectorType>(adaptor.getOperand().getType());
2893 rewriter.replaceOpWithNewOp<aievec::ExpOp>(expOp, srcType,
2894 adaptor.getOperand());
2903 using OpConversionPattern::OpConversionPattern;
2907 ConversionPatternRewriter &rewriter)
const override {
2908 auto srcType = dyn_cast<VectorType>(adaptor.getOperand().getType());
2912 Type scalarType = srcType.getElementType();
2913 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2916 if (!scalarType.isBF16() || (laneSize != 16 && laneSize != 32) ||
2920 rewriter.replaceOpWithNewOp<aievec::TanhOp>(tanhOp, srcType,
2921 adaptor.getOperand());
2927 using OpConversionPattern::OpConversionPattern;
2931 ConversionPatternRewriter &rewriter)
const override {
2933 if (!matchExpOpForAIE2LUT(adaptor))
2936 auto srcType = dyn_cast<VectorType>(adaptor.getOperand().getType());
2938 Location loc = expOp.getLoc();
2939 StringRef funcName =
"getExpBf16";
2941 VectorType v16bf16Ty = mlir::VectorType::get({16}, rewriter.getBF16Type());
2942 VectorType v8i64Ty = mlir::VectorType::get({8}, rewriter.getI64Type());
2943 func::FuncOp fnOp = getOrInsertFuncDecl(
2944 rewriter, expOp->getParentWithTrait<OpTrait::SymbolTable>(), funcName,
2945 TypeRange{v16bf16Ty}, TypeRange{v8i64Ty});
2948 if (laneSize == 32) {
2949 splitWideUnaryVectorOp<math::ExpOp>(
2950 expOp, adaptor.getOperand(), v16bf16Ty, srcType, rewriter,
2951 [&fnOp](Value halfInput, Location loc,
2952 ConversionPatternRewriter &rewriter) -> Value {
2953 VectorType v16bf16Ty =
2954 mlir::VectorType::get({16}, rewriter.getBF16Type());
2955 auto callOp = func::CallOp::create(rewriter, loc, fnOp,
2956 SmallVector<Value>{halfInput});
2958 auto resCastOp = vector::BitCastOp::create(rewriter, loc, accType,
2959 callOp.getResults());
2960 auto shiftParamOp = arith::ConstantOp::create(
2961 rewriter, loc, rewriter.getI32IntegerAttr(0));
2962 auto srsOp = aievec::SRSOp::create(rewriter, loc, v16bf16Ty,
2963 resCastOp.getResult(),
2964 shiftParamOp.getResult());
2965 return srsOp.getResult();
2971 SmallVector<Value> expOperands = {adaptor.getOperand()};
2974 auto callOp = func::CallOp::create(rewriter, loc, fnOp, expOperands);
2975 auto resCastOp = vector::BitCastOp::create(rewriter, loc, accTypeNative,
2976 callOp.getResults());
2978 arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0));
2979 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
2980 expOp, srcType, resCastOp.getResult(), shiftParamOp.getResult());
2987 using OpConversionPattern::OpConversionPattern;
2991 ConversionPatternRewriter &rewriter)
const override {
2992 if (!matchExpOpForAIE2LUT(adaptor))
2994 auto srcType = dyn_cast<VectorType>(adaptor.getOperand().getType());
2995 StringRef includeName =
"lut_based_ops.h";
2996 auto moduleOp = expOp->getParentOfType<mlir::ModuleOp>();
2997 rewriter.setInsertionPointToStart(
2998 &moduleOp.getRegion().getBlocks().front());
2999 emitc::IncludeOp::create(rewriter, moduleOp.getLoc(), includeName,
false);
3001 rewriter.setInsertionPoint(expOp);
3003 auto v16bf16OpaqueTy =
3004 emitc::OpaqueType::get(rewriter.getContext(),
"v16bfloat16");
3005 auto opaquedOperand =
3006 UnrealizedConversionCastOp::create(
3007 rewriter, expOp.getLoc(), v16bf16OpaqueTy, adaptor.getOperand())
3009 SmallVector<Value> expOperands = {opaquedOperand};
3012 Type v16accf32OpaqueTy =
3013 emitc::OpaqueType::get(rewriter.getContext(),
"v16accfloat");
3014 auto callOp = emitc::CallOpaqueOp::create(
3015 rewriter, expOp.getLoc(), TypeRange{v16accf32OpaqueTy},
"getExpBf16",
3016 nullptr,
nullptr, expOperands);
3017 auto resCastOp = UnrealizedConversionCastOp::create(
3018 rewriter, expOp.getLoc(), accTypeNative, callOp.getResults());
3019 auto shiftParamOp = arith::ConstantOp::create(
3020 rewriter, expOp.getLoc(), rewriter.getI32IntegerAttr(0));
3021 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
3022 expOp, srcType, resCastOp.getResult(0), shiftParamOp.getResult());
3036 using OpConversionPattern::OpConversionPattern;
3040 ConversionPatternRewriter &rewriter)
const override {
3041 Type srcType = adaptor.getLhs().getType();
3042 if (!divOp->hasOneUse() || isa<VectorType>(srcType) ||
3043 !isa<FloatType>(srcType))
3046 if (!isNarrowingOp(*divOp->getUsers().begin()))
3049 auto fType = cast<FloatType>(srcType);
3050 if (fType.getWidth() != 32)
3053 auto constOp = dyn_cast<arith::ConstantOp>(divOp.getLhs().getDefiningOp());
3055 cast<FloatAttr>(constOp.getValue()).getValue().convertToDouble() !=
3059 StringRef includeName =
"lut_based_ops.h";
3060 auto moduleOp = divOp->getParentOfType<mlir::ModuleOp>();
3061 rewriter.setInsertionPointToStart(
3062 &moduleOp.getRegion().getBlocks().front());
3063 emitc::IncludeOp::create(rewriter, moduleOp.getLoc(), includeName,
false);
3065 auto truncOp = cast<arith::TruncFOp>(*divOp->getUsers().begin());
3067 rewriter.setInsertionPoint(truncOp);
3069 emitc::OpaqueType::get(rewriter.getContext(),
"bfloat16");
3070 SmallVector<Value> invOperands = {adaptor.getRhs()};
3071 auto callOp = emitc::CallOpaqueOp::create(rewriter, truncOp.getLoc(),
3072 bf16OpaqueTy,
"getInvBf16",
3073 nullptr,
nullptr, invOperands);
3074 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
3075 truncOp, TypeRange{truncOp.getResult().getType()}, callOp.getResults());
3076 rewriter.eraseOp(divOp);
3095 using OpConversionPattern::OpConversionPattern;
3099 ConversionPatternRewriter &rewriter)
const override {
3100 Type srcType = adaptor.getLhs().getType();
3103 auto *defOp = divOp.getLhs().getDefiningOp();
3107 auto constOp = dyn_cast<arith::ConstantOp>(defOp);
3112 if (
auto fType = dyn_cast<FloatType>(srcType)) {
3113 if (fType.getWidth() != 32)
3116 auto floatAttr = dyn_cast<FloatAttr>(constOp.getValue());
3117 if (!floatAttr || !floatAttr.getValue().isExactlyValue(1.0))
3120 rewriter.replaceOpWithNewOp<aievec::InvOp>(divOp, srcType,
3126 if (
auto vecType = dyn_cast<VectorType>(srcType)) {
3127 auto elemType = vecType.getElementType();
3128 if (!elemType.isF32())
3133 if (laneSize != 16 && laneSize != 32)
3137 auto denseAttr = dyn_cast<DenseFPElementsAttr>(constOp.getValue());
3138 if (!denseAttr || !denseAttr.isSplat())
3141 if (!denseAttr.getSplatValue<APFloat>().isExactlyValue(1.0))
3144 rewriter.replaceOpWithNewOp<aievec::InvOp>(divOp, vecType,
3155 using OpConversionPattern::OpConversionPattern;
3159 ConversionPatternRewriter &rewriter)
const override {
3160 auto srcType = dyn_cast<VectorType>(tanhOp.getOperand().getType());
3164 Type scalarType = srcType.getElementType();
3165 if (!isa<FloatType>(scalarType))
3169 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3170 if (elWidth != 16 || laneSize != 16)
3173 StringRef includeName =
"lut_based_ops.h";
3174 auto moduleOp = tanhOp->getParentOfType<mlir::ModuleOp>();
3175 rewriter.setInsertionPointToStart(
3176 &moduleOp.getRegion().getBlocks().front());
3177 emitc::IncludeOp::create(rewriter, moduleOp.getLoc(), includeName,
false);
3179 rewriter.setInsertionPoint(tanhOp);
3180 Type v16bf16OpaqueTy =
3181 emitc::OpaqueType::get(rewriter.getContext(),
"v16bfloat16");
3182 auto opaquedOperand =
3183 UnrealizedConversionCastOp::create(
3184 rewriter, tanhOp.getLoc(), v16bf16OpaqueTy, adaptor.getOperand())
3186 SmallVector<Value> tanhOperands = {opaquedOperand};
3187 auto callOp = emitc::CallOpaqueOp::create(rewriter, tanhOp.getLoc(),
3188 v16bf16OpaqueTy,
"getTanhBf16",
3189 nullptr,
nullptr, tanhOperands);
3190 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
3191 tanhOp, TypeRange{tanhOp.getResult().getType()}, callOp.getResults());
3200 using OpConversionPattern::OpConversionPattern;
3204 ConversionPatternRewriter &rewriter)
const override {
3205 auto srcType = dyn_cast<VectorType>(sqrtOp.getOperand().getType());
3209 Type scalarType = srcType.getElementType();
3210 if (!isa<FloatType>(scalarType))
3214 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3215 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
3218 StringRef includeName =
"vec_math.h";
3219 auto moduleOp = sqrtOp->getParentOfType<mlir::ModuleOp>();
3220 rewriter.setInsertionPointToStart(
3221 &moduleOp.getRegion().getBlocks().front());
3222 emitc::IncludeOp::create(rewriter, moduleOp.getLoc(), includeName,
false);
3224 rewriter.setInsertionPoint(sqrtOp);
3225 Type vLNbf16OpaqueTy;
3228 emitc::OpaqueType::get(rewriter.getContext(),
"v16bfloat16");
3231 emitc::OpaqueType::get(rewriter.getContext(),
"v32bfloat16");
3232 auto opaquedOperand =
3233 UnrealizedConversionCastOp::create(
3234 rewriter, sqrtOp.getLoc(), vLNbf16OpaqueTy, adaptor.getOperand())
3236 SmallVector<Value> sqrtOperands = {opaquedOperand};
3237 auto callOp = emitc::CallOpaqueOp::create(
3238 rewriter, sqrtOp.getLoc(), TypeRange{vLNbf16OpaqueTy},
"getSqrtBf16",
3239 nullptr,
nullptr, sqrtOperands);
3240 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
3241 sqrtOp, TypeRange{sqrtOp.getResult().getType()}, callOp.getResults());
3248 using OpConversionPattern::OpConversionPattern;
3252 ConversionPatternRewriter &rewriter)
const override {
3253 auto srcType = dyn_cast<VectorType>(adaptor.getOperand().getType());
3257 Type scalarType = srcType.getElementType();
3258 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3262 if (!isa<FloatType>(scalarType) || laneSize != 16 || elWidth != 16)
3265 StringRef funcName =
"getRsqrtBf16";
3267 VectorType v16bf16Ty = mlir::VectorType::get({16}, rewriter.getBF16Type());
3268 VectorType v8i64Ty = mlir::VectorType::get({8}, rewriter.getI64Type());
3269 func::FuncOp fnOp = getOrInsertFuncDecl(
3270 rewriter, rsqrtOp->getParentWithTrait<OpTrait::SymbolTable>(), funcName,
3271 TypeRange{v16bf16Ty}, TypeRange{v8i64Ty});
3273 SmallVector<Value> rsqrtOperands = {adaptor.getOperand()};
3277 func::CallOp::create(rewriter, rsqrtOp.getLoc(), fnOp, rsqrtOperands);
3278 auto resCastOp = vector::BitCastOp::create(
3279 rewriter, rsqrtOp.getLoc(), accTypeNative, callOp.getResults());
3280 auto shiftParamOp = arith::ConstantOp::create(
3281 rewriter, rsqrtOp.getLoc(), rewriter.getI32IntegerAttr(0));
3282 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
3283 rsqrtOp, srcType, resCastOp.getResult(), shiftParamOp.getResult());
3292 using OpConversionPattern::OpConversionPattern;
3296 ConversionPatternRewriter &rewriter)
const override {
3297 auto srcType = dyn_cast<VectorType>(rsqrtOp.getOperand().getType());
3301 Type scalarType = srcType.getElementType();
3302 if (!isa<FloatType>(scalarType))
3306 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3307 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
3310 StringRef includeName =
"vec_math.h";
3311 auto moduleOp = rsqrtOp->getParentOfType<mlir::ModuleOp>();
3312 rewriter.setInsertionPointToStart(
3313 &moduleOp.getRegion().getBlocks().front());
3314 emitc::IncludeOp::create(rewriter, moduleOp.getLoc(), includeName,
false);
3316 rewriter.setInsertionPoint(rsqrtOp);
3317 Type vLNbf16OpaqueTy;
3320 emitc::OpaqueType::get(rewriter.getContext(),
"v16bfloat16");
3323 emitc::OpaqueType::get(rewriter.getContext(),
"v32bfloat16");
3324 auto opaquedOperand =
3325 UnrealizedConversionCastOp::create(
3326 rewriter, rsqrtOp.getLoc(), vLNbf16OpaqueTy, adaptor.getOperand())
3328 SmallVector<Value> rsqrtOperands = {opaquedOperand};
3329 auto callOp = emitc::CallOpaqueOp::create(
3330 rewriter, rsqrtOp.getLoc(), TypeRange{vLNbf16OpaqueTy},
"getRsqrtBf16",
3331 nullptr,
nullptr, rsqrtOperands);
3332 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
3333 rsqrtOp, TypeRange{rsqrtOp.getResult().getType()}, callOp.getResults());
3342 using OpConversionPattern::OpConversionPattern;
3346 ConversionPatternRewriter &rewriter)
const override {
3347 auto srcType = dyn_cast<VectorType>(erfOp.getOperand().getType());
3351 Type scalarType = srcType.getElementType();
3352 if (!isa<FloatType>(scalarType))
3356 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3357 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
3360 StringRef includeName =
"vec_math.h";
3361 auto moduleOp = erfOp->getParentOfType<mlir::ModuleOp>();
3362 rewriter.setInsertionPointToStart(
3363 &moduleOp.getRegion().getBlocks().front());
3364 emitc::IncludeOp::create(rewriter, moduleOp.getLoc(), includeName,
false);
3366 rewriter.setInsertionPoint(erfOp);
3367 Type vLNbf16OpaqueTy;
3370 emitc::OpaqueType::get(rewriter.getContext(),
"v16bfloat16");
3373 emitc::OpaqueType::get(rewriter.getContext(),
"v32bfloat16");
3374 auto opaquedOperand =
3375 UnrealizedConversionCastOp::create(
3376 rewriter, erfOp.getLoc(), vLNbf16OpaqueTy, adaptor.getOperand())
3378 SmallVector<Value> erfOperands = {opaquedOperand};
3379 auto callOp = emitc::CallOpaqueOp::create(
3380 rewriter, erfOp.getLoc(), TypeRange{vLNbf16OpaqueTy},
"getErfBf16",
3381 nullptr,
nullptr, erfOperands);
3382 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
3383 erfOp, TypeRange{erfOp.getResult().getType()}, callOp.getResults());
3391template <
typename SrcOpTy>
3398 ConversionPatternRewriter &rewriter)
const override {
3399 auto vecTy = dyn_cast<VectorType>(absOp.getOperand().getType());
3403 Type elemTy = vecTy.getElementType();
3406 unsigned elWidth = elemTy.getIntOrFloatBitWidth();
3408 StringRef includeName =
"vec_math.h";
3409 auto moduleOp = absOp->template getParentOfType<mlir::ModuleOp>();
3410 rewriter.setInsertionPointToStart(
3411 &moduleOp.getRegion().getBlocks().front());
3412 emitc::IncludeOp::create(rewriter, moduleOp.getLoc(), includeName,
false);
3414 rewriter.setInsertionPoint(absOp);
3415 std::ostringstream typeName;
3416 typeName <<
"v" << laneSize;
3417 if (isa<FloatType>(elemTy)) {
3419 typeName <<
"bfloat16";
3421 typeName <<
"float";
3423 typeName <<
"int" << elWidth;
3425 emitc::OpaqueType::get(rewriter.getContext(), typeName.str());
3426 auto opaquedOperand =
3427 UnrealizedConversionCastOp::create(rewriter, absOp.getLoc(),
3428 vecOpaqueTy, adaptor.getOperand())
3430 SmallVector<Value> absOperands = {opaquedOperand};
3431 auto callOp = emitc::CallOpaqueOp::create(rewriter, absOp.getLoc(),
3432 TypeRange{vecOpaqueTy},
"getAbs",
3433 nullptr,
nullptr, absOperands);
3434 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
3435 absOp, TypeRange{absOp.getResult().getType()}, callOp.getResults());
3444template <
typename SrcOpTy>
3451 ConversionPatternRewriter &rewriter)
const override {
3452 VectorType srcType = dyn_cast<VectorType>(extOp.getIn().getType());
3453 VectorType dstType = dyn_cast<VectorType>(extOp.getOut().getType());
3455 Type scalarType = dstType.getElementType();
3456 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3458 isa<IntegerType>(scalarType) && (elWidth == 32 || elWidth == 64)
3462 aievec::UPSOp::create(rewriter, extOp.getLoc(), accType, extOp.getIn());
3464 if (dstType.getElementType().getIntOrFloatBitWidth() == 16) {
3465 auto shiftParamOp = arith::ConstantOp::create(
3466 rewriter, extOp.getLoc(), rewriter.getI32IntegerAttr(0));
3467 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
3468 extOp, dstType, upsOp.getResult(), shiftParamOp.getResult());
3470 rewriter.replaceOpWithNewOp<aievec::CastOp>(
3471 extOp, dstType, upsOp.getResult(),
false);
3480template <
typename SrcOpTy>
3487 ConversionPatternRewriter &rewriter)
const override {
3488 VectorType srcType = dyn_cast<VectorType>(truncOp.getIn().getType());
3489 VectorType dstType = dyn_cast<VectorType>(truncOp.getOut().getType());
3490 Type scalarType = srcType.getElementType();
3491 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3493 isa<IntegerType>(scalarType) && (elWidth == 32 || elWidth == 64)
3497 auto shiftParamOp = arith::ConstantOp::create(
3498 rewriter, truncOp.getLoc(), rewriter.getI32IntegerAttr(0));
3499 if (elWidth == 16) {
3500 auto upsOp = aievec::UPSOp::create(rewriter, truncOp.getLoc(), accType,
3502 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
3503 truncOp, dstType, upsOp.getResult(), shiftParamOp.getResult());
3505 auto castOp = aievec::CastOp::create(rewriter, truncOp.getLoc(), accType,
3506 truncOp.getIn(),
true);
3507 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
3508 truncOp, dstType, castOp.getResult(), shiftParamOp.getResult());
3523static std::optional<Value>
3524getUnOpaquedOperandOfEmitCOpaqueCallOp(Operation *op, StringRef funcName) {
3525 auto uccOp = dyn_cast<UnrealizedConversionCastOp>(op);
3529 auto inVal = uccOp.getInputs()[0];
3530 if (!isa<emitc::OpaqueType>(inVal.getType()))
3533 auto callOp = inVal.getDefiningOp<emitc::CallOpaqueOp>();
3534 if (callOp.getCallee() != funcName)
3537 auto callOperandsUccOp =
3538 callOp.getOperands()[0].getDefiningOp<UnrealizedConversionCastOp>();
3539 if (!callOperandsUccOp)
3542 return callOperandsUccOp.getInputs()[0];
3558template <
typename DivFOpTy>
3559static bool hasSigmoidComputationChain(DivFOpTy divfOp, arith::NegFOp &negOp) {
3560 auto *lhsDefOp = divfOp.getLhs().getDefiningOp();
3563 auto constOp = dyn_cast<arith::ConstantOp>(lhsDefOp);
3567 auto cstDense = dyn_cast<DenseFPElementsAttr>(constOp.getValue());
3571 if (cstDense.template getSplatValue<APFloat>().convertToFloat() != 1.0f)
3574 Operation *addLvalOp;
3575 Operation *addRvalOp;
3581 auto *rhsDefOp = divfOp.getRhs().getDefiningOp();
3584 auto addOp = dyn_cast<arith::AddFOp>(rhsDefOp);
3586 auto srsOp = dyn_cast<aievec::SRSOp>(rhsDefOp);
3591 dyn_cast<aievec::AddElemOp>(srsOp.getSource().getDefiningOp());
3595 auto lUpsOp = dyn_cast<aievec::UPSOp>(addElemOp.getLhs().getDefiningOp());
3596 auto rUpsOp = dyn_cast<aievec::UPSOp>(addElemOp.getRhs().getDefiningOp());
3597 if (!lUpsOp || !rUpsOp)
3600 addLvalOp = lUpsOp.getSource().getDefiningOp();
3601 addRvalOp = rUpsOp.getSource().getDefiningOp();
3604 auto addDefOp = isa<arith::ConstantOp>(addLvalOp)
3605 ? dyn_cast<aievec::SRSOp>(addRvalOp)
3606 : dyn_cast<aievec::SRSOp>(addLvalOp);
3608 addLvalOp = isa<arith::ConstantOp>(addLvalOp)
3609 ? dyn_cast<math::ExpOp>(addRvalOp)
3610 : dyn_cast<math::ExpOp>(addLvalOp);
3612 addLvalOp = addDefOp.getSource().getDefiningOp();
3614 addRvalOp = isa<arith::ConstantOp>(addLvalOp)
3615 ? lUpsOp.getSource().getDefiningOp()
3616 : rUpsOp.getSource().getDefiningOp();
3618 addLvalOp = addOp.getLhs().getDefiningOp();
3619 addRvalOp = addOp.getRhs().getDefiningOp();
3622 if (!addLvalOp || !addRvalOp)
3625 auto addLvalExpOp = dyn_cast<math::ExpOp>(addLvalOp);
3626 auto addRvalExpOp = dyn_cast<math::ExpOp>(addRvalOp);
3627 auto addLvalExpOpIn =
3628 getUnOpaquedOperandOfEmitCOpaqueCallOp(addLvalOp,
"getExpBf16")
3630 auto addRvalExpOpIn =
3631 getUnOpaquedOperandOfEmitCOpaqueCallOp(addRvalOp,
"getExpBf16")
3633 if (!addLvalExpOpIn && addLvalExpOp)
3634 addLvalExpOpIn = addLvalExpOp.getOperand();
3635 if (!addRvalExpOpIn && addRvalExpOp)
3636 addRvalExpOpIn = addRvalExpOp.getOperand();
3638 if (!((addLvalExpOpIn && isa<arith::ConstantOp>(addRvalOp)) ||
3639 (addRvalExpOpIn && isa<arith::ConstantOp>(addLvalOp))))
3642 constOp = isa<arith::ConstantOp>(addLvalOp)
3643 ? cast<arith::ConstantOp>(addLvalOp)
3644 : cast<arith::ConstantOp>(addRvalOp);
3646 cstDense = dyn_cast<DenseFPElementsAttr>(constOp.getValue());
3649 if (cstDense.template getSplatValue<APFloat>().convertToFloat() != 1.0f)
3652 auto expOperand = addLvalExpOpIn ? addLvalExpOpIn : addRvalExpOpIn;
3654 negOp = expOperand.getDefiningOp<arith::NegFOp>();
3656 return negOp !=
nullptr;
3673 using OpConversionPattern::OpConversionPattern;
3677 ConversionPatternRewriter &rewriter)
const override {
3678 auto srcType = dyn_cast<VectorType>(adaptor.getLhs().getType());
3682 Type scalarType = srcType.getElementType();
3683 if (!isa<FloatType>(scalarType))
3687 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3688 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
3691 arith::NegFOp negOp =
nullptr;
3692 if (!hasSigmoidComputationChain(adaptor, negOp))
3695 StringRef includeName =
"vec_math.h";
3696 auto moduleOp = divfOp->getParentOfType<mlir::ModuleOp>();
3697 rewriter.setInsertionPointToStart(
3698 &moduleOp.getRegion().getBlocks().front());
3699 emitc::IncludeOp::create(rewriter, moduleOp.getLoc(), includeName,
false);
3701 rewriter.setInsertionPoint(divfOp);
3705 emitc::OpaqueType::get(rewriter.getContext(),
"v16bfloat16");
3708 emitc::OpaqueType::get(rewriter.getContext(),
"v32bfloat16");
3709 auto opaquedOperand =
3710 UnrealizedConversionCastOp::create(rewriter, divfOp.getLoc(),
3711 vecOpaqueTy, negOp.getOperand())
3713 SmallVector<Value> sigmoidOperands = {opaquedOperand};
3714 auto callOp = emitc::CallOpaqueOp::create(
3715 rewriter, divfOp.getLoc(), TypeRange{vecOpaqueTy},
"getSigmoidBf16",
3716 nullptr,
nullptr, sigmoidOperands);
3717 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
3718 divfOp, TypeRange{adaptor.getLhs().getType()}, callOp.getResults());
3726 using OpConversionPattern::OpConversionPattern;
3730 ConversionPatternRewriter &rewriter)
const override {
3731 auto srcType = dyn_cast<VectorType>(ceilOp.getOperand().getType());
3735 Type scalarType = srcType.getElementType();
3736 if (!isa<FloatType>(scalarType))
3740 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3741 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
3744 StringRef includeName =
"vec_math.h";
3745 auto moduleOp = ceilOp->getParentOfType<mlir::ModuleOp>();
3746 rewriter.setInsertionPointToStart(
3747 &moduleOp.getRegion().getBlocks().front());
3748 emitc::IncludeOp::create(rewriter, moduleOp.getLoc(), includeName,
false);
3750 rewriter.setInsertionPoint(ceilOp);
3754 emitc::OpaqueType::get(rewriter.getContext(),
"v16bfloat16");
3757 emitc::OpaqueType::get(rewriter.getContext(),
"v32bfloat16");
3758 auto opaquedOperand =
3759 UnrealizedConversionCastOp::create(rewriter, ceilOp.getLoc(),
3760 vecOpaqueTy, adaptor.getOperand())
3762 SmallVector<Value> ceilOperands = {opaquedOperand};
3763 auto callOp = emitc::CallOpaqueOp::create(
3764 rewriter, ceilOp.getLoc(), TypeRange{vecOpaqueTy},
"getCeilBf16",
3765 nullptr,
nullptr, ceilOperands);
3766 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
3767 ceilOp, TypeRange{ceilOp.getResult().getType()}, callOp.getResults());
3775 using OpConversionPattern::OpConversionPattern;
3779 ConversionPatternRewriter &rewriter)
const override {
3780 auto srcType = dyn_cast<VectorType>(floorOp.getOperand().getType());
3784 Type scalarType = srcType.getElementType();
3785 if (!isa<FloatType>(scalarType))
3789 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3790 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
3793 StringRef includeName =
"vec_math.h";
3794 auto moduleOp = floorOp->getParentOfType<mlir::ModuleOp>();
3795 rewriter.setInsertionPointToStart(
3796 &moduleOp.getRegion().getBlocks().front());
3797 emitc::IncludeOp::create(rewriter, moduleOp.getLoc(), includeName,
false);
3799 rewriter.setInsertionPoint(floorOp);
3803 emitc::OpaqueType::get(rewriter.getContext(),
"v16bfloat16");
3806 emitc::OpaqueType::get(rewriter.getContext(),
"v32bfloat16");
3807 auto opaquedOperand =
3808 UnrealizedConversionCastOp::create(rewriter, floorOp.getLoc(),
3809 vecOpaqueTy, adaptor.getOperand())
3811 SmallVector<Value> floorOperands = {opaquedOperand};
3812 auto callOp = emitc::CallOpaqueOp::create(
3813 rewriter, floorOp.getLoc(), TypeRange{vecOpaqueTy},
"getFloorBf16",
3814 nullptr,
nullptr, floorOperands);
3815 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
3816 floorOp, TypeRange{floorOp.getResult().getType()}, callOp.getResults());
3825 using OpConversionPattern::OpConversionPattern;
3829 ConversionPatternRewriter &rewriter)
const override {
3830 auto srcType = dyn_cast<VectorType>(negOp.getOperand().getType());
3834 Type scalarType = srcType.getElementType();
3835 if (!isa<FloatType>(scalarType))
3841 Location loc = negOp.getLoc();
3844 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3845 if (elWidth == 16) {
3847 aievec::UPSOp::create(rewriter, loc, accType, adaptor.getOperand());
3849 aievec::NegOp::create(rewriter, loc, accType, upsOp.getResult());
3850 auto shiftParamOp = arith::ConstantOp::create(
3851 rewriter, negOp.getLoc(), rewriter.getI32IntegerAttr(0));
3852 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
3853 negOp, srcType, aieNegOp.getResult(), shiftParamOp.getResult());
3855 auto castOp = aievec::CastOp::create(
3856 rewriter, loc, accType, adaptor.getOperand(),
true);
3858 aievec::NegOp::create(rewriter, loc, accType, castOp.getResult());
3859 rewriter.replaceOpWithNewOp<aievec::CastOp>(
3860 negOp, srcType, aieNegOp.getResult(),
false);
3869static bool hasConstNegOneValue(arith::ConstantOp constOp,
unsigned elWidth) {
3873 auto cstDense = dyn_cast<DenseIntElementsAttr>(constOp.getValue());
3878 return cstDense.getSplatValue<int32_t>() == -1;
3880 return cstDense.getSplatValue<int16_t>() == -1;
3882 return cstDense.getSplatValue<int8_t>() == -1;
3889 using OpConversionPattern::OpConversionPattern;
3893 ConversionPatternRewriter &rewriter)
const override {
3894 auto srcType = dyn_cast<VectorType>(xorOp.getLhs().getType());
3898 Type scalarType = srcType.getElementType();
3899 if (!isa<IntegerType>(scalarType))
3903 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3904 if (laneSize * elWidth != 512)
3908 dyn_cast<arith::ConstantOp>(xorOp.getLhs().getDefiningOp());
3910 dyn_cast<arith::ConstantOp>(xorOp.getRhs().getDefiningOp());
3914 if ((lhsConstOp && hasConstNegOneValue(lhsConstOp, elWidth)) ||
3915 (rhsConstOp && hasConstNegOneValue(rhsConstOp, elWidth))) {
3916 Value val = hasConstNegOneValue(lhsConstOp, elWidth) ? adaptor.getRhs()
3918 rewriter.replaceOpWithNewOp<aievec::BnegOp>(xorOp, srcType, val);
3920 rewriter.replaceOpWithNewOp<aievec::BxorOp>(
3921 xorOp, srcType, adaptor.getLhs(), adaptor.getRhs());
3927template <
typename SrcOpTy,
typename DstOpTy>
3934 ConversionPatternRewriter &rewriter)
const override {
3935 VectorType srcType = dyn_cast<VectorType>(srcOp.getLhs().getType());
3939 Type scalarType = srcType.getElementType();
3940 if (!isa<IntegerType>(scalarType))
3944 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3945 if (laneSize * elWidth != 512)
3948 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getResult().getType(),
3949 adaptor.getLhs(), adaptor.getRhs());
3965 using OpConversionPattern::OpConversionPattern;
3969 ConversionPatternRewriter &rewriter)
const override {
3970 auto srcType = dyn_cast<VectorType>(adaptor.getLhs().getType());
3974 Type scalarType = srcType.getElementType();
3976 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3977 if (laneSize * elWidth != 512)
3981 dyn_cast<aievec::BroadcastOp>(adaptor.getRhs().getDefiningOp());
3986 arith::ConstantOp::create(rewriter, bcastOp.getLoc(),
3987 rewriter.getI32IntegerAttr(bcastOp.getIdx()));
3988 auto extElemOp = aievec::ExtElemOp::create(
3989 rewriter, bcastOp.getLoc(), scalarType, bcastOp, constOp.getResult());
3990 Location loc = rsOp.getLoc();
3996 auto rsOpLow = aievec::ExtOp::create(rewriter, loc, halfSrcType,
3997 adaptor.getLhs(), 0);
3998 auto rsOpHigh = aievec::ExtOp::create(rewriter, loc, halfSrcType,
3999 adaptor.getLhs(), 1);
4002 aievec::UPSOp::create(rewriter, loc, accType, rsOpLow.getResult());
4004 aievec::SRSOp::create(rewriter, loc, halfSrcType,
4005 upsOpLow.getResult(), extElemOp.getResult());
4007 aievec::UPSOp::create(rewriter, loc, accType, rsOpHigh.getResult());
4009 aievec::SRSOp::create(rewriter, loc, halfSrcType,
4010 upsOpHigh.getResult(), extElemOp.getResult());
4011 SmallVector<Value> inputSources = {srsOpLow.getResult(),
4012 srsOpHigh.getResult()};
4013 rewriter.replaceOpWithNewOp<aievec::ConcatOp>(rsOp, srcType,
4018 aievec::UPSOp::create(rewriter, loc, accType, adaptor.getLhs());
4019 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
4020 rsOp, srcType, upsOp.getResult(), extElemOp.getResult());
4043 using OpConversionPattern::OpConversionPattern;
4051 auto defOp = val.getDefiningOp<arith::ConstantOp>();
4053 return std::nullopt;
4054 auto denseAttr = dyn_cast<DenseIntElementsAttr>(defOp.getValue());
4055 if (!denseAttr || !denseAttr.isSplat())
4056 return std::nullopt;
4057 return denseAttr.getSplatValue<APInt>().getSExtValue();
4063 static std::optional<Value>
4066 if (
auto constOp = rhs.getDefiningOp<arith::ConstantOp>()) {
4067 auto denseAttr = dyn_cast<DenseIntElementsAttr>(constOp.getValue());
4068 if (denseAttr && denseAttr.isSplat()) {
4069 int64_t shiftVal = denseAttr.getSplatValue<APInt>().getSExtValue();
4070 return arith::ConstantOp::create(rewriter, loc,
4071 rewriter.getI32IntegerAttr(shiftVal))
4076 if (
auto bcastOp = dyn_cast<aievec::BroadcastOp>(rhs.getDefiningOp())) {
4077 auto constOp = arith::ConstantOp::create(
4078 rewriter, bcastOp.getLoc(),
4079 rewriter.getI32IntegerAttr(bcastOp.getIdx()));
4080 return aievec::ExtElemOp::create(rewriter, bcastOp.getLoc(),
4081 rewriter.getI32Type(), bcastOp,
4082 constOp.getResult())
4085 return std::nullopt;
4090 ConversionPatternRewriter &rewriter)
const override {
4091 auto dstType = dyn_cast<VectorType>(truncOp.getOut().getType());
4095 Type dstScalarType = dstType.getElementType();
4096 if (!isa<IntegerType>(dstScalarType))
4100 Value source = adaptor.getIn();
4104 arith::MinSIOp minOp =
nullptr;
4105 arith::MaxSIOp maxOp =
nullptr;
4107 if (
auto minsiOp = source.getDefiningOp<arith::MinSIOp>()) {
4108 if (
auto maxsiOp = minsiOp.getLhs().getDefiningOp<arith::MaxSIOp>()) {
4111 source = maxOp.getLhs();
4113 }
else if (
auto maxsiOp = source.getDefiningOp<arith::MaxSIOp>()) {
4114 if (
auto minsiOp = maxsiOp.getLhs().getDefiningOp<arith::MinSIOp>()) {
4117 source = minOp.getLhs();
4122 if (minOp && maxOp) {
4125 if (!loVal || !hiVal)
4128 unsigned dstBits = dstScalarType.getIntOrFloatBitWidth();
4130 if (dstBits == 0 || dstBits > 63)
4132 uint64_t one = 1ULL;
4133 int64_t unsignedLo = 0;
4134 int64_t unsignedHi =
static_cast<int64_t
>((one << dstBits) - 1);
4135 int64_t signedLo = -
static_cast<int64_t
>(one << (dstBits - 1));
4136 int64_t signedHi =
static_cast<int64_t
>((one << (dstBits - 1)) - 1);
4138 if (*loVal == unsignedLo && *hiVal == unsignedHi) {
4140 }
else if (*loVal == signedLo && *hiVal == signedHi) {
4149 auto shrsiOp = source.getDefiningOp<arith::ShRSIOp>();
4153 auto srcType = dyn_cast<VectorType>(shrsiOp.getLhs().getType());
4157 Type srcScalarType = srcType.getElementType();
4158 if (!isa<IntegerType>(srcScalarType))
4161 unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth();
4162 unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth();
4163 if (dstElWidth >= srcElWidth)
4166 Location loc = truncOp.getLoc();
4169 auto shiftVal =
getShiftValue(shrsiOp.getRhs(), rewriter, loc);
4174 Value wideInput = shrsiOp.getLhs();
4177 bool needsPadding = (laneSize % 16 != 0);
4179 VectorType paddedSrcType = srcType;
4180 VectorType paddedDstType = dstType;
4181 unsigned paddedLanes = laneSize;
4185 paddedLanes = ((laneSize + 15) / 16) * 16;
4190 auto zeroAttr = rewriter.getZeroAttr(paddedSrcType);
4192 arith::ConstantOp::create(rewriter, loc, zeroAttr).getResult();
4193 SmallVector<int64_t> offsets(1, 0);
4194 SmallVector<int64_t> strides(1, 1);
4195 wideInput = vector::InsertStridedSliceOp::create(
4196 rewriter, loc, wideInput, zeroPad, offsets, strides)
4204 Type accScalarType = paddedSrcType.getElementType();
4205 unsigned accElWidth = accScalarType.getIntOrFloatBitWidth();
4207 if (accElWidth == 16) {
4210 aievec::UPSOp::create(rewriter, loc, accType, wideInput).getResult();
4213 accValue = aievec::CastOp::create(rewriter, loc, paddedSrcType, wideInput,
4217 auto srsOp = aievec::SRSOp::create(rewriter, loc, paddedDstType, accValue,
4220 Value result = srsOp.getResult();
4224 SmallVector<int64_t> offsets(1, 0);
4225 SmallVector<int64_t> sizes = {
static_cast<int64_t
>(laneSize)};
4226 SmallVector<int64_t> strides(1, 1);
4227 result = vector::ExtractStridedSliceOp::create(rewriter, loc, result,
4228 offsets, sizes, strides)
4232 rewriter.replaceOp(truncOp, result);
4237 SmallVector<Operation *, 3> opsToErase;
4238 if (minOp && minOp->use_empty())
4239 opsToErase.push_back(minOp);
4240 if (maxOp && maxOp->use_empty())
4241 opsToErase.push_back(maxOp);
4242 if (shrsiOp->use_empty())
4243 opsToErase.push_back(shrsiOp);
4244 for (Operation *op : opsToErase)
4245 rewriter.eraseOp(op);
4258 using OpConversionPattern::OpConversionPattern;
4262 auto defOp = val.getDefiningOp<arith::ConstantOp>();
4264 return std::nullopt;
4265 auto intAttr = dyn_cast<IntegerAttr>(defOp.getValue());
4267 return std::nullopt;
4268 return intAttr.getInt();
4273 ConversionPatternRewriter &rewriter)
const override {
4275 Type dstType = truncOp.getOut().getType();
4276 if (isa<VectorType>(dstType))
4279 auto dstIntType = dyn_cast<IntegerType>(dstType);
4283 unsigned dstBits = dstIntType.getWidth();
4284 if (dstBits != 8 && dstBits != 16)
4287 auto srcIntType = dyn_cast<IntegerType>(truncOp.getIn().getType());
4288 if (!srcIntType || srcIntType.getWidth() != 32)
4292 Value source = truncOp.getIn();
4295 arith::MinSIOp minOp =
nullptr;
4296 arith::MaxSIOp maxOp =
nullptr;
4298 if (
auto minsiOp = source.getDefiningOp<arith::MinSIOp>()) {
4299 if (
auto maxsiOp = minsiOp.getLhs().getDefiningOp<arith::MaxSIOp>()) {
4302 source = maxOp.getLhs();
4304 }
else if (
auto maxsiOp = source.getDefiningOp<arith::MaxSIOp>()) {
4305 if (
auto minsiOp = maxsiOp.getLhs().getDefiningOp<arith::MinSIOp>()) {
4308 source = minOp.getLhs();
4313 if (minOp && maxOp) {
4316 if (!loVal || !hiVal)
4319 if (dstBits == 0 || dstBits > 63)
4321 uint64_t one = 1ULL;
4322 int64_t unsignedLo = 0;
4323 int64_t unsignedHi =
static_cast<int64_t
>((one << dstBits) - 1);
4324 int64_t signedLo = -
static_cast<int64_t
>(one << (dstBits - 1));
4325 int64_t signedHi =
static_cast<int64_t
>((one << (dstBits - 1)) - 1);
4327 if (*loVal == unsignedLo && *hiVal == unsignedHi) {
4329 }
else if (*loVal == signedLo && *hiVal == signedHi) {
4337 Location loc = truncOp.getLoc();
4340 arith::ShRSIOp shrsiOp = source.getDefiningOp<arith::ShRSIOp>();
4343 if (!isa<IntegerType>(shrsiOp.getLhs().getType()))
4345 preShiftVal = shrsiOp.getLhs();
4346 shiftVal = shrsiOp.getRhs();
4351 preShiftVal = source;
4352 shiftVal = arith::ConstantOp::create(rewriter, loc,
4353 rewriter.getI32IntegerAttr(0));
4357 unsigned srcLanes = 512 / srcIntType.getWidth();
4361 auto bcast = aievec::BroadcastScalarOp::create(rewriter, loc, bcastVecType,
4365 VectorType srsOutType;
4371 unsigned accLanes = srcLanes * 2;
4372 VectorType accVecType =
4374 auto concatSrc = aievec::ConcatOp::create(
4375 rewriter, loc, accVecType,
4376 SmallVector<Value>({bcast.getResult(), bcast.getResult()}));
4378 aievec::CastOp::create(rewriter, loc, accVecType,
4379 concatSrc.getResult(),
true)
4386 accValue = aievec::CastOp::create(rewriter, loc, bcastVecType,
4387 bcast.getResult(),
true)
4393 auto srsOp = aievec::SRSOp::create(rewriter, loc, srsOutType, accValue,
4397 unsigned extLanes = 512 / dstBits;
4399 auto concatForExt = aievec::ConcatOp::create(
4400 rewriter, loc, extVecType,
4401 SmallVector<Value>({srsOp.getResult(), srsOp.getResult()}));
4405 arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0));
4406 rewriter.replaceOpWithNewOp<aievec::ExtElemOp>(
4407 truncOp, dstIntType, concatForExt.getResult(), zeroIdx.getResult());
4410 SmallVector<Operation *, 3> opsToErase;
4411 if (minOp && minOp->use_empty())
4412 opsToErase.push_back(minOp);
4413 if (maxOp && maxOp->use_empty())
4414 opsToErase.push_back(maxOp);
4415 if (shrsiOp && shrsiOp->use_empty())
4416 opsToErase.push_back(shrsiOp);
4417 for (Operation *op : opsToErase)
4418 rewriter.eraseOp(op);
4428 using OpConversionPattern::OpConversionPattern;
4432 ConversionPatternRewriter &rewriter)
const override {
4434 Type resultType = rsOp.getType();
4435 if (isa<VectorType>(resultType))
4438 auto intType = dyn_cast<IntegerType>(resultType);
4439 if (!intType || intType.getWidth() != 32)
4442 Location loc = rsOp.getLoc();
4446 auto lhsBcast = aievec::BroadcastScalarOp::create(rewriter, loc, vecType,
4452 aievec::UPSOp::create(rewriter, loc, accType, lhsBcast.getResult());
4455 auto srsOp = aievec::SRSOp::create(rewriter, loc, vecType,
4456 upsOp.getResult(), adaptor.getRhs());
4460 arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0));
4461 rewriter.replaceOpWithNewOp<aievec::ExtElemOp>(
4462 rsOp, intType, srsOp.getResult(), zeroIdx.getResult());
4469template <
typename MatMulOpTy>
4472 using OpConversionPattern::OpConversionPattern;
4479 auto vecTy = dyn_cast<VectorType>(v.getType());
4482 auto vecShape = vecTy.getShape();
4484 size_t numLeadUnitDims = 0;
4485 while (numLeadUnitDims < vecShape.size() && vecShape[numLeadUnitDims] == 1)
4488 if (!numLeadUnitDims)
4491 SmallVector<int64_t> newShape(vecShape.begin() + numLeadUnitDims,
4493 auto newVecTy = VectorType::get(newShape, vecTy.getElementType());
4494 return vector::ShapeCastOp::create(b, v.getLoc(), newVecTy, v).getResult();
4499 ConversionPatternRewriter &rewriter)
const override {
4503 bool bReshapedAcc = (acc != adaptor.getAcc());
4506 acc = aievec::CastOp::create(rewriter, contractOp.getLoc(), acc.getType(),
4509 auto matmulOp = MatMulOpTy::create(rewriter, contractOp.getLoc(),
4510 acc.getType(), lhs, rhs, acc);
4515 ScopedDiagnosticHandler diagHandler(
4516 contractOp.getContext(), [](Diagnostic &) { return success(); });
4517 if (failed(matmulOp.verifyInvariants())) {
4518 rewriter.eraseOp(matmulOp);
4522 lhs = adaptor.getLhs();
4523 auto wideLhsValue = getSourceOfWideningOp(lhs).value_or(
nullptr);
4527 rhs = adaptor.getRhs();
4528 auto wideRhsValue = getSourceOfWideningOp(rhs).value_or(
nullptr);
4532 matmulOp = MatMulOpTy::create(rewriter, contractOp.getLoc(),
4533 acc.getType(), lhs, rhs, acc);
4534 if (failed(matmulOp.verifyInvariants()))
4538 result = matmulOp.getResult();
4541 result = aievec::CastOp::create(rewriter, contractOp.getLoc(),
4542 acc.getType(), result,
false);
4544 result = vector::ShapeCastOp::create(rewriter, contractOp.getLoc(),
4545 adaptor.getAcc().getType(), result);
4546 rewriter.replaceOp(contractOp, result);
4562 using OpConversionPattern::OpConversionPattern;
4565 ConversionPatternRewriter &rewriter)
const override {
4566 auto resTy = transpOp.getResultVectorType();
4567 auto resShape = resTy.getShape();
4568 auto elemTyBitWidth = resTy.getElementTypeBitWidth();
4569 auto vBitWidth = std::accumulate(resShape.begin(), resShape.end(),
4570 elemTyBitWidth, std::multiplies<>());
4571 if (vBitWidth != 512)
4574 if (elemTyBitWidth != 8 && elemTyBitWidth != 16 && elemTyBitWidth != 32)
4578 for (int64_t i = 0; i < static_cast<int64_t>(resShape.size() - 2); ++i)
4579 if (resShape[i] != 1)
4583 ArrayRef<int64_t> perm = transpOp.getPermutation();
4584 for (int64_t i = 0; i < static_cast<int64_t>(perm.size() - 2); ++i)
4587 if (perm.back() !=
static_cast<int64_t
>(perm.size() - 2))
4590 auto shuffleMode = aievec::ShuffleMode::T32_4X4;
4591 if (elemTyBitWidth == 8) {
4592 switch (resShape.back()) {
4594 shuffleMode = aievec::ShuffleMode::T8_4X16;
4597 shuffleMode = aievec::ShuffleMode::T8_8X8;
4600 shuffleMode = aievec::ShuffleMode::T8_16X4;
4605 }
else if (elemTyBitWidth == 16) {
4606 switch (resShape.back()) {
4608 shuffleMode = aievec::ShuffleMode::T16_2X16;
4611 shuffleMode = aievec::ShuffleMode::T16_4X8;
4614 shuffleMode = aievec::ShuffleMode::T16_8X4;
4617 shuffleMode = aievec::ShuffleMode::T16_16X2;
4622 }
else if (resShape.back() != 4)
4626 VectorType::get({512 / elemTyBitWidth}, resTy.getElementType());
4627 auto loc = transpOp.getLoc();
4628 auto flatInput = vector::ShapeCastOp::create(rewriter, loc, flatVecTy,
4629 adaptor.getVector());
4630 auto shuffOp = aievec::ShuffleOp::create(rewriter, loc, flatVecTy,
4631 flatInput,
nullptr, shuffleMode);
4632 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transpOp, resTy, shuffOp);
4642static void populateAIEVecCommonConversionPatterns(RewritePatternSet &patterns,
4652static void populateAIEVecV1ConversionPatterns(RewritePatternSet &patterns,
4671populateAIEVecV2CommonConversionPatterns(RewritePatternSet &patterns,
4675 if (backend == TargetBackend::CPP) {
4678 >(patterns.getContext(), 128, 1024, 256, 1024);
4687 >(patterns.getContext());
4688 }
else if (backend == TargetBackend::LLVMIR){
4692 >(patterns.getContext());
4739 >(patterns.getContext());
4743static void populateAIEVecV2ConversionPatterns(RewritePatternSet &patterns,
4745 populateAIEVecV2CommonConversionPatterns(patterns, backend);
4747 patterns.getContext(), backend == TargetBackend::CPP);
4750 if (backend == TargetBackend::LLVMIR) {
4752 patterns.getContext());
4760 using OpConversionPattern::OpConversionPattern;
4764 ConversionPatternRewriter &rewriter)
const override {
4766 if (adaptor.getSource().getDefiningOp<vector::ExtractOp>())
4769 auto resultType = cast<VectorType>(bcastOp.getResult().getType());
4771 Type scalarType = resultType.getElementType();
4772 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
4774 auto src = bcastOp.getSource();
4777 if (laneSize * elWidth == 512 || laneSize * elWidth == 256) {
4778 Value newOp = aievec::BroadcastScalarOp::create(
4779 rewriter, bcastOp.getLoc(), flatResultType, src);
4780 if (resultType != flatResultType)
4781 newOp = vector::ShapeCastOp::create(rewriter, bcastOp.getLoc(),
4783 rewriter.replaceOp(bcastOp, newOp);
4787 if (laneSize * elWidth == 1024) {
4789 auto aieBcastOp = aievec::BroadcastScalarOp::create(
4790 rewriter, bcastOp.getLoc(), vecType, src);
4791 Value newOp = aievec::ConcatOp::create(
4792 rewriter, bcastOp.getLoc(), flatResultType,
4793 SmallVector<Value>({aieBcastOp.getResult(), aieBcastOp.getResult()}));
4794 if (resultType != flatResultType)
4795 newOp = vector::ShapeCastOp::create(rewriter, bcastOp.getLoc(),
4797 rewriter.replaceOp(bcastOp, newOp);
4805static void populateAIEVecV2PConversionPatterns(RewritePatternSet &patterns,
4807 populateAIEVecV2CommonConversionPatterns(patterns, backend);
4809 patterns.getContext(), backend == TargetBackend::CPP);
4815 if (backend == TargetBackend::LLVMIR) {
4831static bool isInSigmoidOperationChain(math::ExpOp expOp) {
4832 if (!expOp.getOperand().getDefiningOp<arith::NegFOp>())
4835 arith::AddFOp addOp =
nullptr;
4836 for (Operation *user : expOp->getUsers()) {
4837 addOp = dyn_cast<arith::AddFOp>(user);
4845 auto *addLvalOp = addOp.getLhs().getDefiningOp();
4846 auto *addRvalOp = addOp.getRhs().getDefiningOp();
4847 if (!((isa<math::ExpOp>(addLvalOp) && isa<arith::ConstantOp>(addRvalOp)) ||
4848 (isa<math::ExpOp>(addRvalOp) && isa<arith::ConstantOp>(addLvalOp))))
4851 auto constOp = isa<arith::ConstantOp>(addLvalOp)
4852 ? cast<arith::ConstantOp>(addLvalOp)
4853 : cast<arith::ConstantOp>(addRvalOp);
4855 auto cstDense = dyn_cast<DenseFPElementsAttr>(constOp.getValue());
4859 if (cstDense.getSplatValue<APFloat>().convertToFloat() != 1.0f)
4862 arith::DivFOp divOp =
nullptr;
4863 for (Operation *user : addOp->getUsers()) {
4864 divOp = dyn_cast<arith::DivFOp>(user);
4872 constOp = dyn_cast<arith::ConstantOp>(divOp.getLhs().getDefiningOp());
4875 cstDense = dyn_cast<DenseFPElementsAttr>(constOp.getValue());
4878 if (cstDense.getSplatValue<APFloat>().convertToFloat() != 1.0f)
4884static void configureAIEVecCommonLegalizations(ConversionTarget &target,
4887 .addLegalDialect<xilinx::aievec::aie1::AIEVecAIE1Dialect,
4888 xilinx::aievec::AIEVecDialect, arith::ArithDialect,
4889 ub::UBDialect, emitc::EmitCDialect, func::FuncDialect>();
4890 if (backend == TargetBackend::CPP) {
4891 target.addIllegalOp<vector::TransferReadOp>();
4893 target.addIllegalOp<vector::ExtractStridedSliceOp>();
4894 target.addLegalOp<vector::BitCastOp>();
4896 target.addDynamicallyLegalOp<arith::ExtFOp>([](arith::ExtFOp extfOp) {
4897 auto srcType = dyn_cast<VectorType>(extfOp.getIn().getType());
4898 auto dstType = dyn_cast<VectorType>(extfOp.getOut().getType());
4899 if (!srcType || !dstType)
4902 Type srcScalarType = srcType.getElementType();
4903 Type dstScalarType = dstType.getElementType();
4904 if (!isa<FloatType>(srcScalarType) || !isa<FloatType>(dstScalarType))
4909 unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth();
4910 unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth();
4911 return srcElWidth != 16 || srcLaneSize != 16 || dstElWidth != 32 ||
4915 target.addDynamicallyLegalOp<arith::ExtSIOp>([](arith::ExtSIOp extsiOp) {
4916 auto srcType = dyn_cast<VectorType>(extsiOp.getIn().getType());
4917 auto dstType = dyn_cast<VectorType>(extsiOp.getOut().getType());
4918 if (!srcType || !dstType)
4921 Type srcScalarType = srcType.getElementType();
4922 Type dstScalarType = dstType.getElementType();
4923 if (!isa<IntegerType>(srcScalarType) || !isa<IntegerType>(dstScalarType))
4928 unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth();
4929 unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth();
4930 return srcLaneSize != 32 || (dstElWidth <= srcElWidth) ||
4931 (dstLaneSize != srcLaneSize);
4934 target.addDynamicallyLegalOp<arith::TruncFOp>([](arith::TruncFOp truncfOp) {
4935 auto srcType = dyn_cast<VectorType>(truncfOp.getIn().getType());
4936 auto dstType = dyn_cast<VectorType>(truncfOp.getOut().getType());
4937 if (!srcType || !dstType)
4940 Type srcScalarType = srcType.getElementType();
4941 Type dstScalarType = dstType.getElementType();
4942 if (!isa<FloatType>(srcScalarType) || !isa<FloatType>(dstScalarType))
4947 unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth();
4948 unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth();
4949 return srcElWidth != 32 || srcLaneSize != 16 || dstElWidth != 16 ||
4953 target.addDynamicallyLegalOp<arith::TruncIOp>([](arith::TruncIOp trunciOp) {
4954 auto srcType = dyn_cast<VectorType>(trunciOp.getIn().getType());
4955 auto dstType = dyn_cast<VectorType>(trunciOp.getOut().getType());
4956 if (!srcType || !dstType) {
4959 if (!srcType && !dstType && isSRSCompoundCandidate(trunciOp))
4964 Type srcScalarType = srcType.getElementType();
4965 Type dstScalarType = dstType.getElementType();
4966 if (!isa<IntegerType>(srcScalarType) || !isa<IntegerType>(dstScalarType))
4970 if (isSRSCompoundCandidate(trunciOp))
4975 unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth();
4976 unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth();
4978 return srcLaneSize != 32 || (dstElWidth >= srcElWidth) ||
4979 (dstLaneSize != srcLaneSize);
4982 target.addDynamicallyLegalOp<math::TanhOp>([](math::TanhOp tanhOp) {
4983 auto srcType = dyn_cast<VectorType>(tanhOp.getOperand().getType());
4987 Type scalarType = srcType.getElementType();
4988 if (!isa<FloatType>(scalarType))
4992 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
4993 return elWidth != 16 || laneSize != 16;
4996 target.addDynamicallyLegalOp<math::SqrtOp>([](math::SqrtOp sqrtOp) {
4997 auto srcType = dyn_cast<VectorType>(sqrtOp.getOperand().getType());
5001 Type scalarType = srcType.getElementType();
5002 if (!isa<FloatType>(scalarType))
5006 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5007 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
5010 target.addDynamicallyLegalOp<math::ErfOp>([](math::ErfOp erfOp) {
5011 auto srcType = dyn_cast<VectorType>(erfOp.getOperand().getType());
5015 Type scalarType = srcType.getElementType();
5016 if (!isa<FloatType>(scalarType))
5020 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5021 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
5024 target.addDynamicallyLegalOp<math::AbsFOp>([](math::AbsFOp absfOp) {
5025 auto srcType = dyn_cast<VectorType>(absfOp.getOperand().getType());
5029 Type scalarType = srcType.getElementType();
5031 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5032 return elWidth * laneSize != 512 && elWidth * laneSize != 256;
5035 target.addDynamicallyLegalOp<math::AbsIOp>([](math::AbsIOp absiOp) {
5036 auto srcType = dyn_cast<VectorType>(absiOp.getOperand().getType());
5040 Type scalarType = srcType.getElementType();
5042 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5043 return elWidth * laneSize != 512 && elWidth * laneSize != 256;
5048 if (backend == TargetBackend::CPP) {
5049 target.addDynamicallyLegalOp<arith::DivFOp>([](arith::DivFOp divfOp) {
5050 if (
auto srcType = dyn_cast<VectorType>(divfOp.getLhs().getType());
5052 Type scalarType = divfOp.getLhs().getType();
5053 if (!divfOp->hasOneUse() || !isa<FloatType>(scalarType))
5055 if (!isNarrowingOp(*divfOp->getUsers().begin()))
5058 auto fType = cast<FloatType>(scalarType);
5059 if (fType.getWidth() != 32)
5063 dyn_cast<arith::ConstantOp>(divfOp.getLhs().getDefiningOp());
5065 cast<FloatAttr>(constOp.getValue()).getValue().convertToDouble() !=
5069 Type scalarType = srcType.getElementType();
5070 if (!isa<FloatType>(scalarType))
5074 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5076 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
5079 arith::NegFOp negOp =
nullptr;
5080 if (!hasSigmoidComputationChain(divfOp, negOp))
5088 target.addDynamicallyLegalOp<math::CeilOp>([](math::CeilOp ceilOp) {
5089 auto srcType = dyn_cast<VectorType>(ceilOp.getOperand().getType());
5092 Type scalarType = srcType.getElementType();
5093 if (!isa<FloatType>(scalarType))
5097 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5098 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
5101 target.addDynamicallyLegalOp<math::FloorOp>([](math::FloorOp floorOp) {
5102 auto srcType = dyn_cast<VectorType>(floorOp.getOperand().getType());
5105 Type scalarType = srcType.getElementType();
5106 if (!isa<FloatType>(scalarType))
5110 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5111 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
5114 target.addDynamicallyLegalOp<arith::NegFOp>([](arith::NegFOp negOp) {
5115 auto srcType = dyn_cast<VectorType>(negOp.getOperand().getType());
5118 if (Type scalarType = srcType.getElementType(); !isa<FloatType>(scalarType))
5122 return laneSize != 16;
5125 target.addDynamicallyLegalOp<arith::XOrIOp>([](arith::XOrIOp xorOp) {
5126 auto srcType = dyn_cast<VectorType>(xorOp.getLhs().getType());
5129 Type scalarType = srcType.getElementType();
5130 if (!isa<IntegerType>(scalarType))
5134 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5136 return laneSize * elWidth != 512;
5139 target.addDynamicallyLegalOp<arith::OrIOp>([](arith::OrIOp orOp) {
5140 auto srcType = dyn_cast<VectorType>(orOp.getLhs().getType());
5143 Type scalarType = srcType.getElementType();
5144 if (!isa<IntegerType>(scalarType))
5148 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5150 return laneSize * elWidth != 512;
5153 target.addDynamicallyLegalOp<arith::ShRSIOp>([](arith::ShRSIOp rsOp) {
5154 auto srcType = dyn_cast<VectorType>(rsOp.getLhs().getType());
5158 if (
auto intType = dyn_cast<IntegerType>(rsOp.getLhs().getType()))
5159 if (intType.getWidth() == 32) {
5160 if (shrsiUsedByCompoundSRS(rsOp))
5169 if (shrsiUsedByCompoundSRS(rsOp))
5172 Type scalarType = srcType.getElementType();
5174 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5176 return laneSize * elWidth != 512;
5179 target.addDynamicallyLegalOp<arith::AndIOp>([](arith::AndIOp andOp) {
5180 auto srcType = dyn_cast<VectorType>(andOp.getLhs().getType());
5183 Type scalarType = srcType.getElementType();
5184 if (!isa<IntegerType>(scalarType))
5188 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5190 return laneSize * elWidth != 512;
5193 if (backend == TargetBackend::CPP) {
5194 target.addDynamicallyLegalOp<arith::AddIOp>(
5195 [](arith::AddIOp op) {
return !isa<VectorType>(op.getType()); });
5197 target.addDynamicallyLegalOp<arith::AddFOp>(
5198 [](arith::AddFOp op) {
return !isa<VectorType>(op.getType()); });
5199 target.addDynamicallyLegalOp<arith::SubIOp>(
5200 [](arith::SubIOp op) {
return !isa<VectorType>(op.getType()); });
5201 target.addDynamicallyLegalOp<arith::SubFOp>(
5202 [](arith::SubFOp op) {
return !isa<VectorType>(op.getType()); });
5205static void configureAIEVecV1Legalizations(ConversionTarget &target,
5207 target.addDynamicallyLegalOp<arith::MulIOp>(
5208 [](arith::MulIOp op) {
return !isa<VectorType>(op.getType()); });
5209 target.addDynamicallyLegalOp<arith::MulFOp>(
5210 [](arith::MulFOp op) {
return !isa<VectorType>(op.getType()); });
5211 target.addDynamicallyLegalOp<aievec::aie1::FMAOp>(
5212 [](xilinx::aievec::aie1::FMAOp op) {
5213 auto *lhsDefOp = op.getLhs().getDefiningOp();
5214 aievec::ConcatOp concatOp =
nullptr;
5216 concatOp = dyn_cast<aievec::ConcatOp>(op.getLhs().getDefiningOp());
5220 vector::BroadcastOp srcBcast =
nullptr;
5221 if (
auto *lhsOp = concatOp.getSources()[0].getDefiningOp())
5222 srcBcast = dyn_cast<vector::BroadcastOp>(lhsOp);
5224 auto *rhsOp = op.getRhs().getDefiningOp();
5227 srcBcast = dyn_cast<vector::BroadcastOp>(rhsOp);
5231 if (
auto *srcOp = srcBcast.getSource().getDefiningOp())
5232 return !isa<vector::ExtractOp>(srcOp);
5237 target.addDynamicallyLegalOp<aievec::aie1::AddOp>([](aievec::aie1::AddOp op) {
5238 auto lSrsOp = op.getLhs().getDefiningOp<aievec::SRSOp>();
5239 auto rSrsOp = op.getRhs().getDefiningOp<aievec::SRSOp>();
5241 !lSrsOp.getSource().getDefiningOp<aievec::aie1::MulOp>()) &&
5243 !rSrsOp.getSource().getDefiningOp<aievec::aie1::MulOp>());
5245 target.addLegalDialect<memref::MemRefDialect>();
5248static void configureAIEVecV2PLegalizations(ConversionTarget &target,
5253 if (backend == TargetBackend::LLVMIR) {
5254 target.addDynamicallyLegalOp<math::RsqrtOp>([](math::RsqrtOp rsqrtOp) {
5255 auto vecType = dyn_cast<VectorType>(rsqrtOp.getOperand().getType());
5257 if (vecType && vecType.getElementType().isBF16())
5265 target.addDynamicallyLegalOp<math::ExpOp>([](math::ExpOp expOp) {
5266 auto srcType = dyn_cast<VectorType>(expOp.getOperand().getType());
5270 Type scalarType = srcType.getElementType();
5271 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5274 if (!scalarType.isBF16() || (laneSize != 16 && laneSize != 32) ||
5277 if (expOp->hasOneUse() && isInSigmoidOperationChain(expOp))
5285 target.addDynamicallyLegalOp<math::TanhOp>([](math::TanhOp tanhOp) {
5286 auto srcType = dyn_cast<VectorType>(tanhOp.getOperand().getType());
5290 Type scalarType = srcType.getElementType();
5291 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5294 if (!scalarType.isBF16() || (laneSize != 16 && laneSize != 32) ||
5303 target.addDynamicallyLegalOp<arith::DivFOp>([](arith::DivFOp divfOp) {
5304 Type srcType = divfOp.getLhs().getType();
5308 dyn_cast_or_null<arith::ConstantOp>(divfOp.getLhs().getDefiningOp());
5313 if (srcType.isF32()) {
5314 auto floatAttr = dyn_cast<FloatAttr>(constOp.getValue());
5315 if (floatAttr && floatAttr.getValue().isExactlyValue(1.0))
5321 if (
auto vecType = dyn_cast<VectorType>(srcType)) {
5322 if (vecType.getElementType().isF32()) {
5324 if (laneSize == 16 || laneSize == 32) {
5325 auto denseAttr = dyn_cast<DenseFPElementsAttr>(constOp.getValue());
5326 if (denseAttr && denseAttr.isSplat() &&
5327 denseAttr.getSplatValue<APFloat>().isExactlyValue(1.0))
5339 target.addDynamicallyLegalOp<arith::ExtFOp>([](arith::ExtFOp extfOp) {
5340 auto srcType = dyn_cast<VectorType>(extfOp.getIn().getType());
5341 auto dstType = dyn_cast<VectorType>(extfOp.getOut().getType());
5342 if (!srcType || !dstType)
5345 Type srcScalarType = srcType.getElementType();
5346 Type dstScalarType = dstType.getElementType();
5347 if (!isa<FloatType>(srcScalarType) || !isa<FloatType>(dstScalarType))
5352 if ((srcLaneSize % 16 == 0) && (dstLaneSize % 16 == 0))
5359 target.addDynamicallyLegalOp<arith::TruncFOp>([](arith::TruncFOp truncfOp) {
5360 auto srcType = dyn_cast<VectorType>(truncfOp.getIn().getType());
5361 auto dstType = dyn_cast<VectorType>(truncfOp.getOut().getType());
5362 if (!srcType || !dstType)
5365 Type srcScalarType = srcType.getElementType();
5366 Type dstScalarType = dstType.getElementType();
5367 if (!isa<FloatType>(srcScalarType) || !isa<FloatType>(dstScalarType))
5372 if ((srcLaneSize % 16 == 0) && (dstLaneSize % 16 == 0))
5379 target.addDynamicallyLegalOp<arith::ExtSIOp>([](arith::ExtSIOp extsiOp) {
5380 auto srcType = dyn_cast<VectorType>(extsiOp.getIn().getType());
5381 auto dstType = dyn_cast<VectorType>(extsiOp.getOut().getType());
5382 if (!srcType || !dstType)
5385 Type srcScalarType = srcType.getElementType();
5386 Type dstScalarType = dstType.getElementType();
5387 if (!isa<IntegerType>(srcScalarType) || !isa<IntegerType>(dstScalarType))
5392 if ((srcLaneSize % 16 == 0) && (dstLaneSize % 16 == 0))
5399 target.addDynamicallyLegalOp<arith::TruncIOp>([](arith::TruncIOp trunciOp) {
5400 auto srcType = dyn_cast<VectorType>(trunciOp.getIn().getType());
5401 auto dstType = dyn_cast<VectorType>(trunciOp.getOut().getType());
5402 if (!srcType || !dstType) {
5404 if (!srcType && !dstType && isSRSCompoundCandidate(trunciOp))
5408 Type srcScalarType = srcType.getElementType();
5409 Type dstScalarType = dstType.getElementType();
5410 if (!isa<IntegerType>(srcScalarType) || !isa<IntegerType>(dstScalarType))
5415 if (isSRSCompoundCandidate(trunciOp))
5420 if ((srcLaneSize % 16 == 0) && (dstLaneSize % 16 == 0))
5428 target.addDynamicallyLegalOp<arith::AddFOp>([](arith::AddFOp op) {
5429 auto resultType = dyn_cast<VectorType>(op.getType());
5433 Type scalarType = resultType.getElementType();
5437 if (isa<FloatType>(scalarType))
5438 return laneSize != 16 && laneSize != 32;
5441 return laneSize != 16;
5446 target.addDynamicallyLegalOp<arith::SubFOp>([](arith::SubFOp op) {
5447 auto resultType = dyn_cast<VectorType>(op.getType());
5451 Type scalarType = resultType.getElementType();
5455 if (isa<FloatType>(scalarType))
5456 return laneSize != 16 && laneSize != 32;
5459 return laneSize != 16;
5463static void configureAIEVecV2Legalizations(ConversionTarget &target,
5465 target.addLegalOp<UnrealizedConversionCastOp>();
5466 target.addLegalOp<vector::ShapeCastOp>();
5469 llvm::SmallSet<std::pair<unsigned, unsigned>, 16> laneSizeElWidthPairSet;
5470 laneSizeElWidthPairSet.insert({64, 8});
5471 laneSizeElWidthPairSet.insert({32, 16});
5472 laneSizeElWidthPairSet.insert({16, 32});
5473 laneSizeElWidthPairSet.insert({32, 32});
5476 llvm::SmallSet<unsigned, 16> elWidthSet;
5477 elWidthSet.insert(8);
5478 elWidthSet.insert(16);
5479 elWidthSet.insert(32);
5481 if (backend == TargetBackend::CPP) {
5482 target.addDynamicallyLegalOp<arith::AddIOp>([=](arith::AddIOp op) {
5483 auto resultType = dyn_cast<VectorType>(op.getType());
5487 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
5490 return !laneSizeElWidthPairSet.count(
5491 std::make_pair(laneSize, resultElWidth));
5495 target.addDynamicallyLegalOp<arith::SubIOp>([=](arith::SubIOp op) {
5496 auto resultType = dyn_cast<VectorType>(op.getType());
5499 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
5502 return !laneSizeElWidthPairSet.count(
5503 std::make_pair(laneSize, resultElWidth));
5506 target.addDynamicallyLegalOp<arith::AddFOp>([](arith::AddFOp op) {
5507 auto resultType = dyn_cast<VectorType>(op.getType());
5511 Type scalarType = resultType.getElementType();
5513 unsigned resultElWidth = scalarType.getIntOrFloatBitWidth();
5519 if (laneSize == 32 && resultElWidth == 16)
5522 if (laneSize == 32 && resultElWidth == 32)
5528 target.addDynamicallyLegalOp<arith::SubFOp>([](arith::SubFOp op) {
5529 auto resultType = dyn_cast<VectorType>(op.getType());
5533 Type scalarType = resultType.getElementType();
5535 unsigned resultElWidth = scalarType.getIntOrFloatBitWidth();
5541 if (laneSize == 32 && resultElWidth == 16)
5544 if (laneSize == 32 && resultElWidth == 32)
5550 target.addDynamicallyLegalOp<arith::MulIOp>([](arith::MulIOp op) {
5551 auto resultType = dyn_cast<VectorType>(op.getType());
5554 auto isAddOp = [&](Operation *op) {
return isa<arith::AddIOp>(op); };
5556 if (op->hasOneUse() && llvm::any_of(op->getUsers(), isAddOp))
5559 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
5562 return (laneSize != 32 || (resultElWidth != 16 && resultElWidth != 8)) &&
5563 ((laneSize != 16 && laneSize != 32) || resultElWidth != 32);
5566 target.addDynamicallyLegalOp<arith::MulFOp>([](arith::MulFOp op) {
5567 auto resultType = dyn_cast<VectorType>(op.getType());
5571 auto isAddOp = [&](Operation *op) {
return isa<arith::AddFOp>(op); };
5573 if (op->hasOneUse() && llvm::any_of(op->getUsers(), isAddOp))
5576 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
5580 if (laneSize == 16 && (resultElWidth == 16 || resultElWidth == 32))
5582 if (laneSize == 32 && resultElWidth == 16)
5588 target.addDynamicallyLegalOp<arith::MinSIOp>([=](arith::MinSIOp op) {
5589 auto resultType = dyn_cast<VectorType>(op.getType());
5592 if (
auto intType = dyn_cast<IntegerType>(op.getType())) {
5593 unsigned w = intType.getWidth();
5594 if (w == 8 || w == 16 || w == 32) {
5595 if (scalarClampInCompoundSRS(op))
5603 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
5606 return !elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512;
5609 target.addDynamicallyLegalOp<arith::MaxSIOp>([=](arith::MaxSIOp op) {
5610 auto resultType = dyn_cast<VectorType>(op.getType());
5613 if (
auto intType = dyn_cast<IntegerType>(op.getType())) {
5614 unsigned w = intType.getWidth();
5615 if (w == 8 || w == 16 || w == 32) {
5616 if (scalarClampInCompoundSRS(op))
5624 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
5627 return !elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512;
5630 target.addDynamicallyLegalOp<arith::MinimumFOp>([=](arith::MinimumFOp op) {
5631 auto resultType = dyn_cast<VectorType>(op.getType());
5635 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
5637 unsigned totalBits = laneSize * resultElWidth;
5639 return !elWidthSet.count(resultElWidth) ||
5640 (totalBits != 512 && !(totalBits == 256 && resultElWidth == 16));
5643 target.addDynamicallyLegalOp<arith::MaximumFOp>([=](arith::MaximumFOp op) {
5644 auto resultType = dyn_cast<VectorType>(op.getType());
5648 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
5650 unsigned totalBits = laneSize * resultElWidth;
5652 return !elWidthSet.count(resultElWidth) ||
5653 (totalBits != 512 && !(totalBits == 256 && resultElWidth == 16));
5656 target.addDynamicallyLegalOp<arith::MaxNumFOp>([=](arith::MaxNumFOp op) {
5657 auto resultType = dyn_cast<VectorType>(op.getType());
5661 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
5663 unsigned totalBits = laneSize * resultElWidth;
5665 return !elWidthSet.count(resultElWidth) ||
5666 (totalBits != 512 && !(totalBits == 256 && resultElWidth == 16));
5669 target.addDynamicallyLegalOp<arith::MinNumFOp>([=](arith::MinNumFOp op) {
5670 auto resultType = dyn_cast<VectorType>(op.getType());
5674 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
5676 unsigned totalBits = laneSize * resultElWidth;
5678 return !elWidthSet.count(resultElWidth) ||
5679 (totalBits != 512 && !(totalBits == 256 && resultElWidth == 16));
5682 target.addDynamicallyLegalOp<arith::CmpIOp>([=](arith::CmpIOp op) {
5683 auto lhsType = dyn_cast<VectorType>(op.getLhs().getType());
5687 auto lhsElWidth = lhsType.getElementType().getIntOrFloatBitWidth();
5689 unsigned totalBits = laneSize * lhsElWidth;
5691 return !elWidthSet.count(lhsElWidth) ||
5692 (totalBits != 512 && !(totalBits == 256 && lhsElWidth == 16));
5695 target.addDynamicallyLegalOp<arith::CmpFOp>([=](arith::CmpFOp op) {
5696 auto lhsType = dyn_cast<VectorType>(op.getLhs().getType());
5700 auto lhsElWidth = lhsType.getElementType().getIntOrFloatBitWidth();
5702 unsigned totalBits = laneSize * lhsElWidth;
5704 return !elWidthSet.count(lhsElWidth) ||
5705 (totalBits != 512 && !(totalBits == 256 && lhsElWidth == 16));
5708 target.addDynamicallyLegalOp<arith::SelectOp>([=](arith::SelectOp op) {
5709 auto resultType = dyn_cast<VectorType>(op.getType());
5713 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
5715 unsigned totalBits = laneSize * resultElWidth;
5717 return !elWidthSet.count(resultElWidth) ||
5718 (totalBits != 512 && !(totalBits == 256 && resultElWidth == 16));
5721 target.addDynamicallyLegalOp<vector::ReductionOp>(
5722 [=](vector::ReductionOp op) {
5723 if (
auto kind = op.getKind(); kind != vector::CombiningKind::ADD &&
5724 kind != vector::CombiningKind::MINSI &&
5725 kind != vector::CombiningKind::MINUI &&
5726 kind != vector::CombiningKind::MINIMUMF &&
5727 kind != vector::CombiningKind::MINNUMF &&
5728 kind != vector::CombiningKind::MAXSI &&
5729 kind != vector::CombiningKind::MAXUI &&
5730 kind != vector::CombiningKind::MAXIMUMF &&
5731 kind != vector::CombiningKind::MAXNUMF)
5734 auto vType = dyn_cast<VectorType>(op.getVector().getType());
5738 llvm::SmallSet<std::pair<unsigned, signed>, 16> laneSizeElWidthPairSet;
5739 laneSizeElWidthPairSet.insert({64, 8});
5740 laneSizeElWidthPairSet.insert({32, 16});
5741 laneSizeElWidthPairSet.insert({32, 32});
5742 laneSizeElWidthPairSet.insert({16, 32});
5744 Type scalarType = vType.getElementType();
5745 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5748 if (isa<IntegerType>(scalarType) &&
5749 !laneSizeElWidthPairSet.count(std::make_pair(laneSize, elWidth)))
5752 if (isa<FloatType>(scalarType) && laneSize != 16 && laneSize != 32)
5759 target.addDynamicallyLegalOp<arith::ExtFOp>([](arith::ExtFOp extfOp) {
5760 auto srcType = dyn_cast<VectorType>(extfOp.getIn().getType());
5761 auto dstType = dyn_cast<VectorType>(extfOp.getOut().getType());
5762 if (!srcType || !dstType)
5765 Type srcScalarType = srcType.getElementType();
5766 Type dstScalarType = dstType.getElementType();
5767 if (!isa<FloatType>(srcScalarType) || !isa<FloatType>(dstScalarType))
5772 if ((srcLaneSize % 16 == 0) && (dstLaneSize % 16 == 0))
5779 target.addDynamicallyLegalOp<arith::TruncFOp>([](arith::TruncFOp truncfOp) {
5780 auto srcType = dyn_cast<VectorType>(truncfOp.getIn().getType());
5781 auto dstType = dyn_cast<VectorType>(truncfOp.getOut().getType());
5782 if (!srcType || !dstType)
5785 Type srcScalarType = srcType.getElementType();
5786 Type dstScalarType = dstType.getElementType();
5787 if (!isa<FloatType>(srcScalarType) || !isa<FloatType>(dstScalarType))
5792 if ((srcLaneSize % 16 == 0) && (dstLaneSize % 16 == 0))
5799 target.addDynamicallyLegalOp<arith::ExtSIOp>([](arith::ExtSIOp extsiOp) {
5800 auto srcType = dyn_cast<VectorType>(extsiOp.getIn().getType());
5801 auto dstType = dyn_cast<VectorType>(extsiOp.getOut().getType());
5802 if (!srcType || !dstType)
5805 Type srcScalarType = srcType.getElementType();
5806 Type dstScalarType = dstType.getElementType();
5807 if (!isa<IntegerType>(srcScalarType) || !isa<IntegerType>(dstScalarType))
5812 if ((srcLaneSize % 16 == 0) && (dstLaneSize % 16 == 0))
5819 target.addDynamicallyLegalOp<arith::TruncIOp>([](arith::TruncIOp trunciOp) {
5820 auto srcType = dyn_cast<VectorType>(trunciOp.getIn().getType());
5821 auto dstType = dyn_cast<VectorType>(trunciOp.getOut().getType());
5822 if (!srcType || !dstType) {
5824 if (!srcType && !dstType && isSRSCompoundCandidate(trunciOp))
5828 Type srcScalarType = srcType.getElementType();
5829 Type dstScalarType = dstType.getElementType();
5830 if (!isa<IntegerType>(srcScalarType) || !isa<IntegerType>(dstScalarType))
5835 if (isSRSCompoundCandidate(trunciOp))
5840 if ((srcLaneSize % 16 == 0) && (dstLaneSize % 16 == 0))
5846 target.addIllegalOp<vector::ContractionOp, vector::TransposeOp,
5851 target.addDynamicallyLegalOp<math::ExpOp>([](math::ExpOp expOp) {
5852 auto srcType = dyn_cast<VectorType>(expOp.getOperand().getType());
5856 Type scalarType = srcType.getElementType();
5857 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5860 if (!isa<FloatType>(scalarType) || (laneSize != 16 && laneSize != 32) ||
5863 if (expOp->hasOneUse() && isInSigmoidOperationChain(expOp))
5869 target.addDynamicallyLegalOp<math::RsqrtOp>([](math::RsqrtOp rsqrtOp) {
5870 auto srcType = dyn_cast<VectorType>(rsqrtOp.getOperand().getType());
5874 Type scalarType = srcType.getElementType();
5875 if (!isa<FloatType>(scalarType))
5879 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5880 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
5904 StringRef
getArgument() const final {
return "test-lower-vector-to-aievec"; }
5906 return "Lower vector operations to AIE vector intrinsics";
5910 .insert<affine::AffineDialect, xilinx::aievec::aie1::AIEVecAIE1Dialect,
5911 xilinx::aievec::AIEVecDialect, arith::ArithDialect,
5912 memref::MemRefDialect, scf::SCFDialect, vector::VectorDialect,
5913 emitc::EmitCDialect>();
5917 *
this,
"aie-target",
5919 "Select AIE version: \"aie\", \"aie2\", or \"aie2p\". This will "
5920 "determine the vector size and available operations."),
5921 llvm::cl::init(
"aie")};
5924 *
this,
"target-backend",
5925 llvm::cl::desc(
"Select translation backend: \"cpp\" or \"llvmir\". This "
5926 "will determine the aievec operations used to convert "
5927 "from vector dialect."),
5928 llvm::cl::init(
"cpp")};
5931 auto *op = getOperation();
5932 MLIRContext *context = &getContext();
5933 RewritePatternSet patterns(context);
5934 ConversionTarget target(*context);
5935 auto aieVersion = AIEArch::AIE;
5938 if (targetStr ==
"aieml" || targetStr ==
"aie2")
5939 aieVersion = AIEArch::AIE2;
5940 else if (targetStr ==
"aie2p")
5941 aieVersion = AIEArch::AIE2P;
5942 else if (targetStr !=
"aie") {
5943 op->emitError() <<
"unknown AIE target '" <<
aieTarget <<
"'";
5944 return signalPassFailure();
5951 if (backendStr ==
"llvmir") {
5952 backend = TargetBackend::LLVMIR;
5953 if (aieVersion == AIEArch::AIE) {
5954 op->emitError() <<
"targetting LLVM IR is not supported for AIEv1";
5955 signalPassFailure();
5958 }
else if (backendStr !=
"cpp") {
5959 op->emitError() <<
"unknown target backend '" <<
targetBackend <<
"'";
5960 signalPassFailure();
5965 populateAIEVecCommonConversionPatterns(patterns, backend);
5966 configureAIEVecCommonLegalizations(target, backend);
5967 if (aieVersion == AIEArch::AIE) {
5968 populateAIEVecV1ConversionPatterns(patterns, backend);
5969 configureAIEVecV1Legalizations(target, backend);
5970 }
else if (aieVersion == AIEArch::AIE2) {
5971 populateAIEVecV2ConversionPatterns(patterns, backend);
5972 configureAIEVecV2Legalizations(target, backend);
5973 }
else if (aieVersion == AIEArch::AIE2P) {
5974 populateAIEVecV2PConversionPatterns(patterns, backend);
5975 configureAIEVecV2Legalizations(target, backend);
5976 configureAIEVecV2PLegalizations(target, backend);
5978 llvm_unreachable(
"AIE version is misconfigured");
5981 if (failed(applyPartialConversion(op, target, std::move(patterns))))
5982 return signalPassFailure();
5986static std::unique_ptr<Pass>
5988 return std::make_unique<LowerVectorToAIEVec>(options);
6002 MLIRContext *context = &getContext();
6003 RewritePatternSet patterns(context);
6004 ConversionTarget target(*context);
6006 target.addLegalDialect<aievec::AIEVecDialect>();
6007 target.addDynamicallyLegalOp<aievec::UPDOp>([](aievec::UPDOp op) {
6008 return op.getVector() ||
6009 (op->hasOneUse() && isa<aievec::UPDOp>(*op->getUsers().begin())) ||
6010 llvm::all_of(op->getUsers(),
6011 [](Operation *op) {
return isa<aievec::ExtOp>(op); });
6014 if (
auto *op = getOperation();
6015 failed(applyPartialConversion(op, target, std::move(patterns)))) {
6016 return signalPassFailure();
6029 MLIRContext *context = &getContext();
6030 RewritePatternSet patterns(context);
6031 ConversionTarget target(*context);
6033 target.addLegalDialect<aievec::AIEVecDialect>();
6034 target.addDynamicallyLegalOp<aievec::ExtOp>([](aievec::ExtOp op) {
6035 auto *defOp = op.getSource().getDefiningOp();
6036 return !defOp || !isa<aievec::UPDOp>(defOp) || !defOp->hasOneUse() ||
6040 if (
auto *op = getOperation();
6041 failed(applyPartialConversion(op, target, std::move(patterns)))) {
6042 return signalPassFailure();
6054 pm.addPass(createLowerVectorToAIEVec(options));
6055 pm.addPass(createCanonicalizerPass());
6058 pm.addPass(std::make_unique<ExtendUPDOpsPass>());
6059 pm.addPass(createCSEPass());
6060 pm.addPass(std::make_unique<SimplifyUPDOpsPass>());
6061 pm.addPass(createCanonicalizerPass());
LowerScalarMinMaxToAIEVecMinMaxOp< arith::MaxSIOp, aievec::MaxOp > LowerScalarMaxSIOpToAIEVecMaxOp
LowerVectorMinMaxOpToAIEVecMinMaxOp< arith::MaximumFOp, aievec::MaxOp > LowerVectorMaximumFOpToAIEVecMaxOp
ComputeBandAndBorOpPattern< arith::OrIOp, aievec::BorOp > ComputeBorOpPattern
ComputeBandAndBorOpPattern< arith::AndIOp, aievec::BandOp > ComputeBandOpPattern
OneToOneVectorOpToAIEVecOpPattern< arith::SubFOp, aievec::aie1::SubOp > LowerVectorSubFOpToAIEVecSubOp
ComputeAbsOpPattern< math::AbsIOp > ComputeAbsIOpPattern
LowerTruncOpPattern< arith::TruncFOp > LowerTruncFOpPattern
LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp< arith::AddFOp, aievec::AddElemOp > LowerVectorAddFOpToAIEVecAddElemOp
LowerScalarMinMaxToAIEVecMinMaxOp< arith::MinSIOp, aievec::MinOp > LowerScalarMinSIOpToAIEVecMinOp
LowerVectorMinMaxOpToAIEVecMinMaxOp< arith::MaxSIOp, aievec::MaxOp > LowerVectorMaxSIOpToAIEVecMaxOp
LowerExtOpPattern< arith::ExtFOp > LowerExtFOpPattern
LowerVectorCmpOpToAIEVecCmpOp< arith::CmpFOp, CmpFPredicate > LowerVectorCmpFOpToAIEVecCmpOp
OneToOneVectorOpToAIEVecOpPattern< arith::SubIOp, aievec::aie1::SubOp > LowerVectorSubIOpToAIEVecSubOp
ComputeAbsOpPattern< math::AbsFOp > ComputeAbsFOpPattern
LowerVectorCmpOpToAIEVecCmpOp< arith::CmpIOp, CmpIPredicate > LowerVectorCmpIOpToAIEVecCmpOp
LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp< arith::SubFOp, aievec::SubElemOp > LowerVectorSubFOpToAIEVecSubElemOp
LowerVectorMinMaxOpToAIEVecMinMaxOp< arith::MinSIOp, aievec::MinOp > LowerVectorMinSIOpToAIEVecMinOp
OneToOneVectorOpToAIEVecOpPattern< arith::AddFOp, aievec::aie1::AddOp > LowerVectorAddFOpToAIEVecAddOp
OneToOneVectorOpToAIEVecOpPattern< arith::MulFOp, aievec::aie1::MulOp > LowerVectorMulFOpToAIEVecMulOp
LowerVectorMinMaxOpToAIEVecMinMaxOp< arith::MaxNumFOp, aievec::MaxOp > LowerVectorMaxNumFFOpToAIEVecMaxOp
LowerExtOpPattern< arith::ExtSIOp > LowerExtSIOpPattern
LowerVectorMinMaxOpToAIEVecMinMaxOp< arith::MinimumFOp, aievec::MinOp > LowerVectorMinimumFOpToAIEVecMinOp
LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp< arith::AddIOp, aievec::AddElemOp > LowerVectorAddIOpToAIEVecAddElemOp
mlir::VectorType getFlattenedVectorType(mlir::VectorType vecTy)
unsigned getVectorLaneSize(mlir::VectorType type)
SmallVector< NamedAttribute > buildFMAOpSplatAttrForElemTy(aievec::aie1::FMAOp fmaOp, int64_t bcastPos, int64_t step=1)
std::optional< int64_t > getTransferReadAlignmentOffset(TransferReadLikeOp readOp, mlir::VectorType vType, int64_t alignment)
mlir::VectorType createVectorType(unsigned lanes, mlir::Type elementType)
int32_t getElementSizeInBits(mlir::VectorType type)
void buildLowerVectorToAIEVec(mlir::OpPassManager &pm, const LowerVectorToAIEVecOptions &options)
mlir::VectorType getVectorOpDestType(mlir::VectorType type, bool AIE2)
LogicalResult matchAndRewrite(SrcOpTy absOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(arith::XOrIOp xorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::CeilOp ceilOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::ErfOp erfOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::ExpOp expOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::ExpOp expOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::FloorOp floorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::DivFOp divOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::NegFOp negOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::RsqrtOp rsqrtOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::RsqrtOp rsqrtOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::DivFOp divfOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::ShRSIOp rsOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::SqrtOp sqrtOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::TanhOp tanhOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::DivFOp divOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::ExpOp expOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::TanhOp tanhOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::AddFOp addOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertMulAddFToAIEVecFMAElemOpPattern(MLIRContext *context, unsigned shiftParam=0)
LogicalResult matchAndRewrite(arith::AddIOp addOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertMulAddToAIEVecFMAElemOpPattern(MLIRContext *context, unsigned shiftParam=0)
LogicalResult matchAndRewrite(aievec::aie1::AddOp addOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::MulFOp mulOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertMulFToAIEVecMulElemOpPattern(MLIRContext *context, unsigned shiftParam=0)
LogicalResult matchAndRewrite(arith::MulIOp mulOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertMulIToAIEVecMulElemOpPattern(MLIRContext *context, unsigned shiftParam=0)
LogicalResult matchAndRewrite(vector::BroadcastOp bcastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::BroadcastOp bcastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertVectorFMAOpToAIEVecFMAElemOpPattern(MLIRContext *context, unsigned shiftParam=0)
ExpandUPDToUPDAndExtPattern(MLIRContext *context)
LogicalResult matchAndRewrite(aievec::UPDOp updOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
void runOnOperation() override
LogicalResult matchAndRewrite(aievec::aie1::FMAOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::BroadcastOp bcastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ExtOp extOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
FuseExtIntoUPDPattern(MLIRContext *context)
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(SrcOpTy extOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::ShRSIOp rsOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::TruncIOp truncOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static std::optional< int64_t > getScalarConstantValue(Value val)
LogicalResult matchAndRewrite(SrcOpTy truncOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(arith::AddIOp addOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(vector::ContractionOp contractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LowerVectorContractionOpToAIEVecMatMulPattern(MLIRContext *context, bool matMoveToAcc=true)
Value reshapeLeadingUnitDims(OpBuilder &b, Value v) const
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(arith::MulIOp mulOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::SelectOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Lower incoming vector operations into their corresponding AIE vector intrinsics.
void getDependentDialects(DialectRegistry ®istry) const override
LowerVectorToAIEVec(const LowerVectorToAIEVecOptions &options)
void runOnOperation() override
StringRef getDescription() const final
Option< std::string > aieTarget
Option< std::string > targetBackend
StringRef getArgument() const final
LogicalResult matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LowerVectorTransferReadToAIEUPD(MLIRContext *context, int64_t minVectorSize, int64_t maxVectorSize, int64_t alignment, int64_t maxLoadSize)
LogicalResult matchAndRewrite(vector::TransposeOp transpOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ShiftClampTruncToSRSPattern(MLIRContext *context, PatternBenefit benefit=2)
static std::optional< Value > getShiftValue(Value rhs, ConversionPatternRewriter &rewriter, Location loc)
static std::optional< int64_t > getConstantSplatValue(Value val)
LogicalResult matchAndRewrite(arith::TruncIOp truncOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
void runOnOperation() override
Options for the "lower-vector-to-aievec" pipeline.
PassOptions::Option< std::string > aieTarget
PassOptions::Option< std::string > targetBackend