Program Listing for File batcher.hpp

Return to documentation for file (/workspace/amdinfer/src/amdinfer/batching/batcher.hpp)

// Copyright 2021 Xilinx, Inc.
// Copyright 2022 Advanced Micro Devices, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef GUARD_AMDINFER_BATCHING_BATCHER
#define GUARD_AMDINFER_BATCHING_BATCHER

#include <chrono>   // for system_clock::time_point
#include <cstddef>  // for size_t
#include <memory>   // for unique_ptr, shared_ptr
#include <string>   // for string
#include <thread>   // for thread
#include <vector>   // for vector

#include "amdinfer/build_options.hpp"        // for AMDINFER_ENABLE_LOGGING
#include "amdinfer/core/predict_api.hpp"     // for RequestParameters
#include "amdinfer/declarations.hpp"         // for BufferPtrs, InferenceReq...
#include "amdinfer/observation/logging.hpp"  // for LoggerPtr
#include "amdinfer/observation/tracing.hpp"  // for TracePtr
#include "amdinfer/util/queue.hpp"           // for BlockingConcurrentQueue

namespace amdinfer {
class Buffer;
class WorkerInfo;
}  // namespace amdinfer

namespace amdinfer {

enum class BatcherStatus { New, Run, Inactive, Dead };

class Batch {
 public:
  explicit Batch(const WorkerInfo* worker);
  Batch(Batch const&) = delete;
  Batch& operator=(const Batch&) = delete;
  Batch(Batch&& other) = default;
  Batch& operator=(Batch&& other) = default;
  ~Batch();

  void addRequest(InferenceRequestPtr request);

  [[nodiscard]] const InferenceRequestPtr& getRequest(size_t index);
  [[nodiscard]] const std::vector<InferenceRequestPtr>& getRequests() const;
  [[nodiscard]] const BufferPtrs& getInputBuffers() const;
  [[nodiscard]] const BufferPtrs& getOutputBuffers() const;
  [[nodiscard]] std::vector<Buffer*> getRawInputBuffers() const;
  [[nodiscard]] std::vector<Buffer*> getRawOutputBuffers() const;

  [[nodiscard]] bool empty() const;
  [[nodiscard]] size_t size() const;
  [[nodiscard]] size_t getInputSize() const;
  [[nodiscard]] size_t getOutputSize() const;

#ifdef AMDINFER_ENABLE_TRACING
  void addTrace(TracePtr trace);
  TracePtr& getTrace(size_t index);
#endif
#ifdef AMDINFER_ENABLE_METRICS
  void addTime(std::chrono::high_resolution_clock::time_point timestamp);
  std::chrono::high_resolution_clock::time_point getTime(size_t index);
#endif

  [[nodiscard]] auto begin() const { return requests_.begin(); }
  [[nodiscard]] auto end() const { return requests_.end(); }

 private:
  const WorkerInfo* worker_;
  std::vector<InferenceRequestPtr> requests_;
  std::vector<BufferPtr> input_buffers_;
  std::vector<BufferPtr> output_buffers_;
#ifdef AMDINFER_ENABLE_TRACING
  std::vector<TracePtr> traces_;
#endif
#ifdef AMDINFER_ENABLE_METRICS
  std::vector<std::chrono::high_resolution_clock::time_point> start_times_;
#endif
};

using BatchPtr = std::unique_ptr<Batch>;
using BatchPtrQueue = BlockingQueue<BatchPtr>;

class Batcher {
 public:
  Batcher();
  explicit Batcher(RequestParameters* parameters);
  // explicit Batcher(const std::string& name);
  Batcher(const Batcher& batcher);
  Batcher& operator=(const Batcher&) = delete;
  Batcher(Batcher&& other) = delete;
  Batcher& operator=(Batcher&& other) =
    delete;
  virtual ~Batcher() = default;

  void start(WorkerInfo* worker);
  void setBatchSize(size_t batch_size);
  void setName(const std::string& name);
  [[nodiscard]] std::string getName() const;

  BlockingQueue<InterfacePtr>* getInputQueue();
  BatchPtrQueue* getOutputQueue();

  void run(WorkerInfo* worker);

  BatcherStatus getStatus();

  void enqueue(InterfacePtr request);

  void end();

 protected:
#ifdef AMDINFER_ENABLE_LOGGING
  [[nodiscard]] const Logger& getLogger() const;
#endif

  size_t batch_size_ = 1;
  std::shared_ptr<BlockingQueue<InterfacePtr>> input_queue_;
  std::shared_ptr<BatchPtrQueue> output_queue_;
  std::thread thread_;
  std::string model_;
  RequestParameters parameters_;

 private:
  virtual void doRun(WorkerInfo* worker) = 0;

  BatcherStatus status_;

#ifdef AMDINFER_ENABLE_LOGGING
  Logger logger_{Loggers::Server};
#endif
};

}  // namespace amdinfer

#endif  // GUARD_AMDINFER_BATCHING_BATCHER