14#include "llvm/Support/Debug.h"
15#include "llvm/Support/raw_os_ostream.h"
17#include "llvm/ADT/MapVector.h"
23#define DEBUG_TYPE "aie-pathfinder"
26 LLVM_DEBUG(llvm::dbgs() <<
"\t---Begin DynamicTileAnalysis Constructor---\n");
28 maxCol = device.getTargetModel().columns();
29 maxRow = device.getTargetModel().rows();
37 for (PacketFlowOp pktFlowOp : device.getOps<PacketFlowOp>()) {
38 Region &r = pktFlowOp.getPorts();
40 Port srcPort, dstPort;
41 TileOp srcTile, dstTile;
44 for (Operation &Op : b.getOperations()) {
45 if (
auto pktSource = dyn_cast<PacketSourceOp>(Op)) {
46 srcTile = dyn_cast<TileOp>(pktSource.getTile().getDefiningOp());
47 srcPort = pktSource.port();
48 srcCoords = {srcTile.colIndex(), srcTile.rowIndex()};
52 return pktFlowOp.emitOpError(
"packet_flow has no packet_source");
55 for (Operation &Op : b.getOperations()) {
56 if (
auto pktDest = dyn_cast<PacketDestOp>(Op)) {
57 dstTile = dyn_cast<TileOp>(pktDest.getTile().getDefiningOp());
58 dstPort = pktDest.port();
59 dstCoords = {dstTile.colIndex(), dstTile.rowIndex()};
60 LLVM_DEBUG(llvm::dbgs()
61 <<
"\tAdding Packet Flow: (" <<
srcCoords.col <<
", "
63 << stringifyWireBundle(srcPort.bundle) << srcPort.channel
65 << stringifyWireBundle(dstPort.bundle) << dstPort.channel
69 pktFlowOp.getPriorityRoute()
70 ? *pktFlowOp.getPriorityRoute()
80 pathfinder->sortFlows(device.getTargetModel().columns(),
81 device.getTargetModel().rows());
84 for (FlowOp flowOp : device.getOps<FlowOp>()) {
85 TileOp srcTile = cast<TileOp>(flowOp.getSource().getDefiningOp());
86 TileOp dstTile = cast<TileOp>(flowOp.getDest().getDefiningOp());
89 Port srcPort = {flowOp.getSourceBundle(), flowOp.getSourceChannel()};
90 Port dstPort = {flowOp.getDestBundle(), flowOp.getDestChannel()};
91 LLVM_DEBUG(llvm::dbgs()
93 <<
")" << stringifyWireBundle(srcPort.bundle) << srcPort.channel
95 << stringifyWireBundle(dstPort.bundle) << dstPort.channel
103 for (SwitchboxOp switchboxOp : device.getOps<SwitchboxOp>()) {
104 if (!
pathfinder->addFixedConnection(switchboxOp))
105 return switchboxOp.emitOpError() <<
"Unable to add fixed connections";
114 return device.emitError(
"Unable to find a legal routing");
122 for (
auto tileOp : device.getOps<TileOp>()) {
124 col = tileOp.colIndex();
125 row = tileOp.rowIndex();
129 for (
auto switchboxOp : device.getOps<SwitchboxOp>()) {
130 int col = switchboxOp.colIndex();
131 int row = switchboxOp.rowIndex();
135 for (
auto shimmuxOp : device.getOps<ShimMuxOp>()) {
136 int col = shimmuxOp.colIndex();
137 int row = shimmuxOp.rowIndex();
142 LLVM_DEBUG(llvm::dbgs() <<
"\t---End DynamicTileAnalysis Constructor---\n");
150 auto tileOp = TileOp::create(builder, builder.getUnknownLoc(),
col,
row);
156 return getTile(builder, tileId.col, tileId.row);
166 auto switchboxOp = SwitchboxOp::create(builder, builder.getUnknownLoc(),
168 SwitchboxOp::ensureTerminator(switchboxOp.getConnections(), builder,
169 builder.getUnknownLoc());
181 auto switchboxOp = ShimMuxOp::create(builder, builder.getUnknownLoc(),
183 SwitchboxOp::ensureTerminator(switchboxOp.getConnections(), builder,
184 builder.getUnknownLoc());
192 std::map<WireBundle, int> maxChannels;
193 auto intraconnect = [&](
int col,
int row) {
197 for (
int i = 0, e = getMaxEnumValForWireBundle() + 1; i < e; ++i) {
198 WireBundle bundle = symbolizeWireBundle(i).value();
218 maxChannels[bundle] = channels;
222 for (
size_t i = 0; i < sb.srcPorts.size(); i++) {
223 for (
size_t j = 0; j < sb.dstPorts.size(); j++) {
224 auto &pIn = sb.srcPorts[i];
225 auto &pOut = sb.dstPorts[j];
227 pOut.bundle, pOut.channel))
233 auto isBundleInList = [](WireBundle bundle,
234 std::vector<WireBundle> bundles) {
235 return std::find(bundles.begin(), bundles.end(), bundle) !=
238 const std::vector<WireBundle> bundles = {
239 WireBundle::DMA, WireBundle::NOC, WireBundle::PLIO};
240 if (isBundleInList(pIn.bundle, bundles) ||
241 isBundleInList(pOut.bundle, bundles))
250 auto interconnect = [&](
int col,
int row,
int targetCol,
int targetRow,
251 WireBundle srcBundle, WireBundle dstBundle) {
258 for (
size_t i = 0; i < sb.srcPorts.size(); i++) {
264 for (
int row = 0;
row <= maxRow;
row++) {
265 for (
int col = 0;
col <= maxCol;
col++) {
300 for (
auto &[_, prioritized,
src,
dsts] : flows) {
318 for (
auto &[existingId, _,
src,
dsts] : flows) {
346 std::vector<Flow> priorityFlows;
347 std::vector<Flow> normalFlows;
348 for (
auto f : flows) {
349 if (f.isPriorityFlow)
350 priorityFlows.push_back(f);
352 normalFlows.push_back(f);
354 std::sort(priorityFlows.begin(), priorityFlows.end(),
355 [](
const auto &lhs,
const auto &rhs) {
359 std::make_tuple(lhs.src.coords.col, lhs.src.coords.row,
360 getWireBundleAsInt(lhs.src.port.bundle),
361 lhs.src.port.channel);
363 std::make_tuple(rhs.src.coords.col, rhs.src.coords.row,
364 getWireBundleAsInt(rhs.src.port.bundle),
365 rhs.src.port.channel);
366 return lhsKey < rhsKey;
368 flows = priorityFlows;
369 flows.insert(flows.end(), normalFlows.begin(), normalFlows.end());
375 int col = switchboxOp.colIndex();
376 int row = switchboxOp.rowIndex();
379 for (ConnectOp connectOp : switchboxOp.getOps<ConnectOp>()) {
381 for (
size_t i = 0; i < sb.srcPorts.size(); i++) {
382 if (sb.srcPorts[i] != connectOp.sourcePort())
387 for (
size_t j = 0; j < sb.dstPorts.size(); j++) {
388 if (sb.dstPorts[j] == connectOp.destPort() &&
403static constexpr double INF = std::numeric_limits<double>::max();
405std::map<PathEndPoint, PathEndPoint>
409 std::map<PathEndPoint, double> distance;
410 std::map<PathEndPoint, PathEndPoint> preds;
411 std::map<PathEndPoint, uint64_t> indexInHeap;
412 enum Color { WHITE, GRAY, BLACK };
413 std::map<PathEndPoint, Color> colors;
416 std::map<PathEndPoint, uint64_t>,
417 std::map<PathEndPoint, double> &,
420 MutableQueue Q(distance, indexInHeap);
429 if (channels.count(
src) == 0) {
430 auto &sb = graph[std::make_pair(
src.coords,
src.coords)];
431 for (
size_t i = 0; i < sb.srcPorts.size(); i++) {
432 for (
size_t j = 0; j < sb.dstPorts.size(); j++) {
433 if (sb.srcPorts[i] ==
src.port &&
441 std::vector<std::pair<TileID, Port>> neighbors = {
442 {{
src.coords.col,
src.coords.row - 1},
443 {WireBundle::North,
src.port.channel}},
444 {{
src.coords.col - 1,
src.coords.row},
445 {WireBundle::East,
src.port.channel}},
446 {{
src.coords.col,
src.coords.row + 1},
447 {WireBundle::South,
src.port.channel}},
448 {{
src.coords.col + 1,
src.coords.row},
449 {WireBundle::West,
src.port.channel}}};
451 for (
const auto &[neighborCoords, neighborPort] : neighbors) {
452 if (graph.count(std::make_pair(
src.coords, neighborCoords)) > 0 &&
454 auto &sb = graph[std::make_pair(
src.coords, neighborCoords)];
455 if (std::find(sb.dstPorts.begin(), sb.dstPorts.end(), neighborPort) !=
457 channels[
src].push_back({neighborCoords, neighborPort});
460 std::sort(channels[
src].begin(), channels[
src].end());
463 for (
auto &dest : channels[
src]) {
464 if (distance.count(dest) == 0)
465 distance[dest] = INF;
466 auto &sb = graph[std::make_pair(
src.coords, dest.coords)];
467 size_t i = std::distance(
469 std::find(sb.srcPorts.begin(), sb.srcPorts.end(),
src.port));
470 size_t j = std::distance(
472 std::find(sb.dstPorts.begin(), sb.dstPorts.end(), dest.port));
473 assert(i < sb.srcPorts.size());
474 assert(j < sb.dstPorts.size());
475 bool relax = distance[
src] + sb.demand[i][j] < distance[dest];
476 if (colors.count(dest) == 0) {
479 distance[dest] = distance[
src] + sb.demand[i][j];
484 }
else if (colors[dest] == GRAY && relax) {
485 distance[dest] = distance[
src] + sb.demand[i][j];
501std::optional<std::map<PathEndPoint, SwitchSettings>>
503 LLVM_DEBUG(llvm::dbgs() <<
"\t---Begin Pathfinder::findPaths---\n");
504 std::map<PathEndPoint, SwitchSettings> routingSolution;
506 for (
auto &[_, sb] : graph) {
507 for (
size_t i = 0; i < sb.srcPorts.size(); i++) {
508 for (
size_t j = 0; j < sb.dstPorts.size(); j++) {
509 sb.usedCapacity[i][j] = 0;
510 sb.overCapacity[i][j] = 0;
511 sb.isPriority[i][j] =
false;
517 llvm::MapVector<int, std::vector<Flow>> groupedFlows;
518 for (
auto &f : flows) {
519 if (groupedFlows.count(f.packetGroupId) == 0) {
520 groupedFlows[f.packetGroupId] = std::vector<Flow>();
522 groupedFlows[f.packetGroupId].push_back(f);
525 int iterationCount = -1;
526 int illegalEdges = 0;
528 int totalPathLength = 0;
532 if (++iterationCount >= maxIterations) {
533 LLVM_DEBUG(llvm::dbgs()
534 <<
"\t\tPathfinder: maxIterations has been exceeded ("
536 <<
" iterations)...unable to find routing for flows.\n");
540 LLVM_DEBUG(llvm::dbgs() <<
"\t\t---Begin findPaths iteration #"
541 << iterationCount <<
"---\n");
543 for (
auto &[_, sb] : graph) {
552 routingSolution.clear();
553 for (
auto &[_, sb] : graph) {
554 for (
size_t i = 0; i < sb.srcPorts.size(); i++) {
555 for (
size_t j = 0; j < sb.dstPorts.size(); j++) {
556 sb.usedCapacity[i][j] = 0;
557 sb.packetFlowCount[i][j] = 0;
558 sb.packetGroupId[i][j] = -1;
566 for (
const auto &[_, flows] : groupedFlows) {
572 std::set<PathEndPoint> processed;
578 processed.insert(
src);
579 for (
auto endPoint :
dsts) {
580 if (endPoint ==
src) {
582 switchSettings[
src.coords].srcs.push_back(
src.port);
583 switchSettings[
src.coords].dsts.push_back(
src.port);
585 auto curr = endPoint;
587 while (!processed.count(curr)) {
588 auto &sb = graph[std::make_pair(preds[curr].
coords, curr.coords)];
590 std::distance(sb.srcPorts.begin(),
591 std::find(sb.srcPorts.begin(), sb.srcPorts.end(),
593 size_t j = std::distance(
595 std::find(sb.dstPorts.begin(), sb.dstPorts.end(), curr.port));
596 assert(i < sb.srcPorts.size());
597 assert(j < sb.dstPorts.size());
600 (sb.packetGroupId[i][j] == -1 ||
602 for (
size_t k = 0; k < sb.srcPorts.size(); k++) {
603 for (
size_t l = 0; l < sb.dstPorts.size(); l++) {
604 if (k == i || l == j) {
609 sb.packetFlowCount[i][j]++;
612 sb.packetFlowCount[i][j] = 0;
613 sb.usedCapacity[i][j]++;
616 sb.usedCapacity[i][j]++;
621 if (preds[curr].
coords == curr.coords) {
622 switchSettings[preds[curr].coords].srcs.push_back(
624 switchSettings[curr.coords].dsts.push_back(curr.port);
626 processed.insert(curr);
631 routingSolution[
src] = switchSettings;
633 for (
auto &[_, sb] : graph) {
634 for (
size_t i = 0; i < sb.srcPorts.size(); i++) {
635 for (
size_t j = 0; j < sb.dstPorts.size(); j++) {
637 if (sb.packetFlowCount[i][j] > 0) {
638 sb.packetFlowCount[i][j] = 0;
639 sb.usedCapacity[i][j]++;
647 for (
auto &[_, sb] : graph) {
648 for (
size_t i = 0; i < sb.srcPorts.size(); i++) {
649 for (
size_t j = 0; j < sb.dstPorts.size(); j++) {
652 sb.overCapacity[i][j]++;
656 <<
"\t\t\tToo much capacity on (" << sb.srcCoords.col <<
","
657 << sb.srcCoords.row <<
") " << sb.srcPorts[i].bundle
658 << sb.srcPorts[i].channel <<
" -> (" << sb.dstCoords.col <<
","
659 << sb.dstCoords.row <<
") " << sb.dstPorts[j].bundle
660 << sb.dstPorts[j].channel <<
", used_capacity = "
661 << sb.usedCapacity[i][j] <<
", demand = " << sb.demand[i][j]
662 <<
", over_capacity_count = " << sb.overCapacity[i][j] <<
"\n");
666 if (sb.srcCoords != sb.dstCoords) {
667 totalPathLength += sb.usedCapacity[i][j];
675 for (
const auto &[
PathEndPoint, switchSetting] : routingSolution) {
676 LLVM_DEBUG(llvm::dbgs()
677 <<
"\t\t\tFlow starting at (" <<
PathEndPoint.coords.col <<
","
679 LLVM_DEBUG(llvm::dbgs() << switchSetting);
681 LLVM_DEBUG(llvm::dbgs()
682 <<
"\t\t---End findPaths iteration #" << iterationCount
683 <<
" , illegal edges count = " << illegalEdges
684 <<
", total path length = " << totalPathLength <<
"---\n");
686 }
while (illegalEdges >
689 LLVM_DEBUG(llvm::dbgs() <<
"\t---End Pathfinder::findPaths---\n");
690 return routingSolution;
695 return static_cast<typename std::underlying_type<WireBundle>::type
>(bundle);
#define MAX_CIRCUIT_STREAM_CAPACITY
#define MAX_PACKET_STREAM_CAPACITY
virtual uint32_t getNumSourceShimMuxConnections(int col, int row, WireBundle bundle) const =0
Return the number of sources of connections inside a shimmux.
virtual bool isLegalTileConnection(int col, int row, WireBundle srcBundle, int srcChan, WireBundle dstBundle, int dstChan) const =0
bool isShimNOCorPLTile(int col, int row) const
Return true if the given tile is either a ShimNOC or ShimPL tile.
virtual uint32_t getNumDestShimMuxConnections(int col, int row, WireBundle bundle) const =0
Return the number of destinations of connections inside a shimmux.
virtual uint32_t getNumDestSwitchboxConnections(int col, int row, WireBundle bundle) const =0
Return the number of destinations of connections inside a switchbox.
virtual uint32_t getNumSourceSwitchboxConnections(int col, int row, WireBundle bundle) const =0
Return the number of sources of connections inside a switchbox.
ShimMuxOp getShimMux(mlir::OpBuilder &builder, int col)
SwitchboxOp getSwitchbox(mlir::OpBuilder &builder, int col, int row)
llvm::DenseMap< TileID, SwitchboxOp > coordToSwitchbox
llvm::DenseMap< TileID, ShimMuxOp > coordToShimMux
TileOp getTile(mlir::OpBuilder &builder, int col, int row)
std::shared_ptr< Router > pathfinder
std::map< PathEndPoint, bool > processedFlows
mlir::LogicalResult runAnalysis(DeviceOp &device)
llvm::DenseMap< TileID, TileOp > coordToTile
std::map< PathEndPoint, SwitchSettings > flowSolutions
std::map< PathEndPoint, PathEndPoint > dijkstraShortestPaths(PathEndPoint src)
bool addFixedConnection(SwitchboxOp switchboxOp) override
void initialize(int maxCol, int maxRow, const AIETargetModel &targetModel) override
std::optional< std::map< PathEndPoint, SwitchSettings > > findPaths(int maxIterations) override
void sortFlows(const int maxCol, const int maxRow) override
void addFlow(TileID srcCoords, Port srcPort, TileID dstCoords, Port dstPort, bool isPacketFlow, bool isPriorityFlow) override
Include the generated interface declarations.
int getWireBundleAsInt(WireBundle bundle)
SwitchboxConnect { SwitchboxConnect()=default SwitchboxConnect
std::vector< PathEndPoint > dsts
TileID { friend std::ostream &operator<<(std::ostream &os, const TileID &s) { os<< "TileID("<< s.col<< ", "<< s.row<< ")" TileID
std::vector< std::vector< int > > packetGroupId
std::map< TileID, SwitchSetting > SwitchSettings
Port { WireBundle bundle Port
std::vector< std::vector< bool > > isPriority
Flow { int packetGroupId Flow
PathEndPoint { PathEndPoint()=default PathEndPoint
WireBundle getConnectingBundle(WireBundle dir)