Add ReadOrAgain and WriteOrAgain methods to FuseMessage.
These methods return kAgain if operation cannot be done without blocking the current thread. The CL also introduecs new helper function SetupMessageSockets so that FuseMessages are always transfered via sockets that save message boundaries. Bug: 34903085 Test: libappfuse_test Change-Id: I34544372cc1b0c7bc9622e581ae16c018a123fa9
This commit is contained in:
parent
cc9d94ce04
commit
57b780fbc3
5 changed files with 149 additions and 91 deletions
|
@ -23,77 +23,132 @@
|
|||
#include <algorithm>
|
||||
#include <type_traits>
|
||||
|
||||
#include <sys/socket.h>
|
||||
|
||||
#include <android-base/file.h>
|
||||
#include <android-base/logging.h>
|
||||
#include <android-base/macros.h>
|
||||
|
||||
namespace android {
|
||||
namespace fuse {
|
||||
|
||||
static_assert(
|
||||
std::is_standard_layout<FuseBuffer>::value,
|
||||
"FuseBuffer must be standard layout union.");
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
bool FuseMessage<T>::CheckHeaderLength(const char* name) const {
|
||||
const auto& header = static_cast<const T*>(this)->header;
|
||||
if (header.len >= sizeof(header) && header.len <= sizeof(T)) {
|
||||
bool CheckHeaderLength(const FuseMessage<T>* self, const char* name) {
|
||||
const auto& header = static_cast<const T*>(self)->header;
|
||||
if (header.len >= sizeof(header) && header.len <= sizeof(T)) {
|
||||
return true;
|
||||
} else {
|
||||
LOG(ERROR) << "Invalid header length is found in " << name << ": " << header.len;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ResultOrAgain ReadInternal(FuseMessage<T>* self, int fd, int sockflag) {
|
||||
char* const buf = reinterpret_cast<char*>(self);
|
||||
const ssize_t result = sockflag ? TEMP_FAILURE_RETRY(recv(fd, buf, sizeof(T), sockflag))
|
||||
: TEMP_FAILURE_RETRY(read(fd, buf, sizeof(T)));
|
||||
|
||||
switch (result) {
|
||||
case 0:
|
||||
// Expected EOF.
|
||||
return ResultOrAgain::kFailure;
|
||||
case -1:
|
||||
if (errno == EAGAIN) {
|
||||
return ResultOrAgain::kAgain;
|
||||
}
|
||||
PLOG(ERROR) << "Failed to read a FUSE message";
|
||||
return ResultOrAgain::kFailure;
|
||||
}
|
||||
|
||||
const auto& header = static_cast<const T*>(self)->header;
|
||||
if (result < static_cast<ssize_t>(sizeof(header))) {
|
||||
LOG(ERROR) << "Read bytes " << result << " are shorter than header size " << sizeof(header);
|
||||
return ResultOrAgain::kFailure;
|
||||
}
|
||||
|
||||
if (!CheckHeaderLength<T>(self, "Read")) {
|
||||
return ResultOrAgain::kFailure;
|
||||
}
|
||||
|
||||
if (static_cast<uint32_t>(result) != header.len) {
|
||||
LOG(ERROR) << "Read bytes " << result << " are different from header.len " << header.len;
|
||||
return ResultOrAgain::kFailure;
|
||||
}
|
||||
|
||||
return ResultOrAgain::kSuccess;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ResultOrAgain WriteInternal(const FuseMessage<T>* self, int fd, int sockflag) {
|
||||
if (!CheckHeaderLength<T>(self, "Write")) {
|
||||
return ResultOrAgain::kFailure;
|
||||
}
|
||||
|
||||
const char* const buf = reinterpret_cast<const char*>(self);
|
||||
const auto& header = static_cast<const T*>(self)->header;
|
||||
const int result = sockflag ? TEMP_FAILURE_RETRY(send(fd, buf, header.len, sockflag))
|
||||
: TEMP_FAILURE_RETRY(write(fd, buf, header.len));
|
||||
|
||||
if (result == -1) {
|
||||
if (errno == EAGAIN) {
|
||||
return ResultOrAgain::kAgain;
|
||||
}
|
||||
PLOG(ERROR) << "Failed to write a FUSE message";
|
||||
return ResultOrAgain::kFailure;
|
||||
}
|
||||
|
||||
CHECK(static_cast<uint32_t>(result) == header.len);
|
||||
return ResultOrAgain::kSuccess;
|
||||
}
|
||||
}
|
||||
|
||||
static_assert(std::is_standard_layout<FuseBuffer>::value,
|
||||
"FuseBuffer must be standard layout union.");
|
||||
|
||||
bool SetupMessageSockets(base::unique_fd (*result)[2]) {
|
||||
base::unique_fd fds[2];
|
||||
{
|
||||
int raw_fds[2];
|
||||
if (socketpair(AF_UNIX, SOCK_SEQPACKET, 0, raw_fds) == -1) {
|
||||
PLOG(ERROR) << "Failed to create sockets for proxy";
|
||||
return false;
|
||||
}
|
||||
fds[0].reset(raw_fds[0]);
|
||||
fds[1].reset(raw_fds[1]);
|
||||
}
|
||||
|
||||
constexpr int kMaxMessageSize = sizeof(FuseBuffer);
|
||||
if (setsockopt(fds[0], SOL_SOCKET, SO_SNDBUF, &kMaxMessageSize, sizeof(int)) != 0 ||
|
||||
setsockopt(fds[1], SOL_SOCKET, SO_SNDBUF, &kMaxMessageSize, sizeof(int)) != 0) {
|
||||
PLOG(ERROR) << "Failed to update buffer size for socket";
|
||||
return false;
|
||||
}
|
||||
|
||||
(*result)[0] = std::move(fds[0]);
|
||||
(*result)[1] = std::move(fds[1]);
|
||||
return true;
|
||||
} else {
|
||||
LOG(ERROR) << "Invalid header length is found in " << name << ": " <<
|
||||
header.len;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool FuseMessage<T>::Read(int fd) {
|
||||
char* const buf = reinterpret_cast<char*>(this);
|
||||
const ssize_t result = TEMP_FAILURE_RETRY(::read(fd, buf, sizeof(T)));
|
||||
if (result < 0) {
|
||||
PLOG(ERROR) << "Failed to read a FUSE message";
|
||||
return false;
|
||||
}
|
||||
return ReadInternal(this, fd, 0) == ResultOrAgain::kSuccess;
|
||||
}
|
||||
|
||||
const auto& header = static_cast<const T*>(this)->header;
|
||||
if (result < static_cast<ssize_t>(sizeof(header))) {
|
||||
LOG(ERROR) << "Read bytes " << result << " are shorter than header size " <<
|
||||
sizeof(header);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!CheckHeaderLength("Read")) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (static_cast<uint32_t>(result) > header.len) {
|
||||
LOG(ERROR) << "Read bytes " << result << " are longer than header.len " <<
|
||||
header.len;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!base::ReadFully(fd, buf + result, header.len - result)) {
|
||||
PLOG(ERROR) << "ReadFully failed";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
template <typename T>
|
||||
ResultOrAgain FuseMessage<T>::ReadOrAgain(int fd) {
|
||||
return ReadInternal(this, fd, MSG_DONTWAIT);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool FuseMessage<T>::Write(int fd) const {
|
||||
if (!CheckHeaderLength("Write")) {
|
||||
return false;
|
||||
}
|
||||
return WriteInternal(this, fd, 0) == ResultOrAgain::kSuccess;
|
||||
}
|
||||
|
||||
const char* const buf = reinterpret_cast<const char*>(this);
|
||||
const auto& header = static_cast<const T*>(this)->header;
|
||||
if (!base::WriteFully(fd, buf, header.len)) {
|
||||
PLOG(ERROR) << "WriteFully failed";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
template <typename T>
|
||||
ResultOrAgain FuseMessage<T>::WriteOrAgain(int fd) const {
|
||||
return WriteInternal(this, fd, MSG_DONTWAIT);
|
||||
}
|
||||
|
||||
template class FuseMessage<FuseRequest>;
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#ifndef ANDROID_LIBAPPFUSE_FUSEBUFFER_H_
|
||||
#define ANDROID_LIBAPPFUSE_FUSEBUFFER_H_
|
||||
|
||||
#include <android-base/unique_fd.h>
|
||||
#include <linux/fuse.h>
|
||||
|
||||
namespace android {
|
||||
|
@ -28,12 +29,24 @@ constexpr size_t kFuseMaxWrite = 256 * 1024;
|
|||
constexpr size_t kFuseMaxRead = 128 * 1024;
|
||||
constexpr int32_t kFuseSuccess = 0;
|
||||
|
||||
// Setup sockets to transfer FuseMessage.
|
||||
bool SetupMessageSockets(base::unique_fd (*sockets)[2]);
|
||||
|
||||
enum class ResultOrAgain {
|
||||
kSuccess,
|
||||
kFailure,
|
||||
kAgain,
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
class FuseMessage {
|
||||
public:
|
||||
bool Read(int fd);
|
||||
bool Write(int fd) const;
|
||||
private:
|
||||
ResultOrAgain ReadOrAgain(int fd);
|
||||
ResultOrAgain WriteOrAgain(int fd) const;
|
||||
|
||||
private:
|
||||
bool CheckHeaderLength(const char* name) const;
|
||||
};
|
||||
|
||||
|
@ -54,7 +67,7 @@ struct FuseRequest : public FuseMessage<FuseRequest> {
|
|||
// for FUSE_READ
|
||||
fuse_read_in read_in;
|
||||
// for FUSE_LOOKUP
|
||||
char lookup_name[0];
|
||||
char lookup_name[kFuseMaxWrite];
|
||||
};
|
||||
void Reset(uint32_t data_length, uint32_t opcode, uint64_t unique);
|
||||
};
|
||||
|
|
|
@ -109,10 +109,7 @@ class FuseAppLoopTest : public ::testing::Test {
|
|||
|
||||
void SetUp() override {
|
||||
base::SetMinimumLogSeverity(base::VERBOSE);
|
||||
int sockets[2];
|
||||
ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_SEQPACKET, 0, sockets));
|
||||
sockets_[0].reset(sockets[0]);
|
||||
sockets_[1].reset(sockets[1]);
|
||||
ASSERT_TRUE(SetupMessageSockets(&sockets_));
|
||||
thread_ = std::thread([this] {
|
||||
StartFuseAppLoop(sockets_[1].release(), &callback_);
|
||||
});
|
||||
|
|
|
@ -50,15 +50,8 @@ class FuseBridgeLoopTest : public ::testing::Test {
|
|||
|
||||
void SetUp() override {
|
||||
base::SetMinimumLogSeverity(base::VERBOSE);
|
||||
int dev_sockets[2];
|
||||
int proxy_sockets[2];
|
||||
ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_SEQPACKET, 0, dev_sockets));
|
||||
ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_SEQPACKET, 0, proxy_sockets));
|
||||
dev_sockets_[0].reset(dev_sockets[0]);
|
||||
dev_sockets_[1].reset(dev_sockets[1]);
|
||||
proxy_sockets_[0].reset(proxy_sockets[0]);
|
||||
proxy_sockets_[1].reset(proxy_sockets[1]);
|
||||
|
||||
ASSERT_TRUE(SetupMessageSockets(&dev_sockets_));
|
||||
ASSERT_TRUE(SetupMessageSockets(&proxy_sockets_));
|
||||
thread_ = std::thread([this] {
|
||||
StartFuseBridgeLoop(
|
||||
dev_sockets_[1].release(), proxy_sockets_[0].release(), &callback_);
|
||||
|
|
|
@ -112,30 +112,6 @@ TEST(FuseMessageTest, Write_TooShort) {
|
|||
TestWriteInvalidLength(sizeof(fuse_in_header) - 1);
|
||||
}
|
||||
|
||||
TEST(FuseMessageTest, ShortWriteAndRead) {
|
||||
int raw_fds[2];
|
||||
ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_STREAM, 0, raw_fds));
|
||||
|
||||
android::base::unique_fd fds[2];
|
||||
fds[0].reset(raw_fds[0]);
|
||||
fds[1].reset(raw_fds[1]);
|
||||
|
||||
const int send_buffer_size = 1024;
|
||||
ASSERT_EQ(0, setsockopt(fds[0], SOL_SOCKET, SO_SNDBUF, &send_buffer_size,
|
||||
sizeof(int)));
|
||||
|
||||
bool succeed = false;
|
||||
const int sender_fd = fds[0].get();
|
||||
std::thread thread([sender_fd, &succeed] {
|
||||
FuseRequest request;
|
||||
request.header.len = 1024 * 4;
|
||||
succeed = request.Write(sender_fd);
|
||||
});
|
||||
thread.detach();
|
||||
FuseRequest request;
|
||||
ASSERT_TRUE(request.Read(fds[1]));
|
||||
}
|
||||
|
||||
TEST(FuseResponseTest, Reset) {
|
||||
FuseResponse response;
|
||||
// Write 1 to the first ten bytes.
|
||||
|
@ -211,5 +187,29 @@ TEST(FuseBufferTest, HandleNotImpl) {
|
|||
EXPECT_EQ(-ENOSYS, buffer.response.header.error);
|
||||
}
|
||||
|
||||
TEST(SetupMessageSocketsTest, Stress) {
|
||||
constexpr int kCount = 1000;
|
||||
|
||||
FuseRequest request;
|
||||
request.header.len = sizeof(FuseRequest);
|
||||
|
||||
base::unique_fd fds[2];
|
||||
SetupMessageSockets(&fds);
|
||||
|
||||
std::thread thread([&fds] {
|
||||
FuseRequest request;
|
||||
for (int i = 0; i < kCount; ++i) {
|
||||
ASSERT_TRUE(request.Read(fds[1]));
|
||||
usleep(1000);
|
||||
}
|
||||
});
|
||||
|
||||
for (int i = 0; i < kCount; ++i) {
|
||||
ASSERT_TRUE(request.Write(fds[0]));
|
||||
}
|
||||
|
||||
thread.join();
|
||||
}
|
||||
|
||||
} // namespace fuse
|
||||
} // namespace android
|
||||
|
|
Loading…
Reference in a new issue