Program Listing for File grpc_server.cpp

Return to documentation for file (/workspace/amdinfer/src/amdinfer/servers/grpc_server.cpp)

// Copyright 2022 Xilinx, Inc.
// Copyright 2022 Advanced Micro Devices, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "amdinfer/servers/grpc_server.hpp"

#include <google/protobuf/repeated_ptr_field.h>  // for RepeatedPtrField
#include <grpc/support/log.h>                    // for GPR_ASSERT, GPR_UNL...
#include <grpcpp/grpcpp.h>                       // for ServerCompletionQueue

#include <cassert>        // for assert
#include <cstddef>        // for size_t, byte
#include <cstdint>        // for uint64_t, int16_t
#include <cstring>        // for memcpy
#include <exception>      // for exception
#include <memory>         // for unique_ptr, shared_ptr
#include <string>         // for allocator, string
#include <thread>         // for thread, yield
#include <unordered_set>  // for unordered_set
#include <utility>        // for move
#include <vector>         // for vector

#include "amdinfer/buffers/buffer.hpp"         // for Buffer
#include "amdinfer/build_options.hpp"          // for AMDINFER_ENABLE_LOGGING
#include "amdinfer/clients/grpc_internal.hpp"  // for mapProtoToParameters
#include "amdinfer/core/api.hpp"               // for hasHardware, modelI...
#include "amdinfer/core/data_types.hpp"        // for DataType, DataType:...
#include "amdinfer/core/exceptions.hpp"        // for invalid_argument
#include "amdinfer/core/interface.hpp"         // for Interface, Interfac...
#include "amdinfer/core/predict_api_internal.hpp"  // for InferenceRequestInput
#include "amdinfer/declarations.hpp"               // for BufferRawPtrs, Infe...
#include "amdinfer/observation/observer.hpp"
#include "amdinfer/util/string.hpp"  // for toLower
#include "amdinfer/util/traits.hpp"  // IWYU pragma: keep
#include "predict_api.grpc.pb.h"     // for GRPCInferenceServic...
#include "predict_api.pb.h"          // for InferTensorContents

namespace amdinfer {
class CallDataModelInfer;
class CallDataModelMetadata;
class CallDataModelLoad;
class CallDataWorkerLoad;
class CallDataModelReady;
class CallDataModelUnload;
class CallDataWorkerUnload;
class CallDataServerLive;
class CallDataServerMetadata;
class CallDataServerReady;
class CallDataHasHardware;
class CallDataModelList;
}  // namespace amdinfer

// use aliases to prevent clashes between grpc:: and amdinfer::grpc::
using ServerBuilder = grpc::ServerBuilder;
using ServerCompletionQueue = grpc::ServerCompletionQueue;
template <typename T>
using ServerAsyncResponseWriter = grpc::ServerAsyncResponseWriter<T>;
using ServerContext = grpc::ServerContext;
using Server = grpc::Server;
using StatusCode = grpc::StatusCode;

// namespace inference {
// using StreamModelInferRequest = ModelInferRequest;
// using StreamModelInferResponse = ModelInferResponse;
// }

namespace amdinfer {

using AsyncService = inference::GRPCInferenceService::AsyncService;

class CallDataBase {
 public:
  virtual void proceed() = 0;
};

template <typename RequestType, typename ReplyType>
class CallData : public CallDataBase {
 public:
  // Take in the "service" instance (in this case representing an asynchronous
  // server) and the completion queue "cq" used for asynchronous communication
  // with the gRPC runtime.
  CallData(AsyncService* service, ::grpc::ServerCompletionQueue* cq)
    : service_(service), cq_(cq), status_(Create) {}

  virtual ~CallData() = default;

  void proceed() override {
    if (status_ == Create) {
      // Make this instance progress to the Process state.
      status_ = Process;

      waitForRequest();
    } else if (status_ == Process) {
      addNewCallData();

      // queue_->enqueue(this);
      // status_ = Wait;
      handleRequest();
      status_ = Wait;
    } else if (status_ == Wait) {
      std::this_thread::yield();
    } else {
      assert(status_ == Finish);
      // Once in the Finish state, deallocate ourselves (CallData).
      delete this;
    }
  }

  virtual void finish(const ::grpc::Status& status) = 0;

 protected:
  // When we handle a request of this type, we need to tell
  // the completion queue to wait for new requests of the same type.
  virtual void addNewCallData() = 0;

