40#include "mlir/Dialect/Arith/IR/Arith.h"
41#include "mlir/Dialect/Func/IR/FuncOps.h"
42#include "mlir/Dialect/MemRef/IR/MemRef.h"
43#include "mlir/Dialect/Ptr/IR/PtrOps.h"
44#include "mlir/Dialect/SCF/IR/SCF.h"
45#include "mlir/Dialect/Vector/IR/VectorOps.h"
46#include "mlir/IR/PatternMatch.h"
47#include "mlir/Pass/Pass.h"
48#include "mlir/Transforms/DialectConversion.h"
49#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
51#include "llvm/ADT/DenseMap.h"
52#include "llvm/ADT/SetVector.h"
54#define DEBUG_TYPE "aie-vector-to-pointer-loops"
63static bool isLoopCarriedValue(Value val, scf::ForOp forOp) {
64 auto blockArgs = forOp.getRegionIterArgs();
65 return llvm::is_contained(blockArgs, val);
69struct MemrefVectorAccess {
71 SmallVector<vector::LoadOp> loads;
72 SmallVector<vector::StoreOp> stores;
73 SmallVector<Value> indices;
77static bool analyzeLoopForVectorAccesses(
78 scf::ForOp forOp, DenseMap<Value, MemrefVectorAccess> &memrefAccesses) {
80 bool foundPattern =
false;
82 forOp.walk([&](vector::LoadOp loadOp) {
83 Value base = loadOp.getBase();
84 auto indices = loadOp.getIndices();
87 if (indices.size() != 1)
90 Value idx = indices[0];
93 if (isLoopCarriedValue(idx, forOp)) {
94 memrefAccesses[base].memref = base;
95 memrefAccesses[base].loads.push_back(loadOp);
96 memrefAccesses[base].indices.push_back(idx);
101 forOp.walk([&](vector::StoreOp storeOp) {
102 Value base = storeOp.getBase();
103 auto indices = storeOp.getIndices();
105 if (indices.size() != 1)
108 Value idx = indices[0];
110 if (isLoopCarriedValue(idx, forOp)) {
111 memrefAccesses[base].memref = base;
112 memrefAccesses[base].stores.push_back(storeOp);
113 memrefAccesses[base].indices.push_back(idx);
125 LogicalResult matchAndRewrite(scf::ForOp forOp,
126 PatternRewriter &rewriter)
const override {
129 for (Value iterArg : forOp.getRegionIterArgs()) {
130 if (llvm::isa<ptr::PtrType>(iterArg.getType()))
135 DenseMap<Value, MemrefVectorAccess> memrefAccesses;
136 if (!analyzeLoopForVectorAccesses(forOp, memrefAccesses))
139 if (memrefAccesses.empty())
142 Location loc = forOp.getLoc();
143 OpBuilder::InsertionGuard guard(rewriter);
144 rewriter.setInsertionPoint(forOp);
147 DenseMap<Value, Value> memrefToPtrMap;
148 DenseMap<Value, Type>
149 memrefToGenericTypeMap;
150 DenseMap<Value, Value>
151 memrefToConvertedMap;
152 auto genericSpace = ptr::GenericSpaceAttr::get(rewriter.getContext());
154 for (
auto &[memref, access] : memrefAccesses) {
156 auto memrefType = cast<MemRefType>(memref.getType());
157 Attribute memorySpace = memrefType.getMemorySpace();
159 Value memrefToConvert = memref;
160 Type genericMemrefType = memrefType;
163 if (memorySpace && !llvm::isa<ptr::GenericSpaceAttr>(memorySpace)) {
166 MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
167 memrefType.getLayout(), genericSpace);
170 auto castOp = UnrealizedConversionCastOp::create(rewriter, loc,
171 newMemrefType, memref);
172 memrefToConvert = castOp.getResult(0);
173 genericMemrefType = newMemrefType;
177 auto ptrType = ptr::PtrType::get(rewriter.getContext(), genericSpace);
179 ptr::ToPtrOp::create(rewriter, loc, ptrType, memrefToConvert);
180 memrefToPtrMap[memref] = ptrOp.getResult();
181 memrefToGenericTypeMap[memref] = genericMemrefType;
182 memrefToConvertedMap[memref] =
187 SmallVector<unsigned> indexIterArgPositions;
188 SmallVector<Value> correspondingMemrefs;
190 for (
auto [idx, iterArg] : llvm::enumerate(forOp.getRegionIterArgs())) {
191 for (
auto &[memref, access] : memrefAccesses) {
192 if (llvm::is_contained(access.indices, iterArg)) {
193 indexIterArgPositions.push_back(idx);
194 correspondingMemrefs.push_back(memref);
200 if (indexIterArgPositions.empty())
204 SmallVector<Value> newInitArgs;
205 Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
207 for (
auto [idx, initArg] : llvm::enumerate(forOp.getInitArgs())) {
208 auto it = llvm::find(indexIterArgPositions, idx);
209 if (it != indexIterArgPositions.end()) {
211 size_t pos = std::distance(indexIterArgPositions.begin(), it);
212 Value memref = correspondingMemrefs[pos];
213 Value basePtr = memrefToPtrMap[memref];
216 auto memrefType = cast<MemRefType>(memref.getType());
217 unsigned elementSizeBits = memrefType.getElementTypeBitWidth();
218 unsigned elementSizeBytes = (elementSizeBits + 7) / 8;
221 Value byteOffset = initArg;
222 if (elementSizeBytes != 1) {
224 arith::ConstantIndexOp::create(rewriter, loc, elementSizeBytes);
226 arith::MulIOp::create(rewriter, loc, initArg, elementSize);
231 ptr::PtrAddOp::create(rewriter, loc, basePtr, byteOffset);
232 newInitArgs.push_back(initPtrOp.getResult());
235 newInitArgs.push_back(initArg);
241 scf::ForOp::create(rewriter, loc, forOp.getLowerBound(),
242 forOp.getUpperBound(), forOp.getStep(), newInitArgs);
246 mapper.map(forOp.getInductionVar(), newForOp.getInductionVar());
248 for (
auto [oldArg, newArg] :
249 llvm::zip(forOp.getRegionIterArgs(), newForOp.getRegionIterArgs())) {
250 mapper.map(oldArg, newArg);
253 rewriter.setInsertionPointToStart(newForOp.getBody());
256 for (Operation &op : forOp.getBody()->without_terminator()) {
258 if (
auto loadOp = dyn_cast<vector::LoadOp>(&op)) {
259 Value idx = loadOp.getIndices()[0];
260 Value mappedIdx = mapper.lookup(idx);
263 if (llvm::isa<ptr::PtrType>(mappedIdx.getType())) {
265 Type genericType = memrefToGenericTypeMap[loadOp.getBase()];
268 Value convertedMemref = memrefToConvertedMap[loadOp.getBase()];
269 auto metadataOp = ptr::GetMetadataOp::create(
270 rewriter, loadOp.getLoc(), convertedMemref);
275 ptr::FromPtrOp::create(rewriter, loadOp.getLoc(), genericType,
276 mappedIdx, metadataOp.getResult());
277 auto newLoad = vector::LoadOp::create(
278 rewriter, loadOp.getLoc(), loadOp.getVectorType(),
279 fromPtrOp.getResult(), ValueRange{c0});
280 mapper.map(loadOp.getResult(), newLoad.getResult());
286 if (
auto storeOp = dyn_cast<vector::StoreOp>(&op)) {
287 Value idx = storeOp.getIndices()[0];
288 Value mappedIdx = mapper.lookup(idx);
290 if (llvm::isa<ptr::PtrType>(mappedIdx.getType())) {
292 Type genericType = memrefToGenericTypeMap[storeOp.getBase()];
295 Value convertedMemref = memrefToConvertedMap[storeOp.getBase()];
296 auto metadataOp = ptr::GetMetadataOp::create(
297 rewriter, storeOp.getLoc(), convertedMemref);
302 ptr::FromPtrOp::create(rewriter, storeOp.getLoc(), genericType,
303 mappedIdx, metadataOp.getResult());
305 mapper.lookupOrDefault(storeOp.getValueToStore());
306 vector::StoreOp::create(rewriter, storeOp.getLoc(), valueToStore,
307 fromPtrOp.getResult(), ValueRange{c0});
313 if (
auto addiOp = dyn_cast<arith::AddIOp>(&op)) {
314 Value lhs = mapper.lookupOrDefault(addiOp.getLhs());
315 Value rhs = mapper.lookupOrDefault(addiOp.getRhs());
318 if (llvm::isa<ptr::PtrType>(lhs.getType())) {
320 Value memrefForPtr =
nullptr;
321 for (
auto [memref, access] : memrefAccesses) {
324 memrefForPtr = memref;
328 Value byteOffset = rhs;
330 auto memrefType = cast<MemRefType>(memrefForPtr.getType());
331 unsigned elementSizeBits = memrefType.getElementTypeBitWidth();
332 unsigned elementSizeBytes = (elementSizeBits + 7) / 8;
335 if (elementSizeBytes != 1) {
336 Value elementSize = arith::ConstantIndexOp::create(
337 rewriter, addiOp.getLoc(), elementSizeBytes);
338 byteOffset = arith::MulIOp::create(rewriter, addiOp.getLoc(), rhs,
344 ptr::PtrAddOp::create(rewriter, addiOp.getLoc(), lhs, byteOffset);
345 mapper.map(addiOp.getResult(), ptrAddOp.getResult());
351 rewriter.clone(op, mapper);
355 auto oldYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
356 SmallVector<Value> newYieldOperands;
357 for (Value operand : oldYield.getOperands()) {
358 newYieldOperands.push_back(mapper.lookupOrDefault(operand));
360 scf::YieldOp::create(rewriter, loc, newYieldOperands);
363 rewriter.replaceOp(forOp, newForOp.getResults());
372struct AIEVectorToPointerLoopsPass
373 :
public PassWrapper<AIEVectorToPointerLoopsPass, OperationPass<DeviceOp>> {
375 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AIEVectorToPointerLoopsPass)
377 StringRef getArgument()
const override {
378 return "aie-vector-to-pointer-loops";
381 StringRef getDescription()
const override {
382 return "Transform vector.load/store with loop-carried indices to use ptr "
386 void getDependentDialects(DialectRegistry ®istry)
const override {
387 registry.insert<ptr::PtrDialect>();
388 registry.insert<scf::SCFDialect>();
389 registry.insert<vector::VectorDialect>();
390 registry.insert<arith::ArithDialect>();
391 registry.insert<memref::MemRefDialect>();
394 void runOnOperation()
override {
395 DeviceOp deviceOp = getOperation();
397 RewritePatternSet patterns(&getContext());
398 patterns.add<VectorToPointerLoopsPattern>(&getContext());
402 if (failed(applyPatternsGreedily(deviceOp, std::move(patterns)))) {
414 return std::make_unique<AIEVectorToPointerLoopsPass>();
Include the generated interface declarations.
std::unique_ptr< mlir::OperationPass< DeviceOp > > createAIEVectorToPointerLoopsPass()