diff --git a/fastboot/Android.mk b/fastboot/Android.mk index e0f7c73f3..a326f55d6 100644 --- a/fastboot/Android.mk +++ b/fastboot/Android.mk @@ -32,6 +32,7 @@ LOCAL_SRC_FILES := \ protocol.cpp \ socket.cpp \ tcp.cpp \ + udp.cpp \ util.cpp \ LOCAL_MODULE := fastboot @@ -114,6 +115,8 @@ LOCAL_SRC_FILES := \ socket_test.cpp \ tcp.cpp \ tcp_test.cpp \ + udp.cpp \ + udp_test.cpp \ LOCAL_STATIC_LIBRARIES := libbase libcutils diff --git a/fastboot/fastboot.cpp b/fastboot/fastboot.cpp index 7c7d417c5..636092e9e 100644 --- a/fastboot/fastboot.cpp +++ b/fastboot/fastboot.cpp @@ -59,6 +59,7 @@ #include "fs.h" #include "tcp.h" #include "transport.h" +#include "udp.h" #include "usb.h" #ifndef O_BINARY @@ -245,22 +246,41 @@ static Transport* open_device() { return transport; } + Socket::Protocol protocol = Socket::Protocol::kTcp; std::string host; - int port = tcp::kDefaultPort; - if (serial != nullptr && android::base::StartsWith(serial, "tcp:")) { - std::string error; - const char* address = serial + strlen("tcp:"); + int port = 0; + if (serial != nullptr) { + const char* net_address = nullptr; - if (!android::base::ParseNetAddress(address, &host, &port, nullptr, &error)) { - fprintf(stderr, "error: Invalid network address '%s': %s\n", address, error.c_str()); - return nullptr; + if (android::base::StartsWith(serial, "tcp:")) { + protocol = Socket::Protocol::kTcp; + port = tcp::kDefaultPort; + net_address = serial + strlen("tcp:"); + } else if (android::base::StartsWith(serial, "udp:")) { + protocol = Socket::Protocol::kUdp; + port = udp::kDefaultPort; + net_address = serial + strlen("udp:"); + } + + if (net_address != nullptr) { + std::string error; + if (!android::base::ParseNetAddress(net_address, &host, &port, nullptr, &error)) { + fprintf(stderr, "error: Invalid network address '%s': %s\n", net_address, + error.c_str()); + return nullptr; + } } } while (true) { if (!host.empty()) { std::string error; - transport = tcp::Connect(host, port, &error).release(); + if (protocol == Socket::Protocol::kTcp) { + transport = tcp::Connect(host, port, &error).release(); + } else if (protocol == Socket::Protocol::kUdp) { + transport = udp::Connect(host, port, &error).release(); + } + if (transport == nullptr && announce) { fprintf(stderr, "error: %s\n", error.c_str()); } @@ -337,8 +357,9 @@ static void usage() { " formatting.\n" " -s Specify a device. For USB, provide either\n" " a serial number or path to device port.\n" - " For TCP, provide an address in the form\n" - " tcp:[:port].\n" + " For ethernet, provide an address in the" + " form :[:port] where" + " is either tcp or udp.\n" " -p Specify product name.\n" " -c Override kernel commandline.\n" " -i Specify a custom USB vendor id.\n" diff --git a/fastboot/fastboot_protocol.txt b/fastboot/fastboot_protocol.txt index 4aa48b1ff..2801703c8 100644 --- a/fastboot/fastboot_protocol.txt +++ b/fastboot/fastboot_protocol.txt @@ -17,9 +17,9 @@ Basic Requirements * The protocol is entirely host-driven and synchronous (unlike the multi-channel, bi-directional, asynchronous ADB protocol) -* TCP +* TCP or UDP * Device must be reachable via IP. - * Device will act as the TCP server, fastboot will be the client. + * Device will act as the server, fastboot will be the client. * Fastboot data is wrapped in a simple protocol; see below for details. @@ -217,3 +217,226 @@ Device [0x00][0x00][0x00][0x00][0x00][0x00][0x00][0x07]OKAY0.4 Host [0x00][0x00][0x00][0x00][0x00][0x00][0x00][0x0B]getvar:none Device [0x00][0x00][0x00][0x00][0x00][0x00][0x00][0x04]OKAY Host + + +UDP Protocol v1 +--------------- + +The UDP protocol is more complex than TCP since we must implement reliability +to ensure no packets are lost, but the general concept of wrapping the fastboot +protocol is the same. + +Overview: + 1. As with TCP, the device will listen on UDP port 5554. + 2. Maximum UDP packet size is negotiated during initialization. + 3. The host drives all communication; the device may only send a packet as a + response to a host packet. + 4. If the host does not receive a response in 500ms it will re-transmit. + +-- UDP Packet format -- + +----------+----+-------+-------+--------------------+ + | Byte # | 0 | 1 | 2 - 3 | 4+ | + +----------+----+-------+-------+--------------------+ + | Contents | ID | Flags | Seq # | Data | + +----------+----+-------+-------+--------------------+ + + ID Packet ID: + 0x00: Error. + 0x01: Query. + 0x02: Initialization. + 0x03: Fastboot. + + Packet types are described in more detail below. + + Flags Packet flags: 0 0 0 0 0 0 0 C + C=1 indicates a continuation packet; the data is too large and will + continue in the next packet. + + Remaining bits are reserved for future use and must be set to 0. + + Seq # 2-byte packet sequence number (big-endian). The host will increment + this by 1 with each new packet, and the device must provide the + corresponding sequence number in the response packets. + + Data Packet data, not present in all packets. + +-- Packet Types -- +Query The host sends a query packet once on startup to sync with the device. + The host will not know the current sequence number, so the device must + respond to all query packets regardless of sequence number. + + The response data field should contain a 2-byte big-endian value + giving the next expected sequence number. + +Init The host sends an init packet once the query response is returned. The + device must abort any in-progress operation and prepare for a new + fastboot session. This message is meant to allow recovery if a + previous session failed, e.g. due to network error or user Ctrl+C. + + The data field contains two big-endian 2-byte values, a protocol + version and the max UDP packet size (including the 4-byte header). + Both the host and device will send these values, and in each case + the minimum of the sent values must be used. + +Fastboot These packets wrap the fastboot protocol. To write, the host will + send a packet with fastboot data, and the device will reply with an + empty packet as an ACK. To read, the host will send an empty packet, + and the device will reply with fastboot data. The device may not give + any data in the ACK packet. + +Error The device may respond to any packet with an error packet to indicate + a UDP protocol error. The data field should contain an ASCII string + describing the error. This is the only case where a device is allowed + to return a packet ID other than the one sent by the host. + +-- Packet Size -- +The maximum packet size is negotiated by the host and device in the Init packet. +Devices must support at least 512-byte packets, but packet size has a direct +correlation with download speed, so devices are strongly suggested to support at +least 1024-byte packets. On a local network with 0.5ms round-trip time this will +provide transfer rates of ~2MB/s. Over WiFi it will likely be significantly +less. + +Query and Initialization packets, which are sent before size negotiation is +complete, must always be 512 bytes or less. + +-- Packet Re-Transmission -- +The host will re-transmit any packet that does not receive a response. The +requirement of exactly one device response packet per host packet is how we +achieve reliability and in-order delivery of packets. + +For simplicity of implementation, there is no windowing of multiple +unacknowledged packets in this version of the protocol. The host will continue +to send the same packet until a response is received. Windowing functionality +may be implemented in future versions if necessary to increase performance. + +The first Query packet will only be attempted a small number of times, but +subsequent packets will attempt to retransmit for at least 1 minute before +giving up. This means a device may safely ignore host UDP packets for up to 1 +minute during long operations, e.g. writing to flash. + +-- Continuation Packets -- +Any packet may set the continuation flag to indicate that the data is +incomplete. Large data such as downloading an image may require many +continuation packets. The receiver should respond to a continuation packet with +an empty packet to acknowledge receipt. See examples below. + +-- Summary -- +The host starts with a Query packet, then an Initialization packet, after +which only Fastboot packets are sent. Fastboot packets may contain data from +the host for writes, or from the device for reads, but not both. + +Given a next expected sequence number S and a received packet P, the device +behavior should be: + if P is a Query packet: + * respond with a Query packet with S in the data field + else if P has sequence == S: + * process P and take any required action + * create a response packet R with the same ID and sequence as P, containing + any response data required. + * transmit R and save it in case of re-transmission + * increment S + else if P has sequence == S - 1: + * re-transmit the saved response packet R from above + else: + * ignore the packet + +-- Examples -- +In the examples below, S indicates the starting client sequence number. + +Host Client +====================================================================== +[Initialization, S = 0x55AA] +[Host: version 1, 2048-byte packets. Client: version 2, 1024-byte packets.] +[Resulting values to use: version = 1, max packet size = 1024] +ID Flag SeqH SeqL Data ID Flag SeqH SeqL Data +---------------------------------------------------------------------- +0x01 0x00 0x00 0x00 + 0x01 0x00 0x00 0x00 0x55 0xAA +0x02 0x00 0x55 0xAA 0x00 0x01 0x08 0x00 + 0x02 0x00 0x55 0xAA 0x00 0x02 0x04 0x00 + +---------------------------------------------------------------------- +[fastboot "getvar" commands, S = 0x0001] +ID Flags SeqH SeqL Data ID Flags SeqH SeqL Data +---------------------------------------------------------------------- +0x03 0x00 0x00 0x01 getvar:version + 0x03 0x00 0x00 0x01 +0x03 0x00 0x00 0x02 + 0x03 0x00 0x00 0x02 OKAY0.4 +0x03 0x00 0x00 0x03 getvar:foo + 0x03 0x00 0x00 0x03 +0x03 0x00 0x00 0x04 + 0x03 0x00 0x00 0x04 OKAY + +---------------------------------------------------------------------- +[fastboot "INFO" responses, S = 0x0000] +ID Flags SeqH SeqL Data ID Flags SeqH SeqL Data +---------------------------------------------------------------------- +0x03 0x00 0x00 0x00 + 0x03 0x00 0x00 0x00 +0x03 0x00 0x00 0x01 + 0x03 0x00 0x00 0x01 INFOWait1 +0x03 0x00 0x00 0x02 + 0x03 0x00 0x00 0x02 INFOWait2 +0x03 0x00 0x00 0x03 + 0x03 0x00 0x00 0x03 OKAY + +---------------------------------------------------------------------- +[Chunking 2100 bytes of data, max packet size = 1024, S = 0xFFFF] +ID Flag SeqH SeqL Data ID Flag SeqH SeqL Data +---------------------------------------------------------------------- +0x03 0x00 0xFF 0xFF download:0000834 + 0x03 0x00 0xFF 0xFF +0x03 0x00 0x00 0x00 + 0x03 0x00 0x00 0x00 DATA0000834 +0x03 0x01 0x00 0x01 <1020 bytes> + 0x03 0x00 0x00 0x01 +0x03 0x01 0x00 0x02 <1020 bytes> + 0x03 0x00 0x00 0x02 +0x03 0x00 0x00 0x03 <60 bytes> + 0x03 0x00 0x00 0x03 +0x03 0x00 0x00 0x04 + 0x03 0x00 0x00 0x04 OKAY + +---------------------------------------------------------------------- +[Unknown ID error, S = 0x0000] +ID Flags SeqH SeqL Data ID Flags SeqH SeqL Data +---------------------------------------------------------------------- +0x10 0x00 0x00 0x00 + 0x00 0x00 0x00 0x00 + +---------------------------------------------------------------------- +[Host packet loss and retransmission, S = 0x0000] +ID Flags SeqH SeqL Data ID Flags SeqH SeqL Data +---------------------------------------------------------------------- +0x03 0x00 0x00 0x00 getvar:version [lost] +0x03 0x00 0x00 0x00 getvar:version [lost] +0x03 0x00 0x00 0x00 getvar:version + 0x03 0x00 0x00 0x00 +0x03 0x00 0x00 0x01 + 0x03 0x00 0x00 0x01 OKAY0.4 + +---------------------------------------------------------------------- +[Client packet loss and retransmission, S = 0x0000] +ID Flags SeqH SeqL Data ID Flags SeqH SeqL Data +---------------------------------------------------------------------- +0x03 0x00 0x00 0x00 getvar:version + 0x03 0x00 0x00 0x00 [lost] +0x03 0x00 0x00 0x00 getvar:version + 0x03 0x00 0x00 0x00 [lost] +0x03 0x00 0x00 0x00 getvar:version + 0x03 0x00 0x00 0x00 +0x03 0x00 0x00 0x01 + 0x03 0x00 0x00 0x01 OKAY0.4 + +---------------------------------------------------------------------- +[Host packet delayed, S = 0x0000] +ID Flags SeqH SeqL Data ID Flags SeqH SeqL Data +---------------------------------------------------------------------- +0x03 0x00 0x00 0x00 getvar:version [delayed] +0x03 0x00 0x00 0x00 getvar:version + 0x03 0x00 0x00 0x00 +0x03 0x00 0x00 0x01 + 0x03 0x00 0x00 0x01 OKAY0.4 +0x03 0x00 0x00 0x00 getvar:version [arrives late with old seq#, is ignored] diff --git a/fastboot/socket.cpp b/fastboot/socket.cpp index d49f47ff2..14ecd937a 100644 --- a/fastboot/socket.cpp +++ b/fastboot/socket.cpp @@ -48,18 +48,6 @@ int Socket::Close() { return ret; } -bool Socket::SetReceiveTimeout(int timeout_ms) { - if (timeout_ms != receive_timeout_ms_) { - if (socket_set_receive_timeout(sock_, timeout_ms) == 0) { - receive_timeout_ms_ = timeout_ms; - return true; - } - return false; - } - - return true; -} - ssize_t Socket::ReceiveAll(void* data, size_t length, int timeout_ms) { size_t total = 0; @@ -82,6 +70,40 @@ int Socket::GetLocalPort() { return socket_get_local_port(sock_); } +// According to Windows setsockopt() documentation, if a Windows socket times out during send() or +// recv() the state is indeterminate and should not be used. Our UDP protocol relies on being able +// to re-send after a timeout, so we must use select() rather than SO_RCVTIMEO. +// See https://msdn.microsoft.com/en-us/library/windows/desktop/ms740476(v=vs.85).aspx. +bool Socket::WaitForRecv(int timeout_ms) { + receive_timed_out_ = false; + + // In our usage |timeout_ms| <= 0 means block forever, so just return true immediately and let + // the subsequent recv() do the blocking. + if (timeout_ms <= 0) { + return true; + } + + // select() doesn't always check this case and will block for |timeout_ms| if we let it. + if (sock_ == INVALID_SOCKET) { + return false; + } + + fd_set read_set; + FD_ZERO(&read_set); + FD_SET(sock_, &read_set); + + timeval timeout; + timeout.tv_sec = timeout_ms / 1000; + timeout.tv_usec = (timeout_ms % 1000) * 1000; + + int result = TEMP_FAILURE_RETRY(select(sock_ + 1, &read_set, nullptr, nullptr, &timeout)); + + if (result == 0) { + receive_timed_out_ = true; + } + return result == 1; +} + // Implements the Socket interface for UDP. class UdpSocket : public Socket { public: @@ -127,7 +149,7 @@ bool UdpSocket::Send(std::vector buffers) { } ssize_t UdpSocket::Receive(void* data, size_t length, int timeout_ms) { - if (!SetReceiveTimeout(timeout_ms)) { + if (!WaitForRecv(timeout_ms)) { return -1; } @@ -206,7 +228,7 @@ bool TcpSocket::Send(std::vector buffers) { } ssize_t TcpSocket::Receive(void* data, size_t length, int timeout_ms) { - if (!SetReceiveTimeout(timeout_ms)) { + if (!WaitForRecv(timeout_ms)) { return -1; } diff --git a/fastboot/socket.h b/fastboot/socket.h index c0bd7c96c..de543dbab 100644 --- a/fastboot/socket.h +++ b/fastboot/socket.h @@ -81,13 +81,17 @@ class Socket { virtual bool Send(std::vector buffers) = 0; // Waits up to |timeout_ms| to receive up to |length| bytes of data. |timout_ms| of 0 will - // block forever. Returns the number of bytes received or -1 on error/timeout. On timeout - // errno will be set to EAGAIN or EWOULDBLOCK. + // block forever. Returns the number of bytes received or -1 on error/timeout; see + // ReceiveTimedOut() to distinguish between the two. virtual ssize_t Receive(void* data, size_t length, int timeout_ms) = 0; // Calls Receive() until exactly |length| bytes have been received or an error occurs. virtual ssize_t ReceiveAll(void* data, size_t length, int timeout_ms); + // Returns true if the last Receive() call timed out normally and can be retried; fatal errors + // or successful reads will return false. + bool ReceiveTimedOut() { return receive_timed_out_; } + // Closes the socket. Returns 0 on success, -1 on error. virtual int Close(); @@ -102,10 +106,13 @@ class Socket { // Protected constructor to force factory function use. Socket(cutils_socket_t sock); - // Update the socket receive timeout if necessary. - bool SetReceiveTimeout(int timeout_ms); + // Blocks up to |timeout_ms| until a read is possible on |sock_|, and sets |receive_timed_out_| + // as appropriate to help distinguish between normal timeouts and fatal errors. Returns true if + // a subsequent recv() on |sock_| will complete without blocking or if |timeout_ms| <= 0. + bool WaitForRecv(int timeout_ms); cutils_socket_t sock_ = INVALID_SOCKET; + bool receive_timed_out_ = false; // Non-class functions we want to override during tests to verify functionality. Implementation // should call this rather than using socket_send_buffers() directly. @@ -113,8 +120,6 @@ class Socket { socket_send_buffers_function_ = &socket_send_buffers; private: - int receive_timeout_ms_ = 0; - FRIEND_TEST(SocketTest, TestTcpSendBuffers); FRIEND_TEST(SocketTest, TestUdpSendBuffers); diff --git a/fastboot/socket_mock.cpp b/fastboot/socket_mock.cpp index c962f303d..2531b53ad 100644 --- a/fastboot/socket_mock.cpp +++ b/fastboot/socket_mock.cpp @@ -55,7 +55,7 @@ bool SocketMock::Send(const void* data, size_t length) { return false; } - bool return_value = events_.front().return_value; + bool return_value = events_.front().status; events_.pop(); return return_value; } @@ -76,21 +76,28 @@ ssize_t SocketMock::Receive(void* data, size_t length, int /*timeout_ms*/) { return -1; } - if (events_.front().type != EventType::kReceive) { + const Event& event = events_.front(); + if (event.type != EventType::kReceive) { ADD_FAILURE() << "Receive() was called out-of-order"; return -1; } - if (events_.front().return_value > static_cast(length)) { - ADD_FAILURE() << "Receive(): not enough bytes (" << length << ") for " - << events_.front().message; + const std::string& message = event.message; + if (message.length() > length) { + ADD_FAILURE() << "Receive(): not enough bytes (" << length << ") for " << message; return -1; } - ssize_t return_value = events_.front().return_value; - if (return_value > 0) { - memcpy(data, events_.front().message.data(), return_value); + receive_timed_out_ = event.status; + ssize_t return_value = message.length(); + + // Empty message indicates failure. + if (message.empty()) { + return_value = -1; + } else { + memcpy(data, message.data(), message.length()); } + events_.pop(); return return_value; } @@ -124,18 +131,21 @@ void SocketMock::ExpectSendFailure(std::string message) { } void SocketMock::AddReceive(std::string message) { - ssize_t return_value = message.length(); - events_.push(Event(EventType::kReceive, std::move(message), return_value, nullptr)); + events_.push(Event(EventType::kReceive, std::move(message), false, nullptr)); +} + +void SocketMock::AddReceiveTimeout() { + events_.push(Event(EventType::kReceive, "", true, nullptr)); } void SocketMock::AddReceiveFailure() { - events_.push(Event(EventType::kReceive, "", -1, nullptr)); + events_.push(Event(EventType::kReceive, "", false, nullptr)); } void SocketMock::AddAccept(std::unique_ptr sock) { - events_.push(Event(EventType::kAccept, "", 0, std::move(sock))); + events_.push(Event(EventType::kAccept, "", false, std::move(sock))); } -SocketMock::Event::Event(EventType _type, std::string _message, ssize_t _return_value, +SocketMock::Event::Event(EventType _type, std::string _message, ssize_t _status, std::unique_ptr _sock) - : type(_type), message(_message), return_value(_return_value), sock(std::move(_sock)) {} + : type(_type), message(_message), status(_status), sock(std::move(_sock)) {} diff --git a/fastboot/socket_mock.h b/fastboot/socket_mock.h index 41fe06db0..eacd6bb6a 100644 --- a/fastboot/socket_mock.h +++ b/fastboot/socket_mock.h @@ -71,7 +71,10 @@ class SocketMock : public Socket { // Adds data to provide for Receive(). void AddReceive(std::string message); - // Adds a Receive() failure. + // Adds a Receive() timeout after which ReceiveTimedOut() will return true. + void AddReceiveTimeout(); + + // Adds a Receive() failure after which ReceiveTimedOut() will return false. void AddReceiveFailure(); // Adds a Socket to return from Accept(). @@ -81,12 +84,12 @@ class SocketMock : public Socket { enum class EventType { kSend, kReceive, kAccept }; struct Event { - Event(EventType _type, std::string _message, ssize_t _return_value, + Event(EventType _type, std::string _message, ssize_t _status, std::unique_ptr _sock); EventType type; std::string message; - ssize_t return_value; + bool status; // Return value for Send() or timeout status for Receive(). std::unique_ptr sock; }; diff --git a/fastboot/socket_test.cpp b/fastboot/socket_test.cpp index cc7107529..affbdfd88 100644 --- a/fastboot/socket_test.cpp +++ b/fastboot/socket_test.cpp @@ -28,7 +28,8 @@ #include #include -enum { kTestTimeoutMs = 3000 }; +static constexpr int kShortTimeoutMs = 10; +static constexpr int kTestTimeoutMs = 3000; // Creates connected sockets |server| and |client|. Returns true on success. bool MakeConnectedSockets(Socket::Protocol protocol, std::unique_ptr* server, @@ -87,6 +88,50 @@ TEST(SocketTest, TestSendAndReceive) { } } +TEST(SocketTest, TestReceiveTimeout) { + std::unique_ptr server, client; + char buffer[16]; + + for (Socket::Protocol protocol : {Socket::Protocol::kUdp, Socket::Protocol::kTcp}) { + ASSERT_TRUE(MakeConnectedSockets(protocol, &server, &client)); + + EXPECT_EQ(-1, server->Receive(buffer, sizeof(buffer), kShortTimeoutMs)); + EXPECT_TRUE(server->ReceiveTimedOut()); + + EXPECT_EQ(-1, client->Receive(buffer, sizeof(buffer), kShortTimeoutMs)); + EXPECT_TRUE(client->ReceiveTimedOut()); + } + + // UDP will wait for timeout if the other side closes. + ASSERT_TRUE(MakeConnectedSockets(Socket::Protocol::kUdp, &server, &client)); + EXPECT_EQ(0, server->Close()); + EXPECT_EQ(-1, client->Receive(buffer, sizeof(buffer), kShortTimeoutMs)); + EXPECT_TRUE(client->ReceiveTimedOut()); +} + +TEST(SocketTest, TestReceiveFailure) { + std::unique_ptr server, client; + char buffer[16]; + + for (Socket::Protocol protocol : {Socket::Protocol::kUdp, Socket::Protocol::kTcp}) { + ASSERT_TRUE(MakeConnectedSockets(protocol, &server, &client)); + + EXPECT_EQ(0, server->Close()); + EXPECT_EQ(-1, server->Receive(buffer, sizeof(buffer), kTestTimeoutMs)); + EXPECT_FALSE(server->ReceiveTimedOut()); + + EXPECT_EQ(0, client->Close()); + EXPECT_EQ(-1, client->Receive(buffer, sizeof(buffer), kTestTimeoutMs)); + EXPECT_FALSE(client->ReceiveTimedOut()); + } + + // TCP knows right away when the other side closes and returns 0 to indicate EOF. + ASSERT_TRUE(MakeConnectedSockets(Socket::Protocol::kTcp, &server, &client)); + EXPECT_EQ(0, server->Close()); + EXPECT_EQ(0, client->Receive(buffer, sizeof(buffer), kTestTimeoutMs)); + EXPECT_FALSE(client->ReceiveTimedOut()); +} + // Tests sending and receiving large packets. TEST(SocketTest, TestLargePackets) { std::string message(1024, '\0'); @@ -290,6 +335,11 @@ TEST(SocketMockTest, TestReceiveFailure) { mock->AddReceiveFailure(); EXPECT_FALSE(ReceiveString(mock, "foo")); + EXPECT_FALSE(mock->ReceiveTimedOut()); + + mock->AddReceiveTimeout(); + EXPECT_FALSE(ReceiveString(mock, "foo")); + EXPECT_TRUE(mock->ReceiveTimedOut()); mock->AddReceive("foo"); mock->AddReceiveFailure(); diff --git a/fastboot/tcp.cpp b/fastboot/tcp.cpp index da2880a5e..e42c4e1af 100644 --- a/fastboot/tcp.cpp +++ b/fastboot/tcp.cpp @@ -28,6 +28,7 @@ #include "tcp.h" +#include #include namespace tcp { @@ -98,7 +99,8 @@ bool TcpTransport::InitializeProtocol(std::string* error) { return false; } - char buffer[kHandshakeLength]; + char buffer[kHandshakeLength + 1]; + buffer[kHandshakeLength] = '\0'; if (socket_->ReceiveAll(buffer, kHandshakeLength, kHandshakeTimeoutMs) != kHandshakeLength) { *error = android::base::StringPrintf( "No initialization message received (%s). Target may not support TCP fastboot", @@ -111,9 +113,10 @@ bool TcpTransport::InitializeProtocol(std::string* error) { return false; } - if (memcmp(buffer + 2, "01", 2) != 0) { + int version = 0; + if (!android::base::ParseInt(buffer + 2, &version) || version < kProtocolVersion) { *error = android::base::StringPrintf("Unknown TCP protocol version %s (host version %02d)", - std::string(buffer + 2, 2).c_str(), kProtocolVersion); + buffer + 2, kProtocolVersion); return false; } diff --git a/fastboot/tcp_test.cpp b/fastboot/tcp_test.cpp index 7d80d76f5..6e867ae85 100644 --- a/fastboot/tcp_test.cpp +++ b/fastboot/tcp_test.cpp @@ -42,6 +42,16 @@ TEST(TcpConnectTest, TestSuccess) { EXPECT_EQ("", error); } +TEST(TcpConnectTest, TestNewerVersionSuccess) { + std::unique_ptr mock(new SocketMock); + mock->ExpectSend("FB01"); + mock->AddReceive("FB99"); + + std::string error; + EXPECT_NE(nullptr, tcp::internal::Connect(std::move(mock), &error)); + EXPECT_EQ("", error); +} + TEST(TcpConnectTest, TestSendFailure) { std::unique_ptr mock(new SocketMock); mock->ExpectSendFailure("FB01"); @@ -74,11 +84,11 @@ TEST(TcpConnectTest, TestBadResponseFailure) { TEST(TcpConnectTest, TestUnknownVersionFailure) { std::unique_ptr mock(new SocketMock); mock->ExpectSend("FB01"); - mock->AddReceive("FB02"); + mock->AddReceive("FB00"); std::string error; EXPECT_EQ(nullptr, tcp::internal::Connect(std::move(mock), &error)); - EXPECT_EQ("Unknown TCP protocol version 02 (host version 01)", error); + EXPECT_EQ("Unknown TCP protocol version 00 (host version 01)", error); } // Fixture to configure a SocketMock for a successful TCP connection. diff --git a/fastboot/udp.cpp b/fastboot/udp.cpp new file mode 100644 index 000000000..b36bd605c --- /dev/null +++ b/fastboot/udp.cpp @@ -0,0 +1,391 @@ +/* + * Copyright (C) 2015 The Android Open Source Project + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE + * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, + * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS + * OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED + * AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT + * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGE. + */ + +// This file implements the fastboot UDP protocol; see fastboot_protocol.txt for documentation. + +#include "udp.h" + +#include +#include + +#include +#include +#include + +#include +#include + +#include "socket.h" + +namespace udp { + +using namespace internal; + +constexpr size_t kMinPacketSize = 512; +constexpr size_t kHeaderSize = 4; + +enum Index { + kIndexId = 0, + kIndexFlags = 1, + kIndexSeqH = 2, + kIndexSeqL = 3, +}; + +// Extracts a big-endian uint16_t from a byte array. +static uint16_t ExtractUint16(const uint8_t* bytes) { + return (static_cast(bytes[0]) << 8) | bytes[1]; +} + +// Packet header handling. +class Header { + public: + Header(); + ~Header() = default; + + uint8_t id() const { return bytes_[kIndexId]; } + const uint8_t* bytes() const { return bytes_; } + + void Set(uint8_t id, uint16_t sequence, Flag flag); + + // Checks whether |response| is a match for this header. + bool Matches(const uint8_t* response); + + private: + uint8_t bytes_[kHeaderSize]; +}; + +Header::Header() { + Set(kIdError, 0, kFlagNone); +} + +void Header::Set(uint8_t id, uint16_t sequence, Flag flag) { + bytes_[kIndexId] = id; + bytes_[kIndexFlags] = flag; + bytes_[kIndexSeqH] = sequence >> 8; + bytes_[kIndexSeqL] = sequence; +} + +bool Header::Matches(const uint8_t* response) { + // Sequence numbers must be the same to match, but the response ID can either be the same + // or an error response which is always accepted. + return bytes_[kIndexSeqH] == response[kIndexSeqH] && + bytes_[kIndexSeqL] == response[kIndexSeqL] && + (bytes_[kIndexId] == response[kIndexId] || response[kIndexId] == kIdError); +} + +// Implements the Transport interface to work with the fastboot engine. +class UdpTransport : public Transport { + public: + // Factory function so we can return nullptr if initialization fails. + static std::unique_ptr NewTransport(std::unique_ptr socket, + std::string* error); + ~UdpTransport() override = default; + + ssize_t Read(void* data, size_t length) override; + ssize_t Write(const void* data, size_t length) override; + int Close() override; + + private: + UdpTransport(std::unique_ptr socket) : socket_(std::move(socket)) {} + + // Performs the UDP initialization procedure. Returns true on success. + bool InitializeProtocol(std::string* error); + + // Sends |length| bytes from |data| and waits for the response packet up to |attempts| times. + // Continuation packets are handled automatically and any return data is written to |rx_data|. + // Excess bytes that cannot fit in |rx_data| are dropped. + // On success, returns the number of response data bytes received, which may be greater than + // |rx_length|. On failure, returns -1 and fills |error| on failure. + ssize_t SendData(Id id, const uint8_t* tx_data, size_t tx_length, uint8_t* rx_data, + size_t rx_length, int attempts, std::string* error); + + // Helper for SendData(); sends a single packet and handles the response. |header| specifies + // the initial outgoing packet information but may be modified by this function. + ssize_t SendSinglePacketHelper(Header* header, const uint8_t* tx_data, size_t tx_length, + uint8_t* rx_data, size_t rx_length, int attempts, + std::string* error); + + std::unique_ptr socket_; + int sequence_ = -1; + size_t max_data_length_ = kMinPacketSize - kHeaderSize; + std::vector rx_packet_; + + DISALLOW_COPY_AND_ASSIGN(UdpTransport); +}; + +std::unique_ptr UdpTransport::NewTransport(std::unique_ptr socket, + std::string* error) { + std::unique_ptr transport(new UdpTransport(std::move(socket))); + + if (!transport->InitializeProtocol(error)) { + return nullptr; + } + + return transport; +} + +bool UdpTransport::InitializeProtocol(std::string* error) { + uint8_t rx_data[4]; + + sequence_ = 0; + rx_packet_.resize(kMinPacketSize); + + // First send the query packet to sync with the target. Only attempt this a small number of + // times so we can fail out quickly if the target isn't available. + ssize_t rx_bytes = SendData(kIdDeviceQuery, nullptr, 0, rx_data, sizeof(rx_data), + kMaxConnectAttempts, error); + if (rx_bytes == -1) { + return false; + } else if (rx_bytes < 2) { + *error = "invalid query response from target"; + return false; + } + // The first two bytes contain the next expected sequence number. + sequence_ = ExtractUint16(rx_data); + + // Now send the initialization packet with our version and maximum packet size. + uint8_t init_data[] = {kProtocolVersion >> 8, kProtocolVersion & 0xFF, + kHostMaxPacketSize >> 8, kHostMaxPacketSize & 0xFF}; + rx_bytes = SendData(kIdInitialization, init_data, sizeof(init_data), rx_data, sizeof(rx_data), + kMaxTransmissionAttempts, error); + if (rx_bytes == -1) { + return false; + } else if (rx_bytes < 4) { + *error = "invalid initialization response from target"; + return false; + } + + // The first two data bytes contain the version, the second two bytes contain the target max + // supported packet size, which must be at least 512 bytes. + uint16_t version = ExtractUint16(rx_data); + if (version < kProtocolVersion) { + *error = android::base::StringPrintf("target reported invalid protocol version %d", + version); + return false; + } + uint16_t packet_size = ExtractUint16(rx_data + 2); + if (packet_size < kMinPacketSize) { + *error = android::base::StringPrintf("target reported invalid packet size %d", packet_size); + return false; + } + + packet_size = std::min(kHostMaxPacketSize, packet_size); + max_data_length_ = packet_size - kHeaderSize; + rx_packet_.resize(packet_size); + + return true; +} + +// SendData() is just responsible for chunking |data| into packets until it's all been sent. +// Per-packet timeout/retransmission logic is done in SendSinglePacketHelper(). +ssize_t UdpTransport::SendData(Id id, const uint8_t* tx_data, size_t tx_length, uint8_t* rx_data, + size_t rx_length, int attempts, std::string* error) { + if (socket_ == nullptr) { + *error = "socket is closed"; + return -1; + } + + Header header; + size_t packet_data_length; + ssize_t ret = 0; + // We often send header-only packets with no data as part of the protocol, so always send at + // least once even if |length| == 0, then repeat until we've sent all of |data|. + do { + // Set the continuation flag and truncate packet data if needed. + if (tx_length > max_data_length_) { + packet_data_length = max_data_length_; + header.Set(id, sequence_, kFlagContinuation); + } else { + packet_data_length = tx_length; + header.Set(id, sequence_, kFlagNone); + } + + ssize_t bytes = SendSinglePacketHelper(&header, tx_data, packet_data_length, rx_data, + rx_length, attempts, error); + + // Advance our read and write buffers for the next packet. Keep going even if we run out + // of receive buffer space so we can detect overflows. + if (bytes == -1) { + return -1; + } else if (static_cast(bytes) < rx_length) { + rx_data += bytes; + rx_length -= bytes; + } else { + rx_data = nullptr; + rx_length = 0; + } + + tx_length -= packet_data_length; + tx_data += packet_data_length; + + ret += bytes; + } while (tx_length > 0); + + return ret; +} + +ssize_t UdpTransport::SendSinglePacketHelper( + Header* header, const uint8_t* tx_data, size_t tx_length, uint8_t* rx_data, + size_t rx_length, const int attempts, std::string* error) { + ssize_t total_data_bytes = 0; + error->clear(); + + int attempts_left = attempts; + while (attempts_left > 0) { + if (!socket_->Send({{header->bytes(), kHeaderSize}, {tx_data, tx_length}})) { + *error = Socket::GetErrorMessage(); + return -1; + } + + // Keep receiving until we get a matching response or we timeout. + ssize_t bytes = 0; + do { + bytes = socket_->Receive(rx_packet_.data(), rx_packet_.size(), kResponseTimeoutMs); + if (bytes == -1) { + if (socket_->ReceiveTimedOut()) { + break; + } + *error = Socket::GetErrorMessage(); + return -1; + } else if (bytes < static_cast(kHeaderSize)) { + *error = "protocol error: incomplete header"; + return -1; + } + } while (!header->Matches(rx_packet_.data())); + + if (socket_->ReceiveTimedOut()) { + --attempts_left; + continue; + } + ++sequence_; + + // Save to |error| or |rx_data| as appropriate. + if (rx_packet_[kIndexId] == kIdError) { + error->append(rx_packet_.data() + kHeaderSize, rx_packet_.data() + bytes); + } else { + total_data_bytes += bytes - kHeaderSize; + size_t rx_data_bytes = std::min(bytes - kHeaderSize, rx_length); + if (rx_data_bytes > 0) { + memcpy(rx_data, rx_packet_.data() + kHeaderSize, rx_data_bytes); + rx_data += rx_data_bytes; + rx_length -= rx_data_bytes; + } + } + + // If the response has a continuation flag we need to prompt for more data by sending + // an empty packet. + if (rx_packet_[kIndexFlags] & kFlagContinuation) { + // We got a valid response so reset our attempt counter. + attempts_left = attempts; + header->Set(rx_packet_[kIndexId], sequence_, kFlagNone); + tx_data = nullptr; + tx_length = 0; + continue; + } + + break; + } + + if (attempts_left <= 0) { + *error = "no response from target"; + return -1; + } + + if (rx_packet_[kIndexId] == kIdError) { + *error = "target reported error: " + *error; + return -1; + } + + return total_data_bytes; +} + +ssize_t UdpTransport::Read(void* data, size_t length) { + // Read from the target by sending an empty packet. + std::string error; + ssize_t bytes = SendData(kIdFastboot, nullptr, 0, reinterpret_cast(data), length, + kMaxTransmissionAttempts, &error); + + if (bytes == -1) { + fprintf(stderr, "UDP error: %s\n", error.c_str()); + return -1; + } else if (static_cast(bytes) > length) { + // Fastboot protocol error: the target sent more data than our fastboot engine was prepared + // to receive. + fprintf(stderr, "UDP error: receive overflow, target sent too much fastboot data\n"); + return -1; + } + + return bytes; +} + +ssize_t UdpTransport::Write(const void* data, size_t length) { + std::string error; + ssize_t bytes = SendData(kIdFastboot, reinterpret_cast(data), length, nullptr, + 0, kMaxTransmissionAttempts, &error); + + if (bytes == -1) { + fprintf(stderr, "UDP error: %s\n", error.c_str()); + return -1; + } else if (bytes > 0) { + // UDP protocol error: only empty ACK packets are allowed when writing to a device. + fprintf(stderr, "UDP error: target sent fastboot data out-of-turn\n"); + return -1; + } + + return length; +} + +int UdpTransport::Close() { + if (socket_ == nullptr) { + return 0; + } + + int result = socket_->Close(); + socket_.reset(); + return result; +} + +std::unique_ptr Connect(const std::string& hostname, int port, std::string* error) { + return internal::Connect(Socket::NewClient(Socket::Protocol::kUdp, hostname, port, error), + error); +} + +namespace internal { + +std::unique_ptr Connect(std::unique_ptr sock, std::string* error) { + if (sock == nullptr) { + // If Socket creation failed |error| is already set. + return nullptr; + } + + return UdpTransport::NewTransport(std::move(sock), error); +} + +} // namespace internal + +} // namespace udp diff --git a/fastboot/udp.h b/fastboot/udp.h new file mode 100644 index 000000000..14f5b3547 --- /dev/null +++ b/fastboot/udp.h @@ -0,0 +1,81 @@ +/* + * Copyright (C) 2015 The Android Open Source Project + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE + * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, + * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS + * OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED + * AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT + * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGE. + */ + +#ifndef UDP_H_ +#define UDP_H_ + +#include +#include + +#include "socket.h" +#include "transport.h" + +namespace udp { + +constexpr int kDefaultPort = 5554; + +// Returns a newly allocated Transport object connected to |hostname|:|port|. On failure, |error| is +// filled and nullptr is returned. +std::unique_ptr Connect(const std::string& hostname, int port, std::string* error); + +// Internal namespace for test use only. +namespace internal { + +constexpr uint16_t kProtocolVersion = 1; + +// This will be negotiated with the device so may end up being smaller. +constexpr uint16_t kHostMaxPacketSize = 8192; + +// Retransmission constants. Retransmission timeout must be at least 500ms, and the host must +// attempt to send packets for at least 1 minute once the device has connected. See +// fastboot_protocol.txt for more information. +constexpr int kResponseTimeoutMs = 500; +constexpr int kMaxConnectAttempts = 4; +constexpr int kMaxTransmissionAttempts = 60 * 1000 / kResponseTimeoutMs; + +enum Id : uint8_t { + kIdError = 0x00, + kIdDeviceQuery = 0x01, + kIdInitialization = 0x02, + kIdFastboot = 0x03 +}; + +enum Flag : uint8_t { + kFlagNone = 0x00, + kFlagContinuation = 0x01 +}; + +// Creates a UDP Transport object using a given Socket. Used for unit tests to create a Transport +// object that uses a SocketMock. +std::unique_ptr Connect(std::unique_ptr sock, std::string* error); + +} // namespace internal + +} // namespace udp + +#endif // UDP_H_ diff --git a/fastboot/udp_test.cpp b/fastboot/udp_test.cpp new file mode 100644 index 000000000..ff8cf0fc8 --- /dev/null +++ b/fastboot/udp_test.cpp @@ -0,0 +1,531 @@ +/* + * Copyright (C) 2015 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "udp.h" + +#include + +#include "socket.h" +#include "socket_mock.h" + +using namespace udp; +using namespace udp::internal; + +// Some possible corner case sequence numbers we want to check. +static const uint16_t kTestSequenceNumbers[] = {0x0000, 0x0001, 0x00FF, 0x0100, + 0x7FFF, 0x8000, 0xFFFF}; + +// Converts |value| to a binary big-endian string. +static std::string PacketValue(uint16_t value) { + return std::string{static_cast(value >> 8), static_cast(value)}; +} + +// Returns an Error packet. +static std::string ErrorPacket(uint16_t sequence, const std::string& message = "", + char flags = kFlagNone) { + return std::string{kIdError, flags} + PacketValue(sequence) + message; +} + +// Returns a Query packet with no data. +static std::string QueryPacket(uint16_t sequence) { + return std::string{kIdDeviceQuery, kFlagNone} + PacketValue(sequence); +} + +// Returns a Query packet with a 2-byte |new_sequence|. +static std::string QueryPacket(uint16_t sequence, uint16_t new_sequence) { + return std::string{kIdDeviceQuery, kFlagNone} + PacketValue(sequence) + + PacketValue(new_sequence); +} + +// Returns an Init packet with a 2-byte |version| and |max_packet_size|. +static std::string InitPacket(uint16_t sequence, uint16_t version, uint16_t max_packet_size) { + return std::string{kIdInitialization, kFlagNone} + PacketValue(sequence) + + PacketValue(version) + PacketValue(max_packet_size); +} + +// Returns a Fastboot packet with |data|. +static std::string FastbootPacket(uint16_t sequence, const std::string& data = "", + char flags = kFlagNone) { + return std::string{kIdFastboot, flags} + PacketValue(sequence) + data; +} + +// Fixture class to test protocol initialization. Usage is to set up the expected calls to the +// SocketMock object then call UdpConnect() and check the result. +class UdpConnectTest : public ::testing::Test { + public: + UdpConnectTest() : mock_socket_(new SocketMock) {} + + // Run the initialization, return whether it was successful or not. This passes ownership of + // the current |mock_socket_| but allocates a new one for re-use. + bool UdpConnect(std::string* error = nullptr) { + std::string local_error; + if (error == nullptr) { + error = &local_error; + } + std::unique_ptr transport(Connect(std::move(mock_socket_), error)); + mock_socket_.reset(new SocketMock); + return transport != nullptr && error->empty(); + } + + protected: + std::unique_ptr mock_socket_; +}; + +// Tests a successful protocol initialization with various starting sequence numbers. +TEST_F(UdpConnectTest, InitializationSuccess) { + for (uint16_t seq : kTestSequenceNumbers) { + mock_socket_->ExpectSend(QueryPacket(0)); + mock_socket_->AddReceive(QueryPacket(0, seq)); + mock_socket_->ExpectSend(InitPacket(seq, kProtocolVersion, kHostMaxPacketSize)); + mock_socket_->AddReceive(InitPacket(seq, kProtocolVersion, 1024)); + + EXPECT_TRUE(UdpConnect()); + } +} + +// Tests continuation packets during initialization. +TEST_F(UdpConnectTest, InitializationContinuationSuccess) { + mock_socket_->ExpectSend(QueryPacket(0)); + mock_socket_->AddReceive(std::string{kIdDeviceQuery, kFlagContinuation, 0, 0, 0x44}); + mock_socket_->ExpectSend(std::string{kIdDeviceQuery, kFlagNone, 0, 1}); + mock_socket_->AddReceive(std::string{kIdDeviceQuery, kFlagNone, 0, 1, 0x55}); + + mock_socket_->ExpectSend(InitPacket(0x4455, kProtocolVersion, kHostMaxPacketSize)); + mock_socket_->AddReceive(std::string{kIdInitialization, kFlagContinuation, 0x44, 0x55, 0}); + mock_socket_->ExpectSend(std::string{kIdInitialization, kFlagNone, 0x44, 0x56}); + mock_socket_->AddReceive(std::string{kIdInitialization, kFlagContinuation, 0x44, 0x56, 1}); + mock_socket_->ExpectSend(std::string{kIdInitialization, kFlagNone, 0x44, 0x57}); + mock_socket_->AddReceive(std::string{kIdInitialization, kFlagContinuation, 0x44, 0x57, 2}); + mock_socket_->ExpectSend(std::string{kIdInitialization, kFlagNone, 0x44, 0x58}); + mock_socket_->AddReceive(std::string{kIdInitialization, kFlagNone, 0x44, 0x58, 0}); + + EXPECT_TRUE(UdpConnect()); +} + + +// Tests a mismatched version number; as long as the minimum of the two versions is supported +// we should allow the connection. +TEST_F(UdpConnectTest, InitializationVersionMismatch) { + mock_socket_->ExpectSend(QueryPacket(0)); + mock_socket_->AddReceive(QueryPacket(0, 0)); + mock_socket_->ExpectSend(InitPacket(0, kProtocolVersion, kHostMaxPacketSize)); + mock_socket_->AddReceive(InitPacket(0, 2, 1024)); + + EXPECT_TRUE(UdpConnect()); + + mock_socket_->ExpectSend(QueryPacket(0)); + mock_socket_->AddReceive(QueryPacket(0, 0)); + mock_socket_->ExpectSend(InitPacket(0, kProtocolVersion, kHostMaxPacketSize)); + mock_socket_->AddReceive(InitPacket(0, 0, 1024)); + + EXPECT_FALSE(UdpConnect()); +} + +TEST_F(UdpConnectTest, QueryResponseTimeoutFailure) { + for (int i = 0; i < kMaxConnectAttempts; ++i) { + mock_socket_->ExpectSend(QueryPacket(0)); + mock_socket_->AddReceiveTimeout(); + } + + EXPECT_FALSE(UdpConnect()); +} + +TEST_F(UdpConnectTest, QueryResponseReceiveFailure) { + mock_socket_->ExpectSend(QueryPacket(0)); + mock_socket_->AddReceiveFailure(); + + EXPECT_FALSE(UdpConnect()); +} + +TEST_F(UdpConnectTest, InitResponseTimeoutFailure) { + mock_socket_->ExpectSend(QueryPacket(0)); + mock_socket_->AddReceive(QueryPacket(0, 0)); + for (int i = 0; i < kMaxTransmissionAttempts; ++i) { + mock_socket_->ExpectSend(InitPacket(0, kProtocolVersion, kHostMaxPacketSize)); + mock_socket_->AddReceiveTimeout(); + } + + EXPECT_FALSE(UdpConnect()); +} + +TEST_F(UdpConnectTest, InitResponseReceiveFailure) { + mock_socket_->ExpectSend(QueryPacket(0)); + mock_socket_->AddReceive(QueryPacket(0, 0)); + mock_socket_->ExpectSend(InitPacket(0, kProtocolVersion, kHostMaxPacketSize)); + mock_socket_->AddReceiveFailure(); + + EXPECT_FALSE(UdpConnect()); +} + +// Tests that we can recover up to the maximum number of allowed retries. +TEST_F(UdpConnectTest, ResponseRecovery) { + // The device query packet can recover from up to (kMaxConnectAttempts - 1) timeouts. + for (int i = 0; i < kMaxConnectAttempts - 1; ++i) { + mock_socket_->ExpectSend(QueryPacket(0)); + mock_socket_->AddReceiveTimeout(); + } + mock_socket_->ExpectSend(QueryPacket(0)); + mock_socket_->AddReceive(QueryPacket(0, 0)); + + // Subsequent packets try up to (kMaxTransmissionAttempts - 1) times. + for (int i = 0; i < kMaxTransmissionAttempts - 1; ++i) { + mock_socket_->ExpectSend(InitPacket(0, kProtocolVersion, kHostMaxPacketSize)); + mock_socket_->AddReceiveTimeout(); + } + mock_socket_->ExpectSend(InitPacket(0, kProtocolVersion, kHostMaxPacketSize)); + mock_socket_->AddReceive(InitPacket(0, kProtocolVersion, 1024)); + + EXPECT_TRUE(UdpConnect()); +} + +// Tests that the host can handle receiving additional bytes for forward compatibility. +TEST_F(UdpConnectTest, ExtraResponseDataSuccess) { + mock_socket_->ExpectSend(QueryPacket(0)); + mock_socket_->AddReceive(QueryPacket(0, 0) + "foo"); + mock_socket_->ExpectSend(InitPacket(0, kProtocolVersion, kHostMaxPacketSize)); + mock_socket_->AddReceive(InitPacket(0, kProtocolVersion, 1024) + "bar"); + + EXPECT_TRUE(UdpConnect()); +} + +// Tests mismatched response sequence numbers. A wrong sequence number is interpreted as a previous +// retransmission and just ignored so we should be able to recover. +TEST_F(UdpConnectTest, WrongSequenceRecovery) { + mock_socket_->ExpectSend(QueryPacket(0)); + mock_socket_->AddReceive(QueryPacket(1, 0)); + mock_socket_->AddReceive(QueryPacket(0, 0)); + + mock_socket_->ExpectSend(InitPacket(0, kProtocolVersion, kHostMaxPacketSize)); + mock_socket_->AddReceive(InitPacket(1, kProtocolVersion, 1024)); + mock_socket_->AddReceive(InitPacket(0, kProtocolVersion, 1024)); + + EXPECT_TRUE(UdpConnect()); +} + +// Tests mismatched response IDs. This should also be interpreted as a retransmission and ignored. +TEST_F(UdpConnectTest, WrongIdRecovery) { + mock_socket_->ExpectSend(QueryPacket(0)); + mock_socket_->AddReceive(FastbootPacket(0)); + mock_socket_->AddReceive(QueryPacket(0, 0)); + + mock_socket_->ExpectSend(InitPacket(0, kProtocolVersion, kHostMaxPacketSize)); + mock_socket_->AddReceive(FastbootPacket(0)); + mock_socket_->AddReceive(InitPacket(0, kProtocolVersion, 1024)); + + EXPECT_TRUE(UdpConnect()); +} + +// Tests an invalid query response. Query responses must have at least 2 bytes of data. +TEST_F(UdpConnectTest, InvalidQueryResponseFailure) { + std::string error; + + mock_socket_->ExpectSend(QueryPacket(0)); + mock_socket_->AddReceive(QueryPacket(0)); + + EXPECT_FALSE(UdpConnect(&error)); + EXPECT_EQ("invalid query response from target", error); + + mock_socket_->ExpectSend(QueryPacket(0)); + mock_socket_->AddReceive(QueryPacket(0) + std::string{0x00}); + + EXPECT_FALSE(UdpConnect(&error)); + EXPECT_EQ("invalid query response from target", error); +} + +// Tests an invalid initialization response. Max packet size must be at least 512 bytes. +TEST_F(UdpConnectTest, InvalidInitResponseFailure) { + std::string error; + + mock_socket_->ExpectSend(QueryPacket(0)); + mock_socket_->AddReceive(QueryPacket(0, 0)); + mock_socket_->ExpectSend(InitPacket(0, kProtocolVersion, kHostMaxPacketSize)); + mock_socket_->AddReceive(InitPacket(0, kProtocolVersion, 511)); + + EXPECT_FALSE(UdpConnect(&error)); + EXPECT_EQ("target reported invalid packet size 511", error); + + mock_socket_->ExpectSend(QueryPacket(0)); + mock_socket_->AddReceive(QueryPacket(0, 0)); + mock_socket_->ExpectSend(InitPacket(0, kProtocolVersion, kHostMaxPacketSize)); + mock_socket_->AddReceive(InitPacket(0, 0, 1024)); + + EXPECT_FALSE(UdpConnect(&error)); + EXPECT_EQ("target reported invalid protocol version 0", error); +} + +TEST_F(UdpConnectTest, ErrorResponseFailure) { + std::string error; + + mock_socket_->ExpectSend(QueryPacket(0)); + mock_socket_->AddReceive(ErrorPacket(0, "error1")); + + EXPECT_FALSE(UdpConnect(&error)); + EXPECT_NE(std::string::npos, error.find("error1")); + + mock_socket_->ExpectSend(QueryPacket(0)); + mock_socket_->AddReceive(QueryPacket(0, 0)); + mock_socket_->ExpectSend(InitPacket(0, kProtocolVersion, kHostMaxPacketSize)); + mock_socket_->AddReceive(ErrorPacket(0, "error2")); + + EXPECT_FALSE(UdpConnect(&error)); + EXPECT_NE(std::string::npos, error.find("error2")); +} + +// Tests an error response with continuation flag. +TEST_F(UdpConnectTest, ErrorContinuationFailure) { + std::string error; + + mock_socket_->ExpectSend(QueryPacket(0)); + mock_socket_->AddReceive(ErrorPacket(0, "error1", kFlagContinuation)); + mock_socket_->ExpectSend(ErrorPacket(1)); + mock_socket_->AddReceive(ErrorPacket(1, " ", kFlagContinuation)); + mock_socket_->ExpectSend(ErrorPacket(2)); + mock_socket_->AddReceive(ErrorPacket(2, "error2")); + + EXPECT_FALSE(UdpConnect(&error)); + EXPECT_NE(std::string::npos, error.find("error1 error2")); +} + +// Fixture class to test UDP Transport read/write functionality. +class UdpTest : public ::testing::Test { + public: + void SetUp() override { + // Create |transport_| starting at sequence 0 with 512 byte max packet size. Tests can call + // InitializeTransport() again to change settings. + ASSERT_TRUE(InitializeTransport(0, 512)); + } + + // Sets up |mock_socket_| to correctly initialize the protocol and creates |transport_|. This + // can be called multiple times in a test if needed. + bool InitializeTransport(uint16_t starting_sequence, int device_max_packet_size = 512) { + mock_socket_ = new SocketMock; + mock_socket_->ExpectSend(QueryPacket(0)); + mock_socket_->AddReceive(QueryPacket(0, starting_sequence)); + mock_socket_->ExpectSend( + InitPacket(starting_sequence, kProtocolVersion, kHostMaxPacketSize)); + mock_socket_->AddReceive( + InitPacket(starting_sequence, kProtocolVersion, device_max_packet_size)); + + std::string error; + transport_ = Connect(std::unique_ptr(mock_socket_), &error); + return transport_ != nullptr && error.empty(); + } + + // Writes |message| to |transport_|, returns true on success. + bool Write(const std::string& message) { + return transport_->Write(message.data(), message.length()) == + static_cast(message.length()); + } + + // Reads from |transport_|, returns true if it matches |message|. + bool Read(const std::string& message) { + std::string buffer(message.length(), '\0'); + return transport_->Read(&buffer[0], buffer.length()) == + static_cast(message.length()) && buffer == message; + } + + protected: + // |mock_socket_| is a raw pointer here because we transfer ownership to |transport_| but we + // need to retain a pointer to set send and receive expectations. + SocketMock* mock_socket_ = nullptr; + std::unique_ptr transport_; +}; + +// Tests sequence behavior with various starting sequence numbers. +TEST_F(UdpTest, SequenceIncrementCheck) { + for (uint16_t seq : kTestSequenceNumbers) { + ASSERT_TRUE(InitializeTransport(seq)); + + for (int i = 0; i < 10; ++i) { + mock_socket_->ExpectSend(FastbootPacket(++seq, "foo")); + mock_socket_->AddReceive(FastbootPacket(seq, "")); + mock_socket_->ExpectSend(FastbootPacket(++seq, "")); + mock_socket_->AddReceive(FastbootPacket(seq, "bar")); + + EXPECT_TRUE(Write("foo")); + EXPECT_TRUE(Read("bar")); + } + } +} + +// Tests sending and receiving a few small packets. +TEST_F(UdpTest, ReadAndWriteSmallPackets) { + mock_socket_->ExpectSend(FastbootPacket(1, "foo")); + mock_socket_->AddReceive(FastbootPacket(1, "")); + mock_socket_->ExpectSend(FastbootPacket(2, "")); + mock_socket_->AddReceive(FastbootPacket(2, "bar")); + + EXPECT_TRUE(Write("foo")); + EXPECT_TRUE(Read("bar")); + + mock_socket_->ExpectSend(FastbootPacket(3, "12345 67890")); + mock_socket_->AddReceive(FastbootPacket(3)); + mock_socket_->ExpectSend(FastbootPacket(4, "\x01\x02\x03\x04\x05")); + mock_socket_->AddReceive(FastbootPacket(4)); + + EXPECT_TRUE(Write("12345 67890")); + EXPECT_TRUE(Write("\x01\x02\x03\x04\x05")); + + // Reads are done by sending empty packets. + mock_socket_->ExpectSend(FastbootPacket(5)); + mock_socket_->AddReceive(FastbootPacket(5, "foo bar baz")); + mock_socket_->ExpectSend(FastbootPacket(6)); + mock_socket_->AddReceive(FastbootPacket(6, "\x01\x02\x03\x04\x05")); + + EXPECT_TRUE(Read("foo bar baz")); + EXPECT_TRUE(Read("\x01\x02\x03\x04\x05")); +} + +TEST_F(UdpTest, ResponseTimeoutFailure) { + for (int i = 0; i < kMaxTransmissionAttempts; ++i) { + mock_socket_->ExpectSend(FastbootPacket(1, "foo")); + mock_socket_->AddReceiveTimeout(); + } + + EXPECT_FALSE(Write("foo")); +} + +TEST_F(UdpTest, ResponseReceiveFailure) { + mock_socket_->ExpectSend(FastbootPacket(1, "foo")); + mock_socket_->AddReceiveFailure(); + + EXPECT_FALSE(Write("foo")); +} + +TEST_F(UdpTest, ResponseTimeoutRecovery) { + for (int i = 0; i < kMaxTransmissionAttempts - 1; ++i) { + mock_socket_->ExpectSend(FastbootPacket(1, "foo")); + mock_socket_->AddReceiveTimeout(); + } + mock_socket_->ExpectSend(FastbootPacket(1, "foo")); + mock_socket_->AddReceive(FastbootPacket(1, "")); + + EXPECT_TRUE(Write("foo")); +} + +// Tests continuation packets for various max packet sizes. +// The important part of this test is that regardless of what kind of packet fragmentation happens +// at the socket layer, a single call to Transport::Read() and Transport::Write() is all the +// fastboot code needs to do. +TEST_F(UdpTest, ContinuationPackets) { + for (uint16_t max_packet_size : {512, 1024, 1200}) { + ASSERT_TRUE(InitializeTransport(0, max_packet_size)); + + // Initialize the data we want to send. Use (size - 4) to leave room for the header. + size_t max_data_size = max_packet_size - 4; + std::string data(max_data_size * 3, '\0'); + for (size_t i = 0; i < data.length(); ++i) { + data[i] = i; + } + std::string chunks[] = {data.substr(0, max_data_size), + data.substr(max_data_size, max_data_size), + data.substr(max_data_size * 2, max_data_size)}; + + // Write data: split into 3 UDP packets, each of which will be ACKed. + mock_socket_->ExpectSend(FastbootPacket(1, chunks[0], kFlagContinuation)); + mock_socket_->AddReceive(FastbootPacket(1)); + mock_socket_->ExpectSend(FastbootPacket(2, chunks[1], kFlagContinuation)); + mock_socket_->AddReceive(FastbootPacket(2)); + mock_socket_->ExpectSend(FastbootPacket(3, chunks[2])); + mock_socket_->AddReceive(FastbootPacket(3)); + EXPECT_TRUE(Write(data)); + + // Same thing for reading the data. + mock_socket_->ExpectSend(FastbootPacket(4)); + mock_socket_->AddReceive(FastbootPacket(4, chunks[0], kFlagContinuation)); + mock_socket_->ExpectSend(FastbootPacket(5)); + mock_socket_->AddReceive(FastbootPacket(5, chunks[1], kFlagContinuation)); + mock_socket_->ExpectSend(FastbootPacket(6)); + mock_socket_->AddReceive(FastbootPacket(6, chunks[2])); + EXPECT_TRUE(Read(data)); + } +} + +// Tests that the continuation bit is respected even if the packet isn't max size. +TEST_F(UdpTest, SmallContinuationPackets) { + mock_socket_->ExpectSend(FastbootPacket(1)); + mock_socket_->AddReceive(FastbootPacket(1, "foo", kFlagContinuation)); + mock_socket_->ExpectSend(FastbootPacket(2)); + mock_socket_->AddReceive(FastbootPacket(2, "bar")); + + EXPECT_TRUE(Read("foobar")); +} + +// Tests receiving an error packet mid-continuation. +TEST_F(UdpTest, ContinuationPacketError) { + mock_socket_->ExpectSend(FastbootPacket(1)); + mock_socket_->AddReceive(FastbootPacket(1, "foo", kFlagContinuation)); + mock_socket_->ExpectSend(FastbootPacket(2)); + mock_socket_->AddReceive(ErrorPacket(2, "test error")); + + EXPECT_FALSE(Read("foo")); +} + +// Tests timeout during a continuation sequence. +TEST_F(UdpTest, ContinuationTimeoutRecovery) { + mock_socket_->ExpectSend(FastbootPacket(1)); + mock_socket_->AddReceive(FastbootPacket(1, "foo", kFlagContinuation)); + mock_socket_->ExpectSend(FastbootPacket(2)); + mock_socket_->AddReceiveTimeout(); + mock_socket_->ExpectSend(FastbootPacket(2)); + mock_socket_->AddReceive(FastbootPacket(2, "bar")); + + EXPECT_TRUE(Read("foobar")); +} + +// Tests read overflow returns -1 to indicate the failure. +TEST_F(UdpTest, MultipleReadPacket) { + mock_socket_->ExpectSend(FastbootPacket(1)); + mock_socket_->AddReceive(FastbootPacket(1, "foobarbaz")); + + char buffer[3]; + EXPECT_EQ(-1, transport_->Read(buffer, 3)); +} + +// Tests that packets arriving out-of-order are ignored. +TEST_F(UdpTest, IgnoreOutOfOrderPackets) { + mock_socket_->ExpectSend(FastbootPacket(1)); + mock_socket_->AddReceive(FastbootPacket(0, "sequence too low")); + mock_socket_->AddReceive(FastbootPacket(2, "sequence too high")); + mock_socket_->AddReceive(QueryPacket(1)); + mock_socket_->AddReceive(FastbootPacket(1, "correct")); + + EXPECT_TRUE(Read("correct")); +} + +// Tests that an error response with the correct sequence number causes immediate failure. +TEST_F(UdpTest, ErrorResponse) { + // Error packets with the wrong sequence number should be ignored like any other packet. + mock_socket_->ExpectSend(FastbootPacket(1, "foo")); + mock_socket_->AddReceive(ErrorPacket(0, "ignored error")); + mock_socket_->AddReceive(FastbootPacket(1)); + + EXPECT_TRUE(Write("foo")); + + // Error packets with the correct sequence should abort immediately without retransmission. + mock_socket_->ExpectSend(FastbootPacket(2, "foo")); + mock_socket_->AddReceive(ErrorPacket(2, "test error")); + + EXPECT_FALSE(Write("foo")); +} + +// Tests that attempting to use a closed transport returns -1 without making any socket calls. +TEST_F(UdpTest, CloseTransport) { + char buffer[32]; + EXPECT_EQ(0, transport_->Close()); + EXPECT_EQ(-1, transport_->Write("foo", 3)); + EXPECT_EQ(-1, transport_->Read(buffer, sizeof(buffer))); +}