MLIR-AIE
AIEDmaToNpu.cpp
Go to the documentation of this file.
1//===- AIEDmaToNpu.cpp ------------------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2023 Advanced Micro Devices, Inc.
8//
9//===----------------------------------------------------------------------===//
10
15
16#include "mlir/Pass/Pass.h"
17#include "mlir/Transforms/DialectConversion.h"
18#include <algorithm>
19#include <cstdint>
20
21using namespace mlir;
22using namespace xilinx;
23using namespace xilinx::AIEX;
24
25namespace {
26
27struct Write32SymToAddr : OpConversionPattern<NpuWrite32Op> {
28 using OpConversionPattern::OpConversionPattern;
29
30 Write32SymToAddr(MLIRContext *context, PatternBenefit benefit = 1)
31 : OpConversionPattern(context, benefit) {}
32
33 LogicalResult
34 matchAndRewrite(NpuWrite32Op op, OpAdaptor adaptor,
35 ConversionPatternRewriter &rewriter) const override {
36
37 if (!op.getBuffer())
38 return failure();
39
40 auto device = op->getParentOfType<AIE::DeviceOp>();
41 auto buffer = device.lookupSymbol<AIE::BufferOp>(*op.getBuffer());
42 if (!buffer)
43 return op->emitError("buffer '" + *op.getBuffer() +
44 "' not found in device");
45
46 if (!buffer.getAddress())
47 return op->emitError("buffer must have address assigned");
48
49 const AIE::AIETargetModel &tm = device.getTargetModel();
50 uint32_t address = static_cast<uint32_t>(*buffer.getAddress()) +
51 op.getAddress() * sizeof(uint32_t);
52 auto col = buffer.getTileOp().getCol();
53 auto row = buffer.getTileOp().getRow();
54 address |= ((col & 0xff) << tm.getColumnShift()) |
55 ((row & 0xff) << tm.getRowShift()) | (address & 0xFFFFF);
56
57 rewriter.replaceOpWithNewOp<NpuWrite32Op>(op, address, op.getValue(),
58 nullptr, nullptr, nullptr);
59 return success();
60 }
61};
62
63struct BlockWriteSymToAddr : OpConversionPattern<NpuBlockWriteOp> {
64 using OpConversionPattern::OpConversionPattern;
65
66 BlockWriteSymToAddr(MLIRContext *context, PatternBenefit benefit = 1)
67 : OpConversionPattern(context, benefit) {}
68
69 LogicalResult
70 matchAndRewrite(NpuBlockWriteOp op, OpAdaptor adaptor,
71 ConversionPatternRewriter &rewriter) const override {
72
73 if (!op.getBuffer())
74 return failure();
75
76 auto device = op->getParentOfType<AIE::DeviceOp>();
77
78 auto buffer = device.lookupSymbol<AIE::BufferOp>(*op.getBuffer());
79 if (!buffer)
80 return op->emitError("buffer '" + *op.getBuffer() +
81 "' not found in device");
82
83 if (!buffer.getAddress())
84 return op->emitError("buffer must have address assigned");
85
86 const AIE::AIETargetModel &tm = device.getTargetModel();
87 uint32_t address = static_cast<uint32_t>(*buffer.getAddress()) +
88 op.getAddress() * sizeof(uint32_t);
89 auto col = buffer.getTileOp().getCol();
90 auto row = buffer.getTileOp().getRow();
91 address |= ((col & 0xff) << tm.getColumnShift()) |
92 ((row & 0xff) << tm.getRowShift()) | (address & 0xFFFFF);
93
94 rewriter.replaceOpWithNewOp<NpuBlockWriteOp>(op, address, op.getData(),
95 nullptr, nullptr, nullptr);
96 return success();
97 }
98};
99
100struct MaskWrite32SymToAddr : OpConversionPattern<NpuMaskWrite32Op> {
101 using OpConversionPattern::OpConversionPattern;
102
103 MaskWrite32SymToAddr(MLIRContext *context, PatternBenefit benefit = 1)
104 : OpConversionPattern(context, benefit) {}
105
106 LogicalResult
107 matchAndRewrite(NpuMaskWrite32Op op, OpAdaptor adaptor,
108 ConversionPatternRewriter &rewriter) const override {
109
110 if (!op.getBuffer())
111 return failure();
112
113 auto device = op->getParentOfType<AIE::DeviceOp>();
114
115 auto buffer = device.lookupSymbol<AIE::BufferOp>(*op.getBuffer());
116 if (!buffer)
117 return op->emitError("buffer '" + *op.getBuffer() +
118 "' not found in device");
119
120 if (!buffer.getAddress())
121 return op->emitError("buffer must have address assigned");
122
123 const AIE::AIETargetModel &tm = device.getTargetModel();
124 uint32_t address = static_cast<uint32_t>(*buffer.getAddress()) +
125 op.getAddress() * sizeof(uint32_t);
126 auto col = buffer.getTileOp().getCol();
127 auto row = buffer.getTileOp().getRow();
128 address |= ((col & 0xff) << tm.getColumnShift()) |
129 ((row & 0xff) << tm.getRowShift()) | (address & 0xFFFFF);
130
131 rewriter.replaceOpWithNewOp<NpuMaskWrite32Op>(
132 op, address, op.getValue(), op.getMask(), nullptr, nullptr, nullptr);
133 return success();
134 }
135};
136
137struct RtpToWrite32Pattern : OpConversionPattern<NpuWriteRTPOp> {
138 using OpConversionPattern::OpConversionPattern;
139
140 RtpToWrite32Pattern(MLIRContext *context, PatternBenefit benefit = 1)
141 : OpConversionPattern(context, benefit) {}
142
143 LogicalResult
144 matchAndRewrite(NpuWriteRTPOp op, OpAdaptor adaptor,
145 ConversionPatternRewriter &rewriter) const override {
146
147 auto device = op->getParentOfType<AIE::DeviceOp>();
148
149 auto buffer = device.lookupSymbol<AIE::BufferOp>(op.getBuffer());
150 if (!buffer) {
151 op->emitError("buffer '" + op.getBuffer() + "' not found in device");
152 return failure();
153 }
154
155 if (!buffer.getAddress()) {
156 op->emitError("buffer must have address assigned");
157 return failure();
158 }
159 AIE::TileOp tile = buffer.getTileOp();
160
161 uint32_t idx = op.getIndex() * sizeof(uint32_t);
162 uint32_t address = buffer.getAddress().value() + idx;
163
164 rewriter.create<NpuWrite32Op>(op->getLoc(), address, op.getValue(), nullptr,
165 rewriter.getI32IntegerAttr(tile.getCol()),
166 rewriter.getI32IntegerAttr(tile.getRow()));
167
168 rewriter.eraseOp(op);
169 return success();
170 }
171};
172
173struct PushQueuetoWrite32Pattern : OpConversionPattern<NpuPushQueueOp> {
174
175public:
176 using OpConversionPattern::OpConversionPattern;
177
178 PushQueuetoWrite32Pattern(MLIRContext *context, PatternBenefit benefit = 1)
179 : OpConversionPattern(context, benefit) {}
180
181 LogicalResult
182 matchAndRewrite(NpuPushQueueOp op, OpAdaptor adaptor,
183 ConversionPatternRewriter &rewriter) const override {
184
185 auto column = rewriter.getI32IntegerAttr(op.getColumn());
186 auto row = rewriter.getI32IntegerAttr(0);
187 bool isMM2S = op.getDirection() == AIE::DMAChannelDir::MM2S;
188
189 // control packet for issuing token
190 if (op.getIssueToken()) {
191 // set the task-complete-token controller ID field in the dma control
192 // register
193 AIE::TileOp shimTile = AIE::TileOp::getOrCreate(
194 rewriter, op->getParentOfType<AIE::DeviceOp>(), op.getColumn(), 0);
195 if (shimTile->hasAttr("controller_id")) {
196 uint32_t ctrl_offset = isMM2S ? 0x1D210 : 0x1D200;
197 if (op.getChannel() == 1)
198 ctrl_offset += 0x8;
199 AIE::PacketInfoAttr controller_id_attr =
200 shimTile->getAttrOfType<AIE::PacketInfoAttr>("controller_id");
201 uint32_t data = controller_id_attr.getPktId() << 8;
202 uint32_t mask = 0x00000F00;
203 rewriter.create<NpuMaskWrite32Op>(op->getLoc(), ctrl_offset, data, mask,
204 nullptr, column, row);
205 }
206 }
207
208 // the offset of the task queue register in the tile
209 uint32_t queue_offset = isMM2S ? 0x1D214 : 0x1D204;
210 if (op.getChannel() == 1)
211 queue_offset += 0x8;
212
213 // the value to write
214 uint32_t bd_id = op.getBdId();
215 uint32_t repeat_cnt = op.getRepeatCount();
216 uint32_t cmd = 0;
217 cmd |= bd_id & 0xF;
218 cmd |= (repeat_cnt & 0xFF) << 16;
219 if (op.getIssueToken())
220 cmd |= 0x80000000;
221
222 rewriter.create<NpuWrite32Op>(op->getLoc(), queue_offset, cmd, nullptr,
223 column, row);
224 rewriter.eraseOp(op);
225 return success();
226 }
227};
228
229struct DmaToNpuPattern : OpConversionPattern<NpuDmaMemcpyNdOp> {
230 using OpConversionPattern::OpConversionPattern;
231
232private:
233 AIE::ShimDMAllocationGetter &allocGetter;
234
235public:
236 DmaToNpuPattern(MLIRContext *context, AIE::ShimDMAllocationGetter &getter,
237 PatternBenefit benefit = 1)
238 : OpConversionPattern(context, benefit), allocGetter(getter) {}
239
240 LogicalResult
241 matchAndRewrite(NpuDmaMemcpyNdOp op, OpAdaptor adaptor,
242 ConversionPatternRewriter &rewriter) const override {
243 const auto &targetModel = AIE::getTargetModel(op);
244 BaseMemRefType bufferType = op.getMemref().getType();
245 auto *ctx = op->getContext();
246 auto i32ty = IntegerType::get(ctx, 32);
247 auto zero = IntegerAttr::get(i32ty, 0);
248 auto memref = adaptor.getMemref();
249
250 auto dev = op->getParentOfType<AIE::DeviceOp>();
251 if (!dev)
252 return failure();
253
254 auto infoOp = allocGetter.get(dev, op.getMetadata());
255 if (!infoOp) {
256 return op->emitOpError("couldn't find shim_dma_allocation op.");
257 }
258
259 auto channelDir = infoOp->getChannelDir();
260 bool isMM2S = channelDir == AIE::DMAChannelDir::MM2S;
261 int col = infoOp->getCol();
262
263 // initialize fields to zero
264 auto column = zero;
265 auto bd_id = zero;
266 auto buffer_length = zero;
267 auto buffer_offset = zero;
268 auto enable_packet = zero;
269 auto out_of_order_id = zero;
270 auto packet_id = zero;
271 auto packet_type = zero;
272 auto d0_size = zero;
273 auto d0_stride = zero;
274 auto d1_size = zero;
275 auto d1_stride = zero;
276 auto d2_size = zero;
277 auto d2_stride = zero;
278 auto iteration_current = zero;
279 auto iteration_size = zero;
280 auto iteration_stride = zero;
281 auto next_bd = zero;
282 auto row = zero;
283 auto use_next_bd = zero;
284 auto valid_bd = zero;
285 auto lock_rel_val = zero;
286 auto lock_rel_id = zero;
287 auto lock_acq_enable = zero;
288 auto lock_acq_val = zero;
289 auto lock_acq_id = zero;
290 auto d0_zero_before = zero;
291 auto d1_zero_before = zero;
292 auto d2_zero_before = zero;
293 auto d0_zero_after = zero;
294 auto d1_zero_after = zero;
295 auto d2_zero_after = zero;
296 auto burst_length = zero;
297
298 auto issue_token = BoolAttr::get(ctx, false);
299 auto repeat_count = zero;
300 llvm::SmallVector<int64_t, 4> inputSizes = llvm::map_to_vector(
301 llvm::reverse(op.getMixedSizes()),
302 [](OpFoldResult s) { return getConstantIntValue(s).value(); });
303 llvm::SmallVector<int64_t, 4> inputStrides = llvm::map_to_vector(
304 llvm::reverse(op.getMixedStrides()),
305 [](OpFoldResult s) { return getConstantIntValue(s).value(); });
306 llvm::SmallVector<int64_t, 4> sizes(4);
307 llvm::SmallVector<int64_t, 4> strides(4);
308 getHardwareStridesWraps(targetModel, bufferType, inputSizes, inputStrides,
309 sizes, strides);
310 int64_t offset = op.getOffsetInBytes();
311
312 // column
313 column = IntegerAttr::get(i32ty, col);
314
315 // row
316 row = IntegerAttr::get(i32ty, 0);
317
318 bool skipTransformationChecks = op.isLinearTransferWithoutTransformation();
319 if (failed(verifyStridesWraps(op, bufferType, col, 0, inputSizes,
320 inputStrides, sizes, strides,
321 skipTransformationChecks))) {
322 return failure();
323 }
324
325 // arg_idx
326 AIEX::RuntimeSequenceOp seq_op =
327 op->getParentOfType<AIEX::RuntimeSequenceOp>();
328 if (!seq_op) {
329 op->emitOpError("NpuDmaMemcpyNdOps must have RuntimeSequenceOp parent at "
330 "time of lowering.");
331 return failure();
332 }
333 Block &entryBB = seq_op.getBody().front();
334 int arg_idx = -1;
335 for (int i = 0, e = entryBB.getNumArguments(); i < e; i++) {
336 if (entryBB.getArgument(i) == memref) {
337 arg_idx = i;
338 break;
339 }
340 }
341 if (arg_idx < 0)
342 return failure();
343
344 // bd_id
345 bd_id = IntegerAttr::get(i32ty, op.getId());
346
347 // buffer_length
348 uint64_t buffer_length_val = inputSizes[0] *
349 bufferType.getElementTypeBitWidth() /
350 targetModel.getAddressGenGranularity();
351 if (inputSizes.size() > 1) {
352 for (size_t i = 1; i < std::min(inputSizes.size(), (size_t)3); i++) {
353 buffer_length_val *= inputSizes[i];
354 }
355 }
356 buffer_length = IntegerAttr::get(i32ty, buffer_length_val);
357
358 // buffer_offset - zero because the complete address is set by the patch op
359 buffer_offset = IntegerAttr::get(i32ty, 0);
360
361 // enable_packet
362 if (auto packetInfo = op.getPacket()) {
363 enable_packet = IntegerAttr::get(i32ty, 1);
364 packet_type = IntegerAttr::get(i32ty, packetInfo->getPktType());
365 packet_id = IntegerAttr::get(i32ty, packetInfo->getPktId());
366 }
367
368 // out_of_order_id
369
370 if (!op.isLinearTransferWithoutTransformation()) {
371 // d0_size, d0_stride
372 d0_size = IntegerAttr::get(i32ty, sizes[0]);
373 d0_stride = IntegerAttr::get(i32ty, strides[0]);
374
375 // d1_size, d1_stride
376 d1_size = IntegerAttr::get(i32ty, sizes[1]);
377 d1_stride = IntegerAttr::get(i32ty, strides[1]);
378
379 // d2_stride
380 d2_stride = IntegerAttr::get(i32ty, strides[2]);
381
382 // d2_size
383 if (targetModel.isMemTile(col, 0)) // Need to be any row
384 d2_size = IntegerAttr::get(i32ty, sizes[2]);
385 else
386 d2_size = IntegerAttr::get(i32ty, 0);
387 }
388 // iteration_current, iteration_size, iteration_stride, repeat_count
389 if (inputSizes[3] > 1) {
390 if (inputStrides[3] > 0) {
391 iteration_size = IntegerAttr::get(i32ty, sizes[3]);
392 iteration_stride = IntegerAttr::get(i32ty, strides[3]);
393 } else {
394 // We allow users to encode the repeat_count as a dimension 3 stride
395 // of 0. This must lower to a iteration wrap of 0, so no stride is
396 // ever added. We then repeat the BD using the repeat_count in
397 // NpuPushQueueOp.
398 iteration_size = zero;
399 iteration_stride = zero;
400 }
401 }
402 repeat_count = IntegerAttr::get(i32ty, sizes[3]);
403
404 // next_bd
405
406 // use_next_bd
407
408 // valid_bd
409 valid_bd = IntegerAttr::get(i32ty, 1);
410
411 // lock_rel_val
412
413 // lock_rel_id
414
415 // lock_acq_enable
416
417 // lock_acq_val
418
419 // lock_acq_id
420
421 // d0_zero_before
422 d0_zero_before = IntegerAttr::get(i32ty, op.getD0ZeroBefore());
423
424 // d1_zero_before
425 d1_zero_before = IntegerAttr::get(i32ty, op.getD1ZeroBefore());
426
427 // d2_zero_before
428 d2_zero_before = IntegerAttr::get(i32ty, op.getD2ZeroBefore());
429
430 // d0_zero_after
431 d0_zero_after = IntegerAttr::get(i32ty, op.getD0ZeroAfter());
432
433 // d1_zero_after
434 d1_zero_after = IntegerAttr::get(i32ty, op.getD1ZeroAfter());
435
436 // d2_zero_after
437 d2_zero_after = IntegerAttr::get(i32ty, op.getD2ZeroAfter());
438
439 // burst_size
440 burst_length = IntegerAttr::get(i32ty, op.getBurstLength());
441
442 // Set the issue_token
443 issue_token = BoolAttr::get(ctx, op.getIssueToken());
444 // Earlier, all S2MM channels were implicitly assumed to issue a token.
445 // This logic is kept for now for backward compatibility.
446 if (!isMM2S)
447 issue_token = BoolAttr::get(ctx, true);
448
449 if (targetModel.isMemTile(col, 0) && (!isMM2S) &&
450 (op.getD0ZeroBefore() != 0 || op.getD0ZeroAfter() != 0 ||
451 op.getD1ZeroBefore() != 0 || op.getD1ZeroAfter() != 0 ||
452 op.getD2ZeroBefore() != 0 || op.getD2ZeroAfter() != 0))
453 op->emitOpError("MemTile supports zero padding only on MM2S direction");
454
455 rewriter.create<NpuWriteBdOp>(
456 op->getLoc(), column, bd_id, buffer_length, buffer_offset,
457 enable_packet, out_of_order_id, packet_id, packet_type, d0_size,
458 d0_stride, d1_size, d1_stride, d2_size, d2_stride, iteration_current,
459 iteration_size, iteration_stride, next_bd, row, use_next_bd, valid_bd,
460 lock_rel_val, lock_rel_id, lock_acq_enable, lock_acq_val, lock_acq_id,
461 d0_zero_before, d1_zero_before, d2_zero_before, d0_zero_after,
462 d1_zero_after, d2_zero_after, burst_length);
463
465 targetModel, op.getId(), col, 0);
466
467 rewriter.create<NpuAddressPatchOp>(op->getLoc(), addr, arg_idx, offset);
468
469 rewriter.create<NpuPushQueueOp>(
470 op->getLoc(), column, row, infoOp->getChannelDirAttr(),
471 infoOp->getChannelIndexAttr(), issue_token, repeat_count, bd_id);
472
473 rewriter.eraseOp(op);
474 return success();
475 }
476};
477
478/// Convert NpuDmaWaitOp into NpuSyncOp by retrieving the necessary
479/// information from the ShimDMAAllocationOp referenced through the
480/// symbol argument of this op.
481struct DmaWaitToSyncPattern : OpConversionPattern<NpuDmaWaitOp> {
482
483private:
484 AIE::ShimDMAllocationGetter &allocGetter;
485
486public:
487 using OpConversionPattern::OpConversionPattern;
488
489 DmaWaitToSyncPattern(MLIRContext *context,
491 PatternBenefit benefit = 1)
492 : OpConversionPattern(context, benefit), allocGetter(getter) {}
493
494 LogicalResult
495 matchAndRewrite(NpuDmaWaitOp op, OpAdaptor adaptor,
496 ConversionPatternRewriter &rewriter) const override {
497 AIE::DeviceOp dev = op->getParentOfType<AIE::DeviceOp>();
498 if (!dev)
499 return op->emitError("couldn't find parent of type DeviceOp");
500
501 std::optional<AIE::ShimDMAAllocationOp> shimDmaAllocOp =
502 allocGetter.get(dev, op.getSymbol());
503 if (!shimDmaAllocOp) {
504 return op->emitError("couldn't find shim_dma_allocation op");
505 }
506
507 // Create with `column_num == 1` and `row_num == 1` to check for a single
508 // column and row. Row is always 0 for shim tiles.
509 (void)rewriter.replaceOpWithNewOp<NpuSyncOp>(
510 op, shimDmaAllocOp->getCol(), /* row */ 0,
511 static_cast<uint32_t>(shimDmaAllocOp->getChannelDir()),
512 shimDmaAllocOp->getChannelIndex(), 1, 1);
513
514 return success();
515 }
516};
517
518struct WriteBdToBlockWritePattern : OpConversionPattern<NpuWriteBdOp> {
519 using OpConversionPattern::OpConversionPattern;
520
521private:
522 static int cachedId;
523
524public:
525 WriteBdToBlockWritePattern(MLIRContext *context, int &cachedId,
526 PatternBenefit benefit = 1)
527 : OpConversionPattern(context, benefit) {}
528
529 LogicalResult
530 matchAndRewrite(NpuWriteBdOp op, OpAdaptor adaptor,
531 ConversionPatternRewriter &rewriter) const override {
532
533 AIE::DeviceOp dev = op->getParentOfType<AIE::DeviceOp>();
534 const AIE::AIETargetModel &tm = dev.getTargetModel();
535
536 std::vector<uint32_t> words(8, 0);
537 uint32_t bd_id = op.getBdId();
538 uint32_t bd_addr;
539 if (tm.isShimNOCTile(op.getColumn(), op.getRow())) {
540 bd_addr = (op.getColumn() << tm.getColumnShift()) |
541 (op.getRow() << tm.getRowShift()) | (0x1D000 + bd_id * 0x20);
542
543 // DMA_BDX_0
544 words[0] = op.getBufferLength();
545
546 // DMA_BDX_1
547 words[1] = op.getBufferOffset();
548
549 // DMA_BDX_2
550 // En Packet , OoO BD ID , Packet ID , Packet Type
551 words[2] |= (op.getEnablePacket() & 0x1) << 30;
552 words[2] |= (op.getOutOfOrderId() & 0x3f) << 24;
553 words[2] |= (op.getPacketId() & 0x1f) << 19;
554 words[2] |= (op.getPacketType() & 0x7) << 16;
555
556 // DMA_BDX_3
557 // TODO: Secure Access
558 words[3] |= (op.getD0Size() & 0x3ff) << 20;
559 words[3] |= op.getD0Stride() & 0xfffff;
560
561 // DMA_BDX_4
562 words[4] = (getShimBurstLengthEncoding(tm, op.getBurstLength()) & 0x3)
563 << 30;
564 words[4] |= (op.getD1Size() & 0x3ff) << 20;
565 words[4] |= op.getD1Stride() & 0xfffff;
566
567 // DMA_BDX_5
568 // TODO: SIMID, AxCache, AXQoS
569 words[5] = op.getD2Stride() & 0xfffff;
570
571 // DMA_BDX_6
572 words[6] |= (op.getIterationCurrent() & 0x3f) << 26;
573 words[6] |= (op.getIterationSize() & 0x3f) << 20;
574 words[6] |= op.getIterationStride() & 0xfffff;
575
576 // DMA_BDX_7
577 // TODO: TLAST Suppress
578 words[7] |= (op.getNextBd() & 0xf) << 27;
579 words[7] |= (op.getUseNextBd() & 0x1) << 26;
580 words[7] |= (op.getValidBd() & 0x1) << 25;
581 words[7] |= (op.getLockRelVal() & 0xef) << 18;
582 words[7] |= (op.getLockRelId() & 0xf) << 13;
583 words[7] |= (op.getLockAcqEnable() & 0x1) << 12;
584 words[7] |= (op.getLockAcqVal() & 0xef) << 5;
585 words[7] |= op.getLockAcqId() & 0xf;
586
587 if (op.getD0ZeroBefore() || op.getD1ZeroBefore() ||
588 op.getD2ZeroBefore() || op.getD0ZeroAfter() || op.getD1ZeroAfter() ||
589 op.getD2ZeroAfter()) {
590 op->emitError("Zero padding is only available on MemTile");
591 }
592 } else if (tm.isMemTile(op.getColumn(), op.getRow())) {
593 bd_addr = (op.getColumn() << tm.getColumnShift()) |
594 (op.getRow() << tm.getRowShift()) | (0xA0000 + bd_id * 0x20);
595 // DMA_BDX_0
596 words[0] |= (op.getEnablePacket() & 0x1) << 31;
597 words[0] |= (op.getPacketType() & 0x7) << 28;
598 words[0] |= (op.getPacketId() & 0x1f) << 23;
599 words[0] |= (op.getOutOfOrderId() & 0x3f) << 17;
600 words[0] |= op.getBufferLength() & 0x1ffff;
601
602 // DMA_BDX_1
603 words[1] |= (op.getD0ZeroBefore() & 0x3F) << 26;
604 words[1] |= (op.getNextBd() & 0x3f) << 20;
605 words[1] |= (op.getUseNextBd() & 0x1) << 19;
606 words[1] |= op.getBufferOffset() & 0x7ffff;
607
608 // DMA_BDX_2
609 words[2] |= (op.getD0Size() & 0x3ff) << 17;
610 words[2] |= op.getD0Stride() & 0x1ffff;
611
612 // DMA_BDX_3
613 // TODO: Secure Access
614 words[3] |= (op.getD1ZeroBefore() & 0x1F) << 27;
615 words[3] |= (op.getD1Size() & 0x3ff) << 17;
616 words[3] |= op.getD1Stride() & 0x1ffff;
617
618 // DMA_BDX_4
619 // TODO: D2Size
620 words[4] |= (op.getD2ZeroBefore() & 0xF) << 27;
621 words[4] |= op.getD2Stride() & 0x1ffff;
622
623 // DMA_BDX_5
624 // ToDO: D3Stride
625 words[5] |= (op.getD2ZeroAfter() & 0xF) << 28;
626 words[5] |= (op.getD1ZeroAfter() & 0x1F) << 23;
627 words[5] |= (op.getD0ZeroAfter() & 0x3F) << 17;
628
629 // DMA_BDX_6
630 words[6] |= (op.getIterationCurrent() & 0x3f) << 23;
631 words[6] |= (op.getIterationSize() & 0x3f) << 17;
632 words[6] |= op.getIterationStride() & 0x1ffff;
633
634 // DMA_BDX_7
635 words[7] |= (op.getValidBd() & 0x1) << 31;
636 words[7] |= (op.getLockRelVal() & 0x7f) << 24;
637 words[7] |= (op.getLockRelId() & 0xff) << 16;
638 words[7] |= (op.getLockAcqEnable() & 0x1) << 15;
639 words[7] |= (op.getLockAcqVal() & 0x7f) << 8;
640 words[7] |= op.getLockAcqId() & 0xff;
641 } else {
642 // TODO: DMA BD configuration for Compute Tiles
643 op->emitError("Run-time DMA configuration is supported only for "
644 "ShimTiles and MemTiles currently.");
645 return failure();
646 }
647
648 MemRefType memrefType = MemRefType::get({8}, rewriter.getI32Type());
649 TensorType tensorType = RankedTensorType::get({8}, rewriter.getI32Type());
650 memref::GlobalOp global = nullptr;
651 auto initVal = DenseElementsAttr::get<uint32_t>(tensorType, words);
652 auto otherGlobals = dev.getOps<memref::GlobalOp>();
653 for (auto g : otherGlobals) {
654 if (g == op)
655 continue;
656 if (g.getType() != memrefType)
657 continue;
658 auto otherValue = g.getInitialValue();
659 if (!otherValue)
660 continue;
661 if (*otherValue != initVal)
662 continue;
663 global = g;
664 break;
665 }
666 if (!global) {
667 OpBuilder::InsertionGuard guard(rewriter);
668 rewriter.setInsertionPoint(
669 op->getParentOfType<AIEX::RuntimeSequenceOp>());
670 std::string name = "blockwrite_data_";
671 while (dev.lookupSymbol(name + std::to_string(cachedId)))
672 cachedId++;
673 name += std::to_string(cachedId);
674 global = rewriter.create<memref::GlobalOp>(
675 op->getLoc(), name, rewriter.getStringAttr("private"), memrefType,
676 initVal, true, nullptr);
677 }
678 auto memref = rewriter.create<memref::GetGlobalOp>(op->getLoc(), memrefType,
679 global.getName());
680 (void)rewriter.replaceOpWithNewOp<NpuBlockWriteOp>(
681 op, rewriter.getUI32IntegerAttr(bd_addr), memref.getResult(), nullptr,
682 nullptr, nullptr);
683 return success();
684 }
685};
686
687int WriteBdToBlockWritePattern::cachedId = 0;
688
689struct AIEDmaToNpuPass : AIEDmaToNpuBase<AIEDmaToNpuPass> {
690
691 void getDependentDialects(DialectRegistry &registry) const override {
692 registry.insert<memref::MemRefDialect>();
693 }
694
695 void runOnOperation() override {
696
697 AIE::ShimDMAllocationGetter cachingGetter;
698
699 AIE::DeviceOp device = getOperation();
700
701 ConversionTarget target(getContext());
702 target.addLegalDialect<AIEXDialect>();
703 target.addLegalDialect<memref::MemRefDialect>();
704 target.addLegalOp<AIE::BufferOp>();
705 target.addLegalOp<AIE::ShimDMAAllocationOp>();
706 target.addLegalOp<AIE::TileOp>();
707
708 target.addIllegalOp<NpuDmaMemcpyNdOp>();
709 target.addIllegalOp<NpuDmaWaitOp>();
710 target.addIllegalOp<NpuPushQueueOp>();
711 target.addIllegalOp<NpuWriteRTPOp>();
712 target.addIllegalOp<NpuWriteBdOp>();
713 target.addDynamicallyLegalOp<NpuWrite32Op>(
714 [&](NpuWrite32Op op) { return !op.getBuffer(); });
715 target.addDynamicallyLegalOp<NpuBlockWriteOp>(
716 [&](NpuBlockWriteOp op) { return !op.getBuffer(); });
717 target.addDynamicallyLegalOp<NpuMaskWrite32Op>(
718 [&](NpuMaskWrite32Op op) { return !op.getBuffer(); });
719
720 RewritePatternSet patterns(&getContext());
721 patterns.insert<BlockWriteSymToAddr>(&getContext());
722 patterns.insert<DmaToNpuPattern>(&getContext(), cachingGetter);
723 patterns.insert<DmaWaitToSyncPattern>(&getContext(), cachingGetter);
724 patterns.insert<MaskWrite32SymToAddr>(&getContext());
725 patterns.insert<PushQueuetoWrite32Pattern>(&getContext());
726 patterns.insert<RtpToWrite32Pattern>(&getContext());
727 patterns.insert<Write32SymToAddr>(&getContext());
728 patterns.insert<WriteBdToBlockWritePattern>(&getContext());
729
730 if (failed(applyPartialConversion(device, target, std::move(patterns))))
731 signalPassFailure();
732 }
733};
734
735} // namespace
736
737std::unique_ptr<OperationPass<AIE::DeviceOp>> AIEX::createAIEDmaToNpuPass() {
738 return std::make_unique<AIEDmaToNpuPass>();
739}
virtual uint32_t getColumnShift() const =0
virtual uint32_t getRowShift() const =0
std::unique_ptr< mlir::OperationPass< AIE::DeviceOp > > createAIEDmaToNpuPass()
void getHardwareStridesWraps(const AIE::AIETargetModel &targetModel, mlir::BaseMemRefType referencedBufType, llvm::SmallVector< int64_t, 4 > inputSizes, llvm::SmallVector< int64_t, 4 > inputStrides, llvm::SmallVector< int64_t, 4 > &sizes, llvm::SmallVector< int64_t, 4 > &strides)
uint64_t getBufferDescriptorAddressRegisterAddress(const AIE::AIETargetModel &tm, unsigned bd_id, unsigned col, unsigned row)
mlir::LogicalResult verifyStridesWraps(mlir::Operation *forOp, mlir::BaseMemRefType referencedBufType, int tileCol, int tileRow, llvm::SmallVector< int64_t, 4 > inputSizes, llvm::SmallVector< int64_t, 4 > inputStrides, llvm::SmallVector< int64_t, 4 > hardwareSizes, llvm::SmallVector< int64_t, 4 > hardwareStrides, bool skipTransformationChecks=false)
uint32_t getShimBurstLengthEncoding(const AIE::AIETargetModel &tm, uint32_t burstLength)
const AIETargetModel & getTargetModel(mlir::Operation *op)
std::optional< AIE::ShimDMAAllocationOp > get(DeviceOp dev, mlir::StringRef sym_name)