MLIR-AIE
AIEDmaToNpu.cpp
Go to the documentation of this file.
1//===- AIEDmaToNpu.cpp ------------------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2023 Advanced Micro Devices, Inc.
8//
9//===----------------------------------------------------------------------===//
10
15
16#include "mlir/Pass/Pass.h"
17#include "mlir/Transforms/DialectConversion.h"
18#include <algorithm>
19#include <cstdint>
20
21using namespace mlir;
22using namespace xilinx;
23using namespace xilinx::AIEX;
24
25namespace {
26
27struct Write32SymToAddr : OpConversionPattern<NpuWrite32Op> {
28 using OpConversionPattern::OpConversionPattern;
29
30 Write32SymToAddr(MLIRContext *context, PatternBenefit benefit = 1)
31 : OpConversionPattern(context, benefit) {}
32
33 LogicalResult
34 matchAndRewrite(NpuWrite32Op op, OpAdaptor adaptor,
35 ConversionPatternRewriter &rewriter) const override {
36
37 if (!op.getBuffer())
38 return failure();
39
40 auto device = op->getParentOfType<AIE::DeviceOp>();
41 auto buffer = device.lookupSymbol<AIE::BufferOp>(*op.getBuffer());
42 if (!buffer)
43 return op->emitError("buffer '" + *op.getBuffer() +
44 "' not found in device");
45
46 if (!buffer.getAddress())
47 return op->emitError("buffer must have address assigned");
48
49 const AIE::AIETargetModel &tm = device.getTargetModel();
50 uint32_t address = static_cast<uint32_t>(*buffer.getAddress()) +
51 op.getAddress() * sizeof(uint32_t);
52 auto col = buffer.getTileOp().getCol();
53 auto row = buffer.getTileOp().getRow();
54 address |= ((col & 0xff) << tm.getColumnShift()) |
55 ((row & 0xff) << tm.getRowShift()) | (address & 0xFFFFF);
56
57 rewriter.replaceOpWithNewOp<NpuWrite32Op>(op, address, op.getValue(),
58 nullptr, nullptr, nullptr);
59 return success();
60 }
61};
62
63struct BlockWriteSymToAddr : OpConversionPattern<NpuBlockWriteOp> {
64 using OpConversionPattern::OpConversionPattern;
65
66 BlockWriteSymToAddr(MLIRContext *context, PatternBenefit benefit = 1)
67 : OpConversionPattern(context, benefit) {}
68
69 LogicalResult
70 matchAndRewrite(NpuBlockWriteOp op, OpAdaptor adaptor,
71 ConversionPatternRewriter &rewriter) const override {
72
73 if (!op.getBuffer())
74 return failure();
75
76 auto device = op->getParentOfType<AIE::DeviceOp>();
77
78 auto buffer = device.lookupSymbol<AIE::BufferOp>(*op.getBuffer());
79 if (!buffer)
80 return op->emitError("buffer '" + *op.getBuffer() +
81 "' not found in device");
82
83 if (!buffer.getAddress())
84 return op->emitError("buffer must have address assigned");
85
86 const AIE::AIETargetModel &tm = device.getTargetModel();
87 uint32_t address = static_cast<uint32_t>(*buffer.getAddress()) +
88 op.getAddress() * sizeof(uint32_t);
89 auto col = buffer.getTileOp().getCol();
90 auto row = buffer.getTileOp().getRow();
91 address |= ((col & 0xff) << tm.getColumnShift()) |
92 ((row & 0xff) << tm.getRowShift()) | (address & 0xFFFFF);
93
94 rewriter.replaceOpWithNewOp<NpuBlockWriteOp>(op, address, op.getData(),
95 nullptr, nullptr, nullptr);
96 return success();
97 }
98};
99
100struct MaskWrite32SymToAddr : OpConversionPattern<NpuMaskWrite32Op> {
101 using OpConversionPattern::OpConversionPattern;
102
103 MaskWrite32SymToAddr(MLIRContext *context, PatternBenefit benefit = 1)
104 : OpConversionPattern(context, benefit) {}
105
106 LogicalResult
107 matchAndRewrite(NpuMaskWrite32Op op, OpAdaptor adaptor,
108 ConversionPatternRewriter &rewriter) const override {
109
110 if (!op.getBuffer())
111 return failure();
112
113 auto device = op->getParentOfType<AIE::DeviceOp>();
114
115 auto buffer = device.lookupSymbol<AIE::BufferOp>(*op.getBuffer());
116 if (!buffer)
117 return op->emitError("buffer '" + *op.getBuffer() +
118 "' not found in device");
119
120 if (!buffer.getAddress())
121 return op->emitError("buffer must have address assigned");
122
123 const AIE::AIETargetModel &tm = device.getTargetModel();
124 uint32_t address = static_cast<uint32_t>(*buffer.getAddress()) +
125 op.getAddress() * sizeof(uint32_t);
126 auto col = buffer.getTileOp().getCol();
127 auto row = buffer.getTileOp().getRow();
128 address |= ((col & 0xff) << tm.getColumnShift()) |
129 ((row & 0xff) << tm.getRowShift()) | (address & 0xFFFFF);
130
131 rewriter.replaceOpWithNewOp<NpuMaskWrite32Op>(
132 op, address, op.getValue(), op.getMask(), nullptr, nullptr, nullptr);
133 return success();
134 }
135};
136
137struct RtpToWrite32Pattern : OpConversionPattern<NpuWriteRTPOp> {
138 using OpConversionPattern::OpConversionPattern;
139
140 RtpToWrite32Pattern(MLIRContext *context, PatternBenefit benefit = 1)
141 : OpConversionPattern(context, benefit) {}
142
143 LogicalResult
144 matchAndRewrite(NpuWriteRTPOp op, OpAdaptor adaptor,
145 ConversionPatternRewriter &rewriter) const override {
146
147 auto device = op->getParentOfType<AIE::DeviceOp>();
148
149 auto buffer = device.lookupSymbol<AIE::BufferOp>(op.getBuffer());
150 if (!buffer) {
151 op->emitError("buffer '" + op.getBuffer() + "' not found in device");
152 return failure();
153 }
154
155 if (!buffer.getAddress()) {
156 op->emitError("buffer must have address assigned");
157 return failure();
158 }
159 AIE::TileOp tile = buffer.getTileOp();
160
161 uint32_t idx = op.getIndex() * sizeof(uint32_t);
162 uint32_t address = buffer.getAddress().value() + idx;
163
164 rewriter.create<NpuWrite32Op>(op->getLoc(), address, op.getValue(), nullptr,
165 rewriter.getI32IntegerAttr(tile.getCol()),
166 rewriter.getI32IntegerAttr(tile.getRow()));
167
168 rewriter.eraseOp(op);
169 return success();
170 }
171};
172
173struct PushQueuetoWrite32Pattern : OpConversionPattern<NpuPushQueueOp> {
174
175public:
176 using OpConversionPattern::OpConversionPattern;
177
178 PushQueuetoWrite32Pattern(MLIRContext *context, PatternBenefit benefit = 1)
179 : OpConversionPattern(context, benefit) {}
180
181 LogicalResult
182 matchAndRewrite(NpuPushQueueOp op, OpAdaptor adaptor,
183 ConversionPatternRewriter &rewriter) const override {
184
185 const auto &tm = AIE::getTargetModel(op);
186 uint32_t ctrl_offset = tm.getDmaControlAddress(
187 op.getColumn(), op.getRow(), op.getChannel(), op.getDirection());
188
189 // control packet for issuing token
190 if (op.getIssueToken()) {
191 // set the task-complete-token controller ID field in the dma control
192 // register
193 AIE::TileOp shimTile = AIE::TileOp::getOrCreate(
194 rewriter, op->getParentOfType<AIE::DeviceOp>(), op.getColumn(), 0);
195 if (shimTile->hasAttr("controller_id")) {
196 AIE::PacketInfoAttr controller_id_attr =
197 shimTile->getAttrOfType<AIE::PacketInfoAttr>("controller_id");
198 uint32_t data = controller_id_attr.getPktId() << 8;
199 uint32_t mask = 0x00001F00;
200 rewriter.create<NpuMaskWrite32Op>(op->getLoc(), ctrl_offset, data, mask,
201 nullptr, nullptr, nullptr);
202 }
203 }
204
205 // the offset of the task queue register in the tile
206 uint32_t queue_offset = ctrl_offset + 0x4;
207
208 // the value to write
209 uint32_t bd_id = op.getBdId();
210 uint32_t repeat_cnt = op.getRepeatCount();
211 uint32_t cmd = 0;
212 cmd |= bd_id & 0xF;
213 cmd |= (repeat_cnt & 0xFF) << 16;
214 if (op.getIssueToken())
215 cmd |= 0x80000000;
216
217 rewriter.create<NpuWrite32Op>(op->getLoc(), queue_offset, cmd, nullptr,
218 nullptr, nullptr);
219 rewriter.eraseOp(op);
220 return success();
221 }
222};
223
224struct DmaToNpuPattern : OpConversionPattern<NpuDmaMemcpyNdOp> {
225 using OpConversionPattern::OpConversionPattern;
226
227private:
228 AIE::ShimDMAllocationGetter &allocGetter;
229
230public:
231 DmaToNpuPattern(MLIRContext *context, AIE::ShimDMAllocationGetter &getter,
232 PatternBenefit benefit = 1)
233 : OpConversionPattern(context, benefit), allocGetter(getter) {}
234
235 LogicalResult
236 matchAndRewrite(NpuDmaMemcpyNdOp op, OpAdaptor adaptor,
237 ConversionPatternRewriter &rewriter) const override {
238 const auto &targetModel = AIE::getTargetModel(op);
239 BaseMemRefType bufferType = op.getMemref().getType();
240 auto *ctx = op->getContext();
241 auto i32ty = IntegerType::get(ctx, 32);
242 auto zero = IntegerAttr::get(i32ty, 0);
243 auto memref = adaptor.getMemref();
244
245 auto dev = op->getParentOfType<AIE::DeviceOp>();
246 if (!dev)
247 return failure();
248
249 auto infoOp = allocGetter.get(dev, op.getMetadata());
250 if (!infoOp) {
251 return op->emitOpError("couldn't find shim_dma_allocation op.");
252 }
253
254 auto channelDir = infoOp->getChannelDir();
255 bool isMM2S = channelDir == AIE::DMAChannelDir::MM2S;
256 int col = infoOp->getCol();
257
258 // initialize fields to zero
259 auto column = zero;
260 auto bd_id = zero;
261 auto buffer_length = zero;
262 auto buffer_offset = zero;
263 auto enable_packet = zero;
264 auto out_of_order_id = zero;
265 auto packet_id = zero;
266 auto packet_type = zero;
267 auto d0_size = zero;
268 auto d0_stride = zero;
269 auto d1_size = zero;
270 auto d1_stride = zero;
271 auto d2_size = zero;
272 auto d2_stride = zero;
273 auto iteration_current = zero;
274 auto iteration_size = zero;
275 auto iteration_stride = zero;
276 auto next_bd = zero;
277 auto row = zero;
278 auto use_next_bd = zero;
279 auto valid_bd = zero;
280 auto lock_rel_val = zero;
281 auto lock_rel_id = zero;
282 auto lock_acq_enable = zero;
283 auto lock_acq_val = zero;
284 auto lock_acq_id = zero;
285 auto d0_zero_before = zero;
286 auto d1_zero_before = zero;
287 auto d2_zero_before = zero;
288 auto d0_zero_after = zero;
289 auto d1_zero_after = zero;
290 auto d2_zero_after = zero;
291 auto burst_length = zero;
292
293 auto issue_token = BoolAttr::get(ctx, false);
294 auto repeat_count = zero;
295 llvm::SmallVector<int64_t, 4> inputSizes = llvm::map_to_vector(
296 llvm::reverse(op.getMixedSizes()),
297 [](OpFoldResult s) { return getConstantIntValue(s).value(); });
298 llvm::SmallVector<int64_t, 4> inputStrides = llvm::map_to_vector(
299 llvm::reverse(op.getMixedStrides()),
300 [](OpFoldResult s) { return getConstantIntValue(s).value(); });
301 llvm::SmallVector<int64_t, 4> sizes(4);
302 llvm::SmallVector<int64_t, 4> strides(4);
303 getHardwareStridesWraps(targetModel, op, bufferType, inputSizes,
304 inputStrides, sizes, strides);
305 int64_t offset = op.getOffsetInBytes();
306
307 // column
308 column = IntegerAttr::get(i32ty, col);
309
310 // row
311 row = IntegerAttr::get(i32ty, 0);
312
313 bool skipTransformationChecks = op.isLinearTransferWithoutTransformation();
314 if (failed(verifyStridesWraps(op, bufferType, col, 0, inputSizes,
315 inputStrides, sizes, strides,
316 skipTransformationChecks))) {
317 return failure();
318 }
319
320 // arg_idx
321 AIEX::RuntimeSequenceOp seq_op =
322 op->getParentOfType<AIEX::RuntimeSequenceOp>();
323 if (!seq_op) {
324 op->emitOpError("NpuDmaMemcpyNdOps must have RuntimeSequenceOp parent at "
325 "time of lowering.");
326 return failure();
327 }
328 Block &entryBB = seq_op.getBody().front();
329 int arg_idx = -1;
330 for (int i = 0, e = entryBB.getNumArguments(); i < e; i++) {
331 if (entryBB.getArgument(i) == memref) {
332 arg_idx = i;
333 break;
334 }
335 }
336 if (arg_idx < 0)
337 return failure();
338
339 // bd_id
340 bd_id = IntegerAttr::get(i32ty, op.getId());
341
342 // buffer_length
343 uint64_t buffer_length_val = inputSizes[0] * op.getElementTypeBitwidth() /
344 targetModel.getAddressGenGranularity();
345 if (inputSizes.size() > 1) {
346 for (size_t i = 1; i < std::min(inputSizes.size(), (size_t)3); i++) {
347 buffer_length_val *= inputSizes[i];
348 }
349 }
350 buffer_length = IntegerAttr::get(i32ty, buffer_length_val);
351
352 // buffer_offset - zero because the complete address is set by the patch op
353 buffer_offset = IntegerAttr::get(i32ty, 0);
354
355 // enable_packet
356 if (auto packetInfo = op.getPacket()) {
357 enable_packet = IntegerAttr::get(i32ty, 1);
358 packet_type = IntegerAttr::get(i32ty, packetInfo->getPktType());
359 packet_id = IntegerAttr::get(i32ty, packetInfo->getPktId());
360 }
361
362 // out_of_order_id
363
364 if (!op.isLinearTransferWithoutTransformation()) {
365 // d0_size, d0_stride
366 d0_size = IntegerAttr::get(i32ty, sizes[0]);
367 d0_stride = IntegerAttr::get(i32ty, strides[0]);
368
369 // d1_size, d1_stride
370 d1_size = IntegerAttr::get(i32ty, sizes[1]);
371 d1_stride = IntegerAttr::get(i32ty, strides[1]);
372
373 // d2_stride
374 d2_stride = IntegerAttr::get(i32ty, strides[2]);
375
376 // d2_size
377 if (targetModel.isMemTile(col, 0)) // Need to be any row
378 d2_size = IntegerAttr::get(i32ty, sizes[2]);
379 else
380 d2_size = IntegerAttr::get(i32ty, 0);
381 }
382 // iteration_current, iteration_size, iteration_stride, repeat_count
383 if (inputSizes[3] > 1) {
384 if (inputStrides[3] > 0) {
385 iteration_size = IntegerAttr::get(i32ty, sizes[3]);
386 iteration_stride = IntegerAttr::get(i32ty, strides[3]);
387 } else {
388 // We allow users to encode the repeat_count as a dimension 3 stride
389 // of 0. This must lower to a iteration wrap of 0, so no stride is
390 // ever added. We then repeat the BD using the repeat_count in
391 // NpuPushQueueOp.
392 iteration_size = zero;
393 iteration_stride = zero;
394 }
395 }
396 repeat_count = IntegerAttr::get(i32ty, sizes[3]);
397
398 // next_bd
399
400 // use_next_bd
401
402 // valid_bd
403 valid_bd = IntegerAttr::get(i32ty, 1);
404
405 // lock_rel_val
406
407 // lock_rel_id
408
409 // lock_acq_enable
410
411 // lock_acq_val
412
413 // lock_acq_id
414
415 // d0_zero_before
416 d0_zero_before = IntegerAttr::get(i32ty, op.getD0ZeroBefore());
417
418 // d1_zero_before
419 d1_zero_before = IntegerAttr::get(i32ty, op.getD1ZeroBefore());
420
421 // d2_zero_before
422 d2_zero_before = IntegerAttr::get(i32ty, op.getD2ZeroBefore());
423
424 // d0_zero_after
425 d0_zero_after = IntegerAttr::get(i32ty, op.getD0ZeroAfter());
426
427 // d1_zero_after
428 d1_zero_after = IntegerAttr::get(i32ty, op.getD1ZeroAfter());
429
430 // d2_zero_after
431 d2_zero_after = IntegerAttr::get(i32ty, op.getD2ZeroAfter());
432
433 // burst_size
434 burst_length = IntegerAttr::get(i32ty, op.getBurstLength());
435
436 // Set the issue_token
437 issue_token = BoolAttr::get(ctx, op.getIssueToken());
438 // Earlier, all S2MM channels were implicitly assumed to issue a token.
439 // This logic is kept for now for backward compatibility.
440 if (!isMM2S)
441 issue_token = BoolAttr::get(ctx, true);
442
443 if (targetModel.isMemTile(col, 0) && (!isMM2S) &&
444 (op.getD0ZeroBefore() != 0 || op.getD0ZeroAfter() != 0 ||
445 op.getD1ZeroBefore() != 0 || op.getD1ZeroAfter() != 0 ||
446 op.getD2ZeroBefore() != 0 || op.getD2ZeroAfter() != 0))
447 op->emitOpError("MemTile supports zero padding only on MM2S direction");
448
449 // write the buffer descriptor to the array
450 rewriter.create<NpuWriteBdOp>(
451 op->getLoc(), column, bd_id, buffer_length, buffer_offset,
452 enable_packet, out_of_order_id, packet_id, packet_type, d0_size,
453 d0_stride, d1_size, d1_stride, d2_size, d2_stride, iteration_current,
454 iteration_size, iteration_stride, next_bd, row, use_next_bd, valid_bd,
455 lock_rel_val, lock_rel_id, lock_acq_enable, lock_acq_val, lock_acq_id,
456 d0_zero_before, d1_zero_before, d2_zero_before, d0_zero_after,
457 d1_zero_after, d2_zero_after, burst_length);
458
459 // compute the location of the address to patch in the bd and emit patch
460 // instruction to perform the patch.
461 uint64_t addr = targetModel.getDmaBdAddress(col, 0, op.getId()) +
462 targetModel.getDmaBdAddressOffset(col, 0);
463 rewriter.create<NpuAddressPatchOp>(op->getLoc(), addr, arg_idx, offset);
464
465 // push the patched bd onto the dma task queue
466 rewriter.create<NpuPushQueueOp>(
467 op->getLoc(), column, row, infoOp->getChannelDirAttr(),
468 infoOp->getChannelIndexAttr(), issue_token, repeat_count, bd_id);
469
470 rewriter.eraseOp(op);
471 return success();
472 }
473};
474
475/// Convert NpuDmaWaitOp into NpuSyncOp by retrieving the necessary
476/// information from the ShimDMAAllocationOp referenced through the
477/// symbol argument of this op.
478struct DmaWaitToSyncPattern : OpConversionPattern<NpuDmaWaitOp> {
479
480private:
481 AIE::ShimDMAllocationGetter &allocGetter;
482
483public:
484 using OpConversionPattern::OpConversionPattern;
485
486 DmaWaitToSyncPattern(MLIRContext *context,
488 PatternBenefit benefit = 1)
489 : OpConversionPattern(context, benefit), allocGetter(getter) {}
490
491 LogicalResult
492 matchAndRewrite(NpuDmaWaitOp op, OpAdaptor adaptor,
493 ConversionPatternRewriter &rewriter) const override {
494 AIE::DeviceOp dev = op->getParentOfType<AIE::DeviceOp>();
495 if (!dev)
496 return op->emitError("couldn't find parent of type DeviceOp");
497
498 std::optional<AIE::ShimDMAAllocationOp> shimDmaAllocOp =
499 allocGetter.get(dev, op.getSymbol());
500 if (!shimDmaAllocOp) {
501 return op->emitError("couldn't find shim_dma_allocation op");
502 }
503
504 // Create with `column_num == 1` and `row_num == 1` to check for a single
505 // column and row. Row is always 0 for shim tiles.
506 (void)rewriter.replaceOpWithNewOp<NpuSyncOp>(
507 op, shimDmaAllocOp->getCol(), /* row */ 0,
508 static_cast<uint32_t>(shimDmaAllocOp->getChannelDir()),
509 shimDmaAllocOp->getChannelIndex(), 1, 1);
510
511 return success();
512 }
513};
514
515struct WriteBdToBlockWritePattern : OpConversionPattern<NpuWriteBdOp> {
516 using OpConversionPattern::OpConversionPattern;
517
518private:
519 static int cachedId;
520
521public:
522 WriteBdToBlockWritePattern(MLIRContext *context, int &cachedId,
523 PatternBenefit benefit = 1)
524 : OpConversionPattern(context, benefit) {}
525
526 LogicalResult
527 matchAndRewrite(NpuWriteBdOp op, OpAdaptor adaptor,
528 ConversionPatternRewriter &rewriter) const override {
529
530 AIE::DeviceOp dev = op->getParentOfType<AIE::DeviceOp>();
531 const AIE::AIETargetModel &tm = dev.getTargetModel();
532
533 int num_words = 0;
534 if (isa<AIE::AIE2TargetModel>(tm))
535 num_words = 8;
536 else
537 llvm_unreachable(
538 "Unsupported AIETargetModel in WriteBdToBlockWritePattern");
539
540 std::vector<uint32_t> words(num_words, 0);
541
542 uint32_t bd_id = op.getBdId();
543 int col = op.getColumn();
544 int row = op.getRow();
545 uint64_t bd_addr = tm.getDmaBdAddress(col, row, bd_id);
546 if (tm.isShimNOCTile(col, row)) {
547 // DMA_BDX_0
548 words[0] = op.getBufferLength();
549
550 // DMA_BDX_1
551 words[1] = op.getBufferOffset();
552
553 // DMA_BDX_2
554 // En Packet , OoO BD ID , Packet ID , Packet Type
555 words[2] |= (op.getEnablePacket() & 0x1) << 30;
556 words[2] |= (op.getOutOfOrderId() & 0x3f) << 24;
557 words[2] |= (op.getPacketId() & 0x1f) << 19;
558 words[2] |= (op.getPacketType() & 0x7) << 16;
559
560 // DMA_BDX_3
561 // TODO: Secure Access
562 words[3] |= (op.getD0Size() & 0x3ff) << 20;
563 words[3] |= op.getD0Stride() & 0xfffff;
564
565 // DMA_BDX_4
566 words[4] = (getShimBurstLengthEncoding(tm, op.getBurstLength()) & 0x3)
567 << 30;
568 words[4] |= (op.getD1Size() & 0x3ff) << 20;
569 words[4] |= op.getD1Stride() & 0xfffff;
570
571 // DMA_BDX_5
572 // TODO: SIMID, AxCache, AXQoS
573 words[5] = op.getD2Stride() & 0xfffff;
574
575 // DMA_BDX_6
576 words[6] |= (op.getIterationCurrent() & 0x3f) << 26;
577 words[6] |= (op.getIterationSize() & 0x3f) << 20;
578 words[6] |= op.getIterationStride() & 0xfffff;
579
580 // DMA_BDX_7
581 // TODO: TLAST Suppress
582 words[7] |= (op.getNextBd() & 0xf) << 27;
583 words[7] |= (op.getUseNextBd() & 0x1) << 26;
584 words[7] |= (op.getValidBd() & 0x1) << 25;
585 words[7] |= (op.getLockRelVal() & 0x7f) << 18;
586 words[7] |= (op.getLockRelId() & 0xf) << 13;
587 words[7] |= (op.getLockAcqEnable() & 0x1) << 12;
588 words[7] |= (op.getLockAcqVal() & 0x7f) << 5;
589 words[7] |= op.getLockAcqId() & 0xf;
590 if (op.getD0ZeroBefore() || op.getD1ZeroBefore() ||
591 op.getD2ZeroBefore() || op.getD0ZeroAfter() || op.getD1ZeroAfter() ||
592 op.getD2ZeroAfter()) {
593 op->emitError("Zero padding is only available on MemTile");
594 }
595 } else if (tm.isMemTile(op.getColumn(), op.getRow())) {
596
597 // DMA_BDX_0
598 words[0] |= (op.getEnablePacket() & 0x1) << 31;
599 words[0] |= (op.getPacketType() & 0x7) << 28;
600 words[0] |= (op.getPacketId() & 0x1f) << 23;
601 words[0] |= (op.getOutOfOrderId() & 0x3f) << 17;
602 words[0] |= op.getBufferLength() & 0x1ffff;
603
604 // DMA_BDX_1
605 words[1] |= (op.getD0ZeroBefore() & 0x3F) << 26;
606 words[1] |= (op.getNextBd() & 0x3f) << 20;
607 words[1] |= (op.getUseNextBd() & 0x1) << 19;
608 words[1] |= op.getBufferOffset() & 0x7ffff;
609
610 // DMA_BDX_2
611 words[2] |= (op.getD0Size() & 0x3ff) << 17;
612 words[2] |= op.getD0Stride() & 0x1ffff;
613
614 // DMA_BDX_3
615 // TODO: Secure Access
616 words[3] |= (op.getD1ZeroBefore() & 0x1F) << 27;
617 words[3] |= (op.getD1Size() & 0x3ff) << 17;
618 words[3] |= op.getD1Stride() & 0x1ffff;
619
620 // DMA_BDX_4
621 // TODO: D2Size
622 words[4] |= (op.getD2ZeroBefore() & 0xF) << 27;
623 words[4] |= op.getD2Stride() & 0x1ffff;
624
625 // DMA_BDX_5
626 // ToDO: D3Stride
627 words[5] |= (op.getD2ZeroAfter() & 0xF) << 28;
628 words[5] |= (op.getD1ZeroAfter() & 0x1F) << 23;
629 words[5] |= (op.getD0ZeroAfter() & 0x3F) << 17;
630
631 // DMA_BDX_6
632 words[6] |= (op.getIterationCurrent() & 0x3f) << 23;
633 words[6] |= (op.getIterationSize() & 0x3f) << 17;
634 words[6] |= op.getIterationStride() & 0x1ffff;
635
636 // DMA_BDX_7
637 words[7] |= (op.getValidBd() & 0x1) << 31;
638 words[7] |= (op.getLockRelVal() & 0x7f) << 24;
639 words[7] |= (op.getLockRelId() & 0xff) << 16;
640 words[7] |= (op.getLockAcqEnable() & 0x1) << 15;
641 words[7] |= (op.getLockAcqVal() & 0x7f) << 8;
642 words[7] |= op.getLockAcqId() & 0xff;
643 } else {
644 // TODO: DMA BD configuration for Compute Tiles
645 op->emitError("Run-time DMA configuration is supported only for "
646 "ShimTiles and MemTiles currently.");
647 return failure();
648 }
649
650 MemRefType memrefType = MemRefType::get({num_words}, rewriter.getI32Type());
651 TensorType tensorType =
652 RankedTensorType::get({num_words}, rewriter.getI32Type());
653 memref::GlobalOp global = nullptr;
654 auto initVal = DenseElementsAttr::get<uint32_t>(tensorType, words);
655 auto otherGlobals = dev.getOps<memref::GlobalOp>();
656 for (auto g : otherGlobals) {
657 if (g == op)
658 continue;
659 if (g.getType() != memrefType)
660 continue;
661 auto otherValue = g.getInitialValue();
662 if (!otherValue)
663 continue;
664 if (*otherValue != initVal)
665 continue;
666 global = g;
667 break;
668 }
669 if (!global) {
670 OpBuilder::InsertionGuard guard(rewriter);
671 rewriter.setInsertionPoint(
672 op->getParentOfType<AIEX::RuntimeSequenceOp>());
673 std::string name = "blockwrite_data_";
674 while (dev.lookupSymbol(name + std::to_string(cachedId)))
675 cachedId++;
676 name += std::to_string(cachedId);
677 global = rewriter.create<memref::GlobalOp>(
678 op->getLoc(), name, rewriter.getStringAttr("private"), memrefType,
679 initVal, true, nullptr);
680 }
681 auto memref = rewriter.create<memref::GetGlobalOp>(op->getLoc(), memrefType,
682 global.getName());
683 (void)rewriter.replaceOpWithNewOp<NpuBlockWriteOp>(
684 op, rewriter.getUI32IntegerAttr(bd_addr), memref.getResult(), nullptr,
685 nullptr, nullptr);
686 return success();
687 }
688};
689
690int WriteBdToBlockWritePattern::cachedId = 0;
691
692struct AIEDmaToNpuPass : AIEDmaToNpuBase<AIEDmaToNpuPass> {
693
694 void getDependentDialects(DialectRegistry &registry) const override {
695 registry.insert<memref::MemRefDialect>();
696 }
697
698 void runOnOperation() override {
699
700 AIE::ShimDMAllocationGetter cachingGetter;
701
702 AIE::DeviceOp device = getOperation();
703
704 ConversionTarget target(getContext());
705 target.addLegalDialect<AIEXDialect>();
706 target.addLegalDialect<memref::MemRefDialect>();
707 target.addLegalOp<AIE::BufferOp>();
708 target.addLegalOp<AIE::ShimDMAAllocationOp>();
709 target.addLegalOp<AIE::TileOp>();
710
711 target.addIllegalOp<NpuDmaMemcpyNdOp>();
712 target.addIllegalOp<NpuDmaWaitOp>();
713 target.addIllegalOp<NpuPushQueueOp>();
714 target.addIllegalOp<NpuWriteRTPOp>();
715 target.addIllegalOp<NpuWriteBdOp>();
716 target.addDynamicallyLegalOp<NpuWrite32Op>(
717 [&](NpuWrite32Op op) { return !op.getBuffer(); });
718 target.addDynamicallyLegalOp<NpuBlockWriteOp>(
719 [&](NpuBlockWriteOp op) { return !op.getBuffer(); });
720 target.addDynamicallyLegalOp<NpuMaskWrite32Op>(
721 [&](NpuMaskWrite32Op op) { return !op.getBuffer(); });
722
723 RewritePatternSet patterns(&getContext());
724 patterns.insert<BlockWriteSymToAddr>(&getContext());
725 patterns.insert<DmaToNpuPattern>(&getContext(), cachingGetter);
726 patterns.insert<DmaWaitToSyncPattern>(&getContext(), cachingGetter);
727 patterns.insert<MaskWrite32SymToAddr>(&getContext());
728 patterns.insert<PushQueuetoWrite32Pattern>(&getContext());
729 patterns.insert<RtpToWrite32Pattern>(&getContext());
730 patterns.insert<Write32SymToAddr>(&getContext());
731 patterns.insert<WriteBdToBlockWritePattern>(&getContext());
732
733 if (failed(applyPartialConversion(device, target, std::move(patterns))))
734 signalPassFailure();
735 }
736};
737
738} // namespace
739
740std::unique_ptr<OperationPass<AIE::DeviceOp>> AIEX::createAIEDmaToNpuPass() {
741 return std::make_unique<AIEDmaToNpuPass>();
742}
virtual uint32_t getDmaControlAddress(int col, int row, int channel, AIE::DMAChannelDir direction) const =0
Return the array address of the dma task queue register for the given col, row, channel and direction...
virtual uint32_t getColumnShift() const =0
virtual uint32_t getRowShift() const =0
std::unique_ptr< mlir::OperationPass< AIE::DeviceOp > > createAIEDmaToNpuPass()
void getHardwareStridesWraps(const AIE::AIETargetModel &targetModel, mlir::Operation *op, mlir::BaseMemRefType referencedBufType, llvm::SmallVector< int64_t, 4 > inputSizes, llvm::SmallVector< int64_t, 4 > inputStrides, llvm::SmallVector< int64_t, 4 > &sizes, llvm::SmallVector< int64_t, 4 > &strides)
mlir::LogicalResult verifyStridesWraps(mlir::Operation *forOp, mlir::BaseMemRefType referencedBufType, int tileCol, int tileRow, llvm::SmallVector< int64_t, 4 > inputSizes, llvm::SmallVector< int64_t, 4 > inputStrides, llvm::SmallVector< int64_t, 4 > hardwareSizes, llvm::SmallVector< int64_t, 4 > hardwareStrides, bool skipTransformationChecks=false)
uint32_t getShimBurstLengthEncoding(const AIE::AIETargetModel &tm, uint32_t burstLength)
const AIETargetModel & getTargetModel(mlir::Operation *op)
std::optional< AIE::ShimDMAAllocationOp > get(DeviceOp dev, mlir::StringRef sym_name)