浏览代码

IB/rdma_cm: Add wrapper for cma reference count

Currently, cma users can't increase or decrease the cma reference
count. This is necassary when setting cma attributes (like the
default GID type) in order to avoid use-after-free errors.
Adding cma_ref_dev and cma_deref_dev APIs.

Signed-off-by: Matan Barak <matanb@mellanox.com>
Signed-off-by: Doug Ledford <dledford@redhat.com>
Matan Barak 9 年之前
父节点
当前提交
218a773f76
共有 2 个文件被更改,包括 13 次插入2 次删除
  1. 9 2
      drivers/infiniband/core/cma.c
  2. 4 0
      drivers/infiniband/core/core_priv.h

+ 9 - 2
drivers/infiniband/core/cma.c

@@ -60,6 +60,8 @@
 #include <rdma/ib_sa.h>
 #include <rdma/iw_cm.h>
 
+#include "core_priv.h"
+
 MODULE_AUTHOR("Sean Hefty");
 MODULE_DESCRIPTION("Generic RDMA CM Agent");
 MODULE_LICENSE("Dual BSD/GPL");
@@ -185,6 +187,11 @@ enum {
 	CMA_OPTION_AFONLY,
 };
 
+void cma_ref_dev(struct cma_device *cma_dev)
+{
+	atomic_inc(&cma_dev->refcount);
+}
+
 /*
  * Device removal can occur at anytime, so we need extra handling to
  * serialize notifying the user of device removal with other callbacks.
@@ -339,7 +346,7 @@ static inline void cma_set_ip_ver(struct cma_hdr *hdr, u8 ip_ver)
 static void cma_attach_to_dev(struct rdma_id_private *id_priv,
 			      struct cma_device *cma_dev)
 {
-	atomic_inc(&cma_dev->refcount);
+	cma_ref_dev(cma_dev);
 	id_priv->cma_dev = cma_dev;
 	id_priv->id.device = cma_dev->device;
 	id_priv->id.route.addr.dev_addr.transport =
@@ -347,7 +354,7 @@ static void cma_attach_to_dev(struct rdma_id_private *id_priv,
 	list_add_tail(&id_priv->list, &cma_dev->id_list);
 }
 
-static inline void cma_deref_dev(struct cma_device *cma_dev)
+void cma_deref_dev(struct cma_device *cma_dev)
 {
 	if (atomic_dec_and_test(&cma_dev->refcount))
 		complete(&cma_dev->comp);

+ 4 - 0
drivers/infiniband/core/core_priv.h

@@ -38,6 +38,10 @@
 
 #include <rdma/ib_verbs.h>
 
+struct cma_device;
+void cma_ref_dev(struct cma_device *cma_dev);
+void cma_deref_dev(struct cma_device *cma_dev);
+
 int  ib_device_register_sysfs(struct ib_device *device,
 			      int (*port_callback)(struct ib_device *,
 						   u8, struct kobject *));