Эх сурвалжийг харах

Merge branch 'concurrent_hash_tables'

Thomas Graf says:

====================
Lockless netlink_lookup() with new concurrent hash table

Netlink sockets are maintained in a hash table to allow efficient lookup
via the port ID for unicast messages. However, lookups currently require
a read lock to be taken. This series adds a new generic, resizable,
scalable, concurrent hash table based on the paper referenced in the first
patch. It then makes use of the new data type to implement lockless
netlink_lookup().

Patch 3/3 to convert nft_hash is included for reference but should be
merged via the netfilter tree. Inclusion in this series is to provide
context for the suggested API.

Against net-next since the initial user of the new hash table is in net/

Changes:
v4-v5:
 - use GFP_KERNEL to alloc Netlink buckets as suggested by Nikolay
   Aleksandrov
 - free nft hash element on removal as spotted by Nikolay Aleksandrov
   and Patrick McHardy
v3-v4:
 - fixed wrong shift assignment placement as spotted by Nikolay Aleksandrov
 - reverted default size of nft_hash to 4 as requested by Patrick McHardy,
   default size for other hash tables remains at 64 if no hint is given
 - fixed copyright as requested by Patrick McHardy
v2-v3:
 - fixed typo in nft_hash_destroy() when passing rhashtable handle
v1-v2:
 - fixed traversal off-by-one as spotted by Tobias Klauser
 - removed unlikely() from BUG_ON() as spotted by Josh Triplett
 - new 3rd patch to convert nft_hash to rhashtable
 - make rhashtable_insert() return void
 - nl_sk_hash_lock must be a mutex
 - fixed wrong name of rht_shrink_below_30()
 - exported symbols rht_grow_above_75() and rht_shrink_below_30()
 - allow table freeing with RCU callback
====================

Signed-off-by: David S. Miller <davem@davemloft.net>
David S. Miller 11 жил өмнө
parent
commit
bae2e81a69

+ 213 - 0
include/linux/rhashtable.h

@@ -0,0 +1,213 @@
+/*
+ * Resizable, Scalable, Concurrent Hash Table
+ *
+ * Copyright (c) 2014 Thomas Graf <tgraf@suug.ch>
+ * Copyright (c) 2008-2014 Patrick McHardy <kaber@trash.net>
+ *
+ * Based on the following paper by Josh Triplett, Paul E. McKenney
+ * and Jonathan Walpole:
+ * https://www.usenix.org/legacy/event/atc11/tech/final_files/Triplett.pdf
+ *
+ * Code partially derived from nft_hash
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License version 2 as
+ * published by the Free Software Foundation.
+ */
+
+#ifndef _LINUX_RHASHTABLE_H
+#define _LINUX_RHASHTABLE_H
+
+#include <linux/rculist.h>
+
+struct rhash_head {
+	struct rhash_head		*next;
+};
+
+#define INIT_HASH_HEAD(ptr) ((ptr)->next = NULL)
+
+struct bucket_table {
+	size_t				size;
+	struct rhash_head __rcu		*buckets[];
+};
+
+typedef u32 (*rht_hashfn_t)(const void *data, u32 len, u32 seed);
+typedef u32 (*rht_obj_hashfn_t)(const void *data, u32 seed);
+
+struct rhashtable;
+
+/**
+ * struct rhashtable_params - Hash table construction parameters
+ * @nelem_hint: Hint on number of elements, should be 75% of desired size
+ * @key_len: Length of key
+ * @key_offset: Offset of key in struct to be hashed
+ * @head_offset: Offset of rhash_head in struct to be hashed
+ * @hash_rnd: Seed to use while hashing
+ * @max_shift: Maximum number of shifts while expanding
+ * @hashfn: Function to hash key
+ * @obj_hashfn: Function to hash object
+ * @grow_decision: If defined, may return true if table should expand
+ * @shrink_decision: If defined, may return true if table should shrink
+ * @mutex_is_held: Must return true if protecting mutex is held
+ */
+struct rhashtable_params {
+	size_t			nelem_hint;
+	size_t			key_len;
+	size_t			key_offset;
+	size_t			head_offset;
+	u32			hash_rnd;
+	size_t			max_shift;
+	rht_hashfn_t		hashfn;
+	rht_obj_hashfn_t	obj_hashfn;
+	bool			(*grow_decision)(const struct rhashtable *ht,
+						 size_t new_size);
+	bool			(*shrink_decision)(const struct rhashtable *ht,
+						   size_t new_size);
+	int			(*mutex_is_held)(void);
+};
+
+/**
+ * struct rhashtable - Hash table handle
+ * @tbl: Bucket table
+ * @nelems: Number of elements in table
+ * @shift: Current size (1 << shift)
+ * @p: Configuration parameters
+ */
+struct rhashtable {
+	struct bucket_table __rcu	*tbl;
+	size_t				nelems;
+	size_t				shift;
+	struct rhashtable_params	p;
+};
+
+#ifdef CONFIG_PROVE_LOCKING
+int lockdep_rht_mutex_is_held(const struct rhashtable *ht);
+#else
+static inline int lockdep_rht_mutex_is_held(const struct rhashtable *ht)
+{
+	return 1;
+}
+#endif /* CONFIG_PROVE_LOCKING */
+
+int rhashtable_init(struct rhashtable *ht, struct rhashtable_params *params);
+
+u32 rhashtable_hashfn(const struct rhashtable *ht, const void *key, u32 len);
+u32 rhashtable_obj_hashfn(const struct rhashtable *ht, void *ptr);
+
+void rhashtable_insert(struct rhashtable *ht, struct rhash_head *node, gfp_t);
+bool rhashtable_remove(struct rhashtable *ht, struct rhash_head *node, gfp_t);
+void rhashtable_remove_pprev(struct rhashtable *ht, struct rhash_head *obj,
+			     struct rhash_head **pprev, gfp_t flags);
+
+bool rht_grow_above_75(const struct rhashtable *ht, size_t new_size);
+bool rht_shrink_below_30(const struct rhashtable *ht, size_t new_size);
+
+int rhashtable_expand(struct rhashtable *ht, gfp_t flags);
+int rhashtable_shrink(struct rhashtable *ht, gfp_t flags);
+
+void *rhashtable_lookup(const struct rhashtable *ht, const void *key);
+void *rhashtable_lookup_compare(const struct rhashtable *ht, u32 hash,
+				bool (*compare)(void *, void *), void *arg);
+
+void rhashtable_destroy(const struct rhashtable *ht);
+
+#define rht_dereference(p, ht) \
+	rcu_dereference_protected(p, lockdep_rht_mutex_is_held(ht))
+
+#define rht_dereference_rcu(p, ht) \
+	rcu_dereference_check(p, lockdep_rht_mutex_is_held(ht))
+
+/* Internal, use rht_obj() instead */
+#define rht_entry(ptr, type, member) container_of(ptr, type, member)
+#define rht_entry_safe(ptr, type, member) \
+({ \
+	typeof(ptr) __ptr = (ptr); \
+	   __ptr ? rht_entry(__ptr, type, member) : NULL; \
+})
+#define rht_entry_safe_rcu(ptr, type, member) \
+({ \
+	typeof(*ptr) __rcu *__ptr = (typeof(*ptr) __rcu __force *)ptr; \
+	__ptr ? container_of((typeof(ptr))rcu_dereference_raw(__ptr), type, member) : NULL; \
+})
+
+#define rht_next_entry_safe(pos, ht, member) \
+({ \
+	pos ? rht_entry_safe(rht_dereference((pos)->member.next, ht), \
+			     typeof(*(pos)), member) : NULL; \
+})
+
+/**
+ * rht_for_each - iterate over hash chain
+ * @pos:	&struct rhash_head to use as a loop cursor.
+ * @head:	head of the hash chain (struct rhash_head *)
+ * @ht:		pointer to your struct rhashtable
+ */
+#define rht_for_each(pos, head, ht) \
+	for (pos = rht_dereference(head, ht); \
+	     pos; \
+	     pos = rht_dereference((pos)->next, ht))
+
+/**
+ * rht_for_each_entry - iterate over hash chain of given type
+ * @pos:	type * to use as a loop cursor.
+ * @head:	head of the hash chain (struct rhash_head *)
+ * @ht:		pointer to your struct rhashtable
+ * @member:	name of the rhash_head within the hashable struct.
+ */
+#define rht_for_each_entry(pos, head, ht, member) \
+	for (pos = rht_entry_safe(rht_dereference(head, ht), \
+				   typeof(*(pos)), member); \
+	     pos; \
+	     pos = rht_next_entry_safe(pos, ht, member))
+
+/**
+ * rht_for_each_entry_safe - safely iterate over hash chain of given type
+ * @pos:	type * to use as a loop cursor.
+ * @n:		type * to use for temporary next object storage
+ * @head:	head of the hash chain (struct rhash_head *)
+ * @ht:		pointer to your struct rhashtable
+ * @member:	name of the rhash_head within the hashable struct.
+ *
+ * This hash chain list-traversal primitive allows for the looped code to
+ * remove the loop cursor from the list.
+ */
+#define rht_for_each_entry_safe(pos, n, head, ht, member)		\
+	for (pos = rht_entry_safe(rht_dereference(head, ht), \
+				  typeof(*(pos)), member), \
+	     n = rht_next_entry_safe(pos, ht, member); \
+	     pos; \
+	     pos = n, \
+	     n = rht_next_entry_safe(pos, ht, member))
+
+/**
+ * rht_for_each_rcu - iterate over rcu hash chain
+ * @pos:	&struct rhash_head to use as a loop cursor.
+ * @head:	head of the hash chain (struct rhash_head *)
+ * @ht:		pointer to your struct rhashtable
+ *
+ * This hash chain list-traversal primitive may safely run concurrently with
+ * the _rcu fkht mutation primitives such as rht_insert() as long as the
+ * traversal is guarded by rcu_read_lock().
+ */
+#define rht_for_each_rcu(pos, head, ht) \
+	for (pos = rht_dereference_rcu(head, ht); \
+	     pos; \
+	     pos = rht_dereference_rcu((pos)->next, ht))
+
+/**
+ * rht_for_each_entry_rcu - iterate over rcu hash chain of given type
+ * @pos:	type * to use as a loop cursor.
+ * @head:	head of the hash chain (struct rhash_head *)
+ * @member:	name of the rhash_head within the hashable struct.
+ *
+ * This hash chain list-traversal primitive may safely run concurrently with
+ * the _rcu fkht mutation primitives such as rht_insert() as long as the
+ * traversal is guarded by rcu_read_lock().
+ */
+#define rht_for_each_entry_rcu(pos, head, member) \
+	for (pos = rht_entry_safe_rcu(head, typeof(*(pos)), member); \
+	     pos; \
+	     pos = rht_entry_safe_rcu((pos)->member.next, \
+				      typeof(*(pos)), member))
+
+#endif /* _LINUX_RHASHTABLE_H */

