13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/IR/DialectImplementation.h"
15#include "mlir/IR/Operation.h"
16#include "mlir/IR/TypeUtilities.h"
17#include "mlir/Interfaces/DataLayoutInterfaces.h"
18#include "mlir/Interfaces/FoldInterfaces.h"
19#include "mlir/Transforms/InliningUtils.h"
21#include "llvm/ADT/TypeSwitch.h"
22#include "llvm/Support/TypeSize.h"
30#include "aie/Dialect/AIEX/IR/AIEXDialect.cpp.inc"
32#define GET_TYPEDEF_CLASSES
33#include "aie/Dialect/AIEX/IR/AIEXTypes.cpp.inc"
38void AIEXDialect::initialize() {
41#include "aie/Dialect/AIEX/IR/AIEX.cpp.inc"
44#define GET_TYPEDEF_LIST
45#include "aie/Dialect/AIEX/IR/AIEXTypes.cpp.inc"
52#include "aie/Dialect/AIEX/IR/AIEX.cpp.inc"
89 mlir::BaseMemRefType referencedBufType,
90 llvm::SmallVector<int64_t, 4> inputSizes,
91 llvm::SmallVector<int64_t, 4> inputStrides,
92 llvm::SmallVector<int64_t, 4> &sizes,
93 llvm::SmallVector<int64_t, 4> &strides) {
94 assert(inputSizes.size() == inputStrides.size());
95 assert(sizes.size() == 4);
96 assert(strides.size() == 4);
98 DataLayout dataLayout = DataLayout::closest(op);
100 dataLayout.getTypeSizeInBits(referencedBufType.getElementType());
104 std::fill(sizes.begin(), sizes.end(), 0);
105 std::fill(strides.begin(), strides.end(), 0);
107 if (inputSizes[0] == 0) {
114 sizes[0] = inputSizes[0] * elemWidth / addressGranularity;
115 if (inputStrides[0] * elemWidth < addressGranularity ||
116 (elemWidth > addressGranularity)) {
137 strides[0] = inputStrides[0] * elemWidth / addressGranularity - 1;
141 sizes[1] = inputSizes[1];
142 if (inputSizes[1] > 1) {
144 strides[1] = inputStrides[1] * elemWidth / addressGranularity - 1;
148 sizes[2] = inputSizes[2];
149 if (inputSizes[2] > 1) {
151 strides[2] = inputStrides[2] * elemWidth / addressGranularity - 1;
155 if (inputSizes[3] > 1) {
157 sizes[3] = inputSizes[3] - 1;
164 if (inputStrides[3] > 0) {
165 strides[3] = inputStrides[3] * elemWidth / addressGranularity - 1;
172 mlir::BaseMemRefType referencedBufType,
int tileCol,
173 int tileRow, llvm::SmallVector<int64_t, 4> inputSizes,
174 llvm::SmallVector<int64_t, 4> inputStrides,
175 llvm::SmallVector<int64_t, 4> hardwareSizes,
176 llvm::SmallVector<int64_t, 4> hardwareStrides,
177 bool skipTransformationChecks) {
179 auto addressGranularity = targetModel.getAddressGenGranularity();
180 DataLayout dataLayout = DataLayout::closest(forOp);
182 dataLayout.getTypeSizeInBits(referencedBufType.getElementType());
184 uint32_t wrap_bits = 0;
185 uint32_t step_bits = 0;
186 uint32_t iter_bits = 6;
187 if (targetModel.isShimNOCTile(tileCol, tileRow)) {
190 }
else if (targetModel.isMemTile(tileCol, tileRow)) {
193 }
else if (targetModel.isCoreTile(tileCol, tileRow)) {
197 return forOp->emitOpError(
198 "Unsupported tile type at (" + std::to_string(tileCol) +
", " +
199 std::to_string(tileRow) +
") Must be ShimNOC, Mem or Core.");
202 for (
int i = 0; i < 4; i++) {
203 if (inputSizes[i] <= 0) {
204 return forOp->emitOpError(
"Size ") << i <<
" must be a positive integer.";
208 if (inputSizes[0] * elemWidth % addressGranularity != 0) {
209 std::stringstream msg;
210 msg <<
"Transfer sizes must be multiples of " << (addressGranularity / 8)
211 <<
" bytes. " << inputSizes[0] <<
" elements at " << (elemWidth / 8)
212 <<
" bytes each equal " << (inputSizes[0] * elemWidth / 8)
213 <<
" bytes, which is not divisible by " << (addressGranularity / 8)
215 return forOp->emitOpError(msg.str());
218 for (
int i = 0; i < 3; i++) {
219 if (inputSizes[i] > 1 && inputStrides[i] < 1) {
223 return forOp->emitOpError(
"Stride ")
224 << i <<
" must be a positive integer.";
229 if (inputSizes[3] > 1 && inputStrides[3] < 0) {
230 return forOp->emitOpError(
"Stride 3 must be a non-negative integer.");
233 for (
int i = 0; i < 4; i++) {
236 if (i == 0 && inputStrides[i] == 1)
238 if (inputStrides[i] * elemWidth % addressGranularity != 0) {
239 std::stringstream msg;
240 msg <<
"Stride " << i <<
" is " << inputStrides[i] <<
" elements * "
241 << (elemWidth / 8) <<
" bytes = " << (inputStrides[i] * elemWidth / 8)
242 <<
" bytes, which is not divisible by " << (addressGranularity / 8)
244 return forOp->emitOpError(msg.str());
248 if (!skipTransformationChecks && hardwareSizes[0] > (1 << wrap_bits) - 1)
249 return forOp->emitOpError(
250 "Size 0 exceeds the [0:" + std::to_string((1 << wrap_bits) - 1) +
252 if (hardwareSizes[1] > (1 << wrap_bits) - 1)
253 return forOp->emitOpError(
254 "Size 1 exceeds the [0:" + std::to_string((1 << wrap_bits) - 1) +
256 if (hardwareSizes[3] > (1 << iter_bits))
257 return forOp->emitOpError(
258 "Size 3 exceeds the [1:" + std::to_string(1 << iter_bits) +
"] range.");
259 if (hardwareStrides[0] > (1 << step_bits))
260 return forOp->emitOpError(
"Stride 0 exceeds the [1:" +
261 std::to_string(1 << step_bits) +
"] range.");
262 if (hardwareStrides[1] > (1 << step_bits))
263 return forOp->emitOpError(
"Stride 1 exceeds the [1:" +
264 std::to_string(1 << step_bits) +
"] range.");
265 if (hardwareStrides[2] > (1 << step_bits))
266 return forOp->emitOpError(
"Stride 2 exceeds the [1:" +
267 std::to_string(1 << step_bits) +
"] range.");
270 if (hardwareStrides[3] > (1 << step_bits) && hardwareSizes[3] > 0)
271 return forOp->emitOpError(
"Stride 3 exceeds the [1:" +
272 std::to_string(1 << step_bits) +
"] range.");
281LogicalResult AIEX::UseTokenOp::verify() {
282 auto *parentOp = (*this)->getParentOp();
283 if (isa<func::FuncOp>(parentOp) || isa<AIE::CoreOp>(parentOp) ||
284 isa<AIE::MemOp>(parentOp) || isa<AIE::ShimDMAOp>(parentOp))
293LogicalResult AIEX::MulticastOp::verify() {
294 Region &body = getPorts();
295 assert(getOperation()->getNumRegions());
296 assert(!body.empty());
297 for (
auto &ops : body.front())
298 if (!isa<MultiDestOp,
AIE::EndOp>(ops))
299 return ops.emitOpError(
"cannot be contained in a Multicast op");
308LogicalResult AIEX::BroadcastPacketOp::verify() {
309 Region &body = getPorts();
310 assert(getOperation()->getNumRegions());
311 assert(!body.empty());
312 for (
auto &ops : body.front())
313 if (!isa<BPIDOp,
AIE::EndOp>(ops))
314 return ops.emitOpError(
"cannot be contained in a BroadcastPacket op");
325int64_t AIEX::NpuDmaMemcpyNdOp::getOffsetInBytes() {
326 llvm::SmallVector<int64_t, 4> offsets =
327 llvm::map_to_vector(llvm::reverse(getMixedOffsets()), [](OpFoldResult s) {
328 return getConstantIntValue(s).value();
330 llvm::SmallVector<int64_t, 4> strides =
331 llvm::map_to_vector(llvm::reverse(getMixedStrides()), [](OpFoldResult s) {
332 return getConstantIntValue(s).value();
335 size_t R = offsets.size();
336 size_t el_bit_width = getElementTypeBitwidth();
337 assert(el_bit_width % 8 == 0 &&
338 "Expected Memref element bitwidth to be multiple of 8.");
339 size_t S = el_bit_width / 8;
340 for (
size_t i = 0; i < R; i++)
341 offset += offsets[i] * strides[i] * S;
349bool AIEX::NpuDmaMemcpyNdOp::isLinearTransferWithoutTransformation() {
350 llvm::SmallVector<int64_t, 4> inputSizes =
351 llvm::map_to_vector(llvm::reverse(getMixedSizes()), [](OpFoldResult s) {
352 return getConstantIntValue(s).value();
354 llvm::SmallVector<int64_t, 4> inputStrides =
355 llvm::map_to_vector(llvm::reverse(getMixedStrides()), [](OpFoldResult s) {
356 return getConstantIntValue(s).value();
358 return (inputSizes[1] == 1 && inputSizes[2] == 1 && inputStrides[0] == 1 &&
359 inputStrides[1] == 0 && inputStrides[2] == 0);
365static std::optional<std::string>
367 uint32_t requestedBurstLength) {
368 if (requestedBurstLength != 0) {
370 auto pair = std::find_if(bel.begin(), bel.end(),
371 [=](
const std::pair<uint32_t, uint32_t> &p) {
372 return p.second == requestedBurstLength;
375 if (pair == bel.end()) {
376 std::string errorMessage =
377 "Requested burst length is not supported by the target. "
378 "Supported burst lengths:";
381 std::accumulate(bel.begin(), bel.end(), errorMessage,
382 [](
const std::string &a,
auto b) {
383 return a +
" " + std::to_string(b.second);
393LogicalResult AIEX::NpuDmaMemcpyNdOp::verify() {
394 BaseMemRefType buffer = getMemref().getType();
398 if (getElementTypeBitwidth() > addressGranularity) {
399 return emitOpError(
"Maximum element bit width allowed is ")
400 << addressGranularity <<
"bits. ";
402 if (buffer.hasStaticShape() &&
403 (buffer.getNumElements() * getElementTypeBitwidth()) <
404 addressGranularity) {
405 return emitOpError(
"Minimum data transfer size required is ")
406 << addressGranularity <<
"bits. ";
408 if (!llvm::all_of(getMixedStrides(), [](OpFoldResult s) {
409 return getConstantIntValue(s).has_value();
411 return emitOpError(
"Only constant strides currently supported.");
412 if (!llvm::all_of(getMixedSizes(), [](OpFoldResult s) {
413 return getConstantIntValue(s).has_value();
415 return emitOpError(
"Only constant sizes currently supported.");
416 if (!llvm::all_of(getMixedOffsets(), [](OpFoldResult s) {
417 return getConstantIntValue(s).has_value();
419 return emitOpError(
"Only constant offsets currently supported.");
421 llvm::SmallVector<int64_t, 4> inputSizes =
422 llvm::map_to_vector(llvm::reverse(getMixedSizes()), [](OpFoldResult s) {
423 return getConstantIntValue(s).value();
425 llvm::SmallVector<int64_t, 4> inputStrides =
426 llvm::map_to_vector(llvm::reverse(getMixedStrides()), [](OpFoldResult s) {
427 return getConstantIntValue(s).value();
429 llvm::SmallVector<int64_t, 4> hardwareSizes(4);
430 llvm::SmallVector<int64_t, 4> hardwareStrides(4);
432 inputStrides, hardwareSizes, hardwareStrides);
433 int64_t offset = getOffsetInBytes();
435 auto errorMessage = checkBurstLength(targetModel, getBurstLength());
436 if (errorMessage.has_value()) {
437 return emitOpError(errorMessage.value());
445 if (offset % 4 != 0) {
446 return emitOpError(
"Offset must be 4-byte-aligned.");
456 AIE::DeviceOp dev = getOperation()->getParentOfType<AIE::DeviceOp>();
457 if (
auto allocOp = allocGetter.
get(dev, getMetadata())) {
458 int col = allocOp->getCol();
459 bool skipTransformationChecks = isLinearTransferWithoutTransformation();
461 inputStrides, hardwareSizes, hardwareStrides,
462 skipTransformationChecks))) {
468 if (
auto packetInfo = getPacket()) {
469 if (packetInfo->getPktType() > 7)
470 return emitOpError(
"Packet type field can only hold 3 bits.");
471 if (packetInfo->getPktId() > 31)
472 return emitOpError(
"Packet ID field can only hold 5 bits.");
482LogicalResult AIEX::NpuDmaWaitOp::verify() {
483 AIE::DeviceOp dev = (*this)->getParentOfType<AIE::DeviceOp>();
486 if (dev && !dev.lookupSymbol(getSymbol()))
487 return emitOpError(
"couldn't find symbol in parent device");
495LogicalResult AIEX::NpuPushQueueOp::verify() {
497 auto numBds = targetModel.
getNumBDs(getColumn(), getRow());
498 if (getBdId() > numBds)
499 return emitOpError(
"BD ID exceeds the maximum ID.");
500 if (getRepeatCount() > 255)
501 return emitOpError(
"Repeat count exceeds the [0:255] range.");
509LogicalResult AIEX::NpuWriteBdOp::verify() {
511 auto numBds = targetModel.
getNumBDs(getColumn(), getRow());
512 bool isLinearTransfer =
513 (getD0Size() >= 1) && (getD1Size() == 1) && (getIterationSize() == 0);
514 if (getBdId() > numBds)
515 return emitOpError(
"BD ID exceeds the maximum ID.");
516 if (getPacketId() > 31)
517 return emitOpError(
"Packet ID exceeds the maximum supported by 5 bits.");
518 if (getPacketType() > 7)
519 return emitOpError(
"Packet Type exceeds the maximum supported by 3 bits.");
520 if (!isLinearTransfer && getD0Size() > 0x3FF)
521 return emitOpError(
"D0 Size exceeds the [0:1023] range.");
522 if (getD0Stride() > 0xFFFFF)
523 return emitOpError(
"D0 Stride exceeds the [0:1M-1] range.");
524 if (getD1Size() > 0x3FF)
525 return emitOpError(
"D1 Size exceeds the [0:1023] range.");
526 if (getD1Stride() > 0xFFFFF)
527 return emitOpError(
"D1 Stride exceeds the [0:1M-1] range.");
528 if (getD2Stride() > 0xFFFFF)
529 return emitOpError(
"D2 Stride exceeds the [0:1M-1] range.");
530 if (getIterationSize() > 0x3F)
531 return emitOpError(
"Iteration Size exceeds the [0:63] range.");
532 if (getIterationStride() > 0xFFFFF)
533 return emitOpError(
"Iteration Stride exceeds the [0:1M-1] range.");
534 if (targetModel.
isShimNOCTile(getColumn(), getRow()) && getD2Size() != 0)
535 return emitOpError(
"ShimTile only supports 3 dimensions of sizes.");
537 (getD0ZeroBefore() != 0 || getD0ZeroAfter() != 0 ||
538 getD1ZeroBefore() != 0 || getD1ZeroAfter() != 0 ||
539 getD2ZeroBefore() != 0 || getD2ZeroAfter() != 0))
540 return emitOpError(
"ShimTile doesn't support zero padding.");
542 getBurstLength() != 0)
543 return emitOpError(
"Only ShimTiles support burst length.");
544 auto errorMessage = checkBurstLength(targetModel, getBurstLength());
545 if (errorMessage.has_value()) {
546 return emitOpError(errorMessage.value());
556ParseResult AIEX::RuntimeSequenceOp::parse(OpAsmParser &parser,
557 OperationState &result) {
560 (void)parser.parseOptionalSymbolName(
561 nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);
563 SmallVector<OpAsmParser::Argument> entryArgs;
566 ParseResult argParseResult = parser.parseCommaSeparatedList(
567 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
568 OpAsmParser::Argument argument;
569 if (parser.parseArgument(argument,
true,
true)) {
572 entryArgs.push_back(argument);
575 if (argParseResult) {
576 return argParseResult;
580 auto *body = result.addRegion();
581 ParseResult bodyParseResult = parser.parseRegion(*body, entryArgs,
false);
582 if (bodyParseResult) {
583 return bodyParseResult;
589void AIEX::RuntimeSequenceOp::print(OpAsmPrinter &printer) {
590 Region &body = getRegion();
592 auto nameAttr = (*this)->getAttrOfType<StringAttr>(
593 mlir::SymbolTable::getSymbolAttrName());
596 printer.printSymbolName(nameAttr);
600 for (
unsigned i = 0, n = body.getNumArguments(); i < n; i++) {
604 printer.printRegionArgument(body.getArgument(i));
609 printer.printRegion(body,
false,
true);
612LogicalResult AIEX::RuntimeSequenceOp::verify() {
613 AIE::DeviceOp device = (*this)->getParentOfType<AIE::DeviceOp>();
616 (*this)->emitOpError() <<
"must be inside AIE device operation.";
626std::optional<uint32_t> AIEX::DMAConfigureTaskOp::getFirstBdId() {
627 Region &body = getBody();
631 auto bd_ops = body.front().getOps<AIE::DMABDOp>();
632 if (bd_ops.empty() && body.front().getNumSuccessors() == 1) {
636 Block &chain_entry = *body.front().getSuccessor(0);
637 bd_ops = chain_entry.getOps<AIE::DMABDOp>();
639 if (bd_ops.empty()) {
642 AIE::DMABDOp bd = *bd_ops.begin();
643 if (!bd.getBdId().has_value()) {
646 return bd.getBdId().value();
650AIEX::DMAConfigureTaskOp::canonicalize(AIEX::DMAConfigureTaskOp op,
651 PatternRewriter &rewriter) {
653 Region &body = op.getBody();
654 bool did_rewrite =
false;
655 for (
auto it = body.begin(); it != body.end(); ++it) {
660 auto ops_it = block.without_terminator();
661 if (std::distance(ops_it.begin(), ops_it.end()) == 0) {
662 rewriter.eraseOp(block.getTerminator());
672LogicalResult AIEX::DMAConfigureTaskOp::verify() {
673 Region &body = getBody();
674 for (
auto it = body.begin(); it != body.end(); ++it) {
679 if (block.hasNoPredecessors() && !block.isEntryBlock()) {
680 auto error = block.getTerminator()->emitError(
681 "Block ending in this terminator does not form a chain with "
693 LogicalResult result = success();
694 block.walk([&](AIE::DMABDOp bd) {
695 if (bd.getBurstLength() != 0 &&
696 !targetModel.
isShimNOCTile(getTileID().col, getTileID().row)) {
697 bd.emitOpError(
"Burst length is only supported in Shim NOC tiles that "
698 "are connected to the memory-mapped NOC.");
702 if (failed(result)) {
713AIE::BDChainOp AIEX::DMAStartBdChainOp::getBDChainOp() {
714 AIE::DeviceOp device = (*this)->getParentOfType<AIE::DeviceOp>();
715 AIE::BDChainOp chain = device.lookupSymbol<AIE::BDChainOp>(getSymbol());
719LogicalResult AIEX::DMAStartBdChainOp::verify() {
720 AIE::BDChainOp chain = getBDChainOp();
722 return emitOpError(
"symbol does not reference valid BD chain");
725 auto actualArgTypes = getArgs().getTypes();
726 auto expectedArgTypes = chain.getRegion().getArgumentTypes();
727 if (actualArgTypes.size() != expectedArgTypes.size()) {
728 return emitOpError(
"Number of arguments mismatches.");
730 for (
unsigned i = 0, n = expectedArgTypes.size(); i < n; i++) {
731 if (actualArgTypes[i] != expectedArgTypes[i]) {
732 return emitOpError(
"Argument ") << (i + 1) <<
" types mismatch: "
733 <<
"expected " << expectedArgTypes[i]
734 <<
" but got " << actualArgTypes[i];
744uint32_t AIEX::NpuControlPacketOp::getRowFromAddr() {
746 uint32_t addr = getAddress();
747 uint32_t rowInt = (addr >> targetModel.
getRowShift()) & 0x1f;
751uint32_t AIEX::NpuControlPacketOp::getColumnFromAddr() {
753 uint32_t addr = getAddress();
762LogicalResult AIEX::SetLockOp::verify() {
766 return emitOpError(
"SetLockOp is not supported on AIE1.");
769 return emitOpError(
"Lock value exceeds the maximum value of " +
772 auto lockOp = getLockOp();
773 auto lockIDOpt = getLockOp().getLockID();
780 auto col = lockOp.colIndex();
781 auto row = lockOp.rowIndex();
782 uint32_t lockID = lockOp.getLockIDValue();
785 return emitOpError(
"Lock ID out of range for given tile. Max ID: " +
786 std::to_string(targetModel.
getNumLocks(col, row) - 1));
790 return emitOpError(
"Invalid lock ID and tile combination when trying to "
791 "retrieve the local lock address.");
800uint64_t AIEX::BlockFloatType::getTotalSizeInBits()
const {
801 return getBlockSize() * getMantissaBits() + getExponentBits() +
802 getSubtileShiftBits();
805llvm::TypeSize AIEX::BlockFloatType::getTypeSizeInBits(
806 const mlir::DataLayout &dataLayout,
807 mlir::DataLayoutEntryListRef params)
const {
808 return llvm::TypeSize::getFixed(getTotalSizeInBits());
811uint64_t AIEX::BlockFloatType::getABIAlignment(
812 const mlir::DataLayout &dataLayout,
813 mlir::DataLayoutEntryListRef params)
const {
819std::optional<AIEX::BlockFloatType::BlockFormat>
820AIEX::BlockFloatType::getBlockFormat(StringRef blockType) {
821 static const llvm::StringMap<AIEX::BlockFloatType::BlockFormat>
823 {
"v8bfp16ebs8", {8, 8, 8, 0}},
824 {
"v16bfp16ebs16", {16, 8, 8, 0}},
827 auto it = blockFormatsMap.find(blockType);
828 if (it != blockFormatsMap.end()) {
836AIEX::BlockFloatType::verify(function_ref<InFlightDiagnostic()> emitError,
837 StringRef block_type) {
838 if (!getBlockFormat(block_type))
839 return emitError() <<
"Invalid block type: " << block_type
840 <<
". Known types are: v8bfp16ebs8, v16bfp16ebs16.";
virtual AIEArch getTargetArch() const =0
Return the target architecture.
virtual uint32_t getNumBDs(int col, int row) const =0
Return the number of buffer descriptors supported by the DMA in the given tile.
virtual std::vector< std::pair< uint32_t, uint32_t > > getShimBurstEncodingsAndLengths() const =0
virtual std::optional< uint32_t > getLocalLockAddress(uint32_t lockId, TileID tile) const =0
virtual uint32_t getNumLocks(int col, int row) const =0
Return the number of lock objects.
virtual bool isShimNOCTile(int col, int row) const =0
Return true if the given tile is a Shim NOC tile.
virtual uint32_t getMaxLockValue() const =0
Return the maximum value that can be stored in a lock register.
virtual uint32_t getColumnShift() const =0
virtual uint32_t getRowShift() const =0
virtual uint32_t getAddressGenGranularity() const =0
Return the data bus width of the device.
void getHardwareStridesWraps(const AIE::AIETargetModel &targetModel, mlir::Operation *op, mlir::BaseMemRefType referencedBufType, llvm::SmallVector< int64_t, 4 > inputSizes, llvm::SmallVector< int64_t, 4 > inputStrides, llvm::SmallVector< int64_t, 4 > &sizes, llvm::SmallVector< int64_t, 4 > &strides)
mlir::LogicalResult verifyStridesWraps(mlir::Operation *forOp, mlir::BaseMemRefType referencedBufType, int tileCol, int tileRow, llvm::SmallVector< int64_t, 4 > inputSizes, llvm::SmallVector< int64_t, 4 > inputStrides, llvm::SmallVector< int64_t, 4 > hardwareSizes, llvm::SmallVector< int64_t, 4 > hardwareStrides, bool skipTransformationChecks=false)
const AIETargetModel & getTargetModel(mlir::Operation *op)
std::optional< AIE::ShimDMAAllocationOp > get(DeviceOp dev, mlir::StringRef sym_name)