MLIR-AIE
AIECreateCores.cpp
Go to the documentation of this file.
1//===- AIECreateCores.cpp ---------------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2019 Xilinx Inc.
8//
9//===----------------------------------------------------------------------===//
10
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/IR/Attributes.h"
18#include "mlir/IR/IRMapping.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/Pass/Pass.h"
21#include "mlir/Tools/mlir-translate/MlirTranslateMain.h"
22#include "mlir/Transforms/DialectConversion.h"
23
24using namespace mlir;
25using namespace xilinx;
26using namespace xilinx::AIE;
27using namespace xilinx::AIEX;
28
29struct RemoveAIEFuncs : public OpConversionPattern<func::FuncOp> {
31 DenseMap<func::FuncOp, std::pair<int, int>> &funcs;
32
33 RemoveAIEFuncs(MLIRContext *context,
34 DenseMap<func::FuncOp, std::pair<int, int>> &funcs,
35 PatternBenefit benefit = 1)
36 : OpConversionPattern<func::FuncOp>(context, benefit), funcs(funcs) {}
37
38 LogicalResult
39 matchAndRewrite(func::FuncOp op, OpAdaptor adaptor,
40 ConversionPatternRewriter &rewriter) const override {
41 Operation *Op = op.getOperation();
42 if (funcs.find(op) == funcs.end())
43 return failure();
44
45 rewriter.eraseOp(Op);
46 return success();
47 }
48};
49
50struct RemoveAIECalls : public OpConversionPattern<func::CallOp> {
52
53 RemoveAIECalls(MLIRContext *context, PatternBenefit benefit = 1)
54 : OpConversionPattern<func::CallOp>(context, benefit) {}
55
56 LogicalResult
57 matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
58 ConversionPatternRewriter &rewriter) const override {
59 Operation *Op = op.getOperation();
60 if (!op->getAttr("aie.x") || !op->getAttr("aie.y"))
61 return failure();
62
63 rewriter.eraseOp(Op);
64 return success();
65 }
66};
67
68struct AIECreateCoresPass : public AIECreateCoresBase<AIECreateCoresPass> {
69 void runOnOperation() override {
70
71 DeviceOp device = getOperation();
72 OpBuilder builder = OpBuilder::atBlockTerminator(device.getBody());
73
74 DenseMap<TileID, Operation *> tiles;
75 DenseMap<Operation *, CoreOp> cores;
76 DenseMap<Operation *, MemOp> mems;
77 DenseMap<Value, Value> buffers;
78 DenseMap<func::FuncOp, std::pair<int, int>> funcs;
79
80 // Collect existing TileOps
81 for (auto tile : device.getOps<TileOp>()) {
82 int colIndex = tile.colIndex();
83 int rowIndex = tile.rowIndex();
84 tiles[{colIndex, rowIndex}] = tile;
85 }
86
87 // Bind FuncOp to an AIE core based on attributes of the CallOp
88 // A CoreOp will be created for the core, and the FuncOp body is cloned
89 // to the CoreOp region
90 for (auto callOp : device.getOps<func::CallOp>()) {
91 if (!callOp->getAttr("aie.x") || !callOp->getAttr("aie.y"))
92 continue;
93
94 SmallVector<Value, 4> callOperands(callOp.getArgOperands());
95 SmallVector<std::pair<MemRefType, int>, 4> coreBufTypes;
96
97 int colIndex = callOp->getAttrOfType<IntegerAttr>("aie.x").getInt();
98 int rowIndex = callOp->getAttrOfType<IntegerAttr>("aie.y").getInt();
99
100 // get or create TileOp
101 if (!tiles[{colIndex, rowIndex}]) {
102 builder.setInsertionPointToStart(device.getBody());
103 TileOp tile =
104 builder.create<TileOp>(builder.getUnknownLoc(), colIndex, rowIndex);
105 tiles[{colIndex, rowIndex}] = tile;
106 }
107 Operation *tileOp = tiles[{colIndex, rowIndex}];
108 TileOp tile = dyn_cast<TileOp>(tileOp);
109 builder.setInsertionPointAfter(tileOp);
110
111 // create MemOp
112 if (!mems[tileOp]) {
113 for (unsigned i = 0; i < callOperands.size(); i++) {
114 Value operand = callOperands[i]; // Should be produced by an AllocOp
115 MemRefType t = nullptr;
116 if (llvm::isa<MemRefType>(operand.getType())) {
117 t = llvm::cast<MemRefType>(operand.getType());
118 } else if (operand.getType().isIntOrFloat()) {
119 // promote scalar type to memref type
120 int64_t shape[1] = {1};
121 t = MemRefType::get(shape, operand.getType());
122 }
123
124 assert(t && "Unsupported type!");
125 coreBufTypes.push_back({t, i});
126 BufferOp buf = builder.create<BufferOp>(
127 builder.getUnknownLoc(), t, tile, /*sym_name*/ nullptr,
128 /*address*/ nullptr, /*initial_value*/ nullptr,
129 /*mem_bank*/ nullptr);
130 buffers[callOperands[i]] = buf;
131 operand.replaceAllUsesWith(buf.getResult());
132 }
133
134 MemOp mem = builder.create<MemOp>(builder.getUnknownLoc(),
135 builder.getIndexType(), tile);
136 Region &r = mem.getBody();
137 Block *endBlock = builder.createBlock(&r);
138
139 // block terminator
140 builder.setInsertionPointToStart(endBlock);
141 builder.create<EndOp>(builder.getUnknownLoc());
142 mems[tileOp] = mem;
143 }
144
145 // create CoreOp with buffer reference
146 if (CallOpInterface call =
147 dyn_cast<CallOpInterface>(callOp.getOperation())) {
148 Operation *callable = call.resolveCallable();
149 if (func::FuncOp func = dyn_cast<func::FuncOp>(callable)) {
150 funcs[func] = {colIndex, rowIndex};
151
152 IRMapping mapper;
153
154 builder.setInsertionPoint(callOp);
155
156 CoreOp core;
157 Block *currentBlock;
158
159 if (!cores[tileOp]) {
160 core = builder.create<CoreOp>(builder.getUnknownLoc(), tile);
161 Region &r = core.getBody();
162 currentBlock = builder.createBlock(&r);
163 builder.setInsertionPointToStart(currentBlock);
164 } else {
165 core = cores[tileOp];
166 currentBlock = &core.getBody().back();
167 builder.setInsertionPoint(currentBlock->getTerminator());
168 }
169
170 // Mapping between function arguments (FuncOp) and AIE buffers
171 // (CoreOp) We will create one buffer for each function argument If
172 // the function argument's type is a scalar, we promote it to a
173 // one-element memref, and do a load to the buffer at index 0
174 for (auto pair : coreBufTypes) {
175 int operandID = pair.second;
176 Value arg = func.getArgument(operandID);
177 Value buf = buffers[callOperands[operandID]];
178 if (arg.getType().isIntOrFloat()) {
179 assert(pair.first.getShape().size() == 1 &&
180 "Expected MemRefType of shape 1");
181 assert(pair.first.getShape()[0] == 1 &&
182 "Expected MemRefType of single element");
183
184 Value zero = builder.create<arith::ConstantIndexOp>(
185 builder.getUnknownLoc(), 0);
186 auto loadOp = builder.create<memref::LoadOp>(
187 builder.getUnknownLoc(), arg.getType(), buf, zero);
188 mapper.map(arg, loadOp);
189 } else {
190 mapper.map(arg, buf);
191 }
192 }
193
194 // Clone ops from the original function to CoreOp's body
195 for (auto &childOp : func.getCallableRegion()->getOps()) {
196 // skip ReturnOp since it lives only within a funcOp
197 if (auto returnOp = dyn_cast<func::ReturnOp>(childOp))
198 continue;
199
200 builder.clone(childOp, mapper);
201 }
202 if (!cores[tileOp]) {
203 // block terminator
204 builder.create<EndOp>(builder.getUnknownLoc());
205 cores[tileOp] = core;
206 }
207 }
208 }
209 }
210
211 // Setup FlowOps
212 // Since memcpy moves data from one memory module to another, we use
213 // WireBundle::DMA for both the source and the destination In addition, we
214 // only have two DMA ports per each direction (MM2S/S2MM), and in a
215 // circuit-switch mode, dest port/channel sharing is not possible.
216 // Therefore, we will generate error if the number of logical flows
217 // (streams) targeting the same destination (S2MM) is more than 2
218 // DenseMap<Value, int> destChannel;
219 // for (auto op : device.getOps<MemcpyOp>()) {
220 // builder.setInsertionPoint(op);
221 // TileOp srcTile = dyn_cast<TileOp>(op.srcTile().getDefiningOp());
222 // TileOp dstTile = dyn_cast<TileOp>(op.dstTile().getDefiningOp());
223 // // TODO: perhaps a better approach is to not assert here, but rather
224 // have a subsequent pass
225 // // that legally relocates the ports
226 // assert(destChannel[op.dstTile()] <= 2 &&
227 // "Could not allocate more than two dest. channel when creating
228 // FlowOp");
229 // // WireBundle[1] = DMA
230 // builder.create<FlowOp>(builder.getUnknownLoc(), srcTile, 1, 0, dstTile,
231 // 1, destChannel[op.dstTile()]); destChannel[op.dstTile()]++;
232 // }
233
234 ConversionTarget target(getContext());
235 RewritePatternSet patterns(&getContext());
236 target.addLegalOp<DMAStartOp>();
237 target.addLegalOp<DMABDOp>();
238 target.addLegalOp<UseTokenOp>();
239 target.addLegalOp<NextBDOp>();
240
241 // Remove standard CallOps and FuncOps that are bound to AIE CoreOps
242 patterns.insert<RemoveAIECalls>(device.getContext());
243 patterns.insert<RemoveAIEFuncs>(device.getContext(), funcs);
244
245 if (failed(applyPartialConversion(device, target, std::move(patterns))))
246 signalPassFailure();
247 }
248};
249
250std::unique_ptr<OperationPass<DeviceOp>>
252 return std::make_unique<AIECreateCoresPass>();
253}
std::unique_ptr< mlir::OperationPass< AIE::DeviceOp > > createAIECreateCoresPass()
Include the generated interface declarations.
void runOnOperation() override
RemoveAIECalls(MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(func::CallOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
RemoveAIEFuncs(MLIRContext *context, DenseMap< func::FuncOp, std::pair< int, int > > &funcs, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(func::FuncOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
DenseMap< func::FuncOp, std::pair< int, int > > & funcs