17#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
18#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
19#include "mlir/Dialect/Affine/IR/AffineOps.h"
20#include "mlir/Dialect/Func/IR/FuncOps.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/UB/IR/UBOps.h"
23#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
24#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
25#include "mlir/IR/Matchers.h"
26#include "mlir/IR/PatternMatch.h"
27#include "mlir/Pass/PassManager.h"
28#include "mlir/Transforms/DialectConversion.h"
29#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32#define DEBUG_TYPE "aievec-canonicalization"
36using namespace vector;
44static TargetBackend decodeTargetBackend(
const std::string &backend) {
45 if (!backend.empty()) {
46 if (backend ==
"llvmir")
47 return TargetBackend::LLVMIR;
49 return TargetBackend::UNKNOWN;
51 return TargetBackend::CPP;
54static AIEArch decodeAIETarget(
const std::string &target) {
55 if (!target.empty()) {
56 if (target ==
"aieml" || target ==
"aie2" || target ==
"aie2p")
59 return AIEArch::UNKNOWN;
68static bool isGemmBTransposedContractionOp(vector::ContractionOp op) {
69 if (op.getKind() != vector::CombiningKind::ADD)
73 auto lhsShape = op.getLhsType().getShape();
74 auto rhsShape = op.getRhsType().getShape();
75 auto accShape = cast<ShapedType>(op.getAccType()).getShape();
76 if (lhsShape.size() < 2 || rhsShape.size() < 2 || accShape.size() < 2)
80 SmallVector<vector::IteratorType> iterators = op.getIteratorTypesArray();
81 if (iterators.size() < 3)
83 auto innerMostIterators =
84 SmallVector<vector::IteratorType>(iterators.end() - 3, iterators.end());
85 if (vector::IteratorType::parallel != innerMostIterators[0] ||
86 vector::IteratorType::parallel != innerMostIterators[1] ||
87 vector::IteratorType::reduction != innerMostIterators[2])
91 SmallVector<AffineMap, 4> indexingMaps(op.getIndexingMapsArray());
92 SmallVector<int64_t> outerMostResults;
93 for (int64_t i = 0; i < indexingMaps[0].getNumResults() - 2; i++)
94 outerMostResults.push_back(i);
96 auto innerLhsMap = indexingMaps[0].dropResults(outerMostResults);
97 auto innerRhsMap = indexingMaps[1].dropResults(outerMostResults);
98 auto innerAccMap = indexingMaps[2].dropResults(outerMostResults);
101 auto *ctx = op.getContext();
103 AffineMap::getPermutationMap(ArrayRef<unsigned>{1, 0, 2}, ctx)
106 AffineMap::getPermutationMap(ArrayRef<unsigned>{0, 1, 2}, ctx)
109 AffineMap::getPermutationMap(ArrayRef<unsigned>{2, 0, 1}, ctx)
111 int64_t numOuterMostDims = indexingMaps[0].getNumDims() - 3;
112 return innerLhsMap == mmAidxMap.shiftDims(numOuterMostDims) &&
113 innerRhsMap == mmBidxMap.shiftDims(numOuterMostDims) &&
114 innerAccMap == mmCidxMap.shiftDims(numOuterMostDims);
136 ConversionPatternRewriter &rewriter)
const override {
138 if (adaptor.getPermutationMap().isConstant())
142 auto vType = readOp.getVectorType();
150 auto vLen = vType.getShape().back();
151 auto longVecTy = VectorType::get(2 * vLen, vType.getElementType());
159 auto loc = readOp.getLoc();
160 Value oldInnerMostIdx = adaptor.getIndices().back();
161 auto offsetCorrectionMap =
162 AffineMap::get(1, 0, getAffineDimExpr(0, readOp.getContext()) - offset);
163 Value newInnerMostIdx = affine::AffineApplyOp::create(
164 rewriter, readOp.getLoc(), offsetCorrectionMap,
165 SmallVector<Value, 1>({oldInnerMostIdx}))
167 SmallVector<Value, 8> alignedIdx;
168 alignedIdx.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
169 alignedIdx[alignedIdx.size() - 1] = newInnerMostIdx;
173 auto newReadOp = vector::TransferReadOp::create(
174 rewriter, loc, longVecTy, adaptor.getBase(), alignedIdx,
175 adaptor.getPadding());
178 rewriter.replaceOpWithNewOp<vector::ExtractStridedSliceOp>(
179 readOp, newReadOp.getResult(), offset, vLen, 1);
201 ConversionPatternRewriter &rewriter)
const override {
202 AffineMap map = readOp.getPermutationMap();
203 if (!map.isConstant())
206 Value srcMemRef = adaptor.getBase();
207 SmallVector<Value, 8> indices;
211 if (cast<MemRefType>(srcMemRef.getType()).getRank() == 0) {
212 srcMemRef = memref::ExpandShapeOp::create(
213 rewriter, readOp.getLoc(), SmallVector<int64_t, 1>({1}),
214 srcMemRef, SmallVector<ReassociationIndices, 1>({}))
216 newIdx = arith::ConstantOp::create(rewriter, readOp.getLoc(),
217 rewriter.getIndexAttr(0L));
218 indices.push_back(newIdx);
220 indices.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
221 newIdx = indices[indices.size() - 1];
225 if (
auto applyOp = newIdx.getDefiningOp<affine::AffineApplyOp>())
226 if (applyOp.getAffineMap().getNumDims() == 1) {
227 newIdx = applyOp.getMapOperands()[0];
228 offset = applyOp.getAffineMap().compose(ArrayRef<int64_t>{0})[0];
232 int64_t vlen = readOp.getVector().getType().getShape()[0];
233 if (offset >= vlen) {
236 int64_t numElemsToSkip = vlen * (offset / vlen);
237 offset = offset % vlen;
238 auto newAddrMap = AffineMap::get(
239 1, 0, getAffineDimExpr(0, readOp.getContext()) + numElemsToSkip);
241 affine::AffineApplyOp::create(rewriter, readOp.getLoc(), newAddrMap,
242 SmallVector<Value, 1>({newIdx}))
245 indices[indices.size() - 1] = newIdx;
246 auto newReadOp = vector::TransferReadOp::create(
247 rewriter, readOp.getLoc(), readOp.getVector().getType(), srcMemRef,
248 indices, adaptor.getPadding());
249 auto extractOp = vector::ExtractOp::create(rewriter, readOp.getLoc(),
250 newReadOp.getResult(),
251 ArrayRef<int64_t>{offset});
252 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
253 readOp, newReadOp.getVector().getType(), extractOp.getResult());
269 PatternRewriter &rewriter)
const override {
270 arith::ExtSIOp extOp = cast<arith::ExtSIOp>(op);
271 Operation *defOp = extOp.getIn().getDefiningOp();
273 if (!defOp || isa<vector::TransferReadOp, memref::LoadOp,
274 affine::AffineLoadOp, func::CallOp>(defOp))
278 if (!isa<vector::BroadcastOp, vector::ExtractOp,
279 vector::ExtractStridedSliceOp>(defOp))
282 Type extOpInTy = extOp.getIn().getType();
283 SmallVector<Value, 4> inputs;
284 for (Value operand : defOp->getOperands()) {
285 Type operandTy = operand.getType();
286 VectorType extOpInVecTy = dyn_cast<VectorType>(extOpInTy);
287 VectorType operandVecTy = dyn_cast<VectorType>(operandTy);
288 if (operandTy == extOpInTy) {
289 Type outTy = extOp.getOut().getType();
291 arith::ExtSIOp::create(rewriter, extOp.getLoc(), outTy, operand)
293 }
else if (extOpInVecTy && extOpInVecTy.getElementType() == operandTy) {
296 cast<VectorType>(extOp.getOut().getType()).getElementType();
298 arith::ExtSIOp::create(rewriter, extOp.getLoc(), outTy, operand)
300 }
else if (operandVecTy && operandVecTy.getElementType() == extOpInTy) {
303 VectorType::get(operandVecTy.getShape(), extOp.getOut().getType());
305 arith::ExtSIOp::create(rewriter, extOp.getLoc(), outTy, operand)
307 }
else if (extOpInVecTy && operandVecTy &&
308 (extOpInVecTy.getElementType() ==
309 operandVecTy.getElementType())) {
311 Type outTy = VectorType::get(
312 operandVecTy.getShape(),
313 cast<VectorType>(extOp.getOut().getType()).getElementType());
315 arith::ExtSIOp::create(rewriter, extOp.getLoc(), outTy, operand)
318 inputs.push_back(operand);
323 rewriter.create(extOp->getLoc(), defOp->getName().getIdentifier(),
324 inputs, {extOp.getOut().getType()}, defOp->getAttrs());
325 rewriter.replaceOp(extOp, newOp->getResult(0));
333template <
class UnaryOpA,
class UnaryOpB>
345 PatternRewriter &rewriter)
const override {
347 UnaryOpA::template hasTrait<OpTrait::OneOperand>(),
348 "SwapUnaryOps can only be instantiated for single-operand ops");
350 UnaryOpB::template hasTrait<OpTrait::OneOperand>(),
351 "SwapUnaryOps can only be instantiated for single-operand ops");
352 UnaryOpA aOp = bOp.getOperand().template getDefiningOp<UnaryOpA>();
354 return rewriter.notifyMatchFailure(bOp, UnaryOpB::getOperationName() +
355 " not preceeded by " +
356 UnaryOpA::getOperationName());
361 UnaryOpB::create(rewriter, bOp->getLoc(), SmallVector<Type>({newA2BTy}),
362 aOp->getOperands(), bOp->getAttrs());
363 auto newB = UnaryOpA::create(rewriter, bOp->getLoc(),
364 SmallVector<Type>({bOp.getResult().getType()}),
365 newA->getResults(), aOp->getAttrs());
366 rewriter.replaceOp(bOp, newB.getResult());
371static SmallVector<Value> collapseInnerMostDimIndices(PatternRewriter &b,
372 Location loc,
int numDims,
374 ArrayRef<int64_t> shape,
377 auto newIdxExpr = b.getAffineDimExpr(numDims - 1);
379 for (int64_t dim = numDims - 2; dim >= 0; dim--) {
380 stride *= shape[shape.size() - (numDims - dim - 1)];
381 newIdxExpr = newIdxExpr + b.getAffineDimExpr(dim) * stride;
383 auto newIndexMap = AffineMap::get(numDims, 0, newIdxExpr);
384 Value newInnerMostIdxValue =
385 affine::AffineApplyOp::create(b, loc, newIndexMap,
386 indices.take_back(numDims))
388 SmallVector<Value> newIdxRange;
389 for (
auto idx : indices.drop_back(numDims))
390 newIdxRange.push_back(idx);
391 newIdxRange.push_back(newInnerMostIdxValue);
395static Value collapseInnerMostShapeDims(PatternRewriter &b, Location loc,
396 int numDims, Value val) {
397 auto memRefTy = cast<MemRefType>(val.getType());
398 auto shape = memRefTy.getShape();
399 int64_t newInnerMostDim = std::accumulate(shape.end() - numDims, shape.end(),
400 1, std::multiplies<>());
401 SmallVector<int64_t, 4> newShape{shape.begin(), shape.end() - numDims + 1};
402 newShape[shape.size() - numDims] = newInnerMostDim;
403 auto reassocIndices =
404 getReassociationIndicesForCollapse(shape, newShape).value();
408 memref::CollapseShapeOp::computeCollapsedType(memRefTy, reassocIndices);
409 return memref::CollapseShapeOp::create(b, loc, newMemRefTy, val,
417static bool hasContiguousRowMajorStrides(MemRefType memRefTy) {
418 SmallVector<int64_t> strides;
420 if (failed(memRefTy.getStridesAndOffset(strides, offset)))
422 auto shape = memRefTy.getShape();
423 int64_t expected = 1;
424 for (
int i = shape.size() - 1; i >= 0; --i) {
425 if (strides[i] != ShapedType::kDynamic && strides[i] != expected)
427 if (shape[i] != ShapedType::kDynamic)
428 expected *= shape[i];
430 expected = ShapedType::kDynamic;
444 ConversionPatternRewriter &rewriter)
const override {
447 if (!adaptor.getPermutationMap().isMinorIdentity() || adaptor.getMask())
449 VectorType vectorTy = readOp.getVector().getType();
450 if (vectorTy.getRank() < 2)
453 MemRefType memRefTy = dyn_cast<MemRefType>(adaptor.getBase().getType());
456 auto memRefShape = memRefTy.getShape();
457 auto vecShape = vectorTy.getShape();
460 VectorType::get({std::accumulate(vecShape.begin(), vecShape.end(), 1,
461 std::multiplies<>())},
462 vectorTy.getElementType());
463 if (!hasContiguousRowMajorStrides(memRefTy))
466 AffineMap layout = memRefTy.getLayout().getAffineMap();
468 collapseInnerMostDimIndices(rewriter, readOp.getLoc(), vecShape.size(),
469 adaptor.getIndices(), memRefShape, layout);
470 auto newSource = collapseInnerMostShapeDims(
471 rewriter, readOp.getLoc(), vecShape.size(), adaptor.getBase());
472 auto newVector = vector::TransferReadOp::create(
473 rewriter, readOp.getLoc(), newVectorTy, newSource, newIndices,
475 arith::getZeroConstant(rewriter, readOp.getLoc(),
476 newVectorTy.getElementType()));
478 auto inBoundsArrayAttrOpt = adaptor.getInBounds();
479 if (inBoundsArrayAttrOpt) {
480 SmallVector<bool> inBounds =
481 llvm::to_vector(inBoundsArrayAttrOpt.getAsValueRange<BoolAttr>());
482 SmallVector<bool> newInBounds({
false});
483 newInBounds[0] = std::all_of(inBounds.begin(), inBounds.end(),
484 [](
bool v) { return v; });
485 newVector.getProperties().setInBounds(
486 rewriter.getBoolArrayAttr(newInBounds));
489 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, vectorTy,
505 ConversionPatternRewriter &rewriter)
const override {
508 if (!adaptor.getPermutationMap().isMinorIdentity() || adaptor.getMask())
510 VectorType vectorTy = cast<VectorType>(adaptor.getValueToStore().getType());
511 if (vectorTy.getRank() < 2)
514 MemRefType memRefTy = dyn_cast<MemRefType>(adaptor.getBase().getType());
517 auto memRefShape = memRefTy.getShape();
518 auto vecShape = vectorTy.getShape();
520 if (!hasContiguousRowMajorStrides(memRefTy))
524 VectorType::get({std::accumulate(vecShape.begin(), vecShape.end(), 1,
525 std::multiplies<>())},
526 vectorTy.getElementType());
527 AffineMap layout = memRefTy.getLayout().getAffineMap();
529 vector::ShapeCastOp::create(rewriter, writeOp.getLoc(), newVectorTy,
530 adaptor.getValueToStore())
533 collapseInnerMostDimIndices(rewriter, writeOp.getLoc(), vecShape.size(),
534 adaptor.getIndices(), memRefShape, layout);
535 auto newSource = collapseInnerMostShapeDims(
536 rewriter, writeOp.getLoc(), vecShape.size(), adaptor.getBase());
538 auto newOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
539 writeOp, newVector, newSource, newIndices);
541 auto inBoundsArrayAttrOpt = adaptor.getInBounds();
542 if (inBoundsArrayAttrOpt) {
543 SmallVector<bool> inBounds =
544 llvm::to_vector(inBoundsArrayAttrOpt.getAsValueRange<BoolAttr>());
545 SmallVector<bool> newInBounds({
false});
546 newInBounds[0] = std::all_of(inBounds.begin(), inBounds.end(),
547 [](
bool v) { return v; });
548 newOp.getProperties().setInBounds(rewriter.getBoolArrayAttr(newInBounds));
565 SmallVector<int64_t> shape{vecTy.getShape()};
566 auto nDim = shape.size();
567 int64_t dimNm1 = shape[nDim - 1];
568 shape[nDim - 1] = shape[nDim - 2];
569 shape[nDim - 2] = dimNm1;
570 auto elemTy = vecTy.getElementType();
571 return VectorType::get(shape, elemTy);
576 ConversionPatternRewriter &rewriter)
const override {
577 if (!isGemmBTransposedContractionOp(contractOp))
580 Location loc = contractOp.getLoc();
581 auto *ctx = rewriter.getContext();
583 Value rhsVal = adaptor.getRhs();
584 VectorType rhsVecTy = contractOp.getRhsType();
585 Type rhsElemTy = rhsVecTy.getElementType();
587 bool doExtF =
false, doExtSI =
false, doExtUI =
false;
588 if (
auto extfRhsOp = rhsVal.getDefiningOp<arith::ExtFOp>()) {
589 rhsVal = extfRhsOp.getIn();
590 rhsVecTy = cast<VectorType>(rhsVal.getType());
592 }
else if (
auto extsiRhsOp = rhsVal.getDefiningOp<arith::ExtSIOp>()) {
593 rhsVal = extsiRhsOp.getIn();
594 rhsVecTy = cast<VectorType>(rhsVal.getType());
596 }
else if (
auto extuiRhsOp = rhsVal.getDefiningOp<arith::ExtUIOp>()) {
597 rhsVal = extuiRhsOp.getIn();
598 rhsVecTy = cast<VectorType>(rhsVal.getType());
602 int64_t nDim = rhsVecTy.getShape().size();
603 SmallVector<int64_t> rhsPermutation;
604 for (int64_t i = 0; i < nDim - 2; i++)
605 rhsPermutation.push_back(i);
606 rhsPermutation.push_back(nDim - 1);
607 rhsPermutation.push_back(nDim - 2);
609 rhsVal = vector::TransposeOp::create(rewriter, loc, transpRhsVecTy, rhsVal,
615 arith::ExtFOp::create(
617 VectorType::get(transpRhsVecTy.getShape(), rhsElemTy), rhsVal)
621 arith::ExtSIOp::create(
623 VectorType::get(transpRhsVecTy.getShape(), rhsElemTy), rhsVal)
627 arith::ExtUIOp::create(
629 VectorType::get(transpRhsVecTy.getShape(), rhsElemTy), rhsVal)
632 SmallVector<AffineMap, 4> oldIdxMaps(contractOp.getIndexingMapsArray());
634 nDim = oldIdxMaps[1].getNumDims();
635 SmallVector<int64_t> innerDimPerm;
636 for (int64_t i = 0; i < nDim - 2; i++)
637 innerDimPerm.push_back(i);
638 innerDimPerm.push_back(nDim - 1);
639 innerDimPerm.push_back(nDim - 2);
640 auto transpPermMap = AffineMap::getPermutationMap(innerDimPerm, ctx);
642 auto newIdxMaps = rewriter.getAffineMapArrayAttr(
643 {oldIdxMaps[0], oldIdxMaps[1].compose(transpPermMap), oldIdxMaps[2]});
645 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
646 contractOp, contractOp.getResult().getType(), adaptor.getLhs(), rhsVal,
647 adaptor.getAcc(), newIdxMaps, contractOp.getIteratorTypes());
655static LogicalResult isAllZeroOffsetAccess(mlir::OperandRange indices) {
656 if (!llvm::all_of(indices, [](Value val) {
658 if (!matchPattern(val, m_Constant(&attr)))
660 return attr.getInt() == 0;
669static SmallVector<Value> opFoldResultsToValues(PatternRewriter &rewriter,
671 memref::SubViewOp subViewOp) {
672 OpBuilder::InsertionGuard g(rewriter);
673 rewriter.setInsertionPoint(subViewOp);
674 SmallVector<Value> newIndices;
675 for (OpFoldResult offset : subViewOp.getMixedOffsets()) {
677 if (
auto attr = dyn_cast<Attribute>(offset)) {
678 indexVal = arith::ConstantIndexOp::create(
679 rewriter, loc, cast<IntegerAttr>(attr).getInt());
681 indexVal = cast<Value>(offset);
683 newIndices.push_back(indexVal);
708 PatternRewriter &rewriter)
const override {
710 auto subViewOp = dyn_cast_if_present<memref::SubViewOp>(
711 readOp.getBase().getDefiningOp());
716 if (failed(isAllZeroOffsetAccess(readOp.getIndices())))
720 SmallVector<Value> newIndices =
721 opFoldResultsToValues(rewriter, readOp.getLoc(), subViewOp);
724 rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
725 readOp, readOp.getType(), subViewOp.getSource(), newIndices,
726 readOp.getPadding(), readOp.getInBoundsValues());
752 PatternRewriter &rewriter)
const override {
754 auto subViewOp = dyn_cast_if_present<memref::SubViewOp>(
755 writeOp.getBase().getDefiningOp());
760 if (failed(isAllZeroOffsetAccess(writeOp.getIndices())))
764 SmallVector<Value> newIndices =
765 opFoldResultsToValues(rewriter, writeOp.getLoc(), subViewOp);
768 vector::TransferWriteOp::create(rewriter, writeOp.getLoc(),
769 writeOp.getVector(), subViewOp.getSource(),
770 newIndices, writeOp.getInBoundsValues());
773 rewriter.eraseOp(writeOp);
788 PatternRewriter &rewriter)
const override {
789 auto insSrcTy = dyn_cast<VectorType>(insOp.getValueToStoreType());
793 auto srcShape = insSrcTy.getShape();
794 auto dstShape = insOp.getDestVectorType().getShape();
796 unsigned long numLeadUnitDimDst = 0;
797 while (numLeadUnitDimDst < dstShape.size() &&
798 dstShape[numLeadUnitDimDst] == 1)
801 if (!numLeadUnitDimDst)
804 unsigned long numLeadUnitDimSrc = 0;
805 while (numLeadUnitDimSrc < srcShape.size() &&
806 srcShape[numLeadUnitDimSrc] == 1)
809 SmallVector<int64_t> nonLeadUnitDimDstShape(
810 dstShape.begin() + numLeadUnitDimDst, dstShape.end());
811 SmallVector<int64_t> nonLeadUnitDimSrcShape(
812 srcShape.begin() + numLeadUnitDimSrc, srcShape.end());
814 if (nonLeadUnitDimSrcShape != nonLeadUnitDimDstShape)
817 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
818 insOp, insOp.getDestVectorType(), insOp.getValueToStore());
827configureCommonAIECanonicalizeLegalizations(ConversionTarget &target,
829 target.addLegalDialect<arith::ArithDialect, affine::AffineDialect,
830 memref::MemRefDialect, vector::VectorDialect,
835populateCommonAIECanonicalizeConversionPatterns(RewritePatternSet &patterns,
838 patterns.getContext());
845static void configureAIEv1CanonicalizeLegalizations(ConversionTarget &target,
847 target.addDynamicallyLegalOp<vector::TransferReadOp>(
848 [](vector::TransferReadOp op) {
849 return !op.getPermutationMap().isConstant() &&
856populateAIEv1CanonicalizeConversionPatterns(RewritePatternSet &patterns,
866static void configureAIE2CanonicalizeLegalizations(ConversionTarget &target,
868 target.addDynamicallyLegalOp<vector::TransferReadOp>(
869 [](vector::TransferReadOp op) {
870 return !op.getPermutationMap().isConstant() &&
873 op.getVector().getType().getRank() < 2;
875 target.addDynamicallyLegalOp<vector::TransferWriteOp>(
876 [](vector::TransferWriteOp op) {
877 return cast<VectorType>(op.getVector().getType()).getRank() < 2;
879 target.addDynamicallyLegalOp<vector::ContractionOp>(
880 [](vector::ContractionOp op) {
881 return !isGemmBTransposedContractionOp(op);
886populateAIE2CanonicalizeConversionPatterns(RewritePatternSet &patterns,
905static Value smartTruncF32ToBF16(PatternRewriter &rewriter, Location loc,
906 Value val, Type bf16Type) {
907 if (
auto extfOp = val.getDefiningOp<arith::ExtFOp>()) {
908 if (extfOp.getIn().getType() == bf16Type)
909 return extfOp.getIn();
911 return arith::TruncFOp::create(rewriter, loc, bf16Type, val);
921template <
typename OpTy>
926 PatternRewriter &rewriter)
const override {
927 auto resultType = dyn_cast<VectorType>(op.getType());
928 if (!resultType || !resultType.getElementType().isF32())
931 Location loc = op.getLoc();
933 VectorType::get(resultType.getShape(), rewriter.getBF16Type());
936 smartTruncF32ToBF16(rewriter, loc, op.getLhs(), bf16VecType);
938 smartTruncF32ToBF16(rewriter, loc, op.getRhs(), bf16VecType);
941 OpTy::create(rewriter, loc, bf16VecType, lhsBF16, rhsBF16);
942 auto extOp = arith::ExtFOp::create(rewriter, loc, resultType, newResult);
943 rewriter.replaceOp(op, extOp);
951 using OpRewritePattern::OpRewritePattern;
954 PatternRewriter &rewriter)
const override {
955 auto lhsType = dyn_cast<VectorType>(op.getLhs().getType());
956 if (!lhsType || !lhsType.getElementType().isF32())
959 Location loc = op.getLoc();
961 VectorType::get(lhsType.getShape(), rewriter.getBF16Type());
964 smartTruncF32ToBF16(rewriter, loc, op.getLhs(), bf16VecType);
966 smartTruncF32ToBF16(rewriter, loc, op.getRhs(), bf16VecType);
968 rewriter.replaceOpWithNewOp<arith::CmpFOp>(op, op.getPredicate(), lhsBF16,
978 using OpRewritePattern::OpRewritePattern;
981 PatternRewriter &rewriter)
const override {
982 auto resultType = dyn_cast<VectorType>(op.getType());
983 if (!resultType || !resultType.getElementType().isF32())
986 Location loc = op.getLoc();
988 VectorType::get(resultType.getShape(), rewriter.getBF16Type());
991 smartTruncF32ToBF16(rewriter, loc, op.getTrueValue(), bf16VecType);
993 smartTruncF32ToBF16(rewriter, loc, op.getFalseValue(), bf16VecType);
995 Value newResult = arith::SelectOp::create(rewriter, loc, op.getCondition(),
996 trueValBF16, falseValBF16);
997 auto extOp = arith::ExtFOp::create(rewriter, loc, resultType, newResult);
998 rewriter.replaceOp(op, extOp);
1006 using OpRewritePattern::OpRewritePattern;
1009 PatternRewriter &rewriter)
const override {
1010 auto resultType = dyn_cast<VectorType>(op.getType());
1011 if (!resultType || !resultType.getElementType().isF32())
1014 Location loc = op.getLoc();
1016 VectorType::get(resultType.getShape(), rewriter.getBF16Type());
1019 smartTruncF32ToBF16(rewriter, loc, op.getLhs(), bf16VecType);
1021 smartTruncF32ToBF16(rewriter, loc, op.getRhs(), bf16VecType);
1023 smartTruncF32ToBF16(rewriter, loc, op.getAcc(), bf16VecType);
1026 vector::FMAOp::create(rewriter, loc, lhsBF16, rhsBF16, accBF16);
1027 auto extOp = arith::ExtFOp::create(rewriter, loc, resultType, newResult);
1028 rewriter.replaceOp(op, extOp);
1034template <
typename OpTy>
1039 PatternRewriter &rewriter)
const override {
1040 auto resultType = dyn_cast<VectorType>(op.getType());
1041 if (!resultType || !resultType.getElementType().isF32())
1044 Location loc = op.getLoc();
1046 VectorType::get(resultType.getShape(), rewriter.getBF16Type());
1049 smartTruncF32ToBF16(rewriter, loc, op->getOperand(0), bf16VecType);
1051 Value newResult = OpTy::create(rewriter, loc, bf16VecType, inputBF16);
1052 auto extOp = arith::ExtFOp::create(rewriter, loc, resultType, newResult);
1053 rewriter.replaceOp(op, extOp);
1059 :
public PassWrapper<BF16EmulationPass, OperationPass<>> {
1062 auto *op = getOperation();
1063 MLIRContext *context = &getContext();
1064 RewritePatternSet patterns(context);
1086 (void)applyPatternsGreedily(op, std::move(patterns));
1091 return std::make_unique<BF16EmulationPass>();
1095 :
public PassWrapper<VectorBroadcastLoweringPass, OperationPass<>> {
1098 auto *op = getOperation();
1099 MLIRContext *context = &getContext();
1100 RewritePatternSet patterns(context);
1101 populateVectorBroadcastLoweringPatterns(patterns);
1103 patterns.getContext());
1105 (void)applyPatternsGreedily(op, std::move(patterns));
1109static std::unique_ptr<::mlir::Pass> createVectorBroadcastLoweringPass() {
1110 return std::make_unique<VectorBroadcastLoweringPass>();
1121 :
public PassWrapper<CanonicalizeVectorForAIEVecPass, OperationPass<>> {
1138 return "test-canonicalize-vector-for-aievec";
1142 return "Canonicalize vector operations for AIEVec conversion";
1147 .insert<arith::ArithDialect, memref::MemRefDialect,
1148 vector::VectorDialect, affine::AffineDialect, ub::UBDialect>();
1152 *
this,
"aie-target",
1154 "Select AIE version: \"aie\", \"aie2\", or \"aie2p\". This will "
1155 "determine the vector size and available operations."),
1156 llvm::cl::init(
"aie")};
1159 *
this,
"target-backend",
1160 llvm::cl::desc(
"Select translation backend: \"cpp\" or \"llvmir\". This "
1161 "will determine the aievec operations used to convert "
1162 "from vector dialect."),
1163 llvm::cl::init(
"cpp")};
1166 auto *op = getOperation();
1167 MLIRContext *context = &getContext();
1168 RewritePatternSet patterns(context);
1169 ConversionTarget target(*context);
1172 if (aieVersion == AIEArch::UNKNOWN) {
1173 op->emitError() <<
"unknown AIE target '" <<
aieTarget <<
"'";
1174 signalPassFailure();
1179 if (backend == TargetBackend::UNKNOWN) {
1180 op->emitError() <<
"unknown target backend '" <<
targetBackend <<
"'";
1181 signalPassFailure();
1184 if (backend == TargetBackend::LLVMIR && aieVersion == AIEArch::AIE) {
1185 op->emitError() <<
"targetting LLVM IR is not supported for AIEv1";
1186 signalPassFailure();
1190 populateCommonAIECanonicalizeConversionPatterns(patterns, backend);
1191 configureCommonAIECanonicalizeLegalizations(target, backend);
1192 if (aieVersion == AIEArch::AIE) {
1193 populateAIEv1CanonicalizeConversionPatterns(patterns, backend);
1194 configureAIEv1CanonicalizeLegalizations(target, backend);
1196 populateAIE2CanonicalizeConversionPatterns(patterns, backend);
1197 configureAIE2CanonicalizeLegalizations(target, backend);
1201 RewritePatternSet patterns(context);
1204 (void)applyPatternsGreedily(op, std::move(patterns));
1207 if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
1208 signalPassFailure();
1213static std::unique_ptr<::mlir::Pass> createCanonicalizeVectorForAIEVecPass(
1215 return std::make_unique<CanonicalizeVectorForAIEVecPass>(options);
1219 :
public PassWrapper<HoistCastOpToDataSourcePass, OperationPass<>> {
1222 auto *op = getOperation();
1223 MLIRContext *context = &getContext();
1224 RewritePatternSet patterns(context);
1228 (void)applyPatternsGreedily(op, std::move(patterns));
1232static std::unique_ptr<::mlir::Pass> createHoistCastOpToDataSourcePass() {
1233 return std::make_unique<HoistCastOpToDataSourcePass>();
1237 :
public PassWrapper<ReorderOperationsPass, OperationPass<>> {
1240 auto *op = getOperation();
1241 MLIRContext *context = &getContext();
1242 RewritePatternSet patterns(context);
1245 patterns.getContext(),
1246 [](arith::ExtSIOp extOp, vector::BroadcastOp bcastOp) -> Type {
1247 Type extInElemTy = extOp.getIn().getType();
1248 auto extInVecTy = dyn_cast<VectorType>(extInElemTy);
1250 extInElemTy = extInVecTy.getElementType();
1251 return VectorType::get(bcastOp.getResultVectorType().getShape(),
1255 (void)applyPatternsGreedily(op, std::move(patterns));
1259static std::unique_ptr<::mlir::Pass> createReorderOperationsPass() {
1260 return std::make_unique<ReorderOperationsPass>();
1277 if (decodeTargetBackend(options.
targetBackend) == TargetBackend::LLVMIR)
1278 pm.addPass(createReorderOperationsPass());
1280 pm.addPass(createVectorBroadcastLoweringPass());
1281 pm.addPass(createCanonicalizeVectorForAIEVecPass(options));
1282 if (decodeTargetBackend(options.
targetBackend) == TargetBackend::CPP)
1283 pm.addPass(createHoistCastOpToDataSourcePass());
std::unique_ptr<::mlir::Pass > createCopyRemovalPass()
Create a pass that removes unnecessary Copy operations.
std::optional< int64_t > getTransferReadAlignmentOffset(TransferReadLikeOp readOp, mlir::VectorType vType, int64_t alignment)
int32_t getElementSizeInBits(mlir::VectorType type)
std::unique_ptr<::mlir::Pass > createBF16EmulationPass()
Create a pass that emulates f32 vector arithmetic using bf16 operations.
void buildCanonicalizeVectorForAIEVec(mlir::OpPassManager &pm, const CanonicalizeVectorForAIEVecOptions &options)
void runOnOperation() override
Pattern to canonicalize trivial vector.transfer_read operations on subviews.
LogicalResult matchAndRewrite(vector::TransferReadOp readOp, PatternRewriter &rewriter) const override
Pattern to canonicalize trivial vector.transfer_write operations on subviews.
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, PatternRewriter &rewriter) const override
StringRef getArgument() const final
void getDependentDialects(DialectRegistry ®istry) const override
StringRef getDescription() const final
Option< std::string > aieTarget
CanonicalizeVectorForAIEVecPass(const CanonicalizeVectorForAIEVecOptions &options)
Option< std::string > targetBackend
void runOnOperation() override
LogicalResult matchAndRewrite(vector::InsertOp insOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertSplatTransferReadToBroadcastPattern(MLIRContext *context)
Pattern to emulate f32 binary vector arithmetic ops in bf16.
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
Pattern to emulate f32 comparison ops in bf16.
LogicalResult matchAndRewrite(arith::CmpFOp op, PatternRewriter &rewriter) const override
Pattern to emulate f32 vector.fma in bf16.
LogicalResult matchAndRewrite(vector::FMAOp op, PatternRewriter &rewriter) const override
Pattern to emulate f32 select ops in bf16.
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
Pattern to emulate f32 unary vector ops in bf16.
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
void runOnOperation() override
HoistCastOpToDataSourcePattern(MLIRContext *context)
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
void runOnOperation() override
SplitUnalignedTransferReadPattern(MLIRContext *context, int64_t maxVectorSize, int64_t alignment)
LogicalResult matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
InferTypeB2AFnTy inferTypeB2A
LogicalResult matchAndRewrite(UnaryOpB bOp, PatternRewriter &rewriter) const override
std::function< Type(UnaryOpA aOp, UnaryOpB bOp)> InferTypeB2AFnTy
SwapUnaryOpsPattern(MLIRContext *context, InferTypeB2AFnTy inferType)
void runOnOperation() override
Options for the "canonicalize-vector-for-aievec" pipeline.
PassOptions::Option< bool > enableBF16Emulation
PassOptions::Option< std::string > targetBackend
PassOptions::Option< std::string > aieTarget