Program Listing for File data_types.hpp

Return to documentation for file (/workspace/amdinfer/include/amdinfer/core/data_types.hpp)

// Copyright 2021 Xilinx, Inc.
// Copyright 2022 Advanced Micro Devices, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef GUARD_AMDINFER_CORE_DATA_TYPES
#define GUARD_AMDINFER_CORE_DATA_TYPES

#include <cstddef>      // for size_t
#include <cstdint>      // for uint8_t, int16_t, int32_t
#include <iostream>     // for ostream
#include <string>       // for string
#include <string_view>  // for string_view

#include "amdinfer/core/exceptions.hpp"  // for invalid_argument
#include "half/half.hpp"                 // for half

namespace amdinfer {

namespace detail {
// used for hashing strings for switch statements. Assuming lowercase, uppercase
// and numbers (26 + 26 + 10 chars) < 2^7 options per char and string length <
// 9 chars
constexpr uint64_t hash(std::string_view str) {
  const int shift = 7;    // <2^7 options per char
  const int offset = 47;  // 0 in ascii corresponds to 48 so map it to 1

  const auto* const data = str.data();
  const auto size = str.size();

  uint64_t hash = static_cast<int>(*data);
  for (const char* c = data + 1; c < data + size; ++c) {
    hash = (hash << shift) + static_cast<int>(*c) - offset;
  }
  return hash;
}
}  // namespace detail

// this is kept lower-case for visual consistency with other POD types
using fp16 = half_float::half;  // NOLINT(readability-identifier-naming)

class DataType {
 public:
  enum Value : uint8_t {
    Bool,
    BOOL = Bool,
    Uint8,
    UINT8 = Uint8,
    Uint16,
    UINT16 = Uint16,
    Uint32,
    UINT32 = Uint32,
    Uint64,
    UINT64 = Uint64,
    Int8,
    INT8 = Int8,
    Int16,
    INT16 = Int16,
    Int32,
    INT32 = Int32,
    Int64,
    INT64 = Int64,
    Fp16,
    Float16 = Fp16,
    FP16 = Fp16,
    Fp32,
    Float32 = Fp32,
    FP32 = Fp32,
    Fp64,
    Float64 = Fp64,
    FP64 = Fp64,
    String,
    STRING = String,
    Unknown,
    UNKNOWN = Unknown
  };

  constexpr DataType() = default;
  constexpr explicit DataType(const char* value)
    : value_(mapStrToType(value)) {}
  // NOLINTNEXTLINE(google-explicit-constructor, hicpp-explicit-conversions)
  constexpr DataType(DataType::Value value) : value_(value) {}

  // NOLINTNEXTLINE(google-explicit-constructor, hicpp-explicit-conversions)
  constexpr operator Value() const { return value_; }

  friend std::ostream& operator<<(std::ostream& os, const DataType& value);

  [[nodiscard]] constexpr size_t size() const {
    switch (value_) {
      case DataType::Bool:
        return sizeof(bool);
      case DataType::Uint8:
        return sizeof(uint8_t);
      case DataType::Uint16:
        return sizeof(uint16_t);
      case DataType::Uint32:
        return sizeof(uint32_t);
      case DataType::Uint64:
        return sizeof(uint64_t);
      case DataType::Int8:
        return sizeof(int8_t);
      case DataType::Int16:
        return sizeof(int16_t);
      case DataType::Int32:
        return sizeof(int32_t);
      case DataType::Int64:
        return sizeof(int64_t);
      case DataType::Fp16:
        return sizeof(fp16);
      case DataType::Fp32:
        return sizeof(float);
      case DataType::Fp64:
        return sizeof(double);
      case DataType::String:
        return sizeof(std::string);
      default:
        throw invalid_argument("Unknown datatype passed");
    }
  }

  [[nodiscard]] constexpr const char* str() const {
    switch (value_) {
      case DataType::Bool:
        return "BOOL";
      case DataType::Uint8:
        return "UINT8";
      case DataType::Uint16:
        return "UINT16";
      case DataType::Uint32:
        return "UINT32";
      case DataType::Uint64:
        return "UINT64";
      case DataType::Int8:
        return "INT8";
      case DataType::Int16:
        return "INT16";
      case DataType::Int32:
        return "INT32";
      case DataType::Int64:
        return "INT64";
      case DataType::Fp16:
        return "FP16";
      case DataType::Fp32:
        return "FP32";
      case DataType::Fp64:
        return "FP64";
      case DataType::String:
        return "STRING";
      default:
        throw invalid_argument("Unknown datatype passed");
    }
  }

 private:
  constexpr DataType::Value static mapStrToType(const char* value) {
    switch (detail::hash(value)) {
      case detail::hash("BOOL"):
      case detail::hash("Bool"):
        return DataType::Bool;
      case detail::hash("UINT8"):
      case detail::hash("Uint8"):
        return DataType::Uint8;
      case detail::hash("UINT16"):
      case detail::hash("Uint16"):
        return DataType::Uint16;
      case detail::hash("UINT32"):
      case detail::hash("Uint32"):
        return DataType::Uint32;
      case detail::hash("UINT64"):
      case detail::hash("Uint64"):
        return DataType::Uint64;
      case detail::hash("INT8"):
      case detail::hash("Int8"):
        return DataType::Int8;
      case detail::hash("INT16"):
      case detail::hash("Int16"):
        return DataType::Int16;
      case detail::hash("INT32"):
      case detail::hash("Int32"):
        return DataType::Int32;
      case detail::hash("INT64"):
      case detail::hash("Int64"):
        return DataType::Int64;
      case detail::hash("FP16"):
      case detail::hash("Fp16"):
        return DataType::Fp16;
      case detail::hash("FP32"):
      case detail::hash("Fp32"):
        return DataType::Fp32;
      case detail::hash("FP64"):
      case detail::hash("Fp64"):
        return DataType::Fp64;
      case detail::hash("STRING"):
      case detail::hash("String"):
        return DataType::String;
      default:
        throw invalid_argument("Unknown datatype passed");
    }
  }

  Value value_ = Value::Unknown;
};

template <typename F, typename... Args>
auto switchOverTypes(F f, DataType type, [[maybe_unused]] const Args&... args) {
  switch (type) {
    case DataType::Bool: {
      return f.template operator()<bool>(args...);
    }
    case DataType::Uint8: {
      return f.template operator()<uint8_t>(args...);
    }
    case DataType::Uint16: {
      return f.template operator()<uint16_t>(args...);
    }
    case DataType::Uint32: {
      return f.template operator()<uint32_t>(args...);
    }
    case DataType::Uint64: {
      return f.template operator()<uint64_t>(args...);
    }
    case DataType::Int8: {
      return f.template operator()<int8_t>(args...);
    }
    case DataType::Int16: {
      return f.template operator()<int16_t>(args...);
    }
    case DataType::Int32: {
      return f.template operator()<int32_t>(args...);
    }
    case DataType::Int64: {
      return f.template operator()<int64_t>(args...);
    }
    case DataType::Fp16: {
      return f.template operator()<fp16>(args...);
    }
    case DataType::Fp32: {
      return f.template operator()<float>(args...);
    }
    case DataType::Fp64: {
      return f.template operator()<double>(args...);
    }
    case DataType::String: {
      return f.template operator()<char>(args...);
    }
    default:
      throw invalid_argument("Unknown datatype passed");
  }
}

}  // namespace amdinfer
#endif  // GUARD_AMDINFER_CORE_DATA_TYPES