16#include "mlir/Pass/Pass.h"
17#include "mlir/Transforms/DialectConversion.h"
28 using OpConversionPattern::OpConversionPattern;
30 Write32SymToAddr(MLIRContext *context, PatternBenefit benefit = 1)
34 matchAndRewrite(NpuWrite32Op op, OpAdaptor adaptor,
35 ConversionPatternRewriter &rewriter)
const override {
40 auto device = op->getParentOfType<AIE::DeviceOp>();
41 auto buffer = device.lookupSymbol<AIE::BufferOp>(*op.getBuffer());
43 return op->emitError(
"buffer '" + *op.getBuffer() +
44 "' not found in device");
46 if (!buffer.getAddress())
47 return op->emitError(
"buffer must have address assigned");
50 uint32_t address =
static_cast<uint32_t
>(*buffer.getAddress()) +
51 op.getAddress() *
sizeof(uint32_t);
52 auto col = buffer.getTileOp().getCol();
53 auto row = buffer.getTileOp().getRow();
55 ((row & 0xff) << tm.
getRowShift()) | (address & 0xFFFFF);
57 rewriter.replaceOpWithNewOp<NpuWrite32Op>(op, address, op.getValue(),
58 nullptr,
nullptr,
nullptr);
64 using OpConversionPattern::OpConversionPattern;
66 BlockWriteSymToAddr(MLIRContext *context, PatternBenefit benefit = 1)
70 matchAndRewrite(NpuBlockWriteOp op, OpAdaptor adaptor,
71 ConversionPatternRewriter &rewriter)
const override {
76 auto device = op->getParentOfType<AIE::DeviceOp>();
78 auto buffer = device.lookupSymbol<AIE::BufferOp>(*op.getBuffer());
80 return op->emitError(
"buffer '" + *op.getBuffer() +
81 "' not found in device");
83 if (!buffer.getAddress())
84 return op->emitError(
"buffer must have address assigned");
87 uint32_t address =
static_cast<uint32_t
>(*buffer.getAddress()) +
88 op.getAddress() *
sizeof(uint32_t);
89 auto col = buffer.getTileOp().getCol();
90 auto row = buffer.getTileOp().getRow();
92 ((row & 0xff) << tm.
getRowShift()) | (address & 0xFFFFF);
94 rewriter.replaceOpWithNewOp<NpuBlockWriteOp>(op, address, op.getData(),
95 nullptr,
nullptr,
nullptr);
101 using OpConversionPattern::OpConversionPattern;
103 MaskWrite32SymToAddr(MLIRContext *context, PatternBenefit benefit = 1)
107 matchAndRewrite(NpuMaskWrite32Op op, OpAdaptor adaptor,
108 ConversionPatternRewriter &rewriter)
const override {
113 auto device = op->getParentOfType<AIE::DeviceOp>();
115 auto buffer = device.lookupSymbol<AIE::BufferOp>(*op.getBuffer());
117 return op->emitError(
"buffer '" + *op.getBuffer() +
118 "' not found in device");
120 if (!buffer.getAddress())
121 return op->emitError(
"buffer must have address assigned");
124 uint32_t address =
static_cast<uint32_t
>(*buffer.getAddress()) +
125 op.getAddress() *
sizeof(uint32_t);
126 auto col = buffer.getTileOp().getCol();
127 auto row = buffer.getTileOp().getRow();
129 ((row & 0xff) << tm.
getRowShift()) | (address & 0xFFFFF);
131 rewriter.replaceOpWithNewOp<NpuMaskWrite32Op>(
132 op, address, op.getValue(), op.getMask(),
nullptr,
nullptr,
nullptr);
138 using OpConversionPattern::OpConversionPattern;
140 RtpToWrite32Pattern(MLIRContext *context, PatternBenefit benefit = 1)
144 matchAndRewrite(NpuWriteRTPOp op, OpAdaptor adaptor,
145 ConversionPatternRewriter &rewriter)
const override {
147 auto device = op->getParentOfType<AIE::DeviceOp>();
149 auto buffer = device.lookupSymbol<AIE::BufferOp>(op.getBuffer());
151 op->emitError(
"buffer '" + op.getBuffer() +
"' not found in device");
155 if (!buffer.getAddress()) {
156 op->emitError(
"buffer must have address assigned");
159 AIE::TileOp tile = buffer.getTileOp();
161 uint32_t idx = op.getIndex() *
sizeof(uint32_t);
162 uint32_t address = buffer.getAddress().value() + idx;
164 rewriter.create<NpuWrite32Op>(op->getLoc(), address, op.getValue(),
nullptr,
165 rewriter.getI32IntegerAttr(tile.getCol()),
166 rewriter.getI32IntegerAttr(tile.getRow()));
168 rewriter.eraseOp(op);
176 using OpConversionPattern::OpConversionPattern;
178 PushQueuetoWrite32Pattern(MLIRContext *context, PatternBenefit benefit = 1)
182 matchAndRewrite(NpuPushQueueOp op, OpAdaptor adaptor,
183 ConversionPatternRewriter &rewriter)
const override {
187 op.getColumn(), op.getRow(), op.getChannel(), op.getDirection());
190 if (op.getIssueToken()) {
193 AIE::TileOp shimTile = AIE::TileOp::getOrCreate(
194 rewriter, op->getParentOfType<AIE::DeviceOp>(), op.getColumn(), 0);
195 if (shimTile->hasAttr(
"controller_id")) {
196 AIE::PacketInfoAttr controller_id_attr =
197 shimTile->getAttrOfType<AIE::PacketInfoAttr>(
"controller_id");
198 uint32_t data = controller_id_attr.getPktId() << 8;
199 uint32_t mask = 0x00001F00;
200 rewriter.create<NpuMaskWrite32Op>(op->getLoc(), ctrl_offset, data, mask,
201 nullptr,
nullptr,
nullptr);
206 uint32_t queue_offset = ctrl_offset + 0x4;
209 uint32_t bd_id = op.getBdId();
210 uint32_t repeat_cnt = op.getRepeatCount();
213 cmd |= (repeat_cnt & 0xFF) << 16;
214 if (op.getIssueToken())
217 rewriter.create<NpuWrite32Op>(op->getLoc(), queue_offset, cmd,
nullptr,
219 rewriter.eraseOp(op);
225 using OpConversionPattern::OpConversionPattern;
232 PatternBenefit benefit = 1)
236 matchAndRewrite(NpuDmaMemcpyNdOp op, OpAdaptor adaptor,
237 ConversionPatternRewriter &rewriter)
const override {
239 BaseMemRefType bufferType = op.getMemref().getType();
240 auto *ctx = op->getContext();
241 auto i32ty = IntegerType::get(ctx, 32);
242 auto zero = IntegerAttr::get(i32ty, 0);
243 auto memref = adaptor.getMemref();
245 auto dev = op->getParentOfType<AIE::DeviceOp>();
249 auto infoOp = allocGetter.
get(dev, op.getMetadata());
251 return op->emitOpError(
"couldn't find shim_dma_allocation op.");
254 auto channelDir = infoOp->getChannelDir();
255 bool isMM2S = channelDir == AIE::DMAChannelDir::MM2S;
256 int col = infoOp->getCol();
261 auto buffer_length = zero;
262 auto buffer_offset = zero;
263 auto enable_packet = zero;
264 auto out_of_order_id = zero;
265 auto packet_id = zero;
266 auto packet_type = zero;
268 auto d0_stride = zero;
270 auto d1_stride = zero;
272 auto d2_stride = zero;
273 auto iteration_current = zero;
274 auto iteration_size = zero;
275 auto iteration_stride = zero;
278 auto use_next_bd = zero;
279 auto valid_bd = zero;
280 auto lock_rel_val = zero;
281 auto lock_rel_id = zero;
282 auto lock_acq_enable = zero;
283 auto lock_acq_val = zero;
284 auto lock_acq_id = zero;
285 auto d0_zero_before = zero;
286 auto d1_zero_before = zero;
287 auto d2_zero_before = zero;
288 auto d0_zero_after = zero;
289 auto d1_zero_after = zero;
290 auto d2_zero_after = zero;
291 auto burst_length = zero;
293 auto issue_token = BoolAttr::get(ctx,
false);
294 auto repeat_count = zero;
295 llvm::SmallVector<int64_t, 4> inputSizes = llvm::map_to_vector(
296 llvm::reverse(op.getMixedSizes()),
297 [](OpFoldResult s) { return getConstantIntValue(s).value(); });
298 llvm::SmallVector<int64_t, 4> inputStrides = llvm::map_to_vector(
299 llvm::reverse(op.getMixedStrides()),
300 [](OpFoldResult s) { return getConstantIntValue(s).value(); });
301 llvm::SmallVector<int64_t, 4> sizes(4);
302 llvm::SmallVector<int64_t, 4> strides(4);
304 inputStrides, sizes, strides);
305 int64_t offset = op.getOffsetInBytes();
308 column = IntegerAttr::get(i32ty, col);
311 row = IntegerAttr::get(i32ty, 0);
313 bool skipTransformationChecks = op.isLinearTransferWithoutTransformation();
315 inputStrides, sizes, strides,
316 skipTransformationChecks))) {
321 AIEX::RuntimeSequenceOp seq_op =
322 op->getParentOfType<AIEX::RuntimeSequenceOp>();
324 op->emitOpError(
"NpuDmaMemcpyNdOps must have RuntimeSequenceOp parent at "
325 "time of lowering.");
328 Block &entryBB = seq_op.getBody().front();
330 for (
int i = 0, e = entryBB.getNumArguments(); i < e; i++) {
331 if (entryBB.getArgument(i) == memref) {
340 bd_id = IntegerAttr::get(i32ty, op.getId());
343 uint64_t buffer_length_val = inputSizes[0] * op.getElementTypeBitwidth() /
344 targetModel.getAddressGenGranularity();
345 if (inputSizes.size() > 1) {
346 for (
size_t i = 1; i < std::min(inputSizes.size(), (
size_t)3); i++) {
347 buffer_length_val *= inputSizes[i];
350 buffer_length = IntegerAttr::get(i32ty, buffer_length_val);
353 buffer_offset = IntegerAttr::get(i32ty, 0);
356 if (
auto packetInfo = op.getPacket()) {
357 enable_packet = IntegerAttr::get(i32ty, 1);
358 packet_type = IntegerAttr::get(i32ty, packetInfo->getPktType());
359 packet_id = IntegerAttr::get(i32ty, packetInfo->getPktId());
364 if (!op.isLinearTransferWithoutTransformation()) {
366 d0_size = IntegerAttr::get(i32ty, sizes[0]);
367 d0_stride = IntegerAttr::get(i32ty, strides[0]);
370 d1_size = IntegerAttr::get(i32ty, sizes[1]);
371 d1_stride = IntegerAttr::get(i32ty, strides[1]);
374 d2_stride = IntegerAttr::get(i32ty, strides[2]);
377 if (targetModel.isMemTile(col, 0))
378 d2_size = IntegerAttr::get(i32ty, sizes[2]);
380 d2_size = IntegerAttr::get(i32ty, 0);
383 if (inputSizes[3] > 1) {
384 if (inputStrides[3] > 0) {
385 iteration_size = IntegerAttr::get(i32ty, sizes[3]);
386 iteration_stride = IntegerAttr::get(i32ty, strides[3]);
392 iteration_size = zero;
393 iteration_stride = zero;
396 repeat_count = IntegerAttr::get(i32ty, sizes[3]);
403 valid_bd = IntegerAttr::get(i32ty, 1);
416 d0_zero_before = IntegerAttr::get(i32ty, op.getD0ZeroBefore());
419 d1_zero_before = IntegerAttr::get(i32ty, op.getD1ZeroBefore());
422 d2_zero_before = IntegerAttr::get(i32ty, op.getD2ZeroBefore());
425 d0_zero_after = IntegerAttr::get(i32ty, op.getD0ZeroAfter());
428 d1_zero_after = IntegerAttr::get(i32ty, op.getD1ZeroAfter());
431 d2_zero_after = IntegerAttr::get(i32ty, op.getD2ZeroAfter());
434 burst_length = IntegerAttr::get(i32ty, op.getBurstLength());
437 issue_token = BoolAttr::get(ctx, op.getIssueToken());
441 issue_token = BoolAttr::get(ctx,
true);
443 if (targetModel.isMemTile(col, 0) && (!isMM2S) &&
444 (op.getD0ZeroBefore() != 0 || op.getD0ZeroAfter() != 0 ||
445 op.getD1ZeroBefore() != 0 || op.getD1ZeroAfter() != 0 ||
446 op.getD2ZeroBefore() != 0 || op.getD2ZeroAfter() != 0))
447 op->emitOpError(
"MemTile supports zero padding only on MM2S direction");
450 rewriter.create<NpuWriteBdOp>(
451 op->getLoc(), column, bd_id, buffer_length, buffer_offset,
452 enable_packet, out_of_order_id, packet_id, packet_type, d0_size,
453 d0_stride, d1_size, d1_stride, d2_size, d2_stride, iteration_current,
454 iteration_size, iteration_stride, next_bd,
row, use_next_bd, valid_bd,
455 lock_rel_val, lock_rel_id, lock_acq_enable, lock_acq_val, lock_acq_id,
456 d0_zero_before, d1_zero_before, d2_zero_before, d0_zero_after,
457 d1_zero_after, d2_zero_after, burst_length);
461 uint64_t addr = targetModel.getDmaBdAddress(col, 0, op.getId()) +
462 targetModel.getDmaBdAddressOffset(col, 0);
463 rewriter.create<NpuAddressPatchOp>(op->getLoc(), addr, arg_idx, offset);
466 rewriter.create<NpuPushQueueOp>(
467 op->getLoc(), column,
row, infoOp->getChannelDirAttr(),
468 infoOp->getChannelIndexAttr(), issue_token, repeat_count, bd_id);
470 rewriter.eraseOp(op);
484 using OpConversionPattern::OpConversionPattern;
486 DmaWaitToSyncPattern(MLIRContext *context,
488 PatternBenefit benefit = 1)
492 matchAndRewrite(NpuDmaWaitOp op, OpAdaptor adaptor,
493 ConversionPatternRewriter &rewriter)
const override {
494 AIE::DeviceOp dev = op->getParentOfType<AIE::DeviceOp>();
496 return op->emitError(
"couldn't find parent of type DeviceOp");
498 std::optional<AIE::ShimDMAAllocationOp> shimDmaAllocOp =
499 allocGetter.
get(dev, op.getSymbol());
500 if (!shimDmaAllocOp) {
501 return op->emitError(
"couldn't find shim_dma_allocation op");
506 (void)rewriter.replaceOpWithNewOp<NpuSyncOp>(
507 op, shimDmaAllocOp->getCol(), 0,
508 static_cast<uint32_t
>(shimDmaAllocOp->getChannelDir()),
509 shimDmaAllocOp->getChannelIndex(), 1, 1);
516 using OpConversionPattern::OpConversionPattern;
522 WriteBdToBlockWritePattern(MLIRContext *context,
int &cachedId,
523 PatternBenefit benefit = 1)
527 matchAndRewrite(NpuWriteBdOp op, OpAdaptor adaptor,
528 ConversionPatternRewriter &rewriter)
const override {
530 AIE::DeviceOp dev = op->getParentOfType<AIE::DeviceOp>();
534 if (isa<AIE::AIE2TargetModel>(tm))
538 "Unsupported AIETargetModel in WriteBdToBlockWritePattern");
540 std::vector<uint32_t> words(num_words, 0);
542 uint32_t bd_id = op.getBdId();
543 int col = op.getColumn();
544 int row = op.getRow();
545 uint64_t bd_addr = tm.getDmaBdAddress(col, row, bd_id);
546 if (tm.isShimNOCTile(col, row)) {
548 words[0] = op.getBufferLength();
551 words[1] = op.getBufferOffset();
555 words[2] |= (op.getEnablePacket() & 0x1) << 30;
556 words[2] |= (op.getOutOfOrderId() & 0x3f) << 24;
557 words[2] |= (op.getPacketId() & 0x1f) << 19;
558 words[2] |= (op.getPacketType() & 0x7) << 16;
562 words[3] |= (op.getD0Size() & 0x3ff) << 20;
563 words[3] |= op.getD0Stride() & 0xfffff;
568 words[4] |= (op.getD1Size() & 0x3ff) << 20;
569 words[4] |= op.getD1Stride() & 0xfffff;
573 words[5] = op.getD2Stride() & 0xfffff;
576 words[6] |= (op.getIterationCurrent() & 0x3f) << 26;
577 words[6] |= (op.getIterationSize() & 0x3f) << 20;
578 words[6] |= op.getIterationStride() & 0xfffff;
582 words[7] |= (op.getNextBd() & 0xf) << 27;
583 words[7] |= (op.getUseNextBd() & 0x1) << 26;
584 words[7] |= (op.getValidBd() & 0x1) << 25;
585 words[7] |= (op.getLockRelVal() & 0x7f) << 18;
586 words[7] |= (op.getLockRelId() & 0xf) << 13;
587 words[7] |= (op.getLockAcqEnable() & 0x1) << 12;
588 words[7] |= (op.getLockAcqVal() & 0x7f) << 5;
589 words[7] |= op.getLockAcqId() & 0xf;
590 if (op.getD0ZeroBefore() || op.getD1ZeroBefore() ||
591 op.getD2ZeroBefore() || op.getD0ZeroAfter() || op.getD1ZeroAfter() ||
592 op.getD2ZeroAfter()) {
593 op->emitError(
"Zero padding is only available on MemTile");
595 }
else if (tm.isMemTile(op.getColumn(), op.getRow())) {
598 words[0] |= (op.getEnablePacket() & 0x1) << 31;
599 words[0] |= (op.getPacketType() & 0x7) << 28;
600 words[0] |= (op.getPacketId() & 0x1f) << 23;
601 words[0] |= (op.getOutOfOrderId() & 0x3f) << 17;
602 words[0] |= op.getBufferLength() & 0x1ffff;
605 words[1] |= (op.getD0ZeroBefore() & 0x3F) << 26;
606 words[1] |= (op.getNextBd() & 0x3f) << 20;
607 words[1] |= (op.getUseNextBd() & 0x1) << 19;
608 words[1] |= op.getBufferOffset() & 0x7ffff;
611 words[2] |= (op.getD0Size() & 0x3ff) << 17;
612 words[2] |= op.getD0Stride() & 0x1ffff;
616 words[3] |= (op.getD1ZeroBefore() & 0x1F) << 27;
617 words[3] |= (op.getD1Size() & 0x3ff) << 17;
618 words[3] |= op.getD1Stride() & 0x1ffff;
622 words[4] |= (op.getD2ZeroBefore() & 0xF) << 27;
623 words[4] |= op.getD2Stride() & 0x1ffff;
627 words[5] |= (op.getD2ZeroAfter() & 0xF) << 28;
628 words[5] |= (op.getD1ZeroAfter() & 0x1F) << 23;
629 words[5] |= (op.getD0ZeroAfter() & 0x3F) << 17;
632 words[6] |= (op.getIterationCurrent() & 0x3f) << 23;
633 words[6] |= (op.getIterationSize() & 0x3f) << 17;
634 words[6] |= op.getIterationStride() & 0x1ffff;
637 words[7] |= (op.getValidBd() & 0x1) << 31;
638 words[7] |= (op.getLockRelVal() & 0x7f) << 24;
639 words[7] |= (op.getLockRelId() & 0xff) << 16;
640 words[7] |= (op.getLockAcqEnable() & 0x1) << 15;
641 words[7] |= (op.getLockAcqVal() & 0x7f) << 8;
642 words[7] |= op.getLockAcqId() & 0xff;
645 op->emitError(
"Run-time DMA configuration is supported only for "
646 "ShimTiles and MemTiles currently.");
650 MemRefType memrefType = MemRefType::get({num_words}, rewriter.getI32Type());
651 TensorType tensorType =
652 RankedTensorType::get({num_words}, rewriter.getI32Type());
653 memref::GlobalOp global =
nullptr;
654 auto initVal = DenseElementsAttr::get<uint32_t>(tensorType, words);
655 auto otherGlobals = dev.getOps<memref::GlobalOp>();
656 for (
auto g : otherGlobals) {
659 if (g.getType() != memrefType)
661 auto otherValue = g.getInitialValue();
664 if (*otherValue != initVal)
670 OpBuilder::InsertionGuard guard(rewriter);
671 rewriter.setInsertionPoint(
672 op->getParentOfType<AIEX::RuntimeSequenceOp>());
673 std::string name =
"blockwrite_data_";
674 while (dev.lookupSymbol(name + std::to_string(cachedId)))
676 name += std::to_string(cachedId);
677 global = rewriter.create<memref::GlobalOp>(
678 op->getLoc(), name, rewriter.getStringAttr(
"private"), memrefType,
679 initVal,
true,
nullptr);
681 auto memref = rewriter.create<memref::GetGlobalOp>(op->getLoc(), memrefType,
683 (void)rewriter.replaceOpWithNewOp<NpuBlockWriteOp>(
684 op, rewriter.getUI32IntegerAttr(bd_addr), memref.getResult(),
nullptr,
690int WriteBdToBlockWritePattern::cachedId = 0;
692struct AIEDmaToNpuPass : AIEDmaToNpuBase<AIEDmaToNpuPass> {
694 void getDependentDialects(DialectRegistry ®istry)
const override {
695 registry.insert<memref::MemRefDialect>();
698 void runOnOperation()
override {
702 AIE::DeviceOp device = getOperation();
704 ConversionTarget target(getContext());
705 target.addLegalDialect<AIEXDialect>();
706 target.addLegalDialect<memref::MemRefDialect>();
707 target.addLegalOp<AIE::BufferOp>();
708 target.addLegalOp<AIE::ShimDMAAllocationOp>();
709 target.addLegalOp<AIE::TileOp>();
711 target.addIllegalOp<NpuDmaMemcpyNdOp>();
712 target.addIllegalOp<NpuDmaWaitOp>();
713 target.addIllegalOp<NpuPushQueueOp>();
714 target.addIllegalOp<NpuWriteRTPOp>();
715 target.addIllegalOp<NpuWriteBdOp>();
716 target.addDynamicallyLegalOp<NpuWrite32Op>(
717 [&](NpuWrite32Op op) {
return !op.getBuffer(); });
718 target.addDynamicallyLegalOp<NpuBlockWriteOp>(
719 [&](NpuBlockWriteOp op) {
return !op.getBuffer(); });
720 target.addDynamicallyLegalOp<NpuMaskWrite32Op>(
721 [&](NpuMaskWrite32Op op) {
return !op.getBuffer(); });
723 RewritePatternSet patterns(&getContext());
724 patterns.insert<BlockWriteSymToAddr>(&getContext());
725 patterns.insert<DmaToNpuPattern>(&getContext(), cachingGetter);
726 patterns.insert<DmaWaitToSyncPattern>(&getContext(), cachingGetter);
727 patterns.insert<MaskWrite32SymToAddr>(&getContext());
728 patterns.insert<PushQueuetoWrite32Pattern>(&getContext());
729 patterns.insert<RtpToWrite32Pattern>(&getContext());
730 patterns.insert<Write32SymToAddr>(&getContext());
731 patterns.insert<WriteBdToBlockWritePattern>(&getContext());
733 if (failed(applyPartialConversion(device, target, std::move(patterns))))
741 return std::make_unique<AIEDmaToNpuPass>();
virtual uint32_t getDmaControlAddress(int col, int row, int channel, AIE::DMAChannelDir direction) const =0
Return the array address of the dma task queue register for the given col, row, channel and direction...
virtual uint32_t getColumnShift() const =0
virtual uint32_t getRowShift() const =0
std::unique_ptr< mlir::OperationPass< AIE::DeviceOp > > createAIEDmaToNpuPass()
void getHardwareStridesWraps(const AIE::AIETargetModel &targetModel, mlir::Operation *op, mlir::BaseMemRefType referencedBufType, llvm::SmallVector< int64_t, 4 > inputSizes, llvm::SmallVector< int64_t, 4 > inputStrides, llvm::SmallVector< int64_t, 4 > &sizes, llvm::SmallVector< int64_t, 4 > &strides)
mlir::LogicalResult verifyStridesWraps(mlir::Operation *forOp, mlir::BaseMemRefType referencedBufType, int tileCol, int tileRow, llvm::SmallVector< int64_t, 4 > inputSizes, llvm::SmallVector< int64_t, 4 > inputStrides, llvm::SmallVector< int64_t, 4 > hardwareSizes, llvm::SmallVector< int64_t, 4 > hardwareStrides, bool skipTransformationChecks=false)
uint32_t getShimBurstLengthEncoding(const AIE::AIETargetModel &tm, uint32_t burstLength)
const AIETargetModel & getTargetModel(mlir::Operation *op)
std::optional< AIE::ShimDMAAllocationOp > get(DeviceOp dev, mlir::StringRef sym_name)