Browse Source

Merge git://git.kernel.org/pub/scm/linux/kernel/git/bpf/bpf

Daniel Borkmann says:

====================
pull-request: bpf 2018-08-29

The following pull-request contains BPF updates for your *net* tree.

The main changes are:

1) Fix a build error in sk_reuseport_convert_ctx_access() when
   compiling with clang which cannot resolve hweight_long() at
   build time inside the BUILD_BUG_ON() assertion, from Stefan.

2) Several fixes for BPF sockmap, four of them in getting the
   bpf_msg_pull_data() helper to work, one use after free case
   in bpf_tcp_close() and one refcount leak in bpf_tcp_recvmsg(),
   from Daniel.

3) Another fix for BPF sockmap where we misaccount sk_mem_uncharge()
   in the socket redirect error case from unwinding scatterlist
   twice, from John.
====================

Signed-off-by: David S. Miller <davem@davemloft.net>
David S. Miller 7 years ago
parent
commit
6a5d39aa9a
2 changed files with 54 additions and 50 deletions
  1. 25 27
      kernel/bpf/sockmap.c
  2. 29 23
      net/core/filter.c

+ 25 - 27
kernel/bpf/sockmap.c

@@ -236,7 +236,7 @@ static int bpf_tcp_init(struct sock *sk)
 }
 
 static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
-static int free_start_sg(struct sock *sk, struct sk_msg_buff *md);
+static int free_start_sg(struct sock *sk, struct sk_msg_buff *md, bool charge);
 
 static void bpf_tcp_release(struct sock *sk)
 {
@@ -248,7 +248,7 @@ static void bpf_tcp_release(struct sock *sk)
 		goto out;
 
 	if (psock->cork) {
-		free_start_sg(psock->sock, psock->cork);
+		free_start_sg(psock->sock, psock->cork, true);
 		kfree(psock->cork);
 		psock->cork = NULL;
 	}
@@ -330,14 +330,14 @@ static void bpf_tcp_close(struct sock *sk, long timeout)
 	close_fun = psock->save_close;
 
 	if (psock->cork) {
-		free_start_sg(psock->sock, psock->cork);
+		free_start_sg(psock->sock, psock->cork, true);
 		kfree(psock->cork);
 		psock->cork = NULL;
 	}
 
 	list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
 		list_del(&md->list);
-		free_start_sg(psock->sock, md);
+		free_start_sg(psock->sock, md, true);
 		kfree(md);
 	}
 
@@ -369,7 +369,7 @@ static void bpf_tcp_close(struct sock *sk, long timeout)
 			/* If another thread deleted this object skip deletion.
 			 * The refcnt on psock may or may not be zero.
 			 */