+ 8 - 0
lib/Kconfig.debug

@@ -1550,6 +1550,14 @@ config TEST_STRING_HELPERS
 config TEST_KSTRTOX
 config TEST_KSTRTOX
 	tristate "Test kstrto*() family of functions at runtime"
 	tristate "Test kstrto*() family of functions at runtime"
 
 
+config TEST_RHASHTABLE
+	bool "Perform selftest on resizable hash table"
+	default n
+	help
+	  Enable this option to test the rhashtable functions at boot.
+
+	  If unsure, say N.
+
 endmenu # runtime tests
 endmenu # runtime tests
 
 
 config PROVIDE_OHCI1394_DMA_INIT
 config PROVIDE_OHCI1394_DMA_INIT

+ 1 - 1
lib/Makefile

@@ -26,7 +26,7 @@ obj-y += bcd.o div64.o sort.o parser.o halfmd4.o debug_locks.o random32.o \
 	 bust_spinlocks.o hexdump.o kasprintf.o bitmap.o scatterlist.o \
 	 bust_spinlocks.o hexdump.o kasprintf.o bitmap.o scatterlist.o \
 	 gcd.o lcm.o list_sort.o uuid.o flex_array.o iovec.o clz_ctz.o \
 	 gcd.o lcm.o list_sort.o uuid.o flex_array.o iovec.o clz_ctz.o \
 	 bsearch.o find_last_bit.o find_next_bit.o llist.o memweight.o kfifo.o \
 	 bsearch.o find_last_bit.o find_next_bit.o llist.o memweight.o kfifo.o \
-	 percpu-refcount.o percpu_ida.o hash.o
+	 percpu-refcount.o percpu_ida.o hash.o rhashtable.o
 obj-y += string_helpers.o
 obj-y += string_helpers.o
 obj-$(CONFIG_TEST_STRING_HELPERS) += test-string_helpers.o
 obj-$(CONFIG_TEST_STRING_HELPERS) += test-string_helpers.o
 obj-y += kstrtox.o
 obj-y += kstrtox.o

+ 797 - 0
lib/rhashtable.c

