Commit e341694e3eb57fcda9f1adc7bfea42fe080d8d7a

Authored by Thomas Graf
Committed by David S. Miller
1 parent 7e1e77636e

netlink: Convert netlink_lookup() to use RCU protected hash table

Heavy Netlink users such as Open vSwitch spend a considerable amount of
time in netlink_lookup() due to the read-lock on nl_table_lock. Use of
RCU relieves the lock contention.

Makes use of the new resizable hash table to avoid locking on the
lookup.

The hash table will grow if entries exceeds 75% of table size up to a
total table size of 64K. It will automatically shrink if usage falls
below 30%.

Also splits nl_table_lock into a separate mutex to protect hash table
mutations and allow synchronize_rcu() to sleep while waiting for readers
during expansion and shrinking.

Before:
   9.16%  kpktgend_0  [openvswitch]      [k] masked_flow_lookup
   6.42%  kpktgend_0  [pktgen]           [k] mod_cur_headers
   6.26%  kpktgend_0  [pktgen]           [k] pktgen_thread_worker
   6.23%  kpktgend_0  [kernel.kallsyms]  [k] memset
   4.79%  kpktgend_0  [kernel.kallsyms]  [k] netlink_lookup
   4.37%  kpktgend_0  [kernel.kallsyms]  [k] memcpy
   3.60%  kpktgend_0  [openvswitch]      [k] ovs_flow_extract
   2.69%  kpktgend_0  [kernel.kallsyms]  [k] jhash2

After:
  15.26%  kpktgend_0  [openvswitch]      [k] masked_flow_lookup
   8.12%  kpktgend_0  [pktgen]           [k] pktgen_thread_worker
   7.92%  kpktgend_0  [pktgen]           [k] mod_cur_headers
   5.11%  kpktgend_0  [kernel.kallsyms]  [k] memset
   4.11%  kpktgend_0  [openvswitch]      [k] ovs_flow_extract
   4.06%  kpktgend_0  [kernel.kallsyms]  [k] _raw_spin_lock
   3.90%  kpktgend_0  [kernel.kallsyms]  [k] jhash2
   [...]
   0.67%  kpktgend_0  [kernel.kallsyms]  [k] netlink_lookup

Signed-off-by: Thomas Graf <tgraf@suug.ch>
Reviewed-by: Nikolay Aleksandrov <nikolay@redhat.com>
Signed-off-by: David S. Miller <davem@davemloft.net>

Showing 3 changed files with 119 additions and 195 deletions Side-by-side Diff

net/netlink/af_netlink.c
... ... @@ -58,7 +58,9 @@
58 58 #include <linux/mutex.h>
59 59 #include <linux/vmalloc.h>
60 60 #include <linux/if_arp.h>
  61 +#include <linux/rhashtable.h>
61 62 #include <asm/cacheflush.h>
  63 +#include <linux/hash.h>
62 64  
63 65 #include <net/net_namespace.h>
64 66 #include <net/sock.h>
... ... @@ -100,6 +102,18 @@
100 102  
101 103 #define nl_deref_protected(X) rcu_dereference_protected(X, lockdep_is_held(&nl_table_lock));
102 104  
  105 +/* Protects netlink socket hash table mutations */
  106 +DEFINE_MUTEX(nl_sk_hash_lock);
  107 +
  108 +static int lockdep_nl_sk_hash_is_held(void)
  109 +{
  110 +#ifdef CONFIG_LOCKDEP
  111 + return (debug_locks) ? lockdep_is_held(&nl_sk_hash_lock) : 1;
  112 +#else
  113 + return 1;
  114 +#endif
  115 +}
  116 +
