MLIR-AIE
AIEHerdRouting.cpp
Go to the documentation of this file.
1//===- AIEHerdRouting.cpp ---------------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2019 Xilinx Inc.
8//
9//===----------------------------------------------------------------------===//
10
14
15#include "mlir/IR/PatternMatch.h"
16#include "mlir/Pass/Pass.h"
17#include "mlir/Tools/mlir-translate/MlirTranslateMain.h"
18#include "mlir/Transforms/DialectConversion.h"
19
20#define DEBUG_TYPE "aie-herd-routing"
21
22using namespace mlir;
23using namespace xilinx;
24using namespace xilinx::AIE;
25using namespace xilinx::AIEX;
26
27template <typename MyOp> struct AIEOpRemoval : OpConversionPattern<MyOp> {
29 using OpAdaptor = typename MyOp::Adaptor;
30
31 explicit AIEOpRemoval(MLIRContext *context, PatternBenefit benefit = 1)
32 : OpConversionPattern<MyOp>(context, benefit) {}
33
34 LogicalResult
35 matchAndRewrite(MyOp op, OpAdaptor operands,
36 ConversionPatternRewriter &rewriter) const override {
37 Operation *Op = op.getOperation();
38
39 rewriter.eraseOp(Op);
40 return success();
41 }
42};
43
44std::optional<int> getAvailableDestChannel(SmallVector<Connect, 8> &connects,
45 Port sourcePort,
46 WireBundle destBundle) {
47
48 if (connects.empty())
49 return {0};
50
51 int numChannels;
52
53 if (destBundle == WireBundle::North)
54 numChannels = 6;
55 else if (destBundle == WireBundle::South || destBundle == WireBundle::East ||
56 destBundle == WireBundle::West)
57 numChannels = 4;
58 else
59 numChannels = 2;
60
61 // look for existing connect
62 for (int i = 0; i < numChannels; i++) {
63 if (Port port = {destBundle, i};
64 std::find(connects.begin(), connects.end(),
65 Connect{sourcePort, port}) != connects.end())
66 return {i};
67 }
68
69 // if not, look for available destination port
70 for (int i = 0; i < numChannels; i++) {
71 Port port = {destBundle, i};
72 SmallVector<Port, 8> ports;
73 for (auto [src, dst] : connects)
74 ports.push_back(dst);
75
76 if (std::find(ports.begin(), ports.end(), port) == ports.end())
77 return {i};
78 }
79
80 return std::nullopt;
81}
82
83void buildRoute(int xSrc, int ySrc, int xDest, int yDest,
84 WireBundle sourceBundle, int sourceChannel,
85 WireBundle destBundle, int destChannel, Operation *herdOp,
86 DenseMap<std::pair<Operation *, TileID>,
87 SmallVector<Connect, 8>> &switchboxes) {
88
89 int xCur = xSrc;
90 int yCur = ySrc;
91 WireBundle curBundle = WireBundle::Core;
92 int curChannel = 0;
93 WireBundle lastBundle = WireBundle::Core;
94 Port lastPort = {sourceBundle, sourceChannel};
95
96 SmallVector<TileID, 4> congestion;
97
98 LLVM_DEBUG(llvm::dbgs() << "Build route: " << xSrc << " " << ySrc << " --> "
99 << xDest << " " << yDest << '\n');
100 // traverse horizontally, then vertically
101 while (!(xCur == xDest && yCur == yDest)) {
102 LLVM_DEBUG(llvm::dbgs() << "coord " << xCur << " " << yCur << '\n');
103 TileID curCoord = {xCur, yCur};
104 SmallVector<WireBundle, 4> moves;
105
106 if (xCur < xDest)
107 moves.push_back(WireBundle::East);
108 if (xCur > xDest)
109 moves.push_back(WireBundle::West);
110 if (yCur < yDest)
111 moves.push_back(WireBundle::North);
112 if (yCur > yDest)
113 moves.push_back(WireBundle::South);
114
115 if (std::find(moves.begin(), moves.end(), WireBundle::East) == moves.end())
116 moves.push_back(WireBundle::East);
117 if (std::find(moves.begin(), moves.end(), WireBundle::West) == moves.end())
118 moves.push_back(WireBundle::West);
119 if (std::find(moves.begin(), moves.end(), WireBundle::North) == moves.end())
120 moves.push_back(WireBundle::North);
121 if (std::find(moves.begin(), moves.end(), WireBundle::South) == moves.end())
122 moves.push_back(WireBundle::South);
123
124 for (auto move : moves) {
125 if (auto maybeDestChannel = getAvailableDestChannel(
126 switchboxes[std::make_pair(herdOp, curCoord)], lastPort, move))
127 curChannel = maybeDestChannel.value();
128 else
129 continue;
130
131 if (move == lastBundle)
132 continue;
133
134 if (move == WireBundle::East)
135 xCur = xCur + 1;
136 // yCur = yCur;
137 else if (move == WireBundle::West)
138 xCur = xCur - 1;
139 // yCur = yCur;
140 else if (move == WireBundle::North)
141 // xCur = xCur;
142 yCur = yCur + 1;
143 else if (move == WireBundle::South)
144 // xCur = xCur;
145 yCur = yCur - 1;
146
147 if (std::find(congestion.begin(), congestion.end(), TileID{xCur, yCur}) !=
148 congestion.end())
149 continue;
150
151 curBundle = move;
152 lastBundle = move == WireBundle::East ? WireBundle::West
153 : move == WireBundle::West ? WireBundle::East
154 : move == WireBundle::North ? WireBundle::South
155 : move == WireBundle::South ? WireBundle::North
156 : lastBundle;
157 break;
158 }
159
160 assert(curChannel >= 0 && "Could not find available destination port!");
161 LLVM_DEBUG(llvm::dbgs()
162 << "[" << stringifyWireBundle(lastPort.bundle) << " : "
163 << lastPort.channel << "], [" << stringifyWireBundle(curBundle)
164 << " : " << curChannel << "]\n");
165
166 Port curPort = {curBundle, curChannel};
167 Connect connect = {lastPort, curPort};
168 if (std::find(switchboxes[std::make_pair(herdOp, curCoord)].begin(),
169 switchboxes[std::make_pair(herdOp, curCoord)].end(),
170 connect) ==
171 switchboxes[std::make_pair(herdOp, curCoord)].end())
172 switchboxes[std::make_pair(herdOp, curCoord)].push_back(connect);
173 lastPort = {lastBundle, curChannel};
174 }
175
176 LLVM_DEBUG(llvm::dbgs() << "coord " << xCur << " " << yCur << '\n');
177 LLVM_DEBUG(llvm::dbgs() << "[" << stringifyWireBundle(lastPort.bundle)
178 << " : " << lastPort.channel << "], ["
179 << stringifyWireBundle(destBundle) << " : "
180 << destChannel << "]\n");
181
182 switchboxes[std::make_pair(herdOp, TileID{xCur, yCur})].push_back(
183 {lastPort, Port{destBundle, destChannel}});
184}
185
186struct AIEHerdRoutingPass : AIEHerdRoutingBase<AIEHerdRoutingPass> {
187 void runOnOperation() override {
188
189 DeviceOp device = getOperation();
190 OpBuilder builder(device.getBody()->getTerminator());
191
192 SmallVector<HerdOp, 4> herds;
193 SmallVector<Operation *, 4> placeOps;
194 SmallVector<Operation *, 4> routeOps;
195 DenseMap<std::pair<Operation *, Operation *>, std::pair<int, int>>
196 distances;
197 SmallVector<std::pair<std::pair<int, int>, std::pair<int, int>>, 4> routes;
198 DenseMap<std::pair<Operation *, TileID>, SmallVector<Connect, 8>>
199 switchboxes;
200
201 for (auto herd : device.getOps<HerdOp>())
202 herds.push_back(herd);
203
204 for (auto placeOp : device.getOps<PlaceOp>()) {
205 placeOps.push_back(placeOp);
206 Operation *sourceHerd = placeOp.getSourceHerd().getDefiningOp();
207 Operation *destHerd = placeOp.getDestHerd().getDefiningOp();
208 int distX = placeOp.getDistXValue();
209 int distY = placeOp.getDistYValue();
210 distances[std::make_pair(sourceHerd, destHerd)] =
211 std::make_pair(distX, distY);
212 }
213
214 // FIXME: multiple route ops with different sourceHerds does not seem to be
215 // aware of the routes done before
216 for (auto routeOp : device.getOps<RouteOp>()) {
217 routeOps.push_back(routeOp);
218
219 auto sourceHerds =
220 dyn_cast<SelectOp>(routeOp.getSourceHerds().getDefiningOp());
221 auto destHerds =
222 dyn_cast<SelectOp>(routeOp.getDestHerds().getDefiningOp());
223 WireBundle sourceBundle = routeOp.getSourceBundle();
224 WireBundle destBundle = routeOp.getDestBundle();
225 int sourceChannel = routeOp.getSourceChannelValue();
226 int destChannel = routeOp.getDestChannelValue();
227
228 HerdOp sourceHerd =
229 dyn_cast<HerdOp>(sourceHerds.getStartHerd().getDefiningOp());
230 IterOp sourceIterX =
231 dyn_cast<IterOp>(sourceHerds.getIterX().getDefiningOp());
232 IterOp sourceIterY =
233 dyn_cast<IterOp>(sourceHerds.getIterY().getDefiningOp());
234
235 HerdOp destHerd =
236 dyn_cast<HerdOp>(destHerds.getStartHerd().getDefiningOp());
237 IterOp destIterX = dyn_cast<IterOp>(destHerds.getIterX().getDefiningOp());
238 IterOp destIterY = dyn_cast<IterOp>(destHerds.getIterY().getDefiningOp());
239
240 int sourceStartX = sourceIterX.getStartValue();
241 int sourceEndX = sourceIterX.getEndValue();
242 int sourceStrideX = sourceIterX.getStrideValue();
243 int sourceStartY = sourceIterY.getStartValue();
244 int sourceEndY = sourceIterY.getEndValue();
245 int sourceStrideY = sourceIterY.getStrideValue();
246
247 int destStartX = destIterX.getStartValue();
248 int destEndX = destIterX.getEndValue();
249 int destStrideX = destIterX.getStrideValue();
250 int destStartY = destIterY.getStartValue();
251 int destEndY = destIterY.getEndValue();
252 int destStrideY = destIterY.getStrideValue();
253
254 assert(distances.count(std::make_pair(sourceHerd, destHerd)) == 1);
255
256 auto [distX, distY] = distances[std::make_pair(sourceHerd, destHerd)];
257 // FIXME: this looks like it can be improved further ...
258 for (int xSrc = sourceStartX; xSrc < sourceEndX; xSrc += sourceStrideX)
259 for (int ySrc = sourceStartY; ySrc < sourceEndY; ySrc += sourceStrideY)
260 for (int xDst = destStartX; xDst < destEndX; xDst += destStrideX)
261 for (int yDst = destStartY; yDst < destEndY; yDst += destStrideY) {
262 // Build route (x0, y0) --> (x1, y1)
263 int x0 = xSrc;
264 int y0 = ySrc;
265 int x1 = xDst;
266 int y1 = yDst;
267 if (destIterX == sourceIterX)
268 x1 = x0;
269 if (destIterY == sourceIterX)
270 y1 = x0;
271 if (destIterX == sourceIterY)
272 x1 = y0;
273 if (destIterY == sourceIterY)
274 y1 = y0;
275
276 auto route = std::make_pair(
277 std::make_pair(x0, y0),
278 std::make_pair(distX + x1 - x0, distY + y1 - y0));
279 if (std::find(routes.begin(), routes.end(), route) !=
280 routes.end())
281 continue;
282
283 buildRoute(x0, y0, x1 + distX, y1 + distY, sourceBundle,
284 sourceChannel, destBundle, destChannel, sourceHerd,
285 switchboxes);
286
287 routes.push_back(route);
288 }
289 }
290
291 for (const auto &swboxCfg : switchboxes) {
292 Operation *herdOp = swboxCfg.first.first;
293 int x = swboxCfg.first.second.col;
294 int y = swboxCfg.first.second.row;
295 auto connects = swboxCfg.second;
296 HerdOp herd = dyn_cast<HerdOp>(herdOp);
297
298 builder.setInsertionPoint(device.getBody()->getTerminator());
299
300 auto iterx = builder.create<IterOp>(builder.getUnknownLoc(), x, x + 1, 1);
301 auto itery = builder.create<IterOp>(builder.getUnknownLoc(), y, y + 1, 1);
302 auto sel =
303 builder.create<SelectOp>(builder.getUnknownLoc(), herd, iterx, itery);
304 auto swbox = builder.create<SwitchboxOp>(builder.getUnknownLoc(), sel);
305 SwitchboxOp::ensureTerminator(swbox.getConnections(), builder,
306 builder.getUnknownLoc());
307 Block &b = swbox.getConnections().front();
308 builder.setInsertionPoint(b.getTerminator());
309
310 for (auto [sourcePort, destPort] : connects) {
311 WireBundle sourceBundle = sourcePort.bundle;
312 int sourceChannel = sourcePort.channel;
313 WireBundle destBundle = destPort.bundle;
314 int destChannel = destPort.channel;
315
316 builder.create<ConnectOp>(builder.getUnknownLoc(), sourceBundle,
317 sourceChannel, destBundle, destChannel);
318 }
319 }
320
321 ConversionTarget target(getContext());
322
323 RewritePatternSet patterns(&getContext());
325 device.getContext());
326
327 if (failed(applyPartialConversion(device, target, std::move(patterns))))
328 signalPassFailure();
329 }
330};
331
332std::unique_ptr<OperationPass<DeviceOp>> AIEX::createAIEHerdRoutingPass() {
333 return std::make_unique<AIEHerdRoutingPass>();
334}
std::optional< int > getAvailableDestChannel(SmallVector< Connect, 8 > &connects, Port sourcePort, WireBundle destBundle)
void buildRoute(int xSrc, int ySrc, int xDest, int yDest, WireBundle sourceBundle, int sourceChannel, WireBundle destBundle, int destChannel, Operation *herdOp, DenseMap< std::pair< Operation *, TileID >, SmallVector< Connect, 8 > > &switchboxes)
std::unique_ptr< mlir::OperationPass< AIE::DeviceOp > > createAIEHerdRoutingPass()
Include the generated interface declarations.
Connect { Port src Connect
Definition AIEDialect.h:171
TileID { friend std::ostream &operator<<(std::ostream &os, const TileID &s) { os<< "TileID("<< s.col<< ", "<< s.row<< ")" TileID
Port { WireBundle bundle Port
Definition AIEDialect.h:118
PathEndPoint src
void runOnOperation() override
LogicalResult matchAndRewrite(MyOp op, OpAdaptor operands, ConversionPatternRewriter &rewriter) const override
AIEOpRemoval(MLIRContext *context, PatternBenefit benefit=1)
typename MyAIEOp::Adaptor OpAdaptor