Program Listing for File manager.cpp¶
↰ Return to documentation for file (/workspace/amdinfer/src/amdinfer/core/manager.cpp)
// Copyright 2021 Xilinx, Inc.
// Copyright 2022 Advanced Micro Devices, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "amdinfer/core/manager.hpp"
#include <thread> // for yield, thread
#include <utility> // for pair, make_pair, move
#include "amdinfer/build_options.hpp" // for kMaxModelNameSize
#include "amdinfer/core/exceptions.hpp" // for invalid_argument
#include "amdinfer/core/worker_info.hpp" // for WorkerInfo
#include "amdinfer/util/thread.hpp" // for setThreadName
#include "amdinfer/workers/worker.hpp" // for Worker
namespace amdinfer {
Manager::Manager() {
update_queue_ = std::make_unique<UpdateCommandQueue>();
init();
}
Manager::~Manager() { shutdown(); }
void Manager::init() {
// default constructed threads are not joinable
if (!update_thread_.joinable()) {
update_thread_ =
std::thread(&Manager::updateManager, this, update_queue_.get());
}
}
std::string Manager::loadWorker(std::string const& key,
RequestParameters parameters) {
std::shared_ptr<amdinfer::UpdateCommand> request;
std::string retval;
retval.reserve(kMaxModelNameSize);
retval = "";
request = std::make_shared<UpdateCommand>(UpdateCommandType::Add, key,
¶meters, &retval);
update_queue_->enqueue(request);
while (static_cast<std::string*>(request->retval)->empty() &&
request->eptr == nullptr) {
std::this_thread::yield();
}
if (request->eptr != nullptr) {
std::rethrow_exception(request->eptr);
}
auto endpoint = *(static_cast<std::string*>(request->retval));
return endpoint;
}
void Manager::unloadWorker(const std::string& key) {
if (this->endpoints_.exists(key)) {
auto request =
std::make_shared<UpdateCommand>(UpdateCommandType::Delete, key);
update_queue_->enqueue(request);
}
}
WorkerInfo* Manager::getWorker(const std::string& key) const {
return this->endpoints_.get(key);
}
bool Manager::workerReady(const std::string& key) const {
std::shared_ptr<amdinfer::UpdateCommand> request;
int ready = -1;
request = std::make_shared<UpdateCommand>(UpdateCommandType::Ready, key,
nullptr, &ready);
update_queue_->enqueue(request);
while (ready == -1 && request->eptr == nullptr) {
std::this_thread::yield();
}
if (request->eptr != nullptr) {
std::rethrow_exception(request->eptr);
}
return ready != 0;
}
// FIXME(varunsh): potential race condition if the worker is being deleted
ModelMetadata Manager::getWorkerMetadata(const std::string& key) const {
auto* worker = this->getWorker(key);
if (worker == nullptr) {
throw invalid_argument("Worker " + key + " not found");
}
auto* worker_class = worker->workers_.begin()->second;
return worker_class->getMetadata();
}
void Manager::workerAllocate(std::string const& key, int num) {
const auto* worker = this->getWorker(key);
if (worker == nullptr) {
throw invalid_argument("Worker " + key + " not found");
}
auto request =
std::make_shared<UpdateCommand>(UpdateCommandType::Allocate, key, &num);
update_queue_->enqueue(request);
while (!worker->inputSizeValid(num) && request->eptr == nullptr) {
std::this_thread::yield();
}
if (request->eptr != nullptr) {
std::rethrow_exception(request->eptr);
}
}
std::vector<std::string> Manager::getWorkerEndpoints() {
return this->endpoints_.list();
}
// TODO(varunsh): if multiple commands sent post-shutdown, they will linger
// in the queue and may cause problems
void Manager::shutdown() {
if (this->update_thread_.joinable()) {
auto request = std::make_shared<UpdateCommand>(UpdateCommandType::Shutdown);
this->update_queue_->enqueue(request);
this->update_thread_.join();
}
}
void Manager::updateManager(UpdateCommandQueue* input_queue) {
AMDINFER_LOG_DEBUG(logger_, "Starting the Manager update thread");
util::setThreadName("manager");
std::shared_ptr<UpdateCommand> request;
bool run = true;
while (run) {
input_queue->wait_dequeue(request);
AMDINFER_LOG_DEBUG(logger_,
"Got request in Manager update thread with ID " +
std::to_string(static_cast<int>(request->cmd)));
switch (request->cmd) {
case UpdateCommandType::Shutdown:
this->endpoints_.shutdown();
run = false;
break;
case UpdateCommandType::Delete:
this->endpoints_.unload(request->key);
break;
case UpdateCommandType::Allocate:
try {
auto* worker_info = this->endpoints_.get(request->key);
auto num = *static_cast<int*>(request->object);
if (!worker_info->inputSizeValid(num)) {
AMDINFER_LOG_DEBUG(
logger_,
"Allocating more buffers for worker " + request->key);
worker_info->allocate(num);
}
} catch (...) {
request->eptr = std::current_exception();
}
break;
case UpdateCommandType::Add:
try {
auto* parameters = static_cast<RequestParameters*>(request->object);
auto endpoint = this->endpoints_.add(request->key, *parameters);
static_cast<std::string*>(request->retval)
->assign(std::string{endpoint});
} catch (...) {
request->eptr = std::current_exception();
}
break;
case UpdateCommandType::Ready:
try {
auto* worker_info = this->getWorker(request->key);
if (worker_info == nullptr) {
throw invalid_argument("Worker " + request->key + " not found");
}
auto* worker = worker_info->workers_.begin()->second;
auto metadata = worker->getMetadata();
*static_cast<int*>(request->retval) =
static_cast<int>(metadata.isReady());
} catch (...) {
request->eptr = std::current_exception();
}
break;
}
}
AMDINFER_LOG_DEBUG(logger_, "Ending update_thread");
}
std::string Manager::Endpoints::load(const std::string& worker,
RequestParameters* parameters) {
if (worker_endpoints_.find(worker) == worker_endpoints_.end()) {
// this is a brand-new worker we haven't seen before
std::map<RequestParameters, std::string> map;
map.insert(std::make_pair(*parameters, worker));
worker_endpoints_.insert(std::make_pair(worker, map));
worker_parameters_.insert(std::make_pair(worker, *parameters));
return worker;
}
auto& map = worker_endpoints_.at(worker);
// we've seen this worker before but not with these parameters
if (map.find(*parameters) == map.end()) {
int index = 0;
if (worker_indices_.find(worker) != worker_indices_.end()) {
// technically, this can overflow and cause problems but that's unlikely
index = worker_indices_.at(worker) + 1;
}
std::string url = worker + "-" + std::to_string(index);
map.insert(std::make_pair(*parameters, url));
worker_indices_.insert_or_assign(worker, index);
worker_parameters_.insert(std::make_pair(url, *parameters));
return url;
}
return map.at(*parameters);
}
void Manager::Endpoints::unload(const std::string& endpoint) {
auto hyphen_pos = endpoint.find('-');
auto worker =
hyphen_pos != std::string::npos ? endpoint.substr(0, hyphen_pos) : endpoint;
auto* worker_info = this->get(endpoint);
if (worker_info != nullptr) {
worker_info->unload();
}
// if it's a brand-new worker that failed or the last worker being unloaded,
// clean up our parameters and endpoint metadata
if (worker_info == nullptr || worker_info->getGroupSize() == 0) {
this->workers_.erase(endpoint);
if (worker_endpoints_.find(worker) != worker_endpoints_.end()) {
auto& map = worker_endpoints_.at(worker);
if (worker_parameters_.find(endpoint) != worker_parameters_.end()) {
const auto& parameters = worker_parameters_.at(endpoint);
map.erase(parameters);
}
if (map.empty()) {
worker_endpoints_.erase(worker);
worker_indices_.erase(worker);
}
worker_parameters_.erase(endpoint);
}
}
}
bool Manager::Endpoints::exists(const std::string& endpoint) {
return workers_.find(endpoint) != workers_.end();
}
std::vector<std::string> Manager::Endpoints::list() const {
std::vector<std::string> workers;
workers.reserve(this->workers_.size());
for (const auto& [worker, _] : workers_) {
workers.push_back(worker);
}
return workers;
}
WorkerInfo* Manager::Endpoints::get(const std::string& endpoint) const {
auto iterator = workers_.find(endpoint);
if (iterator != workers_.end()) {
return iterator->second.get();
}
return nullptr;
}
std::string Manager::Endpoints::add(const std::string& worker,
RequestParameters parameters) {
bool share = true;
if (parameters.has("share")) {
share = parameters.get<bool>("share");
parameters.erase("share");
}
auto endpoint = this->load(worker, ¶meters);
auto* worker_info = this->get(endpoint);
std::string worker_name = endpoint;
if (parameters.has("worker")) {
worker_name = parameters.get<std::string>("worker");
}
// if the worker doesn't exist yet, we need to create it
try {
if (worker_info == nullptr) {
auto new_worker = std::make_unique<WorkerInfo>(worker_name, ¶meters);
this->workers_.try_emplace(endpoint, std::move(new_worker));
// if the worker exists but the share parameter is false, we need to add
// one
} else if (!share) {
worker_info->addAndStartWorker(worker_name, ¶meters);
}
} catch (...) {
// undo the load if the worker creation fails
this->unload(endpoint);
throw;
}
return endpoint;
}
void Manager::Endpoints::shutdown() {
for (auto const& worker_info : this->workers_) {
worker_info.second->shutdown();
}
this->workers_.clear();
this->worker_endpoints_.clear();
this->worker_indices_.clear();
this->worker_parameters_.clear();
}
} // namespace amdinfer