Program Listing for File grpc.cpp

Return to documentation for file (/workspace/amdinfer/src/amdinfer/clients/grpc.cpp)

// Copyright 2022 Xilinx, Inc.
// Copyright 2022 Advanced Micro Devices, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "amdinfer/clients/grpc.hpp"

#include <google/protobuf/repeated_ptr_field.h>  // for RepeatedPtrField
#include <grpcpp/grpcpp.h>                       // for Status, ClientContext

#include <future>         // for __forced_unwind, async
#include <memory>         // for unique_ptr, shared_ptr
#include <string>         // for string
#include <unordered_set>  // for unordered_set
#include <vector>         // for vector

#include "amdinfer/clients/grpc_internal.hpp"  // for mapParametersToProto
#include "amdinfer/core/data_types.hpp"        // for DataType
#include "amdinfer/core/exceptions.hpp"        // for bad_status, connecti...
#include "amdinfer/declarations.hpp"           // for InferenceResponseFuture
#include "amdinfer/observation/observer.hpp"   // for Logger, Observer
#include "predict_api.grpc.pb.h"               // for GRPCInferenceService...
#include "predict_api.pb.h"                    // for ModelMetadataRespons...

using grpc::ClientContext;
using grpc::Status;

namespace amdinfer {

class GrpcClient::GrpcClientImpl {
 public:
  explicit GrpcClientImpl(const std::shared_ptr<::grpc::Channel>& channel) {
    this->stub_ = inference::GRPCInferenceService::NewStub(channel);
  }

  inference::GRPCInferenceService::Stub* getStub() { return this->stub_.get(); }

