virtio_transport_common.c 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991
  1. /*
  2. * common code for virtio vsock
  3. *
  4. * Copyright (C) 2013-2015 Red Hat, Inc.
  5. * Author: Asias He <asias@redhat.com>
  6. * Stefan Hajnoczi <stefanha@redhat.com>
  7. *
  8. * This work is licensed under the terms of the GNU GPL, version 2.
  9. */
  10. #include <linux/spinlock.h>
  11. #include <linux/module.h>
  12. #include <linux/ctype.h>
  13. #include <linux/list.h>
  14. #include <linux/virtio.h>
  15. #include <linux/virtio_ids.h>
  16. #include <linux/virtio_config.h>
  17. #include <linux/virtio_vsock.h>
  18. #include <net/sock.h>
  19. #include <net/af_vsock.h>
  20. #define CREATE_TRACE_POINTS
  21. #include <trace/events/vsock_virtio_transport_common.h>
  22. /* How long to wait for graceful shutdown of a connection */
  23. #define VSOCK_CLOSE_TIMEOUT (8 * HZ)
  24. static const struct virtio_transport *virtio_transport_get_ops(void)
  25. {
  26. const struct vsock_transport *t = vsock_core_get_transport();
  27. return container_of(t, struct virtio_transport, transport);
  28. }
  29. static struct virtio_vsock_pkt *
  30. virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
  31. size_t len,
  32. u32 src_cid,
  33. u32 src_port,
  34. u32 dst_cid,
  35. u32 dst_port)
  36. {
  37. struct virtio_vsock_pkt *pkt;
  38. int err;
  39. pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
  40. if (!pkt)
  41. return NULL;
  42. pkt->hdr.type = cpu_to_le16(info->type);
  43. pkt->hdr.op = cpu_to_le16(info->op);
  44. pkt->hdr.src_cid = cpu_to_le64(src_cid);
  45. pkt->hdr.dst_cid = cpu_to_le64(dst_cid);
  46. pkt->hdr.src_port = cpu_to_le32(src_port);
  47. pkt->hdr.dst_port = cpu_to_le32(dst_port);
  48. pkt->hdr.flags = cpu_to_le32(info->flags);
  49. pkt->len = len;
  50. pkt->hdr.len = cpu_to_le32(len);
  51. pkt->reply = info->reply;
  52. if (info->msg && len > 0) {
  53. pkt->buf = kmalloc(len, GFP_KERNEL);
  54. if (!pkt->buf)
  55. goto out_pkt;
  56. err = memcpy_from_msg(pkt->buf, info->msg, len);
  57. if (err)
  58. goto out;
  59. }
  60. trace_virtio_transport_alloc_pkt(src_cid, src_port,
  61. dst_cid, dst_port,
  62. len,
  63. info->type,
  64. info->op,
  65. info->flags);
  66. return pkt;
  67. out:
  68. kfree(pkt->buf);
  69. out_pkt:
  70. kfree(pkt);
  71. return NULL;
  72. }
  73. static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
  74. struct virtio_vsock_pkt_info *info)
  75. {
  76. u32 src_cid, src_port, dst_cid, dst_port;
  77. struct virtio_vsock_sock *vvs;
  78. struct virtio_vsock_pkt *pkt;
  79. u32 pkt_len = info->pkt_len;
  80. src_cid = vm_sockets_get_local_cid();
  81. src_port = vsk->local_addr.svm_port;
  82. if (!info->remote_cid) {
  83. dst_cid = vsk->remote_addr.svm_cid;
  84. dst_port = vsk->remote_addr.svm_port;
  85. } else {
  86. dst_cid = info->remote_cid;
  87. dst_port = info->remote_port;
  88. }
  89. vvs = vsk->trans;
  90. /* we can send less than pkt_len bytes */
  91. if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE)
  92. pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
  93. /* virtio_transport_get_credit might return less than pkt_len credit */
  94. pkt_len = virtio_transport_get_credit(vvs, pkt_len);
  95. /* Do not send zero length OP_RW pkt */
  96. if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
  97. return pkt_len;
  98. pkt = virtio_transport_alloc_pkt(info, pkt_len,
  99. src_cid, src_port,
  100. dst_cid, dst_port);
  101. if (!pkt) {
  102. virtio_transport_put_credit(vvs, pkt_len);
  103. return -ENOMEM;
  104. }
  105. virtio_transport_inc_tx_pkt(vvs, pkt);
  106. return virtio_transport_get_ops()->send_pkt(pkt);
  107. }
  108. static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
  109. struct virtio_vsock_pkt *pkt)
  110. {
  111. vvs->rx_bytes += pkt->len;
  112. }
  113. static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
  114. struct virtio_vsock_pkt *pkt)
  115. {
  116. vvs->rx_bytes -= pkt->len;
  117. vvs->fwd_cnt += pkt->len;
  118. }
  119. void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
  120. {
  121. spin_lock_bh(&vvs->tx_lock);
  122. pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
  123. pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
  124. spin_unlock_bh(&vvs->tx_lock);
  125. }
  126. EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
  127. u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
  128. {
  129. u32 ret;
  130. spin_lock_bh(&vvs->tx_lock);
  131. ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
  132. if (ret > credit)
  133. ret = credit;
  134. vvs->tx_cnt += ret;
  135. spin_unlock_bh(&vvs->tx_lock);
  136. return ret;
  137. }
  138. EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
  139. void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
  140. {
  141. spin_lock_bh(&vvs->tx_lock);
  142. vvs->tx_cnt -= credit;
  143. spin_unlock_bh(&vvs->tx_lock);
  144. }
  145. EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
  146. static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
  147. int type,
  148. struct virtio_vsock_hdr *hdr)
  149. {
  150. struct virtio_vsock_pkt_info info = {
  151. .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
  152. .type = type,
  153. };
  154. return virtio_transport_send_pkt_info(vsk, &info);
  155. }
  156. static ssize_t
  157. virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
  158. struct msghdr *msg,
  159. size_t len)
  160. {
  161. struct virtio_vsock_sock *vvs = vsk->trans;
  162. struct virtio_vsock_pkt *pkt;
  163. size_t bytes, total = 0;
  164. int err = -EFAULT;
  165. spin_lock_bh(&vvs->rx_lock);
  166. while (total < len && !list_empty(&vvs->rx_queue)) {
  167. pkt = list_first_entry(&vvs->rx_queue,
  168. struct virtio_vsock_pkt, list);
  169. bytes = len - total;
  170. if (bytes > pkt->len - pkt->off)
  171. bytes = pkt->len - pkt->off;
  172. /* sk_lock is held by caller so no one else can dequeue.
  173. * Unlock rx_lock since memcpy_to_msg() may sleep.
  174. */
  175. spin_unlock_bh(&vvs->rx_lock);
  176. err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
  177. if (err)
  178. goto out;
  179. spin_lock_bh(&vvs->rx_lock);
  180. total += bytes;
  181. pkt->off += bytes;
  182. if (pkt->off == pkt->len) {
  183. virtio_transport_dec_rx_pkt(vvs, pkt);
  184. list_del(&pkt->list);
  185. virtio_transport_free_pkt(pkt);
  186. }
  187. }
  188. spin_unlock_bh(&vvs->rx_lock);
  189. /* Send a credit pkt to peer */
  190. virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
  191. NULL);
  192. return total;
  193. out:
  194. if (total)
  195. err = total;
  196. return err;
  197. }
  198. ssize_t
  199. virtio_transport_stream_dequeue(struct vsock_sock *vsk,
  200. struct msghdr *msg,
  201. size_t len, int flags)
  202. {
  203. if (flags & MSG_PEEK)
  204. return -EOPNOTSUPP;
  205. return virtio_transport_stream_do_dequeue(vsk, msg, len);
  206. }
  207. EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
  208. int
  209. virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
  210. struct msghdr *msg,
  211. size_t len, int flags)
  212. {
  213. return -EOPNOTSUPP;
  214. }
  215. EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
  216. s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
  217. {
  218. struct virtio_vsock_sock *vvs = vsk->trans;
  219. s64 bytes;
  220. spin_lock_bh(&vvs->rx_lock);
  221. bytes = vvs->rx_bytes;
  222. spin_unlock_bh(&vvs->rx_lock);
  223. return bytes;
  224. }
  225. EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
  226. static s64 virtio_transport_has_space(struct vsock_sock *vsk)
  227. {
  228. struct virtio_vsock_sock *vvs = vsk->trans;
  229. s64 bytes;
  230. bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
  231. if (bytes < 0)
  232. bytes = 0;
  233. return bytes;
  234. }
  235. s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
  236. {
  237. struct virtio_vsock_sock *vvs = vsk->trans;
  238. s64 bytes;
  239. spin_lock_bh(&vvs->tx_lock);
  240. bytes = virtio_transport_has_space(vsk);
  241. spin_unlock_bh(&vvs->tx_lock);
  242. return bytes;
  243. }
  244. EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
  245. int virtio_transport_do_socket_init(struct vsock_sock *vsk,
  246. struct vsock_sock *psk)
  247. {
  248. struct virtio_vsock_sock *vvs;
  249. vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
  250. if (!vvs)
  251. return -ENOMEM;
  252. vsk->trans = vvs;
  253. vvs->vsk = vsk;
  254. if (psk) {
  255. struct virtio_vsock_sock *ptrans = psk->trans;
  256. vvs->buf_size = ptrans->buf_size;
  257. vvs->buf_size_min = ptrans->buf_size_min;
  258. vvs->buf_size_max = ptrans->buf_size_max;
  259. vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
  260. } else {
  261. vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
  262. vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
  263. vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
  264. }
  265. vvs->buf_alloc = vvs->buf_size;
  266. spin_lock_init(&vvs->rx_lock);
  267. spin_lock_init(&vvs->tx_lock);
  268. INIT_LIST_HEAD(&vvs->rx_queue);
  269. return 0;
  270. }
  271. EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
  272. u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
  273. {
  274. struct virtio_vsock_sock *vvs = vsk->trans;
  275. return vvs->buf_size;
  276. }
  277. EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
  278. u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
  279. {
  280. struct virtio_vsock_sock *vvs = vsk->trans;
  281. return vvs->buf_size_min;
  282. }
  283. EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
  284. u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
  285. {
  286. struct virtio_vsock_sock *vvs = vsk->trans;
  287. return vvs->buf_size_max;
  288. }
  289. EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
  290. void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
  291. {
  292. struct virtio_vsock_sock *vvs = vsk->trans;
  293. if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
  294. val = VIRTIO_VSOCK_MAX_BUF_SIZE;
  295. if (val < vvs->buf_size_min)
  296. vvs->buf_size_min = val;
  297. if (val > vvs->buf_size_max)
  298. vvs->buf_size_max = val;
  299. vvs->buf_size = val;
  300. vvs->buf_alloc = val;
  301. }
  302. EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
  303. void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
  304. {
  305. struct virtio_vsock_sock *vvs = vsk->trans;
  306. if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
  307. val = VIRTIO_VSOCK_MAX_BUF_SIZE;
  308. if (val > vvs->buf_size)
  309. vvs->buf_size = val;
  310. vvs->buf_size_min = val;
  311. }
  312. EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
  313. void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
  314. {
  315. struct virtio_vsock_sock *vvs = vsk->trans;
  316. if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
  317. val = VIRTIO_VSOCK_MAX_BUF_SIZE;
  318. if (val < vvs->buf_size)
  319. vvs->buf_size = val;
  320. vvs->buf_size_max = val;
  321. }
  322. EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
  323. int
  324. virtio_transport_notify_poll_in(struct vsock_sock *vsk,
  325. size_t target,
  326. bool *data_ready_now)
  327. {
  328. if (vsock_stream_has_data(vsk))
  329. *data_ready_now = true;
  330. else
  331. *data_ready_now = false;
  332. return 0;
  333. }
  334. EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
  335. int
  336. virtio_transport_notify_poll_out(struct vsock_sock *vsk,
  337. size_t target,
  338. bool *space_avail_now)
  339. {
  340. s64 free_space;
  341. free_space = vsock_stream_has_space(vsk);
  342. if (free_space > 0)
  343. *space_avail_now = true;
  344. else if (free_space == 0)
  345. *space_avail_now = false;
  346. return 0;
  347. }
  348. EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
  349. int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
  350. size_t target, struct vsock_transport_recv_notify_data *data)
  351. {
  352. return 0;
  353. }
  354. EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
  355. int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
  356. size_t target, struct vsock_transport_recv_notify_data *data)
  357. {
  358. return 0;
  359. }
  360. EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
  361. int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
  362. size_t target, struct vsock_transport_recv_notify_data *data)
  363. {
  364. return 0;
  365. }
  366. EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
  367. int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
  368. size_t target, ssize_t copied, bool data_read,
  369. struct vsock_transport_recv_notify_data *data)
  370. {
  371. return 0;
  372. }
  373. EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
  374. int virtio_transport_notify_send_init(struct vsock_sock *vsk,
  375. struct vsock_transport_send_notify_data *data)
  376. {
  377. return 0;
  378. }
  379. EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
  380. int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
  381. struct vsock_transport_send_notify_data *data)
  382. {
  383. return 0;
  384. }
  385. EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
  386. int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
  387. struct vsock_transport_send_notify_data *data)
  388. {
  389. return 0;
  390. }
  391. EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
  392. int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
  393. ssize_t written, struct vsock_transport_send_notify_data *data)
  394. {
  395. return 0;
  396. }
  397. EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
  398. u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
  399. {
  400. struct virtio_vsock_sock *vvs = vsk->trans;
  401. return vvs->buf_size;
  402. }
  403. EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
  404. bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
  405. {
  406. return true;
  407. }
  408. EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
  409. bool virtio_transport_stream_allow(u32 cid, u32 port)
  410. {
  411. return true;
  412. }
  413. EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
  414. int virtio_transport_dgram_bind(struct vsock_sock *vsk,
  415. struct sockaddr_vm *addr)
  416. {
  417. return -EOPNOTSUPP;
  418. }
  419. EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
  420. bool virtio_transport_dgram_allow(u32 cid, u32 port)
  421. {
  422. return false;
  423. }
  424. EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
  425. int virtio_transport_connect(struct vsock_sock *vsk)
  426. {
  427. struct virtio_vsock_pkt_info info = {
  428. .op = VIRTIO_VSOCK_OP_REQUEST,
  429. .type = VIRTIO_VSOCK_TYPE_STREAM,
  430. };
  431. return virtio_transport_send_pkt_info(vsk, &info);
  432. }
  433. EXPORT_SYMBOL_GPL(virtio_transport_connect);
  434. int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
  435. {
  436. struct virtio_vsock_pkt_info info = {
  437. .op = VIRTIO_VSOCK_OP_SHUTDOWN,
  438. .type = VIRTIO_VSOCK_TYPE_STREAM,
  439. .flags = (mode & RCV_SHUTDOWN ?
  440. VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
  441. (mode & SEND_SHUTDOWN ?
  442. VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
  443. };
  444. return virtio_transport_send_pkt_info(vsk, &info);
  445. }
  446. EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
  447. int
  448. virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
  449. struct sockaddr_vm *remote_addr,
  450. struct msghdr *msg,
  451. size_t dgram_len)
  452. {
  453. return -EOPNOTSUPP;
  454. }
  455. EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
  456. ssize_t
  457. virtio_transport_stream_enqueue(struct vsock_sock *vsk,
  458. struct msghdr *msg,
  459. size_t len)
  460. {
  461. struct virtio_vsock_pkt_info info = {
  462. .op = VIRTIO_VSOCK_OP_RW,
  463. .type = VIRTIO_VSOCK_TYPE_STREAM,
  464. .msg = msg,
  465. .pkt_len = len,
  466. };
  467. return virtio_transport_send_pkt_info(vsk, &info);
  468. }
  469. EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
  470. void virtio_transport_destruct(struct vsock_sock *vsk)
  471. {
  472. struct virtio_vsock_sock *vvs = vsk->trans;
  473. kfree(vvs);
  474. }
  475. EXPORT_SYMBOL_GPL(virtio_transport_destruct);
  476. static int virtio_transport_reset(struct vsock_sock *vsk,
  477. struct virtio_vsock_pkt *pkt)
  478. {
  479. struct virtio_vsock_pkt_info info = {
  480. .op = VIRTIO_VSOCK_OP_RST,
  481. .type = VIRTIO_VSOCK_TYPE_STREAM,
  482. .reply = !!pkt,
  483. };
  484. /* Send RST only if the original pkt is not a RST pkt */
  485. if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
  486. return 0;
  487. return virtio_transport_send_pkt_info(vsk, &info);
  488. }
  489. /* Normally packets are associated with a socket. There may be no socket if an
  490. * attempt was made to connect to a socket that does not exist.
  491. */
  492. static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
  493. {
  494. struct virtio_vsock_pkt_info info = {
  495. .op = VIRTIO_VSOCK_OP_RST,
  496. .type = le16_to_cpu(pkt->hdr.type),
  497. .reply = true,
  498. };
  499. /* Send RST only if the original pkt is not a RST pkt */
  500. if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
  501. return 0;
  502. pkt = virtio_transport_alloc_pkt(&info, 0,
  503. le64_to_cpu(pkt->hdr.dst_cid),
  504. le32_to_cpu(pkt->hdr.dst_port),
  505. le64_to_cpu(pkt->hdr.src_cid),
  506. le32_to_cpu(pkt->hdr.src_port));
  507. if (!pkt)
  508. return -ENOMEM;
  509. return virtio_transport_get_ops()->send_pkt(pkt);
  510. }
  511. static void virtio_transport_wait_close(struct sock *sk, long timeout)
  512. {
  513. if (timeout) {
  514. DEFINE_WAIT_FUNC(wait, woken_wake_function);
  515. add_wait_queue(sk_sleep(sk), &wait);
  516. do {
  517. if (sk_wait_event(sk, &timeout,
  518. sock_flag(sk, SOCK_DONE), &wait))
  519. break;
  520. } while (!signal_pending(current) && timeout);
  521. remove_wait_queue(sk_sleep(sk), &wait);
  522. }
  523. }
  524. static void virtio_transport_do_close(struct vsock_sock *vsk,
  525. bool cancel_timeout)
  526. {
  527. struct sock *sk = sk_vsock(vsk);
  528. sock_set_flag(sk, SOCK_DONE);
  529. vsk->peer_shutdown = SHUTDOWN_MASK;
  530. if (vsock_stream_has_data(vsk) <= 0)
  531. sk->sk_state = SS_DISCONNECTING;
  532. sk->sk_state_change(sk);
  533. if (vsk->close_work_scheduled &&
  534. (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
  535. vsk->close_work_scheduled = false;
  536. vsock_remove_sock(vsk);
  537. /* Release refcnt obtained when we scheduled the timeout */
  538. sock_put(sk);
  539. }
  540. }
  541. static void virtio_transport_close_timeout(struct work_struct *work)
  542. {
  543. struct vsock_sock *vsk =
  544. container_of(work, struct vsock_sock, close_work.work);
  545. struct sock *sk = sk_vsock(vsk);
  546. sock_hold(sk);
  547. lock_sock(sk);
  548. if (!sock_flag(sk, SOCK_DONE)) {
  549. (void)virtio_transport_reset(vsk, NULL);
  550. virtio_transport_do_close(vsk, false);
  551. }
  552. vsk->close_work_scheduled = false;
  553. release_sock(sk);
  554. sock_put(sk);
  555. }
  556. /* User context, vsk->sk is locked */
  557. static bool virtio_transport_close(struct vsock_sock *vsk)
  558. {
  559. struct sock *sk = &vsk->sk;
  560. if (!(sk->sk_state == SS_CONNECTED ||
  561. sk->sk_state == SS_DISCONNECTING))
  562. return true;
  563. /* Already received SHUTDOWN from peer, reply with RST */
  564. if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
  565. (void)virtio_transport_reset(vsk, NULL);
  566. return true;
  567. }
  568. if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
  569. (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
  570. if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
  571. virtio_transport_wait_close(sk, sk->sk_lingertime);
  572. if (sock_flag(sk, SOCK_DONE)) {
  573. return true;
  574. }
  575. sock_hold(sk);
  576. INIT_DELAYED_WORK(&vsk->close_work,
  577. virtio_transport_close_timeout);
  578. vsk->close_work_scheduled = true;
  579. schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
  580. return false;
  581. }
  582. void virtio_transport_release(struct vsock_sock *vsk)
  583. {
  584. struct sock *sk = &vsk->sk;
  585. bool remove_sock = true;
  586. lock_sock(sk);
  587. if (sk->sk_type == SOCK_STREAM)
  588. remove_sock = virtio_transport_close(vsk);
  589. release_sock(sk);
  590. if (remove_sock)
  591. vsock_remove_sock(vsk);
  592. }
  593. EXPORT_SYMBOL_GPL(virtio_transport_release);
  594. static int
  595. virtio_transport_recv_connecting(struct sock *sk,
  596. struct virtio_vsock_pkt *pkt)
  597. {
  598. struct vsock_sock *vsk = vsock_sk(sk);
  599. int err;
  600. int skerr;
  601. switch (le16_to_cpu(pkt->hdr.op)) {
  602. case VIRTIO_VSOCK_OP_RESPONSE:
  603. sk->sk_state = SS_CONNECTED;
  604. sk->sk_socket->state = SS_CONNECTED;
  605. vsock_insert_connected(vsk);
  606. sk->sk_state_change(sk);
  607. break;
  608. case VIRTIO_VSOCK_OP_INVALID:
  609. break;
  610. case VIRTIO_VSOCK_OP_RST:
  611. skerr = ECONNRESET;
  612. err = 0;
  613. goto destroy;
  614. default:
  615. skerr = EPROTO;
  616. err = -EINVAL;
  617. goto destroy;
  618. }
  619. return 0;
  620. destroy:
  621. virtio_transport_reset(vsk, pkt);
  622. sk->sk_state = SS_UNCONNECTED;
  623. sk->sk_err = skerr;
  624. sk->sk_error_report(sk);
  625. return err;
  626. }
  627. static int
  628. virtio_transport_recv_connected(struct sock *sk,
  629. struct virtio_vsock_pkt *pkt)
  630. {
  631. struct vsock_sock *vsk = vsock_sk(sk);
  632. struct virtio_vsock_sock *vvs = vsk->trans;
  633. int err = 0;
  634. switch (le16_to_cpu(pkt->hdr.op)) {
  635. case VIRTIO_VSOCK_OP_RW:
  636. pkt->len = le32_to_cpu(pkt->hdr.len);
  637. pkt->off = 0;
  638. spin_lock_bh(&vvs->rx_lock);
  639. virtio_transport_inc_rx_pkt(vvs, pkt);
  640. list_add_tail(&pkt->list, &vvs->rx_queue);
  641. spin_unlock_bh(&vvs->rx_lock);
  642. sk->sk_data_ready(sk);
  643. return err;
  644. case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
  645. sk->sk_write_space(sk);
  646. break;
  647. case VIRTIO_VSOCK_OP_SHUTDOWN:
  648. if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
  649. vsk->peer_shutdown |= RCV_SHUTDOWN;
  650. if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
  651. vsk->peer_shutdown |= SEND_SHUTDOWN;
  652. if (vsk->peer_shutdown == SHUTDOWN_MASK &&
  653. vsock_stream_has_data(vsk) <= 0)
  654. sk->sk_state = SS_DISCONNECTING;
  655. if (le32_to_cpu(pkt->hdr.flags))
  656. sk->sk_state_change(sk);
  657. break;
  658. case VIRTIO_VSOCK_OP_RST:
  659. virtio_transport_do_close(vsk, true);
  660. break;
  661. default:
  662. err = -EINVAL;
  663. break;
  664. }
  665. virtio_transport_free_pkt(pkt);
  666. return err;
  667. }
  668. static void
  669. virtio_transport_recv_disconnecting(struct sock *sk,
  670. struct virtio_vsock_pkt *pkt)
  671. {
  672. struct vsock_sock *vsk = vsock_sk(sk);
  673. if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
  674. virtio_transport_do_close(vsk, true);
  675. }
  676. static int
  677. virtio_transport_send_response(struct vsock_sock *vsk,
  678. struct virtio_vsock_pkt *pkt)
  679. {
  680. struct virtio_vsock_pkt_info info = {
  681. .op = VIRTIO_VSOCK_OP_RESPONSE,
  682. .type = VIRTIO_VSOCK_TYPE_STREAM,
  683. .remote_cid = le64_to_cpu(pkt->hdr.src_cid),
  684. .remote_port = le32_to_cpu(pkt->hdr.src_port),
  685. .reply = true,
  686. };
  687. return virtio_transport_send_pkt_info(vsk, &info);
  688. }
  689. /* Handle server socket */
  690. static int
  691. virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
  692. {
  693. struct vsock_sock *vsk = vsock_sk(sk);
  694. struct vsock_sock *vchild;
  695. struct sock *child;
  696. if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
  697. virtio_transport_reset(vsk, pkt);
  698. return -EINVAL;
  699. }
  700. if (sk_acceptq_is_full(sk)) {
  701. virtio_transport_reset(vsk, pkt);
  702. return -ENOMEM;
  703. }
  704. child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
  705. sk->sk_type, 0);
  706. if (!child) {
  707. virtio_transport_reset(vsk, pkt);
  708. return -ENOMEM;
  709. }
  710. sk->sk_ack_backlog++;
  711. lock_sock_nested(child, SINGLE_DEPTH_NESTING);
  712. child->sk_state = SS_CONNECTED;
  713. vchild = vsock_sk(child);
  714. vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid),
  715. le32_to_cpu(pkt->hdr.dst_port));
  716. vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid),
  717. le32_to_cpu(pkt->hdr.src_port));
  718. vsock_insert_connected(vchild);
  719. vsock_enqueue_accept(sk, child);
  720. virtio_transport_send_response(vchild, pkt);
  721. release_sock(child);
  722. sk->sk_data_ready(sk);
  723. return 0;
  724. }
  725. static bool virtio_transport_space_update(struct sock *sk,
  726. struct virtio_vsock_pkt *pkt)
  727. {
  728. struct vsock_sock *vsk = vsock_sk(sk);
  729. struct virtio_vsock_sock *vvs = vsk->trans;
  730. bool space_available;
  731. /* buf_alloc and fwd_cnt is always included in the hdr */
  732. spin_lock_bh(&vvs->tx_lock);
  733. vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
  734. vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
  735. space_available = virtio_transport_has_space(vsk);
  736. spin_unlock_bh(&vvs->tx_lock);
  737. return space_available;
  738. }
  739. /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
  740. * lock.
  741. */
  742. void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
  743. {
  744. struct sockaddr_vm src, dst;
  745. struct vsock_sock *vsk;
  746. struct sock *sk;
  747. bool space_available;
  748. vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid),
  749. le32_to_cpu(pkt->hdr.src_port));
  750. vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid),
  751. le32_to_cpu(pkt->hdr.dst_port));
  752. trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
  753. dst.svm_cid, dst.svm_port,
  754. le32_to_cpu(pkt->hdr.len),
  755. le16_to_cpu(pkt->hdr.type),
  756. le16_to_cpu(pkt->hdr.op),
  757. le32_to_cpu(pkt->hdr.flags),
  758. le32_to_cpu(pkt->hdr.buf_alloc),
  759. le32_to_cpu(pkt->hdr.fwd_cnt));
  760. if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
  761. (void)virtio_transport_reset_no_sock(pkt);
  762. goto free_pkt;
  763. }
  764. /* The socket must be in connected or bound table
  765. * otherwise send reset back
  766. */
  767. sk = vsock_find_connected_socket(&src, &dst);
  768. if (!sk) {
  769. sk = vsock_find_bound_socket(&dst);
  770. if (!sk) {
  771. (void)virtio_transport_reset_no_sock(pkt);
  772. goto free_pkt;
  773. }
  774. }
  775. vsk = vsock_sk(sk);
  776. space_available = virtio_transport_space_update(sk, pkt);
  777. lock_sock(sk);
  778. /* Update CID in case it has changed after a transport reset event */
  779. vsk->local_addr.svm_cid = dst.svm_cid;
  780. if (space_available)
  781. sk->sk_write_space(sk);
  782. switch (sk->sk_state) {
  783. case VSOCK_SS_LISTEN:
  784. virtio_transport_recv_listen(sk, pkt);
  785. virtio_transport_free_pkt(pkt);
  786. break;
  787. case SS_CONNECTING:
  788. virtio_transport_recv_connecting(sk, pkt);
  789. virtio_transport_free_pkt(pkt);
  790. break;
  791. case SS_CONNECTED:
  792. virtio_transport_recv_connected(sk, pkt);
  793. break;
  794. case SS_DISCONNECTING:
  795. virtio_transport_recv_disconnecting(sk, pkt);
  796. virtio_transport_free_pkt(pkt);
  797. break;
  798. default:
  799. virtio_transport_free_pkt(pkt);
  800. break;
  801. }
  802. release_sock(sk);
  803. /* Release refcnt obtained when we fetched this socket out of the
  804. * bound or connected list.
  805. */
  806. sock_put(sk);
  807. return;
  808. free_pkt:
  809. virtio_transport_free_pkt(pkt);
  810. }
  811. EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
  812. void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
  813. {
  814. kfree(pkt->buf);
  815. kfree(pkt);
  816. }
  817. EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
  818. MODULE_LICENSE("GPL v2");
  819. MODULE_AUTHOR("Asias He");
  820. MODULE_DESCRIPTION("common code for virtio vsock");