Эх сурвалжийг харах

Merge branch 'tls-misc-fixes'

Ilya Lesokhin says:

====================
tls: Miscellaneous fixes

Here's a set of miscellaneous fix patches.

Patch 1 makes sure aead_request is initailized properly.
Patches 2-3 Fix a memory leak we've encountered.
patch 4 moves tls_make_aad to allow sharing it in the future.
Patch 5 fixes a TOCTOU issue reported here:
https://www.spinics.net/lists/kernel/msg2608603.html
Patch 6 Avoids callback overriding when tls_set_sw_offload fails.
====================

Signed-off-by: David S. Miller <davem@davemloft.net>
David S. Miller 7 жил өмнө
parent
commit
951b796695
3 өөрчлөгдсөн 79 нэмэгдсэн , 60 устгасан
  1. 18 1
      include/net/tls.h
  2. 57 39
      net/tls/tls_main.c
  3. 4 20
      net/tls/tls_sw.c

+ 18 - 1
include/net/tls.h

@@ -83,6 +83,8 @@ struct tls_context {
 
 	void *priv_ctx;
 
+	u8 tx_conf:2;
+
 	u16 prepend_size;
 	u16 tag_size;
 	u16 overhead_size;
@@ -97,7 +99,6 @@ struct tls_context {
 
 	u16 pending_open_record_frags;
 	int (*push_pending_record)(struct sock *sk, int flags);
-	void (*free_resources)(struct sock *sk);
 
 	void (*sk_write_space)(struct sock *sk);
 	void (*sk_proto_close)(struct sock *sk, long timeout);
@@ -122,6 +123,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
 int tls_sw_sendpage(struct sock *sk, struct page *page,
 		    int offset, size_t size, int flags);
 void tls_sw_close(struct sock *sk, long timeout);
+void tls_sw_free_tx_resources(struct sock *sk);
 
 void tls_sk_destruct(struct sock *sk, struct tls_context *ctx);
 void tls_icsk_clean_acked(struct sock *sk);
@@ -212,6 +214,21 @@ static inline void tls_fill_prepend(struct tls_context *ctx,
 	       ctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv_size);
 }
 
+static inline void tls_make_aad(char *buf,
+				size_t size,
+				char *record_sequence,
+				int record_sequence_size,
+				unsigned char record_type)
+{
+	memcpy(buf, record_sequence, record_sequence_size);
+
+	buf[8] = record_type;
+	buf[9] = TLS_1_2_VERSION_MAJOR;
+	buf[10] = TLS_1_2_VERSION_MINOR;
+	buf[11] = size >> 8;
+	buf[12] = size & 0xFF;
+}
+
 static inline struct tls_context *tls_get_ctx(const struct sock *sk)
 {
 	struct inet_connection_sock *icsk = inet_csk(sk);

+ 57 - 39
net/tls/tls_main.c

@@ -45,8 +45,18 @@ MODULE_AUTHOR("Mellanox Technologies");
 MODULE_DESCRIPTION("Transport Layer Security Support");
 MODULE_LICENSE("Dual BSD/GPL");
 
-static struct proto tls_base_prot;
-static struct proto tls_sw_prot;
+enum {
+	TLS_BASE_TX,
+	TLS_SW_TX,
+	TLS_NUM_CONFIG,
+};
+
+static struct proto tls_prots[TLS_NUM_CONFIG];
+
+static inline void update_sk_prot(struct sock *sk, struct tls_context *ctx)
+{
+	sk->sk_prot = &tls_prots[ctx->tx_conf];
+}
 
 int wait_on_pending_writer(struct sock *sk, long *timeo)
 {
@@ -216,6 +226,12 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
 	void (*sk_proto_close)(struct sock *sk, long timeout);
 
 	lock_sock(sk);
+	sk_proto_close = ctx->sk_proto_close;
+
+	if (ctx->tx_conf == TLS_BASE_TX) {
+		kfree(ctx);
+		goto skip_tx_cleanup;
+	}
 
 	if (!tls_complete_pending_work(sk, ctx, 0, &timeo))
 		tls_handle_open_record(sk, 0);
@@ -232,13 +248,14 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
 			sg++;
 		}
 	}
-	ctx->free_resources(sk);
+
 	kfree(ctx->rec_seq);
 	kfree(ctx->iv);
 
-	sk_proto_close = ctx->sk_proto_close;
-	kfree(ctx);
+	if (ctx->tx_conf == TLS_SW_TX)
+		tls_sw_free_tx_resources(sk);
 
+skip_tx_cleanup:
 	release_sock(sk);
 	sk_proto_close(sk, timeout);
 }
