MLIR-AIE
AIEXDialect.cpp
Go to the documentation of this file.
1//===- AIEXDialect.cpp ------------------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2019 Xilinx Inc.
8//
9//===----------------------------------------------------------------------===//
10
12
13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/IR/DialectImplementation.h"
15#include "mlir/IR/Operation.h"
16#include "mlir/IR/SymbolTable.h"
17#include "mlir/IR/TypeUtilities.h"
18#include "mlir/Interfaces/DataLayoutInterfaces.h"
19#include "mlir/Interfaces/FoldInterfaces.h"
20#include "mlir/Transforms/InliningUtils.h"
21
22#include "llvm/ADT/TypeSwitch.h"
23#include "llvm/Support/TypeSize.h"
24
25#include <cstdint>
26#include <numeric>
27
28using namespace mlir;
29using namespace xilinx;
30
31#include "aie/Dialect/AIEX/IR/AIEXDialect.cpp.inc"
32
33#define GET_TYPEDEF_CLASSES
34#include "aie/Dialect/AIEX/IR/AIEXTypes.cpp.inc"
35
36namespace xilinx::AIEX {
37
38// FIXME: use Tablegen'd dialect class
39void AIEXDialect::initialize() {
40 addOperations<
41#define GET_OP_LIST
42#include "aie/Dialect/AIEX/IR/AIEX.cpp.inc"
43 >();
44 addTypes<
45#define GET_TYPEDEF_LIST
46#include "aie/Dialect/AIEX/IR/AIEXTypes.cpp.inc"
47 >();
48}
49
50} // namespace xilinx::AIEX
51
52#define GET_OP_CLASSES
53#include "aie/Dialect/AIEX/IR/AIEX.cpp.inc"
54
55/* Return the correct values to write to the hardware registers to configure
56 strides and wraps given the input user-facing strides and wraps.
57
58 In the IR, we express strides in units of element data type, but the hardware
59 requires it in units of address granularity. Address granularity currently is
60 4 bytes for all hardware.
61
62
63 User-facing strides/wraps relate to hardware as follows:
64
65 - By default, stride 0 and size 1 is assumed if unspecified.
66 - If only N strides/wraps are defined, those define the lowest N dimensions.
67
68 inputStride[3] == iteration_stride / elemSizeFac + 1
69 inputWrap[3] == iteration_size + 1
70 Highest-dimension stride/wrap is iteration count / iteration stride.
71 inputStride[2] == d2_stride / elemSizeFac + 1
72 Note: d2_size is not specified in hardware as it is
73 implicit from the total buffer transfer length
74 inputStride[1] == d1_stride / elemSizeFac + 1
75 inputSize[1] == d1_size
76 inputStride[0] == d0_stride / elemSizeFac + 1
77 inputSize[0] == d0_size / elemSizeFac
78
79 where elemSizeFac == bufferElementSize / addressGranularity
80 where bufferElementSize == size in bytes of elements in buffer,
81 e.g. 4 for int32
82 where addressGranularity == transfer granularity in hardware, which is
83 4 bytes for all current hardware
84
85 Note: strides are expressed offset by one from user input strides, because the
86 hardware does not support a 0 stride (repeat).
87 */
89 mlir::Operation *op,
90 mlir::BaseMemRefType referencedBufType,
91 llvm::SmallVector<int64_t, 4> inputSizes,
92 llvm::SmallVector<int64_t, 4> inputStrides,
93 llvm::SmallVector<int64_t, 4> &sizes,
94 llvm::SmallVector<int64_t, 4> &strides) {
95 assert(inputSizes.size() == inputStrides.size());
96 assert(sizes.size() == 4);
97 assert(strides.size() == 4);
98
99 DataLayout dataLayout = DataLayout::closest(op);
100 auto elemWidth =
101 dataLayout.getTypeSizeInBits(referencedBufType.getElementType());
102 auto addressGranularity = targetModel.getAddressGenGranularity();
103
104 // Output strides and sizes are default-initialized to 0
105 std::fill(sizes.begin(), sizes.end(), 0);
106 std::fill(strides.begin(), strides.end(), 0);
107
108 if (inputSizes[0] == 0) {
109 // Illegal input, this won't transfer anything at all.
110 // Leave it to the verification functions to complain to the user.
111 return;
112 }
113
114 // d0_size, d0_stride
115 sizes[0] = inputSizes[0] * elemWidth / addressGranularity;
116 if (inputStrides[0] * elemWidth < addressGranularity ||
117 (elemWidth > addressGranularity)) {
118 // First check:
119 // While the hardware cannot transfer less than addressGranularity bits at
120 // a time, the user may expresses a contiguous transfer of multiple
121 // elements with a stride smaller than addressGranularity. We can thus set
122 // the stride to 1 (encoded in hardware as 0) here to allow such transfers.
123 // The verification function should ensure that
124 // inputStrides[0] * elemWidth < addressGranularity
125 // iff. inputSize[0] * elemWidth > addressGranularity.
126 // Second check:
127 // If the element width is larger than addressGranularity, we need to make
128 // sure that all bytes are properly copied and therefore the stride must be
129 // set to 1 (encoded in hardware as 0).
130 // The verification function should ensure that
131 // inputStrides[0] * elemWidth % addressGranularity == 0
132 // && inputStrides[0] == 1 if elemWidth > addressGranularity
133 // This makes it impossible to have a stride greater than 1 for
134 // elemWidths bigger than addressGranularity, even if they are a multiple of
135 // it. Such operations should make use of an additional dimension instead.
136 strides[0] = 0;
137 } else {
138 strides[0] = inputStrides[0] * elemWidth / addressGranularity - 1;
139 }
140
141 // d1_size, d1_stride
142 sizes[1] = inputSizes[1];
143 if (inputSizes[1] > 1) {
144 // Stride only matters if we have more than one iteration.
145 strides[1] = inputStrides[1] * elemWidth / addressGranularity - 1;
146 }
147
148 // d2_size, d2_stride
149 sizes[2] = inputSizes[2];
150 if (inputSizes[2] > 1) {
151 // Stride only matters if we have more than one iteration.
152 strides[2] = inputStrides[2] * elemWidth / addressGranularity - 1;
153 }
154
155 // iteration_size, iteration_stride
156 if (inputSizes[3] > 1) {
157 // Stride only matters if we have more than one iteration.
158 sizes[3] = inputSizes[3] - 1;
159 // Note that the iteration_stride must be positive, just like the other
160 // dimensions. However, one can encode a zero-stride "repeat" of the same
161 // transfer by setting a positive repeat_count on the pushToQueue instr,
162 // and setting the size here to 1. This causes the BD to "wrap" at every
163 // single iteration, effectively never adding the specified stride, in turn
164 // equalling a repeat without stride.
165 if (inputStrides[3] > 0) {
166 strides[3] = inputStrides[3] * elemWidth / addressGranularity - 1;
167 }
168 }
169}
170
171mlir::LogicalResult
172AIEX::verifyStridesWraps(mlir::Operation *forOp,
173 mlir::BaseMemRefType referencedBufType, int tileCol,
174 int tileRow, llvm::SmallVector<int64_t, 4> inputSizes,
175 llvm::SmallVector<int64_t, 4> inputStrides,
176 llvm::SmallVector<int64_t, 4> hardwareSizes,
177 llvm::SmallVector<int64_t, 4> hardwareStrides,
178 bool skipTransformationChecks) {
179 const auto &targetModel = AIE::getTargetModel(forOp);
180 auto addressGranularity = targetModel.getAddressGenGranularity();
181 DataLayout dataLayout = DataLayout::closest(forOp);
182 auto elemWidth =
183 dataLayout.getTypeSizeInBits(referencedBufType.getElementType());
184
185 uint32_t wrap_bits = 0;
186 uint32_t step_bits = 0;
187 uint32_t iter_bits = 6;
188 if (targetModel.isShimNOCTile(tileCol, tileRow)) {
189 step_bits = 20; // XAIEMLGBL_NOC_MODULE_DMA_BD0_3_D0_STEPSIZE_WIDTH
190 wrap_bits = 10; // XAIEMLGBL_NOC_MODULE_DMA_BD0_3_D0_WRAP_WIDTH
191 } else if (targetModel.isMemTile(tileCol, tileRow)) {
192 step_bits = 17; // XAIEMLGBL_MEM_TILE_MODULE_DMA_BD0_2_D0_STEPSIZE_WIDTH
193 wrap_bits = 10; // XAIEMLGBL_MEM_TILE_MODULE_DMA_BD0_2_D0_WRAP_WIDTH
194 } else if (targetModel.isCoreTile(tileCol, tileRow)) {
195 step_bits = 13; // XAIEMLGBL_MEMORY_MODULE_DMA_BD0_2_D0_STEPSIZE_WIDTH
196 wrap_bits = 8; // XAIEMLGBL_MEMORY_MODULE_DMA_BD0_3_D0_WRAP_WIDTH
197 } else {
198 return forOp->emitOpError(
199 "Unsupported tile type at (" + std::to_string(tileCol) + ", " +
200 std::to_string(tileRow) + ") Must be ShimNOC, Mem or Core.");
201 }
202
203 for (int i = 0; i < 4; i++) {
204 if (inputSizes[i] <= 0) {
205 return forOp->emitOpError("Size ") << i << " must be a positive integer.";
206 }
207 }
208
209 if (inputSizes[0] * elemWidth % addressGranularity != 0) {
210 std::stringstream msg;
211 msg << "Transfer sizes must be multiples of " << (addressGranularity / 8)
212 << " bytes. " << inputSizes[0] << " elements at " << (elemWidth / 8)
213 << " bytes each equal " << (inputSizes[0] * elemWidth / 8)
214 << " bytes, which is not divisible by " << (addressGranularity / 8)
215 << ". ";
216 return forOp->emitOpError(msg.str());
217 }
218
219 for (int i = 0; i < 3; i++) {
220 if (inputSizes[i] > 1 && inputStrides[i] < 1) {
221 // If inputSize[i] == 1, anything is allowable in the stride, since that
222 // stride will never be applied. For any larger size, we must verify that
223 // the stride is positive.
224 return forOp->emitOpError("Stride ")
225 << i << " must be a positive integer.";
226 }
227 }
228 // A value of zero is allowable for the fourth-dimension stride
229 // (this indicates an interation stride for the repeat of 0)
230 if (inputSizes[3] > 1 && inputStrides[3] < 0) {
231 return forOp->emitOpError("Stride 3 must be a non-negative integer.");
232 }
233
234 for (int i = 0; i < 4; i++) {
235 // strides[0] == 1 is ok iff the transfer size is a multiple of
236 // addressGranularity, which is checked below
237 if (i == 0 && inputStrides[i] == 1)
238 continue;
239 if (inputStrides[i] * elemWidth % addressGranularity != 0) {
240 std::stringstream msg;
241 msg << "Stride " << i << " is " << inputStrides[i] << " elements * "
242 << (elemWidth / 8) << " bytes = " << (inputStrides[i] * elemWidth / 8)
243 << " bytes, which is not divisible by " << (addressGranularity / 8)
244 << ". ";
245 return forOp->emitOpError(msg.str());
246 }
247 }
248
249 if (!skipTransformationChecks && hardwareSizes[0] > (1 << wrap_bits) - 1)
250 return forOp->emitOpError(
251 "Size 0 exceeds the [0:" + std::to_string((1 << wrap_bits) - 1) +
252 "] range.");
253 if (hardwareSizes[1] > (1 << wrap_bits) - 1)
254 return forOp->emitOpError(
255 "Size 1 exceeds the [0:" + std::to_string((1 << wrap_bits) - 1) +
256 "] range.");
257 if (hardwareSizes[3] > (1 << iter_bits))
258 return forOp->emitOpError(
259 "Size 3 exceeds the [1:" + std::to_string(1 << iter_bits) + "] range.");
260 if (hardwareStrides[0] > (1 << step_bits))
261 return forOp->emitOpError("Stride 0 exceeds the [1:" +
262 std::to_string(1 << step_bits) + "] range.");
263 if (hardwareStrides[1] > (1 << step_bits))
264 return forOp->emitOpError("Stride 1 exceeds the [1:" +
265 std::to_string(1 << step_bits) + "] range.");
266 if (hardwareStrides[2] > (1 << step_bits))
267 return forOp->emitOpError("Stride 2 exceeds the [1:" +
268 std::to_string(1 << step_bits) + "] range.");
269 // strides[3] exceeding the range is ok iff the sizes[3] is one, which is
270 // checked below
271 if (hardwareStrides[3] > (1 << step_bits) && hardwareSizes[3] > 0)
272 return forOp->emitOpError("Stride 3 exceeds the [1:" +
273 std::to_string(1 << step_bits) + "] range.");
274
275 return success();
276}
277
278//===----------------------------------------------------------------------===//
279// UseTokenOp
280//===----------------------------------------------------------------------===//
281
282LogicalResult AIEX::UseTokenOp::verify() {
283 auto *parentOp = (*this)->getParentOp();
284 if (isa<func::FuncOp>(parentOp) || isa<AIE::CoreOp>(parentOp) ||
285 isa<AIE::MemOp>(parentOp) || isa<AIE::ShimDMAOp>(parentOp))
286 return success();
287 return failure();
288}
289
290//===----------------------------------------------------------------------===//
291// MulticastOp
292//===----------------------------------------------------------------------===//
293
294LogicalResult AIEX::MulticastOp::verify() {
295 Region &body = getPorts();
296 assert(getOperation()->getNumRegions());
297 assert(!body.empty());
298 for (auto &ops : body.front())
299 if (!isa<MultiDestOp, AIE::EndOp>(ops))
300 return ops.emitOpError("cannot be contained in a Multicast op");
301
302 return success();
303}
304
305//===----------------------------------------------------------------------===//
306// BroadcastPacketOp
307//===----------------------------------------------------------------------===//
308
309LogicalResult AIEX::BroadcastPacketOp::verify() {
310 Region &body = getPorts();
311 assert(getOperation()->getNumRegions());
312 assert(!body.empty());
313 for (auto &ops : body.front())
314 if (!isa<BPIDOp, AIE::EndOp>(ops))
315 return ops.emitOpError("cannot be contained in a BroadcastPacket op");
316
317 return success();
318}
319
320//===----------------------------------------------------------------------===//
321// NpuDmaMemcpyNdOp
322//===----------------------------------------------------------------------===//
323
324/* Calculates the offset value to be written to the
325 */
326int64_t AIEX::NpuDmaMemcpyNdOp::getOffsetInBytes() {
327 llvm::SmallVector<int64_t, 4> offsets =
328 llvm::map_to_vector(llvm::reverse(getMixedOffsets()), [](OpFoldResult s) {
329 return getConstantIntValue(s).value();
330 });
331 llvm::SmallVector<int64_t, 4> strides =
332 llvm::map_to_vector(llvm::reverse(getMixedStrides()), [](OpFoldResult s) {
333 return getConstantIntValue(s).value();
334 });
335 size_t offset = 0;
336 size_t R = offsets.size();
337 size_t el_bit_width = getElementTypeBitwidth();
338 assert(el_bit_width % 8 == 0 &&
339 "Expected Memref element bitwidth to be multiple of 8.");
340 size_t S = el_bit_width / 8;
341 for (size_t i = 0; i < R; i++)
342 offset += offsets[i] * strides[i] * S;
343 return offset;
344}
345
346// Returns true when sizes/strides describe a plain contiguous transfer with
347// no data layout transformation (d1/d2 sizes == 1, d0 stride == 1).
348// d3 (repeat) is intentionally excluded.
349bool AIEX::isLinearTransfer(llvm::ArrayRef<int64_t> sizes,
350 llvm::ArrayRef<int64_t> strides) {
351 return sizes[1] == 1 && sizes[2] == 1 && strides[0] == 1 && strides[1] == 0 &&
352 strides[2] == 0;
353}
354
355// dma_memcpy_nd transfers of the form [*, 1, 1, len][*, 0, 0, 1] do not
356// specify any data layout transformation, but simply express a contiguous
357// transfer of `len`. The 4th dimension is excluded because a repeat count
358// is still compatible with a linear transfer.
359bool AIEX::NpuDmaMemcpyNdOp::isLinearTransferWithoutTransformation() {
360 llvm::SmallVector<int64_t, 4> inputSizes =
361 llvm::map_to_vector(llvm::reverse(getMixedSizes()), [](OpFoldResult s) {
362 return getConstantIntValue(s).value();
363 });
364 llvm::SmallVector<int64_t, 4> inputStrides =
365 llvm::map_to_vector(llvm::reverse(getMixedStrides()), [](OpFoldResult s) {
366 return getConstantIntValue(s).value();
367 });
368 return isLinearTransfer(inputSizes, inputStrides);
369}
370
371// Helper method to check if a requested burst length is supported by the target
372// model. Returns an error message if the burst length is not supported or an
373// empty option otherwise.
374static std::optional<std::string>
375checkBurstLength(const xilinx::AIE::AIETargetModel &targetModel,
376 uint32_t requestedBurstLength) {
377 if (requestedBurstLength != 0) {
378 auto bel = targetModel.getShimBurstEncodingsAndLengths();
379 auto pair = std::find_if(bel.begin(), bel.end(),
380 [=](const std::pair<uint32_t, uint32_t> &p) {
381 return p.second == requestedBurstLength;
382 });
383
384 if (pair == bel.end()) {
385 std::string errorMessage =
386 "Requested burst length is not supported by the target. "
387 "Supported burst lengths:";
388
389 errorMessage =
390 std::accumulate(bel.begin(), bel.end(), errorMessage,
391 [](const std::string &a, auto b) {
392 return a + " " + std::to_string(b.second);
393 });
394
395 return errorMessage;
396 }
397 }
398
399 return std::nullopt;
400}
401
402LogicalResult AIEX::NpuDmaMemcpyNdOp::verify() {
403 BaseMemRefType buffer = getMemref().getType();
404 const auto &targetModel = AIE::getTargetModel(*this);
405 auto addressGranularity = targetModel.getAddressGenGranularity();
406
407 if (getElementTypeBitwidth() > addressGranularity) {
408 return emitOpError("Maximum element bit width allowed is ")
409 << addressGranularity << "bits. ";
410 }
411 if (buffer.hasStaticShape() &&
412 (buffer.getNumElements() * getElementTypeBitwidth()) <
413 addressGranularity) {
414 return emitOpError("Minimum data transfer size required is ")
415 << addressGranularity << "bits. ";
416 }
417 if (!llvm::all_of(getMixedStrides(), [](OpFoldResult s) {
418 return getConstantIntValue(s).has_value();
419 }))
420 return emitOpError("Only constant strides currently supported.");
421 if (!llvm::all_of(getMixedSizes(), [](OpFoldResult s) {
422 return getConstantIntValue(s).has_value();
423 }))
424 return emitOpError("Only constant sizes currently supported.");
425 if (!llvm::all_of(getMixedOffsets(), [](OpFoldResult s) {
426 return getConstantIntValue(s).has_value();
427 }))
428 return emitOpError("Only constant offsets currently supported.");
429
430 llvm::SmallVector<int64_t, 4> inputSizes =
431 llvm::map_to_vector(llvm::reverse(getMixedSizes()), [](OpFoldResult s) {
432 return getConstantIntValue(s).value();
433 });
434 llvm::SmallVector<int64_t, 4> inputStrides =
435 llvm::map_to_vector(llvm::reverse(getMixedStrides()), [](OpFoldResult s) {
436 return getConstantIntValue(s).value();
437 });
438 llvm::SmallVector<int64_t, 4> hardwareSizes(4);
439 llvm::SmallVector<int64_t, 4> hardwareStrides(4);
440 getHardwareStridesWraps(targetModel, getOperation(), buffer, inputSizes,
441 inputStrides, hardwareSizes, hardwareStrides);
442 int64_t offset = getOffsetInBytes();
443
444 auto errorMessage = checkBurstLength(targetModel, getBurstLength());
445 if (errorMessage.has_value()) {
446 return emitOpError(errorMessage.value());
447 }
448
449 // The experimental HSA target uses this op on AIE1, skip all the AIE2
450 // specific checks
451 if (targetModel.getTargetArch() == AIE::AIEArch::AIE1)
452 return success();
453
454 if (offset % 4 != 0) {
455 return emitOpError("Offset must be 4-byte-aligned.");
456 }
457
458 // dma_memcpy_nd transfers of the form [1, 1, 1, len][0, 0, 0, 1] do not
459 // specify any data layout transformation, but simply express a contiguous
460 // transfer of `len`. For backwards compatibility, we allow this to proceed
461 // even if it exceeds the maximum stride/wrap size of any one dimension,
462 // and simply do not lower any data layout transformations, since there is
463 // no other way to express this at the dma_memcpy_nd interface otherwise.
464 AIE::DeviceOp dev = getOperation()->getParentOfType<AIE::DeviceOp>();
465 if (auto allocOp = AIE::ShimDMAAllocationOp::getForSymbol(
466 dev, getMetadata().getRootReference())) {
467 AIE::TileOp tile = allocOp.getTileOp();
468 if (!tile) {
469 return emitOpError("shim DMA allocation must reference a valid TileOp");
470 }
471 int col = tile.getCol();
472 int row = tile.getRow();
473 bool skipTransformationChecks = isLinearTransferWithoutTransformation();
474 if (failed(verifyStridesWraps(*this, buffer, col, row, inputSizes,
475 inputStrides, hardwareSizes, hardwareStrides,
476 skipTransformationChecks))) {
477 return failure();
478 }
479 }
480
481 // packet header
482 if (auto packetInfo = getPacket()) {
483 if (packetInfo->getPktType() > 7)
484 return emitOpError("Packet type field can only hold 3 bits.");
485 if (packetInfo->getPktId() > 31)
486 return emitOpError("Packet ID field can only hold 5 bits.");
487 }
488
489 return success();
490}
491
492//===----------------------------------------------------------------------===//
493// NpuDmaWaitOp
494//===----------------------------------------------------------------------===//
495
496LogicalResult AIEX::NpuDmaWaitOp::verify() {
497 AIE::DeviceOp dev = (*this)->getParentOfType<AIE::DeviceOp>();
498 // Some passes (e.g. aie-standard-lowering) use aiex ops outside a DeviceOp,
499 // so we can't expect the device to always exist.
500 if (dev && !dev.lookupSymbol(getSymbol()))
501 return emitOpError("couldn't find symbol in parent device");
502 return success();
503}
504
505//===----------------------------------------------------------------------===//
506// NpuPushQueueOp
507//===----------------------------------------------------------------------===//
508
509LogicalResult AIEX::NpuPushQueueOp::verify() {
510 const auto &targetModel = AIE::getTargetModel(*this);
511 auto numBds = targetModel.getNumBDs(getColumn(), getRow());
512 if (getBdId() > numBds)
513 return emitOpError("BD ID exceeds the maximum ID.");
514 if (getRepeatCount() > 255)
515 return emitOpError("Repeat count exceeds the [0:255] range.");
516 return success();
517}
518
519//===----------------------------------------------------------------------===//
520// NpuWriteBdOp
521//===----------------------------------------------------------------------===//
522
523LogicalResult AIEX::NpuWriteBdOp::verify() {
524 const auto &targetModel = AIE::getTargetModel(*this);
525 auto numBds = targetModel.getNumBDs(getColumn(), getRow());
526 bool isLinearTransfer =
527 (getD0Size() >= 1) && (getD1Size() == 1) && (getIterationSize() == 0);
528 if (getBdId() > numBds)
529 return emitOpError("BD ID exceeds the maximum ID.");
530 if (getPacketId() > 31)
531 return emitOpError("Packet ID exceeds the maximum supported by 5 bits.");
532 if (getPacketType() > 7)
533 return emitOpError("Packet Type exceeds the maximum supported by 3 bits.");
534 if (!isLinearTransfer && getD0Size() > 0x3FF)
535 return emitOpError("D0 Size exceeds the [0:1023] range.");
536 if (getD0Stride() > 0xFFFFF)
537 return emitOpError("D0 Stride exceeds the [0:1M-1] range.");
538 if (getD1Size() > 0x3FF)
539 return emitOpError("D1 Size exceeds the [0:1023] range.");
540 if (getD1Stride() > 0xFFFFF)
541 return emitOpError("D1 Stride exceeds the [0:1M-1] range.");
542 if (getD2Stride() > 0xFFFFF)
543 return emitOpError("D2 Stride exceeds the [0:1M-1] range.");
544 if (getIterationSize() > 0x3F)
545 return emitOpError("Iteration Size exceeds the [0:63] range.");
546 if (getIterationStride() > 0xFFFFF)
547 return emitOpError("Iteration Stride exceeds the [0:1M-1] range.");
548 if (targetModel.isShimNOCTile(getColumn(), getRow()) && getD2Size() != 0)
549 return emitOpError("ShimTile only supports 3 dimensions of sizes.");
550 if (targetModel.isShimNOCTile(getColumn(), getRow()) &&
551 (getD0ZeroBefore() != 0 || getD0ZeroAfter() != 0 ||
552 getD1ZeroBefore() != 0 || getD1ZeroAfter() != 0 ||
553 getD2ZeroBefore() != 0 || getD2ZeroAfter() != 0))
554 return emitOpError("ShimTile doesn't support zero padding.");
555 if (!targetModel.isShimNOCTile(getColumn(), getRow()) &&
556 getBurstLength() != 0)
557 return emitOpError("Only ShimTiles support burst length.");
558 auto errorMessage = checkBurstLength(targetModel, getBurstLength());
559 if (errorMessage.has_value()) {
560 return emitOpError(errorMessage.value());
561 }
562
563 return success();
564}
565
566//===----------------------------------------------------------------------===//
567// NpuWrite32Op
568//===----------------------------------------------------------------------===//
569
570template <typename T>
571static std::optional<uint32_t> getAbsoluteAddress(T *op) {
572 AIE::DeviceOp device =
573 op->getOperation()->template getParentOfType<AIE::DeviceOp>();
574 if (!device) {
575 op->emitError("Must be inside a device.");
576 return std::nullopt;
577 }
578 const AIE::AIETargetModel &tm = device.getTargetModel();
579
580 uint32_t address = 0;
581
582 // If blockwrite references a buffer, the given address is understood to be
583 // relative to the buffer's start address.
584 if (op->getBuffer()) {
585 AIE::BufferOp buffer = device.lookupSymbol<AIE::BufferOp>(*op->getBuffer());
586 if (!buffer) {
587 op->emitError() << "buffer '" << *op->getBuffer()
588 << "' not found in device";
589 return std::nullopt;
590 }
591
592 if (!buffer.getAddress()) {
593 mlir::InFlightDiagnostic err =
594 op->emitError("referenced buffer must have address assigned");
595 err.attachNote(buffer.getLoc()) << "This buffer must have an address.";
596 return std::nullopt;
597 }
598
599 uint32_t col = buffer.getTileOp().getCol();
600 uint32_t row = buffer.getTileOp().getRow();
601 address = static_cast<uint32_t>(*buffer.getAddress()) +
602 op->getAddress() * sizeof(uint32_t);
603 address = ((col & 0xff) << tm.getColumnShift()) |
604 ((row & 0xff) << tm.getRowShift()) | (address & 0xfffff);
605 } else { // otherwise, the given address is absolute
606 address = op->getAddress();
607 std::optional<uint32_t> col = op->getColumn();
608 std::optional<uint32_t> row = op->getRow();
609 if (col && row) {
610 // If col and row are set, only the lower 20 bits of the address are
611 // used, and col and row dictate the upper bits (ignored)
612 address = ((*col & 0xff) << tm.getColumnShift()) |
613 ((*row & 0xff) << tm.getRowShift()) | (address & 0xfffff);
614 }
615 }
616
617 return address;
618}
619
620std::optional<uint32_t> AIEX::NpuWrite32Op::getAbsoluteAddress() {
621 return ::getAbsoluteAddress(this);
622}
623
624//===----------------------------------------------------------------------===//
625// NpuMaskWrite32Op
626//===----------------------------------------------------------------------===//
627
628std::optional<uint32_t> AIEX::NpuMaskWrite32Op::getAbsoluteAddress() {
629 return ::getAbsoluteAddress(this);
630}
631
632//===----------------------------------------------------------------------===//
633// NpuBlockWriteOp
634//===----------------------------------------------------------------------===//
635
636std::optional<uint32_t> AIEX::NpuBlockWriteOp::getAbsoluteAddress() {
637 return ::getAbsoluteAddress(this);
638}
639
640DenseIntElementsAttr AIEX::NpuBlockWriteOp::getDataWords() {
641 Value memref = this->getData();
642 DataLayout dataLayout = DataLayout::closest(*this);
643 int64_t width = dataLayout.getTypeSizeInBits(
644 cast<MemRefType>(memref.getType()).getElementType());
645 if (width != 32) {
646 emitWarning("Only 32-bit data type is supported for now");
647 return nullptr;
648 }
649
650 memref::GetGlobalOp getGlobal = memref.getDefiningOp<memref::GetGlobalOp>();
651 if (!getGlobal) {
652 emitError("Only MemRefs from memref.get_global are supported");
653 return nullptr;
654 }
655
656 auto global = dyn_cast_if_present<memref::GlobalOp>(
657 (*this)->getParentOfType<AIE::DeviceOp>().lookupSymbol(
658 getGlobal.getName()));
659 if (!global) {
660 emitError("Global symbol not found");
661 return nullptr;
662 }
663
664 auto initVal = global.getInitialValue();
665 if (!initVal) {
666 emitError("Global symbol has no initial value");
667 return nullptr;
668 }
669
670 auto data = dyn_cast<DenseIntElementsAttr>(*initVal);
671 if (!data) {
672 emitError("Global symbol initial value is not a dense int array");
673 return nullptr;
674 }
675
676 return data;
677}
678
679//===----------------------------------------------------------------------===//
680// DMAConfigureTaskOp
681//===----------------------------------------------------------------------===//
682
683std::optional<uint32_t> AIEX::DMAConfigureTaskOp::getFirstBdId() {
684 Region &body = getBody();
685 if (body.empty()) {
686 return std::nullopt;
687 }
688 auto bd_ops = body.front().getOps<AIE::DMABDOp>();
689 if (bd_ops.empty() && body.front().getNumSuccessors() == 1) {
690 // Allow the first block to be empty and point to the entry point of the
691 // chain. This allows for specifying cyclying BD chains (infinite loops)
692 // within the constraints of MLIR syntax.
693 Block &chain_entry = *body.front().getSuccessor(0);
694 bd_ops = chain_entry.getOps<AIE::DMABDOp>();
695 }
696 if (bd_ops.empty()) {
697 return std::nullopt;
698 }
699 AIE::DMABDOp bd = *bd_ops.begin();
700 if (!bd.getBdId().has_value()) {
701 return std::nullopt;
702 }
703 return bd.getBdId().value();
704}
705
706LogicalResult
707AIEX::DMAConfigureTaskOp::canonicalize(AIEX::DMAConfigureTaskOp op,
708 PatternRewriter &rewriter) {
709 // Remove blocks that contain nothing but a terminator
710 Region &body = op.getBody();
711 bool did_rewrite = false;
712 for (auto it = body.begin(); it != body.end(); ++it) {
713 Block &block = *it;
714 if (block.empty()) {
715 continue;
716 }
717 auto ops_it = block.without_terminator();
718 if (std::distance(ops_it.begin(), ops_it.end()) == 0) {
719 rewriter.eraseOp(block.getTerminator());
720 did_rewrite = true;
721 }
722 }
723 if (did_rewrite) {
724 return success();
725 }
726 return failure();
727}
728
729LogicalResult AIEX::DMAConfigureTaskOp::verify() {
730 Region &body = getBody();
731 for (auto it = body.begin(); it != body.end(); ++it) {
732 Block &block = *it;
733 if (block.empty()) {
734 continue;
735 }
736 if (block.hasNoPredecessors() && !block.isEntryBlock()) {
737 auto error = block.getTerminator()->emitError(
738 "Block ending in this terminator does not form a chain with "
739 "entry block.");
740 return failure();
741 }
742
743 const AIE::AIETargetModel &targetModel =
744 AIE::getTargetModel(getOperation());
745
746 // This is a layering violation on the DMABDOps, but they are never verified
747 // otherwise Because DMAConfigureTaskOps are not yet merged into the AIE
748 // dialect. The normal DMABDOp verify operation will skip over any BD inside
749 // a DMAConfigureTaskOp
750 LogicalResult result = success();
751 block.walk([&](AIE::DMABDOp bd) {
752 if (bd.getBurstLength() != 0 &&
753 !targetModel.isShimNOCTile(getTileID().col, getTileID().row)) {
754 bd.emitOpError("Burst length is only supported in Shim NOC tiles that "
755 "are connected to the memory-mapped NOC.");
756 result = failure();
757 }
758 });
759 if (failed(result)) {
760 return result;
761 }
762 }
763 return success();
764}
765
766//===----------------------------------------------------------------------===//
767// DMAStartBdChainOp
768//===----------------------------------------------------------------------===//
769
770AIE::BDChainOp AIEX::DMAStartBdChainOp::getBDChainOp() {
771 AIE::DeviceOp device = (*this)->getParentOfType<AIE::DeviceOp>();
772 AIE::BDChainOp chain = device.lookupSymbol<AIE::BDChainOp>(getSymbol());
773 return chain;
774}
775
776LogicalResult AIEX::DMAStartBdChainOp::verify() {
777 AIE::BDChainOp chain = getBDChainOp();
778 if (!chain) {
779 return emitOpError("symbol does not reference valid BD chain");
780 }
781
782 auto actualArgTypes = getArgs().getTypes();
783 auto expectedArgTypes = chain.getRegion().getArgumentTypes();
784 if (actualArgTypes.size() != expectedArgTypes.size()) {
785 return emitOpError("Number of arguments mismatches.");
786 }
787 for (unsigned i = 0, n = expectedArgTypes.size(); i < n; i++) {
788 if (actualArgTypes[i] != expectedArgTypes[i]) {
789 return emitOpError("Argument ") << (i + 1) << " types mismatch: "
790 << "expected " << expectedArgTypes[i]
791 << " but got " << actualArgTypes[i];
792 }
793 }
794 return success();
795}
796
797//===----------------------------------------------------------------------===//
798// NpuControlPacketOp
799//===----------------------------------------------------------------------===//
800
801uint32_t AIEX::NpuControlPacketOp::getRowFromAddr() {
802 const auto &targetModel = AIE::getTargetModel(*this);
803 uint32_t addr = getAddress();
804 uint32_t rowInt = (addr >> targetModel.getRowShift()) & 0x1f;
805 return rowInt;
806}
807
808uint32_t AIEX::NpuControlPacketOp::getColumnFromAddr() {
809 const auto &targetModel = AIE::getTargetModel(*this);
810 uint32_t addr = getAddress();
811 uint32_t colInt = (addr >> targetModel.getColumnShift()) & 0x1f;
812 return colInt;
813}
814
815//===----------------------------------------------------------------------===//
816// SetLockOp
817//===----------------------------------------------------------------------===//
818
819LogicalResult AIEX::SetLockOp::verify() {
820 const auto &targetModel = AIE::getTargetModel(*this);
821
822 if (targetModel.getTargetArch() == AIE::AIEArch::AIE1)
823 return emitOpError("SetLockOp is not supported on AIE1.");
824
825 if (getValue() > targetModel.getMaxLockValue())
826 return emitOpError("Lock value exceeds the maximum value of " +
827 std::to_string(targetModel.getMaxLockValue()));
828
829 auto lockOp = getLockOp();
830 auto lockIDOpt = getLockOp().getLockID();
831 // Note that the lockID may not be assigned initially, so lets wait until it
832 // is to verify the lockID dependent conditions
833 if (!lockIDOpt) {
834 return success();
835 }
836
837 auto col = lockOp.colIndex();
838 auto row = lockOp.rowIndex();
839 uint32_t lockID = lockOp.getLockIDValue();
840
841 if (lockID >= targetModel.getNumLocks(col, row)) {
842 return emitOpError("Lock ID out of range for given tile. Max ID: " +
843 std::to_string(targetModel.getNumLocks(col, row) - 1));
844 }
845
846 if (!targetModel.getLocalLockAddress(lockID, lockOp.getTileID())) {
847 return emitOpError("Invalid lock ID and tile combination when trying to "
848 "retrieve the local lock address.");
849 }
850
851 return success();
852}
853
854//===----------------------------------------------------------------------===//
855// BlockFloatingPointType
856//===----------------------------------------------------------------------===//
857uint64_t AIEX::BlockFloatType::getTotalSizeInBits() const {
858 return getBlockSize() * getMantissaBits() + getExponentBits() +
859 getSubtileShiftBits();
860}
861
862llvm::TypeSize AIEX::BlockFloatType::getTypeSizeInBits(
863 const mlir::DataLayout &dataLayout,
864 mlir::DataLayoutEntryListRef params) const {
865 return llvm::TypeSize::getFixed(getTotalSizeInBits());
866}
867
868uint64_t AIEX::BlockFloatType::getABIAlignment(
869 const mlir::DataLayout &dataLayout,
870 mlir::DataLayoutEntryListRef params) const {
871 // For the purposes of the data movement operations, we want all types to be
872 // packed <=> ABI alignment is 1.
873 return 1;
874}
875
876std::optional<AIEX::BlockFloatType::BlockFormat>
877AIEX::BlockFloatType::getBlockFormat(StringRef blockType) {
878 static const llvm::StringMap<AIEX::BlockFloatType::BlockFormat>
879 blockFormatsMap = {
880 {"v8bfp16ebs8", {8, 8, 8, 0}},
881 {"v16bfp16ebs16", {16, 8, 8, 0}},
882 };
883
884 auto it = blockFormatsMap.find(blockType);
885 if (it != blockFormatsMap.end()) {
886 return it->second;
887 }
888
889 return std::nullopt;
890}
891
892LogicalResult
893AIEX::BlockFloatType::verify(function_ref<InFlightDiagnostic()> emitError,
894 StringRef block_type) {
895 if (!getBlockFormat(block_type))
896 return emitError() << "Invalid block type: " << block_type
897 << ". Known types are: v8bfp16ebs8, v16bfp16ebs16.";
898
899 return success();
900}
901
902//===----------------------------------------------------------------------===//
903// ConfigureOp
904//===----------------------------------------------------------------------===//
905
906AIE::DeviceOp AIEX::ConfigureOp::getReferencedDeviceOp() {
907 ModuleOp moduleOp = this->getOperation()->getParentOfType<ModuleOp>();
908 if (!moduleOp) {
909 emitError("aiex.configure must be inside of a module");
910 return nullptr;
911 }
912 Operation *maybeReferencedDevice =
913 SymbolTable::lookupSymbolIn(moduleOp.getOperation(), getSymbolAttr());
914 if (!maybeReferencedDevice) {
915 emitError("No such device: '") << getSymbolAttr() << "'";
916 return nullptr;
917 }
918 AIE::DeviceOp referencedDevice =
919 llvm::dyn_cast<AIE::DeviceOp>(maybeReferencedDevice);
920 if (!referencedDevice) {
921 emitError("Not a device: '") << getSymbolAttr() << "'";
922 return nullptr;
923 }
924 return referencedDevice;
925}
926
927LogicalResult AIEX::ConfigureOp::verify() {
928 AIE::DeviceOp parentDev = getOperation()->getParentOfType<AIE::DeviceOp>();
929 AIE::DeviceOp referencedDev = getReferencedDeviceOp();
930 if (!referencedDev) {
931 return failure();
932 }
933 if (parentDev.getDevice() != referencedDev.getDevice()) {
934 emitError("Device types do not match: '")
935 << AIE::stringifyAIEDevice(parentDev.getDevice()) << "' vs. '"
936 << AIE::stringifyAIEDevice(referencedDev.getDevice()) << "'";
937 return failure();
938 }
939 return success();
940}
941
942//===----------------------------------------------------------------------===//
943// RunOp
944//===----------------------------------------------------------------------===//
945
946AIE::DeviceOp AIEX::RunOp::getCalleeDeviceOp() {
947 AIEX::ConfigureOp configureOp =
948 getOperation()->getParentOfType<AIEX::ConfigureOp>();
949 if (!configureOp) {
950 return nullptr;
951 }
952 AIE::DeviceOp referencedDevice = configureOp.getReferencedDeviceOp();
953 return referencedDevice;
954}
955
956AIE::RuntimeSequenceOp AIEX::RunOp::getCalleeRuntimeSequenceOp() {
957 AIEX::ConfigureOp configureOp =
958 getOperation()->getParentOfType<AIEX::ConfigureOp>();
959 if (!configureOp) {
960 return nullptr;
961 }
962 AIE::DeviceOp referencedDevice = configureOp.getReferencedDeviceOp();
963 if (!referencedDevice) {
964 return nullptr;
965 }
966
967 Operation *maybeRuntimeSequence =
968 SymbolTable::lookupSymbolIn(referencedDevice, getRuntimeSequenceSymbol());
969
970 if (!maybeRuntimeSequence) {
971 return nullptr;
972 }
973 AIE::RuntimeSequenceOp runtimeSequence =
974 llvm::dyn_cast<AIE::RuntimeSequenceOp>(maybeRuntimeSequence);
975 if (!runtimeSequence) {
976 return nullptr;
977 }
978
979 return runtimeSequence;
980}
981
982//===----------------------------------------------------------------------===//
983// NpuLoadPdiOp
984//===----------------------------------------------------------------------===//
985
986LogicalResult AIEX::NpuLoadPdiOp::canonicalize(AIEX::NpuLoadPdiOp op,
987 PatternRewriter &rewriter) {
988 // Check for back-to-back identical load_pdi ops and remove duplicates
989 Operation *nextOp = op->getNextNode();
990 if (!nextOp)
991 return failure();
992
993 // Check if next op is also a NpuLoadPdiOp
994 auto nextLoadPdi = dyn_cast<AIEX::NpuLoadPdiOp>(nextOp);
995 if (!nextLoadPdi)
996 return failure();
997
998 // Check if they are identical (all attributes match)
999 if (op.getDeviceRefAttr() == nextLoadPdi.getDeviceRefAttr() &&
1000 op.getId() == nextLoadPdi.getId() &&
1001 op.getSize() == nextLoadPdi.getSize() &&
1002 op.getAddress() == nextLoadPdi.getAddress()) {
1003 // Erase the first one, keeping the second
1004 rewriter.eraseOp(op);
1005 return success();
1006 }
1007
1008 return failure();
1009}
virtual AIEArch getTargetArch() const =0
Return the target architecture.
virtual uint32_t getNumLocks(AIETileType tileType) const =0
Return the number of lock objects for a given tile type.
virtual std::vector< std::pair< uint32_t, uint32_t > > getShimBurstEncodingsAndLengths() const =0
virtual std::optional< uint32_t > getLocalLockAddress(uint32_t lockId, TileID tile) const =0
bool isShimNOCTile(int col, int row) const
Return true if the given tile is a ShimNOC tile.
virtual uint32_t getNumBDs(AIETileType tileType) const =0
Return the number of buffer descriptors for a given tile type.
virtual uint32_t getMaxLockValue() const =0
Return the maximum value that can be stored in a lock register.
virtual uint32_t getColumnShift() const =0
virtual uint32_t getRowShift() const =0
virtual uint32_t getAddressGenGranularity() const =0
Return the data bus width of the device.
void getHardwareStridesWraps(const AIE::AIETargetModel &targetModel, mlir::Operation *op, mlir::BaseMemRefType referencedBufType, llvm::SmallVector< int64_t, 4 > inputSizes, llvm::SmallVector< int64_t, 4 > inputStrides, llvm::SmallVector< int64_t, 4 > &sizes, llvm::SmallVector< int64_t, 4 > &strides)
mlir::LogicalResult verifyStridesWraps(mlir::Operation *forOp, mlir::BaseMemRefType referencedBufType, int tileCol, int tileRow, llvm::SmallVector< int64_t, 4 > inputSizes, llvm::SmallVector< int64_t, 4 > inputStrides, llvm::SmallVector< int64_t, 4 > hardwareSizes, llvm::SmallVector< int64_t, 4 > hardwareStrides, bool skipTransformationChecks=false)
bool isLinearTransfer(llvm::ArrayRef< int64_t > sizes, llvm::ArrayRef< int64_t > strides)
const AIETargetModel & getTargetModel(mlir::Operation *op)