@@ -0,0 +1,797 @@
+/*
+ * Resizable, Scalable, Concurrent Hash Table
+ *
+ * Copyright (c) 2014 Thomas Graf <tgraf@suug.ch>
+ * Copyright (c) 2008-2014 Patrick McHardy <kaber@trash.net>
+ *
+ * Based on the following paper:
+ * https://www.usenix.org/legacy/event/atc11/tech/final_files/Triplett.pdf
+ *
+ * Code partially derived from nft_hash
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License version 2 as
+ * published by the Free Software Foundation.
+ */
+
+#include <linux/kernel.h>
+#include <linux/init.h>
+#include <linux/log2.h>
+#include <linux/slab.h>
+#include <linux/vmalloc.h>
+#include <linux/mm.h>
+#include <linux/hash.h>
+#include <linux/random.h>
+#include <linux/rhashtable.h>
+#include <linux/log2.h>
+
+#define HASH_DEFAULT_SIZE	64UL
+#define HASH_MIN_SIZE		4UL
+
+#define ASSERT_RHT_MUTEX(HT) BUG_ON(!lockdep_rht_mutex_is_held(HT))
+
+#ifdef CONFIG_PROVE_LOCKING
+int lockdep_rht_mutex_is_held(const struct rhashtable *ht)
+{
+	return ht->p.mutex_is_held();
+}
+EXPORT_SYMBOL_GPL(lockdep_rht_mutex_is_held);
+#endif
+
+/**
+ * rht_obj - cast hash head to outer object
+ * @ht:		hash table
+ * @he:		hashed node
+ */
+void *rht_obj(const struct rhashtable *ht, const struct rhash_head *he)
+{
+	return (void *) he - ht->p.head_offset;
+}
+EXPORT_SYMBOL_GPL(rht_obj);
+
+static u32 __hashfn(const struct rhashtable *ht, const void *key,
+		      u32 len, u32 hsize)
+{
+	u32 h;
+
+	h = ht->p.hashfn(key, len, ht->p.hash_rnd);
+
+	return h & (hsize - 1);
+}
+
+/**
+ * rhashtable_hashfn - compute hash for key of given length
+ * @ht:		hash table to compuate for
+ * @key:	pointer to key
+ * @len:	length of key
+ *
+ * Computes the hash value using the hash function provided in the 'hashfn'
+ * of struct rhashtable_params. The returned value is guaranteed to be
+ * smaller than the number of buckets in the hash table.
+ */
+u32 rhashtable_hashfn(const struct rhashtable *ht, const void *key, u32 len)
+{
+	struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht);
+
+	return __hashfn(ht, key, len, tbl->size);
+}
+EXPORT_SYMBOL_GPL(rhashtable_hashfn);
+
+static u32 obj_hashfn(const struct rhashtable *ht, const void *ptr, u32 hsize)
+{
+	if (unlikely(!ht->p.key_len)) {
+		u32 h;
+
+		h = ht->p.obj_hashfn(ptr, ht->p.hash_rnd);
+
+		return h & (hsize - 1);
+	}
+
+	return __hashfn(ht, ptr + ht->p.key_offset, ht->p.key_len, hsize);
+}
+
+/**
+ * rhashtable_obj_hashfn - compute hash for hashed object
+ * @ht:		hash table to compuate for
+ * @ptr:	pointer to hashed object
+ *
+ * Computes the hash value using the hash function `hashfn` respectively
+ * 'obj_hashfn' depending on whether the hash table is set up to work with
+ * a fixed length key. The returned value is guaranteed to be smaller than
+ * the number of buckets in the hash table.
+ */
+u32 rhashtable_obj_hashfn(const struct rhashtable *ht, void *ptr)
+{
+	struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht);
+
+	return obj_hashfn(ht, ptr, tbl->size);
+}
+EXPORT_SYMBOL_GPL(rhashtable_obj_hashfn);
+
+static u32 head_hashfn(const struct rhashtable *ht,
+		       const struct rhash_head *he, u32 hsize)
+{
+	return obj_hashfn(ht, rht_obj(ht, he), hsize);
+}
+
+static struct bucket_table *bucket_table_alloc(size_t nbuckets, gfp_t flags)
+{
+	struct bucket_table *tbl;
+	size_t size;
+
+	size = sizeof(*tbl) + nbuckets * sizeof(tbl->buckets[0]);
+	tbl = kzalloc(size, flags);
+	if (tbl == NULL)
+		tbl = vzalloc(size);
+
+	if (tbl == NULL)
+		return NULL;
+
+	tbl->size = nbuckets;
+
+	return tbl;
+}
+
+static void bucket_table_free(const struct bucket_table *tbl)
+{
+	kvfree(tbl);
+}
+
+/**
+ * rht_grow_above_75 - returns true if nelems > 0.75 * table-size
+ * @ht:		hash table
+ * @new_size:	new table size
+ */
+bool rht_grow_above_75(const struct rhashtable *ht, size_t new_size)
+{
+	/* Expand table when exceeding 75% load */
+	return ht->nelems > (new_size / 4 * 3);
+}
+EXPORT_SYMBOL_GPL(rht_grow_above_75);
+
+/**
+ * rht_shrink_below_30 - returns true if nelems < 0.3 * table-size
+ * @ht:		hash table
+ * @new_size:	new table size
+ */
+bool rht_shrink_below_30(const struct rhashtable *ht, size_t new_size)
+{
+	/* Shrink table beneath 30% load */
+	return ht->nelems < (new_size * 3 / 10);
+}
+EXPORT_SYMBOL_GPL(rht_shrink_below_30);
+
+static void hashtable_chain_unzip(const struct rhashtable *ht,
+				  const struct bucket_table *new_tbl,
+				  struct bucket_table *old_tbl, size_t n)
+{
+	struct rhash_head *he, *p, *next;
+	unsigned int h;
+
+	/* Old bucket empty, no work needed. */
+	p = rht_dereference(old_tbl->buckets[n], ht);
+	if (!p)
+		return;
+
+	/* Advance the old bucket pointer one or more times until it
+	 * reaches a node that doesn't hash to the same bucket as the
+	 * previous node p. Call the previous node p;
+	 */
+	h = head_hashfn(ht, p, new_tbl->size);
+	rht_for_each(he, p->next, ht) {
+		if (head_hashfn(ht, he, new_tbl->size) != h)
+			break;
+		p = he;
+	}
+	RCU_INIT_POINTER(old_tbl->buckets[n], p->next);
+
+	/* Find the subsequent node which does hash to the same
+	 * bucket as node P, or NULL if no such node exists.
+	 */
+	next = NULL;
+	if (he) {
+		rht_for_each(he, he->next, ht) {
+			if (head_hashfn(ht, he, new_tbl->size) == h) {
+				next = he;
+				break;
+			}
+		}
+	}
+
+	/* Set p's next pointer to that subsequent node pointer,
+	 * bypassing the nodes which do not hash to p's bucket
+	 */
+	RCU_INIT_POINTER(p->next, next);
+}
+
+/**
+ * rhashtable_expand - Expand hash table while allowing concurrent lookups
+ * @ht:		the hash table to expand
+ * @flags:	allocation flags
+ *
+ * A secondary bucket array is allocated and the hash entries are migrated
+ * while keeping them on both lists until the end of the RCU grace period.
+ *
+ * This function may only be called in a context where it is safe to call
+ * synchronize_rcu(), e.g. not within a rcu_read_lock() section.
+ *
+ * The caller must ensure that no concurrent table mutations take place.
+ * It is however valid to have concurrent lookups if they are RCU protected.
+ */
+int rhashtable_expand(struct rhashtable *ht, gfp_t flags)
+{
+	struct bucket_table *new_tbl, *old_tbl = rht_dereference(ht->tbl, ht);
+	struct rhash_head *he;
+	unsigned int i, h;
+	bool complete;
+
+	ASSERT_RHT_MUTEX(ht);
+
+	if (ht->p.max_shift && ht->shift >= ht->p.max_shift)
+		return 0;
+
+	new_tbl = bucket_table_alloc(old_tbl->size * 2, flags);
+	if (new_tbl == NULL)
+		return -ENOMEM;
+
+	ht->shift++;
+
+	/* For each new bucket, search the corresponding old bucket
+	 * for the first entry that hashes to the new bucket, and
+	 * link the new bucket to that entry. Since all the entries
+	 * which will end up in the new bucket appear in the same
+	 * old bucket, this constructs an entirely valid new hash
+	 * table, but with multiple buckets "zipped" together into a
+	 * single imprecise chain.
+	 */
+	for (i = 0; i < new_tbl->size; i++) {
+		h = i & (old_tbl->size - 1);
+		rht_for_each(he, old_tbl->buckets[h], ht) {
+			if (head_hashfn(ht, he, new_tbl->size) == i) {
+				RCU_INIT_POINTER(new_tbl->buckets[i], he);
+				break;
+			}
+		}
+	}
+
+	/* Publish the new table pointer. Lookups may now traverse
+	 * the new table, but they will not benefit from any
+	 * additional efficiency until later steps unzip the buckets.
+	 */
+	rcu_assign_pointer(ht->tbl, new_tbl);
+
+	/* Unzip interleaved hash chains */
+	do {
+		/* Wait for readers. All new readers will see the new
+		 * table, and thus no references to the old table will
+		 * remain.
+		 */
+		synchronize_rcu();
+
+		/* For each bucket in the old table (each of which
+		 * contains items from multiple buckets of the new
+		 * table): ...
+		 */
+		complete = true;
+		for (i = 0; i < old_tbl->size; i++) {
+			hashtable_chain_unzip(ht, new_tbl, old_tbl, i);
+			if (old_tbl->buckets[i] != NULL)
+				complete = false;
+		}
+	} while (!complete);
+
+	bucket_table_free(old_tbl);
+	return 0;
+}
+EXPORT_SYMBOL_GPL(rhashtable_expand);
+
+/**
+ * rhashtable_shrink - Shrink hash table while allowing concurrent lookups
+ * @ht:		the hash table to shrink
+ * @flags:	allocation flags
+ *
+ * This function may only be called in a context where it is safe to call
+ * synchronize_rcu(), e.g. not within a rcu_read_lock() section.
+ *
+ * The caller must ensure that no concurrent table mutations take place.
+ * It is however valid to have concurrent lookups if they are RCU protected.
+ */
+int rhashtable_shrink(struct rhashtable *ht, gfp_t flags)
+{
+	struct bucket_table *ntbl, *tbl = rht_dereference(ht->tbl, ht);
+	struct rhash_head __rcu **pprev;
+	unsigned int i;
+
+	ASSERT_RHT_MUTEX(ht);
+
+	if (tbl->size <= HASH_MIN_SIZE)
+		return 0;
+
+	ntbl = bucket_table_alloc(tbl->size / 2, flags);
+	if (ntbl == NULL)
+		return -ENOMEM;
+
+	ht->shift--;
+
+	/* Link each bucket in the new table to the first bucket
+	 * in the old table that contains entries which will hash
+	 * to the new bucket.
+	 */
+	for (i = 0; i < ntbl->size; i++) {
+		ntbl->buckets[i] = tbl->buckets[i];
+
+		/* Link each bucket in the new table to the first bucket
+		 * in the old table that contains entries which will hash
+		 * to the new bucket.
+		 */
+		for (pprev = &ntbl->buckets[i]; *pprev != NULL;
+		     pprev = &rht_dereference(*pprev, ht)->next)
+			;
+		RCU_INIT_POINTER(*pprev, tbl->buckets[i + ntbl->size]);
+	}
+
+	/* Publish the new, valid hash table */
+	rcu_assign_pointer(ht->tbl, ntbl);
+
+	/* Wait for readers. No new readers will have references to the
+	 * old hash table.
+	 */
+	synchronize_rcu();
+
+	bucket_table_free(tbl);
+
+	return 0;
+}
+EXPORT_SYMBOL_GPL(rhashtable_shrink);
+
+/**
+ * rhashtable_insert - insert object into hash hash table
+ * @ht:		hash table
+ * @obj:	pointer to hash head inside object
+ * @flags:	allocation flags (table expansion)
+ *
+ * Will automatically grow the table via rhashtable_expand() if the the
+ * grow_decision function specified at rhashtable_init() returns true.
+ *
+ * The caller must ensure that no concurrent table mutations occur. It is
+ * however valid to have concurrent lookups if they are RCU protected.
+ */
+void rhashtable_insert(struct rhashtable *ht, struct rhash_head *obj,
+		       gfp_t flags)
+{
+	struct bucket_table *tbl = rht_dereference(ht->tbl, ht);
+	u32 hash;
+
+	ASSERT_RHT_MUTEX(ht);
+
+	hash = head_hashfn(ht, obj, tbl->size);
+	RCU_INIT_POINTER(obj->next, tbl->buckets[hash]);
+	rcu_assign_pointer(tbl->buckets[hash], obj);
+	ht->nelems++;
+
+	if (ht->p.grow_decision && ht->p.grow_decision(ht, tbl->size))
+		rhashtable_expand(ht, flags);
+}
+EXPORT_SYMBOL_GPL(rhashtable_insert);
+
+/**
+ * rhashtable_remove_pprev - remove object from hash table given previous element
+ * @ht:		hash table
+ * @obj:	pointer to hash head inside object
+ * @pprev:	pointer to previous element
+ * @flags:	allocation flags (table expansion)
+ *
+ * Identical to rhashtable_remove() but caller is alreayd aware of the element
+ * in front of the element to be deleted. This is in particular useful for
+ * deletion when combined with walking or lookup.
+ */
+void rhashtable_remove_pprev(struct rhashtable *ht, struct rhash_head *obj,
+			     struct rhash_head **pprev, gfp_t flags)
+{
+	struct bucket_table *tbl = rht_dereference(ht->tbl, ht);
+
+	ASSERT_RHT_MUTEX(ht);
+
+	RCU_INIT_POINTER(*pprev, obj->next);
+	ht->nelems--;
+
+	if (ht->p.shrink_decision &&
+	    ht->p.shrink_decision(ht, tbl->size))
+		rhashtable_shrink(ht, flags);
+}
+EXPORT_SYMBOL_GPL(rhashtable_remove_pprev);
+
+/**
+ * rhashtable_remove - remove object from hash table
+ * @ht:		hash table
+ * @obj:	pointer to hash head inside object
+ * @flags:	allocation flags (table expansion)
+ *
+ * Since the hash chain is single linked, the removal operation needs to
+ * walk the bucket chain upon removal. The removal operation is thus
+ * considerable slow if the hash table is not correctly sized.
+ *
+ * Will automatically shrink the table via rhashtable_expand() if the the
+ * shrink_decision function specified at rhashtable_init() returns true.
+ *
+ * The caller must ensure that no concurrent table mutations occur. It is
+ * however valid to have concurrent lookups if they are RCU protected.
+ */
+bool rhashtable_remove(struct rhashtable *ht, struct rhash_head *obj,
+		       gfp_t flags)
+{
+	struct bucket_table *tbl = rht_dereference(ht->tbl, ht);
+	struct rhash_head __rcu **pprev;
+	struct rhash_head *he;
+	u32 h;
+
+	ASSERT_RHT_MUTEX(ht);
+
+	h = head_hashfn(ht, obj, tbl->size);
+
+	pprev = &tbl->buckets[h];
+	rht_for_each(he, tbl->buckets[h], ht) {
+		if (he != obj) {
+			pprev = &he->next;
+			continue;
+		}
+
+		rhashtable_remove_pprev(ht, he, pprev, flags);
+		return true;
+	}
+
+	return false;
+}
+EXPORT_SYMBOL_GPL(rhashtable_remove);
+
+/**
+ * rhashtable_lookup - lookup key in hash table
+ * @ht:		hash table
+ * @key:	pointer to key
+ *
+ * Computes the hash value for the key and traverses the bucket chain looking
+ * for a entry with an identical key. The first matching entry is returned.
+ *
+ * This lookup function may only be used for fixed key hash table (key_len
+ * paramter set). It will BUG() if used inappropriately.
+ *
+ * Lookups may occur in parallel with hash mutations as long as the lookup is
+ * guarded by rcu_read_lock(). The caller must take care of this.
+ */
+void *rhashtable_lookup(const struct rhashtable *ht, const void *key)
+{
+	const struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht);
+	struct rhash_head *he;
+	u32 h;
+
+	BUG_ON(!ht->p.key_len);
+
+	h = __hashfn(ht, key, ht->p.key_len, tbl->size);
+	rht_for_each_rcu(he, tbl->buckets[h], ht) {
+		if (memcmp(rht_obj(ht, he) + ht->p.key_offset, key,
+			   ht->p.key_len))
+			continue;
+		return (void *) he - ht->p.head_offset;
+	}
+
+	return NULL;
+}
+EXPORT_SYMBOL_GPL(rhashtable_lookup);
+
+/**
+ * rhashtable_lookup_compare - search hash table with compare function
+ * @ht:		hash table
+ * @hash:	hash value of desired entry
+ * @compare:	compare function, must return true on match
+ * @arg:	argument passed on to compare function
+ *
+ * Traverses the bucket chain behind the provided hash value and calls the
+ * specified compare function for each entry.
+ *
+ * Lookups may occur in parallel with hash mutations as long as the lookup is
+ * guarded by rcu_read_lock(). The caller must take care of this.
+ *
+ * Returns the first entry on which the compare function returned true.
+ */
+void *rhashtable_lookup_compare(const struct rhashtable *ht, u32 hash,
+				bool (*compare)(void *, void *), void *arg)
+{
+	const struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht);
+	struct rhash_head *he;
+
+	if (unlikely(hash >= tbl->size))
+		return NULL;
+
+	rht_for_each_rcu(he, tbl->buckets[hash], ht) {
+		if (!compare(rht_obj(ht, he), arg))
+			continue;
+		return (void *) he - ht->p.head_offset;
+	}
+
+	return NULL;
+}
+EXPORT_SYMBOL_GPL(rhashtable_lookup_compare);
+
+static size_t rounded_hashtable_size(unsigned int nelem)
+{
+	return max(roundup_pow_of_two(nelem * 4 / 3), HASH_MIN_SIZE);
+}
+
+/**
+ * rhashtable_init - initialize a new hash table
+ * @ht:		hash table to be initialized
+ * @params:	configuration parameters
+ *
+ * Initializes a new hash table based on the provided configuration
+ * parameters. A table can be configured either with a variable or
+ * fixed length key:
+ *
+ * Configuration Example 1: Fixed length keys
+ * struct test_obj {
+ *	int			key;
+ *	void *			my_member;
+ *	struct rhash_head	node;
+ * };
+ *
+ * struct rhashtable_params params = {
+ *	.head_offset = offsetof(struct test_obj, node),
+ *	.key_offset = offsetof(struct test_obj, key),
+ *	.key_len = sizeof(int),
+ *	.hashfn = arch_fast_hash,
+ *	.mutex_is_held = &my_mutex_is_held,
+ * };
+ *
+ * Configuration Example 2: Variable length keys
+ * struct test_obj {
+ *	[...]
+ *	struct rhash_head	node;
+ * };
+ *
+ * u32 my_hash_fn(const void *data, u32 seed)
+ * {
+ *	struct test_obj *obj = data;
+ *
+ *	return [... hash ...];
+ * }
+ *
+ * struct rhashtable_params params = {
+ *	.head_offset = offsetof(struct test_obj, node),
+ *	.hashfn = arch_fast_hash,
+ *	.obj_hashfn = my_hash_fn,
+ *	.mutex_is_held = &my_mutex_is_held,
+ * };
+ */
+int rhashtable_init(struct rhashtable *ht, struct rhashtable_params *params)
+{
+	struct bucket_table *tbl;
+	size_t size;
+
+	size = HASH_DEFAULT_SIZE;
+
+	if ((params->key_len && !params->hashfn) ||
+	    (!params->key_len && !params->obj_hashfn))
+		return -EINVAL;
+
+	if (params->nelem_hint)
+		size = rounded_hashtable_size(params->nelem_hint);
+
+	tbl = bucket_table_alloc(size, GFP_KERNEL);
+	if (tbl == NULL)
+		return -ENOMEM;
+
+	memset(ht, 0, sizeof(*ht));
+	ht->shift = ilog2(tbl->size);
+	memcpy(&ht->p, params, sizeof(*params));
+	RCU_INIT_POINTER(ht->tbl, tbl);
+
+	if (!ht->p.hash_rnd)
+		get_random_bytes(&ht->p.hash_rnd, sizeof(ht->p.hash_rnd));
+
+	return 0;
+}
+EXPORT_SYMBOL_GPL(rhashtable_init);
+
+/**
+ * rhashtable_destroy - destroy hash table
+ * @ht:		the hash table to destroy
+ *
+ * Frees the bucket array.
+ */
+void rhashtable_destroy(const struct rhashtable *ht)
+{
+	const struct bucket_table *tbl = rht_dereference(ht->tbl, ht);
+
+	bucket_table_free(tbl);
+}
+EXPORT_SYMBOL_GPL(rhashtable_destroy);
+
+/**************************************************************************
+ * Self Test
+ **************************************************************************/
+
+#ifdef CONFIG_TEST_RHASHTABLE
+
+#define TEST_HT_SIZE	8
+#define TEST_ENTRIES	2048
+#define TEST_PTR	((void *) 0xdeadbeef)
+#define TEST_NEXPANDS	4
+
+static int test_mutex_is_held(void)
+{
+	return 1;
+}
+
+struct test_obj {
+	void			*ptr;
+	int			value;
+	struct rhash_head	node;
+};
+
+static int __init test_rht_lookup(struct rhashtable *ht)
+{
+	unsigned int i;
+
+	for (i = 0; i < TEST_ENTRIES * 2; i++) {
+		struct test_obj *obj;
+		bool expected = !(i % 2);
+		u32 key = i;
+
+		obj = rhashtable_lookup(ht, &key);
+
+		if (expected && !obj) {
+			pr_warn("Test failed: Could not find key %u\n", key);
+			return -ENOENT;
+		} else if (!expected && obj) {
+			pr_warn("Test failed: Unexpected entry found for key %u\n",
+				key);
+			return -EEXIST;
+		} else if (expected && obj) {
+			if (obj->ptr != TEST_PTR || obj->value != i) {
+				pr_warn("Test failed: Lookup value mismatch %p!=%p, %u!=%u\n",
+					obj->ptr, TEST_PTR, obj->value, i);
+				return -EINVAL;
+			}
+		}
+	}
+
+	return 0;
+}
+
+static void test_bucket_stats(struct rhashtable *ht,
+				     struct bucket_table *tbl,
+				     bool quiet)
+{
+	unsigned int cnt, i, total = 0;
+	struct test_obj *obj;
+
+	for (i = 0; i < tbl->size; i++) {
+		cnt = 0;
+
+		if (!quiet)
+			pr_info(" [%#4x/%zu]", i, tbl->size);
+
+		rht_for_each_entry_rcu(obj, tbl->buckets[i], node) {
+			cnt++;
+			total++;
+			if (!quiet)
+				pr_cont(" [%p],", obj);
+		}
+
+		if (!quiet)
+			pr_cont("\n  [%#x] first element: %p, chain length: %u\n",
+				i, tbl->buckets[i], cnt);
+	}
+
+	pr_info("  Traversal complete: counted=%u, nelems=%zu, entries=%d\n",
+		total, ht->nelems, TEST_ENTRIES);
+}
+
+static int __init test_rhashtable(struct rhashtable *ht)
+{
+	struct bucket_table *tbl;
+	struct test_obj *obj, *next;
+	int err;
+	unsigned int i;
+
+	/*
+	 * Insertion Test:
+	 * Insert TEST_ENTRIES into table with all keys even numbers
+	 */
+	pr_info("  Adding %d keys\n", TEST_ENTRIES);
+	for (i = 0; i < TEST_ENTRIES; i++) {
+		struct test_obj *obj;
+
+		obj = kzalloc(sizeof(*obj), GFP_KERNEL);
+		if (!obj) {
+			err = -ENOMEM;
+			goto error;
+		}
+
+		obj->ptr = TEST_PTR;
+		obj->value = i * 2;
+
+		rhashtable_insert(ht, &obj->node, GFP_KERNEL);
+	}
+
+	rcu_read_lock();
+	tbl = rht_dereference_rcu(ht->tbl, ht);
+	test_bucket_stats(ht, tbl, true);
+	test_rht_lookup(ht);
+	rcu_read_unlock();
+
+	for (i = 0; i < TEST_NEXPANDS; i++) {
+		pr_info("  Table expansion iteration %u...\n", i);
+		rhashtable_expand(ht, GFP_KERNEL);
+
+		rcu_read_lock();
+		pr_info("  Verifying lookups...\n");
+		test_rht_lookup(ht);
+		rcu_read_unlock();
+	}
+
+	for (i = 0; i < TEST_NEXPANDS; i++) {
+		pr_info("  Table shrinkage iteration %u...\n", i);
+		rhashtable_shrink(ht, GFP_KERNEL);
+
+		rcu_read_lock();
+		pr_info("  Verifying lookups...\n");
+		test_rht_lookup(ht);
+		rcu_read_unlock();
+	}
+
+	pr_info("  Deleting %d keys\n", TEST_ENTRIES);
+	for (i = 0; i < TEST_ENTRIES; i++) {
+		u32 key = i * 2;
+
+		obj = rhashtable_lookup(ht, &key);
+		BUG_ON(!obj);
+
+		rhashtable_remove(ht, &obj->node, GFP_KERNEL);
+		kfree(obj);
+	}
+
+	return 0;
+
+error:
+	tbl = rht_dereference_rcu(ht->tbl, ht);
+	for (i = 0; i < tbl->size; i++)
+		rht_for_each_entry_safe(obj, next, tbl->buckets[i], ht, node)
+			kfree(obj);
+
+	return err;
+}
+
+static int __init test_rht_init(void)
+{
+	struct rhashtable ht;
+	struct rhashtable_params params = {
+		.nelem_hint = TEST_HT_SIZE,
+		.head_offset = offsetof(struct test_obj, node),
+		.key_offset = offsetof(struct test_obj, value),
+		.key_len = sizeof(int),
+		.hashfn = arch_fast_hash,
+		.mutex_is_held = &test_mutex_is_held,
+		.grow_decision = rht_grow_above_75,
+		.shrink_decision = rht_shrink_below_30,
+	};
+	int err;
+
+	pr_info("Running resizable hashtable tests...\n");
+
+	err = rhashtable_init(&ht, &params);
+	if (err < 0) {
+		pr_warn("Test failed: Unable to initialize hashtable: %d\n",
+			err);
+		return err;
+	}
+
+	err = test_rhashtable(&ht);
+
+	rhashtable_destroy(&ht);
+
+	return err;
+}
+
+subsys_initcall(test_rht_init);
+
+#endif /* CONFIG_TEST_RHASHTABLE */