  virtual void waitForRequest() = 0;
  virtual void handleRequest() noexcept = 0;

  // The means of communication with the gRPC runtime for an asynchronous
  // server.
  AsyncService* service_;
  // The producer-consumer queue where for asynchronous server notifications.
  ::grpc::ServerCompletionQueue* cq_;
  // Context for the rpc, allowing to tweak aspects of it such as the use
  // of compression, authentication, as well as to send metadata back to the
  // client.
  ::grpc::ServerContext ctx_;

  // What we get from the client.
  RequestType request_;
  // What we send back to the client.
  ReplyType reply_;

  // Let's implement a tiny state machine with the following states.
  enum CallStatus { Create, Process, Wait, Finish };
  CallStatus status_;  // The current serving state.
};

template <typename RequestType, typename ReplyType>
class CallDataUnary : public CallData<RequestType, ReplyType> {
 public:
  // Take in the "service" instance (in this case representing an asynchronous
  // server) and the completion queue "cq" used for asynchronous communication
  // with the gRPC runtime.
  CallDataUnary(AsyncService* service, ::grpc::ServerCompletionQueue* cq)
    : CallData<RequestType, ReplyType>(service, cq), responder_(&this->ctx_) {}

  void finish(const ::grpc::Status& status) override {
    // And we are done! Let the gRPC runtime know we've finished, using the
    // memory address of this instance as the uniquely identifying tag for
    // the event.
    this->status_ = this->Finish;
    responder_.Finish(this->reply_, status, this);
  }

 protected:
  // The means to get back to the client.
  ::grpc::ServerAsyncResponseWriter<ReplyType> responder_;
};

template <typename RequestType, typename ReplyType>
class CallDataServerStream : public CallData<RequestType, ReplyType> {
 public:
  // Take in the "service" instance (in this case representing an asynchronous
  // server) and the completion queue "cq" used for asynchronous communication
  // with the gRPC runtime.
  CallDataServerStream(AsyncService* service, ::grpc::ServerCompletionQueue* cq)
    : CallData<RequestType, ReplyType>(service, cq), responder_(&this->ctx_) {}

  void write(const ReplyType& response) { responder_->Write(response, this); }

  void finish(const ::grpc::Status& status) override {
    // And we are done! Let the gRPC runtime know we've finished, using the
    // memory address of this instance as the uniquely identifying tag for
    // the event.
    this->status_ = this->Finish;
    responder_.Finish(this->reply_, status, this);
  }

