Browse Source

Merge branch 'listener_refactor_16'

Eric Dumazet says:

====================
tcp: listener refactor part 16

A CONFIG_PROVE_RCU=y build revealed an RCU splat I had to fix.

I added const qualifiers to various md5 methods, as I expect
to call them on behalf of request sock traffic even if
the listener socket is not locked. This seems ok, but adding
const makes the contract clearer. Note a good reduction
of code size thanks to request/establish sockets convergence.
====================

Signed-off-by: David S. Miller <davem@davemloft.net>
David S. Miller 10 years ago
parent
commit
b1275eb32e
5 changed files with 72 additions and 117 deletions
  1. 17 20
      include/net/tcp.h
  2. 13 12
      net/ipv4/tcp.c
  3. 15 38
      net/ipv4/tcp_ipv4.c
  4. 13 11
      net/ipv4/tcp_output.c
  5. 14 36
      net/ipv6/tcp_ipv6.c

+ 17 - 20
include/net/tcp.h

@@ -1296,15 +1296,14 @@ struct tcp_md5sig_pool {
 };
 
 /* - functions */
-int tcp_v4_md5_hash_skb(char *md5_hash, struct tcp_md5sig_key *key,
-			const struct sock *sk, const struct request_sock *req,
-			const struct sk_buff *skb);
+int tcp_v4_md5_hash_skb(char *md5_hash, const struct tcp_md5sig_key *key,
+			const struct sock *sk, const struct sk_buff *skb);
 int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
 		   int family, const u8 *newkey, u8 newkeylen, gfp_t gfp);
 int tcp_md5_do_del(struct sock *sk, const union tcp_md5_addr *addr,
 		   int family);
 struct tcp_md5sig_key *tcp_v4_md5_lookup(struct sock *sk,
-					 struct sock *addr_sk);
+					 const struct sock *addr_sk);
 
 #ifdef CONFIG_TCP_MD5SIG
 struct tcp_md5sig_key *tcp_md5_do_lookup(struct sock *sk,
@@ -1615,28 +1614,26 @@ int tcp_conn_request(struct request_sock_ops *rsk_ops,
 struct tcp_sock_af_ops {
 #ifdef CONFIG_TCP_MD5SIG
 	struct tcp_md5sig_key	*(*md5_lookup) (struct sock *sk,
-						struct sock *addr_sk);
-	int			(*calc_md5_hash) (char *location,
-						  struct tcp_md5sig_key *md5,
-						  const struct sock *sk,
-						  const struct request_sock *req,
-						  const struct sk_buff *skb);
-	int			(*md5_parse) (struct sock *sk,
-					      char __user *optval,
-					      int optlen);
+						const struct sock *addr_sk);
+	int		(*calc_md5_hash)(char *location,
+					 const struct tcp_md5sig_key *md5,
+					 const struct sock *sk,
+					 const struct sk_buff *skb);
+	int		(*md5_parse)(struct sock *sk,
+				     char __user *optval,
+				     int optlen);
 #endif
 };
 
 struct tcp_request_sock_ops {
 	u16 mss_clamp;
 #ifdef CONFIG_TCP_MD5SIG
-	struct tcp_md5sig_key	*(*md5_lookup) (struct sock *sk,
-						struct request_sock *req);
-	int			(*calc_md5_hash) (char *location,
-						  struct tcp_md5sig_key *md5,
-						  const struct sock *sk,
-						  const struct request_sock *req,
-						  const struct sk_buff *skb);
+	struct tcp_md5sig_key *(*req_md5_lookup)(struct sock *sk,
+						 const struct sock *addr_sk);
+	int		(*calc_md5_hash) (char *location,
+					  const struct tcp_md5sig_key *md5,
+					  const struct sock *sk,
+					  const struct sk_buff *skb);
 #endif
 	void (*init_req)(struct request_sock *req, struct sock *sk,
 			 struct sk_buff *skb);

+ 13 - 12
net/ipv4/tcp.c

@@ -1913,18 +1913,19 @@ EXPORT_SYMBOL_GPL(tcp_set_state);
 
 static const unsigned char new_state[16] = {
   /* current state:        new state:      action:	*/
-  /* (Invalid)		*/ TCP_CLOSE,
-  /* TCP_ESTABLISHED	*/ TCP_FIN_WAIT1 | TCP_ACTION_FIN,
-  /* TCP_SYN_SENT	*/ TCP_CLOSE,
-  /* TCP_SYN_RECV	*/ TCP_FIN_WAIT1 | TCP_ACTION_FIN,
-  /* TCP_FIN_WAIT1	*/ TCP_FIN_WAIT1,
-  /* TCP_FIN_WAIT2	*/ TCP_FIN_WAIT2,
-  /* TCP_TIME_WAIT	*/ TCP_CLOSE,
-  /* TCP_CLOSE		*/ TCP_CLOSE,
-  /* TCP_CLOSE_WAIT	*/ TCP_LAST_ACK  | TCP_ACTION_FIN,
-  /* TCP_LAST_ACK	*/ TCP_LAST_ACK,
-  /* TCP_LISTEN		*/ TCP_CLOSE,
-  /* TCP_CLOSING	*/ TCP_CLOSING,
+  [0 /* (Invalid) */]	= TCP_CLOSE,
+  [TCP_ESTABLISHED]	= TCP_FIN_WAIT1 | TCP_ACTION_FIN,
+  [TCP_SYN_SENT]	= TCP_CLOSE,
+  [TCP_SYN_RECV]	= TCP_FIN_WAIT1 | TCP_ACTION_FIN,
+  [TCP_FIN_WAIT1]	= TCP_FIN_WAIT1,
+  [TCP_FIN_WAIT2]	= TCP_FIN_WAIT2,
+  [TCP_TIME_WAIT]	= TCP_CLOSE,
+  [TCP_CLOSE]		= TCP_CLOSE,
+  [TCP_CLOSE_WAIT]	= TCP_LAST_ACK  | TCP_ACTION_FIN,
+  [TCP_LAST_ACK]	= TCP_LAST_ACK,
+  [TCP_LISTEN]		= TCP_CLOSE,
+  [TCP_CLOSING]		= TCP_CLOSING,
+  [TCP_NEW_SYN_RECV]	= TCP_CLOSE,	/* should not happen ! */
 };
 
 static int tcp_close_state(struct sock *sk)

+ 15 - 38
net/ipv4/tcp_ipv4.c

@@ -648,7 +648,7 @@ static void tcp_v4_send_reset(struct sock *sk, struct sk_buff *skb)
 		if (!key)
 			goto release_sk1;
 
-		genhash = tcp_v4_md5_hash_skb(newhash, key, NULL, NULL, skb);
+		genhash = tcp_v4_md5_hash_skb(newhash, key, NULL, skb);
 		if (genhash || memcmp(hash_location, newhash, 16) != 0)
 			goto release_sk1;
 	} else {
@@ -898,10 +898,10 @@ struct tcp_md5sig_key *tcp_md5_do_lookup(struct sock *sk,
 					 const union tcp_md5_addr *addr,
 					 int family)
 {
-	struct tcp_sock *tp = tcp_sk(sk);
+	const struct tcp_sock *tp = tcp_sk(sk);
 	struct tcp_md5sig_key *key;
 	unsigned int size = sizeof(struct in_addr);
-	struct tcp_md5sig_info *md5sig;
+	const struct tcp_md5sig_info *md5sig;
 
 	/* caller either holds rcu_read_lock() or socket lock */
 	md5sig = rcu_dereference_check(tp->md5sig_info,
@@ -924,24 +924,15 @@ struct tcp_md5sig_key *tcp_md5_do_lookup(struct sock *sk,
 EXPORT_SYMBOL(tcp_md5_do_lookup);
 
 struct tcp_md5sig_key *tcp_v4_md5_lookup(struct sock *sk,
-					 struct sock *addr_sk)
+					 const struct sock *addr_sk)
 {
 	union tcp_md5_addr *addr;
 
-	addr = (union tcp_md5_addr *)&inet_sk(addr_sk)->inet_daddr;
+	addr = (union tcp_md5_addr *)&sk->sk_daddr;
 	return tcp_md5_do_lookup(sk, addr, AF_INET);
 }
 EXPORT_SYMBOL(tcp_v4_md5_lookup);
 
-static struct tcp_md5sig_key *tcp_v4_reqsk_md5_lookup(struct sock *sk,
-						      struct request_sock *req)
-{
-	union tcp_md5_addr *addr;
-
-	addr = (union tcp_md5_addr *)&inet_rsk(req)->ir_rmt_addr;
-	return tcp_md5_do_lookup(sk, addr, AF_INET);
-}
-
 /* This can be called on a newly created socket, from other files */
 int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
 		   int family, const u8 *newkey, u8 newkeylen, gfp_t gfp)
@@ -1102,8 +1093,8 @@ clear_hash_noput:
 	return 1;
 }
 
-int tcp_v4_md5_hash_skb(char *md5_hash, struct tcp_md5sig_key *key,
-			const struct sock *sk, const struct request_sock *req,
+int tcp_v4_md5_hash_skb(char *md5_hash, const struct tcp_md5sig_key *key,
+			const struct sock *sk,
 			const struct sk_buff *skb)
 {
 	struct tcp_md5sig_pool *hp;
@@ -1111,12 +1102,9 @@ int tcp_v4_md5_hash_skb(char *md5_hash, struct tcp_md5sig_key *key,
 	const struct tcphdr *th = tcp_hdr(skb);
 	__be32 saddr, daddr;
 
-	if (sk) {
-		saddr = inet_sk(sk)->inet_saddr;
-		daddr = inet_sk(sk)->inet_daddr;
-	} else if (req) {
-		saddr = inet_rsk(req)->ir_loc_addr;
-		daddr = inet_rsk(req)->ir_rmt_addr;
+	if (sk) { /* valid for establish/request sockets */
+		saddr = sk->sk_rcv_saddr;
+		daddr = sk->sk_daddr;
 	} else {
 		const struct iphdr *iph = ip_hdr(skb);
 		saddr = iph->saddr;
@@ -1153,8 +1141,9 @@ clear_hash_noput:
 }
 EXPORT_SYMBOL(tcp_v4_md5_hash_skb);
 
-static bool __tcp_v4_inbound_md5_hash(struct sock *sk,
-				      const struct sk_buff *skb)
+/* Called with rcu_read_lock() */
+static bool tcp_v4_inbound_md5_hash(struct sock *sk,
+				    const struct sk_buff *skb)
 {
 	/*
 	 * This gets called for each TCP segment that arrives
@@ -1194,7 +1183,7 @@ static bool __tcp_v4_inbound_md5_hash(struct sock *sk,
 	 */
 	genhash = tcp_v4_md5_hash_skb(newhash,
 				      hash_expected,
-				      NULL, NULL, skb);
+				      NULL, skb);
 
 	if (genhash || memcmp(hash_location, newhash, 16) != 0) {
 		net_info_ratelimited("MD5 Hash failed for (%pI4, %d)->(%pI4, %d)%s\n",
@@ -1206,18 +1195,6 @@ static bool __tcp_v4_inbound_md5_hash(struct sock *sk,
 	}
 	return false;
 }
-
-static bool tcp_v4_inbound_md5_hash(struct sock *sk, const struct sk_buff *skb)
-{
-	bool ret;
-
-	rcu_read_lock();
-	ret = __tcp_v4_inbound_md5_hash(sk, skb);
-	rcu_read_unlock();
-
-	return ret;
-}
-
 #endif
 
 static void tcp_v4_init_req(struct request_sock *req, struct sock *sk_listener,
@@ -1261,7 +1238,7 @@ struct request_sock_ops tcp_request_sock_ops __read_mostly = {
 static const struct tcp_request_sock_ops tcp_request_sock_ipv4_ops = {
 	.mss_clamp	=	TCP_MSS_DEFAULT,
 #ifdef CONFIG_TCP_MD5SIG
-	.md5_lookup	=	tcp_v4_reqsk_md5_lookup,
+	.req_md5_lookup	=	tcp_v4_md5_lookup,
 	.calc_md5_hash	=	tcp_v4_md5_hash_skb,
 #endif
 	.init_req	=	tcp_v4_init_req,

+ 13 - 11
net/ipv4/tcp_output.c

@@ -601,15 +601,14 @@ static unsigned int tcp_synack_options(struct sock *sk,
 				   struct request_sock *req,
 				   unsigned int mss, struct sk_buff *skb,
 				   struct tcp_out_options *opts,
-				   struct tcp_md5sig_key **md5,
+				   const struct tcp_md5sig_key *md5,
 				   struct tcp_fastopen_cookie *foc)
 {
 	struct inet_request_sock *ireq = inet_rsk(req);
 	unsigned int remaining = MAX_TCP_OPTION_SPACE;
 
 #ifdef CONFIG_TCP_MD5SIG
-	*md5 = tcp_rsk(req)->af_specific->md5_lookup(sk, req);
-	if (*md5) {
+	if (md5) {
 		opts->options |= OPTION_MD5;
 		remaining -= TCPOLEN_MD5SIG_ALIGNED;
 
@@ -620,8 +619,6 @@ static unsigned int tcp_synack_options(struct sock *sk,
 		 */
 		ireq->tstamp_ok &= !ireq->sack_ok;
 	}
-#else
-	*md5 = NULL;
 #endif
 
 	/* We always send an MSS option. */
@@ -989,7 +986,7 @@ static int tcp_transmit_skb(struct sock *sk, struct sk_buff *skb, int clone_it,
 	if (md5) {
 		sk_nocaps_add(sk, NETIF_F_GSO_MASK);
 		tp->af_specific->calc_md5_hash(opts.hash_location,
-					       md5, sk, NULL, skb);
+					       md5, sk, skb);
 	}
 #endif
 
@@ -2913,7 +2910,7 @@ struct sk_buff *tcp_make_synack(struct sock *sk, struct dst_entry *dst,
 	struct tcp_sock *tp = tcp_sk(sk);
 	struct tcphdr *th;
 	struct sk_buff *skb;
-	struct tcp_md5sig_key *md5;
+	struct tcp_md5sig_key *md5 = NULL;
 	int tcp_header_size;
 	int mss;
 
@@ -2938,7 +2935,12 @@ struct sk_buff *tcp_make_synack(struct sock *sk, struct dst_entry *dst,
 	else
 #endif
 	skb_mstamp_get(&skb->skb_mstamp);
-	tcp_header_size = tcp_synack_options(sk, req, mss, skb, &opts, &md5,
+
+#ifdef CONFIG_TCP_MD5SIG
+	rcu_read_lock();
+	md5 = tcp_rsk(req)->af_specific->req_md5_lookup(sk, req_to_sk(req));
+#endif
+	tcp_header_size = tcp_synack_options(sk, req, mss, skb, &opts, md5,
 					     foc) + sizeof(*th);
 
 	skb_push(skb, tcp_header_size);
@@ -2969,10 +2971,10 @@ struct sk_buff *tcp_make_synack(struct sock *sk, struct dst_entry *dst,
 
 #ifdef CONFIG_TCP_MD5SIG
 	/* Okay, we have all we need - do the md5 hash if needed */
-	if (md5) {
+	if (md5)
 		tcp_rsk(req)->af_specific->calc_md5_hash(opts.hash_location,
-					       md5, NULL, req, skb);
-	}
+					       md5, req_to_sk(req), skb);
+	rcu_read_unlock();
 #endif
 
 	return skb;

+ 14 - 36
net/ipv6/tcp_ipv6.c

@@ -486,17 +486,11 @@ static struct tcp_md5sig_key *tcp_v6_md5_do_lookup(struct sock *sk,
 }
 
 static struct tcp_md5sig_key *tcp_v6_md5_lookup(struct sock *sk,
-						struct sock *addr_sk)
+						const struct sock *addr_sk)
 {
 	return tcp_v6_md5_do_lookup(sk, &addr_sk->sk_v6_daddr);
 }
 
-static struct tcp_md5sig_key *tcp_v6_reqsk_md5_lookup(struct sock *sk,
-						      struct request_sock *req)
-{
-	return tcp_v6_md5_do_lookup(sk, &inet_rsk(req)->ir_v6_rmt_addr);
-}
-
 static int tcp_v6_parse_md5_keys(struct sock *sk, char __user *optval,
 				 int optlen)
 {
@@ -582,9 +576,9 @@ clear_hash_noput:
 	return 1;
 }
 
-static int tcp_v6_md5_hash_skb(char *md5_hash, struct tcp_md5sig_key *key,
+static int tcp_v6_md5_hash_skb(char *md5_hash,
+			       const struct tcp_md5sig_key *key,
 			       const struct sock *sk,
-			       const struct request_sock *req,
 			       const struct sk_buff *skb)
 {
 	const struct in6_addr *saddr, *daddr;
@@ -592,12 +586,9 @@ static int tcp_v6_md5_hash_skb(char *md5_hash, struct tcp_md5sig_key *key,
 	struct hash_desc *desc;
 	const struct tcphdr *th = tcp_hdr(skb);
 
-	if (sk) {
-		saddr = &inet6_sk(sk)->saddr;
+	if (sk) { /* valid for establish/request sockets */
+		saddr = &sk->sk_v6_rcv_saddr;
 		daddr = &sk->sk_v6_daddr;
-	} else if (req) {
-		saddr = &inet_rsk(req)->ir_v6_loc_addr;
-		daddr = &inet_rsk(req)->ir_v6_rmt_addr;
 	} else {
 		const struct ipv6hdr *ip6h = ipv6_hdr(skb);
 		saddr = &ip6h->saddr;
@@ -633,8 +624,7 @@ clear_hash_noput:
 	return 1;
 }
 
-static int __tcp_v6_inbound_md5_hash(struct sock *sk,
-				     const struct sk_buff *skb)
+static bool tcp_v6_inbound_md5_hash(struct sock *sk, const struct sk_buff *skb)
 {
 	const __u8 *hash_location = NULL;
 	struct tcp_md5sig_key *hash_expected;
@@ -648,44 +638,32 @@ static int __tcp_v6_inbound_md5_hash(struct sock *sk,
 
 	/* We've parsed the options - do we have a hash? */
 	if (!hash_expected && !hash_location)
-		return 0;
+		return false;
 
 	if (hash_expected && !hash_location) {
 		NET_INC_STATS_BH(sock_net(sk), LINUX_MIB_TCPMD5NOTFOUND);
-		return 1;
+		return true;
 	}
 
 	if (!hash_expected && hash_location) {
 		NET_INC_STATS_BH(sock_net(sk), LINUX_MIB_TCPMD5UNEXPECTED);
-		return 1;
+		return true;
 	}
 
 	/* check the signature */
 	genhash = tcp_v6_md5_hash_skb(newhash,
 				      hash_expected,
-				      NULL, NULL, skb);
+				      NULL, skb);
 
 	if (genhash || memcmp(hash_location, newhash, 16) != 0) {
 		net_info_ratelimited("MD5 Hash %s for [%pI6c]:%u->[%pI6c]:%u\n",
 				     genhash ? "failed" : "mismatch",
 				     &ip6h->saddr, ntohs(th->source),
 				     &ip6h->daddr, ntohs(th->dest));
-		return 1;
+		return true;
 	}
-	return 0;
+	return false;
 }
-
-static int tcp_v6_inbound_md5_hash(struct sock *sk, const struct sk_buff *skb)
-{
-	int ret;
-
-	rcu_read_lock();
-	ret = __tcp_v6_inbound_md5_hash(sk, skb);
-	rcu_read_unlock();
-
-	return ret;
-}
-
 #endif
 
 static void tcp_v6_init_req(struct request_sock *req, struct sock *sk,
@@ -736,7 +714,7 @@ static const struct tcp_request_sock_ops tcp_request_sock_ipv6_ops = {
 	.mss_clamp	=	IPV6_MIN_MTU - sizeof(struct tcphdr) -
 				sizeof(struct ipv6hdr),
 #ifdef CONFIG_TCP_MD5SIG
-	.md5_lookup	=	tcp_v6_reqsk_md5_lookup,
+	.req_md5_lookup	=	tcp_v6_md5_lookup,
 	.calc_md5_hash	=	tcp_v6_md5_hash_skb,
 #endif
 	.init_req	=	tcp_v6_init_req,
@@ -893,7 +871,7 @@ static void tcp_v6_send_reset(struct sock *sk, struct sk_buff *skb)
 		if (!key)
 			goto release_sk1;
 
-		genhash = tcp_v6_md5_hash_skb(newhash, key, NULL, NULL, skb);
+		genhash = tcp_v6_md5_hash_skb(newhash, key, NULL, skb);
 		if (genhash || memcmp(hash_location, newhash, 16) != 0)
 			goto release_sk1;
 	} else {