15#include "mlir/IR/PatternMatch.h"
16#include "mlir/Pass/Pass.h"
17#include "mlir/Transforms/DialectConversion.h"
22#define GEN_PASS_DEF_AIELEGALIZECONTROLPACKET
23#define GEN_PASS_DEF_AIETXNTOCONTROLPACKET
24#include "aie/Dialect/AIEX/Transforms/AIEXPasses.h.inc"
27#define DEBUG_TYPE "aie-txn-to-control"
35struct BlockWriteToControlPacketPattern
39 LogicalResult matchAndRewrite(AIEX::NpuBlockWriteOp op,
40 PatternRewriter &rewriter)
const override {
42 Value memref = op.getData();
43 int64_t width = cast<MemRefType>(memref.getType()).getElementTypeBitWidth();
45 return op.emitWarning(
"Only 32-bit data type is supported for now");
48 memref::GetGlobalOp getGlobal = memref.getDefiningOp<memref::GetGlobalOp>();
50 return op.emitError(
"Only MemRefs from memref.get_global are supported");
53 auto global = dyn_cast_if_present<memref::GlobalOp>(
54 op->getParentOfType<AIE::DeviceOp>().lookupSymbol(getGlobal.getName()));
56 return op.emitError(
"Global symbol not found");
59 auto initVal = global.getInitialValue();
61 return op.emitError(
"Global symbol has no initial value");
64 auto data = dyn_cast<DenseIntElementsAttr>(*initVal);
67 "Global symbol initial value is not a dense int array");
69 std::vector<int32_t> dataVec(data.value_begin<int32_t>(),
70 data.value_end<int32_t>());
71 AIEX::NpuControlPacketOp::create(
72 rewriter, op->getLoc(), op.getAddressAttr(),
nullptr,
73 rewriter.getI32IntegerAttr(0),
74 rewriter.getI32IntegerAttr(0),
75 DenseI32ArrayAttr::get(op->getContext(), dataVec));
82struct ControlPacketSplitPattern
86 ControlPacketSplitPattern(MLIRContext *ctx, uint32_t max_payload_size)
89 LogicalResult matchAndRewrite(AIEX::NpuControlPacketOp op,
90 PatternRewriter &rewriter)
const override {
91 auto data = op.getData();
95 uint32_t numElements = op.getDataAttr().size();
97 if (numElements <= max_payload_size) {
101 auto chunks = llvm::divideCeil(numElements, max_payload_size);
102 auto context = op.getContext();
103 auto loc = op.getLoc();
104 for (
unsigned i = 0; i < chunks; ++i) {
105 uint32_t startIdx = i * max_payload_size;
106 uint32_t endIdx = std::min(startIdx + max_payload_size, numElements);
108 SmallVector<int32_t, 4> chunkData;
109 for (
auto it = data->begin() + startIdx; it != data->begin() + endIdx;
111 chunkData.push_back(*it);
115 auto incrementedAddress = rewriter.getUI32IntegerAttr(
116 op.getAddress() + (i * 4 *
sizeof(uint32_t)));
118 AIEX::NpuControlPacketOp::create(
119 rewriter, loc, incrementedAddress,
nullptr, op.getOpcodeAttr(),
120 op.getStreamIdAttr(), DenseI32ArrayAttr::get(context, chunkData));
123 rewriter.eraseOp(op);
128 uint32_t max_payload_size;
131struct AIETxnToControlPacketPass
132 :
public xilinx::AIEX::impl::AIETxnToControlPacketBase<
133 AIETxnToControlPacketPass> {
134 void runOnOperation()
override {
135 AIE::DeviceOp device = getOperation();
137 ConversionTarget target(getContext());
138 target.addIllegalOp<AIEX::NpuBlockWriteOp>();
139 target.addLegalOp<AIEX::NpuControlPacketOp>();
141 RewritePatternSet patterns(&getContext());
142 patterns.add<BlockWriteToControlPacketPattern>(&getContext());
144 if (failed(applyPartialConversion(device, target, std::move(patterns)))) {
150struct AIELegalizeControlPacketPass
151 :
public xilinx::AIEX::impl::AIELegalizeControlPacketBase<
152 AIELegalizeControlPacketPass> {
153 void runOnOperation()
override {
154 AIE::DeviceOp device = getOperation();
156 ConversionTarget target(getContext());
157 target.addDynamicallyLegalOp<AIEX::NpuControlPacketOp>([](Operation *op) {
158 auto packetOp = cast<AIEX::NpuControlPacketOp>(op);
160 return packetOp.getDataAttr().size() <= 4;
163 RewritePatternSet patterns(&getContext());
164 patterns.add<ControlPacketSplitPattern>(&getContext(), 4);
166 if (failed(applyPartialConversion(device, target, std::move(patterns)))) {
174std::unique_ptr<OperationPass<AIE::DeviceOp>>
176 return std::make_unique<AIETxnToControlPacketPass>();
179std::unique_ptr<OperationPass<AIE::DeviceOp>>
181 return std::make_unique<AIELegalizeControlPacketPass>();
std::unique_ptr< mlir::OperationPass< AIE::DeviceOp > > createAIETxnToControlPacketPass()
std::unique_ptr< mlir::OperationPass< AIE::DeviceOp > > createAIELegalizeControlPacketPass()