diff --git a/compiler-rt/lib/dfsan/dfsan_custom.cpp b/compiler-rt/lib/dfsan/dfsan_custom.cpp --- a/compiler-rt/lib/dfsan/dfsan_custom.cpp +++ b/compiler-rt/lib/dfsan/dfsan_custom.cpp @@ -890,22 +890,40 @@ return ret; } +static void clear_msghdr_labels(size_t bytes_written, struct msghdr *msg) { + dfsan_set_label(0, msg, sizeof(*msg)); + dfsan_set_label(0, msg->msg_name, msg->msg_namelen); + dfsan_set_label(0, msg->msg_control, msg->msg_controllen); + for (size_t i = 0; bytes_written > 0; ++i) { + assert(i < msg->msg_iovlen); + struct iovec *iov = &msg->msg_iov[i]; + size_t iov_written = + bytes_written < iov->iov_len ? bytes_written : iov->iov_len; + dfsan_set_label(0, iov->iov_base, iov_written); + bytes_written -= iov_written; + } +} + +SANITIZER_INTERFACE_ATTRIBUTE int __dfsw_recvmmsg( + int sockfd, struct mmsghdr *msgvec, unsigned int vlen, int flags, + struct timespec *timeout, dfsan_label sockfd_label, + dfsan_label msgvec_label, dfsan_label vlen_label, dfsan_label flags_label, + dfsan_label timeout_label, dfsan_label *ret_label) { + int ret = recvmmsg(sockfd, msgvec, vlen, flags, timeout); + for (int i = 0; i < ret; ++i) { + dfsan_set_label(0, &msgvec[i].msg_len, sizeof(msgvec[i].msg_len)); + clear_msghdr_labels(msgvec[i].msg_len, &msgvec[i].msg_hdr); + } + *ret_label = 0; + return ret; +} + SANITIZER_INTERFACE_ATTRIBUTE ssize_t __dfsw_recvmsg( int sockfd, struct msghdr *msg, int flags, dfsan_label sockfd_label, dfsan_label msg_label, dfsan_label flags_label, dfsan_label *ret_label) { ssize_t ret = recvmsg(sockfd, msg, flags); - if (ret >= 0) { - dfsan_set_label(0, msg, sizeof(*msg)); - dfsan_set_label(0, msg->msg_name, msg->msg_namelen); - dfsan_set_label(0, msg->msg_control, msg->msg_controllen); - for (size_t remaining = ret, i = 0; remaining > 0; ++i) { - assert(i < msg->msg_iovlen); - struct iovec *iov = &msg->msg_iov[i]; - size_t written = remaining < iov->iov_len ? remaining : iov->iov_len; - dfsan_set_label(0, iov->iov_base, written); - remaining -= written; - } - } + if (ret >= 0) + clear_msghdr_labels(ret, msg); *ret_label = 0; return ret; } diff --git a/compiler-rt/lib/dfsan/done_abilist.txt b/compiler-rt/lib/dfsan/done_abilist.txt --- a/compiler-rt/lib/dfsan/done_abilist.txt +++ b/compiler-rt/lib/dfsan/done_abilist.txt @@ -199,6 +199,7 @@ fun:nanosleep=custom fun:pread=custom fun:read=custom +fun:recvmmsg=custom fun:recvmsg=custom fun:sigaltstack=custom fun:socketpair=custom diff --git a/compiler-rt/test/dfsan/custom.cpp b/compiler-rt/test/dfsan/custom.cpp --- a/compiler-rt/test/dfsan/custom.cpp +++ b/compiler-rt/test/dfsan/custom.cpp @@ -337,6 +337,64 @@ free(crv); } +void test_recvmmsg() { + int sockfds[2]; + int ret = socketpair(AF_UNIX, SOCK_DGRAM, 0, sockfds); + assert(ret != -1); + + // Setup messages to send. + struct mmsghdr smmsg[2] = {}; + char sbuf0[] = "abcdefghijkl"; + struct iovec siov0[2] = {{&sbuf0[0], 4}, {&sbuf0[4], 4}}; + smmsg[0].msg_hdr.msg_iov = siov0; + smmsg[0].msg_hdr.msg_iovlen = 2; + char sbuf1[] = "1234567890"; + struct iovec siov1[1] = {{&sbuf1[0], 7}}; + smmsg[1].msg_hdr.msg_iov = siov1; + smmsg[1].msg_hdr.msg_iovlen = 1; + + // Send messages. + int sent_msgs = sendmmsg(sockfds[0], smmsg, 2, 0); + assert(sent_msgs == 2); + + // Setup receive buffers. + struct mmsghdr rmmsg[2] = {}; + char rbuf0[128]; + struct iovec riov0[2] = {{&rbuf0[0], 4}, {&rbuf0[4], 4}}; + rmmsg[0].msg_hdr.msg_iov = riov0; + rmmsg[0].msg_hdr.msg_iovlen = 2; + char rbuf1[128]; + struct iovec riov1[1] = {{&rbuf1[0], 16}}; + rmmsg[1].msg_hdr.msg_iov = riov1; + rmmsg[1].msg_hdr.msg_iovlen = 1; + struct timespec timeout = {1, 1}; + dfsan_set_label(i_label, rbuf0, sizeof(rbuf0)); + dfsan_set_label(i_label, rbuf1, sizeof(rbuf1)); + dfsan_set_label(i_label, &rmmsg[0].msg_len, sizeof(rmmsg[0].msg_len)); + dfsan_set_label(i_label, &rmmsg[1].msg_len, sizeof(rmmsg[1].msg_len)); + dfsan_set_label(i_label, &timeout, sizeof(timeout)); + + // Receive messages and check labels. + int received_msgs = recvmmsg(sockfds[1], rmmsg, 2, 0, &timeout); + assert(received_msgs == sent_msgs); + assert(rmmsg[0].msg_len == smmsg[0].msg_len); + assert(rmmsg[1].msg_len == smmsg[1].msg_len); + assert(memcmp(sbuf0, rbuf0, 8) == 0); + assert(memcmp(sbuf1, rbuf1, 7) == 0); + ASSERT_ZERO_LABEL(received_msgs); + ASSERT_ZERO_LABEL(rmmsg[0].msg_len); + ASSERT_ZERO_LABEL(rmmsg[1].msg_len); + ASSERT_READ_ZERO_LABEL(&rbuf0[0], 8); + ASSERT_READ_LABEL(&rbuf0[8], 1, i_label); + ASSERT_READ_ZERO_LABEL(&rbuf1[0], 7); + ASSERT_READ_LABEL(&rbuf1[7], 1, i_label); + ASSERT_LABEL(timeout.tv_sec, i_label); + ASSERT_LABEL(timeout.tv_nsec, i_label); + + close(sockfds[0]); + close(sockfds[1]); +} + void test_recvmsg() { int sockfds[2]; int ret = socketpair(AF_UNIX, SOCK_DGRAM, 0, sockfds); @@ -1177,6 +1235,7 @@ test_pread(); test_pthread_create(); test_read(); + test_recvmmsg(); test_recvmsg(); test_sched_getaffinity(); test_select();