diff --git a/src/ucs/sys/netlink.c b/src/ucs/sys/netlink.c index 71510b17d41..23885742425 100644 --- a/src/ucs/sys/netlink.c +++ b/src/ucs/sys/netlink.c @@ -44,6 +44,11 @@ KHASH_INIT(ucs_netlink_rt_cache, khint32_t, ucs_netlink_rt_rules_t, 1, kh_int_hash_func, kh_int_hash_equal); static khash_t(ucs_netlink_rt_cache) ucs_netlink_routing_table_cache; +static inline int ucs_netlink_is_msg_done(const struct nlmsghdr *nlh) +{ + return (nlh->nlmsg_type == NLMSG_DONE); +} + static ucs_status_t ucs_netlink_socket_init(int *fd_p, int protocol) { struct sockaddr_nl sa = {.nl_family = AF_NETLINK}; @@ -78,7 +83,7 @@ ucs_netlink_parse_msg(const void *msg, size_t msg_len, const struct nlmsghdr *nlh = (const struct nlmsghdr *)msg; while ((status == UCS_INPROGRESS) && NLMSG_OK(nlh, msg_len) && - (nlh->nlmsg_type != NLMSG_DONE)) { + !ucs_netlink_is_msg_done(nlh)) { if (nlh->nlmsg_type == NLMSG_ERROR) { struct nlmsgerr *err = (struct nlmsgerr *)NLMSG_DATA(nlh); ucs_error("received error response from netlink err=%d: %s\n", @@ -100,9 +105,10 @@ ucs_netlink_send_request(int protocol, unsigned short nlmsg_type, ucs_netlink_parse_cb_t parse_cb, void *arg) { struct nlmsghdr nlh = {0}; - char *recv_msg = NULL; - size_t recv_msg_len = 0; int netlink_fd = -1; + size_t recv_msg_len; + char *recv_msg; + int msg_done; ucs_status_t status; struct iovec iov[2]; size_t bytes_sent; @@ -131,33 +137,38 @@ ucs_netlink_send_request(int protocol, unsigned short nlmsg_type, } /* get message size */ - status = ucs_socket_recv_nb(netlink_fd, NULL, MSG_PEEK | MSG_TRUNC, - &recv_msg_len); - if (status != UCS_OK) { - ucs_error("failed to get netlink message size %d (%s)", - status, ucs_status_string(status)); - goto out; - } + do { + recv_msg_len = 0; + status = ucs_socket_recv_nb(netlink_fd, NULL, MSG_PEEK | MSG_TRUNC, + &recv_msg_len); + if (status != UCS_OK) { + ucs_error("failed to get netlink message size %d (%s)", + status, ucs_status_string(status)); + goto out; + } - recv_msg = ucs_malloc(recv_msg_len, "netlink recv message"); - if (recv_msg == NULL) { - ucs_error("failed to allocate a buffer for netlink receive message of" - " size %zu", recv_msg_len); - goto out; - } + recv_msg = ucs_malloc(recv_msg_len, "netlink recv message"); + if (recv_msg == NULL) { + ucs_error("failed to allocate a buffer for netlink receive message" + " of size %zu", recv_msg_len); + goto out; + } - status = ucs_socket_recv(netlink_fd, recv_msg, recv_msg_len); - if (status != UCS_OK) { - ucs_error("failed to receive netlink message on fd=%d: %s", - netlink_fd, ucs_status_string(status)); - goto out; - } + status = ucs_socket_recv(netlink_fd, recv_msg, recv_msg_len); + if (status != UCS_OK) { + ucs_error("failed to receive netlink message on fd=%d: %s", + netlink_fd, ucs_status_string(status)); + ucs_free(recv_msg); + goto out; + } - status = ucs_netlink_parse_msg(recv_msg, recv_msg_len, parse_cb, arg); + status = ucs_netlink_parse_msg(recv_msg, recv_msg_len, parse_cb, arg); + msg_done = ucs_netlink_is_msg_done((const struct nlmsghdr *)recv_msg); + ucs_free(recv_msg); + } while ((nlmsg_flags & NLM_F_DUMP) && !msg_done); out: ucs_close_fd(&netlink_fd); - ucs_free(recv_msg); return status; } @@ -262,8 +273,8 @@ int ucs_netlink_route_exists(int if_index, const struct sockaddr *sa_remote) ucs_netlink_route_info_t info; UCS_INIT_ONCE(&init_once) { + rtm.rtm_table = RT_TABLE_UNSPEC; /* fetch all the tables */ rtm.rtm_family = AF_INET; - rtm.rtm_table = RT_TABLE_MAIN; ucs_netlink_send_request(NETLINK_ROUTE, RTM_GETROUTE, NLM_F_DUMP, &rtm, sizeof(rtm), ucs_netlink_parse_rt_entry_cb, NULL);