From 130b392c6cd6b2aed1b7eb32253d4920babb4891 Mon Sep 17 00:00:00 2001 From: Dave Watson Date: Wed, 30 Jan 2019 21:58:31 +0000 Subject: net: tls: Add tls 1.3 support TLS 1.3 has minor changes from TLS 1.2 at the record layer. * Header now hardcodes the same version and application content type in the header. * The real content type is appended after the data, before encryption (or after decryption). * The IV is xored with the sequence number, instead of concatinating four bytes of IV with the explicit IV. * Zero-padding: No exlicit length is given, we search backwards from the end of the decrypted data for the first non-zero byte, which is the content type. Currently recv supports reading zero-padding, but there is no way for send to add zero padding. Signed-off-by: Dave Watson Signed-off-by: David S. Miller --- net/tls/tls_device.c | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'net/tls/tls_device.c') diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c index d753e362d2d9..7ee9008b2187 100644 --- a/net/tls/tls_device.c +++ b/net/tls/tls_device.c @@ -257,7 +257,8 @@ static int tls_push_record(struct sock *sk, tls_fill_prepend(ctx, skb_frag_address(frag), record->len - ctx->tx.prepend_size, - record_type); + record_type, + ctx->crypto_send.info.version); /* HW doesn't care about the data in the tag, because it fills it. */ dummy_tag_frag.page = skb_frag_page(frag); @@ -270,7 +271,7 @@ static int tls_push_record(struct sock *sk, spin_unlock_irq(&offload_ctx->lock); offload_ctx->open_record = NULL; set_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags); - tls_advance_record_sn(sk, &ctx->tx); + tls_advance_record_sn(sk, &ctx->tx, ctx->crypto_send.info.version); for (i = 0; i < record->num_frags; i++) { frag = &record->frags[i]; -- cgit v1.2.3-70-g09d2 From 4509de14680084141d3514c3b87bd9d070fc366d Mon Sep 17 00:00:00 2001 From: Vakul Garg Date: Thu, 14 Feb 2019 07:11:35 +0000 Subject: net/tls: Move protocol constants from cipher context to tls context Each tls context maintains two cipher contexts (one each for tx and rx directions). For each tls session, the constants such as protocol version, ciphersuite, iv size, associated data size etc are same for both the directions and need to be stored only once per tls context. Hence these are moved from 'struct cipher_context' to 'struct tls_prot_info' and stored only once in 'struct tls_context'. Signed-off-by: Vakul Garg Signed-off-by: David S. Miller --- include/net/tls.h | 46 +++++++++----- net/tls/tls_device.c | 24 ++++--- net/tls/tls_main.c | 17 ++++- net/tls/tls_sw.c | 172 +++++++++++++++++++++++++++------------------------ 4 files changed, 149 insertions(+), 110 deletions(-) (limited to 'net/tls/tls_device.c') diff --git a/include/net/tls.h b/include/net/tls.h index a93a8ed8f716..a8b37226a287 100644 --- a/include/net/tls.h +++ b/include/net/tls.h @@ -199,15 +199,8 @@ enum { }; struct cipher_context { - u16 prepend_size; - u16 tag_size; - u16 overhead_size; - u16 iv_size; char *iv; - u16 rec_seq_size; char *rec_seq; - u16 aad_size; - u16 tail_size; }; union tls_crypto_context { @@ -218,7 +211,21 @@ union tls_crypto_context { }; }; +struct tls_prot_info { + u16 version; + u16 cipher_type; + u16 prepend_size; + u16 tag_size; + u16 overhead_size; + u16 iv_size; + u16 rec_seq_size; + u16 aad_size; + u16 tail_size; +}; + struct tls_context { + struct tls_prot_info prot_info; + union tls_crypto_context crypto_send; union tls_crypto_context crypto_recv; @@ -401,16 +408,26 @@ static inline bool tls_bigint_increment(unsigned char *seq, int len) return (i == -1); } +static inline struct tls_context *tls_get_ctx(const struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + + return icsk->icsk_ulp_data; +} + static inline void tls_advance_record_sn(struct sock *sk, struct cipher_context *ctx, int version) { - if (tls_bigint_increment(ctx->rec_seq, ctx->rec_seq_size)) + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; + + if (tls_bigint_increment(ctx->rec_seq, prot->rec_seq_size)) tls_err_abort(sk, EBADMSG); if (version != TLS_1_3_VERSION) { tls_bigint_increment(ctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, - ctx->iv_size); + prot->iv_size); } } @@ -420,9 +437,10 @@ static inline void tls_fill_prepend(struct tls_context *ctx, unsigned char record_type, int version) { - size_t pkt_len, iv_size = ctx->tx.iv_size; + struct tls_prot_info *prot = &ctx->prot_info; + size_t pkt_len, iv_size = prot->iv_size; - pkt_len = plaintext_len + ctx->tx.tag_size; + pkt_len = plaintext_len + prot->tag_size; if (version != TLS_1_3_VERSION) { pkt_len += iv_size; @@ -475,12 +493,6 @@ static inline void xor_iv_with_seq(int version, char *iv, char *seq) } } -static inline struct tls_context *tls_get_ctx(const struct sock *sk) -{ - struct inet_connection_sock *icsk = inet_csk(sk); - - return icsk->icsk_ulp_data; -} static inline struct tls_sw_context_rx *tls_sw_ctx_rx( const struct tls_context *tls_ctx) diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c index 7ee9008b2187..a5c17c47d08a 100644 --- a/net/tls/tls_device.c +++ b/net/tls/tls_device.c @@ -247,6 +247,7 @@ static int tls_push_record(struct sock *sk, int flags, unsigned char record_type) { + struct tls_prot_info *prot = &ctx->prot_info; struct tcp_sock *tp = tcp_sk(sk); struct page_frag dummy_tag_frag; skb_frag_t *frag; @@ -256,7 +257,7 @@ static int tls_push_record(struct sock *sk, frag = &record->frags[0]; tls_fill_prepend(ctx, skb_frag_address(frag), - record->len - ctx->tx.prepend_size, + record->len - prot->prepend_size, record_type, ctx->crypto_send.info.version); @@ -264,7 +265,7 @@ static int tls_push_record(struct sock *sk, dummy_tag_frag.page = skb_frag_page(frag); dummy_tag_frag.offset = 0; - tls_append_frag(record, &dummy_tag_frag, ctx->tx.tag_size); + tls_append_frag(record, &dummy_tag_frag, prot->tag_size); record->end_seq = tp->write_seq + record->len; spin_lock_irq(&offload_ctx->lock); list_add_tail(&record->list, &offload_ctx->records_list); @@ -347,6 +348,7 @@ static int tls_push_data(struct sock *sk, unsigned char record_type) { struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx); int tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST; int more = flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE); @@ -376,10 +378,10 @@ static int tls_push_data(struct sock *sk, * we need to leave room for an authentication tag. */ max_open_record_len = TLS_MAX_PAYLOAD_SIZE + - tls_ctx->tx.prepend_size; + prot->prepend_size; do { rc = tls_do_allocation(sk, ctx, pfrag, - tls_ctx->tx.prepend_size); + prot->prepend_size); if (rc) { rc = sk_stream_wait_memory(sk, &timeo); if (!rc) @@ -397,7 +399,7 @@ handle_error: size = orig_size; destroy_record(record); ctx->open_record = NULL; - } else if (record->len > tls_ctx->tx.prepend_size) { + } else if (record->len > prot->prepend_size) { goto last_record; } @@ -658,6 +660,8 @@ int tls_device_decrypted(struct sock *sk, struct sk_buff *skb) int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) { u16 nonce_size, tag_size, iv_size, rec_seq_size; + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_record_info *start_marker_record; struct tls_offload_context_tx *offload_ctx; struct tls_crypto_info *crypto_info; @@ -703,10 +707,10 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) goto free_offload_ctx; } - ctx->tx.prepend_size = TLS_HEADER_SIZE + nonce_size; - ctx->tx.tag_size = tag_size; - ctx->tx.overhead_size = ctx->tx.prepend_size + ctx->tx.tag_size; - ctx->tx.iv_size = iv_size; + prot->prepend_size = TLS_HEADER_SIZE + nonce_size; + prot->tag_size = tag_size; + prot->overhead_size = prot->prepend_size + prot->tag_size; + prot->iv_size = iv_size; ctx->tx.iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE, GFP_KERNEL); if (!ctx->tx.iv) { @@ -716,7 +720,7 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size); - ctx->tx.rec_seq_size = rec_seq_size; + prot->rec_seq_size = rec_seq_size; ctx->tx.rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL); if (!ctx->tx.rec_seq) { rc = -ENOMEM; diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index d1c2fd9a3f63..caff15b2f9b2 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -435,6 +435,7 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, unsigned int optlen, int tx) { struct tls_crypto_info *crypto_info; + struct tls_crypto_info *alt_crypto_info; struct tls_context *ctx = tls_get_ctx(sk); size_t optsize; int rc = 0; @@ -445,10 +446,13 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, goto out; } - if (tx) + if (tx) { crypto_info = &ctx->crypto_send.info; - else + alt_crypto_info = &ctx->crypto_recv.info; + } else { crypto_info = &ctx->crypto_recv.info; + alt_crypto_info = &ctx->crypto_send.info; + } /* Currently we don't support set crypto info more than one time */ if (TLS_CRYPTO_INFO_READY(crypto_info)) { @@ -469,6 +473,15 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, goto err_crypto_info; } + /* Ensure that TLS version and ciphers are same in both directions */ + if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) { + if (alt_crypto_info->version != crypto_info->version || + alt_crypto_info->cipher_type != crypto_info->cipher_type) { + rc = -EINVAL; + goto err_crypto_info; + } + } + switch (crypto_info->cipher_type) { case TLS_CIPHER_AES_GCM_128: case TLS_CIPHER_AES_GCM_256: { diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index ae4784734547..71be8acfbc9b 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -127,7 +127,7 @@ static int padding_length(struct tls_sw_context_rx *ctx, int sub = 0; /* Determine zero-padding length */ - if (tls_ctx->crypto_recv.info.version == TLS_1_3_VERSION) { + if (tls_ctx->prot_info.version == TLS_1_3_VERSION) { char content_type = 0; int err; int back = 17; @@ -155,6 +155,7 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err) struct scatterlist *sgin = aead_req->src; struct tls_sw_context_rx *ctx; struct tls_context *tls_ctx; + struct tls_prot_info *prot; struct scatterlist *sg; struct sk_buff *skb; unsigned int pages; @@ -163,6 +164,7 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err) skb = (struct sk_buff *)req->data; tls_ctx = tls_get_ctx(skb->sk); ctx = tls_sw_ctx_rx(tls_ctx); + prot = &tls_ctx->prot_info; /* Propagate if there was an err */ if (err) { @@ -171,8 +173,8 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err) } else { struct strp_msg *rxm = strp_msg(skb); rxm->full_len -= padding_length(ctx, tls_ctx, skb); - rxm->offset += tls_ctx->rx.prepend_size; - rxm->full_len -= tls_ctx->rx.overhead_size; + rxm->offset += prot->prepend_size; + rxm->full_len -= prot->overhead_size; } /* After using skb->sk to propagate sk through crypto async callback @@ -209,13 +211,14 @@ static int tls_do_decryption(struct sock *sk, bool async) { struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); int ret; aead_request_set_tfm(aead_req, ctx->aead_recv); - aead_request_set_ad(aead_req, tls_ctx->rx.aad_size); + aead_request_set_ad(aead_req, prot->aad_size); aead_request_set_crypt(aead_req, sgin, sgout, - data_len + tls_ctx->rx.tag_size, + data_len + prot->tag_size, (u8 *)iv_recv); if (async) { @@ -253,12 +256,13 @@ static int tls_do_decryption(struct sock *sk, static void tls_trim_both_msgs(struct sock *sk, int target_size) { struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_rec *rec = ctx->open_rec; sk_msg_trim(sk, &rec->msg_plaintext, target_size); if (target_size > 0) - target_size += tls_ctx->tx.overhead_size; + target_size += prot->overhead_size; sk_msg_trim(sk, &rec->msg_encrypted, target_size); } @@ -275,6 +279,7 @@ static int tls_alloc_encrypted_msg(struct sock *sk, int len) static int tls_clone_plaintext_msg(struct sock *sk, int required) { struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_rec *rec = ctx->open_rec; struct sk_msg *msg_pl = &rec->msg_plaintext; @@ -290,7 +295,7 @@ static int tls_clone_plaintext_msg(struct sock *sk, int required) /* Skip initial bytes in msg_en's data to be able to use * same offset of both plain and encrypted data. */ - skip = tls_ctx->tx.prepend_size + msg_pl->sg.size; + skip = prot->prepend_size + msg_pl->sg.size; return sk_msg_clone(sk, msg_pl, msg_en, skip, len); } @@ -298,6 +303,7 @@ static int tls_clone_plaintext_msg(struct sock *sk, int required) static struct tls_rec *tls_get_rec(struct sock *sk) { struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct sk_msg *msg_pl, *msg_en; struct tls_rec *rec; @@ -316,13 +322,11 @@ static struct tls_rec *tls_get_rec(struct sock *sk) sk_msg_init(msg_en); sg_init_table(rec->sg_aead_in, 2); - sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, - tls_ctx->tx.aad_size); + sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, prot->aad_size); sg_unmark_end(&rec->sg_aead_in[1]); sg_init_table(rec->sg_aead_out, 2); - sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, - tls_ctx->tx.aad_size); + sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, prot->aad_size); sg_unmark_end(&rec->sg_aead_out[1]); return rec; @@ -411,6 +415,7 @@ static void tls_encrypt_done(struct crypto_async_request *req, int err) struct aead_request *aead_req = (struct aead_request *)req; struct sock *sk = req->data; struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct scatterlist *sge; struct sk_msg *msg_en; @@ -422,8 +427,8 @@ static void tls_encrypt_done(struct crypto_async_request *req, int err) msg_en = &rec->msg_encrypted; sge = sk_msg_elem(msg_en, msg_en->sg.curr); - sge->offset -= tls_ctx->tx.prepend_size; - sge->length += tls_ctx->tx.prepend_size; + sge->offset -= prot->prepend_size; + sge->length += prot->prepend_size; /* Check if error is previously set on socket */ if (err || sk->sk_err) { @@ -470,22 +475,23 @@ static int tls_do_encryption(struct sock *sk, struct aead_request *aead_req, size_t data_len, u32 start) { + struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_rec *rec = ctx->open_rec; struct sk_msg *msg_en = &rec->msg_encrypted; struct scatterlist *sge = sk_msg_elem(msg_en, start); int rc; memcpy(rec->iv_data, tls_ctx->tx.iv, sizeof(rec->iv_data)); - xor_iv_with_seq(tls_ctx->crypto_send.info.version, rec->iv_data, + xor_iv_with_seq(prot->version, rec->iv_data, tls_ctx->tx.rec_seq); - sge->offset += tls_ctx->tx.prepend_size; - sge->length -= tls_ctx->tx.prepend_size; + sge->offset += prot->prepend_size; + sge->length -= prot->prepend_size; msg_en->sg.curr = start; aead_request_set_tfm(aead_req, ctx->aead_send); - aead_request_set_ad(aead_req, tls_ctx->tx.aad_size); + aead_request_set_ad(aead_req, prot->aad_size); aead_request_set_crypt(aead_req, rec->sg_aead_in, rec->sg_aead_out, data_len, rec->iv_data); @@ -500,8 +506,8 @@ static int tls_do_encryption(struct sock *sk, rc = crypto_aead_encrypt(aead_req); if (!rc || rc != -EINPROGRESS) { atomic_dec(&ctx->encrypt_pending); - sge->offset -= tls_ctx->tx.prepend_size; - sge->length += tls_ctx->tx.prepend_size; + sge->offset -= prot->prepend_size; + sge->length += prot->prepend_size; } if (!rc) { @@ -513,8 +519,7 @@ static int tls_do_encryption(struct sock *sk, /* Unhook the record from context if encryption is not failure */ ctx->open_rec = NULL; - tls_advance_record_sn(sk, &tls_ctx->tx, - tls_ctx->crypto_send.info.version); + tls_advance_record_sn(sk, &tls_ctx->tx, prot->version); return rc; } @@ -640,6 +645,7 @@ static int tls_push_record(struct sock *sk, int flags, unsigned char record_type) { struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_rec *rec = ctx->open_rec, *tmp = NULL; u32 i, split_point, uninitialized_var(orig_end); @@ -658,12 +664,12 @@ static int tls_push_record(struct sock *sk, int flags, split = split_point && split_point < msg_pl->sg.size; if (split) { rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en, - split_point, tls_ctx->tx.overhead_size, + split_point, prot->overhead_size, &orig_end); if (rc < 0) return rc; sk_msg_trim(sk, msg_en, msg_pl->sg.size + - tls_ctx->tx.overhead_size); + prot->overhead_size); } rec->tx_flags = flags; @@ -673,7 +679,7 @@ static int tls_push_record(struct sock *sk, int flags, sk_msg_iter_var_prev(i); rec->content_type = record_type; - if (tls_ctx->crypto_send.info.version == TLS_1_3_VERSION) { + if (prot->version == TLS_1_3_VERSION) { /* Add content type to end of message. No padding added */ sg_set_buf(&rec->sg_content_type, &rec->content_type, 1); sg_mark_end(&rec->sg_content_type); @@ -694,22 +700,20 @@ static int tls_push_record(struct sock *sk, int flags, i = msg_en->sg.start; sg_chain(rec->sg_aead_out, 2, &msg_en->sg.data[i]); - tls_make_aad(rec->aad_space, msg_pl->sg.size + tls_ctx->tx.tail_size, - tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size, - record_type, - tls_ctx->crypto_send.info.version); + tls_make_aad(rec->aad_space, msg_pl->sg.size + prot->tail_size, + tls_ctx->tx.rec_seq, prot->rec_seq_size, + record_type, prot->version); tls_fill_prepend(tls_ctx, page_address(sg_page(&msg_en->sg.data[i])) + msg_en->sg.data[i].offset, - msg_pl->sg.size + tls_ctx->tx.tail_size, - record_type, - tls_ctx->crypto_send.info.version); + msg_pl->sg.size + prot->tail_size, + record_type, prot->version); tls_ctx->pending_open_record_frags = false; rc = tls_do_encryption(sk, tls_ctx, ctx, req, - msg_pl->sg.size + tls_ctx->tx.tail_size, i); + msg_pl->sg.size + prot->tail_size, i); if (rc < 0) { if (rc != -EINPROGRESS) { tls_err_abort(sk, EBADMSG); @@ -723,8 +727,7 @@ static int tls_push_record(struct sock *sk, int flags, } else if (split) { msg_pl = &tmp->msg_plaintext; msg_en = &tmp->msg_encrypted; - sk_msg_trim(sk, msg_en, msg_pl->sg.size + - tls_ctx->tx.overhead_size); + sk_msg_trim(sk, msg_en, msg_pl->sg.size + prot->overhead_size); tls_ctx->pending_open_record_frags = true; ctx->open_rec = tmp; } @@ -859,6 +862,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) { long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); bool async_capable = ctx->async_capable; unsigned char record_type = TLS_RECORD_TYPE_DATA; @@ -925,7 +929,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) } required_size = msg_pl->sg.size + try_to_copy + - tls_ctx->tx.overhead_size; + prot->overhead_size; if (!sk_stream_memory_free(sk)) goto wait_for_sndbuf; @@ -994,8 +998,8 @@ fallback_to_reg_send: */ try_to_copy -= required_size - msg_pl->sg.size; full_record = true; - sk_msg_trim(sk, msg_en, msg_pl->sg.size + - tls_ctx->tx.overhead_size); + sk_msg_trim(sk, msg_en, + msg_pl->sg.size + prot->overhead_size); } if (try_to_copy) { @@ -1081,6 +1085,7 @@ static int tls_sw_do_sendpage(struct sock *sk, struct page *page, long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct tls_prot_info *prot = &tls_ctx->prot_info; unsigned char record_type = TLS_RECORD_TYPE_DATA; struct sk_msg *msg_pl; struct tls_rec *rec; @@ -1130,8 +1135,7 @@ static int tls_sw_do_sendpage(struct sock *sk, struct page *page, full_record = true; } - required_size = msg_pl->sg.size + copy + - tls_ctx->tx.overhead_size; + required_size = msg_pl->sg.size + copy + prot->overhead_size; if (!sk_stream_memory_free(sk)) goto wait_for_sndbuf; @@ -1330,6 +1334,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct strp_msg *rxm = strp_msg(skb); int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0; struct aead_request *aead_req; @@ -1337,16 +1342,16 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, u8 *aad, *iv, *mem = NULL; struct scatterlist *sgin = NULL; struct scatterlist *sgout = NULL; - const int data_len = rxm->full_len - tls_ctx->rx.overhead_size + - tls_ctx->rx.tail_size; + const int data_len = rxm->full_len - prot->overhead_size + + prot->tail_size; if (*zc && (out_iov || out_sg)) { if (out_iov) n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1; else n_sgout = sg_nents(out_sg); - n_sgin = skb_nsg(skb, rxm->offset + tls_ctx->rx.prepend_size, - rxm->full_len - tls_ctx->rx.prepend_size); + n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size, + rxm->full_len - prot->prepend_size); } else { n_sgout = 0; *zc = false; @@ -1363,7 +1368,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv); mem_size = aead_size + (nsg * sizeof(struct scatterlist)); - mem_size = mem_size + tls_ctx->rx.aad_size; + mem_size = mem_size + prot->aad_size; mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv); /* Allocate a single block of memory which contains @@ -1379,37 +1384,35 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, sgin = (struct scatterlist *)(mem + aead_size); sgout = sgin + n_sgin; aad = (u8 *)(sgout + n_sgout); - iv = aad + tls_ctx->rx.aad_size; + iv = aad + prot->aad_size; /* Prepare IV */ err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE, iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, - tls_ctx->rx.iv_size); + prot->iv_size); if (err < 0) { kfree(mem); return err; } - if (tls_ctx->crypto_recv.info.version == TLS_1_3_VERSION) + if (prot->version == TLS_1_3_VERSION) memcpy(iv, tls_ctx->rx.iv, crypto_aead_ivsize(ctx->aead_recv)); else memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE); - xor_iv_with_seq(tls_ctx->crypto_recv.info.version, iv, - tls_ctx->rx.rec_seq); + xor_iv_with_seq(prot->version, iv, tls_ctx->rx.rec_seq); /* Prepare AAD */ - tls_make_aad(aad, rxm->full_len - tls_ctx->rx.overhead_size + - tls_ctx->rx.tail_size, - tls_ctx->rx.rec_seq, tls_ctx->rx.rec_seq_size, - ctx->control, - tls_ctx->crypto_recv.info.version); + tls_make_aad(aad, rxm->full_len - prot->overhead_size + + prot->tail_size, + tls_ctx->rx.rec_seq, prot->rec_seq_size, + ctx->control, prot->version); /* Prepare sgin */ sg_init_table(sgin, n_sgin); - sg_set_buf(&sgin[0], aad, tls_ctx->rx.aad_size); + sg_set_buf(&sgin[0], aad, prot->aad_size); err = skb_to_sgvec(skb, &sgin[1], - rxm->offset + tls_ctx->rx.prepend_size, - rxm->full_len - tls_ctx->rx.prepend_size); + rxm->offset + prot->prepend_size, + rxm->full_len - prot->prepend_size); if (err < 0) { kfree(mem); return err; @@ -1418,7 +1421,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, if (n_sgout) { if (out_iov) { sg_init_table(sgout, n_sgout); - sg_set_buf(&sgout[0], aad, tls_ctx->rx.aad_size); + sg_set_buf(&sgout[0], aad, prot->aad_size); *chunk = 0; err = tls_setup_from_iter(sk, out_iov, data_len, @@ -1459,7 +1462,8 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - int version = tls_ctx->crypto_recv.info.version; + struct tls_prot_info *prot = &tls_ctx->prot_info; + int version = prot->version; struct strp_msg *rxm = strp_msg(skb); int err = 0; @@ -1480,8 +1484,8 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, rxm->full_len -= padding_length(ctx, tls_ctx, skb); - rxm->offset += tls_ctx->rx.prepend_size; - rxm->full_len -= tls_ctx->rx.overhead_size; + rxm->offset += prot->prepend_size; + rxm->full_len -= prot->overhead_size; tls_advance_record_sn(sk, &tls_ctx->rx, version); ctx->decrypted = true; ctx->saved_data_ready(sk); @@ -1605,6 +1609,7 @@ int tls_sw_recvmsg(struct sock *sk, { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct sk_psock *psock; unsigned char control = 0; ssize_t decrypted = 0; @@ -1667,11 +1672,11 @@ int tls_sw_recvmsg(struct sock *sk, rxm = strp_msg(skb); - to_decrypt = rxm->full_len - tls_ctx->rx.overhead_size; + to_decrypt = rxm->full_len - prot->overhead_size; if (to_decrypt <= len && !is_kvec && !is_peek && ctx->control == TLS_RECORD_TYPE_DATA && - tls_ctx->crypto_recv.info.version != TLS_1_3_VERSION) + prot->version != TLS_1_3_VERSION) zc = true; /* Do not use async mode if record is non-data */ @@ -1875,6 +1880,7 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb) { struct tls_context *tls_ctx = tls_get_ctx(strp->sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + struct tls_prot_info *prot = &tls_ctx->prot_info; char header[TLS_HEADER_SIZE + MAX_IV_SIZE]; struct strp_msg *rxm = strp_msg(skb); size_t cipher_overhead; @@ -1882,17 +1888,17 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb) int ret; /* Verify that we have a full TLS header, or wait for more data */ - if (rxm->offset + tls_ctx->rx.prepend_size > skb->len) + if (rxm->offset + prot->prepend_size > skb->len) return 0; /* Sanity-check size of on-stack buffer. */ - if (WARN_ON(tls_ctx->rx.prepend_size > sizeof(header))) { + if (WARN_ON(prot->prepend_size > sizeof(header))) { ret = -EINVAL; goto read_failure; } /* Linearize header to local buffer */ - ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size); + ret = skb_copy_bits(skb, rxm->offset, header, prot->prepend_size); if (ret < 0) goto read_failure; @@ -1901,12 +1907,12 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb) data_len = ((header[4] & 0xFF) | (header[3] << 8)); - cipher_overhead = tls_ctx->rx.tag_size; - if (tls_ctx->crypto_recv.info.version != TLS_1_3_VERSION) - cipher_overhead += tls_ctx->rx.iv_size; + cipher_overhead = prot->tag_size; + if (prot->version != TLS_1_3_VERSION) + cipher_overhead += prot->iv_size; if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead + - tls_ctx->rx.tail_size) { + prot->tail_size) { ret = -EMSGSIZE; goto read_failure; } @@ -2066,6 +2072,8 @@ static void tx_work_handler(struct work_struct *work) int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) { + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_crypto_info *crypto_info; struct tls12_crypto_info_aes_gcm_128 *gcm_128_info; struct tls12_crypto_info_aes_gcm_256 *gcm_256_info; @@ -2171,18 +2179,20 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) if (crypto_info->version == TLS_1_3_VERSION) { nonce_size = 0; - cctx->aad_size = TLS_HEADER_SIZE; - cctx->tail_size = 1; + prot->aad_size = TLS_HEADER_SIZE; + prot->tail_size = 1; } else { - cctx->aad_size = TLS_AAD_SPACE_SIZE; - cctx->tail_size = 0; + prot->aad_size = TLS_AAD_SPACE_SIZE; + prot->tail_size = 0; } - cctx->prepend_size = TLS_HEADER_SIZE + nonce_size; - cctx->tag_size = tag_size; - cctx->overhead_size = cctx->prepend_size + cctx->tag_size + - cctx->tail_size; - cctx->iv_size = iv_size; + prot->version = crypto_info->version; + prot->cipher_type = crypto_info->cipher_type; + prot->prepend_size = TLS_HEADER_SIZE + nonce_size; + prot->tag_size = tag_size; + prot->overhead_size = prot->prepend_size + + prot->tag_size + prot->tail_size; + prot->iv_size = iv_size; cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE, GFP_KERNEL); if (!cctx->iv) { @@ -2192,7 +2202,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) /* Note: 128 & 256 bit salt are the same size */ memcpy(cctx->iv, salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE); memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size); - cctx->rec_seq_size = rec_seq_size; + prot->rec_seq_size = rec_seq_size; cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL); if (!cctx->rec_seq) { rc = -ENOMEM; @@ -2215,7 +2225,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) if (rc) goto free_aead; - rc = crypto_aead_setauthsize(*aead, cctx->tag_size); + rc = crypto_aead_setauthsize(*aead, prot->tag_size); if (rc) goto free_aead; -- cgit v1.2.3-70-g09d2 From 94850257cf0f88b20db7644f28bfedc7d284de15 Mon Sep 17 00:00:00 2001 From: Boris Pismenny Date: Wed, 27 Feb 2019 17:38:03 +0200 Subject: tls: Fix tls_device handling of partial records Cleanup the handling of partial records while fixing a bug where the tls_push_pending_closed_record function is using the software tls context instead of the hardware context. The bug resulted in the following crash: [ 88.791229] BUG: unable to handle kernel NULL pointer dereference at 0000000000000000 [ 88.793271] #PF error: [normal kernel read fault] [ 88.794449] PGD 800000022a426067 P4D 800000022a426067 PUD 22a156067 PMD 0 [ 88.795958] Oops: 0000 [#1] SMP PTI [ 88.796884] CPU: 2 PID: 4973 Comm: openssl Not tainted 5.0.0-rc4+ #3 [ 88.798314] Hardware name: QEMU Standard PC (i440FX + PIIX, 1996), BIOS Bochs 01/01/2011 [ 88.800067] RIP: 0010:tls_tx_records+0xef/0x1d0 [tls] [ 88.801256] Code: 00 02 48 89 43 08 e8 a0 0b 96 d9 48 89 df e8 48 dd 4d d9 4c 89 f8 4d 8b bf 98 00 00 00 48 05 98 00 00 00 48 89 04 24 49 39 c7 <49> 8b 1f 4d 89 fd 0f 84 af 00 00 00 41 8b 47 10 85 c0 0f 85 8d 00 [ 88.805179] RSP: 0018:ffffbd888186fca8 EFLAGS: 00010213 [ 88.806458] RAX: ffff9af1ed657c98 RBX: ffff9af1e88a1980 RCX: 0000000000000000 [ 88.808050] RDX: 0000000000000000 RSI: 0000000000000000 RDI: ffff9af1e88a1980 [ 88.809724] RBP: ffff9af1e88a1980 R08: 0000000000000017 R09: ffff9af1ebeeb700 [ 88.811294] R10: 0000000000000000 R11: 0000000000000000 R12: 0000000000000000 [ 88.812917] R13: ffff9af1e88a1980 R14: ffff9af1ec13f800 R15: 0000000000000000 [ 88.814506] FS: 00007fcad2240740(0000) GS:ffff9af1f7880000(0000) knlGS:0000000000000000 [ 88.816337] CS: 0010 DS: 0000 ES: 0000 CR0: 0000000080050033 [ 88.817717] CR2: 0000000000000000 CR3: 0000000228b3e000 CR4: 00000000001406e0 [ 88.819328] Call Trace: [ 88.820123] tls_push_data+0x628/0x6a0 [tls] [ 88.821283] ? remove_wait_queue+0x20/0x60 [ 88.822383] ? n_tty_read+0x683/0x910 [ 88.823363] tls_device_sendmsg+0x53/0xa0 [tls] [ 88.824505] sock_sendmsg+0x36/0x50 [ 88.825492] sock_write_iter+0x87/0x100 [ 88.826521] __vfs_write+0x127/0x1b0 [ 88.827499] vfs_write+0xad/0x1b0 [ 88.828454] ksys_write+0x52/0xc0 [ 88.829378] do_syscall_64+0x5b/0x180 [ 88.830369] entry_SYSCALL_64_after_hwframe+0x44/0xa9 [ 88.831603] RIP: 0033:0x7fcad1451680 [ 1248.470626] BUG: unable to handle kernel NULL pointer dereference at 0000000000000000 [ 1248.472564] #PF error: [normal kernel read fault] [ 1248.473790] PGD 0 P4D 0 [ 1248.474642] Oops: 0000 [#1] SMP PTI [ 1248.475651] CPU: 3 PID: 7197 Comm: openssl Tainted: G OE 5.0.0-rc4+ #3 [ 1248.477426] Hardware name: QEMU Standard PC (i440FX + PIIX, 1996), BIOS Bochs 01/01/2011 [ 1248.479310] RIP: 0010:tls_tx_records+0x110/0x1f0 [tls] [ 1248.480644] Code: 00 02 48 89 43 08 e8 4f cb 63 d7 48 89 df e8 f7 9c 1b d7 4c 89 f8 4d 8b bf 98 00 00 00 48 05 98 00 00 00 48 89 04 24 49 39 c7 <49> 8b 1f 4d 89 fd 0f 84 af 00 00 00 41 8b 47 10 85 c0 0f 85 8d 00 [ 1248.484825] RSP: 0018:ffffaa0a41543c08 EFLAGS: 00010213 [ 1248.486154] RAX: ffff955a2755dc98 RBX: ffff955a36031980 RCX: 0000000000000006 [ 1248.487855] RDX: 0000000000000000 RSI: 000000000000002b RDI: 0000000000000286 [ 1248.489524] RBP: ffff955a36031980 R08: 0000000000000000 R09: 00000000000002b1 [ 1248.491394] R10: 0000000000000003 R11: 00000000ad55ad55 R12: 0000000000000000 [ 1248.493162] R13: 0000000000000000 R14: ffff955a2abe6c00 R15: 0000000000000000 [ 1248.494923] FS: 0000000000000000(0000) GS:ffff955a378c0000(0000) knlGS:0000000000000000 [ 1248.496847] CS: 0010 DS: 0000 ES: 0000 CR0: 0000000080050033 [ 1248.498357] CR2: 0000000000000000 CR3: 000000020c40e000 CR4: 00000000001406e0 [ 1248.500136] Call Trace: [ 1248.500998] ? tcp_check_oom+0xd0/0xd0 [ 1248.502106] tls_sk_proto_close+0x127/0x1e0 [tls] [ 1248.503411] inet_release+0x3c/0x60 [ 1248.504530] __sock_release+0x3d/0xb0 [ 1248.505611] sock_close+0x11/0x20 [ 1248.506612] __fput+0xb4/0x220 [ 1248.507559] task_work_run+0x88/0xa0 [ 1248.508617] do_exit+0x2cb/0xbc0 [ 1248.509597] ? core_sys_select+0x17a/0x280 [ 1248.510740] do_group_exit+0x39/0xb0 [ 1248.511789] get_signal+0x1d0/0x630 [ 1248.512823] do_signal+0x36/0x620 [ 1248.513822] exit_to_usermode_loop+0x5c/0xc6 [ 1248.515003] do_syscall_64+0x157/0x180 [ 1248.516094] entry_SYSCALL_64_after_hwframe+0x44/0xa9 [ 1248.517456] RIP: 0033:0x7fb398bd3f53 [ 1248.518537] Code: Bad RIP value. Fixes: a42055e8d2c3 ("net/tls: Add support for async encryption of records for performance") Signed-off-by: Boris Pismenny Signed-off-by: Eran Ben Elisha Signed-off-by: David S. Miller --- include/net/tls.h | 20 ++++---------------- net/tls/tls_device.c | 9 +++++---- net/tls/tls_main.c | 13 ------------- 3 files changed, 9 insertions(+), 33 deletions(-) (limited to 'net/tls/tls_device.c') diff --git a/include/net/tls.h b/include/net/tls.h index 9f4117ae2297..a528a082da73 100644 --- a/include/net/tls.h +++ b/include/net/tls.h @@ -199,10 +199,6 @@ struct tls_offload_context_tx { (ALIGN(sizeof(struct tls_offload_context_tx), sizeof(void *)) + \ TLS_DRIVER_STATE_SIZE) -enum { - TLS_PENDING_CLOSED_RECORD -}; - struct cipher_context { char *iv; char *rec_seq; @@ -335,17 +331,14 @@ int tls_push_sg(struct sock *sk, struct tls_context *ctx, int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, int flags); -int tls_push_pending_closed_record(struct sock *sk, struct tls_context *ctx, - int flags, long *timeo); - static inline struct tls_msg *tls_msg(struct sk_buff *skb) { return (struct tls_msg *)strp_msg(skb); } -static inline bool tls_is_pending_closed_record(struct tls_context *ctx) +static inline bool tls_is_partially_sent_record(struct tls_context *ctx) { - return test_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags); + return !!ctx->partially_sent_record; } static inline int tls_complete_pending_work(struct sock *sk, @@ -357,17 +350,12 @@ static inline int tls_complete_pending_work(struct sock *sk, if (unlikely(sk->sk_write_pending)) rc = wait_on_pending_writer(sk, timeo); - if (!rc && tls_is_pending_closed_record(ctx)) - rc = tls_push_pending_closed_record(sk, ctx, flags, timeo); + if (!rc && tls_is_partially_sent_record(ctx)) + rc = tls_push_partial_record(sk, ctx, flags); return rc; } -static inline bool tls_is_partially_sent_record(struct tls_context *ctx) -{ - return !!ctx->partially_sent_record; -} - static inline bool tls_is_pending_open_record(struct tls_context *tls_ctx) { return tls_ctx->pending_open_record_frags; diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c index a5c17c47d08a..3e5e8e021a87 100644 --- a/net/tls/tls_device.c +++ b/net/tls/tls_device.c @@ -271,7 +271,6 @@ static int tls_push_record(struct sock *sk, list_add_tail(&record->list, &offload_ctx->records_list); spin_unlock_irq(&offload_ctx->lock); offload_ctx->open_record = NULL; - set_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags); tls_advance_record_sn(sk, &ctx->tx, ctx->crypto_send.info.version); for (i = 0; i < record->num_frags; i++) { @@ -368,9 +367,11 @@ static int tls_push_data(struct sock *sk, return -sk->sk_err; timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); - rc = tls_complete_pending_work(sk, tls_ctx, flags, &timeo); - if (rc < 0) - return rc; + if (tls_is_partially_sent_record(tls_ctx)) { + rc = tls_push_partial_record(sk, tls_ctx, flags); + if (rc < 0) + return rc; + } pfrag = sk_page_frag(sk); diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index caff15b2f9b2..7e05af75536d 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -209,19 +209,6 @@ int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, return tls_push_sg(sk, ctx, sg, offset, flags); } -int tls_push_pending_closed_record(struct sock *sk, - struct tls_context *tls_ctx, - int flags, long *timeo) -{ - struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); - - if (tls_is_partially_sent_record(tls_ctx) || - !list_empty(&ctx->tx_list)) - return tls_tx_records(sk, flags); - else - return tls_ctx->push_pending_record(sk, flags); -} - static void tls_write_space(struct sock *sk) { struct tls_context *ctx = tls_get_ctx(sk); -- cgit v1.2.3-70-g09d2 From 7463d3a2db0efea3701aab5eeb310e0d8157aff7 Mon Sep 17 00:00:00 2001 From: Boris Pismenny Date: Wed, 27 Feb 2019 17:38:04 +0200 Subject: tls: Fix write space handling TLS device cannot use the sw context. This patch returns the original tls device write space handler and moves the sw/device specific portions to the relevant files. Also, we remove the write_space call for the tls_sw flow, because it handles partial records in its delayed tx work handler. Fixes: a42055e8d2c3 ("net/tls: Add support for async encryption of records for performance") Signed-off-by: Boris Pismenny Reviewed-by: Eran Ben Elisha Signed-off-by: David S. Miller --- include/net/tls.h | 3 +++ net/tls/tls_device.c | 17 +++++++++++++++++ net/tls/tls_main.c | 15 ++++++--------- net/tls/tls_sw.c | 13 +++++++++++++ 4 files changed, 39 insertions(+), 9 deletions(-) (limited to 'net/tls/tls_device.c') diff --git a/include/net/tls.h b/include/net/tls.h index a528a082da73..a5a938583295 100644 --- a/include/net/tls.h +++ b/include/net/tls.h @@ -519,6 +519,9 @@ static inline bool tls_sw_has_ctx_tx(const struct sock *sk) return !!tls_sw_ctx_tx(ctx); } +void tls_sw_write_space(struct sock *sk, struct tls_context *ctx); +void tls_device_write_space(struct sock *sk, struct tls_context *ctx); + static inline struct tls_offload_context_rx * tls_offload_ctx_rx(const struct tls_context *tls_ctx) { diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c index 3e5e8e021a87..4a1da837a733 100644 --- a/net/tls/tls_device.c +++ b/net/tls/tls_device.c @@ -546,6 +546,23 @@ static int tls_device_push_pending_record(struct sock *sk, int flags) return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA); } +void tls_device_write_space(struct sock *sk, struct tls_context *ctx) +{ + int rc = 0; + + if (!sk->sk_write_pending && tls_is_partially_sent_record(ctx)) { + gfp_t sk_allocation = sk->sk_allocation; + + sk->sk_allocation = GFP_ATOMIC; + rc = tls_push_partial_record(sk, ctx, + MSG_DONTWAIT | MSG_NOSIGNAL); + sk->sk_allocation = sk_allocation; + } + + if (!rc) + ctx->sk_write_space(sk); +} + void handle_device_resync(struct sock *sk, u32 seq, u64 rcd_sn) { struct tls_context *tls_ctx = tls_get_ctx(sk); diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 7e05af75536d..17e8667917aa 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -212,7 +212,6 @@ int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, static void tls_write_space(struct sock *sk) { struct tls_context *ctx = tls_get_ctx(sk); - struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx); /* If in_tcp_sendpages call lower protocol write space handler * to ensure we wake up any waiting operations there. For example @@ -223,14 +222,12 @@ static void tls_write_space(struct sock *sk) return; } - /* Schedule the transmission if tx list is ready */ - if (is_tx_ready(tx_ctx) && !sk->sk_write_pending) { - /* Schedule the transmission */ - if (!test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask)) - schedule_delayed_work(&tx_ctx->tx_work.work, 0); - } - - ctx->sk_write_space(sk); +#ifdef CONFIG_TLS_DEVICE + if (ctx->tx_conf == TLS_HW) + tls_device_write_space(sk, ctx); + else +#endif + tls_sw_write_space(sk, ctx); } static void tls_ctx_free(struct tls_context *ctx) diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 1cc830582fa8..917caacd4d31 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -2126,6 +2126,19 @@ static void tx_work_handler(struct work_struct *work) release_sock(sk); } +void tls_sw_write_space(struct sock *sk, struct tls_context *ctx) +{ + struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx); + + /* Schedule the transmission if tx list is ready */ + if (is_tx_ready(tx_ctx) && !sk->sk_write_pending) { + /* Schedule the transmission */ + if (!test_and_set_bit(BIT_TX_SCHEDULED, + &tx_ctx->tx_bitmask)) + schedule_delayed_work(&tx_ctx->tx_work.work, 0); + } +} + int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) { struct tls_context *tls_ctx = tls_get_ctx(sk); -- cgit v1.2.3-70-g09d2