diff options
Diffstat (limited to 'net')
-rw-r--r-- | net/sunrpc/Makefile | 2 | ||||
-rw-r--r-- | net/sunrpc/auth.c | 2 | ||||
-rw-r--r-- | net/sunrpc/auth_tls.c | 175 | ||||
-rw-r--r-- | net/sunrpc/clnt.c | 22 | ||||
-rw-r--r-- | net/sunrpc/rpcb_clnt.c | 39 | ||||
-rw-r--r-- | net/sunrpc/sysfs.c | 1 | ||||
-rw-r--r-- | net/sunrpc/sysfs.h | 7 | ||||
-rw-r--r-- | net/sunrpc/xprtsock.c | 434 |
8 files changed, 658 insertions, 24 deletions
diff --git a/net/sunrpc/Makefile b/net/sunrpc/Makefile index 1c8de397d6ad..f89c10fe7e6a 100644 --- a/net/sunrpc/Makefile +++ b/net/sunrpc/Makefile @@ -9,7 +9,7 @@ obj-$(CONFIG_SUNRPC_GSS) += auth_gss/ obj-$(CONFIG_SUNRPC_XPRT_RDMA) += xprtrdma/ sunrpc-y := clnt.o xprt.o socklib.o xprtsock.o sched.o \ - auth.o auth_null.o auth_unix.o \ + auth.o auth_null.o auth_tls.o auth_unix.o \ svc.o svcsock.o svcauth.o svcauth_unix.o \ addr.o rpcb_clnt.o timer.o xdr.o \ sunrpc_syms.o cache.o rpc_pipe.o sysfs.o \ diff --git a/net/sunrpc/auth.c b/net/sunrpc/auth.c index fb75a883503f..2f16f9d17966 100644 --- a/net/sunrpc/auth.c +++ b/net/sunrpc/auth.c @@ -32,7 +32,7 @@ static unsigned int auth_hashbits = RPC_CREDCACHE_DEFAULT_HASHBITS; static const struct rpc_authops __rcu *auth_flavors[RPC_AUTH_MAXFLAVOR] = { [RPC_AUTH_NULL] = (const struct rpc_authops __force __rcu *)&authnull_ops, [RPC_AUTH_UNIX] = (const struct rpc_authops __force __rcu *)&authunix_ops, - NULL, /* others can be loadable modules */ + [RPC_AUTH_TLS] = (const struct rpc_authops __force __rcu *)&authtls_ops, }; static LIST_HEAD(cred_unused); diff --git a/net/sunrpc/auth_tls.c b/net/sunrpc/auth_tls.c new file mode 100644 index 000000000000..de7678f8a23d --- /dev/null +++ b/net/sunrpc/auth_tls.c @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2021, 2022 Oracle. All rights reserved. + * + * The AUTH_TLS credential is used only to probe a remote peer + * for RPC-over-TLS support. + */ + +#include <linux/types.h> +#include <linux/module.h> +#include <linux/sunrpc/clnt.h> + +static const char *starttls_token = "STARTTLS"; +static const size_t starttls_len = 8; + +static struct rpc_auth tls_auth; +static struct rpc_cred tls_cred; + +static void tls_encode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr, + const void *obj) +{ +} + +static int tls_decode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr, + void *obj) +{ + return 0; +} + +static const struct rpc_procinfo rpcproc_tls_probe = { + .p_encode = tls_encode_probe, + .p_decode = tls_decode_probe, +}; + +static void rpc_tls_probe_call_prepare(struct rpc_task *task, void *data) +{ + task->tk_flags &= ~RPC_TASK_NO_RETRANS_TIMEOUT; + rpc_call_start(task); +} + +static void rpc_tls_probe_call_done(struct rpc_task *task, void *data) +{ +} + +static const struct rpc_call_ops rpc_tls_probe_ops = { + .rpc_call_prepare = rpc_tls_probe_call_prepare, + .rpc_call_done = rpc_tls_probe_call_done, +}; + +static int tls_probe(struct rpc_clnt *clnt) +{ + struct rpc_message msg = { + .rpc_proc = &rpcproc_tls_probe, + }; + struct rpc_task_setup task_setup_data = { + .rpc_client = clnt, + .rpc_message = &msg, + .rpc_op_cred = &tls_cred, + .callback_ops = &rpc_tls_probe_ops, + .flags = RPC_TASK_SOFT | RPC_TASK_SOFTCONN, + }; + struct rpc_task *task; + int status; + + task = rpc_run_task(&task_setup_data); + if (IS_ERR(task)) + return PTR_ERR(task); + status = task->tk_status; + rpc_put_task(task); + return status; +} + +static struct rpc_auth *tls_create(const struct rpc_auth_create_args *args, + struct rpc_clnt *clnt) +{ + refcount_inc(&tls_auth.au_count); + return &tls_auth; +} + +static void tls_destroy(struct rpc_auth *auth) +{ +} + +static struct rpc_cred *tls_lookup_cred(struct rpc_auth *auth, + struct auth_cred *acred, int flags) +{ + return get_rpccred(&tls_cred); +} + +static void tls_destroy_cred(struct rpc_cred *cred) +{ +} + +static int tls_match(struct auth_cred *acred, struct rpc_cred *cred, int taskflags) +{ + return 1; +} + +static int tls_marshal(struct rpc_task *task, struct xdr_stream *xdr) +{ + __be32 *p; + + p = xdr_reserve_space(xdr, 4 * XDR_UNIT); + if (!p) + return -EMSGSIZE; + /* Credential */ + *p++ = rpc_auth_tls; + *p++ = xdr_zero; + /* Verifier */ + *p++ = rpc_auth_null; + *p = xdr_zero; + return 0; +} + +static int tls_refresh(struct rpc_task *task) +{ + set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_rqstp->rq_cred->cr_flags); + return 0; +} + +static int tls_validate(struct rpc_task *task, struct xdr_stream *xdr) +{ + __be32 *p; + void *str; + + p = xdr_inline_decode(xdr, XDR_UNIT); + if (!p) + return -EIO; + if (*p != rpc_auth_null) + return -EIO; + if (xdr_stream_decode_opaque_inline(xdr, &str, starttls_len) != starttls_len) + return -EIO; + if (memcmp(str, starttls_token, starttls_len)) + return -EIO; + return 0; +} + +const struct rpc_authops authtls_ops = { + .owner = THIS_MODULE, + .au_flavor = RPC_AUTH_TLS, + .au_name = "NULL", + .create = tls_create, + .destroy = tls_destroy, + .lookup_cred = tls_lookup_cred, + .ping = tls_probe, +}; + +static struct rpc_auth tls_auth = { + .au_cslack = NUL_CALLSLACK, + .au_rslack = NUL_REPLYSLACK, + .au_verfsize = NUL_REPLYSLACK, + .au_ralign = NUL_REPLYSLACK, + .au_ops = &authtls_ops, + .au_flavor = RPC_AUTH_TLS, + .au_count = REFCOUNT_INIT(1), +}; + +static const struct rpc_credops tls_credops = { + .cr_name = "AUTH_TLS", + .crdestroy = tls_destroy_cred, + .crmatch = tls_match, + .crmarshal = tls_marshal, + .crwrap_req = rpcauth_wrap_req_encode, + .crrefresh = tls_refresh, + .crvalidate = tls_validate, + .crunwrap_resp = rpcauth_unwrap_resp_decode, +}; + +static struct rpc_cred tls_cred = { + .cr_lru = LIST_HEAD_INIT(tls_cred.cr_lru), + .cr_auth = &tls_auth, + .cr_ops = &tls_credops, + .cr_count = REFCOUNT_INIT(2), + .cr_flags = 1UL << RPCAUTH_CRED_UPTODATE, +}; diff --git a/net/sunrpc/clnt.c b/net/sunrpc/clnt.c index d2ee56634308..d7c697af3762 100644 --- a/net/sunrpc/clnt.c +++ b/net/sunrpc/clnt.c @@ -385,6 +385,7 @@ static struct rpc_clnt * rpc_new_client(const struct rpc_create_args *args, if (!clnt) goto out_err; clnt->cl_parent = parent ? : clnt; + clnt->cl_xprtsec = args->xprtsec; err = rpc_alloc_clid(clnt); if (err) @@ -434,7 +435,7 @@ static struct rpc_clnt * rpc_new_client(const struct rpc_create_args *args, if (parent) refcount_inc(&parent->cl_count); - trace_rpc_clnt_new(clnt, xprt, program->name, args->servername); + trace_rpc_clnt_new(clnt, xprt, args); return clnt; out_no_path: @@ -532,6 +533,7 @@ struct rpc_clnt *rpc_create(struct rpc_create_args *args) .addrlen = args->addrsize, .servername = args->servername, .bc_xprt = args->bc_xprt, + .xprtsec = args->xprtsec, }; char servername[48]; struct rpc_clnt *clnt; @@ -565,8 +567,12 @@ struct rpc_clnt *rpc_create(struct rpc_create_args *args) servername[0] = '\0'; switch (args->address->sa_family) { case AF_LOCAL: - snprintf(servername, sizeof(servername), "%s", - sun->sun_path); + if (sun->sun_path[0]) + snprintf(servername, sizeof(servername), "%s", + sun->sun_path); + else + snprintf(servername, sizeof(servername), "@%s", + sun->sun_path+1); break; case AF_INET: snprintf(servername, sizeof(servername), "%pI4", @@ -727,6 +733,7 @@ int rpc_switch_client_transport(struct rpc_clnt *clnt, struct rpc_clnt *parent; int err; + args->xprtsec = clnt->cl_xprtsec; xprt = xprt_create_transport(args); if (IS_ERR(xprt)) return PTR_ERR(xprt); @@ -1717,6 +1724,11 @@ call_start(struct rpc_task *task) trace_rpc_request(task); + if (task->tk_client->cl_shutdown) { + rpc_call_rpcerror(task, -EIO); + return; + } + /* Increment call count (version might not be valid for ping) */ if (clnt->cl_program->version[clnt->cl_vers]) clnt->cl_program->version[clnt->cl_vers]->counts[idx]++; @@ -2826,6 +2838,9 @@ static int rpc_ping(struct rpc_clnt *clnt) struct rpc_task *task; int status; + if (clnt->cl_auth->au_ops->ping) + return clnt->cl_auth->au_ops->ping(clnt); + task = rpc_call_null_helper(clnt, NULL, NULL, 0, NULL, NULL); if (IS_ERR(task)) return PTR_ERR(task); @@ -3046,6 +3061,7 @@ int rpc_clnt_add_xprt(struct rpc_clnt *clnt, if (!xprtargs->ident) xprtargs->ident = ident; + xprtargs->xprtsec = clnt->cl_xprtsec; xprt = xprt_create_transport(xprtargs); if (IS_ERR(xprt)) { ret = PTR_ERR(xprt); diff --git a/net/sunrpc/rpcb_clnt.c b/net/sunrpc/rpcb_clnt.c index 5a8e6d46809a..5988a5c5ff3f 100644 --- a/net/sunrpc/rpcb_clnt.c +++ b/net/sunrpc/rpcb_clnt.c @@ -36,6 +36,7 @@ #include "netns.h" #define RPCBIND_SOCK_PATHNAME "/var/run/rpcbind.sock" +#define RPCBIND_SOCK_ABSTRACT_NAME "\0/run/rpcbind.sock" #define RPCBIND_PROGRAM (100000u) #define RPCBIND_PORT (111u) @@ -216,21 +217,22 @@ static void rpcb_set_local(struct net *net, struct rpc_clnt *clnt, sn->rpcb_users = 1; } +/* Evaluate to actual length of the `sockaddr_un' structure. */ +# define SUN_LEN(ptr) (offsetof(struct sockaddr_un, sun_path) \ + + 1 + strlen((ptr)->sun_path + 1)) + /* * Returns zero on success, otherwise a negative errno value * is returned. */ -static int rpcb_create_local_unix(struct net *net) +static int rpcb_create_af_local(struct net *net, + const struct sockaddr_un *addr) { - static const struct sockaddr_un rpcb_localaddr_rpcbind = { - .sun_family = AF_LOCAL, - .sun_path = RPCBIND_SOCK_PATHNAME, - }; struct rpc_create_args args = { .net = net, .protocol = XPRT_TRANSPORT_LOCAL, - .address = (struct sockaddr *)&rpcb_localaddr_rpcbind, - .addrsize = sizeof(rpcb_localaddr_rpcbind), + .address = (struct sockaddr *)addr, + .addrsize = SUN_LEN(addr), .servername = "localhost", .program = &rpcb_program, .version = RPCBVERS_2, @@ -269,6 +271,26 @@ out: return result; } +static int rpcb_create_local_abstract(struct net *net) +{ + static const struct sockaddr_un rpcb_localaddr_abstract = { + .sun_family = AF_LOCAL, + .sun_path = RPCBIND_SOCK_ABSTRACT_NAME, + }; + + return rpcb_create_af_local(net, &rpcb_localaddr_abstract); +} + +static int rpcb_create_local_unix(struct net *net) +{ + static const struct sockaddr_un rpcb_localaddr_unix = { + .sun_family = AF_LOCAL, + .sun_path = RPCBIND_SOCK_PATHNAME, + }; + + return rpcb_create_af_local(net, &rpcb_localaddr_unix); +} + /* * Returns zero on success, otherwise a negative errno value * is returned. @@ -332,7 +354,8 @@ int rpcb_create_local(struct net *net) if (rpcb_get_local(net)) goto out; - if (rpcb_create_local_unix(net) != 0) + if (rpcb_create_local_abstract(net) != 0 && + rpcb_create_local_unix(net) != 0) result = rpcb_create_local_net(net); out: diff --git a/net/sunrpc/sysfs.c b/net/sunrpc/sysfs.c index 0d0db4e1064e..5c8ecdaaa985 100644 --- a/net/sunrpc/sysfs.c +++ b/net/sunrpc/sysfs.c @@ -239,6 +239,7 @@ static ssize_t rpc_sysfs_xprt_dstaddr_store(struct kobject *kobj, if (!xprt) return 0; if (!(xprt->xprt_class->ident == XPRT_TRANSPORT_TCP || + xprt->xprt_class->ident == XPRT_TRANSPORT_TCP_TLS || xprt->xprt_class->ident == XPRT_TRANSPORT_RDMA)) { xprt_put(xprt); return -EOPNOTSUPP; diff --git a/net/sunrpc/sysfs.h b/net/sunrpc/sysfs.h index 6620cebd1037..d2dd77a0a0e9 100644 --- a/net/sunrpc/sysfs.h +++ b/net/sunrpc/sysfs.h @@ -5,13 +5,6 @@ #ifndef __SUNRPC_SYSFS_H #define __SUNRPC_SYSFS_H -struct rpc_sysfs_client { - struct kobject kobject; - struct net *net; - struct rpc_clnt *clnt; - struct rpc_xprt_switch *xprt_switch; -}; - struct rpc_sysfs_xprt_switch { struct kobject kobject; struct net *net; diff --git a/net/sunrpc/xprtsock.c b/net/sunrpc/xprtsock.c index 5f9030b81c9e..9f010369100a 100644 --- a/net/sunrpc/xprtsock.c +++ b/net/sunrpc/xprtsock.c @@ -47,6 +47,9 @@ #include <net/checksum.h> #include <net/udp.h> #include <net/tcp.h> +#include <net/tls.h> +#include <net/handshake.h> + #include <linux/bvec.h> #include <linux/highmem.h> #include <linux/uio.h> @@ -96,6 +99,7 @@ static struct ctl_table_header *sunrpc_table_header; static struct xprt_class xs_local_transport; static struct xprt_class xs_udp_transport; static struct xprt_class xs_tcp_transport; +static struct xprt_class xs_tcp_tls_transport; static struct xprt_class xs_bc_tcp_transport; /* @@ -187,6 +191,11 @@ static struct ctl_table xs_tunables_table[] = { */ #define XS_IDLE_DISC_TO (5U * 60 * HZ) +/* + * TLS handshake timeout. + */ +#define XS_TLS_HANDSHAKE_TO (10U * HZ) + #if IS_ENABLED(CONFIG_SUNRPC_DEBUG) # undef RPC_DEBUG_DATA # define RPCDBG_FACILITY RPCDBG_TRANS @@ -253,7 +262,12 @@ static void xs_format_common_peer_addresses(struct rpc_xprt *xprt) switch (sap->sa_family) { case AF_LOCAL: sun = xs_addr_un(xprt); - strscpy(buf, sun->sun_path, sizeof(buf)); + if (sun->sun_path[0]) { + strscpy(buf, sun->sun_path, sizeof(buf)); + } else { + buf[0] = '@'; + strscpy(buf+1, sun->sun_path+1, sizeof(buf)-1); + } xprt->address_strings[RPC_DISPLAY_ADDR] = kstrdup(buf, GFP_KERNEL); break; @@ -342,13 +356,56 @@ xs_alloc_sparse_pages(struct xdr_buf *buf, size_t want, gfp_t gfp) return want; } +static int +xs_sock_process_cmsg(struct socket *sock, struct msghdr *msg, + struct cmsghdr *cmsg, int ret) +{ + if (cmsg->cmsg_level == SOL_TLS && + cmsg->cmsg_type == TLS_GET_RECORD_TYPE) { + u8 content_type = *((u8 *)CMSG_DATA(cmsg)); + + switch (content_type) { + case TLS_RECORD_TYPE_DATA: + /* TLS sets EOR at the end of each application data + * record, even though there might be more frames + * waiting to be decrypted. + */ + msg->msg_flags &= ~MSG_EOR; + break; + case TLS_RECORD_TYPE_ALERT: + ret = -ENOTCONN; + break; + default: + ret = -EAGAIN; + } + } + return ret; +} + +static int +xs_sock_recv_cmsg(struct socket *sock, struct msghdr *msg, int flags) +{ + union { + struct cmsghdr cmsg; + u8 buf[CMSG_SPACE(sizeof(u8))]; + } u; + int ret; + + msg->msg_control = &u; + msg->msg_controllen = sizeof(u); + ret = sock_recvmsg(sock, msg, flags); + if (msg->msg_controllen != sizeof(u)) + ret = xs_sock_process_cmsg(sock, msg, &u.cmsg, ret); + return ret; +} + static ssize_t xs_sock_recvmsg(struct socket *sock, struct msghdr *msg, int flags, size_t seek) { ssize_t ret; if (seek != 0) iov_iter_advance(&msg->msg_iter, seek); - ret = sock_recvmsg(sock, msg, flags); + ret = xs_sock_recv_cmsg(sock, msg, flags); return ret > 0 ? ret + seek : ret; } @@ -374,7 +431,7 @@ xs_read_discard(struct socket *sock, struct msghdr *msg, int flags, size_t count) { iov_iter_discard(&msg->msg_iter, ITER_DEST, count); - return sock_recvmsg(sock, msg, flags); + return xs_sock_recv_cmsg(sock, msg, flags); } #if ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE @@ -695,6 +752,8 @@ static void xs_poll_check_readable(struct sock_xprt *transport) { clear_bit(XPRT_SOCK_DATA_READY, &transport->sock_state); + if (test_bit(XPRT_SOCK_IGNORE_RECV, &transport->sock_state)) + return; if (!xs_poll_socket_readable(transport)) return; if (!test_and_set_bit(XPRT_SOCK_DATA_READY, &transport->sock_state)) @@ -1191,6 +1250,8 @@ static void xs_reset_transport(struct sock_xprt *transport) if (atomic_read(&transport->xprt.swapper)) sk_clear_memalloc(sk); + tls_handshake_cancel(sk); + kernel_sock_shutdown(sock, SHUT_RDWR); mutex_lock(&transport->recv_mutex); @@ -1380,6 +1441,10 @@ static void xs_data_ready(struct sock *sk) trace_xs_data_ready(xprt); transport->old_data_ready(sk); + + if (test_bit(XPRT_SOCK_IGNORE_RECV, &transport->sock_state)) + return; + /* Any data means we had a useful conversation, so * then we don't need to delay the next reconnect */ @@ -2360,6 +2425,267 @@ out_unlock: current_restore_flags(pflags, PF_MEMALLOC); } +/* + * Transfer the connected socket to @upper_transport, then mark that + * xprt CONNECTED. + */ +static int xs_tcp_tls_finish_connecting(struct rpc_xprt *lower_xprt, + struct sock_xprt *upper_transport) +{ + struct sock_xprt *lower_transport = + container_of(lower_xprt, struct sock_xprt, xprt); + struct rpc_xprt *upper_xprt = &upper_transport->xprt; + + if (!upper_transport->inet) { + struct socket *sock = lower_transport->sock; + struct sock *sk = sock->sk; + + /* Avoid temporary address, they are bad for long-lived + * connections such as NFS mounts. + * RFC4941, section 3.6 suggests that: + * Individual applications, which have specific + * knowledge about the normal duration of connections, + * MAY override this as appropriate. + */ + if (xs_addr(upper_xprt)->sa_family == PF_INET6) + ip6_sock_set_addr_preferences(sk, IPV6_PREFER_SRC_PUBLIC); + + xs_tcp_set_socket_timeouts(upper_xprt, sock); + tcp_sock_set_nodelay(sk); + + lock_sock(sk); + + /* @sk is already connected, so it now has the RPC callbacks. + * Reach into @lower_transport to save the original ones. + */ + upper_transport->old_data_ready = lower_transport->old_data_ready; + upper_transport->old_state_change = lower_transport->old_state_change; + upper_transport->old_write_space = lower_transport->old_write_space; + upper_transport->old_error_report = lower_transport->old_error_report; + sk->sk_user_data = upper_xprt; + + /* socket options */ + sock_reset_flag(sk, SOCK_LINGER); + + xprt_clear_connected(upper_xprt); + + upper_transport->sock = sock; + upper_transport->inet = sk; + upper_transport->file = lower_transport->file; + + release_sock(sk); + + /* Reset lower_transport before shutting down its clnt */ + mutex_lock(&lower_transport->recv_mutex); + lower_transport->inet = NULL; + lower_transport->sock = NULL; + lower_transport->file = NULL; + + xprt_clear_connected(lower_xprt); + xs_sock_reset_connection_flags(lower_xprt); + xs_stream_reset_connect(lower_transport); + mutex_unlock(&lower_transport->recv_mutex); + } + + if (!xprt_bound(upper_xprt)) + return -ENOTCONN; + + xs_set_memalloc(upper_xprt); + + if (!xprt_test_and_set_connected(upper_xprt)) { + upper_xprt->connect_cookie++; + clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state); + xprt_clear_connecting(upper_xprt); + + upper_xprt->stat.connect_count++; + upper_xprt->stat.connect_time += (long)jiffies - + upper_xprt->stat.connect_start; + xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING); + } + return 0; +} + +/** + * xs_tls_handshake_done - TLS handshake completion handler + * @data: address of xprt to wake + * @status: status of handshake + * @peerid: serial number of key containing the remote's identity + * + */ +static void xs_tls_handshake_done(void *data, int status, key_serial_t peerid) +{ + struct rpc_xprt *lower_xprt = data; + struct sock_xprt *lower_transport = + container_of(lower_xprt, struct sock_xprt, xprt); + + lower_transport->xprt_err = status ? -EACCES : 0; + complete(&lower_transport->handshake_done); + xprt_put(lower_xprt); +} + +static int xs_tls_handshake_sync(struct rpc_xprt *lower_xprt, struct xprtsec_parms *xprtsec) +{ + struct sock_xprt *lower_transport = + container_of(lower_xprt, struct sock_xprt, xprt); + struct tls_handshake_args args = { + .ta_sock = lower_transport->sock, + .ta_done = xs_tls_handshake_done, + .ta_data = xprt_get(lower_xprt), + .ta_peername = lower_xprt->servername, + }; + struct sock *sk = lower_transport->inet; + int rc; + + init_completion(&lower_transport->handshake_done); + set_bit(XPRT_SOCK_IGNORE_RECV, &lower_transport->sock_state); + lower_transport->xprt_err = -ETIMEDOUT; + switch (xprtsec->policy) { + case RPC_XPRTSEC_TLS_ANON: + rc = tls_client_hello_anon(&args, GFP_KERNEL); + if (rc) + goto out_put_xprt; + break; + case RPC_XPRTSEC_TLS_X509: + args.ta_my_cert = xprtsec->cert_serial; + args.ta_my_privkey = xprtsec->privkey_serial; + rc = tls_client_hello_x509(&args, GFP_KERNEL); + if (rc) + goto out_put_xprt; + break; + default: + rc = -EACCES; + goto out_put_xprt; + } + + rc = wait_for_completion_interruptible_timeout(&lower_transport->handshake_done, + XS_TLS_HANDSHAKE_TO); + if (rc <= 0) { + if (!tls_handshake_cancel(sk)) { + if (rc == 0) + rc = -ETIMEDOUT; + goto out_put_xprt; + } + } + + rc = lower_transport->xprt_err; + +out: + xs_stream_reset_connect(lower_transport); + clear_bit(XPRT_SOCK_IGNORE_RECV, &lower_transport->sock_state); + return rc; + +out_put_xprt: + xprt_put(lower_xprt); + goto out; +} + +/** + * xs_tcp_tls_setup_socket - establish a TLS session on a TCP socket + * @work: queued work item + * + * Invoked by a work queue tasklet. + * + * For RPC-with-TLS, there is a two-stage connection process. + * + * The "upper-layer xprt" is visible to the RPC consumer. Once it has + * been marked connected, the consumer knows that a TCP connection and + * a TLS session have been established. + * + * A "lower-layer xprt", created in this function, handles the mechanics + * of connecting the TCP socket, performing the RPC_AUTH_TLS probe, and + * then driving the TLS handshake. Once all that is complete, the upper + * layer xprt is marked connected. + */ +static void xs_tcp_tls_setup_socket(struct work_struct *work) +{ + struct sock_xprt *upper_transport = + container_of(work, struct sock_xprt, connect_worker.work); + struct rpc_clnt *upper_clnt = upper_transport->clnt; + struct rpc_xprt *upper_xprt = &upper_transport->xprt; + struct rpc_create_args args = { + .net = upper_xprt->xprt_net, + .protocol = upper_xprt->prot, + .address = (struct sockaddr *)&upper_xprt->addr, + .addrsize = upper_xprt->addrlen, + .timeout = upper_clnt->cl_timeout, + .servername = upper_xprt->servername, + .program = upper_clnt->cl_program, + .prognumber = upper_clnt->cl_prog, + .version = upper_clnt->cl_vers, + .authflavor = RPC_AUTH_TLS, + .cred = upper_clnt->cl_cred, + .xprtsec = { + .policy = RPC_XPRTSEC_NONE, + }, + }; + unsigned int pflags = current->flags; + struct rpc_clnt *lower_clnt; + struct rpc_xprt *lower_xprt; + int status; + + if (atomic_read(&upper_xprt->swapper)) + current->flags |= PF_MEMALLOC; + + xs_stream_start_connect(upper_transport); + + /* This implicitly sends an RPC_AUTH_TLS probe */ + lower_clnt = rpc_create(&args); + if (IS_ERR(lower_clnt)) { + trace_rpc_tls_unavailable(upper_clnt, upper_xprt); + clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state); + xprt_clear_connecting(upper_xprt); + xprt_wake_pending_tasks(upper_xprt, PTR_ERR(lower_clnt)); + xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING); + goto out_unlock; + } + + /* RPC_AUTH_TLS probe was successful. Try a TLS handshake on + * the lower xprt. + */ + rcu_read_lock(); + lower_xprt = rcu_dereference(lower_clnt->cl_xprt); + rcu_read_unlock(); + status = xs_tls_handshake_sync(lower_xprt, &upper_xprt->xprtsec); + if (status) { + trace_rpc_tls_not_started(upper_clnt, upper_xprt); + goto out_close; + } + + status = xs_tcp_tls_finish_connecting(lower_xprt, upper_transport); + if (status) + goto out_close; + + trace_rpc_socket_connect(upper_xprt, upper_transport->sock, 0); + if (!xprt_test_and_set_connected(upper_xprt)) { + upper_xprt->connect_cookie++; + clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state); + xprt_clear_connecting(upper_xprt); + + upper_xprt->stat.connect_count++; + upper_xprt->stat.connect_time += (long)jiffies - + upper_xprt->stat.connect_start; + xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING); + } + rpc_shutdown_client(lower_clnt); + +out_unlock: + current_restore_flags(pflags, PF_MEMALLOC); + upper_transport->clnt = NULL; + xprt_unlock_connect(upper_xprt, upper_transport); + return; + +out_close: + rpc_shutdown_client(lower_clnt); + + /* xprt_force_disconnect() wakes tasks with a fixed tk_status code. + * Wake them first here to ensure they get our tk_status code. + */ + xprt_wake_pending_tasks(upper_xprt, status); + xs_tcp_force_close(upper_xprt); + xprt_clear_connecting(upper_xprt); + goto out_unlock; +} + /** * xs_connect - connect a socket to a remote endpoint * @xprt: pointer to transport structure @@ -2391,6 +2717,7 @@ static void xs_connect(struct rpc_xprt *xprt, struct rpc_task *task) } else dprintk("RPC: xs_connect scheduled xprt %p\n", xprt); + transport->clnt = task->tk_client; queue_delayed_work(xprtiod_workqueue, &transport->connect_worker, delay); @@ -2858,7 +3185,7 @@ static struct rpc_xprt *xs_setup_local(struct xprt_create *args) switch (sun->sun_family) { case AF_LOCAL: - if (sun->sun_path[0] != '/') { + if (sun->sun_path[0] != '/' && sun->sun_path[0] != '\0') { dprintk("RPC: bad AF_LOCAL address: %s\n", sun->sun_path); ret = ERR_PTR(-EINVAL); @@ -3045,6 +3372,94 @@ out_err: } /** + * xs_setup_tcp_tls - Set up transport to use a TCP with TLS + * @args: rpc transport creation arguments + * + */ +static struct rpc_xprt *xs_setup_tcp_tls(struct xprt_create *args) +{ + struct sockaddr *addr = args->dstaddr; + struct rpc_xprt *xprt; + struct sock_xprt *transport; + struct rpc_xprt *ret; + unsigned int max_slot_table_size = xprt_max_tcp_slot_table_entries; + + if (args->flags & XPRT_CREATE_INFINITE_SLOTS) + max_slot_table_size = RPC_MAX_SLOT_TABLE_LIMIT; + + xprt = xs_setup_xprt(args, xprt_tcp_slot_table_entries, + max_slot_table_size); + if (IS_ERR(xprt)) + return xprt; + transport = container_of(xprt, struct sock_xprt, xprt); + + xprt->prot = IPPROTO_TCP; + xprt->xprt_class = &xs_tcp_transport; + xprt->max_payload = RPC_MAX_FRAGMENT_SIZE; + + xprt->bind_timeout = XS_BIND_TO; + xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO; + xprt->idle_timeout = XS_IDLE_DISC_TO; + + xprt->ops = &xs_tcp_ops; + xprt->timeout = &xs_tcp_default_timeout; + + xprt->max_reconnect_timeout = xprt->timeout->to_maxval; + xprt->connect_timeout = xprt->timeout->to_initval * + (xprt->timeout->to_retries + 1); + + INIT_WORK(&transport->recv_worker, xs_stream_data_receive_workfn); + INIT_WORK(&transport->error_worker, xs_error_handle); + + switch (args->xprtsec.policy) { + case RPC_XPRTSEC_TLS_ANON: + case RPC_XPRTSEC_TLS_X509: + xprt->xprtsec = args->xprtsec; + INIT_DELAYED_WORK(&transport->connect_worker, + xs_tcp_tls_setup_socket); + break; + default: + ret = ERR_PTR(-EACCES); + goto out_err; + } + + switch (addr->sa_family) { + case AF_INET: + if (((struct sockaddr_in *)addr)->sin_port != htons(0)) + xprt_set_bound(xprt); + + xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP); + break; + case AF_INET6: + if (((struct sockaddr_in6 *)addr)->sin6_port != htons(0)) + xprt_set_bound(xprt); + + xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP6); + break; + default: + ret = ERR_PTR(-EAFNOSUPPORT); + goto out_err; + } + + if (xprt_bound(xprt)) + dprintk("RPC: set up xprt to %s (port %s) via %s\n", + xprt->address_strings[RPC_DISPLAY_ADDR], + xprt->address_strings[RPC_DISPLAY_PORT], + xprt->address_strings[RPC_DISPLAY_PROTO]); + else + dprintk("RPC: set up xprt to %s (autobind) via %s\n", + xprt->address_strings[RPC_DISPLAY_ADDR], + xprt->address_strings[RPC_DISPLAY_PROTO]); + + if (try_module_get(THIS_MODULE)) + return xprt; + ret = ERR_PTR(-EINVAL); +out_err: + xs_xprt_free(xprt); + return ret; +} + +/** * xs_setup_bc_tcp - Set up transport to use a TCP backchannel socket * @args: rpc transport creation arguments * @@ -3153,6 +3568,15 @@ static struct xprt_class xs_tcp_transport = { .netid = { "tcp", "tcp6", "" }, }; +static struct xprt_class xs_tcp_tls_transport = { + .list = LIST_HEAD_INIT(xs_tcp_tls_transport.list), + .name = "tcp-with-tls", + .owner = THIS_MODULE, + .ident = XPRT_TRANSPORT_TCP_TLS, + .setup = xs_setup_tcp_tls, + .netid = { "tcp", "tcp6", "" }, +}; + static struct xprt_class xs_bc_tcp_transport = { .list = LIST_HEAD_INIT(xs_bc_tcp_transport.list), .name = "tcp NFSv4.1 backchannel", @@ -3174,6 +3598,7 @@ int init_socket_xprt(void) xprt_register_transport(&xs_local_transport); xprt_register_transport(&xs_udp_transport); xprt_register_transport(&xs_tcp_transport); + xprt_register_transport(&xs_tcp_tls_transport); xprt_register_transport(&xs_bc_tcp_transport); return 0; @@ -3193,6 +3618,7 @@ void cleanup_socket_xprt(void) xprt_unregister_transport(&xs_local_transport); xprt_unregister_transport(&xs_udp_transport); xprt_unregister_transport(&xs_tcp_transport); + xprt_unregister_transport(&xs_tcp_tls_transport); xprt_unregister_transport(&xs_bc_tcp_transport); } |