init/epoll: Make Epoll::Wait() easier to use

Invoke the callback functions from inside Epoll::Wait() instead of
returning a vector with pointers to callback functions. Remove handlers
after handler invocation finished to prevent that self-removal triggers
a use-after-free.

The CL that made Epoll::Wait() return a vector is available at
https://android-review.googlesource.com/c/platform/system/core/+/1112042.

Bug: 213617178
Change-Id: I52c6ade5746a911510746f83802684f2d9cfb429
Signed-off-by: Bart Van Assche <bvanassche@google.com>
This commit is contained in:
Bart Van Assche 2022-10-14 09:13:19 -07:00
parent a1c8a622b2
commit bc5c4a4659
6 changed files with 25 additions and 37 deletions

View file

@ -69,9 +69,11 @@ Result<void> Epoll::UnregisterHandler(int fd) {
if (epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, fd, nullptr) == -1) {
return ErrnoError() << "epoll_ctl failed to remove fd";
}
if (epoll_handlers_.erase(fd) != 1) {
auto it = epoll_handlers_.find(fd);
if (it == epoll_handlers_.end()) {
return Error() << "Attempting to remove epoll handler for FD without an existing handler";
}
to_remove_.insert(it->first);
return {};
}
@ -79,8 +81,7 @@ void Epoll::SetFirstCallback(std::function<void()> first_callback) {
first_callback_ = std::move(first_callback);
}
Result<std::vector<std::shared_ptr<Epoll::Handler>>> Epoll::Wait(
std::optional<std::chrono::milliseconds> timeout) {
Result<int> Epoll::Wait(std::optional<std::chrono::milliseconds> timeout) {
int timeout_ms = -1;
if (timeout && timeout->count() < INT_MAX) {
timeout_ms = timeout->count();
@ -94,7 +95,6 @@ Result<std::vector<std::shared_ptr<Epoll::Handler>>> Epoll::Wait(
if (num_events > 0 && first_callback_) {
first_callback_();
}
std::vector<std::shared_ptr<Handler>> pending_functions;
for (int i = 0; i < num_events; ++i) {
const auto it = epoll_handlers_.find(ev[i].data.fd);
if (it == epoll_handlers_.end()) {
@ -107,10 +107,13 @@ Result<std::vector<std::shared_ptr<Epoll::Handler>>> Epoll::Wait(
// Log something informational.
LOG(ERROR) << "Received unexpected epoll event set: " << ev[i].events;
}
pending_functions.emplace_back(info.handler);
(*info.handler)();
for (auto fd : to_remove_) {
epoll_handlers_.erase(fd);
}
to_remove_.clear();
}
return pending_functions;
return num_events;
}
} // namespace init

View file

@ -24,6 +24,7 @@
#include <map>
#include <memory>
#include <optional>
#include <unordered_set>
#include <vector>
#include <android-base/unique_fd.h>
@ -43,8 +44,7 @@ class Epoll {
Result<void> RegisterHandler(int fd, Handler handler, uint32_t events = EPOLLIN);
Result<void> UnregisterHandler(int fd);
void SetFirstCallback(std::function<void()> first_callback);
Result<std::vector<std::shared_ptr<Handler>>> Wait(
std::optional<std::chrono::milliseconds> timeout);
Result<int> Wait(std::optional<std::chrono::milliseconds> timeout);
private:
struct Info {
@ -55,6 +55,7 @@ class Epoll {
android::base::unique_fd epoll_fd_;
std::map<int, Info> epoll_handlers_;
std::function<void()> first_callback_;
std::unordered_set<int> to_remove_;
};
} // namespace init

View file

@ -60,14 +60,9 @@ TEST(epoll, UnregisterHandler) {
uint8_t byte = 0xee;
ASSERT_TRUE(android::base::WriteFully(fds[1], &byte, sizeof(byte)));
auto results = epoll.Wait({});
ASSERT_RESULT_OK(results);
ASSERT_EQ(results->size(), size_t(1));
for (const auto& function : *results) {
(*function)();
(*function)();
}
auto epoll_result = epoll.Wait({});
ASSERT_RESULT_OK(epoll_result);
ASSERT_EQ(*epoll_result, 1);
ASSERT_TRUE(handler_invoked);
}

View file

@ -1177,14 +1177,10 @@ int SecondStageMain(int argc, char** argv) {
if (am.HasMoreCommands()) epoll_timeout = 0ms;
}
auto pending_functions = epoll.Wait(epoll_timeout);
if (!pending_functions.ok()) {
LOG(ERROR) << pending_functions.error();
} else if (!pending_functions->empty()) {
for (const auto& function : *pending_functions) {
(*function)();
}
} else if (Service::is_exec_service_running()) {
auto epoll_result = epoll.Wait(epoll_timeout);
if (!epoll_result.ok()) {
LOG(ERROR) << epoll_result.error();
} else if (*epoll_result <= 0 && Service::is_exec_service_running()) {
static bool dumped_diagnostics = false;
std::chrono::duration<double> waited =
std::chrono::steady_clock::now() - Service::exec_service_started();

View file

@ -212,11 +212,8 @@ TestFrame::TestFrame(const std::vector<const std::vector<int>>& chords, EventHan
}
void TestFrame::RelaxForMs(std::chrono::milliseconds wait) {
auto pending_functions = epoll_.Wait(wait);
ASSERT_RESULT_OK(pending_functions);
for (const auto& function : *pending_functions) {
(*function)();
}
auto epoll_result = epoll_.Wait(wait);
ASSERT_RESULT_OK(epoll_result);
}
void TestFrame::SetChord(int key, bool value) {

View file

@ -1381,13 +1381,9 @@ static void PropertyServiceThread() {
}
while (true) {
auto pending_functions = epoll.Wait(std::nullopt);
if (!pending_functions.ok()) {
LOG(ERROR) << pending_functions.error();
} else {
for (const auto& function : *pending_functions) {
(*function)();
}
auto epoll_result = epoll.Wait(std::nullopt);
if (!epoll_result.ok()) {
LOG(ERROR) << epoll_result.error();
}
}
}