|
@@ -52,7 +52,7 @@ static int tls_do_decryption(struct sock *sk,
|
|
|
gfp_t flags)
|
|
|
{
|
|
|
struct tls_context *tls_ctx = tls_get_ctx(sk);
|
|
|
- struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
|
|
|
+ struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
|
|
|
struct strp_msg *rxm = strp_msg(skb);
|
|
|
struct aead_request *aead_req;
|
|
|
|
|
@@ -122,7 +122,7 @@ out:
|
|
|
static void trim_both_sgl(struct sock *sk, int target_size)
|
|
|
{
|
|
|
struct tls_context *tls_ctx = tls_get_ctx(sk);
|
|
|
- struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
|
|
|
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
|
|
|
|
|
|
trim_sg(sk, ctx->sg_plaintext_data,
|
|
|
&ctx->sg_plaintext_num_elem,
|
|
@@ -141,7 +141,7 @@ static void trim_both_sgl(struct sock *sk, int target_size)
|
|
|
static int alloc_encrypted_sg(struct sock *sk, int len)
|
|
|
{
|
|
|
struct tls_context *tls_ctx = tls_get_ctx(sk);
|
|
|
- struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
|
|
|
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
|
|
|
int rc = 0;
|
|
|
|
|
|
rc = sk_alloc_sg(sk, len,
|
|
@@ -155,7 +155,7 @@ static int alloc_encrypted_sg(struct sock *sk, int len)
|
|
|
static int alloc_plaintext_sg(struct sock *sk, int len)
|
|
|
{
|
|
|
struct tls_context *tls_ctx = tls_get_ctx(sk);
|
|
|
- struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
|
|
|
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
|
|
|
int rc = 0;
|
|
|
|
|
|
rc = sk_alloc_sg(sk, len, ctx->sg_plaintext_data, 0,
|
|
@@ -181,7 +181,7 @@ static void free_sg(struct sock *sk, struct scatterlist *sg,
|
|
|
static void tls_free_both_sg(struct sock *sk)
|
|
|
{
|
|
|
struct tls_context *tls_ctx = tls_get_ctx(sk);
|
|
|
- struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
|
|
|
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
|
|
|
|
|
|
free_sg(sk, ctx->sg_encrypted_data, &ctx->sg_encrypted_num_elem,
|
|
|
&ctx->sg_encrypted_size);
|
|
@@ -191,7 +191,7 @@ static void tls_free_both_sg(struct sock *sk)
|
|
|
}
|
|
|
|
|
|
static int tls_do_encryption(struct tls_context *tls_ctx,
|
|
|
- struct tls_sw_context *ctx, size_t data_len,
|
|
|
+ struct tls_sw_context_tx *ctx, size_t data_len,
|
|
|
gfp_t flags)
|
|
|
{
|
|
|
unsigned int req_size = sizeof(struct aead_request) +
|
|
@@ -227,7 +227,7 @@ static int tls_push_record(struct sock *sk, int flags,
|
|
|
unsigned char record_type)
|
|
|
{
|
|
|
struct tls_context *tls_ctx = tls_get_ctx(sk);
|
|
|
- struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
|
|
|
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
|
|
|
int rc;
|
|
|
|
|
|
sg_mark_end(ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem - 1);
|
|
@@ -339,7 +339,7 @@ static int memcopy_from_iter(struct sock *sk, struct iov_iter *from,
|
|
|
int bytes)
|
|
|
{
|
|
|
struct tls_context *tls_ctx = tls_get_ctx(sk);
|
|
|
- struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
|
|
|
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
|
|
|
struct scatterlist *sg = ctx->sg_plaintext_data;
|
|
|
int copy, i, rc = 0;
|
|
|
|
|
@@ -367,7 +367,7 @@ out:
|
|
|
int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
|
|
|
{
|
|
|
struct tls_context *tls_ctx = tls_get_ctx(sk);
|
|
|
- struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
|
|
|
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
|
|
|
int ret = 0;
|
|
|
int required_size;
|
|
|
long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
|
|
@@ -522,7 +522,7 @@ int tls_sw_sendpage(struct sock *sk, struct page *page,
|
|
|
int offset, size_t size, int flags)
|
|
|
{
|
|
|
struct tls_context *tls_ctx = tls_get_ctx(sk);
|
|
|
- struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
|
|
|
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
|
|
|
int ret = 0;
|
|
|
long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
|
|
|
bool eor;
|
|
@@ -636,7 +636,7 @@ static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
|
|
|
long timeo, int *err)
|
|
|
{
|
|
|
struct tls_context *tls_ctx = tls_get_ctx(sk);
|
|
|
- struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
|
|
|
+ struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
|
|
|
struct sk_buff *skb;
|
|
|
DEFINE_WAIT_FUNC(wait, woken_wake_function);
|
|
|
|
|
@@ -674,7 +674,7 @@ static int decrypt_skb(struct sock *sk, struct sk_buff *skb,
|
|
|
struct scatterlist *sgout)
|
|
|
{
|
|
|
struct tls_context *tls_ctx = tls_get_ctx(sk);
|
|
|
- struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
|
|
|
+ struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
|
|
|
char iv[TLS_CIPHER_AES_GCM_128_SALT_SIZE + MAX_IV_SIZE];
|
|
|
struct scatterlist sgin_arr[MAX_SKB_FRAGS + 2];
|
|
|
struct scatterlist *sgin = &sgin_arr[0];
|
|
@@ -723,7 +723,7 @@ static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
|
|
|
unsigned int len)
|
|
|
{
|
|
|
struct tls_context *tls_ctx = tls_get_ctx(sk);
|
|
|
- struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
|
|
|
+ struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
|
|
|
struct strp_msg *rxm = strp_msg(skb);
|
|
|
|
|
|
if (len < rxm->full_len) {
|
|
@@ -749,7 +749,7 @@ int tls_sw_recvmsg(struct sock *sk,
|
|
|
int *addr_len)
|
|
|
{
|
|
|
struct tls_context *tls_ctx = tls_get_ctx(sk);
|
|
|
- struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
|
|
|
+ struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
|
|
|
unsigned char control;
|
|
|
struct strp_msg *rxm;
|
|
|
struct sk_buff *skb;
|
|
@@ -869,7 +869,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
|
|
|
size_t len, unsigned int flags)
|
|
|
{
|
|
|
struct tls_context *tls_ctx = tls_get_ctx(sock->sk);
|
|
|
- struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
|
|
|
+ struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
|
|
|
struct strp_msg *rxm = NULL;
|
|
|
struct sock *sk = sock->sk;
|
|
|
struct sk_buff *skb;
|
|
@@ -922,7 +922,7 @@ unsigned int tls_sw_poll(struct file *file, struct socket *sock,
|
|
|
unsigned int ret;
|
|
|
struct sock *sk = sock->sk;
|
|
|
struct tls_context *tls_ctx = tls_get_ctx(sk);
|
|
|
- struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
|
|
|
+ struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
|
|
|
|
|
|
/* Grab POLLOUT and POLLHUP from the underlying socket */
|
|
|
ret = ctx->sk_poll(file, sock, wait);
|
|
@@ -938,7 +938,7 @@ unsigned int tls_sw_poll(struct file *file, struct socket *sock,
|
|
|
static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
|
|
|
{
|
|
|
struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
|
|
|
- struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
|
|
|
+ struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
|
|
|
char header[tls_ctx->rx.prepend_size];
|
|
|
struct strp_msg *rxm = strp_msg(skb);
|
|
|
size_t cipher_overhead;
|
|
@@ -987,7 +987,7 @@ read_failure:
|
|
|
static void tls_queue(struct strparser *strp, struct sk_buff *skb)
|
|
|
{
|
|
|
struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
|
|
|
- struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
|
|
|
+ struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
|
|
|
struct strp_msg *rxm;
|
|
|
|
|
|
rxm = strp_msg(skb);
|
|
@@ -1003,18 +1003,28 @@ static void tls_queue(struct strparser *strp, struct sk_buff *skb)
|
|
|
static void tls_data_ready(struct sock *sk)
|
|
|
{
|
|
|
struct tls_context *tls_ctx = tls_get_ctx(sk);
|
|
|
- struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
|
|
|
+ struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
|
|
|
|
|
|
strp_data_ready(&ctx->strp);
|
|
|
}
|
|
|
|
|
|
-void tls_sw_free_resources(struct sock *sk)
|
|
|
+void tls_sw_free_resources_tx(struct sock *sk)
|
|
|
{
|
|
|
struct tls_context *tls_ctx = tls_get_ctx(sk);
|
|
|
- struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
|
|
|
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
|
|
|
|
|
|
if (ctx->aead_send)
|
|
|
crypto_free_aead(ctx->aead_send);
|
|
|
+ tls_free_both_sg(sk);
|
|
|
+
|
|
|
+ kfree(ctx);
|
|
|
+}
|
|
|
+
|
|
|
+void tls_sw_free_resources_rx(struct sock *sk)
|
|
|
+{
|
|
|
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
|
|
|
+ struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
|
|
|
+
|
|
|
if (ctx->aead_recv) {
|
|
|
if (ctx->recv_pkt) {
|
|
|
kfree_skb(ctx->recv_pkt);
|
|
@@ -1030,10 +1040,7 @@ void tls_sw_free_resources(struct sock *sk)
|
|
|
lock_sock(sk);
|
|
|
}
|
|
|
|
|
|
- tls_free_both_sg(sk);
|
|
|
-
|
|
|
kfree(ctx);
|
|
|
- kfree(tls_ctx);
|
|
|
}
|
|
|
|
|
|
int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
|
|
@@ -1041,7 +1048,8 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
|
|
|
char keyval[TLS_CIPHER_AES_GCM_128_KEY_SIZE];
|
|
|
struct tls_crypto_info *crypto_info;
|
|
|
struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
|
|
|
- struct tls_sw_context *sw_ctx;
|
|
|
+ struct tls_sw_context_tx *sw_ctx_tx = NULL;
|
|
|
+ struct tls_sw_context_rx *sw_ctx_rx = NULL;
|
|
|
struct cipher_context *cctx;
|
|
|
struct crypto_aead **aead;
|
|
|
struct strp_callbacks cb;
|
|
@@ -1054,27 +1062,32 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
|
|
|
goto out;
|
|
|
}
|
|
|
|
|
|
- if (!ctx->priv_ctx) {
|
|
|
- sw_ctx = kzalloc(sizeof(*sw_ctx), GFP_KERNEL);
|
|
|
- if (!sw_ctx) {
|
|
|
+ if (tx) {
|
|
|
+ sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL);
|
|
|
+ if (!sw_ctx_tx) {
|
|
|
rc = -ENOMEM;
|
|
|
goto out;
|
|
|
}
|
|
|
- crypto_init_wait(&sw_ctx->async_wait);
|
|
|
+ crypto_init_wait(&sw_ctx_tx->async_wait);
|
|
|
+ ctx->priv_ctx_tx = sw_ctx_tx;
|
|
|
} else {
|
|
|
- sw_ctx = ctx->priv_ctx;
|
|
|
+ sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL);
|
|
|
+ if (!sw_ctx_rx) {
|
|
|
+ rc = -ENOMEM;
|
|
|
+ goto out;
|
|
|
+ }
|
|
|
+ crypto_init_wait(&sw_ctx_rx->async_wait);
|
|
|
+ ctx->priv_ctx_rx = sw_ctx_rx;
|
|
|
}
|
|
|
|
|
|
- ctx->priv_ctx = (struct tls_offload_context *)sw_ctx;
|
|
|
-
|
|
|
if (tx) {
|
|
|
crypto_info = &ctx->crypto_send;
|
|
|
cctx = &ctx->tx;
|
|
|
- aead = &sw_ctx->aead_send;
|
|
|
+ aead = &sw_ctx_tx->aead_send;
|
|
|
} else {
|
|
|
crypto_info = &ctx->crypto_recv;
|
|
|
cctx = &ctx->rx;
|
|
|
- aead = &sw_ctx->aead_recv;
|
|
|
+ aead = &sw_ctx_rx->aead_recv;
|
|
|
}
|
|
|
|
|
|
switch (crypto_info->cipher_type) {
|
|
@@ -1121,22 +1134,24 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
|
|
|
}
|
|
|
memcpy(cctx->rec_seq, rec_seq, rec_seq_size);
|
|
|
|
|
|
- if (tx) {
|
|
|
- sg_init_table(sw_ctx->sg_encrypted_data,
|
|
|
- ARRAY_SIZE(sw_ctx->sg_encrypted_data));
|
|
|
- sg_init_table(sw_ctx->sg_plaintext_data,
|
|
|
- ARRAY_SIZE(sw_ctx->sg_plaintext_data));
|
|
|
-
|
|
|
- sg_init_table(sw_ctx->sg_aead_in, 2);
|
|
|
- sg_set_buf(&sw_ctx->sg_aead_in[0], sw_ctx->aad_space,
|
|
|
- sizeof(sw_ctx->aad_space));
|
|
|
- sg_unmark_end(&sw_ctx->sg_aead_in[1]);
|
|
|
- sg_chain(sw_ctx->sg_aead_in, 2, sw_ctx->sg_plaintext_data);
|
|
|
- sg_init_table(sw_ctx->sg_aead_out, 2);
|
|
|
- sg_set_buf(&sw_ctx->sg_aead_out[0], sw_ctx->aad_space,
|
|
|
- sizeof(sw_ctx->aad_space));
|
|
|
- sg_unmark_end(&sw_ctx->sg_aead_out[1]);
|
|
|
- sg_chain(sw_ctx->sg_aead_out, 2, sw_ctx->sg_encrypted_data);
|
|
|
+ if (sw_ctx_tx) {
|
|
|
+ sg_init_table(sw_ctx_tx->sg_encrypted_data,
|
|
|
+ ARRAY_SIZE(sw_ctx_tx->sg_encrypted_data));
|
|
|
+ sg_init_table(sw_ctx_tx->sg_plaintext_data,
|
|
|
+ ARRAY_SIZE(sw_ctx_tx->sg_plaintext_data));
|
|
|
+
|
|
|
+ sg_init_table(sw_ctx_tx->sg_aead_in, 2);
|
|
|
+ sg_set_buf(&sw_ctx_tx->sg_aead_in[0], sw_ctx_tx->aad_space,
|
|
|
+ sizeof(sw_ctx_tx->aad_space));
|
|
|
+ sg_unmark_end(&sw_ctx_tx->sg_aead_in[1]);
|
|
|
+ sg_chain(sw_ctx_tx->sg_aead_in, 2,
|
|
|
+ sw_ctx_tx->sg_plaintext_data);
|
|
|
+ sg_init_table(sw_ctx_tx->sg_aead_out, 2);
|
|
|
+ sg_set_buf(&sw_ctx_tx->sg_aead_out[0], sw_ctx_tx->aad_space,
|
|
|
+ sizeof(sw_ctx_tx->aad_space));
|
|
|
+ sg_unmark_end(&sw_ctx_tx->sg_aead_out[1]);
|
|
|
+ sg_chain(sw_ctx_tx->sg_aead_out, 2,
|
|
|
+ sw_ctx_tx->sg_encrypted_data);
|
|
|
}
|
|
|
|
|
|
if (!*aead) {
|
|
@@ -1161,22 +1176,22 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
|
|
|
if (rc)
|
|
|
goto free_aead;
|
|
|
|
|
|
- if (!tx) {
|
|
|
+ if (sw_ctx_rx) {
|
|
|
/* Set up strparser */
|
|
|
memset(&cb, 0, sizeof(cb));
|
|
|
cb.rcv_msg = tls_queue;
|
|
|
cb.parse_msg = tls_read_size;
|
|
|
|
|
|
- strp_init(&sw_ctx->strp, sk, &cb);
|
|
|
+ strp_init(&sw_ctx_rx->strp, sk, &cb);
|
|
|
|
|
|
write_lock_bh(&sk->sk_callback_lock);
|
|
|
- sw_ctx->saved_data_ready = sk->sk_data_ready;
|
|
|
+ sw_ctx_rx->saved_data_ready = sk->sk_data_ready;
|
|
|
sk->sk_data_ready = tls_data_ready;
|
|
|
write_unlock_bh(&sk->sk_callback_lock);
|
|
|
|
|
|
- sw_ctx->sk_poll = sk->sk_socket->ops->poll;
|
|
|
+ sw_ctx_rx->sk_poll = sk->sk_socket->ops->poll;
|
|
|
|
|
|
- strp_check_rcv(&sw_ctx->strp);
|
|
|
+ strp_check_rcv(&sw_ctx_rx->strp);
|
|
|
}
|
|
|
|
|
|
goto out;
|
|
@@ -1188,11 +1203,16 @@ free_rec_seq:
|
|
|
kfree(cctx->rec_seq);
|
|
|
cctx->rec_seq = NULL;
|
|
|
free_iv:
|
|
|
- kfree(ctx->tx.iv);
|
|
|
- ctx->tx.iv = NULL;
|
|
|
+ kfree(cctx->iv);
|
|
|
+ cctx->iv = NULL;
|
|
|
free_priv:
|
|
|
- kfree(ctx->priv_ctx);
|
|
|
- ctx->priv_ctx = NULL;
|
|
|
+ if (tx) {
|
|
|
+ kfree(ctx->priv_ctx_tx);
|
|
|
+ ctx->priv_ctx_tx = NULL;
|
|
|
+ } else {
|
|
|
+ kfree(ctx->priv_ctx_rx);
|
|
|
+ ctx->priv_ctx_rx = NULL;
|
|
|
+ }
|
|
|
out:
|
|
|
return rc;
|
|
|
}
|