[adbwifi] Add A_STLS command.

This command will be sent by adbd to notify the client that the
connection will be over TLS.

When client connects, it will send the CNXN packet, as usual. If the
server connection has TLS enabled, it will send the A_STLS packet
(regardless of whether auth is required). At this point, the client's
only valid response is to send a A_STLS packet. Once both sides have
exchanged the A_STLS packet, both will start the TLS handshake.

If auth is required, then the client will receive a CertificateRequest
with a list of known public keys (SHA256 hash) that it can use in its
certificate. Otherwise, the list will be empty and the client can assume
that either any key will work, or none will work.

If the handshake was successful, the server will send the CNXN packet
and the usual adb protocol is resumed over TLS. If the handshake failed,
both sides will disconnect, as there's no point to retry because the
server's known keys have already been communicated.

Bug: 111434128

Test: WIP; will add to adb_test.py/adb_device.py.

Enable wireless debugging in the Settings, then 'adb connect
<ip>:<port>'. Connection should succeed if key is in keystore. Used
wireshark to check for packet encryption.

Change-Id: I3d60647491c6c6b92297e4f628707a6457fa9420
This commit is contained in:
Joshua Duong 2020-01-21 13:19:42 -08:00
parent d85f5c0130
commit 5cf7868b7e
14 changed files with 503 additions and 33 deletions

View file

