Program Listing for File predict_api.hpp

Return to documentation for file (/workspace/amdinfer/include/amdinfer/core/predict_api.hpp)

// Copyright 2021 Xilinx, Inc.
// Copyright 2022 Advanced Micro Devices, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef GUARD_AMDINFER_CORE_PREDICT_API
#define GUARD_AMDINFER_CORE_PREDICT_API

#include <cstddef>           // for size_t, byte
#include <cstdint>           // for uint64_t, int32_t
#include <functional>        // for function, less
#include <future>            // for promise
#include <initializer_list>  // for initializer_list
#include <map>               // for map, operator==, map<>::...
#include <memory>            // for shared_ptr, allocator
#include <sstream>           // for operator<<, ostream, bas...
#include <string>            // for string, operator<<, char...
#include <string_view>       // for string_view
#include <unordered_set>     // for unordered_set
#include <utility>           // for move
#include <variant>           // for operator!=, operator<
#include <vector>            // for vector

#include "amdinfer/build_options.hpp"    // for AMDINFER_ENABLE_TRACING
#include "amdinfer/core/data_types.hpp"  // for DataType, mapTypeToStr
#include "amdinfer/core/mixins.hpp"      // for Serializable
#include "amdinfer/declarations.hpp"     // for InferenceResponseOutput

namespace amdinfer {

using Parameter = std::variant<bool, int32_t, double, std::string>;

class RequestParameters : public Serializable {
 public:
  void put(const std::string &key, bool value);
  void put(const std::string &key, double value);
  void put(const std::string &key, int32_t value);
  void put(const std::string &key, const std::string &value);
  void put(const std::string &key, const char *value);

  template <typename T>
  T get(const std::string &key) {
    auto &value = this->parameters_.at(key);
    return std::get<T>(value);
  }

  bool has(const std::string &key);
  void erase(const std::string &key);
  [[nodiscard]] size_t size() const;
  [[nodiscard]] bool empty() const;
  [[nodiscard]] std::map<std::string, Parameter, std::less<>> data() const;

  auto begin() { return parameters_.begin(); }
  [[nodiscard]] auto cbegin() const { return parameters_.cbegin(); }

  auto end() { return parameters_.end(); }
  [[nodiscard]] auto cend() const { return parameters_.cend(); }

  [[nodiscard]] size_t serializeSize() const override;
  void serialize(std::byte *data_out) const override;
  void deserialize(const std::byte *data_in) override;

  friend std::ostream &operator<<(std::ostream &os,
                                  RequestParameters const &self) {
    std::stringstream ss;
    ss << "RequestParameters(" << &self << "):\n";
    for (const auto &[key, value] : self.parameters_) {
      ss << "  " << key << ": ";
      std::visit([&](const auto &c) { ss << c; }, value);
      ss << "\n";
    }
    auto tmp = ss.str();
    tmp.pop_back();  // delete trailing newline
    os << tmp;
    return os;
  }

 private:
  std::map<std::string, Parameter, std::less<>> parameters_;
};

using RequestParametersPtr = std::shared_ptr<RequestParameters>;

struct ServerMetadata {
  std::string name;
  std::string version;
  std::unordered_set<std::string> extensions;
};

class InferenceRequestInput : public Serializable {
 public:
  InferenceRequestInput();

  InferenceRequestInput(void *data, std::vector<uint64_t> shape,
                        DataType data_type, std::string name = "");

  void setData(void *buffer);
  void setData(std::vector<std::byte> &&buffer);
  [[nodiscard]] bool sharedData() const;

  [[nodiscard]] void *getData() const;

  [[nodiscard]] const std::string &getName() const { return this->name_; }
  void setName(std::string name);

  [[nodiscard]] const std::vector<uint64_t> &getShape() const {
    return this->shape_;
  }
  void setShape(std::initializer_list<uint64_t> shape) { this->shape_ = shape; }
  void setShape(const std::vector<uint64_t> &shape) { this->shape_ = shape; }
  void setShape(const std::vector<int32_t> &shape) {
    this->shape_.reserve(shape.size());
    for (const auto &index : shape) {
      this->shape_.push_back(index);
    }
  }

