diff --git a/automotive/can/1.0/default/libnetdevice/can.cpp b/automotive/can/1.0/default/libnetdevice/can.cpp index b047bc920b..ab107fdbaa 100644 --- a/automotive/can/1.0/default/libnetdevice/can.cpp +++ b/automotive/can/1.0/default/libnetdevice/can.cpp @@ -70,7 +70,7 @@ bool setBitrate(std::string ifname, uint32_t bitrate) { struct can_bittiming bt = {}; bt.bitrate = bitrate; - nl::MessageFactory req(RTM_NEWLINK, NLM_F_REQUEST); + nl::MessageFactory req(RTM_NEWLINK, NLM_F_REQUEST | NLM_F_ACK); const auto ifidx = nametoindex(ifname); if (ifidx == 0) { diff --git a/automotive/can/1.0/default/libnetdevice/libnetdevice.cpp b/automotive/can/1.0/default/libnetdevice/libnetdevice.cpp index f7f5f4dd43..ed2a51e864 100644 --- a/automotive/can/1.0/default/libnetdevice/libnetdevice.cpp +++ b/automotive/can/1.0/default/libnetdevice/libnetdevice.cpp @@ -63,7 +63,7 @@ bool down(std::string ifname) { bool add(std::string dev, std::string type) { nl::MessageFactory req(RTM_NEWLINK, - NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL); + NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL | NLM_F_ACK); req.addattr(IFLA_IFNAME, dev); { @@ -76,7 +76,7 @@ bool add(std::string dev, std::string type) { } bool del(std::string dev) { - nl::MessageFactory req(RTM_DELLINK, NLM_F_REQUEST); + nl::MessageFactory req(RTM_DELLINK, NLM_F_REQUEST | NLM_F_ACK); req.addattr(IFLA_IFNAME, dev); nl::Socket sock(NETLINK_ROUTE); diff --git a/automotive/can/1.0/default/libnetdevice/vlan.cpp b/automotive/can/1.0/default/libnetdevice/vlan.cpp index 3e07f670df..3f904f0b02 100644 --- a/automotive/can/1.0/default/libnetdevice/vlan.cpp +++ b/automotive/can/1.0/default/libnetdevice/vlan.cpp @@ -34,7 +34,7 @@ bool add(const std::string& eth, const std::string& vlan, uint16_t id) { } nl::MessageFactory req(RTM_NEWLINK, - NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL); + NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL | NLM_F_ACK); req.addattr(IFLA_IFNAME, vlan); req.addattr(IFLA_LINK, ethidx); diff --git a/automotive/can/1.0/default/libnl++/Socket.cpp b/automotive/can/1.0/default/libnl++/Socket.cpp index aac6416f47..56e990c23f 100644 --- a/automotive/can/1.0/default/libnl++/Socket.cpp +++ b/automotive/can/1.0/default/libnl++/Socket.cpp @@ -27,7 +27,7 @@ namespace android::nl { */ static constexpr bool kSuperVerbose = false; -Socket::Socket(int protocol, unsigned int pid, uint32_t groups) : mProtocol(protocol) { +Socket::Socket(int protocol, unsigned pid, uint32_t groups) : mProtocol(protocol) { mFd.reset(socket(AF_NETLINK, SOCK_RAW, protocol)); if (!mFd.ok()) { PLOG(ERROR) << "Can't open Netlink socket"; @@ -47,83 +47,60 @@ Socket::Socket(int protocol, unsigned int pid, uint32_t groups) : mProtocol(prot } } -bool Socket::send(nlmsghdr* nlmsg, size_t totalLen) { - if constexpr (kSuperVerbose) { - nlmsg->nlmsg_seq = mSeq; - LOG(VERBOSE) << (mFailed ? "(not) " : "") - << "sending Netlink message: " << toString({nlmsg, totalLen}, mProtocol); - } - - if (mFailed) return false; - - nlmsg->nlmsg_pid = 0; // kernel - nlmsg->nlmsg_seq = mSeq++; - nlmsg->nlmsg_flags |= NLM_F_ACK; - - iovec iov = {nlmsg, nlmsg->nlmsg_len}; - - sockaddr_nl sa = {}; - sa.nl_family = AF_NETLINK; - - msghdr msg = {}; - msg.msg_name = &sa; - msg.msg_namelen = sizeof(sa); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - if (sendmsg(mFd.get(), &msg, 0) < 0) { - PLOG(ERROR) << "Can't send Netlink message"; - return false; - } - return true; -} - bool Socket::send(const Buffer& msg, const sockaddr_nl& sa) { if constexpr (kSuperVerbose) { - LOG(VERBOSE) << (mFailed ? "(not) " : "") - << "sending Netlink message: " << toString(msg, mProtocol); + LOG(VERBOSE) << (mFailed ? "(not) " : "") << "sending Netlink message (" // + << msg->nlmsg_pid << " -> " << sa.nl_pid << "): " << toString(msg, mProtocol); } - if (mFailed) return false; + + mSeq = msg->nlmsg_seq; const auto rawMsg = msg.getRaw(); const auto bytesSent = sendto(mFd.get(), rawMsg.ptr(), rawMsg.len(), 0, reinterpret_cast(&sa), sizeof(sa)); if (bytesSent < 0) { PLOG(ERROR) << "Can't send Netlink message"; return false; + } else if (size_t(bytesSent) != rawMsg.len()) { + LOG(ERROR) << "Can't send Netlink message: truncated message"; + return false; } return true; } -std::optional> Socket::receive(void* buf, size_t bufLen) { - sockaddr_nl sa = {}; - return receive(buf, bufLen, sa); +std::optional> Socket::receive(size_t maxSize) { + return receiveFrom(maxSize).first; } -std::optional> Socket::receive(void* buf, size_t bufLen, sockaddr_nl& sa) { - if (mFailed) return std::nullopt; +std::pair>, sockaddr_nl> Socket::receiveFrom(size_t maxSize) { + if (mFailed) return {std::nullopt, {}}; - socklen_t saLen = sizeof(sa); - if (bufLen == 0) { - LOG(ERROR) << "Receive buffer has zero size!"; - return std::nullopt; + if (maxSize == 0) { + LOG(ERROR) << "Maximum receive size should not be zero"; + return {std::nullopt, {}}; } - const auto bytesReceived = - recvfrom(mFd.get(), buf, bufLen, MSG_TRUNC, reinterpret_cast(&sa), &saLen); + if (mReceiveBuffer.size() < maxSize) mReceiveBuffer.resize(maxSize); + + sockaddr_nl sa = {}; + socklen_t saLen = sizeof(sa); + const auto bytesReceived = recvfrom(mFd.get(), mReceiveBuffer.data(), maxSize, MSG_TRUNC, + reinterpret_cast(&sa), &saLen); + if (bytesReceived <= 0) { PLOG(ERROR) << "Failed to receive Netlink message"; - return std::nullopt; - } else if (unsigned(bytesReceived) > bufLen) { - PLOG(ERROR) << "Received data larger than the receive buffer! " << bytesReceived << " > " - << bufLen; - return std::nullopt; + return {std::nullopt, {}}; + } else if (size_t(bytesReceived) > maxSize) { + PLOG(ERROR) << "Received data larger than maximum receive size: " // + << bytesReceived << " > " << maxSize; + return {std::nullopt, {}}; } - Buffer msg(reinterpret_cast(buf), bytesReceived); + Buffer msg(reinterpret_cast(mReceiveBuffer.data()), bytesReceived); if constexpr (kSuperVerbose) { - LOG(VERBOSE) << "received " << toString(msg, mProtocol); + LOG(VERBOSE) << "received (" << sa.nl_pid << " -> " << msg->nlmsg_pid << "):" // + << toString(msg, mProtocol); } - return msg; + return {msg, sa}; } /* TODO(161389935): Migrate receiveAck to use nlmsg<> internally. Possibly reuse @@ -179,11 +156,11 @@ bool Socket::receiveAck() { return false; } -std::optional Socket::getSocketPid() { +std::optional Socket::getPid() { sockaddr_nl sa = {}; socklen_t sasize = sizeof(sa); if (getsockname(mFd.get(), reinterpret_cast(&sa), &sasize) < 0) { - PLOG(ERROR) << "Failed to getsockname() for netlink_fd!"; + PLOG(ERROR) << "Failed to get PID of Netlink socket"; return std::nullopt; } return sa.nl_pid; diff --git a/automotive/can/1.0/default/libnl++/include/libnl++/MessageFactory.h b/automotive/can/1.0/default/libnl++/include/libnl++/MessageFactory.h index e00ca20fcd..5272577b53 100644 --- a/automotive/can/1.0/default/libnl++/include/libnl++/MessageFactory.h +++ b/automotive/can/1.0/default/libnl++/include/libnl++/MessageFactory.h @@ -35,7 +35,6 @@ void addattr_nest_end(struct nlmsghdr* n, struct nlattr* nest); } // namespace impl -// TODO(twasilczyk): rename to NetlinkMessage /** * Wrapper around NETLINK_ROUTE messages, to build them in C++ style. * diff --git a/automotive/can/1.0/default/libnl++/include/libnl++/Socket.h b/automotive/can/1.0/default/libnl++/include/libnl++/Socket.h index 7685733ce5..bc6ad9df54 100644 --- a/automotive/can/1.0/default/libnl++/include/libnl++/Socket.h +++ b/automotive/can/1.0/default/libnl++/include/libnl++/Socket.h @@ -24,6 +24,7 @@ #include #include +#include namespace android::nl { @@ -33,59 +34,88 @@ namespace android::nl { * This class is not thread safe to use a single instance between multiple threads, but it's fine to * use multiple instances over multiple threads. */ -struct Socket { +class Socket { + public: + static constexpr size_t defaultReceiveSize = 8192; + /** * Socket constructor. * * \param protocol the Netlink protocol to use. - * \param pid port id. Default value of 0 allows the kernel to assign us a unique pid. (NOTE: - * this is NOT the same as process id!) + * \param pid port id. Default value of 0 allows the kernel to assign us a unique pid. + * (NOTE: this is NOT the same as process id). * \param groups Netlink multicast groups to listen to. This is a 32-bit bitfield, where each - * bit is a different group. Default value of 0 means no groups are selected. See man netlink.7 + * bit is a different group. Default value of 0 means no groups are selected. + * See man netlink.7. * for more details. */ - Socket(int protocol, unsigned int pid = 0, uint32_t groups = 0); + Socket(int protocol, unsigned pid = 0, uint32_t groups = 0); /** - * Send Netlink message to Kernel. The sequence number will be automatically incremented, and - * the NLM_F_ACK (request ACK) flag will be set. + * Send Netlink message with incremented sequence number to the Kernel. * - * \param msg Message to send. - * \return true, if succeeded + * \param msg Message to send. Its sequence number will be updated. + * \return true, if succeeded. */ - template + template bool send(MessageFactory& req) { - if (!req.isGood()) return false; - return send(req.header(), req.totalLength); + sockaddr_nl sa = {}; + sa.nl_family = AF_NETLINK; + sa.nl_pid = 0; // Kernel + return send(req, sa); } /** - * Send Netlink message. The message will be sent as is, without any modification. + * Send Netlink message with incremented sequence number. + * + * \param msg Message to send. Its sequence number will be updated. + * \param sa Destination address. + * \return true, if succeeded. + */ + template + bool send(MessageFactory& req, const sockaddr_nl& sa) { + if (!req.isGood()) return false; + + const auto nlmsg = req.header(); + nlmsg->nlmsg_seq = mSeq + 1; + + // With MessageFactory<>, we trust nlmsg_len to be correct. + return send({nlmsg, nlmsg->nlmsg_len}, sa); + } + + /** + * Send Netlink message. * * \param msg Message to send. * \param sa Destination address. - * \return true, if succeeded + * \return true, if succeeded. */ bool send(const Buffer& msg, const sockaddr_nl& sa); /** - * Receive Netlink data. + * Receive one or multiple Netlink messages. * - * \param buf buffer to hold message data. - * \param bufLen length of buf. - * \return Buffer with message data, std::nullopt on error. + * WARNING: the underlying buffer is owned by Socket class and the data is valid until the next + * call to the read function or until deallocation of Socket instance. + * + * \param maxSize Maximum total size of received messages + * \return Buffer view with message data, std::nullopt on error. */ - std::optional> receive(void* buf, size_t bufLen); + std::optional> receive(size_t maxSize = defaultReceiveSize); /** - * Receive Netlink data with address info. + * Receive one or multiple Netlink messages and the sender process address. * - * \param buf buffer to hold message data. - * \param bufLen length of buf. - * \param sa Blank struct that recvfrom will populate with address info. - * \return Buffer with message data, std::nullopt on error. + * WARNING: the underlying buffer is owned by Socket class and the data is valid until the next + * call to the read function or until deallocation of Socket instance. + * + * \param maxSize Maximum total size of received messages + * \return A pair (for use with structured binding) containing: + * - buffer view with message data, std::nullopt on error; + * - sender process address. */ - std::optional> receive(void* buf, size_t bufLen, sockaddr_nl& sa); + std::pair>, sockaddr_nl> receiveFrom( + size_t maxSize = defaultReceiveSize); /** * Receive Netlink ACK message from Kernel. @@ -95,11 +125,11 @@ struct Socket { bool receiveAck(); /** - * Gets the PID assigned to mFd. + * Fetches the socket PID. * - * \return pid that mSocket is bound to. + * \return PID that socket is bound to. */ - std::optional getSocketPid(); + std::optional getPid(); private: const int mProtocol; @@ -107,8 +137,7 @@ struct Socket { uint32_t mSeq = 0; base::unique_fd mFd; bool mFailed = false; - - bool send(nlmsghdr* msg, size_t totalLen); + std::vector mReceiveBuffer; DISALLOW_COPY_AND_ASSIGN(Socket); };