MLIR-AIE
AIEHerdRouting.cpp
Go to the documentation of this file.
1//===- AIEHerdRouting.cpp ---------------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2019 Xilinx Inc.
8//
9//===----------------------------------------------------------------------===//
10
14
15#include "mlir/IR/PatternMatch.h"
16#include "mlir/Pass/Pass.h"
17#include "mlir/Tools/mlir-translate/MlirTranslateMain.h"
18#include "mlir/Transforms/DialectConversion.h"
19
20namespace xilinx::AIEX {
21#define GEN_PASS_DEF_AIEHERDROUTING
22#include "aie/Dialect/AIEX/Transforms/AIEXPasses.h.inc"
23} // namespace xilinx::AIEX
24
25#define DEBUG_TYPE "aie-herd-routing"
26
27using namespace mlir;
28using namespace xilinx;
29using namespace xilinx::AIE;
30using namespace xilinx::AIEX;
31
32template <typename MyOp>
33struct AIEOpRemoval : OpConversionPattern<MyOp> {
35 using OpAdaptor = typename MyOp::Adaptor;
36
37 explicit AIEOpRemoval(MLIRContext *context, PatternBenefit benefit = 1)
38 : OpConversionPattern<MyOp>(context, benefit) {}
39
40 LogicalResult
41 matchAndRewrite(MyOp op, OpAdaptor operands,
42 ConversionPatternRewriter &rewriter) const override {
43 Operation *Op = op.getOperation();
44
45 rewriter.eraseOp(Op);
46 return success();
47 }
48};
49
50std::optional<int> getAvailableDestChannel(SmallVector<Connect, 8> &connects,
51 Port sourcePort,
52 WireBundle destBundle) {
53
54 if (connects.empty())
55 return {0};
56
57 int numChannels;
58
59 if (destBundle == WireBundle::North)
60 numChannels = 6;
61 else if (destBundle == WireBundle::South || destBundle == WireBundle::East ||
62 destBundle == WireBundle::West)
63 numChannels = 4;
64 else
65 numChannels = 2;
66
67 // look for existing connect
68 for (int i = 0; i < numChannels; i++) {
69 if (Port port = {destBundle, i};
70 std::find(connects.begin(), connects.end(),
71 Connect{sourcePort, port}) != connects.end())
72 return {i};
73 }
74
75 // if not, look for available destination port
76 for (int i = 0; i < numChannels; i++) {
77 Port port = {destBundle, i};
78 SmallVector<Port, 8> ports;
79 for (auto [src, dst] : connects)
80 ports.push_back(dst);
81
82 if (std::find(ports.begin(), ports.end(), port) == ports.end())
83 return {i};
84 }
85
86 return std::nullopt;
87}
88
89void buildRoute(int xSrc, int ySrc, int xDest, int yDest,
90 WireBundle sourceBundle, int sourceChannel,
91 WireBundle destBundle, int destChannel, Operation *herdOp,
92 DenseMap<std::pair<Operation *, TileID>,
93 SmallVector<Connect, 8>> &switchboxes) {
94
95 int xCur = xSrc;
96 int yCur = ySrc;
97 WireBundle curBundle = WireBundle::Core;
98 int curChannel = 0;
99 WireBundle lastBundle = WireBundle::Core;
100 Port lastPort = {sourceBundle, sourceChannel};
101
102 SmallVector<TileID, 4> congestion;
103
104 LLVM_DEBUG(llvm::dbgs() << "Build route: " << xSrc << " " << ySrc << " --> "
105 << xDest << " " << yDest << '\n');
106 // traverse horizontally, then vertically
107 while (!(xCur == xDest && yCur == yDest)) {
108 LLVM_DEBUG(llvm::dbgs() << "coord " << xCur << " " << yCur << '\n');
109 TileID curCoord = {xCur, yCur};
110 SmallVector<WireBundle, 4> moves;
111
112 if (xCur < xDest)
113 moves.push_back(WireBundle::East);
114 if (xCur > xDest)
115 moves.push_back(WireBundle::West);
116 if (yCur < yDest)
117 moves.push_back(WireBundle::North);
118 if (yCur > yDest)
119 moves.push_back(WireBundle::South);
120
121 if (std::find(moves.begin(), moves.end(), WireBundle::East) == moves.end())
122 moves.push_back(WireBundle::East);
123 if (std::find(moves.begin(), moves.end(), WireBundle::West) == moves.end())
124 moves.push_back(WireBundle::West);
125 if (std::find(moves.begin(), moves.end(), WireBundle::North) == moves.end())
126 moves.push_back(WireBundle::North);
127 if (std::find(moves.begin(), moves.end(), WireBundle::South) == moves.end())
128 moves.push_back(WireBundle::South);
129
130 for (auto move : moves) {
131 if (auto maybeDestChannel = getAvailableDestChannel(
132 switchboxes[std::make_pair(herdOp, curCoord)], lastPort, move))
133 curChannel = maybeDestChannel.value();
134 else
135 continue;
136
137 if (move == lastBundle)
138 continue;
139
140 if (move == WireBundle::East)
141 xCur = xCur + 1;
142 // yCur = yCur;
143 else if (move == WireBundle::West)
144 xCur = xCur - 1;
145 // yCur = yCur;
146 else if (move == WireBundle::North)
147 // xCur = xCur;
148 yCur = yCur + 1;
149 else if (move == WireBundle::South)
150 // xCur = xCur;
151 yCur = yCur - 1;
152
153 if (std::find(congestion.begin(), congestion.end(), TileID{xCur, yCur}) !=
154 congestion.end())
155 continue;
156
157 curBundle = move;
158 lastBundle = move == WireBundle::East ? WireBundle::West
159 : move == WireBundle::West ? WireBundle::East
160 : move == WireBundle::North ? WireBundle::South
161 : move == WireBundle::South ? WireBundle::North
162 : lastBundle;
163 break;
164 }
165
166 assert(curChannel >= 0 && "Could not find available destination port!");
167 LLVM_DEBUG(llvm::dbgs()
168 << "[" << stringifyWireBundle(lastPort.bundle) << " : "
169 << lastPort.channel << "], [" << stringifyWireBundle(curBundle)
170 << " : " << curChannel << "]\n");
171
172 Port curPort = {curBundle, curChannel};
173 Connect connect = {lastPort, curPort};
174 if (std::find(switchboxes[std::make_pair(herdOp, curCoord)].begin(),
175 switchboxes[std::make_pair(herdOp, curCoord)].end(),
176 connect) ==
177 switchboxes[std::make_pair(herdOp, curCoord)].end())
178 switchboxes[std::make_pair(herdOp, curCoord)].push_back(connect);
179 lastPort = {lastBundle, curChannel};
180 }
181
182 LLVM_DEBUG(llvm::dbgs() << "coord " << xCur << " " << yCur << '\n');
183 LLVM_DEBUG(llvm::dbgs() << "[" << stringifyWireBundle(lastPort.bundle)
184 << " : " << lastPort.channel << "], ["
185 << stringifyWireBundle(destBundle) << " : "
186 << destChannel << "]\n");
187
188 switchboxes[std::make_pair(herdOp, TileID{xCur, yCur})].push_back(
189 {lastPort, Port{destBundle, destChannel}});
190}
191
193 : xilinx::AIEX::impl::AIEHerdRoutingBase<AIEHerdRoutingPass> {
194 void runOnOperation() override {
195
196 DeviceOp device = getOperation();
197 OpBuilder builder(device.getBody()->getTerminator());
198
199 SmallVector<HerdOp, 4> herds;
200 SmallVector<Operation *, 4> placeOps;
201 SmallVector<Operation *, 4> routeOps;
202 DenseMap<std::pair<Operation *, Operation *>, std::pair<int, int>>
203 distances;
204 SmallVector<std::pair<std::pair<int, int>, std::pair<int, int>>, 4> routes;
205 DenseMap<std::pair<Operation *, TileID>, SmallVector<Connect, 8>>
206 switchboxes;
207
208 for (auto herd : device.getOps<HerdOp>())
209 herds.push_back(herd);
210
211 for (auto placeOp : device.getOps<PlaceOp>()) {
212 placeOps.push_back(placeOp);
213 Operation *sourceHerd = placeOp.getSourceHerd().getDefiningOp();
214 Operation *destHerd = placeOp.getDestHerd().getDefiningOp();
215 int distX = placeOp.getDistXValue();
216 int distY = placeOp.getDistYValue();
217 distances[std::make_pair(sourceHerd, destHerd)] =
218 std::make_pair(distX, distY);
219 }
220
221 // FIXME: multiple route ops with different sourceHerds does not seem to be
222 // aware of the routes done before
223 for (auto routeOp : device.getOps<RouteOp>()) {
224 routeOps.push_back(routeOp);
225
226 auto sourceHerds =
227 dyn_cast<SelectOp>(routeOp.getSourceHerds().getDefiningOp());
228 auto destHerds =
229 dyn_cast<SelectOp>(routeOp.getDestHerds().getDefiningOp());
230 WireBundle sourceBundle = routeOp.getSourceBundle();
231 WireBundle destBundle = routeOp.getDestBundle();
232 int sourceChannel = routeOp.getSourceChannelValue();
233 int destChannel = routeOp.getDestChannelValue();
234
235 HerdOp sourceHerd =
236 dyn_cast<HerdOp>(sourceHerds.getStartHerd().getDefiningOp());
237 IterOp sourceIterX =
238 dyn_cast<IterOp>(sourceHerds.getIterX().getDefiningOp());
239 IterOp sourceIterY =
240 dyn_cast<IterOp>(sourceHerds.getIterY().getDefiningOp());
241
242 HerdOp destHerd =
243 dyn_cast<HerdOp>(destHerds.getStartHerd().getDefiningOp());
244 IterOp destIterX = dyn_cast<IterOp>(destHerds.getIterX().getDefiningOp());
245 IterOp destIterY = dyn_cast<IterOp>(destHerds.getIterY().getDefiningOp());
246
247 int sourceStartX = sourceIterX.getStartValue();
248 int sourceEndX = sourceIterX.getEndValue();
249 int sourceStrideX = sourceIterX.getStrideValue();
250 int sourceStartY = sourceIterY.getStartValue();
251 int sourceEndY = sourceIterY.getEndValue();
252 int sourceStrideY = sourceIterY.getStrideValue();
253
254 int destStartX = destIterX.getStartValue();
255 int destEndX = destIterX.getEndValue();
256 int destStrideX = destIterX.getStrideValue();
257 int destStartY = destIterY.getStartValue();
258 int destEndY = destIterY.getEndValue();
259 int destStrideY = destIterY.getStrideValue();
260
261 assert(distances.count(std::make_pair(sourceHerd, destHerd)) == 1);
262
263 auto [distX, distY] = distances[std::make_pair(sourceHerd, destHerd)];
264 // FIXME: this looks like it can be improved further ...
265 for (int xSrc = sourceStartX; xSrc < sourceEndX; xSrc += sourceStrideX)
266 for (int ySrc = sourceStartY; ySrc < sourceEndY; ySrc += sourceStrideY)
267 for (int xDst = destStartX; xDst < destEndX; xDst += destStrideX)
268 for (int yDst = destStartY; yDst < destEndY; yDst += destStrideY) {
269 // Build route (x0, y0) --> (x1, y1)
270 int x0 = xSrc;
271 int y0 = ySrc;
272 int x1 = xDst;
273 int y1 = yDst;
274 if (destIterX == sourceIterX)
275 x1 = x0;
276 if (destIterY == sourceIterX)
277 y1 = x0;
278 if (destIterX == sourceIterY)
279 x1 = y0;
280 if (destIterY == sourceIterY)
281 y1 = y0;
282
283 auto route = std::make_pair(
284 std::make_pair(x0, y0),
285 std::make_pair(distX + x1 - x0, distY + y1 - y0));
286 if (std::find(routes.begin(), routes.end(), route) !=
287 routes.end())
288 continue;
289
290 buildRoute(x0, y0, x1 + distX, y1 + distY, sourceBundle,
291 sourceChannel, destBundle, destChannel, sourceHerd,
292 switchboxes);
293
294 routes.push_back(route);
295 }
296 }
297
298 for (const auto &swboxCfg : switchboxes) {
299 Operation *herdOp = swboxCfg.first.first;
300 int x = swboxCfg.first.second.col;
301 int y = swboxCfg.first.second.row;
302 auto connects = swboxCfg.second;
303 HerdOp herd = dyn_cast<HerdOp>(herdOp);
304
305 builder.setInsertionPoint(device.getBody()->getTerminator());
306
307 auto iterx =
308 IterOp::create(builder, builder.getUnknownLoc(), x, x + 1, 1);
309 auto itery =
310 IterOp::create(builder, builder.getUnknownLoc(), y, y + 1, 1);
311 auto sel = SelectOp::create(builder, builder.getUnknownLoc(), herd, iterx,
312 itery);
313 auto swbox = SwitchboxOp::create(builder, builder.getUnknownLoc(), sel);
314 SwitchboxOp::ensureTerminator(swbox.getConnections(), builder,
315 builder.getUnknownLoc());
316 Block &b = swbox.getConnections().front();
317 builder.setInsertionPoint(b.getTerminator());
318
319 for (auto [sourcePort, destPort] : connects) {
320 WireBundle sourceBundle = sourcePort.bundle;
321 int sourceChannel = sourcePort.channel;
322 WireBundle destBundle = destPort.bundle;
323 int destChannel = destPort.channel;
324
325 ConnectOp::create(builder, builder.getUnknownLoc(), sourceBundle,
326 sourceChannel, destBundle, destChannel);
327 }
328 }
329
330 ConversionTarget target(getContext());
331
332 RewritePatternSet patterns(&getContext());
334 device.getContext());
335
336 if (failed(applyPartialConversion(device, target, std::move(patterns))))
337 signalPassFailure();
338 }
339};
340
341std::unique_ptr<OperationPass<DeviceOp>> AIEX::createAIEHerdRoutingPass() {
342 return std::make_unique<AIEHerdRoutingPass>();
343}
std::optional< int > getAvailableDestChannel(SmallVector< Connect, 8 > &connects, Port sourcePort, WireBundle destBundle)
void buildRoute(int xSrc, int ySrc, int xDest, int yDest, WireBundle sourceBundle, int sourceChannel, WireBundle destBundle, int destChannel, Operation *herdOp, DenseMap< std::pair< Operation *, TileID >, SmallVector< Connect, 8 > > &switchboxes)
std::unique_ptr< mlir::OperationPass< AIE::DeviceOp > > createAIEHerdRoutingPass()
Include the generated interface declarations.
Connect { Port src Connect
Definition AIEDialect.h:158
TileID { friend std::ostream &operator<<(std::ostream &os, const TileID &s) { os<< "TileID("<< s.col<< ", "<< s.row<< ")" TileID
Port { WireBundle bundle Port
Definition AIEDialect.h:128
PathEndPoint src
void runOnOperation() override
LogicalResult matchAndRewrite(MyOp op, OpAdaptor operands, ConversionPatternRewriter &rewriter) const override
AIEOpRemoval(MLIRContext *context, PatternBenefit benefit=1)
typename MyAIEOp::Adaptor OpAdaptor