MLIR-AIE
AIEFindFlows.cpp
Go to the documentation of this file.
1//===- AIEFindFlows.cpp -----------------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2019 Xilinx Inc.
8//
9//===----------------------------------------------------------------------===//
10
13
14#include "mlir/IR/IRMapping.h"
15#include "mlir/Pass/Pass.h"
16
17#define DEBUG_TYPE "aie-find-flows"
18
19using namespace mlir;
20using namespace xilinx;
21using namespace xilinx::AIE;
22
23typedef struct MaskValue {
24 int mask;
25 int value;
27
28typedef struct PortConnection {
29 Operation *op;
32
37
42
44 DeviceOp &device;
45
46public:
47 ConnectivityAnalysis(DeviceOp &d) : device(d) {}
48
49private:
50 std::optional<PortConnection>
51 getConnectionThroughWire(Operation *op, Port masterPort) const {
52 LLVM_DEBUG(llvm::dbgs() << "Wire:" << *op << " "
53 << stringifyWireBundle(masterPort.bundle) << " "
54 << masterPort.channel << "\n");
55 for (auto wireOp : device.getOps<WireOp>()) {
56 if (wireOp.getSource().getDefiningOp() == op &&
57 wireOp.getSourceBundle() == masterPort.bundle) {
58 Operation *other = wireOp.getDest().getDefiningOp();
59 Port otherPort = {wireOp.getDestBundle(), masterPort.channel};
60 LLVM_DEBUG(llvm::dbgs() << "Connects To:" << *other << " "
61 << stringifyWireBundle(otherPort.bundle) << " "
62 << otherPort.channel << "\n");
63
64 return PortConnection{other, otherPort};
65 }
66 if (wireOp.getDest().getDefiningOp() == op &&
67 wireOp.getDestBundle() == masterPort.bundle) {
68 Operation *other = wireOp.getSource().getDefiningOp();
69 Port otherPort = {wireOp.getSourceBundle(), masterPort.channel};
70 LLVM_DEBUG(llvm::dbgs() << "Connects To:" << *other << " "
71 << stringifyWireBundle(otherPort.bundle) << " "
72 << otherPort.channel << "\n");
73 return PortConnection{other, otherPort};
74 }
75 }
76 LLVM_DEBUG(llvm::dbgs() << "*** Missing Wire!\n");
77 return std::nullopt;
78 }
79
80 std::vector<PortMaskValue>
81 getConnectionsThroughSwitchbox(Region &r, Port sourcePort) const {
82 LLVM_DEBUG(llvm::dbgs() << "Switchbox:\n");
83 Block &b = r.front();
84 std::vector<PortMaskValue> portSet;
85 for (auto connectOp : b.getOps<ConnectOp>()) {
86 if (connectOp.sourcePort() == sourcePort) {
87 MaskValue maskValue = {0, 0};
88 portSet.push_back({connectOp.destPort(), maskValue});
89 LLVM_DEBUG(llvm::dbgs()
90 << "To:" << stringifyWireBundle(connectOp.destPort().bundle)
91 << " " << connectOp.destPort().channel << "\n");
92 }
93 }
94 for (auto connectOp : b.getOps<PacketRulesOp>()) {
95 if (connectOp.sourcePort() == sourcePort) {
96 LLVM_DEBUG(llvm::dbgs()
97 << "Packet From: "
98 << stringifyWireBundle(connectOp.sourcePort().bundle) << " "
99 << sourcePort.channel << "\n");
100 for (auto masterSetOp : b.getOps<MasterSetOp>())
101 for (Value amsel : masterSetOp.getAmsels())
102 for (auto ruleOp :
103 connectOp.getRules().front().getOps<PacketRuleOp>()) {
104 if (ruleOp.getAmsel() == amsel) {
105 LLVM_DEBUG(llvm::dbgs()
106 << "To:"
107 << stringifyWireBundle(masterSetOp.destPort().bundle)
108 << " " << masterSetOp.destPort().channel << "\n");
109 MaskValue maskValue = {ruleOp.maskInt(), ruleOp.valueInt()};
110 portSet.push_back({masterSetOp.destPort(), maskValue});
111 }
112 }
113 }
114 }
115 return portSet;
116 }
117
118 std::vector<PacketConnection>
119 maskSwitchboxConnections(Operation *switchOp,
120 std::vector<PortMaskValue> nextPortMaskValues,
121 MaskValue maskValue) const {
122 std::vector<PacketConnection> worklist;
123 for (auto &nextPortMaskValue : nextPortMaskValues) {
124 Port nextPort = nextPortMaskValue.port;
125 MaskValue nextMaskValue = nextPortMaskValue.mv;
126 int maskConflicts = nextMaskValue.mask & maskValue.mask;
127 LLVM_DEBUG(llvm::dbgs() << "Mask: " << maskValue.mask << " "
128 << maskValue.value << "\n");
129 LLVM_DEBUG(llvm::dbgs() << "NextMask: " << nextMaskValue.mask << " "
130 << nextMaskValue.value << "\n");
131 LLVM_DEBUG(llvm::dbgs() << maskConflicts << "\n");
132
133 if ((maskConflicts & nextMaskValue.value) !=
134 (maskConflicts & maskValue.value)) {
135 // Incoming packets cannot match this rule. Skip it.
136 continue;
137 }
138 MaskValue newMaskValue = {maskValue.mask | nextMaskValue.mask,
139 maskValue.value |
140 (nextMaskValue.mask & nextMaskValue.value)};
141 auto nextConnection = getConnectionThroughWire(switchOp, nextPort);
142
143 // If there is no wire to follow then bail out.
144 if (!nextConnection)
145 continue;
146
147 worklist.push_back({*nextConnection, newMaskValue});
148 }
149 return worklist;
150 }
151
152public:
153 // Get the tiles connected to the given tile, starting from the given
154 // output port of the tile. This is 1:N relationship because each
155 // switchbox can broadcast.
156 std::vector<PacketConnection> getConnectedTiles(TileOp tileOp,
157 Port port) const {
158
159 LLVM_DEBUG(llvm::dbgs()
160 << "getConnectedTile(" << stringifyWireBundle(port.bundle) << " "
161 << port.channel << ")");
162 LLVM_DEBUG(tileOp.dump());
163
164 // The accumulated result;
165 std::vector<PacketConnection> connectedTiles;
166 // A worklist of PortConnections to visit. These are all input ports of
167 // some object (likely either a TileOp or a SwitchboxOp).
168 std::vector<PacketConnection> worklist;
169 // Start the worklist by traversing from the tile to its connected
170 // switchbox.
171 auto t = getConnectionThroughWire(tileOp.getOperation(), port);
172
173 // If there is no wire to traverse, then just return no connection
174 if (!t)
175 return connectedTiles;
176 worklist.push_back({*t, {0, 0}});
177
178 while (!worklist.empty()) {
179 PacketConnection t = worklist.back();
180 worklist.pop_back();
181 PortConnection portConnection = t.portConnection;
182 MaskValue maskValue = t.mv;
183 Operation *other = portConnection.op;
184 Port otherPort = portConnection.port;
185 if (isa<FlowEndPoint>(other)) {
186 // If we got to a tile, then add it to the result.
187 connectedTiles.push_back(t);
188 } else if (auto switchOp = dyn_cast_or_null<SwitchboxOp>(other)) {
189 std::vector<PortMaskValue> nextPortMaskValues =
190 getConnectionsThroughSwitchbox(switchOp.getConnections(),
191 otherPort);
192 std::vector<PacketConnection> newWorkList =
193 maskSwitchboxConnections(switchOp, nextPortMaskValues, maskValue);
194 // append to the worklist
195 worklist.insert(worklist.end(), newWorkList.begin(), newWorkList.end());
196 if (!nextPortMaskValues.empty() && newWorkList.empty()) {
197 // No rule matched some incoming packet. This is likely a
198 // configuration error.
199 LLVM_DEBUG(llvm::dbgs() << "No rule matched incoming packet here: ");
200 LLVM_DEBUG(other->dump());
201 }
202 } else if (auto switchOp = dyn_cast_or_null<ShimMuxOp>(other)) {
203 std::vector<PortMaskValue> nextPortMaskValues =
204 getConnectionsThroughSwitchbox(switchOp.getConnections(),
205 otherPort);
206 std::vector<PacketConnection> newWorkList =
207 maskSwitchboxConnections(switchOp, nextPortMaskValues, maskValue);
208 // append to the worklist
209 worklist.insert(worklist.end(), newWorkList.begin(), newWorkList.end());
210 if (!nextPortMaskValues.empty() && newWorkList.empty()) {
211 // No rule matched some incoming packet. This is likely a
212 // configuration error.
213 LLVM_DEBUG(llvm::dbgs() << "No rule matched incoming packet here: ");
214 LLVM_DEBUG(other->dump());
215 }
216 } else {
217 LLVM_DEBUG(llvm::dbgs()
218 << "*** Connection Terminated at unknown operation: ");
219 LLVM_DEBUG(other->dump());
220 }
221 }
222 return connectedTiles;
223 }
224};
225
226static void findFlowsFrom(TileOp op, ConnectivityAnalysis &analysis,
227 OpBuilder &rewriter) {
228 Operation *Op = op.getOperation();
229 rewriter.setInsertionPoint(Op->getBlock()->getTerminator());
230
231 std::vector bundles = {WireBundle::Core, WireBundle::DMA};
232 for (WireBundle bundle : bundles) {
233 LLVM_DEBUG(llvm::dbgs()
234 << op << stringifyWireBundle(bundle) << " has "
235 << op.getNumSourceConnections(bundle) << " Connections\n");
236 for (size_t i = 0; i < op.getNumSourceConnections(bundle); i++) {
237 std::vector<PacketConnection> tiles =
238 analysis.getConnectedTiles(op, {bundle, (int)i});
239 LLVM_DEBUG(llvm::dbgs() << tiles.size() << " Flows\n");
240
241 for (PacketConnection &c : tiles) {
242 PortConnection portConnection = c.portConnection;
243 MaskValue maskValue = c.mv;
244 Operation *destOp = portConnection.op;
245 Port destPort = portConnection.port;
246 if (maskValue.mask == 0) {
247 rewriter.create<FlowOp>(Op->getLoc(), Op->getResult(0), bundle, i,
248 destOp->getResult(0), destPort.bundle,
249 destPort.channel);
250 } else {
251 auto flowOp = rewriter.create<PacketFlowOp>(
252 Op->getLoc(), maskValue.value, nullptr, nullptr);
253 PacketFlowOp::ensureTerminator(flowOp.getPorts(), rewriter,
254 Op->getLoc());
255 OpBuilder::InsertPoint ip = rewriter.saveInsertionPoint();
256 rewriter.setInsertionPoint(flowOp.getPorts().front().getTerminator());
257 rewriter.create<PacketSourceOp>(Op->getLoc(), Op->getResult(0),
258 bundle, i);
259 rewriter.create<PacketDestOp>(Op->getLoc(), destOp->getResult(0),
260 destPort.bundle, destPort.channel);
261 rewriter.restoreInsertionPoint(ip);
262 }
263 }
264 }
265 }
266}
267
268struct AIEFindFlowsPass : public AIEFindFlowsBase<AIEFindFlowsPass> {
269 void getDependentDialects(DialectRegistry &registry) const override {
270 registry.insert<func::FuncDialect>();
271 registry.insert<AIEDialect>();
272 }
273 void runOnOperation() override {
274
275 DeviceOp d = getOperation();
276 ConnectivityAnalysis analysis(d);
277 d.getTargetModel().validate();
278
279 OpBuilder builder = OpBuilder::atBlockTerminator(d.getBody());
280 for (auto tile : d.getOps<TileOp>()) {
281 findFlowsFrom(tile, analysis, builder);
282 }
283 }
284};
285
286std::unique_ptr<OperationPass<DeviceOp>> AIE::createAIEFindFlowsPass() {
287 return std::make_unique<AIEFindFlowsPass>();
288}
std::vector< PacketConnection > getConnectedTiles(TileOp tileOp, Port port) const
ConnectivityAnalysis(DeviceOp &d)
Include the generated interface declarations.
std::unique_ptr< mlir::OperationPass< DeviceOp > > createAIEFindFlowsPass()
Port { WireBundle bundle Port
Definition AIEDialect.h:118
void getDependentDialects(DialectRegistry &registry) const override
void runOnOperation() override
PortConnection portConnection
Operation * op