21#include "mlir/Dialect/Affine/IR/AffineOps.h"
22#include "mlir/Dialect/EmitC/IR/EmitC.h"
23#include "mlir/Dialect/Func/IR/FuncOps.h"
24#include "mlir/Dialect/Math/IR/Math.h"
25#include "mlir/Dialect/MemRef/IR/MemRef.h"
26#include "mlir/Dialect/SCF/IR/SCF.h"
27#include "mlir/IR/PatternMatch.h"
28#include "mlir/IR/SymbolTable.h"
29#include "mlir/IR/TypeUtilities.h"
30#include "mlir/Pass/PassManager.h"
31#include "mlir/Transforms/DialectConversion.h"
32#include "mlir/Transforms/Passes.h"
33#include "llvm/ADT/SmallSet.h"
38#define DEBUG_TYPE "lower-vector-to-aievec"
43using namespace vector;
51static bool isNarrowingOp(Operation *op) {
52 if (isa<arith::TruncFOp>(op) || isa<arith::TruncIOp>(op))
55 if (
auto srsOp = dyn_cast<aievec::SRSOp>(op)) {
56 auto *srsOpSrcOp = srsOp.getSource().getDefiningOp();
57 if (isa<aievec::UPSOp>(srsOpSrcOp) || isa<aievec::CastOp>(srsOpSrcOp))
66static std::optional<Value> getSourceOfWideningOp(Value src) {
67 if (
auto extSIOp =
src.getDefiningOp<arith::ExtSIOp>())
68 return extSIOp.getIn();
69 if (
auto extUIOp =
src.getDefiningOp<arith::ExtUIOp>())
70 return extUIOp.getIn();
71 if (
auto extFOp =
src.getDefiningOp<arith::ExtFOp>())
72 return extFOp.getIn();
73 if (
auto srsOp =
src.getDefiningOp<aievec::SRSOp>()) {
77 auto srsSource = srsOp.getSource();
79 if (
auto upsOp = srsSource.getDefiningOp<aievec::UPSOp>())
80 return upsOp.getSource();
82 if (
auto castOp =
src.getDefiningOp<aievec::CastOp>()) {
86 auto castSource = castOp.getSource();
88 if (
auto upsOp = castSource.getDefiningOp<aievec::UPSOp>())
89 return upsOp.getSource();
91 return std::optional<Value>();
97static std::optional<std::tuple<Value, Value, Value>>
98extractMACOperandsFromAddOperands(Value addLhs, Value addRhs) {
99 auto *lhsDefOp = addLhs.getDefiningOp();
100 auto *rhsDefOp = addRhs.getDefiningOp();
101 arith::MulIOp mulOp =
nullptr;
104 mulOp = dyn_cast<arith::MulIOp>(lhsDefOp);
107 if (!mulOp && rhsDefOp) {
108 mulOp = dyn_cast<arith::MulIOp>(rhsDefOp);
112 return std::make_tuple(mulOp.getLhs(), mulOp.getRhs(), acc);
115 auto lhsSrsOp = addLhs.getDefiningOp<aievec::SRSOp>();
116 auto rhsSrsOp = addRhs.getDefiningOp<aievec::SRSOp>();
117 aievec::aie1::MulOp aieMulOp =
nullptr;
119 aieMulOp = lhsSrsOp.getSource().getDefiningOp<aievec::aie1::MulOp>();
122 if (!aieMulOp && rhsSrsOp) {
123 aieMulOp = rhsSrsOp.getSource().getDefiningOp<aievec::aie1::MulOp>();
127 return std::make_tuple(aieMulOp.getLhs(), aieMulOp.getRhs(), acc);
134static std::optional<Value>
135convertValueToTargetTypeAIE2(ConversionPatternRewriter &rewriter, Location loc,
136 Value inputVal, VectorType tgtType) {
137 auto srcType = cast<VectorType>(inputVal.getType());
138 auto srcElemType = srcType.getElementType();
139 unsigned srcBitWidth = srcElemType.getIntOrFloatBitWidth();
142 auto tgtElemType = tgtType.getElementType();
143 unsigned tgtBitWidth = tgtElemType.getIntOrFloatBitWidth();
146 if (srcType == tgtType)
149 if ((srcElemType == tgtElemType) && (srcLaneSize != tgtLaneSize)) {
151 if ((srcLaneSize == 16 && tgtLaneSize == 32 &&
152 isa<FloatType>(srcElemType)) ||
153 (srcLaneSize == 32 && tgtLaneSize == 64 &&
154 isa<IntegerType>(srcElemType))) {
155 auto zeroConstOp = rewriter.create<arith::ConstantOp>(
156 loc, srcType.getElementType(),
157 rewriter.getZeroAttr(srcType.getElementType()));
158 auto broadcastZeroOp = rewriter.create<aievec::BroadcastScalarOp>(
159 loc, tgtType, zeroConstOp->getResult(0));
160 auto extOp = rewriter.create<aievec::ExtOp>(
161 loc, srcType, broadcastZeroOp->getResult(0), 0);
163 SmallVector<Value> inputSources = {inputVal, extOp->getResult(0)};
165 rewriter.create<aievec::ConcatOp>(loc, tgtType, inputSources);
167 return concatOp.getResult();
169 }
else if ((srcElemType != tgtElemType) && (srcLaneSize == tgtLaneSize) &&
170 isa<IntegerType>(srcElemType) && isa<IntegerType>(tgtElemType)) {
171 if (srcBitWidth == 16 && tgtBitWidth == 32 && srcLaneSize == 16) {
175 auto upsOp = rewriter.create<aievec::UPSOp>(loc, accType, inputVal);
176 auto castOp = rewriter.create<aievec::CastOp>(
177 loc, tgtType, upsOp.getResult(),
false);
178 return castOp.getResult();
181 if (srcBitWidth == 8 && tgtBitWidth == 32 && srcLaneSize == 16) {
185 auto concatOp = rewriter.create<aievec::ConcatOp>(
186 loc, concatOutType, SmallVector<Value>({inputVal, inputVal}));
189 rewriter.create<aievec::UPSOp>(loc, accType, concatOp.getResult());
191 auto castOp = rewriter.create<aievec::CastOp>(
192 loc, castType, upsOp.getResult(),
false);
194 rewriter.create<aievec::ExtOp>(loc, tgtType, castOp.getResult(), 0);
195 return extOp.getResult();
198 if (srcBitWidth == 8 && tgtBitWidth == 16 && srcLaneSize == 32) {
200 auto unpackOp = rewriter.create<aievec::UnpackOp>(loc, tgtType, inputVal);
201 return unpackOp.getResult();
211static SmallVector<NamedAttribute>
212buildAttributeListForRotationSelectOp(PatternRewriter &rewriter, VectorType vTy,
215 auto elemTy = vTy.getElementType();
216 if (
auto intTy = dyn_cast<IntegerType>(elemTy))
217 width = intTy.getWidth();
218 StringAttr attr0 = rewriter.getStringAttr(
"0");
219 StringAttr attr0x06040200 = rewriter.getStringAttr(
"0x06040200");
220 StringAttr attr0x0e0c0a08 = rewriter.getStringAttr(
"0x0e0c0a08");
221 StringAttr attr0x2103 = rewriter.getStringAttr(
"0x2103");
222 StringAttr attr0x3210 = rewriter.getStringAttr(
"0x3210");
223 StringAttr selectAttrName = rewriter.getStringAttr(
"select");
224 StringAttr xoffsetsAttrName = rewriter.getStringAttr(
"xoffsets");
225 StringAttr xoffsetsHiAttrName = rewriter.getStringAttr(
"xoffsets_hi");
226 StringAttr xsquareAttrName = rewriter.getStringAttr(
"xsquare");
227 StringAttr xstartAttrName = rewriter.getStringAttr(
"xstart");
228 StringAttr yoffsetsAttrName = rewriter.getStringAttr(
"yoffsets");
229 StringAttr yoffsetsHiAttrName = rewriter.getStringAttr(
"yoffsets_hi");
230 StringAttr ysquareAttrName = rewriter.getStringAttr(
"ysquare");
231 StringAttr ystartAttrName = rewriter.getStringAttr(
"ystart");
236 int64_t xstart = rotation + 1;
237 int64_t ystart = rotation - 1;
238 return SmallVector<NamedAttribute, 9>(
239 {{selectAttrName, rewriter.getStringAttr(
"0x11111111")},
240 {xoffsetsAttrName, attr0x06040200},
241 {xoffsetsHiAttrName, attr0x0e0c0a08},
242 {xsquareAttrName, attr0x2103},
243 {xstartAttrName, rewriter.getStringAttr(std::to_string(xstart))},
244 {yoffsetsAttrName, rewriter.getStringAttr(
"0x0503010f")},
245 {yoffsetsHiAttrName, rewriter.getStringAttr(
"0x0d0b0907")},
246 {ysquareAttrName, attr0x2103},
247 {ystartAttrName, rewriter.getStringAttr(std::to_string(ystart))}});
249 return SmallVector<NamedAttribute, 9>(
250 {{selectAttrName, attr0},
251 {xoffsetsAttrName, attr0x06040200},
252 {xoffsetsHiAttrName, attr0x0e0c0a08},
253 {xsquareAttrName, attr0x3210},
254 {xstartAttrName, rewriter.getStringAttr(std::to_string(rotation))},
255 {yoffsetsAttrName, attr0},
256 {yoffsetsHiAttrName, attr0},
257 {ysquareAttrName, attr0},
258 {ystartAttrName, attr0}});
261 return SmallVector<NamedAttribute, 7>(
262 {{selectAttrName, attr0},
263 {xoffsetsAttrName, rewriter.getStringAttr(
"0x76543210")},
264 {xsquareAttrName, attr0x3210},
265 {xstartAttrName, rewriter.getStringAttr(std::to_string(rotation))},
266 {yoffsetsAttrName, attr0},
267 {ysquareAttrName, attr0},
268 {ystartAttrName, attr0}});
270 llvm::report_fatal_error(
"Unexpected width!");
278SmallVector<NamedAttribute>
282 auto elemTy = fmaOp.getLhs().getType().getElementType();
283 if (
auto intTy = dyn_cast<IntegerType>(elemTy))
284 width = intTy.getWidth();
285 auto *ctx = fmaOp.getContext();
305 return SmallVector<NamedAttribute, 11>(
306 {{fmaOp.getXstartAttrName(), StringAttr::get(ctx,
"0")},
307 {fmaOp.getXoffsetsAttrName(), StringAttr::get(ctx,
"0x73727170")},
308 {fmaOp.getXoffsetsHiAttrName(), StringAttr::get(ctx,
"0x77767574")},
309 {fmaOp.getXstepAttrName(), fmaOp.getXstepAttr()},
310 {fmaOp.getXsquareAttrName(), StringAttr::get(ctx,
"0x3120")},
311 {fmaOp.getZstartAttrName(),
312 StringAttr::get(ctx, std::to_string(bcastPos))},
313 {fmaOp.getZoffsetsAttrName(), StringAttr::get(ctx,
"0")},
314 {fmaOp.getZoffsetsHiAttrName(), StringAttr::get(ctx,
"0")},
315 {fmaOp.getZstepAttrName(), StringAttr::get(ctx, std::to_string(step))},
316 {fmaOp.getZsquareAttrName(), fmaOp.getZsquareAttr()},
317 {fmaOp.getFmsubAttrName(), fmaOp.getFmsubAttr()}});
319 return SmallVector<NamedAttribute, 11>(
320 {{fmaOp.getXstartAttrName(), StringAttr::get(ctx,
"0")},
321 {fmaOp.getXoffsetsAttrName(), StringAttr::get(ctx,
"0x76543210")},
322 {fmaOp.getXoffsetsHiAttrName(), fmaOp.getXoffsetsHiAttr()},
323 {fmaOp.getXstepAttrName(), fmaOp.getXstepAttr()},
324 {fmaOp.getXsquareAttrName(), fmaOp.getXsquareAttr()},
325 {fmaOp.getZstartAttrName(),
326 StringAttr::get(ctx, std::to_string(bcastPos))},
327 {fmaOp.getZoffsetsAttrName(), StringAttr::get(ctx,
"0x00000000")},
328 {fmaOp.getZoffsetsHiAttrName(), fmaOp.getZoffsetsHiAttr()},
329 {fmaOp.getZstepAttrName(), fmaOp.getZstepAttr()},
330 {fmaOp.getZsquareAttrName(), fmaOp.getZsquareAttr()},
331 {fmaOp.getFmsubAttrName(), fmaOp.getFmsubAttr()}});
333 llvm::report_fatal_error(
"Unexpected width!");
341template <
typename SrcOpTy,
typename AIEv2ElemOp>
342static LogicalResult genAddElemAIE2(ConversionPatternRewriter &rewriter,
343 Value lval, Value rval, VectorType srcType,
345 auto lCastOp = rewriter.create<aievec::CastOp>(srcOp.getLoc(), srcType, lval,
347 auto rCastOp = rewriter.create<aievec::CastOp>(srcOp.getLoc(), srcType, rval,
349 auto elemOp = rewriter.create<AIEv2ElemOp>(
350 srcOp.getLoc(), lCastOp->getResult(0).getType(), lCastOp->getResult(0),
351 rCastOp->getResult(0));
352 rewriter.replaceOpWithNewOp<aievec::CastOp>(
353 srcOp, srcOp.getType(), elemOp.getResult(),
false);
357static arith::CmpIPredicate
358convertToIntegerPredicate(arith::CmpFPredicate pred) {
360 case CmpFPredicate::UEQ:
361 case CmpFPredicate::OEQ:
362 return CmpIPredicate::eq;
363 case CmpFPredicate::UGT:
364 return CmpIPredicate::ugt;
365 case CmpFPredicate::OGT:
366 return CmpIPredicate::sgt;
367 case CmpFPredicate::UGE:
368 return CmpIPredicate::uge;
369 case CmpFPredicate::OGE:
370 return CmpIPredicate::sge;
371 case CmpFPredicate::ULT:
372 return CmpIPredicate::ult;
373 case CmpFPredicate::OLT:
374 return CmpIPredicate::slt;
375 case CmpFPredicate::ULE:
376 return CmpIPredicate::ule;
377 case CmpFPredicate::OLE:
378 return CmpIPredicate::sle;
379 case CmpFPredicate::UNE:
380 case CmpFPredicate::ONE:
381 return CmpIPredicate::ne;
383 llvm::report_fatal_error(
"Unexpected predicate!");
387static arith::CmpIPredicate
388convertToIntegerPredicate(arith::CmpIPredicate pred) {
392static aievec::CmpOp createCmpOpAIE2(ConversionPatternRewriter &rewriter,
393 CmpIPredicate pred, Location loc,
394 Type type, Value lhs, Value rhs) {
396 case CmpIPredicate::eq:
397 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs,
"eq");
398 case CmpIPredicate::ne:
399 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs,
"ne");
400 case CmpIPredicate::slt:
401 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs,
"slt");
402 case CmpIPredicate::ult:
403 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs,
"ult");
404 case CmpIPredicate::sle:
405 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs,
"sle");
406 case CmpIPredicate::ule:
407 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs,
"ule");
408 case CmpIPredicate::sgt:
409 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs,
"sgt");
410 case CmpIPredicate::ugt:
411 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs,
"ugt");
412 case CmpIPredicate::sge:
413 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs,
"sge");
414 case CmpIPredicate::uge:
415 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs,
"uge");
420template <
typename DstOpTy>
421static void generateAIEVecOpsForReductionOp(ConversionPatternRewriter &rewriter,
422 vector::ReductionOp srcOp,
423 int shiftIndex, Value curValue) {
424 assert(shiftIndex > 0 && (shiftIndex & (shiftIndex - 1)) == 0 &&
425 "shiftIndex must be power of 2");
427 Location loc = srcOp.getLoc();
428 auto vType = dyn_cast<VectorType>(curValue.getType());
429 Type scalarType = vType.getElementType();
430 Type vecType = curValue.getType();
431 DstOpTy curOp =
nullptr;
432 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
434 for (
int id = shiftIndex;
id > 0;
id /= 2) {
435 auto constOp = rewriter.create<arith::ConstantOp>(
436 loc, rewriter.getI32IntegerAttr(
id * elWidth / 8));
438 auto shiftBytesOp = rewriter.create<aievec::ShiftOp>(
439 loc, vecType, curValue, curValue, constOp.getResult());
441 curOp = rewriter.create<DstOpTy>(loc, vecType, curValue,
442 shiftBytesOp.getResult());
444 curValue = curOp.getResult();
448 rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
449 rewriter.replaceOpWithNewOp<aievec::ExtElemOp>(srcOp, scalarType, curOp,
450 zeroConstOp.getResult());
453static func::FuncOp getOrInsertFuncDecl(ConversionPatternRewriter &rewriter,
454 mlir::ModuleOp parentModuleOp,
455 StringRef funcName, TypeRange inTypes,
456 TypeRange outTypes) {
458 mlir::OpBuilder::InsertionGuard insertGuard(rewriter);
459 rewriter.setInsertionPointToStart(
460 &parentModuleOp.getRegion().getBlocks().front());
461 SymbolTable st = SymbolTable(parentModuleOp);
462 func::FuncOp fnOpLookup = st.lookup<func::FuncOp>(funcName);
466 if (fnOpLookup != NULL) {
469 StringAttr t1 = rewriter.getStringAttr(
"sym_visibility");
470 StringAttr t2 = rewriter.getStringAttr(
"private");
471 NamedAttribute funcAccess = NamedAttribute(t1, t2);
472 FunctionType fnType =
473 mlir::FunctionType::get(rewriter.getContext(), inTypes, outTypes);
474 fnOp = rewriter.create<func::FuncOp>(parentModuleOp.getLoc(), funcName,
480static bool matchExpOpForLUT(math::ExpOp::Adaptor adaptor) {
481 auto srcType = dyn_cast<VectorType>(adaptor.getOperand().getType());
486 Type scalarType = srcType.getElementType();
487 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
489 return isa<FloatType>(scalarType) && laneSize == 16 && elWidth == 16;
500 using OpConversionPattern::OpConversionPattern;
504 ConversionPatternRewriter &rewriter)
const override {
506 auto extOp = adaptor.getInput().getDefiningOp<vector::ExtractOp>();
511 auto src = extOp.getVector();
512 auto pos = extOp.getStaticPosition();
513 int64_t posVal = pos[0];
514 auto srcVecType = cast<VectorType>(src.getType());
515 auto resultType = cast<VectorType>(splatOp.getResult().getType());
516 if (srcVecType != resultType) {
517 if (srcVecType.getNumElements() != 2 * resultType.getNumElements())
519 auto half =
static_cast<int8_t
>(posVal / resultType.getNumElements());
520 posVal -= half * resultType.getNumElements();
522 .create<aievec::ExtOp>(extOp.getLoc(), resultType, src,
523 rewriter.getI8IntegerAttr(half))
527 unsigned elWidth = resultType.getElementType().getIntOrFloatBitWidth();
530 laneSize * elWidth == 512) {
532 rewriter.replaceOpWithNewOp<aievec::BroadcastOp>(splatOp, resultType, src,
534 }
else if (laneSize * elWidth == 256) {
536 VectorType aievecBcastType =
538 auto concatOp = rewriter.create<aievec::ConcatOp>(
539 splatOp.getLoc(), aievecBcastType, SmallVector<Value>({src, src}));
540 auto aieBcastOp = rewriter.create<aievec::BroadcastOp>(
541 splatOp.getLoc(), aievecBcastType, concatOp.getResult(), posVal);
542 rewriter.replaceOpWithNewOp<aievec::ExtOp>(splatOp, resultType,
543 aieBcastOp.getResult(), 0);
544 }
else if (laneSize * elWidth == 1024) {
546 VectorType aievecBcastType =
548 auto half =
static_cast<int8_t
>(posVal / resultType.getNumElements());
549 posVal -= half * resultType.getNumElements();
551 rewriter.create<aievec::ExtOp>(splatOp.getLoc(), aievecBcastType, src,
552 rewriter.getI8IntegerAttr(half));
553 auto aieBcastOp = rewriter.create<aievec::BroadcastOp>(
554 splatOp.getLoc(), aievecBcastType, extOp.getResult(), posVal);
555 rewriter.replaceOpWithNewOp<aievec::ConcatOp>(
557 SmallVector<Value>({aieBcastOp.getResult(), aieBcastOp.getResult()}));
567 using OpConversionPattern::OpConversionPattern;
571 ConversionPatternRewriter &rewriter)
const override {
573 if (adaptor.getInput().getDefiningOp<vector::ExtractOp>())
576 auto resultType = cast<VectorType>(splatOp.getResult().getType());
578 Type scalarType = resultType.getElementType();
579 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
581 auto src = splatOp.getInput();
583 if (laneSize * elWidth == 512) {
584 Value newOp = rewriter.create<aievec::BroadcastScalarOp>(
585 splatOp.getLoc(), flatResultType, src);
586 if (resultType != flatResultType)
587 newOp = rewriter.create<vector::ShapeCastOp>(splatOp.getLoc(),
589 rewriter.replaceOp(splatOp, newOp);
593 if (laneSize * elWidth == 256) {
595 auto aieBcastOp = rewriter.create<aievec::BroadcastScalarOp>(
596 splatOp.getLoc(), vecType, src);
597 Value newOp = rewriter.create<aievec::ExtOp>(
598 splatOp.getLoc(), flatResultType, aieBcastOp.getResult(), 0);
599 if (resultType != flatResultType)
600 newOp = rewriter.create<vector::ShapeCastOp>(splatOp.getLoc(),
602 rewriter.replaceOp(splatOp, newOp);
606 if (laneSize * elWidth == 1024) {
608 auto aieBcastOp = rewriter.create<aievec::BroadcastScalarOp>(
609 splatOp.getLoc(), vecType, src);
610 Value newOp = rewriter.create<aievec::ConcatOp>(
611 splatOp.getLoc(), flatResultType,
612 SmallVector<Value>({aieBcastOp.getResult(), aieBcastOp.getResult()}));
613 if (resultType != flatResultType)
614 newOp = rewriter.create<vector::ShapeCastOp>(splatOp.getLoc(),
616 rewriter.replaceOp(splatOp, newOp);
628 using OpConversionPattern::OpConversionPattern;
636 ConversionPatternRewriter &rewriter)
const override {
638 auto resultType = dyn_cast<VectorType>(addOp.getType());
644 extractMACOperandsFromAddOperands(adaptor.getLhs(), adaptor.getRhs());
647 auto [lhs, rhs, acc] = *res;
650 unsigned resultElWidth =
651 resultType.getElementType().getIntOrFloatBitWidth();
654 if ((laneSize != 32 || resultElWidth != 16) &&
655 (laneSize != 16 || resultElWidth != 32))
660 auto upsOp = rewriter.create<aievec::UPSOp>(addOp.getLoc(), accType, acc,
662 auto fmaElemOp = rewriter.create<aievec::FMAElemOp>(
663 addOp.getLoc(), accType, lhs, rhs, upsOp.getResult(),
666 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
667 addOp.getLoc(), rewriter.getI32IntegerAttr(
shiftParam));
668 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
669 addOp, resultType, fmaElemOp.getResult(), shiftParamOp.getResult());
685 using OpConversionPattern::OpConversionPattern;
693 ConversionPatternRewriter &rewriter)
const override {
695 auto resVecTy = cast<VectorType>(fmaOp.getType());
696 auto resElemTy = resVecTy.getElementType();
699 if (numElems != 16 || (!resElemTy.isF32() && !resElemTy.isBF16()))
700 return rewriter.notifyMatchFailure(
701 fmaOp,
"Unsupported operand types in vector.fma lowering.");
703 Value lhs = adaptor.getLhs();
704 Value rhs = adaptor.getRhs();
705 Value acc = adaptor.getAcc();
706 if (resElemTy.isBF16())
707 acc = rewriter.create<aievec::UPSOp>(
708 fmaOp.getLoc(), VectorType::get({16}, rewriter.getF32Type()), acc,
711 lhs = getSourceOfWideningOp(lhs).value_or(
nullptr);
712 rhs = getSourceOfWideningOp(rhs).value_or(
nullptr);
714 return rewriter.notifyMatchFailure(
715 fmaOp,
"vector.fma operands are f32, and they don't come from "
716 "arith.extf on bf16; can't lower to aievec.");
717 if (!cast<VectorType>(lhs.getType()).getElementType().isBF16() ||
718 !cast<VectorType>(rhs.getType()).getElementType().isBF16())
719 return rewriter.notifyMatchFailure(
720 fmaOp,
"vector.fma operands come from arith.extf, but the source "
721 "of the widening op is not bf16; can't lower to aievec.");
723 Value newOp = rewriter.create<aievec::FMAElemOp>(
724 fmaOp.getLoc(), acc.getType(), lhs, rhs, acc,
false);
726 if (resElemTy.isBF16()) {
727 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
728 fmaOp.getLoc(), rewriter.getI32IntegerAttr(
shiftParam));
729 newOp = rewriter.create<aievec::SRSOp>(fmaOp.getLoc(), resVecTy, newOp,
733 rewriter.replaceOp(fmaOp, newOp);
745 using OpConversionPattern::OpConversionPattern;
753 ConversionPatternRewriter &rewriter)
const override {
755 auto resultType = dyn_cast<VectorType>(mulOp.getType());
760 auto isAddOp = [&](Operation *op) {
return isa<arith::AddFOp>(op); };
761 if (mulOp->hasOneUse() && llvm::any_of(mulOp->getUsers(), isAddOp))
764 unsigned resultElWidth =
765 resultType.getElementType().getIntOrFloatBitWidth();
770 if (laneSize != 16 || (resultElWidth != 16 && resultElWidth != 32))
774 auto lval = adaptor.getLhs();
775 auto rval = adaptor.getRhs();
776 lval = getSourceOfWideningOp(lval).value_or(lval);
777 rval = getSourceOfWideningOp(rval).value_or(rval);
778 auto lSrcType = cast<VectorType>(lval.getType());
779 auto rSrcType = cast<VectorType>(rval.getType());
780 unsigned lBitWidth = lSrcType.getElementType().getIntOrFloatBitWidth();
781 unsigned rBitWidth = rSrcType.getElementType().getIntOrFloatBitWidth();
783 if (rBitWidth > lBitWidth) {
787 if (lSrcType != rSrcType) {
792 unsigned bitWidth = (rBitWidth > lBitWidth) ? rBitWidth : lBitWidth;
793 Type srcElemType = (rBitWidth > lBitWidth) ? rSrcType.getElementType()
794 : lSrcType.getElementType();
795 unsigned numLanes = 0;
796 if (isa<FloatType>(srcElemType) && (bitWidth == 16 || bitWidth == 32)) {
798 }
else if (isa<IntegerType>(srcElemType) &&
799 (bitWidth == 8 || bitWidth == 16)) {
801 }
else if (isa<IntegerType>(srcElemType) && (bitWidth == 32)) {
807 if (targetInputType != lSrcType) {
808 lval = convertValueToTargetTypeAIE2(rewriter, mulOp.getLoc(), lval,
812 if (targetInputType != rSrcType) {
813 rval = convertValueToTargetTypeAIE2(rewriter, mulOp.getLoc(), rval,
822 rewriter.create<aievec::MulElemOp>(mulOp.getLoc(), accType, lval, rval);
825 auto mulElemResultType = mulElemOp.getType();
826 auto mulElemResultElWidth =
827 mulElemResultType.getElementType().getIntOrFloatBitWidth();
829 if (mulElemResultElWidth == resultElWidth) {
830 rewriter.replaceOpWithNewOp<aievec::CastOp>(
831 mulOp, resultType, mulElemOp.getResult(),
false);
832 }
else if (mulElemResultElWidth > resultElWidth) {
833 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
834 mulOp.getLoc(), rewriter.getI32IntegerAttr(
shiftParam));
835 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
836 mulOp, resultType, mulElemOp.getResult(), shiftParamOp.getResult());
851 using OpConversionPattern::OpConversionPattern;
859 ConversionPatternRewriter &rewriter)
const override {
861 auto resultType = dyn_cast<VectorType>(mulOp.getType());
866 auto isAddOp = [&](Operation *op) {
return isa<arith::AddIOp>(op); };
867 if (mulOp->hasOneUse() && llvm::any_of(mulOp->getUsers(), isAddOp))
871 unsigned resultElWidth =
872 resultType.getElementType().getIntOrFloatBitWidth();
875 if ((laneSize != 32 || (resultElWidth != 16 && resultElWidth != 8)) &&
876 ((laneSize != 16 && laneSize != 32) || resultElWidth != 32))
880 auto lval = adaptor.getLhs();
881 auto rval = adaptor.getRhs();
883 lval = getSourceOfWideningOp(lval).value_or(lval);
884 rval = getSourceOfWideningOp(rval).value_or(rval);
886 auto lSrcType = cast<VectorType>(lval.getType());
887 auto rSrcType = cast<VectorType>(rval.getType());
888 unsigned lBitWidth = lSrcType.getElementType().getIntOrFloatBitWidth();
889 unsigned rBitWidth = rSrcType.getElementType().getIntOrFloatBitWidth();
891 if (rBitWidth > lBitWidth) {
896 unsigned bitWidth = (rBitWidth > lBitWidth) ? rBitWidth : lBitWidth;
897 Type srcElemType = (rBitWidth > lBitWidth) ? rSrcType.getElementType()
898 : lSrcType.getElementType();
899 unsigned numLanes = 0;
900 if (isa<FloatType>(srcElemType) && (bitWidth == 16 || bitWidth == 32)) {
902 }
else if (isa<IntegerType>(srcElemType) &&
903 (bitWidth == 8 || bitWidth == 16)) {
905 }
else if (isa<IntegerType>(srcElemType) && (bitWidth == 32)) {
911 if (targetInputType != lSrcType) {
912 lval = convertValueToTargetTypeAIE2(rewriter, mulOp.getLoc(), lval,
916 if (targetInputType != rSrcType) {
917 rval = convertValueToTargetTypeAIE2(rewriter, mulOp.getLoc(), rval,
926 rewriter.create<aievec::MulElemOp>(mulOp.getLoc(), accType, lval, rval);
929 auto mulElemResultType = mulElemOp.getType();
930 auto mulElemResultElWidth =
931 mulElemResultType.getElementType().getIntOrFloatBitWidth();
933 if (mulElemResultElWidth == resultElWidth) {
934 rewriter.replaceOpWithNewOp<aievec::CastOp>(
935 mulOp, resultType, mulElemOp.getResult(),
false);
936 }
else if (mulElemResultElWidth > resultElWidth) {
937 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
938 mulOp.getLoc(), rewriter.getI32IntegerAttr(
shiftParam));
939 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
940 mulOp, resultType, mulElemOp.getResult(), shiftParamOp.getResult());
954 using OpConversionPattern::OpConversionPattern;
958 ConversionPatternRewriter &rewriter)
const override {
960 dyn_cast<aievec::ConcatOp>(adaptor.getLhs().getDefiningOp());
963 vector::SplatOp splatOp =
nullptr;
964 auto *concatDefOp = concatOp.getSources()[0].getDefiningOp();
966 splatOp = dyn_cast<vector::SplatOp>(concatDefOp);
967 Value lhs = adaptor.getRhs();
969 splatOp = dyn_cast<vector::SplatOp>(adaptor.getRhs().getDefiningOp());
972 lhs = concatOp.getSources()[0];
975 dyn_cast<vector::ExtractOp>(splatOp.getInput().getDefiningOp());
979 auto rhs = extOp.getVector();
980 auto concatVecType = cast<VectorType>(concatOp.getResult().getType());
981 auto zvec = rewriter.create<arith::ConstantOp>(
982 concatOp.getLoc(), lhs.getType(), rewriter.getZeroAttr(lhs.getType()));
985 .create<aievec::ConcatOp>(concatOp.getLoc(), concatVecType,
986 SmallVector<Value, 2>({lhs, zvec}))
989 auto pos = extOp.getStaticPosition();
990 int64_t zstart = pos[0];
992 rewriter.replaceOpWithNewOp<aievec::aie1::FMAOp>(
993 fmaOp, TypeRange({fmaOp.getResult().getType()}),
994 ValueRange({lhsX2, rhs, adaptor.getAcc()}), fmaOpAttr);
1002 using OpConversionPattern::OpConversionPattern;
1006 ConversionPatternRewriter &rewriter)
const override {
1007 auto vecType = cast<VectorType>(addOp.getType());
1010 extractMACOperandsFromAddOperands(adaptor.getLhs(), adaptor.getRhs());
1013 auto [lhs, rhs, acc] = *res;
1015 SmallVector<int64_t, 4> concatVecShape(vecType.getShape().begin(),
1016 vecType.getShape().end());
1017 concatVecShape[vecType.getRank() - 1] *= 2;
1018 auto concatVecType =
1019 VectorType::get(concatVecShape, vecType.getElementType());
1022 auto lhsX2 = rewriter
1023 .create<aievec::ConcatOp>(addOp.getLoc(), concatVecType,
1024 SmallVector<Value, 2>(2, lhs))
1026 auto upsOp = rewriter.create<aievec::UPSOp>(addOp.getLoc(), accType, acc);
1027 auto fmaOp = rewriter.create<aievec::aie1::FMAOp>(
1028 addOp.getLoc(), accType, lhsX2, rhs, upsOp.getResult(),
1032 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
1033 addOp.getLoc(), rewriter.getI32IntegerAttr(0));
1034 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1035 addOp, vecType, fmaOp.getResult(), shiftParamOp.getResult());
1045 using OpConversionPattern::OpConversionPattern;
1056 ConversionPatternRewriter &rewriter)
const override {
1058 if (readOp.getMask())
1059 return readOp.emitError() <<
"AIE doesn't support masked loads.";
1062 AffineMap map = readOp.getPermutationMap();
1063 if (!map.isMinorIdentity())
1067 if (map.isConstant())
1071 auto vType = readOp.getVectorType();
1081 int64_t vSize = vType.getNumElements() * vType.getElementTypeBitWidth();
1090 if ((vSize >
minVectorSize) && std::bitset<8>(multiplicity).count() != 1)
1093 auto updOp = rewriter.create<xilinx::aievec::UPDOp>(
1094 readOp.getLoc(), vType, adaptor.getSource(), adaptor.getIndices(), 0, 0,
1095 TypedValue<VectorType>(
nullptr));
1097 updOp = rewriter.create<xilinx::aievec::UPDOp>(
1098 readOp.getLoc(), vType, adaptor.getSource(), adaptor.getIndices(),
1101 rewriter.replaceOp(readOp, updOp.getResult());
1111template <
typename SrcOpTy,
typename DstOpTy>
1118 ConversionPatternRewriter &rewriter)
const override {
1119 rewriter.replaceOpWithNewOp<DstOpTy>(
1120 srcOp, srcOp.getResult().getType(), adaptor.getLhs(), adaptor.getRhs(),
1128 using OpConversionPattern::OpConversionPattern;
1132 ConversionPatternRewriter &rewriter)
const override {
1133 auto resType = addOp.getType();
1134 if (!isa<VectorType>(resType))
1137 auto lhs = adaptor.getLhs();
1138 auto rhs = adaptor.getRhs();
1139 auto *lhsDefOp = lhs.getDefiningOp();
1140 auto *rhsDefOp = rhs.getDefiningOp();
1141 if ((isa_and_nonnull<arith::MulIOp>(lhsDefOp)) ||
1142 (isa_and_nonnull<arith::MulIOp>(rhsDefOp)))
1145 rewriter.replaceOpWithNewOp<aievec::aie1::AddOp>(
1146 addOp, resType, lhs, rhs,
1163 using OpConversionPattern::OpConversionPattern;
1166 ConversionPatternRewriter &rewriter)
const override {
1167 auto resTy = dyn_cast<VectorType>(mulOp.getType());
1171 auto newMulOp = rewriter.create<aievec::aie1::MulOp>(
1172 mulOp.getLoc(), accTy, adaptor.getLhs(), adaptor.getRhs());
1173 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
1174 mulOp.getLoc(), rewriter.getI32IntegerAttr(0));
1175 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1176 mulOp, resTy, newMulOp.getResult(), shiftParamOp.getResult());
1181template <
typename SrcOpTy,
typename DstOpTy>
1189 ConversionPatternRewriter &rewriter)
const override {
1190 VectorType resultType = dyn_cast<VectorType>(srcOp.getType());
1196 llvm::SmallSet<std::pair<unsigned, signed>, 16> laneSizeElWidthPairSet;
1197 laneSizeElWidthPairSet.insert({64, 8});
1198 laneSizeElWidthPairSet.insert({32, 16});
1199 laneSizeElWidthPairSet.insert({16, 32});
1200 laneSizeElWidthPairSet.insert({32, 32});
1202 auto lhs = adaptor.getLhs();
1203 auto rhs = adaptor.getRhs();
1204 auto lhsDefOp = lhs.getDefiningOp();
1205 auto rhsDefOp = rhs.getDefiningOp();
1206 if ((lhsDefOp && isa<arith::MulIOp>(lhsDefOp)) ||
1207 (rhsDefOp && isa<arith::MulIOp>(rhsDefOp)) ||
1208 (lhsDefOp && isa<arith::MulFOp>(lhsDefOp)) ||
1209 (rhsDefOp && isa<arith::MulFOp>(rhsDefOp)))
1212 Type scalarType = resultType.getElementType();
1213 unsigned resultElWidth = scalarType.getIntOrFloatBitWidth();
1217 if (isa<IntegerType>(scalarType)) {
1218 if (!laneSizeElWidthPairSet.count(
1219 std::make_pair(laneSize, resultElWidth)))
1225 if (!lhsDefOp && !rhsDefOp) {
1226 if (laneSize * resultElWidth == 512) {
1227 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getType(), lhs,
1231 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs, resultType,
1236 if (resultElWidth == 32) {
1237 auto lhsExt = getSourceOfWideningOp(lhs).value_or(
nullptr);
1238 auto rhsExt = getSourceOfWideningOp(rhs).value_or(
nullptr);
1240 if (!lhsExt && !rhsExt) {
1241 if (laneSize * resultElWidth == 512) {
1242 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getType(), lhs,
1246 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1250 if (lhsExt && rhsExt) {
1253 VectorType lSrcType = cast<VectorType>(lval.getType());
1257 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, lval);
1259 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, rval);
1260 auto elemOp = rewriter.create<DstOpTy>(
1261 srcOp.getLoc(), lUpsOp->getResult(0).getType(),
1262 lUpsOp->getResult(0), rUpsOp->getResult(0));
1263 rewriter.replaceOpWithNewOp<aievec::CastOp>(
1264 srcOp, srcOp.getType(), elemOp.getResult(),
false);
1268 if (!lhsExt || !rhsExt) {
1269 auto lval = lhsExt ? lhsExt : lhs;
1270 auto rval = rhsExt ? rhsExt : rhs;
1271 auto extVal = lhsExt ? lval : rval;
1272 VectorType vType = cast<VectorType>(extVal.getType());
1273 unsigned bitWidth = vType.getElementType().getIntOrFloatBitWidth();
1275 if (bitWidth != 8 && bitWidth != 16) {
1276 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1280 if (bitWidth * laneSize != 256) {
1281 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1285 Type accType =
nullptr;
1287 if (bitWidth == 8) {
1289 Value valToUps = lhsExt ? lval : rval;
1290 Value valToCast = lhsExt ? rval : lval;
1291 auto upsOp = rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType,
1293 auto castOp = rewriter.create<aievec::CastOp>(
1294 srcOp.getLoc(), resultType, valToCast,
true);
1296 lhsExt ? upsOp->getResult(0) : castOp->getResult(0);
1298 lhsExt ? castOp->getResult(0) : upsOp->getResult(0);
1299 auto elemOp = rewriter.create<DstOpTy>(
1300 srcOp.getLoc(), upsOp->getResult(0).getType(), lhsToElemOp,
1302 rewriter.replaceOpWithNewOp<aievec::CastOp>(
1303 srcOp, srcOp.getType(), elemOp.getResult(),
false);
1307 if (bitWidth == 16) {
1310 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, lval);
1312 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, rval);
1314 auto elemOp = rewriter.create<DstOpTy>(
1315 srcOp.getLoc(), lUpsOp->getResult(0).getType(),
1316 lUpsOp->getResult(0), rUpsOp->getResult(0));
1318 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
1319 srcOp.getLoc(), rewriter.getI32IntegerAttr(0));
1320 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1321 srcOp, srcOp.getType(), elemOp.getResult(),
1322 shiftParamOp.getResult());
1327 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getType(), lhs, rhs);
1337 if (resultElWidth == 32) {
1338 if (!lhsDefOp && !rhsDefOp) {
1339 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1343 auto lhsExt = getSourceOfWideningOp(lhs).value_or(
nullptr);
1344 auto rhsExt = getSourceOfWideningOp(rhs).value_or(
nullptr);
1346 if (!lhsExt && !rhsExt) {
1347 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1352 if (lhsExt && rhsExt) {
1355 VectorType vType = cast<VectorType>(lval.getType());
1359 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, lval);
1361 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, rval);
1362 auto elemOp = rewriter.create<DstOpTy>(
1363 srcOp.getLoc(), lUpsOp->getResult(0).getType(),
1364 lUpsOp->getResult(0), rUpsOp->getResult(0));
1365 rewriter.replaceOpWithNewOp<aievec::CastOp>(srcOp, srcOp.getType(),
1366 elemOp.getResult());
1371 if (!lhsExt || !rhsExt) {
1372 auto lval = lhsExt ? lhsExt : lhs;
1373 auto rval = rhsExt ? rhsExt : rhs;
1374 auto extVal = lhsExt ? lval : rval;
1375 VectorType vType = cast<VectorType>(extVal.getType());
1378 aievec::UPSOp upsOp;
1379 aievec::CastOp castOp;
1382 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, lval);
1383 castOp = rewriter.create<aievec::CastOp>(srcOp.getLoc(), resultType,
1388 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, rval);
1389 castOp = rewriter.create<aievec::CastOp>(srcOp.getLoc(), resultType,
1394 auto elemOp = rewriter.create<DstOpTy>(
1395 srcOp.getLoc(), upsOp->getResult(0).getType(),
1396 upsOp->getResult(0), castOp->getResult(0));
1398 rewriter.replaceOpWithNewOp<aievec::CastOp>(
1399 srcOp, srcOp.getType(), elemOp.getResult(),
false);
1408 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, lhs);
1410 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, rhs);
1411 auto elemOp = rewriter.create<DstOpTy>(
1412 srcOp.getLoc(), lUpsOp->getResult(0).getType(), lUpsOp->getResult(0),
1413 rUpsOp->getResult(0));
1414 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
1415 srcOp.getLoc(), rewriter.getI32IntegerAttr(0));
1416 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1417 srcOp, srcOp.getType(), elemOp.getResult(), shiftParamOp.getResult());
1439template <
typename SrcOpTy,
typename DstOpTy>
1446 ConversionPatternRewriter &rewriter)
const override {
1447 VectorType resultType = dyn_cast<VectorType>(srcOp.getType());
1452 llvm::SmallSet<unsigned, 16> elWidthSet;
1453 elWidthSet.insert(8);
1454 elWidthSet.insert(16);
1455 elWidthSet.insert(32);
1457 Type scalarType = resultType.getElementType();
1458 unsigned resultElWidth = scalarType.getIntOrFloatBitWidth();
1461 if (!elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512)
1464 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getType(),
1465 adaptor.getLhs(), adaptor.getRhs());
1479template <
typename SrcOpTy,
typename CmpTy>
1486 ConversionPatternRewriter &rewriter)
const override {
1487 VectorType lhsType = dyn_cast<VectorType>(srcOp.getLhs().getType());
1491 llvm::SmallSet<unsigned, 16> elWidthSet;
1492 elWidthSet.insert(8);
1493 elWidthSet.insert(16);
1494 elWidthSet.insert(32);
1496 Type scalarType = lhsType.getElementType();
1497 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
1500 if (!elWidthSet.count(elWidth) || laneSize * elWidth != 512)
1505 mlir::IntegerType::get(srcOp.getContext(), laneSize <= 32 ? 32 : 64,
1506 mlir::IntegerType::Unsigned);
1508 Location loc = srcOp.getLoc();
1509 Value lhs = srcOp.getLhs();
1510 Value rhs = srcOp.getRhs();
1511 CmpTy pred = srcOp.getPredicate();
1513 arith::CmpIPredicate ipred = convertToIntegerPredicate(pred);
1515 aievec::CmpOp aieCmpOp =
1516 createCmpOpAIE2(rewriter, ipred, loc, type, lhs, rhs);
1521 VectorType resultType = dyn_cast<VectorType>(srcOp.getResult().getType());
1524 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
1525 srcOp, resultType, aieCmpOp.getResult());
1537 using OpConversionPattern::OpConversionPattern;
1541 ConversionPatternRewriter &rewriter)
const override {
1542 auto resultType = dyn_cast<VectorType>(srcOp.getType());
1546 llvm::SmallSet<unsigned, 16> elWidthSet;
1547 elWidthSet.insert(8);
1548 elWidthSet.insert(16);
1549 elWidthSet.insert(32);
1551 Type scalarType = resultType.getElementType();
1552 unsigned resultElWidth = scalarType.getIntOrFloatBitWidth();
1555 if (!elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512)
1559 mlir::IntegerType::get(srcOp.getContext(), laneSize <= 32 ? 32 : 64,
1560 mlir::IntegerType::Unsigned);
1562 auto convertOp = rewriter.create<UnrealizedConversionCastOp>(
1563 srcOp.getLoc(), type, adaptor.getCondition());
1565 rewriter.replaceOpWithNewOp<aievec::SelOp>(
1566 srcOp, srcOp.getResult().getType(), srcOp.getTrueValue(),
1567 srcOp.getFalseValue(), convertOp.getResult(0));
1574 using OpConversionPattern::OpConversionPattern;
1578 ConversionPatternRewriter &rewriter)
const override {
1579 if (
auto kind = srcOp.getKind(); kind != vector::CombiningKind::MINSI &&
1580 kind != vector::CombiningKind::MINUI &&
1581 kind != vector::CombiningKind::MINIMUMF)
1584 auto vType = cast<VectorType>(srcOp.getVector().getType());
1585 Type scalarType = vType.getElementType();
1586 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
1589 if (laneSize * elWidth != 512)
1592 int shiftIndex = laneSize / 2;
1593 generateAIEVecOpsForReductionOp<aievec::MinOp>(rewriter, srcOp, shiftIndex,
1600 using OpConversionPattern::OpConversionPattern;
1604 ConversionPatternRewriter &rewriter)
const override {
1605 if (
auto kind = srcOp.getKind(); kind != vector::CombiningKind::MAXSI &&
1606 kind != vector::CombiningKind::MAXUI &&
1607 kind != vector::CombiningKind::MAXIMUMF)
1610 auto vType = cast<VectorType>(srcOp.getVector().getType());
1611 Type scalarType = vType.getElementType();
1612 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
1615 if (laneSize * elWidth != 512)
1618 int shiftIndex = laneSize / 2;
1619 generateAIEVecOpsForReductionOp<aievec::MaxOp>(rewriter, srcOp, shiftIndex,
1626 using OpConversionPattern::OpConversionPattern;
1630 ConversionPatternRewriter &rewriter)
const override {
1631 if (
auto kind = srcOp.getKind(); kind != vector::CombiningKind::ADD)
1634 auto vType = cast<VectorType>(srcOp.getVector().getType());
1635 Type scalarType = vType.getElementType();
1636 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
1638 llvm::SmallSet<std::pair<unsigned, signed>, 16> laneSizeElWidthPairSet;
1639 laneSizeElWidthPairSet.insert({64, 8});
1640 laneSizeElWidthPairSet.insert({32, 16});
1641 laneSizeElWidthPairSet.insert({32, 32});
1642 laneSizeElWidthPairSet.insert({16, 32});
1644 if (!isa<IntegerType>(scalarType) ||
1645 !laneSizeElWidthPairSet.count(std::make_pair(laneSize, elWidth)))
1648 int shiftIndex = laneSize / 2;
1649 if (laneSize == 32 && elWidth == 32) {
1650 Location loc = srcOp.getLoc();
1654 rewriter.create<aievec::ExtOp>(loc, vecType, srcOp.getVector(), 0);
1656 rewriter.create<aievec::ExtOp>(loc, vecType, srcOp.getVector(), 1);
1657 auto addElemOp = rewriter.create<aievec::AddElemOp>(
1658 loc, lExtOp.getResult().getType(), lExtOp.getResult(),
1659 rExtOp.getResult());
1661 generateAIEVecOpsForReductionOp<aievec::AddElemOp>(
1662 rewriter, srcOp, shiftIndex, addElemOp.getResult());
1664 generateAIEVecOpsForReductionOp<aievec::AddElemOp>(
1665 rewriter, srcOp, shiftIndex, srcOp.getVector());
1673 using OpConversionPattern::OpConversionPattern;
1677 ConversionPatternRewriter &rewriter)
const override {
1678 if (
auto kind = srcOp.getKind(); kind != vector::CombiningKind::ADD)
1681 auto vType = cast<VectorType>(srcOp.getVector().getType());
1682 Type scalarType = vType.getElementType();
1683 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
1686 if (!isa<FloatType>(scalarType) || laneSize != 16 || elWidth != 32)
1689 int shiftIndex = laneSize / 2;
1690 assert(shiftIndex > 0 && (shiftIndex & (shiftIndex - 1)) == 0 &&
1691 "shiftIndex must be power of 2");
1693 Location loc = srcOp.getLoc();
1694 Value curValue = srcOp.getVector();
1695 aievec::CastOp curOp =
nullptr;
1697 for (
int id = shiftIndex;
id > 0;
id /= 2) {
1698 auto constOp = rewriter.create<arith::ConstantOp>(
1699 loc, rewriter.getI32IntegerAttr(
id * elWidth / 8));
1701 auto shiftBytesOp = rewriter.create<aievec::ShiftOp>(
1702 loc, vType, curValue, curValue, constOp.getResult());
1704 auto lCastOp = rewriter.create<aievec::CastOp>(loc, vType, curValue,
1707 rewriter.create<aievec::CastOp>(loc, vType, shiftBytesOp.getResult(),
1709 auto elemOp = rewriter.create<aievec::AddElemOp>(
1710 loc, lCastOp.getResult().getType(), lCastOp.getResult(),
1711 rCastOp.getResult());
1712 curOp = rewriter.create<aievec::CastOp>(loc, vType, elemOp.getResult(),
1714 curValue = curOp.getResult();
1718 rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
1719 rewriter.replaceOpWithNewOp<aievec::ExtElemOp>(srcOp, scalarType, curOp,
1720 zeroConstOp.getResult());
1727 using OpConversionPattern::OpConversionPattern;
1731 ConversionPatternRewriter &rewriter)
const override {
1732 if (
auto kind = srcOp.getKind(); kind != vector::CombiningKind::ADD)
1735 auto vType = cast<VectorType>(srcOp.getVector().getType());
1736 Type scalarType = vType.getElementType();
1737 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
1740 if (!isa<FloatType>(scalarType) || laneSize != 16 || elWidth != 16)
1743 int shiftIndex = laneSize / 2;
1744 assert(shiftIndex > 0 && (shiftIndex & (shiftIndex - 1)) == 0 &&
1745 "shiftIndex must be power of 2");
1747 Value curValue = srcOp.getVector();
1748 Location loc = srcOp.getLoc();
1751 dyn_cast<VectorType>(accType).getElementType().getIntOrFloatBitWidth();
1754 rewriter.create<aievec::UPSOp>(loc, accType, srcOp.getVector());
1756 curValue = upsOp.getResult();
1759 aievec::AddElemOp curOp =
nullptr;
1761 for (
int id = shiftIndex;
id > 0;
id /= 2) {
1762 auto constOp = rewriter.create<arith::ConstantOp>(
1763 loc, rewriter.getI32IntegerAttr(
id * accWidth / 8));
1764 auto shiftBytesOp = rewriter.create<aievec::ShiftOp>(
1765 loc, accType, curValue, curValue, constOp,
true);
1766 curOp = rewriter.create<aievec::AddElemOp>(loc, accType, curValue,
1767 shiftBytesOp.getResult());
1768 curValue = curOp.getResult();
1771 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
1772 srcOp.getLoc(), rewriter.getI32IntegerAttr(0));
1773 auto srsOp = rewriter.create<aievec::SRSOp>(loc, vType, curOp.getResult(),
1774 shiftParamOp.getResult());
1775 SmallVector<Value> concatSources = {srsOp.getResult(), srsOp.getResult()};
1777 rewriter.create<aievec::ConcatOp>(loc, vecType, concatSources);
1780 rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
1781 rewriter.replaceOpWithNewOp<aievec::ExtElemOp>(srcOp, scalarType, concatOp,
1782 zeroConstOp.getResult());
1791 using OpConversionPattern::OpConversionPattern;
1795 ConversionPatternRewriter &rewriter)
const override {
1796 auto vType = extractOp.getSourceVectorType();
1797 if (vType.getRank() != 1)
1800 int64_t stride = cast<IntegerAttr>(adaptor.getStrides()[0]).getInt();
1806 return extractOp.emitError()
1807 <<
"AIEv1 doesn't support select ops on int8 types";
1811 int64_t size = cast<IntegerAttr>(adaptor.getSizes()[0]).getInt();
1812 if (vType.getNumElements() != 2 * size)
1815 int64_t offset = cast<IntegerAttr>(adaptor.getOffsets()[0]).getInt();
1816 auto selectOp = rewriter.create<aievec::aie1::SelectOp>(
1817 extractOp.getLoc(), vType, adaptor.getVector(),
1818 buildAttributeListForRotationSelectOp(rewriter, vType, offset));
1819 rewriter.replaceOpWithNewOp<aievec::aie1::ExtOp>(
1820 extractOp, extractOp.getType(), selectOp.getResult(),
1821 rewriter.getI8IntegerAttr(0));
1830 using OpConversionPattern::OpConversionPattern;
1834 ConversionPatternRewriter &rewriter)
const override {
1835 auto vType = cast<VectorType>(adaptor.getVector().getType());
1836 if (vType.getRank() != 1)
1839 int64_t stride = cast<IntegerAttr>(adaptor.getStrides()[0]).getInt();
1845 int64_t size = cast<IntegerAttr>(adaptor.getSizes()[0]).getInt();
1846 if (vType.getNumElements() != 2 * size)
1849 auto shortVecType = cast<VectorType>(extractOp.getResult().getType());
1850 auto bottomHalf = rewriter
1851 .create<aievec::ExtOp>(
1852 extractOp.getLoc(), shortVecType,
1853 adaptor.getVector(), rewriter.getI8IntegerAttr(0))
1855 auto topHalf = rewriter
1856 .create<aievec::ExtOp>(extractOp.getLoc(), shortVecType,
1857 adaptor.getVector(),
1858 rewriter.getI8IntegerAttr(1))
1860 int64_t offset = cast<IntegerAttr>(adaptor.getOffsets()[0]).getInt();
1862 auto shiftBytesConstOp = rewriter.create<arith::ConstantOp>(
1863 extractOp.getLoc(), rewriter.getIntegerType(32),
1864 rewriter.getI32IntegerAttr(shiftBytes));
1865 rewriter.replaceOpWithNewOp<aievec::ShiftOp>(
1866 extractOp, shortVecType, bottomHalf, topHalf, shiftBytesConstOp);
1875 using OpConversionPattern::OpConversionPattern;
1882 ConversionPatternRewriter &rewriter)
const override {
1884 if (updOp->hasOneUse() && isa<aievec::ExtOp>(*updOp->getUsers().begin()))
1887 auto vecType = cast<VectorType>(updOp.getType());
1888 SmallVector<int64_t, 4> vecShape(vecType.getShape().begin(),
1889 vecType.getShape().end());
1890 vecShape[vecType.getRank() - 1] *= 2;
1891 auto longVecType = VectorType::get(vecShape, vecType.getElementType());
1892 auto newUpdOp = rewriter.create<aievec::UPDOp>(
1893 updOp.getLoc(), longVecType, adaptor.getSource(), adaptor.getIndices(),
1894 adaptor.getOffset(), adaptor.getIndex(), adaptor.getVector());
1895 rewriter.replaceOpWithNewOp<aievec::ExtOp>(
1896 updOp, vecType, newUpdOp.getResult(), rewriter.getI8IntegerAttr(0));
1905 using OpConversionPattern::OpConversionPattern;
1911 ConversionPatternRewriter &rewriter)
const override {
1913 if (extOp.getIndex() != 0)
1916 auto updOp = dyn_cast<aievec::UPDOp>(extOp.getSource().getDefiningOp());
1921 if (!updOp->hasOneUse())
1924 rewriter.replaceOpWithNewOp<aievec::UPDOp>(
1925 extOp, extOp.getType(), updOp.getSource(), updOp.getIndices(),
1926 updOp.getOffset(), updOp.getIndex(), updOp.getVector());
1933 using OpConversionPattern::OpConversionPattern;
1937 ConversionPatternRewriter &rewriter)
const override {
1939 if (!matchExpOpForLUT(adaptor))
1942 auto srcType = dyn_cast<VectorType>(adaptor.getOperand().getType());
1943 StringRef funcName =
"getExpBf16";
1944 auto moduleOp = expOp->getParentOfType<mlir::ModuleOp>();
1946 VectorType v16bf16Ty = mlir::VectorType::get({16}, rewriter.getBF16Type());
1947 VectorType v8i64Ty = mlir::VectorType::get({8}, rewriter.getI64Type());
1948 func::FuncOp fnOp = getOrInsertFuncDecl(
1949 rewriter, moduleOp, funcName, TypeRange{v16bf16Ty}, TypeRange{v8i64Ty});
1951 SmallVector<Value> expOperands = {adaptor.getOperand()};
1955 rewriter.create<func::CallOp>(expOp.getLoc(), fnOp, expOperands);
1956 auto resCastOp = rewriter.create<vector::BitCastOp>(
1957 expOp.getLoc(), accTypeNative, callOp.getResults());
1958 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
1959 expOp.getLoc(), rewriter.getI32IntegerAttr(0));
1960 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1961 expOp, srcType, resCastOp.getResult(), shiftParamOp.getResult());
1968 using OpConversionPattern::OpConversionPattern;
1972 ConversionPatternRewriter &rewriter)
const override {
1973 if (!matchExpOpForLUT(adaptor))
1975 auto srcType = dyn_cast<VectorType>(adaptor.getOperand().getType());
1976 StringRef includeName =
"lut_based_ops.h";
1977 auto moduleOp = expOp->getParentOfType<mlir::ModuleOp>();
1978 rewriter.setInsertionPointToStart(
1979 &moduleOp.getRegion().getBlocks().front());
1980 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName,
false);
1982 rewriter.setInsertionPoint(expOp);
1984 auto v16bf16OpaqueTy =
1985 emitc::OpaqueType::get(rewriter.getContext(),
"v16bfloat16");
1986 auto opaquedOperand =
1988 .create<UnrealizedConversionCastOp>(expOp.getLoc(), v16bf16OpaqueTy,
1989 adaptor.getOperand())
1991 SmallVector<Value> expOperands = {opaquedOperand};
1994 Type v16accf32OpaqueTy =
1995 emitc::OpaqueType::get(rewriter.getContext(),
"v16accfloat");
1996 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
1997 expOp.getLoc(), TypeRange{v16accf32OpaqueTy},
"getExpBf16",
nullptr,
1998 nullptr, expOperands);
1999 auto resCastOp = rewriter.create<UnrealizedConversionCastOp>(
2000 expOp.getLoc(), accTypeNative, callOp.getResults());
2001 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
2002 expOp.getLoc(), rewriter.getI32IntegerAttr(0));
2003 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
2004 expOp, srcType, resCastOp.getResult(0), shiftParamOp.getResult());
2018 using OpConversionPattern::OpConversionPattern;
2022 ConversionPatternRewriter &rewriter)
const override {
2023 Type srcType = adaptor.getLhs().getType();
2024 if (!divOp->hasOneUse() || isa<VectorType>(srcType) ||
2025 !isa<FloatType>(srcType))
2028 if (!isNarrowingOp(*divOp->getUsers().begin()))
2031 auto fType = cast<FloatType>(srcType);
2032 if (fType.getWidth() != 32)
2035 auto constOp = dyn_cast<arith::ConstantOp>(divOp.getLhs().getDefiningOp());
2037 cast<FloatAttr>(constOp.getValue()).getValue().convertToDouble() !=
2041 StringRef includeName =
"lut_based_ops.h";
2042 auto moduleOp = divOp->getParentOfType<mlir::ModuleOp>();
2043 rewriter.setInsertionPointToStart(
2044 &moduleOp.getRegion().getBlocks().front());
2045 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName,
false);
2047 auto truncOp = cast<arith::TruncFOp>(*divOp->getUsers().begin());
2049 rewriter.setInsertionPoint(truncOp);
2051 emitc::OpaqueType::get(rewriter.getContext(),
"bfloat16");
2052 SmallVector<Value> invOperands = {adaptor.getRhs()};
2053 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2054 truncOp.getLoc(), bf16OpaqueTy,
"getInvBf16",
nullptr,
nullptr,
2056 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2057 truncOp, TypeRange{truncOp.getResult().getType()}, callOp.getResults());
2058 rewriter.eraseOp(divOp);
2066 using OpConversionPattern::OpConversionPattern;
2070 ConversionPatternRewriter &rewriter)
const override {
2071 auto srcType = dyn_cast<VectorType>(tanhOp.getOperand().getType());
2075 Type scalarType = srcType.getElementType();
2076 if (!isa<FloatType>(scalarType))
2080 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2081 if (elWidth != 16 || laneSize != 16)
2084 StringRef includeName =
"lut_based_ops.h";
2085 auto moduleOp = tanhOp->getParentOfType<mlir::ModuleOp>();
2086 rewriter.setInsertionPointToStart(
2087 &moduleOp.getRegion().getBlocks().front());
2088 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName,
false);
2090 rewriter.setInsertionPoint(tanhOp);
2091 Type v16bf16OpaqueTy =
2092 emitc::OpaqueType::get(rewriter.getContext(),
"v16bfloat16");
2093 auto opaquedOperand =
2095 .create<UnrealizedConversionCastOp>(
2096 tanhOp.getLoc(), v16bf16OpaqueTy, adaptor.getOperand())
2098 SmallVector<Value> tanhOperands = {opaquedOperand};
2099 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2100 tanhOp.getLoc(), v16bf16OpaqueTy,
"getTanhBf16",
nullptr,
nullptr,
2102 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2103 tanhOp, TypeRange{tanhOp.getResult().getType()}, callOp.getResults());
2112 using OpConversionPattern::OpConversionPattern;
2116 ConversionPatternRewriter &rewriter)
const override {
2117 auto srcType = dyn_cast<VectorType>(sqrtOp.getOperand().getType());
2121 Type scalarType = srcType.getElementType();
2122 if (!isa<FloatType>(scalarType))
2126 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2127 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
2130 StringRef includeName =
"vec_math.h";
2131 auto moduleOp = sqrtOp->getParentOfType<mlir::ModuleOp>();
2132 rewriter.setInsertionPointToStart(
2133 &moduleOp.getRegion().getBlocks().front());
2134 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName,
false);
2136 rewriter.setInsertionPoint(sqrtOp);
2137 Type vLNbf16OpaqueTy;
2140 emitc::OpaqueType::get(rewriter.getContext(),
"v16bfloat16");
2143 emitc::OpaqueType::get(rewriter.getContext(),
"v32bfloat16");
2144 auto opaquedOperand =
2146 .create<UnrealizedConversionCastOp>(
2147 sqrtOp.getLoc(), vLNbf16OpaqueTy, adaptor.getOperand())
2149 SmallVector<Value> sqrtOperands = {opaquedOperand};
2150 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2151 sqrtOp.getLoc(), TypeRange{vLNbf16OpaqueTy},
"getSqrtBf16",
nullptr,
2152 nullptr, sqrtOperands);
2153 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2154 sqrtOp, TypeRange{sqrtOp.getResult().getType()}, callOp.getResults());
2163 using OpConversionPattern::OpConversionPattern;
2167 ConversionPatternRewriter &rewriter)
const override {
2168 auto srcType = dyn_cast<VectorType>(rsqrtOp.getOperand().getType());
2172 Type scalarType = srcType.getElementType();
2173 if (!isa<FloatType>(scalarType))
2177 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2178 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
2181 StringRef includeName =
"vec_math.h";
2182 auto moduleOp = rsqrtOp->getParentOfType<mlir::ModuleOp>();
2183 rewriter.setInsertionPointToStart(
2184 &moduleOp.getRegion().getBlocks().front());
2185 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName,
false);
2187 rewriter.setInsertionPoint(rsqrtOp);
2188 Type vLNbf16OpaqueTy;
2191 emitc::OpaqueType::get(rewriter.getContext(),
"v16bfloat16");
2194 emitc::OpaqueType::get(rewriter.getContext(),
"v32bfloat16");
2195 auto opaquedOperand =
2197 .create<UnrealizedConversionCastOp>(
2198 rsqrtOp.getLoc(), vLNbf16OpaqueTy, adaptor.getOperand())
2200 SmallVector<Value> rsqrtOperands = {opaquedOperand};
2201 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2202 rsqrtOp.getLoc(), TypeRange{vLNbf16OpaqueTy},
"getRsqrtBf16",
nullptr,
2203 nullptr, rsqrtOperands);
2204 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2205 rsqrtOp, TypeRange{rsqrtOp.getResult().getType()}, callOp.getResults());
2214 using OpConversionPattern::OpConversionPattern;
2218 ConversionPatternRewriter &rewriter)
const override {
2219 auto srcType = dyn_cast<VectorType>(erfOp.getOperand().getType());
2223 Type scalarType = srcType.getElementType();
2224 if (!isa<FloatType>(scalarType))
2228 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2229 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
2232 StringRef includeName =
"vec_math.h";
2233 auto moduleOp = erfOp->getParentOfType<mlir::ModuleOp>();
2234 rewriter.setInsertionPointToStart(
2235 &moduleOp.getRegion().getBlocks().front());
2236 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName,
false);
2238 rewriter.setInsertionPoint(erfOp);
2239 Type vLNbf16OpaqueTy;
2242 emitc::OpaqueType::get(rewriter.getContext(),
"v16bfloat16");
2245 emitc::OpaqueType::get(rewriter.getContext(),
"v32bfloat16");
2246 auto opaquedOperand =
2248 .create<UnrealizedConversionCastOp>(erfOp.getLoc(), vLNbf16OpaqueTy,
2249 adaptor.getOperand())
2251 SmallVector<Value> erfOperands = {opaquedOperand};
2252 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2253 erfOp.getLoc(), TypeRange{vLNbf16OpaqueTy},
"getErfBf16",
nullptr,
2254 nullptr, erfOperands);
2255 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2256 erfOp, TypeRange{erfOp.getResult().getType()}, callOp.getResults());
2264template <
typename SrcOpTy>
2271 ConversionPatternRewriter &rewriter)
const override {
2272 auto vecTy = dyn_cast<VectorType>(absOp.getOperand().getType());
2276 Type elemTy = vecTy.getElementType();
2279 unsigned elWidth = elemTy.getIntOrFloatBitWidth();
2281 StringRef includeName =
"vec_math.h";
2282 auto moduleOp = absOp->template getParentOfType<mlir::ModuleOp>();
2283 rewriter.setInsertionPointToStart(
2284 &moduleOp.getRegion().getBlocks().front());
2285 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName,
false);
2287 rewriter.setInsertionPoint(absOp);
2288 std::ostringstream typeName;
2289 typeName <<
"v" << laneSize;
2290 if (isa<FloatType>(elemTy)) {
2292 typeName <<
"bfloat16";
2294 typeName <<
"float";
2296 typeName <<
"int" << elWidth;
2298 emitc::OpaqueType::get(rewriter.getContext(), typeName.str());
2299 auto opaquedOperand =
2301 .create<UnrealizedConversionCastOp>(absOp.getLoc(), vecOpaqueTy,
2302 adaptor.getOperand())
2304 SmallVector<Value> absOperands = {opaquedOperand};
2305 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2306 absOp.getLoc(), TypeRange{vecOpaqueTy},
"getAbs",
nullptr,
nullptr,
2308 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2309 absOp, TypeRange{absOp.getResult().getType()}, callOp.getResults());
2318template <
typename SrcOpTy>
2325 ConversionPatternRewriter &rewriter)
const override {
2326 VectorType srcType = dyn_cast<VectorType>(extOp.getIn().getType());
2327 VectorType dstType = dyn_cast<VectorType>(extOp.getOut().getType());
2331 rewriter.create<aievec::UPSOp>(extOp.getLoc(), accType, extOp.getIn());
2333 if (dstType.getElementType().getIntOrFloatBitWidth() == 16) {
2334 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
2335 extOp.getLoc(), rewriter.getI32IntegerAttr(0));
2336 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
2337 extOp, dstType, upsOp.getResult(), shiftParamOp.getResult());
2339 rewriter.replaceOpWithNewOp<aievec::CastOp>(
2340 extOp, dstType, upsOp.getResult(),
false);
2349template <
typename SrcOpTy>
2356 ConversionPatternRewriter &rewriter)
const override {
2357 VectorType srcType = dyn_cast<VectorType>(truncOp.getIn().getType());
2358 VectorType dstType = dyn_cast<VectorType>(truncOp.getOut().getType());
2359 Type scalarType = srcType.getElementType();
2360 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2363 auto accType = isa<IntegerType>(scalarType) && (elWidth == 32)
2367 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
2368 truncOp.getLoc(), rewriter.getI32IntegerAttr(0));
2369 if (elWidth == 16) {
2370 auto upsOp = rewriter.create<aievec::UPSOp>(truncOp.getLoc(), accType,
2372 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
2373 truncOp, dstType, upsOp.getResult(), shiftParamOp.getResult());
2375 auto castOp = rewriter.create<aievec::CastOp>(truncOp.getLoc(), accType,
2376 truncOp.getIn(),
true);
2377 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
2378 truncOp, dstType, castOp.getResult(), shiftParamOp.getResult());
2393static std::optional<Value>
2394getUnOpaquedOperandOfEmitCOpaqueCallOp(Operation *op, StringRef funcName) {
2395 auto uccOp = dyn_cast<UnrealizedConversionCastOp>(op);
2399 auto inVal = uccOp.getInputs()[0];
2400 if (!isa<emitc::OpaqueType>(inVal.getType()))
2403 auto callOp = inVal.getDefiningOp<emitc::CallOpaqueOp>();
2404 if (callOp.getCallee() != funcName)
2407 auto callOperandsUccOp =
2408 callOp.getOperands()[0].getDefiningOp<UnrealizedConversionCastOp>();
2409 if (!callOperandsUccOp)
2412 return callOperandsUccOp.getInputs()[0];
2428template <
typename DivFOpTy>
2429static bool hasSigmoidComputationChain(DivFOpTy divfOp, arith::NegFOp &negOp) {
2430 auto constOp = dyn_cast<arith::ConstantOp>(divfOp.getLhs().getDefiningOp());
2434 auto cstDense = dyn_cast<DenseFPElementsAttr>(constOp.getValue());
2438 if (cstDense.template getSplatValue<APFloat>().convertToFloat() != 1.0f)
2441 Operation *addLvalOp;
2442 Operation *addRvalOp;
2448 auto addOp = dyn_cast<arith::AddFOp>(divfOp.getRhs().getDefiningOp());
2450 auto srsOp = dyn_cast<aievec::SRSOp>(divfOp.getRhs().getDefiningOp());
2455 dyn_cast<aievec::AddElemOp>(srsOp.getSource().getDefiningOp());
2459 auto lUpsOp = dyn_cast<aievec::UPSOp>(addElemOp.getLhs().getDefiningOp());
2460 auto rUpsOp = dyn_cast<aievec::UPSOp>(addElemOp.getRhs().getDefiningOp());
2461 if (!lUpsOp || !rUpsOp)
2464 addLvalOp = lUpsOp.getSource().getDefiningOp();
2465 addRvalOp = rUpsOp.getSource().getDefiningOp();
2468 auto addDefOp = isa<arith::ConstantOp>(addLvalOp)
2469 ? dyn_cast<aievec::SRSOp>(addRvalOp)
2470 : dyn_cast<aievec::SRSOp>(addLvalOp);
2472 addLvalOp = isa<arith::ConstantOp>(addLvalOp)
2473 ? dyn_cast<math::ExpOp>(addRvalOp)
2474 : dyn_cast<math::ExpOp>(addLvalOp);
2476 addLvalOp = addDefOp.getSource().getDefiningOp();
2478 addRvalOp = isa<arith::ConstantOp>(addLvalOp)
2479 ? lUpsOp.getSource().getDefiningOp()
2480 : rUpsOp.getSource().getDefiningOp();
2482 addLvalOp = addOp.getLhs().getDefiningOp();
2483 addRvalOp = addOp.getRhs().getDefiningOp();
2486 if (!addLvalOp || !addRvalOp)
2489 auto addLvalExpOp = dyn_cast<math::ExpOp>(addLvalOp);
2490 auto addRvalExpOp = dyn_cast<math::ExpOp>(addRvalOp);
2491 auto addLvalExpOpIn =
2492 getUnOpaquedOperandOfEmitCOpaqueCallOp(addLvalOp,
"getExpBf16")
2494 auto addRvalExpOpIn =
2495 getUnOpaquedOperandOfEmitCOpaqueCallOp(addRvalOp,
"getExpBf16")
2497 if (!addLvalExpOpIn && addLvalExpOp)
2498 addLvalExpOpIn = addLvalExpOp.getOperand();
2499 if (!addRvalExpOpIn && addRvalExpOp)
2500 addRvalExpOpIn = addRvalExpOp.getOperand();
2502 if (!((addLvalExpOpIn && isa<arith::ConstantOp>(addRvalOp)) ||
2503 (addRvalExpOpIn && isa<arith::ConstantOp>(addLvalOp))))
2506 constOp = isa<arith::ConstantOp>(addLvalOp)
2507 ? cast<arith::ConstantOp>(addLvalOp)
2508 : cast<arith::ConstantOp>(addRvalOp);
2510 cstDense = dyn_cast<DenseFPElementsAttr>(constOp.getValue());
2513 if (cstDense.template getSplatValue<APFloat>().convertToFloat() != 1.0f)
2516 auto expOperand = addLvalExpOpIn ? addLvalExpOpIn : addRvalExpOpIn;
2518 negOp = expOperand.getDefiningOp<arith::NegFOp>();
2520 return negOp !=
nullptr;
2537 using OpConversionPattern::OpConversionPattern;
2541 ConversionPatternRewriter &rewriter)
const override {
2542 auto srcType = dyn_cast<VectorType>(adaptor.getLhs().getType());
2546 Type scalarType = srcType.getElementType();
2547 if (!isa<FloatType>(scalarType))
2551 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2552 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
2555 arith::NegFOp negOp =
nullptr;
2556 if (!hasSigmoidComputationChain(adaptor, negOp))
2559 StringRef includeName =
"vec_math.h";
2560 auto moduleOp = divfOp->getParentOfType<mlir::ModuleOp>();
2561 rewriter.setInsertionPointToStart(
2562 &moduleOp.getRegion().getBlocks().front());
2563 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName,
false);
2565 rewriter.setInsertionPoint(divfOp);
2569 emitc::OpaqueType::get(rewriter.getContext(),
"v16bfloat16");
2572 emitc::OpaqueType::get(rewriter.getContext(),
"v32bfloat16");
2573 auto opaquedOperand =
2575 .create<UnrealizedConversionCastOp>(divfOp.getLoc(), vecOpaqueTy,
2578 SmallVector<Value> sigmoidOperands = {opaquedOperand};
2579 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2580 divfOp.getLoc(), TypeRange{vecOpaqueTy},
"getSigmoidBf16",
nullptr,
2581 nullptr, sigmoidOperands);
2582 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2583 divfOp, TypeRange{adaptor.getLhs().getType()}, callOp.getResults());
2591 using OpConversionPattern::OpConversionPattern;
2595 ConversionPatternRewriter &rewriter)
const override {
2596 auto srcType = dyn_cast<VectorType>(ceilOp.getOperand().getType());
2600 Type scalarType = srcType.getElementType();
2601 if (!isa<FloatType>(scalarType))
2605 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2606 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
2609 StringRef includeName =
"vec_math.h";
2610 auto moduleOp = ceilOp->getParentOfType<mlir::ModuleOp>();
2611 rewriter.setInsertionPointToStart(
2612 &moduleOp.getRegion().getBlocks().front());
2613 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName,
false);
2615 rewriter.setInsertionPoint(ceilOp);
2619 emitc::OpaqueType::get(rewriter.getContext(),
"v16bfloat16");
2622 emitc::OpaqueType::get(rewriter.getContext(),
"v32bfloat16");
2623 auto opaquedOperand =
2625 .create<UnrealizedConversionCastOp>(ceilOp.getLoc(), vecOpaqueTy,
2626 adaptor.getOperand())
2628 SmallVector<Value> ceilOperands = {opaquedOperand};
2629 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2630 ceilOp.getLoc(), TypeRange{vecOpaqueTy},
"getCeilBf16",
nullptr,
2631 nullptr, ceilOperands);
2632 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2633 ceilOp, TypeRange{ceilOp.getResult().getType()}, callOp.getResults());
2641 using OpConversionPattern::OpConversionPattern;
2645 ConversionPatternRewriter &rewriter)
const override {
2646 auto srcType = dyn_cast<VectorType>(floorOp.getOperand().getType());
2650 Type scalarType = srcType.getElementType();
2651 if (!isa<FloatType>(scalarType))
2655 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2656 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
2659 StringRef includeName =
"vec_math.h";
2660 auto moduleOp = floorOp->getParentOfType<mlir::ModuleOp>();
2661 rewriter.setInsertionPointToStart(
2662 &moduleOp.getRegion().getBlocks().front());
2663 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName,
false);
2665 rewriter.setInsertionPoint(floorOp);
2669 emitc::OpaqueType::get(rewriter.getContext(),
"v16bfloat16");
2672 emitc::OpaqueType::get(rewriter.getContext(),
"v32bfloat16");
2673 auto opaquedOperand =
2675 .create<UnrealizedConversionCastOp>(floorOp.getLoc(), vecOpaqueTy,
2676 adaptor.getOperand())
2678 SmallVector<Value> floorOperands = {opaquedOperand};
2679 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2680 floorOp.getLoc(), TypeRange{vecOpaqueTy},
"getFloorBf16",
nullptr,
2681 nullptr, floorOperands);
2682 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2683 floorOp, TypeRange{floorOp.getResult().getType()}, callOp.getResults());
2692 using OpConversionPattern::OpConversionPattern;
2696 ConversionPatternRewriter &rewriter)
const override {
2697 auto srcType = dyn_cast<VectorType>(negOp.getOperand().getType());
2701 Type scalarType = srcType.getElementType();
2702 if (!isa<FloatType>(scalarType))
2708 Location loc = negOp.getLoc();
2711 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2712 if (elWidth == 16) {
2714 rewriter.create<aievec::UPSOp>(loc, accType, adaptor.getOperand());
2716 rewriter.create<aievec::NegOp>(loc, accType, upsOp.getResult());
2717 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
2718 negOp.getLoc(), rewriter.getI32IntegerAttr(0));
2719 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
2720 negOp, srcType, aieNegOp.getResult(), shiftParamOp.getResult());
2722 auto castOp = rewriter.create<aievec::CastOp>(
2723 loc, accType, adaptor.getOperand(),
true);
2725 rewriter.create<aievec::NegOp>(loc, accType, castOp.getResult());
2726 rewriter.replaceOpWithNewOp<aievec::CastOp>(
2727 negOp, srcType, aieNegOp.getResult(),
false);
2736static bool hasConstNegOneValue(arith::ConstantOp constOp,
unsigned elWidth) {
2740 auto cstDense = dyn_cast<DenseIntElementsAttr>(constOp.getValue());
2745 return cstDense.getSplatValue<int32_t>() == -1;
2747 return cstDense.getSplatValue<int16_t>() == -1;
2749 return cstDense.getSplatValue<int8_t>() == -1;
2756 using OpConversionPattern::OpConversionPattern;
2760 ConversionPatternRewriter &rewriter)
const override {
2761 auto srcType = dyn_cast<VectorType>(xorOp.getLhs().getType());
2765 Type scalarType = srcType.getElementType();
2766 if (!isa<IntegerType>(scalarType))
2770 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2771 if (laneSize * elWidth != 512)
2775 dyn_cast<arith::ConstantOp>(xorOp.getLhs().getDefiningOp());
2777 dyn_cast<arith::ConstantOp>(xorOp.getRhs().getDefiningOp());
2781 if ((lhsConstOp && hasConstNegOneValue(lhsConstOp, elWidth)) ||
2782 (rhsConstOp && hasConstNegOneValue(rhsConstOp, elWidth))) {
2783 Value val = hasConstNegOneValue(lhsConstOp, elWidth) ? adaptor.getRhs()
2785 rewriter.replaceOpWithNewOp<aievec::BnegOp>(xorOp, srcType, val);
2787 rewriter.replaceOpWithNewOp<aievec::BxorOp>(
2788 xorOp, srcType, adaptor.getLhs(), adaptor.getRhs());
2794template <
typename SrcOpTy,
typename DstOpTy>
2801 ConversionPatternRewriter &rewriter)
const override {
2802 VectorType srcType = dyn_cast<VectorType>(srcOp.getLhs().getType());
2806 Type scalarType = srcType.getElementType();
2807 if (!isa<IntegerType>(scalarType))
2811 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2812 if (laneSize * elWidth != 512)
2815 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getResult().getType(),
2816 adaptor.getLhs(), adaptor.getRhs());
2832 using OpConversionPattern::OpConversionPattern;
2836 ConversionPatternRewriter &rewriter)
const override {
2837 auto srcType = dyn_cast<VectorType>(adaptor.getLhs().getType());
2841 Type scalarType = srcType.getElementType();
2843 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2844 if (laneSize * elWidth != 512)
2848 dyn_cast<aievec::BroadcastOp>(adaptor.getRhs().getDefiningOp());
2852 auto constOp = rewriter.create<arith::ConstantOp>(
2853 bcastOp.getLoc(), rewriter.getI32IntegerAttr(bcastOp.getIdx()));
2854 auto extElemOp = rewriter.create<aievec::ExtElemOp>(
2855 bcastOp.getLoc(), scalarType, bcastOp, constOp.getResult());
2856 Location loc = rsOp.getLoc();
2863 rewriter.create<aievec::ExtOp>(loc, halfSrcType, adaptor.getLhs(), 0);
2865 rewriter.create<aievec::ExtOp>(loc, halfSrcType, adaptor.getLhs(), 1);
2868 rewriter.create<aievec::UPSOp>(loc, accType, rsOpLow.getResult());
2869 auto srsOpLow = rewriter.create<aievec::SRSOp>(
2870 loc, halfSrcType, upsOpLow.getResult(), extElemOp.getResult());
2872 rewriter.create<aievec::UPSOp>(loc, accType, rsOpHigh.getResult());
2873 auto srsOpHigh = rewriter.create<aievec::SRSOp>(
2874 loc, halfSrcType, upsOpHigh.getResult(), extElemOp.getResult());
2875 SmallVector<Value> inputSources = {srsOpLow.getResult(),
2876 srsOpHigh.getResult()};
2877 rewriter.replaceOpWithNewOp<aievec::ConcatOp>(rsOp, srcType,
2882 rewriter.create<aievec::UPSOp>(loc, accType, adaptor.getLhs());
2883 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
2884 rsOp, srcType, upsOp.getResult(), extElemOp.getResult());
2894 using OpConversionPattern::OpConversionPattern;
2901 auto vecTy = dyn_cast<VectorType>(v.getType());
2904 auto vecShape = vecTy.getShape();
2906 size_t numLeadUnitDims = 0;
2907 while (numLeadUnitDims < vecShape.size() && vecShape[numLeadUnitDims] == 1)
2910 if (!numLeadUnitDims)
2913 SmallVector<int64_t> newShape(vecShape.begin() + numLeadUnitDims,
2915 auto newVecTy = VectorType::get(newShape, vecTy.getElementType());
2916 return b.create<vector::ShapeCastOp>(v.getLoc(), newVecTy, v).getResult();
2921 ConversionPatternRewriter &rewriter)
const override {
2925 bool bReshapedAcc = (acc != adaptor.getAcc());
2928 acc = rewriter.create<aievec::CastOp>(contractOp.getLoc(), acc.getType(),
2931 auto matmulOp = rewriter.create<aievec::MatMulOp>(
2932 contractOp.getLoc(), acc.getType(), lhs, rhs, acc);
2936 ScopedDiagnosticHandler diagHandler(
2937 contractOp.getContext(), [](Diagnostic &) { return success(); });
2938 if (failed(matmulOp.verifyInvariants())) {
2939 rewriter.eraseOp(matmulOp);
2943 lhs = adaptor.getLhs();
2944 auto wideLhsValue = getSourceOfWideningOp(lhs).value_or(
nullptr);
2948 rhs = adaptor.getRhs();
2949 auto wideRhsValue = getSourceOfWideningOp(rhs).value_or(
nullptr);
2953 matmulOp = rewriter.create<aievec::MatMulOp>(
2954 contractOp.getLoc(), acc.getType(), lhs, rhs, acc);
2955 if (failed(matmulOp.verifyInvariants()))
2960 Value result = matmulOp.getResult();
2962 result = rewriter.create<aievec::CastOp>(contractOp.getLoc(),
2963 acc.getType(), matmulOp,
false);
2965 result = rewriter.create<vector::ShapeCastOp>(
2966 contractOp.getLoc(), adaptor.getAcc().getType(), result);
2967 rewriter.replaceOp(contractOp, result);
2978 using OpConversionPattern::OpConversionPattern;
2981 ConversionPatternRewriter &rewriter)
const override {
2982 auto resTy = transpOp.getResultVectorType();
2983 auto resShape = resTy.getShape();
2984 auto elemTyBitWidth = resTy.getElementTypeBitWidth();
2985 auto vBitWidth = std::accumulate(resShape.begin(), resShape.end(),
2986 elemTyBitWidth, std::multiplies<>());
2987 if (vBitWidth != 512)
2990 if (elemTyBitWidth != 8 && elemTyBitWidth != 16 && elemTyBitWidth != 32)
2994 for (int64_t i = 0; i < static_cast<int64_t>(resShape.size() - 2); ++i)
2995 if (resShape[i] != 1)
2999 ArrayRef<int64_t> perm = transpOp.getPermutation();
3000 for (int64_t i = 0; i < static_cast<int64_t>(perm.size() - 2); ++i)
3003 if (perm.back() !=
static_cast<int64_t
>(perm.size() - 2))
3006 auto shuffleMode = aievec::ShuffleMode::T32_4X4;
3007 if (elemTyBitWidth == 8) {
3008 switch (resShape.back()) {
3010 shuffleMode = aievec::ShuffleMode::T8_4X16;
3013 shuffleMode = aievec::ShuffleMode::T8_8X8;
3016 shuffleMode = aievec::ShuffleMode::T8_16X4;
3021 }
else if (elemTyBitWidth == 16) {
3022 switch (resShape.back()) {
3024 shuffleMode = aievec::ShuffleMode::T16_2X16;
3027 shuffleMode = aievec::ShuffleMode::T16_4X8;
3030 shuffleMode = aievec::ShuffleMode::T16_8X4;
3033 shuffleMode = aievec::ShuffleMode::T16_16X2;
3038 }
else if (resShape.back() != 4)
3042 VectorType::get({512 / elemTyBitWidth}, resTy.getElementType());
3043 auto loc = transpOp.getLoc();
3044 auto flatInput = rewriter.create<vector::ShapeCastOp>(loc, flatVecTy,
3045 adaptor.getVector());
3046 auto shuffOp = rewriter.create<aievec::ShuffleOp>(loc, flatVecTy, flatInput,
3047 nullptr, shuffleMode);
3048 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transpOp, resTy, shuffOp);
3058static void populateAIEVecCommonConversionPatterns(RewritePatternSet &patterns,
3068static void populateAIEVecV1ConversionPatterns(RewritePatternSet &patterns,
3085static void populateAIEVecV2ConversionPatterns(RewritePatternSet &patterns,
3089 if (backend == TargetBackend::CPP) {
3092 >(patterns.getContext(), 128, 1024, 256, 1024);
3099 >(patterns.getContext());
3100 }
else if (backend == TargetBackend::LLVMIR){
3103 >(patterns.getContext());
3141 >(patterns.getContext());
3143 >(patterns.getContext(), backend == TargetBackend::CPP);
3153static bool isInSigmoidOperationChain(math::ExpOp expOp) {
3154 if (!expOp.getOperand().getDefiningOp<arith::NegFOp>())
3157 arith::AddFOp addOp =
nullptr;
3158 for (Operation *user : expOp->getUsers()) {
3159 addOp = dyn_cast<arith::AddFOp>(user);
3167 auto *addLvalOp = addOp.getLhs().getDefiningOp();
3168 auto *addRvalOp = addOp.getRhs().getDefiningOp();
3169 if (!((isa<math::ExpOp>(addLvalOp) && isa<arith::ConstantOp>(addRvalOp)) ||
3170 (isa<math::ExpOp>(addRvalOp) && isa<arith::ConstantOp>(addLvalOp))))
3173 auto constOp = isa<arith::ConstantOp>(addLvalOp)
3174 ? cast<arith::ConstantOp>(addLvalOp)
3175 : cast<arith::ConstantOp>(addRvalOp);
3177 auto cstDense = dyn_cast<DenseFPElementsAttr>(constOp.getValue());
3181 if (cstDense.getSplatValue<APFloat>().convertToFloat() != 1.0f)
3184 arith::DivFOp divOp =
nullptr;
3185 for (Operation *user : addOp->getUsers()) {
3186 divOp = dyn_cast<arith::DivFOp>(user);
3194 constOp = dyn_cast<arith::ConstantOp>(divOp.getLhs().getDefiningOp());
3197 cstDense = dyn_cast<DenseFPElementsAttr>(constOp.getValue());
3200 if (cstDense.getSplatValue<APFloat>().convertToFloat() != 1.0f)
3206static void configureAIEVecCommonLegalizations(ConversionTarget &target,
3208 target.addLegalDialect<xilinx::aievec::aie1::AIEVecAIE1Dialect,
3209 xilinx::aievec::AIEVecDialect, arith::ArithDialect,
3210 emitc::EmitCDialect, func::FuncDialect>();
3211 if (backend == TargetBackend::CPP) {
3212 target.addIllegalOp<vector::TransferReadOp>();
3214 target.addIllegalOp<vector::ExtractStridedSliceOp>();
3215 target.addLegalOp<vector::BitCastOp>();
3217 target.addDynamicallyLegalOp<arith::ExtFOp>([](arith::ExtFOp extfOp) {
3218 auto srcType = dyn_cast<VectorType>(extfOp.getIn().getType());
3219 auto dstType = dyn_cast<VectorType>(extfOp.getOut().getType());
3220 if (!srcType || !dstType)
3223 Type srcScalarType = srcType.getElementType();
3224 Type dstScalarType = dstType.getElementType();
3225 if (!isa<FloatType>(srcScalarType) || !isa<FloatType>(dstScalarType))
3230 unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth();
3231 unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth();
3232 return srcElWidth != 16 || srcLaneSize != 16 || dstElWidth != 32 ||
3236 target.addDynamicallyLegalOp<arith::ExtSIOp>([](arith::ExtSIOp extsiOp) {
3237 auto srcType = dyn_cast<VectorType>(extsiOp.getIn().getType());
3238 auto dstType = dyn_cast<VectorType>(extsiOp.getOut().getType());
3239 if (!srcType || !dstType)
3242 Type srcScalarType = srcType.getElementType();
3243 Type dstScalarType = dstType.getElementType();
3244 if (!isa<IntegerType>(srcScalarType) || !isa<IntegerType>(dstScalarType))
3249 unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth();
3250 unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth();
3251 return srcLaneSize != 32 || (dstElWidth <= srcElWidth) ||
3252 (dstLaneSize != srcLaneSize);
3255 target.addDynamicallyLegalOp<arith::TruncFOp>([](arith::TruncFOp truncfOp) {
3256 auto srcType = dyn_cast<VectorType>(truncfOp.getIn().getType());
3257 auto dstType = dyn_cast<VectorType>(truncfOp.getOut().getType());
3258 if (!srcType || !dstType)
3261 Type srcScalarType = srcType.getElementType();
3262 Type dstScalarType = dstType.getElementType();
3263 if (!isa<FloatType>(srcScalarType) || !isa<FloatType>(dstScalarType))
3268 unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth();
3269 unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth();
3270 return srcElWidth != 32 || srcLaneSize != 16 || dstElWidth != 16 ||
3274 target.addDynamicallyLegalOp<arith::TruncIOp>([](arith::TruncIOp trunciOp) {
3275 auto srcType = dyn_cast<VectorType>(trunciOp.getIn().getType());
3276 auto dstType = dyn_cast<VectorType>(trunciOp.getOut().getType());
3277 if (!srcType || !dstType)
3280 Type srcScalarType = srcType.getElementType();
3281 Type dstScalarType = dstType.getElementType();
3282 if (!isa<IntegerType>(srcScalarType) || !isa<IntegerType>(dstScalarType))
3287 unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth();
3288 unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth();
3290 return srcLaneSize != 32 || (dstElWidth >= srcElWidth) ||
3291 (dstLaneSize != srcLaneSize);
3294 target.addDynamicallyLegalOp<math::ExpOp>([](math::ExpOp expOp) {
3295 auto srcType = dyn_cast<VectorType>(expOp.getOperand().getType());
3299 Type scalarType = srcType.getElementType();
3300 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3302 if (!isa<FloatType>(scalarType) || laneSize != 16 || elWidth != 16)
3304 if (expOp->hasOneUse() && isInSigmoidOperationChain(expOp))
3310 target.addDynamicallyLegalOp<math::TanhOp>([](math::TanhOp tanhOp) {
3311 auto srcType = dyn_cast<VectorType>(tanhOp.getOperand().getType());
3315 Type scalarType = srcType.getElementType();
3316 if (!isa<FloatType>(scalarType))
3320 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3321 return elWidth != 16 || laneSize != 16;
3324 target.addDynamicallyLegalOp<math::SqrtOp>([](math::SqrtOp sqrtOp) {
3325 auto srcType = dyn_cast<VectorType>(sqrtOp.getOperand().getType());
3329 Type scalarType = srcType.getElementType();
3330 if (!isa<FloatType>(scalarType))
3334 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3335 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
3338 target.addDynamicallyLegalOp<math::RsqrtOp>([](math::RsqrtOp rsqrtOp) {
3339 auto srcType = dyn_cast<VectorType>(rsqrtOp.getOperand().getType());
3340 Type scalarType = srcType.getElementType();
3341 if (!srcType || !isa<FloatType>(scalarType))
3345 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3346 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
3349 target.addDynamicallyLegalOp<math::ErfOp>([](math::ErfOp erfOp) {
3350 auto srcType = dyn_cast<VectorType>(erfOp.getOperand().getType());
3354 Type scalarType = srcType.getElementType();
3355 if (!isa<FloatType>(scalarType))
3359 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3360 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
3363 target.addDynamicallyLegalOp<math::AbsFOp>([](math::AbsFOp absfOp) {
3364 auto srcType = dyn_cast<VectorType>(absfOp.getOperand().getType());
3368 Type scalarType = srcType.getElementType();
3370 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3371 return elWidth * laneSize != 512 && elWidth * laneSize != 256;
3374 target.addDynamicallyLegalOp<math::AbsIOp>([](math::AbsIOp absiOp) {
3375 auto srcType = dyn_cast<VectorType>(absiOp.getOperand().getType());
3379 Type scalarType = srcType.getElementType();
3381 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3382 return elWidth * laneSize != 512 && elWidth * laneSize != 256;
3385 target.addDynamicallyLegalOp<arith::DivFOp>([](arith::DivFOp divfOp) {
3386 if (
auto srcType = dyn_cast<VectorType>(divfOp.getLhs().getType());
3388 Type scalarType = divfOp.getLhs().getType();
3389 if (!divfOp->hasOneUse() || !isa<FloatType>(scalarType))
3391 if (!isNarrowingOp(*divfOp->getUsers().begin()))
3394 auto fType = cast<FloatType>(scalarType);
3395 if (fType.getWidth() != 32)
3399 dyn_cast<arith::ConstantOp>(divfOp.getLhs().getDefiningOp());
3401 cast<FloatAttr>(constOp.getValue()).getValue().convertToDouble() !=
3405 Type scalarType = srcType.getElementType();
3406 if (!isa<FloatType>(scalarType))
3410 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3412 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
3415 arith::NegFOp negOp =
nullptr;
3416 if (!hasSigmoidComputationChain(divfOp, negOp))
3423 target.addDynamicallyLegalOp<math::CeilOp>([](math::CeilOp ceilOp) {
3424 auto srcType = dyn_cast<VectorType>(ceilOp.getOperand().getType());
3427 Type scalarType = srcType.getElementType();
3428 if (!isa<FloatType>(scalarType))
3432 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3433 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
3436 target.addDynamicallyLegalOp<math::FloorOp>([](math::FloorOp floorOp) {
3437 auto srcType = dyn_cast<VectorType>(floorOp.getOperand().getType());
3440 Type scalarType = srcType.getElementType();
3441 if (!isa<FloatType>(scalarType))
3445 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3446 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
3449 target.addDynamicallyLegalOp<arith::NegFOp>([](arith::NegFOp negOp) {
3450 auto srcType = dyn_cast<VectorType>(negOp.getOperand().getType());
3453 if (Type scalarType = srcType.getElementType(); !isa<FloatType>(scalarType))
3457 return laneSize != 16;
3460 target.addDynamicallyLegalOp<arith::XOrIOp>([](arith::XOrIOp xorOp) {
3461 auto srcType = dyn_cast<VectorType>(xorOp.getLhs().getType());
3464 Type scalarType = srcType.getElementType();
3465 if (!isa<IntegerType>(scalarType))
3469 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3471 return laneSize * elWidth != 512;
3474 target.addDynamicallyLegalOp<arith::OrIOp>([](arith::OrIOp orOp) {
3475 auto srcType = dyn_cast<VectorType>(orOp.getLhs().getType());
3478 Type scalarType = srcType.getElementType();
3479 if (!isa<IntegerType>(scalarType))
3483 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3485 return laneSize * elWidth != 512;
3488 target.addDynamicallyLegalOp<arith::ShRSIOp>([](arith::ShRSIOp rsOp) {
3489 auto srcType = dyn_cast<VectorType>(rsOp.getLhs().getType());
3492 Type scalarType = srcType.getElementType();
3495 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3497 return laneSize * elWidth != 512;
3500 target.addDynamicallyLegalOp<arith::AndIOp>([](arith::AndIOp andOp) {
3501 auto srcType = dyn_cast<VectorType>(andOp.getLhs().getType());
3504 Type scalarType = srcType.getElementType();
3505 if (!isa<IntegerType>(scalarType))
3509 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3511 return laneSize * elWidth != 512;
3514 if (backend == TargetBackend::CPP) {
3515 target.addDynamicallyLegalOp<arith::AddIOp>(
3516 [](arith::AddIOp op) {
return !isa<VectorType>(op.getType()); });
3518 target.addDynamicallyLegalOp<arith::AddFOp>(
3519 [](arith::AddFOp op) {
return !isa<VectorType>(op.getType()); });
3520 target.addDynamicallyLegalOp<arith::SubIOp>(
3521 [](arith::SubIOp op) {
return !isa<VectorType>(op.getType()); });
3522 target.addDynamicallyLegalOp<arith::SubFOp>(
3523 [](arith::SubFOp op) {
return !isa<VectorType>(op.getType()); });
3526static void configureAIEVecV1Legalizations(ConversionTarget &target,
3528 target.addDynamicallyLegalOp<arith::MulIOp>(
3529 [](arith::MulIOp op) {
return !isa<VectorType>(op.getType()); });
3530 target.addDynamicallyLegalOp<arith::MulFOp>(
3531 [](arith::MulFOp op) {
return !isa<VectorType>(op.getType()); });
3532 target.addDynamicallyLegalOp<aievec::aie1::FMAOp>(
3533 [](xilinx::aievec::aie1::FMAOp op) {
3534 auto *lhsDefOp = op.getLhs().getDefiningOp();
3535 aievec::ConcatOp concatOp =
nullptr;
3537 concatOp = dyn_cast<aievec::ConcatOp>(op.getLhs().getDefiningOp());
3541 vector::SplatOp srcSplat =
nullptr;
3542 if (
auto *lhsOp = concatOp.getSources()[0].getDefiningOp())
3543 srcSplat = dyn_cast<vector::SplatOp>(lhsOp);
3545 auto *rhsOp = op.getRhs().getDefiningOp();
3548 srcSplat = dyn_cast<vector::SplatOp>(rhsOp);
3552 if (
auto *srcOp = srcSplat.getInput().getDefiningOp())
3553 return !isa<vector::ExtractOp>(srcOp);
3558 target.addDynamicallyLegalOp<aievec::aie1::AddOp>([](aievec::aie1::AddOp op) {
3559 auto lSrsOp = op.getLhs().getDefiningOp<aievec::SRSOp>();
3560 auto rSrsOp = op.getRhs().getDefiningOp<aievec::SRSOp>();
3562 !lSrsOp.getSource().getDefiningOp<aievec::aie1::MulOp>()) &&
3564 !rSrsOp.getSource().getDefiningOp<aievec::aie1::MulOp>());
3566 target.addLegalDialect<memref::MemRefDialect>();
3569static void configureAIEVecV2Legalizations(ConversionTarget &target,
3571 target.addLegalOp<UnrealizedConversionCastOp>();
3572 target.addLegalOp<vector::ShapeCastOp>();
3575 llvm::SmallSet<std::pair<unsigned, unsigned>, 16> laneSizeElWidthPairSet;
3576 laneSizeElWidthPairSet.insert({64, 8});
3577 laneSizeElWidthPairSet.insert({32, 16});
3578 laneSizeElWidthPairSet.insert({16, 32});
3579 laneSizeElWidthPairSet.insert({32, 32});
3582 llvm::SmallSet<unsigned, 16> elWidthSet;
3583 elWidthSet.insert(8);
3584 elWidthSet.insert(16);
3585 elWidthSet.insert(32);
3587 if (backend == TargetBackend::CPP) {
3588 target.addDynamicallyLegalOp<arith::AddIOp>([=](arith::AddIOp op) {
3589 auto resultType = dyn_cast<VectorType>(op.getType());
3593 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3596 return !laneSizeElWidthPairSet.count(
3597 std::make_pair(laneSize, resultElWidth));
3601 target.addDynamicallyLegalOp<arith::SubIOp>([=](arith::SubIOp op) {
3602 auto resultType = dyn_cast<VectorType>(op.getType());
3605 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3608 return !laneSizeElWidthPairSet.count(
3609 std::make_pair(laneSize, resultElWidth));
3612 target.addDynamicallyLegalOp<arith::AddFOp>([](arith::AddFOp op) {
3613 auto resultType = dyn_cast<VectorType>(op.getType());
3618 return laneSize != 16;
3621 target.addDynamicallyLegalOp<arith::SubFOp>([](arith::SubFOp op) {
3622 auto resultType = dyn_cast<VectorType>(op.getType());
3627 return laneSize != 16;
3630 target.addDynamicallyLegalOp<arith::MulIOp>([](arith::MulIOp op) {
3631 auto resultType = dyn_cast<VectorType>(op.getType());
3634 auto isAddOp = [&](Operation *op) {
return isa<arith::AddIOp>(op); };
3636 if (op->hasOneUse() && llvm::any_of(op->getUsers(), isAddOp))
3639 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3642 return (laneSize != 32 || (resultElWidth != 16 && resultElWidth != 8)) &&
3643 ((laneSize != 16 && laneSize != 32) || resultElWidth != 32);
3646 target.addDynamicallyLegalOp<arith::MulFOp>([](arith::MulFOp op) {
3647 auto resultType = dyn_cast<VectorType>(op.getType());
3651 auto isAddOp = [&](Operation *op) {
return isa<arith::AddFOp>(op); };
3653 if (op->hasOneUse() && llvm::any_of(op->getUsers(), isAddOp))
3656 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3659 return laneSize != 16 || (resultElWidth != 16 && resultElWidth != 32);
3662 target.addDynamicallyLegalOp<arith::MinSIOp>([=](arith::MinSIOp op) {
3663 auto resultType = dyn_cast<VectorType>(op.getType());
3667 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3670 return !elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512;
3673 target.addDynamicallyLegalOp<arith::MaxSIOp>([=](arith::MaxSIOp op) {
3674 auto resultType = dyn_cast<VectorType>(op.getType());
3678 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3681 return !elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512;
3684 target.addDynamicallyLegalOp<arith::MinimumFOp>([=](arith::MinimumFOp op) {
3685 auto resultType = dyn_cast<VectorType>(op.getType());
3689 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3692 return !elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512;
3695 target.addDynamicallyLegalOp<arith::MaximumFOp>([=](arith::MaximumFOp op) {
3696 auto resultType = dyn_cast<VectorType>(op.getType());
3700 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3703 return !elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512;
3706 target.addDynamicallyLegalOp<arith::CmpIOp>([=](arith::CmpIOp op) {
3707 auto lhsType = dyn_cast<VectorType>(op.getLhs().getType());
3711 auto lhsElWidth = lhsType.getElementType().getIntOrFloatBitWidth();
3714 return !elWidthSet.count(lhsElWidth) || laneSize * lhsElWidth != 512;
3717 target.addDynamicallyLegalOp<arith::CmpFOp>([=](arith::CmpFOp op) {
3718 auto lhsType = dyn_cast<VectorType>(op.getLhs().getType());
3722 auto lhsElWidth = lhsType.getElementType().getIntOrFloatBitWidth();
3725 return !elWidthSet.count(lhsElWidth) || laneSize * lhsElWidth != 512;
3728 target.addDynamicallyLegalOp<arith::SelectOp>([=](arith::SelectOp op) {
3729 auto resultType = dyn_cast<VectorType>(op.getType());
3733 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3736 return !elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512;
3739 target.addDynamicallyLegalOp<vector::ReductionOp>(
3740 [=](vector::ReductionOp op) {
3741 if (
auto kind = op.getKind(); kind != vector::CombiningKind::ADD &&
3742 kind != vector::CombiningKind::MINSI &&
3743 kind != vector::CombiningKind::MINUI &&
3744 kind != vector::CombiningKind::MINIMUMF &&
3745 kind != vector::CombiningKind::MAXSI &&
3746 kind != vector::CombiningKind::MAXUI &&
3747 kind != vector::CombiningKind::MAXIMUMF)
3750 auto vType = dyn_cast<VectorType>(op.getVector().getType());
3754 llvm::SmallSet<std::pair<unsigned, signed>, 16> laneSizeElWidthPairSet;
3755 laneSizeElWidthPairSet.insert({64, 8});
3756 laneSizeElWidthPairSet.insert({32, 16});
3757 laneSizeElWidthPairSet.insert({32, 32});
3758 laneSizeElWidthPairSet.insert({16, 32});
3760 Type scalarType = vType.getElementType();
3761 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3764 if (isa<IntegerType>(scalarType) &&
3765 !laneSizeElWidthPairSet.count(std::make_pair(laneSize, elWidth)))
3768 if (isa<FloatType>(scalarType) && laneSize != 16 && laneSize != 32)
3774 target.addIllegalOp<vector::ContractionOp, vector::TransposeOp,
3798 StringRef
getArgument() const final {
return "test-lower-vector-to-aievec"; }
3800 return "Lower vector operations to AIE vector intrinsics";
3804 .insert<affine::AffineDialect, xilinx::aievec::aie1::AIEVecAIE1Dialect,
3805 xilinx::aievec::AIEVecDialect, arith::ArithDialect,
3806 memref::MemRefDialect, scf::SCFDialect, vector::VectorDialect,
3807 emitc::EmitCDialect>();
3811 *
this,
"aie-target",
3812 llvm::cl::desc(
"Select AIE version: \"aie\" or \"aie2\". This will "
3813 "determine the vector size and available operations."),
3814 llvm::cl::init(
"aie")};
3817 *
this,
"target-backend",
3818 llvm::cl::desc(
"Select translation backend: \"cpp\" or \"llvmir\". This "
3819 "will determine the aievec operations used to convert "
3820 "from vector dialect."),
3821 llvm::cl::init(
"cpp")};
3824 auto *op = getOperation();
3825 MLIRContext *context = &getContext();
3826 RewritePatternSet patterns(context);
3827 ConversionTarget target(*context);
3828 auto aieVersion = AIEArch::AIE;
3831 if (target ==
"aieml" || target ==
"aie2")
3832 aieVersion = AIEArch::AIE2;
3833 else if (target !=
"aie") {
3834 op->emitError() <<
"unknown AIE target '" <<
aieTarget <<
"'";
3835 return signalPassFailure();
3842 if (backendStr ==
"llvmir") {
3843 backend = TargetBackend::LLVMIR;
3844 if (aieVersion == AIEArch::AIE) {
3845 op->emitError() <<
"targetting LLVM IR is not supported for AIEv1";
3846 signalPassFailure();
3849 }
else if (backendStr !=
"cpp") {
3850 op->emitError() <<
"unknown target backend'" <<
targetBackend <<
"'";
3851 signalPassFailure();
3856 populateAIEVecCommonConversionPatterns(patterns, backend);
3857 configureAIEVecCommonLegalizations(target, backend);
3858 if (aieVersion == AIEArch::AIE) {
3859 populateAIEVecV1ConversionPatterns(patterns, backend);
3860 configureAIEVecV1Legalizations(target, backend);
3862 populateAIEVecV2ConversionPatterns(patterns, backend);
3863 configureAIEVecV2Legalizations(target, backend);
3866 if (failed(applyPartialConversion(op, target, std::move(patterns))))
3867 return signalPassFailure();
3871static std::unique_ptr<Pass>
3873 return std::make_unique<LowerVectorToAIEVec>(options);
3887 MLIRContext *context = &getContext();
3888 RewritePatternSet patterns(context);
3889 ConversionTarget target(*context);
3891 target.addLegalDialect<aievec::AIEVecDialect>();
3892 target.addDynamicallyLegalOp<aievec::UPDOp>([](aievec::UPDOp op) {
3893 return op.getVector() ||
3894 (op->hasOneUse() && isa<aievec::UPDOp>(*op->getUsers().begin())) ||
3895 llvm::all_of(op->getUsers(),
3896 [](Operation *op) {
return isa<aievec::ExtOp>(op); });
3899 if (
auto *op = getOperation();
3900 failed(applyPartialConversion(op, target, std::move(patterns)))) {
3901 return signalPassFailure();
3914 MLIRContext *context = &getContext();
3915 RewritePatternSet patterns(context);
3916 ConversionTarget target(*context);
3918 target.addLegalDialect<aievec::AIEVecDialect>();
3919 target.addDynamicallyLegalOp<aievec::ExtOp>([](aievec::ExtOp op) {
3920 auto *defOp = op.getSource().getDefiningOp();
3921 return !defOp || !isa<aievec::UPDOp>(defOp) || !defOp->hasOneUse() ||
3925 if (
auto *op = getOperation();
3926 failed(applyPartialConversion(op, target, std::move(patterns)))) {
3927 return signalPassFailure();
3939 pm.addPass(createLowerVectorToAIEVec(options));
3940 pm.addPass(createCanonicalizerPass());
3943 pm.addPass(std::make_unique<ExtendUPDOpsPass>());
3944 pm.addPass(createCSEPass());
3945 pm.addPass(std::make_unique<SimplifyUPDOpsPass>());
3946 pm.addPass(createCanonicalizerPass());
LowerVectorMinMaxOpToAIEVecMinMaxOp< arith::MaximumFOp, aievec::MaxOp > LowerVectorMaximumFOpToAIEVecMaxOp
ComputeBandAndBorOpPattern< arith::OrIOp, aievec::BorOp > ComputeBorOpPattern
ComputeBandAndBorOpPattern< arith::AndIOp, aievec::BandOp > ComputeBandOpPattern
OneToOneVectorOpToAIEVecOpPattern< arith::SubFOp, aievec::aie1::SubOp > LowerVectorSubFOpToAIEVecSubOp
ComputeAbsOpPattern< math::AbsIOp > ComputeAbsIOpPattern
LowerTruncOpPattern< arith::TruncFOp > LowerTruncFOpPattern
LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp< arith::AddFOp, aievec::AddElemOp > LowerVectorAddFOpToAIEVecAddElemOp
LowerVectorMinMaxOpToAIEVecMinMaxOp< arith::MaxSIOp, aievec::MaxOp > LowerVectorMaxSIOpToAIEVecMaxOp
LowerExtOpPattern< arith::ExtFOp > LowerExtFOpPattern
LowerVectorCmpOpToAIEVecCmpOp< arith::CmpFOp, CmpFPredicate > LowerVectorCmpFOpToAIEVecCmpOp
OneToOneVectorOpToAIEVecOpPattern< arith::SubIOp, aievec::aie1::SubOp > LowerVectorSubIOpToAIEVecSubOp
ComputeAbsOpPattern< math::AbsFOp > ComputeAbsFOpPattern
LowerVectorCmpOpToAIEVecCmpOp< arith::CmpIOp, CmpIPredicate > LowerVectorCmpIOpToAIEVecCmpOp
LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp< arith::SubFOp, aievec::SubElemOp > LowerVectorSubFOpToAIEVecSubElemOp
LowerVectorMinMaxOpToAIEVecMinMaxOp< arith::MinSIOp, aievec::MinOp > LowerVectorMinSIOpToAIEVecMinOp
OneToOneVectorOpToAIEVecOpPattern< arith::AddFOp, aievec::aie1::AddOp > LowerVectorAddFOpToAIEVecAddOp
OneToOneVectorOpToAIEVecOpPattern< arith::MulFOp, aievec::aie1::MulOp > LowerVectorMulFOpToAIEVecMulOp
LowerExtOpPattern< arith::ExtSIOp > LowerExtSIOpPattern
LowerVectorMinMaxOpToAIEVecMinMaxOp< arith::MinimumFOp, aievec::MinOp > LowerVectorMinimumFOpToAIEVecMinOp
LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp< arith::AddIOp, aievec::AddElemOp > LowerVectorAddIOpToAIEVecAddElemOp
mlir::VectorType getFlattenedVectorType(mlir::VectorType vecTy)
unsigned getVectorLaneSize(mlir::VectorType type)
SmallVector< NamedAttribute > buildFMAOpSplatAttrForElemTy(aievec::aie1::FMAOp fmaOp, int64_t bcastPos, int64_t step=1)
std::optional< int64_t > getTransferReadAlignmentOffset(TransferReadLikeOp readOp, mlir::VectorType vType, int64_t alignment)
mlir::VectorType createVectorType(unsigned lanes, mlir::Type elementType)
int32_t getElementSizeInBits(mlir::VectorType type)
void buildLowerVectorToAIEVec(mlir::OpPassManager &pm, const LowerVectorToAIEVecOptions &options)
mlir::VectorType getVectorOpDestType(mlir::VectorType type, bool AIE2)
LogicalResult matchAndRewrite(SrcOpTy absOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(arith::XOrIOp xorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::CeilOp ceilOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::ErfOp erfOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::ExpOp expOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::ExpOp expOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::FloorOp floorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::DivFOp divOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::NegFOp negOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::RsqrtOp rsqrtOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::DivFOp divfOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::ShRSIOp rsOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::SqrtOp sqrtOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::TanhOp tanhOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::AddIOp addOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertMulAddToAIEVecFMAElemOpPattern(MLIRContext *context, unsigned shiftParam=0)
LogicalResult matchAndRewrite(aievec::aie1::AddOp addOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::MulFOp mulOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertMulFToAIEVecMulElemOpPattern(MLIRContext *context, unsigned shiftParam=0)
LogicalResult matchAndRewrite(arith::MulIOp mulOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertMulIToAIEVecMulElemOpPattern(MLIRContext *context, unsigned shiftParam=0)
LogicalResult matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertVectorFMAOpToAIEVecFMAElemOpPattern(MLIRContext *context, unsigned shiftParam=0)
ExpandUPDToUPDAndExtPattern(MLIRContext *context)
LogicalResult matchAndRewrite(aievec::UPDOp updOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
void runOnOperation() override
LogicalResult matchAndRewrite(aievec::aie1::FMAOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ExtOp extOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
FuseExtIntoUPDPattern(MLIRContext *context)
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(SrcOpTy extOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(SrcOpTy truncOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(arith::AddIOp addOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SrcOpTy::Adaptor OpAdaptor
LowerVectorContractionOpToAIEVecMatMulPattern(MLIRContext *context, bool matMoveToAcc=true)
LogicalResult matchAndRewrite(vector::ContractionOp contractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Value reshapeLeadingUnitDims(OpBuilder &b, Value v) const
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(arith::MulIOp mulOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::SelectOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Lower incoming vector operations into their corresponding AIE vector intrinsics.
void getDependentDialects(DialectRegistry ®istry) const override
LowerVectorToAIEVec(const LowerVectorToAIEVecOptions &options)
void runOnOperation() override
StringRef getDescription() const final
Option< std::string > aieTarget
Option< std::string > targetBackend
StringRef getArgument() const final
LogicalResult matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LowerVectorTransferReadToAIEUPD(MLIRContext *context, int64_t minVectorSize, int64_t maxVectorSize, int64_t alignment, int64_t maxLoadSize)
LogicalResult matchAndRewrite(vector::TransposeOp transpOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
void runOnOperation() override
Options for the "lower-vector-to-aievec" pipeline.
PassOptions::Option< std::string > aieTarget
PassOptions::Option< std::string > targetBackend