Netlink socket refactoring

- merge two send() methods into one
- use internal receive buffer instead of asking user to supply one
- move setting sequence number to MessageFactory sending code
- don't limit send function to Kernel as a recipient
- move adding NLM_F_ACK to the caller side
- getSocketPid -> getPid
- unsigned int -> unsigned

One part missing is refactoring receiveAck (b/161389935).

Bug: 162032964
Test: canhalctrl up test virtual vcan3
Change-Id: Ie3d460dbc2ea1251469bf08504cfe2c6e80bbe75
This commit is contained in:
Tomasz Wasilczyk 2020-08-03 15:06:23 -07:00
parent 66fc939023
commit b428f77f6c
6 changed files with 96 additions and 91 deletions

View file

@ -70,7 +70,7 @@ bool setBitrate(std::string ifname, uint32_t bitrate) {
struct can_bittiming bt = {};
bt.bitrate = bitrate;
nl::MessageFactory<struct ifinfomsg> req(RTM_NEWLINK, NLM_F_REQUEST);
nl::MessageFactory<struct ifinfomsg> req(RTM_NEWLINK, NLM_F_REQUEST | NLM_F_ACK);
const auto ifidx = nametoindex(ifname);
if (ifidx == 0) {

View file

@ -63,7 +63,7 @@ bool down(std::string ifname) {
bool add(std::string dev, std::string type) {
nl::MessageFactory<struct ifinfomsg> req(RTM_NEWLINK,
NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL);
NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL | NLM_F_ACK);
req.addattr(IFLA_IFNAME, dev);
{
@ -76,7 +76,7 @@ bool add(std::string dev, std::string type) {
}
bool del(std::string dev) {
nl::MessageFactory<struct ifinfomsg> req(RTM_DELLINK, NLM_F_REQUEST);
nl::MessageFactory<struct ifinfomsg> req(RTM_DELLINK, NLM_F_REQUEST | NLM_F_ACK);
req.addattr(IFLA_IFNAME, dev);
nl::Socket sock(NETLINK_ROUTE);

View file

@ -34,7 +34,7 @@ bool add(const std::string& eth, const std::string& vlan, uint16_t id) {
}
nl::MessageFactory<struct ifinfomsg> req(RTM_NEWLINK,
NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL);
NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL | NLM_F_ACK);
req.addattr(IFLA_IFNAME, vlan);
req.addattr<uint32_t>(IFLA_LINK, ethidx);

View file

@ -27,7 +27,7 @@ namespace android::nl {
*/
static constexpr bool kSuperVerbose = false;
Socket::Socket(int protocol, unsigned int pid, uint32_t groups) : mProtocol(protocol) {
Socket::Socket(int protocol, unsigned pid, uint32_t groups) : mProtocol(protocol) {
mFd.reset(socket(AF_NETLINK, SOCK_RAW, protocol));
if (!mFd.ok()) {
PLOG(ERROR) << "Can't open Netlink socket";
@ -47,83 +47,60 @@ Socket::Socket(int protocol, unsigned int pid, uint32_t groups) : mProtocol(prot
}
}
bool Socket::send(nlmsghdr* nlmsg, size_t totalLen) {
if constexpr (kSuperVerbose) {
nlmsg->nlmsg_seq = mSeq;
LOG(VERBOSE) << (mFailed ? "(not) " : "")
<< "sending Netlink message: " << toString({nlmsg, totalLen}, mProtocol);
}
if (mFailed) return false;
nlmsg->nlmsg_pid = 0; // kernel
nlmsg->nlmsg_seq = mSeq++;
nlmsg->nlmsg_flags |= NLM_F_ACK;
iovec iov = {nlmsg, nlmsg->nlmsg_len};
sockaddr_nl sa = {};
sa.nl_family = AF_NETLINK;
msghdr msg = {};
msg.msg_name = &sa;
msg.msg_namelen = sizeof(sa);
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
if (sendmsg(mFd.get(), &msg, 0) < 0) {
PLOG(ERROR) << "Can't send Netlink message";
return false;
}
return true;
}
bool Socket::send(const Buffer<nlmsghdr>& msg, const sockaddr_nl& sa) {
if constexpr (kSuperVerbose) {
LOG(VERBOSE) << (mFailed ? "(not) " : "")
<< "sending Netlink message: " << toString(msg, mProtocol);
LOG(VERBOSE) << (mFailed ? "(not) " : "") << "sending Netlink message (" //
<< msg->nlmsg_pid << " -> " << sa.nl_pid << "): " << toString(msg, mProtocol);
}
if (mFailed) return false;
mSeq = msg->nlmsg_seq;
const auto rawMsg = msg.getRaw();
const auto bytesSent = sendto(mFd.get(), rawMsg.ptr(), rawMsg.len(), 0,
reinterpret_cast<const sockaddr*>(&sa), sizeof(sa));
if (bytesSent < 0) {
PLOG(ERROR) << "Can't send Netlink message";
return false;
} else if (size_t(bytesSent) != rawMsg.len()) {
LOG(ERROR) << "Can't send Netlink message: truncated message";
return false;
}
return true;
}
std::optional<Buffer<nlmsghdr>> Socket::receive(void* buf, size_t bufLen) {
sockaddr_nl sa = {};
return receive(buf, bufLen, sa);
std::optional<Buffer<nlmsghdr>> Socket::receive(size_t maxSize) {
return receiveFrom(maxSize).first;
}
std::optional<Buffer<nlmsghdr>> Socket::receive(void* buf, size_t bufLen, sockaddr_nl& sa) {
if (mFailed) return std::nullopt;
std::pair<std::optional<Buffer<nlmsghdr>>, sockaddr_nl> Socket::receiveFrom(size_t maxSize) {
if (mFailed) return {std::nullopt, {}};
socklen_t saLen = sizeof(sa);
if (bufLen == 0) {
LOG(ERROR) << "Receive buffer has zero size!";
return std::nullopt;
if (maxSize == 0) {
LOG(ERROR) << "Maximum receive size should not be zero";
return {std::nullopt, {}};
}
const auto bytesReceived =
recvfrom(mFd.get(), buf, bufLen, MSG_TRUNC, reinterpret_cast<sockaddr*>(&sa), &saLen);
if (mReceiveBuffer.size() < maxSize) mReceiveBuffer.resize(maxSize);
sockaddr_nl sa = {};
socklen_t saLen = sizeof(sa);
const auto bytesReceived = recvfrom(mFd.get(), mReceiveBuffer.data(), maxSize, MSG_TRUNC,
reinterpret_cast<sockaddr*>(&sa), &saLen);
if (bytesReceived <= 0) {
PLOG(ERROR) << "Failed to receive Netlink message";
return std::nullopt;
} else if (unsigned(bytesReceived) > bufLen) {
PLOG(ERROR) << "Received data larger than the receive buffer! " << bytesReceived << " > "
<< bufLen;
return std::nullopt;
return {std::nullopt, {}};
} else if (size_t(bytesReceived) > maxSize) {
PLOG(ERROR) << "Received data larger than maximum receive size: " //
<< bytesReceived << " > " << maxSize;
return {std::nullopt, {}};
}
Buffer<nlmsghdr> msg(reinterpret_cast<nlmsghdr*>(buf), bytesReceived);
Buffer<nlmsghdr> msg(reinterpret_cast<nlmsghdr*>(mReceiveBuffer.data()), bytesReceived);
if constexpr (kSuperVerbose) {
LOG(VERBOSE) << "received " << toString(msg, mProtocol);
LOG(VERBOSE) << "received (" << sa.nl_pid << " -> " << msg->nlmsg_pid << "):" //
<< toString(msg, mProtocol);
}
return msg;
return {msg, sa};
}
/* TODO(161389935): Migrate receiveAck to use nlmsg<> internally. Possibly reuse
@ -179,11 +156,11 @@ bool Socket::receiveAck() {
return false;
}
std::optional<unsigned int> Socket::getSocketPid() {
std::optional<unsigned> Socket::getPid() {
sockaddr_nl sa = {};
socklen_t sasize = sizeof(sa);
if (getsockname(mFd.get(), reinterpret_cast<sockaddr*>(&sa), &sasize) < 0) {
PLOG(ERROR) << "Failed to getsockname() for netlink_fd!";
PLOG(ERROR) << "Failed to get PID of Netlink socket";
return std::nullopt;
}
return sa.nl_pid;

View file

@ -35,7 +35,6 @@ void addattr_nest_end(struct nlmsghdr* n, struct nlattr* nest);
} // namespace impl
// TODO(twasilczyk): rename to NetlinkMessage
/**
* Wrapper around NETLINK_ROUTE messages, to build them in C++ style.
*

View file

@ -24,6 +24,7 @@
#include <linux/netlink.h>
#include <optional>
#include <vector>
namespace android::nl {
@ -33,59 +34,88 @@ namespace android::nl {
* This class is not thread safe to use a single instance between multiple threads, but it's fine to
* use multiple instances over multiple threads.
*/
struct Socket {
class Socket {
public:
static constexpr size_t defaultReceiveSize = 8192;
/**
* Socket constructor.
*
* \param protocol the Netlink protocol to use.
* \param pid port id. Default value of 0 allows the kernel to assign us a unique pid. (NOTE:
* this is NOT the same as process id!)
* \param pid port id. Default value of 0 allows the kernel to assign us a unique pid.
* (NOTE: this is NOT the same as process id).
* \param groups Netlink multicast groups to listen to. This is a 32-bit bitfield, where each
* bit is a different group. Default value of 0 means no groups are selected. See man netlink.7
* bit is a different group. Default value of 0 means no groups are selected.
* See man netlink.7.
* for more details.
*/
Socket(int protocol, unsigned int pid = 0, uint32_t groups = 0);
Socket(int protocol, unsigned pid = 0, uint32_t groups = 0);
/**
* Send Netlink message to Kernel. The sequence number will be automatically incremented, and
* the NLM_F_ACK (request ACK) flag will be set.
* Send Netlink message with incremented sequence number to the Kernel.
*
* \param msg Message to send.
* \return true, if succeeded
* \param msg Message to send. Its sequence number will be updated.
* \return true, if succeeded.
*/
template <class T, unsigned int BUFSIZE>
template <class T, unsigned BUFSIZE>
bool send(MessageFactory<T, BUFSIZE>& req) {
if (!req.isGood()) return false;
return send(req.header(), req.totalLength);
sockaddr_nl sa = {};
sa.nl_family = AF_NETLINK;
sa.nl_pid = 0; // Kernel
return send(req, sa);
}
/**
* Send Netlink message. The message will be sent as is, without any modification.
* Send Netlink message with incremented sequence number.
*
* \param msg Message to send. Its sequence number will be updated.
* \param sa Destination address.
* \return true, if succeeded.
*/
template <class T, unsigned BUFSIZE>
bool send(MessageFactory<T, BUFSIZE>& req, const sockaddr_nl& sa) {
if (!req.isGood()) return false;
const auto nlmsg = req.header();
nlmsg->nlmsg_seq = mSeq + 1;
// With MessageFactory<>, we trust nlmsg_len to be correct.
return send({nlmsg, nlmsg->nlmsg_len}, sa);
}
/**
* Send Netlink message.
*
* \param msg Message to send.
* \param sa Destination address.
* \return true, if succeeded
* \return true, if succeeded.
*/
bool send(const Buffer<nlmsghdr>& msg, const sockaddr_nl& sa);
/**
* Receive Netlink data.
* Receive one or multiple Netlink messages.
*
* \param buf buffer to hold message data.
* \param bufLen length of buf.
* \return Buffer with message data, std::nullopt on error.
* WARNING: the underlying buffer is owned by Socket class and the data is valid until the next
* call to the read function or until deallocation of Socket instance.
*
* \param maxSize Maximum total size of received messages
* \return Buffer view with message data, std::nullopt on error.
*/
std::optional<Buffer<nlmsghdr>> receive(void* buf, size_t bufLen);
std::optional<Buffer<nlmsghdr>> receive(size_t maxSize = defaultReceiveSize);
/**
* Receive Netlink data with address info.
* Receive one or multiple Netlink messages and the sender process address.
*
* \param buf buffer to hold message data.
* \param bufLen length of buf.
* \param sa Blank struct that recvfrom will populate with address info.
* \return Buffer with message data, std::nullopt on error.
* WARNING: the underlying buffer is owned by Socket class and the data is valid until the next
* call to the read function or until deallocation of Socket instance.
*
* \param maxSize Maximum total size of received messages
* \return A pair (for use with structured binding) containing:
* - buffer view with message data, std::nullopt on error;
* - sender process address.
*/
std::optional<Buffer<nlmsghdr>> receive(void* buf, size_t bufLen, sockaddr_nl& sa);
std::pair<std::optional<Buffer<nlmsghdr>>, sockaddr_nl> receiveFrom(
size_t maxSize = defaultReceiveSize);
/**
* Receive Netlink ACK message from Kernel.
@ -95,11 +125,11 @@ struct Socket {
bool receiveAck();
/**
* Gets the PID assigned to mFd.
* Fetches the socket PID.
*
* \return pid that mSocket is bound to.
* \return PID that socket is bound to.
*/
std::optional<unsigned int> getSocketPid();
std::optional<unsigned> getPid();
private:
const int mProtocol;
@ -107,8 +137,7 @@ struct Socket {
uint32_t mSeq = 0;
base::unique_fd mFd;
bool mFailed = false;
bool send(nlmsghdr* msg, size_t totalLen);
std::vector<uint8_t> mReceiveBuffer;
DISALLOW_COPY_AND_ASSIGN(Socket);
};