diff --git a/neuralnetworks/1.0/vts/functional/VtsHalNeuralnetworks.cpp b/neuralnetworks/1.0/vts/functional/VtsHalNeuralnetworks.cpp index 88830574da..31638c425f 100644 --- a/neuralnetworks/1.0/vts/functional/VtsHalNeuralnetworks.cpp +++ b/neuralnetworks/1.0/vts/functional/VtsHalNeuralnetworks.cpp @@ -68,6 +68,11 @@ void NeuralnetworksHidlTest::TearDown() { ::testing::VtsHalHidlTargetTestBase::TearDown(); } +void ValidationTest::validateEverything(const Model& model, const std::vector& request) { + validateModel(model); + validateRequests(model, request); +} + } // namespace functional } // namespace vts } // namespace V1_0 diff --git a/neuralnetworks/1.0/vts/functional/VtsHalNeuralnetworks.h b/neuralnetworks/1.0/vts/functional/VtsHalNeuralnetworks.h index d4c114d3a2..559d678ea1 100644 --- a/neuralnetworks/1.0/vts/functional/VtsHalNeuralnetworks.h +++ b/neuralnetworks/1.0/vts/functional/VtsHalNeuralnetworks.h @@ -63,8 +63,11 @@ class NeuralnetworksHidlTest : public ::testing::VtsHalHidlTargetTestBase { // Tag for the validation tests class ValidationTest : public NeuralnetworksHidlTest { protected: - void validateModel(const Model& model); - void validateRequests(const Model& model, const std::vector& request); + void validateEverything(const Model& model, const std::vector& request); + + private: + void validateModel(const Model& model); + void validateRequests(const Model& model, const std::vector& request); }; // Tag for the generated tests diff --git a/neuralnetworks/1.1/vts/functional/VtsHalNeuralnetworks.cpp b/neuralnetworks/1.1/vts/functional/VtsHalNeuralnetworks.cpp index 224a51d149..11fa693ddc 100644 --- a/neuralnetworks/1.1/vts/functional/VtsHalNeuralnetworks.cpp +++ b/neuralnetworks/1.1/vts/functional/VtsHalNeuralnetworks.cpp @@ -68,6 +68,11 @@ void NeuralnetworksHidlTest::TearDown() { ::testing::VtsHalHidlTargetTestBase::TearDown(); } +void ValidationTest::validateEverything(const Model& model, const std::vector& request) { + validateModel(model); + validateRequests(model, request); +} + } // namespace functional } // namespace vts } // namespace V1_1 diff --git a/neuralnetworks/1.1/vts/functional/VtsHalNeuralnetworks.h b/neuralnetworks/1.1/vts/functional/VtsHalNeuralnetworks.h index 1c8c0e18cb..cea2b54c2d 100644 --- a/neuralnetworks/1.1/vts/functional/VtsHalNeuralnetworks.h +++ b/neuralnetworks/1.1/vts/functional/VtsHalNeuralnetworks.h @@ -72,8 +72,11 @@ class NeuralnetworksHidlTest : public ::testing::VtsHalHidlTargetTestBase { // Tag for the validation tests class ValidationTest : public NeuralnetworksHidlTest { protected: - void validateModel(const Model& model); - void validateRequests(const Model& model, const std::vector& request); + void validateEverything(const Model& model, const std::vector& request); + + private: + void validateModel(const Model& model); + void validateRequests(const Model& model, const std::vector& request); }; // Tag for the generated tests diff --git a/neuralnetworks/1.2/vts/functional/Android.bp b/neuralnetworks/1.2/vts/functional/Android.bp index 891b414480..6c26820b27 100644 --- a/neuralnetworks/1.2/vts/functional/Android.bp +++ b/neuralnetworks/1.2/vts/functional/Android.bp @@ -20,6 +20,7 @@ cc_test { defaults: ["VtsHalNeuralNetworksTargetTestDefaults"], srcs: [ "GeneratedTestsV1_0.cpp", + "ValidateBurst.cpp", ], cflags: [ "-DNN_TEST_DYNAMIC_OUTPUT_SHAPE" @@ -32,6 +33,7 @@ cc_test { defaults: ["VtsHalNeuralNetworksTargetTestDefaults"], srcs: [ "GeneratedTestsV1_1.cpp", + "ValidateBurst.cpp", ], cflags: [ "-DNN_TEST_DYNAMIC_OUTPUT_SHAPE" @@ -46,6 +48,7 @@ cc_test { "BasicTests.cpp", "CompilationCachingTests.cpp", "GeneratedTests.cpp", + "ValidateBurst.cpp", ], cflags: [ "-DNN_TEST_DYNAMIC_OUTPUT_SHAPE" @@ -58,6 +61,7 @@ cc_test { srcs: [ "BasicTests.cpp", "GeneratedTests.cpp", + "ValidateBurst.cpp", ], cflags: [ "-DNN_TEST_DYNAMIC_OUTPUT_SHAPE", diff --git a/neuralnetworks/1.2/vts/functional/ValidateBurst.cpp b/neuralnetworks/1.2/vts/functional/ValidateBurst.cpp new file mode 100644 index 0000000000..386c141f80 --- /dev/null +++ b/neuralnetworks/1.2/vts/functional/ValidateBurst.cpp @@ -0,0 +1,333 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define LOG_TAG "neuralnetworks_hidl_hal_test" + +#include "VtsHalNeuralnetworks.h" + +#include "Callbacks.h" +#include "ExecutionBurstController.h" +#include "ExecutionBurstServer.h" +#include "TestHarness.h" +#include "Utils.h" + +#include + +namespace android { +namespace hardware { +namespace neuralnetworks { +namespace V1_2 { +namespace vts { +namespace functional { + +using ::android::nn::ExecutionBurstController; +using ::android::nn::RequestChannelSender; +using ::android::nn::ResultChannelReceiver; +using ExecutionBurstCallback = ::android::nn::ExecutionBurstController::ExecutionBurstCallback; + +constexpr size_t kExecutionBurstChannelLength = 1024; +constexpr size_t kExecutionBurstChannelSmallLength = 8; + +///////////////////////// UTILITY FUNCTIONS ///////////////////////// + +static bool badTiming(Timing timing) { + return timing.timeOnDevice == UINT64_MAX && timing.timeInDriver == UINT64_MAX; +} + +static void createBurst(const sp& preparedModel, const sp& callback, + std::unique_ptr* sender, + std::unique_ptr* receiver, + sp* context) { + ASSERT_NE(nullptr, preparedModel.get()); + ASSERT_NE(nullptr, sender); + ASSERT_NE(nullptr, receiver); + ASSERT_NE(nullptr, context); + + // create FMQ objects + auto [fmqRequestChannel, fmqRequestDescriptor] = + RequestChannelSender::create(kExecutionBurstChannelLength, /*blocking=*/true); + auto [fmqResultChannel, fmqResultDescriptor] = + ResultChannelReceiver::create(kExecutionBurstChannelLength, /*blocking=*/true); + ASSERT_NE(nullptr, fmqRequestChannel.get()); + ASSERT_NE(nullptr, fmqResultChannel.get()); + ASSERT_NE(nullptr, fmqRequestDescriptor); + ASSERT_NE(nullptr, fmqResultDescriptor); + + // configure burst + ErrorStatus errorStatus; + sp burstContext; + const Return ret = preparedModel->configureExecutionBurst( + callback, *fmqRequestDescriptor, *fmqResultDescriptor, + [&errorStatus, &burstContext](ErrorStatus status, const sp& context) { + errorStatus = status; + burstContext = context; + }); + ASSERT_TRUE(ret.isOk()); + ASSERT_EQ(ErrorStatus::NONE, errorStatus); + ASSERT_NE(nullptr, burstContext.get()); + + // return values + *sender = std::move(fmqRequestChannel); + *receiver = std::move(fmqResultChannel); + *context = burstContext; +} + +static void createBurstWithResultChannelLength( + const sp& preparedModel, + std::shared_ptr* controller, size_t resultChannelLength) { + ASSERT_NE(nullptr, preparedModel.get()); + ASSERT_NE(nullptr, controller); + + // create FMQ objects + auto [fmqRequestChannel, fmqRequestDescriptor] = + RequestChannelSender::create(kExecutionBurstChannelLength, /*blocking=*/true); + auto [fmqResultChannel, fmqResultDescriptor] = + ResultChannelReceiver::create(resultChannelLength, /*blocking=*/true); + ASSERT_NE(nullptr, fmqRequestChannel.get()); + ASSERT_NE(nullptr, fmqResultChannel.get()); + ASSERT_NE(nullptr, fmqRequestDescriptor); + ASSERT_NE(nullptr, fmqResultDescriptor); + + // configure burst + sp callback = new ExecutionBurstCallback(); + ErrorStatus errorStatus; + sp burstContext; + const Return ret = preparedModel->configureExecutionBurst( + callback, *fmqRequestDescriptor, *fmqResultDescriptor, + [&errorStatus, &burstContext](ErrorStatus status, const sp& context) { + errorStatus = status; + burstContext = context; + }); + ASSERT_TRUE(ret.isOk()); + ASSERT_EQ(ErrorStatus::NONE, errorStatus); + ASSERT_NE(nullptr, burstContext.get()); + + // return values + *controller = std::make_shared( + std::move(fmqRequestChannel), std::move(fmqResultChannel), burstContext, callback); +} + +// Primary validation function. This function will take a valid serialized +// request, apply a mutation to it to invalidate the serialized request, then +// pass it to interface calls that use the serialized request. Note that the +// serialized request here is passed by value, and any mutation to the +// serialized request does not leave this function. +static void validate(RequestChannelSender* sender, ResultChannelReceiver* receiver, + const std::string& message, std::vector serialized, + const std::function*)>& mutation) { + mutation(&serialized); + + // skip if packet is too large to send + if (serialized.size() > kExecutionBurstChannelLength) { + return; + } + + SCOPED_TRACE(message); + + // send invalid packet + sender->sendPacket(serialized); + + // receive error + auto results = receiver->getBlocking(); + ASSERT_TRUE(results.has_value()); + const auto [status, outputShapes, timing] = std::move(*results); + EXPECT_NE(ErrorStatus::NONE, status); + EXPECT_EQ(0u, outputShapes.size()); + EXPECT_TRUE(badTiming(timing)); +} + +static std::vector createUniqueDatum() { + const FmqRequestDatum::PacketInformation packetInformation = { + /*.packetSize=*/10, /*.numberOfInputOperands=*/10, /*.numberOfOutputOperands=*/10, + /*.numberOfPools=*/10}; + const FmqRequestDatum::OperandInformation operandInformation = { + /*.hasNoValue=*/false, /*.location=*/{}, /*.numberOfDimensions=*/10}; + const int32_t invalidPoolIdentifier = std::numeric_limits::max(); + std::vector unique(7); + unique[0].packetInformation(packetInformation); + unique[1].inputOperandInformation(operandInformation); + unique[2].inputOperandDimensionValue(0); + unique[3].outputOperandInformation(operandInformation); + unique[4].outputOperandDimensionValue(0); + unique[5].poolIdentifier(invalidPoolIdentifier); + unique[6].measureTiming(MeasureTiming::YES); + return unique; +} + +static const std::vector& getUniqueDatum() { + static const std::vector unique = createUniqueDatum(); + return unique; +} + +///////////////////////// REMOVE DATUM //////////////////////////////////// + +static void removeDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver, + const std::vector& serialized) { + for (size_t index = 0; index < serialized.size(); ++index) { + const std::string message = "removeDatum: removed datum at index " + std::to_string(index); + validate(sender, receiver, message, serialized, + [index](std::vector* serialized) { + serialized->erase(serialized->begin() + index); + }); + } +} + +///////////////////////// ADD DATUM //////////////////////////////////// + +static void addDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver, + const std::vector& serialized) { + const std::vector& extra = getUniqueDatum(); + for (size_t index = 0; index <= serialized.size(); ++index) { + for (size_t type = 0; type < extra.size(); ++type) { + const std::string message = "addDatum: added datum type " + std::to_string(type) + + " at index " + std::to_string(index); + validate(sender, receiver, message, serialized, + [index, type, &extra](std::vector* serialized) { + serialized->insert(serialized->begin() + index, extra[type]); + }); + } + } +} + +///////////////////////// MUTATE DATUM //////////////////////////////////// + +static bool interestingCase(const FmqRequestDatum& lhs, const FmqRequestDatum& rhs) { + using Discriminator = FmqRequestDatum::hidl_discriminator; + + const bool differentValues = (lhs != rhs); + const bool sameSumType = (lhs.getDiscriminator() == rhs.getDiscriminator()); + const auto discriminator = rhs.getDiscriminator(); + const bool isDimensionValue = (discriminator == Discriminator::inputOperandDimensionValue || + discriminator == Discriminator::outputOperandDimensionValue); + + return differentValues && !(sameSumType && isDimensionValue); +} + +static void mutateDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver, + const std::vector& serialized) { + const std::vector& change = getUniqueDatum(); + for (size_t index = 0; index < serialized.size(); ++index) { + for (size_t type = 0; type < change.size(); ++type) { + if (interestingCase(serialized[index], change[type])) { + const std::string message = "mutateDatum: changed datum at index " + + std::to_string(index) + " to datum type " + + std::to_string(type); + validate(sender, receiver, message, serialized, + [index, type, &change](std::vector* serialized) { + (*serialized)[index] = change[type]; + }); + } + } + } +} + +///////////////////////// BURST VALIATION TESTS //////////////////////////////////// + +static void validateBurstSerialization(const sp& preparedModel, + const std::vector& requests) { + // create burst + std::unique_ptr sender; + std::unique_ptr receiver; + sp callback = new ExecutionBurstCallback(); + sp context; + ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context)); + ASSERT_NE(nullptr, sender.get()); + ASSERT_NE(nullptr, receiver.get()); + ASSERT_NE(nullptr, context.get()); + + // validate each request + for (const Request& request : requests) { + // load memory into callback slots + std::vector keys(request.pools.size()); + for (size_t i = 0; i < keys.size(); ++i) { + keys[i] = reinterpret_cast(&request.pools[i]); + } + const std::vector slots = callback->getSlots(request.pools, keys); + + // ensure slot std::numeric_limits::max() doesn't exist (for + // subsequent slot validation testing) + const auto maxElement = std::max_element(slots.begin(), slots.end()); + ASSERT_NE(slots.end(), maxElement); + ASSERT_NE(std::numeric_limits::max(), *maxElement); + + // serialize the request + const auto serialized = ::android::nn::serialize(request, MeasureTiming::YES, slots); + + // validations + removeDatumTest(sender.get(), receiver.get(), serialized); + addDatumTest(sender.get(), receiver.get(), serialized); + mutateDatumTest(sender.get(), receiver.get(), serialized); + } +} + +static void validateBurstFmqLength(const sp& preparedModel, + const std::vector& requests) { + // create regular burst + std::shared_ptr controllerRegular; + ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength(preparedModel, &controllerRegular, + kExecutionBurstChannelLength)); + ASSERT_NE(nullptr, controllerRegular.get()); + + // create burst with small output channel + std::shared_ptr controllerSmall; + ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength(preparedModel, &controllerSmall, + kExecutionBurstChannelSmallLength)); + ASSERT_NE(nullptr, controllerSmall.get()); + + // validate each request + for (const Request& request : requests) { + // load memory into callback slots + std::vector keys(request.pools.size()); + for (size_t i = 0; i < keys.size(); ++i) { + keys[i] = reinterpret_cast(&request.pools[i]); + } + + // collect serialized result by running regular burst + const auto [status1, outputShapes1, timing1] = + controllerRegular->compute(request, MeasureTiming::NO, keys); + + // skip test if synchronous output isn't useful + const std::vector serialized = + ::android::nn::serialize(status1, outputShapes1, timing1); + if (status1 != ErrorStatus::NONE || + serialized.size() <= kExecutionBurstChannelSmallLength) { + continue; + } + + // by this point, execution should fail because the result channel isn't + // large enough to return the serialized result + const auto [status2, outputShapes2, timing2] = + controllerSmall->compute(request, MeasureTiming::NO, keys); + EXPECT_NE(ErrorStatus::NONE, status2); + EXPECT_EQ(0u, outputShapes2.size()); + EXPECT_TRUE(badTiming(timing2)); + } +} + +///////////////////////////// ENTRY POINT ////////////////////////////////// + +void ValidationTest::validateBurst(const sp& preparedModel, + const std::vector& requests) { + ASSERT_NO_FATAL_FAILURE(validateBurstSerialization(preparedModel, requests)); + ASSERT_NO_FATAL_FAILURE(validateBurstFmqLength(preparedModel, requests)); +} + +} // namespace functional +} // namespace vts +} // namespace V1_2 +} // namespace neuralnetworks +} // namespace hardware +} // namespace android diff --git a/neuralnetworks/1.2/vts/functional/ValidateRequest.cpp b/neuralnetworks/1.2/vts/functional/ValidateRequest.cpp index 870d01748a..9703c2d765 100644 --- a/neuralnetworks/1.2/vts/functional/ValidateRequest.cpp +++ b/neuralnetworks/1.2/vts/functional/ValidateRequest.cpp @@ -35,9 +35,7 @@ namespace vts { namespace functional { using ::android::hardware::neuralnetworks::V1_2::implementation::ExecutionCallback; -using ::android::hardware::neuralnetworks::V1_2::implementation::PreparedModelCallback; using ::android::hidl::memory::V1_0::IMemory; -using HidlToken = hidl_array(Constant::BYTE_SIZE_OF_CACHE_TOKEN)>; using test_helper::for_all; using test_helper::MixedTyped; using test_helper::MixedTypedExample; @@ -48,55 +46,6 @@ static bool badTiming(Timing timing) { return timing.timeOnDevice == UINT64_MAX && timing.timeInDriver == UINT64_MAX; } -static void createPreparedModel(const sp& device, const Model& model, - sp* preparedModel) { - ASSERT_NE(nullptr, preparedModel); - - // see if service can handle model - bool fullySupportsModel = false; - Return supportedOpsLaunchStatus = device->getSupportedOperations_1_2( - model, [&fullySupportsModel](ErrorStatus status, const hidl_vec& supported) { - ASSERT_EQ(ErrorStatus::NONE, status); - ASSERT_NE(0ul, supported.size()); - fullySupportsModel = - std::all_of(supported.begin(), supported.end(), [](bool valid) { return valid; }); - }); - ASSERT_TRUE(supportedOpsLaunchStatus.isOk()); - - // launch prepare model - sp preparedModelCallback = new PreparedModelCallback(); - ASSERT_NE(nullptr, preparedModelCallback.get()); - Return prepareLaunchStatus = device->prepareModel_1_2( - model, ExecutionPreference::FAST_SINGLE_ANSWER, hidl_vec(), - hidl_vec(), HidlToken(), preparedModelCallback); - ASSERT_TRUE(prepareLaunchStatus.isOk()); - ASSERT_EQ(ErrorStatus::NONE, static_cast(prepareLaunchStatus)); - - // retrieve prepared model - preparedModelCallback->wait(); - ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus(); - *preparedModel = getPreparedModel_1_2(preparedModelCallback); - - // The getSupportedOperations_1_2 call returns a list of operations that are - // guaranteed not to fail if prepareModel_1_2 is called, and - // 'fullySupportsModel' is true i.f.f. the entire model is guaranteed. - // If a driver has any doubt that it can prepare an operation, it must - // return false. So here, if a driver isn't sure if it can support an - // operation, but reports that it successfully prepared the model, the test - // can continue. - if (!fullySupportsModel && prepareReturnStatus != ErrorStatus::NONE) { - ASSERT_EQ(nullptr, preparedModel->get()); - LOG(INFO) << "NN VTS: Unable to test Request validation because vendor service cannot " - "prepare model that it does not support."; - std::cout << "[ ] Unable to test Request validation because vendor service " - "cannot prepare model that it does not support." - << std::endl; - return; - } - ASSERT_EQ(ErrorStatus::NONE, prepareReturnStatus); - ASSERT_NE(nullptr, preparedModel->get()); -} - // Primary validation function. This function will take a valid request, apply a // mutation to it to invalidate the request, then pass it to interface calls // that use the request. Note that the request here is passed by value, and any @@ -316,14 +265,8 @@ std::vector createRequests(const std::vector& exampl return requests; } -void ValidationTest::validateRequests(const Model& model, const std::vector& requests) { - // create IPreparedModel - sp preparedModel; - ASSERT_NO_FATAL_FAILURE(createPreparedModel(device, model, &preparedModel)); - if (preparedModel == nullptr) { - return; - } - +void ValidationTest::validateRequests(const sp& preparedModel, + const std::vector& requests) { // validate each request for (const Request& request : requests) { removeInputTest(preparedModel, request); diff --git a/neuralnetworks/1.2/vts/functional/VtsHalNeuralnetworks.cpp b/neuralnetworks/1.2/vts/functional/VtsHalNeuralnetworks.cpp index 4728c28e87..93182f1da2 100644 --- a/neuralnetworks/1.2/vts/functional/VtsHalNeuralnetworks.cpp +++ b/neuralnetworks/1.2/vts/functional/VtsHalNeuralnetworks.cpp @@ -18,6 +18,10 @@ #include "VtsHalNeuralnetworks.h" +#include + +#include "Callbacks.h" + namespace android { namespace hardware { namespace neuralnetworks { @@ -25,6 +29,61 @@ namespace V1_2 { namespace vts { namespace functional { +using ::android::hardware::neuralnetworks::V1_2::implementation::ExecutionCallback; +using ::android::hardware::neuralnetworks::V1_2::implementation::PreparedModelCallback; +using HidlToken = hidl_array(Constant::BYTE_SIZE_OF_CACHE_TOKEN)>; +using V1_1::ExecutionPreference; + +// internal helper function +static void createPreparedModel(const sp& device, const Model& model, + sp* preparedModel) { + ASSERT_NE(nullptr, preparedModel); + + // see if service can handle model + bool fullySupportsModel = false; + Return supportedOpsLaunchStatus = device->getSupportedOperations_1_2( + model, [&fullySupportsModel](ErrorStatus status, const hidl_vec& supported) { + ASSERT_EQ(ErrorStatus::NONE, status); + ASSERT_NE(0ul, supported.size()); + fullySupportsModel = std::all_of(supported.begin(), supported.end(), + [](bool valid) { return valid; }); + }); + ASSERT_TRUE(supportedOpsLaunchStatus.isOk()); + + // launch prepare model + sp preparedModelCallback = new PreparedModelCallback(); + ASSERT_NE(nullptr, preparedModelCallback.get()); + Return prepareLaunchStatus = device->prepareModel_1_2( + model, ExecutionPreference::FAST_SINGLE_ANSWER, hidl_vec(), + hidl_vec(), HidlToken(), preparedModelCallback); + ASSERT_TRUE(prepareLaunchStatus.isOk()); + ASSERT_EQ(ErrorStatus::NONE, static_cast(prepareLaunchStatus)); + + // retrieve prepared model + preparedModelCallback->wait(); + ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus(); + *preparedModel = getPreparedModel_1_2(preparedModelCallback); + + // The getSupportedOperations_1_2 call returns a list of operations that are + // guaranteed not to fail if prepareModel_1_2 is called, and + // 'fullySupportsModel' is true i.f.f. the entire model is guaranteed. + // If a driver has any doubt that it can prepare an operation, it must + // return false. So here, if a driver isn't sure if it can support an + // operation, but reports that it successfully prepared the model, the test + // can continue. + if (!fullySupportsModel && prepareReturnStatus != ErrorStatus::NONE) { + ASSERT_EQ(nullptr, preparedModel->get()); + LOG(INFO) << "NN VTS: Unable to test Request validation because vendor service cannot " + "prepare model that it does not support."; + std::cout << "[ ] Unable to test Request validation because vendor service " + "cannot prepare model that it does not support." + << std::endl; + return; + } + ASSERT_EQ(ErrorStatus::NONE, prepareReturnStatus); + ASSERT_NE(nullptr, preparedModel->get()); +} + // A class for test environment setup NeuralnetworksHidlEnvironment::NeuralnetworksHidlEnvironment() {} @@ -68,6 +127,20 @@ void NeuralnetworksHidlTest::TearDown() { ::testing::VtsHalHidlTargetTestBase::TearDown(); } +void ValidationTest::validateEverything(const Model& model, const std::vector& request) { + validateModel(model); + + // create IPreparedModel + sp preparedModel; + ASSERT_NO_FATAL_FAILURE(createPreparedModel(device, model, &preparedModel)); + if (preparedModel == nullptr) { + return; + } + + validateRequests(preparedModel, request); + validateBurst(preparedModel, request); +} + sp getPreparedModel_1_2( const sp& callback) { sp preparedModelV1_0 = callback->getPreparedModel(); diff --git a/neuralnetworks/1.2/vts/functional/VtsHalNeuralnetworks.h b/neuralnetworks/1.2/vts/functional/VtsHalNeuralnetworks.h index 404eec06db..36e73a4fb0 100644 --- a/neuralnetworks/1.2/vts/functional/VtsHalNeuralnetworks.h +++ b/neuralnetworks/1.2/vts/functional/VtsHalNeuralnetworks.h @@ -72,8 +72,14 @@ class NeuralnetworksHidlTest : public ::testing::VtsHalHidlTargetTestBase { // Tag for the validation tests class ValidationTest : public NeuralnetworksHidlTest { protected: - void validateModel(const Model& model); - void validateRequests(const Model& model, const std::vector& request); + void validateEverything(const Model& model, const std::vector& request); + + private: + void validateModel(const Model& model); + void validateRequests(const sp& preparedModel, + const std::vector& requests); + void validateBurst(const sp& preparedModel, + const std::vector& requests); }; // Tag for the generated tests