Merge "Some minor fixes to libadb_tls_connection."
This commit is contained in:
commit
8f04b0ca58
4 changed files with 131 additions and 127 deletions
|
@ -32,13 +32,13 @@ namespace {
|
|||
// CA issuer identifier to distinguished embedded keys. Also has version
|
||||
// information appended to the end of the string (e.g. "AdbKey-0").
|
||||
static constexpr int kAdbKeyIdentifierNid = NID_organizationName;
|
||||
static constexpr char kAdbKeyIdentifierPrefix[] = "AdbKey-";
|
||||
static constexpr int kAdbKeyVersion = 0;
|
||||
static constexpr char kAdbKeyIdentifierV0[] = "AdbKey-0";
|
||||
|
||||
// Where we store the actual data
|
||||
static constexpr int kAdbKeyValueNid = NID_commonName;
|
||||
|
||||
// TODO: Remove this once X509_NAME_add_entry_by_NID is fixed to use const unsigned char*
|
||||
// https://boringssl-review.googlesource.com/c/boringssl/+/39764
|
||||
int X509_NAME_add_entry_by_NID_const(X509_NAME* name, int nid, int type, const unsigned char* bytes,
|
||||
int len, int loc, int set) {
|
||||
return X509_NAME_add_entry_by_NID(name, nid, type, const_cast<unsigned char*>(bytes), len, loc,
|
||||
|
@ -55,13 +55,13 @@ std::optional<std::string> GetX509NameTextByNid(X509_NAME* name, int nid) {
|
|||
// |len| is the len of the text excluding the final null
|
||||
int len = X509_NAME_get_text_by_NID(name, nid, nullptr, -1);
|
||||
if (len <= 0) {
|
||||
return {};
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Include the space for the final null byte
|
||||
std::vector<char> buf(len + 1, '\0');
|
||||
CHECK(X509_NAME_get_text_by_NID(name, nid, buf.data(), buf.size()));
|
||||
return buf.data();
|
||||
return std::make_optional(std::string(buf.data()));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -73,8 +73,7 @@ bssl::UniquePtr<X509_NAME> CreateCAIssuerFromEncodedKey(std::string_view key) {
|
|||
// "O=AdbKey-0;CN=<key>;"
|
||||
CHECK(!key.empty());
|
||||
|
||||
std::string identifier = kAdbKeyIdentifierPrefix;
|
||||
identifier += std::to_string(kAdbKeyVersion);
|
||||
std::string identifier = kAdbKeyIdentifierV0;
|
||||
bssl::UniquePtr<X509_NAME> name(X509_NAME_new());
|
||||
CHECK(X509_NAME_add_entry_by_NID_const(name.get(), kAdbKeyIdentifierNid, MBSTRING_ASC,
|
||||
reinterpret_cast<const uint8_t*>(identifier.data()),
|
||||
|
@ -91,27 +90,34 @@ std::optional<std::string> ParseEncodedKeyFromCAIssuer(X509_NAME* issuer) {
|
|||
CHECK(issuer);
|
||||
|
||||
auto buf = GetX509NameTextByNid(issuer, kAdbKeyIdentifierNid);
|
||||
if (!buf || !android::base::StartsWith(*buf, kAdbKeyIdentifierPrefix)) {
|
||||
return {};
|
||||
if (!buf) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
return GetX509NameTextByNid(issuer, kAdbKeyValueNid);
|
||||
// Check for supported versions
|
||||
if (*buf == kAdbKeyIdentifierV0) {
|
||||
return GetX509NameTextByNid(issuer, kAdbKeyValueNid);
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::string SHA256BitsToHexString(std::string_view sha256) {
|
||||
CHECK_EQ(sha256.size(), static_cast<size_t>(SHA256_DIGEST_LENGTH));
|
||||
std::stringstream ss;
|
||||
auto* u8 = reinterpret_cast<const uint8_t*>(sha256.data());
|
||||
ss << std::uppercase << std::setfill('0') << std::hex;
|
||||
// Convert to hex-string representation
|
||||
for (size_t i = 0; i < SHA256_DIGEST_LENGTH; ++i) {
|
||||
ss << std::setw(2) << (0x00FF & sha256[i]);
|
||||
// Need to cast to something bigger than one byte, or
|
||||
// stringstream will interpret it as a char value.
|
||||
ss << std::setw(2) << static_cast<uint16_t>(u8[i]);
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
std::optional<std::string> SHA256HexStringToBits(std::string_view sha256_str) {
|
||||
if (sha256_str.size() != SHA256_DIGEST_LENGTH * 2) {
|
||||
return {};
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::string result;
|
||||
|
@ -119,7 +125,7 @@ std::optional<std::string> SHA256HexStringToBits(std::string_view sha256_str) {
|
|||
auto bytestr = std::string(sha256_str.substr(i * 2, 2));
|
||||
if (!IsHexDigit(bytestr[0]) || !IsHexDigit(bytestr[1])) {
|
||||
LOG(ERROR) << "SHA256 string has invalid non-hex chars";
|
||||
return {};
|
||||
return std::nullopt;
|
||||
}
|
||||
result += static_cast<char>(std::stol(bytestr, nullptr, 16));
|
||||
}
|
||||
|
|
|
@ -55,16 +55,15 @@ class TlsConnection {
|
|||
|
||||
// Adds a trusted certificate to the list for the SSL connection.
|
||||
// During the handshake phase, it will check the list of trusted certificates.
|
||||
// The connection will fail if the peer's certificate is not in the list. Use
|
||||
// |EnableCertificateVerification(false)| to disable certificate
|
||||
// verification.
|
||||
// The connection will fail if the peer's certificate is not in the list. If
|
||||
// you would like to accept any certificate, use #SetCertVerifyCallback and
|
||||
// set your callback to always return 1.
|
||||
//
|
||||
// Returns true if |cert| was successfully added, false otherwise.
|
||||
virtual bool AddTrustedCertificate(std::string_view cert) = 0;
|
||||
|
||||
// Sets a custom certificate verify callback. |cb| must return 1 if the
|
||||
// certificate is trusted. Otherwise, return 0 if not. Note that |cb| is
|
||||
// only used if EnableCertificateVerification(false).
|
||||
// certificate is trusted. Otherwise, return 0 if not.
|
||||
virtual void SetCertVerifyCallback(CertVerifyCb cb) = 0;
|
||||
|
||||
// Configures a client |ca_list| that the server sends to the client in the
|
||||
|
|
|
@ -199,24 +199,10 @@ using CAIssuer = std::vector<CAIssuerField>;
|
|||
static std::vector<CAIssuer> kCAIssuers = {
|
||||
{
|
||||
{NID_commonName, {'a', 'b', 'c', 'd', 'e'}},
|
||||
{NID_organizationName,
|
||||
{
|
||||
'd',
|
||||
'e',
|
||||
'f',
|
||||
'g',
|
||||
}},
|
||||
{NID_organizationName, {'d', 'e', 'f', 'g'}},
|
||||
},
|
||||
{
|
||||
{NID_commonName,
|
||||
{
|
||||
'h',
|
||||
'i',
|
||||
'j',
|
||||
'k',
|
||||
'l',
|
||||
'm',
|
||||
}},
|
||||
{NID_commonName, {'h', 'i', 'j', 'k', 'l', 'm'}},
|
||||
{NID_countryName, {'n', 'o'}},
|
||||
},
|
||||
};
|
||||
|
@ -224,8 +210,6 @@ static std::vector<CAIssuer> kCAIssuers = {
|
|||
class AdbWifiTlsConnectionTest : public testing::Test {
|
||||
protected:
|
||||
virtual void SetUp() override {
|
||||
// TODO: move client code in each test into its own thread, as the
|
||||
// socket pair buffer is limited.
|
||||
android::base::Socketpair(SOCK_STREAM, &server_fd_, &client_fd_);
|
||||
server_ = TlsConnection::Create(TlsConnection::Role::Server, kTestRsa2048ServerCert,
|
||||
kTestRsa2048ServerPrivKey, server_fd_);
|
||||
|
@ -257,14 +241,8 @@ class AdbWifiTlsConnectionTest : public testing::Test {
|
|||
return ret;
|
||||
}
|
||||
|
||||
void StartClientHandshakeAsync(bool expect_success) {
|
||||
client_thread_ = std::thread([=]() {
|
||||
if (expect_success) {
|
||||
EXPECT_EQ(client_->DoHandshake(), TlsError::Success);
|
||||
} else {
|
||||
EXPECT_NE(client_->DoHandshake(), TlsError::Success);
|
||||
}
|
||||
});
|
||||
void StartClientHandshakeAsync(TlsError expected) {
|
||||
client_thread_ = std::thread([=]() { EXPECT_EQ(client_->DoHandshake(), expected); });
|
||||
}
|
||||
|
||||
void WaitForClientConnection() {
|
||||
|
@ -313,45 +291,52 @@ TEST_F(AdbWifiTlsConnectionTest, NoCertificateVerification) {
|
|||
// Allow any certificate
|
||||
server_->SetCertVerifyCallback([](X509_STORE_CTX*) { return 1; });
|
||||
client_->SetCertVerifyCallback([](X509_STORE_CTX*) { return 1; });
|
||||
StartClientHandshakeAsync(true);
|
||||
StartClientHandshakeAsync(TlsError::Success);
|
||||
|
||||
// Handshake should succeed
|
||||
EXPECT_EQ(server_->DoHandshake(), TlsError::Success);
|
||||
ASSERT_EQ(server_->DoHandshake(), TlsError::Success);
|
||||
WaitForClientConnection();
|
||||
|
||||
// Client write, server read
|
||||
EXPECT_TRUE(client_->WriteFully(
|
||||
std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
|
||||
// Test client/server read and writes
|
||||
client_thread_ = std::thread([&]() {
|
||||
EXPECT_TRUE(client_->WriteFully(
|
||||
std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
|
||||
// Try with overloaded ReadFully
|
||||
std::vector<uint8_t> buf(msg_.size());
|
||||
ASSERT_TRUE(client_->ReadFully(buf.data(), msg_.size()));
|
||||
EXPECT_EQ(buf, msg_);
|
||||
});
|
||||
|
||||
auto data = server_->ReadFully(msg_.size());
|
||||
EXPECT_EQ(data, msg_);
|
||||
|
||||
// Client read, server write
|
||||
EXPECT_TRUE(server_->WriteFully(
|
||||
std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
|
||||
// Try with overloaded ReadFully
|
||||
std::vector<uint8_t> buf(msg_.size());
|
||||
ASSERT_TRUE(client_->ReadFully(buf.data(), msg_.size()));
|
||||
EXPECT_EQ(buf, msg_);
|
||||
|
||||
WaitForClientConnection();
|
||||
}
|
||||
|
||||
TEST_F(AdbWifiTlsConnectionTest, NoTrustedCertificates) {
|
||||
StartClientHandshakeAsync(false);
|
||||
StartClientHandshakeAsync(TlsError::CertificateRejected);
|
||||
|
||||
// Handshake should not succeed
|
||||
EXPECT_NE(server_->DoHandshake(), TlsError::Success);
|
||||
ASSERT_EQ(server_->DoHandshake(), TlsError::PeerRejectedCertificate);
|
||||
WaitForClientConnection();
|
||||
|
||||
// Client write, server read should fail
|
||||
EXPECT_FALSE(client_->WriteFully(
|
||||
std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
|
||||
// All writes and reads should fail
|
||||
client_thread_ = std::thread([&]() {
|
||||
// Client write, server read should fail
|
||||
EXPECT_FALSE(client_->WriteFully(
|
||||
std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
|
||||
auto data = client_->ReadFully(msg_.size());
|
||||
EXPECT_EQ(data.size(), 0);
|
||||
});
|
||||
|
||||
auto data = server_->ReadFully(msg_.size());
|
||||
EXPECT_EQ(data.size(), 0);
|
||||
|
||||
// Client read, server write should fail
|
||||
EXPECT_FALSE(server_->WriteFully(
|
||||
std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
|
||||
data = client_->ReadFully(msg_.size());
|
||||
EXPECT_EQ(data.size(), 0);
|
||||
|
||||
WaitForClientConnection();
|
||||
}
|
||||
|
||||
TEST_F(AdbWifiTlsConnectionTest, AddTrustedCertificates) {
|
||||
|
@ -359,23 +344,26 @@ TEST_F(AdbWifiTlsConnectionTest, AddTrustedCertificates) {
|
|||
EXPECT_TRUE(client_->AddTrustedCertificate(kTestRsa2048ServerCert));
|
||||
EXPECT_TRUE(server_->AddTrustedCertificate(kTestRsa2048ClientCert));
|
||||
|
||||
StartClientHandshakeAsync(true);
|
||||
StartClientHandshakeAsync(TlsError::Success);
|
||||
|
||||
// Handshake should succeed
|
||||
EXPECT_EQ(server_->DoHandshake(), TlsError::Success);
|
||||
ASSERT_EQ(server_->DoHandshake(), TlsError::Success);
|
||||
WaitForClientConnection();
|
||||
|
||||
// Client write, server read
|
||||
EXPECT_TRUE(client_->WriteFully(
|
||||
std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
|
||||
// All read writes should succeed
|
||||
client_thread_ = std::thread([&]() {
|
||||
EXPECT_TRUE(client_->WriteFully(
|
||||
std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
|
||||
auto data = client_->ReadFully(msg_.size());
|
||||
EXPECT_EQ(data, msg_);
|
||||
});
|
||||
|
||||
auto data = server_->ReadFully(msg_.size());
|
||||
EXPECT_EQ(data, msg_);
|
||||
|
||||
// Client read, server write
|
||||
EXPECT_TRUE(server_->WriteFully(
|
||||
std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
|
||||
data = client_->ReadFully(msg_.size());
|
||||
EXPECT_EQ(data, msg_);
|
||||
|
||||
WaitForClientConnection();
|
||||
}
|
||||
|
||||
TEST_F(AdbWifiTlsConnectionTest, AddTrustedCertificates_ClientWrongCert) {
|
||||
|
@ -387,23 +375,26 @@ TEST_F(AdbWifiTlsConnectionTest, AddTrustedCertificates_ClientWrongCert) {
|
|||
// Without enabling EnableClientPostHandshakeCheck(), DoHandshake() will
|
||||
// succeed, because in TLS 1.3, the client doesn't get notified if the
|
||||
// server rejected the certificate until a read operation is called.
|
||||
StartClientHandshakeAsync(true);
|
||||
StartClientHandshakeAsync(TlsError::Success);
|
||||
|
||||
// Handshake should fail for server, succeed for client
|
||||
EXPECT_NE(server_->DoHandshake(), TlsError::Success);
|
||||
ASSERT_EQ(server_->DoHandshake(), TlsError::CertificateRejected);
|
||||
WaitForClientConnection();
|
||||
|
||||
// Client write succeeds, server read should fail
|
||||
EXPECT_TRUE(client_->WriteFully(
|
||||
std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
|
||||
// Client writes will succeed, everything else will fail.
|
||||
client_thread_ = std::thread([&]() {
|
||||
EXPECT_TRUE(client_->WriteFully(
|
||||
std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
|
||||
auto data = client_->ReadFully(msg_.size());
|
||||
EXPECT_EQ(data.size(), 0);
|
||||
});
|
||||
|
||||
auto data = server_->ReadFully(msg_.size());
|
||||
EXPECT_EQ(data.size(), 0);
|
||||
|
||||
// Client read, server write should fail
|
||||
EXPECT_FALSE(server_->WriteFully(
|
||||
std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
|
||||
data = client_->ReadFully(msg_.size());
|
||||
EXPECT_EQ(data.size(), 0);
|
||||
|
||||
WaitForClientConnection();
|
||||
}
|
||||
|
||||
TEST_F(AdbWifiTlsConnectionTest, ExportKeyingMaterial) {
|
||||
|
@ -415,10 +406,10 @@ TEST_F(AdbWifiTlsConnectionTest, ExportKeyingMaterial) {
|
|||
EXPECT_TRUE(client_->AddTrustedCertificate(kTestRsa2048ServerCert));
|
||||
EXPECT_TRUE(server_->AddTrustedCertificate(kTestRsa2048ClientCert));
|
||||
|
||||
StartClientHandshakeAsync(true);
|
||||
StartClientHandshakeAsync(TlsError::Success);
|
||||
|
||||
// Handshake should succeed
|
||||
EXPECT_EQ(server_->DoHandshake(), TlsError::Success);
|
||||
ASSERT_EQ(server_->DoHandshake(), TlsError::Success);
|
||||
WaitForClientConnection();
|
||||
|
||||
// Verify the client and server's exported key material match.
|
||||
|
@ -439,10 +430,10 @@ TEST_F(AdbWifiTlsConnectionTest, SetCertVerifyCallback_ClientAcceptsServerReject
|
|||
// Client handshake should succeed, because in TLS 1.3, client does not
|
||||
// realize that the peer rejected the certificate until after a read
|
||||
// operation.
|
||||
client_thread_ = std::thread([&]() { EXPECT_EQ(client_->DoHandshake(), TlsError::Success); });
|
||||
StartClientHandshakeAsync(TlsError::Success);
|
||||
|
||||
// Server handshake should fail
|
||||
EXPECT_EQ(server_->DoHandshake(), TlsError::CertificateRejected);
|
||||
ASSERT_EQ(server_->DoHandshake(), TlsError::CertificateRejected);
|
||||
WaitForClientConnection();
|
||||
}
|
||||
|
||||
|
@ -455,11 +446,10 @@ TEST_F(AdbWifiTlsConnectionTest, SetCertVerifyCallback_ClientAcceptsServerReject
|
|||
server_->SetCertVerifyCallback([](X509_STORE_CTX*) { return 0; });
|
||||
|
||||
// Client handshake should fail because server rejects everything
|
||||
client_thread_ = std::thread(
|
||||
[&]() { EXPECT_EQ(client_->DoHandshake(), TlsError::PeerRejectedCertificate); });
|
||||
StartClientHandshakeAsync(TlsError::PeerRejectedCertificate);
|
||||
|
||||
// Server handshake should fail
|
||||
EXPECT_EQ(server_->DoHandshake(), TlsError::CertificateRejected);
|
||||
ASSERT_EQ(server_->DoHandshake(), TlsError::CertificateRejected);
|
||||
WaitForClientConnection();
|
||||
}
|
||||
|
||||
|
@ -469,11 +459,10 @@ TEST_F(AdbWifiTlsConnectionTest, SetCertVerifyCallback_ClientRejectsServerAccept
|
|||
// Server accepts all
|
||||
server_->SetCertVerifyCallback([](X509_STORE_CTX*) { return 1; });
|
||||
// Client handshake should fail
|
||||
client_thread_ = std::thread(
|
||||
[&]() { EXPECT_EQ(client_->DoHandshake(), TlsError::CertificateRejected); });
|
||||
StartClientHandshakeAsync(TlsError::CertificateRejected);
|
||||
|
||||
// Server handshake should fail
|
||||
EXPECT_EQ(server_->DoHandshake(), TlsError::PeerRejectedCertificate);
|
||||
ASSERT_EQ(server_->DoHandshake(), TlsError::PeerRejectedCertificate);
|
||||
WaitForClientConnection();
|
||||
}
|
||||
|
||||
|
@ -488,15 +477,15 @@ TEST_F(AdbWifiTlsConnectionTest, SetCertVerifyCallback_ClientRejectsServerAccept
|
|||
server_->SetCertVerifyCallback([](X509_STORE_CTX*) { return 1; });
|
||||
|
||||
// Client handshake should fail
|
||||
client_thread_ = std::thread(
|
||||
[&]() { EXPECT_EQ(client_->DoHandshake(), TlsError::CertificateRejected); });
|
||||
StartClientHandshakeAsync(TlsError::CertificateRejected);
|
||||
|
||||
// Server handshake should fail
|
||||
EXPECT_EQ(server_->DoHandshake(), TlsError::PeerRejectedCertificate);
|
||||
ASSERT_EQ(server_->DoHandshake(), TlsError::PeerRejectedCertificate);
|
||||
WaitForClientConnection();
|
||||
}
|
||||
|
||||
TEST_F(AdbWifiTlsConnectionTest, EnableClientPostHandshakeCheck_ClientWrongCert) {
|
||||
client_->AddTrustedCertificate(kTestRsa2048ServerCert);
|
||||
// client's DoHandshake() will fail if the server rejected the certificate
|
||||
client_->EnableClientPostHandshakeCheck(true);
|
||||
|
||||
|
@ -504,23 +493,26 @@ TEST_F(AdbWifiTlsConnectionTest, EnableClientPostHandshakeCheck_ClientWrongCert)
|
|||
EXPECT_TRUE(server_->AddTrustedCertificate(kTestRsa2048UnknownCert));
|
||||
|
||||
// Handshake should fail for client
|
||||
StartClientHandshakeAsync(false);
|
||||
StartClientHandshakeAsync(TlsError::PeerRejectedCertificate);
|
||||
|
||||
// Handshake should fail for server
|
||||
EXPECT_NE(server_->DoHandshake(), TlsError::Success);
|
||||
ASSERT_EQ(server_->DoHandshake(), TlsError::CertificateRejected);
|
||||
WaitForClientConnection();
|
||||
|
||||
// Client write fails, server read should fail
|
||||
EXPECT_FALSE(client_->WriteFully(
|
||||
std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
|
||||
// All read writes should fail
|
||||
client_thread_ = std::thread([&]() {
|
||||
EXPECT_FALSE(client_->WriteFully(
|
||||
std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
|
||||
auto data = client_->ReadFully(msg_.size());
|
||||
EXPECT_EQ(data.size(), 0);
|
||||
});
|
||||
|
||||
auto data = server_->ReadFully(msg_.size());
|
||||
EXPECT_EQ(data.size(), 0);
|
||||
|
||||
// Client read, server write should fail
|
||||
EXPECT_FALSE(server_->WriteFully(
|
||||
std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
|
||||
data = client_->ReadFully(msg_.size());
|
||||
EXPECT_EQ(data.size(), 0);
|
||||
|
||||
WaitForClientConnection();
|
||||
}
|
||||
|
||||
TEST_F(AdbWifiTlsConnectionTest, SetClientCAList_Empty) {
|
||||
|
@ -569,12 +561,12 @@ TEST_F(AdbWifiTlsConnectionTest, SetClientCAList_Smoke) {
|
|||
return 1;
|
||||
});
|
||||
// Client handshake should succeed
|
||||
EXPECT_EQ(client_->DoHandshake(), TlsError::Success);
|
||||
ASSERT_EQ(client_->DoHandshake(), TlsError::Success);
|
||||
});
|
||||
|
||||
EXPECT_TRUE(server_->AddTrustedCertificate(kTestRsa2048UnknownCert));
|
||||
// Server handshake should succeed
|
||||
EXPECT_EQ(server_->DoHandshake(), TlsError::Success);
|
||||
ASSERT_EQ(server_->DoHandshake(), TlsError::Success);
|
||||
client_thread_.join();
|
||||
}
|
||||
|
||||
|
@ -604,12 +596,12 @@ TEST_F(AdbWifiTlsConnectionTest, SetClientCAList_AdbCAList) {
|
|||
return 1;
|
||||
});
|
||||
// Client handshake should succeed
|
||||
EXPECT_EQ(client_->DoHandshake(), TlsError::Success);
|
||||
ASSERT_EQ(client_->DoHandshake(), TlsError::Success);
|
||||
});
|
||||
|
||||
server_->SetCertVerifyCallback([](X509_STORE_CTX*) { return 1; });
|
||||
// Server handshake should succeed
|
||||
EXPECT_EQ(server_->DoHandshake(), TlsError::Success);
|
||||
ASSERT_EQ(server_->DoHandshake(), TlsError::Success);
|
||||
client_thread_.join();
|
||||
}
|
||||
} // namespace tls
|
||||
|
|
|
@ -61,6 +61,7 @@ class TlsConnectionImpl : public TlsConnection {
|
|||
static const char* SSLErrorString();
|
||||
void Invalidate();
|
||||
TlsError GetFailureReason(int err);
|
||||
const char* RoleToString() { return role_ == Role::Server ? kServerRoleStr : kClientRoleStr; }
|
||||
|
||||
Role role_;
|
||||
bssl::UniquePtr<EVP_PKEY> priv_key_;
|
||||
|
@ -75,15 +76,19 @@ class TlsConnectionImpl : public TlsConnection {
|
|||
CertVerifyCb cert_verify_cb_;
|
||||
SetCertCb set_cert_cb_;
|
||||
borrowed_fd fd_;
|
||||
static constexpr char kClientRoleStr[] = "[client]: ";
|
||||
static constexpr char kServerRoleStr[] = "[server]: ";
|
||||
}; // TlsConnectionImpl
|
||||
|
||||
TlsConnectionImpl::TlsConnectionImpl(Role role, std::string_view cert, std::string_view priv_key,
|
||||
borrowed_fd fd)
|
||||
: role_(role), fd_(fd) {
|
||||
CHECK(!cert.empty() && !priv_key.empty());
|
||||
LOG(INFO) << "Initializing adbwifi TlsConnection";
|
||||
LOG(INFO) << RoleToString() << "Initializing adbwifi TlsConnection";
|
||||
cert_ = BufferFromPEM(cert);
|
||||
CHECK(cert_);
|
||||
priv_key_ = EvpPkeyFromPEM(priv_key);
|
||||
CHECK(priv_key_);
|
||||
}
|
||||
|
||||
TlsConnectionImpl::~TlsConnectionImpl() {
|
||||
|
@ -149,7 +154,7 @@ bool TlsConnectionImpl::AddTrustedCertificate(std::string_view cert) {
|
|||
// Create X509 buffer from the certificate string
|
||||
auto buf = X509FromBuffer(BufferFromPEM(cert));
|
||||
if (buf == nullptr) {
|
||||
LOG(ERROR) << "Failed to create a X509 buffer for the certificate.";
|
||||
LOG(ERROR) << RoleToString() << "Failed to create a X509 buffer for the certificate.";
|
||||
return false;
|
||||
}
|
||||
known_certificates_.push_back(std::move(buf));
|
||||
|
@ -205,8 +210,7 @@ TlsConnection::TlsError TlsConnectionImpl::GetFailureReason(int err) {
|
|||
}
|
||||
|
||||
TlsConnection::TlsError TlsConnectionImpl::DoHandshake() {
|
||||
int err = -1;
|
||||
LOG(INFO) << "Starting adbwifi tls handshake";
|
||||
LOG(INFO) << RoleToString() << "Starting adbwifi tls handshake";
|
||||
ssl_ctx_.reset(SSL_CTX_new(TLS_method()));
|
||||
// TODO: Remove set_max_proto_version() once external/boringssl is updated
|
||||
// past
|
||||
|
@ -214,14 +218,14 @@ TlsConnection::TlsError TlsConnectionImpl::DoHandshake() {
|
|||
if (ssl_ctx_.get() == nullptr ||
|
||||
!SSL_CTX_set_min_proto_version(ssl_ctx_.get(), TLS1_3_VERSION) ||
|
||||
!SSL_CTX_set_max_proto_version(ssl_ctx_.get(), TLS1_3_VERSION)) {
|
||||
LOG(ERROR) << "Failed to create SSL context";
|
||||
LOG(ERROR) << RoleToString() << "Failed to create SSL context";
|
||||
return TlsError::UnknownFailure;
|
||||
}
|
||||
|
||||
// Register user-supplied known certificates
|
||||
for (auto const& cert : known_certificates_) {
|
||||
if (X509_STORE_add_cert(SSL_CTX_get_cert_store(ssl_ctx_.get()), cert.get()) == 0) {
|
||||
LOG(ERROR) << "Unable to add certificates into the X509_STORE";
|
||||
LOG(ERROR) << RoleToString() << "Unable to add certificates into the X509_STORE";
|
||||
return TlsError::UnknownFailure;
|
||||
}
|
||||
}
|
||||
|
@ -248,7 +252,8 @@ TlsConnection::TlsError TlsConnectionImpl::DoHandshake() {
|
|||
};
|
||||
if (!SSL_CTX_set_chain_and_key(ssl_ctx_.get(), cert_chain.data(), cert_chain.size(),
|
||||
priv_key_.get(), nullptr)) {
|
||||
LOG(ERROR) << "Unable to register the certificate chain file and private key ["
|
||||
LOG(ERROR) << RoleToString()
|
||||
<< "Unable to register the certificate chain file and private key ["
|
||||
<< SSLErrorString() << "]";
|
||||
Invalidate();
|
||||
return TlsError::UnknownFailure;
|
||||
|
@ -259,19 +264,21 @@ TlsConnection::TlsError TlsConnectionImpl::DoHandshake() {
|
|||
// Okay! Let's try to do the handshake!
|
||||
ssl_.reset(SSL_new(ssl_ctx_.get()));
|
||||
if (!SSL_set_fd(ssl_.get(), fd_.get())) {
|
||||
LOG(ERROR) << "SSL_set_fd failed. [" << SSLErrorString() << "]";
|
||||
LOG(ERROR) << RoleToString() << "SSL_set_fd failed. [" << SSLErrorString() << "]";
|
||||
return TlsError::UnknownFailure;
|
||||
}
|
||||
|
||||
switch (role_) {
|
||||
case Role::Server:
|
||||
err = SSL_accept(ssl_.get());
|
||||
SSL_set_accept_state(ssl_.get());
|
||||
break;
|
||||
case Role::Client:
|
||||
err = SSL_connect(ssl_.get());
|
||||
SSL_set_connect_state(ssl_.get());
|
||||
break;
|
||||
}
|
||||
if (err != 1) {
|
||||
LOG(ERROR) << "Handshake failed in SSL_accept/SSL_connect [" << SSLErrorString() << "]";
|
||||
if (SSL_do_handshake(ssl_.get()) != 1) {
|
||||
LOG(ERROR) << RoleToString() << "Handshake failed in SSL_accept/SSL_connect ["
|
||||
<< SSLErrorString() << "]";
|
||||
auto sslerr = ERR_get_error();
|
||||
Invalidate();
|
||||
return GetFailureReason(sslerr);
|
||||
|
@ -281,16 +288,16 @@ TlsConnection::TlsError TlsConnectionImpl::DoHandshake() {
|
|||
uint8_t check;
|
||||
// Try to peek one byte for any failures. This assumes on success that
|
||||
// the server actually sends something.
|
||||
err = SSL_peek(ssl_.get(), &check, 1);
|
||||
if (err <= 0) {
|
||||
LOG(ERROR) << "Post-handshake SSL_peek failed [" << SSLErrorString() << "]";
|
||||
if (SSL_peek(ssl_.get(), &check, 1) <= 0) {
|
||||
LOG(ERROR) << RoleToString() << "Post-handshake SSL_peek failed [" << SSLErrorString()
|
||||
<< "]";
|
||||
auto sslerr = ERR_get_error();
|
||||
Invalidate();
|
||||
return GetFailureReason(sslerr);
|
||||
}
|
||||
}
|
||||
|
||||
LOG(INFO) << "Handshake succeeded.";
|
||||
LOG(INFO) << RoleToString() << "Handshake succeeded.";
|
||||
return TlsError::Success;
|
||||
}
|
||||
|
||||
|
@ -311,7 +318,7 @@ std::vector<uint8_t> TlsConnectionImpl::ReadFully(size_t size) {
|
|||
bool TlsConnectionImpl::ReadFully(void* buf, size_t size) {
|
||||
CHECK_GT(size, 0U);
|
||||
if (!ssl_) {
|
||||
LOG(ERROR) << "Tried to read on a null SSL connection";
|
||||
LOG(ERROR) << RoleToString() << "Tried to read on a null SSL connection";
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -321,7 +328,7 @@ bool TlsConnectionImpl::ReadFully(void* buf, size_t size) {
|
|||
int bytes_read =
|
||||
SSL_read(ssl_.get(), p8 + offset, std::min(static_cast<size_t>(INT_MAX), size));
|
||||
if (bytes_read <= 0) {
|
||||
LOG(WARNING) << "SSL_read failed [" << SSLErrorString() << "]";
|
||||
LOG(ERROR) << RoleToString() << "SSL_read failed [" << SSLErrorString() << "]";
|
||||
return false;
|
||||
}
|
||||
size -= bytes_read;
|
||||
|
@ -333,7 +340,7 @@ bool TlsConnectionImpl::ReadFully(void* buf, size_t size) {
|
|||
bool TlsConnectionImpl::WriteFully(std::string_view data) {
|
||||
CHECK(!data.empty());
|
||||
if (!ssl_) {
|
||||
LOG(ERROR) << "Tried to read on a null SSL connection";
|
||||
LOG(ERROR) << RoleToString() << "Tried to read on a null SSL connection";
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -341,7 +348,7 @@ bool TlsConnectionImpl::WriteFully(std::string_view data) {
|
|||
int bytes_out = SSL_write(ssl_.get(), data.data(),
|
||||
std::min(static_cast<size_t>(INT_MAX), data.size()));
|
||||
if (bytes_out <= 0) {
|
||||
LOG(WARNING) << "SSL_write failed [" << SSLErrorString() << "]";
|
||||
LOG(ERROR) << RoleToString() << "SSL_write failed [" << SSLErrorString() << "]";
|
||||
return false;
|
||||
}
|
||||
data = data.substr(bytes_out);
|
||||
|
|
Loading…
Reference in a new issue