MLIR-AIE
AIEVecTransformOps.cpp
Go to the documentation of this file.
1//===- AIEVecTransformOps.cpp -----------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2023-2024 Advanced Micro Devices, Inc. or its affiliates
8//
9//===----------------------------------------------------------------------===//
10
12
13#include "mlir/Dialect/Affine/IR/AffineOps.h"
14#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
15#include "mlir/Dialect/Linalg/IR/Linalg.h"
16#include "mlir/Dialect/Linalg/Utils/Utils.h"
17#include "mlir/Dialect/Transform/IR/TransformTypes.h"
18#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
19#include "mlir/Dialect/Transform/Utils/Utils.h"
20#include "mlir/Dialect/Vector/IR/VectorOps.h"
21
22#include "llvm/ADT/DenseMap.h"
23#include "llvm/ADT/TypeSwitch.h"
24
25using namespace mlir;
26
27#define DEBUG_TYPE "aievec-transforms"
28
29//===----------------------------------------------------------------------===//
30// VectorizeContractionOp
31//===----------------------------------------------------------------------===//
32
33// Emit IR to convert the given tensor in the form tensor<...xMxNxTy> into a
34// tensor<...xvector<MxNxTy>>. It does so by bufferizing, casting, and
35// tensorizing.
36//
37// E.g., for a `%t : tensor<64x64x8x4xf32>`, it will generate the following IR:
38// ```
39// %0 = bufferization.to_memref %t : memref<64x64x8x4xf32>
40// %1 = vector.type_cast %0 : memref<64x64xvector<8x4xf32>>
41// %2 = bufferization.to_tensor %1 restrict : memref<64x64xvector<8x4xf32>>
42// ```
43static Value vectorizeTensor(OpBuilder &rewriter, Location loc, Value tensor) {
44 auto opTy = tensor.getType();
45 auto shapeTy = cast<ShapedType>(opTy);
46 auto shape = shapeTy.getShape();
47 auto elemTy = shapeTy.getElementType();
48 auto toMemRefOp = rewriter.create<bufferization::ToMemrefOp>(
49 loc, MemRefType::get(shape, elemTy), tensor);
50 auto rank = shape.size();
51 auto newShape = shape.slice(0, rank - 2);
52 auto opVecElemTy = VectorType::get(shape.slice(rank - 2, 2), elemTy);
53 auto opMemrefVecTy = MemRefType::get(newShape, opVecElemTy);
54 auto typeCastOp =
55 rewriter.create<vector::TypeCastOp>(loc, opMemrefVecTy, toMemRefOp);
56 auto toTensorOp = rewriter.create<bufferization::ToTensorOp>(
57 loc, RankedTensorType::get(newShape, opVecElemTy), typeCastOp);
58 toTensorOp.setRestrict(true);
59 return toTensorOp.getResult();
60}
61
62// Emit IR to convert the given tensor in the form tensor<...xvector<MxNxTy>>
63// into a tensor<...xMxNxTy>. It performs the inverse operation to
64// `vectorizeTensor` above.
65static Value scalarizeTensor(OpBuilder &rewriter, Location loc, Value tensor) {
66 auto opTy = tensor.getType();
67 auto shapeTy = cast<ShapedType>(opTy);
68
69 auto vecShape = shapeTy.getShape();
70 auto vecElemTy = cast<VectorType>(shapeTy.getElementType());
71 auto elemTy = vecElemTy.getElementType();
72 auto toMemRefVecTyOp = rewriter.create<bufferization::ToMemrefOp>(
73 loc, MemRefType::get(vecShape, vecElemTy), tensor);
74
75 SmallVector<int64_t> scalShape;
76 for (auto d : shapeTy.getShape())
77 scalShape.push_back(d);
78 for (auto d : vecElemTy.getShape())
79 scalShape.push_back(d);
80 auto opMemrefScalTy = MemRefType::get(scalShape, elemTy);
81 auto typeCastOp =
82 rewriter.create<vector::TypeCastOp>(loc, opMemrefScalTy, toMemRefVecTyOp);
83
84 auto toTensorOp = rewriter.create<bufferization::ToTensorOp>(
85 loc, RankedTensorType::get(scalShape, elemTy), typeCastOp);
86 toTensorOp.setRestrict(true);
87 return toTensorOp.getResult();
88}
89
90static bool vectorizeContractionOpBlock(OpBuilder &rewriter, Location loc,
91 Block &srcBlock, Block &dstBlock) {
92 auto ctx = rewriter.getContext();
93 OpBuilder::InsertionGuard g(rewriter);
94 rewriter.setInsertionPointToStart(&dstBlock);
95 auto baA = static_cast<Value>(dstBlock.getArgument(0));
96 auto baB = static_cast<Value>(dstBlock.getArgument(1));
97 auto baC = static_cast<Value>(dstBlock.getArgument(2));
98 // Store vectorized values for op replacement
99 llvm::DenseMap<Value, Value> convertedValues;
100 convertedValues.try_emplace(srcBlock.getArgument(0), baA);
101 convertedValues.try_emplace(srcBlock.getArgument(1), baB);
102 convertedValues.try_emplace(srcBlock.getArgument(2), baC);
103 auto indexingMaps = rewriter.getAffineMapArrayAttr(
104 {AffineMap::getPermutationMap(ArrayRef<unsigned>{1, 0, 2}, ctx)
105 .dropResults(0),
106 AffineMap::getPermutationMap(ArrayRef<unsigned>{0, 2, 1}, ctx)
107 .dropResults(0),
108 AffineMap::getPermutationMap(ArrayRef<unsigned>{2, 0, 1}, ctx)
109 .dropResults(0)});
110 auto iteratorTypes = rewriter.getArrayAttr(
111 {vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel),
112 vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel),
113 vector::IteratorTypeAttr::get(ctx, vector::IteratorType::reduction)});
114 bool addOpFound = false, mulOpFound = false;
115 WalkResult walkResult = srcBlock.walk([&](Operation *op) {
116 return llvm::TypeSwitch<Operation *, WalkResult>(op)
117 .Case<arith::AddIOp, arith::AddFOp>([&](auto addOp) {
118 if (addOpFound)
119 return WalkResult::interrupt();
120 addOpFound = true;
121 auto lhs = addOp->getOperand(0);
122 auto rhs = addOp->getOperand(1);
123 Value opA, opB, opC;
124 auto lhsDefOp = lhs.getDefiningOp();
125 auto rhsDefOp = rhs.getDefiningOp();
126 if (lhsDefOp && isa<arith::MulIOp, arith::MulFOp>(lhsDefOp)) {
127 opA = convertedValues[lhsDefOp->getOperand(0)];
128 opB = convertedValues[lhsDefOp->getOperand(1)];
129 opC = convertedValues[rhs];
130 } else if (rhsDefOp && isa<arith::MulIOp, arith::MulFOp>(rhsDefOp)) {
131 opA = convertedValues[rhsDefOp->getOperand(0)];
132 opB = convertedValues[rhsDefOp->getOperand(1)];
133 opC = convertedValues[lhs];
134 } else
135 return WalkResult::interrupt();
136 auto conOp = rewriter.create<vector::ContractionOp>(
137 loc, opA, opB, opC, indexingMaps, iteratorTypes);
138 convertedValues.try_emplace(op->getResult(0), conOp.getResult());
139 return WalkResult::advance();
140 })
141 .Case<arith::MulIOp, arith::MulFOp>([&](auto) {
142 if (mulOpFound)
143 return WalkResult::interrupt();
144 mulOpFound = true;
145 return WalkResult::skip();
146 })
147 .Case<linalg::YieldOp>([&](linalg::YieldOp yieldOp) {
148 rewriter.create<linalg::YieldOp>(
149 loc, convertedValues[yieldOp.getValues()[0]]);
150 return WalkResult::advance(); // Or ::interrupt()
151 })
152 .Default([&](Operation *unaryOp) {
153 if (unaryOp->getNumResults() != 1 || unaryOp->getNumOperands() != 1)
154 return WalkResult::interrupt();
155 auto srcOpIn = unaryOp->getOperand(0);
156 auto srcOpInTy = srcOpIn.getType();
157 auto srcOpTy = unaryOp->getResultTypes()[0];
158 auto dstOpIn = convertedValues[srcOpIn];
159 Type dstOpTy = dstOpIn.getType();
160 if (srcOpInTy != srcOpTy) {
161 auto vecElemTy = dyn_cast<VectorType>(dstOpTy);
162 if (!vecElemTy)
163 return WalkResult::interrupt();
164 dstOpTy = VectorType::get(vecElemTy.getShape(), srcOpTy);
165 }
166 auto newOp =
167 rewriter.create(loc, unaryOp->getName().getIdentifier(),
168 {dstOpIn}, {dstOpTy}, unaryOp->getAttrs());
169 convertedValues.try_emplace(unaryOp->getResult(0),
170 newOp->getResult(0));
171 return WalkResult::advance();
172 });
173 });
174 return mulOpFound && addOpFound && !walkResult.wasInterrupted();
175}
176
177DiagnosedSilenceableFailure transform::VectorizeContractionOp::applyToOne(
178 TransformRewriter &rewriter, linalg::GenericOp target,
179 ApplyToEachResultList &results, TransformState &state) {
180
181 auto ctx = target.getContext();
182 SmallVector<Value> inputs = target.getInputs();
183 if (SmallVector<Value> outputs = target.getOutputs();
184 inputs.size() != 2 || outputs.size() != 1)
185 return emitSilenceableError() << "payload is not a contraction.";
186
187 // Split the iterators in two: inner contraction + remaining
188 SmallVector<utils::IteratorType> iterators = target.getIteratorTypesArray();
189 auto innerMostIterators =
190 SmallVector<utils::IteratorType>(iterators.end() - 3, iterators.end());
191 auto outerMostIterators =
192 SmallVector<utils::IteratorType>(iterators.begin(), iterators.end() - 3);
193
194 if (!linalg::isParallelIterator(innerMostIterators[0]) ||
195 !linalg::isParallelIterator(innerMostIterators[1]) ||
196 !linalg::isReductionIterator(innerMostIterators[2]))
197 return emitSilenceableError()
198 << "linalg.generic op innermost iterators don't correspond with a "
199 "gemm-like contraction.";
200
201 auto indexingMaps = target.getIndexingMapsArray();
202 //===
203 // Verify that the innermost dimensions are a contraction
204 //===
205
206 // 1. Build the indexing maps for the operands of a GEMM contraction
207 auto mmAidxMap =
208 AffineMap::getPermutationMap(ArrayRef<unsigned>{1, 0, 2}, ctx)
209 .dropResults(0);
210 auto mmBidxMap =
211 AffineMap::getPermutationMap(ArrayRef<unsigned>{0, 2, 1}, ctx)
212 .dropResults(0);
213 auto mmCidxMap =
214 AffineMap::getPermutationMap(ArrayRef<unsigned>{2, 0, 1}, ctx)
215 .dropResults(0);
216
217 // 2. Get the indexing maps for the 2 innermost dimmensions of each operand
218 SmallVector<int64_t> outerMostResults;
219 for (int64_t i = 0; i < indexingMaps[0].getNumResults() - 2; i++)
220 outerMostResults.push_back(i);
221
222 auto innerMostA = indexingMaps[0].dropResults(outerMostResults);
223 auto innerMostB = indexingMaps[1].dropResults(outerMostResults);
224 auto innerMostC = indexingMaps[2].dropResults(outerMostResults);
225
226 // 3. Compare the extended GEMM contraction indexing maps with the indexing
227 // maps of the innermost results.
228 int64_t numOuterMostDims = indexingMaps[0].getNumDims() - 3;
229 if (innerMostA != mmAidxMap.shiftDims(numOuterMostDims) ||
230 innerMostB != mmBidxMap.shiftDims(numOuterMostDims) ||
231 innerMostC != mmCidxMap.shiftDims(numOuterMostDims))
232 return emitSilenceableError()
233 << "linalg.generic op innermost indexing maps don't correspond with "
234 "a gemm-like contraction.";
235
236 //===
237 // Create new indexing maps for the vectorized operation
238 //===
239
240 SmallVector<AffineExpr> remOuterDims;
241 for (unsigned i = 0; i < numOuterMostDims; i++)
242 remOuterDims.push_back(getAffineDimExpr(i, ctx));
243 unsigned numResults = indexingMaps[0].getNumResults();
244 SmallVector<int64_t> positions = {numResults - 2, numResults - 1};
245 auto outerMostAidxMap =
246 indexingMaps[0].dropResults(positions).replaceDimsAndSymbols(
247 remOuterDims, {}, numOuterMostDims, 0);
248 auto outerMostBidxMap =
249 indexingMaps[1].dropResults(positions).replaceDimsAndSymbols(
250 remOuterDims, {}, numOuterMostDims, 0);
251 auto outerMostCidxMap =
252 indexingMaps[2].dropResults(positions).replaceDimsAndSymbols(
253 remOuterDims, {}, numOuterMostDims, 0);
254
255 rewriter.setInsertionPoint(target);
256 Location loc = target.getLoc();
257 // Insert reshape ops for input operands
258 auto opA = vectorizeTensor(rewriter, loc, target.getInputs()[0]);
259 auto opB = vectorizeTensor(rewriter, loc, target.getInputs()[1]);
260 auto opC = vectorizeTensor(rewriter, loc, target.getOutputs()[0]);
261
262 // Create new linalg.generic with vector arguments and vectorized basic block
263 auto newOp = rewriter.create<linalg::GenericOp>(
264 loc, TypeRange({opC.getType()}), ValueRange({opA, opB}),
265 ValueRange({opC}),
266 SmallVector<AffineMap>(
267 {outerMostAidxMap, outerMostBidxMap, outerMostCidxMap}),
268 outerMostIterators);
269 auto &opBody = newOp->getRegion(0);
270 opBody.push_back(new Block());
271 auto &opBlock = opBody.front();
272 opBlock.addArguments({cast<TensorType>(opA.getType()).getElementType(),
273 cast<TensorType>(opB.getType()).getElementType(),
274 cast<TensorType>(opC.getType()).getElementType()},
275 {loc, loc, loc});
276 if (!vectorizeContractionOpBlock(rewriter, loc, target->getRegion(0).front(),
277 opBlock))
278 return emitSilenceableError()
279 << "linalg.generic op payload does not correspond with a "
280 "vectorizable contraction.";
281
282 // Insert reshape ops for output operand
283 auto res = scalarizeTensor(rewriter, loc, newOp.getResults()[0]);
284 rewriter.replaceOp(target, res);
285
286 results.push_back(newOp);
287
288 return DiagnosedSilenceableFailure::success();
289}
290
291#define GET_OP_CLASSES
292#include "aie/Dialect/AIEVec/TransformOps/AIEVecTransformOps.cpp.inc"