15#include "mlir/IR/Attributes.h"
16#include "mlir/Pass/Pass.h"
22#define GEN_PASS_DEF_AIEINSERTTRACEFLOWS
23#include "aie/Dialect/AIE/Transforms/AIEPasses.h.inc"
36 TracePacketType packetType;
41 std::optional<int> stopBroadcast;
50 std::vector<TraceInfo> traceSources;
51 std::optional<int> startBroadcast;
52 std::optional<int> stopBroadcast;
55struct AIEInsertTraceFlowsPass
56 : xilinx::AIE::impl::AIEInsertTraceFlowsBase<AIEInsertTraceFlowsPass> {
58 void runOnOperation()
override {
59 DeviceOp device = getOperation();
60 OpBuilder builder(device);
61 const auto &targetModel = device.getTargetModel();
64 bool hasLogicalTile =
false;
65 device.walk([&](LogicalTileOp op) {
66 op.emitError() <<
"LogicalTileOp must be resolved to TileOp before "
67 "running -aie-insert-trace-flows (run -aie-place-tiles "
69 hasLogicalTile =
true;
72 return signalPassFailure();
75 SmallVector<TraceOp> traces;
76 device.walk([&](TraceOp trace) { traces.push_back(trace); });
82 RuntimeSequenceOp runtimeSeq =
nullptr;
83 TraceHostConfigOp hostConfig =
nullptr;
84 for (
auto &op : device.getBody()->getOperations()) {
85 if (
auto seq = dyn_cast<RuntimeSequenceOp>(&op)) {
87 for (
auto &subOp : seq.getBody().front().getOperations()) {
88 if (
auto hc = dyn_cast<TraceHostConfigOp>(&subOp)) {
100 <<
"aie.trace ops found but no runtime_sequence defined";
101 return signalPassFailure();
106 runtimeSeq.emitError()
107 <<
"runtime_sequence with traces requires aie.trace.host_config";
108 return signalPassFailure();
112 int bufferSizeBytes = hostConfig.getBufferSize();
113 int traceArgIdx = hostConfig.getArgIdx();
114 auto routing = hostConfig.getRouting();
117 int traceBufferOffset = 0;
118 if (traceArgIdx == -1) {
119 auto args = runtimeSeq.getBody().getArguments();
120 assert(!
args.empty() &&
"runtime_sequence must have args for arg_idx=-1");
122 Value lastArg =
args.back();
123 traceArgIdx =
args.size() - 1;
125 auto memrefType = cast<MemRefType>(lastArg.getType());
126 traceBufferOffset = memrefType.getNumElements() *
127 (memrefType.getElementTypeBitWidth() / 8);
134 std::vector<TraceInfo> traceInfos;
135 int nextPacketId = clPacketIdStart;
137 for (
auto trace : traces) {
138 auto tile = cast<TileOp>(trace.getTile().getDefiningOp());
141 std::optional<int> packetId;
142 std::optional<TracePacketType> packetType;
143 TracePacketOp existingPacketOp =
nullptr;
144 for (
auto &op : trace.getBody().getOps()) {
145 if (
auto packetOp = dyn_cast<TracePacketOp>(op)) {
146 existingPacketOp = packetOp;
147 packetId = packetOp.getId();
148 packetType = packetOp.getType();
155 if (tile.isShimTile()) {
156 packetType = TracePacketType::ShimTile;
157 }
else if (tile.isMemTile()) {
158 packetType = TracePacketType::MemTile;
161 packetType = TracePacketType::Core;
167 packetId = nextPacketId++;
172 if (!existingPacketOp) {
173 OpBuilder traceBuilder(&trace.getBody().front(),
174 trace.getBody().front().begin());
175 TracePacketOp::create(traceBuilder, trace.getLoc(), *packetId,
180 WireBundle tracePort = WireBundle::Trace;
181 int traceChannel = 0;
182 if (*packetType == TracePacketType::Mem) {
187 std::optional<int> startBroadcast;
188 std::optional<int> stopBroadcast;
189 bool hasStartConfig =
false;
190 bool hasStopConfig =
false;
191 for (
auto &op : trace.getBody().getOps()) {
192 if (
auto startOp = dyn_cast<TraceStartEventOp>(op)) {
193 hasStartConfig =
true;
194 if (startOp.getBroadcast())
195 startBroadcast = *startOp.getBroadcast();
197 if (
auto stopOp = dyn_cast<TraceStopEventOp>(op)) {
198 hasStopConfig =
true;
199 if (stopOp.getBroadcast())
200 stopBroadcast = *stopOp.getBroadcast();
205 if (!hasStartConfig) {
206 trace.emitError() <<
"trace is missing 'aie.trace.start'";
207 return signalPassFailure();
209 if (!hasStopConfig) {
210 trace.emitError() <<
"trace is missing 'aie.trace.stop'";
211 return signalPassFailure();
215 info.traceOp = trace;
217 info.packetId = *packetId;
218 info.packetType = *packetType;
219 info.tracePort = tracePort;
220 info.traceChannel = traceChannel;
221 info.startBroadcast = startBroadcast;
222 info.stopBroadcast = stopBroadcast;
223 traceInfos.push_back(info);
227 std::map<int, ShimInfo> shimInfos;
229 if (routing == TraceShimRouting::Single) {
232 TileOp shimTile =
nullptr;
233 for (
auto tile : device.getOps<TileOp>()) {
234 if (tile.getCol() == targetCol && tile.getRow() == 0) {
241 builder.setInsertionPointToStart(&device.getRegion().front());
242 shimTile = TileOp::create(builder, device.getLoc(), targetCol, 0);
246 shimInfo.shimTile = shimTile;
247 shimInfo.channel = clShimChannel;
248 shimInfo.bdId = clDefaultBdId;
249 shimInfo.argIdx = traceArgIdx;
250 shimInfo.bufferOffset = traceBufferOffset;
251 shimInfo.traceSources = traceInfos;
253 for (
auto &trace : traceInfos) {
254 if (trace.startBroadcast && !shimInfo.startBroadcast)
255 shimInfo.startBroadcast = trace.startBroadcast;
256 if (trace.stopBroadcast && !shimInfo.stopBroadcast)
257 shimInfo.stopBroadcast = trace.stopBroadcast;
261 for (
auto &info : traceInfos) {
262 int col = info.tile.getCol();
263 if (shimInfos.find(
col) == shimInfos.end()) {
264 shimInfos[
col] = shimInfo;
267 if (shimInfos.find(targetCol) == shimInfos.end()) {
268 shimInfos[targetCol] = shimInfo;
274 for (
auto &info : traceInfos) {
275 if (!info.tile.isShimTile())
278 int col = info.tile.getCol();
279 auto shimIt = shimInfos.find(
col);
280 if (shimIt == shimInfos.end())
284 if (shimIt->second.shimTile.getTileID() != info.tile.getTileID())
289 for (
auto &op : info.traceOp.getBody().getOps()) {
290 if (
auto startOp = dyn_cast<TraceStartEventOp>(op)) {
291 if (startOp.getBroadcast()) {
293 startOp->removeAttr(
"broadcast");
294 startOp->setAttr(
"event", TraceEventAttr::get(builder.getContext(),
295 builder.getStringAttr(
299 if (
auto stopOp = dyn_cast<TraceStopEventOp>(op)) {
300 if (stopOp.getBroadcast()) {
302 stopOp->removeAttr(
"broadcast");
303 stopOp->setAttr(
"event", TraceEventAttr::get(builder.getContext(),
304 builder.getStringAttr(
310 info.startBroadcast = std::nullopt;
311 info.stopBroadcast = std::nullopt;
316 Block &deviceBlock = device.getRegion().front();
317 builder.setInsertionPoint(deviceBlock.getTerminator());
319 for (
auto &info : traceInfos) {
321 int col = info.tile.getCol();
322 ShimInfo &shimInfo = shimInfos[
col];
325 auto packetFlowOp = PacketFlowOp::create(
326 builder, device.getLoc(), builder.getI8IntegerAttr(info.packetId),
329 Block *flowBody =
new Block();
330 packetFlowOp.getPorts().push_back(flowBody);
331 OpBuilder flowBuilder = OpBuilder::atBlockEnd(flowBody);
333 PacketSourceOp::create(flowBuilder, device.getLoc(),
334 Value(info.tile.getResult()), info.tracePort,
337 PacketDestOp::create(flowBuilder, device.getLoc(),
338 Value(shimInfo.shimTile.getResult()),
339 WireBundle::DMA, shimInfo.channel);
341 EndOp::create(flowBuilder, device.getLoc());
343 packetFlowOp->setAttr(
"keep_pkt_header", builder.getBoolAttr(
true));
347 Block &seqBlock = runtimeSeq.getBody().front();
350 Operation *lastStartConfig =
nullptr;
351 for (
auto &op : seqBlock.getOperations()) {
352 if (isa<TraceStartConfigOp>(op)) {
353 lastStartConfig = &op;
358 if (lastStartConfig) {
359 builder.setInsertionPointAfter(lastStartConfig);
361 builder.setInsertionPointToStart(&seqBlock);
365 std::set<std::tuple<int, int, bool>> processedTiles;
366 for (
auto &info : traceInfos) {
367 if (!info.startBroadcast)
370 int col = info.tile.getCol();
371 int row = info.tile.getRow();
372 bool isMemTrace = info.packetType == TracePacketType::Mem;
374 if (processedTiles.count({col, row, isMemTrace}))
376 processedTiles.insert({
col,
row, isMemTrace});
379 if (info.tile.isShimTile()) {
380 auto shimIt = shimInfos.find(
col);
381 if (shimIt != shimInfos.end() &&
382 shimIt->second.shimTile.getTileID() == info.tile.getTileID()) {
388 uint32_t timerCtrlAddr =
389 computeTimerCtrlAddress(info.tile, targetModel, isMemTrace);
391 std::string broadcastEventName;
392 if (info.tile.isShimTile()) {
394 "BROADCAST_A_" + std::to_string(*info.startBroadcast);
397 "BROADCAST_" + std::to_string(*info.startBroadcast);
400 auto broadcastEvent = targetModel.lookupEvent(
401 broadcastEventName, info.tile.getTileID(), isMemTrace);
402 if (!broadcastEvent) {
403 info.traceOp.emitError() <<
"Failed to lookup broadcast event '"
404 << broadcastEventName <<
"'";
405 return signalPassFailure();
407 uint32_t timerCtrlValue = *broadcastEvent << 8;
409 xilinx::AIEX::NpuWrite32Op::create(
410 builder, runtimeSeq.getLoc(), timerCtrlAddr, timerCtrlValue,
nullptr,
411 builder.getI32IntegerAttr(
col), builder.getI32IntegerAttr(
row));
415 std::set<int> configuredShimCols;
416 for (
auto &[
col, shimInfo] : shimInfos) {
417 int shimCol = shimInfo.shimTile.getCol();
418 if (!configuredShimCols.insert(shimCol).second)
422 int bufferLengthWords = bufferSizeBytes / 4;
425 xilinx::AIEX::NpuWriteBdOp::create(
426 builder, runtimeSeq.getLoc(),
450 uint32_t bdAddress = computeBDAddress(shimCol, shimInfo.bdId,
451 shimInfo.shimTile, targetModel);
452 xilinx::AIEX::NpuAddressPatchOp::create(builder, runtimeSeq.getLoc(),
453 bdAddress, shimInfo.argIdx,
454 shimInfo.bufferOffset);
458 computeCtrlAddress(DMAChannelDir::S2MM, shimInfo.channel,
459 shimInfo.shimTile, targetModel);
460 xilinx::AIEX::NpuMaskWrite32Op::create(
461 builder, runtimeSeq.getLoc(), ctrlAddr, 3840, 7936,
462 nullptr, builder.getI32IntegerAttr(shimCol),
463 builder.getI32IntegerAttr(0));
466 uint32_t taskQueueAddr =
467 computeTaskQueueAddress(DMAChannelDir::S2MM, shimInfo.channel,
468 shimInfo.shimTile, targetModel);
469 uint32_t bdIdWithToken = (1U << 31) | shimInfo.bdId;
470 xilinx::AIEX::NpuWrite32Op::create(
471 builder, runtimeSeq.getLoc(), taskQueueAddr, bdIdWithToken,
nullptr,
472 builder.getI32IntegerAttr(shimCol), builder.getI32IntegerAttr(0));
475 if (shimInfo.startBroadcast) {
477 auto userEvent1 = targetModel.lookupEvent(
478 "USER_EVENT_1", shimInfo.shimTile.getTileID(),
false);
480 llvm::report_fatal_error(
"Failed to lookup USER_EVENT_1 event");
482 uint32_t shimTimerCtrlAddr =
483 computeTimerCtrlAddress(shimInfo.shimTile, targetModel,
false);
484 xilinx::AIEX::NpuWrite32Op::create(
485 builder, runtimeSeq.getLoc(), shimTimerCtrlAddr, *userEvent1 << 8,
486 nullptr, builder.getI32IntegerAttr(shimCol),
487 builder.getI32IntegerAttr(0));
490 std::string broadcastRegName =
491 "Event_Broadcast" + std::to_string(*shimInfo.startBroadcast) +
"_A";
492 const RegisterInfo *broadcastReg = targetModel.lookupRegister(
493 broadcastRegName, shimInfo.shimTile.getTileID());
495 llvm::report_fatal_error(llvm::Twine(
"Failed to lookup ") +
497 xilinx::AIEX::NpuWrite32Op::create(
498 builder, runtimeSeq.getLoc(), broadcastReg->
offset, *userEvent1,
499 nullptr, builder.getI32IntegerAttr(shimCol),
500 builder.getI32IntegerAttr(0));
503 const RegisterInfo *eventGenReg = targetModel.lookupRegister(
504 "Event_Generate", shimInfo.shimTile.getTileID());
506 llvm::report_fatal_error(
"Failed to lookup Event_Generate register");
507 xilinx::AIEX::NpuWrite32Op::create(
508 builder, runtimeSeq.getLoc(), eventGenReg->
offset, *userEvent1,
509 nullptr, builder.getI32IntegerAttr(shimCol),
510 builder.getI32IntegerAttr(0));
515 builder.setInsertionPointToEnd(&seqBlock);
517 std::set<int> stoppedShimCols;
518 for (
auto &[
col, shimInfo] : shimInfos) {
519 if (!shimInfo.stopBroadcast)
522 int shimCol = shimInfo.shimTile.getCol();
523 if (!stoppedShimCols.insert(shimCol).second)
526 auto userEvent0 = targetModel.lookupEvent(
527 "USER_EVENT_0", shimInfo.shimTile.getTileID(),
false);
529 llvm::report_fatal_error(
"Failed to lookup USER_EVENT_0 event");
531 std::string broadcastRegName =
532 "Event_Broadcast" + std::to_string(*shimInfo.stopBroadcast) +
"_A";
533 const RegisterInfo *broadcastReg = targetModel.lookupRegister(
534 broadcastRegName, shimInfo.shimTile.getTileID());
536 llvm::report_fatal_error(llvm::Twine(
"Failed to lookup ") +
538 xilinx::AIEX::NpuWrite32Op::create(
539 builder, runtimeSeq.getLoc(), broadcastReg->
offset, *userEvent0,
540 nullptr, builder.getI32IntegerAttr(shimCol),
541 builder.getI32IntegerAttr(0));
543 const RegisterInfo *stopEventGenReg = targetModel.lookupRegister(
544 "Event_Generate", shimInfo.shimTile.getTileID());
545 if (!stopEventGenReg)
546 llvm::report_fatal_error(
"Failed to lookup Event_Generate register");
547 xilinx::AIEX::NpuWrite32Op::create(
548 builder, runtimeSeq.getLoc(), stopEventGenReg->
offset, *userEvent0,
549 nullptr, builder.getI32IntegerAttr(shimCol),
550 builder.getI32IntegerAttr(0));
556 uint32_t computeBDAddress(
int col,
int bdId, TileOp shimTile,
563 llvm::report_fatal_error(
"Failed to lookup DMA_BD0_0 register");
564 const uint32_t BD_STRIDE = 0x20;
565 const uint32_t BUFFER_ADDR_OFFSET = 4;
567 (bdReg->
offset + bdId * BD_STRIDE + BUFFER_ADDR_OFFSET);
571 uint32_t computeTaskQueueAddress(DMAChannelDir dir,
int channel,
574 if (dir == DMAChannelDir::S2MM) {
576 (
channel == 0) ?
"DMA_S2MM_0_Task_Queue" :
"DMA_S2MM_1_Task_Queue";
579 (
channel == 0) ?
"DMA_MM2S_0_Task_Queue" :
"DMA_MM2S_1_Task_Queue";
583 llvm::report_fatal_error(llvm::Twine(
"Failed to lookup ") + regName);
588 uint32_t computeCtrlAddress(DMAChannelDir dir,
int channel, TileOp shimTile,
591 if (dir == DMAChannelDir::S2MM) {
592 regName = (
channel == 0) ?
"DMA_S2MM_0_Ctrl" :
"DMA_S2MM_1_Ctrl";
594 regName = (
channel == 0) ?
"DMA_MM2S_0_Ctrl" :
"DMA_MM2S_1_Ctrl";
598 llvm::report_fatal_error(llvm::Twine(
"Failed to lookup ") + regName);
603 uint32_t computeTimerCtrlAddress(TileOp tile,
const AIETargetModel &tm,
609 llvm::report_fatal_error(
"Failed to lookup Timer_Control register");
616std::unique_ptr<OperationPass<DeviceOp>>
618 return std::make_unique<AIEInsertTraceFlowsPass>();
const RegisterInfo * lookupRegister(llvm::StringRef name, TileID tile, bool isMem=false) const
Register Database API - provides access to register and event information for trace configuration and...
virtual uint32_t getColumnShift() const =0
Include the generated interface declarations.
std::unique_ptr< mlir::OperationPass< DeviceOp > > createAIEInsertTraceFlowsPass()