103 117 static ATOMIC_NOTIFIER_HEAD(netlink_chain);
104 118  
105 119 static DEFINE_SPINLOCK(netlink_tap_lock);
... ... @@ -110,11 +124,6 @@
110 124 return group ? 1 << (group - 1) : 0;
111 125 }
112 126  
113   -static inline struct hlist_head *nl_portid_hashfn(struct nl_portid_hash *hash, u32 portid)
114   -{
115   - return &hash->table[jhash_1word(portid, hash->rnd) & hash->mask];
116   -}
117   -
118 127 int netlink_add_tap(struct netlink_tap *nt)
119 128 {
120 129 if (unlikely(nt->dev->type != ARPHRD_NETLINK))
121 130  
122 131  
123 132  
124 133  
125 134  
126 135  
127 136  
128 137  
129 138  
130 139  
131 140  
132 141  
... ... @@ -983,105 +992,48 @@
983 992 wake_up(&nl_table_wait);
984 993 }
985 994  
986   -static bool netlink_compare(struct net *net, struct sock *sk)
  995 +struct netlink_compare_arg
987 996 {
988   - return net_eq(sock_net(sk), net);
989   -}
  997 + struct net *net;
  998 + u32 portid;
  999 +};
990 1000  
991   -static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid)
  1001 +static bool netlink_compare(void *ptr, void *arg)
992 1002 {
993   - struct netlink_table *table = &nl_table[protocol];
994   - struct nl_portid_hash *hash = &table->hash;
995   - struct hlist_head *head;
996   - struct sock *sk;
  1003 + struct netlink_compare_arg *x = arg;
  1004 + struct sock *sk = ptr;
997 1005  
998   - read_lock(&nl_table_lock);
999   - head = nl_portid_hashfn(hash, portid);
1000   - sk_for_each(sk, head) {
1001   - if (table->compare(net, sk) &&
1002   - (nlk_sk(sk)->portid == portid)) {
1003   - sock_hold(sk);
1004   - goto found;
1005   - }
1006   - }
1007   - sk = NULL;
1008   -found:
1009   - read_unlock(&nl_table_lock);
1010   - return sk;
  1006 + return nlk_sk(sk)->portid == x->portid &&
  1007 + net_eq(sock_net(sk), x->net);
1011 1008 }
1012 1009  
1013   -static struct hlist_head *nl_portid_hash_zalloc(size_t size)
  1010 +static struct sock *__netlink_lookup(struct netlink_table *table, u32 portid,
  1011 + struct net *net)
1014 1012 {
1015   - if (size <= PAGE_SIZE)
1016   - return kzalloc(size, GFP_ATOMIC);
1017   - else
1018   - return (struct hlist_head *)
1019   - __get_free_pages(GFP_ATOMIC | __GFP_ZERO,
1020   - get_order(size));
1021   -}
  1013 + struct netlink_compare_arg arg = {
  1014 + .net = net,
  1015 + .portid = portid,
  1016 + };
  1017 + u32 hash;
1022 1018  
1023   -static void nl_portid_hash_free(struct hlist_head *table, size_t size)
1024   -{
1025   - if (size <= PAGE_SIZE)
1026   - kfree(table);
1027   - else
1028   - free_pages((unsigned long)table, get_order(size));
1029   -}
  1019 + hash = rhashtable_hashfn(&table->hash, &portid, sizeof(portid));
1030 1020  
1031   -static int nl_portid_hash_rehash(struct nl_portid_hash *hash, int grow)
1032   -{
1033   - unsigned int omask, mask, shift;
1034   - size_t osize, size;
1035   - struct hlist_head *otable, *table;
1036   - int i;
1037   -
1038   - omask = mask = hash->mask;
1039   - osize = size = (mask + 1) * sizeof(*table);
1040   - shift = hash->shift;
1041   -
1042   - if (grow) {
1043   - if (++shift > hash->max_shift)
1044   - return 0;
1045   - mask = mask * 2 + 1;
1046   - size *= 2;
1047   - }
1048   -
1049   - table = nl_portid_hash_zalloc(size);
1050   - if (!table)
1051   - return 0;
1052   -
1053   - otable = hash->table;
1054   - hash->table = table;
1055   - hash->mask = mask;
1056   - hash->shift = shift;
1057   - get_random_bytes(&hash->rnd, sizeof(hash->rnd));
1058   -
1059   - for (i = 0; i <= omask; i++) {
1060   - struct sock *sk;
1061   - struct hlist_node *tmp;
1062   -
1063   - sk_for_each_safe(sk, tmp, &otable[i])
1064   - __sk_add_node(sk, nl_portid_hashfn(hash, nlk_sk(sk)->portid));
1065   - }
1066   -
1067   - nl_portid_hash_free(otable, osize);
1068   - hash->rehash_time = jiffies + 10 * 60 * HZ;
1069   - return 1;
  1021 + return rhashtable_lookup_compare(&table->hash, hash,
  1022 + &netlink_compare, &arg);
1070 1023 }
1071 1024  
1072   -static inline int nl_portid_hash_dilute(struct nl_portid_hash *hash, int len)
  1025 +static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid)
