MLIR-AIE
AIEFindFlows.cpp
Go to the documentation of this file.
1//===- AIEFindFlows.cpp -----------------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2019 Xilinx Inc.
8//
9//===----------------------------------------------------------------------===//
10
13
14#include "mlir/IR/IRMapping.h"
15#include "mlir/Pass/Pass.h"
16
17namespace xilinx::AIE {
18#define GEN_PASS_DEF_AIEFINDFLOWS
19#include "aie/Dialect/AIE/Transforms/AIEPasses.h.inc"
20} // namespace xilinx::AIE
21
22#define DEBUG_TYPE "aie-find-flows"
23
24using namespace mlir;
25using namespace xilinx;
26using namespace xilinx::AIE;
27
28typedef struct MaskValue {
29 int mask;
30 int value;
32
33typedef struct PortConnection {
34 Operation *op;
37
42
47
49 DeviceOp &device;
50
51public:
52 ConnectivityAnalysis(DeviceOp &d) : device(d) {}
53
54private:
55 std::optional<PortConnection>
56 getConnectionThroughWire(Operation *op, Port masterPort) const {
57 LLVM_DEBUG(llvm::dbgs() << "Wire:" << *op << " "
58 << stringifyWireBundle(masterPort.bundle) << " "
59 << masterPort.channel << "\n");
60 for (auto wireOp : device.getOps<WireOp>()) {
61 if (wireOp.getSource().getDefiningOp() == op &&
62 wireOp.getSourceBundle() == masterPort.bundle) {
63 Operation *other = wireOp.getDest().getDefiningOp();
64 Port otherPort = {wireOp.getDestBundle(), masterPort.channel};
65 LLVM_DEBUG(llvm::dbgs() << "Connects To:" << *other << " "
66 << stringifyWireBundle(otherPort.bundle) << " "
67 << otherPort.channel << "\n");
68
69 return PortConnection{other, otherPort};
70 }
71 if (wireOp.getDest().getDefiningOp() == op &&
72 wireOp.getDestBundle() == masterPort.bundle) {
73 Operation *other = wireOp.getSource().getDefiningOp();
74 Port otherPort = {wireOp.getSourceBundle(), masterPort.channel};
75 LLVM_DEBUG(llvm::dbgs() << "Connects To:" << *other << " "
76 << stringifyWireBundle(otherPort.bundle) << " "
77 << otherPort.channel << "\n");
78 return PortConnection{other, otherPort};
79 }
80 }
81 LLVM_DEBUG(llvm::dbgs() << "*** Missing Wire!\n");
82 return std::nullopt;
83 }
84
85 std::vector<PortMaskValue>
86 getConnectionsThroughSwitchbox(Region &r, Port sourcePort) const {
87 LLVM_DEBUG(llvm::dbgs() << "Switchbox:\n");
88 Block &b = r.front();
89 std::vector<PortMaskValue> portSet;
90 for (auto connectOp : b.getOps<ConnectOp>()) {
91 if (connectOp.sourcePort() == sourcePort) {
92 MaskValue maskValue = {0, 0};
93 portSet.push_back({connectOp.destPort(), maskValue});
94 LLVM_DEBUG(llvm::dbgs()
95 << "To:" << stringifyWireBundle(connectOp.destPort().bundle)
96 << " " << connectOp.destPort().channel << "\n");
97 }
98 }
99 for (auto connectOp : b.getOps<PacketRulesOp>()) {
100 if (connectOp.sourcePort() == sourcePort) {
101 LLVM_DEBUG(llvm::dbgs()
102 << "Packet From: "
103 << stringifyWireBundle(connectOp.sourcePort().bundle) << " "
104 << sourcePort.channel << "\n");
105 for (auto masterSetOp : b.getOps<MasterSetOp>())
106 for (Value amsel : masterSetOp.getAmsels())
107 for (auto ruleOp :
108 connectOp.getRules().front().getOps<PacketRuleOp>()) {
109 if (ruleOp.getAmsel() == amsel) {
110 LLVM_DEBUG(llvm::dbgs()
111 << "To:"
112 << stringifyWireBundle(masterSetOp.destPort().bundle)
113 << " " << masterSetOp.destPort().channel << "\n");
114 MaskValue maskValue = {ruleOp.maskInt(), ruleOp.valueInt()};
115 portSet.push_back({masterSetOp.destPort(), maskValue});
116 }
117 }
118 }
119 }
120 return portSet;
121 }
122
123 std::vector<PacketConnection>
124 maskSwitchboxConnections(Operation *switchOp,
125 std::vector<PortMaskValue> nextPortMaskValues,
126 MaskValue maskValue) const {
127 std::vector<PacketConnection> worklist;
128 for (auto &nextPortMaskValue : nextPortMaskValues) {
129 Port nextPort = nextPortMaskValue.port;
130 MaskValue nextMaskValue = nextPortMaskValue.mv;
131 int maskConflicts = nextMaskValue.mask & maskValue.mask;
132 LLVM_DEBUG(llvm::dbgs() << "Mask: " << maskValue.mask << " "
133 << maskValue.value << "\n");
134 LLVM_DEBUG(llvm::dbgs() << "NextMask: " << nextMaskValue.mask << " "
135 << nextMaskValue.value << "\n");
136 LLVM_DEBUG(llvm::dbgs() << maskConflicts << "\n");
137
138 if ((maskConflicts & nextMaskValue.value) !=
139 (maskConflicts & maskValue.value)) {
140 // Incoming packets cannot match this rule. Skip it.
141 continue;
142 }
143 MaskValue newMaskValue = {maskValue.mask | nextMaskValue.mask,
144 maskValue.value |
145 (nextMaskValue.mask & nextMaskValue.value)};
146 auto nextConnection = getConnectionThroughWire(switchOp, nextPort);
147
148 // If there is no wire to follow then bail out.
149 if (!nextConnection)
150 continue;
151
152 worklist.push_back({*nextConnection, newMaskValue});
153 }
154 return worklist;
155 }
156
157public:
158 // Get the tiles connected to the given tile, starting from the given
159 // output port of the tile. This is 1:N relationship because each
160 // switchbox can broadcast.
161 std::vector<PacketConnection> getConnectedTiles(TileOp tileOp,
162 Port port) const {
163
164 LLVM_DEBUG(llvm::dbgs()
165 << "getConnectedTile(" << stringifyWireBundle(port.bundle) << " "
166 << port.channel << ")");
167 LLVM_DEBUG(tileOp.dump());
168
169 // The accumulated result;
170 std::vector<PacketConnection> connectedTiles;
171 // A worklist of PortConnections to visit. These are all input ports of
172 // some object (likely either a TileOp or a SwitchboxOp).
173 std::vector<PacketConnection> worklist;
174 // Start the worklist by traversing from the tile to its connected
175 // switchbox.
176 auto t = getConnectionThroughWire(tileOp.getOperation(), port);
177
178 // If there is no wire to traverse, then just return no connection
179 if (!t)
180 return connectedTiles;
181 worklist.push_back({*t, {0, 0}});
182
183 while (!worklist.empty()) {
184 PacketConnection t = worklist.back();
185 worklist.pop_back();
186 PortConnection portConnection = t.portConnection;
187 MaskValue maskValue = t.mv;
188 Operation *other = portConnection.op;
189 Port otherPort = portConnection.port;
190 if (other && other->hasTrait<IsFlowEndPoint>()) {
191 // If we got to a tile, then add it to the result.
192 connectedTiles.push_back(t);
193 } else if (auto switchOp = dyn_cast_or_null<SwitchboxOp>(other)) {
194 std::vector<PortMaskValue> nextPortMaskValues =
195 getConnectionsThroughSwitchbox(switchOp.getConnections(),
196 otherPort);
197 std::vector<PacketConnection> newWorkList =
198 maskSwitchboxConnections(switchOp, nextPortMaskValues, maskValue);
199 // append to the worklist
200 worklist.insert(worklist.end(), newWorkList.begin(), newWorkList.end());
201 if (!nextPortMaskValues.empty() && newWorkList.empty()) {
202 // No rule matched some incoming packet. This is likely a
203 // configuration error.
204 LLVM_DEBUG(llvm::dbgs() << "No rule matched incoming packet here: ");
205 LLVM_DEBUG(other->dump());
206 }
207 } else if (auto switchOp = dyn_cast_or_null<ShimMuxOp>(other)) {
208 std::vector<PortMaskValue> nextPortMaskValues =
209 getConnectionsThroughSwitchbox(switchOp.getConnections(),
210 otherPort);
211 std::vector<PacketConnection> newWorkList =
212 maskSwitchboxConnections(switchOp, nextPortMaskValues, maskValue);
213 // append to the worklist
214 worklist.insert(worklist.end(), newWorkList.begin(), newWorkList.end());
215 if (!nextPortMaskValues.empty() && newWorkList.empty()) {
216 // No rule matched some incoming packet. This is likely a
217 // configuration error.
218 LLVM_DEBUG(llvm::dbgs() << "No rule matched incoming packet here: ");
219 LLVM_DEBUG(other->dump());
220 }
221 } else {
222 LLVM_DEBUG(llvm::dbgs()
223 << "*** Connection Terminated at unknown operation: ");
224 LLVM_DEBUG(other->dump());
225 }
226 }
227 return connectedTiles;
228 }
229};
230
231static void findFlowsFrom(TileOp op, ConnectivityAnalysis &analysis,
232 OpBuilder &rewriter) {
233 Operation *Op = op.getOperation();
234 rewriter.setInsertionPoint(Op->getBlock()->getTerminator());
235
236 std::vector bundles = {WireBundle::Core, WireBundle::DMA};
237 for (WireBundle bundle : bundles) {
238 LLVM_DEBUG(llvm::dbgs()
239 << op << stringifyWireBundle(bundle) << " has "
240 << op.getNumSourceConnections(bundle) << " Connections\n");
241 for (size_t i = 0; i < op.getNumSourceConnections(bundle); i++) {
242 std::vector<PacketConnection> tiles =
243 analysis.getConnectedTiles(op, {bundle, (int)i});
244 LLVM_DEBUG(llvm::dbgs() << tiles.size() << " Flows\n");
245
246 for (PacketConnection &c : tiles) {
247 PortConnection portConnection = c.portConnection;
248 MaskValue maskValue = c.mv;
249 Operation *destOp = portConnection.op;
250 Port destPort = portConnection.port;
251 if (maskValue.mask == 0) {
252 FlowOp::create(rewriter, Op->getLoc(), Op->getResult(0), bundle, i,
253 destOp->getResult(0), destPort.bundle,
254 destPort.channel);
255 } else {
256 auto flowOp = PacketFlowOp::create(rewriter, Op->getLoc(),
257 maskValue.value, nullptr, nullptr);
258 PacketFlowOp::ensureTerminator(flowOp.getPorts(), rewriter,
259 Op->getLoc());
260 OpBuilder::InsertPoint ip = rewriter.saveInsertionPoint();
261 rewriter.setInsertionPoint(flowOp.getPorts().front().getTerminator());
262 PacketSourceOp::create(rewriter, Op->getLoc(), Op->getResult(0),
263 bundle, i);
264 PacketDestOp::create(rewriter, Op->getLoc(), destOp->getResult(0),
265 destPort.bundle, destPort.channel);
266 rewriter.restoreInsertionPoint(ip);
267 }
268 }
269 }
270 }
271}
272
274 : public xilinx::AIE::impl::AIEFindFlowsBase<AIEFindFlowsPass> {
275 void getDependentDialects(DialectRegistry &registry) const override {
276 registry.insert<func::FuncDialect>();
277 registry.insert<AIEDialect>();
278 }
279 void runOnOperation() override {
280
281 DeviceOp d = getOperation();
282 ConnectivityAnalysis analysis(d);
283 d.getTargetModel().validate();
284
285 OpBuilder builder = OpBuilder::atBlockTerminator(d.getBody());
286 for (auto tile : d.getOps<TileOp>()) {
287 findFlowsFrom(tile, analysis, builder);
288 }
289 }
290};
291
292std::unique_ptr<OperationPass<DeviceOp>> AIE::createAIEFindFlowsPass() {
293 return std::make_unique<AIEFindFlowsPass>();
294}
std::vector< PacketConnection > getConnectedTiles(TileOp tileOp, Port port) const
ConnectivityAnalysis(DeviceOp &d)
Include the generated interface declarations.
std::unique_ptr< mlir::OperationPass< DeviceOp > > createAIEFindFlowsPass()
Port { WireBundle bundle Port
Definition AIEDialect.h:128
void getDependentDialects(DialectRegistry &registry) const override
void runOnOperation() override
PortConnection portConnection
Operation * op