MLIR-AIE
AIEMaterializeBDChains.cpp
Go to the documentation of this file.
1//===- AIEMaterializeBDChains.cpp -------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2024 Advanced Micro Devices, Inc.
8//
9//===----------------------------------------------------------------------===//
10
14
15#include "mlir/Analysis/CallGraph.h"
16#include "mlir/IR/IRMapping.h"
17#include "mlir/IR/PatternMatch.h"
18#include "mlir/Pass/Pass.h"
19#include "mlir/Pass/PassManager.h"
20#include "mlir/Transforms/DialectConversion.h"
21#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22
23using namespace mlir;
24using namespace xilinx;
25using namespace xilinx::AIEX;
26
28
30 : RewritePattern(DMAStartBdChainForOp::getOperationName(),
31 PatternBenefit(1), ctx) {}
32
33 LogicalResult matchAndRewrite(Operation *op_any,
34 PatternRewriter &rewriter) const override {
35 DMAStartBdChainForOp op = llvm::dyn_cast<DMAStartBdChainForOp>(op_any);
36 if (!op) {
37 return failure();
38 }
39 AIE::DeviceOp device = op->getParentOfType<AIE::DeviceOp>();
40
41 AIE::ShimDMAAllocationOp alloc_op =
42 AIE::ShimDMAAllocationOp::getForSymbol(device, op.getAlloc());
43 if (!alloc_op) {
44 return op.emitOpError("no shim DMA allocation found for symbol");
45 }
46
47 const int col = alloc_op.getCol();
48 AIE::TileOp tile = AIE::TileOp::getOrCreate(rewriter, device, col, 0);
49 DMAStartBdChainOp new_op = rewriter.create<DMAStartBdChainOp>(
50 op.getLoc(), rewriter.getIndexType(), op.getSymbol(), op.getArgs(),
51 tile.getResult(), alloc_op.getChannelDir(),
52 (int32_t)alloc_op.getChannelIndex(), op.getIssueToken(),
53 op.getRepeatCount());
54 rewriter.replaceAllUsesWith(op.getResult(), new_op.getResult());
55 rewriter.eraseOp(op);
56 return success();
57 }
58};
59
61
62 DMAInlineBDChainPattern(MLIRContext *ctx)
63 : RewritePattern(DMAStartBdChainOp::getOperationName(), PatternBenefit(1),
64 ctx) {}
65
66 LogicalResult matchAndRewrite(Operation *op,
67 PatternRewriter &rewriter) const override {
68 DMAStartBdChainOp start_op = llvm::dyn_cast<DMAStartBdChainOp>(op);
69 if (!start_op) { // Not a match.
70 return failure();
71 }
72 rewriter.setInsertionPointAfter(start_op);
73
74 // Get referenced abstract BD chain
75 AIE::BDChainOp chain_def = start_op.getBDChainOp();
76 assert(chain_def);
77 Region &source_region = chain_def.getBody();
78
79 // Create BD op into which the result will be inlined
80 DMAConfigureTaskOp configure_op = rewriter.create<DMAConfigureTaskOp>(
81 start_op.getLoc(), rewriter.getIndexType(), start_op.getTile(),
82 start_op.getDirection(), start_op.getChannel(),
83 start_op.getIssueToken(), start_op.getRepeatCount());
84 Region &target_region = configure_op.getBody();
85
86 // Clone BD definition into usage site, replacing abstract SSA values with
87 // concrete ones
88 IRMapping arg_map;
89 ValueRange values = start_op.getArgs();
90 for (unsigned i = 0, n = source_region.getNumArguments(); i < n; i++) {
91 BlockArgument arg = source_region.getArgument(i);
92 Value val = values[i];
93 assert(arg.getType() == val.getType());
94 arg_map.map(arg, val);
95 }
96 source_region.cloneInto(&target_region, arg_map);
97
98 // Replace result of dma start task with result of bd chain configuration
99 rewriter.replaceAllUsesWith(start_op.getResult(), configure_op.getResult());
100
101 // Add a start BDs instruction
102 rewriter.create<DMAStartTaskOp>(start_op.getLoc(),
103 configure_op.getResult());
104
105 // After fully inlining, remove the original instruction
106 rewriter.eraseOp(start_op);
107
108 return success();
109 }
110};
111
113 : AIEMaterializeBDChainsBase<AIEMaterializeBDChainsPass> {
114
115 void runOnOperation() override {
116 MLIRContext *ctx = &getContext();
117 AIE::DeviceOp device = getOperation();
118 GreedyRewriteConfig rewriter_config = GreedyRewriteConfig();
119 rewriter_config.enableRegionSimplification =
120 GreedySimplifyRegionLevel::Disabled;
121
122 RewritePatternSet patterns_0(ctx);
123 patterns_0.insert<DMAStartBdChainForOpPattern>(ctx);
124 DMAConfigureTaskOp::getCanonicalizationPatterns(patterns_0, ctx);
125 if (failed(applyPatternsGreedily(device, std::move(patterns_0),
126 rewriter_config))) {
127 signalPassFailure();
128 }
129
130 RewritePatternSet patterns_1(ctx);
131 patterns_1.insert<DMAInlineBDChainPattern>(ctx);
132 rewriter_config.enableRegionSimplification =
133 GreedySimplifyRegionLevel::Disabled;
134 DMAConfigureTaskOp::getCanonicalizationPatterns(patterns_1, ctx);
135 if (failed(applyPatternsGreedily(device, std::move(patterns_1),
136 rewriter_config))) {
137 signalPassFailure();
138 }
139 }
140};
141
142std::unique_ptr<OperationPass<AIE::DeviceOp>>
144 return std::make_unique<AIEMaterializeBDChainsPass>();
145}
std::unique_ptr< mlir::OperationPass< AIE::DeviceOp > > createAIEMaterializeBDChainsPass()
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
DMAInlineBDChainPattern(MLIRContext *ctx)
LogicalResult matchAndRewrite(Operation *op_any, PatternRewriter &rewriter) const override