Program Listing for File websocket.cpp

Return to documentation for file (/workspace/amdinfer/src/amdinfer/clients/websocket.cpp)

// Copyright 2022 Xilinx, Inc.
// Copyright 2022 Advanced Micro Devices, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "amdinfer/clients/websocket.hpp"

#include <concurrentqueue/blockingconcurrentqueue.h>  // for BlockingConcurr...
#include <drogon/HttpRequest.h>                       // for HttpRequest
#include <drogon/HttpResponse.h>                      // for HttpResponsePtr
#include <drogon/HttpTypes.h>                         // for WebSocketMessag...
#include <drogon/WebSocketClient.h>                   // for WebSocketClientPtr
#include <drogon/WebSocketConnection.h>               // for WebSocketConnec...
#include <json/value.h>                               // for Value
#include <json/writer.h>                              // for StreamWriterBui...
#include <trantor/net/EventLoop.h>                    // for EventLoop
#include <trantor/net/EventLoopThread.h>              // for EventLoopThread

#include <cassert>  // for assert
#include <chrono>   // for milliseconds
#include <thread>   // for sleep_for

#include "amdinfer/clients/http.hpp"           // for HttpClient
#include "amdinfer/clients/http_internal.hpp"  // for mapRequestToJson

namespace amdinfer {

class WebSocketClient::WebSocketClientImpl {
 public:
  WebSocketClientImpl(const std::string& ws_address,
                      const std::string& http_address) {
    using drogon::WebSocketMessageType;

    loop_.run();
    ws_client_ =
      drogon::WebSocketClient::newWebSocketClient(ws_address, loop_.getLoop());
    http_client_ = std::make_unique<HttpClient>(http_address);

    ws_client_->setMessageHandler(
      [&](const std::string& message, const drogon::WebSocketClientPtr& client,
          const drogon::WebSocketMessageType& type) {
        (void)client;
        std::string message_type = "Unknown";
        switch (type) {
          case WebSocketMessageType::Text: {
            // Json::CharReaderBuilder builder;
            // Json::CharReader* reader = builder.newCharReader();

            // Json::Value root;
            // std::string errors;

            // bool parsingSuccessful =
            //   reader->parse(message.c_str(), message.c_str() +
            //   message.size(),
            //                 &root, &errors);
            // delete reader;
            // if (!parsingSuccessful) {
            //   throw std::runtime_error("Unsuccessful?");
            // }
            // auto json_ptr = std::make_shared<Json::Value>(std::move(root));
            // queue_.enqueue(mapJsonToResponse(json_ptr));
            queue_.enqueue(message);
            break;
          }
          case WebSocketMessageType::Close: {
            ws_client_->stop();
            break;
          }
          default: {
            break;
          }
        }
      });

    ws_client_->setConnectionClosedHandler(
      [&](const drogon::WebSocketClientPtr& /* client */) {});
  }

  ~WebSocketClientImpl() {
    if (auto connection = ws_client_->getConnection(); connection != nullptr) {
      connection->shutdown();
    }
    ws_client_->stop();
    loop_.getLoop()->quit();
  }

  WebSocketClientImpl(WebSocketClientImpl const&) = delete;
  WebSocketClientImpl& operator=(const WebSocketClientImpl&) = delete;
  WebSocketClientImpl(WebSocketClientImpl&& other) = delete;
  WebSocketClientImpl& operator=(WebSocketClientImpl&& other) = delete;

  void connect() {
    auto connection = ws_client_->getConnection();
    if (connection == nullptr || connection->disconnected()) {
      auto req = drogon::HttpRequest::newHttpRequest();
      req->setMethod(drogon::Get);
      req->setPath("/models/infer");
      ws_client_->connectToServer(
        req, [](drogon::ReqResult r, const drogon::HttpResponsePtr& /*resp*/,
                const drogon::WebSocketClientPtr& wsptr) {
          if (r != drogon::ReqResult::Ok) {
            wsptr->stop();
          }
        });
    }
    while (connection == nullptr || connection->disconnected()) {
      std::this_thread::sleep_for(std::chrono::milliseconds(1));
      connection = ws_client_->getConnection();
    }
  }

