Browse Source

Merge branch 'vhost-skb-leaks'

Wei Xu says:

====================
vhost: fix a few skb leaks

Matthew found a roughly 40% tcp throughput regression with commit
c67df11f(vhost_net: try batch dequing from skb array) as discussed
in the following thread:
https://www.mail-archive.com/netdev@vger.kernel.org/msg187936.html

v4:
- fix zero iov iterator count in tap/tap_do_read()(Jason)
- don't put tun in case of EBADFD(Jason)
- Replace msg->msg_control with new 'skb' when calling tun/tap_do_read()

v3:
- move freeing skb from vhost to tun/tap recvmsg() to not
  confuse the callers.

v2:
- add Matthew as the reporter, thanks matthew.
- moving zero headcount check ahead instead of defer consuming skb
  due to jason and mst's comment.
- add freeing skb in favor of recvmsg() fails.
====================

Acked-by: Michael S. Tsirkin <mst@redhat.com>
Tested-by: Matthew Rosato <mjrosato@linux.vnet.ibm.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
David S. Miller 7 years ago
parent
commit
7344ba039f
3 changed files with 38 additions and 20 deletions
  1. 10 4
      drivers/net/tap.c
  2. 18 6
      drivers/net/tun.c
  3. 10 10
      drivers/vhost/net.c

+ 10 - 4
drivers/net/tap.c

@@ -829,8 +829,11 @@ static ssize_t tap_do_read(struct tap_queue *q,
 	DEFINE_WAIT(wait);
 	ssize_t ret = 0;
 
-	if (!iov_iter_count(to))
+	if (!iov_iter_count(to)) {
+		if (skb)
+			kfree_skb(skb);
 		return 0;
+	}
 
 	if (skb)
 		goto put;
@@ -1154,11 +1157,14 @@ static int tap_recvmsg(struct socket *sock, struct msghdr *m,
 		       size_t total_len, int flags)
 {
 	struct tap_queue *q = container_of(sock, struct tap_queue, sock);
+	struct sk_buff *skb = m->msg_control;
 	int ret;
-	if (flags & ~(MSG_DONTWAIT|MSG_TRUNC))
+	if (flags & ~(MSG_DONTWAIT|MSG_TRUNC)) {
+		if (skb)
+			kfree_skb(skb);
 		return -EINVAL;
-	ret = tap_do_read(q, &m->msg_iter, flags & MSG_DONTWAIT,
-			  m->msg_control);
+	}
+	ret = tap_do_read(q, &m->msg_iter, flags & MSG_DONTWAIT, skb);
 	if (ret > total_len) {
 		m->msg_flags |= MSG_TRUNC;
 		ret = flags & MSG_TRUNC ? ret : total_len;

+ 18 - 6
drivers/net/tun.c

@@ -1952,8 +1952,11 @@ static ssize_t tun_do_read(struct tun_struct *tun, struct tun_file *tfile,
 
 	tun_debug(KERN_INFO, tun, "tun_do_read\n");
 
-	if (!iov_iter_count(to))
+	if (!iov_iter_count(to)) {
+		if (skb)
+			kfree_skb(skb);
 		return 0;
+	}
 
 	if (!skb) {
 		/* Read frames from ring */
@@ -2069,22 +2072,24 @@ static int tun_recvmsg(struct socket *sock, struct msghdr *m, size_t total_len,
 {
 	struct tun_file *tfile = container_of(sock, struct tun_file, socket);
 	struct tun_struct *tun = tun_get(tfile);
+	struct sk_buff *skb = m->msg_control;
 	int ret;
 
-	if (!tun)
-		return -EBADFD;
+	if (!tun) {
+		ret = -EBADFD;
+		goto out_free_skb;
+	}
 
 	if (flags & ~(MSG_DONTWAIT|MSG_TRUNC|MSG_ERRQUEUE)) {
 		ret = -EINVAL;
-		goto out;
+		goto out_put_tun;
 	}
 	if (flags & MSG_ERRQUEUE) {
 		ret = sock_recv_errqueue(sock->sk, m, total_len,
 					 SOL_PACKET, TUN_TX_TIMESTAMP);
 		goto out;
 	}
-	ret = tun_do_read(tun, tfile, &m->msg_iter, flags & MSG_DONTWAIT,
-			  m->msg_control);
+	ret = tun_do_read(tun, tfile, &m->msg_iter, flags & MSG_DONTWAIT, skb);
 	if (ret > (ssize_t)total_len) {
 		m->msg_flags |= MSG_TRUNC;
 		ret = flags & MSG_TRUNC ? ret : total_len;
@@ -2092,6 +2097,13 @@ static int tun_recvmsg(struct socket *sock, struct msghdr *m, size_t total_len,
 out:
 	tun_put(tun);
 	return ret;
+
+out_put_tun:
+	tun_put(tun);
+out_free_skb:
+	if (skb)
+		kfree_skb(skb);
+	return ret;
 }
 
 static int tun_peek_len(struct socket *sock)

+ 10 - 10
drivers/vhost/net.c

@@ -778,16 +778,6 @@ static void handle_rx(struct vhost_net *net)
 		/* On error, stop handling until the next kick. */
 		if (unlikely(headcount < 0))
 			goto out;
-		if (nvq->rx_array)
-			msg.msg_control = vhost_net_buf_consume(&nvq->rxq);
-		/* On overrun, truncate and discard */
-		if (unlikely(headcount > UIO_MAXIOV)) {
-			iov_iter_init(&msg.msg_iter, READ, vq->iov, 1, 1);
-			err = sock->ops->recvmsg(sock, &msg,
-						 1, MSG_DONTWAIT | MSG_TRUNC);
-			pr_debug("Discarded rx packet: len %zd\n", sock_len);
-			continue;
-		}
 		/* OK, now we need to know about added descriptors. */
 		if (!headcount) {
 			if (unlikely(vhost_enable_notify(&net->dev, vq))) {
@@ -800,6 +790,16 @@ static void handle_rx(struct vhost_net *net)
 			 * they refilled. */
 			goto out;
 		}
+		if (nvq->rx_array)
+			msg.msg_control = vhost_net_buf_consume(&nvq->rxq);
+		/* On overrun, truncate and discard */
+		if (unlikely(headcount > UIO_MAXIOV)) {
+			iov_iter_init(&msg.msg_iter, READ, vq->iov, 1, 1);
+			err = sock->ops->recvmsg(sock, &msg,
+						 1, MSG_DONTWAIT | MSG_TRUNC);
+			pr_debug("Discarded rx packet: len %zd\n", sock_len);
+			continue;
+		}
 		/* We don't need to be notified again. */
 		iov_iter_init(&msg.msg_iter, READ, vq->iov, in, vhost_len);
 		fixup = msg.msg_iter;