-			if (l) {
+			if (l && l == link) {
 				hlist_del_rcu(&link->hash_node);
 				smap_release_sock(psock, link->sk);
 				free_htab_elem(htab, link);
@@ -570,14 +570,16 @@ static void free_bytes_sg(struct sock *sk, int bytes,
 	md->sg_start = i;
 }
 
-static int free_sg(struct sock *sk, int start, struct sk_msg_buff *md)
+static int free_sg(struct sock *sk, int start,
+		   struct sk_msg_buff *md, bool charge)
 {
 	struct scatterlist *sg = md->sg_data;
 	int i = start, free = 0;
 
 	while (sg[i].length) {
 		free += sg[i].length;
-		sk_mem_uncharge(sk, sg[i].length);
+		if (charge)
+			sk_mem_uncharge(sk, sg[i].length);
 		if (!md->skb)
 			put_page(sg_page(&sg[i]));
 		sg[i].length = 0;
@@ -594,9 +596,9 @@ static int free_sg(struct sock *sk, int start, struct sk_msg_buff *md)
 	return free;
 }
 
-static int free_start_sg(struct sock *sk, struct sk_msg_buff *md)
+static int free_start_sg(struct sock *sk, struct sk_msg_buff *md, bool charge)
 {
-	int free = free_sg(sk, md->sg_start, md);
+	int free = free_sg(sk, md->sg_start, md, charge);
 
 	md->sg_start = md->sg_end;
 	return free;
@@ -604,7 +606,7 @@ static int free_start_sg(struct sock *sk, struct sk_msg_buff *md)
 
 static int free_curr_sg(struct sock *sk, struct sk_msg_buff *md)
 {
-	return free_sg(sk, md->sg_curr, md);
+	return free_sg(sk, md->sg_curr, md, true);
 }
 
 static int bpf_map_msg_verdict(int _rc, struct sk_msg_buff *md)
@@ -718,7 +720,7 @@ static int bpf_tcp_ingress(struct sock *sk, int apply_bytes,
 		list_add_tail(&r->list, &psock->ingress);
 		sk->sk_data_ready(sk);
 	} else {
-		free_start_sg(sk, r);
+		free_start_sg(sk, r, true);
 		kfree(r);
 	}
 
@@ -752,14 +754,10 @@ static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,
 		release_sock(sk);
 	}
 	smap_release_sock(psock, sk);
-	if (unlikely(err))
-		goto out;
-	return 0;
+	return err;
 out_rcu:
 	rcu_read_unlock();
-out:
-	free_bytes_sg(NULL, send, md, false);
-	return err;
+	return 0;
 }
 
 static inline void bpf_md_init(struct smap_psock *psock)
@@ -822,7 +820,7 @@ more_data:
 	case __SK_PASS:
 		err = bpf_tcp_push(sk, send, m, flags, true);
 		if (unlikely(err)) {
-			*copied -= free_start_sg(sk, m);
+			*copied -= free_start_sg(sk, m, true);
 			break;
 		}
 
@@ -845,16 +843,17 @@ more_data:
 		lock_sock(sk);
 
 		if (unlikely(err < 0)) {
-			free_start_sg(sk, m);
+			int free = free_start_sg(sk, m, false);
+
 			psock->sg_size = 0;
 			if (!cork)
-				*copied -= send;
+				*copied -= free;
 		} else {
 			psock->sg_size -= send;
 		}
 
 		if (cork) {
-			free_start_sg(sk, m);
+			free_start_sg(sk, m, true);
 			psock->sg_size = 0;
 			kfree(m);
 			m = NULL;
@@ -912,6 +911,8 @@ static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
 
 	if (unlikely(flags & MSG_ERRQUEUE))
 		return inet_recv_error(sk, msg, len, addr_len);
+	if (!skb_queue_empty(&sk->sk_receive_queue))
+		return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
 
 	rcu_read_lock();
 	psock = smap_psock_sk(sk);
@@ -922,9 +923,6 @@ static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
 		goto out;
 	rcu_read_unlock();
 
-	if (!skb_queue_empty(&sk->sk_receive_queue))
-		return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
-
 	lock_sock(sk);
 bytes_ready:
 	while (copied != len) {
@@ -1122,7 +1120,7 @@ wait_for_memory:
 		err = sk_stream_wait_memory(sk, &timeo);
 		if (err) {
 			if (m && m != psock->cork)
-				free_start_sg(sk, m);
+				free_start_sg(sk, m, true);
 			goto out_err;
 		}
 	}
@@ -1581,13 +1579,13 @@ static void smap_gc_work(struct work_struct *w)
 		bpf_prog_put(psock->bpf_tx_msg);
 
 	if (psock->cork) {
-		free_start_sg(psock->sock, psock->cork);
+		free_start_sg(psock->sock, psock->cork, true);
 		kfree(psock->cork);
 	}
 
 	list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
 		list_del(&md->list);
-		free_start_sg(psock->sock, md);
+		free_start_sg(psock->sock, md, true);
 		kfree(md);
 	}
 

+ 29 - 23
net/core/filter.c

@@ -2282,14 +2282,21 @@ static const struct bpf_func_proto bpf_msg_cork_bytes_proto = {
 	.arg2_type      = ARG_ANYTHING,
 };
 
+#define sk_msg_iter_var(var)			\
+	do {					\
+		var++;				\
+		if (var == MAX_SKB_FRAGS)	\
+			var = 0;		\
+	} while (0)
+
 BPF_CALL_4(bpf_msg_pull_data,
 	   struct sk_msg_buff *, msg, u32, start, u32, end, u64, flags)
 {
 	unsigned int len = 0, offset = 0, copy = 0;
+	int bytes = end - start, bytes_sg_total;
 	struct scatterlist *sg = msg->sg_data;
 	int first_sg, last_sg, i, shift;
 	unsigned char *p, *to, *from;
-	int bytes = end - start;
 	struct page *page;
 
 	if (unlikely(flags || end <= start))
@@ -2299,21 +2306,22 @@ BPF_CALL_4(bpf_msg_pull_data,
 	i = msg->sg_start;
 	do {
 		len = sg[i].length;
-		offset += len;
 		if (start < offset + len)
 			break;
-		i++;
-		if (i == MAX_SKB_FRAGS)
-			i = 0;
+		offset += len;
+		sk_msg_iter_var(i);
 	} while (i != msg->sg_end);
 
 	if (unlikely(start >= offset + len))
 		return -EINVAL;
 
-	if (!msg->sg_copy[i] && bytes <= len)
-		goto out;
-
 	first_sg = i;
+	/* The start may point into the sg element so we need to also
+	 * account for the headroom.
+	 */
+	bytes_sg_total = start - offset + bytes;
+	if (!msg->sg_copy[i] && bytes_sg_total <= len)
+		goto out;
 
 	/* At this point we need to linearize multiple scatterlist
 	 * elements or a single shared page. Either way we need to
@@ -2327,15 +2335,13 @@ BPF_CALL_4(bpf_msg_pull_data,
 	 */
 	do {
 		copy += sg[i].length;
-		i++;
-		if (i == MAX_SKB_FRAGS)
-			i = 0;
-		if (bytes < copy)
+		sk_msg_iter_var(i);
+		if (bytes_sg_total <= copy)
 			break;
 	} while (i != msg->sg_end);
 	last_sg = i;
 
-	if (unlikely(copy < end - start))
+	if (unlikely(bytes_sg_total > copy))
 		return -EINVAL;
 
 	page = alloc_pages(__GFP_NOWARN | GFP_ATOMIC, get_order(copy));
@@ -2355,9 +2361,7 @@ BPF_CALL_4(bpf_msg_pull_data,
 		sg[i].length = 0;
 		put_page(sg_page(&sg[i]));
 
-		i++;
-		if (i == MAX_SKB_FRAGS)
-			i = 0;
+		sk_msg_iter_var(i);
 	} while (i != last_sg);
 
 	sg[first_sg].length = copy;
@@ -2367,11 +2371,15 @@ BPF_CALL_4(bpf_msg_pull_data,
 	 * had a single entry though we can just replace it and
 	 * be done. Otherwise walk the ring and shift the entries.
 	 */
-	shift = last_sg - first_sg - 1;
+	WARN_ON_ONCE(last_sg == first_sg);
+	shift = last_sg > first_sg ?
+		last_sg - first_sg - 1 :
+		MAX_SKB_FRAGS - first_sg + last_sg - 1;
 	if (!shift)
 		goto out;
 
-	i = first_sg + 1;
+	i = first_sg;
+	sk_msg_iter_var(i);
 	do {
 		int move_from;
 
@@ -2388,15 +2396,13 @@ BPF_CALL_4(bpf_msg_pull_data,
 		sg[move_from].page_link = 0;
 		sg[move_from].offset = 0;
 
-		i++;
-		if (i == MAX_SKB_FRAGS)
-			i = 0;
+		sk_msg_iter_var(i);
 	} while (1);
 	msg->sg_end -= shift;
 	if (msg->sg_end < 0)
 		msg->sg_end += MAX_SKB_FRAGS;
 out:
-	msg->data = sg_virt(&sg[i]) + start - offset;
+	msg->data = sg_virt(&sg[first_sg]) + start - offset;
 	msg->data_end = msg->data + bytes;
 
 	return 0;
@@ -7281,7 +7287,7 @@ static u32 sk_reuseport_convert_ctx_access(enum bpf_access_type type,
 		break;
 
 	case offsetof(struct sk_reuseport_md, ip_protocol):
-		BUILD_BUG_ON(hweight_long(SK_FL_PROTO_MASK) != BITS_PER_BYTE);
+		BUILD_BUG_ON(HWEIGHT32(SK_FL_PROTO_MASK) != BITS_PER_BYTE);
 		SK_REUSEPORT_LOAD_SK_FIELD_SIZE_OFF(__sk_flags_offset,
 						    BPF_W, 0);
 		*insn++ = BPF_ALU32_IMM(BPF_AND, si->dst_reg, SK_FL_PROTO_MASK);