diff --git a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/ProtectCallback.h b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/ProtectCallback.h index 92ed1cda5d..9a7fe5e1c7 100644 --- a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/ProtectCallback.h +++ b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/ProtectCallback.h @@ -56,6 +56,8 @@ class IProtectedCallback { // Thread safe class class DeathMonitor final { public: + explicit DeathMonitor(uintptr_t cookieKey) : kCookieKey(cookieKey) {} + static void serviceDied(void* cookie); void serviceDied(); // Precondition: `killable` must be non-null. @@ -63,9 +65,18 @@ class DeathMonitor final { // Precondition: `killable` must be non-null. void remove(IProtectedCallback* killable) const; + uintptr_t getCookieKey() const { return kCookieKey; } + + ~DeathMonitor(); + DeathMonitor(const DeathMonitor&) = delete; + DeathMonitor(DeathMonitor&&) noexcept = delete; + DeathMonitor& operator=(const DeathMonitor&) = delete; + DeathMonitor& operator=(DeathMonitor&&) noexcept = delete; + private: mutable std::mutex mMutex; mutable std::vector mObjects GUARDED_BY(mMutex); + const uintptr_t kCookieKey; }; class DeathHandler final { diff --git a/neuralnetworks/aidl/utils/src/ProtectCallback.cpp b/neuralnetworks/aidl/utils/src/ProtectCallback.cpp index 54a673caf5..4a7ac08895 100644 --- a/neuralnetworks/aidl/utils/src/ProtectCallback.cpp +++ b/neuralnetworks/aidl/utils/src/ProtectCallback.cpp @@ -25,6 +25,7 @@ #include #include +#include #include #include #include @@ -33,6 +34,16 @@ namespace aidl::android::hardware::neuralnetworks::utils { +namespace { + +// Only dereference the cookie if it's valid (if it's in this set) +// Only used with ndk +std::mutex sCookiesMutex; +uintptr_t sCookieKeyCounter GUARDED_BY(sCookiesMutex) = 0; +std::map> sCookies GUARDED_BY(sCookiesMutex); + +} // namespace + void DeathMonitor::serviceDied() { std::lock_guard guard(mMutex); std::for_each(mObjects.begin(), mObjects.end(), @@ -40,8 +51,24 @@ void DeathMonitor::serviceDied() { } void DeathMonitor::serviceDied(void* cookie) { - auto deathMonitor = static_cast(cookie); - deathMonitor->serviceDied(); + std::shared_ptr monitor; + { + std::lock_guard guard(sCookiesMutex); + if (auto it = sCookies.find(reinterpret_cast(cookie)); it != sCookies.end()) { + monitor = it->second.lock(); + sCookies.erase(it); + } else { + LOG(INFO) + << "Service died, but cookie is no longer valid so there is nothing to notify."; + return; + } + } + if (monitor) { + LOG(INFO) << "Notifying DeathMonitor from serviceDied."; + monitor->serviceDied(); + } else { + LOG(INFO) << "Tried to notify DeathMonitor from serviceDied but could not promote."; + } } void DeathMonitor::add(IProtectedCallback* killable) const { @@ -57,12 +84,25 @@ void DeathMonitor::remove(IProtectedCallback* killable) const { mObjects.erase(removedIter); } +DeathMonitor::~DeathMonitor() { + // lock must be taken so object is not used in OnBinderDied" + std::lock_guard guard(sCookiesMutex); + sCookies.erase(kCookieKey); +} + nn::GeneralResult DeathHandler::create(std::shared_ptr object) { if (object == nullptr) { return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "utils::DeathHandler::create must have non-null object"; } - auto deathMonitor = std::make_shared(); + + std::shared_ptr deathMonitor; + { + std::lock_guard guard(sCookiesMutex); + deathMonitor = std::make_shared(sCookieKeyCounter++); + sCookies[deathMonitor->getCookieKey()] = deathMonitor; + } + auto deathRecipient = ndk::ScopedAIBinder_DeathRecipient( AIBinder_DeathRecipient_new(DeathMonitor::serviceDied)); @@ -70,8 +110,9 @@ nn::GeneralResult DeathHandler::create(std::shared_ptrisRemote()) { - const auto ret = ndk::ScopedAStatus::fromStatus(AIBinder_linkToDeath( - object->asBinder().get(), deathRecipient.get(), deathMonitor.get())); + const auto ret = ndk::ScopedAStatus::fromStatus( + AIBinder_linkToDeath(object->asBinder().get(), deathRecipient.get(), + reinterpret_cast(deathMonitor->getCookieKey()))); HANDLE_ASTATUS(ret) << "AIBinder_linkToDeath failed"; } @@ -91,8 +132,9 @@ DeathHandler::DeathHandler(std::shared_ptr object, DeathHandler::~DeathHandler() { if (kObject != nullptr && kDeathRecipient.get() != nullptr && kDeathMonitor != nullptr) { - const auto ret = ndk::ScopedAStatus::fromStatus(AIBinder_unlinkToDeath( - kObject->asBinder().get(), kDeathRecipient.get(), kDeathMonitor.get())); + const auto ret = ndk::ScopedAStatus::fromStatus( + AIBinder_unlinkToDeath(kObject->asBinder().get(), kDeathRecipient.get(), + reinterpret_cast(kDeathMonitor->getCookieKey()))); const auto maybeSuccess = handleTransportError(ret); if (!maybeSuccess.ok()) { LOG(ERROR) << maybeSuccess.error().message; diff --git a/neuralnetworks/aidl/utils/test/DeviceTest.cpp b/neuralnetworks/aidl/utils/test/DeviceTest.cpp index 73727b3974..ffd3b8e5f1 100644 --- a/neuralnetworks/aidl/utils/test/DeviceTest.cpp +++ b/neuralnetworks/aidl/utils/test/DeviceTest.cpp @@ -697,7 +697,8 @@ TEST_P(DeviceTest, prepareModelAsyncCrash) { const auto mockDevice = createMockDevice(); const auto device = Device::create(kName, mockDevice, kVersion).value(); const auto ret = [&device]() { - DeathMonitor::serviceDied(device->getDeathMonitor()); + DeathMonitor::serviceDied( + reinterpret_cast(device->getDeathMonitor()->getCookieKey())); return ndk::ScopedAStatus::ok(); }; EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _)) @@ -846,7 +847,8 @@ TEST_P(DeviceTest, prepareModelWithConfigAsyncCrash) { const auto mockDevice = createMockDevice(); const auto device = Device::create(kName, mockDevice, kVersion).value(); const auto ret = [&device]() { - DeathMonitor::serviceDied(device->getDeathMonitor()); + DeathMonitor::serviceDied( + reinterpret_cast(device->getDeathMonitor()->getCookieKey())); return ndk::ScopedAStatus::ok(); }; EXPECT_CALL(*mockDevice, prepareModelWithConfig(_, _, _)) @@ -970,7 +972,8 @@ TEST_P(DeviceTest, prepareModelFromCacheAsyncCrash) { const auto mockDevice = createMockDevice(); const auto device = Device::create(kName, mockDevice, kVersion).value(); const auto ret = [&device]() { - DeathMonitor::serviceDied(device->getDeathMonitor()); + DeathMonitor::serviceDied( + reinterpret_cast(device->getDeathMonitor()->getCookieKey())); return ndk::ScopedAStatus::ok(); }; EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))