MLIR-AIE
test_utils.cpp
Go to the documentation of this file.
1//===- test_utils.cpp ----------------------------000---*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// Copyright (C) 2024, Advanced Micro Devices, Inc.
8//
9//===----------------------------------------------------------------------===//
10
11// This file contains common helper functions for the generic host code
12
13#include "test_utils.h"
14#include <cassert>
15#include <filesystem>
16
17#ifdef TEST_UTILS_USE_XRT
18#include "xrt/xrt_device.h"
19#include "xrt/xrt_kernel.h"
20#endif
21
22// --------------------------------------------------------------------------
23// Command Line Argument Handling
24// --------------------------------------------------------------------------
25
27 std::string name) {
28 if (!result.count(name)) {
29 throw std::runtime_error("Missing required argument: " + name);
30 }
31 std::string path = result[name].as<std::string>();
32 if (!std::filesystem::exists(path)) {
33 throw std::runtime_error("File does not exist: " + path);
34 }
35}
36
38 options.add_options()("help,h", "produce help message")(
39 "xclbin,x", "the input xclbin path", cxxopts::value<std::string>())(
40 "kernel,k", "the kernel name in the XCLBIN (for instance PP_PRE_FD)",
41 cxxopts::value<std::string>())("verbosity,v",
42 "the verbosity of the output",
43 cxxopts::value<int>()->default_value("0"))(
44 "instr,i",
45 "path of file containing userspace instructions sent to the NPU",
46 cxxopts::value<std::string>())(
47 "verify", "whether to verify the AIE computed output",
48 cxxopts::value<bool>()->default_value("true"))(
49 "iters", "number of iterations",
50 cxxopts::value<int>()->default_value("1"))(
51 "warmup", "number of warmup iterations",
52 cxxopts::value<int>()->default_value("0"))(
53 "trace_sz,t", "trace size", cxxopts::value<int>()->default_value("0"))(
54 "trace_file", "where to store trace output",
55 cxxopts::value<std::string>()->default_value("trace.txt"));
56}
57
58void test_utils::parse_options(int argc, const char *argv[],
59 cxxopts::Options &options,
61 try {
62 vm = options.parse(argc, argv);
63
64 if (vm.count("help")) {
65 std::cout << options.help() << "\n";
66 std::exit(1);
67 }
68 } catch (const cxxopts::exceptions::parsing &e) {
69 std::cerr << e.what() << "\n\n";
70 std::cerr << "Usage:\n" << options.help() << "\n";
71 std::exit(1);
72 }
73
74 try {
75 check_arg_file_exists(vm, "xclbin");
76 check_arg_file_exists(vm, "instr");
77 } catch (const std::exception &ex) {
78 std::cerr << ex.what() << "\n\n";
79 }
80}
81
82// --------------------------------------------------------------------------
83// AIE Specifics
84// --------------------------------------------------------------------------
85
86std::vector<uint32_t> test_utils::load_instr_sequence(std::string instr_path) {
87 std::ifstream instr_file(instr_path);
88 std::string line;
89 std::vector<uint32_t> instr_v;
90 while (std::getline(instr_file, line)) {
91 std::istringstream iss(line);
92 uint32_t a;
93 if (!(iss >> std::hex >> a)) {
94 throw std::runtime_error("Unable to parse instruction file\n");
95 }
96 instr_v.push_back(a);
97 }
98 return instr_v;
99}
100
101std::vector<uint32_t> test_utils::load_instr_binary(std::string instr_path) {
102 // Open file in binary mode
103 std::ifstream instr_file(instr_path, std::ios::binary);
104 if (!instr_file.is_open()) {
105 throw std::runtime_error("Unable to open instruction file\n");
106 }
107
108 // Get the size of the file
109 instr_file.seekg(0, std::ios::end);
110 std::streamsize size = instr_file.tellg();
111 instr_file.seekg(0, std::ios::beg);
112
113 // Check that the file size is a multiple of 4 bytes (size of uint32_t)
114 if (size % 4 != 0) {
115 throw std::runtime_error("File size is not a multiple of 4 bytes\n");
116 }
117
118 // Allocate vector and read the binary data
119 std::vector<uint32_t> instr_v(size / 4);
120 if (!instr_file.read(reinterpret_cast<char *>(instr_v.data()), size)) {
121 throw std::runtime_error("Failed to read instruction file\n");
122 }
123 return instr_v;
124}
125
126#ifdef TEST_UTILS_USE_XRT
127
128// --------------------------------------------------------------------------
129// XRT
130// --------------------------------------------------------------------------
131void test_utils::init_xrt_load_kernel(xrt::device &device, xrt::kernel &kernel,
132 int verbosity, std::string xclbinFileName,
133 std::string kernelNameInXclbin) {
134 // Get a device handle
135 unsigned int device_index = 0;
136 device = xrt::device(device_index);
137
138 // Load the xclbin
139 if (verbosity >= 1)
140 std::cout << "Loading xclbin: " << xclbinFileName << "\n";
141 auto xclbin = xrt::xclbin(xclbinFileName);
142
143 if (verbosity >= 1)
144 std::cout << "Kernel opcode: " << kernelNameInXclbin << "\n";
145
146 // Get the kernel from the xclbin
147 auto xkernels = xclbin.get_kernels();
148 auto xkernel =
149 *std::find_if(xkernels.begin(), xkernels.end(),
150 [kernelNameInXclbin, verbosity](xrt::xclbin::kernel &k) {
151 auto name = k.get_name();
152 if (verbosity >= 1) {
153 std::cout << "Name: " << name << std::endl;
154 }
155 return name.rfind(kernelNameInXclbin, 0) == 0;
156 });
157 auto kernelName = xkernel.get_name();
158 // Register xclbin
159 if (verbosity >= 1)
160 std::cout << "Registering xclbin: " << xclbinFileName << "\n";
161
162 device.register_xclbin(xclbin);
163
164 // Get a hardware context
165 if (verbosity >= 1)
166 std::cout << "Getting hardware context.\n";
167 xrt::hw_context context(device, xclbin.get_uuid());
168
169 // Get a kernel handle
170 if (verbosity >= 1)
171 std::cout << "Getting handle to kernel:" << kernelName << "\n";
172 kernel = xrt::kernel(context, kernelName);
173
174 return;
175}
176
177#endif // TEST_UTILS_USE_XRT
178
179// --------------------------------------------------------------------------
180// Matrix / Float / Math
181// --------------------------------------------------------------------------
182
183// nearly_equal function adapted from Stack Overflow, License CC BY-SA 4.0
184// Original author: P-Gn
185// Source: https://stackoverflow.com/a/32334103
186bool test_utils::nearly_equal(float a, float b, float epsilon, float abs_th)
187// those defaults are arbitrary and could be removed
188{
189 assert(std::numeric_limits<float>::epsilon() <= epsilon);
190 assert(epsilon < 1.f);
191
192 if (a == b)
193 return true;
194
195 auto diff = std::abs(a - b);
196 auto norm =
197 std::min((std::abs(a) + std::abs(b)), std::numeric_limits<float>::max());
198 // or even faster: std::min(std::abs(a + b),
199 // std::numeric_limits<float>::max()); keeping this commented out until I
200 // update figures below
201 return diff < std::max(abs_th, epsilon * norm);
202}
203
204// --------------------------------------------------------------------------
205// Tracing
206// --------------------------------------------------------------------------
207void test_utils::write_out_trace(char *traceOutPtr, size_t trace_size,
208 std::string path) {
209 std::ofstream fout(path);
210 uint32_t *traceOut = (uint32_t *)traceOutPtr;
211 for (int i = 0; i < trace_size / sizeof(traceOut[0]); i++) {
212 fout << std::setfill('0') << std::setw(8) << std::hex << (int)traceOut[i];
213 fout << std::endl;
214 }
215}
std::string help(const std::vector< std::string > &groups={}, bool print_usage=true) const
Definition cxxopts.hpp:2120
ParseResult parse(int argc, const char *const *argv)
Definition cxxopts.hpp:1835
OptionAdder add_options(std::string group="")
Definition cxxopts.hpp:1707
std::size_t count(const std::string &o) const
Definition cxxopts.hpp:1302
CXXOPTS_NODISCARD const char * what() const noexcept override
Definition cxxopts.hpp:338
bool nearly_equal(float a, float b, float epsilon=128 *FLT_EPSILON, float abs_th=FLT_MIN)
std::vector< uint32_t > load_instr_binary(std::string instr_path)
void init_xrt_load_kernel(xrt::device &device, xrt::kernel &kernel, int verbosity, std::string xclbinFileName, std::string kernelNameInXclbin)
void parse_options(int argc, const char *argv[], cxxopts::Options &options, cxxopts::ParseResult &result)
void write_out_trace(char *traceOutPtr, size_t trace_size, std::string path)
void check_arg_file_exists(const cxxopts::ParseResult &result, std::string name)
std::vector< uint32_t > load_instr_sequence(std::string instr_path)
void add_default_options(cxxopts::Options &options)