summaryrefslogtreecommitdiff
path: root/drivers/nvme/host/tcp.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/nvme/host/tcp.c')
-rw-r--r--drivers/nvme/host/tcp.c138
1 files changed, 132 insertions, 6 deletions
diff --git a/drivers/nvme/host/tcp.c b/drivers/nvme/host/tcp.c
index 5af6c7df44a6..696fc2a7da52 100644
--- a/drivers/nvme/host/tcp.c
+++ b/drivers/nvme/host/tcp.c
@@ -8,9 +8,13 @@
#include <linux/init.h>
#include <linux/slab.h>
#include <linux/err.h>
+#include <linux/key.h>
#include <linux/nvme-tcp.h>
+#include <linux/nvme-keyring.h>
#include <net/sock.h>
#include <net/tcp.h>
+#include <net/tls.h>
+#include <net/handshake.h>
#include <linux/blk-mq.h>
#include <crypto/hash.h>
#include <net/busy_poll.h>
@@ -31,6 +35,16 @@ static int so_priority;
module_param(so_priority, int, 0644);
MODULE_PARM_DESC(so_priority, "nvme tcp socket optimize priority");
+#ifdef CONFIG_NVME_TCP_TLS
+/*
+ * TLS handshake timeout
+ */
+static int tls_handshake_timeout = 10;
+module_param(tls_handshake_timeout, int, 0644);
+MODULE_PARM_DESC(tls_handshake_timeout,
+ "nvme TLS handshake timeout in seconds (default 10)");
+#endif
+
#ifdef CONFIG_DEBUG_LOCK_ALLOC
/* lockdep can detect a circular dependency of the form
* sk_lock -> mmap_lock (page fault) -> fs locks -> sk_lock
@@ -146,7 +160,10 @@ struct nvme_tcp_queue {
struct ahash_request *snd_hash;
__le32 exp_ddgst;
__le32 recv_ddgst;
-
+#ifdef CONFIG_NVME_TCP_TLS
+ struct completion tls_complete;
+ int tls_err;
+#endif
struct page_frag_cache pf_cache;
void (*state_change)(struct sock *);
@@ -1509,7 +1526,92 @@ static void nvme_tcp_set_queue_io_cpu(struct nvme_tcp_queue *queue)
queue->io_cpu = cpumask_next_wrap(n - 1, cpu_online_mask, -1, false);
}
-static int nvme_tcp_alloc_queue(struct nvme_ctrl *nctrl, int qid)
+#ifdef CONFIG_NVME_TCP_TLS
+static void nvme_tcp_tls_done(void *data, int status, key_serial_t pskid)
+{
+ struct nvme_tcp_queue *queue = data;
+ struct nvme_tcp_ctrl *ctrl = queue->ctrl;
+ int qid = nvme_tcp_queue_id(queue);
+ struct key *tls_key;
+
+ dev_dbg(ctrl->ctrl.device, "queue %d: TLS handshake done, key %x, status %d\n",
+ qid, pskid, status);
+
+ if (status) {
+ queue->tls_err = -status;
+ goto out_complete;
+ }
+
+ tls_key = key_lookup(pskid);
+ if (IS_ERR(tls_key)) {
+ dev_warn(ctrl->ctrl.device, "queue %d: Invalid key %x\n",
+ qid, pskid);
+ queue->tls_err = -ENOKEY;
+ } else {
+ ctrl->ctrl.tls_key = tls_key;
+ queue->tls_err = 0;
+ }
+
+out_complete:
+ complete(&queue->tls_complete);
+}
+
+static int nvme_tcp_start_tls(struct nvme_ctrl *nctrl,
+ struct nvme_tcp_queue *queue,
+ key_serial_t pskid)
+{
+ int qid = nvme_tcp_queue_id(queue);
+ int ret;
+ struct tls_handshake_args args;
+ unsigned long tmo = tls_handshake_timeout * HZ;
+ key_serial_t keyring = nvme_keyring_id();
+
+ dev_dbg(nctrl->device, "queue %d: start TLS with key %x\n",
+ qid, pskid);
+ memset(&args, 0, sizeof(args));
+ args.ta_sock = queue->sock;
+ args.ta_done = nvme_tcp_tls_done;
+ args.ta_data = queue;
+ args.ta_my_peerids[0] = pskid;
+ args.ta_num_peerids = 1;
+ args.ta_keyring = keyring;
+ args.ta_timeout_ms = tls_handshake_timeout * 1000;
+ queue->tls_err = -EOPNOTSUPP;
+ init_completion(&queue->tls_complete);
+ ret = tls_client_hello_psk(&args, GFP_KERNEL);
+ if (ret) {
+ dev_err(nctrl->device, "queue %d: failed to start TLS: %d\n",
+ qid, ret);
+ return ret;
+ }
+ ret = wait_for_completion_interruptible_timeout(&queue->tls_complete, tmo);
+ if (ret <= 0) {
+ if (ret == 0)
+ ret = -ETIMEDOUT;
+
+ dev_err(nctrl->device,
+ "queue %d: TLS handshake failed, error %d\n",
+ qid, ret);
+ tls_handshake_cancel(queue->sock->sk);
+ } else {
+ dev_dbg(nctrl->device,
+ "queue %d: TLS handshake complete, error %d\n",
+ qid, queue->tls_err);
+ ret = queue->tls_err;
+ }
+ return ret;
+}
+#else
+static int nvme_tcp_start_tls(struct nvme_ctrl *nctrl,
+ struct nvme_tcp_queue *queue,
+ key_serial_t pskid)
+{
+ return -EPROTONOSUPPORT;
+}
+#endif
+
+static int nvme_tcp_alloc_queue(struct nvme_ctrl *nctrl, int qid,
+ key_serial_t pskid)
{
struct nvme_tcp_ctrl *ctrl = to_tcp_ctrl(nctrl);
struct nvme_tcp_queue *queue = &ctrl->queues[qid];
@@ -1632,6 +1734,13 @@ static int nvme_tcp_alloc_queue(struct nvme_ctrl *nctrl, int qid)
goto err_rcv_pdu;
}
+ /* If PSKs are configured try to start TLS */
+ if (pskid) {
+ ret = nvme_tcp_start_tls(nctrl, queue, pskid);
+ if (ret)
+ goto err_init_connect;
+ }
+
ret = nvme_tcp_init_connection(queue);
if (ret)
goto err_init_connect;
@@ -1781,10 +1890,22 @@ out_stop_queues:
static int nvme_tcp_alloc_admin_queue(struct nvme_ctrl *ctrl)
{
int ret;
+ key_serial_t pskid = 0;
+
+ if (ctrl->opts->tls) {
+ pskid = nvme_tls_psk_default(NULL,
+ ctrl->opts->host->nqn,
+ ctrl->opts->subsysnqn);
+ if (!pskid) {
+ dev_err(ctrl->device, "no valid PSK found\n");
+ ret = -ENOKEY;
+ goto out_free_queue;
+ }
+ }
- ret = nvme_tcp_alloc_queue(ctrl, 0);
+ ret = nvme_tcp_alloc_queue(ctrl, 0, pskid);
if (ret)
- return ret;
+ goto out_free_queue;
ret = nvme_tcp_alloc_async_req(to_tcp_ctrl(ctrl));
if (ret)
@@ -1801,8 +1922,13 @@ static int __nvme_tcp_alloc_io_queues(struct nvme_ctrl *ctrl)
{
int i, ret;
+ if (ctrl->opts->tls && !ctrl->tls_key) {
+ dev_err(ctrl->device, "no PSK negotiated\n");
+ return -ENOKEY;
+ }
for (i = 1; i < ctrl->queue_count; i++) {
- ret = nvme_tcp_alloc_queue(ctrl, i);
+ ret = nvme_tcp_alloc_queue(ctrl, i,
+ key_serial(ctrl->tls_key));
if (ret)
goto out_free_queues;
}
@@ -2630,7 +2756,7 @@ static struct nvmf_transport_ops nvme_tcp_transport = {
NVMF_OPT_HOST_TRADDR | NVMF_OPT_CTRL_LOSS_TMO |
NVMF_OPT_HDR_DIGEST | NVMF_OPT_DATA_DIGEST |
NVMF_OPT_NR_WRITE_QUEUES | NVMF_OPT_NR_POLL_QUEUES |
- NVMF_OPT_TOS | NVMF_OPT_HOST_IFACE,
+ NVMF_OPT_TOS | NVMF_OPT_HOST_IFACE | NVMF_OPT_TLS,
.create_ctrl = nvme_tcp_create_ctrl,
};