Add missing validation for NN canonical types

Bug: 177669661
Test: mma
Test: NeuralNetworksTest_static
Change-Id: Ic05c177f61a906a69bf82ff9c4d5bb8b0556d5ca
This commit is contained in:
Michael Butler 2021-02-03 15:15:43 -08:00
parent 89d463fc3c
commit 08ee3f9287
12 changed files with 176 additions and 264 deletions

View file

@ -22,10 +22,15 @@
#include <android-base/logging.h>
#include <android/hardware/neuralnetworks/1.0/types.h>
#include <nnapi/Result.h>
#include <nnapi/TypeUtils.h>
#include <nnapi/Types.h>
#include <nnapi/Validation.h>
#include <nnapi/hal/HandleError.h>
namespace android::hardware::neuralnetworks::V1_0::utils {
constexpr auto kVersion = nn::Version::ANDROID_OC_MR1;
template <typename Type>
nn::Result<void> validate(const Type& halObject) {
const auto maybeCanonical = nn::convert(halObject);
@ -44,6 +49,15 @@ bool valid(const Type& halObject) {
return result.has_value();
}
template <typename Type>
nn::GeneralResult<void> compliantVersion(const Type& canonical) {
const auto version = NN_TRY(hal::utils::makeGeneralFailure(nn::validate(canonical)));
if (version > kVersion) {
return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
}
return {};
}
template <typename Type>
auto convertFromNonCanonical(const Type& nonCanonicalObject)
-> decltype(convert(nn::convert(nonCanonicalObject).value())) {

View file

@ -35,6 +35,8 @@
#include <utility>
#include <variant>
#include "Utils.h"
namespace {
template <typename Type>
@ -42,8 +44,6 @@ constexpr std::underlying_type_t<Type> underlyingType(Type value) {
return static_cast<std::underlying_type_t<Type>>(value);
}
constexpr auto kVersion = android::nn::Version::ANDROID_OC_MR1;
} // namespace
namespace android::nn {
@ -53,13 +53,13 @@ using hardware::hidl_memory;
using hardware::hidl_vec;
template <typename Input>
using unvalidatedConvertOutput =
using UnvalidatedConvertOutput =
std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
template <typename Type>
GeneralResult<std::vector<unvalidatedConvertOutput<Type>>> unvalidatedConvert(
GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
const hidl_vec<Type>& arguments) {
std::vector<unvalidatedConvertOutput<Type>> canonical;
std::vector<UnvalidatedConvertOutput<Type>> canonical;
canonical.reserve(arguments.size());
for (const auto& argument : arguments) {
canonical.push_back(NN_TRY(nn::unvalidatedConvert(argument)));
@ -68,16 +68,9 @@ GeneralResult<std::vector<unvalidatedConvertOutput<Type>>> unvalidatedConvert(
}
template <typename Type>
decltype(nn::unvalidatedConvert(std::declval<Type>())) validatedConvert(const Type& halObject) {
GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& halObject) {
auto canonical = NN_TRY(nn::unvalidatedConvert(halObject));
const auto maybeVersion = validate(canonical);
if (!maybeVersion.has_value()) {
return error() << maybeVersion.error();
}
const auto version = maybeVersion.value();
if (version > kVersion) {
return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
}
NN_TRY(hal::V1_0::utils::compliantVersion(canonical));
return canonical;
}
@ -248,13 +241,13 @@ namespace android::hardware::neuralnetworks::V1_0::utils {
namespace {
template <typename Input>
using unvalidatedConvertOutput =
using UnvalidatedConvertOutput =
std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
template <typename Type>
nn::GeneralResult<hidl_vec<unvalidatedConvertOutput<Type>>> unvalidatedConvert(
nn::GeneralResult<hidl_vec<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
const std::vector<Type>& arguments) {
hidl_vec<unvalidatedConvertOutput<Type>> halObject(arguments.size());
hidl_vec<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i) {
halObject[i] = NN_TRY(utils::unvalidatedConvert(arguments[i]));
}
@ -262,15 +255,8 @@ nn::GeneralResult<hidl_vec<unvalidatedConvertOutput<Type>>> unvalidatedConvert(
}
template <typename Type>
decltype(utils::unvalidatedConvert(std::declval<Type>())) validatedConvert(const Type& canonical) {
const auto maybeVersion = nn::validate(canonical);
if (!maybeVersion.has_value()) {
return nn::error() << maybeVersion.error();
}
const auto version = maybeVersion.value();
if (version > kVersion) {
return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
}
nn::GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& canonical) {
NN_TRY(compliantVersion(canonical));
return utils::unvalidatedConvert(canonical);
}

View file

@ -22,12 +22,16 @@
#include <android-base/logging.h>
#include <android/hardware/neuralnetworks/1.1/types.h>
#include <nnapi/Result.h>
#include <nnapi/TypeUtils.h>
#include <nnapi/Types.h>
#include <nnapi/Validation.h>
#include <nnapi/hal/1.0/Conversions.h>
#include <nnapi/hal/HandleError.h>
namespace android::hardware::neuralnetworks::V1_1::utils {
constexpr auto kDefaultExecutionPreference = ExecutionPreference::FAST_SINGLE_ANSWER;
constexpr auto kVersion = nn::Version::ANDROID_P;
template <typename Type>
nn::Result<void> validate(const Type& halObject) {
@ -47,6 +51,15 @@ bool valid(const Type& halObject) {
return result.has_value();
}
template <typename Type>
nn::GeneralResult<void> compliantVersion(const Type& canonical) {
const auto version = NN_TRY(hal::utils::makeGeneralFailure(nn::validate(canonical)));
if (version > kVersion) {
return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
}
return {};
}
template <typename Type>
auto convertFromNonCanonical(const Type& nonCanonicalObject)
-> decltype(convert(nn::convert(nonCanonicalObject).value())) {

View file

@ -35,11 +35,7 @@
#include <type_traits>
#include <utility>
namespace {
constexpr auto kVersion = android::nn::Version::ANDROID_P;
} // namespace
#include "Utils.h"
namespace android::nn {
namespace {
@ -47,13 +43,13 @@ namespace {
using hardware::hidl_vec;
template <typename Input>
using unvalidatedConvertOutput =
using UnvalidatedConvertOutput =
std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
template <typename Type>
GeneralResult<std::vector<unvalidatedConvertOutput<Type>>> unvalidatedConvert(
GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
const hidl_vec<Type>& arguments) {
std::vector<unvalidatedConvertOutput<Type>> canonical;
std::vector<UnvalidatedConvertOutput<Type>> canonical;
canonical.reserve(arguments.size());
for (const auto& argument : arguments) {
canonical.push_back(NN_TRY(nn::unvalidatedConvert(argument)));
@ -62,16 +58,9 @@ GeneralResult<std::vector<unvalidatedConvertOutput<Type>>> unvalidatedConvert(
}
template <typename Type>
decltype(nn::unvalidatedConvert(std::declval<Type>())) validatedConvert(const Type& halObject) {
GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& halObject) {
auto canonical = NN_TRY(nn::unvalidatedConvert(halObject));
const auto maybeVersion = validate(canonical);
if (!maybeVersion.has_value()) {
return error() << maybeVersion.error();
}
const auto version = maybeVersion.value();
if (version > kVersion) {
return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
}
NN_TRY(hal::V1_1::utils::compliantVersion(canonical));
return canonical;
}
@ -180,13 +169,13 @@ nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::SharedMemory& memory
}
template <typename Input>
using unvalidatedConvertOutput =
using UnvalidatedConvertOutput =
std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
template <typename Type>
nn::GeneralResult<hidl_vec<unvalidatedConvertOutput<Type>>> unvalidatedConvert(
nn::GeneralResult<hidl_vec<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
const std::vector<Type>& arguments) {
hidl_vec<unvalidatedConvertOutput<Type>> halObject(arguments.size());
hidl_vec<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i) {
halObject[i] = NN_TRY(unvalidatedConvert(arguments[i]));
}
@ -194,16 +183,9 @@ nn::GeneralResult<hidl_vec<unvalidatedConvertOutput<Type>>> unvalidatedConvert(
}
template <typename Type>
decltype(utils::unvalidatedConvert(std::declval<Type>())) validatedConvert(const Type& canonical) {
const auto maybeVersion = nn::validate(canonical);
if (!maybeVersion.has_value()) {
return nn::error() << maybeVersion.error();
}
const auto version = maybeVersion.value();
if (version > kVersion) {
return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
}
return utils::unvalidatedConvert(canonical);
nn::GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& canonical) {
NN_TRY(compliantVersion(canonical));
return unvalidatedConvert(canonical);
}
} // anonymous namespace

View file

@ -22,19 +22,25 @@
#include <android-base/logging.h>
#include <android/hardware/neuralnetworks/1.2/types.h>
#include <nnapi/Result.h>
#include <nnapi/TypeUtils.h>
#include <nnapi/Types.h>
#include <nnapi/Validation.h>
#include <nnapi/hal/1.0/Conversions.h>
#include <nnapi/hal/1.1/Conversions.h>
#include <nnapi/hal/1.1/Utils.h>
#include <nnapi/hal/HandleError.h>
#include <limits>
namespace android::hardware::neuralnetworks::V1_2::utils {
using CacheToken = hidl_array<uint8_t, static_cast<size_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN)>;
using V1_1::utils::kDefaultExecutionPreference;
constexpr auto kDefaultMesaureTiming = MeasureTiming::NO;
constexpr auto kNoTiming = Timing{.timeOnDevice = std::numeric_limits<uint64_t>::max(),
.timeInDriver = std::numeric_limits<uint64_t>::max()};
constexpr auto kVersion = nn::Version::ANDROID_Q;
template <typename Type>
nn::Result<void> validate(const Type& halObject) {
@ -54,6 +60,15 @@ bool valid(const Type& halObject) {
return result.has_value();
}
template <typename Type>
nn::GeneralResult<void> compliantVersion(const Type& canonical) {
const auto version = NN_TRY(hal::utils::makeGeneralFailure(nn::validate(canonical)));
if (version > kVersion) {
return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
}
return {};
}
template <typename Type>
auto convertFromNonCanonical(const Type& nonCanonicalObject)
-> decltype(convert(nn::convert(nonCanonicalObject).value())) {

View file

@ -37,6 +37,8 @@
#include <type_traits>
#include <utility>
#include "Utils.h"
namespace {
template <typename Type>
@ -45,50 +47,23 @@ constexpr std::underlying_type_t<Type> underlyingType(Type value) {
}
using HalDuration = std::chrono::duration<uint64_t, std::micro>;
constexpr auto kVersion = android::nn::Version::ANDROID_Q;
constexpr uint64_t kNoTiming = std::numeric_limits<uint64_t>::max();
} // namespace
namespace android::nn {
namespace {
constexpr bool validOperandType(OperandType operandType) {
switch (operandType) {
case OperandType::FLOAT32:
case OperandType::INT32:
case OperandType::UINT32:
case OperandType::TENSOR_FLOAT32:
case OperandType::TENSOR_INT32:
case OperandType::TENSOR_QUANT8_ASYMM:
case OperandType::BOOL:
case OperandType::TENSOR_QUANT16_SYMM:
case OperandType::TENSOR_FLOAT16:
case OperandType::TENSOR_BOOL8:
case OperandType::FLOAT16:
case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
case OperandType::TENSOR_QUANT16_ASYMM:
case OperandType::TENSOR_QUANT8_SYMM:
case OperandType::OEM:
case OperandType::TENSOR_OEM_BYTE:
return true;
default:
break;
}
return isExtension(operandType);
}
using hardware::hidl_handle;
using hardware::hidl_vec;
template <typename Input>
using unvalidatedConvertOutput =
using UnvalidatedConvertOutput =
std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
template <typename Type>
GeneralResult<std::vector<unvalidatedConvertOutput<Type>>> unvalidatedConvertVec(
GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
const hidl_vec<Type>& arguments) {
std::vector<unvalidatedConvertOutput<Type>> canonical;
std::vector<UnvalidatedConvertOutput<Type>> canonical;
canonical.reserve(arguments.size());
for (const auto& argument : arguments) {
canonical.push_back(NN_TRY(nn::unvalidatedConvert(argument)));
@ -97,29 +72,16 @@ GeneralResult<std::vector<unvalidatedConvertOutput<Type>>> unvalidatedConvertVec
}
template <typename Type>
GeneralResult<std::vector<unvalidatedConvertOutput<Type>>> unvalidatedConvert(
const hidl_vec<Type>& arguments) {
return unvalidatedConvertVec(arguments);
}
template <typename Type>
decltype(nn::unvalidatedConvert(std::declval<Type>())) validatedConvert(const Type& halObject) {
GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& halObject) {
auto canonical = NN_TRY(nn::unvalidatedConvert(halObject));
const auto maybeVersion = validate(canonical);
if (!maybeVersion.has_value()) {
return error() << maybeVersion.error();
}
const auto version = maybeVersion.value();
if (version > kVersion) {
return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
}
NN_TRY(hal::V1_2::utils::compliantVersion(canonical));
return canonical;
}
template <typename Type>
GeneralResult<std::vector<unvalidatedConvertOutput<Type>>> validatedConvert(
GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> validatedConvert(
const hidl_vec<Type>& arguments) {
std::vector<unvalidatedConvertOutput<Type>> canonical;
std::vector<UnvalidatedConvertOutput<Type>> canonical;
canonical.reserve(arguments.size());
for (const auto& argument : arguments) {
canonical.push_back(NN_TRY(validatedConvert(argument)));
@ -145,8 +107,7 @@ GeneralResult<Capabilities> unvalidatedConvert(const hal::V1_2::Capabilities& ca
const bool validOperandTypes = std::all_of(
capabilities.operandPerformance.begin(), capabilities.operandPerformance.end(),
[](const hal::V1_2::Capabilities::OperandPerformance& operandPerformance) {
const auto maybeType = unvalidatedConvert(operandPerformance.type);
return !maybeType.has_value() ? false : validOperandType(maybeType.value());
return validatedConvert(operandPerformance.type).has_value();
});
if (!validOperandTypes) {
return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
@ -275,6 +236,7 @@ GeneralResult<MeasureTiming> unvalidatedConvert(const hal::V1_2::MeasureTiming&
GeneralResult<Timing> unvalidatedConvert(const hal::V1_2::Timing& timing) {
constexpr uint64_t kMaxTiming = std::chrono::floor<HalDuration>(Duration::max()).count();
constexpr auto convertTiming = [](uint64_t halTiming) -> OptionalDuration {
constexpr uint64_t kNoTiming = std::numeric_limits<uint64_t>::max();
if (halTiming == kNoTiming) {
return {};
}
@ -378,25 +340,19 @@ nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::SharedMemory& memory
}
template <typename Input>
using unvalidatedConvertOutput =
using UnvalidatedConvertOutput =
std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
template <typename Type>
nn::GeneralResult<hidl_vec<unvalidatedConvertOutput<Type>>> unvalidatedConvertVec(
nn::GeneralResult<hidl_vec<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
const std::vector<Type>& arguments) {
hidl_vec<unvalidatedConvertOutput<Type>> halObject(arguments.size());
hidl_vec<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i) {
halObject[i] = NN_TRY(unvalidatedConvert(arguments[i]));
}
return halObject;
}
template <typename Type>
nn::GeneralResult<hidl_vec<unvalidatedConvertOutput<Type>>> unvalidatedConvert(
const std::vector<Type>& arguments) {
return unvalidatedConvertVec(arguments);
}
nn::GeneralResult<Operand::ExtraParams> makeExtraParams(nn::Operand::NoParams /*noParams*/) {
return Operand::ExtraParams{};
}
@ -416,22 +372,15 @@ nn::GeneralResult<Operand::ExtraParams> makeExtraParams(
}
template <typename Type>
decltype(utils::unvalidatedConvert(std::declval<Type>())) validatedConvert(const Type& canonical) {
const auto maybeVersion = nn::validate(canonical);
if (!maybeVersion.has_value()) {
return nn::error() << maybeVersion.error();
}
const auto version = maybeVersion.value();
if (version > kVersion) {
return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
}
return utils::unvalidatedConvert(canonical);
nn::GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& canonical) {
NN_TRY(compliantVersion(canonical));
return unvalidatedConvert(canonical);
}
template <typename Type>
nn::GeneralResult<hidl_vec<unvalidatedConvertOutput<Type>>> validatedConvert(
nn::GeneralResult<hidl_vec<UnvalidatedConvertOutput<Type>>> validatedConvert(
const std::vector<Type>& arguments) {
hidl_vec<unvalidatedConvertOutput<Type>> halObject(arguments.size());
hidl_vec<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i) {
halObject[i] = NN_TRY(validatedConvert(arguments[i]));
}
@ -469,7 +418,7 @@ nn::GeneralResult<Capabilities> unvalidatedConvert(const nn::Capabilities& capab
capabilities.operandPerformance.asVector().end(),
std::back_inserter(operandPerformance),
[](const nn::Capabilities::OperandPerformance& operandPerformance) {
return nn::validOperandType(operandPerformance.type);
return compliantVersion(operandPerformance.type).has_value();
});
return Capabilities{
@ -570,6 +519,7 @@ nn::GeneralResult<MeasureTiming> unvalidatedConvert(const nn::MeasureTiming& mea
nn::GeneralResult<Timing> unvalidatedConvert(const nn::Timing& timing) {
constexpr auto convertTiming = [](nn::OptionalDuration canonicalTiming) -> uint64_t {
constexpr uint64_t kNoTiming = std::numeric_limits<uint64_t>::max();
if (!canonicalTiming.has_value()) {
return kNoTiming;
}

View file

@ -22,14 +22,25 @@
#include <android-base/logging.h>
#include <android/hardware/neuralnetworks/1.3/types.h>
#include <nnapi/Result.h>
#include <nnapi/TypeUtils.h>
#include <nnapi/Types.h>
#include <nnapi/Validation.h>
#include <nnapi/hal/1.0/Conversions.h>
#include <nnapi/hal/1.1/Conversions.h>
#include <nnapi/hal/1.1/Utils.h>
#include <nnapi/hal/1.2/Conversions.h>
#include <nnapi/hal/1.2/Utils.h>
#include <nnapi/hal/HandleError.h>
namespace android::hardware::neuralnetworks::V1_3::utils {
using V1_1::utils::kDefaultExecutionPreference;
using V1_2::utils::CacheToken;
using V1_2::utils::kDefaultMesaureTiming;
using V1_2::utils::kNoTiming;
constexpr auto kDefaultPriority = Priority::MEDIUM;
constexpr auto kVersion = nn::Version::ANDROID_R;
template <typename Type>
nn::Result<void> validate(const Type& halObject) {
@ -49,6 +60,15 @@ bool valid(const Type& halObject) {
return result.has_value();
}
template <typename Type>
nn::GeneralResult<void> compliantVersion(const Type& canonical) {
const auto version = NN_TRY(hal::utils::makeGeneralFailure(nn::validate(canonical)));
if (version > kVersion) {
return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
}
return {};
}
template <typename Type>
auto convertFromNonCanonical(const Type& nonCanonicalObject)
-> decltype(convert(nn::convert(nonCanonicalObject).value())) {

View file

@ -38,6 +38,8 @@
#include <type_traits>
#include <utility>
#include "Utils.h"
namespace {
template <typename Type>
@ -45,48 +47,21 @@ constexpr std::underlying_type_t<Type> underlyingType(Type value) {
return static_cast<std::underlying_type_t<Type>>(value);
}
constexpr auto kVersion = android::nn::Version::ANDROID_R;
} // namespace
namespace android::nn {
namespace {
constexpr auto validOperandType(nn::OperandType operandType) {
switch (operandType) {
case nn::OperandType::FLOAT32:
case nn::OperandType::INT32:
case nn::OperandType::UINT32:
case nn::OperandType::TENSOR_FLOAT32:
case nn::OperandType::TENSOR_INT32:
case nn::OperandType::TENSOR_QUANT8_ASYMM:
case nn::OperandType::BOOL:
case nn::OperandType::TENSOR_QUANT16_SYMM:
case nn::OperandType::TENSOR_FLOAT16:
case nn::OperandType::TENSOR_BOOL8:
case nn::OperandType::FLOAT16:
case nn::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
case nn::OperandType::TENSOR_QUANT16_ASYMM:
case nn::OperandType::TENSOR_QUANT8_SYMM:
case nn::OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
case nn::OperandType::SUBGRAPH:
case nn::OperandType::OEM:
case nn::OperandType::TENSOR_OEM_BYTE:
return true;
}
return nn::isExtension(operandType);
}
using hardware::hidl_vec;
template <typename Input>
using unvalidatedConvertOutput =
using UnvalidatedConvertOutput =
std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
template <typename Type>
GeneralResult<std::vector<unvalidatedConvertOutput<Type>>> unvalidatedConvertVec(
GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
const hidl_vec<Type>& arguments) {
std::vector<unvalidatedConvertOutput<Type>> canonical;
std::vector<UnvalidatedConvertOutput<Type>> canonical;
canonical.reserve(arguments.size());
for (const auto& argument : arguments) {
canonical.push_back(NN_TRY(nn::unvalidatedConvert(argument)));
@ -95,29 +70,16 @@ GeneralResult<std::vector<unvalidatedConvertOutput<Type>>> unvalidatedConvertVec
}
template <typename Type>
GeneralResult<std::vector<unvalidatedConvertOutput<Type>>> unvalidatedConvert(
const hidl_vec<Type>& arguments) {
return unvalidatedConvertVec(arguments);
}
template <typename Type>
decltype(nn::unvalidatedConvert(std::declval<Type>())) validatedConvert(const Type& halObject) {
GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& halObject) {
auto canonical = NN_TRY(nn::unvalidatedConvert(halObject));
const auto maybeVersion = validate(canonical);
if (!maybeVersion.has_value()) {
return error() << maybeVersion.error();
}
const auto version = maybeVersion.value();
if (version > kVersion) {
return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
}
NN_TRY(hal::V1_3::utils::compliantVersion(canonical));
return canonical;
}
template <typename Type>
GeneralResult<std::vector<unvalidatedConvertOutput<Type>>> validatedConvert(
GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> validatedConvert(
const hidl_vec<Type>& arguments) {
std::vector<unvalidatedConvertOutput<Type>> canonical;
std::vector<UnvalidatedConvertOutput<Type>> canonical;
canonical.reserve(arguments.size());
for (const auto& argument : arguments) {
canonical.push_back(NN_TRY(validatedConvert(argument)));
@ -143,8 +105,7 @@ GeneralResult<Capabilities> unvalidatedConvert(const hal::V1_3::Capabilities& ca
const bool validOperandTypes = std::all_of(
capabilities.operandPerformance.begin(), capabilities.operandPerformance.end(),
[](const hal::V1_3::Capabilities::OperandPerformance& operandPerformance) {
const auto maybeType = unvalidatedConvert(operandPerformance.type);
return !maybeType.has_value() ? false : validOperandType(maybeType.value());
return validatedConvert(operandPerformance.type).has_value();
});
if (!validOperandTypes) {
return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
@ -401,25 +362,19 @@ nn::GeneralResult<V1_2::Model::ExtensionNameAndPrefix> unvalidatedConvert(
}
template <typename Input>
using unvalidatedConvertOutput =
using UnvalidatedConvertOutput =
std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
template <typename Type>
nn::GeneralResult<hidl_vec<unvalidatedConvertOutput<Type>>> unvalidatedConvertVec(
nn::GeneralResult<hidl_vec<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
const std::vector<Type>& arguments) {
hidl_vec<unvalidatedConvertOutput<Type>> halObject(arguments.size());
hidl_vec<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i) {
halObject[i] = NN_TRY(unvalidatedConvert(arguments[i]));
}
return halObject;
}
template <typename Type>
nn::GeneralResult<hidl_vec<unvalidatedConvertOutput<Type>>> unvalidatedConvert(
const std::vector<Type>& arguments) {
return unvalidatedConvertVec(arguments);
}
nn::GeneralResult<Request::MemoryPool> makeMemoryPool(const nn::SharedMemory& memory) {
Request::MemoryPool ret;
ret.hidlMemory(NN_TRY(unvalidatedConvert(memory)));
@ -439,22 +394,15 @@ nn::GeneralResult<Request::MemoryPool> makeMemoryPool(const nn::SharedBuffer& /*
using utils::unvalidatedConvert;
template <typename Type>
decltype(unvalidatedConvert(std::declval<Type>())) validatedConvert(const Type& canonical) {
const auto maybeVersion = nn::validate(canonical);
if (!maybeVersion.has_value()) {
return nn::error() << maybeVersion.error();
}
const auto version = maybeVersion.value();
if (version > kVersion) {
return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
}
nn::GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& canonical) {
NN_TRY(compliantVersion(canonical));
return unvalidatedConvert(canonical);
}
template <typename Type>
nn::GeneralResult<hidl_vec<unvalidatedConvertOutput<Type>>> validatedConvert(
nn::GeneralResult<hidl_vec<UnvalidatedConvertOutput<Type>>> validatedConvert(
const std::vector<Type>& arguments) {
hidl_vec<unvalidatedConvertOutput<Type>> halObject(arguments.size());
hidl_vec<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i) {
halObject[i] = NN_TRY(validatedConvert(arguments[i]));
}
@ -482,7 +430,7 @@ nn::GeneralResult<Capabilities> unvalidatedConvert(const nn::Capabilities& capab
capabilities.operandPerformance.asVector().end(),
std::back_inserter(operandPerformance),
[](const nn::Capabilities::OperandPerformance& operandPerformance) {
return nn::validOperandType(operandPerformance.type);
return compliantVersion(operandPerformance.type).has_value();
});
return Capabilities{

View file

@ -99,6 +99,9 @@ GeneralResult<SharedHandle> unvalidatedConvert(
const ::aidl::android::hardware::common::NativeHandle& handle);
GeneralResult<SyncFence> unvalidatedConvert(const ndk::ScopedFileDescriptor& syncFence);
GeneralResult<std::vector<Operation>> unvalidatedConvert(
const std::vector<aidl_hal::Operation>& operations);
GeneralResult<Capabilities> convert(const aidl_hal::Capabilities& capabilities);
GeneralResult<DeviceType> convert(const aidl_hal::DeviceType& deviceType);
GeneralResult<ErrorStatus> convert(const aidl_hal::ErrorStatus& errorStatus);
@ -106,16 +109,13 @@ GeneralResult<ExecutionPreference> convert(
const aidl_hal::ExecutionPreference& executionPreference);
GeneralResult<SharedMemory> convert(const aidl_hal::Memory& memory);
GeneralResult<Model> convert(const aidl_hal::Model& model);
GeneralResult<Operand> convert(const aidl_hal::Operand& operand);
GeneralResult<OperandType> convert(const aidl_hal::OperandType& operandType);
GeneralResult<Priority> convert(const aidl_hal::Priority& priority);
GeneralResult<Request::MemoryPool> convert(const aidl_hal::RequestMemoryPool& memoryPool);
GeneralResult<Request> convert(const aidl_hal::Request& request);
GeneralResult<Timing> convert(const aidl_hal::Timing& timing);
GeneralResult<SyncFence> convert(const ndk::ScopedFileDescriptor& syncFence);
GeneralResult<std::vector<Extension>> convert(const std::vector<aidl_hal::Extension>& extension);
GeneralResult<std::vector<Operation>> convert(const std::vector<aidl_hal::Operation>& outputShapes);
GeneralResult<std::vector<SharedMemory>> convert(const std::vector<aidl_hal::Memory>& memories);
GeneralResult<std::vector<OutputShape>> convert(
const std::vector<aidl_hal::OutputShape>& outputShapes);

View file

@ -21,6 +21,7 @@
#include <android-base/logging.h>
#include <nnapi/Result.h>
#include <nnapi/TypeUtils.h>
#include <nnapi/Types.h>
#include <nnapi/Validation.h>
#include <nnapi/hal/HandleError.h>
@ -48,6 +49,22 @@ bool valid(const Type& halObject) {
return result.has_value();
}
template <typename Type>
nn::GeneralResult<void> compliantVersion(const Type& canonical) {
const auto version = NN_TRY(::android::hardware::neuralnetworks::utils::makeGeneralFailure(
nn::validate(canonical)));
if (version > kVersion) {
return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
}
return {};
}
template <typename Type>
auto convertFromNonCanonical(const Type& nonCanonicalObject)
-> decltype(convert(nn::convert(nonCanonicalObject).value())) {
return convert(NN_TRY(nn::convert(nonCanonicalObject)));
}
nn::GeneralResult<Memory> clone(const Memory& memory);
nn::GeneralResult<Request> clone(const Request& request);
nn::GeneralResult<RequestMemoryPool> clone(const RequestMemoryPool& requestPool);

View file

@ -41,6 +41,8 @@
#include <type_traits>
#include <utility>
#include "Utils.h"
#define VERIFY_NON_NEGATIVE(value) \
while (UNLIKELY(value < 0)) return NN_ERROR()
@ -53,7 +55,6 @@ constexpr std::underlying_type_t<Type> underlyingType(Type value) {
return static_cast<std::underlying_type_t<Type>>(value);
}
constexpr auto kVersion = android::nn::Version::ANDROID_S;
constexpr int64_t kNoTiming = -1;
} // namespace
@ -63,32 +64,6 @@ namespace {
using ::aidl::android::hardware::common::NativeHandle;
constexpr auto validOperandType(nn::OperandType operandType) {
switch (operandType) {
case nn::OperandType::FLOAT32:
case nn::OperandType::INT32:
case nn::OperandType::UINT32:
case nn::OperandType::TENSOR_FLOAT32:
case nn::OperandType::TENSOR_INT32:
case nn::OperandType::TENSOR_QUANT8_ASYMM:
case nn::OperandType::BOOL:
case nn::OperandType::TENSOR_QUANT16_SYMM:
case nn::OperandType::TENSOR_FLOAT16:
case nn::OperandType::TENSOR_BOOL8:
case nn::OperandType::FLOAT16:
case nn::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
case nn::OperandType::TENSOR_QUANT16_ASYMM:
case nn::OperandType::TENSOR_QUANT8_SYMM:
case nn::OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
case nn::OperandType::SUBGRAPH:
return true;
case nn::OperandType::OEM:
case nn::OperandType::TENSOR_OEM_BYTE:
return false;
}
return nn::isExtension(operandType);
}
template <typename Input>
using UnvalidatedConvertOutput =
std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
@ -113,14 +88,7 @@ GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
template <typename Type>
GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& halObject) {
auto canonical = NN_TRY(nn::unvalidatedConvert(halObject));
const auto maybeVersion = validate(canonical);
if (!maybeVersion.has_value()) {
return error() << maybeVersion.error();
}
const auto version = maybeVersion.value();
if (version > kVersion) {
return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
}
NN_TRY(aidl_hal::utils::compliantVersion(canonical));
return canonical;
}
@ -185,13 +153,21 @@ static GeneralResult<UniqueNativeHandle> nativeHandleFromAidlHandle(const Native
GeneralResult<OperandType> unvalidatedConvert(const aidl_hal::OperandType& operandType) {
VERIFY_NON_NEGATIVE(underlyingType(operandType)) << "Negative operand types are not allowed.";
return static_cast<OperandType>(operandType);
const auto canonical = static_cast<OperandType>(operandType);
if (canonical == OperandType::OEM || canonical == OperandType::TENSOR_OEM_BYTE) {
return NN_ERROR() << "Unable to convert invalid OperandType " << canonical;
}
return canonical;
}
GeneralResult<OperationType> unvalidatedConvert(const aidl_hal::OperationType& operationType) {
VERIFY_NON_NEGATIVE(underlyingType(operationType))
<< "Negative operation types are not allowed.";
return static_cast<OperationType>(operationType);
const auto canonical = static_cast<OperationType>(operationType);
if (canonical == OperationType::OEM_OPERATION) {
return NN_ERROR() << "Unable to convert invalid OperationType OEM_OPERATION";
}
return canonical;
}
GeneralResult<DeviceType> unvalidatedConvert(const aidl_hal::DeviceType& deviceType) {
@ -206,8 +182,7 @@ GeneralResult<Capabilities> unvalidatedConvert(const aidl_hal::Capabilities& cap
const bool validOperandTypes = std::all_of(
capabilities.operandPerformance.begin(), capabilities.operandPerformance.end(),
[](const aidl_hal::OperandPerformance& operandPerformance) {
const auto maybeType = unvalidatedConvert(operandPerformance.type);
return !maybeType.has_value() ? false : validOperandType(maybeType.value());
return validatedConvert(operandPerformance.type).has_value();
});
if (!validOperandTypes) {
return NN_ERROR() << "Invalid OperandType when unvalidatedConverting OperandPerformance in "
@ -534,6 +509,11 @@ GeneralResult<SharedHandle> unvalidatedConvert(const NativeHandle& aidlNativeHan
return std::make_shared<const Handle>(NN_TRY(unvalidatedConvertHelper(aidlNativeHandle)));
}
GeneralResult<std::vector<Operation>> unvalidatedConvert(
const std::vector<aidl_hal::Operation>& operations) {
return unvalidatedConvertVec(operations);
}
GeneralResult<SyncFence> unvalidatedConvert(const ndk::ScopedFileDescriptor& syncFence) {
auto duplicatedFd = NN_TRY(dupFd(syncFence.get()));
return SyncFence::create(std::move(duplicatedFd));
@ -564,22 +544,14 @@ GeneralResult<Model> convert(const aidl_hal::Model& model) {
return validatedConvert(model);
}
GeneralResult<Operand> convert(const aidl_hal::Operand& operand) {
return unvalidatedConvert(operand);
}
GeneralResult<OperandType> convert(const aidl_hal::OperandType& operandType) {
return unvalidatedConvert(operandType);
return validatedConvert(operandType);
}
GeneralResult<Priority> convert(const aidl_hal::Priority& priority) {
return validatedConvert(priority);
}
GeneralResult<Request::MemoryPool> convert(const aidl_hal::RequestMemoryPool& memoryPool) {
return unvalidatedConvert(memoryPool);
}
GeneralResult<Request> convert(const aidl_hal::Request& request) {
return validatedConvert(request);
}
@ -589,17 +561,13 @@ GeneralResult<Timing> convert(const aidl_hal::Timing& timing) {
}
GeneralResult<SyncFence> convert(const ndk::ScopedFileDescriptor& syncFence) {
return unvalidatedConvert(syncFence);
return validatedConvert(syncFence);
}
GeneralResult<std::vector<Extension>> convert(const std::vector<aidl_hal::Extension>& extension) {
return validatedConvert(extension);
}
GeneralResult<std::vector<Operation>> convert(const std::vector<aidl_hal::Operation>& operations) {
return unvalidatedConvert(operations);
}
GeneralResult<std::vector<SharedMemory>> convert(const std::vector<aidl_hal::Memory>& memories) {
return validatedConvert(memories);
}
@ -644,14 +612,7 @@ nn::GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConver
template <typename Type>
nn::GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& canonical) {
const auto maybeVersion = nn::validate(canonical);
if (!maybeVersion.has_value()) {
return nn::error() << maybeVersion.error();
}
const auto version = maybeVersion.value();
if (version > kVersion) {
return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
}
NN_TRY(compliantVersion(canonical));
return utils::unvalidatedConvert(canonical);
}
@ -797,6 +758,9 @@ nn::GeneralResult<ExecutionPreference> unvalidatedConvert(
}
nn::GeneralResult<OperandType> unvalidatedConvert(const nn::OperandType& operandType) {
if (operandType == nn::OperandType::OEM || operandType == nn::OperandType::TENSOR_OEM_BYTE) {
return NN_ERROR() << "Unable to convert invalid OperandType " << operandType;
}
return static_cast<OperandType>(operandType);
}
@ -864,6 +828,9 @@ nn::GeneralResult<Operand> unvalidatedConvert(const nn::Operand& operand) {
}
nn::GeneralResult<OperationType> unvalidatedConvert(const nn::OperationType& operationType) {
if (operationType == nn::OperationType::OEM_OPERATION) {
return NN_ERROR() << "Unable to convert invalid OperationType OEM_OPERATION";
}
return static_cast<OperationType>(operationType);
}
@ -1004,7 +971,7 @@ nn::GeneralResult<ndk::ScopedFileDescriptor> unvalidatedConvertCache(
}
nn::GeneralResult<std::vector<uint8_t>> convert(const nn::CacheToken& cacheToken) {
return unvalidatedConvert(cacheToken);
return validatedConvert(cacheToken);
}
nn::GeneralResult<BufferDesc> convert(const nn::BufferDesc& bufferDesc) {
@ -1076,7 +1043,7 @@ nn::GeneralResult<std::vector<ndk::ScopedFileDescriptor>> convert(
nn::GeneralResult<std::vector<ndk::ScopedFileDescriptor>> convert(
const std::vector<nn::SyncFence>& syncFences) {
return unvalidatedConvert(syncFences);
return validatedConvert(syncFences);
}
nn::GeneralResult<std::vector<int32_t>> toSigned(const std::vector<uint32_t>& vec) {

View file

@ -1312,7 +1312,7 @@ static void mutateExecutionPriorityTest(const std::shared_ptr<IDevice>& device,
void validateModel(const std::shared_ptr<IDevice>& device, const Model& model) {
const auto numberOfConsumers =
nn::countNumberOfConsumers(model.main.operands.size(),
nn::convert(model.main.operations).value())
nn::unvalidatedConvert(model.main.operations).value())
.value();
mutateExecutionOrderTest(device, model, numberOfConsumers);
mutateOperandTypeTest(device, model);