  [[nodiscard]] DataType getDatatype() const { return this->data_type_; }
  void setDatatype(DataType type);

  [[nodiscard]] RequestParameters *getParameters() const {
    return this->parameters_.get();
  }
  void setParameters(RequestParametersPtr parameters) {
    parameters_ = std::move(parameters);
  }

  [[nodiscard]] size_t getSize() const;

  [[nodiscard]] size_t serializeSize() const override;
  void serialize(std::byte *data_out) const override;
  void deserialize(const std::byte *data_in) override;

  friend std::ostream &operator<<(std::ostream &os,
                                  InferenceRequestInput const &my_class) {
    os << "InferenceRequestInput:\n";
    os << "  Name: " << my_class.name_ << "\n";
    os << "  Shape: ";
    for (const auto &index : my_class.shape_) {
      os << index << ",";
    }
    os << "\n";
    os << "  Datatype: " << my_class.data_type_.str() << "\n";
    os << "  Parameters:\n";
    if (my_class.parameters_ != nullptr) {
      os << *(my_class.parameters_.get()) << "\n";
    }
    os << "  Data: " << my_class.getData() << "\n";
    return os;
  }

 private:
  std::string name_;
  std::vector<uint64_t> shape_;
  DataType data_type_;
  RequestParametersPtr parameters_;
  void *data_;
  std::vector<std::byte> shared_data_;

  template <typename U>
  friend class InferenceRequestInputBuilder;
};

class InferenceRequestOutput {
 public:
  InferenceRequestOutput();

  void setData(void *buffer) { this->data_ = buffer; }

  void *getData() { return this->data_; }

  [[nodiscard]] std::string getName() const { return this->name_; }
  void setName(const std::string &name);

  void setParameters(RequestParametersPtr parameters) {
    parameters_ = std::move(parameters);
  }
  RequestParameters *getParameters() { return parameters_.get(); }

 private:
  std::string name_;
  RequestParametersPtr parameters_;
  void *data_;

  template <typename U>
  friend class InferenceRequestOutputBuilder;
};

class InferenceResponse {
 public:
  InferenceResponse();

  explicit InferenceResponse(const std::string &error);

  [[nodiscard]] std::vector<InferenceResponseOutput> getOutputs() const;
  void addOutput(const InferenceResponseOutput &output);

  std::string getID() const { return id_; }
  void setID(const std::string &id);
  void setModel(const std::string &model);
  std::string getModel();

  bool isError() const;
  std::string getError() const;

#ifdef AMDINFER_ENABLE_TRACING

  void setContext(StringMap &&context);
  const StringMap &getContext() const;
#endif

  RequestParameters *getParameters() { return this->parameters_.get(); }

  friend std::ostream &operator<<(std::ostream &os,
                                  InferenceResponse const &my_class) {
    os << "Inference Response:\n";
    os << "  Model: " << my_class.model_ << "\n";
    os << "  ID: " << my_class.id_ << "\n";
    os << "  Parameters:\n";
    os << "    " << *(my_class.parameters_.get()) << "\n";
    os << "  Outputs:\n";
    for (const auto &output : my_class.outputs_) {
      os << "    " << output << "\n";
    }
    os << "  Error Message: " << my_class.error_msg_ << "\n";
    return os;
  }

 private:
  std::string model_;
  std::string id_;
  std::shared_ptr<RequestParameters> parameters_;
  std::vector<InferenceResponseOutput> outputs_;
  std::string error_msg_;
#ifdef AMDINFER_ENABLE_TRACING
  StringMap context_;
#endif
};

using Callback = std::function<void(const InferenceResponse &)>;

class InferenceRequest {
 public:
  // Construct a new InferenceRequest object
  InferenceRequest() = default;

