Browse Source

Merge branch 'bpf-sk-msg-peek'

John Fastabend says:

====================
This adds support for the MSG_PEEK flag when redirecting into
an ingress psock sk_msg queue.

The first patch adds some base support to the helpers, then the
feature, and finally we add an option for the test suite to do
a duplicate MSG_PEEK call on every recv to test the feature.

With duplicate MSG_PEEK call all tests continue to PASS.
====================

Acked-by: Alexei Starovoitov <ast@kernel.org>
Signed-off-by: Daniel Borkmann <daniel@iogearbox.net>
Daniel Borkmann 6 years ago
parent
commit
44d520eb17
5 changed files with 153 additions and 74 deletions
  1. 8 5
      include/linux/skmsg.h
  2. 1 1
      include/net/tcp.h
  3. 27 15
      net/ipv4/tcp_bpf.c
  4. 2 1
      net/tls/tls_sw.c
  5. 115 52
      tools/testing/selftests/bpf/test_sockmap.c

+ 8 - 5
include/linux/skmsg.h

@@ -187,18 +187,21 @@ static inline void sk_msg_xfer_full(struct sk_msg *dst, struct sk_msg *src)
 	sk_msg_init(src);
 }
 
+static inline bool sk_msg_full(const struct sk_msg *msg)
+{
+	return (msg->sg.end == msg->sg.start) && msg->sg.size;
+}
+
 static inline u32 sk_msg_elem_used(const struct sk_msg *msg)
 {
+	if (sk_msg_full(msg))
+		return MAX_MSG_FRAGS;
+
 	return msg->sg.end >= msg->sg.start ?
 		msg->sg.end - msg->sg.start :
 		msg->sg.end + (MAX_MSG_FRAGS - msg->sg.start);
 }
 
-static inline bool sk_msg_full(const struct sk_msg *msg)
-{
-	return (msg->sg.end == msg->sg.start) && msg->sg.size;
-}
-
 static inline struct scatterlist *sk_msg_elem(struct sk_msg *msg, int which)
 {
 	return &msg->sg.data[which];

+ 1 - 1
include/net/tcp.h

@@ -2089,7 +2089,7 @@ int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, u32 bytes,
 int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
 		    int nonblock, int flags, int *addr_len);
 int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