 protected:
  // The means to get back to the client.
  ::grpc::ServerAsyncWriter<ReplyType> responder_;
};

struct WriteData {
  template <typename T, typename Tensor>
  void operator()(Buffer* buffer, Tensor* tensor, size_t offset, size_t size,
                  const Observer& observer) const {
    auto* contents = getTensorContents<T>(tensor);
    if constexpr (util::is_any_v<T, bool, uint32_t, uint64_t, int32_t, int64_t,
                                 float, double, char>) {
      auto* dest = static_cast<std::byte*>(buffer->data(offset));
      std::memcpy(dest, contents, size * sizeof(T));
    } else if constexpr (util::is_any_v<T, uint8_t, uint16_t, int8_t, int16_t,
                                        fp16>) {
      for (size_t i = 0; i < size; i++) {
#ifdef AMDINFER_ENABLE_LOGGING
        if (const auto min_size = size > kNumTraceData ? kNumTraceData : size;
            i < min_size) {
          AMDINFER_LOG_TRACE(observer.logger,
                             "Writing data to buffer: " +
                               std::to_string(static_cast<T>(contents[i])));
        }
#endif
        offset = buffer->write(static_cast<T>(contents[i]), offset);
      }
    } else {
      static_assert(!sizeof(T), "Invalid type to WriteData");
    }
  }
};

template <>
class InferenceRequestInputBuilder<
  inference::ModelInferRequest_InferInputTensor> {
 public:
  static InferenceRequestInput build(
    const inference::ModelInferRequest_InferInputTensor& req,
    Buffer* input_buffer, size_t offset) {
    Observer observer;
    AMDINFER_IF_LOGGING(observer.logger = Logger{Loggers::Server});

    AMDINFER_LOG_TRACE(observer.logger,
                       "Creating InferenceRequestInput from proto tensor");

    InferenceRequestInput input;
    input.name_ = req.name();
    input.shape_.reserve(req.shape_size());
    for (const auto& index : req.shape()) {
      input.shape_.push_back(static_cast<size_t>(index));
    }
    input.data_type_ = DataType(req.datatype().c_str());

    input.parameters_ = mapProtoToParameters(req.parameters());

    auto size = input.getSize();
    auto* dest = static_cast<std::byte*>(input_buffer->data(offset));
    AMDINFER_LOG_TRACE(observer.logger, "Writing " + std::to_string(size) +
                                          " elements of type " +
                                          input.data_type_.str() + " to " +
                                          util::addressToString(dest));

    switchOverTypes(WriteData(), input.getDatatype(), input_buffer, &req,
                    offset, size, observer);

    input.data_ = dest;
    return input;
  }
};

using InputBuilder =
  InferenceRequestInputBuilder<inference::ModelInferRequest_InferInputTensor>;

#ifdef AMDINFER_ENABLE_LOGGING
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
#define CALLDATA_IMPL(endpoint, type)                                         \
  class CallData##endpoint                                                    \
    : public CallData##type<inference::endpoint##Request,                     \
                            inference::endpoint##Response> {                  \
   public:                                                                    \
    CallData##endpoint(AsyncService* service, ServerCompletionQueue* cq)      \
      : CallData##type(service, cq) {                                         \
      proceed();                                                              \
    }                                                                         \
                                                                              \
   private:                                                                   \
    Logger logger_{Loggers::Server};                                          \
                                                                              \
   protected:                                                                 \
    void addNewCallData() override { new CallData##endpoint(service_, cq_); } \
    void waitForRequest() override {                                          \
      service_->Request##endpoint(&ctx_, &request_, &responder_, cq_, cq_,    \
                                  this);                                      \
    }                                                                         \
    void handleRequest() noexcept override
#else
#define CALLDATA_IMPL(endpoint, type)                                         \
  class CallData##endpoint                                                    \
    : public CallData##type<inference::endpoint##Request,                     \
                            inference::endpoint##Response> {                  \
   public:                                                                    \
    CallData##endpoint(AsyncService* service, ServerCompletionQueue* cq)      \
      : CallData##type(service, cq) {                                         \
      proceed();                                                              \
    }                                                                         \
                                                                              \
   protected:                                                                 \
    void addNewCallData() override { new CallData##endpoint(service_, cq_); } \
    void waitForRequest() override {                                          \
      service_->Request##endpoint(&ctx_, &request_, &responder_, cq_, cq_,    \
                                  this);                                      \
    }                                                                         \
    void handleRequest() noexcept override
#endif

// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
#define CALLDATA_IMPL_END \
  }                       \
  ;  // NOLINT

CALLDATA_IMPL(ModelInfer, Unary);

public:
const inference::ModelInferRequest& getRequest() const {
  return this->request_;
}

inference::ModelInferResponse& getReply() { return this->reply_; }
CALLDATA_IMPL_END

template <>
class InferenceRequestBuilder<CallDataModelInfer*> {
 public:
  static InferenceRequestPtr build(const CallDataModelInfer* req,
                                   const BufferRawPtrs& input_buffers,
                                   std::vector<size_t>& input_offsets,
                                   const BufferRawPtrs& output_buffers,
                                   std::vector<size_t>& output_offsets) {
    Observer observer;
    AMDINFER_IF_LOGGING(observer.logger = Logger{Loggers::Server});

    AMDINFER_LOG_TRACE(observer.logger,
                       "Creating InferenceRequest from proto tensor");

    auto request = std::make_shared<InferenceRequest>();
    const auto& grpc_request = req->getRequest();

    request->id_ = grpc_request.id();

    request->parameters_ = mapProtoToParameters(grpc_request.parameters());

    request->callback_ = nullptr;

    for (const auto& input : grpc_request.inputs()) {
      const auto& buffers = input_buffers;
      auto index = 0;
      for (const auto& buffer : buffers) {
        auto& offset = input_offsets[index];

        request->inputs_.push_back(InputBuilder::build(input, buffer, offset));
        const auto& last_input = request->inputs_.back();
        offset += (last_input.getSize() * last_input.getDatatype().size());
        index++;
      }
    }

    // TODO(varunsh): output_offset is currently ignored! The size of the
    // output needs to come from the worker but we have no such information.
    if (grpc_request.outputs_size() != 0) {
      for (const auto& output : grpc_request.outputs()) {
        // TODO(varunsh): we're ignoring incoming output data
        (void)output;
        const auto& buffers = output_buffers;
        auto index = 0;
        for (const auto& buffer : buffers) {
          auto& offset = output_offsets[index];

          request->outputs_.emplace_back();
          request->outputs_.back().setData(
            static_cast<std::byte*>(buffer->data(offset)));
          index++;
        }
      }
    } else {
      for (const auto& input : grpc_request.inputs()) {
        (void)input;  // suppress unused variable warning
        const auto& buffers = output_buffers;
        for (size_t j = 0; j < buffers.size(); j++) {
          const auto& buffer = buffers[j];
          const auto& offset = output_offsets[j];

          request->outputs_.emplace_back();
          request->outputs_.back().setData(
            static_cast<std::byte*>(buffer->data(offset)));
        }
      }
    }

    return request;
  }
};

