46 llvm::StringRef deviceName) {
48 DenseMap<TileID, Operation *> tiles;
49 DenseMap<Operation *, SmallVector<BufferOp, 4>> buffers;
51 DeviceOp targetOp = AIE::DeviceOp::getForSymbolInModule(module, deviceName);
53 return module.emitOpError("expected AIE.device operation at toplevel");
59 auto sequenceOps = targetOp.getOps<AIE::RuntimeSequenceOp>();
60 if (sequenceOps.empty()) {
63 }
else if (std::distance(sequenceOps.begin(), sequenceOps.end()) > 1) {
64 return module.emitOpError("expected at most one sequence operation");
66 AIE::RuntimeSequenceOp sequenceOp = *sequenceOps.begin();
72 output <<
"void invoke_data_movement(hsa_queue_t *q, hsa_agent_t *a";
77 for (
auto op : sequenceOp.getOps<NpuDmaMemcpyNdOp>()) {
79 auto memref = op.getMemref();
81 op->getParentOfType<AIE::RuntimeSequenceOp>().getBody().front();
83 for (
int i = 0, e = entryBB.getNumArguments(); i < e; i++) {
84 if (entryBB.getArgument(i) == memref) {
91 output <<
", void *buf" << arg_idx;
96 output <<
"\tuint64_t wr_idx = 0;\n";
97 output <<
"\tuint64_t packet_id = 0;\n";
100 for (
auto op : sequenceOp.getOps<NpuDmaMemcpyNdOp>()) {
101 auto dev = sequenceOp->getParentOfType<AIE::DeviceOp>();
103 op.emitOpError(
"couldn't get DeviceOp");
107 AIE::ShimDMAAllocationOp infoOp = AIE::ShimDMAAllocationOp::getForSymbol(
108 dev, op.getMetadata().getRootReference());
110 op.emitOpError(
"couldn't find shim_dma_allocation op");
114 AIE::TileOp tile = infoOp.getTileOp();
116 op.emitOpError(
"shim_dma_allocation op must reference a valid TileOp");
120 auto channelDir = infoOp.getChannelDir();
121 uint32_t ChannelId = infoOp.getChannelIndex();
122 bool isMM2S = channelDir == AIE::DMAChannelDir::MM2S;
123 int col = tile.getCol();
124 bool isPlio = infoOp.getPlio();
126 llvm::SmallVector<int64_t, 4> strides = llvm::map_to_vector(
127 llvm::reverse(op.getMixedStrides()),
128 [](OpFoldResult s) { return getConstantIntValue(s).value(); });
129 ::SmallVector<int64_t, 4> sizes = llvm::map_to_vector(
130 llvm::reverse(op.getMixedSizes()),
131 [](OpFoldResult s) { return getConstantIntValue(s).value(); });
132 ::SmallVector<int64_t, 4> offsets = llvm::map_to_vector(
133 llvm::reverse(op.getMixedOffsets()),
134 [](OpFoldResult s) { return getConstantIntValue(s).value(); });
139 BaseMemRefType my_memref = op.getMemref().getType();
140 auto shape = my_memref.getShape();
141 size_t R = shape.size();
142 size_t el_bit_width = op.getElementTypeBitwidth();
143 assert(el_bit_width % 8 == 0 &&
144 "Expected Memref element bitwidth to be multiple of 8.");
145 size_t S = el_bit_width / 8;
146 for (
size_t i = 0; i < R; i++) {
147 offset += offsets[i] * stride * S;
148 stride *= shape[R - i - 1];
152 auto memref = op.getMemref();
154 op->getParentOfType<AIE::RuntimeSequenceOp>().getBody().front();
156 for (
int i = 0, e = entryBB.getNumArguments(); i < e; i++) {
157 if (entryBB.getArgument(i) == memref) {
164 return module.emitOpError("nd_memcpy inner-dimension stride != 1 is "
165 "unsupported by HSA target");
168 output <<
"\thsa_agent_dispatch_packet_t pkt" << op_count <<
" ;\n";
169 output <<
"\twr_idx = hsa_queue_add_write_index_relaxed(q, 1);\n";
170 output <<
"\tpacket_id = wr_idx % q->size;\n";
171 output <<
"\tmlir_aie_packet_nd_memcpy(&pkt" << op_count
172 <<
", 0 /* herd_id */, " <<
col <<
" /* col */, " << isMM2S
173 <<
" /* dir */, " << ChannelId
174 <<
"/* channel */, 4 /* Burst length */, " << (isPlio ? 1 : 2)
175 <<
" /* Memory space */, "
177 << arg_idx <<
" + " << offset <<
" /* Address */, " << sizes[0] * 4
178 <<
" /* 1d_length */, " << (strides[1] ? sizes[1] : 1)
179 <<
" /* 2d_length */, " << (strides[1] ? strides[1] * 4 : 0)
180 <<
" /* 2d_stride */, " << (strides[2] ? sizes[2] : 1)
181 <<
" /* 3d_length */, " << (strides[2] ? strides[2] * 4 : 0)
182 <<
" /* 3d_stride */ , 1 /* 4d_length */, 0 /* 4d_stride */);\n";
184 bool last_op = op_count == (num_ops - 1);
188 <<
"\tmlir_aie_queue_dispatch_and_wait(a, q, packet_id, wr_idx, &pkt"
189 << op_count <<
", false);\n\n";
191 output <<
"\thsa_amd_signal_create_on_agent(1, 0, nullptr, a, 0, &pkt"
192 << op_count <<
".completion_signal);\n";
193 output <<
"\tmlir_aie_write_pkt<hsa_agent_dispatch_packet_t>(q, "
195 << op_count <<
");\n\n";
202 for (
int i = 0; i < op_count; i++) {
203 output <<
"\twhile (hsa_signal_wait_scacquire(pkt" << i
204 <<
".completion_signal,\n";
205 output <<
"\tHSA_SIGNAL_CONDITION_EQ, 0, 0x80000,\n";
206 output <<
"\tHSA_WAIT_STATE_ACTIVE) != 0);\n";
210 for (
int i = 0; i < op_count; i++) {
211 output <<
"\thsa_signal_destroy(pkt" << i <<
".completion_signal);\n";