auth.c 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924
  1. /*
  2. * linux/net/sunrpc/auth.c
  3. *
  4. * Generic RPC client authentication API.
  5. *
  6. * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de>
  7. */
  8. #include <linux/types.h>
  9. #include <linux/sched.h>
  10. #include <linux/cred.h>
  11. #include <linux/module.h>
  12. #include <linux/slab.h>
  13. #include <linux/errno.h>
  14. #include <linux/hash.h>
  15. #include <linux/sunrpc/clnt.h>
  16. #include <linux/sunrpc/gss_api.h>
  17. #include <linux/spinlock.h>
  18. #if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
  19. # define RPCDBG_FACILITY RPCDBG_AUTH
  20. #endif
  21. #define RPC_CREDCACHE_DEFAULT_HASHBITS (4)
  22. struct rpc_cred_cache {
  23. struct hlist_head *hashtable;
  24. unsigned int hashbits;
  25. spinlock_t lock;
  26. };
  27. static unsigned int auth_hashbits = RPC_CREDCACHE_DEFAULT_HASHBITS;
  28. static const struct rpc_authops __rcu *auth_flavors[RPC_AUTH_MAXFLAVOR] = {
  29. [RPC_AUTH_NULL] = (const struct rpc_authops __force __rcu *)&authnull_ops,
  30. [RPC_AUTH_UNIX] = (const struct rpc_authops __force __rcu *)&authunix_ops,
  31. NULL, /* others can be loadable modules */
  32. };
  33. static LIST_HEAD(cred_unused);
  34. static unsigned long number_cred_unused;
  35. #define MAX_HASHTABLE_BITS (14)
  36. static int param_set_hashtbl_sz(const char *val, const struct kernel_param *kp)
  37. {
  38. unsigned long num;
  39. unsigned int nbits;
  40. int ret;
  41. if (!val)
  42. goto out_inval;
  43. ret = kstrtoul(val, 0, &num);
  44. if (ret)
  45. goto out_inval;
  46. nbits = fls(num - 1);
  47. if (nbits > MAX_HASHTABLE_BITS || nbits < 2)
  48. goto out_inval;
  49. *(unsigned int *)kp->arg = nbits;
  50. return 0;
  51. out_inval:
  52. return -EINVAL;
  53. }
  54. static int param_get_hashtbl_sz(char *buffer, const struct kernel_param *kp)
  55. {
  56. unsigned int nbits;
  57. nbits = *(unsigned int *)kp->arg;
  58. return sprintf(buffer, "%u", 1U << nbits);
  59. }
  60. #define param_check_hashtbl_sz(name, p) __param_check(name, p, unsigned int);
  61. static const struct kernel_param_ops param_ops_hashtbl_sz = {
  62. .set = param_set_hashtbl_sz,
  63. .get = param_get_hashtbl_sz,
  64. };
  65. module_param_named(auth_hashtable_size, auth_hashbits, hashtbl_sz, 0644);
  66. MODULE_PARM_DESC(auth_hashtable_size, "RPC credential cache hashtable size");
  67. static unsigned long auth_max_cred_cachesize = ULONG_MAX;
  68. module_param(auth_max_cred_cachesize, ulong, 0644);
  69. MODULE_PARM_DESC(auth_max_cred_cachesize, "RPC credential maximum total cache size");
  70. static u32
  71. pseudoflavor_to_flavor(u32 flavor) {
  72. if (flavor > RPC_AUTH_MAXFLAVOR)
  73. return RPC_AUTH_GSS;
  74. return flavor;
  75. }
  76. int
  77. rpcauth_register(const struct rpc_authops *ops)
  78. {
  79. const struct rpc_authops *old;
  80. rpc_authflavor_t flavor;
  81. if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
  82. return -EINVAL;
  83. old = cmpxchg((const struct rpc_authops ** __force)&auth_flavors[flavor], NULL, ops);
  84. if (old == NULL || old == ops)
  85. return 0;
  86. return -EPERM;
  87. }
  88. EXPORT_SYMBOL_GPL(rpcauth_register);
  89. int
  90. rpcauth_unregister(const struct rpc_authops *ops)
  91. {
  92. const struct rpc_authops *old;
  93. rpc_authflavor_t flavor;
  94. if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
  95. return -EINVAL;
  96. old = cmpxchg((const struct rpc_authops ** __force)&auth_flavors[flavor], ops, NULL);
  97. if (old == ops || old == NULL)
  98. return 0;
  99. return -EPERM;
  100. }
  101. EXPORT_SYMBOL_GPL(rpcauth_unregister);
  102. static const struct rpc_authops *
  103. rpcauth_get_authops(rpc_authflavor_t flavor)
  104. {
  105. const struct rpc_authops *ops;
  106. if (flavor >= RPC_AUTH_MAXFLAVOR)
  107. return NULL;
  108. rcu_read_lock();
  109. ops = rcu_dereference(auth_flavors[flavor]);
  110. if (ops == NULL) {
  111. rcu_read_unlock();
  112. request_module("rpc-auth-%u", flavor);
  113. rcu_read_lock();
  114. ops = rcu_dereference(auth_flavors[flavor]);
  115. if (ops == NULL)
  116. goto out;
  117. }
  118. if (!try_module_get(ops->owner))
  119. ops = NULL;
  120. out:
  121. rcu_read_unlock();
  122. return ops;
  123. }
  124. static void
  125. rpcauth_put_authops(const struct rpc_authops *ops)
  126. {
  127. module_put(ops->owner);
  128. }
  129. /**
  130. * rpcauth_get_pseudoflavor - check if security flavor is supported
  131. * @flavor: a security flavor
  132. * @info: a GSS mech OID, quality of protection, and service value
  133. *
  134. * Verifies that an appropriate kernel module is available or already loaded.
  135. * Returns an equivalent pseudoflavor, or RPC_AUTH_MAXFLAVOR if "flavor" is
  136. * not supported locally.
  137. */
  138. rpc_authflavor_t
  139. rpcauth_get_pseudoflavor(rpc_authflavor_t flavor, struct rpcsec_gss_info *info)
  140. {
  141. const struct rpc_authops *ops = rpcauth_get_authops(flavor);
  142. rpc_authflavor_t pseudoflavor;
  143. if (!ops)
  144. return RPC_AUTH_MAXFLAVOR;
  145. pseudoflavor = flavor;
  146. if (ops->info2flavor != NULL)
  147. pseudoflavor = ops->info2flavor(info);
  148. rpcauth_put_authops(ops);
  149. return pseudoflavor;
  150. }
  151. EXPORT_SYMBOL_GPL(rpcauth_get_pseudoflavor);
  152. /**
  153. * rpcauth_get_gssinfo - find GSS tuple matching a GSS pseudoflavor
  154. * @pseudoflavor: GSS pseudoflavor to match
  155. * @info: rpcsec_gss_info structure to fill in
  156. *
  157. * Returns zero and fills in "info" if pseudoflavor matches a
  158. * supported mechanism.
  159. */
  160. int
  161. rpcauth_get_gssinfo(rpc_authflavor_t pseudoflavor, struct rpcsec_gss_info *info)
  162. {
  163. rpc_authflavor_t flavor = pseudoflavor_to_flavor(pseudoflavor);
  164. const struct rpc_authops *ops;
  165. int result;
  166. ops = rpcauth_get_authops(flavor);
  167. if (ops == NULL)
  168. return -ENOENT;
  169. result = -ENOENT;
  170. if (ops->flavor2info != NULL)
  171. result = ops->flavor2info(pseudoflavor, info);
  172. rpcauth_put_authops(ops);
  173. return result;
  174. }
  175. EXPORT_SYMBOL_GPL(rpcauth_get_gssinfo);
  176. /**
  177. * rpcauth_list_flavors - discover registered flavors and pseudoflavors
  178. * @array: array to fill in
  179. * @size: size of "array"
  180. *
  181. * Returns the number of array items filled in, or a negative errno.
  182. *
  183. * The returned array is not sorted by any policy. Callers should not
  184. * rely on the order of the items in the returned array.
  185. */
  186. int
  187. rpcauth_list_flavors(rpc_authflavor_t *array, int size)
  188. {
  189. const struct rpc_authops *ops;
  190. rpc_authflavor_t flavor, pseudos[4];
  191. int i, len, result = 0;
  192. rcu_read_lock();
  193. for (flavor = 0; flavor < RPC_AUTH_MAXFLAVOR; flavor++) {
  194. ops = rcu_dereference(auth_flavors[flavor]);
  195. if (result >= size) {
  196. result = -ENOMEM;
  197. break;
  198. }
  199. if (ops == NULL)
  200. continue;
  201. if (ops->list_pseudoflavors == NULL) {
  202. array[result++] = ops->au_flavor;
  203. continue;
  204. }
  205. len = ops->list_pseudoflavors(pseudos, ARRAY_SIZE(pseudos));
  206. if (len < 0) {
  207. result = len;
  208. break;
  209. }
  210. for (i = 0; i < len; i++) {
  211. if (result >= size) {
  212. result = -ENOMEM;
  213. break;
  214. }
  215. array[result++] = pseudos[i];
  216. }
  217. }
  218. rcu_read_unlock();
  219. dprintk("RPC: %s returns %d\n", __func__, result);
  220. return result;
  221. }
  222. EXPORT_SYMBOL_GPL(rpcauth_list_flavors);
  223. struct rpc_auth *
  224. rpcauth_create(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt)
  225. {
  226. struct rpc_auth *auth = ERR_PTR(-EINVAL);
  227. const struct rpc_authops *ops;
  228. u32 flavor = pseudoflavor_to_flavor(args->pseudoflavor);
  229. ops = rpcauth_get_authops(flavor);
  230. if (ops == NULL)
  231. goto out;
  232. auth = ops->create(args, clnt);
  233. rpcauth_put_authops(ops);
  234. if (IS_ERR(auth))
  235. return auth;
  236. if (clnt->cl_auth)
  237. rpcauth_release(clnt->cl_auth);
  238. clnt->cl_auth = auth;
  239. out:
  240. return auth;
  241. }
  242. EXPORT_SYMBOL_GPL(rpcauth_create);
  243. void
  244. rpcauth_release(struct rpc_auth *auth)
  245. {
  246. if (!refcount_dec_and_test(&auth->au_count))
  247. return;
  248. auth->au_ops->destroy(auth);
  249. }
  250. static DEFINE_SPINLOCK(rpc_credcache_lock);
  251. /*
  252. * On success, the caller is responsible for freeing the reference
  253. * held by the hashtable
  254. */
  255. static bool
  256. rpcauth_unhash_cred_locked(struct rpc_cred *cred)
  257. {
  258. if (!test_and_clear_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags))
  259. return false;
  260. hlist_del_rcu(&cred->cr_hash);
  261. return true;
  262. }
  263. static bool
  264. rpcauth_unhash_cred(struct rpc_cred *cred)
  265. {
  266. spinlock_t *cache_lock;
  267. bool ret;
  268. if (!test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags))
  269. return false;
  270. cache_lock = &cred->cr_auth->au_credcache->lock;
  271. spin_lock(cache_lock);
  272. ret = rpcauth_unhash_cred_locked(cred);
  273. spin_unlock(cache_lock);
  274. return ret;
  275. }
  276. /*
  277. * Initialize RPC credential cache
  278. */
  279. int
  280. rpcauth_init_credcache(struct rpc_auth *auth)
  281. {
  282. struct rpc_cred_cache *new;
  283. unsigned int hashsize;
  284. new = kmalloc(sizeof(*new), GFP_KERNEL);
  285. if (!new)
  286. goto out_nocache;
  287. new->hashbits = auth_hashbits;
  288. hashsize = 1U << new->hashbits;
  289. new->hashtable = kcalloc(hashsize, sizeof(new->hashtable[0]), GFP_KERNEL);
  290. if (!new->hashtable)
  291. goto out_nohashtbl;
  292. spin_lock_init(&new->lock);
  293. auth->au_credcache = new;
  294. return 0;
  295. out_nohashtbl:
  296. kfree(new);
  297. out_nocache:
  298. return -ENOMEM;
  299. }
  300. EXPORT_SYMBOL_GPL(rpcauth_init_credcache);
  301. /*
  302. * Setup a credential key lifetime timeout notification
  303. */
  304. int
  305. rpcauth_key_timeout_notify(struct rpc_auth *auth, struct rpc_cred *cred)
  306. {
  307. if (!cred->cr_auth->au_ops->key_timeout)
  308. return 0;
  309. return cred->cr_auth->au_ops->key_timeout(auth, cred);
  310. }
  311. EXPORT_SYMBOL_GPL(rpcauth_key_timeout_notify);
  312. bool
  313. rpcauth_cred_key_to_expire(struct rpc_auth *auth, struct rpc_cred *cred)
  314. {
  315. if (auth->au_flags & RPCAUTH_AUTH_NO_CRKEY_TIMEOUT)
  316. return false;
  317. if (!cred->cr_ops->crkey_to_expire)
  318. return false;
  319. return cred->cr_ops->crkey_to_expire(cred);
  320. }
  321. EXPORT_SYMBOL_GPL(rpcauth_cred_key_to_expire);
  322. char *
  323. rpcauth_stringify_acceptor(struct rpc_cred *cred)
  324. {
  325. if (!cred->cr_ops->crstringify_acceptor)
  326. return NULL;
  327. return cred->cr_ops->crstringify_acceptor(cred);
  328. }
  329. EXPORT_SYMBOL_GPL(rpcauth_stringify_acceptor);
  330. /*
  331. * Destroy a list of credentials
  332. */
  333. static inline
  334. void rpcauth_destroy_credlist(struct list_head *head)
  335. {
  336. struct rpc_cred *cred;
  337. while (!list_empty(head)) {
  338. cred = list_entry(head->next, struct rpc_cred, cr_lru);
  339. list_del_init(&cred->cr_lru);
  340. put_rpccred(cred);
  341. }
  342. }
  343. static void
  344. rpcauth_lru_add_locked(struct rpc_cred *cred)
  345. {
  346. if (!list_empty(&cred->cr_lru))
  347. return;
  348. number_cred_unused++;
  349. list_add_tail(&cred->cr_lru, &cred_unused);
  350. }
  351. static void
  352. rpcauth_lru_add(struct rpc_cred *cred)
  353. {
  354. if (!list_empty(&cred->cr_lru))
  355. return;
  356. spin_lock(&rpc_credcache_lock);
  357. rpcauth_lru_add_locked(cred);
  358. spin_unlock(&rpc_credcache_lock);
  359. }
  360. static void
  361. rpcauth_lru_remove_locked(struct rpc_cred *cred)
  362. {
  363. if (list_empty(&cred->cr_lru))
  364. return;
  365. number_cred_unused--;
  366. list_del_init(&cred->cr_lru);
  367. }
  368. static void
  369. rpcauth_lru_remove(struct rpc_cred *cred)
  370. {
  371. if (list_empty(&cred->cr_lru))
  372. return;
  373. spin_lock(&rpc_credcache_lock);
  374. rpcauth_lru_remove_locked(cred);
  375. spin_unlock(&rpc_credcache_lock);
  376. }
  377. /*
  378. * Clear the RPC credential cache, and delete those credentials
  379. * that are not referenced.
  380. */
  381. void
  382. rpcauth_clear_credcache(struct rpc_cred_cache *cache)
  383. {
  384. LIST_HEAD(free);
  385. struct hlist_head *head;
  386. struct rpc_cred *cred;
  387. unsigned int hashsize = 1U << cache->hashbits;
  388. int i;
  389. spin_lock(&rpc_credcache_lock);
  390. spin_lock(&cache->lock);
  391. for (i = 0; i < hashsize; i++) {
  392. head = &cache->hashtable[i];
  393. while (!hlist_empty(head)) {
  394. cred = hlist_entry(head->first, struct rpc_cred, cr_hash);
  395. rpcauth_unhash_cred_locked(cred);
  396. /* Note: We now hold a reference to cred */
  397. rpcauth_lru_remove_locked(cred);
  398. list_add_tail(&cred->cr_lru, &free);
  399. }
  400. }
  401. spin_unlock(&cache->lock);
  402. spin_unlock(&rpc_credcache_lock);
  403. rpcauth_destroy_credlist(&free);
  404. }
  405. /*
  406. * Destroy the RPC credential cache
  407. */
  408. void
  409. rpcauth_destroy_credcache(struct rpc_auth *auth)
  410. {
  411. struct rpc_cred_cache *cache = auth->au_credcache;
  412. if (cache) {
  413. auth->au_credcache = NULL;
  414. rpcauth_clear_credcache(cache);
  415. kfree(cache->hashtable);
  416. kfree(cache);
  417. }
  418. }
  419. EXPORT_SYMBOL_GPL(rpcauth_destroy_credcache);
  420. #define RPC_AUTH_EXPIRY_MORATORIUM (60 * HZ)
  421. /*
  422. * Remove stale credentials. Avoid sleeping inside the loop.
  423. */
  424. static long
  425. rpcauth_prune_expired(struct list_head *free, int nr_to_scan)
  426. {
  427. struct rpc_cred *cred, *next;
  428. unsigned long expired = jiffies - RPC_AUTH_EXPIRY_MORATORIUM;
  429. long freed = 0;
  430. list_for_each_entry_safe(cred, next, &cred_unused, cr_lru) {
  431. if (nr_to_scan-- == 0)
  432. break;
  433. if (refcount_read(&cred->cr_count) > 1) {
  434. rpcauth_lru_remove_locked(cred);
  435. continue;
  436. }
  437. /*
  438. * Enforce a 60 second garbage collection moratorium
  439. * Note that the cred_unused list must be time-ordered.
  440. */
  441. if (!time_in_range(cred->cr_expire, expired, jiffies))
  442. continue;
  443. if (!rpcauth_unhash_cred(cred))
  444. continue;
  445. rpcauth_lru_remove_locked(cred);
  446. freed++;
  447. list_add_tail(&cred->cr_lru, free);
  448. }
  449. return freed ? freed : SHRINK_STOP;
  450. }
  451. static unsigned long
  452. rpcauth_cache_do_shrink(int nr_to_scan)
  453. {
  454. LIST_HEAD(free);
  455. unsigned long freed;
  456. spin_lock(&rpc_credcache_lock);
  457. freed = rpcauth_prune_expired(&free, nr_to_scan);
  458. spin_unlock(&rpc_credcache_lock);
  459. rpcauth_destroy_credlist(&free);
  460. return freed;
  461. }
  462. /*
  463. * Run memory cache shrinker.
  464. */
  465. static unsigned long
  466. rpcauth_cache_shrink_scan(struct shrinker *shrink, struct shrink_control *sc)
  467. {
  468. if ((sc->gfp_mask & GFP_KERNEL) != GFP_KERNEL)
  469. return SHRINK_STOP;
  470. /* nothing left, don't come back */
  471. if (list_empty(&cred_unused))
  472. return SHRINK_STOP;
  473. return rpcauth_cache_do_shrink(sc->nr_to_scan);
  474. }
  475. static unsigned long
  476. rpcauth_cache_shrink_count(struct shrinker *shrink, struct shrink_control *sc)
  477. {
  478. return number_cred_unused * sysctl_vfs_cache_pressure / 100;
  479. }
  480. static void
  481. rpcauth_cache_enforce_limit(void)
  482. {
  483. unsigned long diff;
  484. unsigned int nr_to_scan;
  485. if (number_cred_unused <= auth_max_cred_cachesize)
  486. return;
  487. diff = number_cred_unused - auth_max_cred_cachesize;
  488. nr_to_scan = 100;
  489. if (diff < nr_to_scan)
  490. nr_to_scan = diff;
  491. rpcauth_cache_do_shrink(nr_to_scan);
  492. }
  493. /*
  494. * Look up a process' credentials in the authentication cache
  495. */
  496. struct rpc_cred *
  497. rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred,
  498. int flags, gfp_t gfp)
  499. {
  500. LIST_HEAD(free);
  501. struct rpc_cred_cache *cache = auth->au_credcache;
  502. struct rpc_cred *cred = NULL,
  503. *entry, *new;
  504. unsigned int nr;
  505. nr = auth->au_ops->hash_cred(acred, cache->hashbits);
  506. rcu_read_lock();
  507. hlist_for_each_entry_rcu(entry, &cache->hashtable[nr], cr_hash) {
  508. if (!entry->cr_ops->crmatch(acred, entry, flags))
  509. continue;
  510. if (flags & RPCAUTH_LOOKUP_RCU) {
  511. if (test_bit(RPCAUTH_CRED_NEW, &entry->cr_flags) ||
  512. refcount_read(&entry->cr_count) == 0)
  513. continue;
  514. cred = entry;
  515. break;
  516. }
  517. cred = get_rpccred(entry);
  518. if (cred)
  519. break;
  520. }
  521. rcu_read_unlock();
  522. if (cred != NULL)
  523. goto found;
  524. if (flags & RPCAUTH_LOOKUP_RCU)
  525. return ERR_PTR(-ECHILD);
  526. new = auth->au_ops->crcreate(auth, acred, flags, gfp);
  527. if (IS_ERR(new)) {
  528. cred = new;
  529. goto out;
  530. }
  531. spin_lock(&cache->lock);
  532. hlist_for_each_entry(entry, &cache->hashtable[nr], cr_hash) {
  533. if (!entry->cr_ops->crmatch(acred, entry, flags))
  534. continue;
  535. cred = get_rpccred(entry);
  536. if (cred)
  537. break;
  538. }
  539. if (cred == NULL) {
  540. cred = new;
  541. set_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags);
  542. refcount_inc(&cred->cr_count);
  543. hlist_add_head_rcu(&cred->cr_hash, &cache->hashtable[nr]);
  544. } else
  545. list_add_tail(&new->cr_lru, &free);
  546. spin_unlock(&cache->lock);
  547. rpcauth_cache_enforce_limit();
  548. found:
  549. if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) &&
  550. cred->cr_ops->cr_init != NULL &&
  551. !(flags & RPCAUTH_LOOKUP_NEW)) {
  552. int res = cred->cr_ops->cr_init(auth, cred);
  553. if (res < 0) {
  554. put_rpccred(cred);
  555. cred = ERR_PTR(res);
  556. }
  557. }
  558. rpcauth_destroy_credlist(&free);
  559. out:
  560. return cred;
  561. }
  562. EXPORT_SYMBOL_GPL(rpcauth_lookup_credcache);
  563. struct rpc_cred *
  564. rpcauth_lookupcred(struct rpc_auth *auth, int flags)
  565. {
  566. struct auth_cred acred;
  567. struct rpc_cred *ret;
  568. const struct cred *cred = current_cred();
  569. dprintk("RPC: looking up %s cred\n",
  570. auth->au_ops->au_name);
  571. memset(&acred, 0, sizeof(acred));
  572. acred.uid = cred->fsuid;
  573. acred.gid = cred->fsgid;
  574. acred.group_info = cred->group_info;
  575. ret = auth->au_ops->lookup_cred(auth, &acred, flags);
  576. return ret;
  577. }
  578. EXPORT_SYMBOL_GPL(rpcauth_lookupcred);
  579. void
  580. rpcauth_init_cred(struct rpc_cred *cred, const struct auth_cred *acred,
  581. struct rpc_auth *auth, const struct rpc_credops *ops)
  582. {
  583. INIT_HLIST_NODE(&cred->cr_hash);
  584. INIT_LIST_HEAD(&cred->cr_lru);
  585. refcount_set(&cred->cr_count, 1);
  586. cred->cr_auth = auth;
  587. cred->cr_ops = ops;
  588. cred->cr_expire = jiffies;
  589. cred->cr_uid = acred->uid;
  590. }
  591. EXPORT_SYMBOL_GPL(rpcauth_init_cred);
  592. struct rpc_cred *
  593. rpcauth_generic_bind_cred(struct rpc_task *task, struct rpc_cred *cred, int lookupflags)
  594. {
  595. dprintk("RPC: %5u holding %s cred %p\n", task->tk_pid,
  596. cred->cr_auth->au_ops->au_name, cred);
  597. return get_rpccred(cred);
  598. }
  599. EXPORT_SYMBOL_GPL(rpcauth_generic_bind_cred);
  600. static struct rpc_cred *
  601. rpcauth_bind_root_cred(struct rpc_task *task, int lookupflags)
  602. {
  603. struct rpc_auth *auth = task->tk_client->cl_auth;
  604. struct auth_cred acred = {
  605. .uid = GLOBAL_ROOT_UID,
  606. .gid = GLOBAL_ROOT_GID,
  607. };
  608. dprintk("RPC: %5u looking up %s cred\n",
  609. task->tk_pid, task->tk_client->cl_auth->au_ops->au_name);
  610. return auth->au_ops->lookup_cred(auth, &acred, lookupflags);
  611. }
  612. static struct rpc_cred *
  613. rpcauth_bind_new_cred(struct rpc_task *task, int lookupflags)
  614. {
  615. struct rpc_auth *auth = task->tk_client->cl_auth;
  616. dprintk("RPC: %5u looking up %s cred\n",
  617. task->tk_pid, auth->au_ops->au_name);
  618. return rpcauth_lookupcred(auth, lookupflags);
  619. }
  620. static int
  621. rpcauth_bindcred(struct rpc_task *task, struct rpc_cred *cred, int flags)
  622. {
  623. struct rpc_rqst *req = task->tk_rqstp;
  624. struct rpc_cred *new;
  625. int lookupflags = 0;
  626. if (flags & RPC_TASK_ASYNC)
  627. lookupflags |= RPCAUTH_LOOKUP_NEW;
  628. if (cred != NULL)
  629. new = cred->cr_ops->crbind(task, cred, lookupflags);
  630. else if (flags & RPC_TASK_ROOTCREDS)
  631. new = rpcauth_bind_root_cred(task, lookupflags);
  632. else
  633. new = rpcauth_bind_new_cred(task, lookupflags);
  634. if (IS_ERR(new))
  635. return PTR_ERR(new);
  636. put_rpccred(req->rq_cred);
  637. req->rq_cred = new;
  638. return 0;
  639. }
  640. void
  641. put_rpccred(struct rpc_cred *cred)
  642. {
  643. if (cred == NULL)
  644. return;
  645. rcu_read_lock();
  646. if (refcount_dec_and_test(&cred->cr_count))
  647. goto destroy;
  648. if (refcount_read(&cred->cr_count) != 1 ||
  649. !test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags))
  650. goto out;
  651. if (test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0) {
  652. cred->cr_expire = jiffies;
  653. rpcauth_lru_add(cred);
  654. /* Race breaker */
  655. if (unlikely(!test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags)))
  656. rpcauth_lru_remove(cred);
  657. } else if (rpcauth_unhash_cred(cred)) {
  658. rpcauth_lru_remove(cred);
  659. if (refcount_dec_and_test(&cred->cr_count))
  660. goto destroy;
  661. }
  662. out:
  663. rcu_read_unlock();
  664. return;
  665. destroy:
  666. rcu_read_unlock();
  667. cred->cr_ops->crdestroy(cred);
  668. }
  669. EXPORT_SYMBOL_GPL(put_rpccred);
  670. __be32 *
  671. rpcauth_marshcred(struct rpc_task *task, __be32 *p)
  672. {
  673. struct rpc_cred *cred = task->tk_rqstp->rq_cred;
  674. dprintk("RPC: %5u marshaling %s cred %p\n",
  675. task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
  676. return cred->cr_ops->crmarshal(task, p);
  677. }
  678. __be32 *
  679. rpcauth_checkverf(struct rpc_task *task, __be32 *p)
  680. {
  681. struct rpc_cred *cred = task->tk_rqstp->rq_cred;
  682. dprintk("RPC: %5u validating %s cred %p\n",
  683. task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
  684. return cred->cr_ops->crvalidate(task, p);
  685. }
  686. static void rpcauth_wrap_req_encode(kxdreproc_t encode, struct rpc_rqst *rqstp,
  687. __be32 *data, void *obj)
  688. {
  689. struct xdr_stream xdr;
  690. xdr_init_encode(&xdr, &rqstp->rq_snd_buf, data);
  691. encode(rqstp, &xdr, obj);
  692. }
  693. int
  694. rpcauth_wrap_req(struct rpc_task *task, kxdreproc_t encode, void *rqstp,
  695. __be32 *data, void *obj)
  696. {
  697. struct rpc_cred *cred = task->tk_rqstp->rq_cred;
  698. dprintk("RPC: %5u using %s cred %p to wrap rpc data\n",
  699. task->tk_pid, cred->cr_ops->cr_name, cred);
  700. if (cred->cr_ops->crwrap_req)
  701. return cred->cr_ops->crwrap_req(task, encode, rqstp, data, obj);
  702. /* By default, we encode the arguments normally. */
  703. rpcauth_wrap_req_encode(encode, rqstp, data, obj);
  704. return 0;
  705. }
  706. static int
  707. rpcauth_unwrap_req_decode(kxdrdproc_t decode, struct rpc_rqst *rqstp,
  708. __be32 *data, void *obj)
  709. {
  710. struct xdr_stream xdr;
  711. xdr_init_decode(&xdr, &rqstp->rq_rcv_buf, data);
  712. return decode(rqstp, &xdr, obj);
  713. }
  714. int
  715. rpcauth_unwrap_resp(struct rpc_task *task, kxdrdproc_t decode, void *rqstp,
  716. __be32 *data, void *obj)
  717. {
  718. struct rpc_cred *cred = task->tk_rqstp->rq_cred;
  719. dprintk("RPC: %5u using %s cred %p to unwrap rpc data\n",
  720. task->tk_pid, cred->cr_ops->cr_name, cred);
  721. if (cred->cr_ops->crunwrap_resp)
  722. return cred->cr_ops->crunwrap_resp(task, decode, rqstp,
  723. data, obj);
  724. /* By default, we decode the arguments normally. */
  725. return rpcauth_unwrap_req_decode(decode, rqstp, data, obj);
  726. }
  727. bool
  728. rpcauth_xmit_need_reencode(struct rpc_task *task)
  729. {
  730. struct rpc_cred *cred = task->tk_rqstp->rq_cred;
  731. if (!cred || !cred->cr_ops->crneed_reencode)
  732. return false;
  733. return cred->cr_ops->crneed_reencode(task);
  734. }
  735. int
  736. rpcauth_refreshcred(struct rpc_task *task)
  737. {
  738. struct rpc_cred *cred;
  739. int err;
  740. cred = task->tk_rqstp->rq_cred;
  741. if (cred == NULL) {
  742. err = rpcauth_bindcred(task, task->tk_msg.rpc_cred, task->tk_flags);
  743. if (err < 0)
  744. goto out;
  745. cred = task->tk_rqstp->rq_cred;
  746. }
  747. dprintk("RPC: %5u refreshing %s cred %p\n",
  748. task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
  749. err = cred->cr_ops->crrefresh(task);
  750. out:
  751. if (err < 0)
  752. task->tk_status = err;
  753. return err;
  754. }
  755. void
  756. rpcauth_invalcred(struct rpc_task *task)
  757. {
  758. struct rpc_cred *cred = task->tk_rqstp->rq_cred;
  759. dprintk("RPC: %5u invalidating %s cred %p\n",
  760. task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
  761. if (cred)
  762. clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
  763. }
  764. int
  765. rpcauth_uptodatecred(struct rpc_task *task)
  766. {
  767. struct rpc_cred *cred = task->tk_rqstp->rq_cred;
  768. return cred == NULL ||
  769. test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0;
  770. }
  771. static struct shrinker rpc_cred_shrinker = {
  772. .count_objects = rpcauth_cache_shrink_count,
  773. .scan_objects = rpcauth_cache_shrink_scan,
  774. .seeks = DEFAULT_SEEKS,
  775. };
  776. int __init rpcauth_init_module(void)
  777. {
  778. int err;
  779. err = rpc_init_authunix();
  780. if (err < 0)
  781. goto out1;
  782. err = rpc_init_generic_auth();
  783. if (err < 0)
  784. goto out2;
  785. err = register_shrinker(&rpc_cred_shrinker);
  786. if (err < 0)
  787. goto out3;
  788. return 0;
  789. out3:
  790. rpc_destroy_generic_auth();
  791. out2:
  792. rpc_destroy_authunix();
  793. out1:
  794. return err;
  795. }
  796. void rpcauth_remove_module(void)
  797. {
  798. rpc_destroy_authunix();
  799. rpc_destroy_generic_auth();
  800. unregister_shrinker(&rpc_cred_shrinker);
  801. }