Browse Source

Merge branch 'net-l3mdev-Support-for-sockets-bound-to-enslaved-device'

David Ahern says:

====================
net: l3mdev: Support for sockets bound to enslaved device

A missing piece to the VRF puzzle is the ability to bind sockets to
devices enslaved to a VRF. This patch set adds the enslaved device
index, sdif, to IPv4 and IPv6 socket lookups. The end result for users
is the following scope options for services:

1. "global" services - sockets not bound to any device

   Allows 1 service to work across all network interfaces with
   connected sockets bound to the VRF the connection originates
   (Requires net.ipv4.tcp_l3mdev_accept=1 for TCP and
    net.ipv4.udp_l3mdev_accept=1 for UDP)

2. "VRF" local services - sockets bound to a VRF

   Sockets work across all network interfaces enslaved to a VRF but
   are limited to just the one VRF.

3. "device" services - sockets bound to a specific network interface

   Service works only through the one specific interface.

v3
- convert __inet_lookup_established in dccp_v4_err; missed in v2

v2
- remove sk_lookup struct and add sdif as an argument to existing
  functions

Changes since RFC:
- no significant logic changes; mainly whitespace cleanups
====================

Signed-off-by: David S. Miller <davem@davemloft.net>
David S. Miller 8 years ago
parent
commit
9bcb5a572f

+ 2 - 1
include/linux/igmp.h

@@ -118,7 +118,8 @@ extern int ip_mc_msfget(struct sock *sk, struct ip_msfilter *msf,
 		struct ip_msfilter __user *optval, int __user *optlen);
 extern int ip_mc_gsfget(struct sock *sk, struct group_filter *gsf,
 		struct group_filter __user *optval, int __user *optlen);
-extern int ip_mc_sf_allow(struct sock *sk, __be32 local, __be32 rmt, int dif);
+extern int ip_mc_sf_allow(struct sock *sk, __be32 local, __be32 rmt,
+			  int dif, int sdif);
 extern void ip_mc_init_dev(struct in_device *);
 extern void ip_mc_destroy_dev(struct in_device *);
 extern void ip_mc_up(struct in_device *);

+ 10 - 0
include/linux/ipv6.h

@@ -158,6 +158,16 @@ static inline bool inet6_is_jumbogram(const struct sk_buff *skb)
 	return !!(IP6CB(skb)->flags & IP6SKB_JUMBOGRAM);
 }
 
+/* can not be used in TCP layer after tcp_v6_fill_cb */
+static inline int inet6_sdif(const struct sk_buff *skb)
+{
+#if IS_ENABLED(CONFIG_NET_L3_MASTER_DEV)
+	if (skb && ipv6_l3mdev_skb(IP6CB(skb)->flags))
+		return IP6CB(skb)->iif;
+#endif
+	return 0;
+}
+
 /* can not be used in TCP layer after tcp_v6_fill_cb */
 static inline bool inet6_exact_dif_match(struct net *net, struct sk_buff *skb)
 {

+ 13 - 9
include/net/inet6_hashtables.h

@@ -49,7 +49,8 @@ struct sock *__inet6_lookup_established(struct net *net,
 					const struct in6_addr *saddr,
 					const __be16 sport,
 					const struct in6_addr *daddr,
-					const u16 hnum, const int dif);
+					const u16 hnum, const int dif,
+					const int sdif);
 
 struct sock *inet6_lookup_listener(struct net *net,
 				   struct inet_hashinfo *hashinfo,
@@ -57,7 +58,8 @@ struct sock *inet6_lookup_listener(struct net *net,
 				   const struct in6_addr *saddr,
 				   const __be16 sport,
 				   const struct in6_addr *daddr,
-				   const unsigned short hnum, const int dif);
+				   const unsigned short hnum,
+				   const int dif, const int sdif);
 
 static inline struct sock *__inet6_lookup(struct net *net,
 					  struct inet_hashinfo *hashinfo,
@@ -66,24 +68,25 @@ static inline struct sock *__inet6_lookup(struct net *net,
 					  const __be16 sport,
 					  const struct in6_addr *daddr,
 					  const u16 hnum,
-					  const int dif,
+					  const int dif, const int sdif,
 					  bool *refcounted)
 {
 	struct sock *sk = __inet6_lookup_established(net, hashinfo, saddr,
-						sport, daddr, hnum, dif);
+						     sport, daddr, hnum,
+						     dif, sdif);
 	*refcounted = true;
 	if (sk)
 		return sk;
 	*refcounted = false;
 	return inet6_lookup_listener(net, hashinfo, skb, doff, saddr, sport,
-				     daddr, hnum, dif);
+				     daddr, hnum, dif, sdif);
 }
 
 static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo,
 					      struct sk_buff *skb, int doff,
 					      const __be16 sport,
 					      const __be16 dport,
-					      int iif,
+					      int iif, int sdif,
 					      bool *refcounted)
 {
 	struct sock *sk = skb_steal_sock(skb);
@@ -95,7 +98,7 @@ static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo,
 	return __inet6_lookup(dev_net(skb_dst(skb)->dev), hashinfo, skb,
 			      doff, &ipv6_hdr(skb)->saddr, sport,
 			      &ipv6_hdr(skb)->daddr, ntohs(dport),
-			      iif, refcounted);
+			      iif, sdif, refcounted);
 }
 
 struct sock *inet6_lookup(struct net *net, struct inet_hashinfo *hashinfo,
@@ -107,13 +110,14 @@ struct sock *inet6_lookup(struct net *net, struct inet_hashinfo *hashinfo,
 int inet6_hash(struct sock *sk);
 #endif /* IS_ENABLED(CONFIG_IPV6) */
 
-#define INET6_MATCH(__sk, __net, __saddr, __daddr, __ports, __dif)	\
+#define INET6_MATCH(__sk, __net, __saddr, __daddr, __ports, __dif, __sdif) \
 	(((__sk)->sk_portpair == (__ports))			&&	\
 	 ((__sk)->sk_family == AF_INET6)			&&	\
 	 ipv6_addr_equal(&(__sk)->sk_v6_daddr, (__saddr))		&&	\
 	 ipv6_addr_equal(&(__sk)->sk_v6_rcv_saddr, (__daddr))	&&	\
 	 (!(__sk)->sk_bound_dev_if	||				\
-	   ((__sk)->sk_bound_dev_if == (__dif))) 		&&	\
+	   ((__sk)->sk_bound_dev_if == (__dif))	||			\
+	   ((__sk)->sk_bound_dev_if == (__sdif)))		&&	\
 	 net_eq(sock_net(__sk), (__net)))
 
 #endif /* _INET6_HASHTABLES_H */

+ 17 - 14
include/net/inet_hashtables.h

@@ -221,16 +221,16 @@ struct sock *__inet_lookup_listener(struct net *net,
 				    const __be32 saddr, const __be16 sport,
 				    const __be32 daddr,
 				    const unsigned short hnum,
-				    const int dif);
+				    const int dif, const int sdif);
 
 static inline struct sock *inet_lookup_listener(struct net *net,
 		struct inet_hashinfo *hashinfo,
 		struct sk_buff *skb, int doff,
 		__be32 saddr, __be16 sport,
-		__be32 daddr, __be16 dport, int dif)
+		__be32 daddr, __be16 dport, int dif, int sdif)
 {
 	return __inet_lookup_listener(net, hashinfo, skb, doff, saddr, sport,
-				      daddr, ntohs(dport), dif);
+				      daddr, ntohs(dport), dif, sdif);
 }
 
 /* Socket demux engine toys. */
@@ -262,22 +262,24 @@ static inline struct sock *inet_lookup_listener(struct net *net,
 				   (((__force __u64)(__be32)(__daddr)) << 32) | \
 				   ((__force __u64)(__be32)(__saddr)))
 #endif /* __BIG_ENDIAN */
