21#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
22#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
23#include "mlir/Dialect/Affine/IR/AffineOps.h"
24#include "mlir/Dialect/Func/IR/FuncOps.h"
25#include "mlir/Dialect/MemRef/IR/MemRef.h"
26#include "mlir/Dialect/SCF/IR/SCF.h"
27#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
28#include "mlir/IR/TypeUtilities.h"
29#include "mlir/Pass/PassManager.h"
30#include "mlir/Transforms/Passes.h"
32#include "llvm/ADT/SmallSet.h"
35#define GEN_PASS_DEF_AIEVECTORIZE
36#include "aie/Dialect/AIEVec/Transforms/Passes.h.inc"
42using namespace vector;
46#define DEBUG_TYPE "aie-vect"
48static llvm::cl::opt<bool>
49 unalignedLoadsCheck(
"unaligned-loads-check",
50 llvm::cl::desc(
"Enable the unaligned loads check"),
51 llvm::cl::init(
true));
53static llvm::cl::opt<bool> AIEML(
"aieml", llvm::cl::desc(
"AI Engine-ML"),
54 llvm::cl::init(
false));
65 SmallVector<IntervalReuse *, 16> reuseIntervals;
68 mlir::DenseMap<Operation *, IntervalReuse *> opToIntervalMap;
73 mlir::DenseMap<Operation *, AffineExpr> linearizedAccess;
77 mlir::DenseMap<Value, AffineExpr> indexToExprDimMap;
81 mlir::DenseMap<Block *, SmallVector<Operation *, 8>> blockToEnclosingLoops;
84 mlir::DenseMap<Operation *, Operation *> pairedOp;
92 mlir::DenseMap<Operation *, std::pair<int32_t, int32_t>> opToColOffsets;
94 mlir::DenseMap<Operation *, Operation *> sextTruncDefMap;
98 llvm::SmallSet<Operation *, 8> mscOps;
114 bool unalignedLoadsCheck, aieml;
117 VectState(MLIRContext *context, int8_t s, int32_t z, int32_t d,
118 bool unalignedLoadsCheck,
bool aieml)
119 : builder(context), shift(s), zeroOffset(z), dupFactor(d),
120 unalignedLoadsCheck(unalignedLoadsCheck), aieml(aieml) {}
126IntervalReuse *VectState::getIntervalForOperation(Operation *op) {
127 assert(opToIntervalMap.count(op) &&
128 "could not find the IntervalReuse object for op");
129 return opToIntervalMap[op];
134struct AIEOpAttributes {
136 SmallVector<std::string, 2> start;
137 SmallVector<std::string, 2> offset, offset_hi;
138 SmallVector<std::string, 2> step;
139 SmallVector<std::string, 2> square;
143struct AIEVecAttributes {
149 int32_t vecSizeInBits;
153 int32_t elementSizeInBits;
159 AIEVecAttributes(
unsigned l,
unsigned vs, Type et, int32_t es)
160 : lanes(l), vecSizeInBits(vs), elementType(et), elementSizeInBits(es),
161 loadFromMemory(false), isSplat(false) {}
171 int32_t xbits, zbits;
173 Scheme(int32_t l, int32_t c, int32_t x, int32_t z)
174 : lanes(l), cols(c), xbits(x), zbits(z) {}
183static AIEVecAttributes getVectorStats(VectorType type) {
189static AIEVecAttributes getResultVecStats(Operation *op,
unsigned idx = 0) {
190 auto vtype = cast<VectorType>(op->getResult(idx).getType());
191 return getVectorStats(vtype);
194static Operation *getOperandDefOp(VectState *state, Operation *op,
196 return state->sextTruncDefMap.count(op->getOperand(idx).getDefiningOp())
197 ? state->sextTruncDefMap[op->getOperand(idx).getDefiningOp()]
198 : op->getOperand(idx).getDefiningOp();
202static AIEVecAttributes getOperandVecStats(Operation *op, VectState *state,
204 assert(op->getNumOperands() > idx);
205 Operation *defOp = getOperandDefOp(state, op, idx);
206 auto vtype = cast<VectorType>(defOp->getResult(0).getType());
207 auto ret = getVectorStats(vtype);
209 if (
auto readOp = dyn_cast<TransferReadOp>(defOp)) {
213 ret.loadFromMemory =
true;
215 ret.isSplat = readOp.getPermutationMap().isConstant();
221static std::pair<int32_t, int32_t> getNumRowsAndCols(Operation *op,
223 assert(op->getNumOperands() >= 2 && op->getNumResults() == 1);
225 Operation *left = getOperandDefOp(state, op, 0);
226 Operation *right = getOperandDefOp(state, op, 1);
229 auto vtype = cast<VectorType>(op->getResult(0).getType());
233 auto ltype = cast<VectorType>(left->getResult(0).getType());
234 auto rtype = cast<VectorType>(right->getResult(0).getType());
238 int32_t width = (lsize == 8 && rsize == 8) ? (state->aieml ? 256 : 128)
239 : (lsize == 16 && rsize == 8) ? 64
252 int32_t cols = width / (m * lanes);
253 return std::make_pair(lanes, cols);
263static void fuseAccessExtent(Operation *Op1, Operation *Op2, VectState *state) {
267 (isa<vector::FMAOp>(Op2) && isa<MulIOp, MulFOp, vector::FMAOp>(Op1));
268 if (!expectedTypes) {
269 printf(
"incorrect operation types\n");
276 for (
int idx = 0; idx < 2; ++idx) {
277 Operation *op1 = getOperandDefOp(state, Op1, idx);
278 Operation *op2 = getOperandDefOp(state, Op2, idx);
282 if (isa<TransferReadOp>(op1) && isa<TransferReadOp>(op2)) {
292 std::make_pair(std::min(op1Extent.first, op2Extent.first),
293 std::max(op1Extent.second, op2Extent.second));
305static bool isSimpleVectIntrinsic(Operation *Op, VectState *state) {
307 bool isMulOrFMAOp = isa<MulIOp, MulFOp, vector::FMAOp>(Op);
308 bool isSubOrAddOp = isa<SubIOp, SubFOp, AddIOp, AddFOp>(Op);
309 if (!isMulOrFMAOp && !isSubOrAddOp)
313 AIEVecAttributes vstat = getResultVecStats(Op);
314 AIEVecAttributes lstat = getOperandVecStats(Op, state, 0);
315 AIEVecAttributes rstat = getOperandVecStats(Op, state, 1);
317 bool sizeMatches = lstat.vecSizeInBits == rstat.vecSizeInBits &&
318 vstat.vecSizeInBits == rstat.vecSizeInBits &&
319 lstat.elementType == rstat.elementType &&
320 vstat.elementType == rstat.elementType;
321 bool noSplat = !lstat.isSplat && !rstat.isSplat;
322 bool noFloat = !isa<FloatType>(vstat.elementType) &&
323 !isa<FloatType>(lstat.elementType) &&
324 !isa<FloatType>(rstat.elementType);
326 return sizeMatches && noSplat && (isSubOrAddOp || noFloat);
333static bool isWellFormedVectorOp(Operation *Op) {
335 if (Op->getNumOperands() == 0 && Op->getNumResults() == 0)
338 SmallVector<Value, 8> operandsAndResults;
339 operandsAndResults.append(Op->operand_begin(), Op->operand_end());
340 operandsAndResults.append(Op->result_begin(), Op->result_end());
343 for (
auto val : operandsAndResults) {
344 if (!isa<VectorType>(val.getType()))
348 auto refType = cast<VectorType>(operandsAndResults.back().getType());
349 Type scalarType = refType.getElementType();
351 for (
auto val : operandsAndResults) {
352 auto vtype = cast<VectorType>(val.getType());
357 if (scalarType != vtype.getElementType())
366static bool writesToAccumulator(Operation *op) {
370 if (
auto mulOp = dyn_cast<aievec::aie1::MulOp>(op))
371 return isa<IntegerType>(
372 cast<VectorType>(mulOp.getResult().getType()).getElementType());
373 if (
auto fmaOp = dyn_cast<aievec::aie1::FMAOp>(op))
374 return isa<IntegerType>(
375 cast<VectorType>(fmaOp.getResult().getType()).getElementType());
377 return isa<aievec::FMAElemOp, aievec::MulElemOp, aievec::FMAConvOp,
378 aievec::MulConvOp, aievec::UPSOp>(op);
388static AffineExpr makeFlattenedStridedExpr(ArrayRef<int64_t> sizes,
389 ArrayRef<AffineExpr> exprs,
390 MLIRContext *context) {
391 assert(!sizes.empty() && !exprs.empty() &&
392 "expected non-empty sizes and exprs");
395 if (llvm::is_contained(sizes, 0))
396 return getAffineConstantExpr(0, context);
398 auto maps = AffineMap::inferFromExprList(exprs, context);
399 assert(!maps.empty() &&
"Expected one non-empty map");
400 unsigned nSymbols = maps[0].getNumSymbols();
403 bool dynamicPoisonBit =
false;
404 int64_t runningSize = 1;
405 for (
auto en :
llvm::zip(
llvm::reverse(exprs),
llvm::reverse(sizes))) {
406 int64_t size = std::get<1>(en);
410 AffineExpr dimExpr = std::get<0>(en);
411 AffineExpr stride = dynamicPoisonBit
412 ? getAffineSymbolExpr(nSymbols++, context)
413 : getAffineConstantExpr(runningSize, context);
414 expr = expr ? expr + dimExpr * stride : dimExpr * stride;
417 assert(runningSize > 0 &&
"integer overflow in size computation");
419 dynamicPoisonBit =
true;
426static AffineExpr constructLinearizedAffineExpr(TransferReadOp readOp,
430 if (state->linearizedAccess.count(readOp))
431 return state->linearizedAccess[readOp];
433 SmallVector<Value, 4> indices(readOp.getIndices().begin(),
434 readOp.getIndices().end());
435 auto memRefType = cast<MemRefType>(readOp.getBase().getType());
436 MLIRContext *context = memRefType.getContext();
438 SmallVector<AffineExpr, 8> exprVec;
442 for (
auto idxAndValue :
llvm::enumerate(indices)) {
443 auto value = idxAndValue.value();
447 if (
auto apOf =
value.getDefiningOp<affine::AffineApplyOp>()) {
448 AffineMap map = apOf.getAffineMap();
449 assert(map.getNumResults() == 1 &&
450 "Failed to create linearized affineExpr for complicated index");
451 SmallVector<AffineExpr, 4> indexExprs;
454 for (
auto index : apOf.getMapOperands()) {
455 if (
auto cIdx = index.getDefiningOp<arith::ConstantOp>()) {
456 auto idxVal = cast<IntegerAttr>(cIdx.getValue()).getValue();
457 unsigned idx = idxVal.getSExtValue();
458 indexExprs.push_back(getAffineConstantExpr(idx, context));
460 if (!state->indexToExprDimMap.count(index))
461 state->indexToExprDimMap[index] =
462 getAffineDimExpr(state->indexToExprDimMap.size(), context);
463 indexExprs.push_back(state->indexToExprDimMap[index]);
467 exprVec.push_back(map.getResult(0).replaceDims(indexExprs));
471 else if (
auto cOp =
value.getDefiningOp<arith::ConstantOp>()) {
472 auto idxVal = cast<IntegerAttr>(cOp.getValue()).getValue();
473 unsigned idx = idxVal.getSExtValue();
474 exprVec.push_back(getAffineConstantExpr(idx, context));
478 if (!state->indexToExprDimMap.count(value))
479 state->indexToExprDimMap[
value] =
480 getAffineDimExpr(state->indexToExprDimMap.size(), context);
481 exprVec.push_back(state->indexToExprDimMap[value]);
485 assert(!exprVec.empty() &&
"Could not construct linearized affineExpr");
488 auto ret = makeFlattenedStridedExpr(memRefType.getShape(), exprVec,
489 memRefType.getContext());
491 state->linearizedAccess[readOp] = ret;
499static std::pair<AffineExpr, int32_t> getBaseAndOffset(AffineExpr expr) {
500 AffineExpr base = expr;
503 if (
auto constExpr = llvm::dyn_cast<AffineConstantExpr>(expr)) {
505 offset += constExpr.getValue();
510 else if (
auto binopExpr = llvm::dyn_cast<AffineBinaryOpExpr>(expr)) {
511 if (binopExpr.getKind() == AffineExprKind::Add) {
512 AffineExpr lhs = binopExpr.getLHS(), rhs = binopExpr.getRHS();
513 if (
auto constExpr = llvm::dyn_cast<AffineConstantExpr>(lhs)) {
515 offset += constExpr.getValue();
517 if (
auto constExpr = llvm::dyn_cast<AffineConstantExpr>(rhs)) {
518 base = base == rhs ? nullptr : lhs;
519 offset += constExpr.getValue();
523 return std::make_pair(base, offset);
530static aievec::CastOp generateCastOp(Value source, VectorType resType,
531 bool isResAcc, VectState *state,
535 aievec::CastOp::create(state->builder, loc, resType, source, isResAcc);
537 assert(castOp &&
"could not create srs op");
543static aievec::SRSOp generateSRSOp(Value source, Type scalarType,
544 VectState *state, Location loc) {
546 Type accType = source.getType();
547 assert(writesToAccumulator(source.getDefiningOp()) &&
548 "srs source should write to accumulator");
555 auto shiftParamOp = arith::ConstantOp::create(
556 state->builder, loc, state->builder.getI32IntegerAttr(state->shift));
558 auto srsOp = aievec::SRSOp::create(state->builder, loc, srsType, source,
559 shiftParamOp.getResult());
561 assert(srsOp &&
"could not create srs op");
567static aievec::UPSOp generateUPSOp(Value source, VectState *state,
569 Type sourceType = source.getType();
572 assert(!writesToAccumulator(source.getDefiningOp()) &&
573 "ups source should not be accumulator");
577 aievec::UPSOp::create(state->builder, loc, accType, source, state->shift);
579 assert(upsOp &&
"could not create ups op");
584static aievec::BroadcastOp generateBroadcastOp(Value source, int8_t idx,
585 VectState *state, Location loc) {
586 auto type = cast<VectorType>(source.getType());
589 aievec::BroadcastOp::create(state->builder, loc, type, source, idx);
591 assert(broadcastOp &&
"could not create broadcast op");
596static aievec::ConcatOp generateConcatOp(SmallVector<Value> &sources,
597 VectState *state, Location loc,
598 VectorType concatType =
nullptr) {
599 assert(sources.size() > 1 &&
"must concat at least two vectors");
601 auto vecType = cast<VectorType>(sources.back().getType());
604 for (
auto source : sources) {
605 auto type = cast<VectorType>(source.getType());
606 if (type != vecType) {
607 printf(
"sources of concat op not of same type\n");
617 Type scalarType = vecType.getElementType();
623 aievec::ConcatOp::create(state->builder, loc, concatType, sources);
625 assert(concatOp &&
"could not create concat op");
631static aievec::aie1::SelectOp
632generateSelectOp(Value xbuff, AIEOpAttributes &opAttr,
unsigned lanes,
633 VectState *state, Location loc, Value ybuff =
nullptr) {
636 assert(!opAttr.select.empty());
637 assert(opAttr.start.size() == opAttr.offset.size() &&
638 opAttr.start.size() == 2);
640 auto xtype = cast<VectorType>(xbuff.getType());
647 auto selectOp = aievec::aie1::SelectOp::create(
648 state->builder, loc, resultType, xbuff, opAttr.select, opAttr.start[0],
649 opAttr.offset[0], opAttr.offset_hi[0], opAttr.square[0], opAttr.start[1],
650 opAttr.offset[1], opAttr.offset_hi[1], opAttr.square[1], ybuff);
652 assert(selectOp &&
"could not create select op");
658static aievec::aie1::ExtOp generateExtOp(Value source,
unsigned lanes,
659 int8_t idx, VectState *state,
661 auto stype = cast<VectorType>(source.getType());
669 aievec::aie1::ExtOp::create(state->builder, loc, resultType, source, idx);
671 assert(extOp &&
"could not create ext op");
676static aievec::PackOp generatePackOp(Value source, VectState *state,
679 auto stype = cast<VectorType>(source.getType());
681 Type i8Type = IntegerType::get(source.getContext(), 8);
685 auto packOp = aievec::PackOp::create(state->builder, loc, resultType, source);
687 assert(packOp &&
"could not create pack op");
692static aievec::aie1::AddOp generateAddOp(Operation *Op, AIEOpAttributes &opAttr,
695 assert(opAttr.start.size() == opAttr.offset.size() &&
696 opAttr.start.size() == 2);
698 auto addOp = aievec::aie1::AddOp::create(
699 state->builder, Op->getLoc(), Op->getResult(0).getType(),
700 Op->getOperand(0), Op->getOperand(1), opAttr.start[0], opAttr.offset[0],
701 opAttr.offset_hi[0], opAttr.square[0], opAttr.start[1], opAttr.offset[1],
702 opAttr.offset_hi[1], opAttr.square[1]);
707static aievec::aie1::SubOp generateSubOp(Operation *Op, AIEOpAttributes &opAttr,
710 assert(opAttr.start.size() == opAttr.offset.size() &&
711 opAttr.start.size() == 2);
713 auto subOp = aievec::aie1::SubOp::create(
714 state->builder, Op->getLoc(), Op->getResult(0).getType(),
715 Op->getOperand(0), Op->getOperand(1), opAttr.start[0], opAttr.offset[0],
716 opAttr.offset_hi[0], opAttr.square[0], opAttr.start[1], opAttr.offset[1],
717 opAttr.offset_hi[1], opAttr.square[1]);
721static aievec::ShiftOp generateShiftOp(Value lhs, Value rhs, int32_t shiftBytes,
722 VectState *state, Location loc,
723 VectorType resType =
nullptr) {
724 auto vecType = cast<VectorType>(rhs.getType());
727 auto type = cast<VectorType>(lhs.getType());
728 if (type != vecType) {
729 printf(
"lhs and rhs do not have same type\n");
737 Type scalarType = vecType.getElementType();
741 auto constOp = arith::ConstantOp::create(
742 state->builder, loc, state->builder.getI32IntegerAttr(shiftBytes));
743 auto shiftOp = aievec::ShiftOp::create(state->builder, loc, resType, lhs, rhs,
744 constOp.getResult());
749static aievec::LegacyShuffleOp generateShuffleOp(Value source, VectState *state,
750 Location loc,
unsigned mode,
751 VectorType resType =
nullptr) {
752 auto vecType = cast<VectorType>(source.getType());
756 Type scalarType = vecType.getElementType();
760 auto shuffleOp = aievec::LegacyShuffleOp::create(state->builder, loc, resType,
769static Operation *generateMulOrFMAConvOpForInt8(Operation *Op,
770 AIEOpAttributes &opAttr,
774 assert(opAttr.start.size() == opAttr.offset.size() &&
775 opAttr.start.size() == 2 && state->dupFactor == 2);
777 Value lhs = state->sextTruncDefMap.count(Op->getOperand(1).getDefiningOp())
778 ? Op->getOperand(1).getDefiningOp()->getOperand(0)
780 Value rhs = state->sextTruncDefMap.count(Op->getOperand(0).getDefiningOp())
781 ? Op->getOperand(0).getDefiningOp()->getOperand(0)
783 auto vType = cast<VectorType>(lhs.getType());
784 Type stype = vType.getElementType();
785 auto itype = cast<IntegerType>(stype);
786 unsigned width = itype.getWidth() <= 8 ? 32 : 64;
790 Type ctype = IntegerType::get(itype.getContext(), width);
791 Type opType = VectorType::get(vType.getShape(), ctype);
792 auto defOp = rhs.getDefiningOp();
793 state->builder.setInsertionPointAfter(defOp);
794 Location loc = defOp->getLoc();
799 Operation *shuffleOp = generateShuffleOp(defOp->getResult(0), state, loc, 0);
806 state->builder.setInsertionPointAfter(shuffleOp);
807 loc = shuffleOp->getLoc();
808 rhs = generateShiftOp(shuffleOp->getResult(0), shuffleOp->getResult(0),
809 shiftBytes, state, loc);
811 rhs = shuffleOp->getResult(0);
814 state->builder.setInsertionPoint(Op);
817 Operation *convOp =
nullptr;
819 if (isa<MulIOp>(Op)) {
821 aievec::MulConvOp::create(state->builder, loc, opType, lhs, rhs, M, N);
824 if (isa<vector::FMAOp>(Op)) {
825 Value acc = Op->getOperand(2);
826 bool isSub = state->mscOps.count(Op);
827 convOp = aievec::FMAConvOp::create(state->builder, loc, opType, lhs, rhs,
838static Operation *generateFMAOp(vector::FMAOp fmaOp, AIEOpAttributes &opAttr,
839 VectState *state,
bool i8xi8_pairedOp =
false) {
842 assert(opAttr.start.size() == opAttr.offset.size() &&
843 opAttr.start.size() == 2);
845 Value lhs = state->sextTruncDefMap.count(fmaOp.getLhs().getDefiningOp())
846 ? fmaOp.getLhs().getDefiningOp()->getOperand(0)
848 Value rhs = state->sextTruncDefMap.count(fmaOp.getRhs().getDefiningOp())
849 ? fmaOp.getRhs().getDefiningOp()->getOperand(0)
851 Value acc = state->sextTruncDefMap.count(fmaOp.getAcc().getDefiningOp())
852 ? fmaOp.getAcc().getDefiningOp()->getOperand(0)
856 bool isSub = state->mscOps.count(fmaOp);
860 bool isInt = isa<IntegerType>(
861 cast<VectorType>(fmaOp.getLhs().getType()).getElementType());
866 if (!writesToAccumulator(acc.getDefiningOp())) {
867 acc = generateUPSOp(acc, state, fmaOp->getLoc());
868 LLVM_DEBUG(llvm::dbgs()
869 <<
"\n\nCreated UPS op " << acc <<
" to move the output of "
870 << fmaOp <<
" into accumulator");
873 if (!isSimpleVectIntrinsic(fmaOp, state)) {
877 AIEVecAttributes rstat = getOperandVecStats(fmaOp, state, 1);
879 rhs = generateBroadcastOp(rhs, stoi(opAttr.start[1]), state,
884 xfmaOp = aievec::FMAElemOp::create(state->builder, fmaOp->getLoc(), lhs,
889 if (i8xi8_pairedOp) {
890 Operation *defOp = acc.getDefiningOp();
891 if (state->pairedOp.count(defOp))
892 acc = state->pairedOp[defOp]->getResult(0);
895 if (isInt && !writesToAccumulator(acc.getDefiningOp())) {
896 acc = generateUPSOp(acc, state, fmaOp->getLoc());
897 LLVM_DEBUG(llvm::dbgs()
898 <<
"\n\nCreated UPS op " << acc <<
" to move the output of "
899 << fmaOp <<
" into accumulator");
904 if (!isSimpleVectIntrinsic(fmaOp, state)) {
905 AIEVecAttributes lstat = getOperandVecStats(fmaOp, state, 0);
906 assert(lstat.vecSizeInBits % 256 == 0);
908 if (lstat.vecSizeInBits == 256) {
909 VectorType concatType =
911 SmallVector<Value> sources = {lhs, lhs};
912 lhs = generateConcatOp(sources, state, fmaOp->getLoc(), concatType);
916 xfmaOp = aievec::aie1::FMAOp::create(
917 state->builder, fmaOp->getLoc(), lhs, rhs, acc, opAttr.start[0],
918 opAttr.offset[0], opAttr.offset_hi[0], opAttr.step[0], opAttr.square[0],
919 opAttr.start[1], opAttr.offset[1], opAttr.offset_hi[1], opAttr.step[1],
920 opAttr.square[1], isSub);
923 assert(xfmaOp &&
"could not create fma op");
930static Operation *generateMulOp(T mulOp, AIEOpAttributes &opAttr,
934 assert(opAttr.start.size() == opAttr.offset.size() &&
935 opAttr.start.size() == 2);
942 Value lhs = state->sextTruncDefMap.count(mulOp.getLhs().getDefiningOp())
943 ? mulOp.getLhs().getDefiningOp()->getOperand(0)
945 Value rhs = state->sextTruncDefMap.count(mulOp.getRhs().getDefiningOp())
946 ? mulOp.getRhs().getDefiningOp()->getOperand(0)
948 if (!isSimpleVectIntrinsic(mulOp, state)) {
949 AIEVecAttributes lstat = getOperandVecStats(mulOp, state, 0);
950 assert(lstat.vecSizeInBits % 256 == 0);
951 if (lstat.vecSizeInBits == 256) {
952 VectorType concatType =
954 SmallVector<Value> sources = {lhs, lhs};
955 lhs = generateConcatOp(sources, state, mulOp->getLoc(), concatType);
960 Operation *xmulOp = aievec::aie1::MulOp::create(
961 state->builder, mulOp->getLoc(), lhs, rhs, opType, opAttr.start[0],
962 opAttr.offset[0], opAttr.offset_hi[0], opAttr.step[0], opAttr.square[0],
963 opAttr.start[1], opAttr.offset[1], opAttr.offset_hi[1], opAttr.step[1],
966 assert(xmulOp &&
"could not create mul op");
976generateUPDOp(TransferReadOp readOp,
977 mlir::DenseMap<std::tuple<IntervalReuse *, int32_t, int32_t>,
978 std::pair<aievec::UPDOp, int8_t>> &memToUpdMap,
979 Region ®ion, VectState *state) {
985 int32_t intervalWidth = interval.second - interval.first;
986 assert(intervalWidth >= 128 &&
"Interval computation incorrect");
991 auto vecType = cast<VectorType>(readOp.getVector().getType());
992 Type elementType = vecType.getElementType();
994 int intervalWidthInBytes = intervalWidth / elementSizeInBits;
1001 int32_t mid = interval.first + intervalWidth / 2;
1005 intervalWidth <= (state->aieml && elementSizeInBits == 8 ? 512 : 256) ||
1010 intervalWidth <= (state->aieml && elementSizeInBits == 8 ? 512 : 256) ||
1016 aievec::UPDOp updOp =
nullptr;
1019 int8_t updIndices = 0;
1020 auto key = std::make_tuple(iv, interval.first, interval.second);
1021 if (memToUpdMap.count(key)) {
1022 updOp = memToUpdMap[key].first;
1023 updIndices = memToUpdMap[key].second;
1033 SmallVector<Value, 4> indices(readOp.getIndices().begin(),
1034 readOp.getIndices().end());
1036 AffineExpr linearAccess = constructLinearizedAffineExpr(readOp, state);
1038 auto [base, offset] = getBaseAndOffset(linearAccess);
1039 offset *= elementSizeInBits;
1046 bool singleBlock = region.getBlocks().size() == 1;
1048 state->builder.setInsertionPoint(readOp);
1050 state->builder.setInsertionPointToStart(®ion.front());
1055 int width = state->aieml ? elementSizeInBits == 8
1059 int32_t incr = std::max(width, intervalWidth / 2);
1061 for (int32_t start = interval.first; start < interval.second;
1062 start += incr, ++idx) {
1065 assert(idx <= 2 &&
"The only allowed values for UPD index are 0 and 1");
1066 int32_t end = std::min(interval.second, start + incr);
1070 if (lb <= start && ub >= end && (updIndices & idx) == 0) {
1073 updOp = aievec::UPDOp::create(
1074 state->builder, readOp.getLoc(), updVecType, readOp.getBase(),
1075 indices, start - offset, idx - 1,
1076 updOp ? updOp.getResult() : TypedValue<VectorType>(nullptr));
1078 LLVM_DEBUG(llvm::dbgs() <<
"\n\nCreated UPD op " << updOp
1079 <<
" for read op " << readOp);
1083 for (
auto &value : indices) {
1084 if (
auto apOf =
value.getDefiningOp<affine::AffineApplyOp>()) {
1086 if (apOf->getBlock() == readOp->getBlock() &&
1087 apOf->isBeforeInBlock(updOp))
1089 apOf.getOperation()->moveBefore(updOp);
1099 memToUpdMap[key] = std::make_pair(updOp, updIndices);
1109static int32_t computeVecorizedLoopStepSize(Operation *op, VectState *state) {
1110 auto readOp = dyn_cast<TransferReadOp>(op);
1116 auto vectorType = cast<VectorType>(readOp.getResult().getType());
1117 SmallVector<Value, 4> indices(readOp.getIndices().begin(),
1118 readOp.getIndices().end());
1119 assert(vectorType && !indices.empty());
1122 auto block = readOp->getBlock();
1123 assert(state->blockToEnclosingLoops.count(block) &&
1124 "enclosing loops should have been computed for the read operation");
1125 auto enclosingLoops = state->blockToEnclosingLoops[block];
1129 AffineExpr expr = readOp.getPermutationMap().getResults().back();
1130 if (
auto dimExpr = llvm::dyn_cast<AffineDimExpr>(expr)) {
1131 assert(dimExpr.getPosition() <= indices.size() &&
1132 "Failed to find the permutation index in index map");
1133 auto index = indices[dimExpr.getPosition()];
1136 [[maybe_unused]]
bool found =
false;
1137 for (
auto loop : enclosingLoops) {
1138 auto iv = cast<affine::AffineForOp>(loop).getInductionVar();
1139 auto invariants = affine::getInvariantAccesses(iv, indices);
1140 if (!invariants.count(index)) {
1143 "stepsize computation already has an entry along the variant dim");
1144 step = cast<affine::AffineForOp>(loop).getStepAsInt();
1150 "non-power-of-two vectorization factor not supported");
1154 return step / lanes;
1163 if (!isa<TransferReadOp>(op))
1166 auto readOp = cast<TransferReadOp>(op);
1169 auto vtype = cast<VectorType>(readOp.getVector().getType());
1173 AffineExpr linearAccess = constructLinearizedAffineExpr(readOp, state);
1175 auto [base, offset] = getBaseAndOffset(linearAccess);
1176 offset *= scalarSizeInBits;
1179 std::pair<int32_t, int32_t> interval = iv->
getInterval(op);
1185 assert(offset >= interval.first &&
"Failed to compute the start");
1186 return (offset - interval.first) / scalarSizeInBits;
1195static Operation *concatAndInterleave_i8xi8(Operation *source1,
1197 VectState *state, Location loc) {
1203 IntegerType::get(source1->getResult(0).getType().getContext(), 16);
1204 auto srsOp1 = generateSRSOp(source1->getResult(0), i16Type, state, loc);
1205 auto srsOp2 = generateSRSOp(source2->getResult(0), i16Type, state, loc);
1208 SmallVector<Value> sources = {srsOp1->getResult(0), srsOp2->getResult(0)};
1209 auto concatOp = generateConcatOp(sources, state, loc);
1213 AIEOpAttributes opAttr;
1216 opAttr.select =
"0xcccccccc";
1219 opAttr.start.push_back(
"0");
1220 opAttr.start.push_back(
"4");
1221 for (
size_t idx = 0; idx < 2; ++idx) {
1227 opAttr.offset.push_back(
"0x0c080400");
1230 opAttr.offset_hi.push_back(
"0x0");
1232 opAttr.square.push_back(
"0x1010");
1236 generateSelectOp(concatOp->getResult(0), opAttr, 32, state, loc);
1239 auto extOp = generateExtOp(selectOp->getResult(0), 16, 0, state, loc);
1241 auto packOp = generatePackOp(extOp->getResult(0), state, loc);
1248static bool canFuseMulAndAddOrSubIntoFMAOp(Operation *Op, VectState *state) {
1250 assert((isa<AddIOp>(Op) || isa<AddFOp>(Op) || isa<SubIOp>(Op) ||
1252 "operation must be an add or sub op");
1255 assert(Op->getNumOperands() == 2 && Op->getNumResults() == 1);
1260 Operation *mulOp = getOperandDefOp(state, Op, 1);
1261 if (!isa<MulIOp, MulFOp>(mulOp))
1265 assert(mulOp->getNumOperands() == 2 && mulOp->getNumResults() == 1);
1268 Value lhs = state->sextTruncDefMap.count(mulOp->getOperand(0).getDefiningOp())
1269 ? mulOp->getOperand(0).getDefiningOp()->getOperand(0)
1270 : mulOp->getOperand(0);
1271 Value rhs = state->sextTruncDefMap.count(mulOp->getOperand(1).getDefiningOp())
1272 ? mulOp->getOperand(1).getDefiningOp()->getOperand(0)
1273 : mulOp->getOperand(1);
1274 Value acc = state->sextTruncDefMap.count(Op->getOperand(0).getDefiningOp())
1275 ? Op->getOperand(0).getDefiningOp()->getOperand(0)
1276 : Op->getOperand(0);
1278 assert(lhs && rhs && acc &&
1279 "Failed to find the three operands of the FMA op");
1282 if (!isa<VectorType>(lhs.getType()) || !isa<VectorType>(rhs.getType()) ||
1283 !isa<VectorType>(acc.getType()))
1288 if (lhs.getParentBlock() != rhs.getParentBlock() ||
1289 rhs.getParentBlock() != acc.getParentBlock())
1293 auto lhsType = cast<VectorType>(lhs.getType());
1294 auto rhsType = cast<VectorType>(rhs.getType());
1295 VectorType accType = state->sextTruncDefMap.count(
1296 acc.getDefiningOp()->getOperand(0).getDefiningOp())
1297 ? cast<VectorType>(acc.getDefiningOp()
1302 : cast<VectorType>(acc.getType());
1308 if (lhsVecSize != rhsVecSize || rhsVecSize != accVecSize)
1313 if (lhsType.getElementType() != rhsType.getElementType() ||
1314 rhsType.getElementType() != accType.getElementType())
1326static void reassociateMulOpBasedOnVecSize(Operation *Op, VectState *state) {
1328 AIEVecAttributes lstat = getOperandVecStats(Op, state, 0);
1329 AIEVecAttributes rstat = getOperandVecStats(Op, state, 1);
1332 if (lstat.vecSizeInBits == rstat.vecSizeInBits)
1336 bool is8x8 = lstat.elementSizeInBits == 8 && rstat.elementSizeInBits == 8;
1339 bool flip = is8x8 ? lstat.vecSizeInBits > rstat.vecSizeInBits
1340 : rstat.vecSizeInBits > lstat.vecSizeInBits;
1342 LLVM_DEBUG(llvm::dbgs()
1343 <<
"\n\nReassociating op " << *Op
1344 <<
" to correctly place operand coming from bigger vector");
1345 Value left = Op->getOperand(0);
1346 Value right = Op->getOperand(1);
1347 Op->setOperand(0, right);
1348 Op->setOperand(1, left);
1349 LLVM_DEBUG(llvm::dbgs() <<
"\n\tOp after reassociation: " << *Op);
1356static void reassociateMulOpWithSplat(Operation *Op, VectState *state) {
1359 assert(Op->getNumOperands() == 2 || Op->getNumOperands() == 3);
1360 assert(Op->getNumResults() == 1);
1363 AIEVecAttributes lstat = getOperandVecStats(Op, state, 0);
1364 AIEVecAttributes rstat = getOperandVecStats(Op, state, 1);
1367 if (lstat.isSplat && rstat.isSplat)
1371 bool is8x8 = lstat.elementSizeInBits == 8 && rstat.elementSizeInBits == 8;
1375 bool flip = is8x8 ? rstat.isSplat : lstat.isSplat;
1376 Value left = state->sextTruncDefMap.count(Op->getOperand(0).getDefiningOp())
1377 ? Op->getOperand(0).getDefiningOp()->getOperand(0)
1378 : Op->getOperand(0);
1379 Value right = state->sextTruncDefMap.count(Op->getOperand(1).getDefiningOp())
1380 ? Op->getOperand(1).getDefiningOp()->getOperand(0)
1381 : Op->getOperand(1);
1383 LLVM_DEBUG(llvm::dbgs() <<
"\n\nReassociating op " << *Op
1384 <<
" to place splat as correct operand");
1385 Op->setOperand(0, right);
1386 Op->setOperand(1, left);
1387 LLVM_DEBUG(llvm::dbgs() <<
"\n\tOp after reassociation: " << *Op);
1389 Op->setOperand(0, left);
1390 Op->setOperand(1, right);
1393 Op->getResult(0).setType(Op->getOperand(0).getType());
1395 if (Op->hasOneUse() &&
1396 isa<AddIOp, AddFOp, SubIOp, SubFOp>(*Op->getUsers().begin())) {
1397 Operation *usrOp = *Op->getUsers().begin();
1398 usrOp->getResult(0).setType(usrOp->getOperand(0).getType());
1403static void fuseMulAndAddOrSubIntoFMAOp(Operation *Op, VectState *state) {
1404 Value acc = state->sextTruncDefMap.count(Op->getOperand(0).getDefiningOp())
1405 ? Op->getOperand(0).getDefiningOp()->getOperand(0)
1406 : Op->getOperand(0);
1407 Operation *mulOp = getOperandDefOp(state, Op, 1);
1408 Value lhs = state->sextTruncDefMap.count(mulOp->getOperand(0).getDefiningOp())
1409 ? mulOp->getOperand(0).getDefiningOp()->getOperand(0)
1410 : mulOp->getOperand(0);
1411 Value rhs = state->sextTruncDefMap.count(mulOp->getOperand(1).getDefiningOp())
1412 ? mulOp->getOperand(1).getDefiningOp()->getOperand(0)
1413 : mulOp->getOperand(1);
1416 state->builder.setInsertionPointAfter(Op);
1418 vector::FMAOp::create(state->builder, Op->getLoc(), lhs, rhs, acc);
1421 bool isSub = isa<SubIOp, SubFOp>(Op);
1423 state->mscOps.insert(fmaOp);
1425 LLVM_DEBUG(llvm::dbgs() <<
"\n\nFused " << (isSub ?
"sub" :
"add") <<
" op "
1426 << *Op <<
"\n\tand mul op " << *mulOp
1427 <<
"\n\tinto fma op " << *fmaOp);
1430 Op->replaceAllUsesWith(fmaOp);
1434 if (mulOp->use_empty())
1444static void generateMulOrFMAOp(Operation *Op, Scheme &scheme,
1445 AIEOpAttributes &opAttr, VectState *state,
1446 const std::string &nextStart =
"") {
1448 assert(opAttr.start.size() == opAttr.offset.size() &&
1449 opAttr.start.size() == 2);
1452 state->builder.setInsertionPointAfter(Op);
1455 auto notMulOrFMAOp = [&](Operation *op) {
1456 return !isa<MulIOp, MulFOp, vector::FMAOp>(op);
1460 auto genOp = [&](Operation *Op, AIEOpAttributes &opAttr, VectState *state,
1461 bool i8xi8_pairedOp =
false) {
1464 if (
auto fmaOp = dyn_cast<vector::FMAOp>(Op))
1465 repOp = generateFMAOp(fmaOp, opAttr, state, i8xi8_pairedOp);
1467 else if (
auto mulOp = dyn_cast<MulIOp>(Op))
1468 repOp = generateMulOp<MulIOp>(mulOp, opAttr, state);
1470 else if (
auto mulOp = dyn_cast<MulFOp>(Op))
1471 repOp = generateMulOp<MulFOp>(mulOp, opAttr, state);
1473 llvm_unreachable(
"Operation not mul/fma op");
1477 Operation *repOp = genOp(Op, opAttr, state);
1478 LLVM_DEBUG(llvm::dbgs() <<
"\n\nGenerated AIE dialect mul/fma op " << *repOp);
1484 if (!nextStart.empty()) {
1485 if (state->aieml && scheme.lanes == 32 && scheme.xbits == 8 &&
1486 scheme.zbits == 8) {
1487 repOp = generateMulOrFMAConvOpForInt8(Op, opAttr, state);
1488 if (any_of(repOp->getUsers(), notMulOrFMAOp)) {
1490 IntegerType::get(repOp->getResult(0).getType().getContext(), 8);
1492 generateSRSOp(repOp->getResult(0), i8Type, state, repOp->getLoc());
1495 opAttr.start[1] = nextStart;
1496 Operation *pairedOp = genOp(Op, opAttr, state,
true);
1497 LLVM_DEBUG(llvm::dbgs() <<
"\n\nGenerated the paired AIE dialect "
1498 <<
"mul/fma op for 8x8 scheme " << *repOp);
1500 assert(!state->pairedOp.count(repOp));
1501 state->pairedOp[repOp] = pairedOp;
1504 if (any_of(Op->getUsers(), notMulOrFMAOp))
1505 repOp = concatAndInterleave_i8xi8(repOp, pairedOp, state, Op->getLoc());
1511 Op->replaceAllUsesWith(repOp);
1516static void computeBuffAttr_i32xi32(
1520 AIEOpAttributes &opAttr) {
1522 std::string startStr = std::to_string(start);
1524 std::string offsetStr =
"0x";
1525 for (
int i = vecSize - 1; i >= 0; --i)
1529 opAttr.start.push_back(startStr);
1530 opAttr.offset.push_back(offsetStr);
1531 opAttr.offset_hi.push_back(
"");
1532 opAttr.square.push_back(
"");
1533 opAttr.step.push_back(
"");
1537static void computeXbuffAttr_i16xi16(
1542 AIEOpAttributes &opAttr) {
1544 assert(colOffset >= -1 && (colOffset <= 1 || colOffset % 2 == 0) &&
1545 "cannot compute offset and square for xbuff");
1548 assert((accIncr <= 1 || colOffset <= 1) &&
1549 "cannot generate offset and square for xbuff");
1552 int32_t m2start = (start / 2) * 2;
1553 std::string startStr = std::to_string(m2start);
1555 int32_t m2Offset = start - m2start;
1559 std::string offsetStr =
"0x";
1560 int32_t offset = std::max(colOffset, accIncr);
1561 for (
int i = vecSize / 2 - 2; i >= 0; i -= 2) {
1562 offsetStr.push_back(offset <= 1 ?
'0' :
getHexValue((offset - 2) / 2));
1563 offsetStr.push_back(
getHexValue((i * accIncr) / 2));
1565 std::string offsetHiStr =
"0x";
1566 for (
int i = vecSize - 2, e = vecSize / 2; i >= e; i -= 2) {
1567 offsetHiStr.push_back(offset <= 1 ?
'0' :
getHexValue((offset - 2) / 2));
1568 offsetHiStr.push_back(
getHexValue((i * accIncr) / 2));
1572 int32_t cstep = std::min(2, std::abs(colOffset));
1573 int32_t astep = std::min(2, accIncr);
1574 assert(m2Offset == 0 || (astep <= 1 && cstep <= 1));
1576 SmallVector<int32_t> sqPattern = {astep + cstep, astep, cstep, 0};
1577 std::string squareStr =
"0x";
1578 for (
auto sq : sqPattern)
1582 opAttr.start.push_back(startStr);
1583 opAttr.offset.push_back(offsetStr);
1584 opAttr.offset_hi.push_back(offsetHiStr);
1585 opAttr.square.push_back(squareStr);
1586 opAttr.step.push_back(
"");
1590static void computeZbuffAttr_i16xi16(
1596 bool aieml, AIEOpAttributes &opAttr) {
1597 std::string offsetStr, offsetHiStr;
1599 assert(start < (aieml ? 32 : 16) &&
"zstart must be 4b value");
1600 std::string startStr = std::to_string(start);
1604 offsetStr = offsetHiStr =
"0";
1608 for (
int i = vecSize / 2 - 1; i >= 0; --i)
1611 for (
auto i = vecSize - 1, e = vecSize / 2; i >= e; --i)
1616 int32_t step = colOffset == -1 ? zeroOffset - 1 - start : colOffset;
1617 assert(step >= 0 &&
"zstep cannot be negative");
1618 std::string stepStr = std::to_string(step);
1621 opAttr.start.push_back(startStr);
1622 opAttr.offset.push_back(offsetStr);
1623 opAttr.offset_hi.push_back(offsetHiStr);
1624 opAttr.square.push_back(
"");
1625 opAttr.step.push_back(stepStr);
1633static void computeXbuffAttr_i8xi8(
1637 AIEOpAttributes &opAttr) {
1642 "each filter entry must be replicated at least twice for i8xi8 scheme");
1643 int32_t colStep = 2 * colOffset;
1644 assert(colStep % 4 == 0 &&
"xstep must be multiple of 4");
1647 int32_t m4start = (start / 4) * 4;
1648 std::string startStr = std::to_string(m4start);
1650 int32_t m4Offset = start - m4start;
1652 assert(m4Offset == 0 || m4Offset == 2);
1656 std::string offsetStr =
"0x";
1657 for (
int i = vecSize / 4 - 1; i >= 0; --i) {
1658 offsetStr.push_back(
getHexValue(colStep / 4 - 1));
1661 std::string stepStr = std::to_string(colStep);
1664 int32_t offsetWithoutDup = colOffset / 2;
1665 int32_t rstep = offsetWithoutDup >= 2 ? 2
1666 : colOffset == -1 ? 1
1668 assert(m4Offset == 0 || rstep <= 1);
1670 SmallVector<int32_t> sqPattern = {rstep, 0, rstep, 0};
1671 std::string squareStr =
"0x";
1672 for (
auto sq : sqPattern)
1676 opAttr.start.push_back(startStr);
1677 opAttr.offset.push_back(offsetStr);
1678 opAttr.offset_hi.push_back(
"");
1679 opAttr.square.push_back(squareStr);
1680 opAttr.step.push_back(stepStr);
1686static void computeZbuffAttr_i8xi8(
1691 AIEOpAttributes &opAttr, std::string &nextStart) {
1693 assert((colOffset <= 1 || colOffset % 2 == 0) &&
"zbuff value not supported");
1696 int32_t m2start = (start / 2) * 2;
1697 std::string startStr = std::to_string(m2start);
1699 int32_t m2Offset = start - m2start;
1703 std::string offsetStr =
"0x";
1704 for (
int i = vecSize / 4 - 1; i >= 0; --i) {
1705 int32_t val = i * accIncr + (colOffset + 1) / 2;
1709 std::string stepStr = std::to_string(2 * std::abs(colOffset));
1710 nextStart = std::to_string(m2start + 2 * accIncr * (vecSize / 4));
1714 int32_t rstep = colOffset >= 2 ? 2 : std::abs(colOffset);
1715 assert(m2Offset == 0 || rstep <= 1);
1717 SmallVector<int32_t> sqPattern = {accIncr + rstep, accIncr, rstep, 0};
1718 std::string squareStr =
"0x";
1719 for (
auto sq : sqPattern)
1723 opAttr.start.push_back(startStr);
1724 opAttr.offset.push_back(offsetStr);
1725 opAttr.offset_hi.push_back(
"");
1726 opAttr.square.push_back(squareStr);
1727 opAttr.step.push_back(stepStr);
1736static void fuseFMAOps(Operation *refOp,
1737 llvm::SmallSet<Operation *, 8> &fusedOpSet, int32_t cols,
1741 if (cols <= 1 || !isa<MulIOp, MulFOp, vector::FMAOp>(refOp) ||
1742 isSimpleVectIntrinsic(refOp, state))
1747 Operation *lOp = getOperandDefOp(state, refOp, 0);
1748 Operation *rOp = getOperandDefOp(state, refOp, 1);
1755 int xOffset = -1, zOffset = -1;
1765 Operation *curOp = refOp;
1766 SmallVector<Operation *, 8> fusedOps;
1768 for (
auto len = 0; len < cols - 1; ++len) {
1770 if (!curOp->hasOneUse())
1773 Operation *usrOp = *curOp->getUsers().begin();
1776 if (!isa<vector::FMAOp>(usrOp) || curOp->getBlock() != usrOp->getBlock() ||
1777 isSimpleVectIntrinsic(usrOp, state))
1780 if (isa<vector::FMAOp>(curOp) &&
1781 state->mscOps.count(curOp) != state->mscOps.count(usrOp))
1784 SmallVector<int32_t, 2> offsets;
1785 for (
size_t idx = 0; idx < 2; ++idx) {
1787 AIEVecAttributes cstat = getOperandVecStats(curOp, state, idx);
1788 AIEVecAttributes ustat = getOperandVecStats(usrOp, state, idx);
1792 if (cstat.vecSizeInBits != ustat.vecSizeInBits ||
1793 cstat.elementSizeInBits != ustat.elementSizeInBits ||
1794 cstat.loadFromMemory != ustat.loadFromMemory ||
1795 cstat.isSplat != ustat.isSplat)
1798 Operation *cdefOp = getOperandDefOp(state, curOp, idx);
1799 Operation *udefOp = getOperandDefOp(state, usrOp, idx);
1801 bool related = cdefOp == udefOp;
1802 if (!related && cstat.loadFromMemory && ustat.loadFromMemory) {
1803 IntervalReuse *civ = state->getIntervalForOperation(cdefOp);
1804 IntervalReuse *uiv = state->getIntervalForOperation(udefOp);
1816 int32_t offset = start2 - start1;
1823 if (offset > 1 && offset % 2 != 0)
1826 int32_t refStart = idx == 0 ? lstart : rstart;
1827 if (!ustat.isSplat && offset > 1 && refStart != 0)
1831 offsets.push_back(offset);
1834 if (offsets.size() < 2)
1838 if ((xOffset != -1 && xOffset != offsets[0]) ||
1839 (zOffset != -1 && zOffset != offsets[1]))
1842 xOffset = offsets[0];
1843 zOffset = offsets[1];
1844 fusedOps.push_back(usrOp);
1851 if (fusedOps.empty())
1854 LLVM_DEBUG(llvm::dbgs() <<
"\n\nFused following fma ops with op " << *refOp);
1858 for (
auto &op : fusedOps) {
1859 LLVM_DEBUG(llvm::dbgs() <<
"\n\tfma op " << *op);
1860 fusedOpSet.insert(op);
1862 fuseAccessExtent(refOp, op, state);
1864 op->replaceAllUsesWith(refOp);
1868 assert(!state->opToColOffsets.count(refOp));
1869 state->opToColOffsets[refOp] = std::make_pair(xOffset, zOffset);
1873static void computeXbuffAttributes(
1879 bool aieml, AIEOpAttributes &opAttr) {
1882 if ((scheme.lanes == 8 || (aieml && scheme.lanes == 16)) &&
1883 scheme.cols == 1 && scheme.xbits == 32 && scheme.zbits == 32)
1884 computeBuffAttr_i32xi32(scheme.lanes, start, accIncr, opAttr);
1886 else if ((scheme.lanes == 16 || (aieml && scheme.lanes == 32)) &&
1887 scheme.cols == 2 && scheme.xbits == 16 && scheme.zbits == 16) {
1889 assert((accIncr <= 1 || accIncr % 2 == 0) &&
1890 "loop step size value not supported");
1891 computeXbuffAttr_i16xi16(scheme.lanes, start, accIncr, colOffset, opAttr);
1894 else if ((scheme.lanes == 16 || (aieml && scheme.lanes == 32)) &&
1895 scheme.cols == 8 && scheme.xbits == 8 && scheme.zbits == 8) {
1897 assert(accIncr <= 1 &&
"loop step size greater than 1 not supported");
1900 if (colOffset == -1)
1901 colOffset = dupFactor;
1902 computeXbuffAttr_i8xi8(scheme.lanes, start, colOffset, opAttr);
1904 llvm_unreachable(
"Unsupported vectorization scheme");
1908static void computeZbuffAttributes(
1915 std::string &nextStart,
1916 AIEOpAttributes &opAttr) {
1919 if ((scheme.lanes == 8 || (aieml && scheme.lanes == 16)) &&
1920 scheme.cols == 1 && scheme.xbits == 32 && scheme.zbits == 32)
1921 computeBuffAttr_i32xi32(scheme.lanes, start, accIncr, opAttr);
1923 else if ((scheme.lanes == 16 || (aieml && scheme.lanes == 32)) &&
1924 scheme.cols == 2 && scheme.xbits == 16 && scheme.zbits == 16) {
1926 assert(accIncr <= 1 &&
"loop step size greater than 1 not supported");
1929 zeroOffset = zeroOffset == 0 ? scheme.lanes
1930 : start + zeroOffset - (start % zeroOffset);
1931 computeZbuffAttr_i16xi16(scheme.lanes, start, accIncr, zeroOffset,
1932 colOffset, aieml, opAttr);
1935 else if ((scheme.lanes == 16 || (aieml && scheme.lanes == 32)) &&
1936 scheme.cols == 8 && scheme.xbits == 8 && scheme.zbits == 8) {
1938 assert(accIncr <= 1 &&
"loop step size greater than 1 not supported");
1939 computeZbuffAttr_i8xi8(scheme.lanes, start, accIncr, colOffset, opAttr,
1942 llvm_unreachable(
"Unsupported vectorization scheme");
1947static void generateSchemeBasedMulOrFMAOp(Operation *Op, VectState *state) {
1948 int32_t lanes, cols;
1949 std::tie(lanes, cols) = getNumRowsAndCols(Op, state);
1951 Value lhs = state->sextTruncDefMap.count(Op->getOperand(0).getDefiningOp())
1952 ? Op->getOperand(0).getDefiningOp()->getOperand(0)
1953 : Op->getOperand(0);
1954 Value rhs = state->sextTruncDefMap.count(Op->getOperand(1).getDefiningOp())
1955 ? Op->getOperand(1).getDefiningOp()->getOperand(0)
1956 : Op->getOperand(1);
1959 Scheme scheme(lanes, cols, xbits, zbits);
1963 if (isSimpleVectIntrinsic(Op, state)) {
1966 AIEOpAttributes opAttr;
1968 for (
size_t idx = 0; idx < 2; ++idx) {
1969 opAttr.start.push_back(
"");
1970 opAttr.offset.push_back(
"");
1971 opAttr.offset_hi.push_back(
"");
1972 opAttr.square.push_back(
"");
1973 opAttr.step.push_back(
"");
1975 generateMulOrFMAOp(Op, scheme, opAttr, state);
1984 auto colOffset = state->opToColOffsets.count(Op) ? state->opToColOffsets[Op]
1985 : std::make_pair(-1, -1);
1989 AIEOpAttributes opAttr;
1993 std::string nextStart;
1996 for (
size_t idx = 0; idx < 2; ++idx) {
1997 AIEVecAttributes stat = getOperandVecStats(Op, state, idx);
1998 Operation *op = getOperandDefOp(state, Op, idx);
2000 int32_t start = 0, accIncr = 1;
2003 if (stat.loadFromMemory) {
2004 auto readOp = cast<TransferReadOp>(op);
2006 accIncr = stat.isSplat ? 0 : computeVecorizedLoopStepSize(readOp, state);
2012 computeXbuffAttributes(scheme, start, colOffset.first, accIncr,
2013 state->dupFactor, state->aieml, opAttr);
2015 computeZbuffAttributes(scheme, start, colOffset.second, accIncr,
2016 state->zeroOffset, state->aieml, nextStart,
2020 generateMulOrFMAOp(Op, scheme, opAttr, state, nextStart);
2026static void fuseFMAOpsForColumnTopology(func::FuncOp func, VectState *state) {
2028 llvm::SmallSet<Operation *, 8> fusedOpSet;
2031 func.walk([&](Operation *op) {
2032 if (isa<MulIOp, MulFOp, vector::FMAOp>(op)) {
2034 if (!fusedOpSet.count(op)) {
2035 auto [lanes, cols] = getNumRowsAndCols(op, state);
2038 fuseFMAOps(op, fusedOpSet, cols, state);
2044 for (
auto op : fusedOpSet)
2048template <
typename T1,
typename T2>
2049static bool matchAttributesAndDistanceForFusion(T1 curOp, T2 defOp) {
2050 return curOp.getOffset(0) == defOp.getOffset(0) &&
2051 curOp.getOffsetHi(0) == defOp.getOffsetHi(0) &&
2052 curOp.getSquare(0) == defOp.getSquare(0) &&
2053 curOp.getStep(0) == defOp.getStep(0) &&
2054 curOp.getOffset(1) == defOp.getOffset(1) &&
2055 curOp.getOffsetHi(1) == defOp.getOffsetHi(1) &&
2056 curOp.getSquare(1) == defOp.getSquare(1) &&
2057 curOp.getStep(1) == defOp.getStep(1) &&
2058 stoi(
static_cast<std::string
>(curOp.getStart(0))) -
2059 stoi(
static_cast<std::string
>(defOp.getStart(0))) ==
2061 stoi(
static_cast<std::string
>(curOp.getStart(1))) -
2062 stoi(
static_cast<std::string
>(defOp.getStart(1))) ==
2101static bool canFuseMulFMAOpsForInt16(Operation *Op) {
2103 assert(isa<aievec::aie1::FMAOp>(Op) &&
"operation must be an aievec fma op");
2104 auto curOp = cast<aievec::aie1::FMAOp>(Op);
2107 auto vType = cast<VectorType>(Op->getOperand(1).getType());
2108 Type stype = vType.getElementType();
2109 auto itype = llvm::dyn_cast<IntegerType>(stype);
2114 if (
unsigned width = itype.getWidth(); width != 16)
2118 Operation *mulOrFMAOp = Op->getOperand(2).getDefiningOp();
2120 if (!isa<aievec::aie1::MulOp, aievec::aie1::FMAOp>(mulOrFMAOp))
2124 if (!mulOrFMAOp->hasOneUse())
2128 if (mulOrFMAOp->getOperand(0) != Op->getOperand(0) ||
2129 mulOrFMAOp->getOperand(1) != Op->getOperand(1))
2132 Value lhs =
nullptr;
2133 Value rhs =
nullptr;
2134 Value acc =
nullptr;
2135 bool isMulOp =
false;
2139 if (
auto mulOp = dyn_cast<aievec::aie1::MulOp>(mulOrFMAOp)) {
2143 lhs = mulOp->getOperand(0);
2144 rhs = mulOp->getOperand(1);
2146 auto fmaOp = cast<aievec::aie1::FMAOp>(mulOrFMAOp);
2149 lhs = fmaOp->getOperand(0);
2150 rhs = fmaOp->getOperand(1);
2151 acc = fmaOp->getOperand(2);
2155 auto lUpdOp = dyn_cast<aievec::UPDOp>(lhs.getDefiningOp());
2156 auto rUpdOp = dyn_cast<aievec::UPDOp>(rhs.getDefiningOp());
2158 if (!lUpdOp || !rUpdOp) {
2164 if (lhs.getParentBlock() != rhs.getParentBlock())
2167 if (acc && rhs.getParentBlock() != acc.getParentBlock())
2172 return (isMulOp && matchAttributesAndDistanceForFusion(
2173 curOp, cast<aievec::aie1::MulOp>(mulOrFMAOp))) ||
2174 matchAttributesAndDistanceForFusion(
2175 curOp, cast<aievec::aie1::FMAOp>(mulOrFMAOp));
2179static void fuseMulFMAOpsForInt16(Operation *Op, VectState *state) {
2180 auto curOp = cast<aievec::aie1::FMAOp>(Op);
2182 Value lhs = curOp->getOperand(0);
2188 auto lUpdOp = dyn_cast<aievec::UPDOp>(lhs.getDefiningOp());
2189 if (lUpdOp.getIndex() == 1) {
2190 auto lUpdOp0 = dyn_cast<aievec::UPDOp>(lUpdOp.getVector().getDefiningOp());
2191 lUpdOp->replaceAllUsesWith(lUpdOp0);
2198 auto rUpdOp = dyn_cast<aievec::UPDOp>(curOp->getOperand(1).getDefiningOp());
2199 state->builder.setInsertionPointAfter(rUpdOp);
2200 AIEVecAttributes rstat = getOperandVecStats(curOp, state, 1);
2201 assert(rstat.vecSizeInBits % 256 == 0);
2202 Value concatRhs =
nullptr;
2204 if (rstat.vecSizeInBits == 256) {
2205 VectorType concatType =
2207 SmallVector<Value> sources = {rUpdOp->getResult(0), rUpdOp->getResult(0)};
2208 concatRhs = generateConcatOp(sources, state, rUpdOp->getLoc(), concatType);
2212 Operation *convOp =
nullptr;
2213 Operation *mulOrFMAOp = Op->getOperand(2).getDefiningOp();
2214 auto mulOp = dyn_cast<aievec::aie1::MulOp>(mulOrFMAOp);
2215 auto fmaOp = dyn_cast<aievec::aie1::FMAOp>(mulOrFMAOp);
2219 aievec::aie1::MulOp defOp = mulOp;
2220 zStart = stoi(
static_cast<std::string
>(defOp.getStart(1)));
2222 aievec::aie1::FMAOp defOp = fmaOp;
2223 zStart = stoi(
static_cast<std::string
>(defOp.getStart(1)));
2226 auto vType = cast<VectorType>(Op->getOperand(1).getType());
2229 auto defOp = mulOp ? mulOp : fmaOp;
2230 state->builder.setInsertionPoint(defOp);
2231 Location loc = defOp->getLoc();
2235 concatRhs = generateShiftOp(concatRhs, concatRhs, shiftBytes, state, loc);
2237 Type stype = vType.getElementType();
2238 auto itype = cast<IntegerType>(stype);
2239 unsigned width = itype.getWidth() <= 8 ? 32 : 64;
2240 Type ctype = IntegerType::get(itype.getContext(), width);
2241 Type opType = VectorType::get(vType.getShape(), ctype);
2242 Value acc =
nullptr;
2245 int32_t M = itype.getWidth();
2249 lhs = curOp->getOperand(0);
2252 convOp = aievec::MulConvOp::create(state->builder, loc, opType, lhs,
2255 acc = defOp->getOperand(2);
2256 bool isSub = state->mscOps.count(defOp);
2257 convOp = aievec::FMAConvOp::create(state->builder, loc, opType, lhs,
2258 concatRhs, acc, M, N, isSub);
2261 Op->replaceAllUsesWith(convOp);
2266static void fuseMulFMAOpsByMulFMAConv(func::FuncOp func, VectState *state) {
2267 func.walk([&](Operation *Op) {
2268 if (isa<aievec::aie1::FMAOp>(Op) && canFuseMulFMAOpsForInt16(Op))
2269 fuseMulFMAOpsForInt16(Op, state);
2279static void generateAIEMulOrFMAOpsInFunc(func::FuncOp func, VectState *state) {
2282 func.walk([&](Operation *op) {
2283 if (isa<MulIOp, MulFOp, vector::FMAOp>(op))
2284 generateSchemeBasedMulOrFMAOp(op, state);
2290static void generateAddOrSubOp(Operation *Op, AIEOpAttributes &opAttr,
2294 state->builder.setInsertionPointAfter(Op);
2297 Operation *repOp =
nullptr;
2298 if (isa<SubIOp, SubFOp>(Op)) {
2299 repOp = generateSubOp(Op, opAttr, state);
2300 LLVM_DEBUG(llvm::dbgs() <<
"\n\nGenerated AIE dialect sub op " << *repOp);
2302 repOp = generateAddOp(Op, opAttr, state);
2303 LLVM_DEBUG(llvm::dbgs() <<
"\n\nGenerated AIE dialect sub op " << *repOp);
2308 Op->replaceAllUsesWith(repOp);
2314static void generateSchemeBasedAddOrSubOp(Operation *Op, VectState *state) {
2317 AIEOpAttributes opAttr;
2321 if (isSimpleVectIntrinsic(Op, state)) {
2323 for (
size_t idx = 0; idx < 2; ++idx) {
2324 opAttr.start.push_back(
"");
2325 opAttr.offset.push_back(
"");
2326 opAttr.offset_hi.push_back(
"");
2327 opAttr.square.push_back(
"");
2329 generateAddOrSubOp(Op, opAttr, state);
2336 for (
size_t idx = 0; idx < 2; ++idx) {
2337 AIEVecAttributes stat = getOperandVecStats(Op, state, idx);
2338 assert(stat.elementSizeInBits >= 16 &&
2339 "advanced scheme for add op on int8 data type not supported");
2341 int32_t start = 0, accIncr = 1;
2342 std::string startStr;
2343 std::string offsetStr, offsetHiStr;
2344 std::string squareStr;
2348 if (stat.loadFromMemory) {
2349 Operation *op = Op->getOperand(idx).getDefiningOp();
2350 auto readOp = cast<TransferReadOp>(op);
2352 accIncr = stat.isSplat ? 0 : computeVecorizedLoopStepSize(readOp, state);
2358 if (stat.elementSizeInBits == 32) {
2359 startStr = std::to_string(start);
2361 for (
int i = 7; i >= 0; --i)
2364 if (stat.lanes > 8) {
2365 assert(stat.lanes == 16 &&
"Cannot generate offset for add/sub op");
2367 assert(accIncr <= 1 &&
"Cannot generate offset for given loop stride");
2369 for (
int i = 15; i >= 8; --i)
2372 }
else if (stat.elementSizeInBits == 16) {
2373 assert(accIncr <= 1 &&
"cannot generate offset for given loop stride");
2375 int32_t m2Offset = start % 2;
2376 startStr = std::to_string(start - m2Offset);
2380 offsetStr = offsetHiStr =
"0";
2383 for (
int i = 6; i >= 0; i -= 2) {
2384 offsetStr.push_back(
'0');
2385 offsetStr.push_back(
getHexValue((i * accIncr) / 2));
2388 for (
int i = 14; i >= 8; i -= 2) {
2389 offsetHiStr.push_back(
'0');
2390 offsetHiStr.push_back(
getHexValue((i * accIncr) / 2));
2395 if (m2Offset == 0 && accIncr == 0)
2398 assert(m2Offset == 0 || accIncr == 0);
2400 int32_t astep = std::min(1, accIncr);
2401 SmallVector<int32_t> sqPattern = {3 * astep, 2 * astep, astep, 0};
2402 for (
auto sq : sqPattern)
2406 llvm_unreachable(
"Cannot generate advanced add op for given datatype");
2409 opAttr.start.push_back(startStr);
2410 opAttr.offset.push_back(offsetStr);
2411 opAttr.offset_hi.push_back(offsetHiStr);
2412 opAttr.square.push_back(squareStr);
2415 generateAddOrSubOp(Op, opAttr, state);
2421static void generateAIEAddOrSubOpsInFunc(func::FuncOp func, VectState *state) {
2422 func.walk([&](Operation *op) {
2423 if (isa<AddIOp, AddFOp, SubIOp, SubFOp>(op))
2424 generateSchemeBasedAddOrSubOp(op, state);
2432static void insertUPDOpsInLoop(affine::AffineForOp forOp, VectState *state) {
2434 for (affine::AffineForOp nestedOp :
2435 forOp.getRegion().getOps<affine::AffineForOp>())
2436 insertUPDOpsInLoop(nestedOp, state);
2442 mlir::DenseMap<std::tuple<IntervalReuse *, int32_t, int32_t>,
2443 std::pair<aievec::UPDOp, int8_t>>
2448 mlir::DenseMap<Operation *, aievec::UPDOp> readOpToUpdMap;
2450 Region ®ion = forOp.getRegion();
2451 for (TransferReadOp readOp : region.getOps<TransferReadOp>()) {
2452 aievec::UPDOp updOp = generateUPDOp(readOp, memToUpdMap, region, state);
2453 readOpToUpdMap[readOp] = updOp;
2457 for (
auto &map : readOpToUpdMap) {
2458 Operation *op = map.first;
2459 op->replaceAllUsesWith(map.second);
2465static void insertUPDOpsInFunc(func::FuncOp func, VectState *state) {
2466 for (affine::AffineForOp forOp : func.getOps<affine::AffineForOp>()) {
2467 insertUPDOpsInLoop(forOp, state);
2474static void insertSRSOp(Operation *Op, VectState *state) {
2476 if (Op->use_empty() || Op->getNumResults() == 0)
2480 assert(writesToAccumulator(Op));
2485 auto isNonAIEOp = [&](Operation *op) {
return !
isAIEOp(op); };
2486 if (!any_of(Op->getUsers(), isNonAIEOp))
2491 mlir::DenseMap<Type, aievec::SRSOp> typeToSRSOpMap;
2494 state->builder.setInsertionPointAfter(Op);
2498 for (
auto user : Op->getUsers()) {
2506 MemRefType memRefType =
nullptr;
2507 if (
auto writeOp = dyn_cast<TransferWriteOp>(user)) {
2509 memRefType = cast<MemRefType>(writeOp.getBase().getType());
2510 scalarType = memRefType.getElementType();
2512 scalarType = getElementTypeOrSelf(*user->getResultTypes().begin());
2513 assert(scalarType &&
"failed to form SRS op");
2516 for (
auto operand : user->getOperands()) {
2517 if (operand.getDefiningOp() == Op) {
2520 if (state->aieml && memRefType &&
2521 cast<VectorType>(Op->getOperand(0).getType())
2523 .getIntOrFloatBitWidth() == 8 &&
2524 cast<VectorType>(Op->getResult(0).getType())
2526 .getIntOrFloatBitWidth() ==
2527 scalarType.getIntOrFloatBitWidth()) {
2531 aievec::CastOp castOp = generateCastOp(Op->getResult(0), castType,
2532 false, state, Op->getLoc());
2533 assert(castOp &&
"Failed to create Cast intrinsic");
2534 user->replaceUsesOfWith(operand, castOp);
2537 aievec::SRSOp srsOp;
2538 if (!typeToSRSOpMap.count(scalarType)) {
2540 generateSRSOp(Op->getResult(0), scalarType, state, Op->getLoc());
2541 LLVM_DEBUG(llvm::dbgs() <<
"\n\nCreated SRS op " << srsOp
2542 <<
" for the acc output of operation " << Op);
2543 typeToSRSOpMap[scalarType] = srsOp;
2545 srsOp = typeToSRSOpMap[scalarType];
2546 assert(srsOp &&
"Failed to create SRS intrinsic");
2548 user->replaceUsesOfWith(operand, srsOp);
2556static void insertSRSOpsInFunc(func::FuncOp func, VectState *state) {
2557 func.walk([&](Operation *op) {
2559 if (writesToAccumulator(op))
2560 insertSRSOp(op, state);
2567template <
typename TransferOp>
2568static void setInBounds(TransferOp op) {
2569 if (op.getTransferRank() == 0)
2571 SmallVector<bool, 4> bools(op.getTransferRank(),
true);
2572 OpBuilder b(op.getContext());
2573 op->setAttr(op.getInBoundsAttrName(), b.getBoolArrayAttr(bools));
2585static void redundantLoadStoreOptimization(ModuleOp module) {
2586 for (func::FuncOp func : module.getOps<func::FuncOp>()) {
2588 func.walk([&](Operation *Op) {
2589 if (
auto readOp = dyn_cast<TransferReadOp>(Op)) {
2590 if (!readOp.getInBounds())
2591 setInBounds<TransferReadOp>(readOp);
2592 }
else if (
auto writeOp = dyn_cast<TransferWriteOp>(Op)) {
2593 if (!writeOp.getInBounds())
2594 setInBounds<TransferWriteOp>(writeOp);
2599 IRRewriter rewriter(module.getContext());
2600 vector::transferOpflowOpt(rewriter, func);
2606static void preCanonicalizeIR(ModuleOp module) {
2607 PassManager pm(module.getContext());
2608 pm.addPass(createCanonicalizerPass());
2609 [[maybe_unused]]
bool success = pm.run(module).succeeded();
2611 redundantLoadStoreOptimization(module);
2617static void postCanonicalizeIR(ModuleOp module) {
2618 PassManager pm(module.getContext());
2619 pm.addPass(createCanonicalizerPass());
2620 pm.addPass(createCSEPass());
2621 pm.addPass(createLoopInvariantCodeMotionPass());
2622 pm.addPass(createLowerAffinePass());
2623 [[maybe_unused]]
bool success = pm.run(module).succeeded();
2630computeEnclosingLoopsPerBlock(affine::AffineForOp forOp, VectState *state,
2631 SmallVector<Operation *, 8> &enclosingLoops) {
2633 for (affine::AffineForOp nestedOp :
2634 forOp.getRegion().getOps<affine::AffineForOp>()) {
2635 enclosingLoops.push_back(nestedOp);
2636 computeEnclosingLoopsPerBlock(nestedOp, state, enclosingLoops);
2637 enclosingLoops.pop_back();
2642 for (TransferReadOp readOp : forOp.getRegion().getOps<TransferReadOp>()) {
2644 Block *block = readOp->getBlock();
2645 state->blockToEnclosingLoops[block] = enclosingLoops;
2654static void reassociateMulOpInFunc(func::FuncOp func, VectState *state) {
2655 func.walk([&](Operation *op) {
2658 if (isa<MulIOp, MulFOp, vector::FMAOp>(op) && isWellFormedVectorOp(op)) {
2660 reassociateMulOpWithSplat(op, state);
2663 reassociateMulOpBasedOnVecSize(op, state);
2673static void reassociateAddOpInFunc(func::FuncOp func, VectState *state) {
2674 func.walk([&](Operation *op) {
2676 if (isa<AddIOp, AddFOp>(op) && isWellFormedVectorOp(op)) {
2678 assert(op->getNumOperands() == 2 && op->getNumResults() == 1);
2681 Operation *rhsOp = getOperandDefOp(state, op, 1);
2683 state->sextTruncDefMap.count(op->getOperand(0).getDefiningOp())
2684 ? op->getOperand(0).getDefiningOp()->getOperand(0)
2685 : op->getOperand(0);
2687 state->sextTruncDefMap.count(op->getOperand(1).getDefiningOp())
2688 ? op->getOperand(1).getDefiningOp()->getOperand(0)
2689 : op->getOperand(1);
2691 if (!isa<MulIOp, MulFOp>(rhsOp)) {
2692 Operation *lhsOp = getOperandDefOp(state, op, 0);
2694 if (isa<MulIOp, MulFOp>(lhsOp)) {
2695 LLVM_DEBUG(llvm::dbgs() <<
"\n\nReassociating addOp " << *op
2696 <<
" to place mul as rhs operand");
2697 op->setOperand(0, right);
2698 op->setOperand(1, left);
2699 LLVM_DEBUG(llvm::dbgs() <<
"\n\taddOp after reassociation: " << *op);
2702 op->setOperand(0, left);
2703 op->setOperand(1, right);
2716static void coalesceLHSOpVectorsInFunc(func::FuncOp func, VectState *state) {
2718 func.walk([&](TransferReadOp op) {
2721 bool onlyLHS =
true;
2722 for (
auto user : op->getUsers()) {
2723 if (!isa<MulIOp, MulFOp, vector::FMAOp>(user) ||
2724 user->getOperand(0).getDefiningOp() != op) {
2740 for (
auto interval : state->reuseIntervals) {
2741 interval->coalesceIntervals();
2746static void recordSextOps(func::FuncOp func, VectState *state) {
2747 func.walk([&](ExtSIOp op) {
2748 state->sextTruncDefMap[op] = op->getOperand(0).getDefiningOp();
2750 func.walk([&](TruncIOp op) {
2751 state->sextTruncDefMap[op] = op->getOperand(0).getDefiningOp();
2757static void computeReuse(TransferReadOp readOp, VectState *state) {
2759 AffineExpr linearAccess = constructLinearizedAffineExpr(readOp, state);
2761 auto [base, offset] = getBaseAndOffset(linearAccess);
2764 int32_t step = computeVecorizedLoopStepSize(readOp, state);
2767 bool isSplat = readOp.getPermutationMap().isConstant();
2772 unsigned minVecSize = 128;
2773 for (
auto user : readOp->getUsers()) {
2774 if (isa<MulIOp, MulFOp, vector::FMAOp>(user)) {
2775 if (user->getOperand(0).getDefiningOp() == readOp ||
2776 user->getOperand(1).getDefiningOp() == readOp) {
2781 if (isa<ExtSIOp>(user)) {
2782 auto extsiOp = cast<ExtSIOp>(user);
2783 for (
auto consumer : extsiOp->getUsers()) {
2784 if (isa<MulIOp, MulFOp, vector::FMAOp>(consumer)) {
2785 if ((state->sextTruncDefMap.count(
2786 consumer->getOperand(0).getDefiningOp()) &&
2787 state->sextTruncDefMap[consumer->getOperand(0)
2788 .getDefiningOp()] == readOp) ||
2789 (state->sextTruncDefMap.count(
2790 consumer->getOperand(1).getDefiningOp()) &&
2791 state->sextTruncDefMap[consumer->getOperand(1)
2792 .getDefiningOp()] == readOp)) {
2801 auto vecType = cast<VectorType>(readOp.getVector().getType());
2814 for (
auto interval : state->reuseIntervals) {
2816 if (interval->potentialReuse(readOp, base, state->blockToEnclosingLoops)) {
2819 interval->insertInterval(readOp, state->opToIntervalMap, offset, step,
2820 isSplat, minVecSize);
2829 iv->
insertInterval(readOp, state->opToIntervalMap, offset, step, isSplat,
2831 state->reuseIntervals.push_back(iv);
2835static LogicalResult isUnalignedLoad(TransferReadOp readOp, VectState *state) {
2836 auto vectorType = cast<VectorType>(readOp.getResult().getType());
2839 AffineExpr linearAccess = constructLinearizedAffineExpr(readOp, state);
2840 if (linearAccess.isSymbolicOrConstant()) {
2844 auto memRefType = cast<MemRefType>(readOp.getBase().getType());
2845 MLIRContext *context = memRefType.getContext();
2846 ArrayRef<int64_t> sizes = memRefType.getShape();
2847 int numDims = sizes.size();
2849 auto block = readOp->getBlock();
2850 assert(state->blockToEnclosingLoops.count(block) &&
2851 "enclosing loops should have been computed for the read operation\n");
2852 auto enclosingLoops = state->blockToEnclosingLoops[block];
2854 SmallVector<Value, 4> indices(readOp.getIndices().begin(),
2855 readOp.getIndices().end());
2860 dyn_cast<AffineDimExpr>(getAffineDimExpr(numDims - 1, context))) {
2861 auto index = indices[dimExpr.getPosition()];
2864 for (
auto loop : enclosingLoops) {
2865 auto affineForOp = cast<affine::AffineForOp>(loop);
2866 auto iv = affineForOp.getInductionVar();
2867 auto invariants = affine::getInvariantAccesses(iv, indices);
2869 if (!invariants.count(index)) {
2870 int step = affineForOp.getStepAsInt();
2872 return readOp->emitError()
2873 <<
"Loop step of inner index of " << readOp->getName()
2874 <<
" is not divisible by number of vector lanes.";
2881 affine::AffineBound ub = affineForOp.getUpperBound();
2882 AffineMap origUbMap = ub.getMap();
2883 if (!origUbMap.isEmpty() && !origUbMap.isConstant()) {
2884 AffineExpr origUbMapResult = origUbMap.getResult(0);
2887 std::tie(base, offset) = getBaseAndOffset(origUbMapResult);
2888 if (offset % lanes) {
2889 return readOp->emitError()
2890 <<
"Loop upper bound's affine map offset of inner index of "
2891 << readOp->getName()
2892 <<
" is not divisible by number of vector lanes.";
2901 for (
int i = 1; i < numDims; ++i) {
2903 if (sizes[i] == -1) {
2907 if (sizes[i] % lanes) {
2908 return readOp->emitError()
2909 << readOp->getName() <<
"'s shape size of index " << i
2910 <<
" is not divisible by number of vector lanes.";
2917static LogicalResult hasUnalignedLoads(func::FuncOp func, VectState *state) {
2918 WalkResult result = func.walk([&](TransferReadOp op) {
2919 if (failed(isUnalignedLoad(op, state))) {
2920 return WalkResult::interrupt();
2922 return WalkResult::advance();
2925 if (result.wasInterrupted()) {
2937static void computeReuseInFunc(func::FuncOp func, VectState *state) {
2940 func.walk([&](TransferReadOp op) { computeReuse(op, state); });
2945static void rewriteFMAOpsInFunc(func::FuncOp func, VectState *state) {
2947 func.walk([&](Operation *Op) {
2948 if (isa<AddIOp, AddFOp, SubIOp, SubFOp>(Op) && isWellFormedVectorOp(Op)) {
2951 if (canFuseMulAndAddOrSubIntoFMAOp(Op, state))
2952 fuseMulAndAddOrSubIntoFMAOp(Op, state);
2959static void reassociateOpsInFunc(func::FuncOp func, VectState *state) {
2963 reassociateMulOpInFunc(func, state);
2968 reassociateAddOpInFunc(func, state);
2983 assert(shiftParam < 64 &&
"SRS shift parameter should be between 0 and 63");
2984 assert(zeroOffset < 128 &&
2985 "Zero offset in the filter should be between 0 and 127");
2986 assert(dupFactor < 128 &&
2987 "Duplicate offset in the filter should be between 0 and 127");
2989 ModuleOp
module = getOperation();
2992 preCanonicalizeIR(module);
2995 for (func::FuncOp func :
module.getOps<func::FuncOp>()) {
2997 bool aieml = ::AIEML;
2998 bool unallignedCheck = ::unalignedLoadsCheck;
2999 if (this->unalignedLoadsCheck.hasValue())
3000 unallignedCheck = this->unalignedLoadsCheck;
3001 if (this->aieml.hasValue())
3002 aieml = this->aieml;
3003 auto *state =
new VectState(func.getContext(), shiftParam, zeroOffset,
3004 dupFactor, unallignedCheck, aieml);
3007 recordSextOps(func, state);
3011 for (
auto forOp : func.getOps<affine::AffineForOp>()) {
3012 SmallVector<Operation *, 8> enclosingLoops;
3013 enclosingLoops.push_back(forOp);
3014 computeEnclosingLoopsPerBlock(forOp, state, enclosingLoops);
3018 if (state->unalignedLoadsCheck && failed(hasUnalignedLoads(func, state))) {
3019 func.emitError() <<
"Cannot apply aie-vectorize to " << func->getName()
3020 <<
" because alignment check has failed.\n";
3026 computeReuseInFunc(func, state);
3031 reassociateOpsInFunc(func, state);
3034 rewriteFMAOpsInFunc(func, state);
3037 coalesceLHSOpVectorsInFunc(func, state);
3040 fuseFMAOpsForColumnTopology(func, state);
3043 generateAIEMulOrFMAOpsInFunc(func, state);
3047 insertSRSOpsInFunc(func, state);
3052 generateAIEAddOrSubOpsInFunc(func, state);
3057 insertUPDOpsInFunc(func, state);
3061 fuseMulFMAOpsByMulFMAConv(func, state);
3066 postCanonicalizeIR(module);
3070 return std::make_unique<AIEVectorize>();
int32_t computeStartInAIEVec(Operation *op, VectState *state)
void insertInterval(mlir::vector::TransferReadOp readOp, llvm::DenseMap< mlir::Operation *, IntervalReuse * > &dataAccessToIntervalMap, int32_t offset, int32_t forLoopStepSize, bool isSplat=false, unsigned minVecSize=128)
void setAccessExtent(mlir::Operation *op, std::pair< int32_t, int32_t > &extent)
int32_t getIntervalWidth(mlir::Operation *op)
std::pair< int32_t, int32_t > getAccessExtent(mlir::Operation *op)
std::pair< int32_t, int32_t > getInterval(mlir::Operation *op)
void markLHSOperandVec(mlir::Operation *op)
std::shared_ptr< Value > value()
bool isPowerOfTwo(int32_t n)
int32_t getVectorSizeInBits(mlir::VectorType type)
unsigned getVectorLaneSize(mlir::VectorType type)
char getHexValue(int val)
std::unique_ptr< mlir::Pass > createAIEVectorizePass()
mlir::VectorType createVectorType(unsigned lanes, mlir::Type elementType)
int32_t getElementSizeInBits(mlir::VectorType type)
mlir::VectorType getVectorOpDestType(mlir::VectorType type, bool AIE2)
bool isAIEOp(mlir::Operation *op)
void runOnOperation() override
Generate AIE vector intrinsics for the current module.