MLIR-AIE
AIEHoistVectorTransferPointers.cpp
Go to the documentation of this file.
1//===- AIEHoistVectorTransferPointers.cpp -----------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2025 Advanced Micro Devices Inc.
8//
9//===----------------------------------------------------------------------===//
10//
11// This pass hoists vector transfer operations with IV-dependent pointers
12// out of scf.for loops by using iter_args to track pointer updates. This
13// optimization reduces address computation overhead in loops by maintaining
14// a running pointer offset rather than recomputing addresses each iteration.
15//
16//===----------------------------------------------------------------------===//
17
20
21#include "mlir/Dialect/Affine/IR/AffineOps.h"
22#include "mlir/Dialect/Arith/IR/Arith.h"
23#include "mlir/Dialect/Func/IR/FuncOps.h"
24#include "mlir/Dialect/MemRef/IR/MemRef.h"
25#include "mlir/Dialect/SCF/IR/SCF.h"
26#include "mlir/Dialect/SCF/Utils/Utils.h"
27#include "mlir/Dialect/Vector/IR/VectorOps.h"
28#include "mlir/IR/IRMapping.h"
29#include "mlir/IR/PatternMatch.h"
30#include "mlir/Interfaces/LoopLikeInterface.h"
31#include "mlir/Pass/Pass.h"
32#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
33
34namespace xilinx::AIE {
35#define GEN_PASS_DEF_AIEHOISTVECTORTRANSFERPOINTERS
36#include "aie/Dialect/AIE/Transforms/AIEPasses.h.inc"
37} // namespace xilinx::AIE
38
39#define DEBUG_TYPE "aie-hoist-vector-transfer-pointers"
40
41using namespace mlir;
42using namespace xilinx;
43using namespace xilinx::AIE;
44
45namespace {
46
47//===----------------------------------------------------------------------===//
48// Helper Functions
49//===----------------------------------------------------------------------===//
50
51/// Check if a value depends on the given loop induction variable
52/// Uses a cache to avoid exponential recursion on complex dependency chains
53static bool dependsOnLoopIVForHoist(Value val, Value loopIV,
54 DenseMap<Value, bool> &cache) {
55 // Check cache - return cached result if already computed
56 auto it = cache.find(val);
57 if (it != cache.end())
58 return it->second;
59
60 // Mark as being computed (assume false initially to handle recursion)
61 // This prevents infinite recursion in case of cycles (though SSA shouldn't
62 // have cycles)
63 cache[val] = false;
64
65 bool result = false;
66 if (val == loopIV) {
67 result = true;
68 } else if (auto defOp = val.getDefiningOp()) {
69 // Check for operations that use the loop IV in their operands
70 for (Value operand : defOp->getOperands()) {
71 if (dependsOnLoopIVForHoist(operand, loopIV, cache)) {
72 result = true;
73 break;
74 }
75 }
76 }
77
78 // Store the computed result in cache
79 cache[val] = result;
80 return result;
81}
82
83/// Wrapper for dependsOnLoopIVForHoist that manages the cache
84static bool dependsOnLoopIVForHoist(Value val, Value loopIV) {
85 DenseMap<Value, bool> cache;
86 return dependsOnLoopIVForHoist(val, loopIV, cache);
87}
88
89/// Clone an operation and its operands (recursively) that don't depend on the
90/// loop IV. Uses memoization via the mapping to avoid exponential recursion.
91static Value cloneOpAndOperands(Operation *op, Value loopIV, OpBuilder &builder,
92 IRMapping &mapping) {
93 // Only handle operations with exactly one result
94 if (op->getNumResults() != 1)
95 return Value();
96
97 // If we've already cloned this operation, return the mapped result
98 // This is critical for avoiding exponential recursion
99 if (mapping.contains(op->getResult(0)))
100 return mapping.lookup(op->getResult(0));
101
102 // Check if this operation depends on the loop IV before trying to clone
103 if (dependsOnLoopIVForHoist(op->getResult(0), loopIV))
104 return Value();
105
106 // Clone operands recursively
107 SmallVector<Value> newOperands;
108 for (Value operand : op->getOperands()) {
109 if (auto defOp = operand.getDefiningOp()) {
110 Value clonedOperand = cloneOpAndOperands(defOp, loopIV, builder, mapping);
111 if (!clonedOperand)
112 return Value(); // Failed to clone an operand
113 newOperands.push_back(clonedOperand);
114 } else {
115 // Operand is a block argument or constant (guaranteed not to be the
116 // loop IV due to the dependency check at line 91)
117 newOperands.push_back(operand);
118 }
119 }
120
121 // Clone the operation
122 Operation *clonedOp = builder.clone(*op);
123 clonedOp->setOperands(newOperands);
124
125 // Map the result to enable memoization
126 mapping.map(op->getResult(0), clonedOp->getResult(0));
127 return clonedOp->getResult(0);
128}
129
130/// Get the total number of elements in a vector type
131static int64_t getVectorNumElements(VectorType vectorType) {
132 int64_t numElements = 1;
133 for (int64_t dim : vectorType.getShape()) {
134 numElements *= dim;
135 }
136 return numElements;
137}
138
139//===----------------------------------------------------------------------===//
140// HoistVectorTransferPointers Pattern
141//===----------------------------------------------------------------------===//
142
143/// Information about a vector transfer operation
144struct TransferOpInfo {
145 Operation *op;
146 Value base;
147 MemRefType memrefType;
148 VectorType vectorType;
149 SmallVector<Value> indices;
150 int64_t constantStride; // Total constant stride per iteration
151 bool hasIVDependentIndices;
152};
153
154/// Pattern to hoist vector transfer operations with IV-dependent pointers
155/// out of scf.for loops by using iter_args to track pointer updates
156struct HoistVectorTransferPointersPattern
157 : public OpRewritePattern<scf::ForOp> {
158 using OpRewritePattern<scf::ForOp>::OpRewritePattern;
159
160 LogicalResult matchAndRewrite(scf::ForOp forOp,
161 PatternRewriter &rewriter) const override {
162 Value loopIV = forOp.getInductionVar();
163 Location loc = forOp.getLoc();
164
165 // Collect all vector transfer operations with IV-dependent indices
166 SmallVector<TransferOpInfo> transferOps;
167
168 for (Operation &op : forOp.getBody()->without_terminator()) {
169 Value base;
170 VectorType vectorType;
171 SmallVector<Value> indices;
172
173 if (auto readOp = dyn_cast<vector::TransferReadOp>(&op)) {
174 base = readOp.getBase();
175 vectorType = readOp.getVectorType();
176 indices.assign(readOp.getIndices().begin(), readOp.getIndices().end());
177 } else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(&op)) {
178 base = writeOp.getBase();
179 vectorType = writeOp.getVectorType();
180 indices.assign(writeOp.getIndices().begin(),
181 writeOp.getIndices().end());
182 } else {
183 continue;
184 }
185
186 auto memrefType = dyn_cast<MemRefType>(base.getType());
187 if (!memrefType)
188 continue;
189
190 // Check if any indices depend on loop IV and compute constant stride
191 bool hasIVDependentIndices = false;
192 int64_t constantStride = 0;
193
194 // Get the loop step to account for in stride calculation
195 auto stepCst = forOp.getConstantStep();
196 int64_t loopStep =
197 stepCst.has_value() ? stepCst.value().getSExtValue() : 1;
198
199 for (size_t dimIdx = 0; dimIdx < indices.size(); ++dimIdx) {
200 Value idx = indices[dimIdx];
201 if (dependsOnLoopIVForHoist(idx, loopIV)) {
202 hasIVDependentIndices = true;
203
204 // Calculate the stride for this dimension
205 int64_t dimStride = 1;
206 bool hasDynamicStride = false;
207 for (size_t j = dimIdx + 1;
208 j < static_cast<size_t>(memrefType.getRank()); ++j) {
209 int64_t dimSize = memrefType.getShape()[j];
210 if (dimSize == ShapedType::kDynamic) {
211 hasDynamicStride = true;
212 break;
213 }
214 dimStride *= dimSize;
215 }
216
217 // Multiply by loop step - the stride per iteration is:
218 // (elements per dimension) * (loop step)
219 if (!hasDynamicStride)
220 constantStride += dimStride * loopStep;
221 else
222 hasIVDependentIndices = false; // Can't hoist if stride is dynamic
223 }
224 }
225
226 transferOps.push_back({&op, base, memrefType, vectorType, indices,
227 constantStride, hasIVDependentIndices});
228 }
229
230 // If there are no transfer ops, don't modify
231 if (transferOps.empty())
232 return failure();
233
234 // Prepare to add iter_args for each transfer operation with IV-dependent
235 // indices
236 SmallVector<Value> newInitArgs;
237 SmallVector<Value> flatMemrefs;
238
239 for (const auto &info : transferOps) {
240 if (!info.hasIVDependentIndices)
241 continue;
242
243 // Flatten the memref if needed
244 rewriter.setInsertionPoint(forOp);
245 Value flatMemref = info.base;
246 if (info.memrefType.getRank() > 1) {
247 int64_t totalSize = 1;
248 for (int64_t dim : info.memrefType.getShape()) {
249 if (dim == ShapedType::kDynamic)
250 return failure(); // Dynamic memref shapes not supported
251 totalSize *= dim;
252 }
253
254 // Preserve strided layout if present
255 MemRefType flatMemrefType;
256 if (auto stridedLayout = dyn_cast_or_null<StridedLayoutAttr>(
257 info.memrefType.getLayout())) {
258 // The collapsed stride is the innermost stride (last element)
259 int64_t collapsedStride = stridedLayout.getStrides().back();
260 int64_t offset = stridedLayout.getOffset();
261
262 auto newLayout = StridedLayoutAttr::get(rewriter.getContext(), offset,
263 {collapsedStride});
264 flatMemrefType =
265 MemRefType::get({totalSize}, info.memrefType.getElementType(),
266 newLayout, info.memrefType.getMemorySpace());
267 } else {
268 flatMemrefType =
269 MemRefType::get({totalSize}, info.memrefType.getElementType(),
270 AffineMap(), info.memrefType.getMemorySpace());
271 }
272
273 SmallVector<ReassociationIndices> reassociation;
274 ReassociationIndices allDims;
275 for (size_t i = 0; i < static_cast<size_t>(info.memrefType.getRank());
276 ++i) {
277 allDims.push_back(i);
278 }
279 reassociation.push_back(allDims);
280
281 flatMemref = memref::CollapseShapeOp::create(
282 rewriter, loc, flatMemrefType, info.base, reassociation);
283 }
284 flatMemrefs.push_back(flatMemref);
285
286 // Compute base pointer (with zeros for IV-dependent parts)
287 int64_t rank = info.memrefType.getRank();
288 AffineExpr linearExpr = rewriter.getAffineConstantExpr(0);
289 int64_t stride = 1;
290 for (int64_t i = rank - 1; i >= 0; --i) {
291 linearExpr = linearExpr + rewriter.getAffineDimExpr(i) * stride;
292 if (i > 0)
293 stride *= info.memrefType.getShape()[i];
294 }
295 auto linearMap = AffineMap::get(rank, 0, linearExpr);
296
297 // For IV-dependent indices, evaluate them at the loop's lower bound
298 // to preserve constant offsets (e.g., %iv+1 becomes lowerBound+1)
299 SmallVector<Value> evaluatedIndices;
300 IRMapping indexMapping;
301 for (Value idx : info.indices) {
302 if (dependsOnLoopIVForHoist(idx, loopIV)) {
303 // Clone the computation with the IV replaced by lower bound
304 if (auto affineOp = idx.getDefiningOp<affine::AffineApplyOp>()) {
305 SmallVector<Value> mappedOperands;
306 for (Value operand : affineOp.getMapOperands()) {
307 if (operand == loopIV)
308 mappedOperands.push_back(forOp.getLowerBound());
309 else
310 mappedOperands.push_back(operand);
311 }
312 Value evaluatedIdx = affine::AffineApplyOp::create(
313 rewriter, loc, affineOp.getAffineMap(), mappedOperands);
314 evaluatedIndices.push_back(evaluatedIdx);
315 } else {
316 // Direct IV usage - just use lower bound
317 evaluatedIndices.push_back(forOp.getLowerBound());
318 }
319 } else {
320 // Index doesn't depend on IV, clone it
321 if (auto defOp = idx.getDefiningOp()) {
322 Value clonedIdx =
323 cloneOpAndOperands(defOp, loopIV, rewriter, indexMapping);
324 if (clonedIdx)
325 evaluatedIndices.push_back(clonedIdx);
326 else
327 evaluatedIndices.push_back(idx);
328 } else {
329 evaluatedIndices.push_back(idx);
330 }
331 }
332 }
333
334 Value basePointer = affine::AffineApplyOp::create(
335 rewriter, loc, linearMap, evaluatedIndices);
336
337 newInitArgs.push_back(basePointer);
338 }
339
340 // If there are no IV-dependent transfers, just process them to flatten
341 // vectors
342 if (newInitArgs.empty()) {
343 // Check if any transfer needs flattening (avoid infinite rewrites)
344 bool needsFlattening = false;
345 bool hasProcessableTransfers = false;
346 for (const auto &info : transferOps) {
347 // Skip if base is defined inside the loop (e.g., a subview)
348 // We can't hoist these
349 if (info.base.getDefiningOp() &&
350 forOp->isProperAncestor(info.base.getDefiningOp()))
351 continue;
352
353 hasProcessableTransfers = true;
354
355 // Check if this transfer has already been flattened
356 // (flattened transfers use 1D identity map)
357 if (auto readOp = dyn_cast<vector::TransferReadOp>(info.op)) {
358 if (readOp.getPermutationMap().getNumDims() != 1)
359 needsFlattening = true;
360 } else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(info.op)) {
361 if (writeOp.getPermutationMap().getNumDims() != 1)
362 needsFlattening = true;
363 }
364 }
365
366 // If there are no processable transfers (all bases defined in loop)
367 // or nothing needs flattening, bail out
368 if (!hasProcessableTransfers || !needsFlattening)
369 return failure();
370
371 // First, create flattened memrefs outside the loop for bases not defined
372 // inside
373 DenseMap<Value, Value> baseFlatMemrefs;
374 rewriter.setInsertionPoint(forOp);
375 for (const auto &info : transferOps) {
376 if (baseFlatMemrefs.count(info.base))
377 continue;
378
379 // Skip if base is defined inside the loop (e.g., a subview)
380 if (info.base.getDefiningOp() &&
381 forOp->isProperAncestor(info.base.getDefiningOp()))
382 continue;
383
384 Value flatMemref = info.base;
385 if (info.memrefType.getRank() > 1) {
386 int64_t totalSize = 1;
387 for (int64_t dim : info.memrefType.getShape()) {
388 totalSize *= dim;
389 }
390
391 // Preserve strided layout if present
392 MemRefType flatMemrefType;
393 if (auto stridedLayout = dyn_cast_or_null<StridedLayoutAttr>(
394 info.memrefType.getLayout())) {
395 int64_t collapsedStride = stridedLayout.getStrides().back();
396 int64_t offset = stridedLayout.getOffset();
397
398 auto newLayout = StridedLayoutAttr::get(rewriter.getContext(),
399 offset, {collapsedStride});
400 flatMemrefType =
401 MemRefType::get({totalSize}, info.memrefType.getElementType(),
402 newLayout, info.memrefType.getMemorySpace());
403 } else {
404 flatMemrefType =
405 MemRefType::get({totalSize}, info.memrefType.getElementType(),
406 AffineMap(), info.memrefType.getMemorySpace());
407 }
408
409 SmallVector<ReassociationIndices> reassociation;
410 ReassociationIndices allDims;
411 for (size_t i = 0; i < static_cast<size_t>(info.memrefType.getRank());
412 ++i) {
413 allDims.push_back(i);
414 }
415 reassociation.push_back(allDims);
416 flatMemref = memref::CollapseShapeOp::create(
417 rewriter, loc, flatMemrefType, info.base, reassociation);
418 }
419 baseFlatMemrefs[info.base] = flatMemref;
420 }
421
422 // Process all transfers without using iter_args
423 bool madeChanges = false;
424 for (const auto &info : transferOps) {
425 // Skip if base is defined inside the loop
426 if (info.base.getDefiningOp() &&
427 forOp->isProperAncestor(info.base.getDefiningOp()))
428 continue;
429
430 // Skip if we don't have a flattened version
431 if (!baseFlatMemrefs.count(info.base))
432 continue;
433
434 rewriter.setInsertionPoint(info.op);
435
436 // Flatten vector type
437 int64_t numElements = getVectorNumElements(info.vectorType);
438 VectorType flatVectorType =
439 VectorType::get({numElements}, info.vectorType.getElementType());
440
441 // Get the flattened memref
442 Value flatMemref = baseFlatMemrefs[info.base];
443
444 // Compute pointer from indices
445 int64_t rank = info.memrefType.getRank();
446 AffineExpr linearExpr = rewriter.getAffineConstantExpr(0);
447 int64_t stride = 1;
448 for (int64_t i = rank - 1; i >= 0; --i) {
449 linearExpr = linearExpr + rewriter.getAffineDimExpr(i) * stride;
450 if (i > 0)
451 stride *= info.memrefType.getShape()[i];
452 }
453 auto linearMap = AffineMap::get(rank, 0, linearExpr);
454
455 Value currentPointer = affine::AffineApplyOp::create(
456 rewriter, loc, linearMap, info.indices);
457
458 // Transform the transfer operation
459 AffineMap identityMap1D = AffineMap::get(
460 1, 0, rewriter.getAffineDimExpr(0), rewriter.getContext());
461 auto inBoundsAttr = rewriter.getBoolArrayAttr({true});
462
463 if (auto readOp = dyn_cast<vector::TransferReadOp>(info.op)) {
464 Value flatRead = vector::TransferReadOp::create(
465 rewriter, loc, flatVectorType, flatMemref,
466 ValueRange{currentPointer}, AffineMapAttr::get(identityMap1D),
467 readOp.getPadding(),
468 /*mask=*/Value(), inBoundsAttr);
469 Value shapedRead = vector::ShapeCastOp::create(
470 rewriter, loc, info.vectorType, flatRead);
471 rewriter.replaceOp(readOp, shapedRead);
472 madeChanges = true;
473 } else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(info.op)) {
474 Value flatValue = vector::ShapeCastOp::create(
475 rewriter, loc, flatVectorType, writeOp.getVector());
476 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
477 writeOp, flatValue, flatMemref, ValueRange{currentPointer},
478 AffineMapAttr::get(identityMap1D), /*mask=*/Value(),
479 inBoundsAttr);
480 madeChanges = true;
481 }
482 }
483 return madeChanges ? success() : failure();
484 }
485
486 // Use replaceWithAdditionalYields to add pointer iter_args
487 auto yieldValuesFn =
488 [&](OpBuilder &b, Location yieldLoc,
489 ArrayRef<BlockArgument> newBbArgs) -> SmallVector<Value> {
490 SmallVector<Value> yieldValues;
491
492 // Process each transfer operation with IV-dependent indices
493 size_t iterArgIdx = 0;
494 for (size_t i = 0; i < transferOps.size(); ++i) {
495 const auto &info = transferOps[i];
496 if (!info.hasIVDependentIndices)
497 continue;
498
499 BlockArgument ptrIterArg =
500 newBbArgs[newBbArgs.size() - newInitArgs.size() + iterArgIdx];
501 Value flatMemref = flatMemrefs[iterArgIdx];
502
503 // Flatten vector type
504 int64_t numElements = getVectorNumElements(info.vectorType);
505 VectorType flatVectorType =
506 VectorType::get({numElements}, info.vectorType.getElementType());
507
508 // Transform the transfer operation to use the iter_arg pointer
509 b.setInsertionPoint(info.op);
510
511 AffineMap identityMap1D =
512 AffineMap::get(1, 0, b.getAffineDimExpr(0), b.getContext());
513 auto inBoundsAttr = b.getBoolArrayAttr({true});
514
515 if (auto readOp = dyn_cast<vector::TransferReadOp>(info.op)) {
516 Value flatRead = vector::TransferReadOp::create(
517 b, loc, flatVectorType, flatMemref, ValueRange{ptrIterArg},
518 AffineMapAttr::get(identityMap1D), readOp.getPadding(),
519 /*mask=*/Value(), inBoundsAttr);
520 Value shapedRead =
521 vector::ShapeCastOp::create(b, loc, info.vectorType, flatRead);
522 rewriter.replaceOp(readOp, shapedRead);
523 } else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(info.op)) {
524 Value flatValue = vector::ShapeCastOp::create(b, loc, flatVectorType,
525 writeOp.getVector());
526 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
527 writeOp, flatValue, flatMemref, ValueRange{ptrIterArg},
528 AffineMapAttr::get(identityMap1D), /*mask=*/Value(),
529 inBoundsAttr);
530 }
531
532 // Compute next pointer value: current_ptr + constant_stride
533 Value strideConst =
534 arith::ConstantIndexOp::create(b, yieldLoc, info.constantStride);
535 Value nextPtr =
536 arith::AddIOp::create(b, yieldLoc, ptrIterArg, strideConst);
537 yieldValues.push_back(nextPtr);
538
539 iterArgIdx++;
540 }
541
542 return yieldValues;
543 };
544
545 // Create new loop with additional iter_args for pointers
546 FailureOr<LoopLikeOpInterface> newLoopResult =
547 cast<LoopLikeOpInterface>(forOp.getOperation())
548 .replaceWithAdditionalYields(
549 rewriter, newInitArgs, // new init operands (base pointers)
550 true, // replace uses in loop
551 yieldValuesFn);
552
553 if (failed(newLoopResult))
554 return failure();
555
556 return success();
557 }
558};
559
560//===----------------------------------------------------------------------===//
561// AIEHoistVectorTransferPointersPass
562//===----------------------------------------------------------------------===//
563
564struct AIEHoistVectorTransferPointersPass
565 : xilinx::AIE::impl::AIEHoistVectorTransferPointersBase<
566 AIEHoistVectorTransferPointersPass> {
567 void getDependentDialects(DialectRegistry &registry) const override {
568 registry.insert<affine::AffineDialect, arith::ArithDialect,
569 memref::MemRefDialect, scf::SCFDialect,
570 vector::VectorDialect>();
571 }
572
573 void runOnOperation() override {
574 ModuleOp moduleOp = getOperation();
575 MLIRContext *context = &getContext();
576
577 RewritePatternSet patterns(context);
578 patterns.add<HoistVectorTransferPointersPattern>(context);
579
580 // Apply patterns to the entire module - the pattern will only match scf.for
581 // ops within aie.core regions
582 if (failed(applyPatternsGreedily(moduleOp, std::move(patterns))))
583 signalPassFailure();
584 }
585};
586
587} // namespace
588
589std::unique_ptr<OperationPass<ModuleOp>>
591 return std::make_unique<AIEHoistVectorTransferPointersPass>();
592}
Include the generated interface declarations.
std::unique_ptr< mlir::OperationPass< mlir::ModuleOp > > createAIEHoistVectorTransferPointersPass()