tcp_bpf.c 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668
  1. // SPDX-License-Identifier: GPL-2.0
  2. /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
  3. #include <linux/skmsg.h>
  4. #include <linux/filter.h>
  5. #include <linux/bpf.h>
  6. #include <linux/init.h>
  7. #include <linux/wait.h>
  8. #include <net/inet_common.h>
  9. static bool tcp_bpf_stream_read(const struct sock *sk)
  10. {
  11. struct sk_psock *psock;
  12. bool empty = true;
  13. rcu_read_lock();
  14. psock = sk_psock(sk);
  15. if (likely(psock))
  16. empty = list_empty(&psock->ingress_msg);
  17. rcu_read_unlock();
  18. return !empty;
  19. }
  20. static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock,
  21. int flags, long timeo, int *err)
  22. {
  23. DEFINE_WAIT_FUNC(wait, woken_wake_function);
  24. int ret;
  25. add_wait_queue(sk_sleep(sk), &wait);
  26. sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
  27. ret = sk_wait_event(sk, &timeo,
  28. !list_empty(&psock->ingress_msg) ||
  29. !skb_queue_empty(&sk->sk_receive_queue), &wait);
  30. sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
  31. remove_wait_queue(sk_sleep(sk), &wait);
  32. return ret;
  33. }
  34. int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
  35. struct msghdr *msg, int len, int flags)
  36. {
  37. struct iov_iter *iter = &msg->msg_iter;
  38. int peek = flags & MSG_PEEK;
  39. int i, ret, copied = 0;
  40. struct sk_msg *msg_rx;
  41. msg_rx = list_first_entry_or_null(&psock->ingress_msg,
  42. struct sk_msg, list);
  43. while (copied != len) {
  44. struct scatterlist *sge;
  45. if (unlikely(!msg_rx))
  46. break;
  47. i = msg_rx->sg.start;
  48. do {
  49. struct page *page;
  50. int copy;
  51. sge = sk_msg_elem(msg_rx, i);
  52. copy = sge->length;
  53. page = sg_page(sge);
  54. if (copied + copy > len)
  55. copy = len - copied;
  56. ret = copy_page_to_iter(page, sge->offset, copy, iter);
  57. if (ret != copy) {
  58. msg_rx->sg.start = i;
  59. return -EFAULT;
  60. }
  61. copied += copy;
  62. if (likely(!peek)) {
  63. sge->offset += copy;
  64. sge->length -= copy;
  65. sk_mem_uncharge(sk, copy);
  66. msg_rx->sg.size -= copy;
  67. if (!sge->length) {
  68. sk_msg_iter_var_next(i);
  69. if (!msg_rx->skb)
  70. put_page(page);
  71. }
  72. } else {
  73. sk_msg_iter_var_next(i);
  74. }
  75. if (copied == len)
  76. break;
  77. } while (i != msg_rx->sg.end);
  78. if (unlikely(peek)) {
  79. msg_rx = list_next_entry(msg_rx, list);
  80. continue;
  81. }
  82. msg_rx->sg.start = i;
  83. if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) {
  84. list_del(&msg_rx->list);
  85. if (msg_rx->skb)
  86. consume_skb(msg_rx->skb);
  87. kfree(msg_rx);
  88. }
  89. msg_rx = list_first_entry_or_null(&psock->ingress_msg,
  90. struct sk_msg, list);
  91. }
  92. return copied;
  93. }
  94. EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg);
  95. int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
  96. int nonblock, int flags, int *addr_len)
  97. {
  98. struct sk_psock *psock;
  99. int copied, ret;
  100. if (unlikely(flags & MSG_ERRQUEUE))
  101. return inet_recv_error(sk, msg, len, addr_len);
  102. if (!skb_queue_empty(&sk->sk_receive_queue))
  103. return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
  104. psock = sk_psock_get(sk);
  105. if (unlikely(!psock))
  106. return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
  107. lock_sock(sk);
  108. msg_bytes_ready:
  109. copied = __tcp_bpf_recvmsg(sk, psock, msg, len, flags);
  110. if (!copied) {
  111. int data, err = 0;
  112. long timeo;
  113. timeo = sock_rcvtimeo(sk, nonblock);
  114. data = tcp_bpf_wait_data(sk, psock, flags, timeo, &err);
  115. if (data) {
  116. if (skb_queue_empty(&sk->sk_receive_queue))
  117. goto msg_bytes_ready;
  118. release_sock(sk);
  119. sk_psock_put(sk, psock);
  120. return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
  121. }
  122. if (err) {
  123. ret = err;
  124. goto out;
  125. }
  126. }
  127. ret = copied;
  128. out:
  129. release_sock(sk);
  130. sk_psock_put(sk, psock);
  131. return ret;
  132. }
  133. static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock,
  134. struct sk_msg *msg, u32 apply_bytes, int flags)
  135. {
  136. bool apply = apply_bytes;
  137. struct scatterlist *sge;
  138. u32 size, copied = 0;
  139. struct sk_msg *tmp;
  140. int i, ret = 0;
  141. tmp = kzalloc(sizeof(*tmp), __GFP_NOWARN | GFP_KERNEL);
  142. if (unlikely(!tmp))
  143. return -ENOMEM;
  144. lock_sock(sk);
  145. tmp->sg.start = msg->sg.start;
  146. i = msg->sg.start;
  147. do {
  148. sge = sk_msg_elem(msg, i);
  149. size = (apply && apply_bytes < sge->length) ?
  150. apply_bytes : sge->length;
  151. if (!sk_wmem_schedule(sk, size)) {
  152. if (!copied)
  153. ret = -ENOMEM;
  154. break;
  155. }
  156. sk_mem_charge(sk, size);
  157. sk_msg_xfer(tmp, msg, i, size);
  158. copied += size;
  159. if (sge->length)
  160. get_page(sk_msg_page(tmp, i));
  161. sk_msg_iter_var_next(i);
  162. tmp->sg.end = i;
  163. if (apply) {
  164. apply_bytes -= size;
  165. if (!apply_bytes)
  166. break;
  167. }
  168. } while (i != msg->sg.end);
  169. if (!ret) {
  170. msg->sg.start = i;
  171. msg->sg.size -= apply_bytes;
  172. sk_psock_queue_msg(psock, tmp);
  173. sk->sk_data_ready(sk);
  174. } else {
  175. sk_msg_free(sk, tmp);
  176. kfree(tmp);
  177. }
  178. release_sock(sk);
  179. return ret;
  180. }
  181. static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes,
  182. int flags, bool uncharge)
  183. {
  184. bool apply = apply_bytes;
  185. struct scatterlist *sge;
  186. struct page *page;
  187. int size, ret = 0;
  188. u32 off;
  189. while (1) {
  190. sge = sk_msg_elem(msg, msg->sg.start);
  191. size = (apply && apply_bytes < sge->length) ?
  192. apply_bytes : sge->length;
  193. off = sge->offset;
  194. page = sg_page(sge);
  195. tcp_rate_check_app_limited(sk);
  196. retry:
  197. ret = do_tcp_sendpages(sk, page, off, size, flags);
  198. if (ret <= 0)
  199. return ret;
  200. if (apply)
  201. apply_bytes -= ret;
  202. msg->sg.size -= ret;
  203. sge->offset += ret;
  204. sge->length -= ret;
  205. if (uncharge)
  206. sk_mem_uncharge(sk, ret);
  207. if (ret != size) {
  208. size -= ret;
  209. off += ret;
  210. goto retry;
  211. }
  212. if (!sge->length) {
  213. put_page(page);
  214. sk_msg_iter_next(msg, start);
  215. sg_init_table(sge, 1);
  216. if (msg->sg.start == msg->sg.end)
  217. break;
  218. }
  219. if (apply && !apply_bytes)
  220. break;
  221. }
  222. return 0;
  223. }
  224. static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg,
  225. u32 apply_bytes, int flags, bool uncharge)
  226. {
  227. int ret;
  228. lock_sock(sk);
  229. ret = tcp_bpf_push(sk, msg, apply_bytes, flags, uncharge);
  230. release_sock(sk);
  231. return ret;
  232. }
  233. int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg,
  234. u32 bytes, int flags)
  235. {
  236. bool ingress = sk_msg_to_ingress(msg);
  237. struct sk_psock *psock = sk_psock_get(sk);
  238. int ret;
  239. if (unlikely(!psock)) {
  240. sk_msg_free(sk, msg);
  241. return 0;
  242. }
  243. ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes, flags) :
  244. tcp_bpf_push_locked(sk, msg, bytes, flags, false);
  245. sk_psock_put(sk, psock);
  246. return ret;
  247. }
  248. EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir);
  249. static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock,
  250. struct sk_msg *msg, int *copied, int flags)
  251. {
  252. bool cork = false, enospc = msg->sg.start == msg->sg.end;
  253. struct sock *sk_redir;
  254. u32 tosend;
  255. int ret;
  256. more_data:
  257. if (psock->eval == __SK_NONE)
  258. psock->eval = sk_psock_msg_verdict(sk, psock, msg);
  259. if (msg->cork_bytes &&
  260. msg->cork_bytes > msg->sg.size && !enospc) {
  261. psock->cork_bytes = msg->cork_bytes - msg->sg.size;
  262. if (!psock->cork) {
  263. psock->cork = kzalloc(sizeof(*psock->cork),
  264. GFP_ATOMIC | __GFP_NOWARN);
  265. if (!psock->cork)
  266. return -ENOMEM;
  267. }
  268. memcpy(psock->cork, msg, sizeof(*msg));
  269. return 0;
  270. }
  271. tosend = msg->sg.size;
  272. if (psock->apply_bytes && psock->apply_bytes < tosend)
  273. tosend = psock->apply_bytes;
  274. switch (psock->eval) {
  275. case __SK_PASS:
  276. ret = tcp_bpf_push(sk, msg, tosend, flags, true);
  277. if (unlikely(ret)) {
  278. *copied -= sk_msg_free(sk, msg);
  279. break;
  280. }
  281. sk_msg_apply_bytes(psock, tosend);
  282. break;
  283. case __SK_REDIRECT:
  284. sk_redir = psock->sk_redir;
  285. sk_msg_apply_bytes(psock, tosend);
  286. if (psock->cork) {
  287. cork = true;
  288. psock->cork = NULL;
  289. }
  290. sk_msg_return(sk, msg, tosend);
  291. release_sock(sk);
  292. ret = tcp_bpf_sendmsg_redir(sk_redir, msg, tosend, flags);
  293. lock_sock(sk);
  294. if (unlikely(ret < 0)) {
  295. int free = sk_msg_free_nocharge(sk, msg);
  296. if (!cork)
  297. *copied -= free;
  298. }
  299. if (cork) {
  300. sk_msg_free(sk, msg);
  301. kfree(msg);
  302. msg = NULL;
  303. ret = 0;
  304. }
  305. break;
  306. case __SK_DROP:
  307. default:
  308. sk_msg_free_partial(sk, msg, tosend);
  309. sk_msg_apply_bytes(psock, tosend);
  310. *copied -= tosend;
  311. return -EACCES;
  312. }
  313. if (likely(!ret)) {
  314. if (!psock->apply_bytes) {
  315. psock->eval = __SK_NONE;
  316. if (psock->sk_redir) {
  317. sock_put(psock->sk_redir);
  318. psock->sk_redir = NULL;
  319. }
  320. }
  321. if (msg &&
  322. msg->sg.data[msg->sg.start].page_link &&
  323. msg->sg.data[msg->sg.start].length)
  324. goto more_data;
  325. }
  326. return ret;
  327. }
  328. static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
  329. {
  330. struct sk_msg tmp, *msg_tx = NULL;
  331. int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS;
  332. int copied = 0, err = 0;
  333. struct sk_psock *psock;
  334. long timeo;
  335. psock = sk_psock_get(sk);
  336. if (unlikely(!psock))
  337. return tcp_sendmsg(sk, msg, size);
  338. lock_sock(sk);
  339. timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
  340. while (msg_data_left(msg)) {
  341. bool enospc = false;
  342. u32 copy, osize;
  343. if (sk->sk_err) {
  344. err = -sk->sk_err;
  345. goto out_err;
  346. }
  347. copy = msg_data_left(msg);
  348. if (!sk_stream_memory_free(sk))
  349. goto wait_for_sndbuf;
  350. if (psock->cork) {
  351. msg_tx = psock->cork;
  352. } else {
  353. msg_tx = &tmp;
  354. sk_msg_init(msg_tx);
  355. }
  356. osize = msg_tx->sg.size;
  357. err = sk_msg_alloc(sk, msg_tx, msg_tx->sg.size + copy, msg_tx->sg.end - 1);
  358. if (err) {
  359. if (err != -ENOSPC)
  360. goto wait_for_memory;
  361. enospc = true;
  362. copy = msg_tx->sg.size - osize;
  363. }
  364. err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_tx,
  365. copy);
  366. if (err < 0) {
  367. sk_msg_trim(sk, msg_tx, osize);
  368. goto out_err;
  369. }
  370. copied += copy;
  371. if (psock->cork_bytes) {
  372. if (size > psock->cork_bytes)
  373. psock->cork_bytes = 0;
  374. else
  375. psock->cork_bytes -= size;
  376. if (psock->cork_bytes && !enospc)
  377. goto out_err;
  378. /* All cork bytes are accounted, rerun the prog. */
  379. psock->eval = __SK_NONE;
  380. psock->cork_bytes = 0;
  381. }
  382. err = tcp_bpf_send_verdict(sk, psock, msg_tx, &copied, flags);
  383. if (unlikely(err < 0))
  384. goto out_err;
  385. continue;
  386. wait_for_sndbuf:
  387. set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
  388. wait_for_memory:
  389. err = sk_stream_wait_memory(sk, &timeo);
  390. if (err) {
  391. if (msg_tx && msg_tx != psock->cork)
  392. sk_msg_free(sk, msg_tx);
  393. goto out_err;
  394. }
  395. }
  396. out_err:
  397. if (err < 0)
  398. err = sk_stream_error(sk, msg->msg_flags, err);
  399. release_sock(sk);
  400. sk_psock_put(sk, psock);
  401. return copied ? copied : err;
  402. }
  403. static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset,
  404. size_t size, int flags)
  405. {
  406. struct sk_msg tmp, *msg = NULL;
  407. int err = 0, copied = 0;
  408. struct sk_psock *psock;
  409. bool enospc = false;
  410. psock = sk_psock_get(sk);
  411. if (unlikely(!psock))
  412. return tcp_sendpage(sk, page, offset, size, flags);
  413. lock_sock(sk);
  414. if (psock->cork) {
  415. msg = psock->cork;
  416. } else {
  417. msg = &tmp;
  418. sk_msg_init(msg);
  419. }
  420. /* Catch case where ring is full and sendpage is stalled. */
  421. if (unlikely(sk_msg_full(msg)))
  422. goto out_err;
  423. sk_msg_page_add(msg, page, size, offset);
  424. sk_mem_charge(sk, size);
  425. copied = size;
  426. if (sk_msg_full(msg))
  427. enospc = true;
  428. if (psock->cork_bytes) {
  429. if (size > psock->cork_bytes)
  430. psock->cork_bytes = 0;
  431. else
  432. psock->cork_bytes -= size;
  433. if (psock->cork_bytes && !enospc)
  434. goto out_err;
  435. /* All cork bytes are accounted, rerun the prog. */
  436. psock->eval = __SK_NONE;
  437. psock->cork_bytes = 0;
  438. }
  439. err = tcp_bpf_send_verdict(sk, psock, msg, &copied, flags);
  440. out_err:
  441. release_sock(sk);
  442. sk_psock_put(sk, psock);
  443. return copied ? copied : err;
  444. }
  445. static void tcp_bpf_remove(struct sock *sk, struct sk_psock *psock)
  446. {
  447. struct sk_psock_link *link;
  448. sk_psock_cork_free(psock);
  449. __sk_psock_purge_ingress_msg(psock);
  450. while ((link = sk_psock_link_pop(psock))) {
  451. sk_psock_unlink(sk, link);
  452. sk_psock_free_link(link);
  453. }
  454. }
  455. static void tcp_bpf_unhash(struct sock *sk)
  456. {
  457. void (*saved_unhash)(struct sock *sk);
  458. struct sk_psock *psock;
  459. rcu_read_lock();
  460. psock = sk_psock(sk);
  461. if (unlikely(!psock)) {
  462. rcu_read_unlock();
  463. if (sk->sk_prot->unhash)
  464. sk->sk_prot->unhash(sk);
  465. return;
  466. }
  467. saved_unhash = psock->saved_unhash;
  468. tcp_bpf_remove(sk, psock);
  469. rcu_read_unlock();
  470. saved_unhash(sk);
  471. }
  472. static void tcp_bpf_close(struct sock *sk, long timeout)
  473. {
  474. void (*saved_close)(struct sock *sk, long timeout);
  475. struct sk_psock *psock;
  476. lock_sock(sk);
  477. rcu_read_lock();
  478. psock = sk_psock(sk);
  479. if (unlikely(!psock)) {
  480. rcu_read_unlock();
  481. release_sock(sk);
  482. return sk->sk_prot->close(sk, timeout);
  483. }
  484. saved_close = psock->saved_close;
  485. tcp_bpf_remove(sk, psock);
  486. rcu_read_unlock();
  487. release_sock(sk);
  488. saved_close(sk, timeout);
  489. }
  490. enum {
  491. TCP_BPF_IPV4,
  492. TCP_BPF_IPV6,
  493. TCP_BPF_NUM_PROTS,
  494. };
  495. enum {
  496. TCP_BPF_BASE,
  497. TCP_BPF_TX,
  498. TCP_BPF_NUM_CFGS,
  499. };
  500. static struct proto *tcpv6_prot_saved __read_mostly;
  501. static DEFINE_SPINLOCK(tcpv6_prot_lock);
  502. static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS];
  503. static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS],
  504. struct proto *base)
  505. {
  506. prot[TCP_BPF_BASE] = *base;
  507. prot[TCP_BPF_BASE].unhash = tcp_bpf_unhash;
  508. prot[TCP_BPF_BASE].close = tcp_bpf_close;
  509. prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg;
  510. prot[TCP_BPF_BASE].stream_memory_read = tcp_bpf_stream_read;
  511. prot[TCP_BPF_TX] = prot[TCP_BPF_BASE];
  512. prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg;
  513. prot[TCP_BPF_TX].sendpage = tcp_bpf_sendpage;
  514. }
  515. static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops)
  516. {
  517. if (sk->sk_family == AF_INET6 &&
  518. unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) {
  519. spin_lock_bh(&tcpv6_prot_lock);
  520. if (likely(ops != tcpv6_prot_saved)) {
  521. tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops);
  522. smp_store_release(&tcpv6_prot_saved, ops);
  523. }
  524. spin_unlock_bh(&tcpv6_prot_lock);
  525. }
  526. }
  527. static int __init tcp_bpf_v4_build_proto(void)
  528. {
  529. tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot);
  530. return 0;
  531. }
  532. core_initcall(tcp_bpf_v4_build_proto);
  533. static void tcp_bpf_update_sk_prot(struct sock *sk, struct sk_psock *psock)
  534. {
  535. int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
  536. int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE;
  537. sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config]);
  538. }
  539. static void tcp_bpf_reinit_sk_prot(struct sock *sk, struct sk_psock *psock)
  540. {
  541. int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
  542. int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE;
  543. /* Reinit occurs when program types change e.g. TCP_BPF_TX is removed
  544. * or added requiring sk_prot hook updates. We keep original saved
  545. * hooks in this case.
  546. */
  547. sk->sk_prot = &tcp_bpf_prots[family][config];
  548. }
  549. static int tcp_bpf_assert_proto_ops(struct proto *ops)
  550. {
  551. /* In order to avoid retpoline, we make assumptions when we call
  552. * into ops if e.g. a psock is not present. Make sure they are
  553. * indeed valid assumptions.
  554. */
  555. return ops->recvmsg == tcp_recvmsg &&
  556. ops->sendmsg == tcp_sendmsg &&
  557. ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
  558. }
  559. void tcp_bpf_reinit(struct sock *sk)
  560. {
  561. struct sk_psock *psock;
  562. sock_owned_by_me(sk);
  563. rcu_read_lock();
  564. psock = sk_psock(sk);
  565. tcp_bpf_reinit_sk_prot(sk, psock);
  566. rcu_read_unlock();
  567. }
  568. int tcp_bpf_init(struct sock *sk)
  569. {
  570. struct proto *ops = READ_ONCE(sk->sk_prot);
  571. struct sk_psock *psock;
  572. sock_owned_by_me(sk);
  573. rcu_read_lock();
  574. psock = sk_psock(sk);
  575. if (unlikely(!psock || psock->sk_proto ||
  576. tcp_bpf_assert_proto_ops(ops))) {
  577. rcu_read_unlock();
  578. return -EINVAL;
  579. }
  580. tcp_bpf_check_v6_needs_rebuild(sk, ops);
  581. tcp_bpf_update_sk_prot(sk, psock);
  582. rcu_read_unlock();
  583. return 0;
  584. }