+ 55 - 236
net/netfilter/nft_hash.c

@@ -15,209 +15,40 @@
 #include <linux/log2.h>
 #include <linux/log2.h>
 #include <linux/jhash.h>
 #include <linux/jhash.h>
 #include <linux/netlink.h>
 #include <linux/netlink.h>
-#include <linux/vmalloc.h>
+#include <linux/rhashtable.h>
 #include <linux/netfilter.h>
 #include <linux/netfilter.h>
 #include <linux/netfilter/nf_tables.h>
 #include <linux/netfilter/nf_tables.h>
 #include <net/netfilter/nf_tables.h>
 #include <net/netfilter/nf_tables.h>
 
 
-#define NFT_HASH_MIN_SIZE	4UL
-
-struct nft_hash {
-	struct nft_hash_table __rcu	*tbl;
-};
-
-struct nft_hash_table {
-	unsigned int			size;
-	struct nft_hash_elem __rcu	*buckets[];
-};
+/* We target a hash table size of 4, element hint is 75% of final size */
+#define NFT_HASH_ELEMENT_HINT 3
 
 
 struct nft_hash_elem {
 struct nft_hash_elem {
-	struct nft_hash_elem __rcu	*next;
+	struct rhash_head		node;
 	struct nft_data			key;
 	struct nft_data			key;
 	struct nft_data			data[];
 	struct nft_data			data[];
 };
 };
 
 
