21#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
22#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
23#include "mlir/Dialect/Affine/IR/AffineOps.h"
24#include "mlir/Dialect/Func/IR/FuncOps.h"
25#include "mlir/Dialect/MemRef/IR/MemRef.h"
26#include "mlir/Dialect/SCF/IR/SCF.h"
27#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
28#include "mlir/IR/TypeUtilities.h"
29#include "mlir/Pass/PassManager.h"
30#include "mlir/Transforms/Passes.h"
32#include "llvm/ADT/SmallSet.h"
37using namespace vector;
41#define DEBUG_TYPE "aie-vect"
43static llvm::cl::opt<bool>
44 unalignedLoadsCheck(
"unaligned-loads-check",
45 llvm::cl::desc(
"Enable the unaligned loads check"),
46 llvm::cl::init(
true));
48static llvm::cl::opt<bool> AIEML(
"aieml", llvm::cl::desc(
"AI Engine-ML"),
49 llvm::cl::init(
false));
60 SmallVector<IntervalReuse *, 16> reuseIntervals;
63 mlir::DenseMap<Operation *, IntervalReuse *> opToIntervalMap;
68 mlir::DenseMap<Operation *, AffineExpr> linearizedAccess;
72 mlir::DenseMap<Value, AffineExpr> indexToExprDimMap;
76 mlir::DenseMap<Block *, SmallVector<Operation *, 8>> blockToEnclosingLoops;
79 mlir::DenseMap<Operation *, Operation *> pairedOp;
87 mlir::DenseMap<Operation *, std::pair<int32_t, int32_t>> opToColOffsets;
89 mlir::DenseMap<Operation *, Operation *> sextTruncDefMap;
93 llvm::SmallSet<Operation *, 8> mscOps;
109 bool unalignedLoadsCheck, aieml;
112 VectState(MLIRContext *context, int8_t s, int32_t z, int32_t d,
113 bool unalignedLoadsCheck,
bool aieml)
114 : builder(context), shift(s), zeroOffset(z), dupFactor(d),
115 unalignedLoadsCheck(unalignedLoadsCheck), aieml(aieml) {}
121IntervalReuse *VectState::getIntervalForOperation(Operation *op) {
122 assert(opToIntervalMap.count(op) &&
123 "could not find the IntervalReuse object for op");
124 return opToIntervalMap[op];
129struct AIEOpAttributes {
131 SmallVector<std::string, 2> start;
132 SmallVector<std::string, 2> offset, offset_hi;
133 SmallVector<std::string, 2> step;
134 SmallVector<std::string, 2> square;
138struct AIEVecAttributes {
144 int32_t vecSizeInBits;
148 int32_t elementSizeInBits;
154 AIEVecAttributes(
unsigned l,
unsigned vs, Type et, int32_t es)
155 : lanes(l), vecSizeInBits(vs), elementType(et), elementSizeInBits(es),
156 loadFromMemory(false), isSplat(false) {}
166 int32_t xbits, zbits;
168 Scheme(int32_t l, int32_t c, int32_t x, int32_t z)
169 : lanes(l), cols(c), xbits(x), zbits(z) {}
178static AIEVecAttributes getVectorStats(VectorType type) {
184static AIEVecAttributes getResultVecStats(Operation *op,
unsigned idx = 0) {
185 auto vtype = cast<VectorType>(op->getResult(idx).getType());
186 return getVectorStats(vtype);
189static Operation *getOperandDefOp(VectState *state, Operation *op,
191 return state->sextTruncDefMap.count(op->getOperand(idx).getDefiningOp())
192 ? state->sextTruncDefMap[op->getOperand(idx).getDefiningOp()]
193 : op->getOperand(idx).getDefiningOp();
197static AIEVecAttributes getOperandVecStats(Operation *op, VectState *state,
199 assert(op->getNumOperands() > idx);
200 Operation *defOp = getOperandDefOp(state, op, idx);
201 auto vtype = cast<VectorType>(defOp->getResult(0).getType());
202 auto ret = getVectorStats(vtype);
204 if (
auto readOp = dyn_cast<TransferReadOp>(defOp)) {
208 ret.loadFromMemory =
true;
210 ret.isSplat = readOp.getPermutationMap().isConstant();
216static std::pair<int32_t, int32_t> getNumRowsAndCols(Operation *op,
218 assert(op->getNumOperands() >= 2 && op->getNumResults() == 1);
220 Operation *left = getOperandDefOp(state, op, 0);
221 Operation *right = getOperandDefOp(state, op, 1);
224 auto vtype = cast<VectorType>(op->getResult(0).getType());
228 auto ltype = cast<VectorType>(left->getResult(0).getType());
229 auto rtype = cast<VectorType>(right->getResult(0).getType());
233 int32_t width = (lsize == 8 && rsize == 8) ? (state->aieml ? 256 : 128)
234 : (lsize == 16 && rsize == 8) ? 64
247 int32_t cols = width / (m * lanes);
248 return std::make_pair(lanes, cols);
258static void fuseAccessExtent(Operation *Op1, Operation *Op2, VectState *state) {
262 (isa<vector::FMAOp>(Op2) && isa<MulIOp, MulFOp, vector::FMAOp>(Op1));
263 if (!expectedTypes) {
264 printf(
"incorrect operation types\n");
271 for (
int idx = 0; idx < 2; ++idx) {
272 Operation *op1 = getOperandDefOp(state, Op1, idx);
273 Operation *op2 = getOperandDefOp(state, Op2, idx);
277 if (isa<TransferReadOp>(op1) && isa<TransferReadOp>(op2)) {
287 std::make_pair(std::min(op1Extent.first, op2Extent.first),
288 std::max(op1Extent.second, op2Extent.second));
300static bool isSimpleVectIntrinsic(Operation *Op, VectState *state) {
302 bool isMulOrFMAOp = isa<MulIOp, MulFOp, vector::FMAOp>(Op);
303 bool isSubOrAddOp = isa<SubIOp, SubFOp, AddIOp, AddFOp>(Op);
304 if (!isMulOrFMAOp && !isSubOrAddOp)
308 AIEVecAttributes vstat = getResultVecStats(Op);
309 AIEVecAttributes lstat = getOperandVecStats(Op, state, 0);
310 AIEVecAttributes rstat = getOperandVecStats(Op, state, 1);
312 bool sizeMatches = lstat.vecSizeInBits == rstat.vecSizeInBits &&
313 vstat.vecSizeInBits == rstat.vecSizeInBits &&
314 lstat.elementType == rstat.elementType &&
315 vstat.elementType == rstat.elementType;
316 bool noSplat = !lstat.isSplat && !rstat.isSplat;
317 bool noFloat = !isa<FloatType>(vstat.elementType) &&
318 !isa<FloatType>(lstat.elementType) &&
319 !isa<FloatType>(rstat.elementType);
321 return sizeMatches && noSplat && (isSubOrAddOp || noFloat);
328static bool isWellFormedVectorOp(Operation *Op) {
330 if (Op->getNumOperands() == 0 && Op->getNumResults() == 0)
333 SmallVector<Value, 8> operandsAndResults;
334 operandsAndResults.append(Op->operand_begin(), Op->operand_end());
335 operandsAndResults.append(Op->result_begin(), Op->result_end());
338 for (
auto val : operandsAndResults) {
339 if (!isa<VectorType>(val.getType()))
343 auto refType = cast<VectorType>(operandsAndResults.back().getType());
344 Type scalarType = refType.getElementType();
346 for (
auto val : operandsAndResults) {
347 auto vtype = cast<VectorType>(val.getType());
352 if (scalarType != vtype.getElementType())
361static bool writesToAccumulator(Operation *op) {
365 if (
auto mulOp = dyn_cast<aievec::aie1::MulOp>(op))
366 return isa<IntegerType>(
367 cast<VectorType>(mulOp.getResult().getType()).getElementType());
368 if (
auto fmaOp = dyn_cast<aievec::aie1::FMAOp>(op))
369 return isa<IntegerType>(
370 cast<VectorType>(fmaOp.getResult().getType()).getElementType());
372 return isa<aievec::FMAElemOp, aievec::MulElemOp, aievec::FMAConvOp,
373 aievec::MulConvOp, aievec::UPSOp>(op);
383static AffineExpr makeFlattenedStridedExpr(ArrayRef<int64_t> sizes,
384 ArrayRef<AffineExpr> exprs,
385 MLIRContext *context) {
386 assert(!sizes.empty() && !exprs.empty() &&
387 "expected non-empty sizes and exprs");
390 if (llvm::is_contained(sizes, 0))
391 return getAffineConstantExpr(0, context);
393 auto maps = AffineMap::inferFromExprList(exprs, context);
394 assert(!maps.empty() &&
"Expected one non-empty map");
395 unsigned nSymbols = maps[0].getNumSymbols();
398 bool dynamicPoisonBit =
false;
399 int64_t runningSize = 1;
400 for (
auto en :
llvm::zip(
llvm::reverse(exprs),
llvm::reverse(sizes))) {
401 int64_t size = std::get<1>(en);
405 AffineExpr dimExpr = std::get<0>(en);
406 AffineExpr stride = dynamicPoisonBit
407 ? getAffineSymbolExpr(nSymbols++, context)
408 : getAffineConstantExpr(runningSize, context);
409 expr = expr ? expr + dimExpr * stride : dimExpr * stride;
412 assert(runningSize > 0 &&
"integer overflow in size computation");
414 dynamicPoisonBit =
true;
421static AffineExpr constructLinearizedAffineExpr(TransferReadOp readOp,
425 if (state->linearizedAccess.count(readOp))
426 return state->linearizedAccess[readOp];
428 SmallVector<Value, 4> indices(readOp.getIndices().begin(),
429 readOp.getIndices().end());
430 auto memRefType = cast<MemRefType>(readOp.getSource().getType());
431 MLIRContext *context = memRefType.getContext();
433 SmallVector<AffineExpr, 8> exprVec;
437 for (
auto idxAndValue :
llvm::enumerate(indices)) {
438 auto value = idxAndValue.value();
442 if (
auto apOf = value.getDefiningOp<affine::AffineApplyOp>()) {
443 AffineMap map = apOf.getAffineMap();
444 assert(map.getNumResults() == 1 &&
445 "Failed to create linearized affineExpr for complicated index");
446 SmallVector<AffineExpr, 4> indexExprs;
449 for (
auto index : apOf.getMapOperands()) {
450 if (
auto cIdx = index.getDefiningOp<arith::ConstantOp>()) {
451 auto idxVal = cast<IntegerAttr>(cIdx.getValue()).getValue();
452 unsigned idx = idxVal.getSExtValue();
453 indexExprs.push_back(getAffineConstantExpr(idx, context));
455 if (!state->indexToExprDimMap.count(index))
456 state->indexToExprDimMap[index] =
457 getAffineDimExpr(state->indexToExprDimMap.size(), context);
458 indexExprs.push_back(state->indexToExprDimMap[index]);
462 exprVec.push_back(map.getResult(0).replaceDims(indexExprs));
466 else if (
auto cOp = value.getDefiningOp<arith::ConstantOp>()) {
467 auto idxVal = cast<IntegerAttr>(cOp.getValue()).getValue();
468 unsigned idx = idxVal.getSExtValue();
469 exprVec.push_back(getAffineConstantExpr(idx, context));
473 if (!state->indexToExprDimMap.count(value))
474 state->indexToExprDimMap[value] =
475 getAffineDimExpr(state->indexToExprDimMap.size(), context);
476 exprVec.push_back(state->indexToExprDimMap[value]);
480 assert(!exprVec.empty() &&
"Could not construct linearized affineExpr");
483 auto ret = makeFlattenedStridedExpr(memRefType.getShape(), exprVec,
484 memRefType.getContext());
486 state->linearizedAccess[readOp] = ret;
494static std::pair<AffineExpr, int32_t> getBaseAndOffset(AffineExpr expr) {
495 AffineExpr base = expr;
498 if (
auto constExpr = llvm::dyn_cast<AffineConstantExpr>(expr)) {
500 offset += constExpr.getValue();
505 else if (
auto binopExpr = llvm::dyn_cast<AffineBinaryOpExpr>(expr)) {
506 if (binopExpr.getKind() == AffineExprKind::Add) {
507 AffineExpr lhs = binopExpr.getLHS(), rhs = binopExpr.getRHS();
508 if (
auto constExpr = llvm::dyn_cast<AffineConstantExpr>(lhs)) {
510 offset += constExpr.getValue();
512 if (
auto constExpr = llvm::dyn_cast<AffineConstantExpr>(rhs)) {
513 base = base == rhs ? nullptr : lhs;
514 offset += constExpr.getValue();
518 return std::make_pair(base, offset);
525static aievec::CastOp generateCastOp(Value source, VectorType resType,
526 bool isResAcc, VectState *state,
530 state->builder.create<aievec::CastOp>(loc, resType, source, isResAcc);
532 assert(castOp &&
"could not create srs op");
538static aievec::SRSOp generateSRSOp(Value source, Type scalarType,
539 VectState *state, Location loc) {
541 Type accType = source.getType();
542 assert(writesToAccumulator(source.getDefiningOp()) &&
543 "srs source should write to accumulator");
550 auto shiftParamOp = state->builder.create<arith::ConstantOp>(
551 loc, state->builder.getI32IntegerAttr(state->shift));
553 auto srsOp = state->builder.create<aievec::SRSOp>(loc, srsType, source,
554 shiftParamOp.getResult());
556 assert(srsOp &&
"could not create srs op");
562static aievec::UPSOp generateUPSOp(Value source, VectState *state,
564 Type sourceType = source.getType();
567 assert(!writesToAccumulator(source.getDefiningOp()) &&
568 "ups source should not be accumulator");
572 state->builder.create<aievec::UPSOp>(loc, accType, source, state->shift);
574 assert(upsOp &&
"could not create ups op");
579static aievec::BroadcastOp generateBroadcastOp(Value source, int8_t idx,
580 VectState *state, Location loc) {
581 auto type = cast<VectorType>(source.getType());
584 state->builder.create<aievec::BroadcastOp>(loc, type, source, idx);
586 assert(broadcastOp &&
"could not create broadcast op");
591static aievec::ConcatOp generateConcatOp(SmallVector<Value> &sources,
592 VectState *state, Location loc,
593 VectorType concatType =
nullptr) {
594 assert(sources.size() > 1 &&
"must concat at least two vectors");
596 auto vecType = cast<VectorType>(sources.back().getType());
599 for (
auto source : sources) {
600 auto type = cast<VectorType>(source.getType());
601 if (type != vecType) {
602 printf(
"sources of concat op not of same type\n");
612 Type scalarType = vecType.getElementType();
618 state->builder.create<aievec::ConcatOp>(loc, concatType, sources);
620 assert(concatOp &&
"could not create concat op");
626static aievec::aie1::SelectOp
627generateSelectOp(Value xbuff, AIEOpAttributes &opAttr,
unsigned lanes,
628 VectState *state, Location loc, Value ybuff =
nullptr) {
631 assert(!opAttr.select.empty());
632 assert(opAttr.start.size() == opAttr.offset.size() &&
633 opAttr.start.size() == 2);
635 auto xtype = cast<VectorType>(xbuff.getType());
642 auto selectOp = state->builder.create<aievec::aie1::SelectOp>(
643 loc, resultType, xbuff, opAttr.select, opAttr.start[0], opAttr.offset[0],
644 opAttr.offset_hi[0], opAttr.square[0], opAttr.start[1], opAttr.offset[1],
645 opAttr.offset_hi[1], opAttr.square[1], ybuff);
647 assert(selectOp &&
"could not create select op");
653static aievec::aie1::ExtOp generateExtOp(Value source,
unsigned lanes,
654 int8_t idx, VectState *state,
656 auto stype = cast<VectorType>(source.getType());
664 state->builder.create<aievec::aie1::ExtOp>(loc, resultType, source, idx);
666 assert(extOp &&
"could not create ext op");
671static aievec::PackOp generatePackOp(Value source, VectState *state,
674 auto stype = cast<VectorType>(source.getType());
676 Type i8Type = IntegerType::get(source.getContext(), 8);
680 auto packOp = state->builder.create<aievec::PackOp>(loc, resultType, source);
682 assert(packOp &&
"could not create pack op");
687static aievec::aie1::AddOp generateAddOp(Operation *Op, AIEOpAttributes &opAttr,
690 assert(opAttr.start.size() == opAttr.offset.size() &&
691 opAttr.start.size() == 2);
693 auto addOp = state->builder.create<aievec::aie1::AddOp>(
694 Op->getLoc(), Op->getResult(0).getType(), Op->getOperand(0),
695 Op->getOperand(1), opAttr.start[0], opAttr.offset[0], opAttr.offset_hi[0],
696 opAttr.square[0], opAttr.start[1], opAttr.offset[1], opAttr.offset_hi[1],
702static aievec::aie1::SubOp generateSubOp(Operation *Op, AIEOpAttributes &opAttr,
705 assert(opAttr.start.size() == opAttr.offset.size() &&
706 opAttr.start.size() == 2);
708 auto subOp = state->builder.create<aievec::aie1::SubOp>(
709 Op->getLoc(), Op->getResult(0).getType(), Op->getOperand(0),
710 Op->getOperand(1), opAttr.start[0], opAttr.offset[0], opAttr.offset_hi[0],
711 opAttr.square[0], opAttr.start[1], opAttr.offset[1], opAttr.offset_hi[1],
716static aievec::ShiftOp generateShiftOp(Value lhs, Value rhs, int32_t shiftBytes,
717 VectState *state, Location loc,
718 VectorType resType =
nullptr) {
719 auto vecType = cast<VectorType>(rhs.getType());
722 auto type = cast<VectorType>(lhs.getType());
723 if (type != vecType) {
724 printf(
"lhs and rhs do not have same type\n");
732 Type scalarType = vecType.getElementType();
736 auto constOp = state->builder.create<arith::ConstantOp>(
737 loc, state->builder.getI32IntegerAttr(shiftBytes));
738 auto shiftOp = state->builder.create<aievec::ShiftOp>(loc, resType, lhs, rhs,
739 constOp.getResult());
744static aievec::LegacyShuffleOp generateShuffleOp(Value source, VectState *state,
745 Location loc,
unsigned mode,
746 VectorType resType =
nullptr) {
747 auto vecType = cast<VectorType>(source.getType());
751 Type scalarType = vecType.getElementType();
755 auto shuffleOp = state->builder.create<aievec::LegacyShuffleOp>(loc, resType,
764static Operation *generateMulOrFMAConvOpForInt8(Operation *Op,
765 AIEOpAttributes &opAttr,
769 assert(opAttr.start.size() == opAttr.offset.size() &&
770 opAttr.start.size() == 2 && state->dupFactor == 2);
772 Value lhs = state->sextTruncDefMap.count(Op->getOperand(1).getDefiningOp())
773 ? Op->getOperand(1).getDefiningOp()->getOperand(0)
775 Value rhs = state->sextTruncDefMap.count(Op->getOperand(0).getDefiningOp())
776 ? Op->getOperand(0).getDefiningOp()->getOperand(0)
778 auto vType = cast<VectorType>(lhs.getType());
779 Type stype = vType.getElementType();
780 auto itype = cast<IntegerType>(stype);
781 unsigned width = itype.getWidth() <= 8 ? 32 : 64;
785 Type ctype = IntegerType::get(itype.getContext(), width);
786 Type opType = VectorType::get(vType.getShape(), ctype);
787 auto defOp = rhs.getDefiningOp();
788 state->builder.setInsertionPointAfter(defOp);
789 Location loc = defOp->getLoc();
794 Operation *shuffleOp = generateShuffleOp(defOp->getResult(0), state, loc, 0);
801 state->builder.setInsertionPointAfter(shuffleOp);
802 loc = shuffleOp->getLoc();
803 rhs = generateShiftOp(shuffleOp->getResult(0), shuffleOp->getResult(0),
804 shiftBytes, state, loc);
806 rhs = shuffleOp->getResult(0);
809 state->builder.setInsertionPoint(Op);
812 Operation *convOp =
nullptr;
814 if (isa<MulIOp>(Op)) {
816 state->builder.create<aievec::MulConvOp>(loc, opType, lhs, rhs, M, N);
819 if (isa<vector::FMAOp>(Op)) {
820 Value acc = Op->getOperand(2);
821 bool isSub = state->mscOps.count(Op);
822 convOp = state->builder.create<aievec::FMAConvOp>(loc, opType, lhs, rhs,
833static Operation *generateFMAOp(vector::FMAOp fmaOp, AIEOpAttributes &opAttr,
834 VectState *state,
bool i8xi8_pairedOp =
false) {
837 assert(opAttr.start.size() == opAttr.offset.size() &&
838 opAttr.start.size() == 2);
840 Value lhs = state->sextTruncDefMap.count(fmaOp.getLhs().getDefiningOp())
841 ? fmaOp.getLhs().getDefiningOp()->getOperand(0)
843 Value rhs = state->sextTruncDefMap.count(fmaOp.getRhs().getDefiningOp())
844 ? fmaOp.getRhs().getDefiningOp()->getOperand(0)
846 Value acc = state->sextTruncDefMap.count(fmaOp.getAcc().getDefiningOp())
847 ? fmaOp.getAcc().getDefiningOp()->getOperand(0)
851 bool isSub = state->mscOps.count(fmaOp);
855 bool isInt = isa<IntegerType>(
856 cast<VectorType>(fmaOp.getLhs().getType()).getElementType());
861 if (!writesToAccumulator(acc.getDefiningOp())) {
862 acc = generateUPSOp(acc, state, fmaOp->getLoc());
863 LLVM_DEBUG(llvm::dbgs()
864 <<
"\n\nCreated UPS op " << acc <<
" to move the output of "
865 << fmaOp <<
" into accumulator");
868 if (!isSimpleVectIntrinsic(fmaOp, state)) {
872 AIEVecAttributes rstat = getOperandVecStats(fmaOp, state, 1);
874 rhs = generateBroadcastOp(rhs, stoi(opAttr.start[1]), state,
879 xfmaOp = state->builder.create<aievec::FMAElemOp>(fmaOp->getLoc(), lhs, rhs,
884 if (i8xi8_pairedOp) {
885 Operation *defOp = acc.getDefiningOp();
886 if (state->pairedOp.count(defOp))
887 acc = state->pairedOp[defOp]->getResult(0);
890 if (isInt && !writesToAccumulator(acc.getDefiningOp())) {
891 acc = generateUPSOp(acc, state, fmaOp->getLoc());
892 LLVM_DEBUG(llvm::dbgs()
893 <<
"\n\nCreated UPS op " << acc <<
" to move the output of "
894 << fmaOp <<
" into accumulator");
899 if (!isSimpleVectIntrinsic(fmaOp, state)) {
900 AIEVecAttributes lstat = getOperandVecStats(fmaOp, state, 0);
901 assert(lstat.vecSizeInBits % 256 == 0);
903 if (lstat.vecSizeInBits == 256) {
904 VectorType concatType =
906 SmallVector<Value> sources = {lhs, lhs};
907 lhs = generateConcatOp(sources, state, fmaOp->getLoc(), concatType);
911 xfmaOp = state->builder.create<aievec::aie1::FMAOp>(
912 fmaOp->getLoc(), lhs, rhs, acc, opAttr.start[0], opAttr.offset[0],
913 opAttr.offset_hi[0], opAttr.step[0], opAttr.square[0], opAttr.start[1],
914 opAttr.offset[1], opAttr.offset_hi[1], opAttr.step[1], opAttr.square[1],
918 assert(xfmaOp &&
"could not create fma op");
925static Operation *generateMulOp(T mulOp, AIEOpAttributes &opAttr,
929 assert(opAttr.start.size() == opAttr.offset.size() &&
930 opAttr.start.size() == 2);
937 Value lhs = state->sextTruncDefMap.count(mulOp.getLhs().getDefiningOp())
938 ? mulOp.getLhs().getDefiningOp()->getOperand(0)
940 Value rhs = state->sextTruncDefMap.count(mulOp.getRhs().getDefiningOp())
941 ? mulOp.getRhs().getDefiningOp()->getOperand(0)
943 if (!isSimpleVectIntrinsic(mulOp, state)) {
944 AIEVecAttributes lstat = getOperandVecStats(mulOp, state, 0);
945 assert(lstat.vecSizeInBits % 256 == 0);
946 if (lstat.vecSizeInBits == 256) {
947 VectorType concatType =
949 SmallVector<Value> sources = {lhs, lhs};
950 lhs = generateConcatOp(sources, state, mulOp->getLoc(), concatType);
955 Operation *xmulOp = state->builder.create<aievec::aie1::MulOp>(
956 mulOp->getLoc(), lhs, rhs, opType, opAttr.start[0], opAttr.offset[0],
957 opAttr.offset_hi[0], opAttr.step[0], opAttr.square[0], opAttr.start[1],
958 opAttr.offset[1], opAttr.offset_hi[1], opAttr.step[1], opAttr.square[1]);
960 assert(xmulOp &&
"could not create mul op");
970generateUPDOp(TransferReadOp readOp,
971 mlir::DenseMap<std::tuple<IntervalReuse *, int32_t, int32_t>,
972 std::pair<aievec::UPDOp, int8_t>> &memToUpdMap,
973 Region ®ion, VectState *state) {
979 int32_t intervalWidth = interval.second - interval.first;
980 assert(intervalWidth >= 128 &&
"Interval computation incorrect");
985 auto vecType = cast<VectorType>(readOp.getVector().getType());
986 Type elementType = vecType.getElementType();
988 int intervalWidthInBytes = intervalWidth / elementSizeInBits;
995 int32_t mid = interval.first + intervalWidth / 2;
999 intervalWidth <= (state->aieml && elementSizeInBits == 8 ? 512 : 256) ||
1004 intervalWidth <= (state->aieml && elementSizeInBits == 8 ? 512 : 256) ||
1010 aievec::UPDOp updOp =
nullptr;
1013 int8_t updIndices = 0;
1014 auto key = std::make_tuple(iv, interval.first, interval.second);
1015 if (memToUpdMap.count(key)) {
1016 updOp = memToUpdMap[key].first;
1017 updIndices = memToUpdMap[key].second;
1027 SmallVector<Value, 4> indices(readOp.getIndices().begin(),
1028 readOp.getIndices().end());
1030 AffineExpr linearAccess = constructLinearizedAffineExpr(readOp, state);
1032 auto [base, offset] = getBaseAndOffset(linearAccess);
1033 offset *= elementSizeInBits;
1040 bool singleBlock = region.getBlocks().size() == 1;
1042 state->builder.setInsertionPoint(readOp);
1044 state->builder.setInsertionPointToStart(®ion.front());
1049 int width = state->aieml ? elementSizeInBits == 8
1053 int32_t incr = std::max(width, intervalWidth / 2);
1055 for (int32_t start = interval.first; start < interval.second;
1056 start += incr, ++idx) {
1059 assert(idx <= 2 &&
"The only allowed values for UPD index are 0 and 1");
1060 int32_t end = std::min(interval.second, start + incr);
1064 if (lb <= start && ub >= end && (updIndices & idx) == 0) {
1067 updOp = state->builder.create<aievec::UPDOp>(
1068 readOp.getLoc(), updVecType, readOp.getSource(), indices,
1069 start - offset, idx - 1,
1070 updOp ? updOp.getResult() : TypedValue<VectorType>(
nullptr));
1072 LLVM_DEBUG(llvm::dbgs() <<
"\n\nCreated UPD op " << updOp
1073 <<
" for read op " << readOp);
1077 for (
auto &value : indices) {
1078 if (
auto apOf = value.getDefiningOp<affine::AffineApplyOp>()) {
1080 if (apOf->getBlock() == readOp->getBlock() &&
1081 apOf->isBeforeInBlock(updOp))
1083 apOf.getOperation()->moveBefore(updOp);
1093 memToUpdMap[key] = std::make_pair(updOp, updIndices);
1103static int32_t computeVecorizedLoopStepSize(Operation *op, VectState *state) {
1104 auto readOp = dyn_cast<TransferReadOp>(op);
1110 auto vectorType = cast<VectorType>(readOp.getResult().getType());
1111 SmallVector<Value, 4> indices(readOp.getIndices().begin(),
1112 readOp.getIndices().end());
1113 assert(vectorType && !indices.empty());
1116 auto block = readOp->getBlock();
1117 assert(state->blockToEnclosingLoops.count(block) &&
1118 "enclosing loops should have been computed for the read operation");
1119 auto enclosingLoops = state->blockToEnclosingLoops[block];
1123 AffineExpr expr = readOp.getPermutationMap().getResults().back();
1124 if (
auto dimExpr = llvm::dyn_cast<AffineDimExpr>(expr)) {
1125 assert(dimExpr.getPosition() <= indices.size() &&
1126 "Failed to find the permutation index in index map");
1127 auto index = indices[dimExpr.getPosition()];
1130 [[maybe_unused]]
bool found =
false;
1131 for (
auto loop : enclosingLoops) {
1132 auto iv = cast<affine::AffineForOp>(loop).getInductionVar();
1133 auto invariants = affine::getInvariantAccesses(iv, indices);
1134 if (!invariants.count(index)) {
1137 "stepsize computation already has an entry along the variant dim");
1138 step = cast<affine::AffineForOp>(loop).getStepAsInt();
1144 "non-power-of-two vectorization factor not supported");
1148 return step / lanes;
1157 if (!isa<TransferReadOp>(op))
1160 auto readOp = cast<TransferReadOp>(op);
1163 auto vtype = cast<VectorType>(readOp.getVector().getType());
1167 AffineExpr linearAccess = constructLinearizedAffineExpr(readOp, state);
1169 auto [base, offset] = getBaseAndOffset(linearAccess);
1170 offset *= scalarSizeInBits;
1173 std::pair<int32_t, int32_t> interval = iv->
getInterval(op);
1179 assert(offset >= interval.first &&
"Failed to compute the start");
1180 return (offset - interval.first) / scalarSizeInBits;
1189static Operation *concatAndInterleave_i8xi8(Operation *source1,
1191 VectState *state, Location loc) {
1197 IntegerType::get(source1->getResult(0).getType().getContext(), 16);
1198 auto srsOp1 = generateSRSOp(source1->getResult(0), i16Type, state, loc);
1199 auto srsOp2 = generateSRSOp(source2->getResult(0), i16Type, state, loc);
1202 SmallVector<Value> sources = {srsOp1->getResult(0), srsOp2->getResult(0)};
1203 auto concatOp = generateConcatOp(sources, state, loc);
1207 AIEOpAttributes opAttr;
1210 opAttr.select =
"0xcccccccc";
1213 opAttr.start.push_back(
"0");
1214 opAttr.start.push_back(
"4");
1215 for (
size_t idx = 0; idx < 2; ++idx) {
1221 opAttr.offset.push_back(
"0x0c080400");
1224 opAttr.offset_hi.push_back(
"0x0");
1226 opAttr.square.push_back(
"0x1010");
1230 generateSelectOp(concatOp->getResult(0), opAttr, 32, state, loc);
1233 auto extOp = generateExtOp(selectOp->getResult(0), 16, 0, state, loc);
1235 auto packOp = generatePackOp(extOp->getResult(0), state, loc);
1242static bool canFuseMulAndAddOrSubIntoFMAOp(Operation *Op, VectState *state) {
1244 assert((isa<AddIOp>(Op) || isa<AddFOp>(Op) || isa<SubIOp>(Op) ||
1246 "operation must be an add or sub op");
1249 assert(Op->getNumOperands() == 2 && Op->getNumResults() == 1);
1254 Operation *mulOp = getOperandDefOp(state, Op, 1);
1255 if (!isa<MulIOp, MulFOp>(mulOp))
1259 assert(mulOp->getNumOperands() == 2 && mulOp->getNumResults() == 1);
1262 Value lhs = state->sextTruncDefMap.count(mulOp->getOperand(0).getDefiningOp())
1263 ? mulOp->getOperand(0).getDefiningOp()->getOperand(0)
1264 : mulOp->getOperand(0);
1265 Value rhs = state->sextTruncDefMap.count(mulOp->getOperand(1).getDefiningOp())
1266 ? mulOp->getOperand(1).getDefiningOp()->getOperand(0)
1267 : mulOp->getOperand(1);
1268 Value acc = state->sextTruncDefMap.count(Op->getOperand(0).getDefiningOp())
1269 ? Op->getOperand(0).getDefiningOp()->getOperand(0)
1270 : Op->getOperand(0);
1272 assert(lhs && rhs && acc &&
1273 "Failed to find the three operands of the FMA op");
1276 if (!isa<VectorType>(lhs.getType()) || !isa<VectorType>(rhs.getType()) ||
1277 !isa<VectorType>(acc.getType()))
1282 if (lhs.getParentBlock() != rhs.getParentBlock() ||
1283 rhs.getParentBlock() != acc.getParentBlock())
1287 auto lhsType = cast<VectorType>(lhs.getType());
1288 auto rhsType = cast<VectorType>(rhs.getType());
1289 VectorType accType = state->sextTruncDefMap.count(
1290 acc.getDefiningOp()->getOperand(0).getDefiningOp())
1291 ? cast<VectorType>(acc.getDefiningOp()
1296 : cast<VectorType>(acc.getType());
1302 if (lhsVecSize != rhsVecSize || rhsVecSize != accVecSize)
1307 if (lhsType.getElementType() != rhsType.getElementType() ||
1308 rhsType.getElementType() != accType.getElementType())
1320static void reassociateMulOpBasedOnVecSize(Operation *Op, VectState *state) {
1322 AIEVecAttributes lstat = getOperandVecStats(Op, state, 0);
1323 AIEVecAttributes rstat = getOperandVecStats(Op, state, 1);
1326 if (lstat.vecSizeInBits == rstat.vecSizeInBits)
1330 bool is8x8 = lstat.elementSizeInBits == 8 && rstat.elementSizeInBits == 8;
1333 bool flip = is8x8 ? lstat.vecSizeInBits > rstat.vecSizeInBits
1334 : rstat.vecSizeInBits > lstat.vecSizeInBits;
1336 LLVM_DEBUG(llvm::dbgs()
1337 <<
"\n\nReassociating op " << *Op
1338 <<
" to correctly place operand coming from bigger vector");
1339 Value left = Op->getOperand(0);
1340 Value right = Op->getOperand(1);
1341 Op->setOperand(0, right);
1342 Op->setOperand(1, left);
1343 LLVM_DEBUG(llvm::dbgs() <<
"\n\tOp after reassociation: " << *Op);
1350static void reassociateMulOpWithSplat(Operation *Op, VectState *state) {
1353 assert(Op->getNumOperands() == 2 || Op->getNumOperands() == 3);
1354 assert(Op->getNumResults() == 1);
1357 AIEVecAttributes lstat = getOperandVecStats(Op, state, 0);
1358 AIEVecAttributes rstat = getOperandVecStats(Op, state, 1);
1361 if (lstat.isSplat && rstat.isSplat)
1365 bool is8x8 = lstat.elementSizeInBits == 8 && rstat.elementSizeInBits == 8;
1369 bool flip = is8x8 ? rstat.isSplat : lstat.isSplat;
1370 Value left = state->sextTruncDefMap.count(Op->getOperand(0).getDefiningOp())
1371 ? Op->getOperand(0).getDefiningOp()->getOperand(0)
1372 : Op->getOperand(0);
1373 Value right = state->sextTruncDefMap.count(Op->getOperand(1).getDefiningOp())
1374 ? Op->getOperand(1).getDefiningOp()->getOperand(0)
1375 : Op->getOperand(1);
1377 LLVM_DEBUG(llvm::dbgs() <<
"\n\nReassociating op " << *Op
1378 <<
" to place splat as correct operand");
1379 Op->setOperand(0, right);
1380 Op->setOperand(1, left);
1381 LLVM_DEBUG(llvm::dbgs() <<
"\n\tOp after reassociation: " << *Op);
1383 Op->setOperand(0, left);
1384 Op->setOperand(1, right);
1387 Op->getResult(0).setType(Op->getOperand(0).getType());
1389 if (Op->hasOneUse() &&
1390 isa<AddIOp, AddFOp, SubIOp, SubFOp>(*Op->getUsers().begin())) {
1391 Operation *usrOp = *Op->getUsers().begin();
1392 usrOp->getResult(0).setType(usrOp->getOperand(0).getType());
1397static void fuseMulAndAddOrSubIntoFMAOp(Operation *Op, VectState *state) {
1398 Value acc = state->sextTruncDefMap.count(Op->getOperand(0).getDefiningOp())
1399 ? Op->getOperand(0).getDefiningOp()->getOperand(0)
1400 : Op->getOperand(0);
1401 Operation *mulOp = getOperandDefOp(state, Op, 1);
1402 Value lhs = state->sextTruncDefMap.count(mulOp->getOperand(0).getDefiningOp())
1403 ? mulOp->getOperand(0).getDefiningOp()->getOperand(0)
1404 : mulOp->getOperand(0);
1405 Value rhs = state->sextTruncDefMap.count(mulOp->getOperand(1).getDefiningOp())
1406 ? mulOp->getOperand(1).getDefiningOp()->getOperand(0)
1407 : mulOp->getOperand(1);
1410 state->builder.setInsertionPointAfter(Op);
1412 state->builder.create<vector::FMAOp>(Op->getLoc(), lhs, rhs, acc);
1415 bool isSub = isa<SubIOp, SubFOp>(Op);
1417 state->mscOps.insert(fmaOp);
1419 LLVM_DEBUG(llvm::dbgs() <<
"\n\nFused " << (isSub ?
"sub" :
"add") <<
" op "
1420 << *Op <<
"\n\tand mul op " << *mulOp
1421 <<
"\n\tinto fma op " << *fmaOp);
1424 Op->replaceAllUsesWith(fmaOp);
1428 if (mulOp->use_empty())
1438static void generateMulOrFMAOp(Operation *Op, Scheme &scheme,
1439 AIEOpAttributes &opAttr, VectState *state,
1440 const std::string &nextStart =
"") {
1442 assert(opAttr.start.size() == opAttr.offset.size() &&
1443 opAttr.start.size() == 2);
1446 state->builder.setInsertionPointAfter(Op);
1449 auto notMulOrFMAOp = [&](Operation *op) {
1450 return !isa<MulIOp, MulFOp, vector::FMAOp>(op);
1454 auto genOp = [&](Operation *Op, AIEOpAttributes &opAttr, VectState *state,
1455 bool i8xi8_pairedOp =
false) {
1458 if (
auto fmaOp = dyn_cast<vector::FMAOp>(Op))
1459 repOp = generateFMAOp(fmaOp, opAttr, state, i8xi8_pairedOp);
1461 else if (
auto mulOp = dyn_cast<MulIOp>(Op))
1462 repOp = generateMulOp<MulIOp>(mulOp, opAttr, state);
1464 else if (
auto mulOp = dyn_cast<MulFOp>(Op))
1465 repOp = generateMulOp<MulFOp>(mulOp, opAttr, state);
1467 llvm_unreachable(
"Operation not mul/fma op");
1471 Operation *repOp = genOp(Op, opAttr, state);
1472 LLVM_DEBUG(llvm::dbgs() <<
"\n\nGenerated AIE dialect mul/fma op " << *repOp);
1478 if (!nextStart.empty()) {
1479 if (state->aieml && scheme.lanes == 32 && scheme.xbits == 8 &&
1480 scheme.zbits == 8) {
1481 repOp = generateMulOrFMAConvOpForInt8(Op, opAttr, state);
1482 if (any_of(repOp->getUsers(), notMulOrFMAOp)) {
1484 IntegerType::get(repOp->getResult(0).getType().getContext(), 8);
1486 generateSRSOp(repOp->getResult(0), i8Type, state, repOp->getLoc());
1489 opAttr.start[1] = nextStart;
1490 Operation *pairedOp = genOp(Op, opAttr, state,
true);
1491 LLVM_DEBUG(llvm::dbgs() <<
"\n\nGenerated the paired AIE dialect "
1492 <<
"mul/fma op for 8x8 scheme " << *repOp);
1494 assert(!state->pairedOp.count(repOp));
1495 state->pairedOp[repOp] = pairedOp;
1498 if (any_of(Op->getUsers(), notMulOrFMAOp))
1499 repOp = concatAndInterleave_i8xi8(repOp, pairedOp, state, Op->getLoc());
1505 Op->replaceAllUsesWith(repOp);
1510static void computeBuffAttr_i32xi32(
1514 AIEOpAttributes &opAttr) {
1516 std::string startStr = std::to_string(start);
1518 std::string offsetStr =
"0x";
1519 for (
int i = vecSize - 1; i >= 0; --i)
1523 opAttr.start.push_back(startStr);
1524 opAttr.offset.push_back(offsetStr);
1525 opAttr.offset_hi.push_back(
"");
1526 opAttr.square.push_back(
"");
1527 opAttr.step.push_back(
"");
1531static void computeXbuffAttr_i16xi16(
1536 AIEOpAttributes &opAttr) {
1538 assert(colOffset >= -1 && (colOffset <= 1 || colOffset % 2 == 0) &&
1539 "cannot compute offset and square for xbuff");
1542 assert((accIncr <= 1 || colOffset <= 1) &&
1543 "cannot generate offset and square for xbuff");
1546 int32_t m2start = (start / 2) * 2;
1547 std::string startStr = std::to_string(m2start);
1549 int32_t m2Offset = start - m2start;
1553 std::string offsetStr =
"0x";
1554 int32_t offset = std::max(colOffset, accIncr);
1555 for (
int i = vecSize / 2 - 2; i >= 0; i -= 2) {
1556 offsetStr.push_back(offset <= 1 ?
'0' :
getHexValue((offset - 2) / 2));
1557 offsetStr.push_back(
getHexValue((i * accIncr) / 2));
1559 std::string offsetHiStr =
"0x";
1560 for (
int i = vecSize - 2, e = vecSize / 2; i >= e; i -= 2) {
1561 offsetHiStr.push_back(offset <= 1 ?
'0' :
getHexValue((offset - 2) / 2));
1562 offsetHiStr.push_back(
getHexValue((i * accIncr) / 2));
1566 int32_t cstep = std::min(2, std::abs(colOffset));
1567 int32_t astep = std::min(2, accIncr);
1568 assert(m2Offset == 0 || (astep <= 1 && cstep <= 1));
1570 SmallVector<int32_t> sqPattern = {astep + cstep, astep, cstep, 0};
1571 std::string squareStr =
"0x";
1572 for (
auto sq : sqPattern)
1576 opAttr.start.push_back(startStr);
1577 opAttr.offset.push_back(offsetStr);
1578 opAttr.offset_hi.push_back(offsetHiStr);
1579 opAttr.square.push_back(squareStr);
1580 opAttr.step.push_back(
"");
1584static void computeZbuffAttr_i16xi16(
1590 bool aieml, AIEOpAttributes &opAttr) {
1591 std::string offsetStr, offsetHiStr;
1593 assert(start < (aieml ? 32 : 16) &&
"zstart must be 4b value");
1594 std::string startStr = std::to_string(start);
1598 offsetStr = offsetHiStr =
"0";
1602 for (
int i = vecSize / 2 - 1; i >= 0; --i)
1605 for (
auto i = vecSize - 1, e = vecSize / 2; i >= e; --i)
1610 int32_t step = colOffset == -1 ? zeroOffset - 1 - start : colOffset;
1611 assert(step >= 0 &&
"zstep cannot be negative");
1612 std::string stepStr = std::to_string(step);
1615 opAttr.start.push_back(startStr);
1616 opAttr.offset.push_back(offsetStr);
1617 opAttr.offset_hi.push_back(offsetHiStr);
1618 opAttr.square.push_back(
"");
1619 opAttr.step.push_back(stepStr);
1627static void computeXbuffAttr_i8xi8(
1631 AIEOpAttributes &opAttr) {
1636 "each filter entry must be replicated at least twice for i8xi8 scheme");
1637 int32_t colStep = 2 * colOffset;
1638 assert(colStep % 4 == 0 &&
"xstep must be multiple of 4");
1641 int32_t m4start = (start / 4) * 4;
1642 std::string startStr = std::to_string(m4start);
1644 int32_t m4Offset = start - m4start;
1646 assert(m4Offset == 0 || m4Offset == 2);
1650 std::string offsetStr =
"0x";
1651 for (
int i = vecSize / 4 - 1; i >= 0; --i) {
1652 offsetStr.push_back(
getHexValue(colStep / 4 - 1));
1655 std::string stepStr = std::to_string(colStep);
1658 int32_t offsetWithoutDup = colOffset / 2;
1659 int32_t rstep = offsetWithoutDup >= 2 ? 2
1660 : colOffset == -1 ? 1
1662 assert(m4Offset == 0 || rstep <= 1);
1664 SmallVector<int32_t> sqPattern = {rstep, 0, rstep, 0};
1665 std::string squareStr =
"0x";
1666 for (
auto sq : sqPattern)
1670 opAttr.start.push_back(startStr);
1671 opAttr.offset.push_back(offsetStr);
1672 opAttr.offset_hi.push_back(
"");
1673 opAttr.square.push_back(squareStr);
1674 opAttr.step.push_back(stepStr);
1680static void computeZbuffAttr_i8xi8(
1685 AIEOpAttributes &opAttr, std::string &nextStart) {
1687 assert((colOffset <= 1 || colOffset % 2 == 0) &&
"zbuff value not supported");
1690 int32_t m2start = (start / 2) * 2;
1691 std::string startStr = std::to_string(m2start);
1693 int32_t m2Offset = start - m2start;
1697 std::string offsetStr =
"0x";
1698 for (
int i = vecSize / 4 - 1; i >= 0; --i) {
1699 int32_t val = i * accIncr + (colOffset + 1) / 2;
1703 std::string stepStr = std::to_string(2 * std::abs(colOffset));
1704 nextStart = std::to_string(m2start + 2 * accIncr * (vecSize / 4));
1708 int32_t rstep = colOffset >= 2 ? 2 : std::abs(colOffset);
1709 assert(m2Offset == 0 || rstep <= 1);
1711 SmallVector<int32_t> sqPattern = {accIncr + rstep, accIncr, rstep, 0};
1712 std::string squareStr =
"0x";
1713 for (
auto sq : sqPattern)
1717 opAttr.start.push_back(startStr);
1718 opAttr.offset.push_back(offsetStr);
1719 opAttr.offset_hi.push_back(
"");
1720 opAttr.square.push_back(squareStr);
1721 opAttr.step.push_back(stepStr);
1730static void fuseFMAOps(Operation *refOp,
1731 llvm::SmallSet<Operation *, 8> &fusedOpSet, int32_t cols,
1735 if (cols <= 1 || !isa<MulIOp, MulFOp, vector::FMAOp>(refOp) ||
1736 isSimpleVectIntrinsic(refOp, state))
1741 Operation *lOp = getOperandDefOp(state, refOp, 0);
1742 Operation *rOp = getOperandDefOp(state, refOp, 1);
1749 int xOffset = -1, zOffset = -1;
1759 Operation *curOp = refOp;
1760 SmallVector<Operation *, 8> fusedOps;
1762 for (
auto len = 0; len < cols - 1; ++len) {
1764 if (!curOp->hasOneUse())
1767 Operation *usrOp = *curOp->getUsers().begin();
1770 if (!isa<vector::FMAOp>(usrOp) || curOp->getBlock() != usrOp->getBlock() ||
1771 isSimpleVectIntrinsic(usrOp, state))
1774 if (isa<vector::FMAOp>(curOp) &&
1775 state->mscOps.count(curOp) != state->mscOps.count(usrOp))
1778 SmallVector<int32_t, 2> offsets;
1779 for (
size_t idx = 0; idx < 2; ++idx) {
1781 AIEVecAttributes cstat = getOperandVecStats(curOp, state, idx);
1782 AIEVecAttributes ustat = getOperandVecStats(usrOp, state, idx);
1786 if (cstat.vecSizeInBits != ustat.vecSizeInBits ||
1787 cstat.elementSizeInBits != ustat.elementSizeInBits ||
1788 cstat.loadFromMemory != ustat.loadFromMemory ||
1789 cstat.isSplat != ustat.isSplat)
1792 Operation *cdefOp = getOperandDefOp(state, curOp, idx);
1793 Operation *udefOp = getOperandDefOp(state, usrOp, idx);
1795 bool related = cdefOp == udefOp;
1796 if (!related && cstat.loadFromMemory && ustat.loadFromMemory) {
1797 IntervalReuse *civ = state->getIntervalForOperation(cdefOp);
1798 IntervalReuse *uiv = state->getIntervalForOperation(udefOp);
1810 int32_t offset = start2 - start1;
1817 if (offset > 1 && offset % 2 != 0)
1820 int32_t refStart = idx == 0 ? lstart : rstart;
1821 if (!ustat.isSplat && offset > 1 && refStart != 0)
1825 offsets.push_back(offset);
1828 if (offsets.size() < 2)
1832 if ((xOffset != -1 && xOffset != offsets[0]) ||
1833 (zOffset != -1 && zOffset != offsets[1]))
1836 xOffset = offsets[0];
1837 zOffset = offsets[1];
1838 fusedOps.push_back(usrOp);
1845 if (fusedOps.empty())
1848 LLVM_DEBUG(llvm::dbgs() <<
"\n\nFused following fma ops with op " << *refOp);
1852 for (
auto &op : fusedOps) {
1853 LLVM_DEBUG(llvm::dbgs() <<
"\n\tfma op " << *op);
1854 fusedOpSet.insert(op);
1856 fuseAccessExtent(refOp, op, state);
1858 op->replaceAllUsesWith(refOp);
1862 assert(!state->opToColOffsets.count(refOp));
1863 state->opToColOffsets[refOp] = std::make_pair(xOffset, zOffset);
1867static void computeXbuffAttributes(
1873 bool aieml, AIEOpAttributes &opAttr) {
1876 if ((scheme.lanes == 8 || (aieml && scheme.lanes == 16)) &&
1877 scheme.cols == 1 && scheme.xbits == 32 && scheme.zbits == 32)
1878 computeBuffAttr_i32xi32(scheme.lanes, start, accIncr, opAttr);
1880 else if ((scheme.lanes == 16 || (aieml && scheme.lanes == 32)) &&
1881 scheme.cols == 2 && scheme.xbits == 16 && scheme.zbits == 16) {
1883 assert((accIncr <= 1 || accIncr % 2 == 0) &&
1884 "loop step size value not supported");
1885 computeXbuffAttr_i16xi16(scheme.lanes, start, accIncr, colOffset, opAttr);
1888 else if ((scheme.lanes == 16 || (aieml && scheme.lanes == 32)) &&
1889 scheme.cols == 8 && scheme.xbits == 8 && scheme.zbits == 8) {
1891 assert(accIncr <= 1 &&
"loop step size greater than 1 not supported");
1894 if (colOffset == -1)
1895 colOffset = dupFactor;
1896 computeXbuffAttr_i8xi8(scheme.lanes, start, colOffset, opAttr);
1898 llvm_unreachable(
"Unsupported vectorization scheme");
1902static void computeZbuffAttributes(
1909 std::string &nextStart,
1910 AIEOpAttributes &opAttr) {
1913 if ((scheme.lanes == 8 || (aieml && scheme.lanes == 16)) &&
1914 scheme.cols == 1 && scheme.xbits == 32 && scheme.zbits == 32)
1915 computeBuffAttr_i32xi32(scheme.lanes, start, accIncr, opAttr);
1917 else if ((scheme.lanes == 16 || (aieml && scheme.lanes == 32)) &&
1918 scheme.cols == 2 && scheme.xbits == 16 && scheme.zbits == 16) {
1920 assert(accIncr <= 1 &&
"loop step size greater than 1 not supported");
1923 zeroOffset = zeroOffset == 0 ? scheme.lanes
1924 : start + zeroOffset - (start % zeroOffset);
1925 computeZbuffAttr_i16xi16(scheme.lanes, start, accIncr, zeroOffset,
1926 colOffset, aieml, opAttr);
1929 else if ((scheme.lanes == 16 || (aieml && scheme.lanes == 32)) &&
1930 scheme.cols == 8 && scheme.xbits == 8 && scheme.zbits == 8) {
1932 assert(accIncr <= 1 &&
"loop step size greater than 1 not supported");
1933 computeZbuffAttr_i8xi8(scheme.lanes, start, accIncr, colOffset, opAttr,
1936 llvm_unreachable(
"Unsupported vectorization scheme");
1941static void generateSchemeBasedMulOrFMAOp(Operation *Op, VectState *state) {
1942 int32_t lanes, cols;
1943 std::tie(lanes, cols) = getNumRowsAndCols(Op, state);
1945 Value lhs = state->sextTruncDefMap.count(Op->getOperand(0).getDefiningOp())
1946 ? Op->getOperand(0).getDefiningOp()->getOperand(0)
1947 : Op->getOperand(0);
1948 Value rhs = state->sextTruncDefMap.count(Op->getOperand(1).getDefiningOp())
1949 ? Op->getOperand(1).getDefiningOp()->getOperand(0)
1950 : Op->getOperand(1);
1953 Scheme scheme(lanes, cols, xbits, zbits);
1957 if (isSimpleVectIntrinsic(Op, state)) {
1960 AIEOpAttributes opAttr;
1962 for (
size_t idx = 0; idx < 2; ++idx) {
1963 opAttr.start.push_back(
"");
1964 opAttr.offset.push_back(
"");
1965 opAttr.offset_hi.push_back(
"");
1966 opAttr.square.push_back(
"");
1967 opAttr.step.push_back(
"");
1969 generateMulOrFMAOp(Op, scheme, opAttr, state);
1978 auto colOffset = state->opToColOffsets.count(Op) ? state->opToColOffsets[Op]
1979 : std::make_pair(-1, -1);
1983 AIEOpAttributes opAttr;
1987 std::string nextStart;
1990 for (
size_t idx = 0; idx < 2; ++idx) {
1991 AIEVecAttributes stat = getOperandVecStats(Op, state, idx);
1992 Operation *op = getOperandDefOp(state, Op, idx);
1994 int32_t start = 0, accIncr = 1;
1997 if (stat.loadFromMemory) {
1998 auto readOp = cast<TransferReadOp>(op);
2000 accIncr = stat.isSplat ? 0 : computeVecorizedLoopStepSize(readOp, state);
2006 computeXbuffAttributes(scheme, start, colOffset.first, accIncr,
2007 state->dupFactor, state->aieml, opAttr);
2009 computeZbuffAttributes(scheme, start, colOffset.second, accIncr,
2010 state->zeroOffset, state->aieml, nextStart,
2014 generateMulOrFMAOp(Op, scheme, opAttr, state, nextStart);
2020static void fuseFMAOpsForColumnTopology(func::FuncOp func, VectState *state) {
2022 llvm::SmallSet<Operation *, 8> fusedOpSet;
2025 func.walk([&](Operation *op) {
2026 if (isa<MulIOp, MulFOp, vector::FMAOp>(op)) {
2028 if (!fusedOpSet.count(op)) {
2029 auto [lanes, cols] = getNumRowsAndCols(op, state);
2032 fuseFMAOps(op, fusedOpSet, cols, state);
2038 for (
auto op : fusedOpSet)
2042template <
typename T1,
typename T2>
2043static bool matchAttributesAndDistanceForFusion(T1 curOp, T2 defOp) {
2044 return curOp.getOffset(0) == defOp.getOffset(0) &&
2045 curOp.getOffsetHi(0) == defOp.getOffsetHi(0) &&
2046 curOp.getSquare(0) == defOp.getSquare(0) &&
2047 curOp.getStep(0) == defOp.getStep(0) &&
2048 curOp.getOffset(1) == defOp.getOffset(1) &&
2049 curOp.getOffsetHi(1) == defOp.getOffsetHi(1) &&
2050 curOp.getSquare(1) == defOp.getSquare(1) &&
2051 curOp.getStep(1) == defOp.getStep(1) &&
2052 stoi(
static_cast<std::string
>(curOp.getStart(0))) -
2053 stoi(
static_cast<std::string
>(defOp.getStart(0))) ==
2055 stoi(
static_cast<std::string
>(curOp.getStart(1))) -
2056 stoi(
static_cast<std::string
>(defOp.getStart(1))) ==
2095static bool canFuseMulFMAOpsForInt16(Operation *Op) {
2097 assert(isa<aievec::aie1::FMAOp>(Op) &&
"operation must be an aievec fma op");
2098 auto curOp = cast<aievec::aie1::FMAOp>(Op);
2101 auto vType = cast<VectorType>(Op->getOperand(1).getType());
2102 Type stype = vType.getElementType();
2103 auto itype = llvm::dyn_cast<IntegerType>(stype);
2108 if (
unsigned width = itype.getWidth(); width != 16)
2112 Operation *mulOrFMAOp = Op->getOperand(2).getDefiningOp();
2114 if (!isa<aievec::aie1::MulOp, aievec::aie1::FMAOp>(mulOrFMAOp))
2118 if (!mulOrFMAOp->hasOneUse())
2122 if (mulOrFMAOp->getOperand(0) != Op->getOperand(0) ||
2123 mulOrFMAOp->getOperand(1) != Op->getOperand(1))
2126 Value lhs =
nullptr;
2127 Value rhs =
nullptr;
2128 Value acc =
nullptr;
2129 bool isMulOp =
false;
2133 if (
auto mulOp = dyn_cast<aievec::aie1::MulOp>(mulOrFMAOp)) {
2137 lhs = mulOp->getOperand(0);
2138 rhs = mulOp->getOperand(1);
2140 auto fmaOp = cast<aievec::aie1::FMAOp>(mulOrFMAOp);
2143 lhs = fmaOp->getOperand(0);
2144 rhs = fmaOp->getOperand(1);
2145 acc = fmaOp->getOperand(2);
2149 auto lUpdOp = dyn_cast<aievec::UPDOp>(lhs.getDefiningOp());
2150 auto rUpdOp = dyn_cast<aievec::UPDOp>(rhs.getDefiningOp());
2152 if (!lUpdOp || !rUpdOp) {
2158 if (lhs.getParentBlock() != rhs.getParentBlock())
2161 if (acc && rhs.getParentBlock() != acc.getParentBlock())
2166 return (isMulOp && matchAttributesAndDistanceForFusion(
2167 curOp, cast<aievec::aie1::MulOp>(mulOrFMAOp))) ||
2168 matchAttributesAndDistanceForFusion(
2169 curOp, cast<aievec::aie1::FMAOp>(mulOrFMAOp));
2173static void fuseMulFMAOpsForInt16(Operation *Op, VectState *state) {
2174 auto curOp = cast<aievec::aie1::FMAOp>(Op);
2176 Value lhs = curOp->getOperand(0);
2182 auto lUpdOp = dyn_cast<aievec::UPDOp>(lhs.getDefiningOp());
2183 if (lUpdOp.getIndex() == 1) {
2184 auto lUpdOp0 = dyn_cast<aievec::UPDOp>(lUpdOp.getVector().getDefiningOp());
2185 lUpdOp->replaceAllUsesWith(lUpdOp0);
2192 auto rUpdOp = dyn_cast<aievec::UPDOp>(curOp->getOperand(1).getDefiningOp());
2193 state->builder.setInsertionPointAfter(rUpdOp);
2194 AIEVecAttributes rstat = getOperandVecStats(curOp, state, 1);
2195 assert(rstat.vecSizeInBits % 256 == 0);
2196 Value concatRhs =
nullptr;
2198 if (rstat.vecSizeInBits == 256) {
2199 VectorType concatType =
2201 SmallVector<Value> sources = {rUpdOp->getResult(0), rUpdOp->getResult(0)};
2202 concatRhs = generateConcatOp(sources, state, rUpdOp->getLoc(), concatType);
2206 Operation *convOp =
nullptr;
2207 Operation *mulOrFMAOp = Op->getOperand(2).getDefiningOp();
2208 auto mulOp = dyn_cast<aievec::aie1::MulOp>(mulOrFMAOp);
2209 auto fmaOp = dyn_cast<aievec::aie1::FMAOp>(mulOrFMAOp);
2213 aievec::aie1::MulOp defOp = mulOp;
2214 zStart = stoi(
static_cast<std::string
>(defOp.getStart(1)));
2216 aievec::aie1::FMAOp defOp = fmaOp;
2217 zStart = stoi(
static_cast<std::string
>(defOp.getStart(1)));
2220 auto vType = cast<VectorType>(Op->getOperand(1).getType());
2223 auto defOp = mulOp ? mulOp : fmaOp;
2224 state->builder.setInsertionPoint(defOp);
2225 Location loc = defOp->getLoc();
2229 concatRhs = generateShiftOp(concatRhs, concatRhs, shiftBytes, state, loc);
2231 Type stype = vType.getElementType();
2232 auto itype = cast<IntegerType>(stype);
2233 unsigned width = itype.getWidth() <= 8 ? 32 : 64;
2234 Type ctype = IntegerType::get(itype.getContext(), width);
2235 Type opType = VectorType::get(vType.getShape(), ctype);
2236 Value acc =
nullptr;
2239 int32_t M = itype.getWidth();
2243 lhs = curOp->getOperand(0);
2246 convOp = state->builder.create<aievec::MulConvOp>(loc, opType, lhs,
2249 acc = defOp->getOperand(2);
2250 bool isSub = state->mscOps.count(defOp);
2251 convOp = state->builder.create<aievec::FMAConvOp>(
2252 loc, opType, lhs, concatRhs, acc, M, N, isSub);
2255 Op->replaceAllUsesWith(convOp);
2260static void fuseMulFMAOpsByMulFMAConv(func::FuncOp func, VectState *state) {
2261 func.walk([&](Operation *Op) {
2262 if (isa<aievec::aie1::FMAOp>(Op) && canFuseMulFMAOpsForInt16(Op))
2263 fuseMulFMAOpsForInt16(Op, state);
2273static void generateAIEMulOrFMAOpsInFunc(func::FuncOp func, VectState *state) {
2276 func.walk([&](Operation *op) {
2277 if (isa<MulIOp, MulFOp, vector::FMAOp>(op))
2278 generateSchemeBasedMulOrFMAOp(op, state);
2284static void generateAddOrSubOp(Operation *Op, AIEOpAttributes &opAttr,
2288 state->builder.setInsertionPointAfter(Op);
2291 Operation *repOp =
nullptr;
2292 if (isa<SubIOp, SubFOp>(Op)) {
2293 repOp = generateSubOp(Op, opAttr, state);
2294 LLVM_DEBUG(llvm::dbgs() <<
"\n\nGenerated AIE dialect sub op " << *repOp);
2296 repOp = generateAddOp(Op, opAttr, state);
2297 LLVM_DEBUG(llvm::dbgs() <<
"\n\nGenerated AIE dialect sub op " << *repOp);
2302 Op->replaceAllUsesWith(repOp);
2308static void generateSchemeBasedAddOrSubOp(Operation *Op, VectState *state) {
2311 AIEOpAttributes opAttr;
2315 if (isSimpleVectIntrinsic(Op, state)) {
2317 for (
size_t idx = 0; idx < 2; ++idx) {
2318 opAttr.start.push_back(
"");
2319 opAttr.offset.push_back(
"");
2320 opAttr.offset_hi.push_back(
"");
2321 opAttr.square.push_back(
"");
2323 generateAddOrSubOp(Op, opAttr, state);
2330 for (
size_t idx = 0; idx < 2; ++idx) {
2331 AIEVecAttributes stat = getOperandVecStats(Op, state, idx);
2332 assert(stat.elementSizeInBits >= 16 &&
2333 "advanced scheme for add op on int8 data type not supported");
2335 int32_t start = 0, accIncr = 1;
2336 std::string startStr;
2337 std::string offsetStr, offsetHiStr;
2338 std::string squareStr;
2342 if (stat.loadFromMemory) {
2343 Operation *op = Op->getOperand(idx).getDefiningOp();
2344 auto readOp = cast<TransferReadOp>(op);
2346 accIncr = stat.isSplat ? 0 : computeVecorizedLoopStepSize(readOp, state);
2352 if (stat.elementSizeInBits == 32) {
2353 startStr = std::to_string(start);
2355 for (
int i = 7; i >= 0; --i)
2358 if (stat.lanes > 8) {
2359 assert(stat.lanes == 16 &&
"Cannot generate offset for add/sub op");
2361 assert(accIncr <= 1 &&
"Cannot generate offset for given loop stride");
2363 for (
int i = 15; i >= 8; --i)
2366 }
else if (stat.elementSizeInBits == 16) {
2367 assert(accIncr <= 1 &&
"cannot generate offset for given loop stride");
2369 int32_t m2Offset = start % 2;
2370 startStr = std::to_string(start - m2Offset);
2374 offsetStr = offsetHiStr =
"0";
2377 for (
int i = 6; i >= 0; i -= 2) {
2378 offsetStr.push_back(
'0');
2379 offsetStr.push_back(
getHexValue((i * accIncr) / 2));
2382 for (
int i = 14; i >= 8; i -= 2) {
2383 offsetHiStr.push_back(
'0');
2384 offsetHiStr.push_back(
getHexValue((i * accIncr) / 2));
2389 if (m2Offset == 0 && accIncr == 0)
2392 assert(m2Offset == 0 || accIncr == 0);
2394 int32_t astep = std::min(1, accIncr);
2395 SmallVector<int32_t> sqPattern = {3 * astep, 2 * astep, astep, 0};
2396 for (
auto sq : sqPattern)
2400 llvm_unreachable(
"Cannot generate advanced add op for given datatype");
2403 opAttr.start.push_back(startStr);
2404 opAttr.offset.push_back(offsetStr);
2405 opAttr.offset_hi.push_back(offsetHiStr);
2406 opAttr.square.push_back(squareStr);
2409 generateAddOrSubOp(Op, opAttr, state);
2415static void generateAIEAddOrSubOpsInFunc(func::FuncOp func, VectState *state) {
2416 func.walk([&](Operation *op) {
2417 if (isa<AddIOp, AddFOp, SubIOp, SubFOp>(op))
2418 generateSchemeBasedAddOrSubOp(op, state);
2426static void insertUPDOpsInLoop(affine::AffineForOp forOp, VectState *state) {
2428 for (affine::AffineForOp nestedOp :
2429 forOp.getRegion().getOps<affine::AffineForOp>())
2430 insertUPDOpsInLoop(nestedOp, state);
2436 mlir::DenseMap<std::tuple<IntervalReuse *, int32_t, int32_t>,
2437 std::pair<aievec::UPDOp, int8_t>>
2442 mlir::DenseMap<Operation *, aievec::UPDOp> readOpToUpdMap;
2444 Region ®ion = forOp.getRegion();
2445 for (TransferReadOp readOp : region.getOps<TransferReadOp>()) {
2446 aievec::UPDOp updOp = generateUPDOp(readOp, memToUpdMap, region, state);
2447 readOpToUpdMap[readOp] = updOp;
2451 for (
auto &map : readOpToUpdMap) {
2452 Operation *op = map.first;
2453 op->replaceAllUsesWith(map.second);
2459static void insertUPDOpsInFunc(func::FuncOp func, VectState *state) {
2460 for (affine::AffineForOp forOp : func.getOps<affine::AffineForOp>()) {
2461 insertUPDOpsInLoop(forOp, state);
2468static void insertSRSOp(Operation *Op, VectState *state) {
2470 if (Op->use_empty() || Op->getNumResults() == 0)
2474 assert(writesToAccumulator(Op));
2479 auto isNonAIEOp = [&](Operation *op) {
return !
isAIEOp(op); };
2480 if (!any_of(Op->getUsers(), isNonAIEOp))
2485 mlir::DenseMap<Type, aievec::SRSOp> typeToSRSOpMap;
2488 state->builder.setInsertionPointAfter(Op);
2492 for (
auto user : Op->getUsers()) {
2500 MemRefType memRefType =
nullptr;
2501 if (
auto writeOp = dyn_cast<TransferWriteOp>(user)) {
2503 memRefType = cast<MemRefType>(writeOp.getSource().getType());
2504 scalarType = memRefType.getElementType();
2506 scalarType = getElementTypeOrSelf(*user->getResultTypes().begin());
2507 assert(scalarType &&
"failed to form SRS op");
2510 for (
auto operand : user->getOperands()) {
2511 if (operand.getDefiningOp() == Op) {
2514 if (state->aieml && memRefType &&
2515 cast<VectorType>(Op->getOperand(0).getType())
2517 .getIntOrFloatBitWidth() == 8 &&
2518 cast<VectorType>(Op->getResult(0).getType())
2520 .getIntOrFloatBitWidth() ==
2521 scalarType.getIntOrFloatBitWidth()) {
2525 aievec::CastOp castOp = generateCastOp(Op->getResult(0), castType,
2526 false, state, Op->getLoc());
2527 assert(castOp &&
"Failed to create Cast intrinsic");
2528 user->replaceUsesOfWith(operand, castOp);
2531 aievec::SRSOp srsOp;
2532 if (!typeToSRSOpMap.count(scalarType)) {
2534 generateSRSOp(Op->getResult(0), scalarType, state, Op->getLoc());
2535 LLVM_DEBUG(llvm::dbgs() <<
"\n\nCreated SRS op " << srsOp
2536 <<
" for the acc output of operation " << Op);
2537 typeToSRSOpMap[scalarType] = srsOp;
2539 srsOp = typeToSRSOpMap[scalarType];
2540 assert(srsOp &&
"Failed to create SRS intrinsic");
2542 user->replaceUsesOfWith(operand, srsOp);
2550static void insertSRSOpsInFunc(func::FuncOp func, VectState *state) {
2551 func.walk([&](Operation *op) {
2553 if (writesToAccumulator(op))
2554 insertSRSOp(op, state);
2561template <
typename TransferOp>
2562static void setInBounds(TransferOp op) {
2563 if (op.getTransferRank() == 0)
2565 SmallVector<bool, 4> bools(op.getTransferRank(),
true);
2566 OpBuilder b(op.getContext());
2567 op->setAttr(op.getInBoundsAttrName(), b.getBoolArrayAttr(bools));
2579static void redundantLoadStoreOptimization(ModuleOp module) {
2580 for (func::FuncOp func : module.getOps<func::FuncOp>()) {
2582 func.walk([&](Operation *Op) {
2583 if (
auto readOp = dyn_cast<TransferReadOp>(Op)) {
2584 if (!readOp.getInBounds())
2585 setInBounds<TransferReadOp>(readOp);
2586 }
else if (
auto writeOp = dyn_cast<TransferWriteOp>(Op)) {
2587 if (!writeOp.getInBounds())
2588 setInBounds<TransferWriteOp>(writeOp);
2593 IRRewriter rewriter(module.getContext());
2594 vector::transferOpflowOpt(rewriter, func);
2600static void preCanonicalizeIR(ModuleOp module) {
2601 PassManager pm(module.getContext());
2602 pm.addPass(createCanonicalizerPass());
2603 [[maybe_unused]]
bool success = pm.run(module).succeeded();
2605 redundantLoadStoreOptimization(module);
2611static void postCanonicalizeIR(ModuleOp module) {
2612 PassManager pm(module.getContext());
2613 pm.addPass(createCanonicalizerPass());
2614 pm.addPass(createCSEPass());
2615 pm.addPass(createLoopInvariantCodeMotionPass());
2616 pm.addPass(createLowerAffinePass());
2617 [[maybe_unused]]
bool success = pm.run(module).succeeded();
2624computeEnclosingLoopsPerBlock(affine::AffineForOp forOp, VectState *state,
2625 SmallVector<Operation *, 8> &enclosingLoops) {
2627 for (affine::AffineForOp nestedOp :
2628 forOp.getRegion().getOps<affine::AffineForOp>()) {
2629 enclosingLoops.push_back(nestedOp);
2630 computeEnclosingLoopsPerBlock(nestedOp, state, enclosingLoops);
2631 enclosingLoops.pop_back();
2636 for (TransferReadOp readOp : forOp.getRegion().getOps<TransferReadOp>()) {
2638 Block *block = readOp->getBlock();
2639 state->blockToEnclosingLoops[block] = enclosingLoops;
2648static void reassociateMulOpInFunc(func::FuncOp func, VectState *state) {
2649 func.walk([&](Operation *op) {
2652 if (isa<MulIOp, MulFOp, vector::FMAOp>(op) && isWellFormedVectorOp(op)) {
2654 reassociateMulOpWithSplat(op, state);
2657 reassociateMulOpBasedOnVecSize(op, state);
2667static void reassociateAddOpInFunc(func::FuncOp func, VectState *state) {
2668 func.walk([&](Operation *op) {
2670 if (isa<AddIOp, AddFOp>(op) && isWellFormedVectorOp(op)) {
2672 assert(op->getNumOperands() == 2 && op->getNumResults() == 1);
2675 Operation *rhsOp = getOperandDefOp(state, op, 1);
2677 state->sextTruncDefMap.count(op->getOperand(0).getDefiningOp())
2678 ? op->getOperand(0).getDefiningOp()->getOperand(0)
2679 : op->getOperand(0);
2681 state->sextTruncDefMap.count(op->getOperand(1).getDefiningOp())
2682 ? op->getOperand(1).getDefiningOp()->getOperand(0)
2683 : op->getOperand(1);
2685 if (!isa<MulIOp, MulFOp>(rhsOp)) {
2686 Operation *lhsOp = getOperandDefOp(state, op, 0);
2688 if (isa<MulIOp, MulFOp>(lhsOp)) {
2689 LLVM_DEBUG(llvm::dbgs() <<
"\n\nReassociating addOp " << *op
2690 <<
" to place mul as rhs operand");
2691 op->setOperand(0, right);
2692 op->setOperand(1, left);
2693 LLVM_DEBUG(llvm::dbgs() <<
"\n\taddOp after reassociation: " << *op);
2696 op->setOperand(0, left);
2697 op->setOperand(1, right);
2710static void coalesceLHSOpVectorsInFunc(func::FuncOp func, VectState *state) {
2712 func.walk([&](TransferReadOp op) {
2715 bool onlyLHS =
true;
2716 for (
auto user : op->getUsers()) {
2717 if (!isa<MulIOp, MulFOp, vector::FMAOp>(user) ||
2718 user->getOperand(0).getDefiningOp() != op) {
2734 for (
auto interval : state->reuseIntervals) {
2735 interval->coalesceIntervals();
2740static void recordSextOps(func::FuncOp func, VectState *state) {
2741 func.walk([&](ExtSIOp op) {
2742 state->sextTruncDefMap[op] = op->getOperand(0).getDefiningOp();
2744 func.walk([&](TruncIOp op) {
2745 state->sextTruncDefMap[op] = op->getOperand(0).getDefiningOp();
2751static void computeReuse(TransferReadOp readOp, VectState *state) {
2753 AffineExpr linearAccess = constructLinearizedAffineExpr(readOp, state);
2755 auto [base, offset] = getBaseAndOffset(linearAccess);
2758 int32_t step = computeVecorizedLoopStepSize(readOp, state);
2761 bool isSplat = readOp.getPermutationMap().isConstant();
2766 unsigned minVecSize = 128;
2767 for (
auto user : readOp->getUsers()) {
2768 if (isa<MulIOp, MulFOp, vector::FMAOp>(user)) {
2769 if (user->getOperand(0).getDefiningOp() == readOp ||
2770 user->getOperand(1).getDefiningOp() == readOp) {
2775 if (isa<ExtSIOp>(user)) {
2776 auto extsiOp = cast<ExtSIOp>(user);
2777 for (
auto consumer : extsiOp->getUsers()) {
2778 if (isa<MulIOp, MulFOp, vector::FMAOp>(consumer)) {
2779 if ((state->sextTruncDefMap.count(
2780 consumer->getOperand(0).getDefiningOp()) &&
2781 state->sextTruncDefMap[consumer->getOperand(0)
2782 .getDefiningOp()] == readOp) ||
2783 (state->sextTruncDefMap.count(
2784 consumer->getOperand(1).getDefiningOp()) &&
2785 state->sextTruncDefMap[consumer->getOperand(1)
2786 .getDefiningOp()] == readOp)) {
2795 auto vecType = cast<VectorType>(readOp.getVector().getType());
2808 for (
auto interval : state->reuseIntervals) {
2810 if (interval->potentialReuse(readOp, base, state->blockToEnclosingLoops)) {
2813 interval->insertInterval(readOp, state->opToIntervalMap, offset, step,
2814 isSplat, minVecSize);
2823 iv->
insertInterval(readOp, state->opToIntervalMap, offset, step, isSplat,
2825 state->reuseIntervals.push_back(iv);
2829static LogicalResult isUnalignedLoad(TransferReadOp readOp, VectState *state) {
2830 auto vectorType = cast<VectorType>(readOp.getResult().getType());
2833 AffineExpr linearAccess = constructLinearizedAffineExpr(readOp, state);
2834 if (linearAccess.isSymbolicOrConstant()) {
2838 auto memRefType = cast<MemRefType>(readOp.getSource().getType());
2839 MLIRContext *context = memRefType.getContext();
2840 ArrayRef<int64_t> sizes = memRefType.getShape();
2841 int numDims = sizes.size();
2843 auto block = readOp->getBlock();
2844 assert(state->blockToEnclosingLoops.count(block) &&
2845 "enclosing loops should have been computed for the read operation\n");
2846 auto enclosingLoops = state->blockToEnclosingLoops[block];
2848 SmallVector<Value, 4> indices(readOp.getIndices().begin(),
2849 readOp.getIndices().end());
2854 dyn_cast<AffineDimExpr>(getAffineDimExpr(numDims - 1, context))) {
2855 auto index = indices[dimExpr.getPosition()];
2858 for (
auto loop : enclosingLoops) {
2859 auto affineForOp = cast<affine::AffineForOp>(loop);
2860 auto iv = affineForOp.getInductionVar();
2861 auto invariants = affine::getInvariantAccesses(iv, indices);
2863 if (!invariants.count(index)) {
2864 int step = affineForOp.getStepAsInt();
2866 return readOp->emitError()
2867 <<
"Loop step of inner index of " << readOp->getName()
2868 <<
" is not divisible by number of vector lanes.";
2875 affine::AffineBound ub = affineForOp.getUpperBound();
2876 AffineMap origUbMap = ub.getMap();
2877 if (!origUbMap.isEmpty() && !origUbMap.isConstant()) {
2878 AffineExpr origUbMapResult = origUbMap.getResult(0);
2881 std::tie(base, offset) = getBaseAndOffset(origUbMapResult);
2882 if (offset % lanes) {
2883 return readOp->emitError()
2884 <<
"Loop upper bound's affine map offset of inner index of "
2885 << readOp->getName()
2886 <<
" is not divisible by number of vector lanes.";
2895 for (
int i = 1; i < numDims; ++i) {
2897 if (sizes[i] == -1) {
2901 if (sizes[i] % lanes) {
2902 return readOp->emitError()
2903 << readOp->getName() <<
"'s shape size of index " << i
2904 <<
" is not divisible by number of vector lanes.";
2911static LogicalResult hasUnalignedLoads(func::FuncOp func, VectState *state) {
2912 WalkResult result = func.walk([&](TransferReadOp op) {
2913 if (failed(isUnalignedLoad(op, state))) {
2914 return WalkResult::interrupt();
2916 return WalkResult::advance();
2919 if (result.wasInterrupted()) {
2931static void computeReuseInFunc(func::FuncOp func, VectState *state) {
2934 func.walk([&](TransferReadOp op) { computeReuse(op, state); });
2939static void rewriteFMAOpsInFunc(func::FuncOp func, VectState *state) {
2941 func.walk([&](Operation *Op) {
2942 if (isa<AddIOp, AddFOp, SubIOp, SubFOp>(Op) && isWellFormedVectorOp(Op)) {
2945 if (canFuseMulAndAddOrSubIntoFMAOp(Op, state))
2946 fuseMulAndAddOrSubIntoFMAOp(Op, state);
2953static void reassociateOpsInFunc(func::FuncOp func, VectState *state) {
2957 reassociateMulOpInFunc(func, state);
2962 reassociateAddOpInFunc(func, state);
2977 assert(shiftParam < 64 &&
"SRS shift parameter should be between 0 and 63");
2978 assert(zeroOffset < 128 &&
2979 "Zero offset in the filter should be between 0 and 127");
2980 assert(dupFactor < 128 &&
2981 "Duplicate offset in the filter should be between 0 and 127");
2983 ModuleOp
module = getOperation();
2986 preCanonicalizeIR(module);
2989 for (func::FuncOp func :
module.getOps<func::FuncOp>()) {
2991 bool aieml = ::AIEML;
2992 bool unallignedCheck = ::unalignedLoadsCheck;
2993 if (this->unalignedLoadsCheck.hasValue())
2994 unallignedCheck = this->unalignedLoadsCheck;
2995 if (this->aieml.hasValue())
2996 aieml = this->aieml;
2997 auto *state =
new VectState(func.getContext(), shiftParam, zeroOffset,
2998 dupFactor, unallignedCheck, aieml);
3001 recordSextOps(func, state);
3005 for (
auto forOp : func.getOps<affine::AffineForOp>()) {
3006 SmallVector<Operation *, 8> enclosingLoops;
3007 enclosingLoops.push_back(forOp);
3008 computeEnclosingLoopsPerBlock(forOp, state, enclosingLoops);
3012 if (state->unalignedLoadsCheck && failed(hasUnalignedLoads(func, state))) {
3013 func.emitError() <<
"Cannot apply aie-vectorize to " << func->getName()
3014 <<
" because alignment check has failed.\n";
3020 computeReuseInFunc(func, state);
3025 reassociateOpsInFunc(func, state);
3028 rewriteFMAOpsInFunc(func, state);
3031 coalesceLHSOpVectorsInFunc(func, state);
3034 fuseFMAOpsForColumnTopology(func, state);
3037 generateAIEMulOrFMAOpsInFunc(func, state);
3041 insertSRSOpsInFunc(func, state);
3046 generateAIEAddOrSubOpsInFunc(func, state);
3051 insertUPDOpsInFunc(func, state);
3055 fuseMulFMAOpsByMulFMAConv(func, state);
3060 postCanonicalizeIR(module);
3064 return std::make_unique<AIEVectorize>();
int32_t computeStartInAIEVec(Operation *op, VectState *state)
void insertInterval(mlir::vector::TransferReadOp readOp, llvm::DenseMap< mlir::Operation *, IntervalReuse * > &dataAccessToIntervalMap, int32_t offset, int32_t forLoopStepSize, bool isSplat=false, unsigned minVecSize=128)
void setAccessExtent(mlir::Operation *op, std::pair< int32_t, int32_t > &extent)
int32_t getIntervalWidth(mlir::Operation *op)
std::pair< int32_t, int32_t > getAccessExtent(mlir::Operation *op)
std::pair< int32_t, int32_t > getInterval(mlir::Operation *op)
void markLHSOperandVec(mlir::Operation *op)
bool isPowerOfTwo(int32_t n)
int32_t getVectorSizeInBits(mlir::VectorType type)
unsigned getVectorLaneSize(mlir::VectorType type)
char getHexValue(int val)
std::unique_ptr< mlir::Pass > createAIEVectorizePass()
mlir::VectorType createVectorType(unsigned lanes, mlir::Type elementType)
int32_t getElementSizeInBits(mlir::VectorType type)
mlir::VectorType getVectorOpDestType(mlir::VectorType type, bool AIE2)
bool isAIEOp(mlir::Operation *op)
void runOnOperation() override
Generate AIE vector intrinsics for the current module.