MLIR-AIE
VectorToVectorConversions.cpp
Go to the documentation of this file.
1//===-VectorToVectorConversions.cpp - Conversions within Vector -*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2023-2024 Advanced Micro Devices, Inc.
8//
9//===----------------------------------------------------------------------===//
10// This file contains conversions and rewrites to the Vector dialect to make
11// it compatible with the available vector instructions in AIE architectures
12//===----------------------------------------------------------------------===//
13
17#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
18#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
19#include "mlir/Dialect/Affine/IR/AffineOps.h"
20#include "mlir/Dialect/Func/IR/FuncOps.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
23#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
24#include "mlir/IR/PatternMatch.h"
25#include "mlir/Pass/PassManager.h"
26#include "mlir/Transforms/DialectConversion.h"
27#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28#include <algorithm>
29
30#define DEBUG_TYPE "aievec-canonicalization"
31
32using namespace mlir;
33using namespace arith;
34using namespace vector;
35using namespace xilinx;
36using namespace xilinx::aievec;
37
38//============================================================================//
39//================== Common AIE canonicalization analysis ====================//
40//============================================================================//
41
42static TargetBackend decodeTargetBackend(const std::string &backend) {
43 if (!backend.empty()) {
44 if (backend == "llvmir")
45 return TargetBackend::LLVMIR;
46 if (backend != "cpp")
47 return TargetBackend::UNKNOWN;
48 }
49 return TargetBackend::CPP;
50}
51
52static AIEArch decodeAIETarget(const std::string &target) {
53 if (!target.empty()) {
54 if (target == "aieml" || target == "aie2")
55 return AIEArch::AIE2;
56 if (target != "aie")
57 return AIEArch::UNKNOWN;
58 }
59 return AIEArch::AIE;
60}
61
62//============================================================================//
63//================== Common AIE canonicalization analysis ====================//
64//============================================================================//
65
66static bool isGemmBTransposedContractionOp(vector::ContractionOp op) {
67 if (op.getKind() != vector::CombiningKind::ADD)
68 return false;
69
70 // Get and check shape of operands
71 auto lhsShape = op.getLhsType().getShape();
72 auto rhsShape = op.getRhsType().getShape();
73 auto accShape = cast<ShapedType>(op.getAccType()).getShape();
74 if (lhsShape.size() < 2 || rhsShape.size() < 2 || accShape.size() < 2)
75 return false;
76
77 // Check that the innermost iterators match gemm-like iterators
78 SmallVector<vector::IteratorType> iterators = op.getIteratorTypesArray();
79 if (iterators.size() < 3)
80 return false;
81 auto innerMostIterators =
82 SmallVector<vector::IteratorType>(iterators.end() - 3, iterators.end());
83 if (vector::IteratorType::parallel != innerMostIterators[0] ||
84 vector::IteratorType::parallel != innerMostIterators[1] ||
85 vector::IteratorType::reduction != innerMostIterators[2])
86 return false;
87
88 // Get indexing maps of iterators for operands
89 SmallVector<AffineMap, 4> indexingMaps(op.getIndexingMapsArray());
90 SmallVector<int64_t> outerMostResults;
91 for (int64_t i = 0; i < indexingMaps[0].getNumResults() - 2; i++)
92 outerMostResults.push_back(i);
93
94 auto innerLhsMap = indexingMaps[0].dropResults(outerMostResults);
95 auto innerRhsMap = indexingMaps[1].dropResults(outerMostResults);
96 auto innerAccMap = indexingMaps[2].dropResults(outerMostResults);
97
98 // Check whether they conform to a "transposed B" gemm
99 auto *ctx = op.getContext();
100 auto mmAidxMap =
101 AffineMap::getPermutationMap(ArrayRef<unsigned>{1, 0, 2}, ctx)
102 .dropResults(0);
103 auto mmBidxMap =
104 AffineMap::getPermutationMap(ArrayRef<unsigned>{0, 1, 2}, ctx)
105 .dropResults(0);
106 auto mmCidxMap =
107 AffineMap::getPermutationMap(ArrayRef<unsigned>{2, 0, 1}, ctx)
108 .dropResults(0);
109 int64_t numOuterMostDims = indexingMaps[0].getNumDims() - 3;
110 return innerLhsMap == mmAidxMap.shiftDims(numOuterMostDims) &&
111 innerRhsMap == mmBidxMap.shiftDims(numOuterMostDims) &&
112 innerAccMap == mmCidxMap.shiftDims(numOuterMostDims);
113}
114
115//============================================================================//
116//============ Common AIE canonicalization conversion patterns ===============//
117//============================================================================//
118
119// This pattern converts a `vector.transfer_read` with an unaligned access
120// into an aligned `vector.transfer_read` twice as long, followed by a
121// `vector.extract_strided_slice` selecting the subvector matching the
122// original `vector.transfer_read`.
124 : public OpConversionPattern<vector::TransferReadOp> {
125 using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern;
126
128 int64_t alignment)
129 : OpConversionPattern<vector::TransferReadOp>(context),
131
132 LogicalResult
133 matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
134 ConversionPatternRewriter &rewriter) const override {
135 // Check that it's not a splat transfer read.
136 if (adaptor.getPermutationMap().isConstant())
137 return failure();
138
139 // Check if the transfer is unaligned.
140 auto vType = readOp.getVectorType();
141 int64_t offset =
143 .value_or(0);
144 if (offset == 0)
145 return failure();
146
147 // Verify that we can load a vector 2x as long as the original
148 auto vLen = vType.getShape().back();
149 auto longVecTy = VectorType::get(2 * vLen, vType.getElementType());
150 auto longVecSize = getElementSizeInBits(vType) * 2 * vLen;
151 if (longVecSize > maxVectorSize)
152 return failure();
153
154 // Calculate the aligned indices for the lower and higher parts.
155 // TODO: Add support for cases where the offset is greater than the
156 // TODO: vector length.
157 auto loc = readOp.getLoc();
158 Value oldInnerMostIdx = adaptor.getIndices().back();
159 auto offsetCorrectionMap =
160 AffineMap::get(1, 0, getAffineDimExpr(0, readOp.getContext()) - offset);
161 Value newInnerMostIdx = rewriter
162 .create<affine::AffineApplyOp>(
163 readOp.getLoc(), offsetCorrectionMap,
164 SmallVector<Value, 1>({oldInnerMostIdx}))
165 .getResult();
166 SmallVector<Value, 8> alignedIdx;
167 alignedIdx.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
168 alignedIdx[alignedIdx.size() - 1] = newInnerMostIdx;
169
170 // Create the aligned transfer read for a vector 2x as long that covers the
171 // elements of the unaligned vector.
172 auto newReadOp = rewriter.create<vector::TransferReadOp>(
173 loc, longVecTy, adaptor.getSource(), alignedIdx, adaptor.getPadding());
174
175 // Create a `vector.extract_strided_slice` to extract the unaligned vector.
176 rewriter.replaceOpWithNewOp<vector::ExtractStridedSliceOp>(
177 readOp, newReadOp.getResult(), offset, vLen, 1);
178
179 return success();
180 }
181
184};
185
186// This pattern converts a `vector.transfer_read` with a splat permutation map
187// into a contiguous `vector.transfer_read` followed by a `vector.extract` to
188// obtain the splat value and a `vector.broadcast` to broadcast it into a
189// vector of the right size.
191 : public OpConversionPattern<vector::TransferReadOp> {
192 using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern;
193
195 : OpConversionPattern<vector::TransferReadOp>(context) {}
196
197 LogicalResult
198 matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
199 ConversionPatternRewriter &rewriter) const override {
200 AffineMap map = readOp.getPermutationMap();
201 if (!map.isConstant())
202 return failure();
203
204 Value srcMemRef = adaptor.getSource();
205 SmallVector<Value, 8> indices;
206 Value newIdx;
207 int64_t offset = 0;
208 // If it's a zero-rank memory access
209 if (cast<MemRefType>(srcMemRef.getType()).getRank() == 0) {
210 srcMemRef = rewriter
211 .create<memref::ExpandShapeOp>(
212 readOp.getLoc(), SmallVector<int64_t, 1>({1}),
213 srcMemRef, SmallVector<ReassociationIndices, 1>({}))
214 .getResult();
215 newIdx = rewriter.create<arith::ConstantOp>(readOp.getLoc(),
216 rewriter.getIndexAttr(0L));
217 indices.push_back(newIdx);
218 } else {
219 indices.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
220 newIdx = indices[indices.size() - 1];
221 // If the innermost index comes from an `affine.apply` op, take the base
222 // as the new innermost index for the new `vector.transfer_read`, and the
223 // offset as the index for the `aievec.broadcast` op.
224 if (auto applyOp = newIdx.getDefiningOp<affine::AffineApplyOp>())
225 if (applyOp.getAffineMap().getNumDims() == 1) {
226 newIdx = applyOp.getMapOperands()[0];
227 offset = applyOp.getAffineMap().compose(ArrayRef<int64_t>{0})[0];
228 }
229 }
230 // XXX: We assume we are reading 1D vectors
231 int64_t vlen = readOp.getVector().getType().getShape()[0];
232 if (offset >= vlen) {
233 // If the splat element is beyond the first vector, we calculate the
234 // address of the vector containing the element.
235 int64_t numElemsToSkip = vlen * (offset / vlen);
236 offset = offset % vlen;
237 auto newAddrMap = AffineMap::get(
238 1, 0, getAffineDimExpr(0, readOp.getContext()) + numElemsToSkip);
239 newIdx =
240 rewriter
241 .create<affine::AffineApplyOp>(readOp.getLoc(), newAddrMap,
242 SmallVector<Value, 1>({newIdx}))
243 .getResult();
244 }
245 indices[indices.size() - 1] = newIdx;
246 auto newReadOp = rewriter.create<vector::TransferReadOp>(
247 readOp.getLoc(), readOp.getVector().getType(), srcMemRef, indices,
248 adaptor.getPadding());
249 auto extractOp = rewriter.create<vector::ExtractOp>(
250 readOp.getLoc(), newReadOp.getResult(), ArrayRef<int64_t>{offset});
251 rewriter.replaceOpWithNewOp<vector::SplatOp>(
252 readOp, newReadOp.getVector().getType(), extractOp.getResult());
253 return success();
254 }
255};
256
257// This pattern moves cast operations as close as possible to the source of
258// the data. This helps to simplify dealing with patterns that may vary only
259// by these sorts of casts between data manipulation operations and arithmetic
260// ops.
261// TODO: Generalize this op and instantiate for different types of cast ops.
263 HoistCastOpToDataSourcePattern(MLIRContext *context)
264 : RewritePattern(arith::ExtSIOp::getOperationName(), /*benefit=*/1,
265 context) {}
266
267 LogicalResult matchAndRewrite(Operation *op,
268 PatternRewriter &rewriter) const override {
269 arith::ExtSIOp extOp = cast<arith::ExtSIOp>(op);
270 Operation *defOp = extOp.getIn().getDefiningOp();
271 // If it's a data source op, we're done.
272 if (!defOp || isa<vector::TransferReadOp, memref::LoadOp,
273 affine::AffineLoadOp, func::CallOp>(defOp))
274 return failure();
275
276 // At the moment, we only accept ops we know we can swap with cast.
277 if (!isa<vector::BroadcastOp, vector::ExtractOp, vector::SplatOp,
278 vector::ExtractStridedSliceOp>(defOp))
279 return failure();
280
281 Type extOpInTy = extOp.getIn().getType();
282 SmallVector<Value, 4> inputs;
283 for (Value operand : defOp->getOperands()) {
284 Type operandTy = operand.getType();
285 VectorType extOpInVecTy = dyn_cast<VectorType>(extOpInTy);
286 VectorType operandVecTy = dyn_cast<VectorType>(operandTy);
287 if (operandTy == extOpInTy) {
288 Type outTy = extOp.getOut().getType();
289 inputs.push_back(
290 rewriter.create<arith::ExtSIOp>(extOp.getLoc(), outTy, operand)
291 .getOut());
292 } else if (extOpInVecTy && extOpInVecTy.getElementType() == operandTy) {
293 // Promote from vector to scalar -> scalar conversion for this operand
294 Type outTy =
295 cast<VectorType>(extOp.getOut().getType()).getElementType();
296 inputs.push_back(
297 rewriter.create<arith::ExtSIOp>(extOp.getLoc(), outTy, operand)
298 .getOut());
299 } else if (operandVecTy && operandVecTy.getElementType() == extOpInTy) {
300 // Promote from scalar to vector -> vector conversion for this operand
301 Type outTy =
302 VectorType::get(operandVecTy.getShape(), extOp.getOut().getType());
303 inputs.push_back(
304 rewriter.create<arith::ExtSIOp>(extOp.getLoc(), outTy, operand)
305 .getOut());
306 } else if (extOpInVecTy && operandVecTy &&
307 (extOpInVecTy.getElementType() ==
308 operandVecTy.getElementType())) {
309 // Hoist through a vector shape change
310 Type outTy = VectorType::get(
311 operandVecTy.getShape(),
312 cast<VectorType>(extOp.getOut().getType()).getElementType());
313 inputs.push_back(
314 rewriter.create<arith::ExtSIOp>(extOp.getLoc(), outTy, operand)
315 .getOut());
316 } else {
317 inputs.push_back(operand);
318 }
319 }
320
321 auto *newOp =
322 rewriter.create(extOp->getLoc(), defOp->getName().getIdentifier(),
323 inputs, {extOp.getOut().getType()}, defOp->getAttrs());
324 rewriter.replaceOp(extOp, newOp->getResult(0));
325 return success();
326 }
327};
328
329// This pattern swaps a UnaryOpA followed by UnaryOpB. This pattern can be used
330// to improve pattern matching for mixed-type arithmetic ops, by getting sign
331// extension ops closer to the single-type arithmetic operations.
332template <class UnaryOpA, class UnaryOpB>
333struct SwapUnaryOpsPattern : public OpRewritePattern<UnaryOpB> {
335 // This function takes the chain of operations A->B, and returns the new type
336 // between B and A after the swap.
337 using InferTypeB2AFnTy = std::function<Type(UnaryOpA aOp, UnaryOpB bOp)>;
339
340 SwapUnaryOpsPattern(MLIRContext *context, InferTypeB2AFnTy inferType)
341 : OpRewritePattern<UnaryOpB>(context), inferTypeB2A(inferType) {}
342
343 LogicalResult matchAndRewrite(UnaryOpB bOp,
344 PatternRewriter &rewriter) const override {
345 static_assert(
346 UnaryOpA::template hasTrait<OpTrait::OneOperand>(),
347 "SwapUnaryOps can only be instantiated for single-operand ops");
348 static_assert(
349 UnaryOpB::template hasTrait<OpTrait::OneOperand>(),
350 "SwapUnaryOps can only be instantiated for single-operand ops");
351 UnaryOpA aOp = bOp.getOperand().template getDefiningOp<UnaryOpA>();
352 if (!aOp)
353 return rewriter.notifyMatchFailure(bOp, UnaryOpB::getOperationName() +
354 " not preceeded by " +
355 UnaryOpA::getOperationName());
356
357 Type newA2BTy = inferTypeB2A(aOp, bOp);
358
359 auto newA =
360 rewriter.create<UnaryOpB>(bOp->getLoc(), SmallVector<Type>({newA2BTy}),
361 aOp->getOperands(), bOp->getAttrs());
362 auto newB = rewriter.create<UnaryOpA>(
363 bOp->getLoc(), SmallVector<Type>({bOp.getResult().getType()}),
364 newA->getResults(), aOp->getAttrs());
365 rewriter.replaceOp(bOp, newB.getResult());
366 return success();
367 }
368};
369
370static SmallVector<Value> collapseInnerMostDimIndices(PatternRewriter &b,
371 Location loc, int numDims,
372 ValueRange indices,
373 ArrayRef<int64_t> shape,
374 AffineMap layout) {
375 // TODO: Don't assume trivial layout
376 assert(layout.isMinorIdentity() &&
377 "dimension collapse in non-identity layout is not implemented");
378 auto newIdxExpr = b.getAffineDimExpr(numDims - 1);
379 int64_t stride = 1;
380 for (int64_t dim = numDims - 2; dim >= 0; dim--) {
381 stride *= shape[shape.size() - (numDims - dim - 1)];
382 newIdxExpr = newIdxExpr + b.getAffineDimExpr(dim) * stride;
383 }
384 auto newIndexMap = AffineMap::get(numDims, 0, newIdxExpr);
385 Value newInnerMostIdxValue =
386 b.create<affine::AffineApplyOp>(loc, newIndexMap,
387 indices.take_back(numDims))
388 .getResult();
389 SmallVector<Value> newIdxRange;
390 for (auto idx : indices.drop_back(numDims))
391 newIdxRange.push_back(idx);
392 newIdxRange.push_back(newInnerMostIdxValue);
393 return newIdxRange;
394}
395
396static Value collapseInnerMostShapeDims(PatternRewriter &b, Location loc,
397 int numDims, Value val) {
398 auto memRefTy = cast<MemRefType>(val.getType());
399 auto shape = memRefTy.getShape();
400 int64_t newInnerMostDim = std::accumulate(shape.end() - numDims, shape.end(),
401 1, std::multiplies<>());
402 SmallVector<int64_t, 4> newShape{shape.begin(), shape.end() - numDims + 1};
403 newShape[shape.size() - numDims] = newInnerMostDim;
404 auto newNumDims = newShape.size();
405 auto *ctx = b.getContext();
406 auto newMemRefTy = MemRefType::get(
407 newShape, memRefTy.getElementType(),
408 AffineMap::getMinorIdentityMap(newNumDims, newNumDims, ctx),
409 memRefTy.getMemorySpace());
410 auto reassocIndices =
411 getReassociationIndicesForCollapse(shape, newShape).value();
412 return b
413 .create<memref::CollapseShapeOp>(loc, newMemRefTy, val, reassocIndices)
414 .getResult();
415}
416
417// This pattern flatten multidimensional `vector.transfer_read` operations
418// replacing them with a `memref.collapse_shape`, a 1D `vector.transfer_read`,
419// and a `vector.shape_cast`.
421 : public OpConversionPattern<vector::TransferReadOp> {
422 using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern;
423
424 LogicalResult
425 matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
426 ConversionPatternRewriter &rewriter) const override {
427 // We can only deal with unmasked transfer ops with an identity permutation
428 // map.
429 if (!adaptor.getPermutationMap().isMinorIdentity() || adaptor.getMask())
430 return failure();
431 VectorType vectorTy = readOp.getVector().getType();
432 if (vectorTy.getRank() < 2)
433 return failure();
434 // Work only on bufferized reads
435 MemRefType memRefTy = dyn_cast<MemRefType>(adaptor.getSource().getType());
436 if (!memRefTy)
437 return failure();
438 auto memRefShape = memRefTy.getShape();
439 auto vecShape = vectorTy.getShape();
440
441 auto newVectorTy =
442 VectorType::get({std::accumulate(vecShape.begin(), vecShape.end(), 1,
443 std::multiplies<>())},
444 vectorTy.getElementType());
445 AffineMap layout = memRefTy.getLayout().getAffineMap();
446 auto newIndices =
447 collapseInnerMostDimIndices(rewriter, readOp.getLoc(), vecShape.size(),
448 adaptor.getIndices(), memRefShape, layout);
449 auto newSource = collapseInnerMostShapeDims(
450 rewriter, readOp.getLoc(), vecShape.size(), adaptor.getSource());
451 auto newVector = rewriter.create<vector::TransferReadOp>(
452 readOp.getLoc(), newVectorTy, newSource, newIndices);
453
454 auto inBoundsArrayAttrOpt = adaptor.getInBounds();
455 if (inBoundsArrayAttrOpt) {
456 SmallVector<bool> inBounds =
457 llvm::to_vector(inBoundsArrayAttrOpt.getAsValueRange<BoolAttr>());
458 SmallVector<bool> newInBounds({false});
459 newInBounds[0] = std::all_of(inBounds.begin(), inBounds.end(),
460 [](bool v) { return v; });
461 newVector.getProperties().setInBounds(
462 rewriter.getBoolArrayAttr(newInBounds));
463 }
464
465 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, vectorTy,
466 newVector);
467
468 return success();
469 }
470};
471
472// This pattern flatten multidimensional `vector.transfer_write` operations
473// replacing them with a `memref.collapse_shape`, a `vector.shape_cast`, and a
474// 1D `vector.transfer_write`,
476 : public OpConversionPattern<vector::TransferWriteOp> {
477 using OpConversionPattern<vector::TransferWriteOp>::OpConversionPattern;
478
479 LogicalResult
480 matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
481 ConversionPatternRewriter &rewriter) const override {
482 // We can only deal with unmasked transfer ops with an identity permutation
483 // map.
484 if (!adaptor.getPermutationMap().isMinorIdentity() || adaptor.getMask())
485 return failure();
486 VectorType vectorTy = cast<VectorType>(adaptor.getVector().getType());
487 if (vectorTy.getRank() < 2)
488 return failure();
489 // Work only on bufferized reads
490 MemRefType memRefTy = dyn_cast<MemRefType>(adaptor.getSource().getType());
491 if (!memRefTy)
492 return failure();
493 auto memRefShape = memRefTy.getShape();
494 auto vecShape = vectorTy.getShape();
495
496 auto newVectorTy =
497 VectorType::get({std::accumulate(vecShape.begin(), vecShape.end(), 1,
498 std::multiplies<>())},
499 vectorTy.getElementType());
500 AffineMap layout = memRefTy.getLayout().getAffineMap();
501 auto newVector = rewriter
502 .create<vector::ShapeCastOp>(
503 writeOp.getLoc(), newVectorTy, adaptor.getVector())
504 .getResult();
505 auto newIndices =
506 collapseInnerMostDimIndices(rewriter, writeOp.getLoc(), vecShape.size(),
507 adaptor.getIndices(), memRefShape, layout);
508 auto newSource = collapseInnerMostShapeDims(
509 rewriter, writeOp.getLoc(), vecShape.size(), adaptor.getSource());
510
511 auto newOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
512 writeOp, newVector, newSource, newIndices);
513
514 auto inBoundsArrayAttrOpt = adaptor.getInBounds();
515 if (inBoundsArrayAttrOpt) {
516 SmallVector<bool> inBounds =
517 llvm::to_vector(inBoundsArrayAttrOpt.getAsValueRange<BoolAttr>());
518 SmallVector<bool> newInBounds({false});
519 newInBounds[0] = std::all_of(inBounds.begin(), inBounds.end(),
520 [](bool v) { return v; });
521 newOp.getProperties().setInBounds(rewriter.getBoolArrayAttr(newInBounds));
522 }
523
524 return success();
525 }
526};
527
528// This pattern extracts an implicit transposition of the 2 innermost
529// dimensions of `rhs` in a gemm-like contraction op, making it an explicit
530// `vector.transpose` op.
531// If `rhs` is coming from a widening op (`extf`/`extsi`/`extui`), the
532// transposition will be hoisted above the widening op.
534 : public OpConversionPattern<vector::ContractionOp> {
535 using OpConversionPattern<vector::ContractionOp>::OpConversionPattern;
536
537 static VectorType getTransposedVectorType(VectorType vecTy) {
538 SmallVector<int64_t> shape{vecTy.getShape()};
539 auto nDim = shape.size();
540 int64_t dimNm1 = shape[nDim - 1];
541 shape[nDim - 1] = shape[nDim - 2];
542 shape[nDim - 2] = dimNm1;
543 auto elemTy = vecTy.getElementType();
544 return VectorType::get(shape, elemTy);
545 }
546
547 LogicalResult
548 matchAndRewrite(vector::ContractionOp contractOp, OpAdaptor adaptor,
549 ConversionPatternRewriter &rewriter) const override {
550 if (!isGemmBTransposedContractionOp(contractOp))
551 return failure();
552
553 Location loc = contractOp.getLoc();
554 auto *ctx = rewriter.getContext();
555
556 Value rhsVal = adaptor.getRhs();
557 VectorType rhsVecTy = contractOp.getRhsType();
558 Type rhsElemTy = rhsVecTy.getElementType();
559
560 bool doExtF = false, doExtSI = false, doExtUI = false;
561 if (auto extfRhsOp = rhsVal.getDefiningOp<arith::ExtFOp>()) {
562 rhsVal = extfRhsOp.getIn();
563 rhsVecTy = cast<VectorType>(rhsVal.getType());
564 doExtF = true;
565 } else if (auto extsiRhsOp = rhsVal.getDefiningOp<arith::ExtSIOp>()) {
566 rhsVal = extsiRhsOp.getIn();
567 rhsVecTy = cast<VectorType>(rhsVal.getType());
568 doExtSI = true;
569 } else if (auto extuiRhsOp = rhsVal.getDefiningOp<arith::ExtUIOp>()) {
570 rhsVal = extuiRhsOp.getIn();
571 rhsVecTy = cast<VectorType>(rhsVal.getType());
572 doExtUI = true;
573 }
574
575 int64_t nDim = rhsVecTy.getShape().size();
576 SmallVector<int64_t> rhsPermutation;
577 for (int64_t i = 0; i < nDim - 2; i++)
578 rhsPermutation.push_back(i);
579 rhsPermutation.push_back(nDim - 1);
580 rhsPermutation.push_back(nDim - 2);
581 auto transpRhsVecTy = getTransposedVectorType(rhsVecTy);
582 rhsVal = rewriter
583 .create<vector::TransposeOp>(loc, transpRhsVecTy, rhsVal,
584 rhsPermutation)
585 .getResult();
586
587 if (doExtF)
588 rhsVal =
589 rewriter
590 .create<arith::ExtFOp>(
591 loc, VectorType::get(transpRhsVecTy.getShape(), rhsElemTy),
592 rhsVal)
593 .getOut();
594 if (doExtSI)
595 rhsVal =
596 rewriter
597 .create<arith::ExtSIOp>(
598 loc, VectorType::get(transpRhsVecTy.getShape(), rhsElemTy),
599 rhsVal)
600 .getOut();
601 if (doExtUI)
602 rhsVal =
603 rewriter
604 .create<arith::ExtUIOp>(
605 loc, VectorType::get(transpRhsVecTy.getShape(), rhsElemTy),
606 rhsVal)
607 .getOut();
608
609 SmallVector<AffineMap, 4> oldIdxMaps(contractOp.getIndexingMapsArray());
610
611 nDim = oldIdxMaps[1].getNumDims();
612 SmallVector<int64_t> innerDimPerm;
613 for (int64_t i = 0; i < nDim - 2; i++)
614 innerDimPerm.push_back(i);
615 innerDimPerm.push_back(nDim - 1);
616 innerDimPerm.push_back(nDim - 2);
617 auto transpPermMap = AffineMap::getPermutationMap(innerDimPerm, ctx);
618
619 auto newIdxMaps = rewriter.getAffineMapArrayAttr(
620 {oldIdxMaps[0], oldIdxMaps[1].compose(transpPermMap), oldIdxMaps[2]});
621
622 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
623 contractOp, contractOp.getResult().getType(), adaptor.getLhs(), rhsVal,
624 adaptor.getAcc(), newIdxMaps, contractOp.getIteratorTypes());
625
626 return success();
627 }
628};
629
630//============================================================================//
631//============ AIE2 canonicalization conversion patterns ===============//
632//============================================================================//
633
635 : public OpRewritePattern<vector::InsertOp> {
636
637 using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
638
639 LogicalResult matchAndRewrite(vector::InsertOp insOp,
640 PatternRewriter &rewriter) const override {
641 auto insSrcTy = dyn_cast<VectorType>(insOp.getSourceType());
642 if (!insSrcTy)
643 return failure();
644
645 auto srcShape = insSrcTy.getShape();
646 auto dstShape = insOp.getDestVectorType().getShape();
647
648 unsigned long numLeadUnitDimDst = 0;
649 while (numLeadUnitDimDst < dstShape.size() &&
650 dstShape[numLeadUnitDimDst] == 1)
651 numLeadUnitDimDst++;
652
653 if (!numLeadUnitDimDst)
654 return failure();
655
656 unsigned long numLeadUnitDimSrc = 0;
657 while (numLeadUnitDimSrc < srcShape.size() &&
658 srcShape[numLeadUnitDimSrc] == 1)
659 numLeadUnitDimSrc++;
660
661 SmallVector<int64_t> nonLeadUnitDimDstShape(
662 dstShape.begin() + numLeadUnitDimDst, dstShape.end());
663 SmallVector<int64_t> nonLeadUnitDimSrcShape(
664 srcShape.begin() + numLeadUnitDimSrc, srcShape.end());
665
666 if (nonLeadUnitDimSrcShape != nonLeadUnitDimDstShape)
667 return failure();
668
669 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
670 insOp, insOp.getDestVectorType(), insOp.getSource());
671 return success();
672 }
673};
674
675//============================================================================//
676//================ Common AIE canonicalization configuration =================//
677//============================================================================//
678static void
679configureCommonAIECanonicalizeLegalizations(ConversionTarget &target,
680 TargetBackend backend) {
681 target.addLegalDialect<arith::ArithDialect, affine::AffineDialect,
682 memref::MemRefDialect, vector::VectorDialect>();
683}
684
685static void
686populateCommonAIECanonicalizeConversionPatterns(RewritePatternSet &patterns,
687 TargetBackend backend) {
689 patterns.getContext());
690}
691
692//============================================================================//
693//============== AIEv1-specific canonicalization configuration ===============//
694//============================================================================//
695
696static void configureAIEv1CanonicalizeLegalizations(ConversionTarget &target,
697 TargetBackend backend) {
698 target.addDynamicallyLegalOp<vector::TransferReadOp>(
699 [](vector::TransferReadOp op) {
700 return !op.getPermutationMap().isConstant() &&
701 getTransferReadAlignmentOffset(op, op.getVectorType(), 128)
702 .value_or(0) == 0;
703 });
704}
705
706static void
707populateAIEv1CanonicalizeConversionPatterns(RewritePatternSet &patterns,
708 TargetBackend backend) {
709 patterns.add<SplitUnalignedTransferReadPattern>(patterns.getContext(), 512,
710 128);
711}
712
713//============================================================================//
714//============== AIE2-specific canonicalization configuration ===============//
715//============================================================================//
716
717static void configureAIE2CanonicalizeLegalizations(ConversionTarget &target,
718 TargetBackend backend) {
719 target.addDynamicallyLegalOp<vector::TransferReadOp>(
720 [](vector::TransferReadOp op) {
721 return !op.getPermutationMap().isConstant() &&
722 getTransferReadAlignmentOffset(op, op.getVectorType(), 256)
723 .value_or(0) == 0 &&
724 op.getVector().getType().getRank() < 2;
725 });
726 target.addDynamicallyLegalOp<vector::TransferWriteOp>(
727 [](vector::TransferWriteOp op) {
728 return cast<VectorType>(op.getVector().getType()).getRank() < 2;
729 });
730 target.addDynamicallyLegalOp<vector::ContractionOp>(
731 [](vector::ContractionOp op) {
732 return !isGemmBTransposedContractionOp(op);
733 });
734}
735
736static void
737populateAIE2CanonicalizeConversionPatterns(RewritePatternSet &patterns,
738 TargetBackend backend) {
739 patterns.add<SplitUnalignedTransferReadPattern>(patterns.getContext(), 1024,
740 256);
741 patterns
743 FlattenMultDimTransferWritePattern>(patterns.getContext());
744}
745
746//============================================================================//
747//=================== Common AIE Canonicalization Passes =====================//
748//============================================================================//
749
751 : public PassWrapper<VectorBroadcastLoweringPass, OperationPass<>> {
752
753 void runOnOperation() override {
754 auto *op = getOperation();
755 MLIRContext *context = &getContext();
756 RewritePatternSet patterns(context);
757 populateVectorBroadcastLoweringPatterns(patterns);
759 patterns.getContext());
760
761 (void)applyPatternsGreedily(op, std::move(patterns));
762 }
763};
764
765static std::unique_ptr<::mlir::Pass> createVectorBroadcastLoweringPass() {
766 return std::make_unique<VectorBroadcastLoweringPass>();
767}
768
769// This pass converts standard vector ops into a subset of `Vector` ops more
770// amenable to being converted to `AIEVec`. So far, this process consists of
771// two steps:
772// 1) Replace splat transfer reads with contiguous transfer reads followed
773// by `extract` + `splat` operations.
774// 2) Split unaligned transfer reads into a wider aligned transfer read
775// followed by a `vector.extract_strided_slice` operation.
777 : public PassWrapper<CanonicalizeVectorForAIEVecPass, OperationPass<>> {
778 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CanonicalizeVectorForAIEVecPass)
779
783
790
791 // In case we want to register this pass as a standalone pass for test
792 // purposes.
793 StringRef getArgument() const final {
794 return "test-canonicalize-vector-for-aievec";
795 }
796
797 StringRef getDescription() const final {
798 return "Canonicalize vector operations for AIEVec conversion";
799 }
800
801 void getDependentDialects(DialectRegistry &registry) const override {
802 registry.insert<arith::ArithDialect, memref::MemRefDialect,
803 vector::VectorDialect, affine::AffineDialect>();
804 }
805
806 Option<std::string> aieTarget{
807 *this, "aie-target",
808 llvm::cl::desc("Select AIE version: \"aie\" or \"aie2\". This will "
809 "determine the vector size and available operations."),
810 llvm::cl::init("aie")};
811
812 Option<std::string> targetBackend{
813 *this, "target-backend",
814 llvm::cl::desc("Select translation backend: \"cpp\" or \"llvmir\". This "
815 "will determine the aievec operations used to convert "
816 "from vector dialect."),
817 llvm::cl::init("cpp")};
818
819 void runOnOperation() override {
820 auto *op = getOperation();
821 MLIRContext *context = &getContext();
822 RewritePatternSet patterns(context);
823 ConversionTarget target(*context);
824
825 AIEArch aieVersion = decodeAIETarget(aieTarget);
826 if (aieVersion == AIEArch::UNKNOWN) {
827 op->emitError() << "unknown AIE target '" << aieTarget << "'";
828 signalPassFailure();
829 return;
830 }
831
832 TargetBackend backend = decodeTargetBackend(targetBackend);
833 if (backend == TargetBackend::UNKNOWN) {
834 op->emitError() << "unknown target backend '" << targetBackend << "'";
835 signalPassFailure();
836 return;
837 }
838 if (backend == TargetBackend::LLVMIR && aieVersion == AIEArch::AIE) {
839 op->emitError() << "targetting LLVM IR is not supported for AIEv1";
840 signalPassFailure();
841 return;
842 }
843
844 populateCommonAIECanonicalizeConversionPatterns(patterns, backend);
845 configureCommonAIECanonicalizeLegalizations(target, backend);
846 if (aieVersion == AIEArch::AIE) {
847 populateAIEv1CanonicalizeConversionPatterns(patterns, backend);
848 configureAIEv1CanonicalizeLegalizations(target, backend);
849 } else {
850 populateAIE2CanonicalizeConversionPatterns(patterns, backend);
851 configureAIE2CanonicalizeLegalizations(target, backend);
852 }
853
854 if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
855 signalPassFailure();
856 }
857 }
858};
859
860static std::unique_ptr<::mlir::Pass> createCanonicalizeVectorForAIEVecPass(
861 const CanonicalizeVectorForAIEVecOptions &options) {
862 return std::make_unique<CanonicalizeVectorForAIEVecPass>(options);
863}
864
866 : public PassWrapper<HoistCastOpToDataSourcePass, OperationPass<>> {
867
868 void runOnOperation() override {
869 auto *op = getOperation();
870 MLIRContext *context = &getContext();
871 RewritePatternSet patterns(context);
872
873 patterns.add<HoistCastOpToDataSourcePattern>(patterns.getContext());
874
875 (void)applyPatternsGreedily(op, std::move(patterns));
876 }
877};
878
879static std::unique_ptr<::mlir::Pass> createHoistCastOpToDataSourcePass() {
880 return std::make_unique<HoistCastOpToDataSourcePass>();
881}
882
884 : public PassWrapper<ReorderOperationsPass, OperationPass<>> {
885
886 void runOnOperation() override {
887 auto *op = getOperation();
888 MLIRContext *context = &getContext();
889 RewritePatternSet patterns(context);
890
892 patterns.getContext(),
893 [](arith::ExtSIOp extOp, vector::BroadcastOp bcastOp) -> Type {
894 Type extInElemTy = extOp.getIn().getType();
895 auto extInVecTy = dyn_cast<VectorType>(extInElemTy);
896 if (extInVecTy)
897 extInElemTy = extInVecTy.getElementType();
898 return VectorType::get(bcastOp.getResultVectorType().getShape(),
899 extInElemTy);
900 });
901
902 (void)applyPatternsGreedily(op, std::move(patterns));
903 }
904};
905
906static std::unique_ptr<::mlir::Pass> createReorderOperationsPass() {
907 return std::make_unique<ReorderOperationsPass>();
908}
909
910//============================================================================//
911//=============== Main Vector2Vector Pipeline Configuration ==================//
912//============================================================================//
913
915 OpPassManager &pm, const CanonicalizeVectorForAIEVecOptions &options) {
916 // Add `Vector` code canonicalization passes
917 // TODO: Add passes to unroll vector with unsupported types
918 // TODO: Add passes to split vectors that won't fit in registers
919 if (decodeTargetBackend(options.targetBackend) == TargetBackend::LLVMIR)
920 pm.addPass(createReorderOperationsPass());
921 pm.addPass(createCopyRemovalPass());
922 pm.addPass(createVectorBroadcastLoweringPass());
923 pm.addPass(createCanonicalizeVectorForAIEVecPass(options));
924 if (decodeTargetBackend(options.targetBackend) == TargetBackend::CPP)
925 pm.addPass(createHoistCastOpToDataSourcePass());
926}
std::unique_ptr<::mlir::Pass > createCopyRemovalPass()
Create a pass that removes unnecessary Copy operations.
std::optional< int64_t > getTransferReadAlignmentOffset(TransferReadLikeOp readOp, mlir::VectorType vType, int64_t alignment)
int32_t getElementSizeInBits(mlir::VectorType type)
Definition AIEVecUtils.h:49
void buildCanonicalizeVectorForAIEVec(mlir::OpPassManager &pm, const CanonicalizeVectorForAIEVecOptions &options)
TargetBackend
Definition Passes.h:27
AIEArch
Definition Passes.h:21
void getDependentDialects(DialectRegistry &registry) const override
CanonicalizeVectorForAIEVecPass(const CanonicalizeVectorForAIEVecOptions &options)
LogicalResult matchAndRewrite(vector::InsertOp insOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static VectorType getTransposedVectorType(VectorType vecTy)
LogicalResult matchAndRewrite(vector::ContractionOp contractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
SplitUnalignedTransferReadPattern(MLIRContext *context, int64_t maxVectorSize, int64_t alignment)
LogicalResult matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(UnaryOpB bOp, PatternRewriter &rewriter) const override
std::function< Type(UnaryOpA aOp, UnaryOpB bOp)> InferTypeB2AFnTy
SwapUnaryOpsPattern(MLIRContext *context, InferTypeB2AFnTy inferType)
Options for the "canonicalize-vector-for-aievec" pipeline.
Definition Passes.h:41
PassOptions::Option< std::string > targetBackend
Definition Passes.h:47
PassOptions::Option< std::string > aieTarget
Definition Passes.h:42