|
@@ -4,6 +4,7 @@
|
|
|
* Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
|
|
|
* Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
|
|
|
* Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
|
|
|
+ * Copyright (c) 2018, Covalent IO, Inc. http://covalent.io
|
|
|
*
|
|
|
* This software is available to you under a choice of one of two
|
|
|
* licenses. You may choose to be licensed under the terms of the GNU
|
|
@@ -258,21 +259,58 @@ static int tls_clone_plaintext_msg(struct sock *sk, int required)
|
|
|
return sk_msg_clone(sk, msg_pl, msg_en, skip, len);
|
|
|
}
|
|
|
|
|
|
-static void tls_free_open_rec(struct sock *sk)
|
|
|
+static struct tls_rec *tls_get_rec(struct sock *sk)
|
|
|
{
|
|
|
struct tls_context *tls_ctx = tls_get_ctx(sk);
|
|
|
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
|
|
|
- struct tls_rec *rec = ctx->open_rec;
|
|
|
+ struct sk_msg *msg_pl, *msg_en;
|
|
|
+ struct tls_rec *rec;
|
|
|
+ int mem_size;
|
|
|
|
|
|
- /* Return if there is no open record */
|
|
|
+ mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send);
|
|
|
+
|
|
|
+ rec = kzalloc(mem_size, sk->sk_allocation);
|
|
|
if (!rec)
|
|
|
- return;
|
|
|
+ return NULL;
|
|
|
|
|
|
+ msg_pl = &rec->msg_plaintext;
|
|
|
+ msg_en = &rec->msg_encrypted;
|
|
|
+
|
|
|
+ sk_msg_init(msg_pl);
|
|
|
+ sk_msg_init(msg_en);
|
|
|
+
|
|
|
+ sg_init_table(rec->sg_aead_in, 2);
|
|
|
+ sg_set_buf(&rec->sg_aead_in[0], rec->aad_space,
|
|
|
+ sizeof(rec->aad_space));
|
|
|
+ sg_unmark_end(&rec->sg_aead_in[1]);
|
|
|
+
|
|
|
+ sg_init_table(rec->sg_aead_out, 2);
|
|
|
+ sg_set_buf(&rec->sg_aead_out[0], rec->aad_space,
|
|
|
+ sizeof(rec->aad_space));
|
|
|
+ sg_unmark_end(&rec->sg_aead_out[1]);
|
|
|
+
|
|
|
+ return rec;
|
|
|
+}
|
|
|
+
|
|
|
+static void tls_free_rec(struct sock *sk, struct tls_rec *rec)
|
|
|
+{
|
|
|
sk_msg_free(sk, &rec->msg_encrypted);
|
|
|
sk_msg_free(sk, &rec->msg_plaintext);
|
|
|
kfree(rec);
|
|
|
}
|
|
|
|
|
|
+static void tls_free_open_rec(struct sock *sk)
|
|
|
+{
|
|
|
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
|
|
|
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
|
|
|
+ struct tls_rec *rec = ctx->open_rec;
|
|
|
+
|
|
|
+ if (rec) {
|
|
|
+ tls_free_rec(sk, rec);
|
|
|
+ ctx->open_rec = NULL;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
int tls_tx_records(struct sock *sk, int flags)
|
|
|
{
|
|
|
struct tls_context *tls_ctx = tls_get_ctx(sk);
|
|
@@ -439,16 +477,135 @@ static int tls_do_encryption(struct sock *sk,
|
|
|
return rc;
|
|
|
}
|
|
|
|
|
|
+static int tls_split_open_record(struct sock *sk, struct tls_rec *from,
|
|
|
+ struct tls_rec **to, struct sk_msg *msg_opl,
|
|
|
+ struct sk_msg *msg_oen, u32 split_point,
|
|
|
+ u32 tx_overhead_size, u32 *orig_end)
|
|
|
+{
|
|
|
+ u32 i, j, bytes = 0, apply = msg_opl->apply_bytes;
|
|
|
+ struct scatterlist *sge, *osge, *nsge;
|
|
|
+ u32 orig_size = msg_opl->sg.size;
|
|
|
+ struct scatterlist tmp = { };
|
|
|
+ struct sk_msg *msg_npl;
|
|
|
+ struct tls_rec *new;
|
|
|
+ int ret;
|
|
|
+
|
|
|
+ new = tls_get_rec(sk);
|
|
|
+ if (!new)
|
|
|
+ return -ENOMEM;
|
|
|
+ ret = sk_msg_alloc(sk, &new->msg_encrypted, msg_opl->sg.size +
|
|
|
+ tx_overhead_size, 0);
|
|
|
+ if (ret < 0) {
|
|
|
+ tls_free_rec(sk, new);
|
|
|
+ return ret;
|
|
|
+ }
|
|
|
+
|
|
|
+ *orig_end = msg_opl->sg.end;
|
|
|
+ i = msg_opl->sg.start;
|
|
|
+ sge = sk_msg_elem(msg_opl, i);
|
|
|
+ while (apply && sge->length) {
|
|
|
+ if (sge->length > apply) {
|
|
|
+ u32 len = sge->length - apply;
|
|
|
+
|
|
|
+ get_page(sg_page(sge));
|
|
|
+ sg_set_page(&tmp, sg_page(sge), len,
|
|
|
+ sge->offset + apply);
|
|
|
+ sge->length = apply;
|
|
|
+ bytes += apply;
|
|
|
+ apply = 0;
|
|
|
+ } else {
|
|
|
+ apply -= sge->length;
|
|
|
+ bytes += sge->length;
|
|
|
+ }
|
|
|
+
|
|
|
+ sk_msg_iter_var_next(i);
|
|
|
+ if (i == msg_opl->sg.end)
|
|
|
+ break;
|
|
|
+ sge = sk_msg_elem(msg_opl, i);
|
|
|
+ }
|
|
|
+
|
|
|
+ msg_opl->sg.end = i;
|
|
|
+ msg_opl->sg.curr = i;
|
|
|
+ msg_opl->sg.copybreak = 0;
|
|
|
+ msg_opl->apply_bytes = 0;
|
|
|
+ msg_opl->sg.size = bytes;
|
|
|
+
|
|
|
+ msg_npl = &new->msg_plaintext;
|
|
|
+ msg_npl->apply_bytes = apply;
|
|
|
+ msg_npl->sg.size = orig_size - bytes;
|
|
|
+
|
|
|
+ j = msg_npl->sg.start;
|
|
|
+ nsge = sk_msg_elem(msg_npl, j);
|
|
|
+ if (tmp.length) {
|
|
|
+ memcpy(nsge, &tmp, sizeof(*nsge));
|
|
|
+ sk_msg_iter_var_next(j);
|
|
|
+ nsge = sk_msg_elem(msg_npl, j);
|
|
|
+ }
|
|
|
+
|
|
|
+ osge = sk_msg_elem(msg_opl, i);
|
|
|
+ while (osge->length) {
|
|
|
+ memcpy(nsge, osge, sizeof(*nsge));
|
|
|
+ sg_unmark_end(nsge);
|
|
|
+ sk_msg_iter_var_next(i);
|
|
|
+ sk_msg_iter_var_next(j);
|
|
|
+ if (i == *orig_end)
|
|
|
+ break;
|
|
|
+ osge = sk_msg_elem(msg_opl, i);
|
|
|
+ nsge = sk_msg_elem(msg_npl, j);
|
|
|
+ }
|
|
|
+
|
|
|
+ msg_npl->sg.end = j;
|
|
|
+ msg_npl->sg.curr = j;
|
|
|
+ msg_npl->sg.copybreak = 0;
|
|
|
+
|
|
|
+ *to = new;
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
+static void tls_merge_open_record(struct sock *sk, struct tls_rec *to,
|
|
|
+ struct tls_rec *from, u32 orig_end)
|
|
|
+{
|
|
|
+ struct sk_msg *msg_npl = &from->msg_plaintext;
|
|
|
+ struct sk_msg *msg_opl = &to->msg_plaintext;
|
|
|
+ struct scatterlist *osge, *nsge;
|
|
|
+ u32 i, j;
|
|
|
+
|
|
|
+ i = msg_opl->sg.end;
|
|
|
+ sk_msg_iter_var_prev(i);
|
|
|
+ j = msg_npl->sg.start;
|
|
|
+
|
|
|
+ osge = sk_msg_elem(msg_opl, i);
|
|
|
+ nsge = sk_msg_elem(msg_npl, j);
|
|
|
+
|
|
|
+ if (sg_page(osge) == sg_page(nsge) &&
|
|
|
+ osge->offset + osge->length == nsge->offset) {
|
|
|
+ osge->length += nsge->length;
|
|
|
+ put_page(sg_page(nsge));
|
|
|
+ }
|
|
|
+
|
|
|
+ msg_opl->sg.end = orig_end;
|
|
|
+ msg_opl->sg.curr = orig_end;
|
|
|
+ msg_opl->sg.copybreak = 0;
|
|
|
+ msg_opl->apply_bytes = msg_opl->sg.size + msg_npl->sg.size;
|
|
|
+ msg_opl->sg.size += msg_npl->sg.size;
|
|
|
+
|
|
|
+ sk_msg_free(sk, &to->msg_encrypted);
|
|
|
+ sk_msg_xfer_full(&to->msg_encrypted, &from->msg_encrypted);
|
|
|
+
|
|
|
+ kfree(from);
|
|
|
+}
|
|
|
+
|
|
|
static int tls_push_record(struct sock *sk, int flags,
|
|
|
unsigned char record_type)
|
|
|
{
|
|
|
struct tls_context *tls_ctx = tls_get_ctx(sk);
|
|
|
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
|
|
|
- struct tls_rec *rec = ctx->open_rec;
|
|
|
+ struct tls_rec *rec = ctx->open_rec, *tmp = NULL;
|
|
|
+ u32 i, split_point, uninitialized_var(orig_end);
|
|
|
struct sk_msg *msg_pl, *msg_en;
|
|
|
struct aead_request *req;
|
|
|
+ bool split;
|
|
|
int rc;
|
|
|
- u32 i;
|
|
|
|
|
|
if (!rec)
|
|
|
return 0;
|
|
@@ -456,6 +613,18 @@ static int tls_push_record(struct sock *sk, int flags,
|
|
|
msg_pl = &rec->msg_plaintext;
|
|
|
msg_en = &rec->msg_encrypted;
|
|
|
|
|
|
+ split_point = msg_pl->apply_bytes;
|
|
|
+ split = split_point && split_point < msg_pl->sg.size;
|
|
|
+ if (split) {
|
|
|
+ rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en,
|
|
|
+ split_point, tls_ctx->tx.overhead_size,
|
|
|
+ &orig_end);
|
|
|
+ if (rc < 0)
|
|
|
+ return rc;
|
|
|
+ sk_msg_trim(sk, msg_en, msg_pl->sg.size +
|
|
|
+ tls_ctx->tx.overhead_size);
|
|
|
+ }
|
|
|
+
|
|
|
rec->tx_flags = flags;
|
|
|
req = &rec->aead_req;
|
|
|
|
|
@@ -487,57 +656,139 @@ static int tls_push_record(struct sock *sk, int flags,
|
|
|
|
|
|
rc = tls_do_encryption(sk, tls_ctx, ctx, req, msg_pl->sg.size, i);
|
|
|
if (rc < 0) {
|
|
|
- if (rc != -EINPROGRESS)
|
|
|
+ if (rc != -EINPROGRESS) {
|
|
|
tls_err_abort(sk, EBADMSG);
|
|
|
+ if (split) {
|
|
|
+ tls_ctx->pending_open_record_frags = true;
|
|
|
+ tls_merge_open_record(sk, rec, tmp, orig_end);
|
|
|
+ }
|
|
|
+ }
|
|
|
return rc;
|
|
|
+ } else if (split) {
|
|
|
+ msg_pl = &tmp->msg_plaintext;
|
|
|
+ msg_en = &tmp->msg_encrypted;
|
|
|
+ sk_msg_trim(sk, msg_en, msg_pl->sg.size +
|
|
|
+ tls_ctx->tx.overhead_size);
|
|
|
+ tls_ctx->pending_open_record_frags = true;
|
|
|
+ ctx->open_rec = tmp;
|
|
|
}
|
|
|
|
|
|
return tls_tx_records(sk, flags);
|
|
|
}
|
|
|
|
|
|
-static int tls_sw_push_pending_record(struct sock *sk, int flags)
|
|
|
-{
|
|
|
- return tls_push_record(sk, flags, TLS_RECORD_TYPE_DATA);
|
|
|
-}
|
|
|
-
|
|
|
-static struct tls_rec *get_rec(struct sock *sk)
|
|
|
+static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
|
|
|
+ bool full_record, u8 record_type,
|
|
|
+ size_t *copied, int flags)
|
|
|
{
|
|
|
struct tls_context *tls_ctx = tls_get_ctx(sk);
|
|
|
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
|
|
|
- struct sk_msg *msg_pl, *msg_en;
|
|
|
+ struct sk_msg msg_redir = { };
|
|
|
+ struct sk_psock *psock;
|
|
|
+ struct sock *sk_redir;
|
|
|
struct tls_rec *rec;
|
|
|
- int mem_size;
|
|
|
+ int err = 0, send;
|
|
|
+ bool enospc;
|
|
|
+
|
|
|
+ psock = sk_psock_get(sk);
|
|
|
+ if (!psock)
|
|
|
+ return tls_push_record(sk, flags, record_type);
|
|
|
+more_data:
|
|
|
+ enospc = sk_msg_full(msg);
|
|
|
+ if (psock->eval == __SK_NONE)
|
|
|
+ psock->eval = sk_psock_msg_verdict(sk, psock, msg);
|
|
|
+ if (msg->cork_bytes && msg->cork_bytes > msg->sg.size &&
|
|
|
+ !enospc && !full_record) {
|
|
|
+ err = -ENOSPC;
|
|
|
+ goto out_err;
|
|
|
+ }
|
|
|
+ msg->cork_bytes = 0;
|
|
|
+ send = msg->sg.size;
|
|
|
+ if (msg->apply_bytes && msg->apply_bytes < send)
|
|
|
+ send = msg->apply_bytes;
|
|
|
+
|
|
|
+ switch (psock->eval) {
|
|
|
+ case __SK_PASS:
|
|
|
+ err = tls_push_record(sk, flags, record_type);
|
|
|
+ if (err < 0) {
|
|
|
+ *copied -= sk_msg_free(sk, msg);
|
|
|
+ tls_free_open_rec(sk);
|
|
|
+ goto out_err;
|
|
|
+ }
|
|
|
+ break;
|
|
|
+ case __SK_REDIRECT:
|
|
|
+ sk_redir = psock->sk_redir;
|
|
|
+ memcpy(&msg_redir, msg, sizeof(*msg));
|
|
|
+ if (msg->apply_bytes < send)
|
|
|
+ msg->apply_bytes = 0;
|
|
|
+ else
|
|
|
+ msg->apply_bytes -= send;
|
|
|
+ sk_msg_return_zero(sk, msg, send);
|
|
|
+ msg->sg.size -= send;
|
|
|
+ release_sock(sk);
|
|
|
+ err = tcp_bpf_sendmsg_redir(sk_redir, &msg_redir, send, flags);
|
|
|
+ lock_sock(sk);
|
|
|
+ if (err < 0) {
|
|
|
+ *copied -= sk_msg_free_nocharge(sk, &msg_redir);
|
|
|
+ msg->sg.size = 0;
|
|
|
+ }
|
|
|
+ if (msg->sg.size == 0)
|
|
|
+ tls_free_open_rec(sk);
|
|
|
+ break;
|
|
|
+ case __SK_DROP:
|
|
|
+ default:
|
|
|
+ sk_msg_free_partial(sk, msg, send);
|
|
|
+ if (msg->apply_bytes < send)
|
|
|
+ msg->apply_bytes = 0;
|
|
|
+ else
|
|
|
+ msg->apply_bytes -= send;
|
|
|
+ if (msg->sg.size == 0)
|
|
|
+ tls_free_open_rec(sk);
|
|
|
+ *copied -= send;
|
|
|
+ err = -EACCES;
|
|
|
+ }
|
|
|
|
|
|
- /* Return if we already have an open record */
|
|
|
- if (ctx->open_rec)
|
|
|
- return ctx->open_rec;
|
|
|
+ if (likely(!err)) {
|
|
|
+ bool reset_eval = !ctx->open_rec;
|
|
|
|
|
|
- mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send);
|
|
|
+ rec = ctx->open_rec;
|
|
|
+ if (rec) {
|
|
|
+ msg = &rec->msg_plaintext;
|
|
|
+ if (!msg->apply_bytes)
|
|
|
+ reset_eval = true;
|
|
|
+ }
|
|
|
+ if (reset_eval) {
|
|
|
+ psock->eval = __SK_NONE;
|
|
|
+ if (psock->sk_redir) {
|
|
|
+ sock_put(psock->sk_redir);
|
|
|
+ psock->sk_redir = NULL;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (rec)
|
|
|
+ goto more_data;
|
|
|
+ }
|
|
|
+ out_err:
|
|
|
+ sk_psock_put(sk, psock);
|
|
|
+ return err;
|
|
|
+}
|
|
|
+
|
|
|
+static int tls_sw_push_pending_record(struct sock *sk, int flags)
|
|
|
+{
|
|
|
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
|
|
|
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
|
|
|
+ struct tls_rec *rec = ctx->open_rec;
|
|
|
+ struct sk_msg *msg_pl;
|
|
|
+ size_t copied;
|
|
|
|
|
|
- rec = kzalloc(mem_size, sk->sk_allocation);
|
|
|
if (!rec)
|
|
|
- return NULL;
|
|
|
+ return 0;
|
|
|
|
|
|
msg_pl = &rec->msg_plaintext;
|
|
|
- msg_en = &rec->msg_encrypted;
|
|
|
-
|
|
|
- sk_msg_init(msg_pl);
|
|
|
- sk_msg_init(msg_en);
|
|
|
-
|
|
|
- sg_init_table(rec->sg_aead_in, 2);
|
|
|
- sg_set_buf(&rec->sg_aead_in[0], rec->aad_space,
|
|
|
- sizeof(rec->aad_space));
|
|
|
- sg_unmark_end(&rec->sg_aead_in[1]);
|
|
|
-
|
|
|
- sg_init_table(rec->sg_aead_out, 2);
|
|
|
- sg_set_buf(&rec->sg_aead_out[0], rec->aad_space,
|
|
|
- sizeof(rec->aad_space));
|
|
|
- sg_unmark_end(&rec->sg_aead_out[1]);
|
|
|
-
|
|
|
- ctx->open_rec = rec;
|
|
|
- rec->inplace_crypto = 1;
|
|
|
+ copied = msg_pl->sg.size;
|
|
|
+ if (!copied)
|
|
|
+ return 0;
|
|
|
|
|
|
- return rec;
|
|
|
+ return bpf_exec_tx_verdict(msg_pl, sk, true, TLS_RECORD_TYPE_DATA,
|
|
|
+ &copied, flags);
|
|
|
}
|
|
|
|
|
|
int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
|
|
@@ -589,7 +840,10 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
|
|
|
goto send_end;
|
|
|
}
|
|
|
|
|
|
- rec = get_rec(sk);
|
|
|
+ if (ctx->open_rec)
|
|
|
+ rec = ctx->open_rec;
|
|
|
+ else
|
|
|
+ rec = ctx->open_rec = tls_get_rec(sk);
|
|
|
if (!rec) {
|
|
|
ret = -ENOMEM;
|
|
|
goto send_end;
|
|
@@ -628,6 +882,8 @@ alloc_encrypted:
|
|
|
}
|
|
|
|
|
|
if (!is_kvec && (full_record || eor) && !async_capable) {
|
|
|
+ u32 first = msg_pl->sg.end;
|
|
|
+
|
|
|
ret = sk_msg_zerocopy_from_iter(sk, &msg->msg_iter,
|
|
|
msg_pl, try_to_copy);
|
|
|
if (ret)
|
|
@@ -637,15 +893,27 @@ alloc_encrypted:
|
|
|
|
|
|
num_zc++;
|
|
|
copied += try_to_copy;
|
|
|
- ret = tls_push_record(sk, msg->msg_flags, record_type);
|
|
|
+
|
|
|
+ sk_msg_sg_copy_set(msg_pl, first);
|
|
|
+ ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
|
|
|
+ record_type, &copied,
|
|
|
+ msg->msg_flags);
|
|
|
if (ret) {
|
|
|
if (ret == -EINPROGRESS)
|
|
|
num_async++;
|
|
|
+ else if (ret == -ENOMEM)
|
|
|
+ goto wait_for_memory;
|
|
|
+ else if (ret == -ENOSPC)
|
|
|
+ goto rollback_iter;
|
|
|
else if (ret != -EAGAIN)
|
|
|
goto send_end;
|
|
|
}
|
|
|
continue;
|
|
|
-
|
|
|
+rollback_iter:
|
|
|
+ copied -= try_to_copy;
|
|
|
+ sk_msg_sg_copy_clear(msg_pl, first);
|
|
|
+ iov_iter_revert(&msg->msg_iter,
|
|
|
+ msg_pl->sg.size - orig_size);
|
|
|
fallback_to_reg_send:
|
|
|
sk_msg_trim(sk, msg_pl, orig_size);
|
|
|
}
|
|
@@ -678,12 +946,19 @@ fallback_to_reg_send:
|
|
|
tls_ctx->pending_open_record_frags = true;
|
|
|
copied += try_to_copy;
|
|
|
if (full_record || eor) {
|
|
|
- ret = tls_push_record(sk, msg->msg_flags, record_type);
|
|
|
+ ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
|
|
|
+ record_type, &copied,
|
|
|
+ msg->msg_flags);
|
|
|
if (ret) {
|
|
|
if (ret == -EINPROGRESS)
|
|
|
num_async++;
|
|
|
- else if (ret != -EAGAIN)
|
|
|
+ else if (ret == -ENOMEM)
|
|
|
+ goto wait_for_memory;
|
|
|
+ else if (ret != -EAGAIN) {
|
|
|
+ if (ret == -ENOSPC)
|
|
|
+ ret = 0;
|
|
|
goto send_end;
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -742,10 +1017,10 @@ int tls_sw_sendpage(struct sock *sk, struct page *page,
|
|
|
struct tls_context *tls_ctx = tls_get_ctx(sk);
|
|
|
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
|
|
|
unsigned char record_type = TLS_RECORD_TYPE_DATA;
|
|
|
- size_t orig_size = size;
|
|
|
struct sk_msg *msg_pl;
|
|
|
struct tls_rec *rec;
|
|
|
int num_async = 0;
|
|
|
+ size_t copied = 0;
|
|
|
bool full_record;
|
|
|
int record_room;
|
|
|
int ret = 0;
|
|
@@ -778,7 +1053,10 @@ int tls_sw_sendpage(struct sock *sk, struct page *page,
|
|
|
goto sendpage_end;
|
|
|
}
|
|
|
|
|
|
- rec = get_rec(sk);
|
|
|
+ if (ctx->open_rec)
|
|
|
+ rec = ctx->open_rec;
|
|
|
+ else
|
|
|
+ rec = ctx->open_rec = tls_get_rec(sk);
|
|
|
if (!rec) {
|
|
|
ret = -ENOMEM;
|
|
|
goto sendpage_end;
|
|
@@ -788,6 +1066,7 @@ int tls_sw_sendpage(struct sock *sk, struct page *page,
|
|
|
|
|
|
full_record = false;
|
|
|
record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
|
|
|
+ copied = 0;
|
|
|
copy = size;
|
|
|
if (copy >= record_room) {
|
|
|
copy = record_room;
|
|
@@ -818,16 +1097,23 @@ alloc_payload:
|
|
|
|
|
|
offset += copy;
|
|
|
size -= copy;
|
|
|
+ copied += copy;
|
|
|
|
|
|
tls_ctx->pending_open_record_frags = true;
|
|
|
if (full_record || eor || sk_msg_full(msg_pl)) {
|
|
|
rec->inplace_crypto = 0;
|
|
|
- ret = tls_push_record(sk, flags, record_type);
|
|
|
+ ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
|
|
|
+ record_type, &copied, flags);
|
|
|
if (ret) {
|
|
|
if (ret == -EINPROGRESS)
|
|
|
num_async++;
|
|
|
- else if (ret != -EAGAIN)
|
|
|
+ else if (ret == -ENOMEM)
|
|
|
+ goto wait_for_memory;
|
|
|
+ else if (ret != -EAGAIN) {
|
|
|
+ if (ret == -ENOSPC)
|
|
|
+ ret = 0;
|
|
|
goto sendpage_end;
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
continue;
|
|
@@ -851,24 +1137,20 @@ wait_for_memory:
|
|
|
}
|
|
|
}
|
|
|
sendpage_end:
|
|
|
- if (orig_size > size)
|
|
|
- ret = orig_size - size;
|
|
|
- else
|
|
|
- ret = sk_stream_error(sk, flags, ret);
|
|
|
-
|
|
|
+ ret = sk_stream_error(sk, flags, ret);
|
|
|
release_sock(sk);
|
|
|
- return ret;
|
|
|
+ return copied ? copied : ret;
|
|
|
}
|
|
|
|
|
|
-static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
|
|
|
- long timeo, int *err)
|
|
|
+static struct sk_buff *tls_wait_data(struct sock *sk, struct sk_psock *psock,
|
|
|
+ int flags, long timeo, int *err)
|
|
|
{
|
|
|
struct tls_context *tls_ctx = tls_get_ctx(sk);
|
|
|
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
|
|
|
struct sk_buff *skb;
|
|
|
DEFINE_WAIT_FUNC(wait, woken_wake_function);
|
|
|
|
|
|
- while (!(skb = ctx->recv_pkt)) {
|
|
|
+ while (!(skb = ctx->recv_pkt) && sk_psock_queue_empty(psock)) {
|
|
|
if (sk->sk_err) {
|
|
|
*err = sock_error(sk);
|
|
|
return NULL;
|
|
@@ -887,7 +1169,10 @@ static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
|
|
|
|
|
|
add_wait_queue(sk_sleep(sk), &wait);
|
|
|
sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
|
|
|
- sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait);
|
|
|
+ sk_wait_event(sk, &timeo,
|
|
|
+ ctx->recv_pkt != skb ||
|
|
|
+ !sk_psock_queue_empty(psock),
|
|
|
+ &wait);
|
|
|
sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
|
|
|
remove_wait_queue(sk_sleep(sk), &wait);
|
|
|
|
|
@@ -1164,6 +1449,7 @@ int tls_sw_recvmsg(struct sock *sk,
|
|
|
{
|
|
|
struct tls_context *tls_ctx = tls_get_ctx(sk);
|
|
|
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
|
|
|
+ struct sk_psock *psock;
|
|
|
unsigned char control;
|
|
|
struct strp_msg *rxm;
|
|
|
struct sk_buff *skb;
|
|
@@ -1179,6 +1465,7 @@ int tls_sw_recvmsg(struct sock *sk,
|
|
|
if (unlikely(flags & MSG_ERRQUEUE))
|
|
|
return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
|
|
|
|
|
|
+ psock = sk_psock_get(sk);
|
|
|
lock_sock(sk);
|
|
|
|
|
|
target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
|
|
@@ -1188,9 +1475,19 @@ int tls_sw_recvmsg(struct sock *sk,
|
|
|
bool async = false;
|
|
|
int chunk = 0;
|
|
|
|
|
|
- skb = tls_wait_data(sk, flags, timeo, &err);
|
|
|
- if (!skb)
|
|
|
+ skb = tls_wait_data(sk, psock, flags, timeo, &err);
|
|
|
+ if (!skb) {
|
|
|
+ if (psock) {
|
|
|
+ int ret = __tcp_bpf_recvmsg(sk, psock, msg, len);
|
|
|
+
|
|
|
+ if (ret > 0) {
|
|
|
+ copied += ret;
|
|
|
+ len -= ret;
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ }
|
|
|
goto recv_end;
|
|
|
+ }
|
|
|
|
|
|
rxm = strp_msg(skb);
|
|
|
|
|
@@ -1296,6 +1593,8 @@ recv_end:
|
|
|
}
|
|
|
|
|
|
release_sock(sk);
|
|
|
+ if (psock)
|
|
|
+ sk_psock_put(sk, psock);
|
|
|
return copied ? : err;
|
|
|
}
|
|
|
|
|
@@ -1318,7 +1617,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
|
|
|
|
|
|
timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
|
|
|
|
|
|
- skb = tls_wait_data(sk, flags, timeo, &err);
|
|
|
+ skb = tls_wait_data(sk, NULL, flags, timeo, &err);
|
|
|
if (!skb)
|
|
|
goto splice_read_end;
|
|
|
|
|
@@ -1356,11 +1655,16 @@ bool tls_sw_stream_read(const struct sock *sk)
|
|
|
{
|
|
|
struct tls_context *tls_ctx = tls_get_ctx(sk);
|
|
|
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
|
|
|
+ bool ingress_empty = true;
|
|
|
+ struct sk_psock *psock;
|
|
|
|
|
|
- if (ctx->recv_pkt)
|
|
|
- return true;
|
|
|
+ rcu_read_lock();
|
|
|
+ psock = sk_psock(sk);
|
|
|
+ if (psock)
|
|
|
+ ingress_empty = list_empty(&psock->ingress_msg);
|
|
|
+ rcu_read_unlock();
|
|
|
|
|
|
- return false;
|
|
|
+ return !ingress_empty || ctx->recv_pkt;
|
|
|
}
|
|
|
|
|
|
static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
|
|
@@ -1439,8 +1743,15 @@ static void tls_data_ready(struct sock *sk)
|
|
|
{
|
|
|
struct tls_context *tls_ctx = tls_get_ctx(sk);
|
|
|
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
|
|
|
+ struct sk_psock *psock;
|
|
|
|
|
|
strp_data_ready(&ctx->strp);
|
|
|
+
|
|
|
+ psock = sk_psock_get(sk);
|
|
|
+ if (psock && !list_empty(&psock->ingress_msg)) {
|
|
|
+ ctx->saved_data_ready(sk);
|
|
|
+ sk_psock_put(sk, psock);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
void tls_sw_free_resources_tx(struct sock *sk)
|