Make NNAPI countNumberOfConsumers return GeneralResult -- hal

Previously, countNumberOfConsumers would trigger a CHECK if the input
was invalid. This CL makes countNumberOfConsumers gracefully fail on
errors, instead returning the error through the GeneralResult.

Bug: N/A
Test: mma
Change-Id: Iee54f87768e52fdf701c22d94083c053b881733d
Merged-In: Iee54f87768e52fdf701c22d94083c053b881733d
(cherry picked from commit c4d98007fd)
This commit is contained in:
Michael Butler 2021-02-09 15:36:11 -08:00
parent 8548f574ee
commit 68b6926e3c
7 changed files with 17 additions and 15 deletions

View file

@ -162,7 +162,7 @@ GeneralResult<Model> unvalidatedConvert(const hal::V1_0::Model& model) {
// Verify number of consumers.
const auto numberOfConsumers =
hal::utils::countNumberOfConsumers(model.operands.size(), operations);
NN_TRY(hal::utils::countNumberOfConsumers(model.operands.size(), operations));
CHECK(model.operands.size() == numberOfConsumers.size());
for (size_t i = 0; i < model.operands.size(); ++i) {
if (model.operands[i].numberOfConsumers != numberOfConsumers[i]) {
@ -360,7 +360,7 @@ nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model) {
// Update number of consumers.
const auto numberOfConsumers =
hal::utils::countNumberOfConsumers(operands.size(), model.main.operations);
NN_TRY(hal::utils::countNumberOfConsumers(operands.size(), model.main.operations));
CHECK(operands.size() == numberOfConsumers.size());
for (size_t i = 0; i < operands.size(); ++i) {
operands[i].numberOfConsumers = numberOfConsumers[i];

View file

@ -111,7 +111,7 @@ GeneralResult<Model> unvalidatedConvert(const hal::V1_1::Model& model) {
// Verify number of consumers.
const auto numberOfConsumers =
hal::utils::countNumberOfConsumers(model.operands.size(), operations);
NN_TRY(hal::utils::countNumberOfConsumers(model.operands.size(), operations));
CHECK(model.operands.size() == numberOfConsumers.size());
for (size_t i = 0; i < model.operands.size(); ++i) {
if (model.operands[i].numberOfConsumers != numberOfConsumers[i]) {
@ -241,7 +241,7 @@ nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model) {
// Update number of consumers.
const auto numberOfConsumers =
hal::utils::countNumberOfConsumers(operands.size(), model.main.operations);
NN_TRY(hal::utils::countNumberOfConsumers(operands.size(), model.main.operations));
CHECK(operands.size() == numberOfConsumers.size());
for (size_t i = 0; i < operands.size(); ++i) {
operands[i].numberOfConsumers = numberOfConsumers[i];

View file

@ -227,7 +227,7 @@ GeneralResult<Model> unvalidatedConvert(const hal::V1_2::Model& model) {
// Verify number of consumers.
const auto numberOfConsumers =
hal::utils::countNumberOfConsumers(model.operands.size(), operations);
NN_TRY(hal::utils::countNumberOfConsumers(model.operands.size(), operations));
CHECK(model.operands.size() == numberOfConsumers.size());
for (size_t i = 0; i < model.operands.size(); ++i) {
if (model.operands[i].numberOfConsumers != numberOfConsumers[i]) {
@ -529,7 +529,7 @@ nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model) {
// Update number of consumers.
const auto numberOfConsumers =
hal::utils::countNumberOfConsumers(operands.size(), model.main.operations);
NN_TRY(hal::utils::countNumberOfConsumers(operands.size(), model.main.operations));
CHECK(operands.size() == numberOfConsumers.size());
for (size_t i = 0; i < operands.size(); ++i) {
operands[i].numberOfConsumers = numberOfConsumers[i];

View file

@ -217,7 +217,7 @@ GeneralResult<Model::Subgraph> unvalidatedConvert(const hal::V1_3::Subgraph& sub
// Verify number of consumers.
const auto numberOfConsumers =
hal::utils::countNumberOfConsumers(subgraph.operands.size(), operations);
NN_TRY(hal::utils::countNumberOfConsumers(subgraph.operands.size(), operations));
CHECK(subgraph.operands.size() == numberOfConsumers.size());
for (size_t i = 0; i < subgraph.operands.size(); ++i) {
if (subgraph.operands[i].numberOfConsumers != numberOfConsumers[i]) {
@ -559,7 +559,7 @@ nn::GeneralResult<Subgraph> unvalidatedConvert(const nn::Model::Subgraph& subgra
// Update number of consumers.
const auto numberOfConsumers =
hal::utils::countNumberOfConsumers(operands.size(), subgraph.operations);
NN_TRY(hal::utils::countNumberOfConsumers(operands.size(), subgraph.operations));
CHECK(operands.size() == numberOfConsumers.size());
for (size_t i = 0; i < operands.size(); ++i) {
operands[i].numberOfConsumers = numberOfConsumers[i];

View file

@ -1310,8 +1310,10 @@ static void mutateExecutionPriorityTest(const std::shared_ptr<IDevice>& device,
////////////////////////// ENTRY POINT //////////////////////////////
void validateModel(const std::shared_ptr<IDevice>& device, const Model& model) {
const auto numberOfConsumers = nn::countNumberOfConsumers(
model.main.operands.size(), nn::convert(model.main.operations).value());
const auto numberOfConsumers =
nn::countNumberOfConsumers(model.main.operands.size(),
nn::convert(model.main.operations).value())
.value();
mutateExecutionOrderTest(device, model, numberOfConsumers);
mutateOperandTypeTest(device, model);
mutateOperandRankTest(device, model);

View file

@ -71,8 +71,8 @@ nn::GeneralResult<std::reference_wrapper<const nn::Request>> flushDataFromPointe
nn::GeneralResult<void> unflushDataFromSharedToPointer(
const nn::Request& request, const std::optional<nn::Request>& maybeRequestInShared);
std::vector<uint32_t> countNumberOfConsumers(size_t numberOfOperands,
const std::vector<nn::Operation>& operations);
nn::GeneralResult<std::vector<uint32_t>> countNumberOfConsumers(
size_t numberOfOperands, const std::vector<nn::Operation>& operations);
nn::GeneralResult<hidl_memory> createHidlMemoryFromSharedMemory(const nn::SharedMemory& memory);
nn::GeneralResult<nn::SharedMemory> createSharedMemoryFromHidlMemory(const hidl_memory& memory);

View file

@ -246,9 +246,9 @@ nn::GeneralResult<void> unflushDataFromSharedToPointer(
return {};
}
std::vector<uint32_t> countNumberOfConsumers(size_t numberOfOperands,
const std::vector<nn::Operation>& operations) {
return nn::countNumberOfConsumers(numberOfOperands, operations);
nn::GeneralResult<std::vector<uint32_t>> countNumberOfConsumers(
size_t numberOfOperands, const std::vector<nn::Operation>& operations) {
return makeGeneralFailure(nn::countNumberOfConsumers(numberOfOperands, operations));
}
nn::GeneralResult<hidl_memory> createHidlMemoryFromSharedMemory(const nn::SharedMemory& memory) {