1073 1026 {
1074   - int avg = hash->entries >> hash->shift;
  1027 + struct netlink_table *table = &nl_table[protocol];
  1028 + struct sock *sk;
1075 1029  
1076   - if (unlikely(avg > 1) && nl_portid_hash_rehash(hash, 1))
1077   - return 1;
  1030 + rcu_read_lock();
  1031 + sk = __netlink_lookup(table, portid, net);
  1032 + if (sk)
  1033 + sock_hold(sk);
  1034 + rcu_read_unlock();
1078 1035  
1079   - if (unlikely(len > avg) && time_after(jiffies, hash->rehash_time)) {
1080   - nl_portid_hash_rehash(hash, 0);
1081   - return 1;
1082   - }
1083   -
1084   - return 0;
  1036 + return sk;
1085 1037 }
1086 1038  
1087 1039 static const struct proto_ops netlink_ops;
1088 1040  
1089 1041  
... ... @@ -1113,22 +1065,10 @@
1113 1065 static int netlink_insert(struct sock *sk, struct net *net, u32 portid)
1114 1066 {
1115 1067 struct netlink_table *table = &nl_table[sk->sk_protocol];
1116   - struct nl_portid_hash *hash = &table->hash;
1117   - struct hlist_head *head;
1118 1068 int err = -EADDRINUSE;
1119   - struct sock *osk;
1120   - int len;
1121 1069  
1122   - netlink_table_grab();
1123   - head = nl_portid_hashfn(hash, portid);
1124   - len = 0;
1125   - sk_for_each(osk, head) {
1126   - if (table->compare(net, osk) &&
1127   - (nlk_sk(osk)->portid == portid))
1128   - break;
1129   - len++;
1130   - }
1131   - if (osk)
  1070 + mutex_lock(&nl_sk_hash_lock);
  1071 + if (__netlink_lookup(table, portid, net))
1132 1072 goto err;
1133 1073  
1134 1074 err = -EBUSY;
1135 1075  
1136 1076  
1137 1077  
1138 1078  
1139 1079  
1140 1080  
... ... @@ -1136,26 +1076,31 @@
1136 1076 goto err;
1137 1077  
1138 1078 err = -ENOMEM;
1139   - if (BITS_PER_LONG > 32 && unlikely(hash->entries >= UINT_MAX))
  1079 + if (BITS_PER_LONG > 32 && unlikely(table->hash.nelems >= UINT_MAX))
1140 1080 goto err;
1141 1081  
1142   - if (len && nl_portid_hash_dilute(hash, len))
1143   - head = nl_portid_hashfn(hash, portid);
1144   - hash->entries++;
1145 1082 nlk_sk(sk)->portid = portid;
1146   - sk_add_node(sk, head);
  1083 + sock_hold(sk);
  1084 + rhashtable_insert(&table->hash, &nlk_sk(sk)->node, GFP_KERNEL);
1147 1085 err = 0;
1148   -
1149 1086 err:
1150   - netlink_table_ungrab();
  1087 + mutex_unlock(&nl_sk_hash_lock);
1151 1088 return err;
1152 1089 }
1153 1090  
1154 1091 static void netlink_remove(struct sock *sk)
1155 1092 {
  1093 + struct netlink_table *table;
  1094 +
  1095 + mutex_lock(&nl_sk_hash_lock);
  1096 + table = &nl_table[sk->sk_protocol];
  1097 + if (rhashtable_remove(&table->hash, &nlk_sk(sk)->node, GFP_KERNEL)) {
  1098 + WARN_ON(atomic_read(&sk->sk_refcnt) == 1);
  1099 + __sock_put(sk);
  1100 + }
  1101 + mutex_unlock(&nl_sk_hash_lock);
  1102 +
1156 1103 netlink_table_grab();
1157   - if (sk_del_node_init(sk))
1158   - nl_table[sk->sk_protocol].hash.entries--;
1159 1104 if (nlk_sk(sk)->subscriptions)
1160 1105 __sk_del_bind_node(sk);
1161 1106 netlink_table_ungrab();
... ... @@ -1311,6 +1256,9 @@
1311 1256 }
1312 1257 netlink_table_ungrab();
1313 1258  
  1259 + /* Wait for readers to complete */
  1260 + synchronize_net();
  1261 +
