Program Listing for File model_repository.cpp¶
↰ Return to documentation for file (/workspace/amdinfer/src/amdinfer/core/model_repository.cpp)
// Copyright 2022 Xilinx, Inc.
// Copyright 2022 Advanced Micro Devices, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "amdinfer/core/model_repository.hpp"
#include <fcntl.h> // for open, O_RDONLY
#include <google/protobuf/io/zero_copy_stream_impl.h> // for FileInputStream
#include <google/protobuf/repeated_ptr_field.h> // for RepeatedPtrField
#include <google/protobuf/text_format.h> // for TextFormat
#include <chrono> // for milliseconds
#include <filesystem> // for path, operator/
#include <thread> // for sleep_for
#include "amdinfer/core/api.hpp" // for modelLoad
#include "amdinfer/core/exceptions.hpp" // for file_not_found...
#include "amdinfer/core/manager.hpp" // for Manager
#include "amdinfer/core/predict_api.hpp" // for RequestParameters
#include "amdinfer/observation/logging.hpp" // for Logger, PROTEU...
#include "model_config.pb.h" // for Config, InferP...
namespace fs = std::filesystem;
namespace amdinfer {
// TODO(varunsh): get rid of this duplicate code with the one in grpc_internal
void mapProtoToParameters2(
const google::protobuf::Map<std::string, inference::InferParameter2>& params,
RequestParameters* parameters) {
using ParameterType = inference::InferParameter2::ParameterChoiceCase;
for (const auto& [key, value] : params) {
auto type = value.parameter_choice_case();
switch (type) {
case ParameterType::kBoolParam: {
parameters->put(key, value.bool_param());
break;
}
case ParameterType::kInt64Param: {
// TODO(varunsh): parameters should switch to uint64?
parameters->put(key, static_cast<int>(value.int64_param()));
break;
}
case ParameterType::kDoubleParam: {
parameters->put(key, value.double_param());
break;
}
case ParameterType::kStringParam: {
parameters->put(key, value.string_param());
break;
}
default: {
// if not set
break;
}
}
}
}
void ModelRepository::modelLoad(const std::string& model,
RequestParameters* parameters) {
repo_.modelLoad(model, parameters);
}
void ModelRepository::setRepository(const std::string& repository) {
repo_.setRepository(repository);
}
void ModelRepository::enableRepositoryMonitoring(bool use_polling) {
repo_.enableRepositoryMonitoring(use_polling);
}
void ModelRepository::ModelRepositoryImpl::setRepository(
const std::string& repository_path) {
repository_ = repository_path;
}
void ModelRepository::ModelRepositoryImpl::modelLoad(
const std::string& model, RequestParameters* parameters) const {
const fs::path config_file = "config.pbtxt";
auto model_path = repository_ / model;
auto config_path = model_path / config_file;
// KServe can sometimes create directories like model/model/config_file
// so if model/config_file doesn't exist, try searching a directory lower too
if (!fs::exists(config_path) &&
fs::exists(model_path / model / config_file)) {
model_path /= model;
config_path = model_path / config_file;
}
// TODO(varunsh): support other versions than 1/
const std::string model_base = model_path / "1/saved_model";
inference::Config config;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-vararg, hicpp-vararg)
int file_descriptor = open(config_path.c_str(), O_RDONLY | O_CLOEXEC);
Logger logger{Loggers::Server};
if (file_descriptor < 0) {
throw file_not_found_error("Config file " + config_path.string() +
" could not be opened");
}
google::protobuf::io::FileInputStream file_input(file_descriptor);
file_input.SetCloseOnDelete(true);
if (!google::protobuf::TextFormat::Parse(&file_input, &config)) {
throw file_read_error("Config file " + config_path.string() +
" could not be parsed");
}
if (config.platform() == "tensorflow_graphdef") {
const auto& inputs = config.inputs();
// currently supporting one input tensor
for (const auto& input : inputs) {
parameters->put("input_node", input.name());
const auto& shape = input.shape();
// ZenDNN assumes square image in HWC format
parameters->put("input_size", static_cast<int>(shape.at(0)));
parameters->put("image_channels",
static_cast<int>(shape.at(shape.size() - 1)));
}
const auto& outputs = config.outputs();
// currently supporting one output tensor
for (const auto& output : outputs) {
parameters->put("output_node", output.name());
const auto& shape = output.shape();
// ZenDNN assumes [X] classes as output
parameters->put("output_classes", static_cast<int>(shape.at(0)));
}
parameters->put("worker", "tfzendnn");
parameters->put("model", model_base + ".pb");
} else if (config.platform() == "pytorch_torchscript") {
parameters->put("worker", "ptzendnn");
parameters->put("model", model_base + ".pt");
} else if (config.platform() == "onnx_onnxv1") {
parameters->put("worker", "migraphx");
parameters->put("model", model_base + ".onnx");
} else if (config.platform() == "migraphx_mxr") {
parameters->put("worker", "migraphx");
parameters->put("model", model_base + ".mxr");
} else if (config.platform() == "vitis_xmodel") {
parameters->put("worker", "xmodel");
parameters->put("model", model_base + ".xmodel");
} else {
throw invalid_argument("Unknown platform: " + config.platform());
}
mapProtoToParameters2(config.parameters(), parameters);
}
void UpdateListener::handleFileAction([[maybe_unused]] efsw::WatchID watchid,
const std::string& dir,
const std::string& filename,
efsw::Action action,
std::string old_filename) {
Logger logger{Loggers::Server};
// arbitrary delay to make sure filesystem has settled
const std::chrono::milliseconds delay{100};
if (filename == "config.pbtxt") {
if (action == efsw::Actions::Add) {
std::this_thread::sleep_for(delay);
auto model = fs::path(dir).parent_path().filename();
// TODO(varunsh): replace with native client
RequestParameters params;
try {
ModelRepository::modelLoad(model, ¶ms);
Manager::getInstance().loadWorker(model, params);
} catch (const runtime_error&) {
AMDINFER_LOG_INFO(logger, "Error loading " + model.string());
}
} else if (action == efsw::Actions::Delete) {
// arbitrary delay to make sure filesystem has settled
std::this_thread::sleep_for(delay);
auto model = fs::path(dir).parent_path().filename();
// TODO(varunsh): replace with native client
Manager::getInstance().unloadWorker(model);
}
}
switch (action) {
case efsw::Actions::Add:
AMDINFER_LOG_DEBUG(
logger, "DIR (" + dir + ") FILE (" + filename + ") has event Added");
break;
case efsw::Actions::Delete:
AMDINFER_LOG_DEBUG(
logger, "DIR (" + dir + ") FILE (" + filename + ") has event Delete");
break;
case efsw::Actions::Modified:
AMDINFER_LOG_DEBUG(
logger, "DIR (" + dir + ") FILE (" + filename + ") has event Modified");
break;
case efsw::Actions::Moved:
AMDINFER_LOG_DEBUG(logger, "DIR (" + dir + ") FILE (" + filename +
") has event Moved from (" + old_filename +
")");
break;
default:
AMDINFER_LOG_ERROR(logger, "Should never happen");
}
}
void ModelRepository::ModelRepositoryImpl::enableRepositoryMonitoring(
bool use_polling) {
file_watcher_ = std::make_unique<efsw::FileWatcher>(use_polling);
listener_ = std::make_unique<amdinfer::UpdateListener>();
file_watcher_->addWatch(repository_.string(), listener_.get(), true);
file_watcher_->watch();
Logger logger{Loggers::Server};
for (const auto& path : fs::directory_iterator(repository_)) {
if (path.is_directory()) {
auto model = path.path().filename();
try {
RequestParameters params;
amdinfer::modelLoad(model, ¶ms);
} catch (const amdinfer::runtime_error& e) {
AMDINFER_LOG_INFO(logger,
"Error loading " + model.string() + ": " + e.what());
}
}
}
}
} // namespace amdinfer