17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Pass/Pass.h"
19#include "mlir/Transforms/DialectConversion.h"
24#define GEN_PASS_DEF_AIEDMATONPU
25#include "aie/Dialect/AIEX/Transforms/AIEXPasses.h.inc"
35 using OpConversionPattern::OpConversionPattern;
37 Write32SymToAddr(MLIRContext *context, PatternBenefit benefit = 1)
41 matchAndRewrite(NpuWrite32Op op, OpAdaptor adaptor,
42 ConversionPatternRewriter &rewriter)
const override {
47 std::optional<uint32_t> address = op.getAbsoluteAddress();
48 if (!address.has_value()) {
52 rewriter.replaceOpWithNewOp<NpuWrite32Op>(op, *address, op.getValue(),
53 nullptr,
nullptr,
nullptr);
59 using OpConversionPattern::OpConversionPattern;
61 BlockWriteSymToAddr(MLIRContext *context, PatternBenefit benefit = 1)
65 matchAndRewrite(NpuBlockWriteOp op, OpAdaptor adaptor,
66 ConversionPatternRewriter &rewriter)
const override {
71 std::optional<uint32_t> address = op.getAbsoluteAddress();
72 if (!address.has_value()) {
75 rewriter.replaceOpWithNewOp<NpuBlockWriteOp>(op, *address, op.getData(),
76 nullptr,
nullptr,
nullptr);
82 using OpConversionPattern::OpConversionPattern;
84 MaskWrite32SymToAddr(MLIRContext *context, PatternBenefit benefit = 1)
88 matchAndRewrite(NpuMaskWrite32Op op, OpAdaptor adaptor,
89 ConversionPatternRewriter &rewriter)
const override {
94 std::optional<uint32_t> absoluteAddress = op.getAbsoluteAddress();
95 if (!absoluteAddress.has_value()) {
99 rewriter.replaceOpWithNewOp<NpuMaskWrite32Op>(op, *absoluteAddress,
100 op.getValue(), op.getMask(),
101 nullptr,
nullptr,
nullptr);
107 using OpConversionPattern::OpConversionPattern;
109 RtpToWrite32Pattern(MLIRContext *context, PatternBenefit benefit = 1)
113 matchAndRewrite(NpuWriteRTPOp op, OpAdaptor adaptor,
114 ConversionPatternRewriter &rewriter)
const override {
116 auto device = op->getParentOfType<AIE::DeviceOp>();
118 auto buffer = device.lookupSymbol<AIE::BufferOp>(op.getBuffer());
120 op->emitError(
"buffer '" + op.getBuffer() +
"' not found in device");
124 if (!buffer.getAddress()) {
125 op->emitError(
"buffer must have address assigned");
128 AIE::TileOp tile = buffer.getTileOp();
130 uint32_t idx = op.getIndex() *
sizeof(uint32_t);
131 uint32_t address = buffer.getAddress().value() + idx;
133 NpuWrite32Op::create(rewriter, op->getLoc(), address, op.getValue(),
134 nullptr, rewriter.getI32IntegerAttr(tile.getCol()),
135 rewriter.getI32IntegerAttr(tile.getRow()));
137 rewriter.eraseOp(op);
145 using OpConversionPattern::OpConversionPattern;
147 PushQueuetoWrite32Pattern(MLIRContext *context, PatternBenefit benefit = 1)
151 matchAndRewrite(NpuPushQueueOp op, OpAdaptor adaptor,
152 ConversionPatternRewriter &rewriter)
const override {
155 uint32_t ctrl_offset = tm.getDmaControlAddress(
156 op.getColumn(), op.getRow(), op.getChannel(), op.getDirection());
159 if (op.getIssueToken()) {
162 AIE::TileOp shimTile = AIE::TileOp::getOrCreate(
163 rewriter, op->getParentOfType<AIE::DeviceOp>(), op.getColumn(),
165 if (shimTile->hasAttr(
"controller_id")) {
166 AIE::PacketInfoAttr controller_id_attr =
167 shimTile->getAttrOfType<AIE::PacketInfoAttr>(
"controller_id");
168 uint32_t data = controller_id_attr.getPktId() << 8;
169 uint32_t mask = 0x00001F00;
170 NpuMaskWrite32Op::create(rewriter, op->getLoc(), ctrl_offset, data,
171 mask,
nullptr,
nullptr,
nullptr);
176 uint32_t queue_offset = ctrl_offset + 0x4;
179 uint32_t bd_id = op.getBdId();
180 uint32_t repeat_cnt = op.getRepeatCount();
183 cmd |= (repeat_cnt & 0xFF) << 16;
184 if (op.getIssueToken())
187 NpuWrite32Op::create(rewriter, op->getLoc(), queue_offset, cmd,
nullptr,
189 rewriter.eraseOp(op);
195 using OpConversionPattern::OpConversionPattern;
198 DmaToNpuPattern(MLIRContext *context, PatternBenefit benefit = 1)
202 matchAndRewrite(NpuDmaMemcpyNdOp op, OpAdaptor adaptor,
203 ConversionPatternRewriter &rewriter)
const override {
205 BaseMemRefType bufferType = op.getMemref().getType();
206 auto *ctx = op->getContext();
207 auto i32ty = IntegerType::get(ctx, 32);
208 auto zero = IntegerAttr::get(i32ty, 0);
209 auto memref = adaptor.getMemref();
211 auto dev = op->getParentOfType<AIE::DeviceOp>();
215 auto infoOp = AIE::ShimDMAAllocationOp::getForSymbol(
216 dev, op.getMetadata().getRootReference());
218 return op->emitOpError(
"couldn't find shim_dma_allocation op.");
221 AIE::TileOp shimTile = infoOp.getTileOp();
223 return op->emitOpError(
224 "shim_dma_allocation op must reference a valid TileOp.");
227 auto channelDir = infoOp.getChannelDir();
228 bool isMM2S = channelDir == AIE::DMAChannelDir::MM2S;
229 int tileCol = shimTile.getCol();
230 int tileRow = shimTile.getRow();
235 auto buffer_length = zero;
236 auto buffer_offset = zero;
237 auto enable_packet = zero;
238 auto out_of_order_id = zero;
239 auto packet_id = zero;
240 auto packet_type = zero;
242 auto d0_stride = zero;
244 auto d1_stride = zero;
246 auto d2_stride = zero;
247 auto iteration_current = zero;
248 auto iteration_size = zero;
249 auto iteration_stride = zero;
252 auto use_next_bd = zero;
253 auto valid_bd = zero;
254 auto lock_rel_val = zero;
255 auto lock_rel_id = zero;
256 auto lock_acq_enable = zero;
257 auto lock_acq_val = zero;
258 auto lock_acq_id = zero;
259 auto d0_zero_before = zero;
260 auto d1_zero_before = zero;
261 auto d2_zero_before = zero;
262 auto d0_zero_after = zero;
263 auto d1_zero_after = zero;
264 auto d2_zero_after = zero;
265 auto burst_length = zero;
267 auto issue_token = BoolAttr::get(ctx,
false);
268 auto repeat_count = zero;
269 llvm::SmallVector<int64_t, 4> inputSizes = llvm::map_to_vector(
270 llvm::reverse(op.getMixedSizes()),
271 [](OpFoldResult s) { return getConstantIntValue(s).value(); });
272 llvm::SmallVector<int64_t, 4> inputStrides = llvm::map_to_vector(
273 llvm::reverse(op.getMixedStrides()),
274 [](OpFoldResult s) { return getConstantIntValue(s).value(); });
275 llvm::SmallVector<int64_t, 4> sizes(4);
276 llvm::SmallVector<int64_t, 4> strides(4);
278 inputStrides, sizes, strides);
279 int64_t offset = op.getOffsetInBytes();
282 column = IntegerAttr::get(i32ty, tileCol);
285 row = IntegerAttr::get(i32ty, tileRow);
287 bool skipTransformationChecks = op.isLinearTransferWithoutTransformation();
289 inputStrides, sizes, strides,
290 skipTransformationChecks))) {
295 AIE::RuntimeSequenceOp seq_op =
296 op->getParentOfType<AIE::RuntimeSequenceOp>();
298 op->emitOpError(
"NpuDmaMemcpyNdOps must have RuntimeSequenceOp parent at "
299 "time of lowering.");
303 mlir::Value rootMemref = memref;
304 int64_t subviewOffset = 0;
310 return op->emitOpError(
311 "memref must be a block argument or subview/cast/reinterpret_cast of "
312 "a block argument with static offsets, sizes, and strides");
314 rootMemref = traceResult->rootArg;
315 subviewOffset = traceResult->offsetInBytes;
318 Block &entryBB = seq_op.getBody().front();
320 for (
int i = 0, e = entryBB.getNumArguments(); i < e; i++) {
321 if (entryBB.getArgument(i) == rootMemref) {
329 offset += subviewOffset;
332 bd_id = IntegerAttr::get(i32ty, op.getId());
335 uint64_t buffer_length_val = inputSizes[0] * op.getElementTypeBitwidth() /
336 targetModel.getAddressGenGranularity();
337 if (inputSizes.size() > 1) {
338 for (
size_t i = 1; i < std::min(inputSizes.size(), (
size_t)3); i++) {
339 buffer_length_val *= inputSizes[i];
342 buffer_length = IntegerAttr::get(i32ty, buffer_length_val);
345 buffer_offset = IntegerAttr::get(i32ty, 0);
348 if (
auto packetInfo = op.getPacket()) {
349 enable_packet = IntegerAttr::get(i32ty, 1);
350 packet_type = IntegerAttr::get(i32ty, packetInfo->getPktType());
351 packet_id = IntegerAttr::get(i32ty, packetInfo->getPktId());
356 if (!op.isLinearTransferWithoutTransformation()) {
358 d0_size = IntegerAttr::get(i32ty, sizes[0]);
359 d0_stride = IntegerAttr::get(i32ty, strides[0]);
362 d1_size = IntegerAttr::get(i32ty, sizes[1]);
363 d1_stride = IntegerAttr::get(i32ty, strides[1]);
366 d2_stride = IntegerAttr::get(i32ty, strides[2]);
369 if (targetModel.isMemTile(tileCol, 0))
370 d2_size = IntegerAttr::get(i32ty, sizes[2]);
372 d2_size = IntegerAttr::get(i32ty, 0);
375 if (inputSizes[3] > 1) {
376 if (inputStrides[3] > 0) {
377 iteration_size = IntegerAttr::get(i32ty, sizes[3]);
378 iteration_stride = IntegerAttr::get(i32ty, strides[3]);
384 iteration_size = zero;
385 iteration_stride = zero;
388 repeat_count = IntegerAttr::get(i32ty, sizes[3]);
395 valid_bd = IntegerAttr::get(i32ty, 1);
408 d0_zero_before = IntegerAttr::get(i32ty, op.getD0ZeroBefore());
411 d1_zero_before = IntegerAttr::get(i32ty, op.getD1ZeroBefore());
414 d2_zero_before = IntegerAttr::get(i32ty, op.getD2ZeroBefore());
417 d0_zero_after = IntegerAttr::get(i32ty, op.getD0ZeroAfter());
420 d1_zero_after = IntegerAttr::get(i32ty, op.getD1ZeroAfter());
423 d2_zero_after = IntegerAttr::get(i32ty, op.getD2ZeroAfter());
426 burst_length = IntegerAttr::get(i32ty, op.getBurstLength());
429 issue_token = BoolAttr::get(ctx, op.getIssueToken());
433 issue_token = BoolAttr::get(ctx,
true);
435 if (targetModel.isMemTile(tileCol, tileRow) && (!isMM2S) &&
436 (op.getD0ZeroBefore() != 0 || op.getD0ZeroAfter() != 0 ||
437 op.getD1ZeroBefore() != 0 || op.getD1ZeroAfter() != 0 ||
438 op.getD2ZeroBefore() != 0 || op.getD2ZeroAfter() != 0)) {
439 op->emitOpError(
"MemTile supports zero padding only on MM2S direction");
444 NpuWriteBdOp::create(
445 rewriter, op->getLoc(), column, bd_id, buffer_length, buffer_offset,
446 enable_packet, out_of_order_id, packet_id, packet_type, d0_size,
447 d0_stride, d1_size, d1_stride, d2_size, d2_stride, iteration_current,
448 iteration_size, iteration_stride, next_bd, row, use_next_bd, valid_bd,
449 lock_rel_val, lock_rel_id, lock_acq_enable, lock_acq_val, lock_acq_id,
450 d0_zero_before, d1_zero_before, d2_zero_before, d0_zero_after,
451 d1_zero_after, d2_zero_after, burst_length);
455 uint64_t addr = targetModel.getDmaBdAddress(tileCol, tileRow, op.getId()) +
456 targetModel.getDmaBdAddressOffset(tileCol, tileRow);
457 NpuAddressPatchOp::create(rewriter, op->getLoc(), addr, arg_idx, offset);
460 NpuPushQueueOp::create(
461 rewriter, op->getLoc(), column, row, infoOp.getChannelDirAttr(),
462 infoOp.getChannelIndexAttr(), issue_token, repeat_count, bd_id);
464 rewriter.eraseOp(op);
475 using OpConversionPattern::OpConversionPattern;
477 DmaWaitToSyncPattern(MLIRContext *context, PatternBenefit benefit = 1)
481 matchAndRewrite(NpuDmaWaitOp op, OpAdaptor adaptor,
482 ConversionPatternRewriter &rewriter)
const override {
483 AIE::DeviceOp dev = op->getParentOfType<AIE::DeviceOp>();
485 return op->emitError(
"couldn't find parent of type DeviceOp");
487 AIE::ShimDMAAllocationOp shimDmaAllocOp =
488 AIE::ShimDMAAllocationOp::getForSymbol(dev, op.getSymbol());
489 if (!shimDmaAllocOp) {
490 return op->emitError(
"couldn't find shim_dma_allocation op");
493 AIE::TileOp shimTile = shimDmaAllocOp.getTileOp();
495 return op->emitError(
496 "shim_dma_allocation op must reference a valid TileOp");
501 (void)rewriter.replaceOpWithNewOp<NpuSyncOp>(
502 op, shimTile.getCol(), shimTile.getRow(),
503 static_cast<uint32_t
>(shimDmaAllocOp.getChannelDir()),
504 shimDmaAllocOp.getChannelIndex(), 1, 1);
511 using OpConversionPattern::OpConversionPattern;
514 WriteBdToBlockWritePattern(MLIRContext *context, PatternBenefit benefit = 1)
518 matchAndRewrite(NpuWriteBdOp op, OpAdaptor adaptor,
519 ConversionPatternRewriter &rewriter)
const override {
521 AIE::DeviceOp dev = op->getParentOfType<AIE::DeviceOp>();
523 int col = op.getColumn();
524 int row = op.getRow();
527 if (isa<AIE::AIE2TargetModel>(tm)) {
529 if (tm.isCoreTile(col, row))
535 "Unsupported AIETargetModel in WriteBdToBlockWritePattern");
538 std::vector<uint32_t> words(num_words, 0);
540 uint32_t bd_id = op.getBdId();
541 uint64_t bd_addr = tm.getDmaBdAddress(col, row, bd_id);
542 if (tm.isShimNOCTile(col, row)) {
544 words[0] = op.getBufferLength();
547 words[1] = op.getBufferOffset();
551 words[2] |= (op.getEnablePacket() & 0x1) << 30;
552 words[2] |= (op.getOutOfOrderId() & 0x3f) << 24;
553 words[2] |= (op.getPacketId() & 0x1f) << 19;
554 words[2] |= (op.getPacketType() & 0x7) << 16;
558 words[3] |= (op.getD0Size() & 0x3ff) << 20;
559 words[3] |= op.getD0Stride() & 0xfffff;
564 words[4] |= (op.getD1Size() & 0x3ff) << 20;
565 words[4] |= op.getD1Stride() & 0xfffff;
569 words[5] |= (2 & 0xf) << 24;
570 words[5] |= op.getD2Stride() & 0xfffff;
573 words[6] |= (op.getIterationCurrent() & 0x3f) << 26;
574 words[6] |= (op.getIterationSize() & 0x3f) << 20;
575 words[6] |= op.getIterationStride() & 0xfffff;
579 words[7] |= (op.getNextBd() & 0xf) << 27;
580 words[7] |= (op.getUseNextBd() & 0x1) << 26;
581 words[7] |= (op.getValidBd() & 0x1) << 25;
582 words[7] |= (op.getLockRelVal() & 0x7f) << 18;
583 words[7] |= (op.getLockRelId() & 0xf) << 13;
584 words[7] |= (op.getLockAcqEnable() & 0x1) << 12;
585 words[7] |= (op.getLockAcqVal() & 0x7f) << 5;
586 words[7] |= op.getLockAcqId() & 0xf;
587 if (op.getD0ZeroBefore() || op.getD1ZeroBefore() ||
588 op.getD2ZeroBefore() || op.getD0ZeroAfter() || op.getD1ZeroAfter() ||
589 op.getD2ZeroAfter()) {
590 op->emitError(
"Zero padding is only available on MemTile");
592 }
else if (tm.isMemTile(op.getColumn(), op.getRow())) {
595 words[0] |= (op.getEnablePacket() & 0x1) << 31;
596 words[0] |= (op.getPacketType() & 0x7) << 28;
597 words[0] |= (op.getPacketId() & 0x1f) << 23;
598 words[0] |= (op.getOutOfOrderId() & 0x3f) << 17;
599 words[0] |= op.getBufferLength() & 0x1ffff;
602 words[1] |= (op.getD0ZeroBefore() & 0x3F) << 26;
603 words[1] |= (op.getNextBd() & 0x3f) << 20;
604 words[1] |= (op.getUseNextBd() & 0x1) << 19;
605 words[1] |= op.getBufferOffset() & 0x7ffff;
608 words[2] |= (op.getD0Size() & 0x3ff) << 17;
609 words[2] |= op.getD0Stride() & 0x1ffff;
613 words[3] |= (op.getD1ZeroBefore() & 0x1F) << 27;
614 words[3] |= (op.getD1Size() & 0x3ff) << 17;
615 words[3] |= op.getD1Stride() & 0x1ffff;
619 words[4] |= (op.getD2ZeroBefore() & 0xF) << 27;
620 words[4] |= op.getD2Stride() & 0x1ffff;
624 words[5] |= (op.getD2ZeroAfter() & 0xF) << 28;
625 words[5] |= (op.getD1ZeroAfter() & 0x1F) << 23;
626 words[5] |= (op.getD0ZeroAfter() & 0x3F) << 17;
629 words[6] |= (op.getIterationCurrent() & 0x3f) << 23;
630 words[6] |= (op.getIterationSize() & 0x3f) << 17;
631 words[6] |= op.getIterationStride() & 0x1ffff;
634 words[7] |= (op.getValidBd() & 0x1) << 31;
635 words[7] |= (op.getLockRelVal() & 0x7f) << 24;
636 words[7] |= (op.getLockRelId() & 0xff) << 16;
637 words[7] |= (op.getLockAcqEnable() & 0x1) << 15;
638 words[7] |= (op.getLockAcqVal() & 0x7f) << 8;
639 words[7] |= op.getLockAcqId() & 0xff;
644 words[0] = ((op.getBufferOffset() / 4) & 0x3fff) << 14;
645 words[0] |= op.getBufferLength() & 0x3fff;
651 words[1] |= (op.getEnablePacket() & 0x1) << 30;
652 words[1] |= (op.getOutOfOrderId() & 0x3f) << 24;
653 words[1] |= (op.getPacketId() & 0x1f) << 19;
654 words[1] |= (op.getPacketType() & 0x7) << 16;
658 words[2] = (op.getD1Stride() & 0x1fff) << 13;
659 words[2] |= op.getD0Stride() & 0x1fff;
663 words[3] = (op.getD1Size() & 0xff) << 21;
664 words[3] |= (op.getD0Size() & 0xff) << 13;
665 words[3] |= op.getD2Stride() & 0x1fff;
670 words[4] = (op.getIterationCurrent() & 0x3f) << 19;
671 words[4] |= (op.getIterationSize() & 0x3f) << 13;
672 words[4] |= op.getIterationStride() & 0x1fff;
679 words[5] |= (op.getNextBd() & 0xf) << 27;
680 words[5] |= (op.getUseNextBd() & 0x1) << 26;
681 words[5] |= (op.getValidBd() & 0x1) << 25;
682 words[5] |= (op.getLockRelVal() & 0x7f) << 18;
683 words[5] |= (op.getLockRelId() & 0xf) << 13;
684 words[5] |= (op.getLockAcqEnable() & 0x1) << 12;
685 words[5] |= (op.getLockAcqVal() & 0x7f) << 5;
686 words[5] |= op.getLockAcqId() & 0xf;
689 memref::GlobalOp global =
nullptr;
691 OpBuilder::InsertionGuard guard(rewriter);
692 rewriter.setInsertionPoint(op->getParentOfType<AIE::RuntimeSequenceOp>());
695 auto memref = memref::GetGlobalOp::create(
696 rewriter, op.getLoc(), global.getType(), global.getName());
698 (void)rewriter.replaceOpWithNewOp<NpuBlockWriteOp>(
699 op, rewriter.getUI32IntegerAttr(bd_addr), memref.getResult(),
nullptr,
705struct AIEDmaToNpuPass : xilinx::AIEX::impl::AIEDmaToNpuBase<AIEDmaToNpuPass> {
707 void getDependentDialects(DialectRegistry ®istry)
const override {
708 registry.insert<memref::MemRefDialect>();
711 void runOnOperation()
override {
713 AIE::DeviceOp device = getOperation();
715 ConversionTarget target(getContext());
716 target.addLegalDialect<AIEXDialect>();
717 target.addLegalDialect<memref::MemRefDialect>();
718 target.addLegalOp<AIE::BufferOp>();
719 target.addLegalOp<AIE::ShimDMAAllocationOp>();
720 target.addLegalOp<AIE::TileOp>();
722 target.addIllegalOp<NpuDmaMemcpyNdOp>();
723 target.addIllegalOp<NpuDmaWaitOp>();
724 target.addIllegalOp<NpuPushQueueOp>();
725 target.addIllegalOp<NpuWriteRTPOp>();
726 target.addIllegalOp<NpuWriteBdOp>();
727 target.addDynamicallyLegalOp<NpuWrite32Op>(
728 [&](NpuWrite32Op op) {
return !op.getBuffer(); });
729 target.addDynamicallyLegalOp<NpuBlockWriteOp>(
730 [&](NpuBlockWriteOp op) {
return !op.getBuffer(); });
731 target.addDynamicallyLegalOp<NpuMaskWrite32Op>(
732 [&](NpuMaskWrite32Op op) {
return !op.getBuffer(); });
734 RewritePatternSet patterns(&getContext());
735 patterns.insert<BlockWriteSymToAddr>(&getContext());
736 patterns.insert<DmaToNpuPattern>(&getContext());
737 patterns.insert<DmaWaitToSyncPattern>(&getContext());
738 patterns.insert<MaskWrite32SymToAddr>(&getContext());
739 patterns.insert<PushQueuetoWrite32Pattern>(&getContext());
740 patterns.insert<RtpToWrite32Pattern>(&getContext());
741 patterns.insert<Write32SymToAddr>(&getContext());
742 patterns.insert<WriteBdToBlockWritePattern>(&getContext());
744 if (failed(applyPartialConversion(device, target, std::move(patterns))))
752 return std::make_unique<AIEDmaToNpuPass>();
std::unique_ptr< mlir::OperationPass< AIE::DeviceOp > > createAIEDmaToNpuPass()
std::optional< SubviewTraceResult > traceSubviewToBlockArgument(Value value)
void getHardwareStridesWraps(const AIE::AIETargetModel &targetModel, mlir::Operation *op, mlir::BaseMemRefType referencedBufType, llvm::SmallVector< int64_t, 4 > inputSizes, llvm::SmallVector< int64_t, 4 > inputStrides, llvm::SmallVector< int64_t, 4 > &sizes, llvm::SmallVector< int64_t, 4 > &strides)
mlir::LogicalResult verifyStridesWraps(mlir::Operation *forOp, mlir::BaseMemRefType referencedBufType, int tileCol, int tileRow, llvm::SmallVector< int64_t, 4 > inputSizes, llvm::SmallVector< int64_t, 4 > inputStrides, llvm::SmallVector< int64_t, 4 > hardwareSizes, llvm::SmallVector< int64_t, 4 > hardwareStrides, bool skipTransformationChecks=false)
memref::GlobalOp getOrCreateDataMemref(OpBuilder &builder, AIE::DeviceOp dev, mlir::Location loc, ArrayRef< uint32_t > words)
uint32_t getShimBurstLengthEncoding(const AIE::AIETargetModel &tm, uint32_t burstLength)
const AIETargetModel & getTargetModel(mlir::Operation *op)