using RequestBuilder = InferenceRequestBuilder<CallDataModelInfer*>;

// CALLDATA_IMPL(StreamModelInfer, ServerStream);

//  public:
//   const inference::ModelInferRequest& getRequest() const {
//     return this->request_;
//   }
// CALLDATA_IMPL_END

void grpcUnaryCallback(CallDataModelInfer* calldata,
                       const InferenceResponse& response) {
  if (response.isError()) {
    calldata->finish(::grpc::Status(StatusCode::UNKNOWN, response.getError()));
    return;
  }
  try {
    mapResponseToProto(response, calldata->getReply());
  } catch (const invalid_argument& e) {
    calldata->finish(::grpc::Status(StatusCode::UNKNOWN, e.what()));
    return;
  }

  // #ifdef AMDINFER_ENABLE_TRACING
  //   const auto &context = response.getContext();
  //   propagate(resp.get(), context);
  // #endif
  calldata->finish(::grpc::Status::OK);
}

class GrpcApiUnary : public Interface {
 public:
  explicit GrpcApiUnary(CallDataModelInfer* calldata) : calldata_(calldata) {
    this->type_ = InterfaceType::Grpc;
  }

  std::shared_ptr<InferenceRequest> getRequest(
    const BufferRawPtrs& input_buffers, std::vector<size_t>& input_offsets,
    const BufferRawPtrs& output_buffers,
    std::vector<size_t>& output_offsets) override {
#ifdef AMDINFER_ENABLE_LOGGING
    const auto& logger = this->getLogger();
#endif
    try {
      auto request =
        RequestBuilder::build(this->calldata_, input_buffers, input_offsets,
                              output_buffers, output_offsets);
      // Callback callback =
      //   std::bind(grpcUnaryCallback, this->calldata_, std::placeholders::_1);
      Callback callback =
        [calldata = this->calldata_](const InferenceResponse& response) {
          grpcUnaryCallback(calldata, response);
        };
      request->setCallback(std::move(callback));
      return request;
    } catch (const invalid_argument& e) {
      AMDINFER_LOG_INFO(logger, e.what());
      errorHandler(e);
      return nullptr;
    }
  }

  size_t getInputSize() override {
    return calldata_->getRequest().inputs_size();
  }

  void errorHandler(const std::exception& e) override {
    AMDINFER_LOG_INFO(this->getLogger(), e.what());
    calldata_->finish(::grpc::Status(StatusCode::UNKNOWN, e.what()));
  }

 private:
  CallDataModelInfer* calldata_;
};

CALLDATA_IMPL(ServerLive, Unary) {
  reply_.set_live(true);
  finish(::grpc::Status::OK);
}
CALLDATA_IMPL_END

CALLDATA_IMPL(ServerReady, Unary) {
  reply_.set_ready(true);
  finish(::grpc::Status::OK);
}
CALLDATA_IMPL_END

CALLDATA_IMPL(ModelReady, Unary) {
  const auto& model = request_.name();
  try {
    reply_.set_ready(::amdinfer::modelReady(model));
    finish(::grpc::Status::OK);
  } catch (const invalid_argument& e) {
    reply_.set_ready(false);
    finish(::grpc::Status(StatusCode::NOT_FOUND, e.what()));
  } catch (const std::exception& e) {
    reply_.set_ready(false);
    finish(::grpc::Status(StatusCode::UNKNOWN, e.what()));
  }
}
CALLDATA_IMPL_END

