MLIR-AIE
VectorToVectorConversions.cpp
Go to the documentation of this file.
1//===-VectorToVectorConversions.cpp - Conversions within Vector -*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2023-2024 Advanced Micro Devices, Inc.
8//
9//===----------------------------------------------------------------------===//
10// This file contains conversions and rewrites to the Vector dialect to make
11// it compatible with the available vector instructions in AIE architectures
12//===----------------------------------------------------------------------===//
13
17#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
18#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
19#include "mlir/Dialect/Affine/IR/AffineOps.h"
20#include "mlir/Dialect/Func/IR/FuncOps.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/UB/IR/UBOps.h"
23#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
24#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
25#include "mlir/IR/Matchers.h"
26#include "mlir/IR/PatternMatch.h"
27#include "mlir/Pass/PassManager.h"
28#include "mlir/Transforms/DialectConversion.h"
29#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30#include <algorithm>
31
32#define DEBUG_TYPE "aievec-canonicalization"
33
34using namespace mlir;
35using namespace arith;
36using namespace vector;
37using namespace xilinx;
38using namespace xilinx::aievec;
39
40//============================================================================//
41//================== Common AIE canonicalization analysis ====================//
42//============================================================================//
43
44static TargetBackend decodeTargetBackend(const std::string &backend) {
45 if (!backend.empty()) {
46 if (backend == "llvmir")
47 return TargetBackend::LLVMIR;
48 if (backend != "cpp")
49 return TargetBackend::UNKNOWN;
50 }
51 return TargetBackend::CPP;
52}
53
54static AIEArch decodeAIETarget(const std::string &target) {
55 if (!target.empty()) {
56 if (target == "aieml" || target == "aie2" || target == "aie2p")
57 return AIEArch::AIE2;
58 if (target != "aie")
59 return AIEArch::UNKNOWN;
60 }
61 return AIEArch::AIE;
62}
63
64//============================================================================//
65//================== Common AIE canonicalization analysis ====================//
66//============================================================================//
67
68static bool isGemmBTransposedContractionOp(vector::ContractionOp op) {
69 if (op.getKind() != vector::CombiningKind::ADD)
70 return false;
71
72 // Get and check shape of operands
73 auto lhsShape = op.getLhsType().getShape();
74 auto rhsShape = op.getRhsType().getShape();
75 auto accShape = cast<ShapedType>(op.getAccType()).getShape();
76 if (lhsShape.size() < 2 || rhsShape.size() < 2 || accShape.size() < 2)
77 return false;
78
79 // Check that the innermost iterators match gemm-like iterators
80 SmallVector<vector::IteratorType> iterators = op.getIteratorTypesArray();
81 if (iterators.size() < 3)
82 return false;
83 auto innerMostIterators =
84 SmallVector<vector::IteratorType>(iterators.end() - 3, iterators.end());
85 if (vector::IteratorType::parallel != innerMostIterators[0] ||
86 vector::IteratorType::parallel != innerMostIterators[1] ||
87 vector::IteratorType::reduction != innerMostIterators[2])
88 return false;
89
90 // Get indexing maps of iterators for operands
91 SmallVector<AffineMap, 4> indexingMaps(op.getIndexingMapsArray());
92 SmallVector<int64_t> outerMostResults;
93 for (int64_t i = 0; i < indexingMaps[0].getNumResults() - 2; i++)
94 outerMostResults.push_back(i);
95
96 auto innerLhsMap = indexingMaps[0].dropResults(outerMostResults);
97 auto innerRhsMap = indexingMaps[1].dropResults(outerMostResults);
98 auto innerAccMap = indexingMaps[2].dropResults(outerMostResults);
99
100 // Check whether they conform to a "transposed B" gemm
101 auto *ctx = op.getContext();
102 auto mmAidxMap =
103 AffineMap::getPermutationMap(ArrayRef<unsigned>{1, 0, 2}, ctx)
104 .dropResults(0);
105 auto mmBidxMap =
106 AffineMap::getPermutationMap(ArrayRef<unsigned>{0, 1, 2}, ctx)
107 .dropResults(0);
108 auto mmCidxMap =
109 AffineMap::getPermutationMap(ArrayRef<unsigned>{2, 0, 1}, ctx)
110 .dropResults(0);
111 int64_t numOuterMostDims = indexingMaps[0].getNumDims() - 3;
112 return innerLhsMap == mmAidxMap.shiftDims(numOuterMostDims) &&
113 innerRhsMap == mmBidxMap.shiftDims(numOuterMostDims) &&
114 innerAccMap == mmCidxMap.shiftDims(numOuterMostDims);
115}
116
117//============================================================================//
118//============ Common AIE canonicalization conversion patterns ===============//
119//============================================================================//
120
121// This pattern converts a `vector.transfer_read` with an unaligned access
122// into an aligned `vector.transfer_read` twice as long, followed by a
123// `vector.extract_strided_slice` selecting the subvector matching the
124// original `vector.transfer_read`.
126 : public OpConversionPattern<vector::TransferReadOp> {
127 using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern;
128
130 int64_t alignment)
131 : OpConversionPattern<vector::TransferReadOp>(context),
133
134 LogicalResult
135 matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
136 ConversionPatternRewriter &rewriter) const override {
137 // Check that it's not a splat transfer read.
138 if (adaptor.getPermutationMap().isConstant())
139 return failure();
140
141 // Check if the transfer is unaligned.
142 auto vType = readOp.getVectorType();
143 int64_t offset =
145 .value_or(0);
146 if (offset == 0)
147 return failure();
148
149 // Verify that we can load a vector 2x as long as the original
150 auto vLen = vType.getShape().back();
151 auto longVecTy = VectorType::get(2 * vLen, vType.getElementType());
152 auto longVecSize = getElementSizeInBits(vType) * 2 * vLen;
153 if (longVecSize > maxVectorSize)
154 return failure();
155
156 // Calculate the aligned indices for the lower and higher parts.
157 // TODO: Add support for cases where the offset is greater than the
158 // TODO: vector length.
159 auto loc = readOp.getLoc();
160 Value oldInnerMostIdx = adaptor.getIndices().back();
161 auto offsetCorrectionMap =
162 AffineMap::get(1, 0, getAffineDimExpr(0, readOp.getContext()) - offset);
163 Value newInnerMostIdx = affine::AffineApplyOp::create(
164 rewriter, readOp.getLoc(), offsetCorrectionMap,
165 SmallVector<Value, 1>({oldInnerMostIdx}))
166 .getResult();
167 SmallVector<Value, 8> alignedIdx;
168 alignedIdx.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
169 alignedIdx[alignedIdx.size() - 1] = newInnerMostIdx;
170
171 // Create the aligned transfer read for a vector 2x as long that covers the
172 // elements of the unaligned vector.
173 auto newReadOp = vector::TransferReadOp::create(
174 rewriter, loc, longVecTy, adaptor.getBase(), alignedIdx,
175 adaptor.getPadding());
176
177 // Create a `vector.extract_strided_slice` to extract the unaligned vector.
178 rewriter.replaceOpWithNewOp<vector::ExtractStridedSliceOp>(
179 readOp, newReadOp.getResult(), offset, vLen, 1);
180
181 return success();
182 }
183
186};
187
188// This pattern converts a `vector.transfer_read` with a splat permutation map
189// into a contiguous `vector.transfer_read` followed by a `vector.extract` to
190// obtain the splat value and a `vector.broadcast` to broadcast it into a
191// vector of the right size.
193 : public OpConversionPattern<vector::TransferReadOp> {
194 using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern;
195
197 : OpConversionPattern<vector::TransferReadOp>(context) {}
198
199 LogicalResult
200 matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
201 ConversionPatternRewriter &rewriter) const override {
202 AffineMap map = readOp.getPermutationMap();
203 if (!map.isConstant())
204 return failure();
205
206 Value srcMemRef = adaptor.getBase();
207 SmallVector<Value, 8> indices;
208 Value newIdx;
209 int64_t offset = 0;
210 // If it's a zero-rank memory access
211 if (cast<MemRefType>(srcMemRef.getType()).getRank() == 0) {
212 srcMemRef = memref::ExpandShapeOp::create(
213 rewriter, readOp.getLoc(), SmallVector<int64_t, 1>({1}),
214 srcMemRef, SmallVector<ReassociationIndices, 1>({}))
215 .getResult();
216 newIdx = arith::ConstantOp::create(rewriter, readOp.getLoc(),
217 rewriter.getIndexAttr(0L));
218 indices.push_back(newIdx);
219 } else {
220 indices.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
221 newIdx = indices[indices.size() - 1];
222 // If the innermost index comes from an `affine.apply` op, take the base
223 // as the new innermost index for the new `vector.transfer_read`, and the
224 // offset as the index for the `aievec.broadcast` op.
225 if (auto applyOp = newIdx.getDefiningOp<affine::AffineApplyOp>())
226 if (applyOp.getAffineMap().getNumDims() == 1) {
227 newIdx = applyOp.getMapOperands()[0];
228 offset = applyOp.getAffineMap().compose(ArrayRef<int64_t>{0})[0];
229 }
230 }
231 // XXX: We assume we are reading 1D vectors
232 int64_t vlen = readOp.getVector().getType().getShape()[0];
233 if (offset >= vlen) {
234 // If the splat element is beyond the first vector, we calculate the
235 // address of the vector containing the element.
236 int64_t numElemsToSkip = vlen * (offset / vlen);
237 offset = offset % vlen;
238 auto newAddrMap = AffineMap::get(
239 1, 0, getAffineDimExpr(0, readOp.getContext()) + numElemsToSkip);
240 newIdx =
241 affine::AffineApplyOp::create(rewriter, readOp.getLoc(), newAddrMap,
242 SmallVector<Value, 1>({newIdx}))
243 .getResult();
244 }
245 indices[indices.size() - 1] = newIdx;
246 auto newReadOp = vector::TransferReadOp::create(
247 rewriter, readOp.getLoc(), readOp.getVector().getType(), srcMemRef,
248 indices, adaptor.getPadding());
249 auto extractOp = vector::ExtractOp::create(rewriter, readOp.getLoc(),
250 newReadOp.getResult(),
251 ArrayRef<int64_t>{offset});
252 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
253 readOp, newReadOp.getVector().getType(), extractOp.getResult());
254 return success();
255 }
256};
257
258// This pattern moves cast operations as close as possible to the source of
259// the data. This helps to simplify dealing with patterns that may vary only
260// by these sorts of casts between data manipulation operations and arithmetic
261// ops.
262// TODO: Generalize this op and instantiate for different types of cast ops.
264 HoistCastOpToDataSourcePattern(MLIRContext *context)
265 : RewritePattern(arith::ExtSIOp::getOperationName(), /*benefit=*/1,
266 context) {}
267
268 LogicalResult matchAndRewrite(Operation *op,
269 PatternRewriter &rewriter) const override {
270 arith::ExtSIOp extOp = cast<arith::ExtSIOp>(op);
271 Operation *defOp = extOp.getIn().getDefiningOp();
272 // If it's a data source op, we're done.
273 if (!defOp || isa<vector::TransferReadOp, memref::LoadOp,
274 affine::AffineLoadOp, func::CallOp>(defOp))
275 return failure();
276
277 // At the moment, we only accept ops we know we can swap with cast.
278 if (!isa<vector::BroadcastOp, vector::ExtractOp,
279 vector::ExtractStridedSliceOp>(defOp))
280 return failure();
281
282 Type extOpInTy = extOp.getIn().getType();
283 SmallVector<Value, 4> inputs;
284 for (Value operand : defOp->getOperands()) {
285 Type operandTy = operand.getType();
286 VectorType extOpInVecTy = dyn_cast<VectorType>(extOpInTy);
287 VectorType operandVecTy = dyn_cast<VectorType>(operandTy);
288 if (operandTy == extOpInTy) {
289 Type outTy = extOp.getOut().getType();
290 inputs.push_back(
291 arith::ExtSIOp::create(rewriter, extOp.getLoc(), outTy, operand)
292 .getOut());
293 } else if (extOpInVecTy && extOpInVecTy.getElementType() == operandTy) {
294 // Promote from vector to scalar -> scalar conversion for this operand
295 Type outTy =
296 cast<VectorType>(extOp.getOut().getType()).getElementType();
297 inputs.push_back(
298 arith::ExtSIOp::create(rewriter, extOp.getLoc(), outTy, operand)
299 .getOut());
300 } else if (operandVecTy && operandVecTy.getElementType() == extOpInTy) {
301 // Promote from scalar to vector -> vector conversion for this operand
302 Type outTy =
303 VectorType::get(operandVecTy.getShape(), extOp.getOut().getType());
304 inputs.push_back(
305 arith::ExtSIOp::create(rewriter, extOp.getLoc(), outTy, operand)
306 .getOut());
307 } else if (extOpInVecTy && operandVecTy &&
308 (extOpInVecTy.getElementType() ==
309 operandVecTy.getElementType())) {
310 // Hoist through a vector shape change
311 Type outTy = VectorType::get(
312 operandVecTy.getShape(),
313 cast<VectorType>(extOp.getOut().getType()).getElementType());
314 inputs.push_back(
315 arith::ExtSIOp::create(rewriter, extOp.getLoc(), outTy, operand)
316 .getOut());
317 } else {
318 inputs.push_back(operand);
319 }
320 }
321
322 auto *newOp =
323 rewriter.create(extOp->getLoc(), defOp->getName().getIdentifier(),
324 inputs, {extOp.getOut().getType()}, defOp->getAttrs());
325 rewriter.replaceOp(extOp, newOp->getResult(0));
326 return success();
327 }
328};
329
330// This pattern swaps a UnaryOpA followed by UnaryOpB. This pattern can be used
331// to improve pattern matching for mixed-type arithmetic ops, by getting sign
332// extension ops closer to the single-type arithmetic operations.
333template <class UnaryOpA, class UnaryOpB>
334struct SwapUnaryOpsPattern : public OpRewritePattern<UnaryOpB> {
336 // This function takes the chain of operations A->B, and returns the new type
337 // between B and A after the swap.
338 using InferTypeB2AFnTy = std::function<Type(UnaryOpA aOp, UnaryOpB bOp)>;
340
341 SwapUnaryOpsPattern(MLIRContext *context, InferTypeB2AFnTy inferType)
342 : OpRewritePattern<UnaryOpB>(context), inferTypeB2A(inferType) {}
343
344 LogicalResult matchAndRewrite(UnaryOpB bOp,
345 PatternRewriter &rewriter) const override {
346 static_assert(
347 UnaryOpA::template hasTrait<OpTrait::OneOperand>(),
348 "SwapUnaryOps can only be instantiated for single-operand ops");
349 static_assert(
350 UnaryOpB::template hasTrait<OpTrait::OneOperand>(),
351 "SwapUnaryOps can only be instantiated for single-operand ops");
352 UnaryOpA aOp = bOp.getOperand().template getDefiningOp<UnaryOpA>();
353 if (!aOp)
354 return rewriter.notifyMatchFailure(bOp, UnaryOpB::getOperationName() +
355 " not preceeded by " +
356 UnaryOpA::getOperationName());
357
358 Type newA2BTy = inferTypeB2A(aOp, bOp);
359
360 auto newA =
361 UnaryOpB::create(rewriter, bOp->getLoc(), SmallVector<Type>({newA2BTy}),
362 aOp->getOperands(), bOp->getAttrs());
363 auto newB = UnaryOpA::create(rewriter, bOp->getLoc(),
364 SmallVector<Type>({bOp.getResult().getType()}),
365 newA->getResults(), aOp->getAttrs());
366 rewriter.replaceOp(bOp, newB.getResult());
367 return success();
368 }
369};
370
371static SmallVector<Value> collapseInnerMostDimIndices(PatternRewriter &b,
372 Location loc, int numDims,
373 ValueRange indices,
374 ArrayRef<int64_t> shape,
375 AffineMap layout) {
376 (void)layout; // Layout is verified by callers; index computation uses shape.
377 auto newIdxExpr = b.getAffineDimExpr(numDims - 1);
378 int64_t stride = 1;
379 for (int64_t dim = numDims - 2; dim >= 0; dim--) {
380 stride *= shape[shape.size() - (numDims - dim - 1)];
381 newIdxExpr = newIdxExpr + b.getAffineDimExpr(dim) * stride;
382 }
383 auto newIndexMap = AffineMap::get(numDims, 0, newIdxExpr);
384 Value newInnerMostIdxValue =
385 affine::AffineApplyOp::create(b, loc, newIndexMap,
386 indices.take_back(numDims))
387 .getResult();
388 SmallVector<Value> newIdxRange;
389 for (auto idx : indices.drop_back(numDims))
390 newIdxRange.push_back(idx);
391 newIdxRange.push_back(newInnerMostIdxValue);
392 return newIdxRange;
393}
394
395static Value collapseInnerMostShapeDims(PatternRewriter &b, Location loc,
396 int numDims, Value val) {
397 auto memRefTy = cast<MemRefType>(val.getType());
398 auto shape = memRefTy.getShape();
399 int64_t newInnerMostDim = std::accumulate(shape.end() - numDims, shape.end(),
400 1, std::multiplies<>());
401 SmallVector<int64_t, 4> newShape{shape.begin(), shape.end() - numDims + 1};
402 newShape[shape.size() - numDims] = newInnerMostDim;
403 auto reassocIndices =
404 getReassociationIndicesForCollapse(shape, newShape).value();
405 // Let CollapseShapeOp::inferResultType compute the correct result type,
406 // which preserves strided layout and dynamic offset from the source.
407 auto newMemRefTy =
408 memref::CollapseShapeOp::computeCollapsedType(memRefTy, reassocIndices);
409 return memref::CollapseShapeOp::create(b, loc, newMemRefTy, val,
410 reassocIndices)
411 .getResult();
412}
413
414/// Check if a memref has contiguous row-major strides (each stride equals the
415/// product of trailing dimensions). Dynamic strides are accepted when the
416/// corresponding dimension size is 1. Memrefs with dynamic offsets are fine.
417static bool hasContiguousRowMajorStrides(MemRefType memRefTy) {
418 SmallVector<int64_t> strides;
419 int64_t offset;
420 if (failed(memRefTy.getStridesAndOffset(strides, offset)))
421 return false;
422 auto shape = memRefTy.getShape();
423 int64_t expected = 1;
424 for (int i = shape.size() - 1; i >= 0; --i) {
425 if (strides[i] != ShapedType::kDynamic && strides[i] != expected)
426 return false;
427 if (shape[i] != ShapedType::kDynamic)
428 expected *= shape[i];
429 else
430 expected = ShapedType::kDynamic;
431 }
432 return true;
433}
434
435// This pattern flatten multidimensional `vector.transfer_read` operations
436// replacing them with a `memref.collapse_shape`, a 1D `vector.transfer_read`,
437// and a `vector.shape_cast`.
439 : public OpConversionPattern<vector::TransferReadOp> {
440 using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern;
441
442 LogicalResult
443 matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
444 ConversionPatternRewriter &rewriter) const override {
445 // We can only deal with unmasked transfer ops with an identity permutation
446 // map.
447 if (!adaptor.getPermutationMap().isMinorIdentity() || adaptor.getMask())
448 return failure();
449 VectorType vectorTy = readOp.getVector().getType();
450 if (vectorTy.getRank() < 2)
451 return failure();
452 // Work only on bufferized reads
453 MemRefType memRefTy = dyn_cast<MemRefType>(adaptor.getBase().getType());
454 if (!memRefTy)
455 return failure();
456 auto memRefShape = memRefTy.getShape();
457 auto vecShape = vectorTy.getShape();
458
459 auto newVectorTy =
460 VectorType::get({std::accumulate(vecShape.begin(), vecShape.end(), 1,
461 std::multiplies<>())},
462 vectorTy.getElementType());
463 if (!hasContiguousRowMajorStrides(memRefTy))
464 return failure();
465
466 AffineMap layout = memRefTy.getLayout().getAffineMap();
467 auto newIndices =
468 collapseInnerMostDimIndices(rewriter, readOp.getLoc(), vecShape.size(),
469 adaptor.getIndices(), memRefShape, layout);
470 auto newSource = collapseInnerMostShapeDims(
471 rewriter, readOp.getLoc(), vecShape.size(), adaptor.getBase());
472 auto newVector = vector::TransferReadOp::create(
473 rewriter, readOp.getLoc(), newVectorTy, newSource, newIndices,
474 /*padding*/
475 arith::getZeroConstant(rewriter, readOp.getLoc(),
476 newVectorTy.getElementType()));
477
478 auto inBoundsArrayAttrOpt = adaptor.getInBounds();
479 if (inBoundsArrayAttrOpt) {
480 SmallVector<bool> inBounds =
481 llvm::to_vector(inBoundsArrayAttrOpt.getAsValueRange<BoolAttr>());
482 SmallVector<bool> newInBounds({false});
483 newInBounds[0] = std::all_of(inBounds.begin(), inBounds.end(),
484 [](bool v) { return v; });
485 newVector.getProperties().setInBounds(
486 rewriter.getBoolArrayAttr(newInBounds));
487 }
488
489 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, vectorTy,
490 newVector);
491
492 return success();
493 }
494};
495
496// This pattern flatten multidimensional `vector.transfer_write` operations
497// replacing them with a `memref.collapse_shape`, a `vector.shape_cast`, and a
498// 1D `vector.transfer_write`,
500 : public OpConversionPattern<vector::TransferWriteOp> {
501 using OpConversionPattern<vector::TransferWriteOp>::OpConversionPattern;
502
503 LogicalResult
504 matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
505 ConversionPatternRewriter &rewriter) const override {
506 // We can only deal with unmasked transfer ops with an identity permutation
507 // map.
508 if (!adaptor.getPermutationMap().isMinorIdentity() || adaptor.getMask())
509 return failure();
510 VectorType vectorTy = cast<VectorType>(adaptor.getValueToStore().getType());
511 if (vectorTy.getRank() < 2)
512 return failure();
513 // Work only on bufferized reads
514 MemRefType memRefTy = dyn_cast<MemRefType>(adaptor.getBase().getType());
515 if (!memRefTy)
516 return failure();
517 auto memRefShape = memRefTy.getShape();
518 auto vecShape = vectorTy.getShape();
519
520 if (!hasContiguousRowMajorStrides(memRefTy))
521 return failure();
522
523 auto newVectorTy =
524 VectorType::get({std::accumulate(vecShape.begin(), vecShape.end(), 1,
525 std::multiplies<>())},
526 vectorTy.getElementType());
527 AffineMap layout = memRefTy.getLayout().getAffineMap();
528 auto newVector =
529 vector::ShapeCastOp::create(rewriter, writeOp.getLoc(), newVectorTy,
530 adaptor.getValueToStore())
531 .getResult();
532 auto newIndices =
533 collapseInnerMostDimIndices(rewriter, writeOp.getLoc(), vecShape.size(),
534 adaptor.getIndices(), memRefShape, layout);
535 auto newSource = collapseInnerMostShapeDims(
536 rewriter, writeOp.getLoc(), vecShape.size(), adaptor.getBase());
537
538 auto newOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
539 writeOp, newVector, newSource, newIndices);
540
541 auto inBoundsArrayAttrOpt = adaptor.getInBounds();
542 if (inBoundsArrayAttrOpt) {
543 SmallVector<bool> inBounds =
544 llvm::to_vector(inBoundsArrayAttrOpt.getAsValueRange<BoolAttr>());
545 SmallVector<bool> newInBounds({false});
546 newInBounds[0] = std::all_of(inBounds.begin(), inBounds.end(),
547 [](bool v) { return v; });
548 newOp.getProperties().setInBounds(rewriter.getBoolArrayAttr(newInBounds));
549 }
550
551 return success();
552 }
553};
554
555// This pattern extracts an implicit transposition of the 2 innermost
556// dimensions of `rhs` in a gemm-like contraction op, making it an explicit
557// `vector.transpose` op.
558// If `rhs` is coming from a widening op (`extf`/`extsi`/`extui`), the
559// transposition will be hoisted above the widening op.
561 : public OpConversionPattern<vector::ContractionOp> {
562 using OpConversionPattern<vector::ContractionOp>::OpConversionPattern;
563
564 static VectorType getTransposedVectorType(VectorType vecTy) {
565 SmallVector<int64_t> shape{vecTy.getShape()};
566 auto nDim = shape.size();
567 int64_t dimNm1 = shape[nDim - 1];
568 shape[nDim - 1] = shape[nDim - 2];
569 shape[nDim - 2] = dimNm1;
570 auto elemTy = vecTy.getElementType();
571 return VectorType::get(shape, elemTy);
572 }
573
574 LogicalResult
575 matchAndRewrite(vector::ContractionOp contractOp, OpAdaptor adaptor,
576 ConversionPatternRewriter &rewriter) const override {
577 if (!isGemmBTransposedContractionOp(contractOp))
578 return failure();
579
580 Location loc = contractOp.getLoc();
581 auto *ctx = rewriter.getContext();
582
583 Value rhsVal = adaptor.getRhs();
584 VectorType rhsVecTy = contractOp.getRhsType();
585 Type rhsElemTy = rhsVecTy.getElementType();
586
587 bool doExtF = false, doExtSI = false, doExtUI = false;
588 if (auto extfRhsOp = rhsVal.getDefiningOp<arith::ExtFOp>()) {
589 rhsVal = extfRhsOp.getIn();
590 rhsVecTy = cast<VectorType>(rhsVal.getType());
591 doExtF = true;
592 } else if (auto extsiRhsOp = rhsVal.getDefiningOp<arith::ExtSIOp>()) {
593 rhsVal = extsiRhsOp.getIn();
594 rhsVecTy = cast<VectorType>(rhsVal.getType());
595 doExtSI = true;
596 } else if (auto extuiRhsOp = rhsVal.getDefiningOp<arith::ExtUIOp>()) {
597 rhsVal = extuiRhsOp.getIn();
598 rhsVecTy = cast<VectorType>(rhsVal.getType());
599 doExtUI = true;
600 }
601
602 int64_t nDim = rhsVecTy.getShape().size();
603 SmallVector<int64_t> rhsPermutation;
604 for (int64_t i = 0; i < nDim - 2; i++)
605 rhsPermutation.push_back(i);
606 rhsPermutation.push_back(nDim - 1);
607 rhsPermutation.push_back(nDim - 2);
608 auto transpRhsVecTy = getTransposedVectorType(rhsVecTy);
609 rhsVal = vector::TransposeOp::create(rewriter, loc, transpRhsVecTy, rhsVal,
610 rhsPermutation)
611 .getResult();
612
613 if (doExtF)
614 rhsVal =
615 arith::ExtFOp::create(
616 rewriter, loc,
617 VectorType::get(transpRhsVecTy.getShape(), rhsElemTy), rhsVal)
618 .getOut();
619 if (doExtSI)
620 rhsVal =
621 arith::ExtSIOp::create(
622 rewriter, loc,
623 VectorType::get(transpRhsVecTy.getShape(), rhsElemTy), rhsVal)
624 .getOut();
625 if (doExtUI)
626 rhsVal =
627 arith::ExtUIOp::create(
628 rewriter, loc,
629 VectorType::get(transpRhsVecTy.getShape(), rhsElemTy), rhsVal)
630 .getOut();
631
632 SmallVector<AffineMap, 4> oldIdxMaps(contractOp.getIndexingMapsArray());
633
634 nDim = oldIdxMaps[1].getNumDims();
635 SmallVector<int64_t> innerDimPerm;
636 for (int64_t i = 0; i < nDim - 2; i++)
637 innerDimPerm.push_back(i);
638 innerDimPerm.push_back(nDim - 1);
639 innerDimPerm.push_back(nDim - 2);
640 auto transpPermMap = AffineMap::getPermutationMap(innerDimPerm, ctx);
641
642 auto newIdxMaps = rewriter.getAffineMapArrayAttr(
643 {oldIdxMaps[0], oldIdxMaps[1].compose(transpPermMap), oldIdxMaps[2]});
644
645 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
646 contractOp, contractOp.getResult().getType(), adaptor.getLhs(), rhsVal,
647 adaptor.getAcc(), newIdxMaps, contractOp.getIteratorTypes());
648
649 return success();
650 }
651};
652
653/// Utility function to check if all provided indices are constant zero values.
654/// @return success() if all indices are constant zeros, failure() otherwise
655static LogicalResult isAllZeroOffsetAccess(mlir::OperandRange indices) {
656 if (!llvm::all_of(indices, [](Value val) {
657 IntegerAttr attr;
658 if (!matchPattern(val, m_Constant(&attr)))
659 return false;
660 return attr.getInt() == 0;
661 })) {
662 return failure();
663 }
664 return success();
665}
666
667/// Utility function to convert OpFoldResult offsets from a SubView operation
668/// into a vector of Values.
669static SmallVector<Value> opFoldResultsToValues(PatternRewriter &rewriter,
670 Location loc,
671 memref::SubViewOp subViewOp) {
672 OpBuilder::InsertionGuard g(rewriter);
673 rewriter.setInsertionPoint(subViewOp);
674 SmallVector<Value> newIndices;
675 for (OpFoldResult offset : subViewOp.getMixedOffsets()) {
676 Value indexVal;
677 if (auto attr = dyn_cast<Attribute>(offset)) {
678 indexVal = arith::ConstantIndexOp::create(
679 rewriter, loc, cast<IntegerAttr>(attr).getInt());
680 } else {
681 indexVal = cast<Value>(offset);
682 }
683 newIndices.push_back(indexVal);
684 }
685 return newIndices;
686}
687
688/// Pattern to canonicalize trivial vector.transfer_read operations on subviews.
689///
690/// This pattern eliminates unnecessary memref.subview operations when the
691/// transfer_read accesses the subview with all-zero indices. It transforms:
692///
693/// INPUT:
694/// %subview = memref.subview %memref [offset0, offset1, ...]
695/// %result = vector.transfer_read %subview[0, 0, ...]
696///
697/// OUTPUT:
698/// %result = vector.transfer_read %memref[offset0, offset1, ...]
699///
700/// The pattern only matches when:
701/// - The base of transfer_read is defined by a memref.subview operation
702/// - All indices in the transfer_read are constant zeros
704 : public OpRewritePattern<vector::TransferReadOp> {
705 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
706
707 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
708 PatternRewriter &rewriter) const override {
709 // Check if the base memref comes from a subview operation
710 auto subViewOp = dyn_cast_if_present<memref::SubViewOp>(
711 readOp.getBase().getDefiningOp());
712 if (!subViewOp)
713 return failure();
714
715 // Verify that all access indices are zero
716 if (failed(isAllZeroOffsetAccess(readOp.getIndices())))
717 return failure();
718
719 // Convert subview offsets to explicit index values
720 SmallVector<Value> newIndices =
721 opFoldResultsToValues(rewriter, readOp.getLoc(), subViewOp);
722
723 // Replace with direct access to the original memref using subview offsets
724 rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
725 readOp, readOp.getType(), subViewOp.getSource(), newIndices,
726 readOp.getPadding(), readOp.getInBoundsValues());
727 return success();
728 }
729};
730
731/// Pattern to canonicalize trivial vector.transfer_write operations on
732/// subviews.
733///
734/// This pattern eliminates unnecessary memref.subview operations when the
735/// transfer_write accesses the subview with all-zero indices. It transforms:
736///
737/// INPUT:
738/// %subview = memref.subview %memref [offset0, offset1, ...]
739/// vector.transfer_write %value, %subview[0, 0, ...]
740///
741/// OUTPUT:
742/// vector.transfer_write %value, %memref[offset0, offset1, ...]
743///
744/// The pattern only matches when:
745/// - The base of transfer_write is defined by a memref.subview operation
746/// - All indices in the transfer_write are constant zeros
748 : public OpRewritePattern<vector::TransferWriteOp> {
749 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
750
751 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
752 PatternRewriter &rewriter) const override {
753 // Check if the base memref comes from a subview operation
754 auto subViewOp = dyn_cast_if_present<memref::SubViewOp>(
755 writeOp.getBase().getDefiningOp());
756 if (!subViewOp)
757 return failure();
758
759 // Verify that all access indices are zero
760 if (failed(isAllZeroOffsetAccess(writeOp.getIndices())))
761 return failure();
762
763 // Convert subview offsets to explicit index values
764 SmallVector<Value> newIndices =
765 opFoldResultsToValues(rewriter, writeOp.getLoc(), subViewOp);
766
767 // Create new transfer_write with direct access to original memref
768 vector::TransferWriteOp::create(rewriter, writeOp.getLoc(),
769 writeOp.getVector(), subViewOp.getSource(),
770 newIndices, writeOp.getInBoundsValues());
771
772 // Remove the original transfer_write operation
773 rewriter.eraseOp(writeOp);
774 return success();
775 }
776};
777
778//============================================================================//
779//============ AIE2 canonicalization conversion patterns ===============//
780//============================================================================//
781
783 : public OpRewritePattern<vector::InsertOp> {
784
785 using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
786
787 LogicalResult matchAndRewrite(vector::InsertOp insOp,
788 PatternRewriter &rewriter) const override {
789 auto insSrcTy = dyn_cast<VectorType>(insOp.getValueToStoreType());
790 if (!insSrcTy)
791 return failure();
792
793 auto srcShape = insSrcTy.getShape();
794 auto dstShape = insOp.getDestVectorType().getShape();
795
796 unsigned long numLeadUnitDimDst = 0;
797 while (numLeadUnitDimDst < dstShape.size() &&
798 dstShape[numLeadUnitDimDst] == 1)
799 numLeadUnitDimDst++;
800
801 if (!numLeadUnitDimDst)
802 return failure();
803
804 unsigned long numLeadUnitDimSrc = 0;
805 while (numLeadUnitDimSrc < srcShape.size() &&
806 srcShape[numLeadUnitDimSrc] == 1)
807 numLeadUnitDimSrc++;
808
809 SmallVector<int64_t> nonLeadUnitDimDstShape(
810 dstShape.begin() + numLeadUnitDimDst, dstShape.end());
811 SmallVector<int64_t> nonLeadUnitDimSrcShape(
812 srcShape.begin() + numLeadUnitDimSrc, srcShape.end());
813
814 if (nonLeadUnitDimSrcShape != nonLeadUnitDimDstShape)
815 return failure();
816
817 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
818 insOp, insOp.getDestVectorType(), insOp.getValueToStore());
819 return success();
820 }
821};
822
823//============================================================================//
824//================ Common AIE canonicalization configuration =================//
825//============================================================================//
826static void
827configureCommonAIECanonicalizeLegalizations(ConversionTarget &target,
828 TargetBackend backend) {
829 target.addLegalDialect<arith::ArithDialect, affine::AffineDialect,
830 memref::MemRefDialect, vector::VectorDialect,
831 ub::UBDialect>();
832}
833
834static void
835populateCommonAIECanonicalizeConversionPatterns(RewritePatternSet &patterns,
836 TargetBackend backend) {
838 patterns.getContext());
839}
840
841//============================================================================//
842//============== AIEv1-specific canonicalization configuration ===============//
843//============================================================================//
844
845static void configureAIEv1CanonicalizeLegalizations(ConversionTarget &target,
846 TargetBackend backend) {
847 target.addDynamicallyLegalOp<vector::TransferReadOp>(
848 [](vector::TransferReadOp op) {
849 return !op.getPermutationMap().isConstant() &&
850 getTransferReadAlignmentOffset(op, op.getVectorType(), 128)
851 .value_or(0) == 0;
852 });
853}
854
855static void
856populateAIEv1CanonicalizeConversionPatterns(RewritePatternSet &patterns,
857 TargetBackend backend) {
858 patterns.add<SplitUnalignedTransferReadPattern>(patterns.getContext(), 512,
859 128);
860}
861
862//============================================================================//
863//============== AIE2-specific canonicalization configuration ===============//
864//============================================================================//
865
866static void configureAIE2CanonicalizeLegalizations(ConversionTarget &target,
867 TargetBackend backend) {
868 target.addDynamicallyLegalOp<vector::TransferReadOp>(
869 [](vector::TransferReadOp op) {
870 return !op.getPermutationMap().isConstant() &&
871 getTransferReadAlignmentOffset(op, op.getVectorType(), 256)
872 .value_or(0) == 0 &&
873 op.getVector().getType().getRank() < 2;
874 });
875 target.addDynamicallyLegalOp<vector::TransferWriteOp>(
876 [](vector::TransferWriteOp op) {
877 return cast<VectorType>(op.getVector().getType()).getRank() < 2;
878 });
879 target.addDynamicallyLegalOp<vector::ContractionOp>(
880 [](vector::ContractionOp op) {
881 return !isGemmBTransposedContractionOp(op);
882 });
883}
884
885static void
886populateAIE2CanonicalizeConversionPatterns(RewritePatternSet &patterns,
887 TargetBackend backend) {
888 patterns.add<SplitUnalignedTransferReadPattern>(patterns.getContext(), 1024,
889 256);
890 patterns
892 FlattenMultDimTransferWritePattern>(patterns.getContext());
893}
894
895//============================================================================//
896//=================== Common AIE Canonicalization Passes =====================//
897//============================================================================//
898
899//===----------------------------------------------------------------------===//
900// BF16 Emulation: Emulate f32 vector arithmetic using bf16 operations.
901//===----------------------------------------------------------------------===//
902
903// Smart truncation helper: if the value was produced by arith.extf from bf16,
904// reuse the bf16 source directly to avoid redundant extf->truncf chains.
905static Value smartTruncF32ToBF16(PatternRewriter &rewriter, Location loc,
906 Value val, Type bf16Type) {
907 if (auto extfOp = val.getDefiningOp<arith::ExtFOp>()) {
908 if (extfOp.getIn().getType() == bf16Type)
909 return extfOp.getIn();
910 }
911 return arith::TruncFOp::create(rewriter, loc, bf16Type, val);
912}
913
914/// Pattern to emulate f32 binary vector arithmetic ops in bf16.
915/// For an op like: %r = arith.addf %a, %b : vector<16xf32>
916/// Produces:
917/// %a_bf16 = arith.truncf %a : vector<16xf32> to vector<16xbf16>
918/// %b_bf16 = arith.truncf %b : vector<16xf32> to vector<16xbf16>
919/// %r_bf16 = arith.addf %a_bf16, %b_bf16 : vector<16xbf16>
920/// %r = arith.extf %r_bf16 : vector<16xbf16> to vector<16xf32>
921template <typename OpTy>
924
925 LogicalResult matchAndRewrite(OpTy op,
926 PatternRewriter &rewriter) const override {
927 auto resultType = dyn_cast<VectorType>(op.getType());
928 if (!resultType || !resultType.getElementType().isF32())
929 return failure();
930
931 Location loc = op.getLoc();
932 auto bf16VecType =
933 VectorType::get(resultType.getShape(), rewriter.getBF16Type());
934
935 Value lhsBF16 =
936 smartTruncF32ToBF16(rewriter, loc, op.getLhs(), bf16VecType);
937 Value rhsBF16 =
938 smartTruncF32ToBF16(rewriter, loc, op.getRhs(), bf16VecType);
939
940 Value newResult =
941 OpTy::create(rewriter, loc, bf16VecType, lhsBF16, rhsBF16);
942 auto extOp = arith::ExtFOp::create(rewriter, loc, resultType, newResult);
943 rewriter.replaceOp(op, extOp);
944 return success();
945 }
946};
947
948/// Pattern to emulate f32 comparison ops in bf16.
949/// Result type stays vector<Nxi1>, only operands are truncated.
950struct EmulateCmpFF32InBF16Pattern : public OpRewritePattern<arith::CmpFOp> {
951 using OpRewritePattern::OpRewritePattern;
952
953 LogicalResult matchAndRewrite(arith::CmpFOp op,
954 PatternRewriter &rewriter) const override {
955 auto lhsType = dyn_cast<VectorType>(op.getLhs().getType());
956 if (!lhsType || !lhsType.getElementType().isF32())
957 return failure();
958
959 Location loc = op.getLoc();
960 auto bf16VecType =
961 VectorType::get(lhsType.getShape(), rewriter.getBF16Type());
962
963 Value lhsBF16 =
964 smartTruncF32ToBF16(rewriter, loc, op.getLhs(), bf16VecType);
965 Value rhsBF16 =
966 smartTruncF32ToBF16(rewriter, loc, op.getRhs(), bf16VecType);
967
968 rewriter.replaceOpWithNewOp<arith::CmpFOp>(op, op.getPredicate(), lhsBF16,
969 rhsBF16);
970 return success();
971 }
972};
973
974/// Pattern to emulate f32 select ops in bf16.
975/// Condition stays vector<Nxi1>, true/false values are truncated.
977 : public OpRewritePattern<arith::SelectOp> {
978 using OpRewritePattern::OpRewritePattern;
979
980 LogicalResult matchAndRewrite(arith::SelectOp op,
981 PatternRewriter &rewriter) const override {
982 auto resultType = dyn_cast<VectorType>(op.getType());
983 if (!resultType || !resultType.getElementType().isF32())
984 return failure();
985
986 Location loc = op.getLoc();
987 auto bf16VecType =
988 VectorType::get(resultType.getShape(), rewriter.getBF16Type());
989
990 Value trueValBF16 =
991 smartTruncF32ToBF16(rewriter, loc, op.getTrueValue(), bf16VecType);
992 Value falseValBF16 =
993 smartTruncF32ToBF16(rewriter, loc, op.getFalseValue(), bf16VecType);
994
995 Value newResult = arith::SelectOp::create(rewriter, loc, op.getCondition(),
996 trueValBF16, falseValBF16);
997 auto extOp = arith::ExtFOp::create(rewriter, loc, resultType, newResult);
998 rewriter.replaceOp(op, extOp);
999 return success();
1000 }
1001};
1002
1003/// Pattern to emulate f32 vector.fma in bf16.
1004/// All three operands (lhs, rhs, acc) are truncated.
1005struct EmulateFMAF32InBF16Pattern : public OpRewritePattern<vector::FMAOp> {
1006 using OpRewritePattern::OpRewritePattern;
1007
1008 LogicalResult matchAndRewrite(vector::FMAOp op,
1009 PatternRewriter &rewriter) const override {
1010 auto resultType = dyn_cast<VectorType>(op.getType());
1011 if (!resultType || !resultType.getElementType().isF32())
1012 return failure();
1013
1014 Location loc = op.getLoc();
1015 auto bf16VecType =
1016 VectorType::get(resultType.getShape(), rewriter.getBF16Type());
1017
1018 Value lhsBF16 =
1019 smartTruncF32ToBF16(rewriter, loc, op.getLhs(), bf16VecType);
1020 Value rhsBF16 =
1021 smartTruncF32ToBF16(rewriter, loc, op.getRhs(), bf16VecType);
1022 Value accBF16 =
1023 smartTruncF32ToBF16(rewriter, loc, op.getAcc(), bf16VecType);
1024
1025 Value newResult =
1026 vector::FMAOp::create(rewriter, loc, lhsBF16, rhsBF16, accBF16);
1027 auto extOp = arith::ExtFOp::create(rewriter, loc, resultType, newResult);
1028 rewriter.replaceOp(op, extOp);
1029 return success();
1030 }
1031};
1032
1033/// Pattern to emulate f32 unary vector ops in bf16.
1034template <typename OpTy>
1037
1038 LogicalResult matchAndRewrite(OpTy op,
1039 PatternRewriter &rewriter) const override {
1040 auto resultType = dyn_cast<VectorType>(op.getType());
1041 if (!resultType || !resultType.getElementType().isF32())
1042 return failure();
1043
1044 Location loc = op.getLoc();
1045 auto bf16VecType =
1046 VectorType::get(resultType.getShape(), rewriter.getBF16Type());
1047
1048 Value inputBF16 =
1049 smartTruncF32ToBF16(rewriter, loc, op->getOperand(0), bf16VecType);
1050
1051 Value newResult = OpTy::create(rewriter, loc, bf16VecType, inputBF16);
1052 auto extOp = arith::ExtFOp::create(rewriter, loc, resultType, newResult);
1053 rewriter.replaceOp(op, extOp);
1054 return success();
1055 }
1056};
1057
1059 : public PassWrapper<BF16EmulationPass, OperationPass<>> {
1060
1061 void runOnOperation() override {
1062 auto *op = getOperation();
1063 MLIRContext *context = &getContext();
1064 RewritePatternSet patterns(context);
1065
1066 // Binary arithmetic ops
1072
1073 // Note: arith.divf is NOT demoted because bf16 vector divf is unsupported
1074 // on all AIE targets (Peano does not legalize G_FDIV on <16 x s16>).
1075
1076 // Special-case ops (excluding ReductionOp — its scalar accumulator
1077 // and result lower to fp_to_bf16/bf16_to_fp which older Peano versions
1078 // cannot select on AIE2P; reductions are intentionally left in f32
1079 // to avoid these scalar bf16 conversions)
1082
1083 // Unary ops
1084 patterns.add<EmulateUnaryF32InBF16Pattern<arith::NegFOp>>(context);
1085
1086 (void)applyPatternsGreedily(op, std::move(patterns));
1087 }
1088};
1089
1090std::unique_ptr<::mlir::Pass> xilinx::aievec::createBF16EmulationPass() {
1091 return std::make_unique<BF16EmulationPass>();
1092}
1093
1095 : public PassWrapper<VectorBroadcastLoweringPass, OperationPass<>> {
1096
1097 void runOnOperation() override {
1098 auto *op = getOperation();
1099 MLIRContext *context = &getContext();
1100 RewritePatternSet patterns(context);
1101 populateVectorBroadcastLoweringPatterns(patterns);
1103 patterns.getContext());
1104
1105 (void)applyPatternsGreedily(op, std::move(patterns));
1106 }
1107};
1108
1109static std::unique_ptr<::mlir::Pass> createVectorBroadcastLoweringPass() {
1110 return std::make_unique<VectorBroadcastLoweringPass>();
1111}
1112
1113// This pass converts standard vector ops into a subset of `Vector` ops more
1114// amenable to being converted to `AIEVec`. So far, this process consists of
1115// two steps:
1116// 1) Replace splat transfer reads with contiguous transfer reads followed
1117// by `extract` + `splat` operations.
1118// 2) Split unaligned transfer reads into a wider aligned transfer read
1119// followed by a `vector.extract_strided_slice` operation.
1121 : public PassWrapper<CanonicalizeVectorForAIEVecPass, OperationPass<>> {
1122 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CanonicalizeVectorForAIEVecPass)
1123
1127
1134
1135 // In case we want to register this pass as a standalone pass for test
1136 // purposes.
1137 StringRef getArgument() const final {
1138 return "test-canonicalize-vector-for-aievec";
1139 }
1140
1141 StringRef getDescription() const final {
1142 return "Canonicalize vector operations for AIEVec conversion";
1143 }
1144
1145 void getDependentDialects(DialectRegistry &registry) const override {
1146 registry
1147 .insert<arith::ArithDialect, memref::MemRefDialect,
1148 vector::VectorDialect, affine::AffineDialect, ub::UBDialect>();
1149 }
1150
1151 Option<std::string> aieTarget{
1152 *this, "aie-target",
1153 llvm::cl::desc(
1154 "Select AIE version: \"aie\", \"aie2\", or \"aie2p\". This will "
1155 "determine the vector size and available operations."),
1156 llvm::cl::init("aie")};
1157
1158 Option<std::string> targetBackend{
1159 *this, "target-backend",
1160 llvm::cl::desc("Select translation backend: \"cpp\" or \"llvmir\". This "
1161 "will determine the aievec operations used to convert "
1162 "from vector dialect."),
1163 llvm::cl::init("cpp")};
1164
1165 void runOnOperation() override {
1166 auto *op = getOperation();
1167 MLIRContext *context = &getContext();
1168 RewritePatternSet patterns(context);
1169 ConversionTarget target(*context);
1170
1171 AIEArch aieVersion = decodeAIETarget(aieTarget);
1172 if (aieVersion == AIEArch::UNKNOWN) {
1173 op->emitError() << "unknown AIE target '" << aieTarget << "'";
1174 signalPassFailure();
1175 return;
1176 }
1177
1178 TargetBackend backend = decodeTargetBackend(targetBackend);
1179 if (backend == TargetBackend::UNKNOWN) {
1180 op->emitError() << "unknown target backend '" << targetBackend << "'";
1181 signalPassFailure();
1182 return;
1183 }
1184 if (backend == TargetBackend::LLVMIR && aieVersion == AIEArch::AIE) {
1185 op->emitError() << "targetting LLVM IR is not supported for AIEv1";
1186 signalPassFailure();
1187 return;
1188 }
1189
1190 populateCommonAIECanonicalizeConversionPatterns(patterns, backend);
1191 configureCommonAIECanonicalizeLegalizations(target, backend);
1192 if (aieVersion == AIEArch::AIE) {
1193 populateAIEv1CanonicalizeConversionPatterns(patterns, backend);
1194 configureAIEv1CanonicalizeLegalizations(target, backend);
1195 } else {
1196 populateAIE2CanonicalizeConversionPatterns(patterns, backend);
1197 configureAIE2CanonicalizeLegalizations(target, backend);
1198 }
1199
1200 {
1201 RewritePatternSet patterns(context);
1204 (void)applyPatternsGreedily(op, std::move(patterns));
1205 }
1206
1207 if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
1208 signalPassFailure();
1209 }
1210 }
1211};
1212
1213static std::unique_ptr<::mlir::Pass> createCanonicalizeVectorForAIEVecPass(
1214 const CanonicalizeVectorForAIEVecOptions &options) {
1215 return std::make_unique<CanonicalizeVectorForAIEVecPass>(options);
1216}
1217
1219 : public PassWrapper<HoistCastOpToDataSourcePass, OperationPass<>> {
1220
1221 void runOnOperation() override {
1222 auto *op = getOperation();
1223 MLIRContext *context = &getContext();
1224 RewritePatternSet patterns(context);
1225
1226 patterns.add<HoistCastOpToDataSourcePattern>(patterns.getContext());
1227
1228 (void)applyPatternsGreedily(op, std::move(patterns));
1229 }
1230};
1231
1232static std::unique_ptr<::mlir::Pass> createHoistCastOpToDataSourcePass() {
1233 return std::make_unique<HoistCastOpToDataSourcePass>();
1234}
1235
1237 : public PassWrapper<ReorderOperationsPass, OperationPass<>> {
1238
1239 void runOnOperation() override {
1240 auto *op = getOperation();
1241 MLIRContext *context = &getContext();
1242 RewritePatternSet patterns(context);
1243
1245 patterns.getContext(),
1246 [](arith::ExtSIOp extOp, vector::BroadcastOp bcastOp) -> Type {
1247 Type extInElemTy = extOp.getIn().getType();
1248 auto extInVecTy = dyn_cast<VectorType>(extInElemTy);
1249 if (extInVecTy)
1250 extInElemTy = extInVecTy.getElementType();
1251 return VectorType::get(bcastOp.getResultVectorType().getShape(),
1252 extInElemTy);
1253 });
1254
1255 (void)applyPatternsGreedily(op, std::move(patterns));
1256 }
1257};
1258
1259static std::unique_ptr<::mlir::Pass> createReorderOperationsPass() {
1260 return std::make_unique<ReorderOperationsPass>();
1261}
1262
1263//============================================================================//
1264//=============== Main Vector2Vector Pipeline Configuration ==================//
1265//============================================================================//
1266
1268 OpPassManager &pm, const CanonicalizeVectorForAIEVecOptions &options) {
1269 // Add `Vector` code canonicalization passes
1270 // TODO: Add passes to unroll vector with unsupported types
1271 // TODO: Add passes to split vectors that won't fit in registers
1272
1273 // If bf16-emulation is enabled, demote f32 vector arithmetic to bf16 first.
1274 if (options.enableBF16Emulation)
1275 pm.addPass(createBF16EmulationPass());
1276
1277 if (decodeTargetBackend(options.targetBackend) == TargetBackend::LLVMIR)
1278 pm.addPass(createReorderOperationsPass());
1279 pm.addPass(createCopyRemovalPass());
1280 pm.addPass(createVectorBroadcastLoweringPass());
1281 pm.addPass(createCanonicalizeVectorForAIEVecPass(options));
1282 if (decodeTargetBackend(options.targetBackend) == TargetBackend::CPP)
1283 pm.addPass(createHoistCastOpToDataSourcePass());
1284}
std::unique_ptr<::mlir::Pass > createCopyRemovalPass()
Create a pass that removes unnecessary Copy operations.
std::optional< int64_t > getTransferReadAlignmentOffset(TransferReadLikeOp readOp, mlir::VectorType vType, int64_t alignment)
int32_t getElementSizeInBits(mlir::VectorType type)
Definition AIEVecUtils.h:49
std::unique_ptr<::mlir::Pass > createBF16EmulationPass()
Create a pass that emulates f32 vector arithmetic using bf16 operations.
void buildCanonicalizeVectorForAIEVec(mlir::OpPassManager &pm, const CanonicalizeVectorForAIEVecOptions &options)
TargetBackend
Definition Passes.h:27
AIEArch
Definition Passes.h:21
Pattern to canonicalize trivial vector.transfer_read operations on subviews.
LogicalResult matchAndRewrite(vector::TransferReadOp readOp, PatternRewriter &rewriter) const override
Pattern to canonicalize trivial vector.transfer_write operations on subviews.
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, PatternRewriter &rewriter) const override
void getDependentDialects(DialectRegistry &registry) const override
CanonicalizeVectorForAIEVecPass(const CanonicalizeVectorForAIEVecOptions &options)
LogicalResult matchAndRewrite(vector::InsertOp insOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Pattern to emulate f32 binary vector arithmetic ops in bf16.
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
Pattern to emulate f32 comparison ops in bf16.
LogicalResult matchAndRewrite(arith::CmpFOp op, PatternRewriter &rewriter) const override
Pattern to emulate f32 vector.fma in bf16.
LogicalResult matchAndRewrite(vector::FMAOp op, PatternRewriter &rewriter) const override
Pattern to emulate f32 select ops in bf16.
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
Pattern to emulate f32 unary vector ops in bf16.
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
static VectorType getTransposedVectorType(VectorType vecTy)
LogicalResult matchAndRewrite(vector::ContractionOp contractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
SplitUnalignedTransferReadPattern(MLIRContext *context, int64_t maxVectorSize, int64_t alignment)
LogicalResult matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(UnaryOpB bOp, PatternRewriter &rewriter) const override
std::function< Type(UnaryOpA aOp, UnaryOpB bOp)> InferTypeB2AFnTy
SwapUnaryOpsPattern(MLIRContext *context, InferTypeB2AFnTy inferType)
Options for the "canonicalize-vector-for-aievec" pipeline.
Definition Passes.h:41
PassOptions::Option< bool > enableBF16Emulation
Definition Passes.h:53
PassOptions::Option< std::string > targetBackend
Definition Passes.h:47
PassOptions::Option< std::string > aieTarget
Definition Passes.h:42