diff options
Diffstat (limited to 'net/tls')
-rw-r--r-- | net/tls/tls.h | 11 | ||||
-rw-r--r-- | net/tls/tls_device.c | 101 | ||||
-rw-r--r-- | net/tls/tls_device_fallback.c | 23 | ||||
-rw-r--r-- | net/tls/tls_main.c | 62 | ||||
-rw-r--r-- | net/tls/tls_sw.c | 194 |
5 files changed, 184 insertions, 207 deletions
diff --git a/net/tls/tls.h b/net/tls/tls.h index 28a8c0e80e3c..762f424ff2d5 100644 --- a/net/tls/tls.h +++ b/net/tls/tls.h @@ -127,7 +127,7 @@ struct tls_rec { struct sock *sk; char aad_space[TLS_AAD_SPACE_SIZE]; - u8 iv_data[MAX_IV_SIZE]; + u8 iv_data[TLS_MAX_IV_SIZE]; struct aead_request aead_req; u8 aead_req_ctx[]; }; @@ -142,7 +142,10 @@ void update_sk_prot(struct sock *sk, struct tls_context *ctx); int wait_on_pending_writer(struct sock *sk, long *timeo); void tls_err_abort(struct sock *sk, int err); -int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx); +int init_prot_info(struct tls_prot_info *prot, + const struct tls_crypto_info *crypto_info, + const struct tls_cipher_desc *cipher_desc); +int tls_set_sw_offload(struct sock *sk, int tx); void tls_update_rx_zc_capable(struct tls_context *tls_ctx); void tls_sw_strparser_arm(struct sock *sk, struct tls_context *ctx); void tls_sw_strparser_done(struct tls_context *tls_ctx); @@ -223,7 +226,7 @@ static inline bool tls_strp_msg_mixed_decrypted(struct tls_sw_context_rx *ctx) #ifdef CONFIG_TLS_DEVICE int tls_device_init(void); void tls_device_cleanup(void); -int tls_set_device_offload(struct sock *sk, struct tls_context *ctx); +int tls_set_device_offload(struct sock *sk); void tls_device_free_resources_tx(struct sock *sk); int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx); void tls_device_offload_cleanup_rx(struct sock *sk); @@ -234,7 +237,7 @@ static inline int tls_device_init(void) { return 0; } static inline void tls_device_cleanup(void) {} static inline int -tls_set_device_offload(struct sock *sk, struct tls_context *ctx) +tls_set_device_offload(struct sock *sk) { return -EOPNOTSUPP; } diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c index 8c94c926606a..bf8ed36b1ad6 100644 --- a/net/tls/tls_device.c +++ b/net/tls/tls_device.c @@ -56,11 +56,8 @@ static struct page *dummy_page; static void tls_device_free_ctx(struct tls_context *ctx) { - if (ctx->tx_conf == TLS_HW) { + if (ctx->tx_conf == TLS_HW) kfree(tls_offload_ctx_tx(ctx)); - kfree(ctx->tx.rec_seq); - kfree(ctx->tx.iv); - } if (ctx->rx_conf == TLS_HW) kfree(tls_offload_ctx_rx(ctx)); @@ -891,14 +888,8 @@ tls_device_reencrypt(struct sock *sk, struct tls_context *tls_ctx) struct strp_msg *rxm; char *orig_buf, *buf; - switch (tls_ctx->crypto_recv.info.cipher_type) { - case TLS_CIPHER_AES_GCM_128: - case TLS_CIPHER_AES_GCM_256: - break; - default: - return -EINVAL; - } cipher_desc = get_cipher_desc(tls_ctx->crypto_recv.info.cipher_type); + DEBUG_NET_WARN_ON_ONCE(!cipher_desc || !cipher_desc->offloadable); rxm = strp_msg(tls_strp_msg(sw_ctx)); orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE + cipher_desc->iv, @@ -1042,22 +1033,45 @@ static void tls_device_attach(struct tls_context *ctx, struct sock *sk, } } -int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) +static struct tls_offload_context_tx *alloc_offload_ctx_tx(struct tls_context *ctx) +{ + struct tls_offload_context_tx *offload_ctx; + __be64 rcd_sn; + + offload_ctx = kzalloc(sizeof(*offload_ctx), GFP_KERNEL); + if (!offload_ctx) + return NULL; + + INIT_WORK(&offload_ctx->destruct_work, tls_device_tx_del_task); + INIT_LIST_HEAD(&offload_ctx->records_list); + spin_lock_init(&offload_ctx->lock); + sg_init_table(offload_ctx->sg_tx_data, + ARRAY_SIZE(offload_ctx->sg_tx_data)); + + /* start at rec_seq - 1 to account for the start marker record */ + memcpy(&rcd_sn, ctx->tx.rec_seq, sizeof(rcd_sn)); + offload_ctx->unacked_record_sn = be64_to_cpu(rcd_sn) - 1; + + offload_ctx->ctx = ctx; + + return offload_ctx; +} + +int tls_set_device_offload(struct sock *sk) { - struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_prot_info *prot = &tls_ctx->prot_info; - const struct tls_cipher_desc *cipher_desc; struct tls_record_info *start_marker_record; struct tls_offload_context_tx *offload_ctx; + const struct tls_cipher_desc *cipher_desc; struct tls_crypto_info *crypto_info; + struct tls_prot_info *prot; struct net_device *netdev; - char *iv, *rec_seq; + struct tls_context *ctx; struct sk_buff *skb; - __be64 rcd_sn; + char *iv, *rec_seq; int rc; - if (!ctx) - return -EINVAL; + ctx = tls_get_ctx(sk); + prot = &ctx->prot_info; if (ctx->priv_ctx_tx) return -EEXIST; @@ -1085,38 +1099,23 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) goto release_netdev; } + rc = init_prot_info(prot, crypto_info, cipher_desc); + if (rc) + goto release_netdev; + iv = crypto_info_iv(crypto_info, cipher_desc); rec_seq = crypto_info_rec_seq(crypto_info, cipher_desc); - prot->version = crypto_info->version; - prot->cipher_type = crypto_info->cipher_type; - prot->prepend_size = TLS_HEADER_SIZE + cipher_desc->iv; - prot->tag_size = cipher_desc->tag; - prot->overhead_size = prot->prepend_size + prot->tag_size; - prot->iv_size = cipher_desc->iv; - prot->salt_size = cipher_desc->salt; - ctx->tx.iv = kmalloc(cipher_desc->iv + cipher_desc->salt, GFP_KERNEL); - if (!ctx->tx.iv) { - rc = -ENOMEM; - goto release_netdev; - } - memcpy(ctx->tx.iv + cipher_desc->salt, iv, cipher_desc->iv); - - prot->rec_seq_size = cipher_desc->rec_seq; - ctx->tx.rec_seq = kmemdup(rec_seq, cipher_desc->rec_seq, GFP_KERNEL); - if (!ctx->tx.rec_seq) { - rc = -ENOMEM; - goto free_iv; - } + memcpy(ctx->tx.rec_seq, rec_seq, cipher_desc->rec_seq); start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL); if (!start_marker_record) { rc = -ENOMEM; - goto free_rec_seq; + goto release_netdev; } - offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_TX, GFP_KERNEL); + offload_ctx = alloc_offload_ctx_tx(ctx); if (!offload_ctx) { rc = -ENOMEM; goto free_marker_record; @@ -1126,22 +1125,10 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) if (rc) goto free_offload_ctx; - /* start at rec_seq - 1 to account for the start marker record */ - memcpy(&rcd_sn, ctx->tx.rec_seq, sizeof(rcd_sn)); - offload_ctx->unacked_record_sn = be64_to_cpu(rcd_sn) - 1; - start_marker_record->end_seq = tcp_sk(sk)->write_seq; start_marker_record->len = 0; start_marker_record->num_frags = 0; - - INIT_WORK(&offload_ctx->destruct_work, tls_device_tx_del_task); - offload_ctx->ctx = ctx; - - INIT_LIST_HEAD(&offload_ctx->records_list); list_add_tail(&start_marker_record->list, &offload_ctx->records_list); - spin_lock_init(&offload_ctx->lock); - sg_init_table(offload_ctx->sg_tx_data, - ARRAY_SIZE(offload_ctx->sg_tx_data)); clean_acked_data_enable(inet_csk(sk), &tls_icsk_clean_acked); ctx->push_pending_record = tls_device_push_pending_record; @@ -1198,10 +1185,6 @@ free_offload_ctx: ctx->priv_ctx_tx = NULL; free_marker_record: kfree(start_marker_record); -free_rec_seq: - kfree(ctx->tx.rec_seq); -free_iv: - kfree(ctx->tx.iv); release_netdev: dev_put(netdev); return rc; @@ -1242,7 +1225,7 @@ int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx) goto release_lock; } - context = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_RX, GFP_KERNEL); + context = kzalloc(sizeof(*context), GFP_KERNEL); if (!context) { rc = -ENOMEM; goto release_lock; @@ -1250,7 +1233,7 @@ int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx) context->resync_nh_reset = 1; ctx->priv_ctx_rx = context; - rc = tls_set_sw_offload(sk, ctx, 0); + rc = tls_set_sw_offload(sk, 0); if (rc) goto release_ctx; diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c index 1d743f310f4f..4e7228f275fa 100644 --- a/net/tls/tls_device_fallback.c +++ b/net/tls/tls_device_fallback.c @@ -54,7 +54,7 @@ static int tls_enc_record(struct aead_request *aead_req, struct scatter_walk *out, int *in_len, struct tls_prot_info *prot) { - unsigned char buf[TLS_HEADER_SIZE + MAX_IV_SIZE]; + unsigned char buf[TLS_HEADER_SIZE + TLS_MAX_IV_SIZE]; const struct tls_cipher_desc *cipher_desc; struct scatterlist sg_in[3]; struct scatterlist sg_out[3]; @@ -62,14 +62,8 @@ static int tls_enc_record(struct aead_request *aead_req, u16 len; int rc; - switch (prot->cipher_type) { - case TLS_CIPHER_AES_GCM_128: - case TLS_CIPHER_AES_GCM_256: - break; - default: - return -EINVAL; - } cipher_desc = get_cipher_desc(prot->cipher_type); + DEBUG_NET_WARN_ON_ONCE(!cipher_desc || !cipher_desc->offloadable); buf_size = TLS_HEADER_SIZE + cipher_desc->iv; len = min_t(int, *in_len, buf_size); @@ -338,17 +332,9 @@ static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx, if (!aead_req) return NULL; - switch (tls_ctx->crypto_send.info.cipher_type) { - case TLS_CIPHER_AES_GCM_128: - salt = tls_ctx->crypto_send.aes_gcm_128.salt; - break; - case TLS_CIPHER_AES_GCM_256: - salt = tls_ctx->crypto_send.aes_gcm_256.salt; - break; - default: - goto free_req; - } cipher_desc = get_cipher_desc(tls_ctx->crypto_send.info.cipher_type); + DEBUG_NET_WARN_ON_ONCE(!cipher_desc || !cipher_desc->offloadable); + buf_len = cipher_desc->salt + cipher_desc->iv + TLS_AAD_SPACE_SIZE + sync_size + cipher_desc->tag; buf = kmalloc(buf_len, GFP_ATOMIC); @@ -356,6 +342,7 @@ static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx, goto free_req; iv = buf; + salt = crypto_info_salt(&tls_ctx->crypto_send.info, cipher_desc); memcpy(iv, salt, cipher_desc->salt); aad = buf + cipher_desc->salt + cipher_desc->iv; dummy_buf = aad + TLS_AAD_SPACE_SIZE; diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 002483e60c19..1c2c6800949d 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -59,7 +59,8 @@ enum { }; #define CHECK_CIPHER_DESC(cipher,ci) \ - static_assert(cipher ## _IV_SIZE <= MAX_IV_SIZE); \ + static_assert(cipher ## _IV_SIZE <= TLS_MAX_IV_SIZE); \ + static_assert(cipher ## _SALT_SIZE <= TLS_MAX_SALT_SIZE); \ static_assert(cipher ## _REC_SEQ_SIZE <= TLS_MAX_REC_SEQ_SIZE); \ static_assert(cipher ## _TAG_SIZE == TLS_TAG_SIZE); \ static_assert(sizeof_field(struct ci, iv) == cipher ## _IV_SIZE); \ @@ -348,8 +349,6 @@ static void tls_sk_proto_cleanup(struct sock *sk, /* We need these for tls_sw_fallback handling of other packets */ if (ctx->tx_conf == TLS_SW) { - kfree(ctx->tx.rec_seq); - kfree(ctx->tx.iv); tls_sw_release_resources_tx(sk); TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); } else if (ctx->tx_conf == TLS_HW) { @@ -585,6 +584,31 @@ static int tls_getsockopt(struct sock *sk, int level, int optname, return do_tls_getsockopt(sk, optname, optval, optlen); } +static int validate_crypto_info(const struct tls_crypto_info *crypto_info, + const struct tls_crypto_info *alt_crypto_info) +{ + if (crypto_info->version != TLS_1_2_VERSION && + crypto_info->version != TLS_1_3_VERSION) + return -EINVAL; + + switch (crypto_info->cipher_type) { + case TLS_CIPHER_ARIA_GCM_128: + case TLS_CIPHER_ARIA_GCM_256: + if (crypto_info->version != TLS_1_2_VERSION) + return -EINVAL; + break; + } + + /* Ensure that TLS version and ciphers are same in both directions */ + if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) { + if (alt_crypto_info->version != crypto_info->version || + alt_crypto_info->cipher_type != crypto_info->cipher_type) + return -EINVAL; + } + + return 0; +} + static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, unsigned int optlen, int tx) { @@ -616,21 +640,9 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, goto err_crypto_info; } - /* check version */ - if (crypto_info->version != TLS_1_2_VERSION && - crypto_info->version != TLS_1_3_VERSION) { - rc = -EINVAL; + rc = validate_crypto_info(crypto_info, alt_crypto_info); + if (rc) goto err_crypto_info; - } - - /* Ensure that TLS version and ciphers are same in both directions */ - if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) { - if (alt_crypto_info->version != crypto_info->version || - alt_crypto_info->cipher_type != crypto_info->cipher_type) { - rc = -EINVAL; - goto err_crypto_info; - } - } cipher_desc = get_cipher_desc(crypto_info->cipher_type); if (!cipher_desc) { @@ -638,16 +650,6 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, goto err_crypto_info; } - switch (crypto_info->cipher_type) { - case TLS_CIPHER_ARIA_GCM_128: - case TLS_CIPHER_ARIA_GCM_256: - if (crypto_info->version != TLS_1_2_VERSION) { - rc = -EINVAL; - goto err_crypto_info; - } - break; - } - if (optlen != cipher_desc->crypto_info) { rc = -EINVAL; goto err_crypto_info; @@ -662,13 +664,13 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, } if (tx) { - rc = tls_set_device_offload(sk, ctx); + rc = tls_set_device_offload(sk); conf = TLS_HW; if (!rc) { TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE); TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); } else { - rc = tls_set_sw_offload(sk, ctx, 1); + rc = tls_set_sw_offload(sk, 1); if (rc) goto err_crypto_info; TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW); @@ -682,7 +684,7 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE); TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); } else { - rc = tls_set_sw_offload(sk, ctx, 0); + rc = tls_set_sw_offload(sk, 0); if (rc) goto err_crypto_info; TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW); diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index e9d1e83a859d..a78e8e722409 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -60,7 +60,7 @@ struct tls_decrypt_arg { struct tls_decrypt_ctx { struct sock *sk; - u8 iv[MAX_IV_SIZE]; + u8 iv[TLS_MAX_IV_SIZE]; u8 aad[TLS_MAX_AAD_SIZE]; u8 tail; struct scatterlist sg[]; @@ -1491,7 +1491,7 @@ static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov, */ aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv); aead_size = ALIGN(aead_size, __alignof__(*dctx)); - mem = kmalloc(aead_size + struct_size(dctx, sg, n_sgin + n_sgout), + mem = kmalloc(aead_size + struct_size(dctx, sg, size_add(n_sgin, n_sgout)), sk->sk_allocation); if (!mem) { err = -ENOMEM; @@ -2326,7 +2326,7 @@ int tls_rx_msg_size(struct tls_strparser *strp, struct sk_buff *skb) { struct tls_context *tls_ctx = tls_get_ctx(strp->sk); struct tls_prot_info *prot = &tls_ctx->prot_info; - char header[TLS_HEADER_SIZE + MAX_IV_SIZE]; + char header[TLS_HEADER_SIZE + TLS_MAX_IV_SIZE]; size_t cipher_overhead; size_t data_len = 0; int ret; @@ -2474,9 +2474,6 @@ void tls_sw_release_resources_rx(struct sock *sk) struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - kfree(tls_ctx->rx.rec_seq); - kfree(tls_ctx->rx.iv); - if (ctx->aead_recv) { __skb_queue_purge(&ctx->rx_list); crypto_free_aead(ctx->aead_recv); @@ -2588,69 +2585,113 @@ void tls_update_rx_zc_capable(struct tls_context *tls_ctx) tls_ctx->prot_info.version != TLS_1_3_VERSION; } -int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) +static struct tls_sw_context_tx *init_ctx_tx(struct tls_context *ctx, struct sock *sk) +{ + struct tls_sw_context_tx *sw_ctx_tx; + + if (!ctx->priv_ctx_tx) { + sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL); + if (!sw_ctx_tx) + return NULL; + } else { + sw_ctx_tx = ctx->priv_ctx_tx; + } + + crypto_init_wait(&sw_ctx_tx->async_wait); + spin_lock_init(&sw_ctx_tx->encrypt_compl_lock); + INIT_LIST_HEAD(&sw_ctx_tx->tx_list); + INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler); + sw_ctx_tx->tx_work.sk = sk; + + return sw_ctx_tx; +} + +static struct tls_sw_context_rx *init_ctx_rx(struct tls_context *ctx) +{ + struct tls_sw_context_rx *sw_ctx_rx; + + if (!ctx->priv_ctx_rx) { + sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL); + if (!sw_ctx_rx) + return NULL; + } else { + sw_ctx_rx = ctx->priv_ctx_rx; + } + + crypto_init_wait(&sw_ctx_rx->async_wait); + spin_lock_init(&sw_ctx_rx->decrypt_compl_lock); + init_waitqueue_head(&sw_ctx_rx->wq); + skb_queue_head_init(&sw_ctx_rx->rx_list); + skb_queue_head_init(&sw_ctx_rx->async_hold); + + return sw_ctx_rx; +} + +int init_prot_info(struct tls_prot_info *prot, + const struct tls_crypto_info *crypto_info, + const struct tls_cipher_desc *cipher_desc) +{ + u16 nonce_size = cipher_desc->nonce; + + if (crypto_info->version == TLS_1_3_VERSION) { + nonce_size = 0; + prot->aad_size = TLS_HEADER_SIZE; + prot->tail_size = 1; + } else { + prot->aad_size = TLS_AAD_SPACE_SIZE; + prot->tail_size = 0; + } + + /* Sanity-check the sizes for stack allocations. */ + if (nonce_size > TLS_MAX_IV_SIZE || prot->aad_size > TLS_MAX_AAD_SIZE) + return -EINVAL; + + prot->version = crypto_info->version; + prot->cipher_type = crypto_info->cipher_type; + prot->prepend_size = TLS_HEADER_SIZE + nonce_size; + prot->tag_size = cipher_desc->tag; + prot->overhead_size = prot->prepend_size + prot->tag_size + prot->tail_size; + prot->iv_size = cipher_desc->iv; + prot->salt_size = cipher_desc->salt; + prot->rec_seq_size = cipher_desc->rec_seq; + + return 0; +} + +int tls_set_sw_offload(struct sock *sk, int tx) { - struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_prot_info *prot = &tls_ctx->prot_info; - struct tls_crypto_info *crypto_info; struct tls_sw_context_tx *sw_ctx_tx = NULL; struct tls_sw_context_rx *sw_ctx_rx = NULL; + const struct tls_cipher_desc *cipher_desc; + struct tls_crypto_info *crypto_info; + char *iv, *rec_seq, *key, *salt; struct cipher_context *cctx; + struct tls_prot_info *prot; struct crypto_aead **aead; + struct tls_context *ctx; struct crypto_tfm *tfm; - char *iv, *rec_seq, *key, *salt; - const struct tls_cipher_desc *cipher_desc; - u16 nonce_size; int rc = 0; - if (!ctx) { - rc = -EINVAL; - goto out; - } + ctx = tls_get_ctx(sk); + prot = &ctx->prot_info; if (tx) { - if (!ctx->priv_ctx_tx) { - sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL); - if (!sw_ctx_tx) { - rc = -ENOMEM; - goto out; - } - ctx->priv_ctx_tx = sw_ctx_tx; - } else { - sw_ctx_tx = - (struct tls_sw_context_tx *)ctx->priv_ctx_tx; - } - } else { - if (!ctx->priv_ctx_rx) { - sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL); - if (!sw_ctx_rx) { - rc = -ENOMEM; - goto out; - } - ctx->priv_ctx_rx = sw_ctx_rx; - } else { - sw_ctx_rx = - (struct tls_sw_context_rx *)ctx->priv_ctx_rx; - } - } + ctx->priv_ctx_tx = init_ctx_tx(ctx, sk); + if (!ctx->priv_ctx_tx) + return -ENOMEM; - if (tx) { - crypto_init_wait(&sw_ctx_tx->async_wait); - spin_lock_init(&sw_ctx_tx->encrypt_compl_lock); + sw_ctx_tx = ctx->priv_ctx_tx; crypto_info = &ctx->crypto_send.info; cctx = &ctx->tx; aead = &sw_ctx_tx->aead_send; - INIT_LIST_HEAD(&sw_ctx_tx->tx_list); - INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler); - sw_ctx_tx->tx_work.sk = sk; } else { - crypto_init_wait(&sw_ctx_rx->async_wait); - spin_lock_init(&sw_ctx_rx->decrypt_compl_lock); - init_waitqueue_head(&sw_ctx_rx->wq); + ctx->priv_ctx_rx = init_ctx_rx(ctx); + if (!ctx->priv_ctx_rx) + return -ENOMEM; + + sw_ctx_rx = ctx->priv_ctx_rx; crypto_info = &ctx->crypto_recv.info; cctx = &ctx->rx; - skb_queue_head_init(&sw_ctx_rx->rx_list); - skb_queue_head_init(&sw_ctx_rx->async_hold); aead = &sw_ctx_rx->aead_recv; } @@ -2660,58 +2701,25 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) goto free_priv; } - nonce_size = cipher_desc->nonce; + rc = init_prot_info(prot, crypto_info, cipher_desc); + if (rc) + goto free_priv; iv = crypto_info_iv(crypto_info, cipher_desc); key = crypto_info_key(crypto_info, cipher_desc); salt = crypto_info_salt(crypto_info, cipher_desc); rec_seq = crypto_info_rec_seq(crypto_info, cipher_desc); - if (crypto_info->version == TLS_1_3_VERSION) { - nonce_size = 0; - prot->aad_size = TLS_HEADER_SIZE; - prot->tail_size = 1; - } else { - prot->aad_size = TLS_AAD_SPACE_SIZE; - prot->tail_size = 0; - } - - /* Sanity-check the sizes for stack allocations. */ - if (nonce_size > MAX_IV_SIZE || prot->aad_size > TLS_MAX_AAD_SIZE) { - rc = -EINVAL; - goto free_priv; - } - - prot->version = crypto_info->version; - prot->cipher_type = crypto_info->cipher_type; - prot->prepend_size = TLS_HEADER_SIZE + nonce_size; - prot->tag_size = cipher_desc->tag; - prot->overhead_size = prot->prepend_size + - prot->tag_size + prot->tail_size; - prot->iv_size = cipher_desc->iv; - prot->salt_size = cipher_desc->salt; - cctx->iv = kmalloc(cipher_desc->iv + cipher_desc->salt, GFP_KERNEL); - if (!cctx->iv) { - rc = -ENOMEM; - goto free_priv; - } - /* Note: 128 & 256 bit salt are the same size */ - prot->rec_seq_size = cipher_desc->rec_seq; memcpy(cctx->iv, salt, cipher_desc->salt); memcpy(cctx->iv + cipher_desc->salt, iv, cipher_desc->iv); - - cctx->rec_seq = kmemdup(rec_seq, cipher_desc->rec_seq, GFP_KERNEL); - if (!cctx->rec_seq) { - rc = -ENOMEM; - goto free_iv; - } + memcpy(cctx->rec_seq, rec_seq, cipher_desc->rec_seq); if (!*aead) { *aead = crypto_alloc_aead(cipher_desc->cipher_name, 0, 0); if (IS_ERR(*aead)) { rc = PTR_ERR(*aead); *aead = NULL; - goto free_rec_seq; + goto free_priv; } } @@ -2743,12 +2751,6 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) free_aead: crypto_free_aead(*aead); *aead = NULL; -free_rec_seq: - kfree(cctx->rec_seq); - cctx->rec_seq = NULL; -free_iv: - kfree(cctx->iv); - cctx->iv = NULL; free_priv: if (tx) { kfree(ctx->priv_ctx_tx); |