CALLDATA_IMPL(ModelMetadata, Unary) {
  const auto& model = request_.name();
  try {
    auto metadata = ::amdinfer::modelMetadata(model);
    mapModelMetadataToProto(metadata, reply_);
    finish(::grpc::Status::OK);
  } catch (const invalid_argument& e) {
    finish(::grpc::Status(StatusCode::NOT_FOUND, e.what()));
  } catch (const std::exception& e) {
    finish(::grpc::Status(StatusCode::UNKNOWN, e.what()));
  }
}
CALLDATA_IMPL_END

CALLDATA_IMPL(ServerMetadata, Unary) {
  auto metadata = serverMetadata();
  reply_.set_name(metadata.name);
  reply_.set_version(metadata.version);
  for (const auto& extension : metadata.extensions) {
    reply_.add_extensions(extension);
  }
  finish(::grpc::Status::OK);
}
CALLDATA_IMPL_END

CALLDATA_IMPL(ModelList, Unary) {
  auto models = ::amdinfer::modelList();
  for (const auto& model : models) {
    reply_.add_models(model);
  }
  finish(::grpc::Status::OK);
}
CALLDATA_IMPL_END

CALLDATA_IMPL(ModelLoad, Unary) {
  auto parameters = mapProtoToParameters(request_.parameters());

  auto* model = request_.mutable_name();
  util::toLower(model);
  try {
    ::amdinfer::modelLoad(*model, parameters.get());
  } catch (const runtime_error& e) {
    AMDINFER_LOG_ERROR(logger_, e.what());
    finish(::grpc::Status(StatusCode::NOT_FOUND, e.what()));
    return;
  } catch (const std::exception& e) {
    finish(::grpc::Status(StatusCode::UNKNOWN, e.what()));
    return;
  }

  finish(::grpc::Status::OK);
}
CALLDATA_IMPL_END

CALLDATA_IMPL(ModelUnload, Unary) {
  auto* model = request_.mutable_name();
  util::toLower(model);
  ::amdinfer::modelUnload(*model);
  finish(::grpc::Status::OK);
}
CALLDATA_IMPL_END

CALLDATA_IMPL(WorkerLoad, Unary) {
  auto parameters = mapProtoToParameters(request_.parameters());

  auto* model = request_.mutable_name();
  util::toLower(model);

  try {
    auto endpoint = ::amdinfer::workerLoad(*model, parameters.get());
    reply_.set_endpoint(endpoint);
    finish(::grpc::Status::OK);
  } catch (const runtime_error& e) {
    AMDINFER_LOG_ERROR(logger_, e.what());
    finish(::grpc::Status(StatusCode::NOT_FOUND, e.what()));
  } catch (const std::exception& e) {
    AMDINFER_LOG_ERROR(logger_, e.what());
    finish(::grpc::Status(StatusCode::UNKNOWN, e.what()));
  }
}
CALLDATA_IMPL_END

CALLDATA_IMPL(WorkerUnload, Unary) {
  auto* worker = request_.mutable_name();
  util::toLower(worker);
  ::amdinfer::workerUnload(*worker);
  finish(::grpc::Status::OK);
}
CALLDATA_IMPL_END

CALLDATA_IMPL(HasHardware, Unary) {
  auto found = ::amdinfer::hasHardware(request_.name(), request_.num());
  reply_.set_found(found);
  finish(::grpc::Status::OK);
}
CALLDATA_IMPL_END

void CallDataModelInfer::handleRequest() noexcept {
  const auto& model = request_.model_name();
#ifdef AMDINFER_ENABLE_TRACING
  auto trace = startTrace(&(__func__[0]));
  trace->setAttribute("model", model);
  trace->startSpan("request_handler");
#endif

  try {
    auto request = std::make_unique<GrpcApiUnary>(this);
#ifdef AMDINFER_ENABLE_TRACING
    trace->endSpan();
    request->setTrace(std::move(trace));
#endif
    ::amdinfer::modelInfer(model, std::move(request));
  } catch (const invalid_argument& e) {
    AMDINFER_LOG_INFO(logger_, e.what());
    finish(::grpc::Status(StatusCode::NOT_FOUND, e.what()));
  } catch (const std::exception& e) {
    AMDINFER_LOG_ERROR(logger_, e.what());
    finish(::grpc::Status(StatusCode::UNKNOWN, e.what()));
  }
}

