15#include "mlir/IR/IRMapping.h"
16#include "mlir/Pass/Pass.h"
17#include "mlir/Transforms/DialectConversion.h"
18#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24#define GEN_PASS_DEF_AIENPUTOCERT
25#define GEN_PASS_DEF_AIECERTPAGES
26#include "aie/Dialect/AIEX/Transforms/AIEXPasses.h.inc"
32#define DEBUG_TYPE "npu-to-cert"
38static constexpr uint32_t cert_page_size = 8000;
41 using OpConversionPattern::OpConversionPattern;
43 RuntimeSequenceToCertJob(MLIRContext *context, PatternBenefit benefit = 1)
47 matchAndRewrite(AIE::RuntimeSequenceOp op, OpAdaptor adaptor,
48 ConversionPatternRewriter &rewriter)
const override {
50 auto symName = op.getSymName();
51 uint32_t newJobId = 1;
52 if (symName !=
"configure") {
53 uint32_t maxJobId = 1;
54 op->getParentOp()->walk([&](AIEX::CertJobOp certJobOp) {
55 maxJobId = std::max(maxJobId, certJobOp.getJobId());
57 newJobId = maxJobId + 1;
59 auto jobOp = rewriter.replaceOpWithNewOp<AIEX::CertJobOp>(
60 op, op->getResultTypes(), newJobId);
62 op.getRegion().cloneInto(&jobOp.getBody(), remap);
63 AIEX::CertJobOp::ensureTerminator(jobOp.getBody(), rewriter, op->getLoc());
70 using OpConversionPattern::OpConversionPattern;
73 matchAndRewrite(AIEX::NpuWrite32Op op, OpAdaptor adaptor,
74 ConversionPatternRewriter &rewriter)
const override {
75 rewriter.replaceOpWithNewOp<AIEX::CertWrite32Op>(op, op.getAddress(),
81struct NpuMaskWrite32ToCertMaskWrite32
83 using OpConversionPattern::OpConversionPattern;
86 matchAndRewrite(AIEX::NpuMaskWrite32Op op, OpAdaptor adaptor,
87 ConversionPatternRewriter &rewriter)
const override {
88 rewriter.replaceOpWithNewOp<AIEX::CertMaskWrite32Op>(
89 op, op.getAddress(), op.getMask(), op.getValue());
95 using OpConversionPattern::OpConversionPattern;
98 matchAndRewrite(AIEX::NpuBlockWriteOp op, OpAdaptor adaptor,
99 ConversionPatternRewriter &rewriter)
const override {
101 memref::GetGlobalOp dataOperand =
102 dyn_cast_or_null<memref::GetGlobalOp>(op.getData().getDefiningOp());
105 MemRefType dataType = cast<MemRefType>(dataOperand.getResult().getType());
106 uint32_t dataSize = dataType.getNumElements();
109 std::string symbolName =
"chain_" + std::to_string(
id);
110 while (op->getParentOfType<AIE::DeviceOp>().lookupSymbol(symbolName))
111 symbolName =
"chain_" + std::to_string(++
id);
114 rewriter.replaceOpWithNewOp<AIEX::CertUcDmaWriteDesSyncOp>(op, symbolName);
117 rewriter.setInsertionPoint(op->getParentOfType<AIEX::CertJobOp>());
118 auto symbolAttr = rewriter.getStringAttr(symbolName);
120 AIEX::CertUcDmaChainOp::create(rewriter, op.getLoc(), symbolAttr);
122 Block *bb =
new Block();
123 chainOp.getRegion().push_back(bb);
124 rewriter.setInsertionPointToStart(bb);
125 AIEX::CertUcDmaBdOp::create(rewriter, op.getLoc(), dataOperand.getName(),
126 op.getAddress(), dataSize,
false);
128 AIEX::CertUcDmaChainOp::ensureTerminator(chainOp.getBody(), rewriter,
135 using OpConversionPattern::OpConversionPattern;
138 matchAndRewrite(AIEX::NpuSyncOp op, OpAdaptor adaptor,
139 ConversionPatternRewriter &rewriter)
const override {
140 uint32_t row = op.getRow();
141 uint32_t col = op.getColumn();
145 const int row_id_shift = 16;
146 const int col_id_shift = 21;
147 uint16_t tile_id = col << (col_id_shift - row_id_shift) | row;
148 uint32_t channel = op.getChannel();
149 uint32_t direction = op.getDirection();
151 const std::vector<int> chan2actor_shim_s2mm = {0, 2};
152 const std::vector<int> chan2actor_shim_mm2s = {6, 7, 8, 9};
154 const std::vector<int> chan2actor_mem_s2mm = {1, 2, 3, 4, 5, 6, 7};
155 const std::vector<int> chan2actor_mem_mm2s = {16, 17, 18, 19, 20,
157 const std::vector<int> chan2actor_tile_s2mm = {0, 1};
158 const std::vector<int> chan2actor_tile_mm2s = {6};
161 direction ==
static_cast<std::underlying_type_t<AIE::DMAChannelDir>
>(
162 AIE::DMAChannelDir::S2MM);
164 const std::vector<int> *chan2actor =
nullptr;
165 if (tm.isCoreTile(col, row))
166 chan2actor = isS2MM ? &chan2actor_tile_s2mm : &chan2actor_tile_mm2s;
167 else if (tm.isMemTile(col, row))
168 chan2actor = isS2MM ? &chan2actor_mem_s2mm : &chan2actor_mem_mm2s;
170 chan2actor = isS2MM ? &chan2actor_shim_s2mm : &chan2actor_shim_mm2s;
172 size_t chanIdx =
static_cast<size_t>(channel);
173 if (!chan2actor || chanIdx >= chan2actor->size()) {
174 op.emitError(
"invalid DMA channel ")
175 << channel <<
" for " << (isS2MM ?
"S2MM" :
"MM2S")
176 <<
" direction in NpuSyncToCertWaitTCTS conversion";
180 uint8_t actor_id =
static_cast<uint8_t
>((*chan2actor)[chanIdx]);
181 uint8_t num_tcts = 1;
182 rewriter.replaceOpWithNewOp<AIEX::CertWaitTCTSOp>(op, tile_id, actor_id,
188struct NpuAddressPatchToCertApplyOffset57
190 using OpConversionPattern::OpConversionPattern;
193 matchAndRewrite(AIEX::NpuAddressPatchOp op, OpAdaptor adaptor,
194 ConversionPatternRewriter &rewriter)
const override {
196 Block::iterator it(op);
197 while (it != op->getBlock()->begin()) {
199 auto blockWriteOp = dyn_cast<AIEX::NpuBlockWriteOp>(*it);
204 uint32_t addr = op.getAddr();
205 int col = (addr >> tm.getColumnShift()) & 0x1f;
206 int row = (addr >> tm.getRowShift()) & 0x1f;
207 if (!tm.isValidTile({col, row}))
211 if (blockWriteOp.getAddress() + tm.getDmaBdAddressOffset(col, row) !=
215 Value data = blockWriteOp.getData();
216 auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(data.getDefiningOp());
221 rewriter.setInsertionPoint(blockWriteOp);
222 rewriter.replaceOpWithNewOp<AIEX::CertApplyOffset57Op>(
223 op, getGlobalOp.getName(), 1, op.getArgIdx());
231struct MergeConsecutiveCertUcDmaWriteDesSyncOps
233 using OpRewritePattern::OpRewritePattern;
235 LogicalResult matchAndRewrite(AIEX::CertUcDmaWriteDesSyncOp op,
236 PatternRewriter &rewriter)
const override {
238 Block::iterator it(op);
239 AIEX::CertUcDmaWriteDesSyncOp prevWriteDesSync =
nullptr;
240 while (it != op->getBlock()->begin() && !prevWriteDesSync) {
242 Operation *prevOp = &*it;
243 if (isa<AIEX::CertWrite32Op, AIEX::CertMaskWrite32Op,
244 AIEX::CertApplyOffset57Op, AIEX::CertWaitTCTSOp>(prevOp))
246 prevWriteDesSync = dyn_cast<AIEX::CertUcDmaWriteDesSyncOp>(prevOp);
248 if (!prevWriteDesSync)
252 StringRef sym_name = op.getSymbol();
253 StringRef prev_sym_name = prevWriteDesSync.getSymbol();
254 auto chain = dyn_cast_if_present<AIEX::CertUcDmaChainOp>(
255 op->getParentOfType<AIE::DeviceOp>().lookupSymbol(sym_name));
256 auto prevChain = dyn_cast_if_present<AIEX::CertUcDmaChainOp>(
257 prevWriteDesSync->getParentOfType<AIE::DeviceOp>().lookupSymbol(
259 if (!chain || !prevChain)
264 uint32_t prevChainSize = 0;
265 for (
auto &o : prevChain.getBody().front().getOperations()) {
266 auto bdOp = dyn_cast<AIEX::CertUcDmaBdOp>(o);
269 prevChainSize += bdOp.getLength() *
sizeof(int);
271 uint32_t currChainSize = 0;
272 for (
auto &o : chain.getBody().front().getOperations()) {
273 auto bdOp = dyn_cast<AIEX::CertUcDmaBdOp>(o);
276 currChainSize += bdOp.getLength() *
sizeof(int);
278 if ((currChainSize + prevChainSize) >= cert_page_size)
282 rewriter.setInsertionPointToStart(&chain.getBody().front());
283 for (
auto &o : prevChain.getBody().front().getOperations()) {
284 auto bdOp = dyn_cast<AIEX::CertUcDmaBdOp>(o);
287 AIEX::CertUcDmaBdOp::create(
288 rewriter, bdOp.getLoc(), bdOp.getRemoteAddress(),
289 bdOp.getLocalAddress(), bdOp.getLength(),
true);
291 rewriter.eraseOp(prevChain);
292 rewriter.eraseOp(prevWriteDesSync);
297struct SplitNpuBlockWriteOpPattern :
OpRewritePattern<AIEX::NpuBlockWriteOp> {
298 using OpRewritePattern::OpRewritePattern;
300 LogicalResult matchAndRewrite(AIEX::NpuBlockWriteOp op,
301 PatternRewriter &rewriter)
const override {
303 memref::GetGlobalOp dataOperand =
304 dyn_cast_or_null<memref::GetGlobalOp>(op.getData().getDefiningOp());
308 MemRefType dataType = cast<MemRefType>(dataOperand.getResult().getType());
309 uint32_t dataSize = dataType.getNumElements();
311 uint32_t dataSizeBytes = dataSize *
sizeof(int);
312 if (dataSizeBytes < cert_page_size)
315 auto loc = op.getLoc();
318 uint32_t splitElements = dataSize / 2;
319 uint32_t firstChunkSize = splitElements;
320 uint32_t secondChunkSize = dataSize - splitElements;
323 auto deviceOp = op->getParentOfType<AIE::DeviceOp>();
324 auto originalGlobal = dyn_cast_if_present<memref::GlobalOp>(
325 deviceOp.lookupSymbol(dataOperand.getName()));
330 auto originalData = originalGlobal.getInitialValue();
334 auto denseData = dyn_cast<DenseIntElementsAttr>(*originalData);
339 auto dataValues = denseData.getValues<APInt>();
340 std::vector<APInt> firstChunkData(dataValues.begin(),
341 dataValues.begin() + firstChunkSize);
342 std::vector<APInt> secondChunkData(dataValues.begin() + firstChunkSize,
346 auto elementType = rewriter.getI32Type();
347 auto firstChunkType = MemRefType::get({firstChunkSize}, elementType);
348 auto secondChunkType = MemRefType::get({secondChunkSize}, elementType);
349 TensorType firstTensorType =
350 RankedTensorType::get({firstChunkSize}, elementType);
351 TensorType secondTensorType =
352 RankedTensorType::get({secondChunkSize}, elementType);
354 auto firstChunkAttr =
355 DenseIntElementsAttr::get(firstTensorType, firstChunkData);
356 auto secondChunkAttr =
357 DenseIntElementsAttr::get(secondTensorType, secondChunkData);
360 std::string firstName = dataOperand.getName().str() +
"_split_0";
361 std::string secondName = dataOperand.getName().str() +
"_split_1";
365 while (deviceOp.lookupSymbol(firstName)) {
367 dataOperand.getName().str() +
"_split_0_" + std::to_string(counter++);
370 while (deviceOp.lookupSymbol(secondName)) {
372 dataOperand.getName().str() +
"_split_1_" + std::to_string(counter++);
376 rewriter.setInsertionPoint(originalGlobal);
377 memref::GlobalOp::create(rewriter, loc, firstName,
378 rewriter.getStringAttr(
"private"), firstChunkType,
379 firstChunkAttr,
true,
nullptr);
381 memref::GlobalOp::create(rewriter, loc, secondName,
382 rewriter.getStringAttr(
"private"), secondChunkType,
383 secondChunkAttr,
true,
nullptr);
386 rewriter.setInsertionPoint(op);
388 auto firstGetGlobal =
389 memref::GetGlobalOp::create(rewriter, loc, firstChunkType, firstName);
390 auto secondGetGlobal =
391 memref::GetGlobalOp::create(rewriter, loc, secondChunkType, secondName);
393 uint32_t baseAddr = op.getAddress();
395 AIEX::NpuBlockWriteOp::create(rewriter, loc, baseAddr,
396 firstGetGlobal.getResult(),
nullptr,
nullptr,
399 AIEX::NpuBlockWriteOp::create(rewriter, loc, baseAddr + firstChunkSize * 4,
400 secondGetGlobal.getResult(),
nullptr,
nullptr,
404 rewriter.eraseOp(op);
406 LLVM_DEBUG(llvm::outs()
407 <<
"Split NpuBlockWriteOp with data size: " << dataSizeBytes
408 <<
" bytes into chunks of " << firstChunkSize <<
" and "
409 << secondChunkSize <<
" elements\n");
415struct AIENpuToCertPass
416 : xilinx::AIEX::impl::AIENpuToCertBase<AIENpuToCertPass> {
417 void runOnOperation()
override {
418 ConversionTarget target(getContext());
419 target.addIllegalOp<AIE::RuntimeSequenceOp>();
421 target.addLegalOp<AIEX::CertApplyOffset57Op>();
422 target.addLegalOp<AIEX::CertJobOp>();
423 target.addLegalOp<AIEX::CertMaskWrite32Op>();
424 target.addLegalOp<AIEX::CertUcDmaWriteDesSyncOp>();
425 target.addLegalOp<AIEX::CertUcDmaChainOp>();
426 target.addLegalOp<AIEX::CertUcDmaBdOp>();
427 target.addLegalOp<AIEX::CertWrite32Op>();
428 target.addLegalOp<AIEX::CertWaitTCTSOp>();
429 target.addLegalDialect<AIE::AIEDialect>();
431 RewritePatternSet p0(&getContext());
432 p0.insert<RuntimeSequenceToCertJob>(&getContext());
433 p0.insert<NpuAddressPatchToCertApplyOffset57>(&getContext());
435 if (failed(applyPartialConversion(getOperation(), target, std::move(p0))))
438 target.addIllegalOp<AIEX::NpuAddressPatchOp>();
441 RewritePatternSet p1(&getContext());
442 p1.insert<NpuAddressPatchToCertApplyOffset57>(&getContext());
444 if (failed(applyPartialConversion(getOperation(), target, std::move(p1))))
449 RewritePatternSet p(&getContext());
450 p.insert<SplitNpuBlockWriteOpPattern>(&getContext());
451 if (failed(applyPatternsGreedily(getOperation(), std::move(p))))
455 target.addIllegalOp<AIEX::NpuBlockWriteOp>();
456 target.addIllegalOp<AIEX::NpuMaskWrite32Op>();
457 target.addIllegalOp<AIEX::NpuSyncOp>();
458 target.addIllegalOp<AIEX::NpuWrite32Op>();
462 RewritePatternSet p(&getContext());
463 p.insert<NpuBlockWriteToCertUcDma>(&getContext());
464 p.insert<NpuMaskWrite32ToCertMaskWrite32>(&getContext());
465 p.insert<NpuWrite32ToCertWrite32>(&getContext());
466 p.insert<NpuSyncToCertWaitTCTS>(&getContext());
468 if (failed(applyPartialConversion(getOperation(), target, std::move(p))))
474 RewritePatternSet p(&getContext());
475 p.insert<MergeConsecutiveCertUcDmaWriteDesSyncOps>(&getContext());
476 if (failed(applyPatternsGreedily(getOperation(), std::move(p))))
484static uint32_t estimateCost(AIEX::CertJobOp op, uint32_t split_target,
485 Block::iterator &split_iter) {
487 uint32_t text_cost = 32;
488 uint32_t data_cost = 0;
489 uint32_t split_cost = 0;
490 for (
auto &o : op.getBody().front().getOperations()) {
491 if (!split_cost && (text_cost + data_cost) >= split_target) {
492 split_iter = Block::iterator(&o);
493 split_cost = text_cost + data_cost;
495 if (isa<AIEX::CertLocalBarrierOp>(o)) {
497 }
else if (isa<AIEX::CertRemoteBarrierOp>(o)) {
499 }
else if (isa<AIEX::CertWaitTCTSOp>(o)) {
501 }
else if (isa<AIEX::CertMaskWrite32Op>(o)) {
503 }
else if (isa<AIEX::CertWrite32Op>(o)) {
505 }
else if (isa<AIEX::CertApplyOffset57Op>(o)) {
507 }
else if (
auto syncOp = dyn_cast<AIEX::CertUcDmaWriteDesSyncOp>(o)) {
510 StringRef sym_name = syncOp.getSymbol();
511 auto chain = dyn_cast_if_present<AIEX::CertUcDmaChainOp>(
512 op->getParentOfType<AIE::DeviceOp>().lookupSymbol(sym_name));
515 for (
auto bdOp : chain.getBody().front().getOps<AIEX::CertUcDmaBdOp>()) {
517 StringRef data_sym_name = bdOp.getRemoteAddress();
518 auto global = dyn_cast_if_present<memref::GlobalOp>(
519 op->getParentOfType<AIE::DeviceOp>().lookupSymbol(data_sym_name));
522 auto initVal = global.getInitialValue();
525 auto data = dyn_cast<DenseIntElementsAttr>(*initVal);
528 data_cost += data.getNumElements() * 4;
532 return text_cost + data_cost;
537 using OpRewritePattern::OpRewritePattern;
539 LogicalResult matchAndRewrite(AIEX::CertJobOp op,
540 PatternRewriter &rewriter)
const override {
542 constexpr uint32_t split_threshold = cert_page_size;
544 Block::iterator split_iter;
545 uint32_t cost = estimateCost(op, cert_page_size / 2, split_iter);
546 LLVM_DEBUG(llvm::outs() <<
"Estimate cost for job: " << op.getJobId()
547 <<
" is " << cost <<
"\n");
549 if (cost < split_threshold)
552 auto loc = op.getLoc();
553 op->getParentOfType<AIE::DeviceOp>().walk([&](AIEX::CertJobOp certJobOp) {
554 if (certJobOp.getJobId() > op.getJobId())
555 certJobOp.setJobId(certJobOp.getJobId() + 1);
559 auto jobId = op.getJobId();
560 auto newJobOp0 = AIEX::CertJobOp::create(rewriter, loc, jobId);
561 auto newJobOp1 = AIEX::CertJobOp::create(rewriter, loc, jobId + 1);
563 newJobOp0.getBody().push_back(
new Block());
564 rewriter.setInsertionPointToStart(&newJobOp0.getBody().front());
565 for (Block::iterator oi = op.getBody().front().getOperations().begin();
566 oi != split_iter; ++oi) {
569 AIEX::CertJobOp::ensureTerminator(newJobOp0.getBody(), rewriter, loc);
571 newJobOp1.getBody().push_back(
new Block());
572 rewriter.setInsertionPointToStart(&newJobOp1.getBody().front());
573 for (Block::iterator oi = split_iter;
574 oi != op.getBody().front().getOperations().end(); ++oi) {
578 rewriter.eraseOp(op);
583struct AIECertPagesPass
584 : xilinx::AIEX::impl::AIECertPagesBase<AIECertPagesPass> {
585 void runOnOperation()
override {
587 RewritePatternSet p0(&getContext());
588 p0.insert<SplitNpuBlockWriteOpPattern>(&getContext());
589 if (failed(applyPatternsGreedily(getOperation(), std::move(p0))))
593 RewritePatternSet p1(&getContext());
594 p1.insert<SplitCertJobOpPattern>(&getContext());
595 if (failed(applyPatternsGreedily(getOperation(), std::move(p1))))
603 return std::make_unique<AIENpuToCertPass>();
607 return std::make_unique<AIECertPagesPass>();
std::unique_ptr< mlir::OperationPass< AIE::DeviceOp > > createAIENpuToCertPass()
std::unique_ptr< mlir::OperationPass< AIE::DeviceOp > > createAIECertPagesPass()
const AIETargetModel & getTargetModel(mlir::Operation *op)