Program Listing for File client.cpp¶
↰ Return to documentation for file (/workspace/amdinfer/src/amdinfer/clients/client.cpp)
// Copyright 2022 Xilinx, Inc.
// Copyright 2022 Advanced Micro Devices, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "amdinfer/clients/client.hpp"
#include <algorithm> // for copy, copy_backward
#include <chrono> // for seconds
#include <future> // for future
#include <queue> // for queue
#include <thread> // for sleep_for
#include <unordered_set> // for operator!=, unordered_set
#include "amdinfer/build_options.hpp" // for AMDINFER_ENABLE_LOGGING
#include "amdinfer/core/exceptions.hpp" // for connection_error
#include "amdinfer/observation/logging.hpp" // for getLogDirectory, initLogger
namespace amdinfer {
void initializeClientLogging() {
#ifdef AMDINFER_ENABLE_LOGGING
LogOptions options{
"client", // logger_name
getLogDirectory(),
true, // enable file logging
LogLevel::Debug, // file log level
true, // enable console logging
LogLevel::Warn // console log level
};
initLogger(options);
#endif
}
Client::Client() { initializeClientLogging(); }
bool serverHasExtension(const Client* client, const std::string& extension) {
auto metadata = client->serverMetadata();
return metadata.extensions.find(extension) != metadata.extensions.end();
}
void waitUntilServerReady(const Client* client) {
bool ready = false;
while (!ready) {
try {
ready = client->serverReady();
} catch (const amdinfer::connection_error&) {
// ignore connection errors
std::this_thread::sleep_for(std::chrono::seconds(1));
}
}
}
void waitUntilModelReady(const Client* client, const std::string& model) {
bool ready = false;
while (!ready) {
ready = client->modelReady(model);
}
}
std::vector<InferenceResponse> inferAsyncOrdered(
Client* client, const std::string& model,
const std::vector<InferenceRequest>& requests) {
std::queue<InferenceResponseFuture> q;
for (const auto& request : requests) {
q.push(client->modelInferAsync(model, request));
}
const auto num_requests = requests.size();
std::vector<InferenceResponse> responses;
responses.reserve(num_requests);
for (auto i = 0U; i < num_requests; ++i) {
auto& future = q.front();
responses.push_back(future.get());
q.pop();
}
return responses;
}
std::vector<InferenceResponse> inferAsyncOrderedBatched(
Client* client, const std::string& model,
const std::vector<InferenceRequest>& requests, size_t batch_size) {
auto num_requests = requests.size();
std::vector<InferenceResponse> responses;
responses.reserve(num_requests);
auto start_index = 0U;
std::queue<InferenceResponseFuture> q;
while (start_index + batch_size < num_requests) {
for (auto i = start_index; i < batch_size; ++i) {
q.push(client->modelInferAsync(model, requests[i]));
}
for (auto i = 0U; i < batch_size; ++i) {
auto& future = q.front();
responses.push_back(future.get());
q.pop();
}
start_index += batch_size;
}
if (start_index != num_requests) {
for (auto i = start_index; i < num_requests; ++i) {
q.push(client->modelInferAsync(model, requests[i]));
}
for (auto i = 0U; i < num_requests - start_index; ++i) {
auto& future = q.front();
responses.push_back(future.get());
q.pop();
}
}
return responses;
}
} // namespace amdinfer