17#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
18#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
19#include "mlir/Dialect/Affine/IR/AffineOps.h"
20#include "mlir/Dialect/Func/IR/FuncOps.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
23#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
24#include "mlir/IR/PatternMatch.h"
25#include "mlir/Pass/PassManager.h"
26#include "mlir/Transforms/DialectConversion.h"
27#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30#define DEBUG_TYPE "aievec-canonicalization"
34using namespace vector;
42static TargetBackend decodeTargetBackend(
const std::string &backend) {
43 if (!backend.empty()) {
44 if (backend ==
"llvmir")
45 return TargetBackend::LLVMIR;
47 return TargetBackend::UNKNOWN;
49 return TargetBackend::CPP;
52static AIEArch decodeAIETarget(
const std::string &target) {
53 if (!target.empty()) {
54 if (target ==
"aieml" || target ==
"aie2")
57 return AIEArch::UNKNOWN;
66static bool isGemmBTransposedContractionOp(vector::ContractionOp op) {
67 if (op.getKind() != vector::CombiningKind::ADD)
71 auto lhsShape = op.getLhsType().getShape();
72 auto rhsShape = op.getRhsType().getShape();
73 auto accShape = cast<ShapedType>(op.getAccType()).getShape();
74 if (lhsShape.size() < 2 || rhsShape.size() < 2 || accShape.size() < 2)
78 SmallVector<vector::IteratorType> iterators = op.getIteratorTypesArray();
79 if (iterators.size() < 3)
81 auto innerMostIterators =
82 SmallVector<vector::IteratorType>(iterators.end() - 3, iterators.end());
83 if (vector::IteratorType::parallel != innerMostIterators[0] ||
84 vector::IteratorType::parallel != innerMostIterators[1] ||
85 vector::IteratorType::reduction != innerMostIterators[2])
89 SmallVector<AffineMap, 4> indexingMaps(op.getIndexingMapsArray());
90 SmallVector<int64_t> outerMostResults;
91 for (int64_t i = 0; i < indexingMaps[0].getNumResults() - 2; i++)
92 outerMostResults.push_back(i);
94 auto innerLhsMap = indexingMaps[0].dropResults(outerMostResults);
95 auto innerRhsMap = indexingMaps[1].dropResults(outerMostResults);
96 auto innerAccMap = indexingMaps[2].dropResults(outerMostResults);
99 auto *ctx = op.getContext();
101 AffineMap::getPermutationMap(ArrayRef<unsigned>{1, 0, 2}, ctx)
104 AffineMap::getPermutationMap(ArrayRef<unsigned>{0, 1, 2}, ctx)
107 AffineMap::getPermutationMap(ArrayRef<unsigned>{2, 0, 1}, ctx)
109 int64_t numOuterMostDims = indexingMaps[0].getNumDims() - 3;
110 return innerLhsMap == mmAidxMap.shiftDims(numOuterMostDims) &&
111 innerRhsMap == mmBidxMap.shiftDims(numOuterMostDims) &&
112 innerAccMap == mmCidxMap.shiftDims(numOuterMostDims);
134 ConversionPatternRewriter &rewriter)
const override {
136 if (adaptor.getPermutationMap().isConstant())
140 auto vType = readOp.getVectorType();
148 auto vLen = vType.getShape().back();
149 auto longVecTy = VectorType::get(2 * vLen, vType.getElementType());
157 auto loc = readOp.getLoc();
158 Value oldInnerMostIdx = adaptor.getIndices().back();
159 auto offsetCorrectionMap =
160 AffineMap::get(1, 0, getAffineDimExpr(0, readOp.getContext()) - offset);
161 Value newInnerMostIdx = rewriter
162 .create<affine::AffineApplyOp>(
163 readOp.getLoc(), offsetCorrectionMap,
164 SmallVector<Value, 1>({oldInnerMostIdx}))
166 SmallVector<Value, 8> alignedIdx;
167 alignedIdx.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
168 alignedIdx[alignedIdx.size() - 1] = newInnerMostIdx;
172 auto newReadOp = rewriter.create<vector::TransferReadOp>(
173 loc, longVecTy, adaptor.getSource(), alignedIdx, adaptor.getPadding());
176 rewriter.replaceOpWithNewOp<vector::ExtractStridedSliceOp>(
177 readOp, newReadOp.getResult(), offset, vLen, 1);
199 ConversionPatternRewriter &rewriter)
const override {
200 AffineMap map = readOp.getPermutationMap();
201 if (!map.isConstant())
204 Value srcMemRef = adaptor.getSource();
205 SmallVector<Value, 8> indices;
209 if (cast<MemRefType>(srcMemRef.getType()).getRank() == 0) {
211 .create<memref::ExpandShapeOp>(
212 readOp.getLoc(), SmallVector<int64_t, 1>({1}),
213 srcMemRef, SmallVector<ReassociationIndices, 1>({}))
215 newIdx = rewriter.create<arith::ConstantOp>(readOp.getLoc(),
216 rewriter.getIndexAttr(0L));
217 indices.push_back(newIdx);
219 indices.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
220 newIdx = indices[indices.size() - 1];
224 if (
auto applyOp = newIdx.getDefiningOp<affine::AffineApplyOp>())
225 if (applyOp.getAffineMap().getNumDims() == 1) {
226 newIdx = applyOp.getMapOperands()[0];
227 offset = applyOp.getAffineMap().compose(ArrayRef<int64_t>{0})[0];
231 int64_t vlen = readOp.getVector().getType().getShape()[0];
232 if (offset >= vlen) {
235 int64_t numElemsToSkip = vlen * (offset / vlen);
236 offset = offset % vlen;
237 auto newAddrMap = AffineMap::get(
238 1, 0, getAffineDimExpr(0, readOp.getContext()) + numElemsToSkip);
241 .create<affine::AffineApplyOp>(readOp.getLoc(), newAddrMap,
242 SmallVector<Value, 1>({newIdx}))
245 indices[indices.size() - 1] = newIdx;
246 auto newReadOp = rewriter.create<vector::TransferReadOp>(
247 readOp.getLoc(), readOp.getVector().getType(), srcMemRef, indices,
248 adaptor.getPadding());
249 auto extractOp = rewriter.create<vector::ExtractOp>(
250 readOp.getLoc(), newReadOp.getResult(), ArrayRef<int64_t>{offset});
251 rewriter.replaceOpWithNewOp<vector::SplatOp>(
252 readOp, newReadOp.getVector().getType(), extractOp.getResult());
268 PatternRewriter &rewriter)
const override {
269 arith::ExtSIOp extOp = cast<arith::ExtSIOp>(op);
270 Operation *defOp = extOp.getIn().getDefiningOp();
272 if (!defOp || isa<vector::TransferReadOp, memref::LoadOp,
273 affine::AffineLoadOp, func::CallOp>(defOp))
277 if (!isa<vector::BroadcastOp, vector::ExtractOp, vector::SplatOp,
278 vector::ExtractStridedSliceOp>(defOp))
281 Type extOpInTy = extOp.getIn().getType();
282 SmallVector<Value, 4> inputs;
283 for (Value operand : defOp->getOperands()) {
284 Type operandTy = operand.getType();
285 VectorType extOpInVecTy = dyn_cast<VectorType>(extOpInTy);
286 VectorType operandVecTy = dyn_cast<VectorType>(operandTy);
287 if (operandTy == extOpInTy) {
288 Type outTy = extOp.getOut().getType();
290 rewriter.create<arith::ExtSIOp>(extOp.getLoc(), outTy, operand)
292 }
else if (extOpInVecTy && extOpInVecTy.getElementType() == operandTy) {
295 cast<VectorType>(extOp.getOut().getType()).getElementType();
297 rewriter.create<arith::ExtSIOp>(extOp.getLoc(), outTy, operand)
299 }
else if (operandVecTy && operandVecTy.getElementType() == extOpInTy) {
302 VectorType::get(operandVecTy.getShape(), extOp.getOut().getType());
304 rewriter.create<arith::ExtSIOp>(extOp.getLoc(), outTy, operand)
306 }
else if (extOpInVecTy && operandVecTy &&
307 (extOpInVecTy.getElementType() ==
308 operandVecTy.getElementType())) {
310 Type outTy = VectorType::get(
311 operandVecTy.getShape(),
312 cast<VectorType>(extOp.getOut().getType()).getElementType());
314 rewriter.create<arith::ExtSIOp>(extOp.getLoc(), outTy, operand)
317 inputs.push_back(operand);
322 rewriter.create(extOp->getLoc(), defOp->getName().getIdentifier(),
323 inputs, {extOp.getOut().getType()}, defOp->getAttrs());
324 rewriter.replaceOp(extOp, newOp->getResult(0));
332template <
class UnaryOpA,
class UnaryOpB>
344 PatternRewriter &rewriter)
const override {
346 UnaryOpA::template hasTrait<OpTrait::OneOperand>(),
347 "SwapUnaryOps can only be instantiated for single-operand ops");
349 UnaryOpB::template hasTrait<OpTrait::OneOperand>(),
350 "SwapUnaryOps can only be instantiated for single-operand ops");
351 UnaryOpA aOp = bOp.getOperand().template getDefiningOp<UnaryOpA>();
353 return rewriter.notifyMatchFailure(bOp, UnaryOpB::getOperationName() +
354 " not preceeded by " +
355 UnaryOpA::getOperationName());
360 rewriter.create<UnaryOpB>(bOp->getLoc(), SmallVector<Type>({newA2BTy}),
361 aOp->getOperands(), bOp->getAttrs());
362 auto newB = rewriter.create<UnaryOpA>(
363 bOp->getLoc(), SmallVector<Type>({bOp.getResult().getType()}),
364 newA->getResults(), aOp->getAttrs());
365 rewriter.replaceOp(bOp, newB.getResult());
370static SmallVector<Value> collapseInnerMostDimIndices(PatternRewriter &b,
371 Location loc,
int numDims,
373 ArrayRef<int64_t> shape,
376 assert(layout.isMinorIdentity() &&
377 "dimension collapse in non-identity layout is not implemented");
378 auto newIdxExpr = b.getAffineDimExpr(numDims - 1);
380 for (int64_t dim = numDims - 2; dim >= 0; dim--) {
381 stride *= shape[shape.size() - (numDims - dim - 1)];
382 newIdxExpr = newIdxExpr + b.getAffineDimExpr(dim) * stride;
384 auto newIndexMap = AffineMap::get(numDims, 0, newIdxExpr);
385 Value newInnerMostIdxValue =
386 b.create<affine::AffineApplyOp>(loc, newIndexMap,
387 indices.take_back(numDims))
389 SmallVector<Value> newIdxRange;
390 for (
auto idx : indices.drop_back(numDims))
391 newIdxRange.push_back(idx);
392 newIdxRange.push_back(newInnerMostIdxValue);
396static Value collapseInnerMostShapeDims(PatternRewriter &b, Location loc,
397 int numDims, Value val) {
398 auto memRefTy = cast<MemRefType>(val.getType());
399 auto shape = memRefTy.getShape();
400 int64_t newInnerMostDim = std::accumulate(shape.end() - numDims, shape.end(),
401 1, std::multiplies<>());
402 SmallVector<int64_t, 4> newShape{shape.begin(), shape.end() - numDims + 1};
403 newShape[shape.size() - numDims] = newInnerMostDim;
404 auto newNumDims = newShape.size();
405 auto *ctx = b.getContext();
406 auto newMemRefTy = MemRefType::get(
407 newShape, memRefTy.getElementType(),
408 AffineMap::getMinorIdentityMap(newNumDims, newNumDims, ctx),
409 memRefTy.getMemorySpace());
410 auto reassocIndices =
411 getReassociationIndicesForCollapse(shape, newShape).value();
413 .create<memref::CollapseShapeOp>(loc, newMemRefTy, val, reassocIndices)
426 ConversionPatternRewriter &rewriter)
const override {
429 if (!adaptor.getPermutationMap().isMinorIdentity() || adaptor.getMask())
431 VectorType vectorTy = readOp.getVector().getType();
432 if (vectorTy.getRank() < 2)
435 MemRefType memRefTy = dyn_cast<MemRefType>(adaptor.getSource().getType());
438 auto memRefShape = memRefTy.getShape();
439 auto vecShape = vectorTy.getShape();
442 VectorType::get({std::accumulate(vecShape.begin(), vecShape.end(), 1,
443 std::multiplies<>())},
444 vectorTy.getElementType());
445 AffineMap layout = memRefTy.getLayout().getAffineMap();
447 collapseInnerMostDimIndices(rewriter, readOp.getLoc(), vecShape.size(),
448 adaptor.getIndices(), memRefShape, layout);
449 auto newSource = collapseInnerMostShapeDims(
450 rewriter, readOp.getLoc(), vecShape.size(), adaptor.getSource());
451 auto newVector = rewriter.create<vector::TransferReadOp>(
452 readOp.getLoc(), newVectorTy, newSource, newIndices);
454 auto inBoundsArrayAttrOpt = adaptor.getInBounds();
455 if (inBoundsArrayAttrOpt) {
456 SmallVector<bool> inBounds =
457 llvm::to_vector(inBoundsArrayAttrOpt.getAsValueRange<BoolAttr>());
458 SmallVector<bool> newInBounds({
false});
459 newInBounds[0] = std::all_of(inBounds.begin(), inBounds.end(),
460 [](
bool v) { return v; });
461 newVector.getProperties().setInBounds(
462 rewriter.getBoolArrayAttr(newInBounds));
465 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, vectorTy,
481 ConversionPatternRewriter &rewriter)
const override {
484 if (!adaptor.getPermutationMap().isMinorIdentity() || adaptor.getMask())
486 VectorType vectorTy = cast<VectorType>(adaptor.getVector().getType());
487 if (vectorTy.getRank() < 2)
490 MemRefType memRefTy = dyn_cast<MemRefType>(adaptor.getSource().getType());
493 auto memRefShape = memRefTy.getShape();
494 auto vecShape = vectorTy.getShape();
497 VectorType::get({std::accumulate(vecShape.begin(), vecShape.end(), 1,
498 std::multiplies<>())},
499 vectorTy.getElementType());
500 AffineMap layout = memRefTy.getLayout().getAffineMap();
501 auto newVector = rewriter
502 .create<vector::ShapeCastOp>(
503 writeOp.getLoc(), newVectorTy, adaptor.getVector())
506 collapseInnerMostDimIndices(rewriter, writeOp.getLoc(), vecShape.size(),
507 adaptor.getIndices(), memRefShape, layout);
508 auto newSource = collapseInnerMostShapeDims(
509 rewriter, writeOp.getLoc(), vecShape.size(), adaptor.getSource());
511 auto newOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
512 writeOp, newVector, newSource, newIndices);
514 auto inBoundsArrayAttrOpt = adaptor.getInBounds();
515 if (inBoundsArrayAttrOpt) {
516 SmallVector<bool> inBounds =
517 llvm::to_vector(inBoundsArrayAttrOpt.getAsValueRange<BoolAttr>());
518 SmallVector<bool> newInBounds({
false});
519 newInBounds[0] = std::all_of(inBounds.begin(), inBounds.end(),
520 [](
bool v) { return v; });
521 newOp.getProperties().setInBounds(rewriter.getBoolArrayAttr(newInBounds));
538 SmallVector<int64_t> shape{vecTy.getShape()};
539 auto nDim = shape.size();
540 int64_t dimNm1 = shape[nDim - 1];
541 shape[nDim - 1] = shape[nDim - 2];
542 shape[nDim - 2] = dimNm1;
543 auto elemTy = vecTy.getElementType();
544 return VectorType::get(shape, elemTy);
549 ConversionPatternRewriter &rewriter)
const override {
550 if (!isGemmBTransposedContractionOp(contractOp))
553 Location loc = contractOp.getLoc();
554 auto *ctx = rewriter.getContext();
556 Value rhsVal = adaptor.getRhs();
557 VectorType rhsVecTy = contractOp.getRhsType();
558 Type rhsElemTy = rhsVecTy.getElementType();
560 bool doExtF =
false, doExtSI =
false, doExtUI =
false;
561 if (
auto extfRhsOp = rhsVal.getDefiningOp<arith::ExtFOp>()) {
562 rhsVal = extfRhsOp.getIn();
563 rhsVecTy = cast<VectorType>(rhsVal.getType());
565 }
else if (
auto extsiRhsOp = rhsVal.getDefiningOp<arith::ExtSIOp>()) {
566 rhsVal = extsiRhsOp.getIn();
567 rhsVecTy = cast<VectorType>(rhsVal.getType());
569 }
else if (
auto extuiRhsOp = rhsVal.getDefiningOp<arith::ExtUIOp>()) {
570 rhsVal = extuiRhsOp.getIn();
571 rhsVecTy = cast<VectorType>(rhsVal.getType());
575 int64_t nDim = rhsVecTy.getShape().size();
576 SmallVector<int64_t> rhsPermutation;
577 for (int64_t i = 0; i < nDim - 2; i++)
578 rhsPermutation.push_back(i);
579 rhsPermutation.push_back(nDim - 1);
580 rhsPermutation.push_back(nDim - 2);
583 .create<vector::TransposeOp>(loc, transpRhsVecTy, rhsVal,
590 .create<arith::ExtFOp>(
591 loc, VectorType::get(transpRhsVecTy.getShape(), rhsElemTy),
597 .create<arith::ExtSIOp>(
598 loc, VectorType::get(transpRhsVecTy.getShape(), rhsElemTy),
604 .create<arith::ExtUIOp>(
605 loc, VectorType::get(transpRhsVecTy.getShape(), rhsElemTy),
609 SmallVector<AffineMap, 4> oldIdxMaps(contractOp.getIndexingMapsArray());
611 nDim = oldIdxMaps[1].getNumDims();
612 SmallVector<int64_t> innerDimPerm;
613 for (int64_t i = 0; i < nDim - 2; i++)
614 innerDimPerm.push_back(i);
615 innerDimPerm.push_back(nDim - 1);
616 innerDimPerm.push_back(nDim - 2);
617 auto transpPermMap = AffineMap::getPermutationMap(innerDimPerm, ctx);
619 auto newIdxMaps = rewriter.getAffineMapArrayAttr(
620 {oldIdxMaps[0], oldIdxMaps[1].compose(transpPermMap), oldIdxMaps[2]});
622 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
623 contractOp, contractOp.getResult().getType(), adaptor.getLhs(), rhsVal,
624 adaptor.getAcc(), newIdxMaps, contractOp.getIteratorTypes());
640 PatternRewriter &rewriter)
const override {
641 auto insSrcTy = dyn_cast<VectorType>(insOp.getSourceType());
645 auto srcShape = insSrcTy.getShape();
646 auto dstShape = insOp.getDestVectorType().getShape();
648 unsigned long numLeadUnitDimDst = 0;
649 while (numLeadUnitDimDst < dstShape.size() &&
650 dstShape[numLeadUnitDimDst] == 1)
653 if (!numLeadUnitDimDst)
656 unsigned long numLeadUnitDimSrc = 0;
657 while (numLeadUnitDimSrc < srcShape.size() &&
658 srcShape[numLeadUnitDimSrc] == 1)
661 SmallVector<int64_t> nonLeadUnitDimDstShape(
662 dstShape.begin() + numLeadUnitDimDst, dstShape.end());
663 SmallVector<int64_t> nonLeadUnitDimSrcShape(
664 srcShape.begin() + numLeadUnitDimSrc, srcShape.end());
666 if (nonLeadUnitDimSrcShape != nonLeadUnitDimDstShape)
669 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
670 insOp, insOp.getDestVectorType(), insOp.getSource());
679configureCommonAIECanonicalizeLegalizations(ConversionTarget &target,
681 target.addLegalDialect<arith::ArithDialect, affine::AffineDialect,
682 memref::MemRefDialect, vector::VectorDialect>();
686populateCommonAIECanonicalizeConversionPatterns(RewritePatternSet &patterns,
689 patterns.getContext());
696static void configureAIEv1CanonicalizeLegalizations(ConversionTarget &target,
698 target.addDynamicallyLegalOp<vector::TransferReadOp>(
699 [](vector::TransferReadOp op) {
700 return !op.getPermutationMap().isConstant() &&
707populateAIEv1CanonicalizeConversionPatterns(RewritePatternSet &patterns,
717static void configureAIE2CanonicalizeLegalizations(ConversionTarget &target,
719 target.addDynamicallyLegalOp<vector::TransferReadOp>(
720 [](vector::TransferReadOp op) {
721 return !op.getPermutationMap().isConstant() &&
724 op.getVector().getType().getRank() < 2;
726 target.addDynamicallyLegalOp<vector::TransferWriteOp>(
727 [](vector::TransferWriteOp op) {
728 return cast<VectorType>(op.getVector().getType()).getRank() < 2;
730 target.addDynamicallyLegalOp<vector::ContractionOp>(
731 [](vector::ContractionOp op) {
732 return !isGemmBTransposedContractionOp(op);
737populateAIE2CanonicalizeConversionPatterns(RewritePatternSet &patterns,
751 :
public PassWrapper<VectorBroadcastLoweringPass, OperationPass<>> {
754 auto *op = getOperation();
755 MLIRContext *context = &getContext();
756 RewritePatternSet patterns(context);
757 populateVectorBroadcastLoweringPatterns(patterns);
759 patterns.getContext());
761 (void)applyPatternsGreedily(op, std::move(patterns));
765static std::unique_ptr<::mlir::Pass> createVectorBroadcastLoweringPass() {
766 return std::make_unique<VectorBroadcastLoweringPass>();
777 :
public PassWrapper<CanonicalizeVectorForAIEVecPass, OperationPass<>> {
794 return "test-canonicalize-vector-for-aievec";
798 return "Canonicalize vector operations for AIEVec conversion";
802 registry.insert<arith::ArithDialect, memref::MemRefDialect,
803 vector::VectorDialect, affine::AffineDialect>();
808 llvm::cl::desc(
"Select AIE version: \"aie\" or \"aie2\". This will "
809 "determine the vector size and available operations."),
810 llvm::cl::init(
"aie")};
813 *
this,
"target-backend",
814 llvm::cl::desc(
"Select translation backend: \"cpp\" or \"llvmir\". This "
815 "will determine the aievec operations used to convert "
816 "from vector dialect."),
817 llvm::cl::init(
"cpp")};
820 auto *op = getOperation();
821 MLIRContext *context = &getContext();
822 RewritePatternSet patterns(context);
823 ConversionTarget target(*context);
826 if (aieVersion == AIEArch::UNKNOWN) {
827 op->emitError() <<
"unknown AIE target '" <<
aieTarget <<
"'";
833 if (backend == TargetBackend::UNKNOWN) {
834 op->emitError() <<
"unknown target backend '" <<
targetBackend <<
"'";
838 if (backend == TargetBackend::LLVMIR && aieVersion == AIEArch::AIE) {
839 op->emitError() <<
"targetting LLVM IR is not supported for AIEv1";
844 populateCommonAIECanonicalizeConversionPatterns(patterns, backend);
845 configureCommonAIECanonicalizeLegalizations(target, backend);
846 if (aieVersion == AIEArch::AIE) {
847 populateAIEv1CanonicalizeConversionPatterns(patterns, backend);
848 configureAIEv1CanonicalizeLegalizations(target, backend);
850 populateAIE2CanonicalizeConversionPatterns(patterns, backend);
851 configureAIE2CanonicalizeLegalizations(target, backend);
854 if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
860static std::unique_ptr<::mlir::Pass> createCanonicalizeVectorForAIEVecPass(
862 return std::make_unique<CanonicalizeVectorForAIEVecPass>(options);
866 :
public PassWrapper<HoistCastOpToDataSourcePass, OperationPass<>> {
869 auto *op = getOperation();
870 MLIRContext *context = &getContext();
871 RewritePatternSet patterns(context);
875 (void)applyPatternsGreedily(op, std::move(patterns));
879static std::unique_ptr<::mlir::Pass> createHoistCastOpToDataSourcePass() {
880 return std::make_unique<HoistCastOpToDataSourcePass>();
884 :
public PassWrapper<ReorderOperationsPass, OperationPass<>> {
887 auto *op = getOperation();
888 MLIRContext *context = &getContext();
889 RewritePatternSet patterns(context);
892 patterns.getContext(),
893 [](arith::ExtSIOp extOp, vector::BroadcastOp bcastOp) -> Type {
894 Type extInElemTy = extOp.getIn().getType();
895 auto extInVecTy = dyn_cast<VectorType>(extInElemTy);
897 extInElemTy = extInVecTy.getElementType();
898 return VectorType::get(bcastOp.getResultVectorType().getShape(),
902 (void)applyPatternsGreedily(op, std::move(patterns));
906static std::unique_ptr<::mlir::Pass> createReorderOperationsPass() {
907 return std::make_unique<ReorderOperationsPass>();
919 if (decodeTargetBackend(options.
targetBackend) == TargetBackend::LLVMIR)
920 pm.addPass(createReorderOperationsPass());
922 pm.addPass(createVectorBroadcastLoweringPass());
923 pm.addPass(createCanonicalizeVectorForAIEVecPass(options));
924 if (decodeTargetBackend(options.
targetBackend) == TargetBackend::CPP)
925 pm.addPass(createHoistCastOpToDataSourcePass());
std::unique_ptr<::mlir::Pass > createCopyRemovalPass()
Create a pass that removes unnecessary Copy operations.
std::optional< int64_t > getTransferReadAlignmentOffset(TransferReadLikeOp readOp, mlir::VectorType vType, int64_t alignment)
int32_t getElementSizeInBits(mlir::VectorType type)
void buildCanonicalizeVectorForAIEVec(mlir::OpPassManager &pm, const CanonicalizeVectorForAIEVecOptions &options)
StringRef getArgument() const final
void getDependentDialects(DialectRegistry ®istry) const override
StringRef getDescription() const final
Option< std::string > aieTarget
CanonicalizeVectorForAIEVecPass(const CanonicalizeVectorForAIEVecOptions &options)
Option< std::string > targetBackend
void runOnOperation() override
LogicalResult matchAndRewrite(vector::InsertOp insOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertSplatTransferReadToBroadcastPattern(MLIRContext *context)
LogicalResult matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
void runOnOperation() override
HoistCastOpToDataSourcePattern(MLIRContext *context)
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
void runOnOperation() override
SplitUnalignedTransferReadPattern(MLIRContext *context, int64_t maxVectorSize, int64_t alignment)
LogicalResult matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
InferTypeB2AFnTy inferTypeB2A
LogicalResult matchAndRewrite(UnaryOpB bOp, PatternRewriter &rewriter) const override
std::function< Type(UnaryOpA aOp, UnaryOpB bOp)> InferTypeB2AFnTy
SwapUnaryOpsPattern(MLIRContext *context, InferTypeB2AFnTy inferType)
void runOnOperation() override
Options for the "canonicalize-vector-for-aievec" pipeline.
PassOptions::Option< std::string > targetBackend
PassOptions::Option< std::string > aieTarget