-#define nft_hash_for_each_entry(i, head) \
-	for (i = nft_dereference(head); i != NULL; i = nft_dereference(i->next))
-#define nft_hash_for_each_entry_rcu(i, head) \
-	for (i = rcu_dereference(head); i != NULL; i = rcu_dereference(i->next))
-
-static u32 nft_hash_rnd __read_mostly;
-static bool nft_hash_rnd_initted __read_mostly;
-
-static unsigned int nft_hash_data(const struct nft_data *data,
-				  unsigned int hsize, unsigned int len)
-{
-	unsigned int h;
-
-	h = jhash(data->data, len, nft_hash_rnd);
-	return h & (hsize - 1);
-}
-
 static bool nft_hash_lookup(const struct nft_set *set,
 static bool nft_hash_lookup(const struct nft_set *set,
 			    const struct nft_data *key,
 			    const struct nft_data *key,
 			    struct nft_data *data)
 			    struct nft_data *data)
 {
 {
-	const struct nft_hash *priv = nft_set_priv(set);
-	const struct nft_hash_table *tbl = rcu_dereference(priv->tbl);
+	const struct rhashtable *priv = nft_set_priv(set);
 	const struct nft_hash_elem *he;
 	const struct nft_hash_elem *he;
-	unsigned int h;
-
-	h = nft_hash_data(key, tbl->size, set->klen);
-	nft_hash_for_each_entry_rcu(he, tbl->buckets[h]) {
-		if (nft_data_cmp(&he->key, key, set->klen))
-			continue;
-		if (set->flags & NFT_SET_MAP)
-			nft_data_copy(data, he->data);
-		return true;
-	}
-	return false;
-}
-
-static void nft_hash_tbl_free(const struct nft_hash_table *tbl)
-{
-	kvfree(tbl);
-}
-
-static unsigned int nft_hash_tbl_size(unsigned int nelem)
-{
-	return max(roundup_pow_of_two(nelem * 4 / 3), NFT_HASH_MIN_SIZE);
-}
-
-static struct nft_hash_table *nft_hash_tbl_alloc(unsigned int nbuckets)
-{
-	struct nft_hash_table *tbl;
-	size_t size;
-
-	size = sizeof(*tbl) + nbuckets * sizeof(tbl->buckets[0]);
-	tbl = kzalloc(size, GFP_KERNEL | __GFP_REPEAT | __GFP_NOWARN);
-	if (tbl == NULL)
-		tbl = vzalloc(size);
-	if (tbl == NULL)
-		return NULL;
-	tbl->size = nbuckets;
-
-	return tbl;
-}
-
-static void nft_hash_chain_unzip(const struct nft_set *set,
-				 const struct nft_hash_table *ntbl,
-				 struct nft_hash_table *tbl, unsigned int n)
-{
-	struct nft_hash_elem *he, *last, *next;
-	unsigned int h;
-
-	he = nft_dereference(tbl->buckets[n]);
-	if (he == NULL)
-		return;
-	h = nft_hash_data(&he->key, ntbl->size, set->klen);
-
-	/* Find last element of first chain hashing to bucket h */
-	last = he;
-	nft_hash_for_each_entry(he, he->next) {
-		if (nft_hash_data(&he->key, ntbl->size, set->klen) != h)
-			break;
-		last = he;
-	}
-
-	/* Unlink first chain from the old table */
-	RCU_INIT_POINTER(tbl->buckets[n], last->next);
 
 
-	/* If end of chain reached, done */
-	if (he == NULL)
-		return;
+	he = rhashtable_lookup(priv, key);
+	if (he && set->flags & NFT_SET_MAP)
+		nft_data_copy(data, he->data);
 
 
-	/* Find first element of second chain hashing to bucket h */
-	next = NULL;
-	nft_hash_for_each_entry(he, he->next) {
-		if (nft_hash_data(&he->key, ntbl->size, set->klen) != h)
-			continue;
-		next = he;
-		break;
-	}
-
-	/* Link the two chains */
-	RCU_INIT_POINTER(last->next, next);
-}
-
-static int nft_hash_tbl_expand(const struct nft_set *set, struct nft_hash *priv)
-{
-	struct nft_hash_table *tbl = nft_dereference(priv->tbl), *ntbl;
-	struct nft_hash_elem *he;
-	unsigned int i, h;
-	bool complete;
-
-	ntbl = nft_hash_tbl_alloc(tbl->size * 2);
-	if (ntbl == NULL)
-		return -ENOMEM;
-
-	/* Link new table's buckets to first element in the old table
-	 * hashing to the new bucket.
-	 */
-	for (i = 0; i < ntbl->size; i++) {
-		h = i < tbl->size ? i : i - tbl->size;
-		nft_hash_for_each_entry(he, tbl->buckets[h]) {
-			if (nft_hash_data(&he->key, ntbl->size, set->klen) != i)
-				continue;
-			RCU_INIT_POINTER(ntbl->buckets[i], he);
-			break;
-		}
-	}
-
-	/* Publish new table */
-	rcu_assign_pointer(priv->tbl, ntbl);
-
-	/* Unzip interleaved hash chains */
-	do {
-		/* Wait for readers to use new table/unzipped chains */
-		synchronize_rcu();
-
-		complete = true;
-		for (i = 0; i < tbl->size; i++) {
-			nft_hash_chain_unzip(set, ntbl, tbl, i);
-			if (tbl->buckets[i] != NULL)
-				complete = false;
-		}
-	} while (!complete);
-
-	nft_hash_tbl_free(tbl);
-	return 0;
-}
-
-static int nft_hash_tbl_shrink(const struct nft_set *set, struct nft_hash *priv)
-{
-	struct nft_hash_table *tbl = nft_dereference(priv->tbl), *ntbl;
-	struct nft_hash_elem __rcu **pprev;
-	unsigned int i;
-
-	ntbl = nft_hash_tbl_alloc(tbl->size / 2);
-	if (ntbl == NULL)
-		return -ENOMEM;
-
-	for (i = 0; i < ntbl->size; i++) {
-		ntbl->buckets[i] = tbl->buckets[i];
-
-		for (pprev = &ntbl->buckets[i]; *pprev != NULL;
-		     pprev = &nft_dereference(*pprev)->next)
-			;
-		RCU_INIT_POINTER(*pprev, tbl->buckets[i + ntbl->size]);
-	}
-
-	/* Publish new table */
-	rcu_assign_pointer(priv->tbl, ntbl);
-	synchronize_rcu();
-
-	nft_hash_tbl_free(tbl);
-	return 0;
+	return !!he;
 }
 }
 
 
 static int nft_hash_insert(const struct nft_set *set,
 static int nft_hash_insert(const struct nft_set *set,
 			   const struct nft_set_elem *elem)
 			   const struct nft_set_elem *elem)
 {
 {
-	struct nft_hash *priv = nft_set_priv(set);
-	struct nft_hash_table *tbl = nft_dereference(priv->tbl);
+	struct rhashtable *priv = nft_set_priv(set);
 	struct nft_hash_elem *he;
 	struct nft_hash_elem *he;
-	unsigned int size, h;
+	unsigned int size;
 
 
 	if (elem->flags != 0)
 	if (elem->flags != 0)
 		return -EINVAL;
 		return -EINVAL;
@@ -234,13 +65,7 @@ static int nft_hash_insert(const struct nft_set *set,
 	if (set->flags & NFT_SET_MAP)
 	if (set->flags & NFT_SET_MAP)
 		nft_data_copy(he->data, &elem->data);
 		nft_data_copy(he->data, &elem->data);
 
 
-	h = nft_hash_data(&he->key, tbl->size, set->klen);
-	RCU_INIT_POINTER(he->next, tbl->buckets[h]);
-	rcu_assign_pointer(tbl->buckets[h], he);
-
-	/* Expand table when exceeding 75% load */
-	if (set->nelems + 1 > tbl->size / 4 * 3)
-		nft_hash_tbl_expand(set, priv);
+	rhashtable_insert(priv, &he->node, GFP_KERNEL);
 
 
 	return 0;
 	return 0;
 }
 }
@@ -257,36 +82,31 @@ static void nft_hash_elem_destroy(const struct nft_set *set,
 static void nft_hash_remove(const struct nft_set *set,
 static void nft_hash_remove(const struct nft_set *set,
 			    const struct nft_set_elem *elem)
 			    const struct nft_set_elem *elem)
 {
 {
-	struct nft_hash *priv = nft_set_priv(set);
-	struct nft_hash_table *tbl = nft_dereference(priv->tbl);
-	struct nft_hash_elem *he, __rcu **pprev;
+	struct rhashtable *priv = nft_set_priv(set);
+	struct rhash_head *he, __rcu **pprev;
 
 
 	pprev = elem->cookie;
 	pprev = elem->cookie;
-	he = nft_dereference((*pprev));
+	he = rht_dereference((*pprev), priv);
+
+	rhashtable_remove_pprev(priv, he, pprev, GFP_KERNEL);
 
 
-	RCU_INIT_POINTER(*pprev, he->next);
 	synchronize_rcu();
 	synchronize_rcu();
 	kfree(he);
 	kfree(he);
-
-	/* Shrink table beneath 30% load */
-	if (set->nelems - 1 < tbl->size * 3 / 10 &&
-	    tbl->size > NFT_HASH_MIN_SIZE)
-		nft_hash_tbl_shrink(set, priv);
 }
 }
 
 
 static int nft_hash_get(const struct nft_set *set, struct nft_set_elem *elem)
 static int nft_hash_get(const struct nft_set *set, struct nft_set_elem *elem)
 {
 {
-	const struct nft_hash *priv = nft_set_priv(set);
-	const struct nft_hash_table *tbl = nft_dereference(priv->tbl);
-	struct nft_hash_elem __rcu * const *pprev;
+	const struct rhashtable *priv = nft_set_priv(set);
+	const struct bucket_table *tbl = rht_dereference_rcu(priv->tbl, priv);
+	struct rhash_head __rcu * const *pprev;
 	struct nft_hash_elem *he;
 	struct nft_hash_elem *he;
-	unsigned int h;
+	u32 h;
 
 
-	h = nft_hash_data(&elem->key, tbl->size, set->klen);
+	h = rhashtable_hashfn(priv, &elem->key, set->klen);
 	pprev = &tbl->buckets[h];
 	pprev = &tbl->buckets[h];
-	nft_hash_for_each_entry(he, tbl->buckets[h]) {
+	rht_for_each_entry_rcu(he, tbl->buckets[h], node) {
 		if (nft_data_cmp(&he->key, &elem->key, set->klen)) {
 		if (nft_data_cmp(&he->key, &elem->key, set->klen)) {
-			pprev = &he->next;
+			pprev = &he->node.next;
 			continue;
 			continue;
 		}
 		}
 
 
@@ -302,14 +122,15 @@ static int nft_hash_get(const struct nft_set *set, struct nft_set_elem *elem)
 static void nft_hash_walk(const struct nft_ctx *ctx, const struct nft_set *set,
 static void nft_hash_walk(const struct nft_ctx *ctx, const struct nft_set *set,
 			  struct nft_set_iter *iter)
 			  struct nft_set_iter *iter)
 {
 {
-	const struct nft_hash *priv = nft_set_priv(set);
-	const struct nft_hash_table *tbl = nft_dereference(priv->tbl);
+	const struct rhashtable *priv = nft_set_priv(set);
+	const struct bucket_table *tbl;
 	const struct nft_hash_elem *he;
 	const struct nft_hash_elem *he;
 	struct nft_set_elem elem;
 	struct nft_set_elem elem;
 	unsigned int i;
 	unsigned int i;
 
 
+	tbl = rht_dereference_rcu(priv->tbl, priv);
 	for (i = 0; i < tbl->size; i++) {
 	for (i = 0; i < tbl->size; i++) {
-		nft_hash_for_each_entry(he, tbl->buckets[i]) {
+		rht_for_each_entry_rcu(he, tbl->buckets[i], node) {
 			if (iter->count < iter->skip)
 			if (iter->count < iter->skip)
 				goto cont;
 				goto cont;
 
 
@@ -329,48 +150,46 @@ cont:
 
 
 static unsigned int nft_hash_privsize(const struct nlattr * const nla[])
 static unsigned int nft_hash_privsize(const struct nlattr * const nla[])
 {
 {
-	return sizeof(struct nft_hash);
+	return sizeof(struct rhashtable);
+}
+
+static int lockdep_nfnl_lock_is_held(void)
+{
+	return lockdep_nfnl_is_held(NFNL_SUBSYS_NFTABLES);
 }
 }
 
 
 static int nft_hash_init(const struct nft_set *set,
 static int nft_hash_init(const struct nft_set *set,
 			 const struct nft_set_desc *desc,
 			 const struct nft_set_desc *desc,
 			 const struct nlattr * const tb[])
 			 const struct nlattr * const tb[])
 {
 {
-	struct nft_hash *priv = nft_set_priv(set);
-	struct nft_hash_table *tbl;
-	unsigned int size;
+	struct rhashtable *priv = nft_set_priv(set);
+	struct rhashtable_params params = {
+		.nelem_hint = desc->size ? : NFT_HASH_ELEMENT_HINT,
+		.head_offset = offsetof(struct nft_hash_elem, node),
+		.key_offset = offsetof(struct nft_hash_elem, key),
+		.key_len = set->klen,
+		.hashfn = jhash,
+		.grow_decision = rht_grow_above_75,
+		.shrink_decision = rht_shrink_below_30,
+		.mutex_is_held = lockdep_nfnl_lock_is_held,
+	};
 
 
-	if (unlikely(!nft_hash_rnd_initted)) {
-		get_random_bytes(&nft_hash_rnd, 4);
-		nft_hash_rnd_initted = true;
-	}
-
-	size = NFT_HASH_MIN_SIZE;
-	if (desc->size)
-		size = nft_hash_tbl_size(desc->size);
-
-	tbl = nft_hash_tbl_alloc(size);
-	if (tbl == NULL)
-		return -ENOMEM;
-	RCU_INIT_POINTER(priv->tbl, tbl);
-	return 0;
+	return rhashtable_init(priv, &params);
 }
 }
 
 
 static void nft_hash_destroy(const struct nft_set *set)
 static void nft_hash_destroy(const struct nft_set *set)
 {
 {
-	const struct nft_hash *priv = nft_set_priv(set);
-	const struct nft_hash_table *tbl = nft_dereference(priv->tbl);
+	const struct rhashtable *priv = nft_set_priv(set);
+	const struct bucket_table *tbl;
 	struct nft_hash_elem *he, *next;
 	struct nft_hash_elem *he, *next;
 	unsigned int i;
 	unsigned int i;
 
 
-	for (i = 0; i < tbl->size; i++) {
-		for (he = nft_dereference(tbl->buckets[i]); he != NULL;
-		     he = next) {
-			next = nft_dereference(he->next);
+	tbl = rht_dereference(priv->tbl, priv);
+	for (i = 0; i < tbl->size; i++)
+		rht_for_each_entry_safe(he, next, tbl->buckets[i], priv, node)
 			nft_hash_elem_destroy(set, he);
 			nft_hash_elem_destroy(set, he);
-		}
-	}
-	kfree(tbl);
+
+	rhashtable_destroy(priv);
 }
 }
 
 
 static bool nft_hash_estimate(const struct nft_set_desc *desc, u32 features,
 static bool nft_hash_estimate(const struct nft_set_desc *desc, u32 features,
@@ -383,8 +202,8 @@ static bool nft_hash_estimate(const struct nft_set_desc *desc, u32 features,
 		esize += FIELD_SIZEOF(struct nft_hash_elem, data[0]);
 		esize += FIELD_SIZEOF(struct nft_hash_elem, data[0]);
 
 
 	if (desc->size) {
 	if (desc->size) {
-		est->size = sizeof(struct nft_hash) +
-			    nft_hash_tbl_size(desc->size) *
+		est->size = sizeof(struct rhashtable) +
+			    roundup_pow_of_two(desc->size * 4 / 3) *
 			    sizeof(struct nft_hash_elem *) +
 			    sizeof(struct nft_hash_elem *) +
 			    desc->size * esize;
 			    desc->size * esize;
 	} else {
 	} else {

+ 107 - 178
net/netlink/af_netlink.c

@@ -58,7 +58,9 @@
 #include <linux/mutex.h>
 #include <linux/mutex.h>
 #include <linux/vmalloc.h>
 #include <linux/vmalloc.h>
 #include <linux/if_arp.h>
 #include <linux/if_arp.h>
+#include <linux/rhashtable.h>
 #include <asm/cacheflush.h>
 #include <asm/cacheflush.h>
+#include <linux/hash.h>
 
 
 #include <net/net_namespace.h>
 #include <net/net_namespace.h>
 #include <net/sock.h>
 #include <net/sock.h>
@@ -100,6 +102,18 @@ static atomic_t nl_table_users = ATOMIC_INIT(0);
 
 
 #define nl_deref_protected(X) rcu_dereference_protected(X, lockdep_is_held(&nl_table_lock));
 #define nl_deref_protected(X) rcu_dereference_protected(X, lockdep_is_held(&nl_table_lock));
 
 
+/* Protects netlink socket hash table mutations */
+DEFINE_MUTEX(nl_sk_hash_lock);
+
+static int lockdep_nl_sk_hash_is_held(void)
+{
+#ifdef CONFIG_LOCKDEP
+	return (debug_locks) ? lockdep_is_held(&nl_sk_hash_lock) : 1;
+#else
+	return 1;
+#endif
+}
+
 static ATOMIC_NOTIFIER_HEAD(netlink_chain);
 static ATOMIC_NOTIFIER_HEAD(netlink_chain);
 
 
 static DEFINE_SPINLOCK(netlink_tap_lock);
 static DEFINE_SPINLOCK(netlink_tap_lock);
@@ -110,11 +124,6 @@ static inline u32 netlink_group_mask(u32 group)
 	return group ? 1 << (group - 1) : 0;
 	return group ? 1 << (group - 1) : 0;
 }
 }
 
 
-static inline struct hlist_head *nl_portid_hashfn(struct nl_portid_hash *hash, u32 portid)
-{
-	return &hash->table[jhash_1word(portid, hash->rnd) & hash->mask];
-}
-
 int netlink_add_tap(struct netlink_tap *nt)
 int netlink_add_tap(struct netlink_tap *nt)
 {
 {
 	if (unlikely(nt->dev->type != ARPHRD_NETLINK))
 	if (unlikely(nt->dev->type != ARPHRD_NETLINK))
@@ -983,105 +992,48 @@ netlink_unlock_table(void)
 		wake_up(&nl_table_wait);
 		wake_up(&nl_table_wait);
 }
 }
 
 
-static bool netlink_compare(struct net *net, struct sock *sk)
-{
-	return net_eq(sock_net(sk), net);
-}
-
-static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid)
+struct netlink_compare_arg
 {
 {
-	struct netlink_table *table = &nl_table[protocol];
-	struct nl_portid_hash *hash = &table->hash;
-	struct hlist_head *head;
-	struct sock *sk;
-
-	read_lock(&nl_table_lock);
-	head = nl_portid_hashfn(hash, portid);
-	sk_for_each(sk, head) {
-		if (table->compare(net, sk) &&
-		    (nlk_sk(sk)->portid == portid)) {
-			sock_hold(sk);
-			goto found;
-		}
-	}
-	sk = NULL;
-found:
-	read_unlock(&nl_table_lock);
-	return sk;
-}
+	struct net *net;
+	u32 portid;
+};
 
 
-static struct hlist_head *nl_portid_hash_zalloc(size_t size)
+static bool netlink_compare(void *ptr, void *arg)
 {
 {
-	if (size <= PAGE_SIZE)
-		return kzalloc(size, GFP_ATOMIC);
-	else
-		return (struct hlist_head *)
-			__get_free_pages(GFP_ATOMIC | __GFP_ZERO,
-					 get_order(size));
-}
+	struct netlink_compare_arg *x = arg;
+	struct sock *sk = ptr;
 
 
-static void nl_portid_hash_free(struct hlist_head *table, size_t size)
-{
-	if (size <= PAGE_SIZE)
-		kfree(table);
-	else
-		free_pages((unsigned long)table, get_order(size));
+	return nlk_sk(sk)->portid == x->portid &&
+	       net_eq(sock_net(sk), x->net);
 }
 }
 
 
-static int nl_portid_hash_rehash(struct nl_portid_hash *hash, int grow)
+static struct sock *__netlink_lookup(struct netlink_table *table, u32 portid,
+				     struct net *net)
 {
 {
-	unsigned int omask, mask, shift;
-	size_t osize, size;
-	struct hlist_head *otable, *table;
-	int i;
-
-	omask = mask = hash->mask;
-	osize = size = (mask + 1) * sizeof(*table);
-	shift = hash->shift;
-
-	if (grow) {
-		if (++shift > hash->max_shift)
-			return 0;
-		mask = mask * 2 + 1;
-		size *= 2;
-	}
+	struct netlink_compare_arg arg = {
+		.net = net,
+		.portid = portid,
+	};
+	u32 hash;
 
 
-	table = nl_portid_hash_zalloc(size);
-	if (!table)
-		return 0;
+	hash = rhashtable_hashfn(&table->hash, &portid, sizeof(portid));
 
 
-	otable = hash->table;
-	hash->table = table;
-	hash->mask = mask;
-	hash->shift = shift;
-	get_random_bytes(&hash->rnd, sizeof(hash->rnd));
-
-	for (i = 0; i <= omask; i++) {
-		struct sock *sk;
-		struct hlist_node *tmp;
-
-		sk_for_each_safe(sk, tmp, &otable[i])
-			__sk_add_node(sk, nl_portid_hashfn(hash, nlk_sk(sk)->portid));
-	}
-
-	nl_portid_hash_free(otable, osize);
-	hash->rehash_time = jiffies + 10 * 60 * HZ;
-	return 1;
+	return rhashtable_lookup_compare(&table->hash, hash,
+					 &netlink_compare, &arg);
 }
 }
 
 
-static inline int nl_portid_hash_dilute(struct nl_portid_hash *hash, int len)
+static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid)
 {
 {
-	int avg = hash->entries >> hash->shift;
-
-	if (unlikely(avg > 1) && nl_portid_hash_rehash(hash, 1))
-		return 1;
+	struct netlink_table *table = &nl_table[protocol];
+	struct sock *sk;
 
 
-	if (unlikely(len > avg) && time_after(jiffies, hash->rehash_time)) {
-		nl_portid_hash_rehash(hash, 0);
-		return 1;
-	}
+	rcu_read_lock();
+	sk = __netlink_lookup(table, portid, net);
+	if (sk)
+		sock_hold(sk);
+	rcu_read_unlock();
 
 
-	return 0;
+	return sk;
 }
 }
 
 
 static const struct proto_ops netlink_ops;
 static const struct proto_ops netlink_ops;
@@ -1113,22 +1065,10 @@ netlink_update_listeners(struct sock *sk)
 static int netlink_insert(struct sock *sk, struct net *net, u32 portid)
 static int netlink_insert(struct sock *sk, struct net *net, u32 portid)
 {
 {
 	struct netlink_table *table = &nl_table[sk->sk_protocol];
 	struct netlink_table *table = &nl_table[sk->sk_protocol];
-	struct nl_portid_hash *hash = &table->hash;
-	struct hlist_head *head;
 	int err = -EADDRINUSE;
 	int err = -EADDRINUSE;
-	struct sock *osk;
-	int len;
 
 
-	netlink_table_grab();
-	head = nl_portid_hashfn(hash, portid);
-	len = 0;
-	sk_for_each(osk, head) {
-		if (table->compare(net, osk) &&
-		    (nlk_sk(osk)->portid == portid))
-			break;
-		len++;
-	}
-	if (osk)
+	mutex_lock(&nl_sk_hash_lock);
+	if (__netlink_lookup(table, portid, net))
 		goto err;
 		goto err;
 
 
 	err = -EBUSY;
 	err = -EBUSY;
@@ -1136,26 +1076,31 @@ static int netlink_insert(struct sock *sk, struct net *net, u32 portid)
 		goto err;
 		goto err;
 
 
 	err = -ENOMEM;
 	err = -ENOMEM;
-	if (BITS_PER_LONG > 32 && unlikely(hash->entries >= UINT_MAX))
+	if (BITS_PER_LONG > 32 && unlikely(table->hash.nelems >= UINT_MAX))
 		goto err;
 		goto err;
 
 
-	if (len && nl_portid_hash_dilute(hash, len))
-		head = nl_portid_hashfn(hash, portid);
-	hash->entries++;
 	nlk_sk(sk)->portid = portid;
 	nlk_sk(sk)->portid = portid;
-	sk_add_node(sk, head);
+	sock_hold(sk);
+	rhashtable_insert(&table->hash, &nlk_sk(sk)->node, GFP_KERNEL);
 	err = 0;
 	err = 0;
-
 err:
 err:
-	netlink_table_ungrab();
+	mutex_unlock(&nl_sk_hash_lock);
 	return err;
 	return err;
 }
 }
 
 
 static void netlink_remove(struct sock *sk)
 static void netlink_remove(struct sock *sk)
 {
 {
+	struct netlink_table *table;
+
+	mutex_lock(&nl_sk_hash_lock);
+	table = &nl_table[sk->sk_protocol];
+	if (rhashtable_remove(&table->hash, &nlk_sk(sk)->node, GFP_KERNEL)) {
+		WARN_ON(atomic_read(&sk->sk_refcnt) == 1);
+		__sock_put(sk);
+	}
+	mutex_unlock(&nl_sk_hash_lock);
+
 	netlink_table_grab();
 	netlink_table_grab();
-	if (sk_del_node_init(sk))
-		nl_table[sk->sk_protocol].hash.entries--;
 	if (nlk_sk(sk)->subscriptions)
 	if (nlk_sk(sk)->subscriptions)
 		__sk_del_bind_node(sk);
 		__sk_del_bind_node(sk);
 	netlink_table_ungrab();
 	netlink_table_ungrab();
@@ -1311,6 +1256,9 @@ static int netlink_release(struct socket *sock)
 	}
 	}
 	netlink_table_ungrab();
 	netlink_table_ungrab();
 
 
+	/* Wait for readers to complete */
+	synchronize_net();
+
 	kfree(nlk->groups);
 	kfree(nlk->groups);
 	nlk->groups = NULL;
 	nlk->groups = NULL;
 
 
@@ -1326,30 +1274,22 @@ static int netlink_autobind(struct socket *sock)
 	struct sock *sk = sock->sk;
 	struct sock *sk = sock->sk;
 	struct net *net = sock_net(sk);
 	struct net *net = sock_net(sk);
 	struct netlink_table *table = &nl_table[sk->sk_protocol];
 	struct netlink_table *table = &nl_table[sk->sk_protocol];
-	struct nl_portid_hash *hash = &table->hash;
-	struct hlist_head *head;
-	struct sock *osk;
 	s32 portid = task_tgid_vnr(current);
 	s32 portid = task_tgid_vnr(current);
 	int err;
 	int err;
 	static s32 rover = -4097;
 	static s32 rover = -4097;
 
 
 retry:
 retry:
 	cond_resched();
 	cond_resched();
-	netlink_table_grab();
-	head = nl_portid_hashfn(hash, portid);
-	sk_for_each(osk, head) {
-		if (!table->compare(net, osk))
-			continue;
-		if (nlk_sk(osk)->portid == portid) {
-			/* Bind collision, search negative portid values. */
-			portid = rover--;
-			if (rover > -4097)
-				rover = -4097;
-			netlink_table_ungrab();
-			goto retry;
-		}
+	rcu_read_lock();
+	if (__netlink_lookup(table, portid, net)) {
+		/* Bind collision, search negative portid values. */
+		portid = rover--;
+		if (rover > -4097)
+			rover = -4097;
+		rcu_read_unlock();
+		goto retry;
 	}
 	}
-	netlink_table_ungrab();
+	rcu_read_unlock();
 
 
 	err = netlink_insert(sk, net, portid);
 	err = netlink_insert(sk, net, portid);
 	if (err == -EADDRINUSE)
 	if (err == -EADDRINUSE)
@@ -2953,14 +2893,18 @@ static struct sock *netlink_seq_socket_idx(struct seq_file *seq, loff_t pos)
 {
 {
 	struct nl_seq_iter *iter = seq->private;
 	struct nl_seq_iter *iter = seq->private;
 	int i, j;
 	int i, j;
+	struct netlink_sock *nlk;
 	struct sock *s;
 	struct sock *s;
 	loff_t off = 0;
 	loff_t off = 0;
 
 
 	for (i = 0; i < MAX_LINKS; i++) {
 	for (i = 0; i < MAX_LINKS; i++) {
-		struct nl_portid_hash *hash = &nl_table[i].hash;
+		struct rhashtable *ht = &nl_table[i].hash;
+		const struct bucket_table *tbl = rht_dereference(ht->tbl, ht);
+
+		for (j = 0; j < tbl->size; j++) {
+			rht_for_each_entry_rcu(nlk, tbl->buckets[j], node) {
+				s = (struct sock *)nlk;
 
 
-		for (j = 0; j <= hash->mask; j++) {
-			sk_for_each(s, &hash->table[j]) {
 				if (sock_net(s) != seq_file_net(seq))
 				if (sock_net(s) != seq_file_net(seq))
 					continue;
 					continue;
 				if (off == pos) {
 				if (off == pos) {
@@ -2976,15 +2920,14 @@ static struct sock *netlink_seq_socket_idx(struct seq_file *seq, loff_t pos)
 }
 }
 
 
 static void *netlink_seq_start(struct seq_file *seq, loff_t *pos)
 static void *netlink_seq_start(struct seq_file *seq, loff_t *pos)
-	__acquires(nl_table_lock)
 {
 {
-	read_lock(&nl_table_lock);
+	rcu_read_lock();
 	return *pos ? netlink_seq_socket_idx(seq, *pos - 1) : SEQ_START_TOKEN;
 	return *pos ? netlink_seq_socket_idx(seq, *pos - 1) : SEQ_START_TOKEN;
 }
 }
 
 
 static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
 static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
 {
 {
-	struct sock *s;
+	struct netlink_sock *nlk;
 	struct nl_seq_iter *iter;
 	struct nl_seq_iter *iter;
 	struct net *net;
 	struct net *net;
 	int i, j;
 	int i, j;
@@ -2996,28 +2939,26 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
 
 
 	net = seq_file_net(seq);
 	net = seq_file_net(seq);
 	iter = seq->private;
 	iter = seq->private;
-	s = v;
-	do {
-		s = sk_next(s);
-	} while (s && !nl_table[s->sk_protocol].compare(net, s));
-	if (s)
-		return s;
+	nlk = v;
+
+	rht_for_each_entry_rcu(nlk, nlk->node.next, node)
+		if (net_eq(sock_net((struct sock *)nlk), net))
+			return nlk;
 
 
 	i = iter->link;
 	i = iter->link;
 	j = iter->hash_idx + 1;
 	j = iter->hash_idx + 1;
 
 
 	do {
 	do {
-		struct nl_portid_hash *hash = &nl_table[i].hash;
-
-		for (; j <= hash->mask; j++) {
-			s = sk_head(&hash->table[j]);
+		struct rhashtable *ht = &nl_table[i].hash;
+		const struct bucket_table *tbl = rht_dereference(ht->tbl, ht);
 
 
-			while (s && !nl_table[s->sk_protocol].compare(net, s))
-				s = sk_next(s);
-			if (s) {
-				iter->link = i;
-				iter->hash_idx = j;
-				return s;
+		for (; j < tbl->size; j++) {
+			rht_for_each_entry_rcu(nlk, tbl->buckets[j], node) {
+				if (net_eq(sock_net((struct sock *)nlk), net)) {
+					iter->link = i;
+					iter->hash_idx = j;
+					return nlk;
+				}
 			}
 			}
 		}
 		}
 
 
@@ -3028,9 +2969,8 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
 }
 }
 
 
 static void netlink_seq_stop(struct seq_file *seq, void *v)
 static void netlink_seq_stop(struct seq_file *seq, void *v)
-	__releases(nl_table_lock)
 {
 {
-	read_unlock(&nl_table_lock);
+	rcu_read_unlock();
 }
 }
 
 
 
 
@@ -3168,9 +3108,17 @@ static struct pernet_operations __net_initdata netlink_net_ops = {
 static int __init netlink_proto_init(void)
 static int __init netlink_proto_init(void)
 {
 {
 	int i;
 	int i;
-	unsigned long limit;
-	unsigned int order;
 	int err = proto_register(&netlink_proto, 0);
 	int err = proto_register(&netlink_proto, 0);
+	struct rhashtable_params ht_params = {
+		.head_offset = offsetof(struct netlink_sock, node),
+		.key_offset = offsetof(struct netlink_sock, portid),
+		.key_len = sizeof(u32), /* portid */
+		.hashfn = arch_fast_hash,
+		.max_shift = 16, /* 64K */
+		.grow_decision = rht_grow_above_75,
+		.shrink_decision = rht_shrink_below_30,
+		.mutex_is_held = lockdep_nl_sk_hash_is_held,
+	};
 
 
 	if (err != 0)
 	if (err != 0)
 		goto out;
 		goto out;
@@ -3181,32 +3129,13 @@ static int __init netlink_proto_init(void)
 	if (!nl_table)
 	if (!nl_table)
 		goto panic;
 		goto panic;
 
 
-	if (totalram_pages >= (128 * 1024))
-		limit = totalram_pages >> (21 - PAGE_SHIFT);
-	else
-		limit = totalram_pages >> (23 - PAGE_SHIFT);
-
-	order = get_bitmask_order(limit) - 1 + PAGE_SHIFT;
-	limit = (1UL << order) / sizeof(struct hlist_head);
-	order = get_bitmask_order(min(limit, (unsigned long)UINT_MAX)) - 1;
-
 	for (i = 0; i < MAX_LINKS; i++) {
 	for (i = 0; i < MAX_LINKS; i++) {
-		struct nl_portid_hash *hash = &nl_table[i].hash;
-
-		hash->table = nl_portid_hash_zalloc(1 * sizeof(*hash->table));
-		if (!hash->table) {
-			while (i-- > 0)
-				nl_portid_hash_free(nl_table[i].hash.table,
-						 1 * sizeof(*hash->table));
+		if (rhashtable_init(&nl_table[i].hash, &ht_params) < 0) {
+			while (--i > 0)
+				rhashtable_destroy(&nl_table[i].hash);
 			kfree(nl_table);
 			kfree(nl_table);
 			goto panic;
 			goto panic;
 		}
 		}
-		hash->max_shift = order;
-		hash->shift = 0;
-		hash->mask = 0;
-		hash->rehash_time = jiffies;
-
-		nl_table[i].compare = netlink_compare;
 	}
 	}
 
 
 	INIT_LIST_HEAD(&netlink_tap_all);
 	INIT_LIST_HEAD(&netlink_tap_all);

+ 4 - 14
net/netlink/af_netlink.h

@@ -1,6 +1,7 @@
 #ifndef _AF_NETLINK_H
 #ifndef _AF_NETLINK_H
 #define _AF_NETLINK_H
 #define _AF_NETLINK_H
 
 
+#include <linux/rhashtable.h>
 #include <net/sock.h>
 #include <net/sock.h>
 
 
 #define NLGRPSZ(x)	(ALIGN(x, sizeof(unsigned long) * 8) / 8)
 #define NLGRPSZ(x)	(ALIGN(x, sizeof(unsigned long) * 8) / 8)
@@ -47,6 +48,8 @@ struct netlink_sock {
 	struct netlink_ring	tx_ring;
 	struct netlink_ring	tx_ring;
 	atomic_t		mapped;
 	atomic_t		mapped;
 #endif /* CONFIG_NETLINK_MMAP */
 #endif /* CONFIG_NETLINK_MMAP */
+
+	struct rhash_head	node;
 };
 };
 
 
 static inline struct netlink_sock *nlk_sk(struct sock *sk)
 static inline struct netlink_sock *nlk_sk(struct sock *sk)
@@ -54,21 +57,8 @@ static inline struct netlink_sock *nlk_sk(struct sock *sk)
 	return container_of(sk, struct netlink_sock, sk);
 	return container_of(sk, struct netlink_sock, sk);
 }
 }
 
 
-struct nl_portid_hash {
-	struct hlist_head	*table;
-	unsigned long		rehash_time;
-
-	unsigned int		mask;
-	unsigned int		shift;
-
-	unsigned int		entries;
-	unsigned int		max_shift;
-
-	u32			rnd;
-};
-
 struct netlink_table {
 struct netlink_table {
-	struct nl_portid_hash	hash;
+	struct rhashtable	hash;
 	struct hlist_head	mc_list;
 	struct hlist_head	mc_list;
 	struct listeners __rcu	*listeners;
 	struct listeners __rcu	*listeners;
 	unsigned int		flags;
 	unsigned int		flags;

+ 8 - 3
net/netlink/diag.c

@@ -4,6 +4,7 @@
 #include <linux/netlink.h>
 #include <linux/netlink.h>
 #include <linux/sock_diag.h>
 #include <linux/sock_diag.h>
 #include <linux/netlink_diag.h>
 #include <linux/netlink_diag.h>
+#include <linux/rhashtable.h>
 
 
 #include "af_netlink.h"
 #include "af_netlink.h"
 
 
@@ -101,16 +102,20 @@ static int __netlink_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
 				int protocol, int s_num)
 				int protocol, int s_num)
 {
 {
 	struct netlink_table *tbl = &nl_table[protocol];
 	struct netlink_table *tbl = &nl_table[protocol];
-	struct nl_portid_hash *hash = &tbl->hash;
+	struct rhashtable *ht = &tbl->hash;
+	const struct bucket_table *htbl = rht_dereference(ht->tbl, ht);
 	struct net *net = sock_net(skb->sk);
 	struct net *net = sock_net(skb->sk);
 	struct netlink_diag_req *req;
 	struct netlink_diag_req *req;
+	struct netlink_sock *nlsk;
 	struct sock *sk;
 	struct sock *sk;
 	int ret = 0, num = 0, i;
 	int ret = 0, num = 0, i;
 
 
 	req = nlmsg_data(cb->nlh);
 	req = nlmsg_data(cb->nlh);
 
 
-	for (i = 0; i <= hash->mask; i++) {
-		sk_for_each(sk, &hash->table[i]) {
+	for (i = 0; i < htbl->size; i++) {
+		rht_for_each_entry(nlsk, htbl->buckets[i], ht, node) {
+			sk = (struct sock *)nlsk;
+
 			if (!net_eq(sock_net(sk), net))
 			if (!net_eq(sock_net(sk), net))
 				continue;
 				continue;
 			if (num < s_num) {
 			if (num < s_num) {