  void setCallback(Callback &&callback);
  void runCallback(const InferenceResponse &response);
  void runCallbackOnce(const InferenceResponse &response);
  void runCallbackError(std::string_view error_msg);

  void addInputTensor(void *data, const std::vector<uint64_t> &shape,
                      DataType data_type, const std::string &name = "");

  void addInputTensor(InferenceRequestInput input);
  void addOutputTensor(const InferenceRequestOutput &output);

  [[nodiscard]] const std::vector<InferenceRequestInput> &getInputs() const;
  [[nodiscard]] size_t getInputSize() const;

  [[nodiscard]] const std::vector<InferenceRequestOutput> &getOutputs() const;

  [[nodiscard]] const std::string &getID() const { return id_; }
  void setID(std::string_view id) { id_ = id; }

  [[nodiscard]] RequestParameters *getParameters() const {
    return this->parameters_.get();
  }
  void setParameters(RequestParametersPtr parameters) {
    parameters_ = std::move(parameters);
  }

 private:
  std::string id_;
  RequestParametersPtr parameters_;
  std::vector<InferenceRequestInput> inputs_;
  std::vector<InferenceRequestOutput> outputs_;
  Callback callback_;

  // TODO(varunsh): do we need this still?
  friend class FakeInferenceRequest;
  template <typename U>
  friend class InferenceRequestBuilder;
};
using InferenceResponsePromisePtr =
  std::shared_ptr<std::promise<InferenceResponse>>;

class ModelMetadataTensor final {
 public:
  ModelMetadataTensor(const std::string &name, DataType datatype,
                      std::vector<uint64_t> shape);

  [[nodiscard]] const std::string &getName() const;
  [[nodiscard]] const DataType &getDataType() const;
  [[nodiscard]] const std::vector<uint64_t> &getShape() const;

 private:
  std::string name_;
  DataType datatype_;
  std::vector<uint64_t> shape_;
};

class ModelMetadata final {
 public:
  ModelMetadata(const std::string &name, const std::string &platform);

  void addInputTensor(const std::string &name, DataType datatype,
                      std::initializer_list<uint64_t> shape);
  void addInputTensor(const std::string &name, DataType datatype,
                      std::vector<int> shape);

  [[nodiscard]] const std::vector<ModelMetadataTensor> &getInputs() const;

  void addOutputTensor(const std::string &name, DataType datatype,
                       std::initializer_list<uint64_t> shape);
  void addOutputTensor(const std::string &name, DataType datatype,
                       std::vector<int> shape);

  [[nodiscard]] const std::vector<ModelMetadataTensor> &getOutputs() const;

  void setName(const std::string &name);
  [[nodiscard]] const std::string &getName() const;

  [[nodiscard]] const std::string &getPlatform() const;

  void setReady(bool ready);
  [[nodiscard]] bool isReady() const;

 private:
  std::string name_;
  std::vector<std::string> versions_;
  std::string platform_;
  std::vector<ModelMetadataTensor> inputs_;
  std::vector<ModelMetadataTensor> outputs_;
  bool ready_;
};

}  // namespace amdinfer

namespace std {
template <>
struct less<amdinfer::RequestParameters> {
  bool operator()(const amdinfer::RequestParameters &lhs,
                  const amdinfer::RequestParameters &rhs) const {
    auto lhs_size = lhs.size();
    auto rhs_size = rhs.size();
    auto lhs_map = lhs.data();
    auto rhs_map = rhs.data();
    if (lhs_size == rhs_size) {
      for (const auto &[key, lhs_value] : lhs_map) {
        if (rhs_map.find(key) == rhs_map.end()) {
          return true;
        }
        const auto &rhs_value = rhs_map.at(key);
        if (lhs_value != rhs_value) {
          return lhs_value < rhs_value;
        }
      }
      return false;
    }
    return lhs_size < rhs_size;
  }
};
}  // namespace std

#endif  // GUARD_AMDINFER_CORE_PREDICT_API