-#define INET_MATCH(__sk, __net, __cookie, __saddr, __daddr, __ports, __dif)	\
+#define INET_MATCH(__sk, __net, __cookie, __saddr, __daddr, __ports, __dif, __sdif) \
 	(((__sk)->sk_portpair == (__ports))			&&	\
 	 ((__sk)->sk_addrpair == (__cookie))			&&	\
 	 (!(__sk)->sk_bound_dev_if	||				\
-	   ((__sk)->sk_bound_dev_if == (__dif))) 		&& 	\
+	   ((__sk)->sk_bound_dev_if == (__dif))			||	\
+	   ((__sk)->sk_bound_dev_if == (__sdif)))		&&	\
 	 net_eq(sock_net(__sk), (__net)))
 #else /* 32-bit arch */
 #define INET_ADDR_COOKIE(__name, __saddr, __daddr) \
 	const int __name __deprecated __attribute__((unused))
 
-#define INET_MATCH(__sk, __net, __cookie, __saddr, __daddr, __ports, __dif) \
+#define INET_MATCH(__sk, __net, __cookie, __saddr, __daddr, __ports, __dif, __sdif) \
 	(((__sk)->sk_portpair == (__ports))		&&		\
 	 ((__sk)->sk_daddr	== (__saddr))		&&		\
 	 ((__sk)->sk_rcv_saddr	== (__daddr))		&&		\
 	 (!(__sk)->sk_bound_dev_if	||				\
-	   ((__sk)->sk_bound_dev_if == (__dif))) 	&&		\
+	   ((__sk)->sk_bound_dev_if == (__dif))		||		\
+	   ((__sk)->sk_bound_dev_if == (__sdif)))	&&		\
 	 net_eq(sock_net(__sk), (__net)))
 #endif /* 64-bit arch */
 
@@ -288,7 +290,7 @@ struct sock *__inet_lookup_established(struct net *net,
 				       struct inet_hashinfo *hashinfo,
 				       const __be32 saddr, const __be16 sport,
 				       const __be32 daddr, const u16 hnum,
-				       const int dif);
+				       const int dif, const int sdif);
 
 static inline struct sock *
 	inet_lookup_established(struct net *net, struct inet_hashinfo *hashinfo,
@@ -297,7 +299,7 @@ static inline struct sock *
 				const int dif)
 {
 	return __inet_lookup_established(net, hashinfo, saddr, sport, daddr,
-					 ntohs(dport), dif);
+					 ntohs(dport), dif, 0);
 }
 
 static inline struct sock *__inet_lookup(struct net *net,
@@ -305,20 +307,20 @@ static inline struct sock *__inet_lookup(struct net *net,
 					 struct sk_buff *skb, int doff,
 					 const __be32 saddr, const __be16 sport,
 					 const __be32 daddr, const __be16 dport,
-					 const int dif,
+					 const int dif, const int sdif,
 					 bool *refcounted)
 {
 	u16 hnum = ntohs(dport);
 	struct sock *sk;
 
 	sk = __inet_lookup_established(net, hashinfo, saddr, sport,
-				       daddr, hnum, dif);
+				       daddr, hnum, dif, sdif);
 	*refcounted = true;
 	if (sk)
 		return sk;
 	*refcounted = false;
 	return __inet_lookup_listener(net, hashinfo, skb, doff, saddr,
-				      sport, daddr, hnum, dif);
+				      sport, daddr, hnum, dif, sdif);
 }
 
 static inline struct sock *inet_lookup(struct net *net,
@@ -332,7 +334,7 @@ static inline struct sock *inet_lookup(struct net *net,
 	bool refcounted;
 
 	sk = __inet_lookup(net, hashinfo, skb, doff, saddr, sport, daddr,
-			   dport, dif, &refcounted);
+			   dport, dif, 0, &refcounted);
 
 	if (sk && !refcounted && !refcount_inc_not_zero(&sk->sk_refcnt))
 		sk = NULL;
@@ -344,6 +346,7 @@ static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo,
 					     int doff,
 					     const __be16 sport,
 					     const __be16 dport,
+					     const int sdif,
 					     bool *refcounted)
 {
 	struct sock *sk = skb_steal_sock(skb);
@@ -355,7 +358,7 @@ static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo,
 
 	return __inet_lookup(dev_net(skb_dst(skb)->dev), hashinfo, skb,
 			     doff, iph->saddr, sport,
-			     iph->daddr, dport, inet_iif(skb),
+			     iph->daddr, dport, inet_iif(skb), sdif,
 			     refcounted);
 }
 

+ 10 - 0
include/net/ip.h

@@ -78,6 +78,16 @@ struct ipcm_cookie {
 #define IPCB(skb) ((struct inet_skb_parm*)((skb)->cb))
 #define PKTINFO_SKB_CB(skb) ((struct in_pktinfo *)((skb)->cb))
 
+/* return enslaved device index if relevant */
+static inline int inet_sdif(struct sk_buff *skb)
+{
+#if IS_ENABLED(CONFIG_NET_L3_MASTER_DEV)
+	if (skb && ipv4_l3mdev_skb(IPCB(skb)->flags))
+		return IPCB(skb)->iif;
+#endif
+	return 0;
+}
+
 struct ip_ra_chain {
 	struct ip_ra_chain __rcu *next;
 	struct sock		*sk;

+ 1 - 1
include/net/raw.h

@@ -26,7 +26,7 @@ extern struct proto raw_prot;
 extern struct raw_hashinfo raw_v4_hashinfo;
 struct sock *__raw_v4_lookup(struct net *net, struct sock *sk,
 			     unsigned short num, __be32 raddr,
-			     __be32 laddr, int dif);
+			     __be32 laddr, int dif, int sdif);
 
 int raw_abort(struct sock *sk, int err);
 void raw_icmp_error(struct sk_buff *, int, u32);

+ 1 - 1
include/net/rawv6.h

@@ -6,7 +6,7 @@
 extern struct raw_hashinfo raw_v6_hashinfo;
 struct sock *__raw_v6_lookup(struct net *net, struct sock *sk,
 			     unsigned short num, const struct in6_addr *loc_addr,
-			     const struct in6_addr *rmt_addr, int dif);
+			     const struct in6_addr *rmt_addr, int dif, int sdif);
 
 int raw_abort(struct sock *sk, int err);
 

+ 20 - 0
include/net/tcp.h

@@ -827,6 +827,16 @@ static inline int tcp_v6_iif(const struct sk_buff *skb)
 
 	return l3_slave ? skb->skb_iif : TCP_SKB_CB(skb)->header.h6.iif;
 }
+
+/* TCP_SKB_CB reference means this can not be used from early demux */
+static inline int tcp_v6_sdif(const struct sk_buff *skb)
+{
+#if IS_ENABLED(CONFIG_NET_L3_MASTER_DEV)
+	if (skb && ipv6_l3mdev_skb(TCP_SKB_CB(skb)->header.h6.flags))
+		return TCP_SKB_CB(skb)->header.h6.iif;
+#endif
+	return 0;
+}
 #endif
 
 /* TCP_SKB_CB reference means this can not be used from early demux */
@@ -840,6 +850,16 @@ static inline bool inet_exact_dif_match(struct net *net, struct sk_buff *skb)
 	return false;
 }
 
+/* TCP_SKB_CB reference means this can not be used from early demux */
+static inline int tcp_v4_sdif(struct sk_buff *skb)
+{
+#if IS_ENABLED(CONFIG_NET_L3_MASTER_DEV)
+	if (skb && ipv4_l3mdev_skb(TCP_SKB_CB(skb)->header.h4.flags))
+		return TCP_SKB_CB(skb)->header.h4.iif;
+#endif
+	return 0;
+}
+
 /* Due to TSO, an SKB can be composed of multiple actual
  * packets.  To keep these tracked properly, we use this.
  */

+ 2 - 2
include/net/udp.h

@@ -287,7 +287,7 @@ int udp_lib_setsockopt(struct sock *sk, int level, int optname,
 struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport,
 			     __be32 daddr, __be16 dport, int dif);
 struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport,
