61 DenseMap<TileID, Operation *> tiles;
62 DenseMap<Operation *, SmallVector<BufferOp, 4>> buffers;
64 if (module.getOps<DeviceOp>().empty())
65 return module.emitOpError("expected AIE.device operation at toplevel");
66 DeviceOp targetOp = *(
module.getOps<DeviceOp>().begin());
72 auto sequenceOps = targetOp.getOps<AIEX::RuntimeSequenceOp>();
73 if (sequenceOps.empty()) {
76 }
else if (std::distance(sequenceOps.begin(), sequenceOps.end()) > 1) {
77 return module.emitOpError("expected at most one sequence operation");
79 AIEX::RuntimeSequenceOp sequenceOp = *sequenceOps.begin();
81 collectTiles(targetOp, tiles);
82 collectBuffers(targetOp, buffers);
85 output <<
"void invoke_data_movement(hsa_queue_t *q, hsa_agent_t *a";
90 for (
auto op : sequenceOp.getOps<NpuDmaMemcpyNdOp>()) {
92 auto memref = op.getMemref();
94 op->getParentOfType<AIEX::RuntimeSequenceOp>().getBody().front();
96 for (
int i = 0, e = entryBB.getNumArguments(); i < e; i++) {
97 if (entryBB.getArgument(i) == memref) {
104 output <<
", void *buf" << arg_idx;
109 output <<
"\tuint64_t wr_idx = 0;\n";
110 output <<
"\tuint64_t packet_id = 0;\n";
113 for (
auto op : sequenceOp.getOps<NpuDmaMemcpyNdOp>()) {
114 auto dev = sequenceOp->getParentOfType<AIE::DeviceOp>();
116 op.emitOpError(
"couldn't get DeviceOp");
122 op.emitOpError(
"couldn't find shim_dma_allocation op");
126 auto channelDir = infoOp->getChannelDir();
127 uint32_t ChannelId = infoOp->getChannelIndex();
128 bool isMM2S = channelDir == AIE::DMAChannelDir::MM2S;
129 int col = infoOp->getCol();
130 bool isPlio = infoOp->getPlio();
132 llvm::SmallVector<int64_t, 4> strides = llvm::map_to_vector(
133 llvm::reverse(op.getMixedStrides()),
134 [](OpFoldResult s) { return getConstantIntValue(s).value(); });
135 ::SmallVector<int64_t, 4> sizes = llvm::map_to_vector(
136 llvm::reverse(op.getMixedSizes()),
137 [](OpFoldResult s) { return getConstantIntValue(s).value(); });
138 ::SmallVector<int64_t, 4> offsets = llvm::map_to_vector(
139 llvm::reverse(op.getMixedOffsets()),
140 [](OpFoldResult s) { return getConstantIntValue(s).value(); });
145 BaseMemRefType my_memref = op.getMemref().getType();
146 auto shape = my_memref.getShape();
147 size_t R = shape.size();
148 size_t el_bit_width = my_memref.getElementTypeBitWidth();
149 assert(el_bit_width % 8 == 0 &&
150 "Expected Memref element bitwidth to be multiple of 8.");
151 size_t S = el_bit_width / 8;
152 for (
size_t i = 0; i < R; i++) {
153 offset += offsets[i] * stride * S;
154 stride *= shape[R - i - 1];
158 auto memref = op.getMemref();
160 op->getParentOfType<AIEX::RuntimeSequenceOp>().getBody().front();
162 for (
int i = 0, e = entryBB.getNumArguments(); i < e; i++) {
163 if (entryBB.getArgument(i) == memref) {
170 return module.emitOpError("nd_memcpy inner-dimension stride != 1 is "
171 "unsupported by HSA target");
174 output <<
"\thsa_agent_dispatch_packet_t pkt" << op_count <<
" ;\n";
175 output <<
"\twr_idx = hsa_queue_add_write_index_relaxed(q, 1);\n";
176 output <<
"\tpacket_id = wr_idx % q->size;\n";
177 output <<
"\tmlir_aie_packet_nd_memcpy(&pkt" << op_count
178 <<
", 0 /* herd_id */, " <<
col <<
" /* col */, " << isMM2S
179 <<
" /* dir */, " << ChannelId
180 <<
"/* channel */, 4 /* Burst length */, " << (isPlio ? 1 : 2)
181 <<
" /* Memory space */, "
183 << arg_idx <<
" + " << offset <<
" /* Address */, " << sizes[0] * 4
184 <<
" /* 1d_length */, " << (strides[1] ? sizes[1] : 1)
185 <<
" /* 2d_length */, " << (strides[1] ? strides[1] * 4 : 0)
186 <<
" /* 2d_stride */, " << (strides[2] ? sizes[2] : 1)
187 <<
" /* 3d_length */, " << (strides[2] ? strides[2] * 4 : 0)
188 <<
" /* 3d_stride */ , 1 /* 4d_length */, 0 /* 4d_stride */);\n";
190 bool last_op = op_count == (num_ops - 1);
194 <<
"\tmlir_aie_queue_dispatch_and_wait(a, q, packet_id, wr_idx, &pkt"
195 << op_count <<
", false);\n\n";
197 output <<
"\thsa_amd_signal_create_on_agent(1, 0, nullptr, a, 0, &pkt"
198 << op_count <<
".completion_signal);\n";
199 output <<
"\tmlir_aie_write_pkt<hsa_agent_dispatch_packet_t>(q, "
201 << op_count <<
");\n\n";
208 for (
int i = 0; i < op_count; i++) {
209 output <<
"\twhile (hsa_signal_wait_scacquire(pkt" << i
210 <<
".completion_signal,\n";
211 output <<
"\tHSA_SIGNAL_CONDITION_EQ, 0, 0x80000,\n";
212 output <<
"\tHSA_WAIT_STATE_ACTIVE) != 0);\n";
216 for (
int i = 0; i < op_count; i++) {
217 output <<
"\thsa_signal_destroy(pkt" << i <<
".completion_signal);\n";