MLIR-AIE
AIEMaterializeBDChains.cpp
Go to the documentation of this file.
1//===- AIEMaterializeBDChains.cpp -------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2024 Advanced Micro Devices, Inc.
8//
9//===----------------------------------------------------------------------===//
10
14
15#include "mlir/Analysis/CallGraph.h"
16#include "mlir/IR/IRMapping.h"
17#include "mlir/IR/PatternMatch.h"
18#include "mlir/Pass/Pass.h"
19#include "mlir/Pass/PassManager.h"
20#include "mlir/Transforms/DialectConversion.h"
21#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22
23namespace xilinx::AIEX {
24#define GEN_PASS_DEF_AIEMATERIALIZEBDCHAINS
25#include "aie/Dialect/AIEX/Transforms/AIEXPasses.h.inc"
26} // namespace xilinx::AIEX
27
28using namespace mlir;
29using namespace xilinx;
30using namespace xilinx::AIEX;
31
33
35 : RewritePattern(DMAStartBdChainForOp::getOperationName(),
36 PatternBenefit(1), ctx) {}
37
38 LogicalResult matchAndRewrite(Operation *op_any,
39 PatternRewriter &rewriter) const override {
40 DMAStartBdChainForOp op = llvm::dyn_cast<DMAStartBdChainForOp>(op_any);
41 if (!op) {
42 return failure();
43 }
44 AIE::DeviceOp device = op->getParentOfType<AIE::DeviceOp>();
45
46 AIE::ShimDMAAllocationOp alloc_op =
47 AIE::ShimDMAAllocationOp::getForSymbol(device, op.getAlloc());
48 if (!alloc_op) {
49 return op.emitOpError("no shim DMA allocation found for symbol");
50 }
51
52 AIE::TileOp tile = alloc_op.getTileOp();
53 if (!tile) {
54 return op.emitOpError(
55 "shim DMA allocation must reference a valid TileOp");
56 }
57
58 DMAStartBdChainOp new_op = DMAStartBdChainOp::create(
59 rewriter, op.getLoc(), rewriter.getIndexType(), op.getSymbol(),
60 op.getArgs(), tile.getResult(), alloc_op.getChannelDir(),
61 (int32_t)alloc_op.getChannelIndex(), op.getIssueToken(),
62 op.getRepeatCount());
63 rewriter.replaceAllUsesWith(op.getResult(), new_op.getResult());
64 rewriter.eraseOp(op);
65 return success();
66 }
67};
68
70
71 DMAInlineBDChainPattern(MLIRContext *ctx)
72 : RewritePattern(DMAStartBdChainOp::getOperationName(), PatternBenefit(1),
73 ctx) {}
74
75 LogicalResult matchAndRewrite(Operation *op,
76 PatternRewriter &rewriter) const override {
77 DMAStartBdChainOp start_op = llvm::dyn_cast<DMAStartBdChainOp>(op);
78 if (!start_op) { // Not a match.
79 return failure();
80 }
81 rewriter.setInsertionPointAfter(start_op);
82
83 // Get referenced abstract BD chain
84 AIE::BDChainOp chain_def = start_op.getBDChainOp();
85 assert(chain_def);
86 Region &source_region = chain_def.getBody();
87
88 // Create BD op into which the result will be inlined
89 DMAConfigureTaskOp configure_op = DMAConfigureTaskOp::create(
90 rewriter, start_op.getLoc(), rewriter.getIndexType(),
91 start_op.getTile(), start_op.getDirection(), start_op.getChannel(),
92 start_op.getIssueToken(), start_op.getRepeatCount());
93 Region &target_region = configure_op.getBody();
94
95 // Clone BD definition into usage site, replacing abstract SSA values with
96 // concrete ones
97 IRMapping arg_map;
98 ValueRange values = start_op.getArgs();
99 for (unsigned i = 0, n = source_region.getNumArguments(); i < n; i++) {
100 BlockArgument arg = source_region.getArgument(i);
101 Value val = values[i];
102 assert(arg.getType() == val.getType());
103 arg_map.map(arg, val);
104 }
105 source_region.cloneInto(&target_region, arg_map);
106
107 // Replace result of dma start task with result of bd chain configuration
108 rewriter.replaceAllUsesWith(start_op.getResult(), configure_op.getResult());
109
110 // Add a start BDs instruction
111 DMAStartTaskOp::create(rewriter, start_op.getLoc(),
112 configure_op.getResult());
113
114 // After fully inlining, remove the original instruction
115 rewriter.eraseOp(start_op);
116
117 return success();
118 }
119};
120
123 AIEMaterializeBDChainsPass> {
124
125 void runOnOperation() override {
126 MLIRContext *ctx = &getContext();
127 AIE::DeviceOp device = getOperation();
128 GreedyRewriteConfig rewriter_config = GreedyRewriteConfig();
129 rewriter_config.setRegionSimplificationLevel(
130 GreedySimplifyRegionLevel::Disabled);
131
132 RewritePatternSet patterns_0(ctx);
133 patterns_0.insert<DMAStartBdChainForOpPattern>(ctx);
134 DMAConfigureTaskOp::getCanonicalizationPatterns(patterns_0, ctx);
135 if (failed(applyPatternsGreedily(device, std::move(patterns_0),
136 rewriter_config))) {
137 signalPassFailure();
138 }
139
140 RewritePatternSet patterns_1(ctx);
141 patterns_1.insert<DMAInlineBDChainPattern>(ctx);
142 rewriter_config.setRegionSimplificationLevel(
143 GreedySimplifyRegionLevel::Disabled);
144 DMAConfigureTaskOp::getCanonicalizationPatterns(patterns_1, ctx);
145 if (failed(applyPatternsGreedily(device, std::move(patterns_1),
146 rewriter_config))) {
147 signalPassFailure();
148 }
149 }
150};
151
152std::unique_ptr<OperationPass<AIE::DeviceOp>>
154 return std::make_unique<AIEMaterializeBDChainsPass>();
155}
std::unique_ptr< mlir::OperationPass< AIE::DeviceOp > > createAIEMaterializeBDChainsPass()
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
DMAInlineBDChainPattern(MLIRContext *ctx)
LogicalResult matchAndRewrite(Operation *op_any, PatternRewriter &rewriter) const override