Program Listing for File batcher.hpp¶
↰ Return to documentation for file (/workspace/amdinfer/src/amdinfer/batching/batcher.hpp)
// Copyright 2021 Xilinx, Inc.
// Copyright 2022 Advanced Micro Devices, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef GUARD_AMDINFER_BATCHING_BATCHER
#define GUARD_AMDINFER_BATCHING_BATCHER
#include <chrono> // for system_clock::time_point
#include <cstddef> // for size_t
#include <memory> // for unique_ptr, shared_ptr
#include <string> // for string
#include <thread> // for thread
#include <vector> // for vector
#include "amdinfer/build_options.hpp" // for AMDINFER_ENABLE_LOGGING
#include "amdinfer/core/predict_api.hpp" // for RequestParameters
#include "amdinfer/declarations.hpp" // for BufferPtrs, InferenceReq...
#include "amdinfer/observation/logging.hpp" // for LoggerPtr
#include "amdinfer/observation/tracing.hpp" // for TracePtr
#include "amdinfer/util/queue.hpp" // for BlockingConcurrentQueue
namespace amdinfer {
class Buffer;
class WorkerInfo;
} // namespace amdinfer
namespace amdinfer {
enum class BatcherStatus { New, Run, Inactive, Dead };
class Batch {
public:
explicit Batch(const WorkerInfo* worker);
Batch(Batch const&) = delete;
Batch& operator=(const Batch&) = delete;
Batch(Batch&& other) = default;
Batch& operator=(Batch&& other) = default;
~Batch();
void addRequest(InferenceRequestPtr request);
[[nodiscard]] const InferenceRequestPtr& getRequest(size_t index);
[[nodiscard]] const std::vector<InferenceRequestPtr>& getRequests() const;
[[nodiscard]] const BufferPtrs& getInputBuffers() const;
[[nodiscard]] const BufferPtrs& getOutputBuffers() const;
[[nodiscard]] std::vector<Buffer*> getRawInputBuffers() const;
[[nodiscard]] std::vector<Buffer*> getRawOutputBuffers() const;
[[nodiscard]] bool empty() const;
[[nodiscard]] size_t size() const;
[[nodiscard]] size_t getInputSize() const;
[[nodiscard]] size_t getOutputSize() const;
#ifdef AMDINFER_ENABLE_TRACING
void addTrace(TracePtr trace);
TracePtr& getTrace(size_t index);
#endif
#ifdef AMDINFER_ENABLE_METRICS
void addTime(std::chrono::high_resolution_clock::time_point timestamp);
std::chrono::high_resolution_clock::time_point getTime(size_t index);
#endif
[[nodiscard]] auto begin() const { return requests_.begin(); }
[[nodiscard]] auto end() const { return requests_.end(); }
private:
const WorkerInfo* worker_;
std::vector<InferenceRequestPtr> requests_;
std::vector<BufferPtr> input_buffers_;
std::vector<BufferPtr> output_buffers_;
#ifdef AMDINFER_ENABLE_TRACING
std::vector<TracePtr> traces_;
#endif
#ifdef AMDINFER_ENABLE_METRICS
std::vector<std::chrono::high_resolution_clock::time_point> start_times_;
#endif
};
using BatchPtr = std::unique_ptr<Batch>;
using BatchPtrQueue = BlockingQueue<BatchPtr>;
class Batcher {
public:
Batcher();
explicit Batcher(RequestParameters* parameters);
// explicit Batcher(const std::string& name);
Batcher(const Batcher& batcher);
Batcher& operator=(const Batcher&) = delete;
Batcher(Batcher&& other) = delete;
Batcher& operator=(Batcher&& other) =
delete;
virtual ~Batcher() = default;
void start(WorkerInfo* worker);
void setBatchSize(size_t batch_size);
void setName(const std::string& name);
[[nodiscard]] std::string getName() const;
BlockingQueue<InterfacePtr>* getInputQueue();
BatchPtrQueue* getOutputQueue();
void run(WorkerInfo* worker);
BatcherStatus getStatus();
void enqueue(InterfacePtr request);
void end();
protected:
#ifdef AMDINFER_ENABLE_LOGGING
[[nodiscard]] const Logger& getLogger() const;
#endif
size_t batch_size_ = 1;
std::shared_ptr<BlockingQueue<InterfacePtr>> input_queue_;
std::shared_ptr<BatchPtrQueue> output_queue_;
std::thread thread_;
std::string model_;
RequestParameters parameters_;
private:
virtual void doRun(WorkerInfo* worker) = 0;
BatcherStatus status_;
#ifdef AMDINFER_ENABLE_LOGGING
Logger logger_{Loggers::Server};
#endif
};
} // namespace amdinfer
#endif // GUARD_AMDINFER_BATCHING_BATCHER