-			       __be32 daddr, __be16 dport, int dif,
+			       __be32 daddr, __be16 dport, int dif, int sdif,
 			       struct udp_table *tbl, struct sk_buff *skb);
 struct sock *udp4_lib_lookup_skb(struct sk_buff *skb,
 				 __be16 sport, __be16 dport);
@@ -298,7 +298,7 @@ struct sock *udp6_lib_lookup(struct net *net,
 struct sock *__udp6_lib_lookup(struct net *net,
 			       const struct in6_addr *saddr, __be16 sport,
 			       const struct in6_addr *daddr, __be16 dport,
-			       int dif, struct udp_table *tbl,
+			       int dif, int sdif, struct udp_table *tbl,
 			       struct sk_buff *skb);
 struct sock *udp6_lib_lookup_skb(struct sk_buff *skb,
 				 __be16 sport, __be16 dport);

+ 2 - 2
net/dccp/ipv4.c

@@ -256,7 +256,7 @@ static void dccp_v4_err(struct sk_buff *skb, u32 info)
 	sk = __inet_lookup_established(net, &dccp_hashinfo,
 				       iph->daddr, dh->dccph_dport,
 				       iph->saddr, ntohs(dh->dccph_sport),
-				       inet_iif(skb));
+				       inet_iif(skb), 0);
 	if (!sk) {
 		__ICMP_INC_STATS(net, ICMP_MIB_INERRORS);
 		return;
@@ -804,7 +804,7 @@ static int dccp_v4_rcv(struct sk_buff *skb)
 
 lookup:
 	sk = __inet_lookup_skb(&dccp_hashinfo, skb, __dccp_hdr_len(dh),
-			       dh->dccph_sport, dh->dccph_dport, &refcounted);
+			       dh->dccph_sport, dh->dccph_dport, 0, &refcounted);
 	if (!sk) {
 		dccp_pr_debug("failed to look up flow ID in table and "
 			      "get corresponding socket\n");

+ 2 - 2
net/dccp/ipv6.c

@@ -89,7 +89,7 @@ static void dccp_v6_err(struct sk_buff *skb, struct inet6_skb_parm *opt,
 	sk = __inet6_lookup_established(net, &dccp_hashinfo,
 					&hdr->daddr, dh->dccph_dport,
 					&hdr->saddr, ntohs(dh->dccph_sport),
-					inet6_iif(skb));
+					inet6_iif(skb), 0);
 
 	if (!sk) {
 		__ICMP6_INC_STATS(net, __in6_dev_get(skb->dev),
@@ -687,7 +687,7 @@ static int dccp_v6_rcv(struct sk_buff *skb)
 lookup:
 	sk = __inet6_lookup_skb(&dccp_hashinfo, skb, __dccp_hdr_len(dh),
 			        dh->dccph_sport, dh->dccph_dport,
-				inet6_iif(skb), &refcounted);
+				inet6_iif(skb), 0, &refcounted);
 	if (!sk) {
 		dccp_pr_debug("failed to look up flow ID in table and "
 			      "get corresponding socket\n");

+ 4 - 2
net/ipv4/igmp.c

@@ -2549,7 +2549,8 @@ done:
 /*
  * check if a multicast source filter allows delivery for a given <src,dst,intf>
  */
-int ip_mc_sf_allow(struct sock *sk, __be32 loc_addr, __be32 rmt_addr, int dif)
+int ip_mc_sf_allow(struct sock *sk, __be32 loc_addr, __be32 rmt_addr,
+		   int dif, int sdif)
 {
 	struct inet_sock *inet = inet_sk(sk);
 	struct ip_mc_socklist *pmc;
@@ -2564,7 +2565,8 @@ int ip_mc_sf_allow(struct sock *sk, __be32 loc_addr, __be32 rmt_addr, int dif)
 	rcu_read_lock();
 	for_each_pmc_rcu(inet, pmc) {
 		if (pmc->multi.imr_multiaddr.s_addr == loc_addr &&
-		    pmc->multi.imr_ifindex == dif)
+		    (pmc->multi.imr_ifindex == dif ||
+		     (sdif && pmc->multi.imr_ifindex == sdif)))
 			break;
 	}
 	ret = inet->mc_all;

+ 17 - 10
net/ipv4/inet_hashtables.c

@@ -170,7 +170,7 @@ EXPORT_SYMBOL_GPL(__inet_inherit_port);
 
 static inline int compute_score(struct sock *sk, struct net *net,
 				const unsigned short hnum, const __be32 daddr,
-				const int dif, bool exact_dif)
+				const int dif, const int sdif, bool exact_dif)
 {
 	int score = -1;
 	struct inet_sock *inet = inet_sk(sk);
@@ -185,9 +185,13 @@ static inline int compute_score(struct sock *sk, struct net *net,
 			score += 4;
 		}
 		if (sk->sk_bound_dev_if || exact_dif) {
-			if (sk->sk_bound_dev_if != dif)
+			bool dev_match = (sk->sk_bound_dev_if == dif ||
+					  sk->sk_bound_dev_if == sdif);
+
+			if (exact_dif && !dev_match)
 				return -1;
-			score += 4;
+			if (sk->sk_bound_dev_if && dev_match)
+				score += 4;
 		}
 		if (sk->sk_incoming_cpu == raw_smp_processor_id())
 			score++;
@@ -208,7 +212,7 @@ struct sock *__inet_lookup_listener(struct net *net,
 				    struct sk_buff *skb, int doff,
 				    const __be32 saddr, __be16 sport,
 				    const __be32 daddr, const unsigned short hnum,
-				    const int dif)
+				    const int dif, const int sdif)
 {
 	unsigned int hash = inet_lhashfn(net, hnum);
 	struct inet_listen_hashbucket *ilb = &hashinfo->listening_hash[hash];
@@ -218,7 +222,8 @@ struct sock *__inet_lookup_listener(struct net *net,
 	u32 phash = 0;
 
 	sk_for_each_rcu(sk, &ilb->head) {
-		score = compute_score(sk, net, hnum, daddr, dif, exact_dif);
+		score = compute_score(sk, net, hnum, daddr,
+				      dif, sdif, exact_dif);
 		if (score > hiscore) {
 			reuseport = sk->sk_reuseport;
 			if (reuseport) {
@@ -268,7 +273,7 @@ struct sock *__inet_lookup_established(struct net *net,
 				  struct inet_hashinfo *hashinfo,
 				  const __be32 saddr, const __be16 sport,
 				  const __be32 daddr, const u16 hnum,
-				  const int dif)
+				  const int dif, const int sdif)
 {
 	INET_ADDR_COOKIE(acookie, saddr, daddr);
 	const __portpair ports = INET_COMBINED_PORTS(sport, hnum);
@@ -286,11 +291,12 @@ begin:
 		if (sk->sk_hash != hash)
 			continue;
 		if (likely(INET_MATCH(sk, net, acookie,
-				      saddr, daddr, ports, dif))) {
+				      saddr, daddr, ports, dif, sdif))) {
 			if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt)))
 				goto out;
 			if (unlikely(!INET_MATCH(sk, net, acookie,
-						 saddr, daddr, ports, dif))) {
+						 saddr, daddr, ports,
+						 dif, sdif))) {
 				sock_gen_put(sk);
 				goto begin;
 			}
@@ -321,9 +327,10 @@ static int __inet_check_established(struct inet_timewait_death_row *death_row,
 	__be32 daddr = inet->inet_rcv_saddr;
 	__be32 saddr = inet->inet_daddr;
 	int dif = sk->sk_bound_dev_if;
+	struct net *net = sock_net(sk);
+	int sdif = l3mdev_master_ifindex_by_index(net, dif);
 	INET_ADDR_COOKIE(acookie, saddr, daddr);
 	const __portpair ports = INET_COMBINED_PORTS(inet->inet_dport, lport);
-	struct net *net = sock_net(sk);
 	unsigned int hash = inet_ehashfn(net, daddr, lport,
 					 saddr, inet->inet_dport);
 	struct inet_ehash_bucket *head = inet_ehash_bucket(hinfo, hash);
@@ -339,7 +346,7 @@ static int __inet_check_established(struct inet_timewait_death_row *death_row,
 			continue;
 
 		if (likely(INET_MATCH(sk2, net, acookie,
-					 saddr, daddr, ports, dif))) {
+					 saddr, daddr, ports, dif, sdif))) {
 			if (sk2->sk_state == TCP_TIME_WAIT) {
 				tw = inet_twsk(sk2);
 				if (twsk_unique(sk, sk2, twp))

+ 12 - 6
net/ipv4/raw.c

@@ -122,7 +122,8 @@ void raw_unhash_sk(struct sock *sk)
 EXPORT_SYMBOL_GPL(raw_unhash_sk);
 
 struct sock *__raw_v4_lookup(struct net *net, struct sock *sk,
-		unsigned short num, __be32 raddr, __be32 laddr, int dif)
+			     unsigned short num, __be32 raddr, __be32 laddr,
+			     int dif, int sdif)
 {
 	sk_for_each_from(sk) {
 		struct inet_sock *inet = inet_sk(sk);
@@ -130,7 +131,8 @@ struct sock *__raw_v4_lookup(struct net *net, struct sock *sk,
 		if (net_eq(sock_net(sk), net) && inet->inet_num == num	&&
 		    !(inet->inet_daddr && inet->inet_daddr != raddr) 	&&
 		    !(inet->inet_rcv_saddr && inet->inet_rcv_saddr != laddr) &&
-		    !(sk->sk_bound_dev_if && sk->sk_bound_dev_if != dif))
+		    !(sk->sk_bound_dev_if && sk->sk_bound_dev_if != dif &&
+		      sk->sk_bound_dev_if != sdif))
 			goto found; /* gotcha */
 	}
 	sk = NULL;
@@ -171,6 +173,7 @@ static int icmp_filter(const struct sock *sk, const struct sk_buff *skb)
  */
 static int raw_v4_input(struct sk_buff *skb, const struct iphdr *iph, int hash)
 {
+	int sdif = inet_sdif(skb);
 	struct sock *sk;
 	struct hlist_head *head;
 	int delivered = 0;
@@ -184,13 +187,13 @@ static int raw_v4_input(struct sk_buff *skb, const struct iphdr *iph, int hash)
 	net = dev_net(skb->dev);
 	sk = __raw_v4_lookup(net, __sk_head(head), iph->protocol,
 			     iph->saddr, iph->daddr,
-			     skb->dev->ifindex);
+			     skb->dev->ifindex, sdif);
 
 	while (sk) {
 		delivered = 1;
 		if ((iph->protocol != IPPROTO_ICMP || !icmp_filter(sk, skb)) &&
 		    ip_mc_sf_allow(sk, iph->daddr, iph->saddr,
-				   skb->dev->ifindex)) {
+				   skb->dev->ifindex, sdif)) {
 			struct sk_buff *clone = skb_clone(skb, GFP_ATOMIC);
 
 			/* Not releasing hash table! */
@@ -199,7 +202,7 @@ static int raw_v4_input(struct sk_buff *skb, const struct iphdr *iph, int hash)
 		}
 		sk = __raw_v4_lookup(net, sk_next(sk), iph->protocol,
 				     iph->saddr, iph->daddr,
-				     skb->dev->ifindex);
+				     skb->dev->ifindex, sdif);
 	}
 out:
 	read_unlock(&raw_v4_hashinfo.lock);
@@ -297,12 +300,15 @@ void raw_icmp_error(struct sk_buff *skb, int protocol, u32 info)
 	read_lock(&raw_v4_hashinfo.lock);
 	raw_sk = sk_head(&raw_v4_hashinfo.ht[hash]);
 	if (raw_sk) {
+		int dif = skb->dev->ifindex;
+		int sdif = inet_sdif(skb);
+
 		iph = (const struct iphdr *)skb->data;
 		net = dev_net(skb->dev);
 
 		while ((raw_sk = __raw_v4_lookup(net, raw_sk, protocol,
 						iph->daddr, iph->saddr,
-						skb->dev->ifindex)) != NULL) {
+						dif, sdif)) != NULL) {
 			raw_err(raw_sk, skb, info);
 			raw_sk = sk_next(raw_sk);
 			iph = (const struct iphdr *)skb->data;

+ 2 - 2
net/ipv4/raw_diag.c

@@ -46,13 +46,13 @@ static struct sock *raw_lookup(struct net *net, struct sock *from,
 		sk = __raw_v4_lookup(net, from, r->sdiag_raw_protocol,
 				     r->id.idiag_dst[0],
 				     r->id.idiag_src[0],
-				     r->id.idiag_if);
+				     r->id.idiag_if, 0);
 #if IS_ENABLED(CONFIG_IPV6)
 	else
 		sk = __raw_v6_lookup(net, from, r->sdiag_raw_protocol,
 				     (const struct in6_addr *)r->id.idiag_src,
 				     (const struct in6_addr *)r->id.idiag_dst,
-				     r->id.idiag_if);
+				     r->id.idiag_if, 0);
 #endif
 	return sk;
 }

