13#include "mlir/Dialect/Affine/IR/AffineOps.h"
14#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
15#include "mlir/Dialect/Linalg/IR/Linalg.h"
16#include "mlir/Dialect/Linalg/Utils/Utils.h"
17#include "mlir/Dialect/Transform/IR/TransformTypes.h"
18#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
19#include "mlir/Dialect/Transform/Utils/Utils.h"
20#include "mlir/Dialect/Vector/IR/VectorOps.h"
22#include "llvm/ADT/DenseMap.h"
23#include "llvm/ADT/TypeSwitch.h"
27#define DEBUG_TYPE "aievec-transforms"
43static Value vectorizeTensor(OpBuilder &rewriter, Location loc, Value tensor) {
44 auto opTy = tensor.getType();
45 auto shapeTy = cast<ShapedType>(opTy);
46 auto shape = shapeTy.getShape();
47 auto elemTy = shapeTy.getElementType();
48 auto toMemRefOp = rewriter.create<bufferization::ToMemrefOp>(
49 loc, MemRefType::get(shape, elemTy), tensor);
50 auto rank = shape.size();
51 auto newShape = shape.slice(0, rank - 2);
52 auto opVecElemTy = VectorType::get(shape.slice(rank - 2, 2), elemTy);
53 auto opMemrefVecTy = MemRefType::get(newShape, opVecElemTy);
55 rewriter.create<vector::TypeCastOp>(loc, opMemrefVecTy, toMemRefOp);
56 auto toTensorOp = rewriter.create<bufferization::ToTensorOp>(
57 loc, RankedTensorType::get(newShape, opVecElemTy), typeCastOp);
58 toTensorOp.setRestrict(
true);
59 return toTensorOp.getResult();
65static Value scalarizeTensor(OpBuilder &rewriter, Location loc, Value tensor) {
66 auto opTy = tensor.getType();
67 auto shapeTy = cast<ShapedType>(opTy);
69 auto vecShape = shapeTy.getShape();
70 auto vecElemTy = cast<VectorType>(shapeTy.getElementType());
71 auto elemTy = vecElemTy.getElementType();
72 auto toMemRefVecTyOp = rewriter.create<bufferization::ToMemrefOp>(
73 loc, MemRefType::get(vecShape, vecElemTy), tensor);
75 SmallVector<int64_t> scalShape;
76 for (
auto d : shapeTy.getShape())
77 scalShape.push_back(d);
78 for (
auto d : vecElemTy.getShape())
79 scalShape.push_back(d);
80 auto opMemrefScalTy = MemRefType::get(scalShape, elemTy);
82 rewriter.create<vector::TypeCastOp>(loc, opMemrefScalTy, toMemRefVecTyOp);
84 auto toTensorOp = rewriter.create<bufferization::ToTensorOp>(
85 loc, RankedTensorType::get(scalShape, elemTy), typeCastOp);
86 toTensorOp.setRestrict(
true);
87 return toTensorOp.getResult();
90static bool vectorizeContractionOpBlock(OpBuilder &rewriter, Location loc,
91 Block &srcBlock, Block &dstBlock) {
92 auto ctx = rewriter.getContext();
93 OpBuilder::InsertionGuard g(rewriter);
94 rewriter.setInsertionPointToStart(&dstBlock);
95 auto baA =
static_cast<Value
>(dstBlock.getArgument(0));
96 auto baB =
static_cast<Value
>(dstBlock.getArgument(1));
97 auto baC =
static_cast<Value
>(dstBlock.getArgument(2));
99 llvm::DenseMap<Value, Value> convertedValues;
100 convertedValues.try_emplace(srcBlock.getArgument(0), baA);
101 convertedValues.try_emplace(srcBlock.getArgument(1), baB);
102 convertedValues.try_emplace(srcBlock.getArgument(2), baC);
103 auto indexingMaps = rewriter.getAffineMapArrayAttr(
104 {AffineMap::getPermutationMap(ArrayRef<unsigned>{1, 0, 2}, ctx)
106 AffineMap::getPermutationMap(ArrayRef<unsigned>{0, 2, 1}, ctx)
108 AffineMap::getPermutationMap(ArrayRef<unsigned>{2, 0, 1}, ctx)
110 auto iteratorTypes = rewriter.getArrayAttr(
111 {vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel),
112 vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel),
113 vector::IteratorTypeAttr::get(ctx, vector::IteratorType::reduction)});
114 bool addOpFound =
false, mulOpFound =
false;
115 WalkResult walkResult = srcBlock.walk([&](Operation *op) {
116 return llvm::TypeSwitch<Operation *, WalkResult>(op)
117 .Case<arith::AddIOp, arith::AddFOp>([&](
auto addOp) {
119 return WalkResult::interrupt();
121 auto lhs = addOp->getOperand(0);
122 auto rhs = addOp->getOperand(1);
124 auto lhsDefOp = lhs.getDefiningOp();
125 auto rhsDefOp = rhs.getDefiningOp();
126 if (lhsDefOp && isa<arith::MulIOp, arith::MulFOp>(lhsDefOp)) {
127 opA = convertedValues[lhsDefOp->getOperand(0)];
128 opB = convertedValues[lhsDefOp->getOperand(1)];
129 opC = convertedValues[rhs];
130 }
else if (rhsDefOp && isa<arith::MulIOp, arith::MulFOp>(rhsDefOp)) {
131 opA = convertedValues[rhsDefOp->getOperand(0)];
132 opB = convertedValues[rhsDefOp->getOperand(1)];
133 opC = convertedValues[lhs];
135 return WalkResult::interrupt();
136 auto conOp = rewriter.create<vector::ContractionOp>(
137 loc, opA, opB, opC, indexingMaps, iteratorTypes);
138 convertedValues.try_emplace(op->getResult(0), conOp.getResult());
139 return WalkResult::advance();
141 .Case<arith::MulIOp, arith::MulFOp>([&](
auto) {
143 return WalkResult::interrupt();
145 return WalkResult::skip();
147 .Case<linalg::YieldOp>([&](linalg::YieldOp yieldOp) {
148 rewriter.create<linalg::YieldOp>(
149 loc, convertedValues[yieldOp.getValues()[0]]);
150 return WalkResult::advance();
152 .Default([&](Operation *unaryOp) {
153 if (unaryOp->getNumResults() != 1 || unaryOp->getNumOperands() != 1)
154 return WalkResult::interrupt();
155 auto srcOpIn = unaryOp->getOperand(0);
156 auto srcOpInTy = srcOpIn.getType();
157 auto srcOpTy = unaryOp->getResultTypes()[0];
158 auto dstOpIn = convertedValues[srcOpIn];
159 Type dstOpTy = dstOpIn.getType();
160 if (srcOpInTy != srcOpTy) {
161 auto vecElemTy = dyn_cast<VectorType>(dstOpTy);
163 return WalkResult::interrupt();
164 dstOpTy = VectorType::get(vecElemTy.getShape(), srcOpTy);
167 rewriter.create(loc, unaryOp->getName().getIdentifier(),
168 {dstOpIn}, {dstOpTy}, unaryOp->getAttrs());
169 convertedValues.try_emplace(unaryOp->getResult(0),
170 newOp->getResult(0));
171 return WalkResult::advance();
174 return mulOpFound && addOpFound && !walkResult.wasInterrupted();
177DiagnosedSilenceableFailure transform::VectorizeContractionOp::applyToOne(
178 TransformRewriter &rewriter, linalg::GenericOp target,
179 ApplyToEachResultList &results, TransformState &state) {
181 auto ctx = target.getContext();
182 SmallVector<Value> inputs = target.getInputs();
183 if (SmallVector<Value> outputs = target.getOutputs();
184 inputs.size() != 2 || outputs.size() != 1)
185 return emitSilenceableError() <<
"payload is not a contraction.";
188 SmallVector<utils::IteratorType> iterators = target.getIteratorTypesArray();
189 auto innerMostIterators =
190 SmallVector<utils::IteratorType>(iterators.end() - 3, iterators.end());
191 auto outerMostIterators =
192 SmallVector<utils::IteratorType>(iterators.begin(), iterators.end() - 3);
194 if (!linalg::isParallelIterator(innerMostIterators[0]) ||
195 !linalg::isParallelIterator(innerMostIterators[1]) ||
196 !linalg::isReductionIterator(innerMostIterators[2]))
197 return emitSilenceableError()
198 <<
"linalg.generic op innermost iterators don't correspond with a "
199 "gemm-like contraction.";
201 auto indexingMaps = target.getIndexingMapsArray();
208 AffineMap::getPermutationMap(ArrayRef<unsigned>{1, 0, 2}, ctx)
211 AffineMap::getPermutationMap(ArrayRef<unsigned>{0, 2, 1}, ctx)
214 AffineMap::getPermutationMap(ArrayRef<unsigned>{2, 0, 1}, ctx)
218 SmallVector<int64_t> outerMostResults;
219 for (int64_t i = 0; i < indexingMaps[0].getNumResults() - 2; i++)
220 outerMostResults.push_back(i);
222 auto innerMostA = indexingMaps[0].dropResults(outerMostResults);
223 auto innerMostB = indexingMaps[1].dropResults(outerMostResults);
224 auto innerMostC = indexingMaps[2].dropResults(outerMostResults);
228 int64_t numOuterMostDims = indexingMaps[0].getNumDims() - 3;
229 if (innerMostA != mmAidxMap.shiftDims(numOuterMostDims) ||
230 innerMostB != mmBidxMap.shiftDims(numOuterMostDims) ||
231 innerMostC != mmCidxMap.shiftDims(numOuterMostDims))
232 return emitSilenceableError()
233 <<
"linalg.generic op innermost indexing maps don't correspond with "
234 "a gemm-like contraction.";
240 SmallVector<AffineExpr> remOuterDims;
241 for (
unsigned i = 0; i < numOuterMostDims; i++)
242 remOuterDims.push_back(getAffineDimExpr(i, ctx));
243 unsigned numResults = indexingMaps[0].getNumResults();
244 SmallVector<int64_t> positions = {numResults - 2, numResults - 1};
245 auto outerMostAidxMap =
246 indexingMaps[0].dropResults(positions).replaceDimsAndSymbols(
247 remOuterDims, {}, numOuterMostDims, 0);
248 auto outerMostBidxMap =
249 indexingMaps[1].dropResults(positions).replaceDimsAndSymbols(
250 remOuterDims, {}, numOuterMostDims, 0);
251 auto outerMostCidxMap =
252 indexingMaps[2].dropResults(positions).replaceDimsAndSymbols(
253 remOuterDims, {}, numOuterMostDims, 0);
255 rewriter.setInsertionPoint(target);
256 Location loc = target.getLoc();
258 auto opA = vectorizeTensor(rewriter, loc, target.getInputs()[0]);
259 auto opB = vectorizeTensor(rewriter, loc, target.getInputs()[1]);
260 auto opC = vectorizeTensor(rewriter, loc, target.getOutputs()[0]);
263 auto newOp = rewriter.create<linalg::GenericOp>(
264 loc, TypeRange({opC.getType()}), ValueRange({opA, opB}),
266 SmallVector<AffineMap>(
267 {outerMostAidxMap, outerMostBidxMap, outerMostCidxMap}),
269 auto &opBody = newOp->getRegion(0);
270 opBody.push_back(
new Block());
271 auto &opBlock = opBody.front();
272 opBlock.addArguments({cast<TensorType>(opA.getType()).getElementType(),
273 cast<TensorType>(opB.getType()).getElementType(),
274 cast<TensorType>(opC.getType()).getElementType()},
276 if (!vectorizeContractionOpBlock(rewriter, loc, target->getRegion(0).front(),
278 return emitSilenceableError()
279 <<
"linalg.generic op payload does not correspond with a "
280 "vectorizable contraction.";
283 auto res = scalarizeTensor(rewriter, loc, newOp.getResults()[0]);
284 rewriter.replaceOp(target, res);
286 results.push_back(newOp);
288 return DiagnosedSilenceableFailure::success();
291#define GET_OP_CLASSES
292#include "aie/Dialect/AIEVec/TransformOps/AIEVecTransformOps.cpp.inc"