Program Listing for File data_types_internal.cpp¶
↰ Return to documentation for file (/workspace/amdinfer/src/amdinfer/core/data_types_internal.cpp
)
// Copyright 2022 Advanced Micro Devices, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "amdinfer/core/data_types_internal.hpp"
#include <cstddef> // for size_t
#include <cstdint> // for int32_t
#include <string> // for operator+, to_string
#include "amdinfer/core/exceptions.hpp" // for invalid_argument
#ifdef AMDINFER_ENABLE_VITIS
#include <xir/util/data_type.hpp> // for DataType, DataType::FLOAT, DataTyp...
#endif
#ifdef AMDINFER_ENABLE_VITIS
namespace amdinfer {
const auto kBitsInByte = 8;
DataType mapXirToType(xir::DataType type) {
auto data_type = type.type;
size_t width = type.bit_width / kBitsInByte;
if (data_type == xir::DataType::FLOAT) {
if (width == DataType("FP32").size()) {
return DataType::Fp32;
}
if (width == DataType("FP64").size()) {
return DataType::Fp64;
}
throw invalid_argument("Unsupported XIR float width: " +
std::to_string(width));
}
if (data_type == xir::DataType::INT || data_type == xir::DataType::XINT) {
if (width == DataType("INT8").size()) {
return DataType::Int8;
}
if (width == DataType("INT16").size()) {
return DataType::Int16;
}
if (width == DataType("INT32").size()) {
return DataType::Int32;
}
if (width == DataType("INT64").size()) {
return DataType::Int64;
}
throw invalid_argument("Unsupported XIR int width: " +
std::to_string(width));
}
if (data_type == xir::DataType::UINT || data_type == xir::DataType::XUINT) {
if (width == DataType("UINT8").size()) {
return DataType::Uint8;
}
if (width == DataType("UINT16").size()) {
return DataType::Uint16;
}
if (width == DataType("UINT32").size()) {
return DataType::Uint32;
}
if (width == DataType("UINT64").size()) {
return DataType::Uint64;
}
throw invalid_argument("Unsupported XIR uint width: " +
std::to_string(width));
}
throw invalid_argument("Unsupported XIR type: " + std::to_string(data_type));
}
xir::DataType mapTypeToXir(DataType type) {
xir::DataType retval;
auto bit_width = static_cast<int32_t>(type.size()) * kBitsInByte;
switch (type) {
case DataType::Bool:
case DataType::Uint8:
case DataType::Uint16:
case DataType::Uint32:
case DataType::Uint64:
retval.type = xir::DataType::UINT;
break;
case DataType::Int8:
case DataType::Int16:
case DataType::Int32:
case DataType::Int64:
retval.type = xir::DataType::INT;
break;
// case DataType::Fp16 fall through to default handler
case DataType::Fp32:
case DataType::Fp64:
retval.type = xir::DataType::FLOAT;
break;
default:
throw invalid_argument("Unsupported type conversion to XIR");
}
retval.bit_width = bit_width;
return retval;
}
} // namespace amdinfer
#endif