MLIR-AIE
AIETargetHSA.cpp
Go to the documentation of this file.
1//===- AIETargetXAIEV2.cpp --------------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2021 Xilinx Inc.
8// (c) Copyright 2021-2023, Advanced Micro Devices, Inc.
9//
10//===----------------------------------------------------------------------===//
12
16
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/IR/Attributes.h"
19#include "mlir/IR/IRMapping.h"
20#include "mlir/Pass/Pass.h"
21#include "mlir/Tools/mlir-translate/MlirTranslateMain.h"
22
23#include "llvm/ADT/StringExtras.h"
24#include "llvm/IR/Module.h"
25
26using namespace mlir;
27using namespace xilinx;
28using namespace xilinx::AIE;
29using namespace xilinx::AIEX;
30
31namespace xilinx::AIE {
32
33// This string is output at the top of the lowered C++ code.
34const char *hsa_cpp_file_header = R"code(
35// This file was auto-generated by aiecc.py --aie-generate-hsa
36
37#ifndef MLIR_AIE_QUIET
38#define __mlir_aie_verbose(x) x
39#else
40#define __mlir_aie_verbose(x)
41#endif
42
43)code";
44
45std::optional<AIE::ShimDMAAllocationOp>
46getAllocOpForSymbol(AIE::DeviceOp dev, StringRef sym_name) {
47 auto sym = dev.lookupSymbol(sym_name);
48 if (!sym)
49 return std::nullopt;
50
51 auto uses = SymbolTable::getSymbolUses(sym, dev);
52 for (auto use : *uses)
53 if (auto infoOp = dyn_cast<AIE::ShimDMAAllocationOp>(use.getUser()))
54 return infoOp;
55
56 return std::nullopt;
57}
58
59mlir::LogicalResult AIETranslateToHSA(ModuleOp module, raw_ostream &output) {
60
61 DenseMap<TileID, Operation *> tiles;
62 DenseMap<Operation *, SmallVector<BufferOp, 4>> buffers;
63
64 if (module.getOps<DeviceOp>().empty())
65 return module.emitOpError("expected AIE.device operation at toplevel");
66 DeviceOp targetOp = *(module.getOps<DeviceOp>().begin());
67
68 // Putting the standard header
69 output << hsa_cpp_file_header;
70
71 // Getting the sequence function op which contains the instructions
72 auto sequenceOps = targetOp.getOps<AIEX::RuntimeSequenceOp>();
73 if (sequenceOps.empty()) {
74 // If no sequenceOp then just return
75 return success();
76 } else if (std::distance(sequenceOps.begin(), sequenceOps.end()) > 1) {
77 return module.emitOpError("expected at most one sequence operation");
78 }
79 AIEX::RuntimeSequenceOp sequenceOp = *sequenceOps.begin();
80
81 collectTiles(targetOp, tiles);
82 collectBuffers(targetOp, buffers);
83
84 // Generate dynamic data movement
85 output << "void invoke_data_movement(hsa_queue_t *q, hsa_agent_t *a";
86
87 // Looping over every Memcpy operation so we take the correct number of
88 // buffers
89 int num_ops = 0;
90 for (auto op : sequenceOp.getOps<NpuDmaMemcpyNdOp>()) {
91 // Getting the IDs of the buffers
92 auto memref = op.getMemref();
93 Block &entryBB =
94 op->getParentOfType<AIEX::RuntimeSequenceOp>().getBody().front();
95 int arg_idx = -1;
96 for (int i = 0, e = entryBB.getNumArguments(); i < e; i++) {
97 if (entryBB.getArgument(i) == memref) {
98 arg_idx = i;
99 break;
100 }
101 }
102 num_ops++;
103
104 output << ", void *buf" << arg_idx;
105 }
106
107 output << ") {\n";
108
109 output << "\tuint64_t wr_idx = 0;\n";
110 output << "\tuint64_t packet_id = 0;\n";
111
112 int op_count = 0;
113 for (auto op : sequenceOp.getOps<NpuDmaMemcpyNdOp>()) {
114 auto dev = sequenceOp->getParentOfType<AIE::DeviceOp>();
115 if (!dev) {
116 op.emitOpError("couldn't get DeviceOp");
117 return failure();
118 }
119
120 auto infoOp = getAllocOpForSymbol(dev, op.getMetadata());
121 if (!infoOp) {
122 op.emitOpError("couldn't find shim_dma_allocation op");
123 return failure();
124 }
125
126 auto channelDir = infoOp->getChannelDir();
127 uint32_t ChannelId = infoOp->getChannelIndex();
128 bool isMM2S = channelDir == AIE::DMAChannelDir::MM2S;
129 int col = infoOp->getCol();
130 bool isPlio = infoOp->getPlio();
131
132 llvm::SmallVector<int64_t, 4> strides = llvm::map_to_vector(
133 llvm::reverse(op.getMixedStrides()),
134 [](OpFoldResult s) { return getConstantIntValue(s).value(); });
135 ::SmallVector<int64_t, 4> sizes = llvm::map_to_vector(
136 llvm::reverse(op.getMixedSizes()),
137 [](OpFoldResult s) { return getConstantIntValue(s).value(); });
138 ::SmallVector<int64_t, 4> offsets = llvm::map_to_vector(
139 llvm::reverse(op.getMixedOffsets()),
140 [](OpFoldResult s) { return getConstantIntValue(s).value(); });
141
142 // buffer_offset
143 size_t stride = 1;
144 size_t offset = 0;
145 BaseMemRefType my_memref = op.getMemref().getType();
146 auto shape = my_memref.getShape();
147 size_t R = shape.size();
148 size_t el_bit_width = my_memref.getElementTypeBitWidth();
149 assert(el_bit_width % 8 == 0 &&
150 "Expected Memref element bitwidth to be multiple of 8.");
151 size_t S = el_bit_width / 8;
152 for (size_t i = 0; i < R; i++) {
153 offset += offsets[i] * stride * S;
154 stride *= shape[R - i - 1];
155 }
156
157 // Getting the ID of the buffer that we are using
158 auto memref = op.getMemref();
159 Block &entryBB =
160 op->getParentOfType<AIEX::RuntimeSequenceOp>().getBody().front();
161 int arg_idx = -1;
162 for (int i = 0, e = entryBB.getNumArguments(); i < e; i++) {
163 if (entryBB.getArgument(i) == memref) {
164 arg_idx = i;
165 break;
166 }
167 }
168
169 if (strides[0] != 1)
170 return module.emitOpError("nd_memcpy inner-dimension stride != 1 is "
171 "unsupported by HSA target");
172
173 // Writing the packet information to perform the DMA
174 output << "\thsa_agent_dispatch_packet_t pkt" << op_count << " ;\n";
175 output << "\twr_idx = hsa_queue_add_write_index_relaxed(q, 1);\n";
176 output << "\tpacket_id = wr_idx % q->size;\n";
177 output << "\tmlir_aie_packet_nd_memcpy(&pkt" << op_count
178 << ", 0 /* herd_id */, " << col << " /* col */, " << isMM2S
179 << " /* dir */, " << ChannelId
180 << "/* channel */, 4 /* Burst length */, " << (isPlio ? 1 : 2)
181 << " /* Memory space */, "
182 "(uint64_t)buf"
183 << arg_idx << " + " << offset << " /* Address */, " << sizes[0] * 4
184 << " /* 1d_length */, " << (strides[1] ? sizes[1] : 1)
185 << " /* 2d_length */, " << (strides[1] ? strides[1] * 4 : 0)
186 << " /* 2d_stride */, " << (strides[2] ? sizes[2] : 1)
187 << " /* 3d_length */, " << (strides[2] ? strides[2] * 4 : 0)
188 << " /* 3d_stride */ , 1 /* 4d_length */, 0 /* 4d_stride */);\n";
189
190 bool last_op = op_count == (num_ops - 1);
191 // Only ring the doorbell on the last packet
192 if (last_op) {
193 output
194 << "\tmlir_aie_queue_dispatch_and_wait(a, q, packet_id, wr_idx, &pkt"
195 << op_count << ", false);\n\n";
196 } else {
197 output << "\thsa_amd_signal_create_on_agent(1, 0, nullptr, a, 0, &pkt"
198 << op_count << ".completion_signal);\n";
199 output << "\tmlir_aie_write_pkt<hsa_agent_dispatch_packet_t>(q, "
200 "packet_id, &pkt"
201 << op_count << ");\n\n";
202 }
203
204 op_count++;
205 }
206
207 // Waiting to make sure each DMA is complete
208 for (int i = 0; i < op_count; i++) {
209 output << "\twhile (hsa_signal_wait_scacquire(pkt" << i
210 << ".completion_signal,\n";
211 output << "\tHSA_SIGNAL_CONDITION_EQ, 0, 0x80000,\n";
212 output << "\tHSA_WAIT_STATE_ACTIVE) != 0);\n";
213 }
214
215 // Destroying every signal that we created
216 for (int i = 0; i < op_count; i++) {
217 output << "\thsa_signal_destroy(pkt" << i << ".completion_signal);\n";
218 }
219
220 output << "}\n";
221
222 return success();
223}
224} // namespace xilinx::AIE
Include the generated interface declarations.
std::optional< AIE::ShimDMAAllocationOp > getAllocOpForSymbol(AIE::DeviceOp dev, StringRef sym_name)
mlir::LogicalResult AIETranslateToHSA(mlir::ModuleOp module, llvm::raw_ostream &output)
const char * hsa_cpp_file_header