fastboot: get rid of manual transport memory management

Existing code has transport memory leaks. Use smart pointers
for transport to get rid of those cases and manual memory
management

Test: atest fastboot_test
Test: manually checked transport isn't leaking anymore
Bug: 296629925
Change-Id: Ifdf162d5084f61ae5c1d2b56a897464af58100da
Signed-off-by: Dmitrii Merkurev <dimorinny@google.com>
This commit is contained in:
Dmitrii Merkurev 2023-09-03 17:30:46 +01:00
parent b5f51166e7
commit 0b627d92c4
10 changed files with 92 additions and 77 deletions

View file

@ -350,23 +350,22 @@ Result<NetworkSerial, FastbootError> ParseNetworkSerial(const std::string& seria
//
// The returned Transport is a singleton, so multiple calls to this function will return the same
// object, and the caller should not attempt to delete the returned Transport.
static Transport* open_device(const char* local_serial, bool wait_for_device = true,
bool announce = true) {
static std::unique_ptr<Transport> open_device(const char* local_serial,
bool wait_for_device = true,
bool announce = true) {
const Result<NetworkSerial, FastbootError> network_serial = ParseNetworkSerial(local_serial);
Transport* transport = nullptr;
std::unique_ptr<Transport> transport;
while (true) {
if (network_serial.ok()) {
std::string error;
if (network_serial->protocol == Socket::Protocol::kTcp) {
transport = tcp::Connect(network_serial->address, network_serial->port, &error)
.release();
transport = tcp::Connect(network_serial->address, network_serial->port, &error);
} else if (network_serial->protocol == Socket::Protocol::kUdp) {
transport = udp::Connect(network_serial->address, network_serial->port, &error)
.release();
transport = udp::Connect(network_serial->address, network_serial->port, &error);
}
if (transport == nullptr && announce) {
if (!transport && announce) {
LOG(ERROR) << "error: " << error;
}
} else if (network_serial.error().code() ==
@ -378,12 +377,12 @@ static Transport* open_device(const char* local_serial, bool wait_for_device = t
Expect(network_serial);
}
if (transport != nullptr) {
if (transport) {
return transport;
}
if (!wait_for_device) {
return nullptr;
return transport;
}
if (announce) {
@ -394,9 +393,9 @@ static Transport* open_device(const char* local_serial, bool wait_for_device = t
}
}
static Transport* NetworkDeviceConnected(bool print = false) {
Transport* transport = nullptr;
Transport* result = nullptr;
static std::unique_ptr<Transport> NetworkDeviceConnected(bool print = false) {
std::unique_ptr<Transport> transport;
std::unique_ptr<Transport> result;
ConnectedDevicesStorage storage;
std::set<std::string> devices;
@ -409,11 +408,11 @@ static Transport* NetworkDeviceConnected(bool print = false) {
transport = open_device(device.c_str(), false, false);
if (print) {
PrintDevice(device.c_str(), transport == nullptr ? "offline" : "fastboot");
PrintDevice(device.c_str(), transport ? "offline" : "fastboot");
}
if (transport != nullptr) {
result = transport;
if (transport) {
result = std::move(transport);
}
}
@ -431,21 +430,21 @@ static Transport* NetworkDeviceConnected(bool print = false) {
//
// The returned Transport is a singleton, so multiple calls to this function will return the same
// object, and the caller should not attempt to delete the returned Transport.
static Transport* open_device() {
static std::unique_ptr<Transport> open_device() {
if (serial != nullptr) {
return open_device(serial);
}
bool announce = true;
Transport* transport = nullptr;
std::unique_ptr<Transport> transport;
while (true) {
transport = usb_open(match_fastboot(nullptr));
if (transport != nullptr) {
if (transport) {
return transport;
}
transport = NetworkDeviceConnected();
if (transport != nullptr) {
if (transport) {
return transport;
}
@ -455,6 +454,8 @@ static Transport* open_device() {
}
std::this_thread::sleep_for(std::chrono::seconds(1));
}
return transport;
}
static int Connect(int argc, char* argv[]) {
@ -466,8 +467,7 @@ static int Connect(int argc, char* argv[]) {
const char* local_serial = *argv;
Expect(ParseNetworkSerial(local_serial));
const Transport* transport = open_device(local_serial, false);
if (transport == nullptr) {
if (!open_device(local_serial, false)) {
return 1;
}
@ -531,6 +531,7 @@ static void list_devices() {
usb_open(list_devices_callback);
NetworkDeviceConnected(/* print */ true);
}
void syntax_error(const char* fmt, ...) {
fprintf(stderr, "fastboot: usage: ");
@ -1547,9 +1548,7 @@ bool is_userspace_fastboot() {
void reboot_to_userspace_fastboot() {
fb->RebootTo("fastboot");
auto* old_transport = fb->set_transport(nullptr);
delete old_transport;
fb->set_transport(nullptr);
// Give the current connection time to close.
std::this_thread::sleep_for(std::chrono::seconds(1));
@ -2377,8 +2376,8 @@ int FastBootTool::Main(int argc, char* argv[]) {
return show_help();
}
Transport* transport = open_device();
if (transport == nullptr) {
std::unique_ptr<Transport> transport = open_device();
if (!transport) {
return 1;
}
fastboot::DriverCallbacks driver_callbacks = {
@ -2388,7 +2387,7 @@ int FastBootTool::Main(int argc, char* argv[]) {
.text = TextMessage,
};
fastboot::FastBootDriver fastboot_driver(transport, driver_callbacks, false);
fastboot::FastBootDriver fastboot_driver(std::move(transport), driver_callbacks, false);
fb = &fastboot_driver;
fp->fb = &fastboot_driver;
@ -2633,9 +2632,6 @@ int FastBootTool::Main(int argc, char* argv[]) {
}
fprintf(stderr, "Finished. Total time: %.3fs\n", (now() - start));
auto* old_transport = fb->set_transport(nullptr);
delete old_transport;
return 0;
}

View file

@ -58,9 +58,10 @@ using namespace android::storage_literals;
namespace fastboot {
/*************************** PUBLIC *******************************/
FastBootDriver::FastBootDriver(Transport* transport, DriverCallbacks driver_callbacks,
FastBootDriver::FastBootDriver(std::unique_ptr<Transport> transport,
DriverCallbacks driver_callbacks,
bool no_checks)
: transport_(transport),
: transport_(std::move(transport)),
prolog_(std::move(driver_callbacks.prolog)),
epilog_(std::move(driver_callbacks.epilog)),
info_(std::move(driver_callbacks.info)),
@ -627,9 +628,8 @@ int FastBootDriver::SparseWriteCallback(std::vector<char>& tpbuf, const char* da
return 0;
}
Transport* FastBootDriver::set_transport(Transport* transport) {
std::swap(transport_, transport);
return transport;
void FastBootDriver::set_transport(std::unique_ptr<Transport> transport) {
transport_ = std::move(transport);
}
} // End namespace fastboot

View file

@ -30,6 +30,7 @@
#include <deque>
#include <functional>
#include <limits>
#include <memory>
#include <string>
#include <vector>
@ -63,7 +64,7 @@ class FastBootDriver : public IFastBootDriver {
static constexpr uint32_t MAX_DOWNLOAD_SIZE = std::numeric_limits<uint32_t>::max();
static constexpr size_t TRANSPORT_CHUNK_SIZE = 1024;
FastBootDriver(Transport* transport, DriverCallbacks driver_callbacks = {},
FastBootDriver(std::unique_ptr<Transport> transport, DriverCallbacks driver_callbacks = {},
bool no_checks = false);
~FastBootDriver();
@ -124,9 +125,7 @@ class FastBootDriver : public IFastBootDriver {
std::string Error();
RetCode WaitForDisconnect() override;
// Note: set_transport will return the previous transport.
Transport* set_transport(Transport* transport);
Transport* transport() const { return transport_; }
void set_transport(std::unique_ptr<Transport> transport);
RetCode RawCommand(const std::string& cmd, const std::string& message,
std::string* response = nullptr, std::vector<std::string>* info = nullptr,
@ -143,7 +142,7 @@ class FastBootDriver : public IFastBootDriver {
std::string ErrnoStr(const std::string& msg);
Transport* transport_;
std::unique_ptr<Transport> transport_;
private:
RetCode SendBuffer(android::base::borrowed_fd fd, size_t size);

View file

@ -16,6 +16,7 @@
#include "fastboot_driver.h"
#include <memory>
#include <optional>
#include <gtest/gtest.h>
@ -30,13 +31,14 @@ class DriverTest : public ::testing::Test {
};
TEST_F(DriverTest, GetVar) {
MockTransport transport;
FastBootDriver driver(&transport);
std::unique_ptr<MockTransport> transport_pointer = std::make_unique<MockTransport>();
MockTransport* transport = transport_pointer.get();
FastBootDriver driver(std::move(transport_pointer));
EXPECT_CALL(transport, Write(_, _))
EXPECT_CALL(*transport, Write(_, _))
.With(AllArgs(RawData("getvar:version")))
.WillOnce(ReturnArg<1>());
EXPECT_CALL(transport, Read(_, _)).WillOnce(Invoke(CopyData("OKAY0.4")));
EXPECT_CALL(*transport, Read(_, _)).WillOnce(Invoke(CopyData("OKAY0.4")));
std::string output;
ASSERT_EQ(driver.GetVar("version", &output), SUCCESS) << driver.Error();
@ -44,14 +46,15 @@ TEST_F(DriverTest, GetVar) {
}
TEST_F(DriverTest, InfoMessage) {
MockTransport transport;
FastBootDriver driver(&transport);
std::unique_ptr<MockTransport> transport_pointer = std::make_unique<MockTransport>();
MockTransport* transport = transport_pointer.get();
FastBootDriver driver(std::move(transport_pointer));
EXPECT_CALL(transport, Write(_, _))
EXPECT_CALL(*transport, Write(_, _))
.With(AllArgs(RawData("oem dmesg")))
.WillOnce(ReturnArg<1>());
EXPECT_CALL(transport, Read(_, _)).WillOnce(Invoke(CopyData("INFOthis is an info line")));
EXPECT_CALL(transport, Read(_, _)).WillOnce(Invoke(CopyData("OKAY")));
EXPECT_CALL(*transport, Read(_, _)).WillOnce(Invoke(CopyData("INFOthis is an info line")));
EXPECT_CALL(*transport, Read(_, _)).WillOnce(Invoke(CopyData("OKAY")));
std::vector<std::string> info;
ASSERT_EQ(driver.RawCommand("oem dmesg", "", nullptr, &info), SUCCESS) << driver.Error();
@ -60,28 +63,29 @@ TEST_F(DriverTest, InfoMessage) {
}
TEST_F(DriverTest, TextMessage) {
MockTransport transport;
std::string text;
std::unique_ptr<MockTransport> transport_pointer = std::make_unique<MockTransport>();
MockTransport* transport = transport_pointer.get();
DriverCallbacks callbacks{[](const std::string&) {}, [](int) {}, [](const std::string&) {},
[&text](const std::string& extra_text) { text += extra_text; }};
FastBootDriver driver(&transport, callbacks);
FastBootDriver driver(std::move(transport_pointer), callbacks);
EXPECT_CALL(transport, Write(_, _))
EXPECT_CALL(*transport, Write(_, _))
.With(AllArgs(RawData("oem trusty runtest trusty.hwaes.bench")))
.WillOnce(ReturnArg<1>());
EXPECT_CALL(transport, Read(_, _)).WillOnce(Invoke(CopyData("TEXTthis is a text line")));
EXPECT_CALL(transport, Read(_, _))
EXPECT_CALL(*transport, Read(_, _)).WillOnce(Invoke(CopyData("TEXTthis is a text line")));
EXPECT_CALL(*transport, Read(_, _))
.WillOnce(Invoke(
CopyData("TEXT, albeit very long and split over multiple TEXT messages.")));
EXPECT_CALL(transport, Read(_, _))
EXPECT_CALL(*transport, Read(_, _))
.WillOnce(Invoke(CopyData("TEXT Indeed we can do that now with a TEXT message whenever "
"we feel like it.")));
EXPECT_CALL(transport, Read(_, _))
EXPECT_CALL(*transport, Read(_, _))
.WillOnce(Invoke(CopyData("TEXT Isn't that truly super cool?")));
EXPECT_CALL(transport, Read(_, _)).WillOnce(Invoke(CopyData("OKAY")));
EXPECT_CALL(*transport, Read(_, _)).WillOnce(Invoke(CopyData("OKAY")));
std::vector<std::string> info;
ASSERT_EQ(driver.RawCommand("oem trusty runtest trusty.hwaes.bench", "", nullptr, &info),

View file

@ -128,7 +128,7 @@ void FastBootTest::SetUp() {
return MatchFastboot(info, device_serial);
};
for (int i = 0; i < MAX_USB_TRIES && !transport; i++) {
std::unique_ptr<UsbTransport> usb(usb_open(matcher, USB_TIMEOUT));
std::unique_ptr<UsbTransport> usb = usb_open(matcher, USB_TIMEOUT);
if (usb)
transport = std::unique_ptr<TransportSniffer>(
new TransportSniffer(std::move(usb), serial_port));
@ -143,7 +143,7 @@ void FastBootTest::SetUp() {
} else {
ASSERT_EQ(device_path, cb_scratch); // The path can not change
}
fb = std::unique_ptr<FastBootDriver>(new FastBootDriver(transport.get(), {}, true));
fb = std::unique_ptr<FastBootDriver>(new FastBootDriver(std::move(transport), {}, true));
// No error checking since non-A/B devices may not support the command
fb->GetVar("current-slot", &initial_slot);
}
@ -200,7 +200,7 @@ void FastBootTest::ReconnectFastbootDevice() {
if (IsFastbootOverTcp()) {
ConnectTcpFastbootDevice();
device_path = cb_scratch;
fb = std::unique_ptr<FastBootDriver>(new FastBootDriver(transport.get(), {}, true));
fb = std::unique_ptr<FastBootDriver>(new FastBootDriver(std::move(transport), {}, true));
return;
}
@ -212,7 +212,7 @@ void FastBootTest::ReconnectFastbootDevice() {
return MatchFastboot(info, device_serial);
};
while (!transport) {
std::unique_ptr<UsbTransport> usb(usb_open(matcher, USB_TIMEOUT));
std::unique_ptr<UsbTransport> usb = usb_open(matcher, USB_TIMEOUT);
if (usb) {
transport = std::unique_ptr<TransportSniffer>(
new TransportSniffer(std::move(usb), serial_port));
@ -220,7 +220,7 @@ void FastBootTest::ReconnectFastbootDevice() {
std::this_thread::sleep_for(1s);
}
device_path = cb_scratch;
fb = std::unique_ptr<FastBootDriver>(new FastBootDriver(transport.get(), {}, true));
fb = std::unique_ptr<FastBootDriver>(new FastBootDriver(std::move(transport), {}, true));
}
void FastBootTest::SetLockState(bool unlock, bool assert_change) {

View file

@ -166,16 +166,15 @@ TEST(USBFunctionality, USBConnect) {
const auto matcher = [](usb_ifc_info* info) -> int {
return FastBootTest::MatchFastboot(info, fastboot::FastBootTest::device_serial);
};
Transport* transport = nullptr;
std::unique_ptr<Transport> transport;
for (int i = 0; i < FastBootTest::MAX_USB_TRIES && !transport; i++) {
transport = usb_open(matcher);
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
ASSERT_NE(transport, nullptr) << "Could not find the fastboot device after: "
<< 10 * FastBootTest::MAX_USB_TRIES << "ms";
ASSERT_NE(transport.get(), nullptr) << "Could not find the fastboot device after: "
<< 10 * FastBootTest::MAX_USB_TRIES << "ms";
if (transport) {
transport->Close();
delete transport;
}
}
@ -1897,7 +1896,7 @@ int main(int argc, char** argv) {
const auto matcher = [](usb_ifc_info* info) -> int {
return fastboot::FastBootTest::MatchFastboot(info, fastboot::FastBootTest::device_serial);
};
Transport* transport = nullptr;
std::unique_ptr<Transport> transport;
while (!transport) {
transport = usb_open(matcher);
std::this_thread::sleep_for(std::chrono::milliseconds(10));

View file

@ -29,6 +29,7 @@
#pragma once
#include <functional>
#include <memory>
#include "transport.h"
@ -66,4 +67,4 @@ class UsbTransport : public Transport {
typedef std::function<int(usb_ifc_info*)> ifc_match_func;
// 0 is non blocking
UsbTransport* usb_open(ifc_match_func callback, uint32_t timeout_ms = 0);
std::unique_ptr<UsbTransport> usb_open(ifc_match_func callback, uint32_t timeout_ms = 0);

View file

@ -503,9 +503,15 @@ int LinuxUsbTransport::Reset() {
return 0;
}
UsbTransport* usb_open(ifc_match_func callback, uint32_t timeout_ms) {
std::unique_ptr<UsbTransport> usb_open(ifc_match_func callback, uint32_t timeout_ms) {
std::unique_ptr<UsbTransport> result;
std::unique_ptr<usb_handle> handle = find_usb_device("/sys/bus/usb/devices", callback);
return handle ? new LinuxUsbTransport(std::move(handle), timeout_ms) : nullptr;
if (handle) {
result = std::make_unique<LinuxUsbTransport>(std::move(handle), timeout_ms);
}
return result;
}
/* Wait for the system to notice the device is gone, so that a subsequent

View file

@ -469,16 +469,20 @@ static int init_usb(ifc_match_func callback, std::unique_ptr<usb_handle>* handle
/*
* Definitions of this file's public functions.
*/
UsbTransport* usb_open(ifc_match_func callback, uint32_t timeout_ms) {
std::unique_ptr<UsbTransport> usb_open(ifc_match_func callback, uint32_t timeout_ms) {
std::unique_ptr<UsbTransport> result;
std::unique_ptr<usb_handle> handle;
if (init_usb(callback, &handle) < 0) {
/* Something went wrong initializing USB. */
return nullptr;
return result;
}
return new OsxUsbTransport(std::move(handle), timeout_ms);
if (handle) {
result = std::make_unique<OsxUsbTransport>(std::move(handle), timeout_ms);
}
return result;
}
OsxUsbTransport::~OsxUsbTransport() {

View file

@ -381,7 +381,13 @@ static std::unique_ptr<usb_handle> find_usb_device(ifc_match_func callback) {
return handle;
}
UsbTransport* usb_open(ifc_match_func callback, uint32_t) {
std::unique_ptr<UsbTransport> usb_open(ifc_match_func callback, uint32_t) {
std::unique_ptr<UsbTransport> result;
std::unique_ptr<usb_handle> handle = find_usb_device(callback);
return handle ? new WindowsUsbTransport(std::move(handle)) : nullptr;
if (handle) {
result = std::make_unique<WindowsUsbTransport>(std::move(handle));
}
return result;
}