Program Listing for File predict_api.hpp¶
↰ Return to documentation for file (/workspace/amdinfer/include/amdinfer/core/predict_api.hpp)
// Copyright 2021 Xilinx, Inc.
// Copyright 2022 Advanced Micro Devices, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef GUARD_AMDINFER_CORE_PREDICT_API
#define GUARD_AMDINFER_CORE_PREDICT_API
#include <cstddef> // for size_t, byte
#include <cstdint> // for uint64_t, int32_t
#include <functional> // for function, less
#include <future> // for promise
#include <initializer_list> // for initializer_list
#include <map> // for map, operator==, map<>::...
#include <memory> // for shared_ptr, allocator
#include <sstream> // for operator<<, ostream, bas...
#include <string> // for string, operator<<, char...
#include <string_view> // for string_view
#include <unordered_set> // for unordered_set
#include <utility> // for move
#include <variant> // for operator!=, operator<
#include <vector> // for vector
#include "amdinfer/build_options.hpp" // for AMDINFER_ENABLE_TRACING
#include "amdinfer/core/data_types.hpp" // for DataType, mapTypeToStr
#include "amdinfer/core/mixins.hpp" // for Serializable
#include "amdinfer/declarations.hpp" // for InferenceResponseOutput
namespace amdinfer {
using Parameter = std::variant<bool, int32_t, double, std::string>;
class RequestParameters : public Serializable {
public:
void put(const std::string &key, bool value);
void put(const std::string &key, double value);
void put(const std::string &key, int32_t value);
void put(const std::string &key, const std::string &value);
void put(const std::string &key, const char *value);
template <typename T>
T get(const std::string &key) {
auto &value = this->parameters_.at(key);
return std::get<T>(value);
}
bool has(const std::string &key);
void erase(const std::string &key);
[[nodiscard]] size_t size() const;
[[nodiscard]] bool empty() const;
[[nodiscard]] std::map<std::string, Parameter, std::less<>> data() const;
auto begin() { return parameters_.begin(); }
[[nodiscard]] auto cbegin() const { return parameters_.cbegin(); }
auto end() { return parameters_.end(); }
[[nodiscard]] auto cend() const { return parameters_.cend(); }
[[nodiscard]] size_t serializeSize() const override;
void serialize(std::byte *data_out) const override;
void deserialize(const std::byte *data_in) override;
friend std::ostream &operator<<(std::ostream &os,
RequestParameters const &self) {
std::stringstream ss;
ss << "RequestParameters(" << &self << "):\n";
for (const auto &[key, value] : self.parameters_) {
ss << " " << key << ": ";
std::visit([&](const auto &c) { ss << c; }, value);
ss << "\n";
}
auto tmp = ss.str();
tmp.pop_back(); // delete trailing newline
os << tmp;
return os;
}
private:
std::map<std::string, Parameter, std::less<>> parameters_;
};
using RequestParametersPtr = std::shared_ptr<RequestParameters>;
struct ServerMetadata {
std::string name;
std::string version;
std::unordered_set<std::string> extensions;
};
class InferenceRequestInput : public Serializable {
public:
InferenceRequestInput();
InferenceRequestInput(void *data, std::vector<uint64_t> shape,
DataType data_type, std::string name = "");
void setData(void *buffer);
void setData(std::vector<std::byte> &&buffer);
[[nodiscard]] bool sharedData() const;
[[nodiscard]] void *getData() const;
[[nodiscard]] const std::string &getName() const { return this->name_; }
void setName(std::string name);
[[nodiscard]] const std::vector<uint64_t> &getShape() const {
return this->shape_;
}
void setShape(std::initializer_list<uint64_t> shape) { this->shape_ = shape; }
void setShape(const std::vector<uint64_t> &shape) { this->shape_ = shape; }
void setShape(const std::vector<int32_t> &shape) {
this->shape_.reserve(shape.size());
for (const auto &index : shape) {
this->shape_.push_back(index);
}
}
[[nodiscard]] DataType getDatatype() const { return this->data_type_; }
void setDatatype(DataType type);
[[nodiscard]] RequestParameters *getParameters() const {
return this->parameters_.get();
}
void setParameters(RequestParametersPtr parameters) {
parameters_ = std::move(parameters);
}
[[nodiscard]] size_t getSize() const;
[[nodiscard]] size_t serializeSize() const override;
void serialize(std::byte *data_out) const override;
void deserialize(const std::byte *data_in) override;
friend std::ostream &operator<<(std::ostream &os,
InferenceRequestInput const &my_class) {
os << "InferenceRequestInput:\n";
os << " Name: " << my_class.name_ << "\n";
os << " Shape: ";
for (const auto &index : my_class.shape_) {
os << index << ",";
}
os << "\n";
os << " Datatype: " << my_class.data_type_.str() << "\n";
os << " Parameters:\n";
if (my_class.parameters_ != nullptr) {
os << *(my_class.parameters_.get()) << "\n";
}
os << " Data: " << my_class.getData() << "\n";
return os;
}
private:
std::string name_;
std::vector<uint64_t> shape_;
DataType data_type_;
RequestParametersPtr parameters_;
void *data_;
std::vector<std::byte> shared_data_;
template <typename U>
friend class InferenceRequestInputBuilder;
};
class InferenceRequestOutput {
public:
InferenceRequestOutput();
void setData(void *buffer) { this->data_ = buffer; }
void *getData() { return this->data_; }
[[nodiscard]] std::string getName() const { return this->name_; }
void setName(const std::string &name);
void setParameters(RequestParametersPtr parameters) {
parameters_ = std::move(parameters);
}
RequestParameters *getParameters() { return parameters_.get(); }
private:
std::string name_;
RequestParametersPtr parameters_;
void *data_;
template <typename U>
friend class InferenceRequestOutputBuilder;
};
class InferenceResponse {
public:
InferenceResponse();
explicit InferenceResponse(const std::string &error);
[[nodiscard]] std::vector<InferenceResponseOutput> getOutputs() const;
void addOutput(const InferenceResponseOutput &output);
std::string getID() const { return id_; }
void setID(const std::string &id);
void setModel(const std::string &model);
std::string getModel();
bool isError() const;
std::string getError() const;
#ifdef AMDINFER_ENABLE_TRACING
void setContext(StringMap &&context);
const StringMap &getContext() const;
#endif
RequestParameters *getParameters() { return this->parameters_.get(); }
friend std::ostream &operator<<(std::ostream &os,
InferenceResponse const &my_class) {
os << "Inference Response:\n";
os << " Model: " << my_class.model_ << "\n";
os << " ID: " << my_class.id_ << "\n";
os << " Parameters:\n";
os << " " << *(my_class.parameters_.get()) << "\n";
os << " Outputs:\n";
for (const auto &output : my_class.outputs_) {
os << " " << output << "\n";
}
os << " Error Message: " << my_class.error_msg_ << "\n";
return os;
}
private:
std::string model_;
std::string id_;
std::shared_ptr<RequestParameters> parameters_;
std::vector<InferenceResponseOutput> outputs_;
std::string error_msg_;
#ifdef AMDINFER_ENABLE_TRACING
StringMap context_;
#endif
};
using Callback = std::function<void(const InferenceResponse &)>;
class InferenceRequest {
public:
// Construct a new InferenceRequest object
InferenceRequest() = default;
void setCallback(Callback &&callback);
void runCallback(const InferenceResponse &response);
void runCallbackOnce(const InferenceResponse &response);
void runCallbackError(std::string_view error_msg);
void addInputTensor(void *data, const std::vector<uint64_t> &shape,
DataType data_type, const std::string &name = "");
void addInputTensor(InferenceRequestInput input);
void addOutputTensor(const InferenceRequestOutput &output);
[[nodiscard]] const std::vector<InferenceRequestInput> &getInputs() const;
[[nodiscard]] size_t getInputSize() const;
[[nodiscard]] const std::vector<InferenceRequestOutput> &getOutputs() const;
[[nodiscard]] const std::string &getID() const { return id_; }
void setID(std::string_view id) { id_ = id; }
[[nodiscard]] RequestParameters *getParameters() const {
return this->parameters_.get();
}
void setParameters(RequestParametersPtr parameters) {
parameters_ = std::move(parameters);
}
private:
std::string id_;
RequestParametersPtr parameters_;
std::vector<InferenceRequestInput> inputs_;
std::vector<InferenceRequestOutput> outputs_;
Callback callback_;
// TODO(varunsh): do we need this still?
friend class FakeInferenceRequest;
template <typename U>
friend class InferenceRequestBuilder;
};
using InferenceResponsePromisePtr =
std::shared_ptr<std::promise<InferenceResponse>>;
class ModelMetadataTensor final {
public:
ModelMetadataTensor(const std::string &name, DataType datatype,
std::vector<uint64_t> shape);
[[nodiscard]] const std::string &getName() const;
[[nodiscard]] const DataType &getDataType() const;
[[nodiscard]] const std::vector<uint64_t> &getShape() const;
private:
std::string name_;
DataType datatype_;
std::vector<uint64_t> shape_;
};
class ModelMetadata final {
public:
ModelMetadata(const std::string &name, const std::string &platform);
void addInputTensor(const std::string &name, DataType datatype,
std::initializer_list<uint64_t> shape);
void addInputTensor(const std::string &name, DataType datatype,
std::vector<int> shape);
[[nodiscard]] const std::vector<ModelMetadataTensor> &getInputs() const;
void addOutputTensor(const std::string &name, DataType datatype,
std::initializer_list<uint64_t> shape);
void addOutputTensor(const std::string &name, DataType datatype,
std::vector<int> shape);
[[nodiscard]] const std::vector<ModelMetadataTensor> &getOutputs() const;
void setName(const std::string &name);
[[nodiscard]] const std::string &getName() const;
[[nodiscard]] const std::string &getPlatform() const;
void setReady(bool ready);
[[nodiscard]] bool isReady() const;
private:
std::string name_;
std::vector<std::string> versions_;
std::string platform_;
std::vector<ModelMetadataTensor> inputs_;
std::vector<ModelMetadataTensor> outputs_;
bool ready_;
};
} // namespace amdinfer
namespace std {
template <>
struct less<amdinfer::RequestParameters> {
bool operator()(const amdinfer::RequestParameters &lhs,
const amdinfer::RequestParameters &rhs) const {
auto lhs_size = lhs.size();
auto rhs_size = rhs.size();
auto lhs_map = lhs.data();
auto rhs_map = rhs.data();
if (lhs_size == rhs_size) {
for (const auto &[key, lhs_value] : lhs_map) {
if (rhs_map.find(key) == rhs_map.end()) {
return true;
}
const auto &rhs_value = rhs_map.at(key);
if (lhs_value != rhs_value) {
return lhs_value < rhs_value;
}
}
return false;
}
return lhs_size < rhs_size;
}
};
} // namespace std
#endif // GUARD_AMDINFER_CORE_PREDICT_API