MLIR-AIE
ADFGenerateCppGraph.cpp
Go to the documentation of this file.
1//===- ADFGenerateCppGraph.cpp ----------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2021 Xilinx Inc.
8//
9//===----------------------------------------------------------------------===//
10
12
15
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/IR/BuiltinOps.h"
18#include "mlir/IR/SymbolTable.h"
19#include "mlir/Pass/Pass.h"
20
21#include "llvm/Support/FileSystem.h"
22
23#include <unordered_map>
24#include <vector>
25
26using namespace mlir;
27using namespace xilinx;
28using namespace xilinx::ADF;
29
30/// Manages the indentation as we traverse the IR nesting.
31static int currentindent = 0;
32struct Indent {
33 int indent;
34 Indent() : indent(1) { currentindent += indent; }
35 ~Indent() { currentindent -= indent; }
36};
37static void resetIndent() { currentindent = 0; }
38
39raw_ostream &operator<<(raw_ostream &os, const Indent &indent) {
40 for (int i = 0; i < currentindent; ++i)
41 os << " ";
42 return os;
43}
44
46 raw_ostream &output;
47 GraphWriter(raw_ostream &output) : output(output) {}
48
49 // maps KernelOp to the generated c++ variable name.
50 std::unordered_map<Operation *, std::string> kernelOp2VarName;
51
52 StringRef getCTypeString(const Type &type) {
53 if (llvm::dyn_cast<int8Type>(type))
54 return int8Type::getMnemonic();
55 if (llvm::dyn_cast<int16Type>(type))
56 return int16Type::getMnemonic();
57 if (llvm::dyn_cast<int32Type>(type))
58 return int32Type::getMnemonic();
59 if (llvm::dyn_cast<int64Type>(type))
60 return int64Type::getMnemonic();
61 if (llvm::dyn_cast<uint8Type>(type))
62 return uint8Type::getMnemonic();
63 if (llvm::dyn_cast<uint16Type>(type))
64 return uint16Type::getMnemonic();
65 if (llvm::dyn_cast<uint32Type>(type))
66 return uint32Type::getMnemonic();
67 if (llvm::dyn_cast<uint64Type>(type))
68 return uint64Type::getMnemonic();
69 if (llvm::dyn_cast<floatType>(type))
70 return floatType::getMnemonic();
71 llvm::report_fatal_error("unknown type");
72 }
73
74 std::string getKernelTypeString(const std::string &direction, Type type) {
75 if (auto window = llvm::dyn_cast<WindowType>(type))
76 return (direction + "_window_" + getCTypeString(window.getType()) + " *")
77 .str();
78 if (auto stream = llvm::dyn_cast<StreamType>(type))
79 return (direction + "_stream_" + getCTypeString(stream.getType()) + " *")
80 .str();
81 if (auto stream = llvm::dyn_cast<ParameterType>(type))
82 return std::string(getCTypeString(stream.getType()));
83
84 llvm::report_fatal_error("unknown kernel type");
85 }
86
87 std::string getConnectionTypeString(Type type) {
88 if (auto windowType = llvm::dyn_cast<WindowType>(type))
89 return std::string("window<") + std::to_string(windowType.getSize()) +
90 "> ";
91 if (llvm::dyn_cast<StreamType>(type))
92 return "stream";
93 if (llvm::dyn_cast<ParameterType>(type))
94 return "parameter";
95 llvm::report_fatal_error("unknown connection type");
96 }
97
98 std::string getTempNetName() {
99 static uint32_t netCnt = 0;
100 return std::string("n") + std::to_string(netCnt++);
101 }
102
103 void visitOpResultUsers(GraphInputOp driverOp) {
104 Indent indent;
105 for (auto indexedResult : llvm::enumerate(driverOp->getResults())) {
106 Value result = indexedResult.value();
107 for (OpOperand &userOperand : result.getUses()) {
108 Operation *userOp = userOperand.getOwner();
109 int targetIndex = userOperand.getOperandNumber();
110 if (auto kernel = dyn_cast<KernelOp>(userOp)) {
111 auto funcOp = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
112 driverOp, kernel.getCalleeAttr());
113 Type opType = funcOp.getFunctionType().getInput(targetIndex);
114 std::string targetKernelName = kernelOp2VarName[kernel];
115 output << indent << "connect<" << getConnectionTypeString(opType)
116 << "> ";
117 output << getTempNetName() << " (" << driverOp.getName() << ", "
118 << targetKernelName << ".in[" << targetIndex << "]);\n";
119 }
120 // todo: kernel should not drive graph input, add an mlir verifier
121 // condition
122 }
123 }
124 }
125
126 void visitOpResultUsers(KernelOp source) {
127 Indent indent;
128 std::string sourceKernelName = kernelOp2VarName[source];
129
130 unsigned sourceIndex = 0;
131 for (auto indexedResult : llvm::enumerate(source->getResults())) {
132 Value result = indexedResult.value();
133 for (OpOperand &userOperand : result.getUses()) {
134 Operation *userOp = userOperand.getOwner();
135 int targetIndex = userOperand.getOperandNumber();
136 if (auto kernel = dyn_cast<KernelOp>(userOp)) {
137 auto funcOp = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
138 kernel, kernel.getCalleeAttr());
139 Type opType = funcOp.getFunctionType().getInput(targetIndex);
140 auto targetKernelName = kernelOp2VarName[kernel];
141 output << indent << "connect<" << getConnectionTypeString(opType)
142 << "> ";
143 output << getTempNetName() << " (" << sourceKernelName << ".out["
144 << sourceIndex << "], " << targetKernelName << ".in["
145 << targetIndex << "]);\n";
146 } else if (auto outputOp = dyn_cast<GraphOutputOp>(userOp)) {
147 auto funcOp = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
148 source, source.getCalleeAttr());
149 Type opType = funcOp.getFunctionType().getInput(sourceIndex);
150 output << indent << "connect<" << getConnectionTypeString(opType)
151 << "> ";
152 output << getTempNetName() << " (" << sourceKernelName << ".out["
153 << sourceIndex << "], " << outputOp.getName() << ");\n";
154 }
155 // todo: kernel should not drive graph input, add an mlir verifier
156 // condition
157 }
158 sourceIndex++;
159 }
160 }
161
162 void writeKernelFunctions(ModuleOp module) {
163 output << "#include <adf.h>\n";
164 output << "#ifndef FUNCTION_KERNELS_H\n";
165 output << "#define FUNCTION_KERNELS_H\n\n";
166
167 for (Block &block : module.getBodyRegion())
168 for (auto funcOp : block.getOps<func::FuncOp>()) {
169 output << "void " << funcOp.getSymName() << "(";
170 FunctionType type = funcOp.getFunctionType();
171 for (unsigned i = 0; i < type.getNumInputs(); i++)
172 output << getKernelTypeString("input", type.getInput(i)) << " in" << i
173 << ", ";
174
175 for (unsigned i = 0; i < type.getNumResults(); i++) {
176 output << getKernelTypeString("output", type.getResult(i)) << " out"
177 << i;
178 if (i < type.getNumResults() - 1)
179 output << ", ";
180 else
181 output << ");\n";
182 }
183 }
184 output << "#endif\n\n";
185 }
186
187 void writeClass(GraphOp graph) {
188 output << "#include <adf.h>\n";
189 output << "using namespace adf;\n";
190 output << "class " << graph.getName() << " : public graph {\n";
191 output << "private:\n";
192 int kCnt = 1;
193 {
194 Indent indent;
195 for (Region &region : graph->getRegions())
196 for (Block &block : region.getBlocks())
197 for (const auto kernel : block.getOps<KernelOp>()) {
198 // collect and initialize some kernel info
199 std::string varName = "k" + std::to_string(kCnt);
200 output << indent << "kernel " << varName << ";\n";
201 kernelOp2VarName[kernel] = varName;
202 kCnt++;
203 }
204 }
205
206 output << "\npublic:\n";
207 Indent indent;
208 for (auto op : graph.getBody()->getOps<GraphInputOp>())
209 output << indent << "input_port " << op.getName() << ";\n";
210 for (auto op : graph.getBody()->getOps<GraphOutputOp>())
211 output << indent << "output_port " << op.getName() << ";\n";
212 for (auto op : graph.getBody()->getOps<GraphInOutOp>())
213 output << indent << "inout_port " << op.getName() << ";\n";
214
215 output << "\n" << indent << graph.getName() << "() {\n";
216 // initialize the kernel instances in the adf c++ graph
217 {
218 Indent indent;
219 for (Region &region : graph->getRegions())
220 for (Block &block : region.getBlocks())
221 for (auto kernel : block.getOps<KernelOp>()) {
222 output << indent << kernelOp2VarName[kernel] << " = kernel::create("
223 << kernel.getCallee().str() << ");\n";
224 }
225 }
226
227 output << "\n";
228
229 for (Region &region : graph->getRegions()) {
230 for (Block &block : region.getBlocks()) {
231 for (Operation &op : block.getOperations()) {
232 if (auto port = dyn_cast<GraphInputOp>(op)) {
233 visitOpResultUsers(port);
234 } else if (auto graph = dyn_cast<KernelOp>(op)) {
235 visitOpResultUsers(graph);
236 } else if (auto graph = dyn_cast<GraphOutputOp>(op)) {
237 // the graph output should have no users in adf, do nothing here
238 }
239 } // all op visited
240 }
241 }
242
243 {
244 Indent indent;
245 for (Region &region : graph->getRegions())
246 for (Block &block : region.getBlocks())
247 for (auto kernel : block.getOps<KernelOp>()) {
248 output << indent << "source(" << kernelOp2VarName[kernel] << ") = "
249 << "\"kernels.cc\";\n";
250 output << indent << "runtime<ratio>(" << kernelOp2VarName[kernel]
251 << ") = "
252 << "0.1;\n";
253 }
254 }
255
256 output << indent << "}\n";
257 output << "};\n\n";
258 }
259};
260
261LogicalResult AIE::ADFGenerateCPPGraph(ModuleOp module, raw_ostream &output) {
262 GraphWriter writer(output);
263 resetIndent();
264
265 writer.writeKernelFunctions(module);
266
267 for (Block &block : module.getBodyRegion())
268 for (auto graphOp : block.getOps<GraphOp>())
269 writer.writeClass(graphOp);
270 return success();
271}
raw_ostream & operator<<(raw_ostream &os, const Indent &indent)
mlir::LogicalResult ADFGenerateCPPGraph(mlir::ModuleOp module, llvm::raw_ostream &output)
raw_ostream & output
void visitOpResultUsers(GraphInputOp driverOp)
GraphWriter(raw_ostream &output)
std::string getConnectionTypeString(Type type)
StringRef getCTypeString(const Type &type)
std::string getTempNetName()
void visitOpResultUsers(KernelOp source)
std::string getKernelTypeString(const std::string &direction, Type type)
void writeKernelFunctions(ModuleOp module)
std::unordered_map< Operation *, std::string > kernelOp2VarName
void writeClass(GraphOp graph)