MLIR-AIE
AIEMaterializeRuntimeSequences.cpp
Go to the documentation of this file.
1//===- AIEMaterializeRuntimeSequences.cpp -----------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2025 Advanced Micro Devices Inc.
8//
9//===----------------------------------------------------------------------===//
10
16#include "aie/Targets/AIERT.h"
17
18#include "mlir/Dialect/Arith/IR/Arith.h"
19#include "mlir/Dialect/MemRef/IR/MemRef.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/Pass/Pass.h"
23#include "mlir/Transforms/DialectConversion.h"
24#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25#include "mlir/Transforms/WalkPatternRewriteDriver.h"
26
27namespace xilinx::AIEX {
28#define GEN_PASS_DEF_AIEMATERIALIZERUNTIMESEQUENCES
29#include "aie/Dialect/AIEX/Transforms/AIEXPasses.h.inc"
30} // namespace xilinx::AIEX
31
32#define DEBUG_TYPE "aie-materialize-runtime-sequence"
33
34using namespace mlir;
35using namespace xilinx;
36using namespace xilinx::AIEX;
37
39 AnalysisManager &analysisManager;
40
41 // if invalid, analysis failed and results should not be considered
42 bool isValid = false;
43
44 // Call graph is cyclic
45 bool isCyclic = false;
46
47 RuntimeCallGraphCyclicityAnalysis(Operation *op, AnalysisManager &am)
48 : analysisManager(am) {
49 AIE::RuntimeSequenceOp runtimeSequenceOp =
50 llvm::dyn_cast<AIE::RuntimeSequenceOp>(op);
51 if (!runtimeSequenceOp) {
52 op->emitError("RuntimeCallGraphCyclicityAnalysis can only be called on "
53 "aiex.runtime_sequence operations.");
54 return;
55 }
56
57 // Use DFS with a stack to detect cycles
58 // A cycle exists if we encounter a sequence already on the current path
59 llvm::DenseSet<AIE::RuntimeSequenceOp> callStack;
60 llvm::DenseSet<AIE::RuntimeSequenceOp> visited;
61
62 std::function<bool(AIE::RuntimeSequenceOp)> hasCycle =
63 [&](AIE::RuntimeSequenceOp seq) -> bool {
64 if (callStack.contains(seq)) {
65 return true; // Found a cycle
66 }
67 if (visited.contains(seq)) {
68 return false; // Already checked this sequence
69 }
70
71 callStack.insert(seq);
72 visited.insert(seq);
73
74 // Check all sequences called by this one
75 bool foundCycle = false;
76 seq.walk([&](RunOp runOp) {
77 if (AIE::RuntimeSequenceOp callee =
78 runOp.getCalleeRuntimeSequenceOp()) {
79 if (hasCycle(callee)) {
80 foundCycle = true;
81 return WalkResult::interrupt();
82 }
83 }
84 return WalkResult::advance();
85 });
86
87 callStack.erase(seq);
88 return foundCycle;
89 };
90
91 if (hasCycle(runtimeSequenceOp)) {
92 isCyclic = true;
93 isValid = true;
94 return;
95 }
96 isCyclic = false;
97 isValid = true;
98 }
99};
100
101// Turn aie.configure @device into aie.run %.. @configure
102// TODO: add check that liveness of two aie.configures do not overlap
103// (i.e., when we configure A, then configure B, cannot call runtime sequence of
104// A after configuring B)
105// TODO: add code to remove repeated @configure ops
107
109 PatternBenefit benefit = 1)
110 : RewritePattern(ConfigureOp::getOperationName(), benefit, context) {}
111
112 LogicalResult matchAndRewrite(Operation *op,
113 PatternRewriter &rewriter) const override {
114 ConfigureOp configureOp = llvm::dyn_cast<ConfigureOp>(op);
115 if (!configureOp) {
116 return failure();
117 }
118
119 // LoadPDI resets the whole device, hence cannot do partial reconfiguration;
120 // therefore, this only supports top-level configure ops
121 if (!llvm::isa<AIE::RuntimeSequenceOp>(configureOp->getParentOp())) {
122 return failure();
123 }
124
125 AIE::DeviceOp referencedDevice = configureOp.getReferencedDeviceOp();
126 if (!referencedDevice) {
127 configureOp.emitError("Referenced symbol is not a device");
128 return failure();
129 }
130
131 Block *configureBlock;
132 if (configureOp.getBody().empty()) {
133 configureBlock = rewriter.createBlock(&configureOp.getBody());
134 } else {
135 configureBlock = &configureOp.getBody().front();
136 }
137
138 rewriter.setInsertionPointToStart(configureBlock);
139 AIEX::NpuLoadPdiOp::create(
140 rewriter, configureOp.getLoc(),
141 FlatSymbolRefAttr::get(referencedDevice.getSymNameAttr()));
142
143 return success();
144 }
145};
146
147// Collects all external SSA values referenced by an operation (and its nested
148// operations).
149// 1. Collects SSA values from the operation's operands.
150// 2. Recursively walks through all operations in the operation's regions.
151// 3. For each nested operation, collects SSA values from its operands.
152// 4. Skips values that are already in argMap or defined within the operation.
153// 5. For memref.subview operations, traces to the root block argument
154static void
155collectReferencedSSAValues(Operation *op, const IRMapping &argMap,
156 llvm::SetVector<Value> &referencedValues) {
157
158 auto processValue = [&](Value operand) {
159 if (argMap.contains(operand)) {
160 return;
161 }
162
163 // If this is a subview, trace to the root block argument
164 if (auto traceResult = traceSubviewToBlockArgument(operand)) {
165 // Check if the root argument is already mapped
166 if (!argMap.contains(traceResult->rootArg)) {
167 referencedValues.insert(traceResult->rootArg);
168 }
169 return;
170 }
171
172 // Not a subview chain leading to block arg, add as-is
173 referencedValues.insert(operand);
174 };
175
176 // Collect SSA values from the operation's direct operands.
177 for (Value operand : op->getOperands()) {
178 processValue(operand);
179 }
180
181 // Recursively collect SSA values from nested operations in all regions.
182 for (Region &region : op->getRegions()) {
183 region.walk([&](Operation *nestedOp) {
184 for (Value operand : nestedOp->getOperands()) {
185 if (argMap.contains(operand)) {
186 return;
187 }
188
189 // Check if defined within the parent operation
190 Operation *defOp = operand.getDefiningOp();
191 if (defOp && op->isProperAncestor(defOp)) {
192 return;
193 }
194
195 processValue(operand);
196 }
197 });
198 }
199}
200
201// Copies SSA value definitions into the caller device.
202// Currently, only `aie.tile` operations are supported.
203// Updates argMap to map old values to new/existing values.
204static LogicalResult
205copyReferencedSSAValues(PatternRewriter &rewriter,
206 const llvm::SetVector<Value> &referencedValues,
207 AIE::DeviceOp callerDevice, IRMapping &argMap,
208 mlir::OpBuilder::InsertPoint &clonedSSAInsertPoint,
209 Operation *errorReportOp) {
210
211 llvm::SetVector<Value> referencedValuesToVisit = referencedValues;
212 std::vector<Operation *> referencedOpsToClone = {};
213 while (!referencedValuesToVisit.empty()) {
214 Value referencedValue = referencedValuesToVisit.pop_back_val();
215 Operation *definingOp = referencedValue.getDefiningOp();
216 if (!definingOp) {
217 return errorReportOp->emitError()
218 << "Referenced value is not defined by an operation";
219 }
220 if (std::find(referencedOpsToClone.begin(), referencedOpsToClone.end(),
221 definingOp) != referencedOpsToClone.end()) {
222 continue;
223 }
224
225 if (auto tileOp = llvm::dyn_cast<AIE::TileOp>(definingOp)) {
226 referencedOpsToClone.insert(referencedOpsToClone.begin(), definingOp);
227 } else if (auto lockOp = llvm::dyn_cast<AIE::LockOp>(definingOp)) {
228 Value lockTile = lockOp.getTile();
229 if (lockTile) {
230 referencedValuesToVisit.insert(lockTile);
231 }
232 referencedOpsToClone.push_back(definingOp);
233 } else {
234 return errorReportOp->emitError()
235 << "Referenced SSA value defined by unsupported operation type: "
236 << definingOp->getName().getStringRef()
237 << ". Currently only aie.tile and aie.lock operations are "
238 "supported.";
239 }
240 }
241
242 for (Operation *definingOp : referencedOpsToClone) {
243 if (auto tileOp = llvm::dyn_cast<AIE::TileOp>(definingOp)) {
244 int col = tileOp.getCol();
245 int row = tileOp.getRow();
246
247 rewriter.restoreInsertionPoint(clonedSSAInsertPoint);
248 mlir::Operation *clonedTile = nullptr;
249
250 // Check if a tile with matching col/row already exists in the caller
251 // device
252 AIE::TileOp existingTile = nullptr;
253 for (AIE::TileOp tile : callerDevice.getOps<AIE::TileOp>()) {
254 if (tile.getCol() == col && tile.getRow() == row) {
255 existingTile = tile;
256 break;
257 }
258 }
259
260 if (existingTile) {
261 clonedTile = existingTile.getOperation();
262 // Verify that all attributes match
263 if (tileOp->getAttrDictionary() != existingTile->getAttrDictionary()) {
264 // Filter out result type attributes and symbol attributes for
265 // comparison
266 auto filterAttrs = [](DictionaryAttr dict) -> DictionaryAttr {
267 SmallVector<NamedAttribute> filteredAttrs;
268 for (auto namedAttr : dict) {
269 StringRef name = namedAttr.getName().getValue();
270 if (name != "col" && name != "row") {
271 filteredAttrs.push_back(namedAttr);
272 }
273 }
274 return DictionaryAttr::get(dict.getContext(), filteredAttrs);
275 };
276
277 DictionaryAttr tileAttrs = filterAttrs(tileOp->getAttrDictionary());
278 DictionaryAttr existingAttrs =
279 filterAttrs(existingTile->getAttrDictionary());
280
281 if (tileAttrs != existingAttrs) {
282 return errorReportOp->emitError()
283 << "aie.tile(" << col << ", " << row
284 << ") already exists in the device with different "
285 "attributes";
286 }
287 }
288 } else {
289 // Clone the tile operation into the caller device
290 rewriter.restoreInsertionPoint(clonedSSAInsertPoint);
291 clonedTile = rewriter.clone(*tileOp);
292 clonedSSAInsertPoint = rewriter.saveInsertionPoint();
293 }
294
295 argMap.map(definingOp->getResult(0), clonedTile->getResult(0));
296 rewriter.replaceOpUsesWithIf(
297 definingOp, clonedTile->getResult(0), [&](OpOperand &operand) {
298 return operand.getOwner()->getParentOfType<AIE::DeviceOp>() ==
299 callerDevice;
300 });
301
302 } else if (auto lockOp = llvm::dyn_cast<AIE::LockOp>(definingOp)) {
303 rewriter.restoreInsertionPoint(clonedSSAInsertPoint);
304 Operation *clonedLock = rewriter.clone(*lockOp, argMap);
305 clonedSSAInsertPoint = rewriter.saveInsertionPoint();
306 rewriter.replaceOpUsesWithIf(
307 definingOp, clonedLock->getResult(0), [&](OpOperand &operand) {
308 return operand.getOwner()->getParentOfType<AIE::DeviceOp>() ==
309 callerDevice;
310 });
311 } else {
312 return errorReportOp->emitError()
313 << "Referenced SSA value defined by unsupported operation type: "
314 << definingOp->getName().getStringRef()
315 << ". Currently only aie.tile and aie.lock operations are "
316 "supported.";
317 }
318 }
319
320 return success();
321}
322
323// Inlines the definitions of all symbols referenced in the given operation
324// at the current insertion point in the given rewriter, unless the symbol
325// definition is in the "previouslyInlinedSymbolMap" map. While inlining,
326// symbols will be renamed to have a unique name.
327// Also copies in SSA values referenced by the inlined symbol definitions.
328static LogicalResult inlineReferencedSymbolDefinitions(
329 PatternRewriter &rewriter, Operation *op, Operation *lookupFrom,
330 IRMapping argMap,
331 llvm::DenseMap<SymbolRefAttr, SymbolRefAttr> &previouslyInlinedSymbolMap,
332 AIE::DeviceOp callerDevice,
333 mlir::OpBuilder::InsertPoint &clonedDefOpsInsertionPoint,
334 llvm::SetVector<SymbolRefAttr> &allSymbolNames) {
335 MLIRContext *ctx = op->getContext();
336 for (NamedAttribute namedAttr : op->getAttrs()) {
337 Attribute attr = namedAttr.getValue();
338 auto newAttr = attr.replace([&](SymbolRefAttr oldSymbolRef) {
339 SymbolRefAttr newSymbolRef;
340 if (!previouslyInlinedSymbolMap.count(oldSymbolRef)) {
341 llvm::StringRef oldName = oldSymbolRef.getRootReference().getValue();
342 std::string uniqueName = oldName.str();
343 unsigned uniquingCounter = 0;
344 while (allSymbolNames.count(SymbolRefAttr::get(ctx, uniqueName))) {
345 uniqueName = oldName.str() + "_" + std::to_string(uniquingCounter);
346 uniquingCounter++;
347 }
348 newSymbolRef = SymbolRefAttr::get(ctx, uniqueName);
349 allSymbolNames.insert(newSymbolRef);
350 previouslyInlinedSymbolMap[oldSymbolRef] = newSymbolRef;
351
352 // Add the new symbol definition
353 // First try to look up from the lookupFrom operation (e.g., within the
354 // callee device). If not found, try looking up from the module level
355 // (for cross-device references).
356 Operation *symbolDefOp =
357 SymbolTable::lookupNearestSymbolFrom(lookupFrom, oldSymbolRef);
358 if (!symbolDefOp) {
359 if (ModuleOp moduleOp = lookupFrom->getParentOfType<ModuleOp>()) {
360 symbolDefOp = SymbolTable::lookupSymbolIn(moduleOp, oldSymbolRef);
361 }
362 }
363 if (!symbolDefOp) {
364 return std::make_pair(newSymbolRef, WalkResult::interrupt());
365 }
366
367 // If the symbol is a device, don't clone it - keep the original
368 // reference. Device ops must stay at module level.
369 if (llvm::isa<AIE::DeviceOp>(symbolDefOp)) {
370 return std::make_pair(oldSymbolRef, WalkResult::advance());
371 }
372
373 // Collect SSA values referenced by the symbol definition operation
374 llvm::SetVector<Value> symbolReferencedValues;
375 collectReferencedSSAValues(symbolDefOp, argMap, symbolReferencedValues);
376
377 // Copy SSA values referenced by the symbol definition
378 // This updates clonedDefOpsInsertionPoint to be after the copied SSA
379 // values
380 if (failed(copyReferencedSSAValues(rewriter, symbolReferencedValues,
381 callerDevice, argMap,
382 clonedDefOpsInsertionPoint, op))) {
383 return std::make_pair(newSymbolRef, WalkResult::interrupt());
384 }
385
386 // Insert the cloned symbol at the device level, after its SSA
387 // dependencies
388 rewriter.restoreInsertionPoint(clonedDefOpsInsertionPoint);
389 Operation *clonedSymbolDefOp = rewriter.clone(*symbolDefOp, argMap);
390 clonedSymbolDefOp->setAttr(SymbolTable::getSymbolAttrName(),
391 StringAttr::get(ctx, uniqueName));
392 clonedDefOpsInsertionPoint = rewriter.saveInsertionPoint();
393 } else {
394 newSymbolRef = previouslyInlinedSymbolMap[oldSymbolRef];
395 }
396
397 return std::make_pair(newSymbolRef, WalkResult::advance());
398 });
399 if (!newAttr) {
400 return failure();
401 }
402 op->setAttr(namedAttr.getName(), newAttr);
403 }
404 return success();
405}
406
408
409 mlir::OpBuilder::InsertPoint &ssaDefInsertPoint;
410 mlir::OpBuilder::InsertPoint &symbolDefInsertPoint;
411 llvm::SetVector<SymbolRefAttr> &allSymbolNames;
412
414 mlir::OpBuilder::InsertPoint &ssaDefInsertPoint,
415 mlir::OpBuilder::InsertPoint &symbolDefInsertPoint,
416 llvm::SetVector<SymbolRefAttr> &allSymbolNames)
417 : RewritePattern(RunOp::getOperationName(), PatternBenefit(1), ctx),
421
422 LogicalResult matchAndRewrite(Operation *op,
423 PatternRewriter &rewriter) const override {
424 llvm::DenseMap<SymbolRefAttr, SymbolRefAttr> previouslyInlinedSymbolMap;
425
426 RunOp runOp = llvm::dyn_cast<RunOp>(op);
427 if (!runOp) {
428 return failure();
429 }
430
431 AIE::DeviceOp calleeDevice = runOp.getCalleeDeviceOp();
432 AIE::RuntimeSequenceOp calleeRuntimeSequence =
433 runOp.getCalleeRuntimeSequenceOp();
434 if (!calleeDevice || !calleeRuntimeSequence) {
435 return failure();
436 }
437
438 // rewrite logic
439
440 // Get caller and callee bodies. The callee body will be inlined into the
441 // caller body at the point of the RunOp.
442 Region &calleeBody = calleeRuntimeSequence.getBody();
443 AIE::DeviceOp callerDevice =
444 runOp.getOperation()->getParentOfType<AIE::DeviceOp>();
445 if (!callerDevice) {
446 runOp.emitError() << "needs to be in a DeviceOp";
447 return failure();
448 }
449
450 // The argMap maps callee arguments to caller SSA values.
451 IRMapping argMap;
452 ValueRange values = runOp.getArgs();
453 for (unsigned i = 0, n = calleeBody.getNumArguments(); i < n; i++) {
454 BlockArgument arg = calleeBody.getArgument(i);
455 Value val = values[i];
456 argMap.map(arg, val);
457 }
458
459 // The callee body may reference SSA values and symbols that are defined
460 // in the callee device (outside the callee runtime sequence). We will
461 // inline a supported set of these and error otherwise.
462
463 // Collect SSA values referenced in the callee not defined by the callee and
464 // not in the argMap.
465 llvm::SetVector<Value> referencedValues;
466 for (Operation &op : calleeBody.getOps()) {
467 collectReferencedSSAValues(&op, argMap, referencedValues);
468 }
469 llvm::SetVector<Value> filteredValues;
470 for (Value val : referencedValues) {
471 if (val.getParentRegion() != &calleeBody) {
472 filteredValues.insert(val);
473 }
474 }
475 referencedValues = std::move(filteredValues);
476
477 // Copy the operations that define these SSA values into the caller device
478 if (failed(copyReferencedSSAValues(rewriter, referencedValues, callerDevice,
479 argMap, ssaDefInsertPoint, runOp))) {
480 return failure();
481 }
482
483 // Now, also inline symbol definitions referenced in the callee body;
484 // this may pull in additional SSA values referenced by the symbol
485 // definitions.
486 rewriter.setInsertionPoint(runOp);
487 mlir::OpBuilder::InsertPoint clonedOpInsertionPoint =
488 rewriter.saveInsertionPoint();
489 for (Operation &op : calleeBody.getOps()) {
490 rewriter.restoreInsertionPoint(clonedOpInsertionPoint);
491 Operation *clonedOp = rewriter.clone(op, argMap);
492 clonedOpInsertionPoint = rewriter.saveInsertionPoint();
493
494 if (failed(inlineReferencedSymbolDefinitions(
495 rewriter, clonedOp, calleeRuntimeSequence.getOperation(), argMap,
496 previouslyInlinedSymbolMap, callerDevice, symbolDefInsertPoint,
497 allSymbolNames))) {
498 return failure();
499 }
500 }
501
502 // The aiex.run op has been inlined; erase it.
503 rewriter.eraseOp(runOp);
504
505 return success();
506 }
507};
508
509/// Validate all aiex.run ops inside a aiex.configure op against the referenced
510/// device. This must be called sequentially (before any parallel pass
511/// execution) because it performs cross-DeviceOp symbol table lookups.
512/// Placing this validation in RunOp::verify() is unsafe: MLIR's pass manager
513/// may invoke op verifiers concurrently on sibling DeviceOps, causing a data
514/// race on the referenced device's symbol table.
515static LogicalResult verifyRunOpsInConfigureOp(ConfigureOp configureOp,
516 AIE::DeviceOp referencedDev) {
517 if (configureOp.getBody().empty())
518 return success();
519 for (RunOp runOp : configureOp.getBody().front().getOps<RunOp>()) {
520 auto seqName = runOp.getRuntimeSequenceSymbol();
521 Operation *maybeSeq = SymbolTable::lookupSymbolIn(referencedDev, seqName);
522 if (!maybeSeq) {
523 auto err = runOp.emitError()
524 << "No such runtime sequence for device '"
525 << referencedDev.getSymName() << "': '" << seqName << "'";
526 err.attachNote(referencedDev.getLoc())
527 << "This device does not have a '" << seqName << "' runtime sequence";
528 return failure();
529 }
530 auto runtimeSeq = llvm::dyn_cast<AIE::RuntimeSequenceOp>(maybeSeq);
531 if (!runtimeSeq) {
532 return runOp.emitError()
533 << "'" << seqName << "' is not a runtime sequence";
534 }
535
536 // Validate argument count and types against the callee signature.
537 // An empty body region (no blocks) means the sequence takes no arguments.
538 Region &calleeRegion = runtimeSeq.getBody();
539 unsigned numCalleeArgs =
540 calleeRegion.empty() ? 0 : calleeRegion.getNumArguments();
541 ValueRange args = runOp.getArgs();
542 if (numCalleeArgs != args.size())
543 return runOp.emitOpError() << "argument count mismatch: expected "
544 << numCalleeArgs << " but got " << args.size();
545 for (unsigned i = 0; i < numCalleeArgs; i++) {
546 Type expected = calleeRegion.front().getArgument(i).getType();
547 Type actual = args[i].getType();
548 if (expected != actual)
549 return runOp.emitOpError()
550 << "argument " << i << " type mismatch: expected " << expected
551 << " but got " << actual;
552 }
553 }
554 return success();
555}
556
559 AIEMaterializeRuntimeSequencesPass> {
560 void runOnOperation() override {
561 ModuleOp moduleOp = getOperation();
562
563 // Process each device in the module
564 for (AIE::DeviceOp deviceOp : moduleOp.getOps<AIE::DeviceOp>()) {
565
566 // Verify all runtime sequences before materialization
567 for (AIE::RuntimeSequenceOp runtimeSequenceOp :
568 deviceOp.getOps<AIE::RuntimeSequenceOp>()) {
569 if (failed(runtimeSequenceOp.verifyBeforeMaterialization())) {
570 return signalPassFailure();
571 }
572
573 // Validate aiex.run ops inside aiex.configure ops. This cross-device
574 // check is performed here (sequentially) rather than in RunOp::verify()
575 // to avoid a race condition: MLIR's pass manager runs verifiers on
576 // sibling DeviceOps concurrently, and looking up symbols in a sibling
577 // DeviceOp from a verifier causes a data race on its symbol table.
578 for (ConfigureOp configureOp :
579 runtimeSequenceOp.getOps<ConfigureOp>()) {
580 AIE::DeviceOp referencedDev = configureOp.getReferencedDeviceOp();
581 if (!referencedDev) {
582 // ConfigureOp::verify() already reported the error (no such
583 // device, not a device, or device type mismatch) at parse time;
584 // this is a safety net if verification was disabled.
585 return signalPassFailure();
586 }
587 if (failed(verifyRunOpsInConfigureOp(configureOp, referencedDev)))
588 return signalPassFailure();
589 }
590 }
591
592 // Check for cycles in runtime sequence calls
593 for (AIE::RuntimeSequenceOp runtimeSequenceOp :
594 deviceOp.getOps<AIE::RuntimeSequenceOp>()) {
595 AnalysisManager am =
596 getAnalysisManager().nest(deviceOp).nest(runtimeSequenceOp);
598 am.getAnalysis<RuntimeCallGraphCyclicityAnalysis>();
599 if (!cyclicity.isValid) {
600 return signalPassFailure();
601 }
602 if (cyclicity.isCyclic) {
603 runtimeSequenceOp.emitError(
604 "Runtime sequence call graph contains a cycle");
605 return signalPassFailure();
606 }
607 }
608
609 // Greedily inline all runtime sequences that can be inlined;
610 // this will start with runtime sequences that do not call other runtime
611 // sequences (leaves); once their callers inline them, the callers can
612 // be inlined as well, and so on
613 mlir::Block &deviceBodyFirstBlock = deviceOp.getBodyRegion().front();
614 auto runtimeSequenceOps = deviceOp.getOps<AIE::RuntimeSequenceOp>();
615 if (runtimeSequenceOps.begin() == runtimeSequenceOps.end()) {
616 // No runtime sequences to materialize
617 continue;
618 }
619 AIE::RuntimeSequenceOp firstRuntimeSequenceOp =
620 *runtimeSequenceOps.begin();
621 mlir::OpBuilder::InsertPoint ssaDefInsertPoint(
622 &deviceBodyFirstBlock, deviceBodyFirstBlock.begin());
623 mlir::OpBuilder::InsertPoint symbolDefInsertPoint(
624 &deviceBodyFirstBlock, mlir::Block::iterator(firstRuntimeSequenceOp));
625 llvm::SetVector<SymbolRefAttr> allSymbolNames = {};
626 for (Operation &op : deviceBodyFirstBlock) {
627 if (auto symbolName = op.getAttrOfType<StringAttr>(
628 SymbolTable::getSymbolAttrName())) {
629 allSymbolNames.insert(SymbolRefAttr::get(symbolName));
630 }
631 }
632
633 MLIRContext *ctx = &getContext();
634 GreedyRewriteConfig rewriter_config = GreedyRewriteConfig();
635 rewriter_config.setRegionSimplificationLevel(
636 GreedySimplifyRegionLevel::Disabled);
637
638 RewritePatternSet patterns_0(ctx);
639 patterns_0.insert<InlineRuntimeCallsPattern>(
640 ctx, ssaDefInsertPoint, symbolDefInsertPoint, allSymbolNames);
641 if (failed(applyPatternsGreedily(deviceOp, std::move(patterns_0),
642 rewriter_config))) {
643 return signalPassFailure();
644 }
645
646 // Insert LoadPDI ops for each aiex.configure op
647 RewritePatternSet patterns_1(ctx);
648 patterns_1.insert<InsertLoadPdiForConfigurePattern>(ctx);
649 walkAndApplyPatterns(deviceOp, std::move(patterns_1));
650
651 // Canonicalize to remove duplicate back-to-back load_pdi ops
652 RewritePatternSet canonicalize_patterns(ctx);
653 AIEX::NpuLoadPdiOp::getCanonicalizationPatterns(canonicalize_patterns,
654 ctx);
655 if (failed(applyPatternsGreedily(
656 deviceOp, std::move(canonicalize_patterns), rewriter_config))) {
657 return signalPassFailure();
658 }
659
660 // Flatten the IR: hoist all operations inside aiex.configure to be direct
661 // children of the runtime sequence, preserving order
662 for (AIE::RuntimeSequenceOp runtimeSequenceOp :
663 deviceOp.getOps<AIE::RuntimeSequenceOp>()) {
664 SmallVector<ConfigureOp> configureOps;
665
666 for (ConfigureOp configureOp :
667 runtimeSequenceOp.getOps<ConfigureOp>()) {
668 configureOps.push_back(configureOp);
669 }
670
671 IRRewriter rewriter(ctx);
672 for (ConfigureOp configureOp : configureOps) {
673 Block &configureBlock = configureOp.getBody().front();
674
675 // Collect all operations in the configure block
676 SmallVector<Operation *> opsToHoist;
677 for (Operation &op : configureBlock) {
678 opsToHoist.push_back(&op);
679 }
680
681 // Hoist operations to be right before the configure op
682 rewriter.setInsertionPoint(configureOp);
683 for (Operation *op : opsToHoist) {
684 op->moveBefore(configureOp);
685 }
686
687 // Erase the now-empty configure op
688 rewriter.eraseOp(configureOp);
689 }
690 }
691
692 } // end for each device
693 }
694};
695
696std::unique_ptr<OperationPass<ModuleOp>>
698 return std::make_unique<AIEMaterializeRuntimeSequencesPass>();
699}
std::optional< SubviewTraceResult > traceSubviewToBlockArgument(Value value)
Definition AIEUtils.cpp:19
std::unique_ptr< mlir::OperationPass< mlir::ModuleOp > > createAIEMaterializeRuntimeSequencesPass()
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
mlir::OpBuilder::InsertPoint & ssaDefInsertPoint
InlineRuntimeCallsPattern(MLIRContext *ctx, mlir::OpBuilder::InsertPoint &ssaDefInsertPoint, mlir::OpBuilder::InsertPoint &symbolDefInsertPoint, llvm::SetVector< SymbolRefAttr > &allSymbolNames)
mlir::OpBuilder::InsertPoint & symbolDefInsertPoint
llvm::SetVector< SymbolRefAttr > & allSymbolNames
InsertLoadPdiForConfigurePattern(MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
RuntimeCallGraphCyclicityAnalysis(Operation *op, AnalysisManager &am)