17#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
18#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
19#include "mlir/Dialect/Affine/IR/AffineOps.h"
20#include "mlir/Dialect/Func/IR/FuncOps.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/UB/IR/UBOps.h"
23#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
24#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
25#include "mlir/IR/Matchers.h"
26#include "mlir/IR/PatternMatch.h"
27#include "mlir/Pass/PassManager.h"
28#include "mlir/Transforms/DialectConversion.h"
29#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32#define DEBUG_TYPE "aievec-canonicalization"
36using namespace vector;
44static TargetBackend decodeTargetBackend(
const std::string &backend) {
45 if (!backend.empty()) {
46 if (backend ==
"llvmir")
47 return TargetBackend::LLVMIR;
49 return TargetBackend::UNKNOWN;
51 return TargetBackend::CPP;
54static AIEArch decodeAIETarget(
const std::string &target) {
55 if (!target.empty()) {
56 if (target ==
"aieml" || target ==
"aie2" || target ==
"aie2p")
59 return AIEArch::UNKNOWN;
68static bool isGemmBTransposedContractionOp(vector::ContractionOp op) {
69 if (op.getKind() != vector::CombiningKind::ADD)
73 auto lhsShape = op.getLhsType().getShape();
74 auto rhsShape = op.getRhsType().getShape();
75 auto accShape = cast<ShapedType>(op.getAccType()).getShape();
76 if (lhsShape.size() < 2 || rhsShape.size() < 2 || accShape.size() < 2)
80 SmallVector<vector::IteratorType> iterators = op.getIteratorTypesArray();
81 if (iterators.size() < 3)
83 auto innerMostIterators =
84 SmallVector<vector::IteratorType>(iterators.end() - 3, iterators.end());
85 if (vector::IteratorType::parallel != innerMostIterators[0] ||
86 vector::IteratorType::parallel != innerMostIterators[1] ||
87 vector::IteratorType::reduction != innerMostIterators[2])
91 SmallVector<AffineMap, 4> indexingMaps(op.getIndexingMapsArray());
92 SmallVector<int64_t> outerMostResults;
93 for (int64_t i = 0; i < indexingMaps[0].getNumResults() - 2; i++)
94 outerMostResults.push_back(i);
96 auto innerLhsMap = indexingMaps[0].dropResults(outerMostResults);
97 auto innerRhsMap = indexingMaps[1].dropResults(outerMostResults);
98 auto innerAccMap = indexingMaps[2].dropResults(outerMostResults);
101 auto *ctx = op.getContext();
103 AffineMap::getPermutationMap(ArrayRef<unsigned>{1, 0, 2}, ctx)
106 AffineMap::getPermutationMap(ArrayRef<unsigned>{0, 1, 2}, ctx)
109 AffineMap::getPermutationMap(ArrayRef<unsigned>{2, 0, 1}, ctx)
111 int64_t numOuterMostDims = indexingMaps[0].getNumDims() - 3;
112 return innerLhsMap == mmAidxMap.shiftDims(numOuterMostDims) &&
113 innerRhsMap == mmBidxMap.shiftDims(numOuterMostDims) &&
114 innerAccMap == mmCidxMap.shiftDims(numOuterMostDims);
136 ConversionPatternRewriter &rewriter)
const override {
138 if (adaptor.getPermutationMap().isConstant())
142 auto vType = readOp.getVectorType();
150 auto vLen = vType.getShape().back();
151 auto longVecTy = VectorType::get(2 * vLen, vType.getElementType());
159 auto loc = readOp.getLoc();
160 Value oldInnerMostIdx = adaptor.getIndices().back();
161 auto offsetCorrectionMap =
162 AffineMap::get(1, 0, getAffineDimExpr(0, readOp.getContext()) - offset);
163 Value newInnerMostIdx = rewriter
164 .create<affine::AffineApplyOp>(
165 readOp.getLoc(), offsetCorrectionMap,
166 SmallVector<Value, 1>({oldInnerMostIdx}))
168 SmallVector<Value, 8> alignedIdx;
169 alignedIdx.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
170 alignedIdx[alignedIdx.size() - 1] = newInnerMostIdx;
174 auto newReadOp = rewriter.create<vector::TransferReadOp>(
175 loc, longVecTy, adaptor.getBase(), alignedIdx, adaptor.getPadding());
178 rewriter.replaceOpWithNewOp<vector::ExtractStridedSliceOp>(
179 readOp, newReadOp.getResult(), offset, vLen, 1);
201 ConversionPatternRewriter &rewriter)
const override {
202 AffineMap map = readOp.getPermutationMap();
203 if (!map.isConstant())
206 Value srcMemRef = adaptor.getBase();
207 SmallVector<Value, 8> indices;
211 if (cast<MemRefType>(srcMemRef.getType()).getRank() == 0) {
213 .create<memref::ExpandShapeOp>(
214 readOp.getLoc(), SmallVector<int64_t, 1>({1}),
215 srcMemRef, SmallVector<ReassociationIndices, 1>({}))
217 newIdx = rewriter.create<arith::ConstantOp>(readOp.getLoc(),
218 rewriter.getIndexAttr(0L));
219 indices.push_back(newIdx);
221 indices.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
222 newIdx = indices[indices.size() - 1];
226 if (
auto applyOp = newIdx.getDefiningOp<affine::AffineApplyOp>())
227 if (applyOp.getAffineMap().getNumDims() == 1) {
228 newIdx = applyOp.getMapOperands()[0];
229 offset = applyOp.getAffineMap().compose(ArrayRef<int64_t>{0})[0];
233 int64_t vlen = readOp.getVector().getType().getShape()[0];
234 if (offset >= vlen) {
237 int64_t numElemsToSkip = vlen * (offset / vlen);
238 offset = offset % vlen;
239 auto newAddrMap = AffineMap::get(
240 1, 0, getAffineDimExpr(0, readOp.getContext()) + numElemsToSkip);
243 .create<affine::AffineApplyOp>(readOp.getLoc(), newAddrMap,
244 SmallVector<Value, 1>({newIdx}))
247 indices[indices.size() - 1] = newIdx;
248 auto newReadOp = rewriter.create<vector::TransferReadOp>(
249 readOp.getLoc(), readOp.getVector().getType(), srcMemRef, indices,
250 adaptor.getPadding());
251 auto extractOp = rewriter.create<vector::ExtractOp>(
252 readOp.getLoc(), newReadOp.getResult(), ArrayRef<int64_t>{offset});
253 rewriter.replaceOpWithNewOp<vector::SplatOp>(
254 readOp, newReadOp.getVector().getType(), extractOp.getResult());
270 PatternRewriter &rewriter)
const override {
271 arith::ExtSIOp extOp = cast<arith::ExtSIOp>(op);
272 Operation *defOp = extOp.getIn().getDefiningOp();
274 if (!defOp || isa<vector::TransferReadOp, memref::LoadOp,
275 affine::AffineLoadOp, func::CallOp>(defOp))
279 if (!isa<vector::BroadcastOp, vector::ExtractOp, vector::SplatOp,
280 vector::ExtractStridedSliceOp>(defOp))
283 Type extOpInTy = extOp.getIn().getType();
284 SmallVector<Value, 4> inputs;
285 for (Value operand : defOp->getOperands()) {
286 Type operandTy = operand.getType();
287 VectorType extOpInVecTy = dyn_cast<VectorType>(extOpInTy);
288 VectorType operandVecTy = dyn_cast<VectorType>(operandTy);
289 if (operandTy == extOpInTy) {
290 Type outTy = extOp.getOut().getType();
292 rewriter.create<arith::ExtSIOp>(extOp.getLoc(), outTy, operand)
294 }
else if (extOpInVecTy && extOpInVecTy.getElementType() == operandTy) {
297 cast<VectorType>(extOp.getOut().getType()).getElementType();
299 rewriter.create<arith::ExtSIOp>(extOp.getLoc(), outTy, operand)
301 }
else if (operandVecTy && operandVecTy.getElementType() == extOpInTy) {
304 VectorType::get(operandVecTy.getShape(), extOp.getOut().getType());
306 rewriter.create<arith::ExtSIOp>(extOp.getLoc(), outTy, operand)
308 }
else if (extOpInVecTy && operandVecTy &&
309 (extOpInVecTy.getElementType() ==
310 operandVecTy.getElementType())) {
312 Type outTy = VectorType::get(
313 operandVecTy.getShape(),
314 cast<VectorType>(extOp.getOut().getType()).getElementType());
316 rewriter.create<arith::ExtSIOp>(extOp.getLoc(), outTy, operand)
319 inputs.push_back(operand);
324 rewriter.create(extOp->getLoc(), defOp->getName().getIdentifier(),
325 inputs, {extOp.getOut().getType()}, defOp->getAttrs());
326 rewriter.replaceOp(extOp, newOp->getResult(0));
334template <
class UnaryOpA,
class UnaryOpB>
346 PatternRewriter &rewriter)
const override {
348 UnaryOpA::template hasTrait<OpTrait::OneOperand>(),
349 "SwapUnaryOps can only be instantiated for single-operand ops");
351 UnaryOpB::template hasTrait<OpTrait::OneOperand>(),
352 "SwapUnaryOps can only be instantiated for single-operand ops");
353 UnaryOpA aOp = bOp.getOperand().template getDefiningOp<UnaryOpA>();
355 return rewriter.notifyMatchFailure(bOp, UnaryOpB::getOperationName() +
356 " not preceeded by " +
357 UnaryOpA::getOperationName());
362 rewriter.create<UnaryOpB>(bOp->getLoc(), SmallVector<Type>({newA2BTy}),
363 aOp->getOperands(), bOp->getAttrs());
364 auto newB = rewriter.create<UnaryOpA>(
365 bOp->getLoc(), SmallVector<Type>({bOp.getResult().getType()}),
366 newA->getResults(), aOp->getAttrs());
367 rewriter.replaceOp(bOp, newB.getResult());
372static SmallVector<Value> collapseInnerMostDimIndices(PatternRewriter &b,
373 Location loc,
int numDims,
375 ArrayRef<int64_t> shape,
378 assert(layout.isMinorIdentity() &&
379 "dimension collapse in non-identity layout is not implemented");
380 auto newIdxExpr = b.getAffineDimExpr(numDims - 1);
382 for (int64_t dim = numDims - 2; dim >= 0; dim--) {
383 stride *= shape[shape.size() - (numDims - dim - 1)];
384 newIdxExpr = newIdxExpr + b.getAffineDimExpr(dim) * stride;
386 auto newIndexMap = AffineMap::get(numDims, 0, newIdxExpr);
387 Value newInnerMostIdxValue =
388 b.create<affine::AffineApplyOp>(loc, newIndexMap,
389 indices.take_back(numDims))
391 SmallVector<Value> newIdxRange;
392 for (
auto idx : indices.drop_back(numDims))
393 newIdxRange.push_back(idx);
394 newIdxRange.push_back(newInnerMostIdxValue);
398static Value collapseInnerMostShapeDims(PatternRewriter &b, Location loc,
399 int numDims, Value val) {
400 auto memRefTy = cast<MemRefType>(val.getType());
401 auto shape = memRefTy.getShape();
402 int64_t newInnerMostDim = std::accumulate(shape.end() - numDims, shape.end(),
403 1, std::multiplies<>());
404 SmallVector<int64_t, 4> newShape{shape.begin(), shape.end() - numDims + 1};
405 newShape[shape.size() - numDims] = newInnerMostDim;
406 auto newNumDims = newShape.size();
407 auto *ctx = b.getContext();
408 auto newMemRefTy = MemRefType::get(
409 newShape, memRefTy.getElementType(),
410 AffineMap::getMinorIdentityMap(newNumDims, newNumDims, ctx),
411 memRefTy.getMemorySpace());
412 auto reassocIndices =
413 getReassociationIndicesForCollapse(shape, newShape).value();
415 .create<memref::CollapseShapeOp>(loc, newMemRefTy, val, reassocIndices)
428 ConversionPatternRewriter &rewriter)
const override {
431 if (!adaptor.getPermutationMap().isMinorIdentity() || adaptor.getMask())
433 VectorType vectorTy = readOp.getVector().getType();
434 if (vectorTy.getRank() < 2)
437 MemRefType memRefTy = dyn_cast<MemRefType>(adaptor.getBase().getType());
440 auto memRefShape = memRefTy.getShape();
441 auto vecShape = vectorTy.getShape();
444 VectorType::get({std::accumulate(vecShape.begin(), vecShape.end(), 1,
445 std::multiplies<>())},
446 vectorTy.getElementType());
447 AffineMap layout = memRefTy.getLayout().getAffineMap();
449 collapseInnerMostDimIndices(rewriter, readOp.getLoc(), vecShape.size(),
450 adaptor.getIndices(), memRefShape, layout);
451 auto newSource = collapseInnerMostShapeDims(
452 rewriter, readOp.getLoc(), vecShape.size(), adaptor.getBase());
453 auto newVector = rewriter.create<vector::TransferReadOp>(
454 readOp.getLoc(), newVectorTy, newSource, newIndices,
456 arith::getZeroConstant(rewriter, readOp.getLoc(),
457 newVectorTy.getElementType()));
459 auto inBoundsArrayAttrOpt = adaptor.getInBounds();
460 if (inBoundsArrayAttrOpt) {
461 SmallVector<bool> inBounds =
462 llvm::to_vector(inBoundsArrayAttrOpt.getAsValueRange<BoolAttr>());
463 SmallVector<bool> newInBounds({
false});
464 newInBounds[0] = std::all_of(inBounds.begin(), inBounds.end(),
465 [](
bool v) { return v; });
466 newVector.getProperties().setInBounds(
467 rewriter.getBoolArrayAttr(newInBounds));
470 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, vectorTy,
486 ConversionPatternRewriter &rewriter)
const override {
489 if (!adaptor.getPermutationMap().isMinorIdentity() || adaptor.getMask())
491 VectorType vectorTy = cast<VectorType>(adaptor.getValueToStore().getType());
492 if (vectorTy.getRank() < 2)
495 MemRefType memRefTy = dyn_cast<MemRefType>(adaptor.getBase().getType());
498 auto memRefShape = memRefTy.getShape();
499 auto vecShape = vectorTy.getShape();
502 VectorType::get({std::accumulate(vecShape.begin(), vecShape.end(), 1,
503 std::multiplies<>())},
504 vectorTy.getElementType());
505 AffineMap layout = memRefTy.getLayout().getAffineMap();
508 .create<vector::ShapeCastOp>(writeOp.getLoc(), newVectorTy,
509 adaptor.getValueToStore())
512 collapseInnerMostDimIndices(rewriter, writeOp.getLoc(), vecShape.size(),
513 adaptor.getIndices(), memRefShape, layout);
514 auto newSource = collapseInnerMostShapeDims(
515 rewriter, writeOp.getLoc(), vecShape.size(), adaptor.getBase());
517 auto newOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
518 writeOp, newVector, newSource, newIndices);
520 auto inBoundsArrayAttrOpt = adaptor.getInBounds();
521 if (inBoundsArrayAttrOpt) {
522 SmallVector<bool> inBounds =
523 llvm::to_vector(inBoundsArrayAttrOpt.getAsValueRange<BoolAttr>());
524 SmallVector<bool> newInBounds({
false});
525 newInBounds[0] = std::all_of(inBounds.begin(), inBounds.end(),
526 [](
bool v) { return v; });
527 newOp.getProperties().setInBounds(rewriter.getBoolArrayAttr(newInBounds));
544 SmallVector<int64_t> shape{vecTy.getShape()};
545 auto nDim = shape.size();
546 int64_t dimNm1 = shape[nDim - 1];
547 shape[nDim - 1] = shape[nDim - 2];
548 shape[nDim - 2] = dimNm1;
549 auto elemTy = vecTy.getElementType();
550 return VectorType::get(shape, elemTy);
555 ConversionPatternRewriter &rewriter)
const override {
556 if (!isGemmBTransposedContractionOp(contractOp))
559 Location loc = contractOp.getLoc();
560 auto *ctx = rewriter.getContext();
562 Value rhsVal = adaptor.getRhs();
563 VectorType rhsVecTy = contractOp.getRhsType();
564 Type rhsElemTy = rhsVecTy.getElementType();
566 bool doExtF =
false, doExtSI =
false, doExtUI =
false;
567 if (
auto extfRhsOp = rhsVal.getDefiningOp<arith::ExtFOp>()) {
568 rhsVal = extfRhsOp.getIn();
569 rhsVecTy = cast<VectorType>(rhsVal.getType());
571 }
else if (
auto extsiRhsOp = rhsVal.getDefiningOp<arith::ExtSIOp>()) {
572 rhsVal = extsiRhsOp.getIn();
573 rhsVecTy = cast<VectorType>(rhsVal.getType());
575 }
else if (
auto extuiRhsOp = rhsVal.getDefiningOp<arith::ExtUIOp>()) {
576 rhsVal = extuiRhsOp.getIn();
577 rhsVecTy = cast<VectorType>(rhsVal.getType());
581 int64_t nDim = rhsVecTy.getShape().size();
582 SmallVector<int64_t> rhsPermutation;
583 for (int64_t i = 0; i < nDim - 2; i++)
584 rhsPermutation.push_back(i);
585 rhsPermutation.push_back(nDim - 1);
586 rhsPermutation.push_back(nDim - 2);
589 .create<vector::TransposeOp>(loc, transpRhsVecTy, rhsVal,
596 .create<arith::ExtFOp>(
597 loc, VectorType::get(transpRhsVecTy.getShape(), rhsElemTy),
603 .create<arith::ExtSIOp>(
604 loc, VectorType::get(transpRhsVecTy.getShape(), rhsElemTy),
610 .create<arith::ExtUIOp>(
611 loc, VectorType::get(transpRhsVecTy.getShape(), rhsElemTy),
615 SmallVector<AffineMap, 4> oldIdxMaps(contractOp.getIndexingMapsArray());
617 nDim = oldIdxMaps[1].getNumDims();
618 SmallVector<int64_t> innerDimPerm;
619 for (int64_t i = 0; i < nDim - 2; i++)
620 innerDimPerm.push_back(i);
621 innerDimPerm.push_back(nDim - 1);
622 innerDimPerm.push_back(nDim - 2);
623 auto transpPermMap = AffineMap::getPermutationMap(innerDimPerm, ctx);
625 auto newIdxMaps = rewriter.getAffineMapArrayAttr(
626 {oldIdxMaps[0], oldIdxMaps[1].compose(transpPermMap), oldIdxMaps[2]});
628 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
629 contractOp, contractOp.getResult().getType(), adaptor.getLhs(), rhsVal,
630 adaptor.getAcc(), newIdxMaps, contractOp.getIteratorTypes());
638static LogicalResult isAllZeroOffsetAccess(mlir::OperandRange indices) {
639 if (!llvm::all_of(indices, [](Value val) {
641 if (!matchPattern(val, m_Constant(&attr)))
643 return attr.getInt() == 0;
652static SmallVector<Value> opFoldResultsToValues(PatternRewriter &rewriter,
654 memref::SubViewOp subViewOp) {
655 OpBuilder::InsertionGuard g(rewriter);
656 rewriter.setInsertionPoint(subViewOp);
657 SmallVector<Value> newIndices;
658 for (OpFoldResult offset : subViewOp.getMixedOffsets()) {
660 if (
auto attr = dyn_cast<Attribute>(offset)) {
661 indexVal = rewriter.create<arith::ConstantIndexOp>(
662 loc, cast<IntegerAttr>(attr).getInt());
664 indexVal = cast<Value>(offset);
666 newIndices.push_back(indexVal);
691 PatternRewriter &rewriter)
const override {
693 auto subViewOp = dyn_cast_if_present<memref::SubViewOp>(
694 readOp.getBase().getDefiningOp());
699 if (failed(isAllZeroOffsetAccess(readOp.getIndices())))
703 SmallVector<Value> newIndices =
704 opFoldResultsToValues(rewriter, readOp.getLoc(), subViewOp);
707 rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
708 readOp, readOp.getType(), subViewOp.getSource(), newIndices,
709 readOp.getPadding(), readOp.getInBoundsValues());
735 PatternRewriter &rewriter)
const override {
737 auto subViewOp = dyn_cast_if_present<memref::SubViewOp>(
738 writeOp.getBase().getDefiningOp());
743 if (failed(isAllZeroOffsetAccess(writeOp.getIndices())))
747 SmallVector<Value> newIndices =
748 opFoldResultsToValues(rewriter, writeOp.getLoc(), subViewOp);
751 rewriter.create<vector::TransferWriteOp>(
752 writeOp.getLoc(), writeOp.getVector(), subViewOp.getSource(),
753 newIndices, writeOp.getInBoundsValues());
756 rewriter.eraseOp(writeOp);
771 PatternRewriter &rewriter)
const override {
772 auto insSrcTy = dyn_cast<VectorType>(insOp.getValueToStoreType());
776 auto srcShape = insSrcTy.getShape();
777 auto dstShape = insOp.getDestVectorType().getShape();
779 unsigned long numLeadUnitDimDst = 0;
780 while (numLeadUnitDimDst < dstShape.size() &&
781 dstShape[numLeadUnitDimDst] == 1)
784 if (!numLeadUnitDimDst)
787 unsigned long numLeadUnitDimSrc = 0;
788 while (numLeadUnitDimSrc < srcShape.size() &&
789 srcShape[numLeadUnitDimSrc] == 1)
792 SmallVector<int64_t> nonLeadUnitDimDstShape(
793 dstShape.begin() + numLeadUnitDimDst, dstShape.end());
794 SmallVector<int64_t> nonLeadUnitDimSrcShape(
795 srcShape.begin() + numLeadUnitDimSrc, srcShape.end());
797 if (nonLeadUnitDimSrcShape != nonLeadUnitDimDstShape)
800 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
801 insOp, insOp.getDestVectorType(), insOp.getValueToStore());
810configureCommonAIECanonicalizeLegalizations(ConversionTarget &target,
812 target.addLegalDialect<arith::ArithDialect, affine::AffineDialect,
813 memref::MemRefDialect, vector::VectorDialect,
818populateCommonAIECanonicalizeConversionPatterns(RewritePatternSet &patterns,
821 patterns.getContext());
828static void configureAIEv1CanonicalizeLegalizations(ConversionTarget &target,
830 target.addDynamicallyLegalOp<vector::TransferReadOp>(
831 [](vector::TransferReadOp op) {
832 return !op.getPermutationMap().isConstant() &&
839populateAIEv1CanonicalizeConversionPatterns(RewritePatternSet &patterns,
849static void configureAIE2CanonicalizeLegalizations(ConversionTarget &target,
851 target.addDynamicallyLegalOp<vector::TransferReadOp>(
852 [](vector::TransferReadOp op) {
853 return !op.getPermutationMap().isConstant() &&
856 op.getVector().getType().getRank() < 2;
858 target.addDynamicallyLegalOp<vector::TransferWriteOp>(
859 [](vector::TransferWriteOp op) {
860 return cast<VectorType>(op.getVector().getType()).getRank() < 2;
862 target.addDynamicallyLegalOp<vector::ContractionOp>(
863 [](vector::ContractionOp op) {
864 return !isGemmBTransposedContractionOp(op);
869populateAIE2CanonicalizeConversionPatterns(RewritePatternSet &patterns,
883 :
public PassWrapper<VectorBroadcastLoweringPass, OperationPass<>> {
886 auto *op = getOperation();
887 MLIRContext *context = &getContext();
888 RewritePatternSet patterns(context);
889 populateVectorBroadcastLoweringPatterns(patterns);
891 patterns.getContext());
893 (void)applyPatternsGreedily(op, std::move(patterns));
897static std::unique_ptr<::mlir::Pass> createVectorBroadcastLoweringPass() {
898 return std::make_unique<VectorBroadcastLoweringPass>();
909 :
public PassWrapper<CanonicalizeVectorForAIEVecPass, OperationPass<>> {
926 return "test-canonicalize-vector-for-aievec";
930 return "Canonicalize vector operations for AIEVec conversion";
935 .insert<arith::ArithDialect, memref::MemRefDialect,
936 vector::VectorDialect, affine::AffineDialect, ub::UBDialect>();
942 "Select AIE version: \"aie\", \"aie2\", or \"aie2p\". This will "
943 "determine the vector size and available operations."),
944 llvm::cl::init(
"aie")};
947 *
this,
"target-backend",
948 llvm::cl::desc(
"Select translation backend: \"cpp\" or \"llvmir\". This "
949 "will determine the aievec operations used to convert "
950 "from vector dialect."),
951 llvm::cl::init(
"cpp")};
954 auto *op = getOperation();
955 MLIRContext *context = &getContext();
956 RewritePatternSet patterns(context);
957 ConversionTarget target(*context);
960 if (aieVersion == AIEArch::UNKNOWN) {
961 op->emitError() <<
"unknown AIE target '" <<
aieTarget <<
"'";
967 if (backend == TargetBackend::UNKNOWN) {
968 op->emitError() <<
"unknown target backend '" <<
targetBackend <<
"'";
972 if (backend == TargetBackend::LLVMIR && aieVersion == AIEArch::AIE) {
973 op->emitError() <<
"targetting LLVM IR is not supported for AIEv1";
978 populateCommonAIECanonicalizeConversionPatterns(patterns, backend);
979 configureCommonAIECanonicalizeLegalizations(target, backend);
980 if (aieVersion == AIEArch::AIE) {
981 populateAIEv1CanonicalizeConversionPatterns(patterns, backend);
982 configureAIEv1CanonicalizeLegalizations(target, backend);
984 populateAIE2CanonicalizeConversionPatterns(patterns, backend);
985 configureAIE2CanonicalizeLegalizations(target, backend);
989 RewritePatternSet patterns(context);
992 (void)applyPatternsGreedily(op, std::move(patterns));
995 if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
1001static std::unique_ptr<::mlir::Pass> createCanonicalizeVectorForAIEVecPass(
1003 return std::make_unique<CanonicalizeVectorForAIEVecPass>(options);
1007 :
public PassWrapper<HoistCastOpToDataSourcePass, OperationPass<>> {
1010 auto *op = getOperation();
1011 MLIRContext *context = &getContext();
1012 RewritePatternSet patterns(context);
1016 (void)applyPatternsGreedily(op, std::move(patterns));
1020static std::unique_ptr<::mlir::Pass> createHoistCastOpToDataSourcePass() {
1021 return std::make_unique<HoistCastOpToDataSourcePass>();
1025 :
public PassWrapper<ReorderOperationsPass, OperationPass<>> {
1028 auto *op = getOperation();
1029 MLIRContext *context = &getContext();
1030 RewritePatternSet patterns(context);
1033 patterns.getContext(),
1034 [](arith::ExtSIOp extOp, vector::BroadcastOp bcastOp) -> Type {
1035 Type extInElemTy = extOp.getIn().getType();
1036 auto extInVecTy = dyn_cast<VectorType>(extInElemTy);
1038 extInElemTy = extInVecTy.getElementType();
1039 return VectorType::get(bcastOp.getResultVectorType().getShape(),
1043 (void)applyPatternsGreedily(op, std::move(patterns));
1047static std::unique_ptr<::mlir::Pass> createReorderOperationsPass() {
1048 return std::make_unique<ReorderOperationsPass>();
1060 if (decodeTargetBackend(options.
targetBackend) == TargetBackend::LLVMIR)
1061 pm.addPass(createReorderOperationsPass());
1063 pm.addPass(createVectorBroadcastLoweringPass());
1064 pm.addPass(createCanonicalizeVectorForAIEVecPass(options));
1065 if (decodeTargetBackend(options.
targetBackend) == TargetBackend::CPP)
1066 pm.addPass(createHoistCastOpToDataSourcePass());
std::unique_ptr<::mlir::Pass > createCopyRemovalPass()
Create a pass that removes unnecessary Copy operations.
std::optional< int64_t > getTransferReadAlignmentOffset(TransferReadLikeOp readOp, mlir::VectorType vType, int64_t alignment)
int32_t getElementSizeInBits(mlir::VectorType type)
void buildCanonicalizeVectorForAIEVec(mlir::OpPassManager &pm, const CanonicalizeVectorForAIEVecOptions &options)
Pattern to canonicalize trivial vector.transfer_read operations on subviews.
LogicalResult matchAndRewrite(vector::TransferReadOp readOp, PatternRewriter &rewriter) const override
Pattern to canonicalize trivial vector.transfer_write operations on subviews.
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, PatternRewriter &rewriter) const override
StringRef getArgument() const final
void getDependentDialects(DialectRegistry ®istry) const override
StringRef getDescription() const final
Option< std::string > aieTarget
CanonicalizeVectorForAIEVecPass(const CanonicalizeVectorForAIEVecOptions &options)
Option< std::string > targetBackend
void runOnOperation() override
LogicalResult matchAndRewrite(vector::InsertOp insOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertSplatTransferReadToBroadcastPattern(MLIRContext *context)
LogicalResult matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
void runOnOperation() override
HoistCastOpToDataSourcePattern(MLIRContext *context)
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
void runOnOperation() override
SplitUnalignedTransferReadPattern(MLIRContext *context, int64_t maxVectorSize, int64_t alignment)
LogicalResult matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
InferTypeB2AFnTy inferTypeB2A
LogicalResult matchAndRewrite(UnaryOpB bOp, PatternRewriter &rewriter) const override
std::function< Type(UnaryOpA aOp, UnaryOpB bOp)> InferTypeB2AFnTy
SwapUnaryOpsPattern(MLIRContext *context, InferTypeB2AFnTy inferType)
void runOnOperation() override
Options for the "canonicalize-vector-for-aievec" pipeline.
PassOptions::Option< std::string > targetBackend
PassOptions::Option< std::string > aieTarget