@ -52,6 +52,7 @@
#include "adb_listeners.h"
#include "adb_unique_fd.h"
#include "adb_utils.h"
#include "adb_wifi.h"
#include "sysdeps/chrono.h"
#include "transport.h"
@ -140,6 +141,9 @@ void print_packet(const char *label, apacket *p)
case A_CLSE: tag = "CLSE"; break;
case A_WRTE: tag = "WRTE"; break;
case A_AUTH: tag = "AUTH"; break;
case A_STLS:
tag = "ATLS";
break;
default: tag = "????"; break;
}
@ -209,6 +213,15 @@ std::string get_connection_string() {
android::base::Join(connection_properties, ';').c_str());
}
void send_tls_request(atransport* t) {
D("Calling send_tls_request");
apacket* p = get_apacket();
p->msg.command = A_STLS;
p->msg.arg0 = A_STLS_VERSION;
p->msg.data_length = 0;
send_packet(p, t);
}
void send_connect(atransport* t) {
D("Calling send_connect");
apacket* cp = get_apacket();
@ -299,7 +312,12 @@ static void handle_new_connection(atransport* t, apacket* p) {
#if ADB_HOST
handle_online(t);
#else
if (!auth_required) {
if (t->use_tls) {
// We still handshake in TLS mode. If auth_required is disabled,
// we'll just not verify the client's certificate. This should be the
// first packet the client receives to indicate the new protocol.
send_tls_request(t);
} else if (!auth_required) {
LOG(INFO) << "authentication not required";
handle_online(t);
send_connect(t);
@ -324,8 +342,21 @@ void handle_packet(apacket *p, atransport *t)
case A_CNXN: // CONNECT(version, maxdata, "system-id-string")
handle_new_connection(t, p);
break;
case A_STLS: // TLS(version, "")
t->use_tls = true;
#if ADB_HOST
send_tls_request(t);
adb_auth_tls_handshake(t);
#else
adbd_auth_tls_handshake(t);
#endif
break;
case A_AUTH:
// All AUTH commands are ignored in TLS mode
if (t->use_tls) {
break;
}
switch (p->msg.arg0) {
#if ADB_HOST
case ADB_AUTH_TOKEN:

View file

@ -44,6 +44,7 @@ constexpr size_t LINUX_MAX_SOCKET_SIZE = 4194304;
#define A_CLSE 0x45534c43
#define A_WRTE 0x45545257
#define A_AUTH 0x48545541
#define A_STLS 0x534C5453
// ADB protocol version.
// Version revision:
@ -53,6 +54,10 @@ constexpr size_t LINUX_MAX_SOCKET_SIZE = 4194304;
#define A_VERSION_SKIP_CHECKSUM 0x01000001
#define A_VERSION 0x01000001
// Stream-based TLS protocol version
#define A_STLS_VERSION_MIN 0x01000000
#define A_STLS_VERSION 0x01000000
// Used for help/version information.
#define ADB_VERSION_MAJOR 1
#define ADB_VERSION_MINOR 0
@ -229,6 +234,7 @@ void handle_online(atransport* t);
void handle_offline(atransport* t);
void send_connect(atransport* t);
void send_tls_request(atransport* t);
void parse_banner(const std::string&, atransport* t);

View file

@ -43,6 +43,9 @@ std::deque<std::shared_ptr<RSA>> adb_auth_get_private_keys();
void send_auth_response(const char* token, size_t token_size, atransport* t);
int adb_tls_set_certificate(SSL* ssl);
void adb_auth_tls_handshake(atransport* t);
#else // !ADB_HOST
extern bool auth_required;
@ -58,6 +61,10 @@ void adbd_notify_framework_connected_key(atransport* t);
void send_auth_request(atransport *t);
void adbd_auth_tls_handshake(atransport* t);
int adbd_tls_verify_cert(X509_STORE_CTX* ctx, std::string* auth_key);
bssl::UniquePtr<STACK_OF(X509_NAME)> adbd_tls_client_ca_list();
#endif // ADB_HOST
#endif // __ADB_AUTH_H

View file

@ -30,6 +30,9 @@
#include <string>
#include <adb/crypto/rsa_2048_key.h>
#include <adb/crypto/x509_generator.h>
#include <adb/tls/adb_ca_list.h>
#include <adb/tls/tls_connection.h>
#include <android-base/errors.h>
#include <android-base/file.h>
#include <android-base/stringprintf.h>
@ -55,6 +58,7 @@ static std::map<std::string, std::shared_ptr<RSA>>& g_keys =
static std::map<int, std::string>& g_monitored_paths = *new std::map<int, std::string>;
using namespace adb::crypto;
using namespace adb::tls;
static bool generate_key(const std::string& file) {
LOG(INFO) << "generate_key(" << file << ")...";
@ -144,6 +148,7 @@ static bool load_key(const std::string& file) {
if (g_keys.find(fingerprint) != g_keys.end()) {
LOG(INFO) << "ignoring already-loaded key: " << file;
} else {
LOG(INFO) << "Loaded fingerprint=[" << SHA256BitsToHexString(fingerprint) << "]";
g_keys[fingerprint] = std::move(key);
}
return true;
@ -475,3 +480,72 @@ void send_auth_response(const char* token, size_t token_size, atransport* t) {
p->msg.data_length = p->payload.size();
send_packet(p, t);
}
void adb_auth_tls_handshake(atransport* t) {
std::thread([t]() {
std::shared_ptr<RSA> key = t->Key();
if (key == nullptr) {
// Can happen if !auth_required
LOG(INFO) << "t->auth_key not set before handshake";
key = t->NextKey();
CHECK(key);
}
LOG(INFO) << "Attempting to TLS handshake";
bool success = t->connection()->DoTlsHandshake(key.get());
if (success) {
LOG(INFO) << "Handshake succeeded. Waiting for CNXN packet...";
} else {
LOG(INFO) << "Handshake failed. Kicking transport";
t->Kick();
}
}).detach();
}
int adb_tls_set_certificate(SSL* ssl) {
LOG(INFO) << __func__;
const STACK_OF(X509_NAME)* ca_list = SSL_get_client_CA_list(ssl);
if (ca_list == nullptr) {
// Either the device doesn't know any keys, or !auth_required.
// So let's just try with the default certificate and see what happens.
LOG(INFO) << "No client CA list. Trying with default certificate.";
return 1;
}
const size_t num_cas = sk_X509_NAME_num(ca_list);
for (size_t i = 0; i < num_cas; ++i) {
auto* x509_name = sk_X509_NAME_value(ca_list, i);
auto adbFingerprint = ParseEncodedKeyFromCAIssuer(x509_name);
if (!adbFingerprint.has_value()) {
// This could be a real CA issuer. Unfortunately, we don't support
// it ATM.
continue;
}
LOG(INFO) << "Checking for fingerprint match [" << *adbFingerprint << "]";
auto encoded_key = SHA256HexStringToBits(*adbFingerprint);
if (!encoded_key.has_value()) {
continue;
}
// Check against our list of encoded keys for a match
std::lock_guard<std::mutex> lock(g_keys_mutex);
auto rsa_priv_key = g_keys.find(*encoded_key);
if (rsa_priv_key != g_keys.end()) {
LOG(INFO) << "Got SHA256 match on a key";
bssl::UniquePtr<EVP_PKEY> evp_pkey(EVP_PKEY_new());
CHECK(EVP_PKEY_set1_RSA(evp_pkey.get(), rsa_priv_key->second.get()));
auto x509 = GenerateX509Certificate(evp_pkey.get());
auto x509_str = X509ToPEMString(x509.get());
auto evp_str = Key::ToPEMString(evp_pkey.get());
TlsConnection::SetCertAndKey(ssl, x509_str, evp_str);
return 1;
} else {
LOG(INFO) << "No match for [" << *adbFingerprint << "]";
}
}
// Let's just try with the default certificate anyways, because daemon might
// not require auth, even though it has a list of keys.
return 1;
}

View file

@ -142,9 +142,9 @@ void TlsServer::OnFdEvent(int fd, unsigned ev) {
close_on_exec(new_fd.get());
disable_tcp_nagle(new_fd.get());
std::string serial = android::base::StringPrintf("host-%d", new_fd.get());
// TODO: register a tls transport
// register_socket_transport(std::move(new_fd), std::move(serial), port_, 1,
// [](atransport*) { return ReconnectResult::Abort; });
register_socket_transport(
std::move(new_fd), std::move(serial), port_, 1,
[](atransport*) { return ReconnectResult::Abort; }, true);
}
}
@ -224,4 +224,5 @@ void adbd_wifi_secure_connect(atransport* t) {
t->auth_id = adbd_auth_tls_device_connected(auth_ctx, kAdbTransportTypeWifi, t->auth_key.data(),
t->auth_key.size());
}
#endif /* !HOST */

View file

@ -23,10 +23,14 @@
#include <string.h>
#include <algorithm>
#include <chrono>
#include <iomanip>
#include <map>
#include <memory>
#include <thread>
#include <adb/crypto/rsa_2048_key.h>
#include <adb/tls/adb_ca_list.h>
#include <adbd_auth.h>
#include <android-base/file.h>
#include <android-base/no_destructor.h>
@ -45,8 +49,14 @@
#include "transport.h"
#include "types.h"
using namespace adb::crypto;
using namespace adb::tls;
using namespace std::chrono_literals;
static AdbdAuthContext* auth_ctx;
static RSA* rsa_pkey = nullptr;
static void adb_disconnected(void* unused, atransport* t);
static struct adisconnect adb_disconnect = {adb_disconnected, nullptr};
@ -93,6 +103,55 @@ static void IteratePublicKeys(std::function<bool(std::string_view public_key)> f
&f);
}
bssl::UniquePtr<STACK_OF(X509_NAME)> adbd_tls_client_ca_list() {
if (!auth_required) {
return nullptr;
}
bssl::UniquePtr<STACK_OF(X509_NAME)> ca_list(sk_X509_NAME_new_null());
IteratePublicKeys([&](std::string_view public_key) {
// TODO: do we really have to support both ' ' and '\t'?
std::vector<std::string> split = android::base::Split(std::string(public_key), " \t");
uint8_t keybuf[ANDROID_PUBKEY_ENCODED_SIZE + 1];
const std::string& pubkey = split[0];
if (b64_pton(pubkey.c_str(), keybuf, sizeof(keybuf)) != ANDROID_PUBKEY_ENCODED_SIZE) {
LOG(ERROR) << "Invalid base64 key " << pubkey;
return true;
}
RSA* key = nullptr;
if (!android_pubkey_decode(keybuf, ANDROID_PUBKEY_ENCODED_SIZE, &key)) {
LOG(ERROR) << "Failed to parse key " << pubkey;
return true;
}
bssl::UniquePtr<RSA> rsa_key(key);
unsigned char* dkey = nullptr;
int len = i2d_RSA_PUBKEY(rsa_key.get(), &dkey);
if (len <= 0 || dkey == nullptr) {
LOG(ERROR) << "Failed to encode RSA public key";
return true;
}
uint8_t digest[SHA256_DIGEST_LENGTH];
// Put the encoded key in the commonName attribute of the issuer name.
// Note that the commonName has a max length of 64 bytes, which is less
// than the SHA256_DIGEST_LENGTH.
SHA256(dkey, len, digest);
OPENSSL_free(dkey);
auto digest_str = SHA256BitsToHexString(
std::string_view(reinterpret_cast<const char*>(&digest[0]), sizeof(digest)));
LOG(INFO) << "fingerprint=[" << digest_str << "]";
auto issuer = CreateCAIssuerFromEncodedKey(digest_str);
CHECK(bssl::PushToStack(ca_list.get(), std::move(issuer)));
return true;
});
return ca_list;
}
bool adbd_auth_verify(const char* token, size_t token_size, const std::string& sig,
std::string* auth_key) {
bool authorized = false;
@ -217,5 +276,89 @@ void adbd_auth_confirm_key(atransport* t) {
}
void adbd_notify_framework_connected_key(atransport* t) {
adbd_auth_notify_auth(auth_ctx, t->auth_key.data(), t->auth_key.size());
t->auth_id = adbd_auth_notify_auth(auth_ctx, t->auth_key.data(), t->auth_key.size());
}
int adbd_tls_verify_cert(X509_STORE_CTX* ctx, std::string* auth_key) {
if (!auth_required) {
// Any key will do.
LOG(INFO) << __func__ << ": auth not required";
return 1;
}
bool authorized = false;
X509* cert = X509_STORE_CTX_get0_cert(ctx);
if (cert == nullptr) {
LOG(INFO) << "got null x509 certificate";
return 0;
}
bssl::UniquePtr<EVP_PKEY> evp_pkey(X509_get_pubkey(cert));
if (evp_pkey == nullptr) {
LOG(INFO) << "got null evp_pkey from x509 certificate";
return 0;
}
IteratePublicKeys([&](std::string_view public_key) {
// TODO: do we really have to support both ' ' and '\t'?
std::vector<std::string> split = android::base::Split(std::string(public_key), " \t");
uint8_t keybuf[ANDROID_PUBKEY_ENCODED_SIZE + 1];
const std::string& pubkey = split[0];
if (b64_pton(pubkey.c_str(), keybuf, sizeof(keybuf)) != ANDROID_PUBKEY_ENCODED_SIZE) {
LOG(ERROR) << "Invalid base64 key " << pubkey;
return true;
}
RSA* key = nullptr;
if (!android_pubkey_decode(keybuf, ANDROID_PUBKEY_ENCODED_SIZE, &key)) {
LOG(ERROR) << "Failed to parse key " << pubkey;
return true;
}
bool verified = false;
bssl::UniquePtr<EVP_PKEY> known_evp(EVP_PKEY_new());
EVP_PKEY_set1_RSA(known_evp.get(), key);
if (EVP_PKEY_cmp(known_evp.get(), evp_pkey.get())) {
LOG(INFO) << "Matched auth_key=" << public_key;
verified = true;
} else {
LOG(INFO) << "auth_key doesn't match [" << public_key << "]";
}
RSA_free(key);
if (verified) {
*auth_key = public_key;
authorized = true;
return false;
}
return true;
});
return authorized ? 1 : 0;
}
void adbd_auth_tls_handshake(atransport* t) {
if (rsa_pkey == nullptr) {
// Generate a random RSA key to feed into the X509 certificate
auto rsa_2048 = CreateRSA2048Key();
CHECK(rsa_2048.has_value());
rsa_pkey = EVP_PKEY_get1_RSA(rsa_2048->GetEvpPkey());
CHECK(rsa_pkey);
}
std::thread([t]() {
std::string auth_key;
if (t->connection()->DoTlsHandshake(rsa_pkey, &auth_key)) {
LOG(INFO) << "auth_key=" << auth_key;
if (t->IsTcpDevice()) {
t->auth_key = auth_key;
adbd_wifi_secure_connect(t);
} else {
adbd_auth_verified(t);
adbd_notify_framework_connected_key(t);
}
} else {
// Only allow one attempt at the handshake.
t->Kick();
}
}).detach();
}

View file

@ -105,8 +105,9 @@ void qemu_socket_thread(std::string_view addr) {
* exchange. */
std::string serial = android::base::StringPrintf("host-%d", fd.get());
WriteFdExactly(fd.get(), _start_req, strlen(_start_req));
register_socket_transport(std::move(fd), std::move(serial), port, 1,
[](atransport*) { return ReconnectResult::Abort; });
register_socket_transport(
std::move(fd), std::move(serial), port, 1,
[](atransport*) { return ReconnectResult::Abort; }, false);
}
/* Prepare for accepting of the next ADB host connection. */

View file

@ -260,6 +260,12 @@ struct UsbFfsConnection : public Connection {
CHECK_EQ(static_cast<size_t>(rc), sizeof(notify));
}
virtual bool DoTlsHandshake(RSA* key, std::string* auth_key) override final {
// TODO: support TLS for usb connections.
LOG(FATAL) << "Not supported yet.";
return false;
}
private:
void StartMonitor() {
// This is a bit of a mess.

View file

@ -79,6 +79,14 @@ where systemtype is "bootloader", "device", or "host", serialno is some
kind of unique ID (or empty), and banner is a human-readable version
or identifier string. The banner is used to transmit useful properties.
--- STLS(type, version, "") --------------------------------------------
Command constant: A_STLS
The TLS message informs the recipient that the connection will be encrypted
and will need to perform a TLS handshake. version is the current version of
the protocol.
--- AUTH(type, 0, "data") ----------------------------------------------
@ -207,6 +215,7 @@ to send across the wire.
#define A_OKAY 0x59414b4f
#define A_CLSE 0x45534c43
#define A_WRTE 0x45545257
#define A_STLS 0x534C5453

View file

@ -36,6 +36,9 @@
#include <set>
#include <thread>
#include <adb/crypto/rsa_2048_key.h>
#include <adb/crypto/x509_generator.h>
#include <adb/tls/tls_connection.h>
#include <android-base/logging.h>
#include <android-base/parsenetaddress.h>
#include <android-base/stringprintf.h>
@ -52,7 +55,10 @@
#include "fdevent/fdevent.h"
#include "sysdeps/chrono.h"
using namespace adb::crypto;
using namespace adb::tls;
using android::base::ScopedLockAssertion;
using TlsError = TlsConnection::TlsError;
static void remove_transport(atransport* transport);
static void transport_destroy(atransport* transport);
@ -279,18 +285,7 @@ void BlockingConnectionAdapter::Start() {
<< "): started multiple times";
}
read_thread_ = std::thread([this]() {
LOG(INFO) << this->transport_name_ << ": read thread spawning";
while (true) {
auto packet = std::make_unique<apacket>();
if (!underlying_->Read(packet.get())) {
PLOG(INFO) << this->transport_name_ << ": read failed";
break;
}
read_callback_(this, std::move(packet));
}
std::call_once(this->error_flag_, [this]() { this->error_callback_(this, "read failed"); });
});
StartReadThread();
write_thread_ = std::thread([this]() {
LOG(INFO) << this->transport_name_ << ": write thread spawning";
@ -319,6 +314,46 @@ void BlockingConnectionAdapter::Start() {
started_ = true;
}
void BlockingConnectionAdapter::StartReadThread() {
read_thread_ = std::thread([this]() {
LOG(INFO) << this->transport_name_ << ": read thread spawning";
while (true) {
auto packet = std::make_unique<apacket>();
if (!underlying_->Read(packet.get())) {
PLOG(INFO) << this->transport_name_ << ": read failed";
break;
}
bool got_stls_cmd = false;
if (packet->msg.command == A_STLS) {
got_stls_cmd = true;
}
read_callback_(this, std::move(packet));
// If we received the STLS packet, we are about to perform the TLS
// handshake. So this read thread must stop and resume after the
// handshake completes otherwise this will interfere in the process.
if (got_stls_cmd) {
LOG(INFO) << this->transport_name_
<< ": Received STLS packet. Stopping read thread.";
return;
}
}
std::call_once(this->error_flag_, [this]() { this->error_callback_(this, "read failed"); });
});
}
bool BlockingConnectionAdapter::DoTlsHandshake(RSA* key, std::string* auth_key) {
std::lock_guard<std::mutex> lock(mutex_);
if (read_thread_.joinable()) {
read_thread_.join();
}
bool success = this->underlying_->DoTlsHandshake(key, auth_key);
StartReadThread();
return success;
}
void BlockingConnectionAdapter::Reset() {
{
std::lock_guard<std::mutex> lock(mutex_);
@ -388,8 +423,36 @@ bool BlockingConnectionAdapter::Write(std::unique_ptr<apacket> packet) {
return true;
}
FdConnection::FdConnection(unique_fd fd) : fd_(std::move(fd)) {}
FdConnection::~FdConnection() {}
bool FdConnection::DispatchRead(void* buf, size_t len) {
if (tls_ != nullptr) {
// The TlsConnection doesn't allow 0 byte reads
if (len == 0) {
return true;
}
return tls_->ReadFully(buf, len);
}
return ReadFdExactly(fd_.get(), buf, len);
}
bool FdConnection::DispatchWrite(void* buf, size_t len) {
if (tls_ != nullptr) {
// The TlsConnection doesn't allow 0 byte writes
if (len == 0) {
return true;
}
return tls_->WriteFully(std::string_view(reinterpret_cast<const char*>(buf), len));
}
return WriteFdExactly(fd_.get(), buf, len);
}
bool FdConnection::Read(apacket* packet) {
if (!ReadFdExactly(fd_.get(), &packet->msg, sizeof(amessage))) {
if (!DispatchRead(&packet->msg, sizeof(amessage))) {
D("remote local: read terminated (message)");
return false;
}
@ -401,7 +464,7 @@ bool FdConnection::Read(apacket* packet) {
packet->payload.resize(packet->msg.data_length);
if (!ReadFdExactly(fd_.get(), &packet->payload[0], packet->payload.size())) {
if (!DispatchRead(&packet->payload[0], packet->payload.size())) {
D("remote local: terminated (data)");
return false;
}
@ -410,13 +473,13 @@ bool FdConnection::Read(apacket* packet) {
}
bool FdConnection::Write(apacket* packet) {
if (!WriteFdExactly(fd_.get(), &packet->msg, sizeof(packet->msg))) {
if (!DispatchWrite(&packet->msg, sizeof(packet->msg))) {
D("remote local: write terminated");
return false;
}
if (packet->msg.data_length) {
if (!WriteFdExactly(fd_.get(), &packet->payload[0], packet->msg.data_length)) {
if (!DispatchWrite(&packet->payload[0], packet->msg.data_length)) {
D("remote local: write terminated");
return false;
}
@ -425,6 +488,51 @@ bool FdConnection::Write(apacket* packet) {
return true;
}
bool FdConnection::DoTlsHandshake(RSA* key, std::string* auth_key) {
bssl::UniquePtr<EVP_PKEY> evp_pkey(EVP_PKEY_new());
if (!EVP_PKEY_set1_RSA(evp_pkey.get(), key)) {
LOG(ERROR) << "EVP_PKEY_set1_RSA failed";
return false;
}
auto x509 = GenerateX509Certificate(evp_pkey.get());
auto x509_str = X509ToPEMString(x509.get());
auto evp_str = Key::ToPEMString(evp_pkey.get());
#if ADB_HOST
tls_ = TlsConnection::Create(TlsConnection::Role::Client,
#else
tls_ = TlsConnection::Create(TlsConnection::Role::Server,
#endif
x509_str, evp_str, fd_);
CHECK(tls_);
#if ADB_HOST
// TLS 1.3 gives the client no message if the server rejected the
// certificate. This will enable a check in the tls connection to check
// whether the client certificate got rejected. Note that this assumes
// that, on handshake success, the server speaks first.
tls_->EnableClientPostHandshakeCheck(true);
// Add callback to set the certificate when server issues the
// CertificateRequest.
tls_->SetCertificateCallback(adb_tls_set_certificate);
// Allow any server certificate
tls_->SetCertVerifyCallback([](X509_STORE_CTX*) { return 1; });
#else
// Add callback to check certificate against a list of known public keys
tls_->SetCertVerifyCallback(
[auth_key](X509_STORE_CTX* ctx) { return adbd_tls_verify_cert(ctx, auth_key); });
// Add the list of allowed client CA issuers
auto ca_list = adbd_tls_client_ca_list();
tls_->SetClientCAList(ca_list.get());
#endif
auto err = tls_->DoHandshake();
if (err == TlsError::Success) {
return true;
}
tls_.reset();
return false;
}
void FdConnection::Close() {
adb_shutdown(fd_.get());
fd_.reset();
@ -750,6 +858,26 @@ void kick_all_transports() {
}
}
void kick_all_tcp_tls_transports() {
std::lock_guard<std::recursive_mutex> lock(transport_lock);
for (auto t : transport_list) {
if (t->IsTcpDevice() && t->use_tls) {
t->Kick();
}
}
}
#if !ADB_HOST
void kick_all_transports_by_auth_key(std::string_view auth_key) {
std::lock_guard<std::recursive_mutex> lock(transport_lock);
for (auto t : transport_list) {
if (auth_key == t->auth_key) {
t->Kick();
}
}
}
#endif
/* the fdevent select pump is single threaded */
void register_transport(atransport* transport) {
tmsg m;
@ -1026,6 +1154,10 @@ int atransport::get_protocol_version() const {
return protocol_version;
}
int atransport::get_tls_version() const {
return tls_version;
}
size_t atransport::get_max_payload() const {
return max_payload;
}
@ -1221,8 +1353,9 @@ void close_usb_devices(bool reset) {
#endif // ADB_HOST
bool register_socket_transport(unique_fd s, std::string serial, int port, int local,
atransport::ReconnectCallback reconnect, int* error) {
atransport::ReconnectCallback reconnect, bool use_tls, int* error) {
atransport* t = new atransport(std::move(reconnect), kCsOffline);
t->use_tls = use_tls;
D("transport: %s init'ing for socket %d, on port %d", serial.c_str(), s.get(), port);
if (init_socket_transport(t, std::move(s), port, local) < 0) {
@ -1360,6 +1493,15 @@ bool check_header(apacket* p, atransport* t) {
}
#if ADB_HOST
std::shared_ptr<RSA> atransport::Key() {
if (keys_.empty()) {
return nullptr;
}
std::shared_ptr<RSA> result = keys_[0];
return result;
}
std::shared_ptr<RSA> atransport::NextKey() {
if (keys_.empty()) {
LOG(INFO) << "fetching keys for transport " << this->serial_name();
@ -1367,10 +1509,11 @@ std::shared_ptr<RSA> atransport::NextKey() {
// We should have gotten at least one key: the one that's automatically generated.
CHECK(!keys_.empty());
} else {
keys_.pop_front();
}
std::shared_ptr<RSA> result = keys_[0];
keys_.pop_front();
return result;
}

View file

@ -43,6 +43,14 @@
typedef std::unordered_set<std::string> FeatureSet;
namespace adb {
namespace tls {
class TlsConnection;
} // namespace tls
} // namespace adb
const FeatureSet& supported_features();
// Encodes and decodes FeatureSet objects into human-readable strings.
@ -104,6 +112,8 @@ struct Connection {
virtual void Start() = 0;
virtual void Stop() = 0;
virtual bool DoTlsHandshake(RSA* key, std::string* auth_key = nullptr) = 0;
// Stop, and reset the device if it's a USB connection.
virtual void Reset();
@ -128,6 +138,8 @@ struct BlockingConnection {
virtual bool Read(apacket* packet) = 0;
virtual bool Write(apacket* packet) = 0;
virtual bool DoTlsHandshake(RSA* key, std::string* auth_key = nullptr) = 0;
// Terminate a connection.
// This method must be thread-safe, and must cause concurrent Reads/Writes to terminate.
// Formerly known as 'Kick' in atransport.
@ -146,9 +158,12 @@ struct BlockingConnectionAdapter : public Connection {
virtual void Start() override final;
virtual void Stop() override final;
virtual bool DoTlsHandshake(RSA* key, std::string* auth_key) override final;
virtual void Reset() override final;
private:
void StartReadThread() REQUIRES(mutex_);
bool started_ GUARDED_BY(mutex_) = false;
bool stopped_ GUARDED_BY(mutex_) = false;
@ -164,16 +179,22 @@ struct BlockingConnectionAdapter : public Connection {
};
struct FdConnection : public BlockingConnection {
explicit FdConnection(unique_fd fd) : fd_(std::move(fd)) {}
explicit FdConnection(unique_fd fd);
~FdConnection();
bool Read(apacket* packet) override final;
bool Write(apacket* packet) override final;
bool DoTlsHandshake(RSA* key, std::string* auth_key) override final;
void Close() override;
virtual void Reset() override final { Close(); }
private:
bool DispatchRead(void* buf, size_t len);
bool DispatchWrite(void* buf, size_t len);
unique_fd fd_;
std::unique_ptr<adb::tls::TlsConnection> tls_;
};
struct UsbConnection : public BlockingConnection {
@ -182,6 +203,7 @@ struct UsbConnection : public BlockingConnection {
bool Read(apacket* packet) override final;
bool Write(apacket* packet) override final;
bool DoTlsHandshake(RSA* key, std::string* auth_key) override final;
void Close() override final;
virtual void Reset() override final;
@ -279,6 +301,12 @@ class atransport : public enable_weak_from_this<atransport> {
std::string device;
std::string devpath;
// If this is set, the transport will initiate the connection with a
// START_TLS command, instead of AUTH.
bool use_tls = false;
int tls_version = A_STLS_VERSION;
int get_tls_version() const;
#if !ADB_HOST
// Used to provide the key to the framework.
std::string auth_key;
@ -288,6 +316,8 @@ class atransport : public enable_weak_from_this<atransport> {
bool IsTcpDevice() const { return type == kTransportLocal; }
#if ADB_HOST
// The current key being authorized.
std::shared_ptr<RSA> Key();
std::shared_ptr<RSA> NextKey();
void ResetKeys();
#endif
@ -400,6 +430,10 @@ std::string list_transports(bool long_listing);
atransport* find_transport(const char* serial);
void kick_all_tcp_devices();
void kick_all_transports();
void kick_all_tcp_tls_transports();
#if !ADB_HOST
void kick_all_transports_by_auth_key(std::string_view auth_key);
#endif
void register_transport(atransport* transport);
void register_usb_transport(usb_handle* h, const char* serial,
@ -410,7 +444,8 @@ void connect_device(const std::string& address, std::string* response);
/* cause new transports to be init'd and added to the list */
bool register_socket_transport(unique_fd s, std::string serial, int port, int local,
atransport::ReconnectCallback reconnect, int* error = nullptr);
atransport::ReconnectCallback reconnect, bool use_tls,
int* error = nullptr);
// This should only be used for transports with connection_state == kCsNoPerm.
void unregister_usb_transport(usb_handle* usb);

View file

@ -155,6 +155,11 @@ struct NonblockingFdConnection : public Connection {
thread_.join();
}
bool DoTlsHandshake(RSA* key, std::string* auth_key) override final {
LOG(FATAL) << "Not supported yet";
return false;
}
void WakeThread() {
uint64_t buf = 0;
if (TEMP_FAILURE_RETRY(adb_write(wake_fd_write_.get(), &buf, sizeof(buf))) != sizeof(buf)) {

View file

@ -126,7 +126,8 @@ void connect_device(const std::string& address, std::string* response) {
};
int error;
if (!register_socket_transport(std::move(fd), serial, port, 0, std::move(reconnect), &error)) {
if (!register_socket_transport(std::move(fd), serial, port, 0, std::move(reconnect), false,
&error)) {
if (error == EALREADY) {
*response = android::base::StringPrintf("already connected to %s", serial.c_str());
} else if (error == EPERM) {
@ -163,8 +164,9 @@ int local_connect_arbitrary_ports(int console_port, int adb_port, std::string* e
close_on_exec(fd.get());
disable_tcp_nagle(fd.get());
std::string serial = getEmulatorSerialString(console_port);
if (register_socket_transport(std::move(fd), std::move(serial), adb_port, 1,
[](atransport*) { return ReconnectResult::Abort; })) {
if (register_socket_transport(
std::move(fd), std::move(serial), adb_port, 1,
[](atransport*) { return ReconnectResult::Abort; }, false)) {
return 0;
}
}
@ -271,8 +273,9 @@ void server_socket_thread(std::function<unique_fd(std::string_view, std::string*
std::string serial = android::base::StringPrintf("host-%d", fd.get());
// We don't care about port value in "register_socket_transport" as it is used
// only from ADB_HOST. "server_socket_thread" is never called from ADB_HOST.
register_socket_transport(std::move(fd), std::move(serial), 0, 1,
[](atransport*) { return ReconnectResult::Abort; });
register_socket_transport(
std::move(fd), std::move(serial), 0, 1,
[](atransport*) { return ReconnectResult::Abort; }, false);
}
}
D("transport: server_socket_thread() exiting");
@ -365,7 +368,7 @@ int init_socket_transport(atransport* t, unique_fd fd, int adb_port, int local)
if (local) {
auto emulator_connection = std::make_unique<EmulatorConnection>(std::move(fd), adb_port);
t->SetConnection(
std::make_unique<BlockingConnectionAdapter>(std::move(emulator_connection)));
std::make_unique<BlockingConnectionAdapter>(std::move(emulator_connection)));
std::lock_guard<std::mutex> lock(local_transports_lock);
atransport* existing_transport = find_emulator_transport_by_adb_port_locked(adb_port);
if (existing_transport != nullptr) {

View file

@ -171,6 +171,12 @@ bool UsbConnection::Write(apacket* packet) {
return true;
}
bool UsbConnection::DoTlsHandshake(RSA* key, std::string* auth_key) {
// TODO: support TLS for usb connections
LOG(FATAL) << "Not supported yet.";
return false;
}
void UsbConnection::Reset() {
usb_reset(handle_);
usb_kick(handle_);