18#include "mlir/Dialect/Arith/IR/Arith.h"
19#include "mlir/Dialect/MemRef/IR/MemRef.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/Pass/Pass.h"
23#include "mlir/Transforms/DialectConversion.h"
24#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25#include "mlir/Transforms/WalkPatternRewriteDriver.h"
28#define GEN_PASS_DEF_AIEMATERIALIZERUNTIMESEQUENCES
29#include "aie/Dialect/AIEX/Transforms/AIEXPasses.h.inc"
32#define DEBUG_TYPE "aie-materialize-runtime-sequence"
49 AIE::RuntimeSequenceOp runtimeSequenceOp =
50 llvm::dyn_cast<AIE::RuntimeSequenceOp>(op);
51 if (!runtimeSequenceOp) {
52 op->emitError(
"RuntimeCallGraphCyclicityAnalysis can only be called on "
53 "aiex.runtime_sequence operations.");
59 llvm::DenseSet<AIE::RuntimeSequenceOp> callStack;
60 llvm::DenseSet<AIE::RuntimeSequenceOp> visited;
62 std::function<
bool(AIE::RuntimeSequenceOp)> hasCycle =
63 [&](AIE::RuntimeSequenceOp seq) ->
bool {
64 if (callStack.contains(seq)) {
67 if (visited.contains(seq)) {
71 callStack.insert(seq);
75 bool foundCycle =
false;
76 seq.walk([&](RunOp runOp) {
77 if (AIE::RuntimeSequenceOp callee =
78 runOp.getCalleeRuntimeSequenceOp()) {
79 if (hasCycle(callee)) {
81 return WalkResult::interrupt();
84 return WalkResult::advance();
91 if (hasCycle(runtimeSequenceOp)) {
109 PatternBenefit benefit = 1)
110 :
RewritePattern(ConfigureOp::getOperationName(), benefit, context) {}
113 PatternRewriter &rewriter)
const override {
114 ConfigureOp configureOp = llvm::dyn_cast<ConfigureOp>(op);
121 if (!llvm::isa<AIE::RuntimeSequenceOp>(configureOp->getParentOp())) {
125 AIE::DeviceOp referencedDevice = configureOp.getReferencedDeviceOp();
126 if (!referencedDevice) {
127 configureOp.emitError(
"Referenced symbol is not a device");
131 Block *configureBlock;
132 if (configureOp.getBody().empty()) {
133 configureBlock = rewriter.createBlock(&configureOp.getBody());
135 configureBlock = &configureOp.getBody().front();
138 rewriter.setInsertionPointToStart(configureBlock);
139 AIEX::NpuLoadPdiOp::create(
140 rewriter, configureOp.getLoc(),
141 FlatSymbolRefAttr::get(referencedDevice.getSymNameAttr()));
155collectReferencedSSAValues(Operation *op,
const IRMapping &argMap,
156 llvm::SetVector<Value> &referencedValues) {
158 auto processValue = [&](Value operand) {
159 if (argMap.contains(operand)) {
166 if (!argMap.contains(traceResult->rootArg)) {
167 referencedValues.insert(traceResult->rootArg);
173 referencedValues.insert(operand);
177 for (Value operand : op->getOperands()) {
178 processValue(operand);
182 for (Region ®ion : op->getRegions()) {
183 region.walk([&](Operation *nestedOp) {
184 for (Value operand : nestedOp->getOperands()) {
185 if (argMap.contains(operand)) {
190 Operation *defOp = operand.getDefiningOp();
191 if (defOp && op->isProperAncestor(defOp)) {
195 processValue(operand);
205copyReferencedSSAValues(PatternRewriter &rewriter,
206 const llvm::SetVector<Value> &referencedValues,
207 AIE::DeviceOp callerDevice, IRMapping &argMap,
208 mlir::OpBuilder::InsertPoint &clonedSSAInsertPoint,
209 Operation *errorReportOp) {
211 llvm::SetVector<Value> referencedValuesToVisit = referencedValues;
212 std::vector<Operation *> referencedOpsToClone = {};
213 while (!referencedValuesToVisit.empty()) {
214 Value referencedValue = referencedValuesToVisit.pop_back_val();
215 Operation *definingOp = referencedValue.getDefiningOp();
217 return errorReportOp->emitError()
218 <<
"Referenced value is not defined by an operation";
220 if (std::find(referencedOpsToClone.begin(), referencedOpsToClone.end(),
221 definingOp) != referencedOpsToClone.end()) {
225 if (
auto tileOp = llvm::dyn_cast<AIE::TileOp>(definingOp)) {
226 referencedOpsToClone.insert(referencedOpsToClone.begin(), definingOp);
227 }
else if (
auto lockOp = llvm::dyn_cast<AIE::LockOp>(definingOp)) {
228 Value lockTile = lockOp.getTile();
230 referencedValuesToVisit.insert(lockTile);
232 referencedOpsToClone.push_back(definingOp);
234 return errorReportOp->emitError()
235 <<
"Referenced SSA value defined by unsupported operation type: "
236 << definingOp->getName().getStringRef()
237 <<
". Currently only aie.tile and aie.lock operations are "
242 for (Operation *definingOp : referencedOpsToClone) {
243 if (
auto tileOp = llvm::dyn_cast<AIE::TileOp>(definingOp)) {
244 int col = tileOp.getCol();
245 int row = tileOp.getRow();
247 rewriter.restoreInsertionPoint(clonedSSAInsertPoint);
248 mlir::Operation *clonedTile =
nullptr;
252 AIE::TileOp existingTile =
nullptr;
253 for (AIE::TileOp tile : callerDevice.getOps<
AIE::TileOp>()) {
254 if (tile.getCol() == col && tile.getRow() == row) {
261 clonedTile = existingTile.getOperation();
263 if (tileOp->getAttrDictionary() != existingTile->getAttrDictionary()) {
266 auto filterAttrs = [](DictionaryAttr dict) -> DictionaryAttr {
267 SmallVector<NamedAttribute> filteredAttrs;
268 for (
auto namedAttr : dict) {
269 StringRef name = namedAttr.getName().getValue();
270 if (name !=
"col" && name !=
"row") {
271 filteredAttrs.push_back(namedAttr);
274 return DictionaryAttr::get(dict.getContext(), filteredAttrs);
277 DictionaryAttr tileAttrs = filterAttrs(tileOp->getAttrDictionary());
278 DictionaryAttr existingAttrs =
279 filterAttrs(existingTile->getAttrDictionary());
281 if (tileAttrs != existingAttrs) {
282 return errorReportOp->emitError()
283 <<
"aie.tile(" <<
col <<
", " <<
row
284 <<
") already exists in the device with different "
290 rewriter.restoreInsertionPoint(clonedSSAInsertPoint);
291 clonedTile = rewriter.clone(*tileOp);
292 clonedSSAInsertPoint = rewriter.saveInsertionPoint();
295 argMap.map(definingOp->getResult(0), clonedTile->getResult(0));
296 rewriter.replaceOpUsesWithIf(
297 definingOp, clonedTile->getResult(0), [&](OpOperand &operand) {
298 return operand.getOwner()->getParentOfType<AIE::DeviceOp>() ==
302 }
else if (
auto lockOp = llvm::dyn_cast<AIE::LockOp>(definingOp)) {
303 rewriter.restoreInsertionPoint(clonedSSAInsertPoint);
304 Operation *clonedLock = rewriter.clone(*lockOp, argMap);
305 clonedSSAInsertPoint = rewriter.saveInsertionPoint();
306 rewriter.replaceOpUsesWithIf(
307 definingOp, clonedLock->getResult(0), [&](OpOperand &operand) {
308 return operand.getOwner()->getParentOfType<AIE::DeviceOp>() ==
312 return errorReportOp->emitError()
313 <<
"Referenced SSA value defined by unsupported operation type: "
314 << definingOp->getName().getStringRef()
315 <<
". Currently only aie.tile and aie.lock operations are "
328static LogicalResult inlineReferencedSymbolDefinitions(
329 PatternRewriter &rewriter, Operation *op, Operation *lookupFrom,
331 llvm::DenseMap<SymbolRefAttr, SymbolRefAttr> &previouslyInlinedSymbolMap,
332 AIE::DeviceOp callerDevice,
333 mlir::OpBuilder::InsertPoint &clonedDefOpsInsertionPoint,
334 llvm::SetVector<SymbolRefAttr> &allSymbolNames) {
335 MLIRContext *ctx = op->getContext();
336 for (NamedAttribute namedAttr : op->getAttrs()) {
337 Attribute attr = namedAttr.getValue();
338 auto newAttr = attr.replace([&](SymbolRefAttr oldSymbolRef) {
339 SymbolRefAttr newSymbolRef;
340 if (!previouslyInlinedSymbolMap.count(oldSymbolRef)) {
341 llvm::StringRef oldName = oldSymbolRef.getRootReference().getValue();
342 std::string uniqueName = oldName.str();
343 unsigned uniquingCounter = 0;
344 while (allSymbolNames.count(SymbolRefAttr::get(ctx, uniqueName))) {
345 uniqueName = oldName.str() +
"_" + std::to_string(uniquingCounter);
348 newSymbolRef = SymbolRefAttr::get(ctx, uniqueName);
349 allSymbolNames.insert(newSymbolRef);
350 previouslyInlinedSymbolMap[oldSymbolRef] = newSymbolRef;
356 Operation *symbolDefOp =
357 SymbolTable::lookupNearestSymbolFrom(lookupFrom, oldSymbolRef);
359 if (ModuleOp moduleOp = lookupFrom->getParentOfType<ModuleOp>()) {
360 symbolDefOp = SymbolTable::lookupSymbolIn(moduleOp, oldSymbolRef);
364 return std::make_pair(newSymbolRef, WalkResult::interrupt());
369 if (llvm::isa<AIE::DeviceOp>(symbolDefOp)) {
370 return std::make_pair(oldSymbolRef, WalkResult::advance());
374 llvm::SetVector<Value> symbolReferencedValues;
375 collectReferencedSSAValues(symbolDefOp, argMap, symbolReferencedValues);
380 if (failed(copyReferencedSSAValues(rewriter, symbolReferencedValues,
381 callerDevice, argMap,
382 clonedDefOpsInsertionPoint, op))) {
383 return std::make_pair(newSymbolRef, WalkResult::interrupt());
388 rewriter.restoreInsertionPoint(clonedDefOpsInsertionPoint);
389 Operation *clonedSymbolDefOp = rewriter.clone(*symbolDefOp, argMap);
390 clonedSymbolDefOp->setAttr(SymbolTable::getSymbolAttrName(),
391 StringAttr::get(ctx, uniqueName));
392 clonedDefOpsInsertionPoint = rewriter.saveInsertionPoint();
394 newSymbolRef = previouslyInlinedSymbolMap[oldSymbolRef];
397 return std::make_pair(newSymbolRef, WalkResult::advance());
402 op->setAttr(namedAttr.getName(), newAttr);
417 :
RewritePattern(RunOp::getOperationName(), PatternBenefit(1), ctx),
423 PatternRewriter &rewriter)
const override {
424 llvm::DenseMap<SymbolRefAttr, SymbolRefAttr> previouslyInlinedSymbolMap;
426 RunOp runOp = llvm::dyn_cast<RunOp>(op);
431 AIE::DeviceOp calleeDevice = runOp.getCalleeDeviceOp();
432 AIE::RuntimeSequenceOp calleeRuntimeSequence =
433 runOp.getCalleeRuntimeSequenceOp();
434 if (!calleeDevice || !calleeRuntimeSequence) {
442 Region &calleeBody = calleeRuntimeSequence.getBody();
443 AIE::DeviceOp callerDevice =
444 runOp.getOperation()->getParentOfType<AIE::DeviceOp>();
446 runOp.emitError() <<
"needs to be in a DeviceOp";
452 ValueRange values = runOp.getArgs();
453 for (
unsigned i = 0, n = calleeBody.getNumArguments(); i < n; i++) {
454 BlockArgument arg = calleeBody.getArgument(i);
455 Value val = values[i];
456 argMap.map(arg, val);
465 llvm::SetVector<Value> referencedValues;
466 for (Operation &op : calleeBody.getOps()) {
467 collectReferencedSSAValues(&op, argMap, referencedValues);
469 llvm::SetVector<Value> filteredValues;
470 for (Value val : referencedValues) {
471 if (val.getParentRegion() != &calleeBody) {
472 filteredValues.insert(val);
475 referencedValues = std::move(filteredValues);
478 if (failed(copyReferencedSSAValues(rewriter, referencedValues, callerDevice,
486 rewriter.setInsertionPoint(runOp);
487 mlir::OpBuilder::InsertPoint clonedOpInsertionPoint =
488 rewriter.saveInsertionPoint();
489 for (Operation &op : calleeBody.getOps()) {
490 rewriter.restoreInsertionPoint(clonedOpInsertionPoint);
491 Operation *clonedOp = rewriter.clone(op, argMap);
492 clonedOpInsertionPoint = rewriter.saveInsertionPoint();
494 if (failed(inlineReferencedSymbolDefinitions(
495 rewriter, clonedOp, calleeRuntimeSequence.getOperation(), argMap,
503 rewriter.eraseOp(runOp);
515static LogicalResult verifyRunOpsInConfigureOp(ConfigureOp configureOp,
516 AIE::DeviceOp referencedDev) {
517 if (configureOp.getBody().empty())
519 for (RunOp runOp : configureOp.getBody().front().getOps<RunOp>()) {
520 auto seqName = runOp.getRuntimeSequenceSymbol();
521 Operation *maybeSeq = SymbolTable::lookupSymbolIn(referencedDev, seqName);
523 auto err = runOp.emitError()
524 <<
"No such runtime sequence for device '"
525 << referencedDev.getSymName() <<
"': '" << seqName <<
"'";
526 err.attachNote(referencedDev.getLoc())
527 <<
"This device does not have a '" << seqName <<
"' runtime sequence";
530 auto runtimeSeq = llvm::dyn_cast<AIE::RuntimeSequenceOp>(maybeSeq);
532 return runOp.emitError()
533 <<
"'" << seqName <<
"' is not a runtime sequence";
538 Region &calleeRegion = runtimeSeq.getBody();
539 unsigned numCalleeArgs =
540 calleeRegion.empty() ? 0 : calleeRegion.getNumArguments();
541 ValueRange
args = runOp.getArgs();
542 if (numCalleeArgs !=
args.size())
543 return runOp.emitOpError() <<
"argument count mismatch: expected "
544 << numCalleeArgs <<
" but got " <<
args.size();
545 for (
unsigned i = 0; i < numCalleeArgs; i++) {
546 Type expected = calleeRegion.front().getArgument(i).getType();
547 Type actual =
args[i].getType();
548 if (expected != actual)
549 return runOp.emitOpError()
550 <<
"argument " << i <<
" type mismatch: expected " << expected
551 <<
" but got " << actual;
559 AIEMaterializeRuntimeSequencesPass> {
561 ModuleOp moduleOp = getOperation();
564 for (AIE::DeviceOp deviceOp : moduleOp.getOps<AIE::DeviceOp>()) {
567 for (AIE::RuntimeSequenceOp runtimeSequenceOp :
568 deviceOp.getOps<AIE::RuntimeSequenceOp>()) {
569 if (failed(runtimeSequenceOp.verifyBeforeMaterialization())) {
570 return signalPassFailure();
578 for (ConfigureOp configureOp :
579 runtimeSequenceOp.getOps<ConfigureOp>()) {
580 AIE::DeviceOp referencedDev = configureOp.getReferencedDeviceOp();
581 if (!referencedDev) {
585 return signalPassFailure();
587 if (failed(verifyRunOpsInConfigureOp(configureOp, referencedDev)))
588 return signalPassFailure();
593 for (AIE::RuntimeSequenceOp runtimeSequenceOp :
594 deviceOp.getOps<AIE::RuntimeSequenceOp>()) {
596 getAnalysisManager().nest(deviceOp).nest(runtimeSequenceOp);
600 return signalPassFailure();
603 runtimeSequenceOp.emitError(
604 "Runtime sequence call graph contains a cycle");
605 return signalPassFailure();
613 mlir::Block &deviceBodyFirstBlock = deviceOp.getBodyRegion().front();
614 auto runtimeSequenceOps = deviceOp.getOps<AIE::RuntimeSequenceOp>();
615 if (runtimeSequenceOps.begin() == runtimeSequenceOps.end()) {
619 AIE::RuntimeSequenceOp firstRuntimeSequenceOp =
620 *runtimeSequenceOps.begin();
621 mlir::OpBuilder::InsertPoint ssaDefInsertPoint(
622 &deviceBodyFirstBlock, deviceBodyFirstBlock.begin());
623 mlir::OpBuilder::InsertPoint symbolDefInsertPoint(
624 &deviceBodyFirstBlock, mlir::Block::iterator(firstRuntimeSequenceOp));
625 llvm::SetVector<SymbolRefAttr> allSymbolNames = {};
626 for (Operation &op : deviceBodyFirstBlock) {
627 if (
auto symbolName = op.getAttrOfType<StringAttr>(
628 SymbolTable::getSymbolAttrName())) {
629 allSymbolNames.insert(SymbolRefAttr::get(symbolName));
633 MLIRContext *ctx = &getContext();
634 GreedyRewriteConfig rewriter_config = GreedyRewriteConfig();
635 rewriter_config.setRegionSimplificationLevel(
636 GreedySimplifyRegionLevel::Disabled);
638 RewritePatternSet patterns_0(ctx);
640 ctx, ssaDefInsertPoint, symbolDefInsertPoint, allSymbolNames);
641 if (failed(applyPatternsGreedily(deviceOp, std::move(patterns_0),
643 return signalPassFailure();
647 RewritePatternSet patterns_1(ctx);
649 walkAndApplyPatterns(deviceOp, std::move(patterns_1));
652 RewritePatternSet canonicalize_patterns(ctx);
653 AIEX::NpuLoadPdiOp::getCanonicalizationPatterns(canonicalize_patterns,
655 if (failed(applyPatternsGreedily(
656 deviceOp, std::move(canonicalize_patterns), rewriter_config))) {
657 return signalPassFailure();
662 for (AIE::RuntimeSequenceOp runtimeSequenceOp :
663 deviceOp.getOps<AIE::RuntimeSequenceOp>()) {
664 SmallVector<ConfigureOp> configureOps;
666 for (ConfigureOp configureOp :
667 runtimeSequenceOp.getOps<ConfigureOp>()) {
668 configureOps.push_back(configureOp);
671 IRRewriter rewriter(ctx);
672 for (ConfigureOp configureOp : configureOps) {
673 Block &configureBlock = configureOp.getBody().front();
676 SmallVector<Operation *> opsToHoist;
677 for (Operation &op : configureBlock) {
678 opsToHoist.push_back(&op);
682 rewriter.setInsertionPoint(configureOp);
683 for (Operation *op : opsToHoist) {
684 op->moveBefore(configureOp);
688 rewriter.eraseOp(configureOp);
696std::unique_ptr<OperationPass<ModuleOp>>
698 return std::make_unique<AIEMaterializeRuntimeSequencesPass>();
std::optional< SubviewTraceResult > traceSubviewToBlockArgument(Value value)
std::unique_ptr< mlir::OperationPass< mlir::ModuleOp > > createAIEMaterializeRuntimeSequencesPass()
void runOnOperation() override
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
mlir::OpBuilder::InsertPoint & ssaDefInsertPoint
InlineRuntimeCallsPattern(MLIRContext *ctx, mlir::OpBuilder::InsertPoint &ssaDefInsertPoint, mlir::OpBuilder::InsertPoint &symbolDefInsertPoint, llvm::SetVector< SymbolRefAttr > &allSymbolNames)
mlir::OpBuilder::InsertPoint & symbolDefInsertPoint
llvm::SetVector< SymbolRefAttr > & allSymbolNames
RuntimeCallGraphCyclicityAnalysis(Operation *op, AnalysisManager &am)
AnalysisManager & analysisManager