 private:
  std::unique_ptr<inference::GRPCInferenceService::Stub> stub_;
};

GrpcClient::GrpcClient(const std::string& address)
  : GrpcClient(
      ::grpc::CreateChannel(address, ::grpc::InsecureChannelCredentials())) {}

GrpcClient::GrpcClient(const std::shared_ptr<::grpc::Channel>& channel) {
  this->impl_ = std::make_unique<GrpcClient::GrpcClientImpl>(channel);
}

GrpcClient::~GrpcClient() = default;

ServerMetadata GrpcClient::serverMetadata() const {
  inference::ServerMetadataRequest request;
  inference::ServerMetadataResponse reply;

  ClientContext context;

  auto* stub = this->impl_->getStub();
  Status status = stub->ServerMetadata(&context, request, &reply);

  if (status.ok()) {
    auto ext = reply.extensions();
    std::unordered_set<std::string> extensions(ext.begin(), ext.end());
    ServerMetadata metadata{reply.name(), reply.version(), extensions};
    return metadata;
  }
  throw bad_status(status.error_message());
}

bool GrpcClient::serverLive() const {
  inference::ServerLiveRequest request;
  inference::ServerLiveResponse reply;

  ClientContext context;

  auto* stub = this->impl_->getStub();
  Status status = stub->ServerLive(&context, request, &reply);

  if (status.ok()) {
    return reply.live();
  }
  if (status.error_code() == ::grpc::StatusCode::UNAVAILABLE) {
    return false;
  }
  throw bad_status(status.error_message());
}

bool GrpcClient::serverReady() const {
  inference::ServerReadyRequest request;
  inference::ServerReadyResponse reply;

  ClientContext context;

  auto* stub = this->impl_->getStub();
  Status status = stub->ServerReady(&context, request, &reply);

  if (status.ok()) {
    return reply.ready();
  }
  if (status.error_code() == ::grpc::StatusCode::UNAVAILABLE) {
    throw connection_error(status.error_message());
  }
  throw bad_status(status.error_message());
}

bool GrpcClient::modelReady(const std::string& model) const {
  inference::ModelReadyRequest request;
  inference::ModelReadyResponse reply;

  ClientContext context;

  request.set_name(model);

  auto* stub = this->impl_->getStub();
  Status status = stub->ModelReady(&context, request, &reply);

  if (status.ok()) {
    return reply.ready();
  }
  return false;
}

ModelMetadata mapProtoToModelMetadata(
  const inference::ModelMetadataResponse& resp) {
  ModelMetadata metadata{resp.name(), resp.platform()};
  const auto& inputs = resp.inputs();
  for (const auto& input : inputs) {
    std::vector<int> shape;
    shape.reserve(input.shape_size());
    for (const auto& index : input.shape()) {
      shape.push_back(static_cast<int>(index));
    }
    metadata.addInputTensor(input.name(), DataType(input.datatype().c_str()),
                            shape);
  }
  const auto& outputs = resp.outputs();
  for (const auto& output : outputs) {
    std::vector<int> shape;
    shape.reserve(output.shape_size());
    for (const auto& index : output.shape()) {
      shape.push_back(static_cast<int>(index));
    }
    metadata.addInputTensor(output.name(), DataType(output.datatype().c_str()),
                            shape);
  }
  return metadata;
}

ModelMetadata GrpcClient::modelMetadata(const std::string& model) const {
  inference::ModelMetadataRequest request;
  inference::ModelMetadataResponse reply;

  ClientContext context;

  request.set_name(model);

  auto* stub = this->impl_->getStub();
  Status status = stub->ModelMetadata(&context, request, &reply);

  if (status.ok()) {
    return mapProtoToModelMetadata(reply);
  }
  throw bad_status(status.error_message());
}

std::vector<std::string> GrpcClient::modelList() const {
  inference::ModelListRequest request;
  inference::ModelListResponse reply;

  ClientContext context;

  auto* stub = this->impl_->getStub();
  Status status = stub->ModelList(&context, request, &reply);

  if (status.ok()) {
    auto mods = reply.models();
    std::vector<std::string> models(mods.begin(), mods.end());
    return models;
  }
  throw bad_status(status.error_message());
}

void GrpcClient::modelLoad(const std::string& model,
                           RequestParameters* parameters) const {
  inference::ModelLoadRequest request;
  inference::ModelLoadResponse reply;

  ClientContext context;

  request.set_name(model);
  auto* params = request.mutable_parameters();
  if (parameters != nullptr) {
    mapParametersToProto(parameters->data(), params);
  }

  auto* stub = this->impl_->getStub();
  Status status = stub->ModelLoad(&context, request, &reply);

  if (!status.ok()) {
    throw bad_status(status.error_message());
  }
}

void GrpcClient::modelUnload(const std::string& model) const {
  inference::ModelUnloadRequest request;
  inference::ModelUnloadResponse reply;

  ClientContext context;

  request.set_name(model);

  auto* stub = this->impl_->getStub();
  Status status = stub->ModelUnload(&context, request, &reply);

  if (!status.ok()) {
    throw bad_status(status.error_message());
  }
}

std::string GrpcClient::workerLoad(const std::string& worker,
                                   RequestParameters* parameters) const {
  inference::WorkerLoadRequest request;
  inference::WorkerLoadResponse reply;

  ClientContext context;

  request.set_name(worker);
  auto* params = request.mutable_parameters();
  if (parameters != nullptr) {
    mapParametersToProto(parameters->data(), params);
  }

  auto* stub = this->impl_->getStub();
  Status status = stub->WorkerLoad(&context, request, &reply);

  if (status.ok()) {
    return reply.endpoint();
  }
  throw bad_status(status.error_message());
}

void GrpcClient::workerUnload(const std::string& worker) const {
  inference::WorkerUnloadRequest request;
  inference::WorkerUnloadResponse reply;

  ClientContext context;

  request.set_name(worker);

  auto* stub = this->impl_->getStub();
  Status status = stub->WorkerUnload(&context, request, &reply);

  if (!status.ok()) {
    throw bad_status(status.error_message());
  }
}

InferenceResponse runInference(inference::GRPCInferenceService::Stub* stub,
                               const std::string& model,
                               const InferenceRequest& request) {
  inference::ModelInferRequest grpc_request;
  inference::ModelInferResponse reply;

  ClientContext context;

  Observer observer;
  AMDINFER_IF_LOGGING(observer.logger = Logger{Loggers::Client});

  grpc_request.set_model_name(model);
  mapRequestToProto(request, grpc_request, observer);

  Status status = stub->ModelInfer(&context, grpc_request, &reply);

  if (!status.ok()) {
    throw bad_status(status.error_message());
  }

  InferenceResponse response;
  mapProtoToResponse(reply, response, observer);
  return response;
}

InferenceResponseFuture GrpcClient::modelInferAsync(
  const std::string& model, const InferenceRequest& request) const {
  return std::async(runInference, this->impl_->getStub(), model, request);
}

InferenceResponse GrpcClient::modelInfer(
  const std::string& model, const InferenceRequest& request) const {
  return runInference(this->impl_->getStub(), model, request);
}

bool GrpcClient::hasHardware(const std::string& name, int num) const {
  inference::HasHardwareRequest grpc_request;
  inference::HasHardwareResponse reply;

  ClientContext context;

  grpc_request.set_name(name);
  grpc_request.set_num(num);

  auto* stub = this->impl_->getStub();
  Status status = stub->HasHardware(&context, grpc_request, &reply);

  if (!status.ok()) {
    throw bad_status(status.error_message());
  }

  return reply.found();
}

}  // namespace amdinfer