21#include "mlir/Dialect/Affine/IR/AffineOps.h"
22#include "mlir/Dialect/EmitC/IR/EmitC.h"
23#include "mlir/Dialect/Func/IR/FuncOps.h"
24#include "mlir/Dialect/Math/IR/Math.h"
25#include "mlir/Dialect/MemRef/IR/MemRef.h"
26#include "mlir/Dialect/SCF/IR/SCF.h"
27#include "mlir/Dialect/UB/IR/UBOps.h"
28#include "mlir/IR/PatternMatch.h"
29#include "mlir/IR/SymbolTable.h"
30#include "mlir/IR/TypeUtilities.h"
31#include "mlir/Pass/PassManager.h"
32#include "mlir/Transforms/DialectConversion.h"
33#include "mlir/Transforms/Passes.h"
34#include "llvm/ADT/SmallSet.h"
39#define DEBUG_TYPE "lower-vector-to-aievec"
44using namespace vector;
52static bool isNarrowingOp(Operation *op) {
53 if (isa<arith::TruncFOp>(op) || isa<arith::TruncIOp>(op))
56 if (
auto srsOp = dyn_cast<aievec::SRSOp>(op)) {
57 auto *srsOpSrcOp = srsOp.getSource().getDefiningOp();
58 if (isa<aievec::UPSOp>(srsOpSrcOp) || isa<aievec::CastOp>(srsOpSrcOp))
67static std::optional<Value> getSourceOfWideningOp(Value src) {
68 if (
auto extSIOp =
src.getDefiningOp<arith::ExtSIOp>())
69 return extSIOp.getIn();
70 if (
auto extUIOp =
src.getDefiningOp<arith::ExtUIOp>())
71 return extUIOp.getIn();
72 if (
auto extFOp =
src.getDefiningOp<arith::ExtFOp>())
73 return extFOp.getIn();
74 if (
auto srsOp =
src.getDefiningOp<aievec::SRSOp>()) {
78 auto srsSource = srsOp.getSource();
80 if (
auto upsOp = srsSource.getDefiningOp<aievec::UPSOp>())
81 return upsOp.getSource();
83 if (
auto castOp =
src.getDefiningOp<aievec::CastOp>()) {
87 auto castSource = castOp.getSource();
89 if (
auto upsOp = castSource.getDefiningOp<aievec::UPSOp>())
90 return upsOp.getSource();
92 return std::optional<Value>();
98static std::optional<std::tuple<Value, Value, Value>>
99extractMACOperandsFromAddOperands(Value addLhs, Value addRhs) {
100 auto *lhsDefOp = addLhs.getDefiningOp();
101 auto *rhsDefOp = addRhs.getDefiningOp();
102 arith::MulIOp mulOp =
nullptr;
105 mulOp = dyn_cast<arith::MulIOp>(lhsDefOp);
108 if (!mulOp && rhsDefOp) {
109 mulOp = dyn_cast<arith::MulIOp>(rhsDefOp);
113 return std::make_tuple(mulOp.getLhs(), mulOp.getRhs(), acc);
116 auto lhsSrsOp = addLhs.getDefiningOp<aievec::SRSOp>();
117 auto rhsSrsOp = addRhs.getDefiningOp<aievec::SRSOp>();
118 aievec::aie1::MulOp aieMulOp =
nullptr;
120 aieMulOp = lhsSrsOp.getSource().getDefiningOp<aievec::aie1::MulOp>();
123 if (!aieMulOp && rhsSrsOp) {
124 aieMulOp = rhsSrsOp.getSource().getDefiningOp<aievec::aie1::MulOp>();
128 return std::make_tuple(aieMulOp.getLhs(), aieMulOp.getRhs(), acc);
135static std::optional<Value>
136convertValueToTargetTypeAIE2(ConversionPatternRewriter &rewriter, Location loc,
137 Value inputVal, VectorType tgtType) {
138 auto srcType = cast<VectorType>(inputVal.getType());
139 auto srcElemType = srcType.getElementType();
140 unsigned srcBitWidth = srcElemType.getIntOrFloatBitWidth();
143 auto tgtElemType = tgtType.getElementType();
144 unsigned tgtBitWidth = tgtElemType.getIntOrFloatBitWidth();
147 if (srcType == tgtType)
150 if ((srcElemType == tgtElemType) && (srcLaneSize != tgtLaneSize)) {
152 if ((srcLaneSize == 16 && tgtLaneSize == 32 &&
153 isa<FloatType>(srcElemType)) ||
154 (srcLaneSize == 32 && tgtLaneSize == 64 &&
155 isa<IntegerType>(srcElemType))) {
156 auto zeroConstOp = rewriter.create<arith::ConstantOp>(
157 loc, srcType.getElementType(),
158 rewriter.getZeroAttr(srcType.getElementType()));
159 auto broadcastZeroOp = rewriter.create<aievec::BroadcastScalarOp>(
160 loc, tgtType, zeroConstOp->getResult(0));
161 auto extOp = rewriter.create<aievec::ExtOp>(
162 loc, srcType, broadcastZeroOp->getResult(0), 0);
164 SmallVector<Value> inputSources = {inputVal, extOp->getResult(0)};
166 rewriter.create<aievec::ConcatOp>(loc, tgtType, inputSources);
168 return concatOp.getResult();
170 }
else if ((srcElemType != tgtElemType) && (srcLaneSize == tgtLaneSize) &&
171 isa<IntegerType>(srcElemType) && isa<IntegerType>(tgtElemType)) {
172 if (srcBitWidth == 16 && tgtBitWidth == 32 && srcLaneSize == 16) {
176 auto upsOp = rewriter.create<aievec::UPSOp>(loc, accType, inputVal);
177 auto castOp = rewriter.create<aievec::CastOp>(
178 loc, tgtType, upsOp.getResult(),
false);
179 return castOp.getResult();
182 if (srcBitWidth == 8 && tgtBitWidth == 32 && srcLaneSize == 16) {
186 auto concatOp = rewriter.create<aievec::ConcatOp>(
187 loc, concatOutType, SmallVector<Value>({inputVal, inputVal}));
190 rewriter.create<aievec::UPSOp>(loc, accType, concatOp.getResult());
192 auto castOp = rewriter.create<aievec::CastOp>(
193 loc, castType, upsOp.getResult(),
false);
195 rewriter.create<aievec::ExtOp>(loc, tgtType, castOp.getResult(), 0);
196 return extOp.getResult();
199 if (srcBitWidth == 8 && tgtBitWidth == 16 && srcLaneSize == 32) {
201 auto unpackOp = rewriter.create<aievec::UnpackOp>(loc, tgtType, inputVal);
202 return unpackOp.getResult();
212static SmallVector<NamedAttribute>
213buildAttributeListForRotationSelectOp(PatternRewriter &rewriter, VectorType vTy,
216 auto elemTy = vTy.getElementType();
217 if (
auto intTy = dyn_cast<IntegerType>(elemTy))
218 width = intTy.getWidth();
219 StringAttr attr0 = rewriter.getStringAttr(
"0");
220 StringAttr attr0x06040200 = rewriter.getStringAttr(
"0x06040200");
221 StringAttr attr0x0e0c0a08 = rewriter.getStringAttr(
"0x0e0c0a08");
222 StringAttr attr0x2103 = rewriter.getStringAttr(
"0x2103");
223 StringAttr attr0x3210 = rewriter.getStringAttr(
"0x3210");
224 StringAttr selectAttrName = rewriter.getStringAttr(
"select");
225 StringAttr xoffsetsAttrName = rewriter.getStringAttr(
"xoffsets");
226 StringAttr xoffsetsHiAttrName = rewriter.getStringAttr(
"xoffsets_hi");
227 StringAttr xsquareAttrName = rewriter.getStringAttr(
"xsquare");
228 StringAttr xstartAttrName = rewriter.getStringAttr(
"xstart");
229 StringAttr yoffsetsAttrName = rewriter.getStringAttr(
"yoffsets");
230 StringAttr yoffsetsHiAttrName = rewriter.getStringAttr(
"yoffsets_hi");
231 StringAttr ysquareAttrName = rewriter.getStringAttr(
"ysquare");
232 StringAttr ystartAttrName = rewriter.getStringAttr(
"ystart");
237 int64_t xstart = rotation + 1;
238 int64_t ystart = rotation - 1;
239 return SmallVector<NamedAttribute, 9>(
240 {{selectAttrName, rewriter.getStringAttr(
"0x11111111")},
241 {xoffsetsAttrName, attr0x06040200},
242 {xoffsetsHiAttrName, attr0x0e0c0a08},
243 {xsquareAttrName, attr0x2103},
244 {xstartAttrName, rewriter.getStringAttr(std::to_string(xstart))},
245 {yoffsetsAttrName, rewriter.getStringAttr(
"0x0503010f")},
246 {yoffsetsHiAttrName, rewriter.getStringAttr(
"0x0d0b0907")},
247 {ysquareAttrName, attr0x2103},
248 {ystartAttrName, rewriter.getStringAttr(std::to_string(ystart))}});
250 return SmallVector<NamedAttribute, 9>(
251 {{selectAttrName, attr0},
252 {xoffsetsAttrName, attr0x06040200},
253 {xoffsetsHiAttrName, attr0x0e0c0a08},
254 {xsquareAttrName, attr0x3210},
255 {xstartAttrName, rewriter.getStringAttr(std::to_string(rotation))},
256 {yoffsetsAttrName, attr0},
257 {yoffsetsHiAttrName, attr0},
258 {ysquareAttrName, attr0},
259 {ystartAttrName, attr0}});
262 return SmallVector<NamedAttribute, 7>(
263 {{selectAttrName, attr0},
264 {xoffsetsAttrName, rewriter.getStringAttr(
"0x76543210")},
265 {xsquareAttrName, attr0x3210},
266 {xstartAttrName, rewriter.getStringAttr(std::to_string(rotation))},
267 {yoffsetsAttrName, attr0},
268 {ysquareAttrName, attr0},
269 {ystartAttrName, attr0}});
271 llvm::report_fatal_error(
"Unexpected width!");
279SmallVector<NamedAttribute>
283 auto elemTy = fmaOp.getLhs().getType().getElementType();
284 if (
auto intTy = dyn_cast<IntegerType>(elemTy))
285 width = intTy.getWidth();
286 auto *ctx = fmaOp.getContext();
306 return SmallVector<NamedAttribute, 11>(
307 {{fmaOp.getXstartAttrName(), StringAttr::get(ctx,
"0")},
308 {fmaOp.getXoffsetsAttrName(), StringAttr::get(ctx,
"0x73727170")},
309 {fmaOp.getXoffsetsHiAttrName(), StringAttr::get(ctx,
"0x77767574")},
310 {fmaOp.getXstepAttrName(), fmaOp.getXstepAttr()},
311 {fmaOp.getXsquareAttrName(), StringAttr::get(ctx,
"0x3120")},
312 {fmaOp.getZstartAttrName(),
313 StringAttr::get(ctx, std::to_string(bcastPos))},
314 {fmaOp.getZoffsetsAttrName(), StringAttr::get(ctx,
"0")},
315 {fmaOp.getZoffsetsHiAttrName(), StringAttr::get(ctx,
"0")},
316 {fmaOp.getZstepAttrName(), StringAttr::get(ctx, std::to_string(step))},
317 {fmaOp.getZsquareAttrName(), fmaOp.getZsquareAttr()},
318 {fmaOp.getFmsubAttrName(), fmaOp.getFmsubAttr()}});
320 return SmallVector<NamedAttribute, 11>(
321 {{fmaOp.getXstartAttrName(), StringAttr::get(ctx,
"0")},
322 {fmaOp.getXoffsetsAttrName(), StringAttr::get(ctx,
"0x76543210")},
323 {fmaOp.getXoffsetsHiAttrName(), fmaOp.getXoffsetsHiAttr()},
324 {fmaOp.getXstepAttrName(), fmaOp.getXstepAttr()},
325 {fmaOp.getXsquareAttrName(), fmaOp.getXsquareAttr()},
326 {fmaOp.getZstartAttrName(),
327 StringAttr::get(ctx, std::to_string(bcastPos))},
328 {fmaOp.getZoffsetsAttrName(), StringAttr::get(ctx,
"0x00000000")},
329 {fmaOp.getZoffsetsHiAttrName(), fmaOp.getZoffsetsHiAttr()},
330 {fmaOp.getZstepAttrName(), fmaOp.getZstepAttr()},
331 {fmaOp.getZsquareAttrName(), fmaOp.getZsquareAttr()},
332 {fmaOp.getFmsubAttrName(), fmaOp.getFmsubAttr()}});
334 llvm::report_fatal_error(
"Unexpected width!");
342template <
typename SrcOpTy,
typename AIEv2ElemOp>
343static LogicalResult genAddElemAIE2(ConversionPatternRewriter &rewriter,
344 Value lval, Value rval, VectorType srcType,
346 auto lCastOp = rewriter.create<aievec::CastOp>(srcOp.getLoc(), srcType, lval,
348 auto rCastOp = rewriter.create<aievec::CastOp>(srcOp.getLoc(), srcType, rval,
350 auto elemOp = rewriter.create<AIEv2ElemOp>(
351 srcOp.getLoc(), lCastOp->getResult(0).getType(), lCastOp->getResult(0),
352 rCastOp->getResult(0));
353 rewriter.replaceOpWithNewOp<aievec::CastOp>(
354 srcOp, srcOp.getType(), elemOp.getResult(),
false);
358static arith::CmpIPredicate
359convertToIntegerPredicate(arith::CmpFPredicate pred) {
361 case CmpFPredicate::UEQ:
362 case CmpFPredicate::OEQ:
363 return CmpIPredicate::eq;
364 case CmpFPredicate::UGT:
365 return CmpIPredicate::ugt;
366 case CmpFPredicate::OGT:
367 return CmpIPredicate::sgt;
368 case CmpFPredicate::UGE:
369 return CmpIPredicate::uge;
370 case CmpFPredicate::OGE:
371 return CmpIPredicate::sge;
372 case CmpFPredicate::ULT:
373 return CmpIPredicate::ult;
374 case CmpFPredicate::OLT:
375 return CmpIPredicate::slt;
376 case CmpFPredicate::ULE:
377 return CmpIPredicate::ule;
378 case CmpFPredicate::OLE:
379 return CmpIPredicate::sle;
380 case CmpFPredicate::UNE:
381 case CmpFPredicate::ONE:
382 return CmpIPredicate::ne;
384 llvm::report_fatal_error(
"Unexpected predicate!");
388static arith::CmpIPredicate
389convertToIntegerPredicate(arith::CmpIPredicate pred) {
393static aievec::CmpOp createCmpOpAIE2(ConversionPatternRewriter &rewriter,
394 CmpIPredicate pred, Location loc,
395 Type type, Value lhs, Value rhs) {
397 case CmpIPredicate::eq:
398 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs,
"eq");
399 case CmpIPredicate::ne:
400 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs,
"ne");
401 case CmpIPredicate::slt:
402 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs,
"slt");
403 case CmpIPredicate::ult:
404 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs,
"ult");
405 case CmpIPredicate::sle:
406 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs,
"sle");
407 case CmpIPredicate::ule:
408 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs,
"ule");
409 case CmpIPredicate::sgt:
410 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs,
"sgt");
411 case CmpIPredicate::ugt:
412 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs,
"ugt");
413 case CmpIPredicate::sge:
414 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs,
"sge");
415 case CmpIPredicate::uge:
416 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs,
"uge");
421template <
typename DstOpTy>
422static void generateAIEVecOpsForReductionOp(ConversionPatternRewriter &rewriter,
423 vector::ReductionOp srcOp,
424 int shiftIndex, Value curValue) {
425 assert(shiftIndex > 0 && (shiftIndex & (shiftIndex - 1)) == 0 &&
426 "shiftIndex must be power of 2");
428 Location loc = srcOp.getLoc();
429 auto vType = dyn_cast<VectorType>(curValue.getType());
430 Type scalarType = vType.getElementType();
431 Type vecType = curValue.getType();
432 DstOpTy curOp =
nullptr;
433 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
435 for (
int id = shiftIndex;
id > 0;
id /= 2) {
436 auto constOp = rewriter.create<arith::ConstantOp>(
437 loc, rewriter.getI32IntegerAttr(
id * elWidth / 8));
439 auto shiftBytesOp = rewriter.create<aievec::ShiftOp>(
440 loc, vecType, curValue, curValue, constOp.getResult());
442 curOp = rewriter.create<DstOpTy>(loc, vecType, curValue,
443 shiftBytesOp.getResult());
445 curValue = curOp.getResult();
449 rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
450 rewriter.replaceOpWithNewOp<aievec::ExtElemOp>(srcOp, scalarType, curOp,
451 zeroConstOp.getResult());
454static func::FuncOp getOrInsertFuncDecl(ConversionPatternRewriter &rewriter,
455 mlir::ModuleOp parentModuleOp,
456 StringRef funcName, TypeRange inTypes,
457 TypeRange outTypes) {
459 mlir::OpBuilder::InsertionGuard insertGuard(rewriter);
460 rewriter.setInsertionPointToStart(
461 &parentModuleOp.getRegion().getBlocks().front());
462 SymbolTable st = SymbolTable(parentModuleOp);
463 func::FuncOp fnOpLookup = st.lookup<func::FuncOp>(funcName);
467 if (fnOpLookup != NULL) {
470 StringAttr t1 = rewriter.getStringAttr(
"sym_visibility");
471 StringAttr t2 = rewriter.getStringAttr(
"private");
472 NamedAttribute funcAccess = NamedAttribute(t1, t2);
473 FunctionType fnType =
474 mlir::FunctionType::get(rewriter.getContext(), inTypes, outTypes);
475 fnOp = rewriter.create<func::FuncOp>(parentModuleOp.getLoc(), funcName,
481static bool matchExpOpForLUT(math::ExpOp::Adaptor adaptor) {
482 auto srcType = dyn_cast<VectorType>(adaptor.getOperand().getType());
487 Type scalarType = srcType.getElementType();
488 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
490 return isa<FloatType>(scalarType) && laneSize == 16 && elWidth == 16;
501 using OpConversionPattern::OpConversionPattern;
505 ConversionPatternRewriter &rewriter)
const override {
507 auto extOp = adaptor.getInput().getDefiningOp<vector::ExtractOp>();
512 auto src = extOp.getVector();
513 auto pos = extOp.getStaticPosition();
514 int64_t posVal = pos[0];
515 auto srcVecType = cast<VectorType>(src.getType());
516 auto resultType = cast<VectorType>(splatOp.getResult().getType());
517 if (srcVecType != resultType) {
518 if (srcVecType.getNumElements() != 2 * resultType.getNumElements())
520 auto half =
static_cast<int8_t
>(posVal / resultType.getNumElements());
521 posVal -= half * resultType.getNumElements();
523 .create<aievec::ExtOp>(extOp.getLoc(), resultType, src,
524 rewriter.getI8IntegerAttr(half))
528 unsigned elWidth = resultType.getElementType().getIntOrFloatBitWidth();
531 laneSize * elWidth == 512) {
533 rewriter.replaceOpWithNewOp<aievec::BroadcastOp>(splatOp, resultType, src,
535 }
else if (laneSize * elWidth == 256) {
537 VectorType aievecBcastType =
539 auto concatOp = rewriter.create<aievec::ConcatOp>(
540 splatOp.getLoc(), aievecBcastType, SmallVector<Value>({src, src}));
541 auto aieBcastOp = rewriter.create<aievec::BroadcastOp>(
542 splatOp.getLoc(), aievecBcastType, concatOp.getResult(), posVal);
543 rewriter.replaceOpWithNewOp<aievec::ExtOp>(splatOp, resultType,
544 aieBcastOp.getResult(), 0);
545 }
else if (laneSize * elWidth == 1024) {
547 VectorType aievecBcastType =
549 auto half =
static_cast<int8_t
>(posVal / resultType.getNumElements());
550 posVal -= half * resultType.getNumElements();
552 rewriter.create<aievec::ExtOp>(splatOp.getLoc(), aievecBcastType, src,
553 rewriter.getI8IntegerAttr(half));
554 auto aieBcastOp = rewriter.create<aievec::BroadcastOp>(
555 splatOp.getLoc(), aievecBcastType, extOp.getResult(), posVal);
556 rewriter.replaceOpWithNewOp<aievec::ConcatOp>(
558 SmallVector<Value>({aieBcastOp.getResult(), aieBcastOp.getResult()}));
568 using OpConversionPattern::OpConversionPattern;
572 ConversionPatternRewriter &rewriter)
const override {
574 if (adaptor.getInput().getDefiningOp<vector::ExtractOp>())
577 auto resultType = cast<VectorType>(splatOp.getResult().getType());
579 Type scalarType = resultType.getElementType();
580 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
582 auto src = splatOp.getInput();
584 if (laneSize * elWidth == 512) {
585 Value newOp = rewriter.create<aievec::BroadcastScalarOp>(
586 splatOp.getLoc(), flatResultType, src);
587 if (resultType != flatResultType)
588 newOp = rewriter.create<vector::ShapeCastOp>(splatOp.getLoc(),
590 rewriter.replaceOp(splatOp, newOp);
594 if (laneSize * elWidth == 256) {
596 auto aieBcastOp = rewriter.create<aievec::BroadcastScalarOp>(
597 splatOp.getLoc(), vecType, src);
598 Value newOp = rewriter.create<aievec::ExtOp>(
599 splatOp.getLoc(), flatResultType, aieBcastOp.getResult(), 0);
600 if (resultType != flatResultType)
601 newOp = rewriter.create<vector::ShapeCastOp>(splatOp.getLoc(),
603 rewriter.replaceOp(splatOp, newOp);
607 if (laneSize * elWidth == 1024) {
609 auto aieBcastOp = rewriter.create<aievec::BroadcastScalarOp>(
610 splatOp.getLoc(), vecType, src);
611 Value newOp = rewriter.create<aievec::ConcatOp>(
612 splatOp.getLoc(), flatResultType,
613 SmallVector<Value>({aieBcastOp.getResult(), aieBcastOp.getResult()}));
614 if (resultType != flatResultType)
615 newOp = rewriter.create<vector::ShapeCastOp>(splatOp.getLoc(),
617 rewriter.replaceOp(splatOp, newOp);
629 using OpConversionPattern::OpConversionPattern;
637 ConversionPatternRewriter &rewriter)
const override {
639 auto resultType = dyn_cast<VectorType>(addOp.getType());
645 extractMACOperandsFromAddOperands(adaptor.getLhs(), adaptor.getRhs());
648 auto [lhs, rhs, acc] = *res;
651 unsigned resultElWidth =
652 resultType.getElementType().getIntOrFloatBitWidth();
655 if ((laneSize != 32 || resultElWidth != 16) &&
656 (laneSize != 16 || resultElWidth != 32))
661 auto upsOp = rewriter.create<aievec::UPSOp>(addOp.getLoc(), accType, acc,
663 auto fmaElemOp = rewriter.create<aievec::FMAElemOp>(
664 addOp.getLoc(), accType, lhs, rhs, upsOp.getResult(),
667 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
668 addOp.getLoc(), rewriter.getI32IntegerAttr(
shiftParam));
669 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
670 addOp, resultType, fmaElemOp.getResult(), shiftParamOp.getResult());
686 using OpConversionPattern::OpConversionPattern;
694 ConversionPatternRewriter &rewriter)
const override {
696 auto resVecTy = cast<VectorType>(fmaOp.getType());
697 auto resElemTy = resVecTy.getElementType();
700 if (numElems != 16 || (!resElemTy.isF32() && !resElemTy.isBF16()))
701 return rewriter.notifyMatchFailure(
702 fmaOp,
"Unsupported operand types in vector.fma lowering.");
704 Value lhs = adaptor.getLhs();
705 Value rhs = adaptor.getRhs();
706 Value acc = adaptor.getAcc();
707 if (resElemTy.isBF16())
708 acc = rewriter.create<aievec::UPSOp>(
709 fmaOp.getLoc(), VectorType::get({16}, rewriter.getF32Type()), acc,
712 lhs = getSourceOfWideningOp(lhs).value_or(
nullptr);
713 rhs = getSourceOfWideningOp(rhs).value_or(
nullptr);
715 return rewriter.notifyMatchFailure(
716 fmaOp,
"vector.fma operands are f32, and they don't come from "
717 "arith.extf on bf16; can't lower to aievec.");
718 if (!cast<VectorType>(lhs.getType()).getElementType().isBF16() ||
719 !cast<VectorType>(rhs.getType()).getElementType().isBF16())
720 return rewriter.notifyMatchFailure(
721 fmaOp,
"vector.fma operands come from arith.extf, but the source "
722 "of the widening op is not bf16; can't lower to aievec.");
724 Value newOp = rewriter.create<aievec::FMAElemOp>(
725 fmaOp.getLoc(), acc.getType(), lhs, rhs, acc,
false);
727 if (resElemTy.isBF16()) {
728 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
729 fmaOp.getLoc(), rewriter.getI32IntegerAttr(
shiftParam));
730 newOp = rewriter.create<aievec::SRSOp>(fmaOp.getLoc(), resVecTy, newOp,
734 rewriter.replaceOp(fmaOp, newOp);
746 using OpConversionPattern::OpConversionPattern;
754 ConversionPatternRewriter &rewriter)
const override {
756 auto resultType = dyn_cast<VectorType>(mulOp.getType());
761 auto isAddOp = [&](Operation *op) {
return isa<arith::AddFOp>(op); };
762 if (mulOp->hasOneUse() && llvm::any_of(mulOp->getUsers(), isAddOp))
765 unsigned resultElWidth =
766 resultType.getElementType().getIntOrFloatBitWidth();
771 if (laneSize != 16 || (resultElWidth != 16 && resultElWidth != 32))
775 auto lval = adaptor.getLhs();
776 auto rval = adaptor.getRhs();
777 lval = getSourceOfWideningOp(lval).value_or(lval);
778 rval = getSourceOfWideningOp(rval).value_or(rval);
779 auto lSrcType = cast<VectorType>(lval.getType());
780 auto rSrcType = cast<VectorType>(rval.getType());
781 unsigned lBitWidth = lSrcType.getElementType().getIntOrFloatBitWidth();
782 unsigned rBitWidth = rSrcType.getElementType().getIntOrFloatBitWidth();
784 if (rBitWidth > lBitWidth) {
788 if (lSrcType != rSrcType) {
793 unsigned bitWidth = (rBitWidth > lBitWidth) ? rBitWidth : lBitWidth;
794 Type srcElemType = (rBitWidth > lBitWidth) ? rSrcType.getElementType()
795 : lSrcType.getElementType();
796 unsigned numLanes = 0;
797 if (isa<FloatType>(srcElemType) && (bitWidth == 16 || bitWidth == 32)) {
799 }
else if (isa<IntegerType>(srcElemType) &&
800 (bitWidth == 8 || bitWidth == 16)) {
802 }
else if (isa<IntegerType>(srcElemType) && (bitWidth == 32)) {
808 if (targetInputType != lSrcType) {
809 lval = convertValueToTargetTypeAIE2(rewriter, mulOp.getLoc(), lval,
813 if (targetInputType != rSrcType) {
814 rval = convertValueToTargetTypeAIE2(rewriter, mulOp.getLoc(), rval,
823 rewriter.create<aievec::MulElemOp>(mulOp.getLoc(), accType, lval, rval);
826 auto mulElemResultType = mulElemOp.getType();
827 auto mulElemResultElWidth =
828 mulElemResultType.getElementType().getIntOrFloatBitWidth();
830 if (mulElemResultElWidth == resultElWidth) {
831 rewriter.replaceOpWithNewOp<aievec::CastOp>(
832 mulOp, resultType, mulElemOp.getResult(),
false);
833 }
else if (mulElemResultElWidth > resultElWidth) {
834 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
835 mulOp.getLoc(), rewriter.getI32IntegerAttr(
shiftParam));
836 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
837 mulOp, resultType, mulElemOp.getResult(), shiftParamOp.getResult());
852 using OpConversionPattern::OpConversionPattern;
860 ConversionPatternRewriter &rewriter)
const override {
862 auto resultType = dyn_cast<VectorType>(mulOp.getType());
867 auto isAddOp = [&](Operation *op) {
return isa<arith::AddIOp>(op); };
868 if (mulOp->hasOneUse() && llvm::any_of(mulOp->getUsers(), isAddOp))
872 unsigned resultElWidth =
873 resultType.getElementType().getIntOrFloatBitWidth();
876 if ((laneSize != 32 || (resultElWidth != 16 && resultElWidth != 8)) &&
877 ((laneSize != 16 && laneSize != 32) || resultElWidth != 32))
881 auto lval = adaptor.getLhs();
882 auto rval = adaptor.getRhs();
884 lval = getSourceOfWideningOp(lval).value_or(lval);
885 rval = getSourceOfWideningOp(rval).value_or(rval);
887 auto lSrcType = cast<VectorType>(lval.getType());
888 auto rSrcType = cast<VectorType>(rval.getType());
889 unsigned lBitWidth = lSrcType.getElementType().getIntOrFloatBitWidth();
890 unsigned rBitWidth = rSrcType.getElementType().getIntOrFloatBitWidth();
892 if (rBitWidth > lBitWidth) {
897 unsigned bitWidth = (rBitWidth > lBitWidth) ? rBitWidth : lBitWidth;
898 Type srcElemType = (rBitWidth > lBitWidth) ? rSrcType.getElementType()
899 : lSrcType.getElementType();
900 unsigned numLanes = 0;
901 if (isa<FloatType>(srcElemType) && (bitWidth == 16 || bitWidth == 32)) {
903 }
else if (isa<IntegerType>(srcElemType) &&
904 (bitWidth == 8 || bitWidth == 16)) {
906 }
else if (isa<IntegerType>(srcElemType) && (bitWidth == 32)) {
912 if (targetInputType != lSrcType) {
913 lval = convertValueToTargetTypeAIE2(rewriter, mulOp.getLoc(), lval,
917 if (targetInputType != rSrcType) {
918 rval = convertValueToTargetTypeAIE2(rewriter, mulOp.getLoc(), rval,
927 rewriter.create<aievec::MulElemOp>(mulOp.getLoc(), accType, lval, rval);
930 auto mulElemResultType = mulElemOp.getType();
931 auto mulElemResultElWidth =
932 mulElemResultType.getElementType().getIntOrFloatBitWidth();
934 if (mulElemResultElWidth == resultElWidth) {
935 rewriter.replaceOpWithNewOp<aievec::CastOp>(
936 mulOp, resultType, mulElemOp.getResult(),
false);
937 }
else if (mulElemResultElWidth > resultElWidth) {
938 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
939 mulOp.getLoc(), rewriter.getI32IntegerAttr(
shiftParam));
940 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
941 mulOp, resultType, mulElemOp.getResult(), shiftParamOp.getResult());
955 using OpConversionPattern::OpConversionPattern;
959 ConversionPatternRewriter &rewriter)
const override {
961 dyn_cast<aievec::ConcatOp>(adaptor.getLhs().getDefiningOp());
964 vector::SplatOp splatOp =
nullptr;
965 auto *concatDefOp = concatOp.getSources()[0].getDefiningOp();
967 splatOp = dyn_cast<vector::SplatOp>(concatDefOp);
968 Value lhs = adaptor.getRhs();
970 splatOp = dyn_cast<vector::SplatOp>(adaptor.getRhs().getDefiningOp());
973 lhs = concatOp.getSources()[0];
976 dyn_cast<vector::ExtractOp>(splatOp.getInput().getDefiningOp());
980 auto rhs = extOp.getVector();
981 auto concatVecType = cast<VectorType>(concatOp.getResult().getType());
982 auto zvec = rewriter.create<arith::ConstantOp>(
983 concatOp.getLoc(), lhs.getType(), rewriter.getZeroAttr(lhs.getType()));
986 .create<aievec::ConcatOp>(concatOp.getLoc(), concatVecType,
987 SmallVector<Value, 2>({lhs, zvec}))
990 auto pos = extOp.getStaticPosition();
991 int64_t zstart = pos[0];
993 rewriter.replaceOpWithNewOp<aievec::aie1::FMAOp>(
994 fmaOp, TypeRange({fmaOp.getResult().getType()}),
995 ValueRange({lhsX2, rhs, adaptor.getAcc()}), fmaOpAttr);
1003 using OpConversionPattern::OpConversionPattern;
1007 ConversionPatternRewriter &rewriter)
const override {
1008 auto vecType = cast<VectorType>(addOp.getType());
1011 extractMACOperandsFromAddOperands(adaptor.getLhs(), adaptor.getRhs());
1014 auto [lhs, rhs, acc] = *res;
1016 SmallVector<int64_t, 4> concatVecShape(vecType.getShape().begin(),
1017 vecType.getShape().end());
1018 concatVecShape[vecType.getRank() - 1] *= 2;
1019 auto concatVecType =
1020 VectorType::get(concatVecShape, vecType.getElementType());
1023 auto lhsX2 = rewriter
1024 .create<aievec::ConcatOp>(addOp.getLoc(), concatVecType,
1025 SmallVector<Value, 2>(2, lhs))
1027 auto upsOp = rewriter.create<aievec::UPSOp>(addOp.getLoc(), accType, acc);
1028 auto fmaOp = rewriter.create<aievec::aie1::FMAOp>(
1029 addOp.getLoc(), accType, lhsX2, rhs, upsOp.getResult(),
1033 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
1034 addOp.getLoc(), rewriter.getI32IntegerAttr(0));
1035 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1036 addOp, vecType, fmaOp.getResult(), shiftParamOp.getResult());
1046 using OpConversionPattern::OpConversionPattern;
1057 ConversionPatternRewriter &rewriter)
const override {
1059 if (readOp.getMask())
1060 return readOp.emitError() <<
"AIE doesn't support masked loads.";
1063 AffineMap map = readOp.getPermutationMap();
1064 if (!map.isMinorIdentity())
1068 if (map.isConstant())
1072 auto vType = readOp.getVectorType();
1082 int64_t vSize = vType.getNumElements() * vType.getElementTypeBitWidth();
1091 if ((vSize >
minVectorSize) && std::bitset<8>(multiplicity).count() != 1)
1094 auto updOp = rewriter.create<xilinx::aievec::UPDOp>(
1095 readOp.getLoc(), vType, adaptor.getBase(), adaptor.getIndices(), 0, 0,
1096 TypedValue<VectorType>(
nullptr));
1098 updOp = rewriter.create<xilinx::aievec::UPDOp>(
1099 readOp.getLoc(), vType, adaptor.getBase(), adaptor.getIndices(),
1102 rewriter.replaceOp(readOp, updOp.getResult());
1112template <
typename SrcOpTy,
typename DstOpTy>
1119 ConversionPatternRewriter &rewriter)
const override {
1120 rewriter.replaceOpWithNewOp<DstOpTy>(
1121 srcOp, srcOp.getResult().getType(), adaptor.getLhs(), adaptor.getRhs(),
1129 using OpConversionPattern::OpConversionPattern;
1133 ConversionPatternRewriter &rewriter)
const override {
1134 auto resType = addOp.getType();
1135 if (!isa<VectorType>(resType))
1138 auto lhs = adaptor.getLhs();
1139 auto rhs = adaptor.getRhs();
1140 auto *lhsDefOp = lhs.getDefiningOp();
1141 auto *rhsDefOp = rhs.getDefiningOp();
1142 if ((isa_and_nonnull<arith::MulIOp>(lhsDefOp)) ||
1143 (isa_and_nonnull<arith::MulIOp>(rhsDefOp)))
1146 rewriter.replaceOpWithNewOp<aievec::aie1::AddOp>(
1147 addOp, resType, lhs, rhs,
1164 using OpConversionPattern::OpConversionPattern;
1167 ConversionPatternRewriter &rewriter)
const override {
1168 auto resTy = dyn_cast<VectorType>(mulOp.getType());
1172 auto newMulOp = rewriter.create<aievec::aie1::MulOp>(
1173 mulOp.getLoc(), accTy, adaptor.getLhs(), adaptor.getRhs());
1174 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
1175 mulOp.getLoc(), rewriter.getI32IntegerAttr(0));
1176 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1177 mulOp, resTy, newMulOp.getResult(), shiftParamOp.getResult());
1182template <
typename SrcOpTy,
typename DstOpTy>
1190 ConversionPatternRewriter &rewriter)
const override {
1191 VectorType resultType = dyn_cast<VectorType>(srcOp.getType());
1197 llvm::SmallSet<std::pair<unsigned, signed>, 16> laneSizeElWidthPairSet;
1198 laneSizeElWidthPairSet.insert({64, 8});
1199 laneSizeElWidthPairSet.insert({32, 16});
1200 laneSizeElWidthPairSet.insert({16, 32});
1201 laneSizeElWidthPairSet.insert({32, 32});
1203 auto lhs = adaptor.getLhs();
1204 auto rhs = adaptor.getRhs();
1205 auto lhsDefOp = lhs.getDefiningOp();
1206 auto rhsDefOp = rhs.getDefiningOp();
1207 if ((lhsDefOp && isa<arith::MulIOp>(lhsDefOp)) ||
1208 (rhsDefOp && isa<arith::MulIOp>(rhsDefOp)) ||
1209 (lhsDefOp && isa<arith::MulFOp>(lhsDefOp)) ||
1210 (rhsDefOp && isa<arith::MulFOp>(rhsDefOp)))
1213 Type scalarType = resultType.getElementType();
1214 unsigned resultElWidth = scalarType.getIntOrFloatBitWidth();
1218 if (isa<IntegerType>(scalarType)) {
1219 if (!laneSizeElWidthPairSet.count(
1220 std::make_pair(laneSize, resultElWidth)))
1226 if (!lhsDefOp && !rhsDefOp) {
1227 if (laneSize * resultElWidth == 512) {
1228 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getType(), lhs,
1232 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs, resultType,
1237 if (resultElWidth == 32) {
1238 auto lhsExt = getSourceOfWideningOp(lhs).value_or(
nullptr);
1239 auto rhsExt = getSourceOfWideningOp(rhs).value_or(
nullptr);
1241 if (!lhsExt && !rhsExt) {
1242 if (laneSize * resultElWidth == 512) {
1243 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getType(), lhs,
1247 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1251 if (lhsExt && rhsExt) {
1254 VectorType lSrcType = cast<VectorType>(lval.getType());
1258 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, lval);
1260 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, rval);
1261 auto elemOp = rewriter.create<DstOpTy>(
1262 srcOp.getLoc(), lUpsOp->getResult(0).getType(),
1263 lUpsOp->getResult(0), rUpsOp->getResult(0));
1264 rewriter.replaceOpWithNewOp<aievec::CastOp>(
1265 srcOp, srcOp.getType(), elemOp.getResult(),
false);
1269 if (!lhsExt || !rhsExt) {
1270 auto lval = lhsExt ? lhsExt : lhs;
1271 auto rval = rhsExt ? rhsExt : rhs;
1272 auto extVal = lhsExt ? lval : rval;
1273 VectorType vType = cast<VectorType>(extVal.getType());
1274 unsigned bitWidth = vType.getElementType().getIntOrFloatBitWidth();
1276 if (bitWidth != 8 && bitWidth != 16) {
1277 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1281 if (bitWidth * laneSize != 256) {
1282 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1286 Type accType =
nullptr;
1288 if (bitWidth == 8) {
1290 Value valToUps = lhsExt ? lval : rval;
1291 Value valToCast = lhsExt ? rval : lval;
1292 auto upsOp = rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType,
1294 auto castOp = rewriter.create<aievec::CastOp>(
1295 srcOp.getLoc(), resultType, valToCast,
true);
1297 lhsExt ? upsOp->getResult(0) : castOp->getResult(0);
1299 lhsExt ? castOp->getResult(0) : upsOp->getResult(0);
1300 auto elemOp = rewriter.create<DstOpTy>(
1301 srcOp.getLoc(), upsOp->getResult(0).getType(), lhsToElemOp,
1303 rewriter.replaceOpWithNewOp<aievec::CastOp>(
1304 srcOp, srcOp.getType(), elemOp.getResult(),
false);
1308 if (bitWidth == 16) {
1311 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, lval);
1313 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, rval);
1315 auto elemOp = rewriter.create<DstOpTy>(
1316 srcOp.getLoc(), lUpsOp->getResult(0).getType(),
1317 lUpsOp->getResult(0), rUpsOp->getResult(0));
1319 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
1320 srcOp.getLoc(), rewriter.getI32IntegerAttr(0));
1321 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1322 srcOp, srcOp.getType(), elemOp.getResult(),
1323 shiftParamOp.getResult());
1328 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getType(), lhs, rhs);
1338 if (resultElWidth == 32) {
1339 if (!lhsDefOp && !rhsDefOp) {
1340 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1344 auto lhsExt = getSourceOfWideningOp(lhs).value_or(
nullptr);
1345 auto rhsExt = getSourceOfWideningOp(rhs).value_or(
nullptr);
1347 if (!lhsExt && !rhsExt) {
1348 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1353 if (lhsExt && rhsExt) {
1356 VectorType vType = cast<VectorType>(lval.getType());
1360 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, lval);
1362 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, rval);
1363 auto elemOp = rewriter.create<DstOpTy>(
1364 srcOp.getLoc(), lUpsOp->getResult(0).getType(),
1365 lUpsOp->getResult(0), rUpsOp->getResult(0));
1366 rewriter.replaceOpWithNewOp<aievec::CastOp>(srcOp, srcOp.getType(),
1367 elemOp.getResult());
1372 if (!lhsExt || !rhsExt) {
1373 auto lval = lhsExt ? lhsExt : lhs;
1374 auto rval = rhsExt ? rhsExt : rhs;
1375 auto extVal = lhsExt ? lval : rval;
1376 VectorType vType = cast<VectorType>(extVal.getType());
1379 aievec::UPSOp upsOp;
1380 aievec::CastOp castOp;
1383 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, lval);
1384 castOp = rewriter.create<aievec::CastOp>(srcOp.getLoc(), resultType,
1389 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, rval);
1390 castOp = rewriter.create<aievec::CastOp>(srcOp.getLoc(), resultType,
1395 auto elemOp = rewriter.create<DstOpTy>(
1396 srcOp.getLoc(), upsOp->getResult(0).getType(),
1397 upsOp->getResult(0), castOp->getResult(0));
1399 rewriter.replaceOpWithNewOp<aievec::CastOp>(
1400 srcOp, srcOp.getType(), elemOp.getResult(),
false);
1409 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, lhs);
1411 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, rhs);
1412 auto elemOp = rewriter.create<DstOpTy>(
1413 srcOp.getLoc(), lUpsOp->getResult(0).getType(), lUpsOp->getResult(0),
1414 rUpsOp->getResult(0));
1415 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
1416 srcOp.getLoc(), rewriter.getI32IntegerAttr(0));
1417 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1418 srcOp, srcOp.getType(), elemOp.getResult(), shiftParamOp.getResult());
1440template <
typename SrcOpTy,
typename DstOpTy>
1447 ConversionPatternRewriter &rewriter)
const override {
1448 VectorType resultType = dyn_cast<VectorType>(srcOp.getType());
1453 llvm::SmallSet<unsigned, 16> elWidthSet;
1454 elWidthSet.insert(8);
1455 elWidthSet.insert(16);
1456 elWidthSet.insert(32);
1458 Type scalarType = resultType.getElementType();
1459 unsigned resultElWidth = scalarType.getIntOrFloatBitWidth();
1462 if (!elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512)
1465 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getType(),
1466 adaptor.getLhs(), adaptor.getRhs());
1480template <
typename SrcOpTy,
typename CmpTy>
1487 ConversionPatternRewriter &rewriter)
const override {
1488 VectorType lhsType = dyn_cast<VectorType>(srcOp.getLhs().getType());
1492 llvm::SmallSet<unsigned, 16> elWidthSet;
1493 elWidthSet.insert(8);
1494 elWidthSet.insert(16);
1495 elWidthSet.insert(32);
1497 Type scalarType = lhsType.getElementType();
1498 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
1501 if (!elWidthSet.count(elWidth) || laneSize * elWidth != 512)
1506 mlir::IntegerType::get(srcOp.getContext(), laneSize <= 32 ? 32 : 64,
1507 mlir::IntegerType::Unsigned);
1509 Location loc = srcOp.getLoc();
1510 Value lhs = srcOp.getLhs();
1511 Value rhs = srcOp.getRhs();
1512 CmpTy pred = srcOp.getPredicate();
1514 arith::CmpIPredicate ipred = convertToIntegerPredicate(pred);
1516 aievec::CmpOp aieCmpOp =
1517 createCmpOpAIE2(rewriter, ipred, loc, type, lhs, rhs);
1522 VectorType resultType = dyn_cast<VectorType>(srcOp.getResult().getType());
1525 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
1526 srcOp, resultType, aieCmpOp.getResult());
1538 using OpConversionPattern::OpConversionPattern;
1542 ConversionPatternRewriter &rewriter)
const override {
1543 auto resultType = dyn_cast<VectorType>(srcOp.getType());
1547 llvm::SmallSet<unsigned, 16> elWidthSet;
1548 elWidthSet.insert(8);
1549 elWidthSet.insert(16);
1550 elWidthSet.insert(32);
1552 Type scalarType = resultType.getElementType();
1553 unsigned resultElWidth = scalarType.getIntOrFloatBitWidth();
1556 if (!elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512)
1560 mlir::IntegerType::get(srcOp.getContext(), laneSize <= 32 ? 32 : 64,
1561 mlir::IntegerType::Unsigned);
1563 auto convertOp = rewriter.create<UnrealizedConversionCastOp>(
1564 srcOp.getLoc(), type, adaptor.getCondition());
1566 rewriter.replaceOpWithNewOp<aievec::SelOp>(
1567 srcOp, srcOp.getResult().getType(), srcOp.getTrueValue(),
1568 srcOp.getFalseValue(), convertOp.getResult(0));
1575 using OpConversionPattern::OpConversionPattern;
1579 ConversionPatternRewriter &rewriter)
const override {
1580 if (
auto kind = srcOp.getKind(); kind != vector::CombiningKind::MINSI &&
1581 kind != vector::CombiningKind::MINUI &&
1582 kind != vector::CombiningKind::MINIMUMF)
1585 auto vType = cast<VectorType>(srcOp.getVector().getType());
1586 Type scalarType = vType.getElementType();
1587 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
1590 if (laneSize * elWidth != 512)
1593 int shiftIndex = laneSize / 2;
1594 generateAIEVecOpsForReductionOp<aievec::MinOp>(rewriter, srcOp, shiftIndex,
1601 using OpConversionPattern::OpConversionPattern;
1605 ConversionPatternRewriter &rewriter)
const override {
1606 if (
auto kind = srcOp.getKind(); kind != vector::CombiningKind::MAXSI &&
1607 kind != vector::CombiningKind::MAXUI &&
1608 kind != vector::CombiningKind::MAXIMUMF)
1611 auto vType = cast<VectorType>(srcOp.getVector().getType());
1612 Type scalarType = vType.getElementType();
1613 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
1616 if (laneSize * elWidth != 512)
1619 int shiftIndex = laneSize / 2;
1620 generateAIEVecOpsForReductionOp<aievec::MaxOp>(rewriter, srcOp, shiftIndex,
1627 using OpConversionPattern::OpConversionPattern;
1631 ConversionPatternRewriter &rewriter)
const override {
1632 if (
auto kind = srcOp.getKind(); kind != vector::CombiningKind::ADD)
1635 auto vType = cast<VectorType>(srcOp.getVector().getType());
1636 Type scalarType = vType.getElementType();
1637 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
1639 llvm::SmallSet<std::pair<unsigned, signed>, 16> laneSizeElWidthPairSet;
1640 laneSizeElWidthPairSet.insert({64, 8});
1641 laneSizeElWidthPairSet.insert({32, 16});
1642 laneSizeElWidthPairSet.insert({32, 32});
1643 laneSizeElWidthPairSet.insert({16, 32});
1645 if (!isa<IntegerType>(scalarType) ||
1646 !laneSizeElWidthPairSet.count(std::make_pair(laneSize, elWidth)))
1649 int shiftIndex = laneSize / 2;
1650 if (laneSize == 32 && elWidth == 32) {
1651 Location loc = srcOp.getLoc();
1655 rewriter.create<aievec::ExtOp>(loc, vecType, srcOp.getVector(), 0);
1657 rewriter.create<aievec::ExtOp>(loc, vecType, srcOp.getVector(), 1);
1658 auto addElemOp = rewriter.create<aievec::AddElemOp>(
1659 loc, lExtOp.getResult().getType(), lExtOp.getResult(),
1660 rExtOp.getResult());
1662 generateAIEVecOpsForReductionOp<aievec::AddElemOp>(
1663 rewriter, srcOp, shiftIndex, addElemOp.getResult());
1665 generateAIEVecOpsForReductionOp<aievec::AddElemOp>(
1666 rewriter, srcOp, shiftIndex, srcOp.getVector());
1674 using OpConversionPattern::OpConversionPattern;
1678 ConversionPatternRewriter &rewriter)
const override {
1679 if (
auto kind = srcOp.getKind(); kind != vector::CombiningKind::ADD)
1682 auto vType = cast<VectorType>(srcOp.getVector().getType());
1683 Type scalarType = vType.getElementType();
1684 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
1687 if (!isa<FloatType>(scalarType) || laneSize != 16 || elWidth != 32)
1690 int shiftIndex = laneSize / 2;
1691 assert(shiftIndex > 0 && (shiftIndex & (shiftIndex - 1)) == 0 &&
1692 "shiftIndex must be power of 2");
1694 Location loc = srcOp.getLoc();
1695 Value curValue = srcOp.getVector();
1696 aievec::CastOp curOp =
nullptr;
1698 for (
int id = shiftIndex;
id > 0;
id /= 2) {
1699 auto constOp = rewriter.create<arith::ConstantOp>(
1700 loc, rewriter.getI32IntegerAttr(
id * elWidth / 8));
1702 auto shiftBytesOp = rewriter.create<aievec::ShiftOp>(
1703 loc, vType, curValue, curValue, constOp.getResult());
1705 auto lCastOp = rewriter.create<aievec::CastOp>(loc, vType, curValue,
1708 rewriter.create<aievec::CastOp>(loc, vType, shiftBytesOp.getResult(),
1710 auto elemOp = rewriter.create<aievec::AddElemOp>(
1711 loc, lCastOp.getResult().getType(), lCastOp.getResult(),
1712 rCastOp.getResult());
1713 curOp = rewriter.create<aievec::CastOp>(loc, vType, elemOp.getResult(),
1715 curValue = curOp.getResult();
1719 rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
1720 rewriter.replaceOpWithNewOp<aievec::ExtElemOp>(srcOp, scalarType, curOp,
1721 zeroConstOp.getResult());
1728 using OpConversionPattern::OpConversionPattern;
1732 ConversionPatternRewriter &rewriter)
const override {
1733 if (
auto kind = srcOp.getKind(); kind != vector::CombiningKind::ADD)
1736 auto vType = cast<VectorType>(srcOp.getVector().getType());
1737 Type scalarType = vType.getElementType();
1738 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
1741 if (!isa<FloatType>(scalarType) || laneSize != 16 || elWidth != 16)
1744 int shiftIndex = laneSize / 2;
1745 assert(shiftIndex > 0 && (shiftIndex & (shiftIndex - 1)) == 0 &&
1746 "shiftIndex must be power of 2");
1748 Value curValue = srcOp.getVector();
1749 Location loc = srcOp.getLoc();
1752 dyn_cast<VectorType>(accType).getElementType().getIntOrFloatBitWidth();
1755 rewriter.create<aievec::UPSOp>(loc, accType, srcOp.getVector());
1757 curValue = upsOp.getResult();
1760 aievec::AddElemOp curOp =
nullptr;
1762 for (
int id = shiftIndex;
id > 0;
id /= 2) {
1763 auto constOp = rewriter.create<arith::ConstantOp>(
1764 loc, rewriter.getI32IntegerAttr(
id * accWidth / 8));
1765 auto shiftBytesOp = rewriter.create<aievec::ShiftOp>(
1766 loc, accType, curValue, curValue, constOp,
true);
1767 curOp = rewriter.create<aievec::AddElemOp>(loc, accType, curValue,
1768 shiftBytesOp.getResult());
1769 curValue = curOp.getResult();
1772 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
1773 srcOp.getLoc(), rewriter.getI32IntegerAttr(0));
1774 auto srsOp = rewriter.create<aievec::SRSOp>(loc, vType, curOp.getResult(),
1775 shiftParamOp.getResult());
1776 SmallVector<Value> concatSources = {srsOp.getResult(), srsOp.getResult()};
1778 rewriter.create<aievec::ConcatOp>(loc, vecType, concatSources);
1781 rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
1782 rewriter.replaceOpWithNewOp<aievec::ExtElemOp>(srcOp, scalarType, concatOp,
1783 zeroConstOp.getResult());
1792 using OpConversionPattern::OpConversionPattern;
1796 ConversionPatternRewriter &rewriter)
const override {
1797 auto vType = extractOp.getSourceVectorType();
1798 if (vType.getRank() != 1)
1801 int64_t stride = cast<IntegerAttr>(adaptor.getStrides()[0]).getInt();
1807 return extractOp.emitError()
1808 <<
"AIEv1 doesn't support select ops on int8 types";
1812 int64_t size = cast<IntegerAttr>(adaptor.getSizes()[0]).getInt();
1813 if (vType.getNumElements() != 2 * size)
1816 int64_t offset = cast<IntegerAttr>(adaptor.getOffsets()[0]).getInt();
1817 auto selectOp = rewriter.create<aievec::aie1::SelectOp>(
1818 extractOp.getLoc(), vType, adaptor.getVector(),
1819 buildAttributeListForRotationSelectOp(rewriter, vType, offset));
1820 rewriter.replaceOpWithNewOp<aievec::aie1::ExtOp>(
1821 extractOp, extractOp.getType(), selectOp.getResult(),
1822 rewriter.getI8IntegerAttr(0));
1831 using OpConversionPattern::OpConversionPattern;
1835 ConversionPatternRewriter &rewriter)
const override {
1836 auto vType = cast<VectorType>(adaptor.getVector().getType());
1837 if (vType.getRank() != 1)
1840 int64_t stride = cast<IntegerAttr>(adaptor.getStrides()[0]).getInt();
1846 int64_t size = cast<IntegerAttr>(adaptor.getSizes()[0]).getInt();
1847 if (vType.getNumElements() != 2 * size)
1850 auto shortVecType = cast<VectorType>(extractOp.getResult().getType());
1851 auto bottomHalf = rewriter
1852 .create<aievec::ExtOp>(
1853 extractOp.getLoc(), shortVecType,
1854 adaptor.getVector(), rewriter.getI8IntegerAttr(0))
1856 auto topHalf = rewriter
1857 .create<aievec::ExtOp>(extractOp.getLoc(), shortVecType,
1858 adaptor.getVector(),
1859 rewriter.getI8IntegerAttr(1))
1861 int64_t offset = cast<IntegerAttr>(adaptor.getOffsets()[0]).getInt();
1863 auto shiftBytesConstOp = rewriter.create<arith::ConstantOp>(
1864 extractOp.getLoc(), rewriter.getIntegerType(32),
1865 rewriter.getI32IntegerAttr(shiftBytes));
1866 rewriter.replaceOpWithNewOp<aievec::ShiftOp>(
1867 extractOp, shortVecType, bottomHalf, topHalf, shiftBytesConstOp);
1876 using OpConversionPattern::OpConversionPattern;
1883 ConversionPatternRewriter &rewriter)
const override {
1885 if (updOp->hasOneUse() && isa<aievec::ExtOp>(*updOp->getUsers().begin()))
1888 auto vecType = cast<VectorType>(updOp.getType());
1889 SmallVector<int64_t, 4> vecShape(vecType.getShape().begin(),
1890 vecType.getShape().end());
1891 vecShape[vecType.getRank() - 1] *= 2;
1892 auto longVecType = VectorType::get(vecShape, vecType.getElementType());
1893 auto newUpdOp = rewriter.create<aievec::UPDOp>(
1894 updOp.getLoc(), longVecType, adaptor.getSource(), adaptor.getIndices(),
1895 adaptor.getOffset(), adaptor.getIndex(), adaptor.getVector());
1896 rewriter.replaceOpWithNewOp<aievec::ExtOp>(
1897 updOp, vecType, newUpdOp.getResult(), rewriter.getI8IntegerAttr(0));
1906 using OpConversionPattern::OpConversionPattern;
1912 ConversionPatternRewriter &rewriter)
const override {
1914 if (extOp.getIndex() != 0)
1917 auto updOp = dyn_cast<aievec::UPDOp>(extOp.getSource().getDefiningOp());
1922 if (!updOp->hasOneUse())
1925 rewriter.replaceOpWithNewOp<aievec::UPDOp>(
1926 extOp, extOp.getType(), updOp.getSource(), updOp.getIndices(),
1927 updOp.getOffset(), updOp.getIndex(), updOp.getVector());
1934 using OpConversionPattern::OpConversionPattern;
1938 ConversionPatternRewriter &rewriter)
const override {
1940 if (!matchExpOpForLUT(adaptor))
1943 auto srcType = dyn_cast<VectorType>(adaptor.getOperand().getType());
1944 StringRef funcName =
"getExpBf16";
1945 auto moduleOp = expOp->getParentOfType<mlir::ModuleOp>();
1947 VectorType v16bf16Ty = mlir::VectorType::get({16}, rewriter.getBF16Type());
1948 VectorType v8i64Ty = mlir::VectorType::get({8}, rewriter.getI64Type());
1949 func::FuncOp fnOp = getOrInsertFuncDecl(
1950 rewriter, moduleOp, funcName, TypeRange{v16bf16Ty}, TypeRange{v8i64Ty});
1952 SmallVector<Value> expOperands = {adaptor.getOperand()};
1956 rewriter.create<func::CallOp>(expOp.getLoc(), fnOp, expOperands);
1957 auto resCastOp = rewriter.create<vector::BitCastOp>(
1958 expOp.getLoc(), accTypeNative, callOp.getResults());
1959 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
1960 expOp.getLoc(), rewriter.getI32IntegerAttr(0));
1961 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1962 expOp, srcType, resCastOp.getResult(), shiftParamOp.getResult());
1969 using OpConversionPattern::OpConversionPattern;
1973 ConversionPatternRewriter &rewriter)
const override {
1974 if (!matchExpOpForLUT(adaptor))
1976 auto srcType = dyn_cast<VectorType>(adaptor.getOperand().getType());
1977 StringRef includeName =
"lut_based_ops.h";
1978 auto moduleOp = expOp->getParentOfType<mlir::ModuleOp>();
1979 rewriter.setInsertionPointToStart(
1980 &moduleOp.getRegion().getBlocks().front());
1981 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName,
false);
1983 rewriter.setInsertionPoint(expOp);
1985 auto v16bf16OpaqueTy =
1986 emitc::OpaqueType::get(rewriter.getContext(),
"v16bfloat16");
1987 auto opaquedOperand =
1989 .create<UnrealizedConversionCastOp>(expOp.getLoc(), v16bf16OpaqueTy,
1990 adaptor.getOperand())
1992 SmallVector<Value> expOperands = {opaquedOperand};
1995 Type v16accf32OpaqueTy =
1996 emitc::OpaqueType::get(rewriter.getContext(),
"v16accfloat");
1997 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
1998 expOp.getLoc(), TypeRange{v16accf32OpaqueTy},
"getExpBf16",
nullptr,
1999 nullptr, expOperands);
2000 auto resCastOp = rewriter.create<UnrealizedConversionCastOp>(
2001 expOp.getLoc(), accTypeNative, callOp.getResults());
2002 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
2003 expOp.getLoc(), rewriter.getI32IntegerAttr(0));
2004 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
2005 expOp, srcType, resCastOp.getResult(0), shiftParamOp.getResult());
2019 using OpConversionPattern::OpConversionPattern;
2023 ConversionPatternRewriter &rewriter)
const override {
2024 Type srcType = adaptor.getLhs().getType();
2025 if (!divOp->hasOneUse() || isa<VectorType>(srcType) ||
2026 !isa<FloatType>(srcType))
2029 if (!isNarrowingOp(*divOp->getUsers().begin()))
2032 auto fType = cast<FloatType>(srcType);
2033 if (fType.getWidth() != 32)
2036 auto constOp = dyn_cast<arith::ConstantOp>(divOp.getLhs().getDefiningOp());
2038 cast<FloatAttr>(constOp.getValue()).getValue().convertToDouble() !=
2042 StringRef includeName =
"lut_based_ops.h";
2043 auto moduleOp = divOp->getParentOfType<mlir::ModuleOp>();
2044 rewriter.setInsertionPointToStart(
2045 &moduleOp.getRegion().getBlocks().front());
2046 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName,
false);
2048 auto truncOp = cast<arith::TruncFOp>(*divOp->getUsers().begin());
2050 rewriter.setInsertionPoint(truncOp);
2052 emitc::OpaqueType::get(rewriter.getContext(),
"bfloat16");
2053 SmallVector<Value> invOperands = {adaptor.getRhs()};
2054 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2055 truncOp.getLoc(), bf16OpaqueTy,
"getInvBf16",
nullptr,
nullptr,
2057 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2058 truncOp, TypeRange{truncOp.getResult().getType()}, callOp.getResults());
2059 rewriter.eraseOp(divOp);
2067 using OpConversionPattern::OpConversionPattern;
2071 ConversionPatternRewriter &rewriter)
const override {
2072 auto srcType = dyn_cast<VectorType>(tanhOp.getOperand().getType());
2076 Type scalarType = srcType.getElementType();
2077 if (!isa<FloatType>(scalarType))
2081 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2082 if (elWidth != 16 || laneSize != 16)
2085 StringRef includeName =
"lut_based_ops.h";
2086 auto moduleOp = tanhOp->getParentOfType<mlir::ModuleOp>();
2087 rewriter.setInsertionPointToStart(
2088 &moduleOp.getRegion().getBlocks().front());
2089 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName,
false);
2091 rewriter.setInsertionPoint(tanhOp);
2092 Type v16bf16OpaqueTy =
2093 emitc::OpaqueType::get(rewriter.getContext(),
"v16bfloat16");
2094 auto opaquedOperand =
2096 .create<UnrealizedConversionCastOp>(
2097 tanhOp.getLoc(), v16bf16OpaqueTy, adaptor.getOperand())
2099 SmallVector<Value> tanhOperands = {opaquedOperand};
2100 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2101 tanhOp.getLoc(), v16bf16OpaqueTy,
"getTanhBf16",
nullptr,
nullptr,
2103 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2104 tanhOp, TypeRange{tanhOp.getResult().getType()}, callOp.getResults());
2113 using OpConversionPattern::OpConversionPattern;
2117 ConversionPatternRewriter &rewriter)
const override {
2118 auto srcType = dyn_cast<VectorType>(sqrtOp.getOperand().getType());
2122 Type scalarType = srcType.getElementType();
2123 if (!isa<FloatType>(scalarType))
2127 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2128 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
2131 StringRef includeName =
"vec_math.h";
2132 auto moduleOp = sqrtOp->getParentOfType<mlir::ModuleOp>();
2133 rewriter.setInsertionPointToStart(
2134 &moduleOp.getRegion().getBlocks().front());
2135 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName,
false);
2137 rewriter.setInsertionPoint(sqrtOp);
2138 Type vLNbf16OpaqueTy;
2141 emitc::OpaqueType::get(rewriter.getContext(),
"v16bfloat16");
2144 emitc::OpaqueType::get(rewriter.getContext(),
"v32bfloat16");
2145 auto opaquedOperand =
2147 .create<UnrealizedConversionCastOp>(
2148 sqrtOp.getLoc(), vLNbf16OpaqueTy, adaptor.getOperand())
2150 SmallVector<Value> sqrtOperands = {opaquedOperand};
2151 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2152 sqrtOp.getLoc(), TypeRange{vLNbf16OpaqueTy},
"getSqrtBf16",
nullptr,
2153 nullptr, sqrtOperands);
2154 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2155 sqrtOp, TypeRange{sqrtOp.getResult().getType()}, callOp.getResults());
2164 using OpConversionPattern::OpConversionPattern;
2168 ConversionPatternRewriter &rewriter)
const override {
2169 auto srcType = dyn_cast<VectorType>(rsqrtOp.getOperand().getType());
2173 Type scalarType = srcType.getElementType();
2174 if (!isa<FloatType>(scalarType))
2178 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2179 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
2182 StringRef includeName =
"vec_math.h";
2183 auto moduleOp = rsqrtOp->getParentOfType<mlir::ModuleOp>();
2184 rewriter.setInsertionPointToStart(
2185 &moduleOp.getRegion().getBlocks().front());
2186 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName,
false);
2188 rewriter.setInsertionPoint(rsqrtOp);
2189 Type vLNbf16OpaqueTy;
2192 emitc::OpaqueType::get(rewriter.getContext(),
"v16bfloat16");
2195 emitc::OpaqueType::get(rewriter.getContext(),
"v32bfloat16");
2196 auto opaquedOperand =
2198 .create<UnrealizedConversionCastOp>(
2199 rsqrtOp.getLoc(), vLNbf16OpaqueTy, adaptor.getOperand())
2201 SmallVector<Value> rsqrtOperands = {opaquedOperand};
2202 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2203 rsqrtOp.getLoc(), TypeRange{vLNbf16OpaqueTy},
"getRsqrtBf16",
nullptr,
2204 nullptr, rsqrtOperands);
2205 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2206 rsqrtOp, TypeRange{rsqrtOp.getResult().getType()}, callOp.getResults());
2215 using OpConversionPattern::OpConversionPattern;
2219 ConversionPatternRewriter &rewriter)
const override {
2220 auto srcType = dyn_cast<VectorType>(erfOp.getOperand().getType());
2224 Type scalarType = srcType.getElementType();
2225 if (!isa<FloatType>(scalarType))
2229 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2230 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
2233 StringRef includeName =
"vec_math.h";
2234 auto moduleOp = erfOp->getParentOfType<mlir::ModuleOp>();
2235 rewriter.setInsertionPointToStart(
2236 &moduleOp.getRegion().getBlocks().front());
2237 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName,
false);
2239 rewriter.setInsertionPoint(erfOp);
2240 Type vLNbf16OpaqueTy;
2243 emitc::OpaqueType::get(rewriter.getContext(),
"v16bfloat16");
2246 emitc::OpaqueType::get(rewriter.getContext(),
"v32bfloat16");
2247 auto opaquedOperand =
2249 .create<UnrealizedConversionCastOp>(erfOp.getLoc(), vLNbf16OpaqueTy,
2250 adaptor.getOperand())
2252 SmallVector<Value> erfOperands = {opaquedOperand};
2253 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2254 erfOp.getLoc(), TypeRange{vLNbf16OpaqueTy},
"getErfBf16",
nullptr,
2255 nullptr, erfOperands);
2256 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2257 erfOp, TypeRange{erfOp.getResult().getType()}, callOp.getResults());
2265template <
typename SrcOpTy>
2272 ConversionPatternRewriter &rewriter)
const override {
2273 auto vecTy = dyn_cast<VectorType>(absOp.getOperand().getType());
2277 Type elemTy = vecTy.getElementType();
2280 unsigned elWidth = elemTy.getIntOrFloatBitWidth();
2282 StringRef includeName =
"vec_math.h";
2283 auto moduleOp = absOp->template getParentOfType<mlir::ModuleOp>();
2284 rewriter.setInsertionPointToStart(
2285 &moduleOp.getRegion().getBlocks().front());
2286 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName,
false);
2288 rewriter.setInsertionPoint(absOp);
2289 std::ostringstream typeName;
2290 typeName <<
"v" << laneSize;
2291 if (isa<FloatType>(elemTy)) {
2293 typeName <<
"bfloat16";
2295 typeName <<
"float";
2297 typeName <<
"int" << elWidth;
2299 emitc::OpaqueType::get(rewriter.getContext(), typeName.str());
2300 auto opaquedOperand =
2302 .create<UnrealizedConversionCastOp>(absOp.getLoc(), vecOpaqueTy,
2303 adaptor.getOperand())
2305 SmallVector<Value> absOperands = {opaquedOperand};
2306 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2307 absOp.getLoc(), TypeRange{vecOpaqueTy},
"getAbs",
nullptr,
nullptr,
2309 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2310 absOp, TypeRange{absOp.getResult().getType()}, callOp.getResults());
2319template <
typename SrcOpTy>
2326 ConversionPatternRewriter &rewriter)
const override {
2327 VectorType srcType = dyn_cast<VectorType>(extOp.getIn().getType());
2328 VectorType dstType = dyn_cast<VectorType>(extOp.getOut().getType());
2332 rewriter.create<aievec::UPSOp>(extOp.getLoc(), accType, extOp.getIn());
2334 if (dstType.getElementType().getIntOrFloatBitWidth() == 16) {
2335 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
2336 extOp.getLoc(), rewriter.getI32IntegerAttr(0));
2337 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
2338 extOp, dstType, upsOp.getResult(), shiftParamOp.getResult());
2340 rewriter.replaceOpWithNewOp<aievec::CastOp>(
2341 extOp, dstType, upsOp.getResult(),
false);
2350template <
typename SrcOpTy>
2357 ConversionPatternRewriter &rewriter)
const override {
2358 VectorType srcType = dyn_cast<VectorType>(truncOp.getIn().getType());
2359 VectorType dstType = dyn_cast<VectorType>(truncOp.getOut().getType());
2360 Type scalarType = srcType.getElementType();
2361 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2364 auto accType = isa<IntegerType>(scalarType) && (elWidth == 32)
2368 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
2369 truncOp.getLoc(), rewriter.getI32IntegerAttr(0));
2370 if (elWidth == 16) {
2371 auto upsOp = rewriter.create<aievec::UPSOp>(truncOp.getLoc(), accType,
2373 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
2374 truncOp, dstType, upsOp.getResult(), shiftParamOp.getResult());
2376 auto castOp = rewriter.create<aievec::CastOp>(truncOp.getLoc(), accType,
2377 truncOp.getIn(),
true);
2378 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
2379 truncOp, dstType, castOp.getResult(), shiftParamOp.getResult());
2394static std::optional<Value>
2395getUnOpaquedOperandOfEmitCOpaqueCallOp(Operation *op, StringRef funcName) {
2396 auto uccOp = dyn_cast<UnrealizedConversionCastOp>(op);
2400 auto inVal = uccOp.getInputs()[0];
2401 if (!isa<emitc::OpaqueType>(inVal.getType()))
2404 auto callOp = inVal.getDefiningOp<emitc::CallOpaqueOp>();
2405 if (callOp.getCallee() != funcName)
2408 auto callOperandsUccOp =
2409 callOp.getOperands()[0].getDefiningOp<UnrealizedConversionCastOp>();
2410 if (!callOperandsUccOp)
2413 return callOperandsUccOp.getInputs()[0];
2429template <
typename DivFOpTy>
2430static bool hasSigmoidComputationChain(DivFOpTy divfOp, arith::NegFOp &negOp) {
2431 auto constOp = dyn_cast<arith::ConstantOp>(divfOp.getLhs().getDefiningOp());
2435 auto cstDense = dyn_cast<DenseFPElementsAttr>(constOp.getValue());
2439 if (cstDense.template getSplatValue<APFloat>().convertToFloat() != 1.0f)
2442 Operation *addLvalOp;
2443 Operation *addRvalOp;
2449 auto addOp = dyn_cast<arith::AddFOp>(divfOp.getRhs().getDefiningOp());
2451 auto srsOp = dyn_cast<aievec::SRSOp>(divfOp.getRhs().getDefiningOp());
2456 dyn_cast<aievec::AddElemOp>(srsOp.getSource().getDefiningOp());
2460 auto lUpsOp = dyn_cast<aievec::UPSOp>(addElemOp.getLhs().getDefiningOp());
2461 auto rUpsOp = dyn_cast<aievec::UPSOp>(addElemOp.getRhs().getDefiningOp());
2462 if (!lUpsOp || !rUpsOp)
2465 addLvalOp = lUpsOp.getSource().getDefiningOp();
2466 addRvalOp = rUpsOp.getSource().getDefiningOp();
2469 auto addDefOp = isa<arith::ConstantOp>(addLvalOp)
2470 ? dyn_cast<aievec::SRSOp>(addRvalOp)
2471 : dyn_cast<aievec::SRSOp>(addLvalOp);
2473 addLvalOp = isa<arith::ConstantOp>(addLvalOp)
2474 ? dyn_cast<math::ExpOp>(addRvalOp)
2475 : dyn_cast<math::ExpOp>(addLvalOp);
2477 addLvalOp = addDefOp.getSource().getDefiningOp();
2479 addRvalOp = isa<arith::ConstantOp>(addLvalOp)
2480 ? lUpsOp.getSource().getDefiningOp()
2481 : rUpsOp.getSource().getDefiningOp();
2483 addLvalOp = addOp.getLhs().getDefiningOp();
2484 addRvalOp = addOp.getRhs().getDefiningOp();
2487 if (!addLvalOp || !addRvalOp)
2490 auto addLvalExpOp = dyn_cast<math::ExpOp>(addLvalOp);
2491 auto addRvalExpOp = dyn_cast<math::ExpOp>(addRvalOp);
2492 auto addLvalExpOpIn =
2493 getUnOpaquedOperandOfEmitCOpaqueCallOp(addLvalOp,
"getExpBf16")
2495 auto addRvalExpOpIn =
2496 getUnOpaquedOperandOfEmitCOpaqueCallOp(addRvalOp,
"getExpBf16")
2498 if (!addLvalExpOpIn && addLvalExpOp)
2499 addLvalExpOpIn = addLvalExpOp.getOperand();
2500 if (!addRvalExpOpIn && addRvalExpOp)
2501 addRvalExpOpIn = addRvalExpOp.getOperand();
2503 if (!((addLvalExpOpIn && isa<arith::ConstantOp>(addRvalOp)) ||
2504 (addRvalExpOpIn && isa<arith::ConstantOp>(addLvalOp))))
2507 constOp = isa<arith::ConstantOp>(addLvalOp)
2508 ? cast<arith::ConstantOp>(addLvalOp)
2509 : cast<arith::ConstantOp>(addRvalOp);
2511 cstDense = dyn_cast<DenseFPElementsAttr>(constOp.getValue());
2514 if (cstDense.template getSplatValue<APFloat>().convertToFloat() != 1.0f)
2517 auto expOperand = addLvalExpOpIn ? addLvalExpOpIn : addRvalExpOpIn;
2519 negOp = expOperand.getDefiningOp<arith::NegFOp>();
2521 return negOp !=
nullptr;
2538 using OpConversionPattern::OpConversionPattern;
2542 ConversionPatternRewriter &rewriter)
const override {
2543 auto srcType = dyn_cast<VectorType>(adaptor.getLhs().getType());
2547 Type scalarType = srcType.getElementType();
2548 if (!isa<FloatType>(scalarType))
2552 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2553 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
2556 arith::NegFOp negOp =
nullptr;
2557 if (!hasSigmoidComputationChain(adaptor, negOp))
2560 StringRef includeName =
"vec_math.h";
2561 auto moduleOp = divfOp->getParentOfType<mlir::ModuleOp>();
2562 rewriter.setInsertionPointToStart(
2563 &moduleOp.getRegion().getBlocks().front());
2564 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName,
false);
2566 rewriter.setInsertionPoint(divfOp);
2570 emitc::OpaqueType::get(rewriter.getContext(),
"v16bfloat16");
2573 emitc::OpaqueType::get(rewriter.getContext(),
"v32bfloat16");
2574 auto opaquedOperand =
2576 .create<UnrealizedConversionCastOp>(divfOp.getLoc(), vecOpaqueTy,
2579 SmallVector<Value> sigmoidOperands = {opaquedOperand};
2580 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2581 divfOp.getLoc(), TypeRange{vecOpaqueTy},
"getSigmoidBf16",
nullptr,
2582 nullptr, sigmoidOperands);
2583 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2584 divfOp, TypeRange{adaptor.getLhs().getType()}, callOp.getResults());
2592 using OpConversionPattern::OpConversionPattern;
2596 ConversionPatternRewriter &rewriter)
const override {
2597 auto srcType = dyn_cast<VectorType>(ceilOp.getOperand().getType());
2601 Type scalarType = srcType.getElementType();
2602 if (!isa<FloatType>(scalarType))
2606 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2607 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
2610 StringRef includeName =
"vec_math.h";
2611 auto moduleOp = ceilOp->getParentOfType<mlir::ModuleOp>();
2612 rewriter.setInsertionPointToStart(
2613 &moduleOp.getRegion().getBlocks().front());
2614 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName,
false);
2616 rewriter.setInsertionPoint(ceilOp);
2620 emitc::OpaqueType::get(rewriter.getContext(),
"v16bfloat16");
2623 emitc::OpaqueType::get(rewriter.getContext(),
"v32bfloat16");
2624 auto opaquedOperand =
2626 .create<UnrealizedConversionCastOp>(ceilOp.getLoc(), vecOpaqueTy,
2627 adaptor.getOperand())
2629 SmallVector<Value> ceilOperands = {opaquedOperand};
2630 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2631 ceilOp.getLoc(), TypeRange{vecOpaqueTy},
"getCeilBf16",
nullptr,
2632 nullptr, ceilOperands);
2633 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2634 ceilOp, TypeRange{ceilOp.getResult().getType()}, callOp.getResults());
2642 using OpConversionPattern::OpConversionPattern;
2646 ConversionPatternRewriter &rewriter)
const override {
2647 auto srcType = dyn_cast<VectorType>(floorOp.getOperand().getType());
2651 Type scalarType = srcType.getElementType();
2652 if (!isa<FloatType>(scalarType))
2656 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2657 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
2660 StringRef includeName =
"vec_math.h";
2661 auto moduleOp = floorOp->getParentOfType<mlir::ModuleOp>();
2662 rewriter.setInsertionPointToStart(
2663 &moduleOp.getRegion().getBlocks().front());
2664 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName,
false);
2666 rewriter.setInsertionPoint(floorOp);
2670 emitc::OpaqueType::get(rewriter.getContext(),
"v16bfloat16");
2673 emitc::OpaqueType::get(rewriter.getContext(),
"v32bfloat16");
2674 auto opaquedOperand =
2676 .create<UnrealizedConversionCastOp>(floorOp.getLoc(), vecOpaqueTy,
2677 adaptor.getOperand())
2679 SmallVector<Value> floorOperands = {opaquedOperand};
2680 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2681 floorOp.getLoc(), TypeRange{vecOpaqueTy},
"getFloorBf16",
nullptr,
2682 nullptr, floorOperands);
2683 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2684 floorOp, TypeRange{floorOp.getResult().getType()}, callOp.getResults());
2693 using OpConversionPattern::OpConversionPattern;
2697 ConversionPatternRewriter &rewriter)
const override {
2698 auto srcType = dyn_cast<VectorType>(negOp.getOperand().getType());
2702 Type scalarType = srcType.getElementType();
2703 if (!isa<FloatType>(scalarType))
2709 Location loc = negOp.getLoc();
2712 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2713 if (elWidth == 16) {
2715 rewriter.create<aievec::UPSOp>(loc, accType, adaptor.getOperand());
2717 rewriter.create<aievec::NegOp>(loc, accType, upsOp.getResult());
2718 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
2719 negOp.getLoc(), rewriter.getI32IntegerAttr(0));
2720 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
2721 negOp, srcType, aieNegOp.getResult(), shiftParamOp.getResult());
2723 auto castOp = rewriter.create<aievec::CastOp>(
2724 loc, accType, adaptor.getOperand(),
true);
2726 rewriter.create<aievec::NegOp>(loc, accType, castOp.getResult());
2727 rewriter.replaceOpWithNewOp<aievec::CastOp>(
2728 negOp, srcType, aieNegOp.getResult(),
false);
2737static bool hasConstNegOneValue(arith::ConstantOp constOp,
unsigned elWidth) {
2741 auto cstDense = dyn_cast<DenseIntElementsAttr>(constOp.getValue());
2746 return cstDense.getSplatValue<int32_t>() == -1;
2748 return cstDense.getSplatValue<int16_t>() == -1;
2750 return cstDense.getSplatValue<int8_t>() == -1;
2757 using OpConversionPattern::OpConversionPattern;
2761 ConversionPatternRewriter &rewriter)
const override {
2762 auto srcType = dyn_cast<VectorType>(xorOp.getLhs().getType());
2766 Type scalarType = srcType.getElementType();
2767 if (!isa<IntegerType>(scalarType))
2771 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2772 if (laneSize * elWidth != 512)
2776 dyn_cast<arith::ConstantOp>(xorOp.getLhs().getDefiningOp());
2778 dyn_cast<arith::ConstantOp>(xorOp.getRhs().getDefiningOp());
2782 if ((lhsConstOp && hasConstNegOneValue(lhsConstOp, elWidth)) ||
2783 (rhsConstOp && hasConstNegOneValue(rhsConstOp, elWidth))) {
2784 Value val = hasConstNegOneValue(lhsConstOp, elWidth) ? adaptor.getRhs()
2786 rewriter.replaceOpWithNewOp<aievec::BnegOp>(xorOp, srcType, val);
2788 rewriter.replaceOpWithNewOp<aievec::BxorOp>(
2789 xorOp, srcType, adaptor.getLhs(), adaptor.getRhs());
2795template <
typename SrcOpTy,
typename DstOpTy>
2802 ConversionPatternRewriter &rewriter)
const override {
2803 VectorType srcType = dyn_cast<VectorType>(srcOp.getLhs().getType());
2807 Type scalarType = srcType.getElementType();
2808 if (!isa<IntegerType>(scalarType))
2812 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2813 if (laneSize * elWidth != 512)
2816 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getResult().getType(),
2817 adaptor.getLhs(), adaptor.getRhs());
2833 using OpConversionPattern::OpConversionPattern;
2837 ConversionPatternRewriter &rewriter)
const override {
2838 auto srcType = dyn_cast<VectorType>(adaptor.getLhs().getType());
2842 Type scalarType = srcType.getElementType();
2844 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2845 if (laneSize * elWidth != 512)
2849 dyn_cast<aievec::BroadcastOp>(adaptor.getRhs().getDefiningOp());
2853 auto constOp = rewriter.create<arith::ConstantOp>(
2854 bcastOp.getLoc(), rewriter.getI32IntegerAttr(bcastOp.getIdx()));
2855 auto extElemOp = rewriter.create<aievec::ExtElemOp>(
2856 bcastOp.getLoc(), scalarType, bcastOp, constOp.getResult());
2857 Location loc = rsOp.getLoc();
2864 rewriter.create<aievec::ExtOp>(loc, halfSrcType, adaptor.getLhs(), 0);
2866 rewriter.create<aievec::ExtOp>(loc, halfSrcType, adaptor.getLhs(), 1);
2869 rewriter.create<aievec::UPSOp>(loc, accType, rsOpLow.getResult());
2870 auto srsOpLow = rewriter.create<aievec::SRSOp>(
2871 loc, halfSrcType, upsOpLow.getResult(), extElemOp.getResult());
2873 rewriter.create<aievec::UPSOp>(loc, accType, rsOpHigh.getResult());
2874 auto srsOpHigh = rewriter.create<aievec::SRSOp>(
2875 loc, halfSrcType, upsOpHigh.getResult(), extElemOp.getResult());
2876 SmallVector<Value> inputSources = {srsOpLow.getResult(),
2877 srsOpHigh.getResult()};
2878 rewriter.replaceOpWithNewOp<aievec::ConcatOp>(rsOp, srcType,
2883 rewriter.create<aievec::UPSOp>(loc, accType, adaptor.getLhs());
2884 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
2885 rsOp, srcType, upsOp.getResult(), extElemOp.getResult());
2895 using OpConversionPattern::OpConversionPattern;
2902 auto vecTy = dyn_cast<VectorType>(v.getType());
2905 auto vecShape = vecTy.getShape();
2907 size_t numLeadUnitDims = 0;
2908 while (numLeadUnitDims < vecShape.size() && vecShape[numLeadUnitDims] == 1)
2911 if (!numLeadUnitDims)
2914 SmallVector<int64_t> newShape(vecShape.begin() + numLeadUnitDims,
2916 auto newVecTy = VectorType::get(newShape, vecTy.getElementType());
2917 return b.create<vector::ShapeCastOp>(v.getLoc(), newVecTy, v).getResult();
2922 ConversionPatternRewriter &rewriter)
const override {
2926 bool bReshapedAcc = (acc != adaptor.getAcc());
2929 acc = rewriter.create<aievec::CastOp>(contractOp.getLoc(), acc.getType(),
2932 auto matmulOp = rewriter.create<aievec::MatMulOp>(
2933 contractOp.getLoc(), acc.getType(), lhs, rhs, acc);
2937 ScopedDiagnosticHandler diagHandler(
2938 contractOp.getContext(), [](Diagnostic &) { return success(); });
2939 if (failed(matmulOp.verifyInvariants())) {
2940 rewriter.eraseOp(matmulOp);
2944 lhs = adaptor.getLhs();
2945 auto wideLhsValue = getSourceOfWideningOp(lhs).value_or(
nullptr);
2949 rhs = adaptor.getRhs();
2950 auto wideRhsValue = getSourceOfWideningOp(rhs).value_or(
nullptr);
2954 matmulOp = rewriter.create<aievec::MatMulOp>(
2955 contractOp.getLoc(), acc.getType(), lhs, rhs, acc);
2956 if (failed(matmulOp.verifyInvariants()))
2961 Value result = matmulOp.getResult();
2963 result = rewriter.create<aievec::CastOp>(contractOp.getLoc(),
2964 acc.getType(), matmulOp,
false);
2966 result = rewriter.create<vector::ShapeCastOp>(
2967 contractOp.getLoc(), adaptor.getAcc().getType(), result);
2968 rewriter.replaceOp(contractOp, result);
2979 using OpConversionPattern::OpConversionPattern;
2982 ConversionPatternRewriter &rewriter)
const override {
2983 auto resTy = transpOp.getResultVectorType();
2984 auto resShape = resTy.getShape();
2985 auto elemTyBitWidth = resTy.getElementTypeBitWidth();
2986 auto vBitWidth = std::accumulate(resShape.begin(), resShape.end(),
2987 elemTyBitWidth, std::multiplies<>());
2988 if (vBitWidth != 512)
2991 if (elemTyBitWidth != 8 && elemTyBitWidth != 16 && elemTyBitWidth != 32)
2995 for (int64_t i = 0; i < static_cast<int64_t>(resShape.size() - 2); ++i)
2996 if (resShape[i] != 1)
3000 ArrayRef<int64_t> perm = transpOp.getPermutation();
3001 for (int64_t i = 0; i < static_cast<int64_t>(perm.size() - 2); ++i)
3004 if (perm.back() !=
static_cast<int64_t
>(perm.size() - 2))
3007 auto shuffleMode = aievec::ShuffleMode::T32_4X4;
3008 if (elemTyBitWidth == 8) {
3009 switch (resShape.back()) {
3011 shuffleMode = aievec::ShuffleMode::T8_4X16;
3014 shuffleMode = aievec::ShuffleMode::T8_8X8;
3017 shuffleMode = aievec::ShuffleMode::T8_16X4;
3022 }
else if (elemTyBitWidth == 16) {
3023 switch (resShape.back()) {
3025 shuffleMode = aievec::ShuffleMode::T16_2X16;
3028 shuffleMode = aievec::ShuffleMode::T16_4X8;
3031 shuffleMode = aievec::ShuffleMode::T16_8X4;
3034 shuffleMode = aievec::ShuffleMode::T16_16X2;
3039 }
else if (resShape.back() != 4)
3043 VectorType::get({512 / elemTyBitWidth}, resTy.getElementType());
3044 auto loc = transpOp.getLoc();
3045 auto flatInput = rewriter.create<vector::ShapeCastOp>(loc, flatVecTy,
3046 adaptor.getVector());
3047 auto shuffOp = rewriter.create<aievec::ShuffleOp>(loc, flatVecTy, flatInput,
3048 nullptr, shuffleMode);
3049 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transpOp, resTy, shuffOp);
3059static void populateAIEVecCommonConversionPatterns(RewritePatternSet &patterns,
3069static void populateAIEVecV1ConversionPatterns(RewritePatternSet &patterns,
3086static void populateAIEVecV2ConversionPatterns(RewritePatternSet &patterns,
3090 if (backend == TargetBackend::CPP) {
3093 >(patterns.getContext(), 128, 1024, 256, 1024);
3100 >(patterns.getContext());
3101 }
else if (backend == TargetBackend::LLVMIR){
3104 >(patterns.getContext());
3142 >(patterns.getContext());
3144 >(patterns.getContext(), backend == TargetBackend::CPP);
3154static bool isInSigmoidOperationChain(math::ExpOp expOp) {
3155 if (!expOp.getOperand().getDefiningOp<arith::NegFOp>())
3158 arith::AddFOp addOp =
nullptr;
3159 for (Operation *user : expOp->getUsers()) {
3160 addOp = dyn_cast<arith::AddFOp>(user);
3168 auto *addLvalOp = addOp.getLhs().getDefiningOp();
3169 auto *addRvalOp = addOp.getRhs().getDefiningOp();
3170 if (!((isa<math::ExpOp>(addLvalOp) && isa<arith::ConstantOp>(addRvalOp)) ||
3171 (isa<math::ExpOp>(addRvalOp) && isa<arith::ConstantOp>(addLvalOp))))
3174 auto constOp = isa<arith::ConstantOp>(addLvalOp)
3175 ? cast<arith::ConstantOp>(addLvalOp)
3176 : cast<arith::ConstantOp>(addRvalOp);
3178 auto cstDense = dyn_cast<DenseFPElementsAttr>(constOp.getValue());
3182 if (cstDense.getSplatValue<APFloat>().convertToFloat() != 1.0f)
3185 arith::DivFOp divOp =
nullptr;
3186 for (Operation *user : addOp->getUsers()) {
3187 divOp = dyn_cast<arith::DivFOp>(user);
3195 constOp = dyn_cast<arith::ConstantOp>(divOp.getLhs().getDefiningOp());
3198 cstDense = dyn_cast<DenseFPElementsAttr>(constOp.getValue());
3201 if (cstDense.getSplatValue<APFloat>().convertToFloat() != 1.0f)
3207static void configureAIEVecCommonLegalizations(ConversionTarget &target,
3210 .addLegalDialect<xilinx::aievec::aie1::AIEVecAIE1Dialect,
3211 xilinx::aievec::AIEVecDialect, arith::ArithDialect,
3212 ub::UBDialect, emitc::EmitCDialect, func::FuncDialect>();
3213 if (backend == TargetBackend::CPP) {
3214 target.addIllegalOp<vector::TransferReadOp>();
3216 target.addIllegalOp<vector::ExtractStridedSliceOp>();
3217 target.addLegalOp<vector::BitCastOp>();
3219 target.addDynamicallyLegalOp<arith::ExtFOp>([](arith::ExtFOp extfOp) {
3220 auto srcType = dyn_cast<VectorType>(extfOp.getIn().getType());
3221 auto dstType = dyn_cast<VectorType>(extfOp.getOut().getType());
3222 if (!srcType || !dstType)
3225 Type srcScalarType = srcType.getElementType();
3226 Type dstScalarType = dstType.getElementType();
3227 if (!isa<FloatType>(srcScalarType) || !isa<FloatType>(dstScalarType))
3232 unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth();
3233 unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth();
3234 return srcElWidth != 16 || srcLaneSize != 16 || dstElWidth != 32 ||
3238 target.addDynamicallyLegalOp<arith::ExtSIOp>([](arith::ExtSIOp extsiOp) {
3239 auto srcType = dyn_cast<VectorType>(extsiOp.getIn().getType());
3240 auto dstType = dyn_cast<VectorType>(extsiOp.getOut().getType());
3241 if (!srcType || !dstType)
3244 Type srcScalarType = srcType.getElementType();
3245 Type dstScalarType = dstType.getElementType();
3246 if (!isa<IntegerType>(srcScalarType) || !isa<IntegerType>(dstScalarType))
3251 unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth();
3252 unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth();
3253 return srcLaneSize != 32 || (dstElWidth <= srcElWidth) ||
3254 (dstLaneSize != srcLaneSize);
3257 target.addDynamicallyLegalOp<arith::TruncFOp>([](arith::TruncFOp truncfOp) {
3258 auto srcType = dyn_cast<VectorType>(truncfOp.getIn().getType());
3259 auto dstType = dyn_cast<VectorType>(truncfOp.getOut().getType());
3260 if (!srcType || !dstType)
3263 Type srcScalarType = srcType.getElementType();
3264 Type dstScalarType = dstType.getElementType();
3265 if (!isa<FloatType>(srcScalarType) || !isa<FloatType>(dstScalarType))
3270 unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth();
3271 unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth();
3272 return srcElWidth != 32 || srcLaneSize != 16 || dstElWidth != 16 ||
3276 target.addDynamicallyLegalOp<arith::TruncIOp>([](arith::TruncIOp trunciOp) {
3277 auto srcType = dyn_cast<VectorType>(trunciOp.getIn().getType());
3278 auto dstType = dyn_cast<VectorType>(trunciOp.getOut().getType());
3279 if (!srcType || !dstType)
3282 Type srcScalarType = srcType.getElementType();
3283 Type dstScalarType = dstType.getElementType();
3284 if (!isa<IntegerType>(srcScalarType) || !isa<IntegerType>(dstScalarType))
3289 unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth();
3290 unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth();
3292 return srcLaneSize != 32 || (dstElWidth >= srcElWidth) ||
3293 (dstLaneSize != srcLaneSize);
3296 target.addDynamicallyLegalOp<math::ExpOp>([](math::ExpOp expOp) {
3297 auto srcType = dyn_cast<VectorType>(expOp.getOperand().getType());
3301 Type scalarType = srcType.getElementType();
3302 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3304 if (!isa<FloatType>(scalarType) || laneSize != 16 || elWidth != 16)
3306 if (expOp->hasOneUse() && isInSigmoidOperationChain(expOp))
3312 target.addDynamicallyLegalOp<math::TanhOp>([](math::TanhOp tanhOp) {
3313 auto srcType = dyn_cast<VectorType>(tanhOp.getOperand().getType());
3317 Type scalarType = srcType.getElementType();
3318 if (!isa<FloatType>(scalarType))
3322 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3323 return elWidth != 16 || laneSize != 16;
3326 target.addDynamicallyLegalOp<math::SqrtOp>([](math::SqrtOp sqrtOp) {
3327 auto srcType = dyn_cast<VectorType>(sqrtOp.getOperand().getType());
3331 Type scalarType = srcType.getElementType();
3332 if (!isa<FloatType>(scalarType))
3336 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3337 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
3340 target.addDynamicallyLegalOp<math::RsqrtOp>([](math::RsqrtOp rsqrtOp) {
3341 auto srcType = dyn_cast<VectorType>(rsqrtOp.getOperand().getType());
3342 Type scalarType = srcType.getElementType();
3343 if (!srcType || !isa<FloatType>(scalarType))
3347 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3348 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
3351 target.addDynamicallyLegalOp<math::ErfOp>([](math::ErfOp erfOp) {
3352 auto srcType = dyn_cast<VectorType>(erfOp.getOperand().getType());
3356 Type scalarType = srcType.getElementType();
3357 if (!isa<FloatType>(scalarType))
3361 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3362 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
3365 target.addDynamicallyLegalOp<math::AbsFOp>([](math::AbsFOp absfOp) {
3366 auto srcType = dyn_cast<VectorType>(absfOp.getOperand().getType());
3370 Type scalarType = srcType.getElementType();
3372 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3373 return elWidth * laneSize != 512 && elWidth * laneSize != 256;
3376 target.addDynamicallyLegalOp<math::AbsIOp>([](math::AbsIOp absiOp) {
3377 auto srcType = dyn_cast<VectorType>(absiOp.getOperand().getType());
3381 Type scalarType = srcType.getElementType();
3383 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3384 return elWidth * laneSize != 512 && elWidth * laneSize != 256;
3387 target.addDynamicallyLegalOp<arith::DivFOp>([](arith::DivFOp divfOp) {
3388 if (
auto srcType = dyn_cast<VectorType>(divfOp.getLhs().getType());
3390 Type scalarType = divfOp.getLhs().getType();
3391 if (!divfOp->hasOneUse() || !isa<FloatType>(scalarType))
3393 if (!isNarrowingOp(*divfOp->getUsers().begin()))
3396 auto fType = cast<FloatType>(scalarType);
3397 if (fType.getWidth() != 32)
3401 dyn_cast<arith::ConstantOp>(divfOp.getLhs().getDefiningOp());
3403 cast<FloatAttr>(constOp.getValue()).getValue().convertToDouble() !=
3407 Type scalarType = srcType.getElementType();
3408 if (!isa<FloatType>(scalarType))
3412 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3414 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
3417 arith::NegFOp negOp =
nullptr;
3418 if (!hasSigmoidComputationChain(divfOp, negOp))
3425 target.addDynamicallyLegalOp<math::CeilOp>([](math::CeilOp ceilOp) {
3426 auto srcType = dyn_cast<VectorType>(ceilOp.getOperand().getType());
3429 Type scalarType = srcType.getElementType();
3430 if (!isa<FloatType>(scalarType))
3434 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3435 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
3438 target.addDynamicallyLegalOp<math::FloorOp>([](math::FloorOp floorOp) {
3439 auto srcType = dyn_cast<VectorType>(floorOp.getOperand().getType());
3442 Type scalarType = srcType.getElementType();
3443 if (!isa<FloatType>(scalarType))
3447 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3448 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
3451 target.addDynamicallyLegalOp<arith::NegFOp>([](arith::NegFOp negOp) {
3452 auto srcType = dyn_cast<VectorType>(negOp.getOperand().getType());
3455 if (Type scalarType = srcType.getElementType(); !isa<FloatType>(scalarType))
3459 return laneSize != 16;
3462 target.addDynamicallyLegalOp<arith::XOrIOp>([](arith::XOrIOp xorOp) {
3463 auto srcType = dyn_cast<VectorType>(xorOp.getLhs().getType());
3466 Type scalarType = srcType.getElementType();
3467 if (!isa<IntegerType>(scalarType))
3471 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3473 return laneSize * elWidth != 512;
3476 target.addDynamicallyLegalOp<arith::OrIOp>([](arith::OrIOp orOp) {
3477 auto srcType = dyn_cast<VectorType>(orOp.getLhs().getType());
3480 Type scalarType = srcType.getElementType();
3481 if (!isa<IntegerType>(scalarType))
3485 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3487 return laneSize * elWidth != 512;
3490 target.addDynamicallyLegalOp<arith::ShRSIOp>([](arith::ShRSIOp rsOp) {
3491 auto srcType = dyn_cast<VectorType>(rsOp.getLhs().getType());
3494 Type scalarType = srcType.getElementType();
3497 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3499 return laneSize * elWidth != 512;
3502 target.addDynamicallyLegalOp<arith::AndIOp>([](arith::AndIOp andOp) {
3503 auto srcType = dyn_cast<VectorType>(andOp.getLhs().getType());
3506 Type scalarType = srcType.getElementType();
3507 if (!isa<IntegerType>(scalarType))
3511 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3513 return laneSize * elWidth != 512;
3516 if (backend == TargetBackend::CPP) {
3517 target.addDynamicallyLegalOp<arith::AddIOp>(
3518 [](arith::AddIOp op) {
return !isa<VectorType>(op.getType()); });
3520 target.addDynamicallyLegalOp<arith::AddFOp>(
3521 [](arith::AddFOp op) {
return !isa<VectorType>(op.getType()); });
3522 target.addDynamicallyLegalOp<arith::SubIOp>(
3523 [](arith::SubIOp op) {
return !isa<VectorType>(op.getType()); });
3524 target.addDynamicallyLegalOp<arith::SubFOp>(
3525 [](arith::SubFOp op) {
return !isa<VectorType>(op.getType()); });
3528static void configureAIEVecV1Legalizations(ConversionTarget &target,
3530 target.addDynamicallyLegalOp<arith::MulIOp>(
3531 [](arith::MulIOp op) {
return !isa<VectorType>(op.getType()); });
3532 target.addDynamicallyLegalOp<arith::MulFOp>(
3533 [](arith::MulFOp op) {
return !isa<VectorType>(op.getType()); });
3534 target.addDynamicallyLegalOp<aievec::aie1::FMAOp>(
3535 [](xilinx::aievec::aie1::FMAOp op) {
3536 auto *lhsDefOp = op.getLhs().getDefiningOp();
3537 aievec::ConcatOp concatOp =
nullptr;
3539 concatOp = dyn_cast<aievec::ConcatOp>(op.getLhs().getDefiningOp());
3543 vector::SplatOp srcSplat =
nullptr;
3544 if (
auto *lhsOp = concatOp.getSources()[0].getDefiningOp())
3545 srcSplat = dyn_cast<vector::SplatOp>(lhsOp);
3547 auto *rhsOp = op.getRhs().getDefiningOp();
3550 srcSplat = dyn_cast<vector::SplatOp>(rhsOp);
3554 if (
auto *srcOp = srcSplat.getInput().getDefiningOp())
3555 return !isa<vector::ExtractOp>(srcOp);
3560 target.addDynamicallyLegalOp<aievec::aie1::AddOp>([](aievec::aie1::AddOp op) {
3561 auto lSrsOp = op.getLhs().getDefiningOp<aievec::SRSOp>();
3562 auto rSrsOp = op.getRhs().getDefiningOp<aievec::SRSOp>();
3564 !lSrsOp.getSource().getDefiningOp<aievec::aie1::MulOp>()) &&
3566 !rSrsOp.getSource().getDefiningOp<aievec::aie1::MulOp>());
3568 target.addLegalDialect<memref::MemRefDialect>();
3571static void configureAIEVecV2Legalizations(ConversionTarget &target,
3573 target.addLegalOp<UnrealizedConversionCastOp>();
3574 target.addLegalOp<vector::ShapeCastOp>();
3577 llvm::SmallSet<std::pair<unsigned, unsigned>, 16> laneSizeElWidthPairSet;
3578 laneSizeElWidthPairSet.insert({64, 8});
3579 laneSizeElWidthPairSet.insert({32, 16});
3580 laneSizeElWidthPairSet.insert({16, 32});
3581 laneSizeElWidthPairSet.insert({32, 32});
3584 llvm::SmallSet<unsigned, 16> elWidthSet;
3585 elWidthSet.insert(8);
3586 elWidthSet.insert(16);
3587 elWidthSet.insert(32);
3589 if (backend == TargetBackend::CPP) {
3590 target.addDynamicallyLegalOp<arith::AddIOp>([=](arith::AddIOp op) {
3591 auto resultType = dyn_cast<VectorType>(op.getType());
3595 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3598 return !laneSizeElWidthPairSet.count(
3599 std::make_pair(laneSize, resultElWidth));
3603 target.addDynamicallyLegalOp<arith::SubIOp>([=](arith::SubIOp op) {
3604 auto resultType = dyn_cast<VectorType>(op.getType());
3607 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3610 return !laneSizeElWidthPairSet.count(
3611 std::make_pair(laneSize, resultElWidth));
3614 target.addDynamicallyLegalOp<arith::AddFOp>([](arith::AddFOp op) {
3615 auto resultType = dyn_cast<VectorType>(op.getType());
3620 return laneSize != 16;
3623 target.addDynamicallyLegalOp<arith::SubFOp>([](arith::SubFOp op) {
3624 auto resultType = dyn_cast<VectorType>(op.getType());
3629 return laneSize != 16;
3632 target.addDynamicallyLegalOp<arith::MulIOp>([](arith::MulIOp op) {
3633 auto resultType = dyn_cast<VectorType>(op.getType());
3636 auto isAddOp = [&](Operation *op) {
return isa<arith::AddIOp>(op); };
3638 if (op->hasOneUse() && llvm::any_of(op->getUsers(), isAddOp))
3641 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3644 return (laneSize != 32 || (resultElWidth != 16 && resultElWidth != 8)) &&
3645 ((laneSize != 16 && laneSize != 32) || resultElWidth != 32);
3648 target.addDynamicallyLegalOp<arith::MulFOp>([](arith::MulFOp op) {
3649 auto resultType = dyn_cast<VectorType>(op.getType());
3653 auto isAddOp = [&](Operation *op) {
return isa<arith::AddFOp>(op); };
3655 if (op->hasOneUse() && llvm::any_of(op->getUsers(), isAddOp))
3658 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3661 return laneSize != 16 || (resultElWidth != 16 && resultElWidth != 32);
3664 target.addDynamicallyLegalOp<arith::MinSIOp>([=](arith::MinSIOp op) {
3665 auto resultType = dyn_cast<VectorType>(op.getType());
3669 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3672 return !elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512;
3675 target.addDynamicallyLegalOp<arith::MaxSIOp>([=](arith::MaxSIOp op) {
3676 auto resultType = dyn_cast<VectorType>(op.getType());
3680 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3683 return !elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512;
3686 target.addDynamicallyLegalOp<arith::MinimumFOp>([=](arith::MinimumFOp op) {
3687 auto resultType = dyn_cast<VectorType>(op.getType());
3691 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3694 return !elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512;
3697 target.addDynamicallyLegalOp<arith::MaximumFOp>([=](arith::MaximumFOp op) {
3698 auto resultType = dyn_cast<VectorType>(op.getType());
3702 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3705 return !elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512;
3708 target.addDynamicallyLegalOp<arith::CmpIOp>([=](arith::CmpIOp op) {
3709 auto lhsType = dyn_cast<VectorType>(op.getLhs().getType());
3713 auto lhsElWidth = lhsType.getElementType().getIntOrFloatBitWidth();
3716 return !elWidthSet.count(lhsElWidth) || laneSize * lhsElWidth != 512;
3719 target.addDynamicallyLegalOp<arith::CmpFOp>([=](arith::CmpFOp op) {
3720 auto lhsType = dyn_cast<VectorType>(op.getLhs().getType());
3724 auto lhsElWidth = lhsType.getElementType().getIntOrFloatBitWidth();
3727 return !elWidthSet.count(lhsElWidth) || laneSize * lhsElWidth != 512;
3730 target.addDynamicallyLegalOp<arith::SelectOp>([=](arith::SelectOp op) {
3731 auto resultType = dyn_cast<VectorType>(op.getType());
3735 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3738 return !elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512;
3741 target.addDynamicallyLegalOp<vector::ReductionOp>(
3742 [=](vector::ReductionOp op) {
3743 if (
auto kind = op.getKind(); kind != vector::CombiningKind::ADD &&
3744 kind != vector::CombiningKind::MINSI &&
3745 kind != vector::CombiningKind::MINUI &&
3746 kind != vector::CombiningKind::MINIMUMF &&
3747 kind != vector::CombiningKind::MAXSI &&
3748 kind != vector::CombiningKind::MAXUI &&
3749 kind != vector::CombiningKind::MAXIMUMF)
3752 auto vType = dyn_cast<VectorType>(op.getVector().getType());
3756 llvm::SmallSet<std::pair<unsigned, signed>, 16> laneSizeElWidthPairSet;
3757 laneSizeElWidthPairSet.insert({64, 8});
3758 laneSizeElWidthPairSet.insert({32, 16});
3759 laneSizeElWidthPairSet.insert({32, 32});
3760 laneSizeElWidthPairSet.insert({16, 32});
3762 Type scalarType = vType.getElementType();
3763 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3766 if (isa<IntegerType>(scalarType) &&
3767 !laneSizeElWidthPairSet.count(std::make_pair(laneSize, elWidth)))
3770 if (isa<FloatType>(scalarType) && laneSize != 16 && laneSize != 32)
3776 target.addIllegalOp<vector::ContractionOp, vector::TransposeOp,
3800 StringRef
getArgument() const final {
return "test-lower-vector-to-aievec"; }
3802 return "Lower vector operations to AIE vector intrinsics";
3806 .insert<affine::AffineDialect, xilinx::aievec::aie1::AIEVecAIE1Dialect,
3807 xilinx::aievec::AIEVecDialect, arith::ArithDialect,
3808 memref::MemRefDialect, scf::SCFDialect, vector::VectorDialect,
3809 emitc::EmitCDialect>();
3813 *
this,
"aie-target",
3815 "Select AIE version: \"aie\", \"aie2\", or \"aie2p\". This will "
3816 "determine the vector size and available operations."),
3817 llvm::cl::init(
"aie")};
3820 *
this,
"target-backend",
3821 llvm::cl::desc(
"Select translation backend: \"cpp\" or \"llvmir\". This "
3822 "will determine the aievec operations used to convert "
3823 "from vector dialect."),
3824 llvm::cl::init(
"cpp")};
3827 auto *op = getOperation();
3828 MLIRContext *context = &getContext();
3829 RewritePatternSet patterns(context);
3830 ConversionTarget target(*context);
3831 auto aieVersion = AIEArch::AIE;
3834 if (targetStr ==
"aieml" || targetStr ==
"aie2" || targetStr ==
"aie2p")
3835 aieVersion = AIEArch::AIE2;
3836 else if (targetStr !=
"aie") {
3837 op->emitError() <<
"unknown AIE target '" <<
aieTarget <<
"'";
3838 return signalPassFailure();
3845 if (backendStr ==
"llvmir") {
3846 backend = TargetBackend::LLVMIR;
3847 if (aieVersion == AIEArch::AIE) {
3848 op->emitError() <<
"targetting LLVM IR is not supported for AIEv1";
3849 signalPassFailure();
3852 }
else if (backendStr !=
"cpp") {
3853 op->emitError() <<
"unknown target backend'" <<
targetBackend <<
"'";
3854 signalPassFailure();
3859 populateAIEVecCommonConversionPatterns(patterns, backend);
3860 configureAIEVecCommonLegalizations(target, backend);
3861 if (aieVersion == AIEArch::AIE) {
3862 populateAIEVecV1ConversionPatterns(patterns, backend);
3863 configureAIEVecV1Legalizations(target, backend);
3865 populateAIEVecV2ConversionPatterns(patterns, backend);
3866 configureAIEVecV2Legalizations(target, backend);
3869 if (failed(applyPartialConversion(op, target, std::move(patterns))))
3870 return signalPassFailure();
3874static std::unique_ptr<Pass>
3876 return std::make_unique<LowerVectorToAIEVec>(options);
3890 MLIRContext *context = &getContext();
3891 RewritePatternSet patterns(context);
3892 ConversionTarget target(*context);
3894 target.addLegalDialect<aievec::AIEVecDialect>();
3895 target.addDynamicallyLegalOp<aievec::UPDOp>([](aievec::UPDOp op) {
3896 return op.getVector() ||
3897 (op->hasOneUse() && isa<aievec::UPDOp>(*op->getUsers().begin())) ||
3898 llvm::all_of(op->getUsers(),
3899 [](Operation *op) {
return isa<aievec::ExtOp>(op); });
3902 if (
auto *op = getOperation();
3903 failed(applyPartialConversion(op, target, std::move(patterns)))) {
3904 return signalPassFailure();
3917 MLIRContext *context = &getContext();
3918 RewritePatternSet patterns(context);
3919 ConversionTarget target(*context);
3921 target.addLegalDialect<aievec::AIEVecDialect>();
3922 target.addDynamicallyLegalOp<aievec::ExtOp>([](aievec::ExtOp op) {
3923 auto *defOp = op.getSource().getDefiningOp();
3924 return !defOp || !isa<aievec::UPDOp>(defOp) || !defOp->hasOneUse() ||
3928 if (
auto *op = getOperation();
3929 failed(applyPartialConversion(op, target, std::move(patterns)))) {
3930 return signalPassFailure();
3942 pm.addPass(createLowerVectorToAIEVec(options));
3943 pm.addPass(createCanonicalizerPass());
3946 pm.addPass(std::make_unique<ExtendUPDOpsPass>());
3947 pm.addPass(createCSEPass());
3948 pm.addPass(std::make_unique<SimplifyUPDOpsPass>());
3949 pm.addPass(createCanonicalizerPass());
LowerVectorMinMaxOpToAIEVecMinMaxOp< arith::MaximumFOp, aievec::MaxOp > LowerVectorMaximumFOpToAIEVecMaxOp
ComputeBandAndBorOpPattern< arith::OrIOp, aievec::BorOp > ComputeBorOpPattern
ComputeBandAndBorOpPattern< arith::AndIOp, aievec::BandOp > ComputeBandOpPattern
OneToOneVectorOpToAIEVecOpPattern< arith::SubFOp, aievec::aie1::SubOp > LowerVectorSubFOpToAIEVecSubOp
ComputeAbsOpPattern< math::AbsIOp > ComputeAbsIOpPattern
LowerTruncOpPattern< arith::TruncFOp > LowerTruncFOpPattern
LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp< arith::AddFOp, aievec::AddElemOp > LowerVectorAddFOpToAIEVecAddElemOp
LowerVectorMinMaxOpToAIEVecMinMaxOp< arith::MaxSIOp, aievec::MaxOp > LowerVectorMaxSIOpToAIEVecMaxOp
LowerExtOpPattern< arith::ExtFOp > LowerExtFOpPattern
LowerVectorCmpOpToAIEVecCmpOp< arith::CmpFOp, CmpFPredicate > LowerVectorCmpFOpToAIEVecCmpOp
OneToOneVectorOpToAIEVecOpPattern< arith::SubIOp, aievec::aie1::SubOp > LowerVectorSubIOpToAIEVecSubOp
ComputeAbsOpPattern< math::AbsFOp > ComputeAbsFOpPattern
LowerVectorCmpOpToAIEVecCmpOp< arith::CmpIOp, CmpIPredicate > LowerVectorCmpIOpToAIEVecCmpOp
LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp< arith::SubFOp, aievec::SubElemOp > LowerVectorSubFOpToAIEVecSubElemOp
LowerVectorMinMaxOpToAIEVecMinMaxOp< arith::MinSIOp, aievec::MinOp > LowerVectorMinSIOpToAIEVecMinOp
OneToOneVectorOpToAIEVecOpPattern< arith::AddFOp, aievec::aie1::AddOp > LowerVectorAddFOpToAIEVecAddOp
OneToOneVectorOpToAIEVecOpPattern< arith::MulFOp, aievec::aie1::MulOp > LowerVectorMulFOpToAIEVecMulOp
LowerExtOpPattern< arith::ExtSIOp > LowerExtSIOpPattern
LowerVectorMinMaxOpToAIEVecMinMaxOp< arith::MinimumFOp, aievec::MinOp > LowerVectorMinimumFOpToAIEVecMinOp
LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp< arith::AddIOp, aievec::AddElemOp > LowerVectorAddIOpToAIEVecAddElemOp
mlir::VectorType getFlattenedVectorType(mlir::VectorType vecTy)
unsigned getVectorLaneSize(mlir::VectorType type)
SmallVector< NamedAttribute > buildFMAOpSplatAttrForElemTy(aievec::aie1::FMAOp fmaOp, int64_t bcastPos, int64_t step=1)
std::optional< int64_t > getTransferReadAlignmentOffset(TransferReadLikeOp readOp, mlir::VectorType vType, int64_t alignment)
mlir::VectorType createVectorType(unsigned lanes, mlir::Type elementType)
int32_t getElementSizeInBits(mlir::VectorType type)
void buildLowerVectorToAIEVec(mlir::OpPassManager &pm, const LowerVectorToAIEVecOptions &options)
mlir::VectorType getVectorOpDestType(mlir::VectorType type, bool AIE2)
LogicalResult matchAndRewrite(SrcOpTy absOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(arith::XOrIOp xorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::CeilOp ceilOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::ErfOp erfOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::ExpOp expOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::ExpOp expOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::FloorOp floorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::DivFOp divOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::NegFOp negOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::RsqrtOp rsqrtOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::DivFOp divfOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::ShRSIOp rsOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::SqrtOp sqrtOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::TanhOp tanhOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::AddIOp addOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertMulAddToAIEVecFMAElemOpPattern(MLIRContext *context, unsigned shiftParam=0)
LogicalResult matchAndRewrite(aievec::aie1::AddOp addOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::MulFOp mulOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertMulFToAIEVecMulElemOpPattern(MLIRContext *context, unsigned shiftParam=0)
LogicalResult matchAndRewrite(arith::MulIOp mulOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertMulIToAIEVecMulElemOpPattern(MLIRContext *context, unsigned shiftParam=0)
LogicalResult matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertVectorFMAOpToAIEVecFMAElemOpPattern(MLIRContext *context, unsigned shiftParam=0)
ExpandUPDToUPDAndExtPattern(MLIRContext *context)
LogicalResult matchAndRewrite(aievec::UPDOp updOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
void runOnOperation() override
LogicalResult matchAndRewrite(aievec::aie1::FMAOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ExtOp extOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
FuseExtIntoUPDPattern(MLIRContext *context)
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(SrcOpTy extOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(SrcOpTy truncOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(arith::AddIOp addOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SrcOpTy::Adaptor OpAdaptor
LowerVectorContractionOpToAIEVecMatMulPattern(MLIRContext *context, bool matMoveToAcc=true)
LogicalResult matchAndRewrite(vector::ContractionOp contractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Value reshapeLeadingUnitDims(OpBuilder &b, Value v) const
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(arith::MulIOp mulOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::SelectOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Lower incoming vector operations into their corresponding AIE vector intrinsics.
void getDependentDialects(DialectRegistry ®istry) const override
LowerVectorToAIEVec(const LowerVectorToAIEVecOptions &options)
void runOnOperation() override
StringRef getDescription() const final
Option< std::string > aieTarget
Option< std::string > targetBackend
StringRef getArgument() const final
LogicalResult matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LowerVectorTransferReadToAIEUPD(MLIRContext *context, int64_t minVectorSize, int64_t maxVectorSize, int64_t alignment, int64_t maxLoadSize)
LogicalResult matchAndRewrite(vector::TransposeOp transpOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
void runOnOperation() override
Options for the "lower-vector-to-aievec" pipeline.
PassOptions::Option< std::string > aieTarget
PassOptions::Option< std::string > targetBackend