MLIR-AIE
AIEDmaToNpu.cpp
Go to the documentation of this file.
1//===- AIEDmaToNpu.cpp ------------------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2023 Advanced Micro Devices, Inc.
8//
9//===----------------------------------------------------------------------===//
10
16
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Pass/Pass.h"
19#include "mlir/Transforms/DialectConversion.h"
20#include <algorithm>
21#include <cstdint>
22
23namespace xilinx::AIEX {
24#define GEN_PASS_DEF_AIEDMATONPU
25#include "aie/Dialect/AIEX/Transforms/AIEXPasses.h.inc"
26} // namespace xilinx::AIEX
27
28using namespace mlir;
29using namespace xilinx;
30using namespace xilinx::AIEX;
31
32namespace {
33
34struct Write32SymToAddr : OpConversionPattern<NpuWrite32Op> {
35 using OpConversionPattern::OpConversionPattern;
36
37 Write32SymToAddr(MLIRContext *context, PatternBenefit benefit = 1)
38 : OpConversionPattern(context, benefit) {}
39
40 LogicalResult
41 matchAndRewrite(NpuWrite32Op op, OpAdaptor adaptor,
42 ConversionPatternRewriter &rewriter) const override {
43
44 if (!op.getBuffer())
45 return failure();
46
47 std::optional<uint32_t> address = op.getAbsoluteAddress();
48 if (!address.has_value()) {
49 return failure();
50 }
51
52 rewriter.replaceOpWithNewOp<NpuWrite32Op>(op, *address, op.getValue(),
53 nullptr, nullptr, nullptr);
54 return success();
55 }
56};
57
58struct BlockWriteSymToAddr : OpConversionPattern<NpuBlockWriteOp> {
59 using OpConversionPattern::OpConversionPattern;
60
61 BlockWriteSymToAddr(MLIRContext *context, PatternBenefit benefit = 1)
62 : OpConversionPattern(context, benefit) {}
63
64 LogicalResult
65 matchAndRewrite(NpuBlockWriteOp op, OpAdaptor adaptor,
66 ConversionPatternRewriter &rewriter) const override {
67
68 if (!op.getBuffer())
69 return failure();
70
71 std::optional<uint32_t> address = op.getAbsoluteAddress();
72 if (!address.has_value()) {
73 return failure();
74 }
75 rewriter.replaceOpWithNewOp<NpuBlockWriteOp>(op, *address, op.getData(),
76 nullptr, nullptr, nullptr);
77 return success();
78 }
79};
80
81struct MaskWrite32SymToAddr : OpConversionPattern<NpuMaskWrite32Op> {
82 using OpConversionPattern::OpConversionPattern;
83
84 MaskWrite32SymToAddr(MLIRContext *context, PatternBenefit benefit = 1)
85 : OpConversionPattern(context, benefit) {}
86
87 LogicalResult
88 matchAndRewrite(NpuMaskWrite32Op op, OpAdaptor adaptor,
89 ConversionPatternRewriter &rewriter) const override {
90
91 if (!op.getBuffer())
92 return failure();
93
94 std::optional<uint32_t> absoluteAddress = op.getAbsoluteAddress();
95 if (!absoluteAddress.has_value()) {
96 return failure();
97 }
98
99 rewriter.replaceOpWithNewOp<NpuMaskWrite32Op>(op, *absoluteAddress,
100 op.getValue(), op.getMask(),
101 nullptr, nullptr, nullptr);
102 return success();
103 }
104};
105
106struct RtpToWrite32Pattern : OpConversionPattern<NpuWriteRTPOp> {
107 using OpConversionPattern::OpConversionPattern;
108
109 RtpToWrite32Pattern(MLIRContext *context, PatternBenefit benefit = 1)
110 : OpConversionPattern(context, benefit) {}
111
112 LogicalResult
113 matchAndRewrite(NpuWriteRTPOp op, OpAdaptor adaptor,
114 ConversionPatternRewriter &rewriter) const override {
115
116 auto device = op->getParentOfType<AIE::DeviceOp>();
117
118 auto buffer = device.lookupSymbol<AIE::BufferOp>(op.getBuffer());
119 if (!buffer) {
120 op->emitError("buffer '" + op.getBuffer() + "' not found in device");
121 return failure();
122 }
123
124 if (!buffer.getAddress()) {
125 op->emitError("buffer must have address assigned");
126 return failure();
127 }
128 AIE::TileOp tile = buffer.getTileOp();
129
130 uint32_t idx = op.getIndex() * sizeof(uint32_t);
131 uint32_t address = buffer.getAddress().value() + idx;
132
133 NpuWrite32Op::create(rewriter, op->getLoc(), address, op.getValue(),
134 nullptr, rewriter.getI32IntegerAttr(tile.getCol()),
135 rewriter.getI32IntegerAttr(tile.getRow()));
136
137 rewriter.eraseOp(op);
138 return success();
139 }
140};
141
142struct PushQueuetoWrite32Pattern : OpConversionPattern<NpuPushQueueOp> {
143
144public:
145 using OpConversionPattern::OpConversionPattern;
146
147 PushQueuetoWrite32Pattern(MLIRContext *context, PatternBenefit benefit = 1)
148 : OpConversionPattern(context, benefit) {}
149
150 LogicalResult
151 matchAndRewrite(NpuPushQueueOp op, OpAdaptor adaptor,
152 ConversionPatternRewriter &rewriter) const override {
153
154 const auto &tm = AIE::getTargetModel(op);
155 uint32_t ctrl_offset = tm.getDmaControlAddress(
156 op.getColumn(), op.getRow(), op.getChannel(), op.getDirection());
157
158 // control packet for issuing token
159 if (op.getIssueToken()) {
160 // set the task-complete-token controller ID field in the dma control
161 // register
162 AIE::TileOp shimTile = AIE::TileOp::getOrCreate(
163 rewriter, op->getParentOfType<AIE::DeviceOp>(), op.getColumn(),
164 op.getRow());
165 if (shimTile->hasAttr("controller_id")) {
166 AIE::PacketInfoAttr controller_id_attr =
167 shimTile->getAttrOfType<AIE::PacketInfoAttr>("controller_id");
168 uint32_t data = controller_id_attr.getPktId() << 8;
169 uint32_t mask = 0x00001F00;
170 NpuMaskWrite32Op::create(rewriter, op->getLoc(), ctrl_offset, data,
171 mask, nullptr, nullptr, nullptr);
172 }
173 }
174
175 // the offset of the task queue register in the tile
176 uint32_t queue_offset = ctrl_offset + 0x4;
177
178 // the value to write
179 uint32_t bd_id = op.getBdId();
180 uint32_t repeat_cnt = op.getRepeatCount();
181 uint32_t cmd = 0;
182 cmd |= bd_id & 0xF;
183 cmd |= (repeat_cnt & 0xFF) << 16;
184 if (op.getIssueToken())
185 cmd |= 0x80000000;
186
187 NpuWrite32Op::create(rewriter, op->getLoc(), queue_offset, cmd, nullptr,
188 nullptr, nullptr);
189 rewriter.eraseOp(op);
190 return success();
191 }
192};
193
194struct DmaToNpuPattern : OpConversionPattern<NpuDmaMemcpyNdOp> {
195 using OpConversionPattern::OpConversionPattern;
196
197public:
198 DmaToNpuPattern(MLIRContext *context, PatternBenefit benefit = 1)
199 : OpConversionPattern(context, benefit) {}
200
201 LogicalResult
202 matchAndRewrite(NpuDmaMemcpyNdOp op, OpAdaptor adaptor,
203 ConversionPatternRewriter &rewriter) const override {
204 const auto &targetModel = AIE::getTargetModel(op);
205 BaseMemRefType bufferType = op.getMemref().getType();
206 auto *ctx = op->getContext();
207 auto i32ty = IntegerType::get(ctx, 32);
208 auto zero = IntegerAttr::get(i32ty, 0);
209 auto memref = adaptor.getMemref();
210
211 auto dev = op->getParentOfType<AIE::DeviceOp>();
212 if (!dev)
213 return failure();
214
215 auto infoOp = AIE::ShimDMAAllocationOp::getForSymbol(
216 dev, op.getMetadata().getRootReference());
217 if (!infoOp) {
218 return op->emitOpError("couldn't find shim_dma_allocation op.");
219 }
220
221 AIE::TileOp shimTile = infoOp.getTileOp();
222 if (!shimTile) {
223 return op->emitOpError(
224 "shim_dma_allocation op must reference a valid TileOp.");
225 }
226
227 auto channelDir = infoOp.getChannelDir();
228 bool isMM2S = channelDir == AIE::DMAChannelDir::MM2S;
229 int tileCol = shimTile.getCol();
230 int tileRow = shimTile.getRow();
231
232 // initialize fields to zero
233 auto column = zero;
234 auto bd_id = zero;
235 auto buffer_length = zero;
236 auto buffer_offset = zero;
237 auto enable_packet = zero;
238 auto out_of_order_id = zero;
239 auto packet_id = zero;
240 auto packet_type = zero;
241 auto d0_size = zero;
242 auto d0_stride = zero;
243 auto d1_size = zero;
244 auto d1_stride = zero;
245 auto d2_size = zero;
246 auto d2_stride = zero;
247 auto iteration_current = zero;
248 auto iteration_size = zero;
249 auto iteration_stride = zero;
250 auto next_bd = zero;
251 auto row = zero;
252 auto use_next_bd = zero;
253 auto valid_bd = zero;
254 auto lock_rel_val = zero;
255 auto lock_rel_id = zero;
256 auto lock_acq_enable = zero;
257 auto lock_acq_val = zero;
258 auto lock_acq_id = zero;
259 auto d0_zero_before = zero;
260 auto d1_zero_before = zero;
261 auto d2_zero_before = zero;
262 auto d0_zero_after = zero;
263 auto d1_zero_after = zero;
264 auto d2_zero_after = zero;
265 auto burst_length = zero;
266
267 auto issue_token = BoolAttr::get(ctx, false);
268 auto repeat_count = zero;
269 llvm::SmallVector<int64_t, 4> inputSizes = llvm::map_to_vector(
270 llvm::reverse(op.getMixedSizes()),
271 [](OpFoldResult s) { return getConstantIntValue(s).value(); });
272 llvm::SmallVector<int64_t, 4> inputStrides = llvm::map_to_vector(
273 llvm::reverse(op.getMixedStrides()),
274 [](OpFoldResult s) { return getConstantIntValue(s).value(); });
275 llvm::SmallVector<int64_t, 4> sizes(4);
276 llvm::SmallVector<int64_t, 4> strides(4);
277 getHardwareStridesWraps(targetModel, op, bufferType, inputSizes,
278 inputStrides, sizes, strides);
279 int64_t offset = op.getOffsetInBytes();
280
281 // column
282 column = IntegerAttr::get(i32ty, tileCol);
283
284 // row
285 row = IntegerAttr::get(i32ty, tileRow);
286
287 bool skipTransformationChecks = op.isLinearTransferWithoutTransformation();
288 if (failed(verifyStridesWraps(op, bufferType, tileCol, tileRow, inputSizes,
289 inputStrides, sizes, strides,
290 skipTransformationChecks))) {
291 return failure();
292 }
293
294 // arg_idx and offset for block arguments
295 AIE::RuntimeSequenceOp seq_op =
296 op->getParentOfType<AIE::RuntimeSequenceOp>();
297 if (!seq_op) {
298 op->emitOpError("NpuDmaMemcpyNdOps must have RuntimeSequenceOp parent at "
299 "time of lowering.");
300 return failure();
301 }
302
303 mlir::Value rootMemref = memref;
304 int64_t subviewOffset = 0;
305
306 // Trace through memref.subview and memref.reinterpret_cast chain, if any,
307 // to find root block argument
308 auto traceResult = traceSubviewToBlockArgument(memref);
309 if (!traceResult) {
310 return op->emitOpError(
311 "memref must be a block argument or subview/cast/reinterpret_cast of "
312 "a block argument with static offsets, sizes, and strides");
313 }
314 rootMemref = traceResult->rootArg;
315 subviewOffset = traceResult->offsetInBytes;
316
317 // Find the argument index of the root memref
318 Block &entryBB = seq_op.getBody().front();
319 int arg_idx = -1;
320 for (int i = 0, e = entryBB.getNumArguments(); i < e; i++) {
321 if (entryBB.getArgument(i) == rootMemref) {
322 arg_idx = i;
323 break;
324 }
325 }
326 if (arg_idx < 0)
327 return failure();
328
329 offset += subviewOffset;
330
331 // bd_id
332 bd_id = IntegerAttr::get(i32ty, op.getId());
333
334 // buffer_length
335 uint64_t buffer_length_val = inputSizes[0] * op.getElementTypeBitwidth() /
336 targetModel.getAddressGenGranularity();
337 if (inputSizes.size() > 1) {
338 for (size_t i = 1; i < std::min(inputSizes.size(), (size_t)3); i++) {
339 buffer_length_val *= inputSizes[i];
340 }
341 }
342 buffer_length = IntegerAttr::get(i32ty, buffer_length_val);
343
344 // buffer_offset - zero because the complete address is set by the patch op
345 buffer_offset = IntegerAttr::get(i32ty, 0);
346
347 // enable_packet
348 if (auto packetInfo = op.getPacket()) {
349 enable_packet = IntegerAttr::get(i32ty, 1);
350 packet_type = IntegerAttr::get(i32ty, packetInfo->getPktType());
351 packet_id = IntegerAttr::get(i32ty, packetInfo->getPktId());
352 }
353
354 // out_of_order_id
355
356 if (!op.isLinearTransferWithoutTransformation()) {
357 // d0_size, d0_stride
358 d0_size = IntegerAttr::get(i32ty, sizes[0]);
359 d0_stride = IntegerAttr::get(i32ty, strides[0]);
360
361 // d1_size, d1_stride
362 d1_size = IntegerAttr::get(i32ty, sizes[1]);
363 d1_stride = IntegerAttr::get(i32ty, strides[1]);
364
365 // d2_stride
366 d2_stride = IntegerAttr::get(i32ty, strides[2]);
367
368 // d2_size
369 if (targetModel.isMemTile(tileCol, 0)) // Need to be any row
370 d2_size = IntegerAttr::get(i32ty, sizes[2]);
371 else
372 d2_size = IntegerAttr::get(i32ty, 0);
373 }
374 // iteration_current, iteration_size, iteration_stride, repeat_count
375 if (inputSizes[3] > 1) {
376 if (inputStrides[3] > 0) {
377 iteration_size = IntegerAttr::get(i32ty, sizes[3]);
378 iteration_stride = IntegerAttr::get(i32ty, strides[3]);
379 } else {
380 // We allow users to encode the repeat_count as a dimension 3 stride
381 // of 0. This must lower to a iteration wrap of 0, so no stride is
382 // ever added. We then repeat the BD using the repeat_count in
383 // NpuPushQueueOp.
384 iteration_size = zero;
385 iteration_stride = zero;
386 }
387 }
388 repeat_count = IntegerAttr::get(i32ty, sizes[3]);
389
390 // next_bd
391
392 // use_next_bd
393
394 // valid_bd
395 valid_bd = IntegerAttr::get(i32ty, 1);
396
397 // lock_rel_val
398
399 // lock_rel_id
400
401 // lock_acq_enable
402
403 // lock_acq_val
404
405 // lock_acq_id
406
407 // d0_zero_before
408 d0_zero_before = IntegerAttr::get(i32ty, op.getD0ZeroBefore());
409
410 // d1_zero_before
411 d1_zero_before = IntegerAttr::get(i32ty, op.getD1ZeroBefore());
412
413 // d2_zero_before
414 d2_zero_before = IntegerAttr::get(i32ty, op.getD2ZeroBefore());
415
416 // d0_zero_after
417 d0_zero_after = IntegerAttr::get(i32ty, op.getD0ZeroAfter());
418
419 // d1_zero_after
420 d1_zero_after = IntegerAttr::get(i32ty, op.getD1ZeroAfter());
421
422 // d2_zero_after
423 d2_zero_after = IntegerAttr::get(i32ty, op.getD2ZeroAfter());
424
425 // burst_size
426 burst_length = IntegerAttr::get(i32ty, op.getBurstLength());
427
428 // Set the issue_token
429 issue_token = BoolAttr::get(ctx, op.getIssueToken());
430 // Earlier, all S2MM channels were implicitly assumed to issue a token.
431 // This logic is kept for now for backward compatibility.
432 if (!isMM2S)
433 issue_token = BoolAttr::get(ctx, true);
434
435 if (targetModel.isMemTile(tileCol, tileRow) && (!isMM2S) &&
436 (op.getD0ZeroBefore() != 0 || op.getD0ZeroAfter() != 0 ||
437 op.getD1ZeroBefore() != 0 || op.getD1ZeroAfter() != 0 ||
438 op.getD2ZeroBefore() != 0 || op.getD2ZeroAfter() != 0)) {
439 op->emitOpError("MemTile supports zero padding only on MM2S direction");
440 return failure();
441 }
442
443 // write the buffer descriptor to the array
444 NpuWriteBdOp::create(
445 rewriter, op->getLoc(), column, bd_id, buffer_length, buffer_offset,
446 enable_packet, out_of_order_id, packet_id, packet_type, d0_size,
447 d0_stride, d1_size, d1_stride, d2_size, d2_stride, iteration_current,
448 iteration_size, iteration_stride, next_bd, row, use_next_bd, valid_bd,
449 lock_rel_val, lock_rel_id, lock_acq_enable, lock_acq_val, lock_acq_id,
450 d0_zero_before, d1_zero_before, d2_zero_before, d0_zero_after,
451 d1_zero_after, d2_zero_after, burst_length);
452
453 // compute the location of the address to patch in the bd and emit patch
454 // instruction to perform the patch.
455 uint64_t addr = targetModel.getDmaBdAddress(tileCol, tileRow, op.getId()) +
456 targetModel.getDmaBdAddressOffset(tileCol, tileRow);
457 NpuAddressPatchOp::create(rewriter, op->getLoc(), addr, arg_idx, offset);
458
459 // push the patched bd onto the dma task queue
460 NpuPushQueueOp::create(
461 rewriter, op->getLoc(), column, row, infoOp.getChannelDirAttr(),
462 infoOp.getChannelIndexAttr(), issue_token, repeat_count, bd_id);
463
464 rewriter.eraseOp(op);
465 return success();
466 }
467};
468
469/// Convert NpuDmaWaitOp into NpuSyncOp by retrieving the necessary
470/// information from the ShimDMAAllocationOp referenced through the
471/// symbol argument of this op.
472struct DmaWaitToSyncPattern : OpConversionPattern<NpuDmaWaitOp> {
473
474public:
475 using OpConversionPattern::OpConversionPattern;
476
477 DmaWaitToSyncPattern(MLIRContext *context, PatternBenefit benefit = 1)
478 : OpConversionPattern(context, benefit) {}
479
480 LogicalResult
481 matchAndRewrite(NpuDmaWaitOp op, OpAdaptor adaptor,
482 ConversionPatternRewriter &rewriter) const override {
483 AIE::DeviceOp dev = op->getParentOfType<AIE::DeviceOp>();
484 if (!dev)
485 return op->emitError("couldn't find parent of type DeviceOp");
486
487 AIE::ShimDMAAllocationOp shimDmaAllocOp =
488 AIE::ShimDMAAllocationOp::getForSymbol(dev, op.getSymbol());
489 if (!shimDmaAllocOp) {
490 return op->emitError("couldn't find shim_dma_allocation op");
491 }
492
493 AIE::TileOp shimTile = shimDmaAllocOp.getTileOp();
494 if (!shimTile) {
495 return op->emitError(
496 "shim_dma_allocation op must reference a valid TileOp");
497 }
498
499 // Create with `column_num == 1` and `row_num == 1` to check for a single
500 // column and row.
501 (void)rewriter.replaceOpWithNewOp<NpuSyncOp>(
502 op, shimTile.getCol(), shimTile.getRow(),
503 static_cast<uint32_t>(shimDmaAllocOp.getChannelDir()),
504 shimDmaAllocOp.getChannelIndex(), 1, 1);
505
506 return success();
507 }
508};
509
510struct WriteBdToBlockWritePattern : OpConversionPattern<NpuWriteBdOp> {
511 using OpConversionPattern::OpConversionPattern;
512
513public:
514 WriteBdToBlockWritePattern(MLIRContext *context, PatternBenefit benefit = 1)
515 : OpConversionPattern(context, benefit) {}
516
517 LogicalResult
518 matchAndRewrite(NpuWriteBdOp op, OpAdaptor adaptor,
519 ConversionPatternRewriter &rewriter) const override {
520
521 AIE::DeviceOp dev = op->getParentOfType<AIE::DeviceOp>();
522 const AIE::AIETargetModel &tm = dev.getTargetModel();
523 int col = op.getColumn();
524 int row = op.getRow();
525
526 int num_words = 0;
527 if (isa<AIE::AIE2TargetModel>(tm)) {
528 // Tile DMAs have 6 words, MemTile and Shim have 8 words
529 if (tm.isCoreTile(col, row))
530 num_words = 6;
531 else
532 num_words = 8;
533 } else {
534 llvm_unreachable(
535 "Unsupported AIETargetModel in WriteBdToBlockWritePattern");
536 }
537
538 std::vector<uint32_t> words(num_words, 0);
539
540 uint32_t bd_id = op.getBdId();
541 uint64_t bd_addr = tm.getDmaBdAddress(col, row, bd_id);
542 if (tm.isShimNOCTile(col, row)) {
543 // DMA_BDX_0
544 words[0] = op.getBufferLength();
545
546 // DMA_BDX_1
547 words[1] = op.getBufferOffset();
548
549 // DMA_BDX_2
550 // En Packet , OoO BD ID , Packet ID , Packet Type
551 words[2] |= (op.getEnablePacket() & 0x1) << 30;
552 words[2] |= (op.getOutOfOrderId() & 0x3f) << 24;
553 words[2] |= (op.getPacketId() & 0x1f) << 19;
554 words[2] |= (op.getPacketType() & 0x7) << 16;
555
556 // DMA_BDX_3
557 // TODO: Secure Access
558 words[3] |= (op.getD0Size() & 0x3ff) << 20;
559 words[3] |= op.getD0Stride() & 0xfffff;
560
561 // DMA_BDX_4
562 words[4] = (getShimBurstLengthEncoding(tm, op.getBurstLength()) & 0x3)
563 << 30;
564 words[4] |= (op.getD1Size() & 0x3ff) << 20;
565 words[4] |= op.getD1Stride() & 0xfffff;
566
567 // DMA_BDX_5
568 // TODO: SIMID, AXQoS
569 words[5] |= (2 & 0xf) << 24; // AXCache = 2 to enable upsizing in NoC
570 words[5] |= op.getD2Stride() & 0xfffff;
571
572 // DMA_BDX_6
573 words[6] |= (op.getIterationCurrent() & 0x3f) << 26;
574 words[6] |= (op.getIterationSize() & 0x3f) << 20;
575 words[6] |= op.getIterationStride() & 0xfffff;
576
577 // DMA_BDX_7
578 // TODO: TLAST Suppress
579 words[7] |= (op.getNextBd() & 0xf) << 27;
580 words[7] |= (op.getUseNextBd() & 0x1) << 26;
581 words[7] |= (op.getValidBd() & 0x1) << 25;
582 words[7] |= (op.getLockRelVal() & 0x7f) << 18;
583 words[7] |= (op.getLockRelId() & 0xf) << 13;
584 words[7] |= (op.getLockAcqEnable() & 0x1) << 12;
585 words[7] |= (op.getLockAcqVal() & 0x7f) << 5;
586 words[7] |= op.getLockAcqId() & 0xf;
587 if (op.getD0ZeroBefore() || op.getD1ZeroBefore() ||
588 op.getD2ZeroBefore() || op.getD0ZeroAfter() || op.getD1ZeroAfter() ||
589 op.getD2ZeroAfter()) {
590 op->emitError("Zero padding is only available on MemTile");
591 }
592 } else if (tm.isMemTile(op.getColumn(), op.getRow())) {
593
594 // DMA_BDX_0
595 words[0] |= (op.getEnablePacket() & 0x1) << 31;
596 words[0] |= (op.getPacketType() & 0x7) << 28;
597 words[0] |= (op.getPacketId() & 0x1f) << 23;
598 words[0] |= (op.getOutOfOrderId() & 0x3f) << 17;
599 words[0] |= op.getBufferLength() & 0x1ffff;
600
601 // DMA_BDX_1
602 words[1] |= (op.getD0ZeroBefore() & 0x3F) << 26;
603 words[1] |= (op.getNextBd() & 0x3f) << 20;
604 words[1] |= (op.getUseNextBd() & 0x1) << 19;
605 words[1] |= op.getBufferOffset() & 0x7ffff;
606
607 // DMA_BDX_2
608 words[2] |= (op.getD0Size() & 0x3ff) << 17;
609 words[2] |= op.getD0Stride() & 0x1ffff;
610
611 // DMA_BDX_3
612 // TODO: Secure Access
613 words[3] |= (op.getD1ZeroBefore() & 0x1F) << 27;
614 words[3] |= (op.getD1Size() & 0x3ff) << 17;
615 words[3] |= op.getD1Stride() & 0x1ffff;
616
617 // DMA_BDX_4
618 // TODO: D2Size
619 words[4] |= (op.getD2ZeroBefore() & 0xF) << 27;
620 words[4] |= op.getD2Stride() & 0x1ffff;
621
622 // DMA_BDX_5
623 // ToDO: D3Stride
624 words[5] |= (op.getD2ZeroAfter() & 0xF) << 28;
625 words[5] |= (op.getD1ZeroAfter() & 0x1F) << 23;
626 words[5] |= (op.getD0ZeroAfter() & 0x3F) << 17;
627
628 // DMA_BDX_6
629 words[6] |= (op.getIterationCurrent() & 0x3f) << 23;
630 words[6] |= (op.getIterationSize() & 0x3f) << 17;
631 words[6] |= op.getIterationStride() & 0x1ffff;
632
633 // DMA_BDX_7
634 words[7] |= (op.getValidBd() & 0x1) << 31;
635 words[7] |= (op.getLockRelVal() & 0x7f) << 24;
636 words[7] |= (op.getLockRelId() & 0xff) << 16;
637 words[7] |= (op.getLockAcqEnable() & 0x1) << 15;
638 words[7] |= (op.getLockAcqVal() & 0x7f) << 8;
639 words[7] |= op.getLockAcqId() & 0xff;
640 } else {
641 // AIE2 Tile DMA - 6 words
642 // DMA_BDX_0
643 // Base_Address [27:14], Buffer_Length [13:0]
644 words[0] = ((op.getBufferOffset() / 4) & 0x3fff) << 14;
645 words[0] |= op.getBufferLength() & 0x3fff;
646
647 // DMA_BDX_1
648 // Enable_Compression [31], Enable_Packet [30], Out_Of_Order_BD_ID
649 // [29:24], Packet_ID [23:19], Packet_Type [18:16]
650 words[1] = 0; // Enable_Compression
651 words[1] |= (op.getEnablePacket() & 0x1) << 30;
652 words[1] |= (op.getOutOfOrderId() & 0x3f) << 24;
653 words[1] |= (op.getPacketId() & 0x1f) << 19;
654 words[1] |= (op.getPacketType() & 0x7) << 16;
655
656 // DMA_BDX_2
657 // D1_Stepsize [25:13], D0_Stepsize [12:0]
658 words[2] = (op.getD1Stride() & 0x1fff) << 13;
659 words[2] |= op.getD0Stride() & 0x1fff;
660
661 // DMA_BDX_3
662 // D1_Wrap [28:21], D0_Wrap [20:13], D2_Stepsize [12:0]
663 words[3] = (op.getD1Size() & 0xff) << 21;
664 words[3] |= (op.getD0Size() & 0xff) << 13;
665 words[3] |= op.getD2Stride() & 0x1fff;
666
667 // DMA_BDX_4
668 // Iteration_Current [24:19], Iteration_Wrap [18:13], Iteration_Stepsize
669 // [12:0]
670 words[4] = (op.getIterationCurrent() & 0x3f) << 19;
671 words[4] |= (op.getIterationSize() & 0x3f) << 13;
672 words[4] |= op.getIterationStride() & 0x1fff;
673
674 // DMA_BDX_5
675 // TLAST_Suppress [31], Next_BD [30:27], Use_Next_BD [26], Valid_BD [25],
676 // Lock_Rel_Value [24:18], Lock_Rel_ID [16:13], Lock_Acq_Enable [12],
677 // Lock_Acq_Value [11:5], Lock_Acq_ID [3:0]
678 words[5] = 0; // TLAST_Suppress
679 words[5] |= (op.getNextBd() & 0xf) << 27;
680 words[5] |= (op.getUseNextBd() & 0x1) << 26;
681 words[5] |= (op.getValidBd() & 0x1) << 25;
682 words[5] |= (op.getLockRelVal() & 0x7f) << 18;
683 words[5] |= (op.getLockRelId() & 0xf) << 13;
684 words[5] |= (op.getLockAcqEnable() & 0x1) << 12;
685 words[5] |= (op.getLockAcqVal() & 0x7f) << 5;
686 words[5] |= op.getLockAcqId() & 0xf;
687 }
688
689 memref::GlobalOp global = nullptr;
690 {
691 OpBuilder::InsertionGuard guard(rewriter);
692 rewriter.setInsertionPoint(op->getParentOfType<AIE::RuntimeSequenceOp>());
693 global = getOrCreateDataMemref(rewriter, dev, op.getLoc(), words);
694 }
695 auto memref = memref::GetGlobalOp::create(
696 rewriter, op.getLoc(), global.getType(), global.getName());
697
698 (void)rewriter.replaceOpWithNewOp<NpuBlockWriteOp>(
699 op, rewriter.getUI32IntegerAttr(bd_addr), memref.getResult(), nullptr,
700 nullptr, nullptr);
701 return success();
702 }
703};
704
705struct AIEDmaToNpuPass : xilinx::AIEX::impl::AIEDmaToNpuBase<AIEDmaToNpuPass> {
706
707 void getDependentDialects(DialectRegistry &registry) const override {
708 registry.insert<memref::MemRefDialect>();
709 }
710
711 void runOnOperation() override {
712
713 AIE::DeviceOp device = getOperation();
714
715 ConversionTarget target(getContext());
716 target.addLegalDialect<AIEXDialect>();
717 target.addLegalDialect<memref::MemRefDialect>();
718 target.addLegalOp<AIE::BufferOp>();
719 target.addLegalOp<AIE::ShimDMAAllocationOp>();
720 target.addLegalOp<AIE::TileOp>();
721
722 target.addIllegalOp<NpuDmaMemcpyNdOp>();
723 target.addIllegalOp<NpuDmaWaitOp>();
724 target.addIllegalOp<NpuPushQueueOp>();
725 target.addIllegalOp<NpuWriteRTPOp>();
726 target.addIllegalOp<NpuWriteBdOp>();
727 target.addDynamicallyLegalOp<NpuWrite32Op>(
728 [&](NpuWrite32Op op) { return !op.getBuffer(); });
729 target.addDynamicallyLegalOp<NpuBlockWriteOp>(
730 [&](NpuBlockWriteOp op) { return !op.getBuffer(); });
731 target.addDynamicallyLegalOp<NpuMaskWrite32Op>(
732 [&](NpuMaskWrite32Op op) { return !op.getBuffer(); });
733
734 RewritePatternSet patterns(&getContext());
735 patterns.insert<BlockWriteSymToAddr>(&getContext());
736 patterns.insert<DmaToNpuPattern>(&getContext());
737 patterns.insert<DmaWaitToSyncPattern>(&getContext());
738 patterns.insert<MaskWrite32SymToAddr>(&getContext());
739 patterns.insert<PushQueuetoWrite32Pattern>(&getContext());
740 patterns.insert<RtpToWrite32Pattern>(&getContext());
741 patterns.insert<Write32SymToAddr>(&getContext());
742 patterns.insert<WriteBdToBlockWritePattern>(&getContext());
743
744 if (failed(applyPartialConversion(device, target, std::move(patterns))))
745 signalPassFailure();
746 }
747};
748
749} // namespace
750
751std::unique_ptr<OperationPass<AIE::DeviceOp>> AIEX::createAIEDmaToNpuPass() {
752 return std::make_unique<AIEDmaToNpuPass>();
753}
std::unique_ptr< mlir::OperationPass< AIE::DeviceOp > > createAIEDmaToNpuPass()
std::optional< SubviewTraceResult > traceSubviewToBlockArgument(Value value)
Definition AIEUtils.cpp:19
void getHardwareStridesWraps(const AIE::AIETargetModel &targetModel, mlir::Operation *op, mlir::BaseMemRefType referencedBufType, llvm::SmallVector< int64_t, 4 > inputSizes, llvm::SmallVector< int64_t, 4 > inputStrides, llvm::SmallVector< int64_t, 4 > &sizes, llvm::SmallVector< int64_t, 4 > &strides)
mlir::LogicalResult verifyStridesWraps(mlir::Operation *forOp, mlir::BaseMemRefType referencedBufType, int tileCol, int tileRow, llvm::SmallVector< int64_t, 4 > inputSizes, llvm::SmallVector< int64_t, 4 > inputStrides, llvm::SmallVector< int64_t, 4 > hardwareSizes, llvm::SmallVector< int64_t, 4 > hardwareStrides, bool skipTransformationChecks=false)
memref::GlobalOp getOrCreateDataMemref(OpBuilder &builder, AIE::DeviceOp dev, mlir::Location loc, ArrayRef< uint32_t > words)
Definition AIEUtils.cpp:113
uint32_t getShimBurstLengthEncoding(const AIE::AIETargetModel &tm, uint32_t burstLength)
const AIETargetModel & getTargetModel(mlir::Operation *op)