Program Listing for File ctpl.hpp

Return to documentation for file (/workspace/amdinfer/src/amdinfer/util/ctpl.hpp)

// Copyright 2014 Vitaliy Vitsentiy
// Copyright 2021 Xilinx, Inc.
// Copyright 2022 Advanced Micro Devices, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef GUARD_AMDINFER_UTIL_CTPL
#define GUARD_AMDINFER_UTIL_CTPL

#include <concurrentqueue/concurrentqueue.h>  // for ConcurrentQueue

#include <atomic>              // for atomic
#include <condition_variable>  // for condition_variable
#include <functional>          // for function, _1, bind
#include <future>              // for future, packaged_task
#include <memory>              // for make_shared, shared_ptr
#include <mutex>               // for mutex
#include <thread>              // for thread
#include <utility>             // for forward
#include <vector>              // for vector

// thread pool to run user's functors with signature
//      ret func(int id, other_params)
// where id is the index of the thread that runs the functor
// ret is some return type

namespace amdinfer::util {

class ThreadPool {
 public:
  ThreadPool();
  explicit ThreadPool(int thread_num);
  ThreadPool(int thread_num, int queue_size);

  ThreadPool(const ThreadPool &) = delete;
  ThreadPool(ThreadPool &&) = delete;
  ThreadPool &operator=(const ThreadPool &) = delete;
  ThreadPool &operator=(ThreadPool &&) = delete;

  // the destructor waits for all the functions in the queue to be finished
  ~ThreadPool();

  // get the number of running threads in the pool
  int getSize() const;

  // number of idle threads
  int getIdle() const;
  std::thread &getThread(int i);

  // change the number of threads in the pool
  // should be called from one thread, otherwise be careful to not interleave,
  // also with this->stop() nThreads must be >= 0
  void resize(int thread_num);

  // empty the queue
  void clearQueue();

  // pops a functional wrapper to the original function
  std::function<void(int)> pop();

  // wait for all computing threads to finish and stop all threads
  // may be called asynchronously to not pause the calling thread while waiting
  // if wait == true, all the functions in the queue are run, otherwise the
  // queue is cleared without running the functions
  void stop(bool wait = false);

  template <typename F, typename... Rest>
  auto push(F &&f, Rest &&...rest) -> std::future<decltype(f(0, rest...))> {
    auto pck =
      std::make_shared<std::packaged_task<decltype(f(0, rest...))(int)>>(
        std::bind(std::forward<F>(f), std::placeholders::_1,
                  std::forward<Rest>(rest)...));

    // NOLINTNEXTLINE(cppcoreguidelines-owning-memory)
    auto *f_new =
      new std::function<void(int id)>([pck](int id) { (*pck)(id); });
    q_.enqueue(f_new);

    // tidy thinks this may leak f_new. Should refactor to use smart pointers
    // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
    std::unique_lock lock(mutex_);
    cv_.notify_one();

    return pck->get_future();
  }

  // run the user's function that excepts argument int - id of the running
  // thread. returned value is templatized operator returns std::future, where
  // the user can get the result and rethrow the caught exceptions
  template <typename F>
  auto push(F &&f) -> std::future<decltype(f(0))> {
    auto pck = std::make_shared<std::packaged_task<decltype(f(0))(int)>>(
      std::forward<F>(f));

    // NOLINTNEXTLINE(cppcoreguidelines-owning-memory)
    auto *f_new =
      new std::function<void(int id)>([pck](int id) { (*pck)(id); });
    q_.enqueue(f_new);

    // tidy thinks this may leak f_new. Should refactor to use smart pointers
    // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
    std::unique_lock lock(mutex_);
    cv_.notify_one();

    return pck->get_future();
  }

 private:
  void setThread(int i);

  std::vector<std::unique_ptr<std::thread>> threads_;
  std::vector<std::shared_ptr<std::atomic<bool>>> flags_;
  mutable moodycamel::ConcurrentQueue<std::function<void(int id)> *> q_;
  std::atomic<bool> done_ = false;
  std::atomic<bool> stop_ = false;
  std::atomic<int> waiting_ = 0;  // how many threads are waiting

  std::mutex mutex_;
  std::condition_variable cv_;
};

}  // namespace amdinfer::util

#endif  // GUARD_AMDINFER_UTIL_CTPL