1314 1262 kfree(nlk->groups);
1315 1263 nlk->groups = NULL;
1316 1264  
1317 1265  
1318 1266  
... ... @@ -1326,30 +1274,22 @@
1326 1274 struct sock *sk = sock->sk;
1327 1275 struct net *net = sock_net(sk);
1328 1276 struct netlink_table *table = &nl_table[sk->sk_protocol];
1329   - struct nl_portid_hash *hash = &table->hash;
1330   - struct hlist_head *head;
1331   - struct sock *osk;
1332 1277 s32 portid = task_tgid_vnr(current);
1333 1278 int err;
1334 1279 static s32 rover = -4097;
1335 1280  
1336 1281 retry:
1337 1282 cond_resched();
1338   - netlink_table_grab();
1339   - head = nl_portid_hashfn(hash, portid);
1340   - sk_for_each(osk, head) {
1341   - if (!table->compare(net, osk))
1342   - continue;
1343   - if (nlk_sk(osk)->portid == portid) {
1344   - /* Bind collision, search negative portid values. */
1345   - portid = rover--;
1346   - if (rover > -4097)
1347   - rover = -4097;
1348   - netlink_table_ungrab();
1349   - goto retry;
1350   - }
  1283 + rcu_read_lock();
  1284 + if (__netlink_lookup(table, portid, net)) {
  1285 + /* Bind collision, search negative portid values. */
  1286 + portid = rover--;
  1287 + if (rover > -4097)
  1288 + rover = -4097;
  1289 + rcu_read_unlock();
  1290 + goto retry;
1351 1291 }
1352   - netlink_table_ungrab();
  1292 + rcu_read_unlock();
1353 1293  
1354 1294 err = netlink_insert(sk, net, portid);
1355 1295 if (err == -EADDRINUSE)
1356 1296  
1357 1297  
... ... @@ -2953,14 +2893,18 @@
2953 2893 {
2954 2894 struct nl_seq_iter *iter = seq->private;
2955 2895 int i, j;
  2896 + struct netlink_sock *nlk;
2956 2897 struct sock *s;
2957 2898 loff_t off = 0;
2958 2899  
2959 2900 for (i = 0; i < MAX_LINKS; i++) {
2960   - struct nl_portid_hash *hash = &nl_table[i].hash;
  2901 + struct rhashtable *ht = &nl_table[i].hash;
  2902 + const struct bucket_table *tbl = rht_dereference(ht->tbl, ht);
2961 2903  
2962   - for (j = 0; j <= hash->mask; j++) {
2963   - sk_for_each(s, &hash->table[j]) {
  2904 + for (j = 0; j < tbl->size; j++) {
  2905 + rht_for_each_entry_rcu(nlk, tbl->buckets[j], node) {
  2906 + s = (struct sock *)nlk;
  2907 +
2964 2908 if (sock_net(s) != seq_file_net(seq))
2965 2909 continue;
2966 2910 if (off == pos) {
2967 2911  
2968 2912  
... ... @@ -2976,15 +2920,14 @@
2976 2920 }
2977 2921  
2978 2922 static void *netlink_seq_start(struct seq_file *seq, loff_t *pos)
2979   - __acquires(nl_table_lock)
2980 2923 {
2981   - read_lock(&nl_table_lock);
  2924 + rcu_read_lock();
2982 2925 return *pos ? netlink_seq_socket_idx(seq, *pos - 1) : SEQ_START_TOKEN;
2983 2926 }
2984 2927  
2985 2928 static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
2986 2929 {
2987   - struct sock *s;
  2930 + struct netlink_sock *nlk;
2988 2931 struct nl_seq_iter *iter;
2989 2932 struct net *net;
2990 2933 int i, j;
2991 2934  
2992 2935  
2993 2936  
... ... @@ -2996,28 +2939,26 @@
2996 2939  
2997 2940 net = seq_file_net(seq);
2998 2941 iter = seq->private;
2999   - s = v;
3000   - do {
3001   - s = sk_next(s);
3002   - } while (s && !nl_table[s->sk_protocol].compare(net, s));
3003   - if (s)
3004   - return s;
  2942 + nlk = v;
3005 2943  
  2944 + rht_for_each_entry_rcu(nlk, nlk->node.next, node)
  2945 + if (net_eq(sock_net((struct sock *)nlk), net))
  2946 + return nlk;
  2947 +
3006 2948 i = iter->link;
3007 2949 j = iter->hash_idx + 1;
3008 2950  
3009 2951 do {
3010   - struct nl_portid_hash *hash = &nl_table[i].hash;
  2952 + struct rhashtable *ht = &nl_table[i].hash;
  2953 + const struct bucket_table *tbl = rht_dereference(ht->tbl, ht);
3011 2954  
3012   - for (; j <= hash->mask; j++) {
3013   - s = sk_head(&hash->table[j]);
3014   -
3015   - while (s && !nl_table[s->sk_protocol].compare(net, s))
3016   - s = sk_next(s);
3017   - if (s) {
3018   - iter->link = i;
3019   - iter->hash_idx = j;
3020   - return s;
  2955 + for (; j < tbl->size; j++) {
  2956 + rht_for_each_entry_rcu(nlk, tbl->buckets[j], node) {
  2957 + if (net_eq(sock_net((struct sock *)nlk), net)) {
  2958 + iter->link = i;
  2959 + iter->hash_idx = j;
  2960 + return nlk;
  2961 + }
3021 2962 }
3022 2963 }
3023 2964  
3024 2965  
... ... @@ -3028,9 +2969,8 @@
3028 2969 }
3029 2970  
3030 2971 static void netlink_seq_stop(struct seq_file *seq, void *v)
3031   - __releases(nl_table_lock)
3032 2972 {
3033   - read_unlock(&nl_table_lock);
  2973 + rcu_read_unlock();
3034 2974 }
3035 2975  
3036 2976  
3037 2977  
... ... @@ -3168,9 +3108,17 @@
3168 3108 static int __init netlink_proto_init(void)
3169 3109 {
3170 3110 int i;
3171   - unsigned long limit;
3172   - unsigned int order;
3173 3111 int err = proto_register(&netlink_proto, 0);
  3112 + struct rhashtable_params ht_params = {
  3113 + .head_offset = offsetof(struct netlink_sock, node),
  3114 + .key_offset = offsetof(struct netlink_sock, portid),
  3115 + .key_len = sizeof(u32), /* portid */
  3116 + .hashfn = arch_fast_hash,
  3117 + .max_shift = 16, /* 64K */
  3118 + .grow_decision = rht_grow_above_75,
  3119 + .shrink_decision = rht_shrink_below_30,
  3120 + .mutex_is_held = lockdep_nl_sk_hash_is_held,
  3121 + };
3174 3122  
3175 3123 if (err != 0)
3176 3124 goto out;
3177 3125  
3178 3126  
... ... @@ -3181,32 +3129,13 @@
3181 3129 if (!nl_table)
3182 3130 goto panic;
3183 3131  
3184   - if (totalram_pages >= (128 * 1024))
3185   - limit = totalram_pages >> (21 - PAGE_SHIFT);
3186   - else
3187   - limit = totalram_pages >> (23 - PAGE_SHIFT);
3188   -
3189   - order = get_bitmask_order(limit) - 1 + PAGE_SHIFT;
3190   - limit = (1UL << order) / sizeof(struct hlist_head);
3191   - order = get_bitmask_order(min(limit, (unsigned long)UINT_MAX)) - 1;
3192   -
3193 3132 for (i = 0; i < MAX_LINKS; i++) {
3194   - struct nl_portid_hash *hash = &nl_table[i].hash;
3195   -
3196   - hash->table = nl_portid_hash_zalloc(1 * sizeof(*hash->table));
3197   - if (!hash->table) {
3198   - while (i-- > 0)
3199   - nl_portid_hash_free(nl_table[i].hash.table,
3200   - 1 * sizeof(*hash->table));
  3133 + if (rhashtable_init(&nl_table[i].hash, &ht_params) < 0) {
  3134 + while (--i > 0)
  3135 + rhashtable_destroy(&nl_table[i].hash);
3201 3136 kfree(nl_table);
3202 3137 goto panic;
3203 3138 }
3204   - hash->max_shift = order;
3205   - hash->shift = 0;
3206   - hash->mask = 0;
3207   - hash->rehash_time = jiffies;
3208   -
3209   - nl_table[i].compare = netlink_compare;
3210 3139 }
3211 3140  
3212 3141 INIT_LIST_HEAD(&netlink_tap_all);
net/netlink/af_netlink.h
1 1 #ifndef _AF_NETLINK_H
2 2 #define _AF_NETLINK_H
3 3  
  4 +#include <linux/rhashtable.h>
4 5 #include <net/sock.h>
5 6  
6 7 #define NLGRPSZ(x) (ALIGN(x, sizeof(unsigned long) * 8) / 8)
... ... @@ -47,6 +48,8 @@
47 48 struct netlink_ring tx_ring;
48 49 atomic_t mapped;
49 50 #endif /* CONFIG_NETLINK_MMAP */
  51 +
  52 + struct rhash_head node;
50 53 };
51 54  
52 55 static inline struct netlink_sock *nlk_sk(struct sock *sk)
53 56  
... ... @@ -54,21 +57,8 @@
54 57 return container_of(sk, struct netlink_sock, sk);
55 58 }
56 59  
57   -struct nl_portid_hash {
58   - struct hlist_head *table;
59   - unsigned long rehash_time;
60   -
61   - unsigned int mask;
62   - unsigned int shift;
63   -
64   - unsigned int entries;
65   - unsigned int max_shift;
66   -
67   - u32 rnd;
68   -};
69   -
70 60 struct netlink_table {
71   - struct nl_portid_hash hash;
  61 + struct rhashtable hash;
72 62 struct hlist_head mc_list;
73 63 struct listeners __rcu *listeners;
74 64 unsigned int flags;
... ... @@ -4,6 +4,7 @@
4 4 #include <linux/netlink.h>
5 5 #include <linux/sock_diag.h>
6 6 #include <linux/netlink_diag.h>
  7 +#include <linux/rhashtable.h>
7 8  
8 9 #include "af_netlink.h"
9 10  
10 11  
11 12  
... ... @@ -101,16 +102,20 @@
101 102 int protocol, int s_num)
102 103 {
103 104 struct netlink_table *tbl = &nl_table[protocol];
104   - struct nl_portid_hash *hash = &tbl->hash;
  105 + struct rhashtable *ht = &tbl->hash;
  106 + const struct bucket_table *htbl = rht_dereference(ht->tbl, ht);
105 107 struct net *net = sock_net(skb->sk);
106 108 struct netlink_diag_req *req;
  109 + struct netlink_sock *nlsk;
107 110 struct sock *sk;
108 111 int ret = 0, num = 0, i;
109 112  
110 113 req = nlmsg_data(cb->nlh);
111 114  
112   - for (i = 0; i <= hash->mask; i++) {
113   - sk_for_each(sk, &hash->table[i]) {
  115 + for (i = 0; i < htbl->size; i++) {
  116 + rht_for_each_entry(nlsk, htbl->buckets[i], ht, node) {
  117 + sk = (struct sock *)nlsk;
  118 +
114 119 if (!net_eq(sock_net(sk), net))
115 120 continue;
116 121 if (num < s_num) {