Browse Source

net, ipx: convert ipx_interface.refcnt from atomic_t to refcount_t

refcount_t type and corresponding API should be
used instead of atomic_t when the variable is used as
a reference counter. This allows to avoid accidental
refcounter overflows that might lead to use-after-free
situations.

Signed-off-by: Elena Reshetova <elena.reshetova@intel.com>
Signed-off-by: Hans Liljestrand <ishkamiel@gmail.com>
Signed-off-by: Kees Cook <keescook@chromium.org>
Signed-off-by: David Windsor <dwindsor@gmail.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
Reshetova, Elena 8 years ago
parent
commit
d25189ca86
3 changed files with 8 additions and 7 deletions
  1. 4 3
      include/net/ipx.h
  2. 3 3
      net/ipx/af_ipx.c
  3. 1 1
      net/ipx/ipx_proc.c

+ 4 - 3
include/net/ipx.h

@@ -14,6 +14,7 @@
 #include <linux/ipx.h>
 #include <linux/ipx.h>
 #include <linux/list.h>
 #include <linux/list.h>
 #include <linux/slab.h>
 #include <linux/slab.h>
+#include <linux/refcount.h>
 
 
 struct ipx_address {
 struct ipx_address {
 	__be32  net;
 	__be32  net;
@@ -54,7 +55,7 @@ struct ipx_interface {
 	/* IPX address */
 	/* IPX address */
 	__be32			if_netnum;
 	__be32			if_netnum;
 	unsigned char		if_node[IPX_NODE_LEN];
 	unsigned char		if_node[IPX_NODE_LEN];
-	atomic_t		refcnt;
+	refcount_t		refcnt;
 
 
 	/* physical device info */
 	/* physical device info */
 	struct net_device	*if_dev;
 	struct net_device	*if_dev;
@@ -139,7 +140,7 @@ const char *ipx_device_name(struct ipx_interface *intrfc);
 
 
 static __inline__ void ipxitf_hold(struct ipx_interface *intrfc)
 static __inline__ void ipxitf_hold(struct ipx_interface *intrfc)
 {
 {
-	atomic_inc(&intrfc->refcnt);
+	refcount_inc(&intrfc->refcnt);
 }
 }
 
 
 void ipxitf_down(struct ipx_interface *intrfc);
 void ipxitf_down(struct ipx_interface *intrfc);
@@ -157,7 +158,7 @@ int ipxrtr_ioctl(unsigned int cmd, void __user *arg);
 
 
 static __inline__ void ipxitf_put(struct ipx_interface *intrfc)
 static __inline__ void ipxitf_put(struct ipx_interface *intrfc)
 {
 {
-	if (atomic_dec_and_test(&intrfc->refcnt))
+	if (refcount_dec_and_test(&intrfc->refcnt))
 		ipxitf_down(intrfc);
 		ipxitf_down(intrfc);
 }
 }
 
 

+ 3 - 3
net/ipx/af_ipx.c

@@ -308,7 +308,7 @@ void ipxitf_down(struct ipx_interface *intrfc)
 
 
 static void __ipxitf_put(struct ipx_interface *intrfc)
 static void __ipxitf_put(struct ipx_interface *intrfc)
 {
 {
-	if (atomic_dec_and_test(&intrfc->refcnt))
+	if (refcount_dec_and_test(&intrfc->refcnt))
 		__ipxitf_down(intrfc);
 		__ipxitf_down(intrfc);
 }
 }
 
 
@@ -876,7 +876,7 @@ static struct ipx_interface *ipxitf_alloc(struct net_device *dev, __be32 netnum,
 		intrfc->if_ipx_offset 	= ipx_offset;
 		intrfc->if_ipx_offset 	= ipx_offset;
 		intrfc->if_sknum 	= IPX_MIN_EPHEMERAL_SOCKET;
 		intrfc->if_sknum 	= IPX_MIN_EPHEMERAL_SOCKET;
 		INIT_HLIST_HEAD(&intrfc->if_sklist);
 		INIT_HLIST_HEAD(&intrfc->if_sklist);
-		atomic_set(&intrfc->refcnt, 1);
+		refcount_set(&intrfc->refcnt, 1);
 		spin_lock_init(&intrfc->if_sklist_lock);
 		spin_lock_init(&intrfc->if_sklist_lock);
 	}
 	}
 
 
@@ -1105,7 +1105,7 @@ static struct ipx_interface *ipxitf_auto_create(struct net_device *dev,
 		memcpy((char *)&(intrfc->if_node[IPX_NODE_LEN-dev->addr_len]),
 		memcpy((char *)&(intrfc->if_node[IPX_NODE_LEN-dev->addr_len]),
 			dev->dev_addr, dev->addr_len);
 			dev->dev_addr, dev->addr_len);
 		spin_lock_init(&intrfc->if_sklist_lock);
 		spin_lock_init(&intrfc->if_sklist_lock);
-		atomic_set(&intrfc->refcnt, 1);
+		refcount_set(&intrfc->refcnt, 1);
 		ipxitf_insert(intrfc);
 		ipxitf_insert(intrfc);
 		dev_hold(dev);
 		dev_hold(dev);
 	}
 	}

+ 1 - 1
net/ipx/ipx_proc.c

@@ -53,7 +53,7 @@ static int ipx_seq_interface_show(struct seq_file *seq, void *v)
 	seq_printf(seq, "%-11s", ipx_device_name(i));
 	seq_printf(seq, "%-11s", ipx_device_name(i));
 	seq_printf(seq, "%-9s", ipx_frame_name(i->if_dlink_type));
 	seq_printf(seq, "%-9s", ipx_frame_name(i->if_dlink_type));
 #ifdef IPX_REFCNT_DEBUG
 #ifdef IPX_REFCNT_DEBUG
-	seq_printf(seq, "%6d", atomic_read(&i->refcnt));
+	seq_printf(seq, "%6d", refcount_read(&i->refcnt));
 #endif
 #endif
 	seq_puts(seq, "\n");
 	seq_puts(seq, "\n");
 out:
 out: