tls_main.c 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743
  1. /*
  2. * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
  3. * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
  4. *
  5. * This software is available to you under a choice of one of two
  6. * licenses. You may choose to be licensed under the terms of the GNU
  7. * General Public License (GPL) Version 2, available from the file
  8. * COPYING in the main directory of this source tree, or the
  9. * OpenIB.org BSD license below:
  10. *
  11. * Redistribution and use in source and binary forms, with or
  12. * without modification, are permitted provided that the following
  13. * conditions are met:
  14. *
  15. * - Redistributions of source code must retain the above
  16. * copyright notice, this list of conditions and the following
  17. * disclaimer.
  18. *
  19. * - Redistributions in binary form must reproduce the above
  20. * copyright notice, this list of conditions and the following
  21. * disclaimer in the documentation and/or other materials
  22. * provided with the distribution.
  23. *
  24. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  25. * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  26. * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  27. * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
  28. * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
  29. * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  30. * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  31. * SOFTWARE.
  32. */
  33. #include <linux/module.h>
  34. #include <net/tcp.h>
  35. #include <net/inet_common.h>
  36. #include <linux/highmem.h>
  37. #include <linux/netdevice.h>
  38. #include <linux/sched/signal.h>
  39. #include <linux/inetdevice.h>
  40. #include <net/tls.h>
  41. MODULE_AUTHOR("Mellanox Technologies");
  42. MODULE_DESCRIPTION("Transport Layer Security Support");
  43. MODULE_LICENSE("Dual BSD/GPL");
  44. enum {
  45. TLSV4,
  46. TLSV6,
  47. TLS_NUM_PROTS,
  48. };
  49. static struct proto *saved_tcpv6_prot;
  50. static DEFINE_MUTEX(tcpv6_prot_mutex);
  51. static LIST_HEAD(device_list);
  52. static DEFINE_MUTEX(device_mutex);
  53. static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
  54. static struct proto_ops tls_sw_proto_ops;
  55. static void update_sk_prot(struct sock *sk, struct tls_context *ctx)
  56. {
  57. int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
  58. sk->sk_prot = &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf];
  59. }
  60. int wait_on_pending_writer(struct sock *sk, long *timeo)
  61. {
  62. int rc = 0;
  63. DEFINE_WAIT_FUNC(wait, woken_wake_function);
  64. add_wait_queue(sk_sleep(sk), &wait);
  65. while (1) {
  66. if (!*timeo) {
  67. rc = -EAGAIN;
  68. break;
  69. }
  70. if (signal_pending(current)) {
  71. rc = sock_intr_errno(*timeo);
  72. break;
  73. }
  74. if (sk_wait_event(sk, timeo, !sk->sk_write_pending, &wait))
  75. break;
  76. }
  77. remove_wait_queue(sk_sleep(sk), &wait);
  78. return rc;
  79. }
  80. int tls_push_sg(struct sock *sk,
  81. struct tls_context *ctx,
  82. struct scatterlist *sg,
  83. u16 first_offset,
  84. int flags)
  85. {
  86. int sendpage_flags = flags | MSG_SENDPAGE_NOTLAST;
  87. int ret = 0;
  88. struct page *p;
  89. size_t size;
  90. int offset = first_offset;
  91. size = sg->length - offset;
  92. offset += sg->offset;
  93. ctx->in_tcp_sendpages = true;
  94. while (1) {
  95. if (sg_is_last(sg))
  96. sendpage_flags = flags;
  97. /* is sending application-limited? */
  98. tcp_rate_check_app_limited(sk);
  99. p = sg_page(sg);
  100. retry:
  101. ret = do_tcp_sendpages(sk, p, offset, size, sendpage_flags);
  102. if (ret != size) {
  103. if (ret > 0) {
  104. offset += ret;
  105. size -= ret;
  106. goto retry;
  107. }
  108. offset -= sg->offset;
  109. ctx->partially_sent_offset = offset;
  110. ctx->partially_sent_record = (void *)sg;
  111. ctx->in_tcp_sendpages = false;
  112. return ret;
  113. }
  114. put_page(p);
  115. sk_mem_uncharge(sk, sg->length);
  116. sg = sg_next(sg);
  117. if (!sg)
  118. break;
  119. offset = sg->offset;
  120. size = sg->length;
  121. }
  122. clear_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags);
  123. ctx->in_tcp_sendpages = false;
  124. ctx->sk_write_space(sk);
  125. return 0;
  126. }
  127. static int tls_handle_open_record(struct sock *sk, int flags)
  128. {
  129. struct tls_context *ctx = tls_get_ctx(sk);
  130. if (tls_is_pending_open_record(ctx))
  131. return ctx->push_pending_record(sk, flags);
  132. return 0;
  133. }
  134. int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg,
  135. unsigned char *record_type)
  136. {
  137. struct cmsghdr *cmsg;
  138. int rc = -EINVAL;
  139. for_each_cmsghdr(cmsg, msg) {
  140. if (!CMSG_OK(msg, cmsg))
  141. return -EINVAL;
  142. if (cmsg->cmsg_level != SOL_TLS)
  143. continue;
  144. switch (cmsg->cmsg_type) {
  145. case TLS_SET_RECORD_TYPE:
  146. if (cmsg->cmsg_len < CMSG_LEN(sizeof(*record_type)))
  147. return -EINVAL;
  148. if (msg->msg_flags & MSG_MORE)
  149. return -EINVAL;
  150. rc = tls_handle_open_record(sk, msg->msg_flags);
  151. if (rc)
  152. return rc;
  153. *record_type = *(unsigned char *)CMSG_DATA(cmsg);
  154. rc = 0;
  155. break;
  156. default:
  157. return -EINVAL;
  158. }
  159. }
  160. return rc;
  161. }
  162. int tls_push_pending_closed_record(struct sock *sk, struct tls_context *ctx,
  163. int flags, long *timeo)
  164. {
  165. struct scatterlist *sg;
  166. u16 offset;
  167. if (!tls_is_partially_sent_record(ctx))
  168. return ctx->push_pending_record(sk, flags);
  169. sg = ctx->partially_sent_record;
  170. offset = ctx->partially_sent_offset;
  171. ctx->partially_sent_record = NULL;
  172. return tls_push_sg(sk, ctx, sg, offset, flags);
  173. }
  174. static void tls_write_space(struct sock *sk)
  175. {
  176. struct tls_context *ctx = tls_get_ctx(sk);
  177. /* We are already sending pages, ignore notification */
  178. if (ctx->in_tcp_sendpages)
  179. return;
  180. if (!sk->sk_write_pending && tls_is_pending_closed_record(ctx)) {
  181. gfp_t sk_allocation = sk->sk_allocation;
  182. int rc;
  183. long timeo = 0;
  184. sk->sk_allocation = GFP_ATOMIC;
  185. rc = tls_push_pending_closed_record(sk, ctx,
  186. MSG_DONTWAIT |
  187. MSG_NOSIGNAL,
  188. &timeo);
  189. sk->sk_allocation = sk_allocation;
  190. if (rc < 0)
  191. return;
  192. }
  193. ctx->sk_write_space(sk);
  194. }
  195. static void tls_sk_proto_close(struct sock *sk, long timeout)
  196. {
  197. struct tls_context *ctx = tls_get_ctx(sk);
  198. long timeo = sock_sndtimeo(sk, 0);
  199. void (*sk_proto_close)(struct sock *sk, long timeout);
  200. bool free_ctx = false;
  201. lock_sock(sk);
  202. sk_proto_close = ctx->sk_proto_close;
  203. if ((ctx->tx_conf == TLS_HW_RECORD && ctx->rx_conf == TLS_HW_RECORD) ||
  204. (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE)) {
  205. free_ctx = true;
  206. goto skip_tx_cleanup;
  207. }
  208. if (!tls_complete_pending_work(sk, ctx, 0, &timeo))
  209. tls_handle_open_record(sk, 0);
  210. if (ctx->partially_sent_record) {
  211. struct scatterlist *sg = ctx->partially_sent_record;
  212. while (1) {
  213. put_page(sg_page(sg));
  214. sk_mem_uncharge(sk, sg->length);
  215. if (sg_is_last(sg))
  216. break;
  217. sg++;
  218. }
  219. }
  220. /* We need these for tls_sw_fallback handling of other packets */
  221. if (ctx->tx_conf == TLS_SW) {
  222. kfree(ctx->tx.rec_seq);
  223. kfree(ctx->tx.iv);
  224. tls_sw_free_resources_tx(sk);
  225. }
  226. if (ctx->rx_conf == TLS_SW) {
  227. kfree(ctx->rx.rec_seq);
  228. kfree(ctx->rx.iv);
  229. tls_sw_free_resources_rx(sk);
  230. }
  231. #ifdef CONFIG_TLS_DEVICE
  232. if (ctx->rx_conf == TLS_HW)
  233. tls_device_offload_cleanup_rx(sk);
  234. if (ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW) {
  235. #else
  236. {
  237. #endif
  238. kfree(ctx);
  239. ctx = NULL;
  240. }
  241. skip_tx_cleanup:
  242. release_sock(sk);
  243. sk_proto_close(sk, timeout);
  244. /* free ctx for TLS_HW_RECORD, used by tcp_set_state
  245. * for sk->sk_prot->unhash [tls_hw_unhash]
  246. */
  247. if (free_ctx)
  248. kfree(ctx);
  249. }
  250. static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval,
  251. int __user *optlen)
  252. {
  253. int rc = 0;
  254. struct tls_context *ctx = tls_get_ctx(sk);
  255. struct tls_crypto_info *crypto_info;
  256. int len;
  257. if (get_user(len, optlen))
  258. return -EFAULT;
  259. if (!optval || (len < sizeof(*crypto_info))) {
  260. rc = -EINVAL;
  261. goto out;
  262. }
  263. if (!ctx) {
  264. rc = -EBUSY;
  265. goto out;
  266. }
  267. /* get user crypto info */
  268. crypto_info = &ctx->crypto_send;
  269. if (!TLS_CRYPTO_INFO_READY(crypto_info)) {
  270. rc = -EBUSY;
  271. goto out;
  272. }
  273. if (len == sizeof(*crypto_info)) {
  274. if (copy_to_user(optval, crypto_info, sizeof(*crypto_info)))
  275. rc = -EFAULT;
  276. goto out;
  277. }
  278. switch (crypto_info->cipher_type) {
  279. case TLS_CIPHER_AES_GCM_128: {
  280. struct tls12_crypto_info_aes_gcm_128 *
  281. crypto_info_aes_gcm_128 =
  282. container_of(crypto_info,
  283. struct tls12_crypto_info_aes_gcm_128,
  284. info);
  285. if (len != sizeof(*crypto_info_aes_gcm_128)) {
  286. rc = -EINVAL;
  287. goto out;
  288. }
  289. lock_sock(sk);
  290. memcpy(crypto_info_aes_gcm_128->iv,
  291. ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
  292. TLS_CIPHER_AES_GCM_128_IV_SIZE);
  293. memcpy(crypto_info_aes_gcm_128->rec_seq, ctx->tx.rec_seq,
  294. TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
  295. release_sock(sk);
  296. if (copy_to_user(optval,
  297. crypto_info_aes_gcm_128,
  298. sizeof(*crypto_info_aes_gcm_128)))
  299. rc = -EFAULT;
  300. break;
  301. }
  302. default:
  303. rc = -EINVAL;
  304. }
  305. out:
  306. return rc;
  307. }
  308. static int do_tls_getsockopt(struct sock *sk, int optname,
  309. char __user *optval, int __user *optlen)
  310. {
  311. int rc = 0;
  312. switch (optname) {
  313. case TLS_TX:
  314. rc = do_tls_getsockopt_tx(sk, optval, optlen);
  315. break;
  316. default:
  317. rc = -ENOPROTOOPT;
  318. break;
  319. }
  320. return rc;
  321. }
  322. static int tls_getsockopt(struct sock *sk, int level, int optname,
  323. char __user *optval, int __user *optlen)
  324. {
  325. struct tls_context *ctx = tls_get_ctx(sk);
  326. if (level != SOL_TLS)
  327. return ctx->getsockopt(sk, level, optname, optval, optlen);
  328. return do_tls_getsockopt(sk, optname, optval, optlen);
  329. }
  330. static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
  331. unsigned int optlen, int tx)
  332. {
  333. struct tls_crypto_info *crypto_info;
  334. struct tls_context *ctx = tls_get_ctx(sk);
  335. int rc = 0;
  336. int conf;
  337. if (!optval || (optlen < sizeof(*crypto_info))) {
  338. rc = -EINVAL;
  339. goto out;
  340. }
  341. if (tx)
  342. crypto_info = &ctx->crypto_send;
  343. else
  344. crypto_info = &ctx->crypto_recv;
  345. /* Currently we don't support set crypto info more than one time */
  346. if (TLS_CRYPTO_INFO_READY(crypto_info)) {
  347. rc = -EBUSY;
  348. goto out;
  349. }
  350. rc = copy_from_user(crypto_info, optval, sizeof(*crypto_info));
  351. if (rc) {
  352. rc = -EFAULT;
  353. goto err_crypto_info;
  354. }
  355. /* check version */
  356. if (crypto_info->version != TLS_1_2_VERSION) {
  357. rc = -ENOTSUPP;
  358. goto err_crypto_info;
  359. }
  360. switch (crypto_info->cipher_type) {
  361. case TLS_CIPHER_AES_GCM_128: {
  362. if (optlen != sizeof(struct tls12_crypto_info_aes_gcm_128)) {
  363. rc = -EINVAL;
  364. goto err_crypto_info;
  365. }
  366. rc = copy_from_user(crypto_info + 1, optval + sizeof(*crypto_info),
  367. optlen - sizeof(*crypto_info));
  368. if (rc) {
  369. rc = -EFAULT;
  370. goto err_crypto_info;
  371. }
  372. break;
  373. }
  374. default:
  375. rc = -EINVAL;
  376. goto err_crypto_info;
  377. }
  378. if (tx) {
  379. #ifdef CONFIG_TLS_DEVICE
  380. rc = tls_set_device_offload(sk, ctx);
  381. conf = TLS_HW;
  382. if (rc) {
  383. #else
  384. {
  385. #endif
  386. rc = tls_set_sw_offload(sk, ctx, 1);
  387. conf = TLS_SW;
  388. }
  389. } else {
  390. #ifdef CONFIG_TLS_DEVICE
  391. rc = tls_set_device_offload_rx(sk, ctx);
  392. conf = TLS_HW;
  393. if (rc) {
  394. #else
  395. {
  396. #endif
  397. rc = tls_set_sw_offload(sk, ctx, 0);
  398. conf = TLS_SW;
  399. }
  400. }
  401. if (rc)
  402. goto err_crypto_info;
  403. if (tx)
  404. ctx->tx_conf = conf;
  405. else
  406. ctx->rx_conf = conf;
  407. update_sk_prot(sk, ctx);
  408. if (tx) {
  409. ctx->sk_write_space = sk->sk_write_space;
  410. sk->sk_write_space = tls_write_space;
  411. } else {
  412. sk->sk_socket->ops = &tls_sw_proto_ops;
  413. }
  414. goto out;
  415. err_crypto_info:
  416. memset(crypto_info, 0, sizeof(*crypto_info));
  417. out:
  418. return rc;
  419. }
  420. static int do_tls_setsockopt(struct sock *sk, int optname,
  421. char __user *optval, unsigned int optlen)
  422. {
  423. int rc = 0;
  424. switch (optname) {
  425. case TLS_TX:
  426. case TLS_RX:
  427. lock_sock(sk);
  428. rc = do_tls_setsockopt_conf(sk, optval, optlen,
  429. optname == TLS_TX);
  430. release_sock(sk);
  431. break;
  432. default:
  433. rc = -ENOPROTOOPT;
  434. break;
  435. }
  436. return rc;
  437. }
  438. static int tls_setsockopt(struct sock *sk, int level, int optname,
  439. char __user *optval, unsigned int optlen)
  440. {
  441. struct tls_context *ctx = tls_get_ctx(sk);
  442. if (level != SOL_TLS)
  443. return ctx->setsockopt(sk, level, optname, optval, optlen);
  444. return do_tls_setsockopt(sk, optname, optval, optlen);
  445. }
  446. static struct tls_context *create_ctx(struct sock *sk)
  447. {
  448. struct inet_connection_sock *icsk = inet_csk(sk);
  449. struct tls_context *ctx;
  450. ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
  451. if (!ctx)
  452. return NULL;
  453. icsk->icsk_ulp_data = ctx;
  454. return ctx;
  455. }
  456. static int tls_hw_prot(struct sock *sk)
  457. {
  458. struct tls_context *ctx;
  459. struct tls_device *dev;
  460. int rc = 0;
  461. mutex_lock(&device_mutex);
  462. list_for_each_entry(dev, &device_list, dev_list) {
  463. if (dev->feature && dev->feature(dev)) {
  464. ctx = create_ctx(sk);
  465. if (!ctx)
  466. goto out;
  467. ctx->hash = sk->sk_prot->hash;
  468. ctx->unhash = sk->sk_prot->unhash;
  469. ctx->sk_proto_close = sk->sk_prot->close;
  470. ctx->rx_conf = TLS_HW_RECORD;
  471. ctx->tx_conf = TLS_HW_RECORD;
  472. update_sk_prot(sk, ctx);
  473. rc = 1;
  474. break;
  475. }
  476. }
  477. out:
  478. mutex_unlock(&device_mutex);
  479. return rc;
  480. }
  481. static void tls_hw_unhash(struct sock *sk)
  482. {
  483. struct tls_context *ctx = tls_get_ctx(sk);
  484. struct tls_device *dev;
  485. mutex_lock(&device_mutex);
  486. list_for_each_entry(dev, &device_list, dev_list) {
  487. if (dev->unhash)
  488. dev->unhash(dev, sk);
  489. }
  490. mutex_unlock(&device_mutex);
  491. ctx->unhash(sk);
  492. }
  493. static int tls_hw_hash(struct sock *sk)
  494. {
  495. struct tls_context *ctx = tls_get_ctx(sk);
  496. struct tls_device *dev;
  497. int err;
  498. err = ctx->hash(sk);
  499. mutex_lock(&device_mutex);
  500. list_for_each_entry(dev, &device_list, dev_list) {
  501. if (dev->hash)
  502. err |= dev->hash(dev, sk);
  503. }
  504. mutex_unlock(&device_mutex);
  505. if (err)
  506. tls_hw_unhash(sk);
  507. return err;
  508. }
  509. static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
  510. struct proto *base)
  511. {
  512. prot[TLS_BASE][TLS_BASE] = *base;
  513. prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt;
  514. prot[TLS_BASE][TLS_BASE].getsockopt = tls_getsockopt;
  515. prot[TLS_BASE][TLS_BASE].close = tls_sk_proto_close;
  516. prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
  517. prot[TLS_SW][TLS_BASE].sendmsg = tls_sw_sendmsg;
  518. prot[TLS_SW][TLS_BASE].sendpage = tls_sw_sendpage;
  519. prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE];
  520. prot[TLS_BASE][TLS_SW].recvmsg = tls_sw_recvmsg;
  521. prot[TLS_BASE][TLS_SW].close = tls_sk_proto_close;
  522. prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE];
  523. prot[TLS_SW][TLS_SW].recvmsg = tls_sw_recvmsg;
  524. prot[TLS_SW][TLS_SW].close = tls_sk_proto_close;
  525. #ifdef CONFIG_TLS_DEVICE
  526. prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
  527. prot[TLS_HW][TLS_BASE].sendmsg = tls_device_sendmsg;
  528. prot[TLS_HW][TLS_BASE].sendpage = tls_device_sendpage;
  529. prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW];
  530. prot[TLS_HW][TLS_SW].sendmsg = tls_device_sendmsg;
  531. prot[TLS_HW][TLS_SW].sendpage = tls_device_sendpage;
  532. prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW];
  533. prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW];
  534. prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW];
  535. #endif
  536. prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
  537. prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_hw_hash;
  538. prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_hw_unhash;
  539. prot[TLS_HW_RECORD][TLS_HW_RECORD].close = tls_sk_proto_close;
  540. }
  541. static int tls_init(struct sock *sk)
  542. {
  543. int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
  544. struct tls_context *ctx;
  545. int rc = 0;
  546. if (tls_hw_prot(sk))
  547. goto out;
  548. /* The TLS ulp is currently supported only for TCP sockets
  549. * in ESTABLISHED state.
  550. * Supporting sockets in LISTEN state will require us
  551. * to modify the accept implementation to clone rather then
  552. * share the ulp context.
  553. */
  554. if (sk->sk_state != TCP_ESTABLISHED)
  555. return -ENOTSUPP;
  556. /* allocate tls context */
  557. ctx = create_ctx(sk);
  558. if (!ctx) {
  559. rc = -ENOMEM;
  560. goto out;
  561. }
  562. ctx->setsockopt = sk->sk_prot->setsockopt;
  563. ctx->getsockopt = sk->sk_prot->getsockopt;
  564. ctx->sk_proto_close = sk->sk_prot->close;
  565. /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */
  566. if (ip_ver == TLSV6 &&
  567. unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) {
  568. mutex_lock(&tcpv6_prot_mutex);
  569. if (likely(sk->sk_prot != saved_tcpv6_prot)) {
  570. build_protos(tls_prots[TLSV6], sk->sk_prot);
  571. smp_store_release(&saved_tcpv6_prot, sk->sk_prot);
  572. }
  573. mutex_unlock(&tcpv6_prot_mutex);
  574. }
  575. ctx->tx_conf = TLS_BASE;
  576. ctx->rx_conf = TLS_BASE;
  577. update_sk_prot(sk, ctx);
  578. out:
  579. return rc;
  580. }
  581. void tls_register_device(struct tls_device *device)
  582. {
  583. mutex_lock(&device_mutex);
  584. list_add_tail(&device->dev_list, &device_list);
  585. mutex_unlock(&device_mutex);
  586. }
  587. EXPORT_SYMBOL(tls_register_device);
  588. void tls_unregister_device(struct tls_device *device)
  589. {
  590. mutex_lock(&device_mutex);
  591. list_del(&device->dev_list);
  592. mutex_unlock(&device_mutex);
  593. }
  594. EXPORT_SYMBOL(tls_unregister_device);
  595. static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
  596. .name = "tls",
  597. .uid = TCP_ULP_TLS,
  598. .user_visible = true,
  599. .owner = THIS_MODULE,
  600. .init = tls_init,
  601. };
  602. static int __init tls_register(void)
  603. {
  604. build_protos(tls_prots[TLSV4], &tcp_prot);
  605. tls_sw_proto_ops = inet_stream_ops;
  606. tls_sw_proto_ops.poll = tls_sw_poll;
  607. tls_sw_proto_ops.splice_read = tls_sw_splice_read;
  608. #ifdef CONFIG_TLS_DEVICE
  609. tls_device_init();
  610. #endif
  611. tcp_register_ulp(&tcp_tls_ulp_ops);
  612. return 0;
  613. }
  614. static void __exit tls_unregister(void)
  615. {
  616. tcp_unregister_ulp(&tcp_tls_ulp_ops);
  617. #ifdef CONFIG_TLS_DEVICE
  618. tls_device_cleanup();
  619. #endif
  620. }
  621. module_init(tls_register);
  622. module_exit(tls_unregister);