Selaa lähdekoodia

Merge branch 'fib_trie-next'

Alexander Duyck says:

====================
ipv4/fib_trie: Cleanups to prepare for introduction of key vector

This patch series is meant to mostly just clean up the fib_trie to prepare
it for the introduction of the key_vector.  As such there are a number of
minor clean-ups such as reformatting the tnode to match the format once the
key vector is introduced, some optimizations to drop the need for a leaf
parent pointer, and some changes to remove duplication of effort such as
the 2 look-ups that were essentially being done per node insertion.

v2: Added code to cleanup idx >> n->bits and explain unsigned long logic
    Added code to prevent allocation when tnode size is larger than size_t
====================

Signed-off-by: David S. Miller <davem@davemloft.net>
David S. Miller 10 vuotta sitten
vanhempi
commit
f93eb4ba0f
4 muutettua tiedostoa jossa 433 lisäystä ja 368 poistoa
  1. 42 28
      include/net/ip_fib.h
  2. 4 3
      include/net/netns/ipv4.h
  3. 37 15
      net/ipv4/fib_frontend.c
  4. 350 322
      net/ipv4/fib_trie.c

+ 42 - 28
include/net/ip_fib.h

@@ -185,6 +185,7 @@ struct fib_table {
 	u32			tb_id;
 	int			tb_default;
 	int			tb_num_default;
+	struct rcu_head		rcu;
 	unsigned long		tb_data[0];
 };
 
@@ -206,12 +207,16 @@ void fib_free_table(struct fib_table *tb);
 
 static inline struct fib_table *fib_get_table(struct net *net, u32 id)
 {
+	struct hlist_node *tb_hlist;
 	struct hlist_head *ptr;
 
 	ptr = id == RT_TABLE_LOCAL ?
 		&net->ipv4.fib_table_hash[TABLE_LOCAL_INDEX] :
 		&net->ipv4.fib_table_hash[TABLE_MAIN_INDEX];
-	return hlist_entry(ptr->first, struct fib_table, tb_hlist);
+
+	tb_hlist = rcu_dereference_rtnl(hlist_first_rcu(ptr));
+
+	return hlist_entry(tb_hlist, struct fib_table, tb_hlist);
 }
 
 static inline struct fib_table *fib_new_table(struct net *net, u32 id)
