tcp_bpf.c 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669
  1. // SPDX-License-Identifier: GPL-2.0
  2. /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
  3. #include <linux/skmsg.h>
  4. #include <linux/filter.h>
  5. #include <linux/bpf.h>
  6. #include <linux/init.h>
  7. #include <linux/wait.h>
  8. #include <net/inet_common.h>
  9. static bool tcp_bpf_stream_read(const struct sock *sk)
  10. {
  11. struct sk_psock *psock;
  12. bool empty = true;
  13. rcu_read_lock();
  14. psock = sk_psock(sk);
  15. if (likely(psock))
  16. empty = list_empty(&psock->ingress_msg);
  17. rcu_read_unlock();
  18. return !empty;
  19. }
  20. static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock,
  21. int flags, long timeo, int *err)
  22. {
  23. DEFINE_WAIT_FUNC(wait, woken_wake_function);
  24. int ret;
  25. add_wait_queue(sk_sleep(sk), &wait);
  26. sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
  27. ret = sk_wait_event(sk, &timeo,
  28. !list_empty(&psock->ingress_msg) ||
  29. !skb_queue_empty(&sk->sk_receive_queue), &wait);
  30. sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
  31. remove_wait_queue(sk_sleep(sk), &wait);
  32. return ret;
  33. }
  34. int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
  35. struct msghdr *msg, int len, int flags)
  36. {
  37. struct iov_iter *iter = &msg->msg_iter;
  38. int peek = flags & MSG_PEEK;
  39. int i, ret, copied = 0;
  40. struct sk_msg *msg_rx;
  41. msg_rx = list_first_entry_or_null(&psock->ingress_msg,
  42. struct sk_msg, list);
  43. while (copied != len) {
  44. struct scatterlist *sge;
  45. if (unlikely(!msg_rx))
  46. break;
  47. i = msg_rx->sg.start;
  48. do {
  49. struct page *page;
  50. int copy;
  51. sge = sk_msg_elem(msg_rx, i);
  52. copy = sge->length;
  53. page = sg_page(sge);
  54. if (copied + copy > len)
  55. copy = len - copied;
  56. ret = copy_page_to_iter(page, sge->offset, copy, iter);
  57. if (ret != copy) {
  58. msg_rx->sg.start = i;
  59. return -EFAULT;
  60. }
  61. copied += copy;
  62. if (likely(!peek)) {
  63. sge->offset += copy;
  64. sge->length -= copy;
  65. sk_mem_uncharge(sk, copy);
  66. msg_rx->sg.size -= copy;
  67. if (!sge->length) {
  68. sk_msg_iter_var_next(i);
  69. if (!msg_rx->skb)
  70. put_page(page);
  71. }
  72. } else {
  73. sk_msg_iter_var_next(i);
  74. }
  75. if (copied == len)
  76. break;
  77. } while (i != msg_rx->sg.end);
  78. if (unlikely(peek)) {
  79. msg_rx = list_next_entry(msg_rx, list);
  80. continue;
  81. }
  82. msg_rx->sg.start = i;
  83. if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) {
  84. list_del(&msg_rx->list);
  85. if (msg_rx->skb)
  86. consume_skb(msg_rx->skb);
  87. kfree(msg_rx);
  88. }
  89. msg_rx = list_first_entry_or_null(&psock->ingress_msg,
  90. struct sk_msg, list);
  91. }
  92. return copied;
  93. }
  94. EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg);
  95. int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
  96. int nonblock, int flags, int *addr_len)
  97. {
  98. struct sk_psock *psock;
  99. int copied, ret;
  100. if (unlikely(flags & MSG_ERRQUEUE))
  101. return inet_recv_error(sk, msg, len, addr_len);
  102. if (!skb_queue_empty(&sk->sk_receive_queue))
  103. return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
  104. psock = sk_psock_get(sk);
  105. if (unlikely(!psock))
  106. return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
  107. lock_sock(sk);
  108. msg_bytes_ready:
  109. copied = __tcp_bpf_recvmsg(sk, psock, msg, len, flags);
  110. if (!copied) {
  111. int data, err = 0;
  112. long timeo;
  113. timeo = sock_rcvtimeo(sk, nonblock);
  114. data = tcp_bpf_wait_data(sk, psock, flags, timeo, &err);
  115. if (data) {
  116. if (skb_queue_empty(&sk->sk_receive_queue))
  117. goto msg_bytes_ready;
  118. release_sock(sk);
  119. sk_psock_put(sk, psock);
  120. return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
  121. }
  122. if (err) {
  123. ret = err;
  124. goto out;
  125. }
  126. copied = -EAGAIN;
  127. }
  128. ret = copied;
  129. out:
  130. release_sock(sk);
  131. sk_psock_put(sk, psock);
  132. return ret;
  133. }
  134. static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock,
  135. struct sk_msg *msg, u32 apply_bytes, int flags)
  136. {
  137. bool apply = apply_bytes;
  138. struct scatterlist *sge;
  139. u32 size, copied = 0;
  140. struct sk_msg *tmp;
  141. int i, ret = 0;
  142. tmp = kzalloc(sizeof(*tmp), __GFP_NOWARN | GFP_KERNEL);
  143. if (unlikely(!tmp))
  144. return -ENOMEM;
  145. lock_sock(sk);
  146. tmp->sg.start = msg->sg.start;
  147. i = msg->sg.start;
  148. do {
  149. sge = sk_msg_elem(msg, i);
  150. size = (apply && apply_bytes < sge->length) ?
  151. apply_bytes : sge->length;
  152. if (!sk_wmem_schedule(sk, size)) {
  153. if (!copied)
  154. ret = -ENOMEM;
  155. break;
  156. }
  157. sk_mem_charge(sk, size);
  158. sk_msg_xfer(tmp, msg, i, size);
  159. copied += size;
  160. if (sge->length)
  161. get_page(sk_msg_page(tmp, i));
  162. sk_msg_iter_var_next(i);
  163. tmp->sg.end = i;
  164. if (apply) {
  165. apply_bytes -= size;
  166. if (!apply_bytes)
  167. break;
  168. }
  169. } while (i != msg->sg.end);
  170. if (!ret) {
  171. msg->sg.start = i;
  172. msg->sg.size -= apply_bytes;
  173. sk_psock_queue_msg(psock, tmp);
  174. sk->sk_data_ready(sk);
  175. } else {
  176. sk_msg_free(sk, tmp);
  177. kfree(tmp);
  178. }
  179. release_sock(sk);
  180. return ret;
  181. }
  182. static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes,
  183. int flags, bool uncharge)
  184. {
  185. bool apply = apply_bytes;
  186. struct scatterlist *sge;
  187. struct page *page;
  188. int size, ret = 0;
  189. u32 off;
  190. while (1) {
  191. sge = sk_msg_elem(msg, msg->sg.start);
  192. size = (apply && apply_bytes < sge->length) ?
  193. apply_bytes : sge->length;
  194. off = sge->offset;
  195. page = sg_page(sge);
  196. tcp_rate_check_app_limited(sk);
  197. retry:
  198. ret = do_tcp_sendpages(sk, page, off, size, flags);
  199. if (ret <= 0)
  200. return ret;
  201. if (apply)
  202. apply_bytes -= ret;
  203. msg->sg.size -= ret;
  204. sge->offset += ret;
  205. sge->length -= ret;
  206. if (uncharge)
  207. sk_mem_uncharge(sk, ret);
  208. if (ret != size) {
  209. size -= ret;
  210. off += ret;
  211. goto retry;
  212. }
  213. if (!sge->length) {
  214. put_page(page);
  215. sk_msg_iter_next(msg, start);
  216. sg_init_table(sge, 1);
  217. if (msg->sg.start == msg->sg.end)
  218. break;
  219. }
  220. if (apply && !apply_bytes)
  221. break;
  222. }
  223. return 0;
  224. }
  225. static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg,
  226. u32 apply_bytes, int flags, bool uncharge)
  227. {
  228. int ret;
  229. lock_sock(sk);
  230. ret = tcp_bpf_push(sk, msg, apply_bytes, flags, uncharge);
  231. release_sock(sk);
  232. return ret;
  233. }
  234. int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg,
  235. u32 bytes, int flags)
  236. {
  237. bool ingress = sk_msg_to_ingress(msg);
  238. struct sk_psock *psock = sk_psock_get(sk);
  239. int ret;
  240. if (unlikely(!psock)) {
  241. sk_msg_free(sk, msg);
  242. return 0;
  243. }
  244. ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes, flags) :
  245. tcp_bpf_push_locked(sk, msg, bytes, flags, false);
  246. sk_psock_put(sk, psock);
  247. return ret;
  248. }
  249. EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir);
  250. static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock,
  251. struct sk_msg *msg, int *copied, int flags)
  252. {
  253. bool cork = false, enospc = msg->sg.start == msg->sg.end;
  254. struct sock *sk_redir;
  255. u32 tosend;
  256. int ret;
  257. more_data:
  258. if (psock->eval == __SK_NONE)
  259. psock->eval = sk_psock_msg_verdict(sk, psock, msg);
  260. if (msg->cork_bytes &&
  261. msg->cork_bytes > msg->sg.size && !enospc) {
  262. psock->cork_bytes = msg->cork_bytes - msg->sg.size;
  263. if (!psock->cork) {
  264. psock->cork = kzalloc(sizeof(*psock->cork),
  265. GFP_ATOMIC | __GFP_NOWARN);
  266. if (!psock->cork)
  267. return -ENOMEM;
  268. }
  269. memcpy(psock->cork, msg, sizeof(*msg));
  270. return 0;
  271. }
  272. tosend = msg->sg.size;
  273. if (psock->apply_bytes && psock->apply_bytes < tosend)
  274. tosend = psock->apply_bytes;
  275. switch (psock->eval) {
  276. case __SK_PASS:
  277. ret = tcp_bpf_push(sk, msg, tosend, flags, true);
  278. if (unlikely(ret)) {
  279. *copied -= sk_msg_free(sk, msg);
  280. break;
  281. }
  282. sk_msg_apply_bytes(psock, tosend);
  283. break;
  284. case __SK_REDIRECT:
  285. sk_redir = psock->sk_redir;
  286. sk_msg_apply_bytes(psock, tosend);
  287. if (psock->cork) {
  288. cork = true;
  289. psock->cork = NULL;
  290. }
  291. sk_msg_return(sk, msg, tosend);
  292. release_sock(sk);
  293. ret = tcp_bpf_sendmsg_redir(sk_redir, msg, tosend, flags);
  294. lock_sock(sk);
  295. if (unlikely(ret < 0)) {
  296. int free = sk_msg_free_nocharge(sk, msg);
  297. if (!cork)
  298. *copied -= free;
  299. }
  300. if (cork) {
  301. sk_msg_free(sk, msg);
  302. kfree(msg);
  303. msg = NULL;
  304. ret = 0;
  305. }
  306. break;
  307. case __SK_DROP:
  308. default:
  309. sk_msg_free_partial(sk, msg, tosend);
  310. sk_msg_apply_bytes(psock, tosend);
  311. *copied -= tosend;
  312. return -EACCES;
  313. }
  314. if (likely(!ret)) {
  315. if (!psock->apply_bytes) {
  316. psock->eval = __SK_NONE;
  317. if (psock->sk_redir) {
  318. sock_put(psock->sk_redir);
  319. psock->sk_redir = NULL;
  320. }
  321. }
  322. if (msg &&
  323. msg->sg.data[msg->sg.start].page_link &&
  324. msg->sg.data[msg->sg.start].length)
  325. goto more_data;
  326. }
  327. return ret;
  328. }
  329. static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
  330. {
  331. struct sk_msg tmp, *msg_tx = NULL;
  332. int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS;
  333. int copied = 0, err = 0;
  334. struct sk_psock *psock;
  335. long timeo;
  336. psock = sk_psock_get(sk);
  337. if (unlikely(!psock))
  338. return tcp_sendmsg(sk, msg, size);
  339. lock_sock(sk);
  340. timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
  341. while (msg_data_left(msg)) {
  342. bool enospc = false;
  343. u32 copy, osize;
  344. if (sk->sk_err) {
  345. err = -sk->sk_err;
  346. goto out_err;
  347. }
  348. copy = msg_data_left(msg);
  349. if (!sk_stream_memory_free(sk))
  350. goto wait_for_sndbuf;
  351. if (psock->cork) {
  352. msg_tx = psock->cork;
  353. } else {
  354. msg_tx = &tmp;
  355. sk_msg_init(msg_tx);
  356. }
  357. osize = msg_tx->sg.size;
  358. err = sk_msg_alloc(sk, msg_tx, msg_tx->sg.size + copy, msg_tx->sg.end - 1);
  359. if (err) {
  360. if (err != -ENOSPC)
  361. goto wait_for_memory;
  362. enospc = true;
  363. copy = msg_tx->sg.size - osize;
  364. }
  365. err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_tx,
  366. copy);
  367. if (err < 0) {
  368. sk_msg_trim(sk, msg_tx, osize);
  369. goto out_err;
  370. }
  371. copied += copy;
  372. if (psock->cork_bytes) {
  373. if (size > psock->cork_bytes)
  374. psock->cork_bytes = 0;
  375. else
  376. psock->cork_bytes -= size;
  377. if (psock->cork_bytes && !enospc)
  378. goto out_err;
  379. /* All cork bytes are accounted, rerun the prog. */
  380. psock->eval = __SK_NONE;
  381. psock->cork_bytes = 0;
  382. }
  383. err = tcp_bpf_send_verdict(sk, psock, msg_tx, &copied, flags);
  384. if (unlikely(err < 0))
  385. goto out_err;
  386. continue;
  387. wait_for_sndbuf:
  388. set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
  389. wait_for_memory:
  390. err = sk_stream_wait_memory(sk, &timeo);
  391. if (err) {
  392. if (msg_tx && msg_tx != psock->cork)
  393. sk_msg_free(sk, msg_tx);
  394. goto out_err;
  395. }
  396. }
  397. out_err:
  398. if (err < 0)
  399. err = sk_stream_error(sk, msg->msg_flags, err);
  400. release_sock(sk);
  401. sk_psock_put(sk, psock);
  402. return copied ? copied : err;
  403. }
  404. static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset,
  405. size_t size, int flags)
  406. {
  407. struct sk_msg tmp, *msg = NULL;
  408. int err = 0, copied = 0;
  409. struct sk_psock *psock;
  410. bool enospc = false;
  411. psock = sk_psock_get(sk);
  412. if (unlikely(!psock))
  413. return tcp_sendpage(sk, page, offset, size, flags);
  414. lock_sock(sk);
  415. if (psock->cork) {
  416. msg = psock->cork;
  417. } else {
  418. msg = &tmp;
  419. sk_msg_init(msg);
  420. }
  421. /* Catch case where ring is full and sendpage is stalled. */
  422. if (unlikely(sk_msg_full(msg)))
  423. goto out_err;
  424. sk_msg_page_add(msg, page, size, offset);
  425. sk_mem_charge(sk, size);
  426. copied = size;
  427. if (sk_msg_full(msg))
  428. enospc = true;
  429. if (psock->cork_bytes) {
  430. if (size > psock->cork_bytes)
  431. psock->cork_bytes = 0;
  432. else
  433. psock->cork_bytes -= size;
  434. if (psock->cork_bytes && !enospc)
  435. goto out_err;
  436. /* All cork bytes are accounted, rerun the prog. */
  437. psock->eval = __SK_NONE;
  438. psock->cork_bytes = 0;
  439. }
  440. err = tcp_bpf_send_verdict(sk, psock, msg, &copied, flags);
  441. out_err:
  442. release_sock(sk);
  443. sk_psock_put(sk, psock);
  444. return copied ? copied : err;
  445. }
  446. static void tcp_bpf_remove(struct sock *sk, struct sk_psock *psock)
  447. {
  448. struct sk_psock_link *link;
  449. sk_psock_cork_free(psock);
  450. __sk_psock_purge_ingress_msg(psock);
  451. while ((link = sk_psock_link_pop(psock))) {
  452. sk_psock_unlink(sk, link);
  453. sk_psock_free_link(link);
  454. }
  455. }
  456. static void tcp_bpf_unhash(struct sock *sk)
  457. {
  458. void (*saved_unhash)(struct sock *sk);
  459. struct sk_psock *psock;
  460. rcu_read_lock();
  461. psock = sk_psock(sk);
  462. if (unlikely(!psock)) {
  463. rcu_read_unlock();
  464. if (sk->sk_prot->unhash)
  465. sk->sk_prot->unhash(sk);
  466. return;
  467. }
  468. saved_unhash = psock->saved_unhash;
  469. tcp_bpf_remove(sk, psock);
  470. rcu_read_unlock();
  471. saved_unhash(sk);
  472. }
  473. static void tcp_bpf_close(struct sock *sk, long timeout)
  474. {
  475. void (*saved_close)(struct sock *sk, long timeout);
  476. struct sk_psock *psock;
  477. lock_sock(sk);
  478. rcu_read_lock();
  479. psock = sk_psock(sk);
  480. if (unlikely(!psock)) {
  481. rcu_read_unlock();
  482. release_sock(sk);
  483. return sk->sk_prot->close(sk, timeout);
  484. }
  485. saved_close = psock->saved_close;
  486. tcp_bpf_remove(sk, psock);
  487. rcu_read_unlock();
  488. release_sock(sk);
  489. saved_close(sk, timeout);
  490. }
  491. enum {
  492. TCP_BPF_IPV4,
  493. TCP_BPF_IPV6,
  494. TCP_BPF_NUM_PROTS,
  495. };
  496. enum {
  497. TCP_BPF_BASE,
  498. TCP_BPF_TX,
  499. TCP_BPF_NUM_CFGS,
  500. };
  501. static struct proto *tcpv6_prot_saved __read_mostly;
  502. static DEFINE_SPINLOCK(tcpv6_prot_lock);
  503. static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS];
  504. static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS],
  505. struct proto *base)
  506. {
  507. prot[TCP_BPF_BASE] = *base;
  508. prot[TCP_BPF_BASE].unhash = tcp_bpf_unhash;
  509. prot[TCP_BPF_BASE].close = tcp_bpf_close;
  510. prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg;
  511. prot[TCP_BPF_BASE].stream_memory_read = tcp_bpf_stream_read;
  512. prot[TCP_BPF_TX] = prot[TCP_BPF_BASE];
  513. prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg;
  514. prot[TCP_BPF_TX].sendpage = tcp_bpf_sendpage;
  515. }
  516. static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops)
  517. {
  518. if (sk->sk_family == AF_INET6 &&
  519. unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) {
  520. spin_lock_bh(&tcpv6_prot_lock);
  521. if (likely(ops != tcpv6_prot_saved)) {
  522. tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops);
  523. smp_store_release(&tcpv6_prot_saved, ops);
  524. }
  525. spin_unlock_bh(&tcpv6_prot_lock);
  526. }
  527. }
  528. static int __init tcp_bpf_v4_build_proto(void)
  529. {
  530. tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot);
  531. return 0;
  532. }
  533. core_initcall(tcp_bpf_v4_build_proto);
  534. static void tcp_bpf_update_sk_prot(struct sock *sk, struct sk_psock *psock)
  535. {
  536. int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
  537. int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE;
  538. sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config]);
  539. }
  540. static void tcp_bpf_reinit_sk_prot(struct sock *sk, struct sk_psock *psock)
  541. {
  542. int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
  543. int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE;
  544. /* Reinit occurs when program types change e.g. TCP_BPF_TX is removed
  545. * or added requiring sk_prot hook updates. We keep original saved
  546. * hooks in this case.
  547. */
  548. sk->sk_prot = &tcp_bpf_prots[family][config];
  549. }
  550. static int tcp_bpf_assert_proto_ops(struct proto *ops)
  551. {
  552. /* In order to avoid retpoline, we make assumptions when we call
  553. * into ops if e.g. a psock is not present. Make sure they are
  554. * indeed valid assumptions.
  555. */
  556. return ops->recvmsg == tcp_recvmsg &&
  557. ops->sendmsg == tcp_sendmsg &&
  558. ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
  559. }
  560. void tcp_bpf_reinit(struct sock *sk)
  561. {
  562. struct sk_psock *psock;
  563. sock_owned_by_me(sk);
  564. rcu_read_lock();
  565. psock = sk_psock(sk);
  566. tcp_bpf_reinit_sk_prot(sk, psock);
  567. rcu_read_unlock();
  568. }
  569. int tcp_bpf_init(struct sock *sk)
  570. {
  571. struct proto *ops = READ_ONCE(sk->sk_prot);
  572. struct sk_psock *psock;
  573. sock_owned_by_me(sk);
  574. rcu_read_lock();
  575. psock = sk_psock(sk);
  576. if (unlikely(!psock || psock->sk_proto ||
  577. tcp_bpf_assert_proto_ops(ops))) {
  578. rcu_read_unlock();
  579. return -EINVAL;
  580. }
  581. tcp_bpf_check_v6_needs_rebuild(sk, ops);
  582. tcp_bpf_update_sk_prot(sk, psock);
  583. rcu_read_unlock();
  584. return 0;
  585. }