  std::string recv() {
    std::string response;
    queue_.wait_dequeue(response);
    return response;
  }

  drogon::WebSocketClient* getWsClient() { return ws_client_.get(); }
  HttpClient* getHttpClient() { return http_client_.get(); }

 private:
  trantor::EventLoopThread loop_;
  std::unique_ptr<HttpClient> http_client_;
  drogon::WebSocketClientPtr ws_client_;
  moodycamel::BlockingConcurrentQueue<std::string> queue_;
};

WebSocketClient::WebSocketClient(const std::string& ws_address,
                                 const std::string& http_address) {
  this->impl_ = std::make_unique<WebSocketClient::WebSocketClientImpl>(
    ws_address, http_address);
}

WebSocketClient::~WebSocketClient() = default;

void WebSocketClient::close() const {
  auto* client = this->impl_->getWsClient();
  if (auto connection = client->getConnection(); connection != nullptr) {
    connection->shutdown();
    // client->stop();
  }
}

ServerMetadata WebSocketClient::serverMetadata() const {
  const auto* client = this->impl_->getHttpClient();
  return client->serverMetadata();
}

bool WebSocketClient::serverLive() const {
  const auto* client = this->impl_->getHttpClient();
  return client->serverLive();
}

bool WebSocketClient::serverReady() const {
  const auto* client = this->impl_->getHttpClient();
  return client->serverReady();
}

bool WebSocketClient::modelReady(const std::string& model) const {
  const auto* client = this->impl_->getHttpClient();
  return client->modelReady(model);
}

ModelMetadata WebSocketClient::modelMetadata(const std::string& model) const {
  const auto* client = this->impl_->getHttpClient();
  return client->modelMetadata(model);
}

void WebSocketClient::modelLoad(const std::string& model,
                                RequestParameters* parameters) const {
  const auto* client = this->impl_->getHttpClient();
  client->modelLoad(model, parameters);
}

void WebSocketClient::modelUnload(const std::string& model) const {
  const auto* client = this->impl_->getHttpClient();
  client->modelUnload(model);
}

std::string WebSocketClient::workerLoad(const std::string& worker,
                                        RequestParameters* parameters) const {
  const auto* client = this->impl_->getHttpClient();
  return client->workerLoad(worker, parameters);
}

void WebSocketClient::workerUnload(const std::string& worker) const {
  const auto* client = this->impl_->getHttpClient();
  client->workerUnload(worker);
}

InferenceResponseFuture WebSocketClient::modelInferAsync(
  const std::string& model, const InferenceRequest& request) const {
  const auto* client = this->impl_->getHttpClient();
  return client->modelInferAsync(model, request);
}

InferenceResponse WebSocketClient::modelInfer(
  const std::string& model, const InferenceRequest& request) const {
  const auto* client = this->impl_->getHttpClient();
  return client->modelInfer(model, request);
}

std::vector<std::string> WebSocketClient::modelList() const {
  const auto* client = this->impl_->getHttpClient();
  return client->modelList();
}

bool WebSocketClient::hasHardware(const std::string& name, int num) const {
  const auto* client = this->impl_->getHttpClient();
  return client->hasHardware(name, num);
}

void WebSocketClient::modelInferWs(const std::string& model,
                                   const InferenceRequest& request) const {
  auto* client = this->impl_->getWsClient();

  auto json = mapRequestToJson(request);
  json["model"] = model;
  Json::StreamWriterBuilder builder;
  builder["indentation"] = "";  // remove whitespace
  const std::string message = Json::writeString(builder, json);

  auto connection = client->getConnection();
  if (connection == nullptr || connection->disconnected()) {
    impl_->connect();
    connection = client->getConnection();
    assert(connection != nullptr);
  }
  connection->send(message);
}

std::string WebSocketClient::modelRecv() const { return impl_->recv(); }

}  // namespace amdinfer