21#include "mlir/Dialect/Affine/IR/AffineOps.h"
22#include "mlir/Dialect/Arith/IR/Arith.h"
23#include "mlir/Dialect/Func/IR/FuncOps.h"
24#include "mlir/Dialect/MemRef/IR/MemRef.h"
25#include "mlir/Dialect/SCF/IR/SCF.h"
26#include "mlir/Dialect/SCF/Utils/Utils.h"
27#include "mlir/Dialect/Vector/IR/VectorOps.h"
28#include "mlir/IR/IRMapping.h"
29#include "mlir/IR/PatternMatch.h"
30#include "mlir/Interfaces/LoopLikeInterface.h"
31#include "mlir/Pass/Pass.h"
32#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
35#define GEN_PASS_DEF_AIEHOISTVECTORTRANSFERPOINTERS
36#include "aie/Dialect/AIE/Transforms/AIEPasses.h.inc"
39#define DEBUG_TYPE "aie-hoist-vector-transfer-pointers"
53static bool dependsOnLoopIVForHoist(Value val, Value loopIV,
54 DenseMap<Value, bool> &cache) {
56 auto it = cache.find(val);
57 if (it != cache.end())
68 }
else if (
auto defOp = val.getDefiningOp()) {
70 for (Value operand : defOp->getOperands()) {
71 if (dependsOnLoopIVForHoist(operand, loopIV, cache)) {
84static bool dependsOnLoopIVForHoist(Value val, Value loopIV) {
85 DenseMap<Value, bool> cache;
86 return dependsOnLoopIVForHoist(val, loopIV, cache);
91static Value cloneOpAndOperands(Operation *op, Value loopIV, OpBuilder &builder,
94 if (op->getNumResults() != 1)
99 if (mapping.contains(op->getResult(0)))
100 return mapping.lookup(op->getResult(0));
103 if (dependsOnLoopIVForHoist(op->getResult(0), loopIV))
107 SmallVector<Value> newOperands;
108 for (Value operand : op->getOperands()) {
109 if (
auto defOp = operand.getDefiningOp()) {
110 Value clonedOperand = cloneOpAndOperands(defOp, loopIV, builder, mapping);
113 newOperands.push_back(clonedOperand);
117 newOperands.push_back(operand);
122 Operation *clonedOp = builder.clone(*op);
123 clonedOp->setOperands(newOperands);
126 mapping.map(op->getResult(0), clonedOp->getResult(0));
127 return clonedOp->getResult(0);
131static int64_t getVectorNumElements(VectorType vectorType) {
132 int64_t numElements = 1;
133 for (int64_t dim : vectorType.getShape()) {
144struct TransferOpInfo {
147 MemRefType memrefType;
148 VectorType vectorType;
149 SmallVector<Value> indices;
150 int64_t constantStride;
151 bool hasIVDependentIndices;
156struct HoistVectorTransferPointersPattern
160 LogicalResult matchAndRewrite(scf::ForOp forOp,
161 PatternRewriter &rewriter)
const override {
162 Value loopIV = forOp.getInductionVar();
163 Location loc = forOp.getLoc();
166 SmallVector<TransferOpInfo> transferOps;
168 for (Operation &op : forOp.getBody()->without_terminator()) {
170 VectorType vectorType;
171 SmallVector<Value> indices;
173 if (
auto readOp = dyn_cast<vector::TransferReadOp>(&op)) {
174 base = readOp.getBase();
175 vectorType = readOp.getVectorType();
176 indices.assign(readOp.getIndices().begin(), readOp.getIndices().end());
177 }
else if (
auto writeOp = dyn_cast<vector::TransferWriteOp>(&op)) {
178 base = writeOp.getBase();
179 vectorType = writeOp.getVectorType();
180 indices.assign(writeOp.getIndices().begin(),
181 writeOp.getIndices().end());
186 auto memrefType = dyn_cast<MemRefType>(base.getType());
191 bool hasIVDependentIndices =
false;
192 int64_t constantStride = 0;
195 auto stepCst = forOp.getConstantStep();
197 stepCst.has_value() ? stepCst.value().getSExtValue() : 1;
199 for (
size_t dimIdx = 0; dimIdx < indices.size(); ++dimIdx) {
200 Value idx = indices[dimIdx];
201 if (dependsOnLoopIVForHoist(idx, loopIV)) {
202 hasIVDependentIndices =
true;
205 int64_t dimStride = 1;
206 bool hasDynamicStride =
false;
207 for (
size_t j = dimIdx + 1;
208 j < static_cast<size_t>(memrefType.getRank()); ++j) {
209 int64_t dimSize = memrefType.getShape()[j];
210 if (dimSize == ShapedType::kDynamic) {
211 hasDynamicStride =
true;
214 dimStride *= dimSize;
219 if (!hasDynamicStride)
220 constantStride += dimStride * loopStep;
222 hasIVDependentIndices =
false;
226 transferOps.push_back({&op, base, memrefType, vectorType, indices,
227 constantStride, hasIVDependentIndices});
231 if (transferOps.empty())
236 SmallVector<Value> newInitArgs;
237 SmallVector<Value> flatMemrefs;
239 for (
const auto &info : transferOps) {
240 if (!info.hasIVDependentIndices)
244 rewriter.setInsertionPoint(forOp);
245 Value flatMemref = info.base;
246 if (info.memrefType.getRank() > 1) {
247 int64_t totalSize = 1;
248 for (int64_t dim : info.memrefType.getShape()) {
249 if (dim == ShapedType::kDynamic)
255 MemRefType flatMemrefType;
256 if (
auto stridedLayout = dyn_cast_or_null<StridedLayoutAttr>(
257 info.memrefType.getLayout())) {
259 int64_t collapsedStride = stridedLayout.getStrides().back();
260 int64_t offset = stridedLayout.getOffset();
262 auto newLayout = StridedLayoutAttr::get(rewriter.getContext(), offset,
265 MemRefType::get({totalSize}, info.memrefType.getElementType(),
266 newLayout, info.memrefType.getMemorySpace());
269 MemRefType::get({totalSize}, info.memrefType.getElementType(),
270 AffineMap(), info.memrefType.getMemorySpace());
273 SmallVector<ReassociationIndices> reassociation;
274 ReassociationIndices allDims;
275 for (
size_t i = 0; i < static_cast<size_t>(info.memrefType.getRank());
277 allDims.push_back(i);
279 reassociation.push_back(allDims);
281 flatMemref = memref::CollapseShapeOp::create(
282 rewriter, loc, flatMemrefType, info.base, reassociation);
284 flatMemrefs.push_back(flatMemref);
287 int64_t rank = info.memrefType.getRank();
288 AffineExpr linearExpr = rewriter.getAffineConstantExpr(0);
290 for (int64_t i = rank - 1; i >= 0; --i) {
291 linearExpr = linearExpr + rewriter.getAffineDimExpr(i) * stride;
293 stride *= info.memrefType.getShape()[i];
295 auto linearMap = AffineMap::get(rank, 0, linearExpr);
299 SmallVector<Value> evaluatedIndices;
300 IRMapping indexMapping;
301 for (Value idx : info.indices) {
302 if (dependsOnLoopIVForHoist(idx, loopIV)) {
304 if (
auto affineOp = idx.getDefiningOp<affine::AffineApplyOp>()) {
305 SmallVector<Value> mappedOperands;
306 for (Value operand : affineOp.getMapOperands()) {
307 if (operand == loopIV)
308 mappedOperands.push_back(forOp.getLowerBound());
310 mappedOperands.push_back(operand);
312 Value evaluatedIdx = affine::AffineApplyOp::create(
313 rewriter, loc, affineOp.getAffineMap(), mappedOperands);
314 evaluatedIndices.push_back(evaluatedIdx);
317 evaluatedIndices.push_back(forOp.getLowerBound());
321 if (
auto defOp = idx.getDefiningOp()) {
323 cloneOpAndOperands(defOp, loopIV, rewriter, indexMapping);
325 evaluatedIndices.push_back(clonedIdx);
327 evaluatedIndices.push_back(idx);
329 evaluatedIndices.push_back(idx);
334 Value basePointer = affine::AffineApplyOp::create(
335 rewriter, loc, linearMap, evaluatedIndices);
337 newInitArgs.push_back(basePointer);
342 if (newInitArgs.empty()) {
344 bool needsFlattening =
false;
345 bool hasProcessableTransfers =
false;
346 for (
const auto &info : transferOps) {
349 if (info.base.getDefiningOp() &&
350 forOp->isProperAncestor(info.base.getDefiningOp()))
353 hasProcessableTransfers =
true;
357 if (
auto readOp = dyn_cast<vector::TransferReadOp>(info.op)) {
358 if (readOp.getPermutationMap().getNumDims() != 1)
359 needsFlattening =
true;
360 }
else if (
auto writeOp = dyn_cast<vector::TransferWriteOp>(info.op)) {
361 if (writeOp.getPermutationMap().getNumDims() != 1)
362 needsFlattening =
true;
368 if (!hasProcessableTransfers || !needsFlattening)
373 DenseMap<Value, Value> baseFlatMemrefs;
374 rewriter.setInsertionPoint(forOp);
375 for (
const auto &info : transferOps) {
376 if (baseFlatMemrefs.count(info.base))
380 if (info.base.getDefiningOp() &&
381 forOp->isProperAncestor(info.base.getDefiningOp()))
384 Value flatMemref = info.base;
385 if (info.memrefType.getRank() > 1) {
386 int64_t totalSize = 1;
387 for (int64_t dim : info.memrefType.getShape()) {
392 MemRefType flatMemrefType;
393 if (
auto stridedLayout = dyn_cast_or_null<StridedLayoutAttr>(
394 info.memrefType.getLayout())) {
395 int64_t collapsedStride = stridedLayout.getStrides().back();
396 int64_t offset = stridedLayout.getOffset();
398 auto newLayout = StridedLayoutAttr::get(rewriter.getContext(),
399 offset, {collapsedStride});
401 MemRefType::get({totalSize}, info.memrefType.getElementType(),
402 newLayout, info.memrefType.getMemorySpace());
405 MemRefType::get({totalSize}, info.memrefType.getElementType(),
406 AffineMap(), info.memrefType.getMemorySpace());
409 SmallVector<ReassociationIndices> reassociation;
410 ReassociationIndices allDims;
411 for (
size_t i = 0; i < static_cast<size_t>(info.memrefType.getRank());
413 allDims.push_back(i);
415 reassociation.push_back(allDims);
416 flatMemref = memref::CollapseShapeOp::create(
417 rewriter, loc, flatMemrefType, info.base, reassociation);
419 baseFlatMemrefs[info.base] = flatMemref;
423 bool madeChanges =
false;
424 for (
const auto &info : transferOps) {
426 if (info.base.getDefiningOp() &&
427 forOp->isProperAncestor(info.base.getDefiningOp()))
431 if (!baseFlatMemrefs.count(info.base))
434 rewriter.setInsertionPoint(info.op);
437 int64_t numElements = getVectorNumElements(info.vectorType);
438 VectorType flatVectorType =
439 VectorType::get({numElements}, info.vectorType.getElementType());
442 Value flatMemref = baseFlatMemrefs[info.base];
445 int64_t rank = info.memrefType.getRank();
446 AffineExpr linearExpr = rewriter.getAffineConstantExpr(0);
448 for (int64_t i = rank - 1; i >= 0; --i) {
449 linearExpr = linearExpr + rewriter.getAffineDimExpr(i) * stride;
451 stride *= info.memrefType.getShape()[i];
453 auto linearMap = AffineMap::get(rank, 0, linearExpr);
455 Value currentPointer = affine::AffineApplyOp::create(
456 rewriter, loc, linearMap, info.indices);
459 AffineMap identityMap1D = AffineMap::get(
460 1, 0, rewriter.getAffineDimExpr(0), rewriter.getContext());
461 auto inBoundsAttr = rewriter.getBoolArrayAttr({
true});
463 if (
auto readOp = dyn_cast<vector::TransferReadOp>(info.op)) {
464 Value flatRead = vector::TransferReadOp::create(
465 rewriter, loc, flatVectorType, flatMemref,
466 ValueRange{currentPointer}, AffineMapAttr::get(identityMap1D),
468 Value(), inBoundsAttr);
469 Value shapedRead = vector::ShapeCastOp::create(
470 rewriter, loc, info.vectorType, flatRead);
471 rewriter.replaceOp(readOp, shapedRead);
473 }
else if (
auto writeOp = dyn_cast<vector::TransferWriteOp>(info.op)) {
474 Value flatValue = vector::ShapeCastOp::create(
475 rewriter, loc, flatVectorType, writeOp.getVector());
476 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
477 writeOp, flatValue, flatMemref, ValueRange{currentPointer},
478 AffineMapAttr::get(identityMap1D), Value(),
483 return madeChanges ? success() : failure();
488 [&](OpBuilder &b, Location yieldLoc,
489 ArrayRef<BlockArgument> newBbArgs) -> SmallVector<Value> {
490 SmallVector<Value> yieldValues;
493 size_t iterArgIdx = 0;
494 for (
size_t i = 0; i < transferOps.size(); ++i) {
495 const auto &info = transferOps[i];
496 if (!info.hasIVDependentIndices)
499 BlockArgument ptrIterArg =
500 newBbArgs[newBbArgs.size() - newInitArgs.size() + iterArgIdx];
501 Value flatMemref = flatMemrefs[iterArgIdx];
504 int64_t numElements = getVectorNumElements(info.vectorType);
505 VectorType flatVectorType =
506 VectorType::get({numElements}, info.vectorType.getElementType());
509 b.setInsertionPoint(info.op);
511 AffineMap identityMap1D =
512 AffineMap::get(1, 0, b.getAffineDimExpr(0), b.getContext());
513 auto inBoundsAttr = b.getBoolArrayAttr({
true});
515 if (
auto readOp = dyn_cast<vector::TransferReadOp>(info.op)) {
516 Value flatRead = vector::TransferReadOp::create(
517 b, loc, flatVectorType, flatMemref, ValueRange{ptrIterArg},
518 AffineMapAttr::get(identityMap1D), readOp.getPadding(),
519 Value(), inBoundsAttr);
521 vector::ShapeCastOp::create(b, loc, info.vectorType, flatRead);
522 rewriter.replaceOp(readOp, shapedRead);
523 }
else if (
auto writeOp = dyn_cast<vector::TransferWriteOp>(info.op)) {
524 Value flatValue = vector::ShapeCastOp::create(b, loc, flatVectorType,
525 writeOp.getVector());
526 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
527 writeOp, flatValue, flatMemref, ValueRange{ptrIterArg},
528 AffineMapAttr::get(identityMap1D), Value(),
534 arith::ConstantIndexOp::create(b, yieldLoc, info.constantStride);
536 arith::AddIOp::create(b, yieldLoc, ptrIterArg, strideConst);
537 yieldValues.push_back(nextPtr);
546 FailureOr<LoopLikeOpInterface> newLoopResult =
547 cast<LoopLikeOpInterface>(forOp.getOperation())
548 .replaceWithAdditionalYields(
549 rewriter, newInitArgs,
553 if (failed(newLoopResult))
564struct AIEHoistVectorTransferPointersPass
565 : xilinx::AIE::impl::AIEHoistVectorTransferPointersBase<
566 AIEHoistVectorTransferPointersPass> {
567 void getDependentDialects(DialectRegistry ®istry)
const override {
568 registry.insert<affine::AffineDialect, arith::ArithDialect,
569 memref::MemRefDialect, scf::SCFDialect,
570 vector::VectorDialect>();
573 void runOnOperation()
override {
574 ModuleOp moduleOp = getOperation();
575 MLIRContext *context = &getContext();
577 RewritePatternSet patterns(context);
578 patterns.add<HoistVectorTransferPointersPattern>(context);
582 if (failed(applyPatternsGreedily(moduleOp, std::move(patterns))))
589std::unique_ptr<OperationPass<ModuleOp>>
591 return std::make_unique<AIEHoistVectorTransferPointersPass>();
Include the generated interface declarations.
std::unique_ptr< mlir::OperationPass< mlir::ModuleOp > > createAIEHoistVectorTransferPointersPass()