13#include "mlir/IR/DialectImplementation.h"
14#include "mlir/Dialect/Func/IR/FuncOps.h"
15#include "mlir/Interfaces/FoldInterfaces.h"
16#include "mlir/Transforms/InliningUtils.h"
18#include "llvm/ADT/TypeSwitch.h"
25#include "aie/Dialect/AIEX/IR/AIEXDialect.cpp.inc"
27#define GET_TYPEDEF_CLASSES
28#include "aie/Dialect/AIEX/IR/AIEXTypes.cpp.inc"
33void AIEXDialect::initialize() {
36#include "aie/Dialect/AIEX/IR/AIEX.cpp.inc"
39#define GET_TYPEDEF_LIST
40#include "aie/Dialect/AIEX/IR/AIEXTypes.cpp.inc"
47#include "aie/Dialect/AIEX/IR/AIEX.cpp.inc"
53 ((row & 0xff) << tm.
getRowShift()) | (0x1D004 + bd_id * 0x20);
90 mlir::BaseMemRefType referencedBufType,
91 llvm::SmallVector<int64_t, 4> inputSizes,
92 llvm::SmallVector<int64_t, 4> inputStrides,
93 llvm::SmallVector<int64_t, 4> &sizes,
94 llvm::SmallVector<int64_t, 4> &strides) {
95 assert(inputSizes.size() == inputStrides.size());
96 assert(sizes.size() == 4);
97 assert(strides.size() == 4);
99 auto elemWidth = referencedBufType.getElementTypeBitWidth();
103 std::fill(sizes.begin(), sizes.end(), 0);
104 std::fill(strides.begin(), strides.end(), 0);
106 if (inputSizes[0] == 0) {
113 sizes[0] = inputSizes[0] * elemWidth / addressGranularity;
114 if (inputStrides[0] * elemWidth < addressGranularity) {
124 strides[0] = inputStrides[0] * elemWidth / addressGranularity - 1;
128 sizes[1] = inputSizes[1];
129 if (inputSizes[1] > 1) {
131 strides[1] = inputStrides[1] * elemWidth / addressGranularity - 1;
135 sizes[2] = inputSizes[2];
136 if (inputSizes[2] > 1) {
138 strides[2] = inputStrides[2] * elemWidth / addressGranularity - 1;
142 if (inputSizes[3] > 1) {
144 sizes[3] = inputSizes[3] - 1;
151 if (inputStrides[3] > 0) {
152 strides[3] = inputStrides[3] * elemWidth / addressGranularity - 1;
159 mlir::BaseMemRefType referencedBufType,
int tileCol,
160 int tileRow, llvm::SmallVector<int64_t, 4> inputSizes,
161 llvm::SmallVector<int64_t, 4> inputStrides,
162 llvm::SmallVector<int64_t, 4> hardwareSizes,
163 llvm::SmallVector<int64_t, 4> hardwareStrides,
164 bool skipTransformationChecks) {
166 auto addressGranularity = targetModel.getAddressGenGranularity();
167 auto elemWidth = referencedBufType.getElementTypeBitWidth();
169 uint32_t wrap_bits = 0;
170 uint32_t step_bits = 0;
171 uint32_t iter_bits = 6;
172 if (targetModel.isShimNOCTile(tileCol, tileRow)) {
175 }
else if (targetModel.isMemTile(tileCol, tileRow)) {
178 }
else if (targetModel.isCoreTile(tileCol, tileRow)) {
182 return forOp->emitOpError(
183 "Unsupported tile type at (" + std::to_string(tileCol) +
", " +
184 std::to_string(tileRow) +
") Must be ShimNOC, Mem or Core.");
187 for (
int i = 0; i < 4; i++) {
188 if (inputSizes[i] <= 0) {
189 return forOp->emitOpError(
"Size ") << i <<
" must be a positive integer.";
193 if (inputSizes[0] * elemWidth % addressGranularity != 0) {
194 std::stringstream msg;
195 msg <<
"Transfer sizes must be multiples of " << (addressGranularity / 8)
196 <<
" bytes. " << inputSizes[0] <<
" elements at " << (elemWidth / 8)
197 <<
" bytes each equal " << (inputSizes[0] * elemWidth / 8)
198 <<
" bytes, which is not divisible by " << (addressGranularity / 8)
200 return forOp->emitOpError(msg.str());
203 for (
int i = 0; i < 3; i++) {
204 if (inputSizes[i] > 1 && inputStrides[i] < 1) {
208 return forOp->emitOpError(
"Stride ")
209 << i <<
" must be a positive integer.";
214 if (inputSizes[3] > 1 && inputStrides[3] < 0) {
215 return forOp->emitOpError(
"Stride 3 must be a non-negative integer.");
218 for (
int i = 0; i < 4; i++) {
221 if (i == 0 && inputStrides[i] == 1)
223 if (inputStrides[i] * elemWidth % addressGranularity != 0) {
224 std::stringstream msg;
225 msg <<
"Stride " << i <<
" is " << inputStrides[i] <<
" elements * "
226 << (elemWidth / 8) <<
" bytes = " << (inputStrides[i] * elemWidth / 8)
227 <<
" bytes, which is not divisible by " << (addressGranularity / 8)
229 return forOp->emitOpError(msg.str());
233 if (!skipTransformationChecks && hardwareSizes[0] > (1 << wrap_bits) - 1)
234 return forOp->emitOpError(
235 "Size 0 exceeds the [0:" + std::to_string((1 << wrap_bits) - 1) +
237 if (hardwareSizes[1] > (1 << wrap_bits) - 1)
238 return forOp->emitOpError(
239 "Size 1 exceeds the [0:" + std::to_string((1 << wrap_bits) - 1) +
241 if (hardwareSizes[3] > (1 << iter_bits))
242 return forOp->emitOpError(
243 "Size 3 exceeds the [1:" + std::to_string(1 << iter_bits) +
"] range.");
244 if (hardwareStrides[0] > (1 << step_bits))
245 return forOp->emitOpError(
"Stride 0 exceeds the [1:" +
246 std::to_string(1 << step_bits) +
"] range.");
247 if (hardwareStrides[1] > (1 << step_bits))
248 return forOp->emitOpError(
"Stride 1 exceeds the [1:" +
249 std::to_string(1 << step_bits) +
"] range.");
250 if (hardwareStrides[2] > (1 << step_bits))
251 return forOp->emitOpError(
"Stride 2 exceeds the [1:" +
252 std::to_string(1 << step_bits) +
"] range.");
255 if (hardwareStrides[3] > (1 << step_bits) && hardwareSizes[3] > 0)
256 return forOp->emitOpError(
"Stride 3 exceeds the [1:" +
257 std::to_string(1 << step_bits) +
"] range.");
266LogicalResult AIEX::UseTokenOp::verify() {
267 auto *parentOp = (*this)->getParentOp();
268 if (isa<func::FuncOp>(parentOp) || isa<AIE::CoreOp>(parentOp) ||
269 isa<AIE::MemOp>(parentOp) || isa<AIE::ShimDMAOp>(parentOp))
278LogicalResult AIEX::MulticastOp::verify() {
279 Region &body = getPorts();
280 assert(getOperation()->getNumRegions());
281 assert(!body.empty());
282 for (
auto &ops : body.front())
283 if (!isa<MultiDestOp,
AIE::EndOp>(ops))
284 return ops.emitOpError(
"cannot be contained in a Multicast op");
293LogicalResult AIEX::BroadcastPacketOp::verify() {
294 Region &body = getPorts();
295 assert(getOperation()->getNumRegions());
296 assert(!body.empty());
297 for (
auto &ops : body.front())
298 if (!isa<BPIDOp,
AIE::EndOp>(ops))
299 return ops.emitOpError(
"cannot be contained in a BroadcastPacket op");
310int64_t AIEX::NpuDmaMemcpyNdOp::getOffsetInBytes() {
311 llvm::SmallVector<int64_t, 4> offsets =
312 llvm::map_to_vector(llvm::reverse(getMixedOffsets()), [](OpFoldResult s) {
313 return getConstantIntValue(s).value();
315 llvm::SmallVector<int64_t, 4> strides =
316 llvm::map_to_vector(llvm::reverse(getMixedStrides()), [](OpFoldResult s) {
317 return getConstantIntValue(s).value();
320 BaseMemRefType my_memref = getMemref().getType();
321 size_t R = offsets.size();
322 size_t el_bit_width = my_memref.getElementTypeBitWidth();
323 assert(el_bit_width % 8 == 0 &&
324 "Expected Memref element bitwidth to be multiple of 8.");
325 size_t S = el_bit_width / 8;
326 for (
size_t i = 0; i < R; i++)
327 offset += offsets[i] * strides[i] * S;
335bool AIEX::NpuDmaMemcpyNdOp::isLinearTransferWithoutTransformation() {
336 llvm::SmallVector<int64_t, 4> inputSizes =
337 llvm::map_to_vector(llvm::reverse(getMixedSizes()), [](OpFoldResult s) {
338 return getConstantIntValue(s).value();
340 llvm::SmallVector<int64_t, 4> inputStrides =
341 llvm::map_to_vector(llvm::reverse(getMixedStrides()), [](OpFoldResult s) {
342 return getConstantIntValue(s).value();
344 return (inputSizes[1] == 1 && inputSizes[2] == 1 && inputStrides[0] == 1 &&
345 inputStrides[1] == 0 && inputStrides[2] == 0);
351static std::optional<std::string>
353 uint32_t requestedBurstLength) {
354 if (requestedBurstLength != 0) {
356 auto pair = std::find_if(bel.begin(), bel.end(),
357 [=](
const std::pair<uint32_t, uint32_t> &p) {
358 return p.second == requestedBurstLength;
361 if (pair == bel.end()) {
362 std::string errorMessage =
363 "Requested burst length is not supported by the target. "
364 "Supported burst lengths:";
367 std::accumulate(bel.begin(), bel.end(), errorMessage,
368 [](
const std::string &a,
auto b) {
369 return a +
" " + std::to_string(b.second);
379LogicalResult AIEX::NpuDmaMemcpyNdOp::verify() {
380 BaseMemRefType buffer = getMemref().getType();
384 if (buffer.getElementTypeBitWidth() > addressGranularity) {
385 return emitOpError(
"Maximum element bit width allowed is ")
386 << addressGranularity <<
"bits. ";
387 }
else if (buffer.hasStaticShape() &&
388 (buffer.getNumElements() * buffer.getElementTypeBitWidth()) <
389 addressGranularity) {
390 return emitOpError(
"Minimum data transfer size required is ")
391 << addressGranularity <<
"bits. ";
393 if (!llvm::all_of(getMixedStrides(), [](OpFoldResult s) {
394 return getConstantIntValue(s).has_value();
396 return emitOpError(
"Only constant strides currently supported.");
397 if (!llvm::all_of(getMixedSizes(), [](OpFoldResult s) {
398 return getConstantIntValue(s).has_value();
400 return emitOpError(
"Only constant sizes currently supported.");
401 if (!llvm::all_of(getMixedOffsets(), [](OpFoldResult s) {
402 return getConstantIntValue(s).has_value();
404 return emitOpError(
"Only constant offsets currently supported.");
406 llvm::SmallVector<int64_t, 4> inputSizes =
407 llvm::map_to_vector(llvm::reverse(getMixedSizes()), [](OpFoldResult s) {
408 return getConstantIntValue(s).value();
410 llvm::SmallVector<int64_t, 4> inputStrides =
411 llvm::map_to_vector(llvm::reverse(getMixedStrides()), [](OpFoldResult s) {
412 return getConstantIntValue(s).value();
414 llvm::SmallVector<int64_t, 4> hardwareSizes(4);
415 llvm::SmallVector<int64_t, 4> hardwareStrides(4);
417 hardwareSizes, hardwareStrides);
418 int64_t offset = getOffsetInBytes();
420 auto errorMessage = checkBurstLength(targetModel, getBurstLength());
421 if (errorMessage.has_value()) {
422 return emitOpError(errorMessage.value());
430 if (offset % 4 != 0) {
431 return emitOpError(
"Offset must be 4-byte-aligned.");
441 AIE::DeviceOp dev = getOperation()->getParentOfType<AIE::DeviceOp>();
442 if (
auto allocOp = allocGetter.
get(dev, getMetadata())) {
443 int col = allocOp->getCol();
444 bool skipTransformationChecks = isLinearTransferWithoutTransformation();
446 inputStrides, hardwareSizes, hardwareStrides,
447 skipTransformationChecks))) {
453 if (
auto packetInfo = getPacket()) {
454 if (packetInfo->getPktType() > 7)
455 return emitOpError(
"Packet type field can only hold 3 bits.");
456 if (packetInfo->getPktId() > 31)
457 return emitOpError(
"Packet ID field can only hold 5 bits.");
467LogicalResult AIEX::NpuDmaWaitOp::verify() {
468 AIE::DeviceOp dev = (*this)->getParentOfType<AIE::DeviceOp>();
471 if (dev && !dev.lookupSymbol(getSymbol()))
472 return emitOpError(
"couldn't find symbol in parent device");
480LogicalResult AIEX::NpuPushQueueOp::verify() {
482 auto numBds = targetModel.
getNumBDs(getColumn(), getRow());
483 if (getBdId() > numBds)
484 return emitOpError(
"BD ID exceeds the maximum ID.");
485 if (getRepeatCount() > 255)
486 return emitOpError(
"Repeat count exceeds the [0:255] range.");
494LogicalResult AIEX::NpuWriteBdOp::verify() {
496 auto numBds = targetModel.
getNumBDs(getColumn(), getRow());
497 bool isLinearTransfer =
498 (getD0Size() >= 1) && (getD1Size() == 1) && (getIterationSize() == 0);
499 if (getBdId() > numBds)
500 return emitOpError(
"BD ID exceeds the maximum ID.");
501 if (!isLinearTransfer && getD0Size() > 0x3FF)
502 return emitOpError(
"D0 Size exceeds the [0:1023] range.");
503 if (getD0Stride() > 0xFFFFF)
504 return emitOpError(
"D0 Stride exceeds the [0:1M-1] range.");
505 if (getD1Size() > 0x3FF)
506 return emitOpError(
"D1 Size exceeds the [0:1023] range.");
507 if (getD1Stride() > 0xFFFFF)
508 return emitOpError(
"D1 Stride exceeds the [0:1M-1] range.");
509 if (getD2Stride() > 0xFFFFF)
510 return emitOpError(
"D2 Stride exceeds the [0:1M-1] range.");
511 if (getIterationSize() > 0x3F)
512 return emitOpError(
"Iteration Size exceeds the [0:63] range.");
513 if (getIterationStride() > 0xFFFFF)
514 return emitOpError(
"Iteration Stride exceeds the [0:1M-1] range.");
515 if (targetModel.
isShimNOCTile(getColumn(), getRow()) && getD2Size() != 0)
516 return emitOpError(
"ShimTile only supports 3 dimensions of sizes.");
518 (getD0ZeroBefore() != 0 || getD0ZeroAfter() != 0 ||
519 getD1ZeroBefore() != 0 || getD1ZeroAfter() != 0 ||
520 getD2ZeroBefore() != 0 || getD2ZeroAfter() != 0))
521 return emitOpError(
"ShimTile doesn't support zero padding.");
523 getBurstLength() != 0)
524 return emitOpError(
"Only ShimTiles support burst length.");
525 auto errorMessage = checkBurstLength(targetModel, getBurstLength());
526 if (errorMessage.has_value()) {
527 return emitOpError(errorMessage.value());
537ParseResult AIEX::RuntimeSequenceOp::parse(OpAsmParser &parser,
538 OperationState &result) {
541 (void)parser.parseOptionalSymbolName(
542 nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);
544 SmallVector<OpAsmParser::Argument> entryArgs;
547 ParseResult argParseResult = parser.parseCommaSeparatedList(
548 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
549 OpAsmParser::Argument argument;
550 if (parser.parseArgument(argument,
true,
true)) {
553 entryArgs.push_back(argument);
556 if (argParseResult) {
557 return argParseResult;
561 auto *body = result.addRegion();
562 ParseResult bodyParseResult = parser.parseRegion(*body, entryArgs,
false);
563 if (bodyParseResult) {
564 return bodyParseResult;
570void AIEX::RuntimeSequenceOp::print(OpAsmPrinter &printer) {
571 Region &body = getRegion();
573 auto nameAttr = (*this)->getAttrOfType<StringAttr>(
574 mlir::SymbolTable::getSymbolAttrName());
577 printer.printSymbolName(nameAttr);
581 for (
unsigned i = 0, n = body.getNumArguments(); i < n; i++) {
585 printer.printRegionArgument(body.getArgument(i));
590 printer.printRegion(body,
false,
true);
593LogicalResult AIEX::RuntimeSequenceOp::verify() {
594 AIE::DeviceOp device = (*this)->getParentOfType<AIE::DeviceOp>();
597 (*this)->emitOpError() <<
"must be inside AIE device operation.";
607std::optional<uint32_t> AIEX::DMAConfigureTaskOp::getFirstBdId() {
608 Region &body = getBody();
612 auto bd_ops = body.front().getOps<AIE::DMABDOp>();
613 if (bd_ops.empty() && body.front().getNumSuccessors() == 1) {
617 Block &chain_entry = *body.front().getSuccessor(0);
618 bd_ops = chain_entry.getOps<AIE::DMABDOp>();
620 if (bd_ops.empty()) {
623 AIE::DMABDOp bd = *bd_ops.begin();
624 if (!bd.getBdId().has_value()) {
627 return bd.getBdId().value();
631AIEX::DMAConfigureTaskOp::canonicalize(AIEX::DMAConfigureTaskOp op,
632 PatternRewriter &rewriter) {
634 Region &body = op.getBody();
635 bool did_rewrite =
false;
636 for (
auto it = body.begin(); it != body.end(); ++it) {
641 auto ops_it = block.without_terminator();
642 if (std::distance(ops_it.begin(), ops_it.end()) == 0) {
643 rewriter.eraseOp(block.getTerminator());
653LogicalResult AIEX::DMAConfigureTaskOp::verify() {
654 Region &body = getBody();
655 for (
auto it = body.begin(); it != body.end(); ++it) {
660 if (block.hasNoPredecessors() && !block.isEntryBlock()) {
661 auto error = block.getTerminator()->emitError(
662 "Block ending in this terminator does not form a chain with "
674 LogicalResult result = success();
675 block.walk([&](AIE::DMABDOp bd) {
676 if (bd.getBurstLength() != 0 &&
677 !targetModel.
isShimNOCTile(getTileID().col, getTileID().row)) {
678 bd.emitOpError(
"Burst length is only supported in Shim NOC tiles that "
679 "are connected to the memory-mapped NOC.");
683 if (failed(result)) {
694AIE::BDChainOp AIEX::DMAStartBdChainOp::getBDChainOp() {
695 AIE::DeviceOp device = (*this)->getParentOfType<AIE::DeviceOp>();
696 AIE::BDChainOp chain = device.lookupSymbol<AIE::BDChainOp>(getSymbol());
700LogicalResult AIEX::DMAStartBdChainOp::verify() {
701 AIE::BDChainOp chain = getBDChainOp();
703 return emitOpError(
"symbol does not reference valid BD chain");
706 auto actualArgTypes = getArgs().getTypes();
707 auto expectedArgTypes = chain.getRegion().getArgumentTypes();
708 if (actualArgTypes.size() != expectedArgTypes.size()) {
709 return emitOpError(
"Number of arguments mismatches.");
711 for (
unsigned i = 0, n = expectedArgTypes.size(); i < n; i++) {
712 if (actualArgTypes[i] != expectedArgTypes[i]) {
713 return emitOpError(
"Argument ") << (i + 1) <<
" types mismatch: "
714 <<
"expected " << expectedArgTypes[i]
715 <<
" but got " << actualArgTypes[i];
725uint32_t AIEX::NpuControlPacketOp::getRowFromAddr() {
727 uint32_t addr = getAddress();
728 uint32_t rowInt = (addr >> targetModel.
getRowShift()) & 0x1f;
732uint32_t AIEX::NpuControlPacketOp::getColumnFromAddr() {
734 uint32_t addr = getAddress();
743LogicalResult AIEX::SetLockOp::verify() {
747 return emitOpError(
"SetLockOp is not supported on AIE1.");
750 return emitOpError(
"Lock value exceeds the maximum value of " +
753 auto lockOp = getLockOp();
754 auto lockIDOpt = getLockOp().getLockID();
761 auto col = lockOp.colIndex();
762 auto row = lockOp.rowIndex();
763 uint32_t lockID = lockOp.getLockIDValue();
766 return emitOpError(
"Lock ID out of range for given tile. Max ID: " +
767 std::to_string(targetModel.
getNumLocks(col, row) - 1));
771 return emitOpError(
"Invalid lock ID and tile combination when trying to "
772 "retrieve the local lock address.");
virtual AIEArch getTargetArch() const =0
Return the target architecture.
virtual uint32_t getNumBDs(int col, int row) const =0
Return the number of buffer descriptors supported by the DMA in the given tile.
virtual std::vector< std::pair< uint32_t, uint32_t > > getShimBurstEncodingsAndLengths() const =0
virtual std::optional< uint32_t > getLocalLockAddress(uint32_t lockId, TileID tile) const =0
virtual uint32_t getNumLocks(int col, int row) const =0
Return the number of lock objects.
virtual bool isShimNOCTile(int col, int row) const =0
Return true if the given tile is a Shim NOC tile.
virtual uint32_t getMaxLockValue() const =0
Return the maximum value that can be stored in a lock register.
virtual uint32_t getColumnShift() const =0
virtual uint32_t getRowShift() const =0
virtual uint32_t getAddressGenGranularity() const =0
Return the data bus width of the device.
void getHardwareStridesWraps(const AIE::AIETargetModel &targetModel, mlir::BaseMemRefType referencedBufType, llvm::SmallVector< int64_t, 4 > inputSizes, llvm::SmallVector< int64_t, 4 > inputStrides, llvm::SmallVector< int64_t, 4 > &sizes, llvm::SmallVector< int64_t, 4 > &strides)
uint64_t getBufferDescriptorAddressRegisterAddress(const AIE::AIETargetModel &tm, unsigned bd_id, unsigned col, unsigned row)
mlir::LogicalResult verifyStridesWraps(mlir::Operation *forOp, mlir::BaseMemRefType referencedBufType, int tileCol, int tileRow, llvm::SmallVector< int64_t, 4 > inputSizes, llvm::SmallVector< int64_t, 4 > inputStrides, llvm::SmallVector< int64_t, 4 > hardwareSizes, llvm::SmallVector< int64_t, 4 > hardwareStrides, bool skipTransformationChecks=false)
const AIETargetModel & getTargetModel(mlir::Operation *op)
std::optional< AIE::ShimDMAAllocationOp > get(DeviceOp dev, mlir::StringRef sym_name)