skmsg.c 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802
  1. // SPDX-License-Identifier: GPL-2.0
  2. /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
  3. #include <linux/skmsg.h>
  4. #include <linux/skbuff.h>
  5. #include <linux/scatterlist.h>
  6. #include <net/sock.h>
  7. #include <net/tcp.h>
  8. static bool sk_msg_try_coalesce_ok(struct sk_msg *msg, int elem_first_coalesce)
  9. {
  10. if (msg->sg.end > msg->sg.start &&
  11. elem_first_coalesce < msg->sg.end)
  12. return true;
  13. if (msg->sg.end < msg->sg.start &&
  14. (elem_first_coalesce > msg->sg.start ||
  15. elem_first_coalesce < msg->sg.end))
  16. return true;
  17. return false;
  18. }
  19. int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
  20. int elem_first_coalesce)
  21. {
  22. struct page_frag *pfrag = sk_page_frag(sk);
  23. int ret = 0;
  24. len -= msg->sg.size;
  25. while (len > 0) {
  26. struct scatterlist *sge;
  27. u32 orig_offset;
  28. int use, i;
  29. if (!sk_page_frag_refill(sk, pfrag))
  30. return -ENOMEM;
  31. orig_offset = pfrag->offset;
  32. use = min_t(int, len, pfrag->size - orig_offset);
  33. if (!sk_wmem_schedule(sk, use))
  34. return -ENOMEM;
  35. i = msg->sg.end;
  36. sk_msg_iter_var_prev(i);
  37. sge = &msg->sg.data[i];
  38. if (sk_msg_try_coalesce_ok(msg, elem_first_coalesce) &&
  39. sg_page(sge) == pfrag->page &&
  40. sge->offset + sge->length == orig_offset) {
  41. sge->length += use;
  42. } else {
  43. if (sk_msg_full(msg)) {
  44. ret = -ENOSPC;
  45. break;
  46. }
  47. sge = &msg->sg.data[msg->sg.end];
  48. sg_unmark_end(sge);
  49. sg_set_page(sge, pfrag->page, use, orig_offset);
  50. get_page(pfrag->page);
  51. sk_msg_iter_next(msg, end);
  52. }
  53. sk_mem_charge(sk, use);
  54. msg->sg.size += use;
  55. pfrag->offset += use;
  56. len -= use;
  57. }
  58. return ret;
  59. }
  60. EXPORT_SYMBOL_GPL(sk_msg_alloc);
  61. int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src,
  62. u32 off, u32 len)
  63. {
  64. int i = src->sg.start;
  65. struct scatterlist *sge = sk_msg_elem(src, i);
  66. u32 sge_len, sge_off;
  67. if (sk_msg_full(dst))
  68. return -ENOSPC;
  69. while (off) {
  70. if (sge->length > off)
  71. break;
  72. off -= sge->length;
  73. sk_msg_iter_var_next(i);
  74. if (i == src->sg.end && off)
  75. return -ENOSPC;
  76. sge = sk_msg_elem(src, i);
  77. }
  78. while (len) {
  79. sge_len = sge->length - off;
  80. sge_off = sge->offset + off;
  81. if (sge_len > len)
  82. sge_len = len;
  83. off = 0;
  84. len -= sge_len;
  85. sk_msg_page_add(dst, sg_page(sge), sge_len, sge_off);
  86. sk_mem_charge(sk, sge_len);
  87. sk_msg_iter_var_next(i);
  88. if (i == src->sg.end && len)
  89. return -ENOSPC;
  90. sge = sk_msg_elem(src, i);
  91. }
  92. return 0;
  93. }
  94. EXPORT_SYMBOL_GPL(sk_msg_clone);
  95. void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes)
  96. {
  97. int i = msg->sg.start;
  98. do {
  99. struct scatterlist *sge = sk_msg_elem(msg, i);
  100. if (bytes < sge->length) {
  101. sge->length -= bytes;
  102. sge->offset += bytes;
  103. sk_mem_uncharge(sk, bytes);
  104. break;
  105. }
  106. sk_mem_uncharge(sk, sge->length);
  107. bytes -= sge->length;
  108. sge->length = 0;
  109. sge->offset = 0;
  110. sk_msg_iter_var_next(i);
  111. } while (bytes && i != msg->sg.end);
  112. msg->sg.start = i;
  113. }
  114. EXPORT_SYMBOL_GPL(sk_msg_return_zero);
  115. void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes)
  116. {
  117. int i = msg->sg.start;
  118. do {
  119. struct scatterlist *sge = &msg->sg.data[i];
  120. int uncharge = (bytes < sge->length) ? bytes : sge->length;
  121. sk_mem_uncharge(sk, uncharge);
  122. bytes -= uncharge;
  123. sk_msg_iter_var_next(i);
  124. } while (i != msg->sg.end);
  125. }
  126. EXPORT_SYMBOL_GPL(sk_msg_return);
  127. static int sk_msg_free_elem(struct sock *sk, struct sk_msg *msg, u32 i,
  128. bool charge)
  129. {
  130. struct scatterlist *sge = sk_msg_elem(msg, i);
  131. u32 len = sge->length;
  132. if (charge)
  133. sk_mem_uncharge(sk, len);
  134. if (!msg->skb)
  135. put_page(sg_page(sge));
  136. memset(sge, 0, sizeof(*sge));
  137. return len;
  138. }
  139. static int __sk_msg_free(struct sock *sk, struct sk_msg *msg, u32 i,
  140. bool charge)
  141. {
  142. struct scatterlist *sge = sk_msg_elem(msg, i);
  143. int freed = 0;
  144. while (msg->sg.size) {
  145. msg->sg.size -= sge->length;
  146. freed += sk_msg_free_elem(sk, msg, i, charge);
  147. sk_msg_iter_var_next(i);
  148. sk_msg_check_to_free(msg, i, msg->sg.size);
  149. sge = sk_msg_elem(msg, i);
  150. }
  151. if (msg->skb)
  152. consume_skb(msg->skb);
  153. sk_msg_init(msg);
  154. return freed;
  155. }
  156. int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg)
  157. {
  158. return __sk_msg_free(sk, msg, msg->sg.start, false);
  159. }
  160. EXPORT_SYMBOL_GPL(sk_msg_free_nocharge);
  161. int sk_msg_free(struct sock *sk, struct sk_msg *msg)
  162. {
  163. return __sk_msg_free(sk, msg, msg->sg.start, true);
  164. }
  165. EXPORT_SYMBOL_GPL(sk_msg_free);
  166. static void __sk_msg_free_partial(struct sock *sk, struct sk_msg *msg,
  167. u32 bytes, bool charge)
  168. {
  169. struct scatterlist *sge;
  170. u32 i = msg->sg.start;
  171. while (bytes) {
  172. sge = sk_msg_elem(msg, i);
  173. if (!sge->length)
  174. break;
  175. if (bytes < sge->length) {
  176. if (charge)
  177. sk_mem_uncharge(sk, bytes);
  178. sge->length -= bytes;
  179. sge->offset += bytes;
  180. msg->sg.size -= bytes;
  181. break;
  182. }
  183. msg->sg.size -= sge->length;
  184. bytes -= sge->length;
  185. sk_msg_free_elem(sk, msg, i, charge);
  186. sk_msg_iter_var_next(i);
  187. sk_msg_check_to_free(msg, i, bytes);
  188. }
  189. msg->sg.start = i;
  190. }
  191. void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes)
  192. {
  193. __sk_msg_free_partial(sk, msg, bytes, true);
  194. }
  195. EXPORT_SYMBOL_GPL(sk_msg_free_partial);
  196. void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg,
  197. u32 bytes)
  198. {
  199. __sk_msg_free_partial(sk, msg, bytes, false);
  200. }
  201. void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len)
  202. {
  203. int trim = msg->sg.size - len;
  204. u32 i = msg->sg.end;
  205. if (trim <= 0) {
  206. WARN_ON(trim < 0);
  207. return;
  208. }
  209. sk_msg_iter_var_prev(i);
  210. msg->sg.size = len;
  211. while (msg->sg.data[i].length &&
  212. trim >= msg->sg.data[i].length) {
  213. trim -= msg->sg.data[i].length;
  214. sk_msg_free_elem(sk, msg, i, true);
  215. sk_msg_iter_var_prev(i);
  216. if (!trim)
  217. goto out;
  218. }
  219. msg->sg.data[i].length -= trim;
  220. sk_mem_uncharge(sk, trim);
  221. out:
  222. /* If we trim data before curr pointer update copybreak and current
  223. * so that any future copy operations start at new copy location.
  224. * However trimed data that has not yet been used in a copy op
  225. * does not require an update.
  226. */
  227. if (msg->sg.curr >= i) {
  228. msg->sg.curr = i;
  229. msg->sg.copybreak = msg->sg.data[i].length;
  230. }
  231. sk_msg_iter_var_next(i);
  232. msg->sg.end = i;
  233. }
  234. EXPORT_SYMBOL_GPL(sk_msg_trim);
  235. int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
  236. struct sk_msg *msg, u32 bytes)
  237. {
  238. int i, maxpages, ret = 0, num_elems = sk_msg_elem_used(msg);
  239. const int to_max_pages = MAX_MSG_FRAGS;
  240. struct page *pages[MAX_MSG_FRAGS];
  241. ssize_t orig, copied, use, offset;
  242. orig = msg->sg.size;
  243. while (bytes > 0) {
  244. i = 0;
  245. maxpages = to_max_pages - num_elems;
  246. if (maxpages == 0) {
  247. ret = -EFAULT;
  248. goto out;
  249. }
  250. copied = iov_iter_get_pages(from, pages, bytes, maxpages,
  251. &offset);
  252. if (copied <= 0) {
  253. ret = -EFAULT;
  254. goto out;
  255. }
  256. iov_iter_advance(from, copied);
  257. bytes -= copied;
  258. msg->sg.size += copied;
  259. while (copied) {
  260. use = min_t(int, copied, PAGE_SIZE - offset);
  261. sg_set_page(&msg->sg.data[msg->sg.end],
  262. pages[i], use, offset);
  263. sg_unmark_end(&msg->sg.data[msg->sg.end]);
  264. sk_mem_charge(sk, use);
  265. offset = 0;
  266. copied -= use;
  267. sk_msg_iter_next(msg, end);
  268. num_elems++;
  269. i++;
  270. }
  271. /* When zerocopy is mixed with sk_msg_*copy* operations we
  272. * may have a copybreak set in this case clear and prefer
  273. * zerocopy remainder when possible.
  274. */
  275. msg->sg.copybreak = 0;
  276. msg->sg.curr = msg->sg.end;
  277. }
  278. out:
  279. /* Revert iov_iter updates, msg will need to use 'trim' later if it
  280. * also needs to be cleared.
  281. */
  282. if (ret)
  283. iov_iter_revert(from, msg->sg.size - orig);
  284. return ret;
  285. }
  286. EXPORT_SYMBOL_GPL(sk_msg_zerocopy_from_iter);
  287. int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
  288. struct sk_msg *msg, u32 bytes)
  289. {
  290. int ret = -ENOSPC, i = msg->sg.curr;
  291. struct scatterlist *sge;
  292. u32 copy, buf_size;
  293. void *to;
  294. do {
  295. sge = sk_msg_elem(msg, i);
  296. /* This is possible if a trim operation shrunk the buffer */
  297. if (msg->sg.copybreak >= sge->length) {
  298. msg->sg.copybreak = 0;
  299. sk_msg_iter_var_next(i);
  300. if (i == msg->sg.end)
  301. break;
  302. sge = sk_msg_elem(msg, i);
  303. }
  304. buf_size = sge->length - msg->sg.copybreak;
  305. copy = (buf_size > bytes) ? bytes : buf_size;
  306. to = sg_virt(sge) + msg->sg.copybreak;
  307. msg->sg.copybreak += copy;
  308. if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
  309. ret = copy_from_iter_nocache(to, copy, from);
  310. else
  311. ret = copy_from_iter(to, copy, from);
  312. if (ret != copy) {
  313. ret = -EFAULT;
  314. goto out;
  315. }
  316. bytes -= copy;
  317. if (!bytes)
  318. break;
  319. msg->sg.copybreak = 0;
  320. sk_msg_iter_var_next(i);
  321. } while (i != msg->sg.end);
  322. out:
  323. msg->sg.curr = i;
  324. return ret;
  325. }
  326. EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter);
  327. static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb)
  328. {
  329. struct sock *sk = psock->sk;
  330. int copied = 0, num_sge;
  331. struct sk_msg *msg;
  332. msg = kzalloc(sizeof(*msg), __GFP_NOWARN | GFP_ATOMIC);
  333. if (unlikely(!msg))
  334. return -EAGAIN;
  335. if (!sk_rmem_schedule(sk, skb, skb->len)) {
  336. kfree(msg);
  337. return -EAGAIN;
  338. }
  339. sk_msg_init(msg);
  340. num_sge = skb_to_sgvec(skb, msg->sg.data, 0, skb->len);
  341. if (unlikely(num_sge < 0)) {
  342. kfree(msg);
  343. return num_sge;
  344. }
  345. sk_mem_charge(sk, skb->len);
  346. copied = skb->len;
  347. msg->sg.start = 0;
  348. msg->sg.end = num_sge == MAX_MSG_FRAGS ? 0 : num_sge;
  349. msg->skb = skb;
  350. sk_psock_queue_msg(psock, msg);
  351. sk->sk_data_ready(sk);
  352. return copied;
  353. }
  354. static int sk_psock_handle_skb(struct sk_psock *psock, struct sk_buff *skb,
  355. u32 off, u32 len, bool ingress)
  356. {
  357. if (ingress)
  358. return sk_psock_skb_ingress(psock, skb);
  359. else
  360. return skb_send_sock_locked(psock->sk, skb, off, len);
  361. }
  362. static void sk_psock_backlog(struct work_struct *work)
  363. {
  364. struct sk_psock *psock = container_of(work, struct sk_psock, work);
  365. struct sk_psock_work_state *state = &psock->work_state;
  366. struct sk_buff *skb;
  367. bool ingress;
  368. u32 len, off;
  369. int ret;
  370. /* Lock sock to avoid losing sk_socket during loop. */
  371. lock_sock(psock->sk);
  372. if (state->skb) {
  373. skb = state->skb;
  374. len = state->len;
  375. off = state->off;
  376. state->skb = NULL;
  377. goto start;
  378. }
  379. while ((skb = skb_dequeue(&psock->ingress_skb))) {
  380. len = skb->len;
  381. off = 0;
  382. start:
  383. ingress = tcp_skb_bpf_ingress(skb);
  384. do {
  385. ret = -EIO;
  386. if (likely(psock->sk->sk_socket))
  387. ret = sk_psock_handle_skb(psock, skb, off,
  388. len, ingress);
  389. if (ret <= 0) {
  390. if (ret == -EAGAIN) {
  391. state->skb = skb;
  392. state->len = len;
  393. state->off = off;
  394. goto end;
  395. }
  396. /* Hard errors break pipe and stop xmit. */
  397. sk_psock_report_error(psock, ret ? -ret : EPIPE);
  398. sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
  399. kfree_skb(skb);
  400. goto end;
  401. }
  402. off += ret;
  403. len -= ret;
  404. } while (len);
  405. if (!ingress)
  406. kfree_skb(skb);
  407. }
  408. end:
  409. release_sock(psock->sk);
  410. }
  411. struct sk_psock *sk_psock_init(struct sock *sk, int node)
  412. {
  413. struct sk_psock *psock = kzalloc_node(sizeof(*psock),
  414. GFP_ATOMIC | __GFP_NOWARN,
  415. node);
  416. if (!psock)
  417. return NULL;
  418. psock->sk = sk;
  419. psock->eval = __SK_NONE;
  420. INIT_LIST_HEAD(&psock->link);
  421. spin_lock_init(&psock->link_lock);
  422. INIT_WORK(&psock->work, sk_psock_backlog);
  423. INIT_LIST_HEAD(&psock->ingress_msg);
  424. skb_queue_head_init(&psock->ingress_skb);
  425. sk_psock_set_state(psock, SK_PSOCK_TX_ENABLED);
  426. refcount_set(&psock->refcnt, 1);
  427. rcu_assign_sk_user_data(sk, psock);
  428. sock_hold(sk);
  429. return psock;
  430. }
  431. EXPORT_SYMBOL_GPL(sk_psock_init);
  432. struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock)
  433. {
  434. struct sk_psock_link *link;
  435. spin_lock_bh(&psock->link_lock);
  436. link = list_first_entry_or_null(&psock->link, struct sk_psock_link,
  437. list);
  438. if (link)
  439. list_del(&link->list);
  440. spin_unlock_bh(&psock->link_lock);
  441. return link;
  442. }
  443. void __sk_psock_purge_ingress_msg(struct sk_psock *psock)
  444. {
  445. struct sk_msg *msg, *tmp;
  446. list_for_each_entry_safe(msg, tmp, &psock->ingress_msg, list) {
  447. list_del(&msg->list);
  448. sk_msg_free(psock->sk, msg);
  449. kfree(msg);
  450. }
  451. }
  452. static void sk_psock_zap_ingress(struct sk_psock *psock)
  453. {
  454. __skb_queue_purge(&psock->ingress_skb);
  455. __sk_psock_purge_ingress_msg(psock);
  456. }
  457. static void sk_psock_link_destroy(struct sk_psock *psock)
  458. {
  459. struct sk_psock_link *link, *tmp;
  460. list_for_each_entry_safe(link, tmp, &psock->link, list) {
  461. list_del(&link->list);
  462. sk_psock_free_link(link);
  463. }
  464. }
  465. static void sk_psock_destroy_deferred(struct work_struct *gc)
  466. {
  467. struct sk_psock *psock = container_of(gc, struct sk_psock, gc);
  468. /* No sk_callback_lock since already detached. */
  469. if (psock->parser.enabled)
  470. strp_done(&psock->parser.strp);
  471. cancel_work_sync(&psock->work);
  472. psock_progs_drop(&psock->progs);
  473. sk_psock_link_destroy(psock);
  474. sk_psock_cork_free(psock);
  475. sk_psock_zap_ingress(psock);
  476. if (psock->sk_redir)
  477. sock_put(psock->sk_redir);
  478. sock_put(psock->sk);
  479. kfree(psock);
  480. }
  481. void sk_psock_destroy(struct rcu_head *rcu)
  482. {
  483. struct sk_psock *psock = container_of(rcu, struct sk_psock, rcu);
  484. INIT_WORK(&psock->gc, sk_psock_destroy_deferred);
  485. schedule_work(&psock->gc);
  486. }
  487. EXPORT_SYMBOL_GPL(sk_psock_destroy);
  488. void sk_psock_drop(struct sock *sk, struct sk_psock *psock)
  489. {
  490. rcu_assign_sk_user_data(sk, NULL);
  491. sk_psock_cork_free(psock);
  492. sk_psock_restore_proto(sk, psock);
  493. write_lock_bh(&sk->sk_callback_lock);
  494. if (psock->progs.skb_parser)
  495. sk_psock_stop_strp(sk, psock);
  496. write_unlock_bh(&sk->sk_callback_lock);
  497. sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
  498. call_rcu_sched(&psock->rcu, sk_psock_destroy);
  499. }
  500. EXPORT_SYMBOL_GPL(sk_psock_drop);
  501. static int sk_psock_map_verd(int verdict, bool redir)
  502. {
  503. switch (verdict) {
  504. case SK_PASS:
  505. return redir ? __SK_REDIRECT : __SK_PASS;
  506. case SK_DROP:
  507. default:
  508. break;
  509. }
  510. return __SK_DROP;
  511. }
  512. int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock,
  513. struct sk_msg *msg)
  514. {
  515. struct bpf_prog *prog;
  516. int ret;
  517. preempt_disable();
  518. rcu_read_lock();
  519. prog = READ_ONCE(psock->progs.msg_parser);
  520. if (unlikely(!prog)) {
  521. ret = __SK_PASS;
  522. goto out;
  523. }
  524. sk_msg_compute_data_pointers(msg);
  525. msg->sk = sk;
  526. ret = BPF_PROG_RUN(prog, msg);
  527. ret = sk_psock_map_verd(ret, msg->sk_redir);
  528. psock->apply_bytes = msg->apply_bytes;
  529. if (ret == __SK_REDIRECT) {
  530. if (psock->sk_redir)
  531. sock_put(psock->sk_redir);
  532. psock->sk_redir = msg->sk_redir;
  533. if (!psock->sk_redir) {
  534. ret = __SK_DROP;
  535. goto out;
  536. }
  537. sock_hold(psock->sk_redir);
  538. }
  539. out:
  540. rcu_read_unlock();
  541. preempt_enable();
  542. return ret;
  543. }
  544. EXPORT_SYMBOL_GPL(sk_psock_msg_verdict);
  545. static int sk_psock_bpf_run(struct sk_psock *psock, struct bpf_prog *prog,
  546. struct sk_buff *skb)
  547. {
  548. int ret;
  549. skb->sk = psock->sk;
  550. bpf_compute_data_end_sk_skb(skb);
  551. preempt_disable();
  552. ret = BPF_PROG_RUN(prog, skb);
  553. preempt_enable();
  554. /* strparser clones the skb before handing it to a upper layer,
  555. * meaning skb_orphan has been called. We NULL sk on the way out
  556. * to ensure we don't trigger a BUG_ON() in skb/sk operations
  557. * later and because we are not charging the memory of this skb
  558. * to any socket yet.
  559. */
  560. skb->sk = NULL;
  561. return ret;
  562. }
  563. static struct sk_psock *sk_psock_from_strp(struct strparser *strp)
  564. {
  565. struct sk_psock_parser *parser;
  566. parser = container_of(strp, struct sk_psock_parser, strp);
  567. return container_of(parser, struct sk_psock, parser);
  568. }
  569. static void sk_psock_verdict_apply(struct sk_psock *psock,
  570. struct sk_buff *skb, int verdict)
  571. {
  572. struct sk_psock *psock_other;
  573. struct sock *sk_other;
  574. bool ingress;
  575. switch (verdict) {
  576. case __SK_REDIRECT:
  577. sk_other = tcp_skb_bpf_redirect_fetch(skb);
  578. if (unlikely(!sk_other))
  579. goto out_free;
  580. psock_other = sk_psock(sk_other);
  581. if (!psock_other || sock_flag(sk_other, SOCK_DEAD) ||
  582. !sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED))
  583. goto out_free;
  584. ingress = tcp_skb_bpf_ingress(skb);
  585. if ((!ingress && sock_writeable(sk_other)) ||
  586. (ingress &&
  587. atomic_read(&sk_other->sk_rmem_alloc) <=
  588. sk_other->sk_rcvbuf)) {
  589. if (!ingress)
  590. skb_set_owner_w(skb, sk_other);
  591. skb_queue_tail(&psock_other->ingress_skb, skb);
  592. schedule_work(&psock_other->work);
  593. break;
  594. }
  595. /* fall-through */
  596. case __SK_DROP:
  597. /* fall-through */
  598. default:
  599. out_free:
  600. kfree_skb(skb);
  601. }
  602. }
  603. static void sk_psock_strp_read(struct strparser *strp, struct sk_buff *skb)
  604. {
  605. struct sk_psock *psock = sk_psock_from_strp(strp);
  606. struct bpf_prog *prog;
  607. int ret = __SK_DROP;
  608. rcu_read_lock();
  609. prog = READ_ONCE(psock->progs.skb_verdict);
  610. if (likely(prog)) {
  611. skb_orphan(skb);
  612. tcp_skb_bpf_redirect_clear(skb);
  613. ret = sk_psock_bpf_run(psock, prog, skb);
  614. ret = sk_psock_map_verd(ret, tcp_skb_bpf_redirect_fetch(skb));
  615. }
  616. rcu_read_unlock();
  617. sk_psock_verdict_apply(psock, skb, ret);
  618. }
  619. static int sk_psock_strp_read_done(struct strparser *strp, int err)
  620. {
  621. return err;
  622. }
  623. static int sk_psock_strp_parse(struct strparser *strp, struct sk_buff *skb)
  624. {
  625. struct sk_psock *psock = sk_psock_from_strp(strp);
  626. struct bpf_prog *prog;
  627. int ret = skb->len;
  628. rcu_read_lock();
  629. prog = READ_ONCE(psock->progs.skb_parser);
  630. if (likely(prog))
  631. ret = sk_psock_bpf_run(psock, prog, skb);
  632. rcu_read_unlock();
  633. return ret;
  634. }
  635. /* Called with socket lock held. */
  636. static void sk_psock_data_ready(struct sock *sk)
  637. {
  638. struct sk_psock *psock;
  639. rcu_read_lock();
  640. psock = sk_psock(sk);
  641. if (likely(psock)) {
  642. write_lock_bh(&sk->sk_callback_lock);
  643. strp_data_ready(&psock->parser.strp);
  644. write_unlock_bh(&sk->sk_callback_lock);
  645. }
  646. rcu_read_unlock();
  647. }
  648. static void sk_psock_write_space(struct sock *sk)
  649. {
  650. struct sk_psock *psock;
  651. void (*write_space)(struct sock *sk);
  652. rcu_read_lock();
  653. psock = sk_psock(sk);
  654. if (likely(psock && sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)))
  655. schedule_work(&psock->work);
  656. write_space = psock->saved_write_space;
  657. rcu_read_unlock();
  658. write_space(sk);
  659. }
  660. int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock)
  661. {
  662. static const struct strp_callbacks cb = {
  663. .rcv_msg = sk_psock_strp_read,
  664. .read_sock_done = sk_psock_strp_read_done,
  665. .parse_msg = sk_psock_strp_parse,
  666. };
  667. psock->parser.enabled = false;
  668. return strp_init(&psock->parser.strp, sk, &cb);
  669. }
  670. void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock)
  671. {
  672. struct sk_psock_parser *parser = &psock->parser;
  673. if (parser->enabled)
  674. return;
  675. parser->saved_data_ready = sk->sk_data_ready;
  676. sk->sk_data_ready = sk_psock_data_ready;
  677. sk->sk_write_space = sk_psock_write_space;
  678. parser->enabled = true;
  679. }
  680. void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock)
  681. {
  682. struct sk_psock_parser *parser = &psock->parser;
  683. if (!parser->enabled)
  684. return;
  685. sk->sk_data_ready = parser->saved_data_ready;
  686. parser->saved_data_ready = NULL;
  687. strp_stop(&parser->strp);
  688. parser->enabled = false;
  689. }