/* * Copyright (C) 2019 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #define LOG_TAG "neuralnetworks_hidl_hal_test" #include "VtsHalNeuralnetworks.h" #include "1.2/Callbacks.h" #include "ExecutionBurstController.h" #include "ExecutionBurstServer.h" #include "GeneratedTestHarness.h" #include "TestHarness.h" #include #include #include namespace android::hardware::neuralnetworks::V1_2::vts::functional { using nn::ExecutionBurstController; using nn::RequestChannelSender; using nn::ResultChannelReceiver; using V1_0::ErrorStatus; using V1_0::Request; using ExecutionBurstCallback = ExecutionBurstController::ExecutionBurstCallback; using BurstExecutionMutation = std::function*)>; // This constant value represents the length of an FMQ that is large enough to // return a result from a burst execution for all of the generated test cases. constexpr size_t kExecutionBurstChannelLength = 1024; // This constant value represents a length of an FMQ that is not large enough // to return a result from a burst execution for some of the generated test // cases. constexpr size_t kExecutionBurstChannelSmallLength = 8; ///////////////////////// UTILITY FUNCTIONS ///////////////////////// static bool badTiming(Timing timing) { return timing.timeOnDevice == UINT64_MAX && timing.timeInDriver == UINT64_MAX; } static void createBurst(const sp& preparedModel, const sp& callback, std::unique_ptr* sender, std::unique_ptr* receiver, sp* context, size_t resultChannelLength = kExecutionBurstChannelLength) { ASSERT_NE(nullptr, preparedModel.get()); ASSERT_NE(nullptr, sender); ASSERT_NE(nullptr, receiver); ASSERT_NE(nullptr, context); // create FMQ objects auto [fmqRequestChannel, fmqRequestDescriptor] = RequestChannelSender::create(kExecutionBurstChannelLength); auto [fmqResultChannel, fmqResultDescriptor] = ResultChannelReceiver::create(resultChannelLength, std::chrono::microseconds{0}); ASSERT_NE(nullptr, fmqRequestChannel.get()); ASSERT_NE(nullptr, fmqResultChannel.get()); ASSERT_NE(nullptr, fmqRequestDescriptor); ASSERT_NE(nullptr, fmqResultDescriptor); // configure burst ErrorStatus errorStatus; sp burstContext; const Return ret = preparedModel->configureExecutionBurst( callback, *fmqRequestDescriptor, *fmqResultDescriptor, [&errorStatus, &burstContext](ErrorStatus status, const sp& context) { errorStatus = status; burstContext = context; }); ASSERT_TRUE(ret.isOk()); ASSERT_EQ(ErrorStatus::NONE, errorStatus); ASSERT_NE(nullptr, burstContext.get()); // return values *sender = std::move(fmqRequestChannel); *receiver = std::move(fmqResultChannel); *context = burstContext; } static void createBurstWithResultChannelLength( const sp& preparedModel, size_t resultChannelLength, std::shared_ptr* controller) { ASSERT_NE(nullptr, preparedModel.get()); ASSERT_NE(nullptr, controller); // create FMQ objects std::unique_ptr sender; std::unique_ptr receiver; sp callback = new ExecutionBurstCallback(); sp context; ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context, resultChannelLength)); ASSERT_NE(nullptr, sender.get()); ASSERT_NE(nullptr, receiver.get()); ASSERT_NE(nullptr, context.get()); // return values *controller = std::make_shared(std::move(sender), std::move(receiver), context, callback); } // Primary validation function. This function will take a valid serialized // request, apply a mutation to it to invalidate the serialized request, then // pass it to interface calls that use the serialized request. static void validate(RequestChannelSender* sender, ResultChannelReceiver* receiver, const std::string& message, const std::vector& originalSerialized, const BurstExecutionMutation& mutate) { std::vector serialized = originalSerialized; mutate(&serialized); // skip if packet is too large to send if (serialized.size() > kExecutionBurstChannelLength) { return; } SCOPED_TRACE(message); // send invalid packet ASSERT_TRUE(sender->sendPacket(serialized)); // receive error auto results = receiver->getBlocking(); ASSERT_TRUE(results.has_value()); const auto [status, outputShapes, timing] = std::move(*results); EXPECT_NE(ErrorStatus::NONE, status); EXPECT_EQ(0u, outputShapes.size()); EXPECT_TRUE(badTiming(timing)); } // For validation, valid packet entries are mutated to invalid packet entries, // or invalid packet entries are inserted into valid packets. This function // creates pre-set invalid packet entries for convenience. static std::vector createBadRequestPacketEntries() { const FmqRequestDatum::PacketInformation packetInformation = { /*.packetSize=*/10, /*.numberOfInputOperands=*/10, /*.numberOfOutputOperands=*/10, /*.numberOfPools=*/10}; const FmqRequestDatum::OperandInformation operandInformation = { /*.hasNoValue=*/false, /*.location=*/{}, /*.numberOfDimensions=*/10}; const int32_t invalidPoolIdentifier = std::numeric_limits::max(); std::vector bad(7); bad[0].packetInformation(packetInformation); bad[1].inputOperandInformation(operandInformation); bad[2].inputOperandDimensionValue(0); bad[3].outputOperandInformation(operandInformation); bad[4].outputOperandDimensionValue(0); bad[5].poolIdentifier(invalidPoolIdentifier); bad[6].measureTiming(MeasureTiming::YES); return bad; } // For validation, valid packet entries are mutated to invalid packet entries, // or invalid packet entries are inserted into valid packets. This function // retrieves pre-set invalid packet entries for convenience. This function // caches these data so they can be reused on subsequent validation checks. static const std::vector& getBadRequestPacketEntries() { static const std::vector bad = createBadRequestPacketEntries(); return bad; } ///////////////////////// REMOVE DATUM //////////////////////////////////// static void removeDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver, const std::vector& serialized) { for (size_t index = 0; index < serialized.size(); ++index) { const std::string message = "removeDatum: removed datum at index " + std::to_string(index); validate(sender, receiver, message, serialized, [index](std::vector* serialized) { serialized->erase(serialized->begin() + index); }); } } ///////////////////////// ADD DATUM //////////////////////////////////// static void addDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver, const std::vector& serialized) { const std::vector& extra = getBadRequestPacketEntries(); for (size_t index = 0; index <= serialized.size(); ++index) { for (size_t type = 0; type < extra.size(); ++type) { const std::string message = "addDatum: added datum type " + std::to_string(type) + " at index " + std::to_string(index); validate(sender, receiver, message, serialized, [index, type, &extra](std::vector* serialized) { serialized->insert(serialized->begin() + index, extra[type]); }); } } } ///////////////////////// MUTATE DATUM //////////////////////////////////// static bool interestingCase(const FmqRequestDatum& lhs, const FmqRequestDatum& rhs) { using Discriminator = FmqRequestDatum::hidl_discriminator; const bool differentValues = (lhs != rhs); const bool sameDiscriminator = (lhs.getDiscriminator() == rhs.getDiscriminator()); const auto discriminator = rhs.getDiscriminator(); const bool isDimensionValue = (discriminator == Discriminator::inputOperandDimensionValue || discriminator == Discriminator::outputOperandDimensionValue); return differentValues && !(sameDiscriminator && isDimensionValue); } static void mutateDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver, const std::vector& serialized) { const std::vector& change = getBadRequestPacketEntries(); for (size_t index = 0; index < serialized.size(); ++index) { for (size_t type = 0; type < change.size(); ++type) { if (interestingCase(serialized[index], change[type])) { const std::string message = "mutateDatum: changed datum at index " + std::to_string(index) + " to datum type " + std::to_string(type); validate(sender, receiver, message, serialized, [index, type, &change](std::vector* serialized) { (*serialized)[index] = change[type]; }); } } } } ///////////////////////// BURST VALIATION TESTS //////////////////////////////////// static void validateBurstSerialization(const sp& preparedModel, const Request& request) { // create burst std::unique_ptr sender; std::unique_ptr receiver; sp callback = new ExecutionBurstCallback(); sp context; ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context)); ASSERT_NE(nullptr, sender.get()); ASSERT_NE(nullptr, receiver.get()); ASSERT_NE(nullptr, context.get()); // load memory into callback slots std::vector keys; keys.reserve(request.pools.size()); std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(keys), [](const auto& pool) { return reinterpret_cast(&pool); }); const std::vector slots = callback->getSlots(request.pools, keys); // ensure slot std::numeric_limits::max() doesn't exist (for // subsequent slot validation testing) ASSERT_TRUE(std::all_of(slots.begin(), slots.end(), [](int32_t slot) { return slot != std::numeric_limits::max(); })); // serialize the request const auto serialized = android::nn::serialize(request, MeasureTiming::YES, slots); // validations removeDatumTest(sender.get(), receiver.get(), serialized); addDatumTest(sender.get(), receiver.get(), serialized); mutateDatumTest(sender.get(), receiver.get(), serialized); } // This test validates that when the Result message size exceeds length of the // result FMQ, the service instance gracefully fails and returns an error. static void validateBurstFmqLength(const sp& preparedModel, const Request& request) { // create regular burst std::shared_ptr controllerRegular; ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength( preparedModel, kExecutionBurstChannelLength, &controllerRegular)); ASSERT_NE(nullptr, controllerRegular.get()); // create burst with small output channel std::shared_ptr controllerSmall; ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength( preparedModel, kExecutionBurstChannelSmallLength, &controllerSmall)); ASSERT_NE(nullptr, controllerSmall.get()); // load memory into callback slots std::vector keys(request.pools.size()); for (size_t i = 0; i < keys.size(); ++i) { keys[i] = reinterpret_cast(&request.pools[i]); } // collect serialized result by running regular burst const auto [nRegular, outputShapesRegular, timingRegular, fallbackRegular] = controllerRegular->compute(request, MeasureTiming::NO, keys); const ErrorStatus statusRegular = nn::legacyConvertResultCodeToErrorStatus(nRegular); EXPECT_FALSE(fallbackRegular); // skip test if regular burst output isn't useful for testing a failure // caused by having too small of a length for the result FMQ const std::vector serialized = android::nn::serialize(statusRegular, outputShapesRegular, timingRegular); if (statusRegular != ErrorStatus::NONE || serialized.size() <= kExecutionBurstChannelSmallLength) { return; } // by this point, execution should fail because the result channel isn't // large enough to return the serialized result const auto [nSmall, outputShapesSmall, timingSmall, fallbackSmall] = controllerSmall->compute(request, MeasureTiming::NO, keys); const ErrorStatus statusSmall = nn::legacyConvertResultCodeToErrorStatus(nSmall); EXPECT_NE(ErrorStatus::NONE, statusSmall); EXPECT_EQ(0u, outputShapesSmall.size()); EXPECT_TRUE(badTiming(timingSmall)); EXPECT_FALSE(fallbackSmall); } static bool isSanitized(const FmqResultDatum& datum) { using Discriminator = FmqResultDatum::hidl_discriminator; // check to ensure the padding values in the returned // FmqResultDatum::OperandInformation are initialized to 0 if (datum.getDiscriminator() == Discriminator::operandInformation) { static_assert( offsetof(FmqResultDatum::OperandInformation, isSufficient) == 0, "unexpected value for offset of FmqResultDatum::OperandInformation::isSufficient"); static_assert( sizeof(FmqResultDatum::OperandInformation::isSufficient) == 1, "unexpected value for size of FmqResultDatum::OperandInformation::isSufficient"); static_assert(offsetof(FmqResultDatum::OperandInformation, numberOfDimensions) == 4, "unexpected value for offset of " "FmqResultDatum::OperandInformation::numberOfDimensions"); static_assert(sizeof(FmqResultDatum::OperandInformation::numberOfDimensions) == 4, "unexpected value for size of " "FmqResultDatum::OperandInformation::numberOfDimensions"); static_assert(sizeof(FmqResultDatum::OperandInformation) == 8, "unexpected value for size of " "FmqResultDatum::OperandInformation"); constexpr size_t paddingOffset = offsetof(FmqResultDatum::OperandInformation, isSufficient) + sizeof(FmqResultDatum::OperandInformation::isSufficient); constexpr size_t paddingSize = offsetof(FmqResultDatum::OperandInformation, numberOfDimensions) - paddingOffset; FmqResultDatum::OperandInformation initialized{}; std::memset(&initialized, 0, sizeof(initialized)); const char* initializedPaddingStart = reinterpret_cast(&initialized) + paddingOffset; const char* datumPaddingStart = reinterpret_cast(&datum.operandInformation()) + paddingOffset; return std::memcmp(datumPaddingStart, initializedPaddingStart, paddingSize) == 0; } // there are no other padding initialization checks required, so return true // for any sum-type that isn't FmqResultDatum::OperandInformation return true; } static void validateBurstSanitized(const sp& preparedModel, const Request& request) { // create burst std::unique_ptr sender; std::unique_ptr receiver; sp callback = new ExecutionBurstCallback(); sp context; ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context)); ASSERT_NE(nullptr, sender.get()); ASSERT_NE(nullptr, receiver.get()); ASSERT_NE(nullptr, context.get()); // load memory into callback slots std::vector keys; keys.reserve(request.pools.size()); std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(keys), [](const auto& pool) { return reinterpret_cast(&pool); }); const std::vector slots = callback->getSlots(request.pools, keys); // send valid request ASSERT_TRUE(sender->send(request, MeasureTiming::YES, slots)); // receive valid result auto serialized = receiver->getPacketBlocking(); ASSERT_TRUE(serialized.has_value()); // sanitize result ASSERT_TRUE(std::all_of(serialized->begin(), serialized->end(), isSanitized)) << "The result serialized data is not properly sanitized"; } ///////////////////////////// ENTRY POINT ////////////////////////////////// void validateBurst(const sp& preparedModel, const Request& request) { ASSERT_NO_FATAL_FAILURE(validateBurstSerialization(preparedModel, request)); ASSERT_NO_FATAL_FAILURE(validateBurstFmqLength(preparedModel, request)); ASSERT_NO_FATAL_FAILURE(validateBurstSanitized(preparedModel, request)); } } // namespace android::hardware::neuralnetworks::V1_2::vts::functional