adb: don't close sockets before hitting EOF.
The standard (RFC 1122 - 4.2.2.13) says that if we call close on a socket while we have pending data, a TCP RST should be sent to the other end to notify it that we didn't read all of its data. However, this can result in data that we've succesfully written out to be dropped on the other end. To avoid this, instead of immediately closing a socket, call shutdown on it instead, and then read from the file descriptor until we hit EOF or an error before closing. Bug: http://b/74616284 Test: ./test_adb.py Test: ./test_device.py Change-Id: I36f72bd14965821dc23de82774b0806b2db24f13
This commit is contained in:
parent
a3303fd21b
commit
ffc11d3cf3
3 changed files with 211 additions and 4 deletions
|
@ -24,6 +24,7 @@ cc_defaults {
|
|||
"-Wno-missing-field-initializers",
|
||||
"-Wvla",
|
||||
],
|
||||
cpp_std: "gnu++17",
|
||||
rtti: true,
|
||||
|
||||
use_version_lib: true,
|
||||
|
|
155
adb/sockets.cpp
155
adb/sockets.cpp
|
@ -26,10 +26,14 @@
|
|||
#include <unistd.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#include <android-base/thread_annotations.h>
|
||||
|
||||
#if !ADB_HOST
|
||||
#include <android-base/properties.h>
|
||||
#include <log/log_properties.h>
|
||||
|
@ -37,9 +41,150 @@
|
|||
|
||||
#include "adb.h"
|
||||
#include "adb_io.h"
|
||||
#include "adb_utils.h"
|
||||
#include "sysdeps/chrono.h"
|
||||
#include "transport.h"
|
||||
#include "types.h"
|
||||
|
||||
// The standard (RFC 1122 - 4.2.2.13) says that if we call close on a
|
||||
// socket while we have pending data, a TCP RST should be sent to the
|
||||
// other end to notify it that we didn't read all of its data. However,
|
||||
// this can result in data that we've successfully written out to be dropped
|
||||
// on the other end. To avoid this, instead of immediately closing a
|
||||
// socket, call shutdown on it instead, and then read from the file
|
||||
// descriptor until we hit EOF or an error before closing.
|
||||
struct LingeringSocketCloser {
|
||||
LingeringSocketCloser() = default;
|
||||
~LingeringSocketCloser() = delete;
|
||||
|
||||
// Defer thread creation until it's needed, because we need for there to
|
||||
// only be one thread when dropping privileges in adbd.
|
||||
void Start() {
|
||||
CHECK(!thread_.joinable());
|
||||
|
||||
int fds[2];
|
||||
if (adb_socketpair(fds) != 0) {
|
||||
PLOG(FATAL) << "adb_socketpair failed";
|
||||
}
|
||||
|
||||
set_file_block_mode(fds[0], false);
|
||||
set_file_block_mode(fds[1], false);
|
||||
|
||||
notify_fd_read_.reset(fds[0]);
|
||||
notify_fd_write_.reset(fds[1]);
|
||||
|
||||
thread_ = std::thread([this]() { Run(); });
|
||||
}
|
||||
|
||||
void EnqueueSocket(unique_fd socket) {
|
||||
// Shutdown the socket in the outgoing direction only, so that
|
||||
// we don't have the same problem on the opposite end.
|
||||
adb_shutdown(socket.get(), SHUT_WR);
|
||||
set_file_block_mode(socket.get(), false);
|
||||
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
int fd = socket.get();
|
||||
SocketInfo info = {
|
||||
.fd = std::move(socket),
|
||||
.deadline = std::chrono::steady_clock::now() + 1s,
|
||||
};
|
||||
|
||||
D("LingeringSocketCloser received fd %d", fd);
|
||||
|
||||
fds_.emplace(fd, std::move(info));
|
||||
if (adb_write(notify_fd_write_, "", 1) == -1 && errno != EAGAIN) {
|
||||
PLOG(FATAL) << "failed to write to LingeringSocketCloser notify fd";
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<adb_pollfd> GeneratePollFds() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
std::vector<adb_pollfd> result;
|
||||
result.push_back(adb_pollfd{.fd = notify_fd_read_, .events = POLLIN});
|
||||
for (auto& [fd, _] : fds_) {
|
||||
result.push_back(adb_pollfd{.fd = fd, .events = POLLIN});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void Run() {
|
||||
while (true) {
|
||||
std::vector<adb_pollfd> pfds = GeneratePollFds();
|
||||
int rc = adb_poll(pfds.data(), pfds.size(), 1000);
|
||||
if (rc == -1) {
|
||||
PLOG(FATAL) << "poll failed in LingeringSocketCloser";
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (rc == 0) {
|
||||
// Check deadlines.
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
for (auto it = fds_.begin(); it != fds_.end();) {
|
||||
if (now > it->second.deadline) {
|
||||
D("LingeringSocketCloser closing fd %d due to deadline", it->first);
|
||||
it = fds_.erase(it);
|
||||
} else {
|
||||
D("deadline still not expired for fd %d", it->first);
|
||||
++it;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto& pfd : pfds) {
|
||||
if ((pfd.revents & POLLIN) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Empty the fd.
|
||||
ssize_t rc;
|
||||
char buf[32768];
|
||||
while ((rc = adb_read(pfd.fd, buf, sizeof(buf))) > 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (pfd.fd == notify_fd_read_) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto it = fds_.find(pfd.fd);
|
||||
if (it == fds_.end()) {
|
||||
LOG(FATAL) << "fd is missing";
|
||||
}
|
||||
|
||||
if (rc == -1 && errno == EAGAIN) {
|
||||
if (std::chrono::steady_clock::now() > it->second.deadline) {
|
||||
D("LingeringSocketCloser closing fd %d due to deadline", pfd.fd);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
} else if (rc == -1) {
|
||||
D("LingeringSocketCloser closing fd %d due to error %d", pfd.fd, errno);
|
||||
} else {
|
||||
D("LingeringSocketCloser closing fd %d due to EOF", pfd.fd);
|
||||
}
|
||||
|
||||
fds_.erase(it);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::thread thread_;
|
||||
unique_fd notify_fd_read_;
|
||||
unique_fd notify_fd_write_;
|
||||
|
||||
struct SocketInfo {
|
||||
unique_fd fd;
|
||||
std::chrono::steady_clock::time_point deadline;
|
||||
};
|
||||
|
||||
std::mutex mutex_;
|
||||
std::map<int, SocketInfo> fds_ GUARDED_BY(mutex_);
|
||||
};
|
||||
|
||||
static auto& socket_closer = *new LingeringSocketCloser();
|
||||
|
||||
static std::recursive_mutex& local_socket_list_lock = *new std::recursive_mutex();
|
||||
static unsigned local_socket_next_id = 1;
|
||||
|
||||
|
@ -243,10 +388,12 @@ static void local_socket_destroy(asocket* s) {
|
|||
|
||||
D("LS(%d): destroying fde.fd=%d", s->id, s->fd);
|
||||
|
||||
/* IMPORTANT: the remove closes the fd
|
||||
** that belongs to this socket
|
||||
*/
|
||||
fdevent_destroy(s->fde);
|
||||
// Defer thread creation until it's needed, because we need for there to
|
||||
// only be one thread when dropping privileges in adbd.
|
||||
static std::once_flag once;
|
||||
std::call_once(once, []() { socket_closer.Start(); });
|
||||
|
||||
socket_closer.EnqueueSocket(fdevent_release(s->fde));
|
||||
|
||||
remove_socket(s);
|
||||
delete s;
|
||||
|
|
|
@ -35,6 +35,8 @@ import threading
|
|||
import time
|
||||
import unittest
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import adb
|
||||
|
||||
def requires_root(func):
|
||||
|
@ -1335,6 +1337,63 @@ class DeviceOfflineTest(DeviceTest):
|
|||
self.device.forward_remove("tcp:{}".format(local_port))
|
||||
|
||||
|
||||
class SocketTest(DeviceTest):
|
||||
def test_socket_flush(self):
|
||||
"""Test that we handle socket closure properly.
|
||||
|
||||
If we're done writing to a socket, closing before the other end has
|
||||
closed will send a TCP_RST if we have incoming data queued up, which
|
||||
may result in data that we've written being discarded.
|
||||
|
||||
Bug: http://b/74616284
|
||||
"""
|
||||
s = socket.create_connection(("localhost", 5037))
|
||||
|
||||
def adb_length_prefixed(string):
|
||||
encoded = string.encode("utf8")
|
||||
result = b"%04x%s" % (len(encoded), encoded)
|
||||
return result
|
||||
|
||||
if "ANDROID_SERIAL" in os.environ:
|
||||
transport_string = "host:transport:" + os.environ["ANDROID_SERIAL"]
|
||||
else:
|
||||
transport_string = "host:transport-any"
|
||||
|
||||
s.sendall(adb_length_prefixed(transport_string))
|
||||
response = s.recv(4)
|
||||
self.assertEquals(b"OKAY", response)
|
||||
|
||||
shell_string = "shell:sleep 0.5; dd if=/dev/zero bs=1m count=1 status=none; echo foo"
|
||||
s.sendall(adb_length_prefixed(shell_string))
|
||||
|
||||
response = s.recv(4)
|
||||
self.assertEquals(b"OKAY", response)
|
||||
|
||||
# Spawn a thread that dumps garbage into the socket until failure.
|
||||
def spam():
|
||||
buf = b"\0" * 16384
|
||||
try:
|
||||
while True:
|
||||
s.sendall(buf)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
|
||||
thread = threading.Thread(target=spam)
|
||||
thread.start()
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
received = b""
|
||||
while True:
|
||||
read = s.recv(512)
|
||||
if len(read) == 0:
|
||||
break
|
||||
received += read
|
||||
|
||||
self.assertEquals(1024 * 1024 + len("foo\n"), len(received))
|
||||
thread.join()
|
||||
|
||||
|
||||
if sys.platform == "win32":
|
||||
# From https://stackoverflow.com/a/38749458
|
||||
import os
|
||||
|
|
Loading…
Reference in a new issue