+ 8 - 5
net/ipv4/tcp_ipv4.c

@@ -383,7 +383,7 @@ void tcp_v4_err(struct sk_buff *icmp_skb, u32 info)
 
 	sk = __inet_lookup_established(net, &tcp_hashinfo, iph->daddr,
 				       th->dest, iph->saddr, ntohs(th->source),
-				       inet_iif(icmp_skb));
+				       inet_iif(icmp_skb), 0);
 	if (!sk) {
 		__ICMP_INC_STATS(net, ICMP_MIB_INERRORS);
 		return;
@@ -659,7 +659,8 @@ static void tcp_v4_send_reset(const struct sock *sk, struct sk_buff *skb)
 		sk1 = __inet_lookup_listener(net, &tcp_hashinfo, NULL, 0,
 					     ip_hdr(skb)->saddr,
 					     th->source, ip_hdr(skb)->daddr,
-					     ntohs(th->source), inet_iif(skb));
+					     ntohs(th->source), inet_iif(skb),
+					     tcp_v4_sdif(skb));
 		/* don't send rst if it can't find key */
 		if (!sk1)
 			goto out;
@@ -1523,7 +1524,7 @@ void tcp_v4_early_demux(struct sk_buff *skb)
 	sk = __inet_lookup_established(dev_net(skb->dev), &tcp_hashinfo,
 				       iph->saddr, th->source,
 				       iph->daddr, ntohs(th->dest),
-				       skb->skb_iif);
+				       skb->skb_iif, inet_sdif(skb));
 	if (sk) {
 		skb->sk = sk;
 		skb->destructor = sock_edemux;
@@ -1588,6 +1589,7 @@ EXPORT_SYMBOL(tcp_filter);
 int tcp_v4_rcv(struct sk_buff *skb)
 {
 	struct net *net = dev_net(skb->dev);
+	int sdif = inet_sdif(skb);
 	const struct iphdr *iph;
 	const struct tcphdr *th;
 	bool refcounted;
@@ -1638,7 +1640,7 @@ int tcp_v4_rcv(struct sk_buff *skb)
 
 lookup:
 	sk = __inet_lookup_skb(&tcp_hashinfo, skb, __tcp_hdrlen(th), th->source,
-			       th->dest, &refcounted);
+			       th->dest, sdif, &refcounted);
 	if (!sk)
 		goto no_tcp_socket;
 
@@ -1766,7 +1768,8 @@ do_time_wait:
 							__tcp_hdrlen(th),
 							iph->saddr, th->source,
 							iph->daddr, th->dest,
-							inet_iif(skb));
+							inet_iif(skb),
+							sdif);
 		if (sk2) {
 			inet_twsk_deschedule_put(inet_twsk(sk));
 			sk = sk2;

+ 38 - 28
net/ipv4/udp.c

@@ -380,8 +380,8 @@ int udp_v4_get_port(struct sock *sk, unsigned short snum)
 
 static int compute_score(struct sock *sk, struct net *net,
 			 __be32 saddr, __be16 sport,
-			 __be32 daddr, unsigned short hnum, int dif,
-			 bool exact_dif)
+			 __be32 daddr, unsigned short hnum,
+			 int dif, int sdif, bool exact_dif)
 {
 	int score;
 	struct inet_sock *inet;
@@ -413,10 +413,15 @@ static int compute_score(struct sock *sk, struct net *net,
 	}
 
 	if (sk->sk_bound_dev_if || exact_dif) {
-		if (sk->sk_bound_dev_if != dif)
+		bool dev_match = (sk->sk_bound_dev_if == dif ||
+				  sk->sk_bound_dev_if == sdif);
+
+		if (exact_dif && !dev_match)
 			return -1;
-		score += 4;
+		if (sk->sk_bound_dev_if && dev_match)
+			score += 4;
 	}
+
 	if (sk->sk_incoming_cpu == raw_smp_processor_id())
 		score++;
 	return score;
@@ -436,10 +441,11 @@ static u32 udp_ehashfn(const struct net *net, const __be32 laddr,
 
 /* called with rcu_read_lock() */
 static struct sock *udp4_lib_lookup2(struct net *net,
-		__be32 saddr, __be16 sport,
-		__be32 daddr, unsigned int hnum, int dif, bool exact_dif,
-		struct udp_hslot *hslot2,
-		struct sk_buff *skb)
+				     __be32 saddr, __be16 sport,
+				     __be32 daddr, unsigned int hnum,
+				     int dif, int sdif, bool exact_dif,
+				     struct udp_hslot *hslot2,
+				     struct sk_buff *skb)
 {
 	struct sock *sk, *result;
 	int score, badness, matches = 0, reuseport = 0;
@@ -449,7 +455,7 @@ static struct sock *udp4_lib_lookup2(struct net *net,
 	badness = 0;
 	udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) {
 		score = compute_score(sk, net, saddr, sport,
-				      daddr, hnum, dif, exact_dif);
+				      daddr, hnum, dif, sdif, exact_dif);
 		if (score > badness) {
 			reuseport = sk->sk_reuseport;
 			if (reuseport) {
@@ -477,8 +483,8 @@ static struct sock *udp4_lib_lookup2(struct net *net,
  * harder than this. -DaveM
  */
 struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
-		__be16 sport, __be32 daddr, __be16 dport,
-		int dif, struct udp_table *udptable, struct sk_buff *skb)
+		__be16 sport, __be32 daddr, __be16 dport, int dif,
+		int sdif, struct udp_table *udptable, struct sk_buff *skb)
 {
 	struct sock *sk, *result;
 	unsigned short hnum = ntohs(dport);
@@ -496,7 +502,7 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
 			goto begin;
 
 		result = udp4_lib_lookup2(net, saddr, sport,
-					  daddr, hnum, dif,
+					  daddr, hnum, dif, sdif,
 					  exact_dif, hslot2, skb);
 		if (!result) {
 			unsigned int old_slot2 = slot2;
@@ -511,7 +517,7 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
 				goto begin;
 
 			result = udp4_lib_lookup2(net, saddr, sport,
-						  daddr, hnum, dif,
+						  daddr, hnum, dif, sdif,
 						  exact_dif, hslot2, skb);
 		}
 		return result;
@@ -521,7 +527,7 @@ begin:
 	badness = 0;
 	sk_for_each_rcu(sk, &hslot->head) {
 		score = compute_score(sk, net, saddr, sport,
-				      daddr, hnum, dif, exact_dif);
+				      daddr, hnum, dif, sdif, exact_dif);
 		if (score > badness) {
 			reuseport = sk->sk_reuseport;
 			if (reuseport) {
@@ -554,7 +560,7 @@ static inline struct sock *__udp4_lib_lookup_skb(struct sk_buff *skb,
 
 	return __udp4_lib_lookup(dev_net(skb->dev), iph->saddr, sport,
 				 iph->daddr, dport, inet_iif(skb),
-				 udptable, skb);
+				 inet_sdif(skb), udptable, skb);
 }
 
 struct sock *udp4_lib_lookup_skb(struct sk_buff *skb,
@@ -576,7 +582,7 @@ struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport,
 	struct sock *sk;
 
 	sk = __udp4_lib_lookup(net, saddr, sport, daddr, dport,
-			       dif, &udp_table, NULL);
+			       dif, 0, &udp_table, NULL);
 	if (sk && !refcount_inc_not_zero(&sk->sk_refcnt))
 		sk = NULL;
 	return sk;
@@ -587,7 +593,7 @@ EXPORT_SYMBOL_GPL(udp4_lib_lookup);
 static inline bool __udp_is_mcast_sock(struct net *net, struct sock *sk,
 				       __be16 loc_port, __be32 loc_addr,
 				       __be16 rmt_port, __be32 rmt_addr,
-				       int dif, unsigned short hnum)
+				       int dif, int sdif, unsigned short hnum)
 {
 	struct inet_sock *inet = inet_sk(sk);
 
@@ -597,9 +603,10 @@ static inline bool __udp_is_mcast_sock(struct net *net, struct sock *sk,
 	    (inet->inet_dport != rmt_port && inet->inet_dport) ||
 	    (inet->inet_rcv_saddr && inet->inet_rcv_saddr != loc_addr) ||
 	    ipv6_only_sock(sk) ||
-	    (sk->sk_bound_dev_if && sk->sk_bound_dev_if != dif))
+	    (sk->sk_bound_dev_if && sk->sk_bound_dev_if != dif &&
+	     sk->sk_bound_dev_if != sdif))
 		return false;
-	if (!ip_mc_sf_allow(sk, loc_addr, rmt_addr, dif))
+	if (!ip_mc_sf_allow(sk, loc_addr, rmt_addr, dif, sdif))
 		return false;
 	return true;
 }
@@ -628,8 +635,8 @@ void __udp4_lib_err(struct sk_buff *skb, u32 info, struct udp_table *udptable)
 	struct net *net = dev_net(skb->dev);
 
 	sk = __udp4_lib_lookup(net, iph->daddr, uh->dest,
-			iph->saddr, uh->source, skb->dev->ifindex, udptable,
-			NULL);
+			       iph->saddr, uh->source, skb->dev->ifindex, 0,
+			       udptable, NULL);
 	if (!sk) {
 		__ICMP_INC_STATS(net, ICMP_MIB_INERRORS);
 		return;	/* No socket for error */
@@ -1953,6 +1960,7 @@ static int __udp4_lib_mcast_deliver(struct net *net, struct sk_buff *skb,
 	unsigned int hash2 = 0, hash2_any = 0, use_hash2 = (hslot->count > 10);
 	unsigned int offset = offsetof(typeof(*sk), sk_node);
 	int dif = skb->dev->ifindex;
+	int sdif = inet_sdif(skb);
 	struct hlist_node *node;
 	struct sk_buff *nskb;
 
@@ -1967,7 +1975,7 @@ start_lookup:
 
 	sk_for_each_entry_offset_rcu(sk, node, &hslot->head, offset) {
 		if (!__udp_is_mcast_sock(net, sk, uh->dest, daddr,
-					 uh->source, saddr, dif, hnum))
+					 uh->source, saddr, dif, sdif, hnum))
 			continue;
 
 		if (!first) {
@@ -2157,7 +2165,7 @@ drop:
 static struct sock *__udp4_lib_mcast_demux_lookup(struct net *net,
 						  __be16 loc_port, __be32 loc_addr,
 						  __be16 rmt_port, __be32 rmt_addr,
-						  int dif)
+						  int dif, int sdif)
 {
 	struct sock *sk, *result;
 	unsigned short hnum = ntohs(loc_port);
@@ -2171,7 +2179,7 @@ static struct sock *__udp4_lib_mcast_demux_lookup(struct net *net,
 	result = NULL;
 	sk_for_each_rcu(sk, &hslot->head) {
 		if (__udp_is_mcast_sock(net, sk, loc_port, loc_addr,
-					rmt_port, rmt_addr, dif, hnum)) {
+					rmt_port, rmt_addr, dif, sdif, hnum)) {
 			if (result)
 				return NULL;
 			result = sk;
@@ -2188,7 +2196,7 @@ static struct sock *__udp4_lib_mcast_demux_lookup(struct net *net,
 static struct sock *__udp4_lib_demux_lookup(struct net *net,
 					    __be16 loc_port, __be32 loc_addr,
 					    __be16 rmt_port, __be32 rmt_addr,
-					    int dif)
+					    int dif, int sdif)
 {
 	unsigned short hnum = ntohs(loc_port);
 	unsigned int hash2 = udp4_portaddr_hash(net, loc_addr, hnum);
@@ -2200,7 +2208,7 @@ static struct sock *__udp4_lib_demux_lookup(struct net *net,
 
 	udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) {
 		if (INET_MATCH(sk, net, acookie, rmt_addr,
-			       loc_addr, ports, dif))
+			       loc_addr, ports, dif, sdif))
 			return sk;
 		/* Only check first socket in chain */
 		break;
@@ -2216,6 +2224,7 @@ void udp_v4_early_demux(struct sk_buff *skb)
 	struct sock *sk = NULL;
 	struct dst_entry *dst;
 	int dif = skb->dev->ifindex;
+	int sdif = inet_sdif(skb);
 	int ours;
 
 	/* validate the packet */
@@ -2241,10 +2250,11 @@ void udp_v4_early_demux(struct sk_buff *skb)
 		}
 
 		sk = __udp4_lib_mcast_demux_lookup(net, uh->dest, iph->daddr,
-						   uh->source, iph->saddr, dif);
+						   uh->source, iph->saddr,
+						   dif, sdif);
 	} else if (skb->pkt_type == PACKET_HOST) {
 		sk = __udp4_lib_demux_lookup(net, uh->dest, iph->daddr,
-					     uh->source, iph->saddr, dif);
+					     uh->source, iph->saddr, dif, sdif);
 	}
 
 	if (!sk || !refcount_inc_not_zero(&sk->sk_refcnt))

+ 5 - 5
net/ipv4/udp_diag.c

@@ -45,7 +45,7 @@ static int udp_dump_one(struct udp_table *tbl, struct sk_buff *in_skb,
 		sk = __udp4_lib_lookup(net,
 				req->id.idiag_src[0], req->id.idiag_sport,
 				req->id.idiag_dst[0], req->id.idiag_dport,
-				req->id.idiag_if, tbl, NULL);
+				req->id.idiag_if, 0, tbl, NULL);
 #if IS_ENABLED(CONFIG_IPV6)
 	else if (req->sdiag_family == AF_INET6)
 		sk = __udp6_lib_lookup(net,
@@ -53,7 +53,7 @@ static int udp_dump_one(struct udp_table *tbl, struct sk_buff *in_skb,
 				req->id.idiag_sport,
 				(struct in6_addr *)req->id.idiag_dst,
 				req->id.idiag_dport,
-				req->id.idiag_if, tbl, NULL);
+				req->id.idiag_if, 0, tbl, NULL);
 #endif
 	if (sk && !refcount_inc_not_zero(&sk->sk_refcnt))
 		sk = NULL;
@@ -182,7 +182,7 @@ static int __udp_diag_destroy(struct sk_buff *in_skb,
 		sk = __udp4_lib_lookup(net,
 				req->id.idiag_dst[0], req->id.idiag_dport,
 				req->id.idiag_src[0], req->id.idiag_sport,
-				req->id.idiag_if, tbl, NULL);
+				req->id.idiag_if, 0, tbl, NULL);
 #if IS_ENABLED(CONFIG_IPV6)
 	else if (req->sdiag_family == AF_INET6) {
 		if (ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_dst) &&
@@ -190,7 +190,7 @@ static int __udp_diag_destroy(struct sk_buff *in_skb,
 			sk = __udp4_lib_lookup(net,
 					req->id.idiag_dst[3], req->id.idiag_dport,
 					req->id.idiag_src[3], req->id.idiag_sport,
-					req->id.idiag_if, tbl, NULL);
+					req->id.idiag_if, 0, tbl, NULL);
 
 		else
 			sk = __udp6_lib_lookup(net,
@@ -198,7 +198,7 @@ static int __udp_diag_destroy(struct sk_buff *in_skb,
 					req->id.idiag_dport,
 					(struct in6_addr *)req->id.idiag_src,
 					req->id.idiag_sport,
-					req->id.idiag_if, tbl, NULL);
+					req->id.idiag_if, 0, tbl, NULL);
 	}
 #endif
 	else {

+ 17 - 11
net/ipv6/inet6_hashtables.c

@@ -56,7 +56,7 @@ struct sock *__inet6_lookup_established(struct net *net,
 					   const __be16 sport,
 					   const struct in6_addr *daddr,
 					   const u16 hnum,
-					   const int dif)
+					   const int dif, const int sdif)
 {
 	struct sock *sk;
 	const struct hlist_nulls_node *node;
@@ -73,12 +73,12 @@ begin:
 	sk_nulls_for_each_rcu(sk, node, &head->chain) {
 		if (sk->sk_hash != hash)
 			continue;
-		if (!INET6_MATCH(sk, net, saddr, daddr, ports, dif))
+		if (!INET6_MATCH(sk, net, saddr, daddr, ports, dif, sdif))
 			continue;
 		if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt)))
 			goto out;
 
-		if (unlikely(!INET6_MATCH(sk, net, saddr, daddr, ports, dif))) {
+		if (unlikely(!INET6_MATCH(sk, net, saddr, daddr, ports, dif, sdif))) {
 			sock_gen_put(sk);
 			goto begin;
 		}
@@ -96,7 +96,7 @@ EXPORT_SYMBOL(__inet6_lookup_established);
 static inline int compute_score(struct sock *sk, struct net *net,
 				const unsigned short hnum,
 				const struct in6_addr *daddr,
-				const int dif, bool exact_dif)
+				const int dif, const int sdif, bool exact_dif)
 {
 	int score = -1;
 
@@ -110,9 +110,13 @@ static inline int compute_score(struct sock *sk, struct net *net,
 			score++;
 		}
 		if (sk->sk_bound_dev_if || exact_dif) {
-			if (sk->sk_bound_dev_if != dif)
+			bool dev_match = (sk->sk_bound_dev_if == dif ||
+					  sk->sk_bound_dev_if == sdif);
+
+			if (exact_dif && !dev_match)
 				return -1;
-			score++;
+			if (sk->sk_bound_dev_if && dev_match)
+				score++;
 		}
 		if (sk->sk_incoming_cpu == raw_smp_processor_id())
 			score++;
@@ -126,7 +130,7 @@ struct sock *inet6_lookup_listener(struct net *net,
 		struct sk_buff *skb, int doff,
 		const struct in6_addr *saddr,
 		const __be16 sport, const struct in6_addr *daddr,
-		const unsigned short hnum, const int dif)
+		const unsigned short hnum, const int dif, const int sdif)
 {
 	unsigned int hash = inet_lhashfn(net, hnum);
 	struct inet_listen_hashbucket *ilb = &hashinfo->listening_hash[hash];
@@ -136,7 +140,7 @@ struct sock *inet6_lookup_listener(struct net *net,
 	u32 phash = 0;
 
 	sk_for_each(sk, &ilb->head) {
-		score = compute_score(sk, net, hnum, daddr, dif, exact_dif);
+		score = compute_score(sk, net, hnum, daddr, dif, sdif, exact_dif);
 		if (score > hiscore) {
 			reuseport = sk->sk_reuseport;
 			if (reuseport) {
@@ -171,7 +175,7 @@ struct sock *inet6_lookup(struct net *net, struct inet_hashinfo *hashinfo,
 	bool refcounted;
 
 	sk = __inet6_lookup(net, hashinfo, skb, doff, saddr, sport, daddr,
-			    ntohs(dport), dif, &refcounted);
+			    ntohs(dport), dif, 0, &refcounted);
 	if (sk && !refcounted && !refcount_inc_not_zero(&sk->sk_refcnt))
 		sk = NULL;
 	return sk;
@@ -187,8 +191,9 @@ static int __inet6_check_established(struct inet_timewait_death_row *death_row,
 	const struct in6_addr *daddr = &sk->sk_v6_rcv_saddr;
 	const struct in6_addr *saddr = &sk->sk_v6_daddr;
 	const int dif = sk->sk_bound_dev_if;
-	const __portpair ports = INET_COMBINED_PORTS(inet->inet_dport, lport);
 	struct net *net = sock_net(sk);
+	const int sdif = l3mdev_master_ifindex_by_index(net, dif);
+	const __portpair ports = INET_COMBINED_PORTS(inet->inet_dport, lport);
 	const unsigned int hash = inet6_ehashfn(net, daddr, lport, saddr,
 						inet->inet_dport);
 	struct inet_ehash_bucket *head = inet_ehash_bucket(hinfo, hash);
@@ -203,7 +208,8 @@ static int __inet6_check_established(struct inet_timewait_death_row *death_row,
 		if (sk2->sk_hash != hash)
 			continue;
 
-		if (likely(INET6_MATCH(sk2, net, saddr, daddr, ports, dif))) {
+		if (likely(INET6_MATCH(sk2, net, saddr, daddr, ports,
+				       dif, sdif))) {
 			if (sk2->sk_state == TCP_TIME_WAIT) {
 				tw = inet_twsk(sk2);
 				if (twsk_unique(sk, sk2, twp))

+ 8 - 5
net/ipv6/raw.c

@@ -72,7 +72,7 @@ EXPORT_SYMBOL_GPL(raw_v6_hashinfo);
 
 struct sock *__raw_v6_lookup(struct net *net, struct sock *sk,
 		unsigned short num, const struct in6_addr *loc_addr,
-		const struct in6_addr *rmt_addr, int dif)
+		const struct in6_addr *rmt_addr, int dif, int sdif)
 {
 	bool is_multicast = ipv6_addr_is_multicast(loc_addr);
 
@@ -86,7 +86,9 @@ struct sock *__raw_v6_lookup(struct net *net, struct sock *sk,
 			    !ipv6_addr_equal(&sk->sk_v6_daddr, rmt_addr))
 				continue;
 
-			if (sk->sk_bound_dev_if && sk->sk_bound_dev_if != dif)
+			if (sk->sk_bound_dev_if &&
+			    sk->sk_bound_dev_if != dif &&
+			    sk->sk_bound_dev_if != sdif)
 				continue;
 
 			if (!ipv6_addr_any(&sk->sk_v6_rcv_saddr)) {
@@ -178,7 +180,8 @@ static bool ipv6_raw_deliver(struct sk_buff *skb, int nexthdr)
 		goto out;
 
 	net = dev_net(skb->dev);
-	sk = __raw_v6_lookup(net, sk, nexthdr, daddr, saddr, inet6_iif(skb));
+	sk = __raw_v6_lookup(net, sk, nexthdr, daddr, saddr,
+			     inet6_iif(skb), inet6_sdif(skb));
 
 	while (sk) {
 		int filtered;
@@ -222,7 +225,7 @@ static bool ipv6_raw_deliver(struct sk_buff *skb, int nexthdr)
 			}
 		}
 		sk = __raw_v6_lookup(net, sk_next(sk), nexthdr, daddr, saddr,
-				     inet6_iif(skb));
+				     inet6_iif(skb), inet6_sdif(skb));
 	}
 out:
 	read_unlock(&raw_v6_hashinfo.lock);
@@ -378,7 +381,7 @@ void raw6_icmp_error(struct sk_buff *skb, int nexthdr,
 		net = dev_net(skb->dev);
 
 		while ((sk = __raw_v6_lookup(net, sk, nexthdr, saddr, daddr,
-						inet6_iif(skb)))) {
+					     inet6_iif(skb), inet6_iif(skb)))) {
 			rawv6_err(sk, skb, NULL, type, code,
 					inner_offset, info);
 			sk = sk_next(sk);

+ 8 - 5
net/ipv6/tcp_ipv6.c

@@ -350,7 +350,7 @@ static void tcp_v6_err(struct sk_buff *skb, struct inet6_skb_parm *opt,
 	sk = __inet6_lookup_established(net, &tcp_hashinfo,
 					&hdr->daddr, th->dest,
 					&hdr->saddr, ntohs(th->source),
-					skb->dev->ifindex);
+					skb->dev->ifindex, inet6_sdif(skb));
 
 	if (!sk) {
 		__ICMP6_INC_STATS(net, __in6_dev_get(skb->dev),
@@ -918,7 +918,8 @@ static void tcp_v6_send_reset(const struct sock *sk, struct sk_buff *skb)
 					   &tcp_hashinfo, NULL, 0,
 					   &ipv6h->saddr,
 					   th->source, &ipv6h->daddr,
-					   ntohs(th->source), tcp_v6_iif(skb));
+					   ntohs(th->source), tcp_v6_iif(skb),
+					   tcp_v6_sdif(skb));
 		if (!sk1)
 			goto out;
 
@@ -1397,6 +1398,7 @@ static void tcp_v6_fill_cb(struct sk_buff *skb, const struct ipv6hdr *hdr,
 
 static int tcp_v6_rcv(struct sk_buff *skb)
 {
+	int sdif = inet6_sdif(skb);
 	const struct tcphdr *th;
 	const struct ipv6hdr *hdr;
 	bool refcounted;
@@ -1430,7 +1432,7 @@ static int tcp_v6_rcv(struct sk_buff *skb)
 
 lookup:
 	sk = __inet6_lookup_skb(&tcp_hashinfo, skb, __tcp_hdrlen(th),
-				th->source, th->dest, inet6_iif(skb),
+				th->source, th->dest, inet6_iif(skb), sdif,
 				&refcounted);
 	if (!sk)
 		goto no_tcp_socket;
@@ -1563,7 +1565,8 @@ do_time_wait:
 					    skb, __tcp_hdrlen(th),
 					    &ipv6_hdr(skb)->saddr, th->source,
 					    &ipv6_hdr(skb)->daddr,
-					    ntohs(th->dest), tcp_v6_iif(skb));
+					    ntohs(th->dest), tcp_v6_iif(skb),
+					    sdif);
 		if (sk2) {
 			struct inet_timewait_sock *tw = inet_twsk(sk);
 			inet_twsk_deschedule_put(tw);
@@ -1610,7 +1613,7 @@ static void tcp_v6_early_demux(struct sk_buff *skb)
 	sk = __inet6_lookup_established(dev_net(skb->dev), &tcp_hashinfo,
 					&hdr->saddr, th->source,
 					&hdr->daddr, ntohs(th->dest),
-					inet6_iif(skb));
+					inet6_iif(skb), inet6_sdif(skb));
 	if (sk) {
 		skb->sk = sk;
 		skb->destructor = sock_edemux;

+ 26 - 21
net/ipv6/udp.c

@@ -129,7 +129,7 @@ static void udp_v6_rehash(struct sock *sk)
 static int compute_score(struct sock *sk, struct net *net,
 			 const struct in6_addr *saddr, __be16 sport,
 			 const struct in6_addr *daddr, unsigned short hnum,
-			 int dif, bool exact_dif)
+			 int dif, int sdif, bool exact_dif)
 {
 	int score;
 	struct inet_sock *inet;
@@ -161,9 +161,13 @@ static int compute_score(struct sock *sk, struct net *net,
 	}
 
 	if (sk->sk_bound_dev_if || exact_dif) {
-		if (sk->sk_bound_dev_if != dif)
+		bool dev_match = (sk->sk_bound_dev_if == dif ||
+				  sk->sk_bound_dev_if == sdif);
+
+		if (exact_dif && !dev_match)
 			return -1;
-		score++;
+		if (sk->sk_bound_dev_if && dev_match)
+			score++;
 	}
 
 	if (sk->sk_incoming_cpu == raw_smp_processor_id())
@@ -175,9 +179,9 @@ static int compute_score(struct sock *sk, struct net *net,
 /* called with rcu_read_lock() */
 static struct sock *udp6_lib_lookup2(struct net *net,
 		const struct in6_addr *saddr, __be16 sport,
-		const struct in6_addr *daddr, unsigned int hnum, int dif,
-		bool exact_dif, struct udp_hslot *hslot2,
-		struct sk_buff *skb)
+		const struct in6_addr *daddr, unsigned int hnum,
+		int dif, int sdif, bool exact_dif,
+		struct udp_hslot *hslot2, struct sk_buff *skb)
 {
 	struct sock *sk, *result;
 	int score, badness, matches = 0, reuseport = 0;
@@ -187,7 +191,7 @@ static struct sock *udp6_lib_lookup2(struct net *net,
 	badness = -1;
 	udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) {
 		score = compute_score(sk, net, saddr, sport,
-				      daddr, hnum, dif, exact_dif);
+				      daddr, hnum, dif, sdif, exact_dif);
 		if (score > badness) {
 			reuseport = sk->sk_reuseport;
 			if (reuseport) {
@@ -214,10 +218,10 @@ static struct sock *udp6_lib_lookup2(struct net *net,
 
 /* rcu_read_lock() must be held */
 struct sock *__udp6_lib_lookup(struct net *net,
-				      const struct in6_addr *saddr, __be16 sport,
-				      const struct in6_addr *daddr, __be16 dport,
-				      int dif, struct udp_table *udptable,
-				      struct sk_buff *skb)
+			       const struct in6_addr *saddr, __be16 sport,
+			       const struct in6_addr *daddr, __be16 dport,
+			       int dif, int sdif, struct udp_table *udptable,
+			       struct sk_buff *skb)
 {
 	struct sock *sk, *result;
 	unsigned short hnum = ntohs(dport);
@@ -235,7 +239,7 @@ struct sock *__udp6_lib_lookup(struct net *net,
 			goto begin;
 
 		result = udp6_lib_lookup2(net, saddr, sport,
-					  daddr, hnum, dif, exact_dif,
+					  daddr, hnum, dif, sdif, exact_dif,
 					  hslot2, skb);
 		if (!result) {
 			unsigned int old_slot2 = slot2;
@@ -250,7 +254,7 @@ struct sock *__udp6_lib_lookup(struct net *net,
 				goto begin;
 
 			result = udp6_lib_lookup2(net, saddr, sport,
-						  daddr, hnum, dif,
+						  daddr, hnum, dif, sdif,
 						  exact_dif, hslot2,
 						  skb);
 		}
@@ -261,7 +265,7 @@ begin:
 	badness = -1;
 	sk_for_each_rcu(sk, &hslot->head) {
 		score = compute_score(sk, net, saddr, sport, daddr, hnum, dif,
-				      exact_dif);
+				      sdif, exact_dif);
 		if (score > badness) {
 			reuseport = sk->sk_reuseport;
 			if (reuseport) {
@@ -294,7 +298,7 @@ static struct sock *__udp6_lib_lookup_skb(struct sk_buff *skb,
 
 	return __udp6_lib_lookup(dev_net(skb->dev), &iph->saddr, sport,
 				 &iph->daddr, dport, inet6_iif(skb),
-				 udptable, skb);
+				 inet6_sdif(skb), udptable, skb);
 }
 
 struct sock *udp6_lib_lookup_skb(struct sk_buff *skb,
@@ -304,7 +308,7 @@ struct sock *udp6_lib_lookup_skb(struct sk_buff *skb,
 
 	return __udp6_lib_lookup(dev_net(skb->dev), &iph->saddr, sport,
 				 &iph->daddr, dport, inet6_iif(skb),
-				 &udp_table, skb);
+				 inet6_sdif(skb), &udp_table, skb);
 }
 EXPORT_SYMBOL_GPL(udp6_lib_lookup_skb);
 
@@ -320,7 +324,7 @@ struct sock *udp6_lib_lookup(struct net *net, const struct in6_addr *saddr, __be
 	struct sock *sk;
 
 	sk =  __udp6_lib_lookup(net, saddr, sport, daddr, dport,
-				dif, &udp_table, NULL);
+				dif, 0, &udp_table, NULL);
 	if (sk && !refcount_inc_not_zero(&sk->sk_refcnt))
 		sk = NULL;
 	return sk;
@@ -501,7 +505,7 @@ void __udp6_lib_err(struct sk_buff *skb, struct inet6_skb_parm *opt,
 	struct net *net = dev_net(skb->dev);
 
 	sk = __udp6_lib_lookup(net, daddr, uh->dest, saddr, uh->source,
-			       inet6_iif(skb), udptable, skb);
+			       inet6_iif(skb), 0, udptable, skb);
 	if (!sk) {
 		__ICMP6_INC_STATS(net, __in6_dev_get(skb->dev),
 				  ICMP6_MIB_INERRORS);
@@ -893,7 +897,7 @@ discard:
 static struct sock *__udp6_lib_demux_lookup(struct net *net,
 			__be16 loc_port, const struct in6_addr *loc_addr,
 			__be16 rmt_port, const struct in6_addr *rmt_addr,
-			int dif)
+			int dif, int sdif)
 {
 	unsigned short hnum = ntohs(loc_port);
 	unsigned int hash2 = udp6_portaddr_hash(net, loc_addr, hnum);
@@ -904,7 +908,7 @@ static struct sock *__udp6_lib_demux_lookup(struct net *net,
 
 	udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) {
 		if (sk->sk_state == TCP_ESTABLISHED &&
-		    INET6_MATCH(sk, net, rmt_addr, loc_addr, ports, dif))
+		    INET6_MATCH(sk, net, rmt_addr, loc_addr, ports, dif, sdif))
 			return sk;
 		/* Only check first socket in chain */
 		break;
@@ -919,6 +923,7 @@ static void udp_v6_early_demux(struct sk_buff *skb)
 	struct sock *sk;
 	struct dst_entry *dst;
 	int dif = skb->dev->ifindex;
+	int sdif = inet6_sdif(skb);
 
 	if (!pskb_may_pull(skb, skb_transport_offset(skb) +
 	    sizeof(struct udphdr)))
@@ -930,7 +935,7 @@ static void udp_v6_early_demux(struct sk_buff *skb)
 		sk = __udp6_lib_demux_lookup(net, uh->dest,
 					     &ipv6_hdr(skb)->daddr,
 					     uh->source, &ipv6_hdr(skb)->saddr,
-					     dif);
+					     dif, sdif);
 	else
 		return;
 

+ 3 - 3
net/netfilter/xt_TPROXY.c

@@ -125,7 +125,7 @@ nf_tproxy_get_sock_v4(struct net *net, struct sk_buff *skb, void *hp,
 						      __tcp_hdrlen(tcph),
 						    saddr, sport,
 						    daddr, dport,
-						    in->ifindex);
+						    in->ifindex, 0);
 
 			if (sk && !refcount_inc_not_zero(&sk->sk_refcnt))
 				sk = NULL;
@@ -195,7 +195,7 @@ nf_tproxy_get_sock_v6(struct net *net, struct sk_buff *skb, int thoff, void *hp,
 						   thoff + __tcp_hdrlen(tcph),
 						   saddr, sport,
 						   daddr, ntohs(dport),
-						   in->ifindex);
+						   in->ifindex, 0);
 
 			if (sk && !refcount_inc_not_zero(&sk->sk_refcnt))
 				sk = NULL;
@@ -208,7 +208,7 @@ nf_tproxy_get_sock_v6(struct net *net, struct sk_buff *skb, int thoff, void *hp,
 		case NFT_LOOKUP_ESTABLISHED:
 			sk = __inet6_lookup_established(net, &tcp_hashinfo,
 							saddr, sport, daddr, ntohs(dport),
-							in->ifindex);
+							in->ifindex, 0);
 			break;
 		default:
 			BUG();