MLIR-AIE
AIETraceToConfig.cpp
Go to the documentation of this file.
1//===- AIETraceToConfig.cpp -------------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// Copyright (C) 2025, Advanced Micro Devices, Inc.
8//
9//===----------------------------------------------------------------------===//
10// Pass to lower aie.trace to aie.trace.config
11//===----------------------------------------------------------------------===//
12
15
16#include "mlir/IR/Attributes.h"
17#include "mlir/Pass/Pass.h"
18
19namespace xilinx::AIE {
20#define GEN_PASS_DEF_AIETRACETOCONFIG
21#define GEN_PASS_DEF_AIETRACEREGPACKWRITES
22#include "aie/Dialect/AIE/Transforms/AIEPasses.h.inc"
23} // namespace xilinx::AIE
24
25using namespace mlir;
26using namespace xilinx;
27using namespace xilinx::AIE;
28
29namespace {
30
31struct AIETraceToConfigPass
32 : xilinx::AIE::impl::AIETraceToConfigBase<AIETraceToConfigPass> {
33 void runOnOperation() override {
34 DeviceOp device = getOperation();
35 OpBuilder builder(device);
36 const auto &targetModel = device.getTargetModel();
37
38 // Collect all trace operations
39 SmallVector<TraceOp> traces;
40 device.walk([&](TraceOp trace) { traces.push_back(trace); });
41
42 for (auto trace : traces) {
43 // Create config symbol name
44 std::string configName = (trace.getSymName().str() + "_config");
45 auto tile = cast<TileOp>(trace.getTile().getDefiningOp());
46 TileID tileID = tile.getTileID();
47
48 // Find packet type (if any)
49 TracePacketType packetType = TracePacketType::Core; // default
50 for (auto &op : trace.getBody().getOps()) {
51 if (auto packetOp = dyn_cast<TracePacketOp>(op)) {
52 packetType = packetOp.getType();
53 break;
54 }
55 }
56
57 // Insert trace.config after trace declaration
58 builder.setInsertionPointAfter(trace);
59 auto configOp = builder.create<TraceConfigOp>(
60 trace.getLoc(), trace.getTile(), builder.getStringAttr(configName),
61 TracePacketTypeAttr::get(builder.getContext(), packetType));
62
63 // Build register writes inside config body
64 Block *configBody = new Block();
65 configOp.getBody().push_back(configBody);
66 OpBuilder configBuilder = OpBuilder::atBlockEnd(configBody);
67
68 bool isMem = (packetType == TracePacketType::Mem);
69
70 // Process combo/edge events FIRST (before other trace config)
71 // This ensures COMBO_EVENT_*/EDGE_DETECTION_EVENT_* are configured
72 // before they can be referenced in trace.event operations
73
74 // 0a. Emit combo event configurations
75 for (auto &op : trace.getBody().getOps()) {
76 if (auto comboOp = dyn_cast<TraceComboEventOp>(op)) {
77 uint32_t slot = comboOp.getSlot();
78
79 // Get input events - use getEventName() helper
80 std::string eventAName = comboOp.getEventA().getEventName();
81 std::string eventBName = comboOp.getEventB().getEventName();
82 ComboLogic logic = comboOp.getLogic();
83
84 // If enum, use enum value directly; otherwise lookup by name
85 std::optional<uint32_t> eventANum, eventBNum;
86
87 if (auto enumValA = comboOp.getEventA().getEnumValue()) {
88 eventANum = static_cast<uint32_t>(*enumValA);
89 } else {
90 eventANum = targetModel.lookupEvent(eventAName, tileID, isMem);
91 }
92
93 if (auto enumValB = comboOp.getEventB().getEnumValue()) {
94 eventBNum = static_cast<uint32_t>(*enumValB);
95 } else {
96 eventBNum = targetModel.lookupEvent(eventBName, tileID, isMem);
97 }
98
99 if (!eventANum) {
100 comboOp.emitError("unknown trace event '") << eventAName << "'";
101 return signalPassFailure();
102 }
103 if (!eventBNum) {
104 comboOp.emitError("unknown trace event '") << eventBName << "'";
105 return signalPassFailure();
106 }
107
108 // Map slot to input event fields
109 StringRef eventAField, eventBField, controlField;
110 if (slot == 0) {
111 eventAField = "eventA";
112 eventBField = "eventB";
113 controlField = "combo0";
114 } else if (slot == 1) {
115 eventAField = "eventC";
116 eventBField = "eventD";
117 controlField = "combo1";
118 } else if (slot == 2) {
119 // Combo2 is hierarchical - reuses eventA/B fields but represents
120 // combo0/combo1
121 eventAField = "eventA";
122 eventBField = "eventB";
123 controlField = "combo2";
124 }
125
126 // Emit Combo_event_inputs register fields
127 configBuilder.create<TraceRegOp>(
128 comboOp.getLoc(), builder.getStringAttr("Combo_event_inputs"),
129 builder.getStringAttr(eventAField), comboOp.getEventA(),
130 /*mask=*/nullptr,
131 builder.getStringAttr("combo" + std::to_string(slot) +
132 " eventA"));
133
134 configBuilder.create<TraceRegOp>(
135 comboOp.getLoc(), builder.getStringAttr("Combo_event_inputs"),
136 builder.getStringAttr(eventBField), comboOp.getEventB(),
137 /*mask=*/nullptr,
138 builder.getStringAttr("combo" + std::to_string(slot) +
139 " eventB"));
140
141 // Emit Combo_event_control register field
142 configBuilder.create<TraceRegOp>(
143 comboOp.getLoc(), builder.getStringAttr("Combo_event_control"),
144 builder.getStringAttr(controlField),
145 builder.getI32IntegerAttr(static_cast<uint32_t>(logic)),
146 /*mask=*/nullptr,
147 builder.getStringAttr("combo" + std::to_string(slot) + " logic"));
148 }
149 }
150
151 // 0b. Emit edge detection configurations
152 for (auto &op : trace.getBody().getOps()) {
153 if (auto edgeOp = dyn_cast<TraceEdgeEventOp>(op)) {
154 uint32_t slot = edgeOp.getSlot();
155 std::string eventName = edgeOp.getEvent().getEventName();
156 EdgeTrigger trigger = edgeOp.getTrigger();
157
158 // If enum, use enum value directly; otherwise lookup by name
159 std::optional<uint32_t> eventNum;
160
161 if (auto enumVal = edgeOp.getEvent().getEnumValue()) {
162 eventNum = static_cast<uint32_t>(*enumVal);
163 } else {
164 eventNum = targetModel.lookupEvent(eventName, tileID, isMem);
165 }
166
167 if (!eventNum) {
168 edgeOp.emitError("unknown trace event '") << eventName << "'";
169 return signalPassFailure();
170 }
171
172 // Map slot to field names
173 StringRef eventField =
174 (slot == 0) ? "Edge_Detection_Event_0" : "Edge_Detection_Event_1";
175 StringRef risingField = (slot == 0)
176 ? "Edge_Detection_0_Trigger_Rising"
177 : "Edge_Detection_1_Trigger_Rising";
178 StringRef fallingField = (slot == 0)
179 ? "Edge_Detection_0_Trigger_Falling"
180 : "Edge_Detection_1_Trigger_Falling";
181
182 // Source event
183 configBuilder.create<TraceRegOp>(
184 edgeOp.getLoc(),
185 builder.getStringAttr("Edge_Detection_event_control"),
186 builder.getStringAttr(eventField), edgeOp.getEvent(),
187 /*mask=*/nullptr,
188 builder.getStringAttr("edge" + std::to_string(slot) + " source"));
189
190 // Trigger mode
191 bool rising =
192 (trigger == EdgeTrigger::RISING || trigger == EdgeTrigger::BOTH);
193 bool falling =
194 (trigger == EdgeTrigger::FALLING || trigger == EdgeTrigger::BOTH);
195
196 configBuilder.create<TraceRegOp>(
197 edgeOp.getLoc(),
198 builder.getStringAttr("Edge_Detection_event_control"),
199 builder.getStringAttr(risingField),
200 builder.getI32IntegerAttr(rising ? 1 : 0),
201 /*mask=*/nullptr,
202 builder.getStringAttr("edge" + std::to_string(slot) + " rising"));
203
204 configBuilder.create<TraceRegOp>(
205 edgeOp.getLoc(),
206 builder.getStringAttr("Edge_Detection_event_control"),
207 builder.getStringAttr(fallingField),
208 builder.getI32IntegerAttr(falling ? 1 : 0),
209 /*mask=*/nullptr,
210 builder.getStringAttr("edge" + std::to_string(slot) +
211 " falling"));
212 }
213 }
214
215 // 1. Emit Trace_Control0 fields
216 // Check for start/stop events
217 for (auto &op : trace.getBody().getOps()) {
218 if (auto startOp = dyn_cast<TraceStartEventOp>(op)) {
219 uint32_t startEvent = 0;
220 if (startOp.getBroadcast()) {
221 uint32_t broadcastNum = *startOp.getBroadcast();
222 // Resolve broadcast channel to hardware event ID
223 std::string eventName;
224 if (tile.isShimTile()) {
225 eventName = "BROADCAST_A_" + std::to_string(broadcastNum);
226 } else {
227 eventName = "BROADCAST_" + std::to_string(broadcastNum);
228 }
229 auto eventNum = targetModel.lookupEvent(eventName, tileID, isMem);
230 if (eventNum) {
231 startEvent = *eventNum;
232 } else {
233 startOp.emitError("unknown broadcast event '")
234 << eventName << "'";
235 return signalPassFailure();
236 }
237 } else if (auto eventAttr = startOp.getEvent()) {
238 // Use getEventName() helper and check for enum
239 std::string eventName = eventAttr->getEventName();
240
241 std::optional<uint32_t> eventNum;
242 if (auto enumVal = eventAttr->getEnumValue()) {
243 eventNum = static_cast<uint32_t>(*enumVal);
244 } else {
245 eventNum = targetModel.lookupEvent(eventName, tileID, isMem);
246 }
247
248 if (eventNum) {
249 startEvent = *eventNum;
250 } else {
251 startOp.emitError("unknown trace event '") << eventName << "'";
252 return signalPassFailure();
253 }
254 }
255
256 configBuilder.create<TraceRegOp>(
257 trace.getLoc(), builder.getStringAttr("Trace_Control0"),
258 builder.getStringAttr("Trace_Start_Event"),
259 builder.getI32IntegerAttr(startEvent),
260 /*mask=*/nullptr, builder.getStringAttr("start event"));
261 }
262
263 if (auto stopOp = dyn_cast<TraceStopEventOp>(op)) {
264 uint32_t stopEvent = 0;
265 if (stopOp.getBroadcast()) {
266 uint32_t broadcastNum = *stopOp.getBroadcast();
267 // Resolve broadcast channel to hardware event ID
268 std::string eventName;
269 if (tile.isShimTile()) {
270 eventName = "BROADCAST_A_" + std::to_string(broadcastNum);
271 } else {
272 eventName = "BROADCAST_" + std::to_string(broadcastNum);
273 }
274 auto eventNum = targetModel.lookupEvent(eventName, tileID, isMem);
275 if (eventNum) {
276 stopEvent = *eventNum;
277 } else {
278 stopOp.emitError("unknown broadcast event '") << eventName << "'";
279 return signalPassFailure();
280 }
281 } else if (auto eventAttr = stopOp.getEvent()) {
282 // Use getEventName() helper and check for enum
283 std::string eventName = eventAttr->getEventName();
284
285 std::optional<uint32_t> eventNum;
286 if (auto enumVal = eventAttr->getEnumValue()) {
287 eventNum = static_cast<uint32_t>(*enumVal);
288 } else {
289 eventNum = targetModel.lookupEvent(eventName, tileID, isMem);
290 }
291
292 if (eventNum) {
293 stopEvent = *eventNum;
294 } else {
295 stopOp.emitError("unknown trace event '") << eventName << "'";
296 return signalPassFailure();
297 }
298 }
299
300 configBuilder.create<TraceRegOp>(
301 trace.getLoc(), builder.getStringAttr("Trace_Control0"),
302 builder.getStringAttr("Trace_Stop_Event"),
303 builder.getI32IntegerAttr(stopEvent),
304 /*mask=*/nullptr, builder.getStringAttr("stop event"));
305 }
306
307 // Emit mode if present.
308 // Only core traces expose Trace_Control0.Mode in the register DB.
309 // Memory, memory_tile, and shim modules do not have the Mode field.
310 bool isCore = (packetType == TracePacketType::Core);
311 if (auto modeOp = dyn_cast<TraceModeOp>(op); modeOp && isCore) {
312 configBuilder.create<TraceRegOp>(
313 trace.getLoc(), builder.getStringAttr("Trace_Control0"),
314 builder.getStringAttr("Mode"),
315 builder.getI32IntegerAttr(
316 static_cast<uint32_t>(modeOp.getMode())),
317 /*mask=*/nullptr, builder.getStringAttr("trace mode"));
318 }
319
320 // Emit packet config if present
321 if (auto packetOp = dyn_cast<TracePacketOp>(op)) {
322 configBuilder.create<TraceRegOp>(
323 trace.getLoc(), builder.getStringAttr("Trace_Control1"),
324 builder.getStringAttr("ID"),
325 builder.getI32IntegerAttr(packetOp.getId()),
326 /*mask=*/nullptr, builder.getStringAttr("packet ID"));
327
328 configBuilder.create<TraceRegOp>(
329 trace.getLoc(), builder.getStringAttr("Trace_Control1"),
330 builder.getStringAttr("Packet_Type"),
331 builder.getI32IntegerAttr(
332 static_cast<uint32_t>(packetOp.getType())),
333 /*mask=*/nullptr, builder.getStringAttr("packet type"));
334 }
335 }
336
337 // 2. Emit port configurations (Stream_Switch_Event_Port_Selection_0/1)
338 for (auto &op : trace.getBody().getOps()) {
339 if (auto portOp = dyn_cast<TracePortOp>(op)) {
340 uint32_t slot = portOp.getSlot();
341
342 // Determine which register based on slot
343 StringRef registerName = (slot < 4)
344 ? "Stream_Switch_Event_Port_Selection_0"
345 : "Stream_Switch_Event_Port_Selection_1";
346
347 // Generate field names
348 std::string idFieldName = "Port_" + std::to_string(slot) + "_ID";
349 std::string masterSlaveFieldName =
350 "Port_" + std::to_string(slot) + "_Master_Slave";
351
352 // Generate port value string "PORT:CHANNEL"
353 std::string portValue = stringifyWireBundle(portOp.getPort()).str() +
354 ":" + std::to_string(portOp.getChannel());
355
356 // Convert DMAChannelDir to master flag: S2MM=master(1), MM2S=slave(0)
357 int masterSlaveValue =
358 (portOp.getDirection() == DMAChannelDir::S2MM) ? 1 : 0;
359
360 // Emit Port_N_ID field
361 configBuilder.create<TraceRegOp>(
362 portOp.getLoc(), builder.getStringAttr(registerName),
363 builder.getStringAttr(idFieldName),
364 builder.getStringAttr(portValue), // "NORTH:1" format
365 /*mask=*/nullptr,
366 builder.getStringAttr("port " + std::to_string(slot) + " ID"));
367
368 // Emit Port_N_Master_Slave field
369 configBuilder.create<TraceRegOp>(
370 portOp.getLoc(), builder.getStringAttr(registerName),
371 builder.getStringAttr(masterSlaveFieldName),
372 builder.getI32IntegerAttr(masterSlaveValue),
373 /*mask=*/nullptr,
374 builder.getStringAttr("port " + std::to_string(slot) +
375 " master/slave"));
376 }
377 }
378
379 // 3. Emit event slots (Trace_Event0 / Trace_Event1)
380 SmallVector<TraceEventOp> events;
381 for (auto &op : trace.getBody().getOps()) {
382 if (auto eventOp = dyn_cast<TraceEventOp>(op)) {
383 events.push_back(eventOp);
384 }
385 }
386
387 for (size_t i = 0; i < events.size() && i < 8; ++i) {
388 std::string eventName = events[i].getEvent().getEventName();
389
390 // If enum, use enum value directly; otherwise lookup by name
391 std::optional<uint32_t> eventNum;
392
393 if (auto enumVal = events[i].getEvent().getEnumValue()) {
394 eventNum = static_cast<uint32_t>(*enumVal);
395 } else {
396 eventNum = targetModel.lookupEvent(eventName, tileID, isMem);
397 }
398
399 if (!eventNum) {
400 events[i].emitError("unknown trace event '") << eventName << "'";
401 return signalPassFailure();
402 }
403
404 // Determine which register and field
405 StringRef registerName = (i < 4) ? "Trace_Event0" : "Trace_Event1";
406 std::string fieldName = "Trace_Event" + std::to_string(i);
407
408 // Emit register write with event number as integer
409 configBuilder.create<TraceRegOp>(
410 trace.getLoc(), builder.getStringAttr(registerName),
411 builder.getStringAttr(fieldName),
412 events[i].getEvent(), // builder.getI32IntegerAttr(*eventNum),
413 /*mask=*/nullptr, builder.getStringAttr(eventName));
414 }
415
416 // Add terminator
417 configBuilder.create<EndOp>(trace.getLoc());
418
419 // Update all trace.start_config references
420 device.walk([&](TraceStartConfigOp startConfig) {
421 if (startConfig.getTraceConfig() == trace.getSymName()) {
422 startConfig.setTraceConfigAttr(
423 SymbolRefAttr::get(builder.getContext(), configName));
424 }
425 });
426
427 // Remove original trace op
428 trace.erase();
429 }
430 }
431};
432
433} // namespace
434
435std::unique_ptr<OperationPass<DeviceOp>>
437 return std::make_unique<AIETraceToConfigPass>();
438}
439
440//===----------------------------------------------------------------------===//
441// AIETraceRegPackWritesPass - Pack multiple register field writes
442//===----------------------------------------------------------------------===//
443
444namespace {
445
446struct AIETraceRegPackWritesPass
447 : xilinx::AIE::impl::AIETraceRegPackWritesBase<AIETraceRegPackWritesPass> {
448 void runOnOperation() override {
449 DeviceOp device = getOperation();
450 const auto &targetModel = device.getTargetModel();
451
452 // Process each trace config
453 device.walk([&](TraceConfigOp configOp) {
454 // Determine module based on tile type and packet type
455 auto tile = cast<TileOp>(configOp.getTile().getDefiningOp());
456
457 // Get packet type to determine if this is memory or core trace
458 bool isMem = false;
459 if (auto packetType = configOp.getPacketType()) {
460 isMem = (*packetType == TracePacketType::Mem);
461 }
462
463 // Phase 1: Convert field+value to mask+shifted_value
464 SmallVector<TraceRegOp> regsToConvert;
465 for (auto &op : configOp.getBody().front()) {
466 if (auto regOp = dyn_cast<TraceRegOp>(op)) {
467 if (regOp.getField() && !regOp.getMask()) {
468 regsToConvert.push_back(regOp);
469 }
470 }
471 }
472
473 OpBuilder builder(&configOp.getBody().front(),
474 configOp.getBody().front().begin());
475
476 TileID tileID = tile.getTileID();
477 for (auto regOp : regsToConvert) {
478 // Look up register and field information
479 const RegisterInfo *regInfo =
480 targetModel.lookupRegister(regOp.getRegName(), tileID, isMem);
481
482 if (!regInfo) {
483 regOp.emitError("Register not found in database: ")
484 << regOp.getRegName();
485 return signalPassFailure();
486 }
487
488 const BitFieldInfo *fieldInfo = regInfo->getField(*regOp.getField());
489 if (!fieldInfo) {
490 regOp.emitError("Field not found in register: ")
491 << *regOp.getField() << " in " << regOp.getRegName();
492 return signalPassFailure();
493 }
494
495 // Get the value - handle both integers and port strings
496 uint32_t value = 0;
497 Attribute valAttr = regOp.getValue();
498 if (auto traceEventAttr = dyn_cast<TraceEventAttr>(valAttr)) {
499 if (auto enumVal = traceEventAttr.getEnumValue()) {
500 value = static_cast<uint32_t>(*enumVal);
501 valAttr = builder.getI32IntegerAttr(value);
502 } else {
503 std::string eventName = traceEventAttr.getEventName();
504 std::optional<uint32_t> eventNum =
505 targetModel.lookupEvent(eventName, tileID, isMem);
506 if (!eventNum) {
507 regOp.emitError("unknown trace event '") << eventName << "'";
508 return signalPassFailure();
509 }
510 value = *eventNum;
511 valAttr = builder.getI32IntegerAttr(value);
512 }
513 }
514 if (auto intAttr = dyn_cast<IntegerAttr>(valAttr)) {
515 // Integer value
516 value = intAttr.getInt();
517 } else if (auto strAttr = dyn_cast<StringAttr>(valAttr)) {
518 // String value - check if it's a port specification
519 StringRef valueStr = strAttr.getValue();
520
521 // Determine master/slave from field name
522 // If field name contains "Master_Slave", this is not a port ID field
523 // Port ID fields are named "Port_N_ID"
524 bool isMasterSlaveField =
525 fieldInfo->name.find("Master_Slave") != std::string::npos;
526
527 if (!isMasterSlaveField && valueStr.contains(':')) {
528 // This looks like "PORT:CHANNEL" format
529 // We need master/slave info - look for corresponding Master_Slave
530 // field
531 bool master = false; // Default to slave
532
533 // Derive the port slot prefix from the current field name, e.g.:
534 // "Port_1_ID" -> "Port_1"
535 StringRef fieldName(fieldInfo->name);
536 StringRef portSlotPrefix;
537 if (fieldName.starts_with("Port_")) {
538 size_t idSuffixPos = fieldName.find("_ID");
539 if (idSuffixPos != std::string::npos)
540 portSlotPrefix = fieldName.take_front(idSuffixPos);
541 }
542 // Search for companion Master_Slave field in same register and,
543 // when possible, for the same port slot.
544 for (auto &siblingOp : configOp.getBody().front()) {
545 if (auto siblingReg = dyn_cast<TraceRegOp>(siblingOp)) {
546 if (siblingReg.getRegName() != regOp.getRegName() ||
547 !siblingReg.getField())
548 continue;
549 StringRef siblingFieldName = *siblingReg.getField();
550 // If we could determine a slot prefix (e.g. "Port_1"), require
551 // the sibling to match that slot and contain "Master_Slave".
552 if (!portSlotPrefix.empty()) {
553 if (!siblingFieldName.starts_with(portSlotPrefix) ||
554 !siblingFieldName.contains("Master_Slave"))
555 continue;
556 } else {
557 // Fallback: match any Master_Slave field in this register.
558 if (!siblingFieldName.contains("Master_Slave"))
559 continue;
560 }
561 // Found companion field - extract master flag
562 if (auto siblingInt =
563 dyn_cast<IntegerAttr>(siblingReg.getValue())) {
564 master = (siblingInt.getInt() != 0);
565 }
566 break;
567 }
568 }
569
570 // Resolve port value
571 auto portIndex =
572 targetModel.resolvePortValue(valueStr, tileID, master);
573 if (!portIndex) {
574 regOp.emitError("Failed to resolve port value: ") << valueStr;
575 return signalPassFailure();
576 }
577 value = *portIndex;
578 } else {
579 regOp.emitError("Unsupported string value: ") << valueStr;
580 return signalPassFailure();
581 }
582 } else {
583 regOp.emitError("Unsupported value type in pack pass");
584 return signalPassFailure();
585 }
586
587 // Compute mask and shifted value
588 auto mask = targetModel.getFieldMask(*fieldInfo);
589 if (!mask) {
590 regOp.emitError("Invalid field bit range for register write: ")
591 << regOp.getRegName() << "." << fieldInfo->name << " ["
592 << fieldInfo->bit_start << ":" << fieldInfo->bit_end
593 << "], width=" << fieldInfo->getWidth();
594 return signalPassFailure();
595 }
596 uint32_t shiftedValue = targetModel.encodeFieldValue(*fieldInfo, value);
597 // Create new operation with mask
598 builder.setInsertionPoint(regOp);
599 builder.create<TraceRegOp>(regOp.getLoc(), regOp.getRegNameAttr(),
600 nullptr, // no field
601 builder.getI32IntegerAttr(shiftedValue),
602 builder.getI32IntegerAttr(*mask),
603 regOp.getCommentAttr());
604
605 // Remove old operation
606 regOp.erase();
607 }
608
609 // Phase 2: Merge writes to same register with non-overlapping masks
610 bool changed = true;
611 while (changed) {
612 changed = false;
613
614 // Collect all register writes
615 SmallVector<TraceRegOp> regWrites;
616 for (TraceRegOp op : configOp.getBody().front().getOps<TraceRegOp>()) {
617 if (op.getMask()) {
618 regWrites.push_back(op);
619 }
620 }
621
622 // Try to merge pairs
623 for (size_t i = 0; i < regWrites.size() && !changed; ++i) {
624 for (size_t j = i + 1; j < regWrites.size() && !changed; ++j) {
625 TraceRegOp reg1 = regWrites[i];
626 TraceRegOp reg2 = regWrites[j];
627
628 // Must be same register
629 if (reg1.getRegName() != reg2.getRegName())
630 continue;
631
632 auto mask1Attr = reg1.getMask();
633 auto mask2Attr = reg2.getMask();
634 if (!mask1Attr || !mask2Attr)
635 continue;
636
637 uint32_t mask1 = *mask1Attr;
638 uint32_t mask2 = *mask2Attr;
639
640 // Check for overlap
641 if (mask1 & mask2) {
642 reg1.emitError("Overlapping masks for register ")
643 << reg1.getRegName() << ": mask1=" << mask1
644 << " mask2=" << mask2;
645 return signalPassFailure();
646 }
647
648 // Merge the two writes
649 auto value1Attr = dyn_cast<IntegerAttr>(reg1.getValue());
650 auto value2Attr = dyn_cast<IntegerAttr>(reg2.getValue());
651 if (!value1Attr || !value2Attr) {
652 reg1.emitError(
653 "Expected integer values for packed register writes");
654 return signalPassFailure();
655 }
656 uint32_t value1 = value1Attr.getInt();
657 uint32_t value2 = value2Attr.getInt();
658 uint32_t mergedValue = value1 | value2;
659 uint32_t mergedMask = mask1 | mask2;
660
661 // Create merged operation
662 builder.setInsertionPoint(reg1);
663 std::string comment;
664 if (reg1.getComment())
665 comment += reg1.getComment()->str();
666 if (reg2.getComment()) {
667 if (!comment.empty())
668 comment += " + ";
669 comment += reg2.getComment()->str();
670 }
671
672 builder.create<TraceRegOp>(
673 reg1.getLoc(), reg1.getRegNameAttr(), nullptr,
674 builder.getI32IntegerAttr(mergedValue),
675 builder.getI32IntegerAttr(mergedMask),
676 comment.empty() ? nullptr : builder.getStringAttr(comment));
677
678 // Remove both old operations
679 reg1.erase();
680 reg2.erase();
681 changed = true;
682 }
683 }
684 }
685 });
686 }
687};
688
689} // namespace
690
691std::unique_ptr<OperationPass<DeviceOp>>
693 return std::make_unique<AIETraceRegPackWritesPass>();
694}
std::shared_ptr< Value > value()
Definition cxxopts.hpp:1026
Include the generated interface declarations.
std::unique_ptr< mlir::OperationPass< DeviceOp > > createAIETraceRegPackWritesPass()
TileID { friend std::ostream &operator<<(std::ostream &os, const TileID &s) { os<< "TileID("<< s.col<< ", "<< s.row<< ")" TileID
std::unique_ptr< mlir::OperationPass< DeviceOp > > createAIETraceToConfigPass()
Bit field information for a register.
const BitFieldInfo * getField(llvm::StringRef fieldName) const