13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/IR/DialectImplementation.h"
15#include "mlir/IR/Operation.h"
16#include "mlir/IR/SymbolTable.h"
17#include "mlir/IR/TypeUtilities.h"
18#include "mlir/Interfaces/DataLayoutInterfaces.h"
19#include "mlir/Interfaces/FoldInterfaces.h"
20#include "mlir/Transforms/InliningUtils.h"
22#include "llvm/ADT/TypeSwitch.h"
23#include "llvm/Support/TypeSize.h"
31#include "aie/Dialect/AIEX/IR/AIEXDialect.cpp.inc"
33#define GET_TYPEDEF_CLASSES
34#include "aie/Dialect/AIEX/IR/AIEXTypes.cpp.inc"
39void AIEXDialect::initialize() {
42#include "aie/Dialect/AIEX/IR/AIEX.cpp.inc"
45#define GET_TYPEDEF_LIST
46#include "aie/Dialect/AIEX/IR/AIEXTypes.cpp.inc"
53#include "aie/Dialect/AIEX/IR/AIEX.cpp.inc"
90 mlir::BaseMemRefType referencedBufType,
91 llvm::SmallVector<int64_t, 4> inputSizes,
92 llvm::SmallVector<int64_t, 4> inputStrides,
93 llvm::SmallVector<int64_t, 4> &sizes,
94 llvm::SmallVector<int64_t, 4> &strides) {
95 assert(inputSizes.size() == inputStrides.size());
96 assert(sizes.size() == 4);
97 assert(strides.size() == 4);
99 DataLayout dataLayout = DataLayout::closest(op);
101 dataLayout.getTypeSizeInBits(referencedBufType.getElementType());
105 std::fill(sizes.begin(), sizes.end(), 0);
106 std::fill(strides.begin(), strides.end(), 0);
108 if (inputSizes[0] == 0) {
115 sizes[0] = inputSizes[0] * elemWidth / addressGranularity;
116 if (inputStrides[0] * elemWidth < addressGranularity ||
117 (elemWidth > addressGranularity)) {
138 strides[0] = inputStrides[0] * elemWidth / addressGranularity - 1;
142 sizes[1] = inputSizes[1];
143 if (inputSizes[1] > 1) {
145 strides[1] = inputStrides[1] * elemWidth / addressGranularity - 1;
149 sizes[2] = inputSizes[2];
150 if (inputSizes[2] > 1) {
152 strides[2] = inputStrides[2] * elemWidth / addressGranularity - 1;
156 if (inputSizes[3] > 1) {
158 sizes[3] = inputSizes[3] - 1;
165 if (inputStrides[3] > 0) {
166 strides[3] = inputStrides[3] * elemWidth / addressGranularity - 1;
173 mlir::BaseMemRefType referencedBufType,
int tileCol,
174 int tileRow, llvm::SmallVector<int64_t, 4> inputSizes,
175 llvm::SmallVector<int64_t, 4> inputStrides,
176 llvm::SmallVector<int64_t, 4> hardwareSizes,
177 llvm::SmallVector<int64_t, 4> hardwareStrides,
178 bool skipTransformationChecks) {
180 auto addressGranularity = targetModel.getAddressGenGranularity();
181 DataLayout dataLayout = DataLayout::closest(forOp);
183 dataLayout.getTypeSizeInBits(referencedBufType.getElementType());
185 uint32_t wrap_bits = 0;
186 uint32_t step_bits = 0;
187 uint32_t iter_bits = 6;
188 if (targetModel.isShimNOCTile(tileCol, tileRow)) {
191 }
else if (targetModel.isMemTile(tileCol, tileRow)) {
194 }
else if (targetModel.isCoreTile(tileCol, tileRow)) {
198 return forOp->emitOpError(
199 "Unsupported tile type at (" + std::to_string(tileCol) +
", " +
200 std::to_string(tileRow) +
") Must be ShimNOC, Mem or Core.");
203 for (
int i = 0; i < 4; i++) {
204 if (inputSizes[i] <= 0) {
205 return forOp->emitOpError(
"Size ") << i <<
" must be a positive integer.";
209 if (inputSizes[0] * elemWidth % addressGranularity != 0) {
210 std::stringstream msg;
211 msg <<
"Transfer sizes must be multiples of " << (addressGranularity / 8)
212 <<
" bytes. " << inputSizes[0] <<
" elements at " << (elemWidth / 8)
213 <<
" bytes each equal " << (inputSizes[0] * elemWidth / 8)
214 <<
" bytes, which is not divisible by " << (addressGranularity / 8)
216 return forOp->emitOpError(msg.str());
219 for (
int i = 0; i < 3; i++) {
220 if (inputSizes[i] > 1 && inputStrides[i] < 1) {
224 return forOp->emitOpError(
"Stride ")
225 << i <<
" must be a positive integer.";
230 if (inputSizes[3] > 1 && inputStrides[3] < 0) {
231 return forOp->emitOpError(
"Stride 3 must be a non-negative integer.");
234 for (
int i = 0; i < 4; i++) {
237 if (i == 0 && inputStrides[i] == 1)
239 if (inputStrides[i] * elemWidth % addressGranularity != 0) {
240 std::stringstream msg;
241 msg <<
"Stride " << i <<
" is " << inputStrides[i] <<
" elements * "
242 << (elemWidth / 8) <<
" bytes = " << (inputStrides[i] * elemWidth / 8)
243 <<
" bytes, which is not divisible by " << (addressGranularity / 8)
245 return forOp->emitOpError(msg.str());
249 if (!skipTransformationChecks && hardwareSizes[0] > (1 << wrap_bits) - 1)
250 return forOp->emitOpError(
251 "Size 0 exceeds the [0:" + std::to_string((1 << wrap_bits) - 1) +
253 if (hardwareSizes[1] > (1 << wrap_bits) - 1)
254 return forOp->emitOpError(
255 "Size 1 exceeds the [0:" + std::to_string((1 << wrap_bits) - 1) +
257 if (hardwareSizes[3] > (1 << iter_bits))
258 return forOp->emitOpError(
259 "Size 3 exceeds the [1:" + std::to_string(1 << iter_bits) +
"] range.");
260 if (hardwareStrides[0] > (1 << step_bits))
261 return forOp->emitOpError(
"Stride 0 exceeds the [1:" +
262 std::to_string(1 << step_bits) +
"] range.");
263 if (hardwareStrides[1] > (1 << step_bits))
264 return forOp->emitOpError(
"Stride 1 exceeds the [1:" +
265 std::to_string(1 << step_bits) +
"] range.");
266 if (hardwareStrides[2] > (1 << step_bits))
267 return forOp->emitOpError(
"Stride 2 exceeds the [1:" +
268 std::to_string(1 << step_bits) +
"] range.");
271 if (hardwareStrides[3] > (1 << step_bits) && hardwareSizes[3] > 0)
272 return forOp->emitOpError(
"Stride 3 exceeds the [1:" +
273 std::to_string(1 << step_bits) +
"] range.");
282LogicalResult AIEX::UseTokenOp::verify() {
283 auto *parentOp = (*this)->getParentOp();
284 if (isa<func::FuncOp>(parentOp) || isa<AIE::CoreOp>(parentOp) ||
285 isa<AIE::MemOp>(parentOp) || isa<AIE::ShimDMAOp>(parentOp))
294LogicalResult AIEX::MulticastOp::verify() {
295 Region &body = getPorts();
296 assert(getOperation()->getNumRegions());
297 assert(!body.empty());
298 for (
auto &ops : body.front())
299 if (!isa<MultiDestOp,
AIE::EndOp>(ops))
300 return ops.emitOpError(
"cannot be contained in a Multicast op");
309LogicalResult AIEX::BroadcastPacketOp::verify() {
310 Region &body = getPorts();
311 assert(getOperation()->getNumRegions());
312 assert(!body.empty());
313 for (
auto &ops : body.front())
314 if (!isa<BPIDOp,
AIE::EndOp>(ops))
315 return ops.emitOpError(
"cannot be contained in a BroadcastPacket op");
326int64_t AIEX::NpuDmaMemcpyNdOp::getOffsetInBytes() {
327 llvm::SmallVector<int64_t, 4> offsets =
328 llvm::map_to_vector(llvm::reverse(getMixedOffsets()), [](OpFoldResult s) {
329 return getConstantIntValue(s).value();
331 llvm::SmallVector<int64_t, 4> strides =
332 llvm::map_to_vector(llvm::reverse(getMixedStrides()), [](OpFoldResult s) {
333 return getConstantIntValue(s).value();
336 size_t R = offsets.size();
337 size_t el_bit_width = getElementTypeBitwidth();
338 assert(el_bit_width % 8 == 0 &&
339 "Expected Memref element bitwidth to be multiple of 8.");
340 size_t S = el_bit_width / 8;
341 for (
size_t i = 0; i < R; i++)
342 offset += offsets[i] * strides[i] * S;
350 llvm::ArrayRef<int64_t> strides) {
351 return sizes[1] == 1 && sizes[2] == 1 && strides[0] == 1 && strides[1] == 0 &&
359bool AIEX::NpuDmaMemcpyNdOp::isLinearTransferWithoutTransformation() {
360 llvm::SmallVector<int64_t, 4> inputSizes =
361 llvm::map_to_vector(llvm::reverse(getMixedSizes()), [](OpFoldResult s) {
362 return getConstantIntValue(s).value();
364 llvm::SmallVector<int64_t, 4> inputStrides =
365 llvm::map_to_vector(llvm::reverse(getMixedStrides()), [](OpFoldResult s) {
366 return getConstantIntValue(s).value();
374static std::optional<std::string>
376 uint32_t requestedBurstLength) {
377 if (requestedBurstLength != 0) {
379 auto pair = std::find_if(bel.begin(), bel.end(),
380 [=](
const std::pair<uint32_t, uint32_t> &p) {
381 return p.second == requestedBurstLength;
384 if (pair == bel.end()) {
385 std::string errorMessage =
386 "Requested burst length is not supported by the target. "
387 "Supported burst lengths:";
390 std::accumulate(bel.begin(), bel.end(), errorMessage,
391 [](
const std::string &a,
auto b) {
392 return a +
" " + std::to_string(b.second);
402LogicalResult AIEX::NpuDmaMemcpyNdOp::verify() {
403 BaseMemRefType buffer = getMemref().getType();
407 if (getElementTypeBitwidth() > addressGranularity) {
408 return emitOpError(
"Maximum element bit width allowed is ")
409 << addressGranularity <<
"bits. ";
411 if (buffer.hasStaticShape() &&
412 (buffer.getNumElements() * getElementTypeBitwidth()) <
413 addressGranularity) {
414 return emitOpError(
"Minimum data transfer size required is ")
415 << addressGranularity <<
"bits. ";
417 if (!llvm::all_of(getMixedStrides(), [](OpFoldResult s) {
418 return getConstantIntValue(s).has_value();
420 return emitOpError(
"Only constant strides currently supported.");
421 if (!llvm::all_of(getMixedSizes(), [](OpFoldResult s) {
422 return getConstantIntValue(s).has_value();
424 return emitOpError(
"Only constant sizes currently supported.");
425 if (!llvm::all_of(getMixedOffsets(), [](OpFoldResult s) {
426 return getConstantIntValue(s).has_value();
428 return emitOpError(
"Only constant offsets currently supported.");
430 llvm::SmallVector<int64_t, 4> inputSizes =
431 llvm::map_to_vector(llvm::reverse(getMixedSizes()), [](OpFoldResult s) {
432 return getConstantIntValue(s).value();
434 llvm::SmallVector<int64_t, 4> inputStrides =
435 llvm::map_to_vector(llvm::reverse(getMixedStrides()), [](OpFoldResult s) {
436 return getConstantIntValue(s).value();
438 llvm::SmallVector<int64_t, 4> hardwareSizes(4);
439 llvm::SmallVector<int64_t, 4> hardwareStrides(4);
441 inputStrides, hardwareSizes, hardwareStrides);
442 int64_t offset = getOffsetInBytes();
444 auto errorMessage = checkBurstLength(targetModel, getBurstLength());
445 if (errorMessage.has_value()) {
446 return emitOpError(errorMessage.value());
454 if (offset % 4 != 0) {
455 return emitOpError(
"Offset must be 4-byte-aligned.");
464 AIE::DeviceOp dev = getOperation()->getParentOfType<AIE::DeviceOp>();
465 if (
auto allocOp = AIE::ShimDMAAllocationOp::getForSymbol(
466 dev, getMetadata().getRootReference())) {
467 AIE::TileOp tile = allocOp.getTileOp();
469 return emitOpError(
"shim DMA allocation must reference a valid TileOp");
471 int col = tile.getCol();
472 int row = tile.getRow();
473 bool skipTransformationChecks = isLinearTransferWithoutTransformation();
475 inputStrides, hardwareSizes, hardwareStrides,
476 skipTransformationChecks))) {
482 if (
auto packetInfo = getPacket()) {
483 if (packetInfo->getPktType() > 7)
484 return emitOpError(
"Packet type field can only hold 3 bits.");
485 if (packetInfo->getPktId() > 31)
486 return emitOpError(
"Packet ID field can only hold 5 bits.");
496LogicalResult AIEX::NpuDmaWaitOp::verify() {
497 AIE::DeviceOp dev = (*this)->getParentOfType<AIE::DeviceOp>();
500 if (dev && !dev.lookupSymbol(getSymbol()))
501 return emitOpError(
"couldn't find symbol in parent device");
509LogicalResult AIEX::NpuPushQueueOp::verify() {
511 auto numBds = targetModel.
getNumBDs(getColumn(), getRow());
512 if (getBdId() > numBds)
513 return emitOpError(
"BD ID exceeds the maximum ID.");
514 if (getRepeatCount() > 255)
515 return emitOpError(
"Repeat count exceeds the [0:255] range.");
523LogicalResult AIEX::NpuWriteBdOp::verify() {
525 auto numBds = targetModel.
getNumBDs(getColumn(), getRow());
527 (getD0Size() >= 1) && (getD1Size() == 1) && (getIterationSize() == 0);
528 if (getBdId() > numBds)
529 return emitOpError(
"BD ID exceeds the maximum ID.");
530 if (getPacketId() > 31)
531 return emitOpError(
"Packet ID exceeds the maximum supported by 5 bits.");
532 if (getPacketType() > 7)
533 return emitOpError(
"Packet Type exceeds the maximum supported by 3 bits.");
534 if (!isLinearTransfer && getD0Size() > 0x3FF)
535 return emitOpError(
"D0 Size exceeds the [0:1023] range.");
536 if (getD0Stride() > 0xFFFFF)
537 return emitOpError(
"D0 Stride exceeds the [0:1M-1] range.");
538 if (getD1Size() > 0x3FF)
539 return emitOpError(
"D1 Size exceeds the [0:1023] range.");
540 if (getD1Stride() > 0xFFFFF)
541 return emitOpError(
"D1 Stride exceeds the [0:1M-1] range.");
542 if (getD2Stride() > 0xFFFFF)
543 return emitOpError(
"D2 Stride exceeds the [0:1M-1] range.");
544 if (getIterationSize() > 0x3F)
545 return emitOpError(
"Iteration Size exceeds the [0:63] range.");
546 if (getIterationStride() > 0xFFFFF)
547 return emitOpError(
"Iteration Stride exceeds the [0:1M-1] range.");
548 if (targetModel.
isShimNOCTile(getColumn(), getRow()) && getD2Size() != 0)
549 return emitOpError(
"ShimTile only supports 3 dimensions of sizes.");
551 (getD0ZeroBefore() != 0 || getD0ZeroAfter() != 0 ||
552 getD1ZeroBefore() != 0 || getD1ZeroAfter() != 0 ||
553 getD2ZeroBefore() != 0 || getD2ZeroAfter() != 0))
554 return emitOpError(
"ShimTile doesn't support zero padding.");
556 getBurstLength() != 0)
557 return emitOpError(
"Only ShimTiles support burst length.");
558 auto errorMessage = checkBurstLength(targetModel, getBurstLength());
559 if (errorMessage.has_value()) {
560 return emitOpError(errorMessage.value());
571static std::optional<uint32_t> getAbsoluteAddress(T *op) {
572 AIE::DeviceOp device =
573 op->getOperation()->template getParentOfType<AIE::DeviceOp>();
575 op->emitError(
"Must be inside a device.");
580 uint32_t address = 0;
584 if (op->getBuffer()) {
585 AIE::BufferOp buffer = device.lookupSymbol<AIE::BufferOp>(*op->getBuffer());
587 op->emitError() <<
"buffer '" << *op->getBuffer()
588 <<
"' not found in device";
592 if (!buffer.getAddress()) {
593 mlir::InFlightDiagnostic err =
594 op->emitError(
"referenced buffer must have address assigned");
595 err.attachNote(buffer.getLoc()) <<
"This buffer must have an address.";
599 uint32_t
col = buffer.getTileOp().getCol();
600 uint32_t
row = buffer.getTileOp().getRow();
601 address =
static_cast<uint32_t
>(*buffer.getAddress()) +
602 op->getAddress() *
sizeof(uint32_t);
604 ((row & 0xff) << tm.
getRowShift()) | (address & 0xfffff);
606 address = op->getAddress();
607 std::optional<uint32_t>
col = op->getColumn();
608 std::optional<uint32_t>
row = op->getRow();
613 ((*row & 0xff) << tm.
getRowShift()) | (address & 0xfffff);
620std::optional<uint32_t> AIEX::NpuWrite32Op::getAbsoluteAddress() {
621 return ::getAbsoluteAddress(
this);
628std::optional<uint32_t> AIEX::NpuMaskWrite32Op::getAbsoluteAddress() {
629 return ::getAbsoluteAddress(
this);
636std::optional<uint32_t> AIEX::NpuBlockWriteOp::getAbsoluteAddress() {
637 return ::getAbsoluteAddress(
this);
640DenseIntElementsAttr AIEX::NpuBlockWriteOp::getDataWords() {
641 Value memref = this->getData();
642 DataLayout dataLayout = DataLayout::closest(*
this);
643 int64_t width = dataLayout.getTypeSizeInBits(
644 cast<MemRefType>(memref.getType()).getElementType());
646 emitWarning(
"Only 32-bit data type is supported for now");
650 memref::GetGlobalOp getGlobal = memref.getDefiningOp<memref::GetGlobalOp>();
652 emitError(
"Only MemRefs from memref.get_global are supported");
656 auto global = dyn_cast_if_present<memref::GlobalOp>(
657 (*this)->getParentOfType<AIE::DeviceOp>().lookupSymbol(
658 getGlobal.getName()));
660 emitError(
"Global symbol not found");
664 auto initVal = global.getInitialValue();
666 emitError(
"Global symbol has no initial value");
670 auto data = dyn_cast<DenseIntElementsAttr>(*initVal);
672 emitError(
"Global symbol initial value is not a dense int array");
683std::optional<uint32_t> AIEX::DMAConfigureTaskOp::getFirstBdId() {
684 Region &body = getBody();
688 auto bd_ops = body.front().getOps<AIE::DMABDOp>();
689 if (bd_ops.empty() && body.front().getNumSuccessors() == 1) {
693 Block &chain_entry = *body.front().getSuccessor(0);
694 bd_ops = chain_entry.getOps<AIE::DMABDOp>();
696 if (bd_ops.empty()) {
699 AIE::DMABDOp bd = *bd_ops.begin();
700 if (!bd.getBdId().has_value()) {
703 return bd.getBdId().value();
707AIEX::DMAConfigureTaskOp::canonicalize(AIEX::DMAConfigureTaskOp op,
708 PatternRewriter &rewriter) {
710 Region &body = op.getBody();
711 bool did_rewrite =
false;
712 for (
auto it = body.begin(); it != body.end(); ++it) {
717 auto ops_it = block.without_terminator();
718 if (std::distance(ops_it.begin(), ops_it.end()) == 0) {
719 rewriter.eraseOp(block.getTerminator());
729LogicalResult AIEX::DMAConfigureTaskOp::verify() {
730 Region &body = getBody();
731 for (
auto it = body.begin(); it != body.end(); ++it) {
736 if (block.hasNoPredecessors() && !block.isEntryBlock()) {
737 auto error = block.getTerminator()->emitError(
738 "Block ending in this terminator does not form a chain with "
750 LogicalResult result = success();
751 block.walk([&](AIE::DMABDOp bd) {
752 if (bd.getBurstLength() != 0 &&
753 !targetModel.
isShimNOCTile(getTileID().col, getTileID().row)) {
754 bd.emitOpError(
"Burst length is only supported in Shim NOC tiles that "
755 "are connected to the memory-mapped NOC.");
759 if (failed(result)) {
770AIE::BDChainOp AIEX::DMAStartBdChainOp::getBDChainOp() {
771 AIE::DeviceOp device = (*this)->getParentOfType<AIE::DeviceOp>();
772 AIE::BDChainOp chain = device.lookupSymbol<AIE::BDChainOp>(getSymbol());
776LogicalResult AIEX::DMAStartBdChainOp::verify() {
777 AIE::BDChainOp chain = getBDChainOp();
779 return emitOpError(
"symbol does not reference valid BD chain");
782 auto actualArgTypes = getArgs().getTypes();
783 auto expectedArgTypes = chain.getRegion().getArgumentTypes();
784 if (actualArgTypes.size() != expectedArgTypes.size()) {
785 return emitOpError(
"Number of arguments mismatches.");
787 for (
unsigned i = 0, n = expectedArgTypes.size(); i < n; i++) {
788 if (actualArgTypes[i] != expectedArgTypes[i]) {
789 return emitOpError(
"Argument ") << (i + 1) <<
" types mismatch: "
790 <<
"expected " << expectedArgTypes[i]
791 <<
" but got " << actualArgTypes[i];
801uint32_t AIEX::NpuControlPacketOp::getRowFromAddr() {
803 uint32_t addr = getAddress();
804 uint32_t rowInt = (addr >> targetModel.
getRowShift()) & 0x1f;
808uint32_t AIEX::NpuControlPacketOp::getColumnFromAddr() {
810 uint32_t addr = getAddress();
819LogicalResult AIEX::SetLockOp::verify() {
823 return emitOpError(
"SetLockOp is not supported on AIE1.");
826 return emitOpError(
"Lock value exceeds the maximum value of " +
829 auto lockOp = getLockOp();
830 auto lockIDOpt = getLockOp().getLockID();
837 auto col = lockOp.colIndex();
838 auto row = lockOp.rowIndex();
839 uint32_t lockID = lockOp.getLockIDValue();
842 return emitOpError(
"Lock ID out of range for given tile. Max ID: " +
843 std::to_string(targetModel.
getNumLocks(col, row) - 1));
847 return emitOpError(
"Invalid lock ID and tile combination when trying to "
848 "retrieve the local lock address.");
857uint64_t AIEX::BlockFloatType::getTotalSizeInBits()
const {
858 return getBlockSize() * getMantissaBits() + getExponentBits() +
859 getSubtileShiftBits();
862llvm::TypeSize AIEX::BlockFloatType::getTypeSizeInBits(
863 const mlir::DataLayout &dataLayout,
864 mlir::DataLayoutEntryListRef params)
const {
865 return llvm::TypeSize::getFixed(getTotalSizeInBits());
868uint64_t AIEX::BlockFloatType::getABIAlignment(
869 const mlir::DataLayout &dataLayout,
870 mlir::DataLayoutEntryListRef params)
const {
876std::optional<AIEX::BlockFloatType::BlockFormat>
877AIEX::BlockFloatType::getBlockFormat(StringRef blockType) {
878 static const llvm::StringMap<AIEX::BlockFloatType::BlockFormat>
880 {
"v8bfp16ebs8", {8, 8, 8, 0}},
881 {
"v16bfp16ebs16", {16, 8, 8, 0}},
884 auto it = blockFormatsMap.find(blockType);
885 if (it != blockFormatsMap.end()) {
893AIEX::BlockFloatType::verify(function_ref<InFlightDiagnostic()> emitError,
894 StringRef block_type) {
895 if (!getBlockFormat(block_type))
896 return emitError() <<
"Invalid block type: " << block_type
897 <<
". Known types are: v8bfp16ebs8, v16bfp16ebs16.";
906AIE::DeviceOp AIEX::ConfigureOp::getReferencedDeviceOp() {
907 ModuleOp moduleOp = this->getOperation()->getParentOfType<ModuleOp>();
909 emitError(
"aiex.configure must be inside of a module");
912 Operation *maybeReferencedDevice =
913 SymbolTable::lookupSymbolIn(moduleOp.getOperation(), getSymbolAttr());
914 if (!maybeReferencedDevice) {
915 emitError(
"No such device: '") << getSymbolAttr() <<
"'";
918 AIE::DeviceOp referencedDevice =
919 llvm::dyn_cast<AIE::DeviceOp>(maybeReferencedDevice);
920 if (!referencedDevice) {
921 emitError(
"Not a device: '") << getSymbolAttr() <<
"'";
924 return referencedDevice;
927LogicalResult AIEX::ConfigureOp::verify() {
928 AIE::DeviceOp parentDev = getOperation()->getParentOfType<AIE::DeviceOp>();
929 AIE::DeviceOp referencedDev = getReferencedDeviceOp();
930 if (!referencedDev) {
933 if (parentDev.getDevice() != referencedDev.getDevice()) {
934 emitError(
"Device types do not match: '")
935 << AIE::stringifyAIEDevice(parentDev.getDevice()) <<
"' vs. '"
936 << AIE::stringifyAIEDevice(referencedDev.getDevice()) <<
"'";
946AIE::DeviceOp AIEX::RunOp::getCalleeDeviceOp() {
947 AIEX::ConfigureOp configureOp =
948 getOperation()->getParentOfType<AIEX::ConfigureOp>();
952 AIE::DeviceOp referencedDevice = configureOp.getReferencedDeviceOp();
953 return referencedDevice;
956AIE::RuntimeSequenceOp AIEX::RunOp::getCalleeRuntimeSequenceOp() {
957 AIEX::ConfigureOp configureOp =
958 getOperation()->getParentOfType<AIEX::ConfigureOp>();
962 AIE::DeviceOp referencedDevice = configureOp.getReferencedDeviceOp();
963 if (!referencedDevice) {
967 Operation *maybeRuntimeSequence =
968 SymbolTable::lookupSymbolIn(referencedDevice, getRuntimeSequenceSymbol());
970 if (!maybeRuntimeSequence) {
973 AIE::RuntimeSequenceOp runtimeSequence =
974 llvm::dyn_cast<AIE::RuntimeSequenceOp>(maybeRuntimeSequence);
975 if (!runtimeSequence) {
979 return runtimeSequence;
986LogicalResult AIEX::NpuLoadPdiOp::canonicalize(AIEX::NpuLoadPdiOp op,
987 PatternRewriter &rewriter) {
989 Operation *nextOp = op->getNextNode();
994 auto nextLoadPdi = dyn_cast<AIEX::NpuLoadPdiOp>(nextOp);
999 if (op.getDeviceRefAttr() == nextLoadPdi.getDeviceRefAttr() &&
1000 op.getId() == nextLoadPdi.getId() &&
1001 op.getSize() == nextLoadPdi.getSize() &&
1002 op.getAddress() == nextLoadPdi.getAddress()) {
1004 rewriter.eraseOp(op);
virtual AIEArch getTargetArch() const =0
Return the target architecture.
virtual uint32_t getNumLocks(AIETileType tileType) const =0
Return the number of lock objects for a given tile type.
virtual std::vector< std::pair< uint32_t, uint32_t > > getShimBurstEncodingsAndLengths() const =0
virtual std::optional< uint32_t > getLocalLockAddress(uint32_t lockId, TileID tile) const =0
bool isShimNOCTile(int col, int row) const
Return true if the given tile is a ShimNOC tile.
virtual uint32_t getNumBDs(AIETileType tileType) const =0
Return the number of buffer descriptors for a given tile type.
virtual uint32_t getMaxLockValue() const =0
Return the maximum value that can be stored in a lock register.
virtual uint32_t getColumnShift() const =0
virtual uint32_t getRowShift() const =0
virtual uint32_t getAddressGenGranularity() const =0
Return the data bus width of the device.
void getHardwareStridesWraps(const AIE::AIETargetModel &targetModel, mlir::Operation *op, mlir::BaseMemRefType referencedBufType, llvm::SmallVector< int64_t, 4 > inputSizes, llvm::SmallVector< int64_t, 4 > inputStrides, llvm::SmallVector< int64_t, 4 > &sizes, llvm::SmallVector< int64_t, 4 > &strides)
mlir::LogicalResult verifyStridesWraps(mlir::Operation *forOp, mlir::BaseMemRefType referencedBufType, int tileCol, int tileRow, llvm::SmallVector< int64_t, 4 > inputSizes, llvm::SmallVector< int64_t, 4 > inputStrides, llvm::SmallVector< int64_t, 4 > hardwareSizes, llvm::SmallVector< int64_t, 4 > hardwareStrides, bool skipTransformationChecks=false)
bool isLinearTransfer(llvm::ArrayRef< int64_t > sizes, llvm::ArrayRef< int64_t > strides)
const AIETargetModel & getTargetModel(mlir::Operation *op)