adb: don't close sockets before hitting EOF.

The standard (RFC 1122 - 4.2.2.13) says that if we call close on a
socket while we have pending data, a TCP RST should be sent to the
other end to notify it that we didn't read all of its data. However,
this can result in data that we've succesfully written out to be dropped
on the other end. To avoid this, instead of immediately closing a
socket, call shutdown on it instead, and then read from the file
descriptor until we hit EOF or an error before closing.

Bug: http://b/74616284
Test: ./test_adb.py
Test: ./test_device.py
Change-Id: I36f72bd14965821dc23de82774b0806b2db24f13
This commit is contained in:
Josh Gao 2018-09-20 17:38:14 -07:00
parent a3303fd21b
commit ffc11d3cf3
3 changed files with 211 additions and 4 deletions

View file

@ -24,6 +24,7 @@ cc_defaults {
"-Wno-missing-field-initializers",
"-Wvla",
],
cpp_std: "gnu++17",
rtti: true,
use_version_lib: true,

View file

@ -26,10 +26,14 @@
#include <unistd.h>
#include <algorithm>
#include <map>
#include <mutex>
#include <string>
#include <thread>
#include <vector>
#include <android-base/thread_annotations.h>
#if !ADB_HOST
#include <android-base/properties.h>
#include <log/log_properties.h>
@ -37,9 +41,150 @@
#include "adb.h"
#include "adb_io.h"
#include "adb_utils.h"
#include "sysdeps/chrono.h"
#include "transport.h"
#include "types.h"
// The standard (RFC 1122 - 4.2.2.13) says that if we call close on a
// socket while we have pending data, a TCP RST should be sent to the
// other end to notify it that we didn't read all of its data. However,
// this can result in data that we've successfully written out to be dropped
// on the other end. To avoid this, instead of immediately closing a
// socket, call shutdown on it instead, and then read from the file
// descriptor until we hit EOF or an error before closing.
struct LingeringSocketCloser {
LingeringSocketCloser() = default;
~LingeringSocketCloser() = delete;
// Defer thread creation until it's needed, because we need for there to
// only be one thread when dropping privileges in adbd.
void Start() {
CHECK(!thread_.joinable());
int fds[2];
if (adb_socketpair(fds) != 0) {
PLOG(FATAL) << "adb_socketpair failed";
}
set_file_block_mode(fds[0], false);
set_file_block_mode(fds[1], false);
notify_fd_read_.reset(fds[0]);
notify_fd_write_.reset(fds[1]);
thread_ = std::thread([this]() { Run(); });
}
void EnqueueSocket(unique_fd socket) {
// Shutdown the socket in the outgoing direction only, so that
// we don't have the same problem on the opposite end.
adb_shutdown(socket.get(), SHUT_WR);
set_file_block_mode(socket.get(), false);
std::lock_guard<std::mutex> lock(mutex_);
int fd = socket.get();
SocketInfo info = {
.fd = std::move(socket),
.deadline = std::chrono::steady_clock::now() + 1s,
};
D("LingeringSocketCloser received fd %d", fd);
fds_.emplace(fd, std::move(info));
if (adb_write(notify_fd_write_, "", 1) == -1 && errno != EAGAIN) {
PLOG(FATAL) << "failed to write to LingeringSocketCloser notify fd";
}
}
private:
std::vector<adb_pollfd> GeneratePollFds() {
std::lock_guard<std::mutex> lock(mutex_);
std::vector<adb_pollfd> result;
result.push_back(adb_pollfd{.fd = notify_fd_read_, .events = POLLIN});
for (auto& [fd, _] : fds_) {
result.push_back(adb_pollfd{.fd = fd, .events = POLLIN});
}
return result;
}
void Run() {
while (true) {
std::vector<adb_pollfd> pfds = GeneratePollFds();
int rc = adb_poll(pfds.data(), pfds.size(), 1000);
if (rc == -1) {
PLOG(FATAL) << "poll failed in LingeringSocketCloser";
}
std::lock_guard<std::mutex> lock(mutex_);
if (rc == 0) {
// Check deadlines.
auto now = std::chrono::steady_clock::now();
for (auto it = fds_.begin(); it != fds_.end();) {
if (now > it->second.deadline) {
D("LingeringSocketCloser closing fd %d due to deadline", it->first);
it = fds_.erase(it);
} else {
D("deadline still not expired for fd %d", it->first);
++it;
}
}
continue;
}
for (auto& pfd : pfds) {
if ((pfd.revents & POLLIN) == 0) {
continue;
}
// Empty the fd.
ssize_t rc;
char buf[32768];
while ((rc = adb_read(pfd.fd, buf, sizeof(buf))) > 0) {
continue;
}
if (pfd.fd == notify_fd_read_) {
continue;
}
auto it = fds_.find(pfd.fd);
if (it == fds_.end()) {
LOG(FATAL) << "fd is missing";
}
if (rc == -1 && errno == EAGAIN) {
if (std::chrono::steady_clock::now() > it->second.deadline) {
D("LingeringSocketCloser closing fd %d due to deadline", pfd.fd);
} else {
continue;
}
} else if (rc == -1) {
D("LingeringSocketCloser closing fd %d due to error %d", pfd.fd, errno);
} else {
D("LingeringSocketCloser closing fd %d due to EOF", pfd.fd);
}
fds_.erase(it);
}
}
}
std::thread thread_;
unique_fd notify_fd_read_;
unique_fd notify_fd_write_;
struct SocketInfo {
unique_fd fd;
std::chrono::steady_clock::time_point deadline;
};
std::mutex mutex_;
std::map<int, SocketInfo> fds_ GUARDED_BY(mutex_);
};
static auto& socket_closer = *new LingeringSocketCloser();
static std::recursive_mutex& local_socket_list_lock = *new std::recursive_mutex();
static unsigned local_socket_next_id = 1;
@ -243,10 +388,12 @@ static void local_socket_destroy(asocket* s) {
D("LS(%d): destroying fde.fd=%d", s->id, s->fd);
/* IMPORTANT: the remove closes the fd
** that belongs to this socket
*/
fdevent_destroy(s->fde);
// Defer thread creation until it's needed, because we need for there to
// only be one thread when dropping privileges in adbd.
static std::once_flag once;
std::call_once(once, []() { socket_closer.Start(); });
socket_closer.EnqueueSocket(fdevent_release(s->fde));
remove_socket(s);
delete s;

View file

@ -35,6 +35,8 @@ import threading
import time
import unittest
from datetime import datetime
import adb
def requires_root(func):
@ -1335,6 +1337,63 @@ class DeviceOfflineTest(DeviceTest):
self.device.forward_remove("tcp:{}".format(local_port))
class SocketTest(DeviceTest):
def test_socket_flush(self):
"""Test that we handle socket closure properly.
If we're done writing to a socket, closing before the other end has
closed will send a TCP_RST if we have incoming data queued up, which
may result in data that we've written being discarded.
Bug: http://b/74616284
"""
s = socket.create_connection(("localhost", 5037))
def adb_length_prefixed(string):
encoded = string.encode("utf8")
result = b"%04x%s" % (len(encoded), encoded)
return result
if "ANDROID_SERIAL" in os.environ:
transport_string = "host:transport:" + os.environ["ANDROID_SERIAL"]
else:
transport_string = "host:transport-any"
s.sendall(adb_length_prefixed(transport_string))
response = s.recv(4)
self.assertEquals(b"OKAY", response)
shell_string = "shell:sleep 0.5; dd if=/dev/zero bs=1m count=1 status=none; echo foo"
s.sendall(adb_length_prefixed(shell_string))
response = s.recv(4)
self.assertEquals(b"OKAY", response)
# Spawn a thread that dumps garbage into the socket until failure.
def spam():
buf = b"\0" * 16384
try:
while True:
s.sendall(buf)
except Exception as ex:
print(ex)
thread = threading.Thread(target=spam)
thread.start()
time.sleep(1)
received = b""
while True:
read = s.recv(512)
if len(read) == 0:
break
received += read
self.assertEquals(1024 * 1024 + len("foo\n"), len(received))
thread.join()
if sys.platform == "win32":
# From https://stackoverflow.com/a/38749458
import os