MLIR-AIE
AIEVectorToPointerLoops.cpp
Go to the documentation of this file.
1//===- AIEVectorToPointerLoops.cpp -----------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2025 Advanced Micro Devices, Inc.
8//
9//===----------------------------------------------------------------------===//
10//
11// This pass transforms vector.load/store operations with loop-carried indices
12// to use ptr dialect operations (ptr.to_ptr, ptr.ptr_add, ptr.from_ptr).
13//
14// Goal: Make pointer increment patterns explicit to help LLVM backend generate
15// efficient post-increment addressing modes (GEP fusion).
16//
17// Transformation:
18// Before:
19// scf.for iter_args(%idx = %0) -> (index) {
20// %vec = vector.load %memref[%idx]
21// %next_idx = arith.addi %idx, %stride
22// scf.yield %next_idx
23// }
24//
25// After:
26// %base_ptr = ptr.to_ptr %memref
27// %init_ptr = ptr.ptr_add %base_ptr, %0
28// scf.for iter_args(%ptr = %init_ptr) -> (!ptr.ptr<...>) {
29// %memref_tmp = ptr.from_ptr %ptr
30// %vec = vector.load %memref_tmp[%c0]
31// %next_ptr = ptr.ptr_add %ptr, %stride
32// scf.yield %next_ptr
33// }
34//
35//===----------------------------------------------------------------------===//
36
39
40#include "mlir/Dialect/Arith/IR/Arith.h"
41#include "mlir/Dialect/Func/IR/FuncOps.h"
42#include "mlir/Dialect/MemRef/IR/MemRef.h"
43#include "mlir/Dialect/Ptr/IR/PtrOps.h"
44#include "mlir/Dialect/SCF/IR/SCF.h"
45#include "mlir/Dialect/Vector/IR/VectorOps.h"
46#include "mlir/IR/PatternMatch.h"
47#include "mlir/Pass/Pass.h"
48#include "mlir/Transforms/DialectConversion.h"
49#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
50
51#include "llvm/ADT/DenseMap.h"
52#include "llvm/ADT/SetVector.h"
53
54#define DEBUG_TYPE "aie-vector-to-pointer-loops"
55
56using namespace mlir;
57using namespace xilinx;
58using namespace xilinx::AIE;
59
60namespace {
61
62/// Check if a value is an iter_arg of an scf.for loop
63static bool isLoopCarriedValue(Value val, scf::ForOp forOp) {
64 auto blockArgs = forOp.getRegionIterArgs();
65 return llvm::is_contained(blockArgs, val);
66}
67
68/// Structure to track memref base and its uses in vector ops
69struct MemrefVectorAccess {
70 Value memref;
71 SmallVector<vector::LoadOp> loads;
72 SmallVector<vector::StoreOp> stores;
73 SmallVector<Value> indices; // Loop-carried indices used
74};
75
76/// Analyze loop to find vector load/store patterns with loop-carried indices
77static bool analyzeLoopForVectorAccesses(
78 scf::ForOp forOp, DenseMap<Value, MemrefVectorAccess> &memrefAccesses) {
79
80 bool foundPattern = false;
81
82 forOp.walk([&](vector::LoadOp loadOp) {
83 Value base = loadOp.getBase();
84 auto indices = loadOp.getIndices();
85
86 // Only handle 1D access for now
87 if (indices.size() != 1)
88 return;
89
90 Value idx = indices[0];
91
92 // Check if index is loop-carried
93 if (isLoopCarriedValue(idx, forOp)) {
94 memrefAccesses[base].memref = base;
95 memrefAccesses[base].loads.push_back(loadOp);
96 memrefAccesses[base].indices.push_back(idx);
97 foundPattern = true;
98 }
99 });
100
101 forOp.walk([&](vector::StoreOp storeOp) {
102 Value base = storeOp.getBase();
103 auto indices = storeOp.getIndices();
104
105 if (indices.size() != 1)
106 return;
107
108 Value idx = indices[0];
109
110 if (isLoopCarriedValue(idx, forOp)) {
111 memrefAccesses[base].memref = base;
112 memrefAccesses[base].stores.push_back(storeOp);
113 memrefAccesses[base].indices.push_back(idx);
114 foundPattern = true;
115 }
116 });
117
118 return foundPattern;
119}
120
121/// Transform an scf.for loop to use pointer iter_args
122struct VectorToPointerLoopsPattern : public OpRewritePattern<scf::ForOp> {
123 using OpRewritePattern<scf::ForOp>::OpRewritePattern;
124
125 LogicalResult matchAndRewrite(scf::ForOp forOp,
126 PatternRewriter &rewriter) const override {
127
128 // Skip if loop already uses pointer iter_args (already transformed)
129 for (Value iterArg : forOp.getRegionIterArgs()) {
130 if (llvm::isa<ptr::PtrType>(iterArg.getType()))
131 return failure(); // Already transformed
132 }
133
134 // Analyze the loop for vector access patterns
135 DenseMap<Value, MemrefVectorAccess> memrefAccesses;
136 if (!analyzeLoopForVectorAccesses(forOp, memrefAccesses))
137 return failure();
138
139 if (memrefAccesses.empty())
140 return failure();
141
142 Location loc = forOp.getLoc();
143 OpBuilder::InsertionGuard guard(rewriter);
144 rewriter.setInsertionPoint(forOp);
145
146 // Step 1: Convert memrefs to pointers before the loop
147 DenseMap<Value, Value> memrefToPtrMap;
148 DenseMap<Value, Type>
149 memrefToGenericTypeMap; // Track generic-space memref types
150 DenseMap<Value, Value>
151 memrefToConvertedMap; // Track the converted memref value (for metadata)
152 auto genericSpace = ptr::GenericSpaceAttr::get(rewriter.getContext());
153
154 for (auto &[memref, access] : memrefAccesses) {
155 // Get the memory space from the memref type
156 auto memrefType = cast<MemRefType>(memref.getType());
157 Attribute memorySpace = memrefType.getMemorySpace();
158
159 Value memrefToConvert = memref;
160 Type genericMemrefType = memrefType;
161
162 // If memref has a different memory space, cast it to generic_space first
163 if (memorySpace && !llvm::isa<ptr::GenericSpaceAttr>(memorySpace)) {
164 // Create new memref type with generic_space
165 auto newMemrefType =
166 MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
167 memrefType.getLayout(), genericSpace);
168
169 // Insert unrealized_conversion_cast to convert to generic_space
170 auto castOp = UnrealizedConversionCastOp::create(rewriter, loc,
171 newMemrefType, memref);
172 memrefToConvert = castOp.getResult(0);
173 genericMemrefType = newMemrefType;
174 }
175
176 // Create pointer type with generic_space
177 auto ptrType = ptr::PtrType::get(rewriter.getContext(), genericSpace);
178 auto ptrOp =
179 ptr::ToPtrOp::create(rewriter, loc, ptrType, memrefToConvert);
180 memrefToPtrMap[memref] = ptrOp.getResult();
181 memrefToGenericTypeMap[memref] = genericMemrefType;
182 memrefToConvertedMap[memref] =
183 memrefToConvert; // Store the converted memref
184 }
185
186 // Step 2: Identify which iter_args are indices used in vector ops
187 SmallVector<unsigned> indexIterArgPositions;
188 SmallVector<Value> correspondingMemrefs;
189
190 for (auto [idx, iterArg] : llvm::enumerate(forOp.getRegionIterArgs())) {
191 for (auto &[memref, access] : memrefAccesses) {
192 if (llvm::is_contained(access.indices, iterArg)) {
193 indexIterArgPositions.push_back(idx);
194 correspondingMemrefs.push_back(memref);
195 break;
196 }
197 }
198 }
199
200 if (indexIterArgPositions.empty())
201 return failure();
202
203 // Step 3: Build new init args with pointers
204 SmallVector<Value> newInitArgs;
205 Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
206
207 for (auto [idx, initArg] : llvm::enumerate(forOp.getInitArgs())) {
208 auto it = llvm::find(indexIterArgPositions, idx);
209 if (it != indexIterArgPositions.end()) {
210 // This is an index iter_arg - convert to pointer
211 size_t pos = std::distance(indexIterArgPositions.begin(), it);
212 Value memref = correspondingMemrefs[pos];
213 Value basePtr = memrefToPtrMap[memref];
214
215 // Get element size in bytes
216 auto memrefType = cast<MemRefType>(memref.getType());
217 unsigned elementSizeBits = memrefType.getElementTypeBitWidth();
218 unsigned elementSizeBytes = (elementSizeBits + 7) / 8;
219
220 // Scale index by element size: byteOffset = initArg * elementSizeBytes
221 Value byteOffset = initArg;
222 if (elementSizeBytes != 1) {
223 Value elementSize =
224 arith::ConstantIndexOp::create(rewriter, loc, elementSizeBytes);
225 byteOffset =
226 arith::MulIOp::create(rewriter, loc, initArg, elementSize);
227 }
228
229 // Create: ptr.ptr_add basePtr, byteOffset
230 auto initPtrOp =
231 ptr::PtrAddOp::create(rewriter, loc, basePtr, byteOffset);
232 newInitArgs.push_back(initPtrOp.getResult());
233 } else {
234 // Keep as-is
235 newInitArgs.push_back(initArg);
236 }
237 }
238
239 // Step 4: Create new loop with updated signature
240 auto newForOp =
241 scf::ForOp::create(rewriter, loc, forOp.getLowerBound(),
242 forOp.getUpperBound(), forOp.getStep(), newInitArgs);
243
244 // Step 5: Transform loop body (simplified - doesn't handle all cases yet)
245 IRMapping mapper;
246 mapper.map(forOp.getInductionVar(), newForOp.getInductionVar());
247
248 for (auto [oldArg, newArg] :
249 llvm::zip(forOp.getRegionIterArgs(), newForOp.getRegionIterArgs())) {
250 mapper.map(oldArg, newArg);
251 }
252
253 rewriter.setInsertionPointToStart(newForOp.getBody());
254
255 // Clone operations with transformation (c0 already created above)
256 for (Operation &op : forOp.getBody()->without_terminator()) {
257 // Transform vector.load operations
258 if (auto loadOp = dyn_cast<vector::LoadOp>(&op)) {
259 Value idx = loadOp.getIndices()[0];
260 Value mappedIdx = mapper.lookup(idx);
261
262 // Check if the index is now a pointer (was transformed)
263 if (llvm::isa<ptr::PtrType>(mappedIdx.getType())) {
264 // Get the generic-space memref type for this base
265 Type genericType = memrefToGenericTypeMap[loadOp.getBase()];
266
267 // Get metadata from the converted memref (with generic_space)
268 Value convertedMemref = memrefToConvertedMap[loadOp.getBase()];
269 auto metadataOp = ptr::GetMetadataOp::create(
270 rewriter, loadOp.getLoc(), convertedMemref);
271
272 // Transform: vector.load %memref[%ptr] -> ptr.from_ptr + vector.load
273 // [...[0]]
274 auto fromPtrOp =
275 ptr::FromPtrOp::create(rewriter, loadOp.getLoc(), genericType,
276 mappedIdx, metadataOp.getResult());
277 auto newLoad = vector::LoadOp::create(
278 rewriter, loadOp.getLoc(), loadOp.getVectorType(),
279 fromPtrOp.getResult(), ValueRange{c0});
280 mapper.map(loadOp.getResult(), newLoad.getResult());
281 continue;
282 }
283 }
284
285 // Transform vector.store operations
286 if (auto storeOp = dyn_cast<vector::StoreOp>(&op)) {
287 Value idx = storeOp.getIndices()[0];
288 Value mappedIdx = mapper.lookup(idx);
289
290 if (llvm::isa<ptr::PtrType>(mappedIdx.getType())) {
291 // Get the generic-space memref type for this base
292 Type genericType = memrefToGenericTypeMap[storeOp.getBase()];
293
294 // Get metadata from the converted memref (with generic_space)
295 Value convertedMemref = memrefToConvertedMap[storeOp.getBase()];
296 auto metadataOp = ptr::GetMetadataOp::create(
297 rewriter, storeOp.getLoc(), convertedMemref);
298
299 // Transform: vector.store %val, %memref[%ptr] -> ptr.from_ptr +
300 // vector.store[0]
301 auto fromPtrOp =
302 ptr::FromPtrOp::create(rewriter, storeOp.getLoc(), genericType,
303 mappedIdx, metadataOp.getResult());
304 Value valueToStore =
305 mapper.lookupOrDefault(storeOp.getValueToStore());
306 vector::StoreOp::create(rewriter, storeOp.getLoc(), valueToStore,
307 fromPtrOp.getResult(), ValueRange{c0});
308 continue;
309 }
310 }
311
312 // Transform arith.addi to ptr.ptr_add when operating on pointers
313 if (auto addiOp = dyn_cast<arith::AddIOp>(&op)) {
314 Value lhs = mapper.lookupOrDefault(addiOp.getLhs());
315 Value rhs = mapper.lookupOrDefault(addiOp.getRhs());
316
317 // If LHS is a pointer, convert to ptr.ptr_add
318 if (llvm::isa<ptr::PtrType>(lhs.getType())) {
319 // Find which memref this pointer corresponds to
320 Value memrefForPtr = nullptr;
321 for (auto [memref, access] : memrefAccesses) {
322 // Check if lhs comes from this memref's pointer chain
323 // For now, find any memref being accessed (simplified)
324 memrefForPtr = memref;
325 break;
326 }
327
328 Value byteOffset = rhs;
329 if (memrefForPtr) {
330 auto memrefType = cast<MemRefType>(memrefForPtr.getType());
331 unsigned elementSizeBits = memrefType.getElementTypeBitWidth();
332 unsigned elementSizeBytes = (elementSizeBits + 7) / 8;
333
334 // Scale offset by element size: byteOffset = rhs * elementSizeBytes
335 if (elementSizeBytes != 1) {
336 Value elementSize = arith::ConstantIndexOp::create(
337 rewriter, addiOp.getLoc(), elementSizeBytes);
338 byteOffset = arith::MulIOp::create(rewriter, addiOp.getLoc(), rhs,
339 elementSize);
340 }
341 }
342
343 auto ptrAddOp =
344 ptr::PtrAddOp::create(rewriter, addiOp.getLoc(), lhs, byteOffset);
345 mapper.map(addiOp.getResult(), ptrAddOp.getResult());
346 continue;
347 }
348 }
349
350 // Default: clone the operation
351 rewriter.clone(op, mapper);
352 }
353
354 // Clone yield
355 auto oldYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
356 SmallVector<Value> newYieldOperands;
357 for (Value operand : oldYield.getOperands()) {
358 newYieldOperands.push_back(mapper.lookupOrDefault(operand));
359 }
360 scf::YieldOp::create(rewriter, loc, newYieldOperands);
361
362 // Step 6: Replace old loop
363 rewriter.replaceOp(forOp, newForOp.getResults());
364
365 // NOTE: This is a simplified implementation
366 // Full version needs to properly transform vector.load/store and arith.addi
367
368 return success();
369 }
370};
371
372struct AIEVectorToPointerLoopsPass
373 : public PassWrapper<AIEVectorToPointerLoopsPass, OperationPass<DeviceOp>> {
374
375 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AIEVectorToPointerLoopsPass)
376
377 StringRef getArgument() const override {
378 return "aie-vector-to-pointer-loops";
379 }
380
381 StringRef getDescription() const override {
382 return "Transform vector.load/store with loop-carried indices to use ptr "
383 "dialect";
384 }
385
386 void getDependentDialects(DialectRegistry &registry) const override {
387 registry.insert<ptr::PtrDialect>();
388 registry.insert<scf::SCFDialect>();
389 registry.insert<vector::VectorDialect>();
390 registry.insert<arith::ArithDialect>();
391 registry.insert<memref::MemRefDialect>();
392 }
393
394 void runOnOperation() override {
395 DeviceOp deviceOp = getOperation();
396
397 RewritePatternSet patterns(&getContext());
398 patterns.add<VectorToPointerLoopsPattern>(&getContext());
399
400 // Apply patterns to the entire device
401 // The pattern will match scf.for loops in aie.core regions
402 if (failed(applyPatternsGreedily(deviceOp, std::move(patterns)))) {
403 signalPassFailure();
404 }
405 }
406};
407
408} // namespace
409
410namespace xilinx {
411namespace AIE {
412
413std::unique_ptr<OperationPass<DeviceOp>> createAIEVectorToPointerLoopsPass() {
414 return std::make_unique<AIEVectorToPointerLoopsPass>();
415}
416
417} // namespace AIE
418} // namespace xilinx
Include the generated interface declarations.
std::unique_ptr< mlir::OperationPass< DeviceOp > > createAIEVectorToPointerLoopsPass()