MLIR-AIE
VectorToVectorConversions.cpp
Go to the documentation of this file.
1//===-VectorToVectorConversions.cpp - Conversions within Vector -*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2023-2024 Advanced Micro Devices, Inc.
8//
9//===----------------------------------------------------------------------===//
10// This file contains conversions and rewrites to the Vector dialect to make
11// it compatible with the available vector instructions in AIE architectures
12//===----------------------------------------------------------------------===//
13
17#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
18#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
19#include "mlir/Dialect/Affine/IR/AffineOps.h"
20#include "mlir/Dialect/Func/IR/FuncOps.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/UB/IR/UBOps.h"
23#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
24#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
25#include "mlir/IR/Matchers.h"
26#include "mlir/IR/PatternMatch.h"
27#include "mlir/Pass/PassManager.h"
28#include "mlir/Transforms/DialectConversion.h"
29#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30#include <algorithm>
31
32#define DEBUG_TYPE "aievec-canonicalization"
33
34using namespace mlir;
35using namespace arith;
36using namespace vector;
37using namespace xilinx;
38using namespace xilinx::aievec;
39
40//============================================================================//
41//================== Common AIE canonicalization analysis ====================//
42//============================================================================//
43
44static TargetBackend decodeTargetBackend(const std::string &backend) {
45 if (!backend.empty()) {
46 if (backend == "llvmir")
47 return TargetBackend::LLVMIR;
48 if (backend != "cpp")
49 return TargetBackend::UNKNOWN;
50 }
51 return TargetBackend::CPP;
52}
53
54static AIEArch decodeAIETarget(const std::string &target) {
55 if (!target.empty()) {
56 if (target == "aieml" || target == "aie2" || target == "aie2p")
57 return AIEArch::AIE2;
58 if (target != "aie")
59 return AIEArch::UNKNOWN;
60 }
61 return AIEArch::AIE;
62}
63
64//============================================================================//
65//================== Common AIE canonicalization analysis ====================//
66//============================================================================//
67
68static bool isGemmBTransposedContractionOp(vector::ContractionOp op) {
69 if (op.getKind() != vector::CombiningKind::ADD)
70 return false;
71
72 // Get and check shape of operands
73 auto lhsShape = op.getLhsType().getShape();
74 auto rhsShape = op.getRhsType().getShape();
75 auto accShape = cast<ShapedType>(op.getAccType()).getShape();
76 if (lhsShape.size() < 2 || rhsShape.size() < 2 || accShape.size() < 2)
77 return false;
78
79 // Check that the innermost iterators match gemm-like iterators
80 SmallVector<vector::IteratorType> iterators = op.getIteratorTypesArray();
81 if (iterators.size() < 3)
82 return false;
83 auto innerMostIterators =
84 SmallVector<vector::IteratorType>(iterators.end() - 3, iterators.end());
85 if (vector::IteratorType::parallel != innerMostIterators[0] ||
86 vector::IteratorType::parallel != innerMostIterators[1] ||
87 vector::IteratorType::reduction != innerMostIterators[2])
88 return false;
89
90 // Get indexing maps of iterators for operands
91 SmallVector<AffineMap, 4> indexingMaps(op.getIndexingMapsArray());
92 SmallVector<int64_t> outerMostResults;
93 for (int64_t i = 0; i < indexingMaps[0].getNumResults() - 2; i++)
94 outerMostResults.push_back(i);
95
96 auto innerLhsMap = indexingMaps[0].dropResults(outerMostResults);
97 auto innerRhsMap = indexingMaps[1].dropResults(outerMostResults);
98 auto innerAccMap = indexingMaps[2].dropResults(outerMostResults);
99
100 // Check whether they conform to a "transposed B" gemm
101 auto *ctx = op.getContext();
102 auto mmAidxMap =
103 AffineMap::getPermutationMap(ArrayRef<unsigned>{1, 0, 2}, ctx)
104 .dropResults(0);
105 auto mmBidxMap =
106 AffineMap::getPermutationMap(ArrayRef<unsigned>{0, 1, 2}, ctx)
107 .dropResults(0);
108 auto mmCidxMap =
109 AffineMap::getPermutationMap(ArrayRef<unsigned>{2, 0, 1}, ctx)
110 .dropResults(0);
111 int64_t numOuterMostDims = indexingMaps[0].getNumDims() - 3;
112 return innerLhsMap == mmAidxMap.shiftDims(numOuterMostDims) &&
113 innerRhsMap == mmBidxMap.shiftDims(numOuterMostDims) &&
114 innerAccMap == mmCidxMap.shiftDims(numOuterMostDims);
115}
116
117//============================================================================//
118//============ Common AIE canonicalization conversion patterns ===============//
119//============================================================================//
120
121// This pattern converts a `vector.transfer_read` with an unaligned access
122// into an aligned `vector.transfer_read` twice as long, followed by a
123// `vector.extract_strided_slice` selecting the subvector matching the
124// original `vector.transfer_read`.
126 : public OpConversionPattern<vector::TransferReadOp> {
127 using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern;
128
130 int64_t alignment)
131 : OpConversionPattern<vector::TransferReadOp>(context),
133
134 LogicalResult
135 matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
136 ConversionPatternRewriter &rewriter) const override {
137 // Check that it's not a splat transfer read.
138 if (adaptor.getPermutationMap().isConstant())
139 return failure();
140
141 // Check if the transfer is unaligned.
142 auto vType = readOp.getVectorType();
143 int64_t offset =
145 .value_or(0);
146 if (offset == 0)
147 return failure();
148
149 // Verify that we can load a vector 2x as long as the original
150 auto vLen = vType.getShape().back();
151 auto longVecTy = VectorType::get(2 * vLen, vType.getElementType());
152 auto longVecSize = getElementSizeInBits(vType) * 2 * vLen;
153 if (longVecSize > maxVectorSize)
154 return failure();
155
156 // Calculate the aligned indices for the lower and higher parts.
157 // TODO: Add support for cases where the offset is greater than the
158 // TODO: vector length.
159 auto loc = readOp.getLoc();
160 Value oldInnerMostIdx = adaptor.getIndices().back();
161 auto offsetCorrectionMap =
162 AffineMap::get(1, 0, getAffineDimExpr(0, readOp.getContext()) - offset);
163 Value newInnerMostIdx = rewriter
164 .create<affine::AffineApplyOp>(
165 readOp.getLoc(), offsetCorrectionMap,
166 SmallVector<Value, 1>({oldInnerMostIdx}))
167 .getResult();
168 SmallVector<Value, 8> alignedIdx;
169 alignedIdx.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
170 alignedIdx[alignedIdx.size() - 1] = newInnerMostIdx;
171
172 // Create the aligned transfer read for a vector 2x as long that covers the
173 // elements of the unaligned vector.
174 auto newReadOp = rewriter.create<vector::TransferReadOp>(
175 loc, longVecTy, adaptor.getBase(), alignedIdx, adaptor.getPadding());
176
177 // Create a `vector.extract_strided_slice` to extract the unaligned vector.
178 rewriter.replaceOpWithNewOp<vector::ExtractStridedSliceOp>(
179 readOp, newReadOp.getResult(), offset, vLen, 1);
180
181 return success();
182 }
183
186};
187
188// This pattern converts a `vector.transfer_read` with a splat permutation map
189// into a contiguous `vector.transfer_read` followed by a `vector.extract` to
190// obtain the splat value and a `vector.broadcast` to broadcast it into a
191// vector of the right size.
193 : public OpConversionPattern<vector::TransferReadOp> {
194 using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern;
195
197 : OpConversionPattern<vector::TransferReadOp>(context) {}
198
199 LogicalResult
200 matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
201 ConversionPatternRewriter &rewriter) const override {
202 AffineMap map = readOp.getPermutationMap();
203 if (!map.isConstant())
204 return failure();
205
206 Value srcMemRef = adaptor.getBase();
207 SmallVector<Value, 8> indices;
208 Value newIdx;
209 int64_t offset = 0;
210 // If it's a zero-rank memory access
211 if (cast<MemRefType>(srcMemRef.getType()).getRank() == 0) {
212 srcMemRef = rewriter
213 .create<memref::ExpandShapeOp>(
214 readOp.getLoc(), SmallVector<int64_t, 1>({1}),
215 srcMemRef, SmallVector<ReassociationIndices, 1>({}))
216 .getResult();
217 newIdx = rewriter.create<arith::ConstantOp>(readOp.getLoc(),
218 rewriter.getIndexAttr(0L));
219 indices.push_back(newIdx);
220 } else {
221 indices.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
222 newIdx = indices[indices.size() - 1];
223 // If the innermost index comes from an `affine.apply` op, take the base
224 // as the new innermost index for the new `vector.transfer_read`, and the
225 // offset as the index for the `aievec.broadcast` op.
226 if (auto applyOp = newIdx.getDefiningOp<affine::AffineApplyOp>())
227 if (applyOp.getAffineMap().getNumDims() == 1) {
228 newIdx = applyOp.getMapOperands()[0];
229 offset = applyOp.getAffineMap().compose(ArrayRef<int64_t>{0})[0];
230 }
231 }
232 // XXX: We assume we are reading 1D vectors
233 int64_t vlen = readOp.getVector().getType().getShape()[0];
234 if (offset >= vlen) {
235 // If the splat element is beyond the first vector, we calculate the
236 // address of the vector containing the element.
237 int64_t numElemsToSkip = vlen * (offset / vlen);
238 offset = offset % vlen;
239 auto newAddrMap = AffineMap::get(
240 1, 0, getAffineDimExpr(0, readOp.getContext()) + numElemsToSkip);
241 newIdx =
242 rewriter
243 .create<affine::AffineApplyOp>(readOp.getLoc(), newAddrMap,
244 SmallVector<Value, 1>({newIdx}))
245 .getResult();
246 }
247 indices[indices.size() - 1] = newIdx;
248 auto newReadOp = rewriter.create<vector::TransferReadOp>(
249 readOp.getLoc(), readOp.getVector().getType(), srcMemRef, indices,
250 adaptor.getPadding());
251 auto extractOp = rewriter.create<vector::ExtractOp>(
252 readOp.getLoc(), newReadOp.getResult(), ArrayRef<int64_t>{offset});
253 rewriter.replaceOpWithNewOp<vector::SplatOp>(
254 readOp, newReadOp.getVector().getType(), extractOp.getResult());
255 return success();
256 }
257};
258
259// This pattern moves cast operations as close as possible to the source of
260// the data. This helps to simplify dealing with patterns that may vary only
261// by these sorts of casts between data manipulation operations and arithmetic
262// ops.
263// TODO: Generalize this op and instantiate for different types of cast ops.
265 HoistCastOpToDataSourcePattern(MLIRContext *context)
266 : RewritePattern(arith::ExtSIOp::getOperationName(), /*benefit=*/1,
267 context) {}
268
269 LogicalResult matchAndRewrite(Operation *op,
270 PatternRewriter &rewriter) const override {
271 arith::ExtSIOp extOp = cast<arith::ExtSIOp>(op);
272 Operation *defOp = extOp.getIn().getDefiningOp();
273 // If it's a data source op, we're done.
274 if (!defOp || isa<vector::TransferReadOp, memref::LoadOp,
275 affine::AffineLoadOp, func::CallOp>(defOp))
276 return failure();
277
278 // At the moment, we only accept ops we know we can swap with cast.
279 if (!isa<vector::BroadcastOp, vector::ExtractOp, vector::SplatOp,
280 vector::ExtractStridedSliceOp>(defOp))
281 return failure();
282
283 Type extOpInTy = extOp.getIn().getType();
284 SmallVector<Value, 4> inputs;
285 for (Value operand : defOp->getOperands()) {
286 Type operandTy = operand.getType();
287 VectorType extOpInVecTy = dyn_cast<VectorType>(extOpInTy);
288 VectorType operandVecTy = dyn_cast<VectorType>(operandTy);
289 if (operandTy == extOpInTy) {
290 Type outTy = extOp.getOut().getType();
291 inputs.push_back(
292 rewriter.create<arith::ExtSIOp>(extOp.getLoc(), outTy, operand)
293 .getOut());
294 } else if (extOpInVecTy && extOpInVecTy.getElementType() == operandTy) {
295 // Promote from vector to scalar -> scalar conversion for this operand
296 Type outTy =
297 cast<VectorType>(extOp.getOut().getType()).getElementType();
298 inputs.push_back(
299 rewriter.create<arith::ExtSIOp>(extOp.getLoc(), outTy, operand)
300 .getOut());
301 } else if (operandVecTy && operandVecTy.getElementType() == extOpInTy) {
302 // Promote from scalar to vector -> vector conversion for this operand
303 Type outTy =
304 VectorType::get(operandVecTy.getShape(), extOp.getOut().getType());
305 inputs.push_back(
306 rewriter.create<arith::ExtSIOp>(extOp.getLoc(), outTy, operand)
307 .getOut());
308 } else if (extOpInVecTy && operandVecTy &&
309 (extOpInVecTy.getElementType() ==
310 operandVecTy.getElementType())) {
311 // Hoist through a vector shape change
312 Type outTy = VectorType::get(
313 operandVecTy.getShape(),
314 cast<VectorType>(extOp.getOut().getType()).getElementType());
315 inputs.push_back(
316 rewriter.create<arith::ExtSIOp>(extOp.getLoc(), outTy, operand)
317 .getOut());
318 } else {
319 inputs.push_back(operand);
320 }
321 }
322
323 auto *newOp =
324 rewriter.create(extOp->getLoc(), defOp->getName().getIdentifier(),
325 inputs, {extOp.getOut().getType()}, defOp->getAttrs());
326 rewriter.replaceOp(extOp, newOp->getResult(0));
327 return success();
328 }
329};
330
331// This pattern swaps a UnaryOpA followed by UnaryOpB. This pattern can be used
332// to improve pattern matching for mixed-type arithmetic ops, by getting sign
333// extension ops closer to the single-type arithmetic operations.
334template <class UnaryOpA, class UnaryOpB>
335struct SwapUnaryOpsPattern : public OpRewritePattern<UnaryOpB> {
337 // This function takes the chain of operations A->B, and returns the new type
338 // between B and A after the swap.
339 using InferTypeB2AFnTy = std::function<Type(UnaryOpA aOp, UnaryOpB bOp)>;
341
342 SwapUnaryOpsPattern(MLIRContext *context, InferTypeB2AFnTy inferType)
343 : OpRewritePattern<UnaryOpB>(context), inferTypeB2A(inferType) {}
344
345 LogicalResult matchAndRewrite(UnaryOpB bOp,
346 PatternRewriter &rewriter) const override {
347 static_assert(
348 UnaryOpA::template hasTrait<OpTrait::OneOperand>(),
349 "SwapUnaryOps can only be instantiated for single-operand ops");
350 static_assert(
351 UnaryOpB::template hasTrait<OpTrait::OneOperand>(),
352 "SwapUnaryOps can only be instantiated for single-operand ops");
353 UnaryOpA aOp = bOp.getOperand().template getDefiningOp<UnaryOpA>();
354 if (!aOp)
355 return rewriter.notifyMatchFailure(bOp, UnaryOpB::getOperationName() +
356 " not preceeded by " +
357 UnaryOpA::getOperationName());
358
359 Type newA2BTy = inferTypeB2A(aOp, bOp);
360
361 auto newA =
362 rewriter.create<UnaryOpB>(bOp->getLoc(), SmallVector<Type>({newA2BTy}),
363 aOp->getOperands(), bOp->getAttrs());
364 auto newB = rewriter.create<UnaryOpA>(
365 bOp->getLoc(), SmallVector<Type>({bOp.getResult().getType()}),
366 newA->getResults(), aOp->getAttrs());
367 rewriter.replaceOp(bOp, newB.getResult());
368 return success();
369 }
370};
371
372static SmallVector<Value> collapseInnerMostDimIndices(PatternRewriter &b,
373 Location loc, int numDims,
374 ValueRange indices,
375 ArrayRef<int64_t> shape,
376 AffineMap layout) {
377 // TODO: Don't assume trivial layout
378 assert(layout.isMinorIdentity() &&
379 "dimension collapse in non-identity layout is not implemented");
380 auto newIdxExpr = b.getAffineDimExpr(numDims - 1);
381 int64_t stride = 1;
382 for (int64_t dim = numDims - 2; dim >= 0; dim--) {
383 stride *= shape[shape.size() - (numDims - dim - 1)];
384 newIdxExpr = newIdxExpr + b.getAffineDimExpr(dim) * stride;
385 }
386 auto newIndexMap = AffineMap::get(numDims, 0, newIdxExpr);
387 Value newInnerMostIdxValue =
388 b.create<affine::AffineApplyOp>(loc, newIndexMap,
389 indices.take_back(numDims))
390 .getResult();
391 SmallVector<Value> newIdxRange;
392 for (auto idx : indices.drop_back(numDims))
393 newIdxRange.push_back(idx);
394 newIdxRange.push_back(newInnerMostIdxValue);
395 return newIdxRange;
396}
397
398static Value collapseInnerMostShapeDims(PatternRewriter &b, Location loc,
399 int numDims, Value val) {
400 auto memRefTy = cast<MemRefType>(val.getType());
401 auto shape = memRefTy.getShape();
402 int64_t newInnerMostDim = std::accumulate(shape.end() - numDims, shape.end(),
403 1, std::multiplies<>());
404 SmallVector<int64_t, 4> newShape{shape.begin(), shape.end() - numDims + 1};
405 newShape[shape.size() - numDims] = newInnerMostDim;
406 auto newNumDims = newShape.size();
407 auto *ctx = b.getContext();
408 auto newMemRefTy = MemRefType::get(
409 newShape, memRefTy.getElementType(),
410 AffineMap::getMinorIdentityMap(newNumDims, newNumDims, ctx),
411 memRefTy.getMemorySpace());
412 auto reassocIndices =
413 getReassociationIndicesForCollapse(shape, newShape).value();
414 return b
415 .create<memref::CollapseShapeOp>(loc, newMemRefTy, val, reassocIndices)
416 .getResult();
417}
418
419// This pattern flatten multidimensional `vector.transfer_read` operations
420// replacing them with a `memref.collapse_shape`, a 1D `vector.transfer_read`,
421// and a `vector.shape_cast`.
423 : public OpConversionPattern<vector::TransferReadOp> {
424 using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern;
425
426 LogicalResult
427 matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
428 ConversionPatternRewriter &rewriter) const override {
429 // We can only deal with unmasked transfer ops with an identity permutation
430 // map.
431 if (!adaptor.getPermutationMap().isMinorIdentity() || adaptor.getMask())
432 return failure();
433 VectorType vectorTy = readOp.getVector().getType();
434 if (vectorTy.getRank() < 2)
435 return failure();
436 // Work only on bufferized reads
437 MemRefType memRefTy = dyn_cast<MemRefType>(adaptor.getBase().getType());
438 if (!memRefTy)
439 return failure();
440 auto memRefShape = memRefTy.getShape();
441 auto vecShape = vectorTy.getShape();
442
443 auto newVectorTy =
444 VectorType::get({std::accumulate(vecShape.begin(), vecShape.end(), 1,
445 std::multiplies<>())},
446 vectorTy.getElementType());
447 AffineMap layout = memRefTy.getLayout().getAffineMap();
448 auto newIndices =
449 collapseInnerMostDimIndices(rewriter, readOp.getLoc(), vecShape.size(),
450 adaptor.getIndices(), memRefShape, layout);
451 auto newSource = collapseInnerMostShapeDims(
452 rewriter, readOp.getLoc(), vecShape.size(), adaptor.getBase());
453 auto newVector = rewriter.create<vector::TransferReadOp>(
454 readOp.getLoc(), newVectorTy, newSource, newIndices,
455 /*padding*/
456 arith::getZeroConstant(rewriter, readOp.getLoc(),
457 newVectorTy.getElementType()));
458
459 auto inBoundsArrayAttrOpt = adaptor.getInBounds();
460 if (inBoundsArrayAttrOpt) {
461 SmallVector<bool> inBounds =
462 llvm::to_vector(inBoundsArrayAttrOpt.getAsValueRange<BoolAttr>());
463 SmallVector<bool> newInBounds({false});
464 newInBounds[0] = std::all_of(inBounds.begin(), inBounds.end(),
465 [](bool v) { return v; });
466 newVector.getProperties().setInBounds(
467 rewriter.getBoolArrayAttr(newInBounds));
468 }
469
470 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, vectorTy,
471 newVector);
472
473 return success();
474 }
475};
476
477// This pattern flatten multidimensional `vector.transfer_write` operations
478// replacing them with a `memref.collapse_shape`, a `vector.shape_cast`, and a
479// 1D `vector.transfer_write`,
481 : public OpConversionPattern<vector::TransferWriteOp> {
482 using OpConversionPattern<vector::TransferWriteOp>::OpConversionPattern;
483
484 LogicalResult
485 matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
486 ConversionPatternRewriter &rewriter) const override {
487 // We can only deal with unmasked transfer ops with an identity permutation
488 // map.
489 if (!adaptor.getPermutationMap().isMinorIdentity() || adaptor.getMask())
490 return failure();
491 VectorType vectorTy = cast<VectorType>(adaptor.getValueToStore().getType());
492 if (vectorTy.getRank() < 2)
493 return failure();
494 // Work only on bufferized reads
495 MemRefType memRefTy = dyn_cast<MemRefType>(adaptor.getBase().getType());
496 if (!memRefTy)
497 return failure();
498 auto memRefShape = memRefTy.getShape();
499 auto vecShape = vectorTy.getShape();
500
501 auto newVectorTy =
502 VectorType::get({std::accumulate(vecShape.begin(), vecShape.end(), 1,
503 std::multiplies<>())},
504 vectorTy.getElementType());
505 AffineMap layout = memRefTy.getLayout().getAffineMap();
506 auto newVector =
507 rewriter
508 .create<vector::ShapeCastOp>(writeOp.getLoc(), newVectorTy,
509 adaptor.getValueToStore())
510 .getResult();
511 auto newIndices =
512 collapseInnerMostDimIndices(rewriter, writeOp.getLoc(), vecShape.size(),
513 adaptor.getIndices(), memRefShape, layout);
514 auto newSource = collapseInnerMostShapeDims(
515 rewriter, writeOp.getLoc(), vecShape.size(), adaptor.getBase());
516
517 auto newOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
518 writeOp, newVector, newSource, newIndices);
519
520 auto inBoundsArrayAttrOpt = adaptor.getInBounds();
521 if (inBoundsArrayAttrOpt) {
522 SmallVector<bool> inBounds =
523 llvm::to_vector(inBoundsArrayAttrOpt.getAsValueRange<BoolAttr>());
524 SmallVector<bool> newInBounds({false});
525 newInBounds[0] = std::all_of(inBounds.begin(), inBounds.end(),
526 [](bool v) { return v; });
527 newOp.getProperties().setInBounds(rewriter.getBoolArrayAttr(newInBounds));
528 }
529
530 return success();
531 }
532};
533
534// This pattern extracts an implicit transposition of the 2 innermost
535// dimensions of `rhs` in a gemm-like contraction op, making it an explicit
536// `vector.transpose` op.
537// If `rhs` is coming from a widening op (`extf`/`extsi`/`extui`), the
538// transposition will be hoisted above the widening op.
540 : public OpConversionPattern<vector::ContractionOp> {
541 using OpConversionPattern<vector::ContractionOp>::OpConversionPattern;
542
543 static VectorType getTransposedVectorType(VectorType vecTy) {
544 SmallVector<int64_t> shape{vecTy.getShape()};
545 auto nDim = shape.size();
546 int64_t dimNm1 = shape[nDim - 1];
547 shape[nDim - 1] = shape[nDim - 2];
548 shape[nDim - 2] = dimNm1;
549 auto elemTy = vecTy.getElementType();
550 return VectorType::get(shape, elemTy);
551 }
552
553 LogicalResult
554 matchAndRewrite(vector::ContractionOp contractOp, OpAdaptor adaptor,
555 ConversionPatternRewriter &rewriter) const override {
556 if (!isGemmBTransposedContractionOp(contractOp))
557 return failure();
558
559 Location loc = contractOp.getLoc();
560 auto *ctx = rewriter.getContext();
561
562 Value rhsVal = adaptor.getRhs();
563 VectorType rhsVecTy = contractOp.getRhsType();
564 Type rhsElemTy = rhsVecTy.getElementType();
565
566 bool doExtF = false, doExtSI = false, doExtUI = false;
567 if (auto extfRhsOp = rhsVal.getDefiningOp<arith::ExtFOp>()) {
568 rhsVal = extfRhsOp.getIn();
569 rhsVecTy = cast<VectorType>(rhsVal.getType());
570 doExtF = true;
571 } else if (auto extsiRhsOp = rhsVal.getDefiningOp<arith::ExtSIOp>()) {
572 rhsVal = extsiRhsOp.getIn();
573 rhsVecTy = cast<VectorType>(rhsVal.getType());
574 doExtSI = true;
575 } else if (auto extuiRhsOp = rhsVal.getDefiningOp<arith::ExtUIOp>()) {
576 rhsVal = extuiRhsOp.getIn();
577 rhsVecTy = cast<VectorType>(rhsVal.getType());
578 doExtUI = true;
579 }
580
581 int64_t nDim = rhsVecTy.getShape().size();
582 SmallVector<int64_t> rhsPermutation;
583 for (int64_t i = 0; i < nDim - 2; i++)
584 rhsPermutation.push_back(i);
585 rhsPermutation.push_back(nDim - 1);
586 rhsPermutation.push_back(nDim - 2);
587 auto transpRhsVecTy = getTransposedVectorType(rhsVecTy);
588 rhsVal = rewriter
589 .create<vector::TransposeOp>(loc, transpRhsVecTy, rhsVal,
590 rhsPermutation)
591 .getResult();
592
593 if (doExtF)
594 rhsVal =
595 rewriter
596 .create<arith::ExtFOp>(
597 loc, VectorType::get(transpRhsVecTy.getShape(), rhsElemTy),
598 rhsVal)
599 .getOut();
600 if (doExtSI)
601 rhsVal =
602 rewriter
603 .create<arith::ExtSIOp>(
604 loc, VectorType::get(transpRhsVecTy.getShape(), rhsElemTy),
605 rhsVal)
606 .getOut();
607 if (doExtUI)
608 rhsVal =
609 rewriter
610 .create<arith::ExtUIOp>(
611 loc, VectorType::get(transpRhsVecTy.getShape(), rhsElemTy),
612 rhsVal)
613 .getOut();
614
615 SmallVector<AffineMap, 4> oldIdxMaps(contractOp.getIndexingMapsArray());
616
617 nDim = oldIdxMaps[1].getNumDims();
618 SmallVector<int64_t> innerDimPerm;
619 for (int64_t i = 0; i < nDim - 2; i++)
620 innerDimPerm.push_back(i);
621 innerDimPerm.push_back(nDim - 1);
622 innerDimPerm.push_back(nDim - 2);
623 auto transpPermMap = AffineMap::getPermutationMap(innerDimPerm, ctx);
624
625 auto newIdxMaps = rewriter.getAffineMapArrayAttr(
626 {oldIdxMaps[0], oldIdxMaps[1].compose(transpPermMap), oldIdxMaps[2]});
627
628 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
629 contractOp, contractOp.getResult().getType(), adaptor.getLhs(), rhsVal,
630 adaptor.getAcc(), newIdxMaps, contractOp.getIteratorTypes());
631
632 return success();
633 }
634};
635
636/// Utility function to check if all provided indices are constant zero values.
637/// @return success() if all indices are constant zeros, failure() otherwise
638static LogicalResult isAllZeroOffsetAccess(mlir::OperandRange indices) {
639 if (!llvm::all_of(indices, [](Value val) {
640 IntegerAttr attr;
641 if (!matchPattern(val, m_Constant(&attr)))
642 return false;
643 return attr.getInt() == 0;
644 })) {
645 return failure();
646 }
647 return success();
648}
649
650/// Utility function to convert OpFoldResult offsets from a SubView operation
651/// into a vector of Values.
652static SmallVector<Value> opFoldResultsToValues(PatternRewriter &rewriter,
653 Location loc,
654 memref::SubViewOp subViewOp) {
655 OpBuilder::InsertionGuard g(rewriter);
656 rewriter.setInsertionPoint(subViewOp);
657 SmallVector<Value> newIndices;
658 for (OpFoldResult offset : subViewOp.getMixedOffsets()) {
659 Value indexVal;
660 if (auto attr = dyn_cast<Attribute>(offset)) {
661 indexVal = rewriter.create<arith::ConstantIndexOp>(
662 loc, cast<IntegerAttr>(attr).getInt());
663 } else {
664 indexVal = cast<Value>(offset);
665 }
666 newIndices.push_back(indexVal);
667 }
668 return newIndices;
669}
670
671/// Pattern to canonicalize trivial vector.transfer_read operations on subviews.
672///
673/// This pattern eliminates unnecessary memref.subview operations when the
674/// transfer_read accesses the subview with all-zero indices. It transforms:
675///
676/// INPUT:
677/// %subview = memref.subview %memref [offset0, offset1, ...]
678/// %result = vector.transfer_read %subview[0, 0, ...]
679///
680/// OUTPUT:
681/// %result = vector.transfer_read %memref[offset0, offset1, ...]
682///
683/// The pattern only matches when:
684/// - The base of transfer_read is defined by a memref.subview operation
685/// - All indices in the transfer_read are constant zeros
687 : public OpRewritePattern<vector::TransferReadOp> {
688 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
689
690 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
691 PatternRewriter &rewriter) const override {
692 // Check if the base memref comes from a subview operation
693 auto subViewOp = dyn_cast_if_present<memref::SubViewOp>(
694 readOp.getBase().getDefiningOp());
695 if (!subViewOp)
696 return failure();
697
698 // Verify that all access indices are zero
699 if (failed(isAllZeroOffsetAccess(readOp.getIndices())))
700 return failure();
701
702 // Convert subview offsets to explicit index values
703 SmallVector<Value> newIndices =
704 opFoldResultsToValues(rewriter, readOp.getLoc(), subViewOp);
705
706 // Replace with direct access to the original memref using subview offsets
707 rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
708 readOp, readOp.getType(), subViewOp.getSource(), newIndices,
709 readOp.getPadding(), readOp.getInBoundsValues());
710 return success();
711 }
712};
713
714/// Pattern to canonicalize trivial vector.transfer_write operations on
715/// subviews.
716///
717/// This pattern eliminates unnecessary memref.subview operations when the
718/// transfer_write accesses the subview with all-zero indices. It transforms:
719///
720/// INPUT:
721/// %subview = memref.subview %memref [offset0, offset1, ...]
722/// vector.transfer_write %value, %subview[0, 0, ...]
723///
724/// OUTPUT:
725/// vector.transfer_write %value, %memref[offset0, offset1, ...]
726///
727/// The pattern only matches when:
728/// - The base of transfer_write is defined by a memref.subview operation
729/// - All indices in the transfer_write are constant zeros
731 : public OpRewritePattern<vector::TransferWriteOp> {
732 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
733
734 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
735 PatternRewriter &rewriter) const override {
736 // Check if the base memref comes from a subview operation
737 auto subViewOp = dyn_cast_if_present<memref::SubViewOp>(
738 writeOp.getBase().getDefiningOp());
739 if (!subViewOp)
740 return failure();
741
742 // Verify that all access indices are zero
743 if (failed(isAllZeroOffsetAccess(writeOp.getIndices())))
744 return failure();
745
746 // Convert subview offsets to explicit index values
747 SmallVector<Value> newIndices =
748 opFoldResultsToValues(rewriter, writeOp.getLoc(), subViewOp);
749
750 // Create new transfer_write with direct access to original memref
751 rewriter.create<vector::TransferWriteOp>(
752 writeOp.getLoc(), writeOp.getVector(), subViewOp.getSource(),
753 newIndices, writeOp.getInBoundsValues());
754
755 // Remove the original transfer_write operation
756 rewriter.eraseOp(writeOp);
757 return success();
758 }
759};
760
761//============================================================================//
762//============ AIE2 canonicalization conversion patterns ===============//
763//============================================================================//
764
766 : public OpRewritePattern<vector::InsertOp> {
767
768 using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
769
770 LogicalResult matchAndRewrite(vector::InsertOp insOp,
771 PatternRewriter &rewriter) const override {
772 auto insSrcTy = dyn_cast<VectorType>(insOp.getValueToStoreType());
773 if (!insSrcTy)
774 return failure();
775
776 auto srcShape = insSrcTy.getShape();
777 auto dstShape = insOp.getDestVectorType().getShape();
778
779 unsigned long numLeadUnitDimDst = 0;
780 while (numLeadUnitDimDst < dstShape.size() &&
781 dstShape[numLeadUnitDimDst] == 1)
782 numLeadUnitDimDst++;
783
784 if (!numLeadUnitDimDst)
785 return failure();
786
787 unsigned long numLeadUnitDimSrc = 0;
788 while (numLeadUnitDimSrc < srcShape.size() &&
789 srcShape[numLeadUnitDimSrc] == 1)
790 numLeadUnitDimSrc++;
791
792 SmallVector<int64_t> nonLeadUnitDimDstShape(
793 dstShape.begin() + numLeadUnitDimDst, dstShape.end());
794 SmallVector<int64_t> nonLeadUnitDimSrcShape(
795 srcShape.begin() + numLeadUnitDimSrc, srcShape.end());
796
797 if (nonLeadUnitDimSrcShape != nonLeadUnitDimDstShape)
798 return failure();
799
800 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
801 insOp, insOp.getDestVectorType(), insOp.getValueToStore());
802 return success();
803 }
804};
805
806//============================================================================//
807//================ Common AIE canonicalization configuration =================//
808//============================================================================//
809static void
810configureCommonAIECanonicalizeLegalizations(ConversionTarget &target,
811 TargetBackend backend) {
812 target.addLegalDialect<arith::ArithDialect, affine::AffineDialect,
813 memref::MemRefDialect, vector::VectorDialect,
814 ub::UBDialect>();
815}
816
817static void
818populateCommonAIECanonicalizeConversionPatterns(RewritePatternSet &patterns,
819 TargetBackend backend) {
821 patterns.getContext());
822}
823
824//============================================================================//
825//============== AIEv1-specific canonicalization configuration ===============//
826//============================================================================//
827
828static void configureAIEv1CanonicalizeLegalizations(ConversionTarget &target,
829 TargetBackend backend) {
830 target.addDynamicallyLegalOp<vector::TransferReadOp>(
831 [](vector::TransferReadOp op) {
832 return !op.getPermutationMap().isConstant() &&
833 getTransferReadAlignmentOffset(op, op.getVectorType(), 128)
834 .value_or(0) == 0;
835 });
836}
837
838static void
839populateAIEv1CanonicalizeConversionPatterns(RewritePatternSet &patterns,
840 TargetBackend backend) {
841 patterns.add<SplitUnalignedTransferReadPattern>(patterns.getContext(), 512,
842 128);
843}
844
845//============================================================================//
846//============== AIE2-specific canonicalization configuration ===============//
847//============================================================================//
848
849static void configureAIE2CanonicalizeLegalizations(ConversionTarget &target,
850 TargetBackend backend) {
851 target.addDynamicallyLegalOp<vector::TransferReadOp>(
852 [](vector::TransferReadOp op) {
853 return !op.getPermutationMap().isConstant() &&
854 getTransferReadAlignmentOffset(op, op.getVectorType(), 256)
855 .value_or(0) == 0 &&
856 op.getVector().getType().getRank() < 2;
857 });
858 target.addDynamicallyLegalOp<vector::TransferWriteOp>(
859 [](vector::TransferWriteOp op) {
860 return cast<VectorType>(op.getVector().getType()).getRank() < 2;
861 });
862 target.addDynamicallyLegalOp<vector::ContractionOp>(
863 [](vector::ContractionOp op) {
864 return !isGemmBTransposedContractionOp(op);
865 });
866}
867
868static void
869populateAIE2CanonicalizeConversionPatterns(RewritePatternSet &patterns,
870 TargetBackend backend) {
871 patterns.add<SplitUnalignedTransferReadPattern>(patterns.getContext(), 1024,
872 256);
873 patterns
875 FlattenMultDimTransferWritePattern>(patterns.getContext());
876}
877
878//============================================================================//
879//=================== Common AIE Canonicalization Passes =====================//
880//============================================================================//
881
883 : public PassWrapper<VectorBroadcastLoweringPass, OperationPass<>> {
884
885 void runOnOperation() override {
886 auto *op = getOperation();
887 MLIRContext *context = &getContext();
888 RewritePatternSet patterns(context);
889 populateVectorBroadcastLoweringPatterns(patterns);
891 patterns.getContext());
892
893 (void)applyPatternsGreedily(op, std::move(patterns));
894 }
895};
896
897static std::unique_ptr<::mlir::Pass> createVectorBroadcastLoweringPass() {
898 return std::make_unique<VectorBroadcastLoweringPass>();
899}
900
901// This pass converts standard vector ops into a subset of `Vector` ops more
902// amenable to being converted to `AIEVec`. So far, this process consists of
903// two steps:
904// 1) Replace splat transfer reads with contiguous transfer reads followed
905// by `extract` + `splat` operations.
906// 2) Split unaligned transfer reads into a wider aligned transfer read
907// followed by a `vector.extract_strided_slice` operation.
909 : public PassWrapper<CanonicalizeVectorForAIEVecPass, OperationPass<>> {
910 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CanonicalizeVectorForAIEVecPass)
911
915
922
923 // In case we want to register this pass as a standalone pass for test
924 // purposes.
925 StringRef getArgument() const final {
926 return "test-canonicalize-vector-for-aievec";
927 }
928
929 StringRef getDescription() const final {
930 return "Canonicalize vector operations for AIEVec conversion";
931 }
932
933 void getDependentDialects(DialectRegistry &registry) const override {
934 registry
935 .insert<arith::ArithDialect, memref::MemRefDialect,
936 vector::VectorDialect, affine::AffineDialect, ub::UBDialect>();
937 }
938
939 Option<std::string> aieTarget{
940 *this, "aie-target",
941 llvm::cl::desc(
942 "Select AIE version: \"aie\", \"aie2\", or \"aie2p\". This will "
943 "determine the vector size and available operations."),
944 llvm::cl::init("aie")};
945
946 Option<std::string> targetBackend{
947 *this, "target-backend",
948 llvm::cl::desc("Select translation backend: \"cpp\" or \"llvmir\". This "
949 "will determine the aievec operations used to convert "
950 "from vector dialect."),
951 llvm::cl::init("cpp")};
952
953 void runOnOperation() override {
954 auto *op = getOperation();
955 MLIRContext *context = &getContext();
956 RewritePatternSet patterns(context);
957 ConversionTarget target(*context);
958
959 AIEArch aieVersion = decodeAIETarget(aieTarget);
960 if (aieVersion == AIEArch::UNKNOWN) {
961 op->emitError() << "unknown AIE target '" << aieTarget << "'";
962 signalPassFailure();
963 return;
964 }
965
966 TargetBackend backend = decodeTargetBackend(targetBackend);
967 if (backend == TargetBackend::UNKNOWN) {
968 op->emitError() << "unknown target backend '" << targetBackend << "'";
969 signalPassFailure();
970 return;
971 }
972 if (backend == TargetBackend::LLVMIR && aieVersion == AIEArch::AIE) {
973 op->emitError() << "targetting LLVM IR is not supported for AIEv1";
974 signalPassFailure();
975 return;
976 }
977
978 populateCommonAIECanonicalizeConversionPatterns(patterns, backend);
979 configureCommonAIECanonicalizeLegalizations(target, backend);
980 if (aieVersion == AIEArch::AIE) {
981 populateAIEv1CanonicalizeConversionPatterns(patterns, backend);
982 configureAIEv1CanonicalizeLegalizations(target, backend);
983 } else {
984 populateAIE2CanonicalizeConversionPatterns(patterns, backend);
985 configureAIE2CanonicalizeLegalizations(target, backend);
986 }
987
988 {
989 RewritePatternSet patterns(context);
992 (void)applyPatternsGreedily(op, std::move(patterns));
993 }
994
995 if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
996 signalPassFailure();
997 }
998 }
999};
1000
1001static std::unique_ptr<::mlir::Pass> createCanonicalizeVectorForAIEVecPass(
1002 const CanonicalizeVectorForAIEVecOptions &options) {
1003 return std::make_unique<CanonicalizeVectorForAIEVecPass>(options);
1004}
1005
1007 : public PassWrapper<HoistCastOpToDataSourcePass, OperationPass<>> {
1008
1009 void runOnOperation() override {
1010 auto *op = getOperation();
1011 MLIRContext *context = &getContext();
1012 RewritePatternSet patterns(context);
1013
1014 patterns.add<HoistCastOpToDataSourcePattern>(patterns.getContext());
1015
1016 (void)applyPatternsGreedily(op, std::move(patterns));
1017 }
1018};
1019
1020static std::unique_ptr<::mlir::Pass> createHoistCastOpToDataSourcePass() {
1021 return std::make_unique<HoistCastOpToDataSourcePass>();
1022}
1023
1025 : public PassWrapper<ReorderOperationsPass, OperationPass<>> {
1026
1027 void runOnOperation() override {
1028 auto *op = getOperation();
1029 MLIRContext *context = &getContext();
1030 RewritePatternSet patterns(context);
1031
1033 patterns.getContext(),
1034 [](arith::ExtSIOp extOp, vector::BroadcastOp bcastOp) -> Type {
1035 Type extInElemTy = extOp.getIn().getType();
1036 auto extInVecTy = dyn_cast<VectorType>(extInElemTy);
1037 if (extInVecTy)
1038 extInElemTy = extInVecTy.getElementType();
1039 return VectorType::get(bcastOp.getResultVectorType().getShape(),
1040 extInElemTy);
1041 });
1042
1043 (void)applyPatternsGreedily(op, std::move(patterns));
1044 }
1045};
1046
1047static std::unique_ptr<::mlir::Pass> createReorderOperationsPass() {
1048 return std::make_unique<ReorderOperationsPass>();
1049}
1050
1051//============================================================================//
1052//=============== Main Vector2Vector Pipeline Configuration ==================//
1053//============================================================================//
1054
1056 OpPassManager &pm, const CanonicalizeVectorForAIEVecOptions &options) {
1057 // Add `Vector` code canonicalization passes
1058 // TODO: Add passes to unroll vector with unsupported types
1059 // TODO: Add passes to split vectors that won't fit in registers
1060 if (decodeTargetBackend(options.targetBackend) == TargetBackend::LLVMIR)
1061 pm.addPass(createReorderOperationsPass());
1062 pm.addPass(createCopyRemovalPass());
1063 pm.addPass(createVectorBroadcastLoweringPass());
1064 pm.addPass(createCanonicalizeVectorForAIEVecPass(options));
1065 if (decodeTargetBackend(options.targetBackend) == TargetBackend::CPP)
1066 pm.addPass(createHoistCastOpToDataSourcePass());
1067}
std::unique_ptr<::mlir::Pass > createCopyRemovalPass()
Create a pass that removes unnecessary Copy operations.
std::optional< int64_t > getTransferReadAlignmentOffset(TransferReadLikeOp readOp, mlir::VectorType vType, int64_t alignment)
int32_t getElementSizeInBits(mlir::VectorType type)
Definition AIEVecUtils.h:49
void buildCanonicalizeVectorForAIEVec(mlir::OpPassManager &pm, const CanonicalizeVectorForAIEVecOptions &options)
TargetBackend
Definition Passes.h:27
AIEArch
Definition Passes.h:21
Pattern to canonicalize trivial vector.transfer_read operations on subviews.
LogicalResult matchAndRewrite(vector::TransferReadOp readOp, PatternRewriter &rewriter) const override
Pattern to canonicalize trivial vector.transfer_write operations on subviews.
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, PatternRewriter &rewriter) const override
void getDependentDialects(DialectRegistry &registry) const override
CanonicalizeVectorForAIEVecPass(const CanonicalizeVectorForAIEVecOptions &options)
LogicalResult matchAndRewrite(vector::InsertOp insOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static VectorType getTransposedVectorType(VectorType vecTy)
LogicalResult matchAndRewrite(vector::ContractionOp contractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
SplitUnalignedTransferReadPattern(MLIRContext *context, int64_t maxVectorSize, int64_t alignment)
LogicalResult matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(UnaryOpB bOp, PatternRewriter &rewriter) const override
std::function< Type(UnaryOpA aOp, UnaryOpB bOp)> InferTypeB2AFnTy
SwapUnaryOpsPattern(MLIRContext *context, InferTypeB2AFnTy inferType)
Options for the "canonicalize-vector-for-aievec" pipeline.
Definition Passes.h:41
PassOptions::Option< std::string > targetBackend
Definition Passes.h:47
PassOptions::Option< std::string > aieTarget
Definition Passes.h:42