@@ -338,46 +355,41 @@ static int tls_getsockopt(struct sock *sk, int level, int optname,
 static int do_tls_setsockopt_tx(struct sock *sk, char __user *optval,
 				unsigned int optlen)
 {
-	struct tls_crypto_info *crypto_info, tmp_crypto_info;
+	struct tls_crypto_info *crypto_info;
 	struct tls_context *ctx = tls_get_ctx(sk);
-	struct proto *prot = NULL;
 	int rc = 0;
+	int tx_conf;
 
 	if (!optval || (optlen < sizeof(*crypto_info))) {
 		rc = -EINVAL;
 		goto out;
 	}
 
-	rc = copy_from_user(&tmp_crypto_info, optval, sizeof(*crypto_info));
+	crypto_info = &ctx->crypto_send;
+	/* Currently we don't support set crypto info more than one time */
+	if (TLS_CRYPTO_INFO_READY(crypto_info))
+		goto out;
+
+	rc = copy_from_user(crypto_info, optval, sizeof(*crypto_info));
 	if (rc) {
 		rc = -EFAULT;
 		goto out;
 	}
 
 	/* check version */
-	if (tmp_crypto_info.version != TLS_1_2_VERSION) {
+	if (crypto_info->version != TLS_1_2_VERSION) {
 		rc = -ENOTSUPP;
-		goto out;
+		goto err_crypto_info;
 	}
 
-	/* get user crypto info */
-	crypto_info = &ctx->crypto_send;
-
-	/* Currently we don't support set crypto info more than one time */
-	if (TLS_CRYPTO_INFO_READY(crypto_info))
-		goto out;
-
-	switch (tmp_crypto_info.cipher_type) {
+	switch (crypto_info->cipher_type) {
 	case TLS_CIPHER_AES_GCM_128: {
 		if (optlen != sizeof(struct tls12_crypto_info_aes_gcm_128)) {
 			rc = -EINVAL;
 			goto out;
 		}
-		rc = copy_from_user(
-		  crypto_info,
-		  optval,
-		  sizeof(struct tls12_crypto_info_aes_gcm_128));
-
+		rc = copy_from_user(crypto_info + 1, optval + sizeof(*crypto_info),
+				    optlen - sizeof(*crypto_info));
 		if (rc) {
 			rc = -EFAULT;
 			goto err_crypto_info;
@@ -389,18 +401,16 @@ static int do_tls_setsockopt_tx(struct sock *sk, char __user *optval,
 		goto out;
 	}
 
-	ctx->sk_write_space = sk->sk_write_space;
-	sk->sk_write_space = tls_write_space;
-
-	ctx->sk_proto_close = sk->sk_prot->close;
-
 	/* currently SW is default, we will have ethtool in future */
 	rc = tls_set_sw_offload(sk, ctx);
-	prot = &tls_sw_prot;
+	tx_conf = TLS_SW_TX;
 	if (rc)
 		goto err_crypto_info;
 
-	sk->sk_prot = prot;
+	ctx->tx_conf = tx_conf;
+	update_sk_prot(sk, ctx);
+	ctx->sk_write_space = sk->sk_write_space;
+	sk->sk_write_space = tls_write_space;
 	goto out;
 
 err_crypto_info:
@@ -453,7 +463,10 @@ static int tls_init(struct sock *sk)
 	icsk->icsk_ulp_data = ctx;
 	ctx->setsockopt = sk->sk_prot->setsockopt;
 	ctx->getsockopt = sk->sk_prot->getsockopt;
-	sk->sk_prot = &tls_base_prot;
+	ctx->sk_proto_close = sk->sk_prot->close;
+
+	ctx->tx_conf = TLS_BASE_TX;
+	update_sk_prot(sk, ctx);
 out:
 	return rc;
 }
@@ -464,16 +477,21 @@ static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
 	.init			= tls_init,
 };
 
+static void build_protos(struct proto *prot, struct proto *base)
+{
+	prot[TLS_BASE_TX] = *base;
+	prot[TLS_BASE_TX].setsockopt	= tls_setsockopt;
+	prot[TLS_BASE_TX].getsockopt	= tls_getsockopt;
+	prot[TLS_BASE_TX].close		= tls_sk_proto_close;
+
+	prot[TLS_SW_TX] = prot[TLS_BASE_TX];
+	prot[TLS_SW_TX].sendmsg		= tls_sw_sendmsg;
+	prot[TLS_SW_TX].sendpage	= tls_sw_sendpage;
+}
+
 static int __init tls_register(void)
 {
-	tls_base_prot			= tcp_prot;
-	tls_base_prot.setsockopt	= tls_setsockopt;
-	tls_base_prot.getsockopt	= tls_getsockopt;
-
-	tls_sw_prot			= tls_base_prot;
-	tls_sw_prot.sendmsg		= tls_sw_sendmsg;
-	tls_sw_prot.sendpage            = tls_sw_sendpage;
-	tls_sw_prot.close               = tls_sk_proto_close;
+	build_protos(tls_prots, &tcp_prot);
 
 	tcp_register_ulp(&tcp_tls_ulp_ops);
 

+ 4 - 20
net/tls/tls_sw.c

@@ -39,22 +39,6 @@
 
 #include <net/tls.h>
 
-static inline void tls_make_aad(int recv,
-				char *buf,
-				size_t size,
-				char *record_sequence,
-				int record_sequence_size,
-				unsigned char record_type)
-{
-	memcpy(buf, record_sequence, record_sequence_size);
-
-	buf[8] = record_type;
-	buf[9] = TLS_1_2_VERSION_MAJOR;
-	buf[10] = TLS_1_2_VERSION_MINOR;
-	buf[11] = size >> 8;
-	buf[12] = size & 0xFF;
-}
-
 static void trim_sg(struct sock *sk, struct scatterlist *sg,
 		    int *sg_num_elem, unsigned int *sg_size, int target_size)
 {
@@ -219,7 +203,7 @@ static int tls_do_encryption(struct tls_context *tls_ctx,
 	struct aead_request *aead_req;
 	int rc;
 
-	aead_req = kmalloc(req_size, flags);
+	aead_req = kzalloc(req_size, flags);
 	if (!aead_req)
 		return -ENOMEM;
 
@@ -249,7 +233,7 @@ static int tls_push_record(struct sock *sk, int flags,
 	sg_mark_end(ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem - 1);
 	sg_mark_end(ctx->sg_encrypted_data + ctx->sg_encrypted_num_elem - 1);
 
-	tls_make_aad(0, ctx->aad_space, ctx->sg_plaintext_size,
+	tls_make_aad(ctx->aad_space, ctx->sg_plaintext_size,
 		     tls_ctx->rec_seq, tls_ctx->rec_seq_size,
 		     record_type);
 
@@ -639,7 +623,7 @@ sendpage_end:
 	return ret;
 }
 
-static void tls_sw_free_resources(struct sock *sk)
+void tls_sw_free_tx_resources(struct sock *sk)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
@@ -650,6 +634,7 @@ static void tls_sw_free_resources(struct sock *sk)
 	tls_free_both_sg(sk);
 
 	kfree(ctx);
+	kfree(tls_ctx);
 }
 
 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx)
@@ -679,7 +664,6 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx)
 	}
 
 	ctx->priv_ctx = (struct tls_offload_context *)sw_ctx;
-	ctx->free_resources = tls_sw_free_resources;
 
 	crypto_info = &ctx->crypto_send;
 	switch (crypto_info->cipher_type) {