@@ -222,15 +227,19 @@ static inline struct fib_table *fib_new_table(struct net *net, u32 id)
 static inline int fib_lookup(struct net *net, const struct flowi4 *flp,
 			     struct fib_result *res)
 {
-	int err = -ENETUNREACH;
+	struct fib_table *tb;
+	int err;
 
 	rcu_read_lock();
 
-	if (!fib_table_lookup(fib_get_table(net, RT_TABLE_LOCAL), flp, res,
-			      FIB_LOOKUP_NOREF) ||
-	    !fib_table_lookup(fib_get_table(net, RT_TABLE_MAIN), flp, res,
-			      FIB_LOOKUP_NOREF))
-		err = 0;
+	for (err = 0; !err; err = -ENETUNREACH) {
+		tb = fib_get_table(net, RT_TABLE_LOCAL);
+		if (tb && !fib_table_lookup(tb, flp, res, FIB_LOOKUP_NOREF))
+			break;
+		tb = fib_get_table(net, RT_TABLE_MAIN);
+		if (tb && !fib_table_lookup(tb, flp, res, FIB_LOOKUP_NOREF))
+			break;
+	}
 
 	rcu_read_unlock();
 
@@ -249,28 +258,33 @@ int __fib_lookup(struct net *net, struct flowi4 *flp, struct fib_result *res);
 static inline int fib_lookup(struct net *net, struct flowi4 *flp,
 			     struct fib_result *res)
 {
-	if (!net->ipv4.fib_has_custom_rules) {
-		int err = -ENETUNREACH;
-
-		rcu_read_lock();
-
-		res->tclassid = 0;
-		if ((net->ipv4.fib_local &&
-		     !fib_table_lookup(net->ipv4.fib_local, flp, res,
-				       FIB_LOOKUP_NOREF)) ||
-		    (net->ipv4.fib_main &&
-		     !fib_table_lookup(net->ipv4.fib_main, flp, res,
-				       FIB_LOOKUP_NOREF)) ||
-		    (net->ipv4.fib_default &&
-		     !fib_table_lookup(net->ipv4.fib_default, flp, res,
-				       FIB_LOOKUP_NOREF)))
-			err = 0;
-
-		rcu_read_unlock();
-
-		return err;
+	struct fib_table *tb;
+	int err;
+
+	if (net->ipv4.fib_has_custom_rules)
+		return __fib_lookup(net, flp, res);
+
+	rcu_read_lock();
+
+	res->tclassid = 0;
+
+	for (err = 0; !err; err = -ENETUNREACH) {
+		tb = rcu_dereference_rtnl(net->ipv4.fib_local);
+		if (tb && !fib_table_lookup(tb, flp, res, FIB_LOOKUP_NOREF))
+			break;
+
+		tb = rcu_dereference_rtnl(net->ipv4.fib_main);
+		if (tb && !fib_table_lookup(tb, flp, res, FIB_LOOKUP_NOREF))
+			break;
+
+		tb = rcu_dereference_rtnl(net->ipv4.fib_default);
+		if (tb && !fib_table_lookup(tb, flp, res, FIB_LOOKUP_NOREF))
+			break;
 	}
-	return __fib_lookup(net, flp, res);
+
+	rcu_read_unlock();
+
+	return err;
 }
 
 #endif /* CONFIG_IP_MULTIPLE_TABLES */

+ 4 - 3
include/net/netns/ipv4.h

@@ -7,6 +7,7 @@
 
 #include <linux/uidgid.h>
 #include <net/inet_frag.h>
+#include <linux/rcupdate.h>
 
 struct tcpm_hash_bucket;
 struct ctl_table_header;
@@ -38,9 +39,9 @@ struct netns_ipv4 {
 #ifdef CONFIG_IP_MULTIPLE_TABLES
 	struct fib_rules_ops	*rules_ops;
 	bool			fib_has_custom_rules;
-	struct fib_table	*fib_local;
-	struct fib_table	*fib_main;
-	struct fib_table	*fib_default;
+	struct fib_table __rcu	*fib_local;
+	struct fib_table __rcu	*fib_main;
+	struct fib_table __rcu	*fib_default;
 #endif
 #ifdef CONFIG_IP_ROUTE_CLASSID
 	int			fib_num_tclassid_users;

+ 37 - 15
net/ipv4/fib_frontend.c

@@ -89,17 +89,14 @@ struct fib_table *fib_new_table(struct net *net, u32 id)
 
 	switch (id) {
 	case RT_TABLE_LOCAL:
-		net->ipv4.fib_local = tb;
+		rcu_assign_pointer(net->ipv4.fib_local, tb);
 		break;
-
 	case RT_TABLE_MAIN:
-		net->ipv4.fib_main = tb;
+		rcu_assign_pointer(net->ipv4.fib_main, tb);
 		break;
-
 	case RT_TABLE_DEFAULT:
-		net->ipv4.fib_default = tb;
+		rcu_assign_pointer(net->ipv4.fib_default, tb);
 		break;
-
 	default:
 		break;
 	}
@@ -132,13 +129,14 @@ struct fib_table *fib_get_table(struct net *net, u32 id)
 static void fib_flush(struct net *net)
 {
 	int flushed = 0;
-	struct fib_table *tb;
-	struct hlist_head *head;
 	unsigned int h;
 
 	for (h = 0; h < FIB_TABLE_HASHSZ; h++) {
-		head = &net->ipv4.fib_table_hash[h];
-		hlist_for_each_entry(tb, head, tb_hlist)
+		struct hlist_head *head = &net->ipv4.fib_table_hash[h];
+		struct hlist_node *tmp;
+		struct fib_table *tb;
+
+		hlist_for_each_entry_safe(tb, tmp, head, tb_hlist)
 			flushed += fib_table_flush(tb);
 	}
 
@@ -665,10 +663,12 @@ static int inet_dump_fib(struct sk_buff *skb, struct netlink_callback *cb)
 	s_h = cb->args[0];
 	s_e = cb->args[1];
 
+	rcu_read_lock();
+
 	for (h = s_h; h < FIB_TABLE_HASHSZ; h++, s_e = 0) {
 		e = 0;
 		head = &net->ipv4.fib_table_hash[h];
-		hlist_for_each_entry(tb, head, tb_hlist) {
+		hlist_for_each_entry_rcu(tb, head, tb_hlist) {
 			if (e < s_e)
 				goto next;
 			if (dumped)
@@ -682,6 +682,8 @@ next:
 		}
 	}
 out:
+	rcu_read_unlock();
+
 	cb->args[1] = e;
 	cb->args[0] = h;
 
@@ -1117,14 +1119,34 @@ static void ip_fib_net_exit(struct net *net)
 
 	rtnl_lock();
 	for (i = 0; i < FIB_TABLE_HASHSZ; i++) {
-		struct fib_table *tb;
-		struct hlist_head *head;
+		struct hlist_head *head = &net->ipv4.fib_table_hash[i];
 		struct hlist_node *tmp;
+		struct fib_table *tb;
+
+		/* this is done in two passes as flushing the table could
+		 * cause it to be reallocated in order to accommodate new
+		 * tnodes at the root as the table shrinks.
+		 */
+		hlist_for_each_entry_safe(tb, tmp, head, tb_hlist)
+			fib_table_flush(tb);
 
-		head = &net->ipv4.fib_table_hash[i];
 		hlist_for_each_entry_safe(tb, tmp, head, tb_hlist) {
+#ifdef CONFIG_IP_MULTIPLE_TABLES
+			switch (tb->tb_id) {
+			case RT_TABLE_LOCAL:
+				RCU_INIT_POINTER(net->ipv4.fib_local, NULL);
+				break;
+			case RT_TABLE_MAIN:
+				RCU_INIT_POINTER(net->ipv4.fib_main, NULL);
+				break;
+			case RT_TABLE_DEFAULT:
+				RCU_INIT_POINTER(net->ipv4.fib_default, NULL);
+				break;
+			default:
+				break;
+			}
+#endif
 			hlist_del(&tb->tb_hlist);
-			fib_table_flush(tb);
 			fib_free_table(tb);
 		}
 	}

+ 350 - 322
net/ipv4/fib_trie.c

@@ -94,24 +94,27 @@ typedef unsigned int t_key;
 #define get_index(_key, _kv) (((_key) ^ (_kv)->key) >> (_kv)->pos)
 
 struct tnode {
+	struct rcu_head rcu;
+
+	t_key empty_children; /* KEYLENGTH bits needed */
+	t_key full_children;  /* KEYLENGTH bits needed */
+	struct tnode __rcu *parent;
+
 	t_key key;
-	unsigned char bits;		/* 2log(KEYLENGTH) bits needed */
 	unsigned char pos;		/* 2log(KEYLENGTH) bits needed */
+	unsigned char bits;		/* 2log(KEYLENGTH) bits needed */
 	unsigned char slen;
-	struct tnode __rcu *parent;
-	struct rcu_head rcu;
 	union {
-		/* The fields in this struct are valid if bits > 0 (TNODE) */
-		struct {
-			t_key empty_children; /* KEYLENGTH bits needed */
-			t_key full_children;  /* KEYLENGTH bits needed */
-			struct tnode __rcu *child[0];
-		};
-		/* This list pointer if valid if bits == 0 (LEAF) */
+		/* This list pointer if valid if (pos | bits) == 0 (LEAF) */
 		struct hlist_head leaf;
+		/* This array is valid if (pos | bits) > 0 (TNODE) */
+		struct tnode __rcu *tnode[0];
 	};
 };
 
+#define TNODE_SIZE(n)	offsetof(struct tnode, tnode[n])
+#define LEAF_SIZE	TNODE_SIZE(1)
+
 #ifdef CONFIG_IP_FIB_TRIE_STATS
 struct trie_use_stats {
 	unsigned int gets;
@@ -180,14 +183,21 @@ static inline unsigned long tnode_child_length(const struct tnode *tn)
 static inline struct tnode *tnode_get_child(const struct tnode *tn,
 					    unsigned long i)
 {
-	return rtnl_dereference(tn->child[i]);
+	return rtnl_dereference(tn->tnode[i]);
 }
 
 /* caller must hold RCU read lock or RTNL */
 static inline struct tnode *tnode_get_child_rcu(const struct tnode *tn,
 						unsigned long i)
 {
-	return rcu_dereference_rtnl(tn->child[i]);
+	return rcu_dereference_rtnl(tn->tnode[i]);
+}
+
+static inline struct fib_table *trie_get_table(struct trie *t)
+{
+	unsigned long *tb_data = (unsigned long *)t;
+
+	return container_of(tb_data, struct fib_table, tb_data[0]);
 }
 
 /* To understand this stuff, an understanding of keys and all their bits is
@@ -266,7 +276,9 @@ static inline void alias_free_mem_rcu(struct fib_alias *fa)
 }
 
 #define TNODE_KMALLOC_MAX \
-	ilog2((PAGE_SIZE - sizeof(struct tnode)) / sizeof(struct tnode *))
+	ilog2((PAGE_SIZE - TNODE_SIZE(0)) / sizeof(struct tnode *))
+#define TNODE_VMALLOC_MAX \
+	ilog2((SIZE_MAX - TNODE_SIZE(0)) / sizeof(struct tnode *))
 
 static void __node_free_rcu(struct rcu_head *head)
 {
@@ -282,8 +294,17 @@ static void __node_free_rcu(struct rcu_head *head)
 
 #define node_free(n) call_rcu(&n->rcu, __node_free_rcu)
 
-static struct tnode *tnode_alloc(size_t size)
+static struct tnode *tnode_alloc(int bits)
 {
+	size_t size;
+
+	/* verify bits is within bounds */
+	if (bits > TNODE_VMALLOC_MAX)
+		return NULL;
+
+	/* determine size and verify it is non-zero and didn't overflow */
+	size = TNODE_SIZE(1ul << bits);
+
 	if (size <= PAGE_SIZE)
 		return kzalloc(size, GFP_KERNEL);
 	else
@@ -300,7 +321,7 @@ static inline void empty_child_dec(struct tnode *n)
 	n->empty_children-- ? : n->full_children--;
 }
 
-static struct tnode *leaf_new(t_key key)
+static struct tnode *leaf_new(t_key key, struct fib_alias *fa)
 {
 	struct tnode *l = kmem_cache_alloc(trie_leaf_kmem, GFP_KERNEL);
 	if (l) {
@@ -310,20 +331,21 @@ static struct tnode *leaf_new(t_key key)
 		 * as the nodes are searched
 		 */
 		l->key = key;
-		l->slen = 0;
+		l->slen = fa->fa_slen;
 		l->pos = 0;
 		/* set bits to 0 indicating we are not a tnode */
 		l->bits = 0;
 
+		/* link leaf to fib alias */
 		INIT_HLIST_HEAD(&l->leaf);
+		hlist_add_head(&fa->fa_list, &l->leaf);
 	}
 	return l;
 }
 
 static struct tnode *tnode_new(t_key key, int pos, int bits)
 {
-	size_t sz = offsetof(struct tnode, child[1ul << bits]);
-	struct tnode *tn = tnode_alloc(sz);
+	struct tnode *tn = tnode_alloc(bits);
 	unsigned int shift = pos + bits;
 
 	/* verify bits and pos their msb bits clear and values are valid */
@@ -341,7 +363,7 @@ static struct tnode *tnode_new(t_key key, int pos, int bits)
 			tn->empty_children = 1ul << bits;
 	}
 
-	pr_debug("AT %p s=%zu %zu\n", tn, sizeof(struct tnode),
+	pr_debug("AT %p s=%zu %zu\n", tn, TNODE_SIZE(0),
 		 sizeof(struct tnode *) << bits);
 	return tn;
 }
@@ -382,7 +404,7 @@ static void put_child(struct tnode *tn, unsigned long i, struct tnode *n)
 	if (n && (tn->slen < n->slen))
 		tn->slen = n->slen;
 
-	rcu_assign_pointer(tn->child[i], n);
+	rcu_assign_pointer(tn->tnode[i], n);
 }
 
 static void update_children(struct tnode *tn)
@@ -433,7 +455,7 @@ static void tnode_free(struct tnode *tn)
 
 	while (head) {
 		head = head->next;
-		tnode_free_size += offsetof(struct tnode, child[1 << tn->bits]);
+		tnode_free_size += TNODE_SIZE(1ul << tn->bits);
 		node_free(tn);
 
 		tn = container_of(head, struct tnode, rcu);
@@ -786,7 +808,7 @@ static void resize(struct trie *t, struct tnode *tn)
 	 * doing it ourselves.  This way we can let RCU fully do its
 	 * thing without us interfering
 	 */
-	cptr = tp ? &tp->child[get_index(tn->key, tp)] : &t->trie;
+	cptr = tp ? &tp->tnode[get_index(tn->key, tp)] : &t->trie;
 	BUG_ON(tn != rtnl_dereference(*cptr));
 
 	/* Double as long as the resulting node has a number of
@@ -842,10 +864,8 @@ static void resize(struct trie *t, struct tnode *tn)
 	}
 }
 
-static void leaf_pull_suffix(struct tnode *l)
+static void leaf_pull_suffix(struct tnode *tp, struct tnode *l)
 {
-	struct tnode *tp = node_parent(l);
-
 	while (tp && (tp->slen > tp->pos) && (tp->slen > l->slen)) {
 		if (update_suffix(tp) > l->slen)
 			break;
@@ -853,10 +873,8 @@ static void leaf_pull_suffix(struct tnode *l)
 	}
 }
 
-static void leaf_push_suffix(struct tnode *l)
+static void leaf_push_suffix(struct tnode *tn, struct tnode *l)
 {
-	struct tnode *tn = node_parent(l);
-
 	/* if this is a new leaf then tn will be NULL and we can sort
 	 * out parent suffix lengths as a part of trie_rebalance
 	 */
@@ -866,55 +884,10 @@ static void leaf_push_suffix(struct tnode *l)
 	}
 }
 
-static void fib_remove_alias(struct tnode *l, struct fib_alias *old)
-{
-	/* record the location of the previous list_info entry */
-	struct hlist_node **pprev = old->fa_list.pprev;
-	struct fib_alias *fa = hlist_entry(pprev, typeof(*fa), fa_list.next);
-
-	/* remove the fib_alias from the list */
-	hlist_del_rcu(&old->fa_list);
-
-	/* only access fa if it is pointing at the last valid hlist_node */
-	if (hlist_empty(&l->leaf) || (*pprev))
-		return;
-
-	/* update the trie with the latest suffix length */
-	l->slen = fa->fa_slen;
-	leaf_pull_suffix(l);
-}
-
-static void fib_insert_alias(struct tnode *l, struct fib_alias *fa,
-			     struct fib_alias *new)
-{
-	if (fa) {
-		hlist_add_before_rcu(&new->fa_list, &fa->fa_list);
-	} else {
-		struct fib_alias *last;
-
-		hlist_for_each_entry(last, &l->leaf, fa_list) {
-			if (new->fa_slen < last->fa_slen)
-				break;
-			fa = last;
-		}
-
-		if (fa)
-			hlist_add_behind_rcu(&new->fa_list, &fa->fa_list);
-		else
-			hlist_add_head_rcu(&new->fa_list, &l->leaf);
-	}
-
-	/* if we added to the tail node then we need to update slen */
-	if (l->slen < new->fa_slen) {
-		l->slen = new->fa_slen;
-		leaf_push_suffix(l);
-	}
-}
-
 /* rcu_read_lock needs to be hold by caller from readside */
-static struct tnode *fib_find_node(struct trie *t, u32 key)
+static struct tnode *fib_find_node(struct trie *t, struct tnode **tn, u32 key)
 {
-	struct tnode *n = rcu_dereference_rtnl(t->trie);
+	struct tnode *pn = NULL, *n = rcu_dereference_rtnl(t->trie);
 
 	while (n) {
 		unsigned long index = get_index(key, n);
@@ -924,21 +897,30 @@ static struct tnode *fib_find_node(struct trie *t, u32 key)
 		 * prefix plus zeros for the bits in the cindex. The index
 		 * is the difference between the key and this value.  From
 		 * this we can actually derive several pieces of data.
-		 *   if (index & (~0ul << bits))
+		 *   if (index >= (1ul << bits))
 		 *     we have a mismatch in skip bits and failed
 		 *   else
 		 *     we know the value is cindex
+		 *
+		 * This check is safe even if bits == KEYLENGTH due to the
+		 * fact that we can only allocate a node with 32 bits if a
+		 * long is greater than 32 bits.
 		 */
-		if (index & (~0ul << n->bits))
-			return NULL;
+		if (index >= (1ul << n->bits)) {
+			n = NULL;
+			break;
+		}
 
 		/* we have found a leaf. Prefixes have already been compared */
 		if (IS_LEAF(n))
 			break;
 
+		pn = n;
 		n = tnode_get_child_rcu(n, index);
 	}
 
+	*tn = pn;
+
 	return n;
 }
 
@@ -971,61 +953,28 @@ static void trie_rebalance(struct trie *t, struct tnode *tn)
 {
 	struct tnode *tp;
 
-	while ((tp = node_parent(tn)) != NULL) {
+	while (tn) {
+		tp = node_parent(tn);
 		resize(t, tn);
 		tn = tp;
 	}
-
-	/* Handle last (top) tnode */
-	if (IS_TNODE(tn))
-		resize(t, tn);
 }
 
 /* only used from updater-side */
-
-static struct tnode *fib_insert_node(struct trie *t, u32 key, int plen)
+static int fib_insert_node(struct trie *t, struct tnode *tp,
+			   struct fib_alias *new, t_key key)
 {
-	struct tnode *l, *n, *tp = NULL;
-
-	n = rtnl_dereference(t->trie);
-
-	/* If we point to NULL, stop. Either the tree is empty and we should
-	 * just put a new leaf in if, or we have reached an empty child slot,
-	 * and we should just put our new leaf in that.
-	 *
-	 * If we hit a node with a key that does't match then we should stop
-	 * and create a new tnode to replace that node and insert ourselves
-	 * and the other node into the new tnode.
-	 */
-	while (n) {
-		unsigned long index = get_index(key, n);
+	struct tnode *n, *l;
 
-		/* This bit of code is a bit tricky but it combines multiple
-		 * checks into a single check.  The prefix consists of the
-		 * prefix plus zeros for the "bits" in the prefix. The index
-		 * is the difference between the key and this value.  From
-		 * this we can actually derive several pieces of data.
-		 *   if !(index >> bits)
-		 *     we know the value is child index
-		 *   else
-		 *     we have a mismatch in skip bits and failed
-		 */
-		if (index >> n->bits)
-			break;
-
-		/* we have found a leaf. Prefixes have already been compared */
-		if (IS_LEAF(n)) {
-			/* Case 1: n is a leaf, and prefixes match*/
-			return n;
-		}
-
-		tp = n;
-		n = tnode_get_child_rcu(n, index);
-	}
-
-	l = leaf_new(key);
+	l = leaf_new(key, new);
 	if (!l)
-		return NULL;
+		return -ENOMEM;
+
+	/* retrieve child from parent node */
+	if (tp)
+		n = tnode_get_child(tp, get_index(key, tp));
+	else
+		n = rcu_dereference_rtnl(t->trie);
 
 	/* Case 2: n is a LEAF or a TNODE and the key doesn't match.
 	 *
@@ -1039,7 +988,7 @@ static struct tnode *fib_insert_node(struct trie *t, u32 key, int plen)
 		tn = tnode_new(key, __fls(key ^ n->key), 1);
 		if (!tn) {
 			node_free(l);
-			return NULL;
+			return -ENOMEM;
 		}
 
 		/* initialize routes out of node */
@@ -1055,31 +1004,58 @@ static struct tnode *fib_insert_node(struct trie *t, u32 key, int plen)
 	}
 
 	/* Case 3: n is NULL, and will just insert a new leaf */
-	if (tp) {
-		NODE_INIT_PARENT(l, tp);
-		put_child(tp, get_index(key, tp), l);
-		trie_rebalance(t, tp);
+	NODE_INIT_PARENT(l, tp);
+	put_child_root(tp, t, key, l);
+	trie_rebalance(t, tp);
+
+	return 0;
+}
+
+static int fib_insert_alias(struct trie *t, struct tnode *tp,
+			    struct tnode *l, struct fib_alias *new,
+			    struct fib_alias *fa, t_key key)
+{
+	if (!l)
+		return fib_insert_node(t, tp, new, key);
+
+	if (fa) {
+		hlist_add_before_rcu(&new->fa_list, &fa->fa_list);
 	} else {
-		rcu_assign_pointer(t->trie, l);
+		struct fib_alias *last;
+
+		hlist_for_each_entry(last, &l->leaf, fa_list) {
+			if (new->fa_slen < last->fa_slen)
+				break;
+			fa = last;
+		}
+
+		if (fa)
+			hlist_add_behind_rcu(&new->fa_list, &fa->fa_list);
+		else
+			hlist_add_head_rcu(&new->fa_list, &l->leaf);
 	}
 
-	return l;
+	/* if we added to the tail node then we need to update slen */
+	if (l->slen < new->fa_slen) {
+		l->slen = new->fa_slen;
+		leaf_push_suffix(tp, l);
+	}
+
+	return 0;
 }
 
-/*
- * Caller must hold RTNL.
- */
+/* Caller must hold RTNL. */
 int fib_table_insert(struct fib_table *tb, struct fib_config *cfg)
 {
-	struct trie *t = (struct trie *) tb->tb_data;
+	struct trie *t = (struct trie *)tb->tb_data;
 	struct fib_alias *fa, *new_fa;
+	struct tnode *l, *tp;
 	struct fib_info *fi;
 	u8 plen = cfg->fc_dst_len;
 	u8 slen = KEYLENGTH - plen;
 	u8 tos = cfg->fc_tos;
-	u32 key, mask;
+	u32 key;
 	int err;
-	struct tnode *l;
 
 	if (plen > KEYLENGTH)
 		return -EINVAL;
@@ -1088,9 +1064,7 @@ int fib_table_insert(struct fib_table *tb, struct fib_config *cfg)
 
 	pr_debug("Insert table=%u %08x/%d\n", tb->tb_id, key, plen);
 
-	mask = ntohl(inet_make_mask(plen));
-
-	if (key & ~mask)
+	if ((plen < KEYLENGTH) && (key << plen))
 		return -EINVAL;
 
 	fi = fib_create_info(cfg);
@@ -1099,7 +1073,7 @@ int fib_table_insert(struct fib_table *tb, struct fib_config *cfg)
 		goto err;
 	}
 
-	l = fib_find_node(t, key);
+	l = fib_find_node(t, &tp, key);
 	fa = l ? fib_find_alias(&l->leaf, slen, tos, fi->fib_priority) : NULL;
 
 	/* Now fa, if non-NULL, points to the first fib alias
@@ -1198,19 +1172,13 @@ int fib_table_insert(struct fib_table *tb, struct fib_config *cfg)
 	new_fa->fa_slen = slen;
 
 	/* Insert new entry to the list. */
-	if (!l) {
-		l = fib_insert_node(t, key, plen);
-		if (unlikely(!l)) {
-			err = -ENOMEM;
-			goto out_free_new_fa;
-		}
-	}
+	err = fib_insert_alias(t, tp, l, new_fa, fa, key);
+	if (err)
+		goto out_free_new_fa;
 
 	if (!plen)
 		tb->tb_num_default++;
 
-	fib_insert_alias(l, fa, new_fa);
-
 	rt_cache_flush(cfg->fc_nlinfo.nl_net);
 	rtmsg_fib(RTM_NEWROUTE, htonl(key), new_fa, plen, tb->tb_id,
 		  &cfg->fc_nlinfo, 0);
@@ -1243,6 +1211,7 @@ int fib_table_lookup(struct fib_table *tb, const struct flowi4 *flp,
 	const t_key key = ntohl(flp->daddr);
 	struct tnode *n, *pn;
 	struct fib_alias *fa;
+	unsigned long index;
 	t_key cindex;
 
 	n = rcu_dereference(t->trie);
@@ -1258,19 +1227,23 @@ int fib_table_lookup(struct fib_table *tb, const struct flowi4 *flp,
 
 	/* Step 1: Travel to the longest prefix match in the trie */
 	for (;;) {
-		unsigned long index = get_index(key, n);
+		index = get_index(key, n);
 
 		/* This bit of code is a bit tricky but it combines multiple
 		 * checks into a single check.  The prefix consists of the
 		 * prefix plus zeros for the "bits" in the prefix. The index
 		 * is the difference between the key and this value.  From
 		 * this we can actually derive several pieces of data.
-		 *   if (index & (~0ul << bits))
+		 *   if (index >= (1ul << bits))
 		 *     we have a mismatch in skip bits and failed
 		 *   else
 		 *     we know the value is cindex
+		 *
+		 * This check is safe even if bits == KEYLENGTH due to the
+		 * fact that we can only allocate a node with 32 bits if a
+		 * long is greater than 32 bits.
 		 */
-		if (index & (~0ul << n->bits))
+		if (index >= (1ul << n->bits))
 			break;
 
 		/* we have found a leaf. Prefixes have already been compared */
@@ -1293,7 +1266,7 @@ int fib_table_lookup(struct fib_table *tb, const struct flowi4 *flp,
 	/* Step 2: Sort out leaves and begin backtracing for longest prefix */
 	for (;;) {
 		/* record the pointer where our next node pointer is stored */
-		struct tnode __rcu **cptr = n->child;
+		struct tnode __rcu **cptr = n->tnode;
 
 		/* This test verifies that none of the bits that differ
 		 * between the key and the prefix exist in the region of
@@ -1339,19 +1312,22 @@ backtrace:
 			cindex &= cindex - 1;
 
 			/* grab pointer for next child node */
-			cptr = &pn->child[cindex];
+			cptr = &pn->tnode[cindex];
 		}
 	}
 
 found:
+	/* this line carries forward the xor from earlier in the function */
+	index = key ^ n->key;
+
 	/* Step 3: Process the leaf, if that fails fall back to backtracing */
 	hlist_for_each_entry_rcu(fa, &n->leaf, fa_list) {
 		struct fib_info *fi = fa->fa_info;
 		int nhsel, err;
 
-		if (((key ^ n->key) >= (1ul << fa->fa_slen)) &&
+		if ((index >= (1ul << fa->fa_slen)) &&
 		    ((BITS_PER_LONG > KEYLENGTH) || (fa->fa_slen != KEYLENGTH)))
-				continue;
+			continue;
 		if (fa->fa_tos && fa->fa_tos != flp->flowi4_tos)
 			continue;
 		if (fi->fib_dead)
@@ -1399,53 +1375,59 @@ found:
 }
 EXPORT_SYMBOL_GPL(fib_table_lookup);
 
-/*
- * Remove the leaf and return parent.
- */
-static void trie_leaf_remove(struct trie *t, struct tnode *l)
+static void fib_remove_alias(struct trie *t, struct tnode *tp,
+			     struct tnode *l, struct fib_alias *old)
 {
-	struct tnode *tp = node_parent(l);
+	/* record the location of the previous list_info entry */
+	struct hlist_node **pprev = old->fa_list.pprev;
+	struct fib_alias *fa = hlist_entry(pprev, typeof(*fa), fa_list.next);
 
-	pr_debug("entering trie_leaf_remove(%p)\n", l);
+	/* remove the fib_alias from the list */
+	hlist_del_rcu(&old->fa_list);
 
-	if (tp) {
-		put_child(tp, get_index(l->key, tp), NULL);
+	/* if we emptied the list this leaf will be freed and we can sort
+	 * out parent suffix lengths as a part of trie_rebalance
+	 */
+	if (hlist_empty(&l->leaf)) {
+		put_child_root(tp, t, l->key, NULL);
+		node_free(l);
 		trie_rebalance(t, tp);
-	} else {
-		RCU_INIT_POINTER(t->trie, NULL);
+		return;
 	}
 
-	node_free(l);
+	/* only access fa if it is pointing at the last valid hlist_node */
+	if (*pprev)
+		return;
+
+	/* update the trie with the latest suffix length */
+	l->slen = fa->fa_slen;
+	leaf_pull_suffix(tp, l);
 }
 
-/*
- * Caller must hold RTNL.
- */
+/* Caller must hold RTNL. */
 int fib_table_delete(struct fib_table *tb, struct fib_config *cfg)
 {
 	struct trie *t = (struct trie *) tb->tb_data;
 	struct fib_alias *fa, *fa_to_delete;
+	struct tnode *l, *tp;
 	u8 plen = cfg->fc_dst_len;
-	u8 tos = cfg->fc_tos;
 	u8 slen = KEYLENGTH - plen;
-	struct tnode *l;
-	u32 key, mask;
+	u8 tos = cfg->fc_tos;
+	u32 key;
 
 	if (plen > KEYLENGTH)
 		return -EINVAL;
 
 	key = ntohl(cfg->fc_dst);
-	mask = ntohl(inet_make_mask(plen));
 
-	if (key & ~mask)
+	if ((plen < KEYLENGTH) && (key << plen))
 		return -EINVAL;
 
-	l = fib_find_node(t, key);
+	l = fib_find_node(t, &tp, key);
 	if (!l)
 		return -ESRCH;
 
 	fa = fib_find_alias(&l->leaf, slen, tos, 0);
-
 	if (!fa)
 		return -ESRCH;
 
@@ -1474,150 +1456,171 @@ int fib_table_delete(struct fib_table *tb, struct fib_config *cfg)
 	if (!fa_to_delete)
 		return -ESRCH;
 
-	fa = fa_to_delete;
-	rtmsg_fib(RTM_DELROUTE, htonl(key), fa, plen, tb->tb_id,
+	rtmsg_fib(RTM_DELROUTE, htonl(key), fa_to_delete, plen, tb->tb_id,
 		  &cfg->fc_nlinfo, 0);
 
-	fib_remove_alias(l, fa);
-
 	if (!plen)
 		tb->tb_num_default--;
 
-	if (hlist_empty(&l->leaf))
-		trie_leaf_remove(t, l);
+	fib_remove_alias(t, tp, l, fa_to_delete);
 
-	if (fa->fa_state & FA_S_ACCESSED)
+	if (fa_to_delete->fa_state & FA_S_ACCESSED)
 		rt_cache_flush(cfg->fc_nlinfo.nl_net);
 
-	fib_release_info(fa->fa_info);
-	alias_free_mem_rcu(fa);
+	fib_release_info(fa_to_delete->fa_info);
+	alias_free_mem_rcu(fa_to_delete);
 	return 0;
 }
 
-static int trie_flush_leaf(struct tnode *l)
+/* Scan for the next leaf starting at the provided key value */
+static struct tnode *leaf_walk_rcu(struct tnode **tn, t_key key)
 {
-	struct hlist_node *tmp;
-	unsigned char slen = 0;
-	struct fib_alias *fa;
-	int found = 0;
+	struct tnode *pn, *n = *tn;
+	unsigned long cindex;
 
-	hlist_for_each_entry_safe(fa, tmp, &l->leaf, fa_list) {
-		struct fib_info *fi = fa->fa_info;
+	/* record parent node for backtracing */
+	pn = n;
+	cindex = n ? get_index(key, n) : 0;
 
-		if (fi && (fi->fib_flags & RTNH_F_DEAD)) {
-			hlist_del_rcu(&fa->fa_list);
-			fib_release_info(fa->fa_info);
-			alias_free_mem_rcu(fa);
-			found++;
+	/* this loop is meant to try and find the key in the trie */
+	while (n) {
+		unsigned long idx = get_index(key, n);
 
-			continue;
-		}
+		/* guarantee forward progress on the keys */
+		if (IS_LEAF(n) && (n->key >= key))
+			goto found;
+		if (idx >= (1ul << n->bits))
+			break;
 
-		slen = fa->fa_slen;
-	}
+		/* record parent and next child index */
+		pn = n;
+		cindex = idx;
 
-	l->slen = slen;
+		/* descend into the next child */
+		n = tnode_get_child_rcu(pn, cindex++);
+	}
 
-	return found;
-}
+	/* this loop will search for the next leaf with a greater key */
+	while (pn) {
+		/* if we exhausted the parent node we will need to climb */
+		if (cindex >= (1ul << pn->bits)) {
+			t_key pkey = pn->key;
 
-/* Scan for the next right leaf starting at node p->child[idx]
- * Since we have back pointer, no recursion necessary.
- */
-static struct tnode *leaf_walk_rcu(struct tnode *p, struct tnode *c)
-{
-	do {
-		unsigned long idx = c ? idx = get_index(c->key, p) + 1 : 0;
+			pn = node_parent_rcu(pn);
+			if (!pn)
+				break;
 
-		while (idx < tnode_child_length(p)) {
-			c = tnode_get_child_rcu(p, idx++);
-			if (!c)
-				continue;
+			cindex = get_index(pkey, pn) + 1;
+			continue;
+		}
 
-			if (IS_LEAF(c))
-				return c;
+		/* grab the next available node */
+		n = tnode_get_child_rcu(pn, cindex++);
+		if (!n)
+			continue;
 
-			/* Rescan start scanning in new node */
-			p = c;
-			idx = 0;
-		}
+		/* no need to compare keys since we bumped the index */
+		if (IS_LEAF(n))
+			goto found;
 
-		/* Node empty, walk back up to parent */
-		c = p;
-	} while ((p = node_parent_rcu(c)) != NULL);
+		/* Rescan start scanning in new node */
+		pn = n;
+		cindex = 0;
+	}
 
+	*tn = pn;
 	return NULL; /* Root of trie */
+found:
+	/* if we are at the limit for keys just return NULL for the tnode */
+	*tn = (n->key == KEY_MAX) ? NULL : pn;
+	return n;
 }
 
-static struct tnode *trie_firstleaf(struct trie *t)
+/* Caller must hold RTNL. */
+int fib_table_flush(struct fib_table *tb)
 {
-	struct tnode *n = rcu_dereference_rtnl(t->trie);
+	struct trie *t = (struct trie *)tb->tb_data;
+	struct hlist_node *tmp;
+	struct fib_alias *fa;
+	struct tnode *n, *pn;
+	unsigned long cindex;
+	unsigned char slen;
+	int found = 0;
 
+	n = rcu_dereference(t->trie);
 	if (!n)
-		return NULL;
+		goto flush_complete;
 
-	if (IS_LEAF(n))          /* trie is just a leaf */
-		return n;
-
-	return leaf_walk_rcu(n, NULL);
-}
+	pn = NULL;
+	cindex = 0;
 
-static struct tnode *trie_nextleaf(struct tnode *l)
-{
-	struct tnode *p = node_parent_rcu(l);
+	while (IS_TNODE(n)) {
+		/* record pn and cindex for leaf walking */
+		pn = n;
+		cindex = 1ul << n->bits;
+backtrace:
+		/* walk trie in reverse order */
+		do {
+			while (!(cindex--)) {
+				t_key pkey = pn->key;
 
-	if (!p)
-		return NULL;	/* trie with just one leaf */
+				n = pn;
+				pn = node_parent(n);
 
-	return leaf_walk_rcu(p, l);
-}
+				/* resize completed node */
+				resize(t, n);
 
-static struct tnode *trie_leafindex(struct trie *t, int index)
-{
-	struct tnode *l = trie_firstleaf(t);
+				/* if we got the root we are done */
+				if (!pn)
+					goto flush_complete;
 
-	while (l && index-- > 0)
-		l = trie_nextleaf(l);
+				cindex = get_index(pkey, pn);
+			}
 
-	return l;
-}
+			/* grab the next available node */
+			n = tnode_get_child(pn, cindex);
+		} while (!n);
+	}
 
+	/* track slen in case any prefixes survive */
+	slen = 0;
 
-/*
- * Caller must hold RTNL.
- */
-int fib_table_flush(struct fib_table *tb)
-{
-	struct trie *t = (struct trie *) tb->tb_data;
-	struct tnode *l, *ll = NULL;
-	int found = 0;
+	hlist_for_each_entry_safe(fa, tmp, &n->leaf, fa_list) {
+		struct fib_info *fi = fa->fa_info;
 
-	for (l = trie_firstleaf(t); l; l = trie_nextleaf(l)) {
-		found += trie_flush_leaf(l);
+		if (fi && (fi->fib_flags & RTNH_F_DEAD)) {
+			hlist_del_rcu(&fa->fa_list);
+			fib_release_info(fa->fa_info);
+			alias_free_mem_rcu(fa);
+			found++;
 
-		if (ll) {
-			if (hlist_empty(&ll->leaf))
-				trie_leaf_remove(t, ll);
-			else
-				leaf_pull_suffix(ll);
+			continue;
 		}
 
-		ll = l;
+		slen = fa->fa_slen;
 	}
 
-	if (ll) {
-		if (hlist_empty(&ll->leaf))
-			trie_leaf_remove(t, ll);
-		else
-			leaf_pull_suffix(ll);
+	/* update leaf slen */
+	n->slen = slen;
+
+	if (hlist_empty(&n->leaf)) {
+		put_child_root(pn, t, n->key, NULL);
+		node_free(n);
+	} else {
+		leaf_pull_suffix(pn, n);
 	}
 
+	/* if trie is leaf only loop is completed */
+	if (pn)
+		goto backtrace;
+flush_complete:
 	pr_debug("trie_flush found=%d\n", found);
 	return found;
 }
 
-void fib_free_table(struct fib_table *tb)
+static void __trie_free_rcu(struct rcu_head *head)
 {
+	struct fib_table *tb = container_of(head, struct fib_table, rcu);
 #ifdef CONFIG_IP_FIB_TRIE_STATS
 	struct trie *t = (struct trie *)tb->tb_data;
 
@@ -1626,6 +1629,11 @@ void fib_free_table(struct fib_table *tb)
 	kfree(tb);
 }
 
+void fib_free_table(struct fib_table *tb)
+{
+	call_rcu(&tb->rcu, __trie_free_rcu);
+}
+
 static int fn_trie_dump_leaf(struct tnode *l, struct fib_table *tb,
 			     struct sk_buff *skb, struct netlink_callback *cb)
 {
@@ -1662,44 +1670,40 @@ static int fn_trie_dump_leaf(struct tnode *l, struct fib_table *tb,
 	return skb->len;
 }
 
+/* rcu_read_lock needs to be hold by caller from readside */
 int fib_table_dump(struct fib_table *tb, struct sk_buff *skb,
 		   struct netlink_callback *cb)
 {
-	struct tnode *l;
-	struct trie *t = (struct trie *) tb->tb_data;
-	t_key key = cb->args[2];
-	int count = cb->args[3];
-
-	rcu_read_lock();
+	struct trie *t = (struct trie *)tb->tb_data;
+	struct tnode *l, *tp;
 	/* Dump starting at last key.
 	 * Note: 0.0.0.0/0 (ie default) is first key.
 	 */
-	if (count == 0)
-		l = trie_firstleaf(t);
-	else {
-		/* Normally, continue from last key, but if that is missing
-		 * fallback to using slow rescan
-		 */
-		l = fib_find_node(t, key);
-		if (!l)
-			l = trie_leafindex(t, count);
-	}
+	int count = cb->args[2];
+	t_key key = cb->args[3];
+
+	tp = rcu_dereference_rtnl(t->trie);
 
-	while (l) {
-		cb->args[2] = l->key;
+	while ((l = leaf_walk_rcu(&tp, key)) != NULL) {
 		if (fn_trie_dump_leaf(l, tb, skb, cb) < 0) {
-			cb->args[3] = count;
-			rcu_read_unlock();
+			cb->args[3] = key;
+			cb->args[2] = count;
 			return -1;
 		}
 
 		++count;
-		l = trie_nextleaf(l);
+		key = l->key + 1;
+
 		memset(&cb->args[4], 0,
 		       sizeof(cb->args) - 4*sizeof(cb->args[0]));
+
+		/* stop loop if key wrapped back to 0 */
+		if (key < l->key)
+			break;
 	}
-	cb->args[3] = count;
-	rcu_read_unlock();
+
+	cb->args[3] = key;
+	cb->args[2] = count;
 
 	return skb->len;
 }
@@ -1711,7 +1715,7 @@ void __init fib_trie_init(void)
 					  0, SLAB_PANIC, NULL);
 
 	trie_leaf_kmem = kmem_cache_create("ip_fib_trie",
-					   sizeof(struct tnode),
+					   LEAF_SIZE,
 					   0, SLAB_PANIC, NULL);
 }
 
@@ -1869,13 +1873,13 @@ static void trie_show_stats(struct seq_file *seq, struct trie_stat *stat)
 	seq_printf(seq, "\tMax depth:      %u\n", stat->maxdepth);
 
 	seq_printf(seq, "\tLeaves:         %u\n", stat->leaves);
-	bytes = sizeof(struct tnode) * stat->leaves;
+	bytes = LEAF_SIZE * stat->leaves;
 
 	seq_printf(seq, "\tPrefixes:       %u\n", stat->prefixes);
 	bytes += sizeof(struct fib_alias) * stat->prefixes;
 
 	seq_printf(seq, "\tInternal nodes: %u\n\t", stat->tnodes);
-	bytes += sizeof(struct tnode) * stat->tnodes;
+	bytes += TNODE_SIZE(0) * stat->tnodes;
 
 	max = MAX_STAT_DEPTH;
 	while (max > 0 && stat->nodesizes[max-1] == 0)
@@ -1944,7 +1948,7 @@ static int fib_triestat_seq_show(struct seq_file *seq, void *v)
 	seq_printf(seq,
 		   "Basic info: size of leaf:"
 		   " %Zd bytes, size of tnode: %Zd bytes.\n",
-		   sizeof(struct tnode), sizeof(struct tnode));
+		   LEAF_SIZE, TNODE_SIZE(0));
 
 	for (h = 0; h < FIB_TABLE_HASHSZ; h++) {
 		struct hlist_head *head = &net->ipv4.fib_table_hash[h];
@@ -2171,31 +2175,46 @@ static const struct file_operations fib_trie_fops = {
 
 struct fib_route_iter {
 	struct seq_net_private p;
-	struct trie *main_trie;
+	struct fib_table *main_tb;
+	struct tnode *tnode;
 	loff_t	pos;
 	t_key	key;
 };
 
 static struct tnode *fib_route_get_idx(struct fib_route_iter *iter, loff_t pos)
 {
-	struct tnode *l = NULL;
-	struct trie *t = iter->main_trie;
+	struct fib_table *tb = iter->main_tb;
+	struct tnode *l, **tp = &iter->tnode;
+	struct trie *t;
+	t_key key;
 
-	/* use cache location of last found key */
-	if (iter->pos > 0 && pos >= iter->pos && (l = fib_find_node(t, iter->key)))
+	/* use cache location of next-to-find key */
+	if (iter->pos > 0 && pos >= iter->pos) {
 		pos -= iter->pos;
-	else {
+		key = iter->key;
+	} else {
+		t = (struct trie *)tb->tb_data;
+		iter->tnode = rcu_dereference_rtnl(t->trie);
 		iter->pos = 0;
-		l = trie_firstleaf(t);
+		key = 0;
 	}
 
-	while (l && pos-- > 0) {
+	while ((l = leaf_walk_rcu(tp, key)) != NULL) {
+		key = l->key + 1;
 		iter->pos++;
-		l = trie_nextleaf(l);
+
+		if (pos-- <= 0)
+			break;
+
+		l = NULL;
+
+		/* handle unlikely case of a key wrap */
+		if (!key)
+			break;
 	}
 
 	if (l)
-		iter->key = pos;	/* remember it */
+		iter->key = key;	/* remember it */
 	else
 		iter->pos = 0;		/* forget it */
 
@@ -2207,37 +2226,46 @@ static void *fib_route_seq_start(struct seq_file *seq, loff_t *pos)
 {
 	struct fib_route_iter *iter = seq->private;
 	struct fib_table *tb;
+	struct trie *t;
 
 	rcu_read_lock();
+
 	tb = fib_get_table(seq_file_net(seq), RT_TABLE_MAIN);
 	if (!tb)
 		return NULL;
 
-	iter->main_trie = (struct trie *) tb->tb_data;
-	if (*pos == 0)
-		return SEQ_START_TOKEN;
-	else
-		return fib_route_get_idx(iter, *pos - 1);
+	iter->main_tb = tb;
+
+	if (*pos != 0)
+		return fib_route_get_idx(iter, *pos);
+
+	t = (struct trie *)tb->tb_data;
+	iter->tnode = rcu_dereference_rtnl(t->trie);
+	iter->pos = 0;
+	iter->key = 0;
+
+	return SEQ_START_TOKEN;
 }
 
 static void *fib_route_seq_next(struct seq_file *seq, void *v, loff_t *pos)
 {
 	struct fib_route_iter *iter = seq->private;
-	struct tnode *l = v;
+	struct tnode *l = NULL;
+	t_key key = iter->key;
 
 	++*pos;
-	if (v == SEQ_START_TOKEN) {
-		iter->pos = 0;
-		l = trie_firstleaf(iter->main_trie);
-	} else {
+
+	/* only allow key of 0 for start of sequence */
+	if ((v == SEQ_START_TOKEN) || key)
+		l = leaf_walk_rcu(&iter->tnode, key);
+
+	if (l) {
+		iter->key = l->key + 1;
 		iter->pos++;
-		l = trie_nextleaf(l);
+	} else {
+		iter->pos = 0;
 	}
 
-	if (l)
-		iter->key = l->key;
-	else
-		iter->pos = 0;
 	return l;
 }