-		      struct msghdr *msg, int len);
+		      struct msghdr *msg, int len, int flags);
 
 /* Call BPF_SOCK_OPS program that returns an int. If the return value
  * is < 0, then the BPF op failed (for example if the loaded BPF

+ 27 - 15
net/ipv4/tcp_bpf.c

@@ -39,17 +39,19 @@ static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock,
 }
 
 int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
-		      struct msghdr *msg, int len)
+		      struct msghdr *msg, int len, int flags)
 {
 	struct iov_iter *iter = &msg->msg_iter;
+	int peek = flags & MSG_PEEK;
 	int i, ret, copied = 0;
+	struct sk_msg *msg_rx;
+
+	msg_rx = list_first_entry_or_null(&psock->ingress_msg,
+					  struct sk_msg, list);
 
 	while (copied != len) {
 		struct scatterlist *sge;
-		struct sk_msg *msg_rx;
 
-		msg_rx = list_first_entry_or_null(&psock->ingress_msg,
-						  struct sk_msg, list);
 		if (unlikely(!msg_rx))
 			break;
 
@@ -70,22 +72,30 @@ int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
 			}
 
 			copied += copy;
-			sge->offset += copy;
-			sge->length -= copy;
-			sk_mem_uncharge(sk, copy);
-			msg_rx->sg.size -= copy;
-			if (!sge->length) {
-				i++;
-				if (i == MAX_SKB_FRAGS)
-					i = 0;
-				if (!msg_rx->skb)
-					put_page(page);
+			if (likely(!peek)) {
+				sge->offset += copy;
+				sge->length -= copy;
+				sk_mem_uncharge(sk, copy);
+				msg_rx->sg.size -= copy;
+
+				if (!sge->length) {
+					sk_msg_iter_var_next(i);
+					if (!msg_rx->skb)
+						put_page(page);
+				}
+			} else {
+				sk_msg_iter_var_next(i);
 			}
 
 			if (copied == len)
 				break;
 		} while (i != msg_rx->sg.end);
 
+		if (unlikely(peek)) {
+			msg_rx = list_next_entry(msg_rx, list);
+			continue;
+		}
+
 		msg_rx->sg.start = i;
 		if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) {
 			list_del(&msg_rx->list);
@@ -93,6 +103,8 @@ int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
 				consume_skb(msg_rx->skb);
 			kfree(msg_rx);
 		}
+		msg_rx = list_first_entry_or_null(&psock->ingress_msg,
+						  struct sk_msg, list);
 	}
 
 	return copied;
@@ -115,7 +127,7 @@ int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
 		return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
 	lock_sock(sk);
 msg_bytes_ready:
-	copied = __tcp_bpf_recvmsg(sk, psock, msg, len);
+	copied = __tcp_bpf_recvmsg(sk, psock, msg, len, flags);
 	if (!copied) {
 		int data, err = 0;
 		long timeo;

+ 2 - 1
net/tls/tls_sw.c

@@ -1478,7 +1478,8 @@ int tls_sw_recvmsg(struct sock *sk,
 		skb = tls_wait_data(sk, psock, flags, timeo, &err);
 		if (!skb) {
 			if (psock) {
-				int ret = __tcp_bpf_recvmsg(sk, psock, msg, len);
+				int ret = __tcp_bpf_recvmsg(sk, psock,
+							    msg, len, flags);
 
 				if (ret > 0) {
 					copied += ret;

+ 115 - 52
tools/testing/selftests/bpf/test_sockmap.c

@@ -80,6 +80,7 @@ int txmsg_end;
 int txmsg_ingress;
 int txmsg_skb;
 int ktls;
+int peek_flag;
 
 static const struct option long_options[] = {
 	{"help",	no_argument,		NULL, 'h' },
@@ -102,6 +103,7 @@ static const struct option long_options[] = {
 	{"txmsg_ingress", no_argument,		&txmsg_ingress, 1 },
 	{"txmsg_skb", no_argument,		&txmsg_skb, 1 },
 	{"ktls", no_argument,			&ktls, 1 },
+	{"peek", no_argument,			&peek_flag, 1 },
 	{0, 0, NULL, 0 }
 };
 
@@ -352,33 +354,40 @@ static int msg_loop_sendpage(int fd, int iov_length, int cnt,
 	return 0;
 }
 
-static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
-		    struct msg_stats *s, bool tx,
-		    struct sockmap_options *opt)
+static void msg_free_iov(struct msghdr *msg)
 {
-	struct msghdr msg = {0};
-	int err, i, flags = MSG_NOSIGNAL;
+	int i;
+
+	for (i = 0; i < msg->msg_iovlen; i++)
+		free(msg->msg_iov[i].iov_base);
+	free(msg->msg_iov);
+	msg->msg_iov = NULL;
+	msg->msg_iovlen = 0;
+}
+
+static int msg_alloc_iov(struct msghdr *msg,
+			 int iov_count, int iov_length,
+			 bool data, bool xmit)
+{
+	unsigned char k = 0;
 	struct iovec *iov;
-	unsigned char k;
-	bool data_test = opt->data_test;
-	bool drop = opt->drop_expected;
+	int i;
 
 	iov = calloc(iov_count, sizeof(struct iovec));
 	if (!iov)
 		return errno;
 
-	k = 0;
 	for (i = 0; i < iov_count; i++) {
 		unsigned char *d = calloc(iov_length, sizeof(char));
 
 		if (!d) {
 			fprintf(stderr, "iov_count %i/%i OOM\n", i, iov_count);
-			goto out_errno;
+			goto unwind_iov;
 		}
 		iov[i].iov_base = d;
 		iov[i].iov_len = iov_length;
 
-		if (data_test && tx) {
+		if (data && xmit) {
 			int j;
 
 			for (j = 0; j < iov_length; j++)
@@ -386,9 +395,60 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
 		}
 	}
 
-	msg.msg_iov = iov;
-	msg.msg_iovlen = iov_count;
-	k = 0;
+	msg->msg_iov = iov;
+	msg->msg_iovlen = iov_count;
+
+	return 0;
+unwind_iov:
+	for (i--; i >= 0 ; i--)
+		free(msg->msg_iov[i].iov_base);
+	return -ENOMEM;
+}
+
+static int msg_verify_data(struct msghdr *msg, int size, int chunk_sz)
+{
+	int i, j, bytes_cnt = 0;
+	unsigned char k = 0;
+
+	for (i = 0; i < msg->msg_iovlen; i++) {
+		unsigned char *d = msg->msg_iov[i].iov_base;
+
+		for (j = 0;
+		     j < msg->msg_iov[i].iov_len && size; j++) {
+			if (d[j] != k++) {
+				fprintf(stderr,
+					"detected data corruption @iov[%i]:%i %02x != %02x, %02x ?= %02x\n",
+					i, j, d[j], k - 1, d[j+1], k);
+				return -EIO;
+			}
+			bytes_cnt++;
+			if (bytes_cnt == chunk_sz) {
+				k = 0;
+				bytes_cnt = 0;
+			}
+			size--;
+		}
+	}
+	return 0;
+}
+
+static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
+		    struct msg_stats *s, bool tx,
+		    struct sockmap_options *opt)
+{
+	struct msghdr msg = {0}, msg_peek = {0};
+	int err, i, flags = MSG_NOSIGNAL;
+	bool drop = opt->drop_expected;
+	bool data = opt->data_test;
+
+	err = msg_alloc_iov(&msg, iov_count, iov_length, data, tx);
+	if (err)
+		goto out_errno;
+	if (peek_flag) {
+		err = msg_alloc_iov(&msg_peek, iov_count, iov_length, data, tx);
+		if (err)
+			goto out_errno;
+	}
 
 	if (tx) {
 		clock_gettime(CLOCK_MONOTONIC, &s->start);
@@ -408,19 +468,12 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
 		}
 		clock_gettime(CLOCK_MONOTONIC, &s->end);
 	} else {
-		int slct, recv, max_fd = fd;
+		int slct, recvp = 0, recv, max_fd = fd;
 		int fd_flags = O_NONBLOCK;
 		struct timeval timeout;
 		float total_bytes;
-		int bytes_cnt = 0;
-		int chunk_sz;
 		fd_set w;
 
-		if (opt->sendpage)
-			chunk_sz = iov_length * cnt;
-		else
-			chunk_sz = iov_length * iov_count;
-
 		fcntl(fd, fd_flags);
 		total_bytes = (float)iov_count * (float)iov_length * (float)cnt;
 		err = clock_gettime(CLOCK_MONOTONIC, &s->start);
@@ -452,6 +505,19 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
 				goto out_errno;
 			}
 
+			errno = 0;
+			if (peek_flag) {
+				flags |= MSG_PEEK;
+				recvp = recvmsg(fd, &msg_peek, flags);
+				if (recvp < 0) {
+					if (errno != EWOULDBLOCK) {
+						clock_gettime(CLOCK_MONOTONIC, &s->end);
+						goto out_errno;
+					}
+				}
+				flags = 0;
+			}
+
 			recv = recvmsg(fd, &msg, flags);
 			if (recv < 0) {
 				if (errno != EWOULDBLOCK) {
@@ -463,27 +529,23 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
 
 			s->bytes_recvd += recv;
 
-			if (data_test) {
-				int j;
-
-				for (i = 0; i < msg.msg_iovlen; i++) {
-					unsigned char *d = iov[i].iov_base;
-
-					for (j = 0;
-					     j < iov[i].iov_len && recv; j++) {
-						if (d[j] != k++) {
-							errno = -EIO;
-							fprintf(stderr,
-								"detected data corruption @iov[%i]:%i %02x != %02x, %02x ?= %02x\n",
-								i, j, d[j], k - 1, d[j+1], k);
-							goto out_errno;
-						}
-						bytes_cnt++;
-						if (bytes_cnt == chunk_sz) {
-							k = 0;
-							bytes_cnt = 0;
-						}
-						recv--;
+			if (data) {
+				int chunk_sz = opt->sendpage ?
+						iov_length * cnt :
+						iov_length * iov_count;
+
+				errno = msg_verify_data(&msg, recv, chunk_sz);
+				if (errno) {
+					perror("data verify msg failed\n");
+					goto out_errno;
+				}
+				if (recvp) {
+					errno = msg_verify_data(&msg_peek,
+								recvp,
+								chunk_sz);
+					if (errno) {
+						perror("data verify msg_peek failed\n");
+						goto out_errno;
 					}
 				}
 			}
@@ -491,14 +553,12 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
 		clock_gettime(CLOCK_MONOTONIC, &s->end);
 	}
 
-	for (i = 0; i < iov_count; i++)
-		free(iov[i].iov_base);
-	free(iov);
-	return 0;
+	msg_free_iov(&msg);
+	msg_free_iov(&msg_peek);
+	return err;
 out_errno:
-	for (i = 0; i < iov_count; i++)
-		free(iov[i].iov_base);
-	free(iov);
+	msg_free_iov(&msg);
+	msg_free_iov(&msg_peek);
 	return errno;
 }
 
@@ -565,9 +625,10 @@ static int sendmsg_test(struct sockmap_options *opt)
 		}
 		if (opt->verbose)
 			fprintf(stdout,
-				"rx_sendmsg: TX: %zuB %fB/s %fGB/s RX: %zuB %fB/s %fGB/s\n",
+				"rx_sendmsg: TX: %zuB %fB/s %fGB/s RX: %zuB %fB/s %fGB/s %s\n",
 				s.bytes_sent, sent_Bps, sent_Bps/giga,
-				s.bytes_recvd, recvd_Bps, recvd_Bps/giga);
+				s.bytes_recvd, recvd_Bps, recvd_Bps/giga,
+				peek_flag ? "(peek_msg)" : "");
 		if (err && txmsg_cork)
 			err = 0;
 		exit(err ? 1 : 0);
@@ -999,6 +1060,8 @@ static void test_options(char *options)
 		strncat(options, "skb,", OPTSTRING);
 	if (ktls)
 		strncat(options, "ktls,", OPTSTRING);
+	if (peek_flag)
+		strncat(options, "peek,", OPTSTRING);
 }
 
 static int __test_exec(int cgrp, int test, struct sockmap_options *opt)