class GrpcServer final {
 public:
  static GrpcServer& getInstance() { return create("", -1); }

  static GrpcServer& create(const std::string& address, const int cq_count) {
    static GrpcServer server(address, cq_count);
    return server;
  }

  GrpcServer(GrpcServer const&) = delete;
  GrpcServer& operator=(const GrpcServer&) =
    delete;
  GrpcServer(GrpcServer&& other) = delete;
  GrpcServer& operator=(GrpcServer&& other) =
    delete;

  ~GrpcServer() {
    server_->Shutdown();
    // Always shutdown the completion queues after the server.
    for (const auto& cq : cq_) {
      cq->Shutdown();
      void* tag = nullptr;
      bool ok = false;
      while (cq->Next(&tag, &ok)) {
        // drain the completion queue to prevent assertion errors in grpc
      }
    }
    for (auto& thread : threads_) {
      if (thread.joinable()) {
        thread.join();
      }
    }
  }

 private:
  GrpcServer(const std::string& address, const int cq_count) {
    ServerBuilder builder;
    builder.SetMaxReceiveMessageSize(kMaxGrpcMessageSize);
    builder.SetMaxSendMessageSize(kMaxGrpcMessageSize);
    // Listen on the given address without any authentication mechanism.
    builder.AddListeningPort(address, ::grpc::InsecureServerCredentials());
    // Register "service_" as the instance through which we'll communicate
    // with clients. In this case it corresponds to an *asynchronous* service.
    builder.RegisterService(&service_);
    // Get hold of the completion queue used for the asynchronous
    // communication with the gRPC runtime.
    for (auto i = 0; i < cq_count; i++) {
      cq_.push_back(builder.AddCompletionQueue());
    }
    // Finally assemble the server.
    server_ = builder.BuildAndStart();

    // Start threads to handle incoming RPCs
    for (auto i = 0; i < cq_count; i++) {
      threads_.emplace_back(&GrpcServer::handleRpcs, this, i);
    }
  }

  // This can be run in multiple threads if needed.
  void handleRpcs(int index) {
    const auto& my_cq = cq_.at(index);

    // Spawn a new CallData instance to serve new clients.
    new CallDataServerLive(&service_, my_cq.get());
    new CallDataServerMetadata(&service_, my_cq.get());
    new CallDataModelMetadata(&service_, my_cq.get());
    new CallDataServerReady(&service_, my_cq.get());
    new CallDataModelList(&service_, my_cq.get());
    new CallDataModelReady(&service_, my_cq.get());
    new CallDataModelLoad(&service_, my_cq.get());
    new CallDataModelUnload(&service_, my_cq.get());
    new CallDataWorkerLoad(&service_, my_cq.get());
    new CallDataWorkerUnload(&service_, my_cq.get());
    new CallDataModelInfer(&service_, my_cq.get());
    new CallDataHasHardware(&service_, my_cq.get());
    // new CallDataStreamModelInfer(&service_, my_cq.get());
    void* tag = nullptr;  // uniquely identifies a request.
    bool ok = false;
    while (true) {
      // the gRPC is shutting down in this case
      if (my_cq == nullptr) {
        return;
      }

      // Block waiting to read the next event from the completion queue. The
      // event is uniquely identified by its tag, which in this case is the
      // memory address of a CallDataBase instance.
      // The return value of Next should always be checked. This return value
      // tells us whether there is any kind of event or cq_ is shutting down.
      auto event_received = my_cq->Next(&tag, &ok);
      if (GPR_UNLIKELY(!(ok && event_received))) {
        break;
      }
      static_cast<CallDataBase*>(tag)->proceed();
    }
  }

  std::vector<std::unique_ptr<::grpc::ServerCompletionQueue>> cq_;
  inference::GRPCInferenceService::AsyncService service_;
  std::unique_ptr<::grpc::Server> server_;
  std::vector<std::thread> threads_;
};

namespace grpc {

void start(int port) {
  const std::string address = "0.0.0.0:" + std::to_string(port);
  GrpcServer::create(address, 1);
}

void stop() {
  // the GrpcServer's destructor is called automatically
  // auto& foo = GrpcServer::getInstance();
  // foo.~GrpcServer();
}

}  // namespace grpc

}  // namespace amdinfer