MLIR-AIE
AIETxnToControlPacket.cpp
Go to the documentation of this file.
1//===- AIETxnToControlPacket.cpp --------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2025 Advanced Micro Devices, Inc.
8//
9//===----------------------------------------------------------------------===//
10
14
15#include "mlir/IR/PatternMatch.h"
16#include "mlir/Pass/Pass.h"
17#include "mlir/Transforms/DialectConversion.h"
18
19#include <algorithm>
20
21namespace xilinx::AIEX {
22#define GEN_PASS_DEF_AIELEGALIZECONTROLPACKET
23#define GEN_PASS_DEF_AIETXNTOCONTROLPACKET
24#include "aie/Dialect/AIEX/Transforms/AIEXPasses.h.inc"
25} // namespace xilinx::AIEX
26
27#define DEBUG_TYPE "aie-txn-to-control"
28
29using namespace mlir;
30using namespace xilinx;
31
32namespace {
33
34/// Pattern to convert transaction operations to control operations
35struct BlockWriteToControlPacketPattern
36 : public OpRewritePattern<AIEX::NpuBlockWriteOp> {
37 using OpRewritePattern<AIEX::NpuBlockWriteOp>::OpRewritePattern;
38
39 LogicalResult matchAndRewrite(AIEX::NpuBlockWriteOp op,
40 PatternRewriter &rewriter) const override {
41
42 Value memref = op.getData();
43 int64_t width = cast<MemRefType>(memref.getType()).getElementTypeBitWidth();
44 if (width != 32) {
45 return op.emitWarning("Only 32-bit data type is supported for now");
46 }
47
48 memref::GetGlobalOp getGlobal = memref.getDefiningOp<memref::GetGlobalOp>();
49 if (!getGlobal) {
50 return op.emitError("Only MemRefs from memref.get_global are supported");
51 }
52
53 auto global = dyn_cast_if_present<memref::GlobalOp>(
54 op->getParentOfType<AIE::DeviceOp>().lookupSymbol(getGlobal.getName()));
55 if (!global) {
56 return op.emitError("Global symbol not found");
57 }
58
59 auto initVal = global.getInitialValue();
60 if (!initVal) {
61 return op.emitError("Global symbol has no initial value");
62 }
63
64 auto data = dyn_cast<DenseIntElementsAttr>(*initVal);
65 if (!data) {
66 return op.emitError(
67 "Global symbol initial value is not a dense int array");
68 }
69 std::vector<int32_t> dataVec(data.value_begin<int32_t>(),
70 data.value_end<int32_t>());
71 AIEX::NpuControlPacketOp::create(
72 rewriter, op->getLoc(), op.getAddressAttr(), nullptr,
73 /*opcode*/ rewriter.getI32IntegerAttr(0),
74 /*stream_id*/ rewriter.getI32IntegerAttr(0),
75 DenseI32ArrayAttr::get(op->getContext(), dataVec));
76 rewriter.eraseOp(op);
77 return success();
78 }
79};
80
81/// Pattern to split control packets into smaller control packets
82struct ControlPacketSplitPattern
83 : public OpRewritePattern<AIEX::NpuControlPacketOp> {
84 using OpRewritePattern<AIEX::NpuControlPacketOp>::OpRewritePattern;
85
86 ControlPacketSplitPattern(MLIRContext *ctx, uint32_t max_payload_size)
87 : OpRewritePattern(ctx), max_payload_size(max_payload_size) {}
88
89 LogicalResult matchAndRewrite(AIEX::NpuControlPacketOp op,
90 PatternRewriter &rewriter) const override {
91 auto data = op.getData();
92 if (!data)
93 return failure();
94
95 uint32_t numElements = op.getDataAttr().size();
96
97 if (numElements <= max_payload_size) {
98 return failure(); // No splitting needed
99 }
100
101 auto chunks = llvm::divideCeil(numElements, max_payload_size);
102 auto context = op.getContext();
103 auto loc = op.getLoc();
104 for (unsigned i = 0; i < chunks; ++i) {
105 uint32_t startIdx = i * max_payload_size;
106 uint32_t endIdx = std::min(startIdx + max_payload_size, numElements);
107
108 SmallVector<int32_t, 4> chunkData;
109 for (auto it = data->begin() + startIdx; it != data->begin() + endIdx;
110 ++it) {
111 chunkData.push_back(*it);
112 }
113
114 // Increment the address for each chunk
115 auto incrementedAddress = rewriter.getUI32IntegerAttr(
116 op.getAddress() + (i * 4 * sizeof(uint32_t)));
117
118 AIEX::NpuControlPacketOp::create(
119 rewriter, loc, incrementedAddress, nullptr, op.getOpcodeAttr(),
120 op.getStreamIdAttr(), DenseI32ArrayAttr::get(context, chunkData));
121 }
122
123 rewriter.eraseOp(op);
124 return success();
125 }
126
127private:
128 uint32_t max_payload_size;
129};
130
131struct AIETxnToControlPacketPass
132 : public xilinx::AIEX::impl::AIETxnToControlPacketBase<
133 AIETxnToControlPacketPass> {
134 void runOnOperation() override {
135 AIE::DeviceOp device = getOperation();
136
137 ConversionTarget target(getContext());
138 target.addIllegalOp<AIEX::NpuBlockWriteOp>();
139 target.addLegalOp<AIEX::NpuControlPacketOp>();
140
141 RewritePatternSet patterns(&getContext());
142 patterns.add<BlockWriteToControlPacketPattern>(&getContext());
143
144 if (failed(applyPartialConversion(device, target, std::move(patterns)))) {
145 signalPassFailure();
146 }
147 }
148};
149
150struct AIELegalizeControlPacketPass
151 : public xilinx::AIEX::impl::AIELegalizeControlPacketBase<
152 AIELegalizeControlPacketPass> {
153 void runOnOperation() override {
154 AIE::DeviceOp device = getOperation();
155
156 ConversionTarget target(getContext());
157 target.addDynamicallyLegalOp<AIEX::NpuControlPacketOp>([](Operation *op) {
158 auto packetOp = cast<AIEX::NpuControlPacketOp>(op);
159 // Check the data size
160 return packetOp.getDataAttr().size() <= 4;
161 });
162
163 RewritePatternSet patterns(&getContext());
164 patterns.add<ControlPacketSplitPattern>(&getContext(), 4);
165
166 if (failed(applyPartialConversion(device, target, std::move(patterns)))) {
167 signalPassFailure();
168 }
169 }
170};
171
172} // namespace
173
174std::unique_ptr<OperationPass<AIE::DeviceOp>>
176 return std::make_unique<AIETxnToControlPacketPass>();
177}
178
179std::unique_ptr<OperationPass<AIE::DeviceOp>>
181 return std::make_unique<AIELegalizeControlPacketPass>();
182}
std::unique_ptr< mlir::OperationPass< AIE::DeviceOp > > createAIETxnToControlPacketPass()
std::unique_ptr< mlir::OperationPass< AIE::DeviceOp > > createAIELegalizeControlPacketPass()