diff options
Diffstat (limited to 'net/vmw_vsock')
-rw-r--r-- | net/vmw_vsock/af_vsock.c | 397 | ||||
-rw-r--r-- | net/vmw_vsock/hyperv_transport.c | 94 | ||||
-rw-r--r-- | net/vmw_vsock/virtio_transport.c | 177 | ||||
-rw-r--r-- | net/vmw_vsock/virtio_transport_common.c | 223 | ||||
-rw-r--r-- | net/vmw_vsock/vmci_transport.c | 142 | ||||
-rw-r--r-- | net/vmw_vsock/vmci_transport.h | 3 | ||||
-rw-r--r-- | net/vmw_vsock/vmci_transport_notify.h | 1 |
7 files changed, 594 insertions, 443 deletions
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c index 582a3e4dfce2..74db4cd637a7 100644 --- a/net/vmw_vsock/af_vsock.c +++ b/net/vmw_vsock/af_vsock.c @@ -126,19 +126,18 @@ static struct proto vsock_proto = { */ #define VSOCK_DEFAULT_CONNECT_TIMEOUT (2 * HZ) -static const struct vsock_transport *transport; +#define VSOCK_DEFAULT_BUFFER_SIZE (1024 * 256) +#define VSOCK_DEFAULT_BUFFER_MAX_SIZE (1024 * 256) +#define VSOCK_DEFAULT_BUFFER_MIN_SIZE 128 + +/* Transport used for host->guest communication */ +static const struct vsock_transport *transport_h2g; +/* Transport used for guest->host communication */ +static const struct vsock_transport *transport_g2h; +/* Transport used for DGRAM communication */ +static const struct vsock_transport *transport_dgram; static DEFINE_MUTEX(vsock_register_mutex); -/**** EXPORTS ****/ - -/* Get the ID of the local context. This is transport dependent. */ - -int vm_sockets_get_local_cid(void) -{ - return transport->get_local_cid(); -} -EXPORT_SYMBOL_GPL(vm_sockets_get_local_cid); - /**** UTILS ****/ /* Each bound VSocket is stored in the bind hash table and each connected @@ -188,7 +187,7 @@ static int vsock_auto_bind(struct vsock_sock *vsk) return __vsock_bind(sk, &local_addr); } -static int __init vsock_init_tables(void) +static void vsock_init_tables(void) { int i; @@ -197,7 +196,6 @@ static int __init vsock_init_tables(void) for (i = 0; i < ARRAY_SIZE(vsock_connected_table); i++) INIT_LIST_HEAD(&vsock_connected_table[i]); - return 0; } static void __vsock_insert_bound(struct list_head *list, @@ -230,9 +228,15 @@ static struct sock *__vsock_find_bound_socket(struct sockaddr_vm *addr) { struct vsock_sock *vsk; - list_for_each_entry(vsk, vsock_bound_sockets(addr), bound_table) - if (addr->svm_port == vsk->local_addr.svm_port) + list_for_each_entry(vsk, vsock_bound_sockets(addr), bound_table) { + if (vsock_addr_equals_addr(addr, &vsk->local_addr)) + return sk_vsock(vsk); + + if (addr->svm_port == vsk->local_addr.svm_port && + (vsk->local_addr.svm_cid == VMADDR_CID_ANY || + addr->svm_cid == VMADDR_CID_ANY)) return sk_vsock(vsk); + } return NULL; } @@ -382,6 +386,88 @@ void vsock_enqueue_accept(struct sock *listener, struct sock *connected) } EXPORT_SYMBOL_GPL(vsock_enqueue_accept); +static void vsock_deassign_transport(struct vsock_sock *vsk) +{ + if (!vsk->transport) + return; + + vsk->transport->destruct(vsk); + module_put(vsk->transport->module); + vsk->transport = NULL; +} + +/* Assign a transport to a socket and call the .init transport callback. + * + * Note: for stream socket this must be called when vsk->remote_addr is set + * (e.g. during the connect() or when a connection request on a listener + * socket is received). + * The vsk->remote_addr is used to decide which transport to use: + * - remote CID <= VMADDR_CID_HOST will use guest->host transport; + * - remote CID == local_cid (guest->host transport) will use guest->host + * transport for loopback (host->guest transports don't support loopback); + * - remote CID > VMADDR_CID_HOST will use host->guest transport; + */ +int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) +{ + const struct vsock_transport *new_transport; + struct sock *sk = sk_vsock(vsk); + unsigned int remote_cid = vsk->remote_addr.svm_cid; + int ret; + + switch (sk->sk_type) { + case SOCK_DGRAM: + new_transport = transport_dgram; + break; + case SOCK_STREAM: + if (remote_cid <= VMADDR_CID_HOST || + (transport_g2h && + remote_cid == transport_g2h->get_local_cid())) + new_transport = transport_g2h; + else + new_transport = transport_h2g; + break; + default: + return -ESOCKTNOSUPPORT; + } + + if (vsk->transport) { + if (vsk->transport == new_transport) + return 0; + + vsk->transport->release(vsk); + vsock_deassign_transport(vsk); + } + + /* We increase the module refcnt to prevent the transport unloading + * while there are open sockets assigned to it. + */ + if (!new_transport || !try_module_get(new_transport->module)) + return -ENODEV; + + ret = new_transport->init(vsk, psk); + if (ret) { + module_put(new_transport->module); + return ret; + } + + vsk->transport = new_transport; + + return 0; +} +EXPORT_SYMBOL_GPL(vsock_assign_transport); + +bool vsock_find_cid(unsigned int cid) +{ + if (transport_g2h && cid == transport_g2h->get_local_cid()) + return true; + + if (transport_h2g && cid == VMADDR_CID_HOST) + return true; + + return false; +} +EXPORT_SYMBOL_GPL(vsock_find_cid); + static struct sock *vsock_dequeue_accept(struct sock *listener) { struct vsock_sock *vlistener; @@ -418,7 +504,12 @@ static bool vsock_is_pending(struct sock *sk) static int vsock_send_shutdown(struct sock *sk, int mode) { - return transport->shutdown(vsock_sk(sk), mode); + struct vsock_sock *vsk = vsock_sk(sk); + + if (!vsk->transport) + return -ENODEV; + + return vsk->transport->shutdown(vsk, mode); } static void vsock_pending_work(struct work_struct *work) @@ -439,7 +530,7 @@ static void vsock_pending_work(struct work_struct *work) if (vsock_is_pending(sk)) { vsock_remove_pending(listener, sk); - listener->sk_ack_backlog--; + sk_acceptq_removed(listener); } else if (!vsk->rejected) { /* We are not on the pending list and accept() did not reject * us, so we must have been accepted by our user process. We @@ -528,13 +619,12 @@ static int __vsock_bind_stream(struct vsock_sock *vsk, static int __vsock_bind_dgram(struct vsock_sock *vsk, struct sockaddr_vm *addr) { - return transport->dgram_bind(vsk, addr); + return vsk->transport->dgram_bind(vsk, addr); } static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr) { struct vsock_sock *vsk = vsock_sk(sk); - u32 cid; int retval; /* First ensure this socket isn't already bound. */ @@ -544,10 +634,9 @@ static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr) /* Now bind to the provided address or select appropriate values if * none are provided (VMADDR_CID_ANY and VMADDR_PORT_ANY). Note that * like AF_INET prevents binding to a non-local IP address (in most - * cases), we only allow binding to the local CID. + * cases), we only allow binding to a local CID. */ - cid = transport->get_local_cid(); - if (addr->svm_cid != cid && addr->svm_cid != VMADDR_CID_ANY) + if (addr->svm_cid != VMADDR_CID_ANY && !vsock_find_cid(addr->svm_cid)) return -EADDRNOTAVAIL; switch (sk->sk_socket->type) { @@ -571,12 +660,12 @@ static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr) static void vsock_connect_timeout(struct work_struct *work); -struct sock *__vsock_create(struct net *net, - struct socket *sock, - struct sock *parent, - gfp_t priority, - unsigned short type, - int kern) +static struct sock *__vsock_create(struct net *net, + struct socket *sock, + struct sock *parent, + gfp_t priority, + unsigned short type, + int kern) { struct sock *sk; struct vsock_sock *psk; @@ -620,28 +709,24 @@ struct sock *__vsock_create(struct net *net, vsk->trusted = psk->trusted; vsk->owner = get_cred(psk->owner); vsk->connect_timeout = psk->connect_timeout; + vsk->buffer_size = psk->buffer_size; + vsk->buffer_min_size = psk->buffer_min_size; + vsk->buffer_max_size = psk->buffer_max_size; } else { vsk->trusted = capable(CAP_NET_ADMIN); vsk->owner = get_current_cred(); vsk->connect_timeout = VSOCK_DEFAULT_CONNECT_TIMEOUT; + vsk->buffer_size = VSOCK_DEFAULT_BUFFER_SIZE; + vsk->buffer_min_size = VSOCK_DEFAULT_BUFFER_MIN_SIZE; + vsk->buffer_max_size = VSOCK_DEFAULT_BUFFER_MAX_SIZE; } - if (transport->init(vsk, psk) < 0) { - sk_free(sk); - return NULL; - } - - if (sock) - vsock_insert_unbound(vsk); - return sk; } -EXPORT_SYMBOL_GPL(__vsock_create); static void __vsock_release(struct sock *sk, int level) { if (sk) { - struct sk_buff *skb; struct sock *pending; struct vsock_sock *vsk; @@ -651,7 +736,10 @@ static void __vsock_release(struct sock *sk, int level) /* The release call is supposed to use lock_sock_nested() * rather than lock_sock(), if a sock lock should be acquired. */ - transport->release(vsk); + if (vsk->transport) + vsk->transport->release(vsk); + else if (sk->sk_type == SOCK_STREAM) + vsock_remove_sock(vsk); /* When "level" is SINGLE_DEPTH_NESTING, use the nested * version to avoid the warning "possible recursive locking @@ -662,8 +750,7 @@ static void __vsock_release(struct sock *sk, int level) sock_orphan(sk); sk->sk_shutdown = SHUTDOWN_MASK; - while ((skb = skb_dequeue(&sk->sk_receive_queue))) - kfree_skb(skb); + skb_queue_purge(&sk->sk_receive_queue); /* Clean up any sockets that never were accepted. */ while ((pending = vsock_dequeue_accept(sk)) != NULL) { @@ -680,7 +767,7 @@ static void vsock_sk_destruct(struct sock *sk) { struct vsock_sock *vsk = vsock_sk(sk); - transport->destruct(vsk); + vsock_deassign_transport(vsk); /* When clearing these addresses, there's no need to set the family and * possibly register the address family with the kernel. @@ -702,15 +789,22 @@ static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) return err; } +struct sock *vsock_create_connected(struct sock *parent) +{ + return __vsock_create(sock_net(parent), NULL, parent, GFP_KERNEL, + parent->sk_type, 0); +} +EXPORT_SYMBOL_GPL(vsock_create_connected); + s64 vsock_stream_has_data(struct vsock_sock *vsk) { - return transport->stream_has_data(vsk); + return vsk->transport->stream_has_data(vsk); } EXPORT_SYMBOL_GPL(vsock_stream_has_data); s64 vsock_stream_has_space(struct vsock_sock *vsk) { - return transport->stream_has_space(vsk); + return vsk->transport->stream_has_space(vsk); } EXPORT_SYMBOL_GPL(vsock_stream_has_space); @@ -879,6 +973,7 @@ static __poll_t vsock_poll(struct file *file, struct socket *sock, mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND; } else if (sock->type == SOCK_STREAM) { + const struct vsock_transport *transport = vsk->transport; lock_sock(sk); /* Listening sockets that have connections in their accept @@ -889,7 +984,7 @@ static __poll_t vsock_poll(struct file *file, struct socket *sock, mask |= EPOLLIN | EPOLLRDNORM; /* If there is something in the queue then we can read. */ - if (transport->stream_is_active(vsk) && + if (transport && transport->stream_is_active(vsk) && !(sk->sk_shutdown & RCV_SHUTDOWN)) { bool data_ready_now = false; int ret = transport->notify_poll_in( @@ -954,6 +1049,7 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg, struct sock *sk; struct vsock_sock *vsk; struct sockaddr_vm *remote_addr; + const struct vsock_transport *transport; if (msg->msg_flags & MSG_OOB) return -EOPNOTSUPP; @@ -962,6 +1058,7 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg, err = 0; sk = sock->sk; vsk = vsock_sk(sk); + transport = vsk->transport; lock_sock(sk); @@ -1046,8 +1143,8 @@ static int vsock_dgram_connect(struct socket *sock, if (err) goto out; - if (!transport->dgram_allow(remote_addr->svm_cid, - remote_addr->svm_port)) { + if (!vsk->transport->dgram_allow(remote_addr->svm_cid, + remote_addr->svm_port)) { err = -EINVAL; goto out; } @@ -1063,7 +1160,9 @@ out: static int vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, int flags) { - return transport->dgram_dequeue(vsock_sk(sock->sk), msg, len, flags); + struct vsock_sock *vsk = vsock_sk(sock->sk); + + return vsk->transport->dgram_dequeue(vsk, msg, len, flags); } static const struct proto_ops vsock_dgram_ops = { @@ -1089,6 +1188,8 @@ static const struct proto_ops vsock_dgram_ops = { static int vsock_transport_cancel_pkt(struct vsock_sock *vsk) { + const struct vsock_transport *transport = vsk->transport; + if (!transport->cancel_pkt) return -EOPNOTSUPP; @@ -1125,6 +1226,7 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr, int err; struct sock *sk; struct vsock_sock *vsk; + const struct vsock_transport *transport; struct sockaddr_vm *remote_addr; long timeout; DEFINE_WAIT(wait); @@ -1159,19 +1261,26 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr, goto out; } + /* Set the remote address that we are connecting to. */ + memcpy(&vsk->remote_addr, remote_addr, + sizeof(vsk->remote_addr)); + + err = vsock_assign_transport(vsk, NULL); + if (err) + goto out; + + transport = vsk->transport; + /* The hypervisor and well-known contexts do not have socket * endpoints. */ - if (!transport->stream_allow(remote_addr->svm_cid, + if (!transport || + !transport->stream_allow(remote_addr->svm_cid, remote_addr->svm_port)) { err = -ENETUNREACH; goto out; } - /* Set the remote address that we are connecting to. */ - memcpy(&vsk->remote_addr, remote_addr, - sizeof(vsk->remote_addr)); - err = vsock_auto_bind(vsk); if (err) goto out; @@ -1301,7 +1410,7 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, int flags, err = -listener->sk_err; if (connected) { - listener->sk_ack_backlog--; + sk_acceptq_removed(listener); lock_sock_nested(connected, SINGLE_DEPTH_NESTING); vconnected = vsock_sk(connected); @@ -1366,6 +1475,23 @@ out: return err; } +static void vsock_update_buffer_size(struct vsock_sock *vsk, + const struct vsock_transport *transport, + u64 val) +{ + if (val > vsk->buffer_max_size) + val = vsk->buffer_max_size; + + if (val < vsk->buffer_min_size) + val = vsk->buffer_min_size; + + if (val != vsk->buffer_size && + transport && transport->notify_buffer_size) + transport->notify_buffer_size(vsk, &val); + + vsk->buffer_size = val; +} + static int vsock_stream_setsockopt(struct socket *sock, int level, int optname, @@ -1375,6 +1501,7 @@ static int vsock_stream_setsockopt(struct socket *sock, int err; struct sock *sk; struct vsock_sock *vsk; + const struct vsock_transport *transport; u64 val; if (level != AF_VSOCK) @@ -1395,23 +1522,26 @@ static int vsock_stream_setsockopt(struct socket *sock, err = 0; sk = sock->sk; vsk = vsock_sk(sk); + transport = vsk->transport; lock_sock(sk); switch (optname) { case SO_VM_SOCKETS_BUFFER_SIZE: COPY_IN(val); - transport->set_buffer_size(vsk, val); + vsock_update_buffer_size(vsk, transport, val); break; case SO_VM_SOCKETS_BUFFER_MAX_SIZE: COPY_IN(val); - transport->set_max_buffer_size(vsk, val); + vsk->buffer_max_size = val; + vsock_update_buffer_size(vsk, transport, vsk->buffer_size); break; case SO_VM_SOCKETS_BUFFER_MIN_SIZE: COPY_IN(val); - transport->set_min_buffer_size(vsk, val); + vsk->buffer_min_size = val; + vsock_update_buffer_size(vsk, transport, vsk->buffer_size); break; case SO_VM_SOCKETS_CONNECT_TIMEOUT: { @@ -1478,17 +1608,17 @@ static int vsock_stream_getsockopt(struct socket *sock, switch (optname) { case SO_VM_SOCKETS_BUFFER_SIZE: - val = transport->get_buffer_size(vsk); + val = vsk->buffer_size; COPY_OUT(val); break; case SO_VM_SOCKETS_BUFFER_MAX_SIZE: - val = transport->get_max_buffer_size(vsk); + val = vsk->buffer_max_size; COPY_OUT(val); break; case SO_VM_SOCKETS_BUFFER_MIN_SIZE: - val = transport->get_min_buffer_size(vsk); + val = vsk->buffer_min_size; COPY_OUT(val); break; @@ -1519,6 +1649,7 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg, { struct sock *sk; struct vsock_sock *vsk; + const struct vsock_transport *transport; ssize_t total_written; long timeout; int err; @@ -1527,6 +1658,7 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg, sk = sock->sk; vsk = vsock_sk(sk); + transport = vsk->transport; total_written = 0; err = 0; @@ -1548,7 +1680,7 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg, goto out; } - if (sk->sk_state != TCP_ESTABLISHED || + if (!transport || sk->sk_state != TCP_ESTABLISHED || !vsock_addr_bound(&vsk->local_addr)) { err = -ENOTCONN; goto out; @@ -1658,6 +1790,7 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, { struct sock *sk; struct vsock_sock *vsk; + const struct vsock_transport *transport; int err; size_t target; ssize_t copied; @@ -1668,11 +1801,12 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, sk = sock->sk; vsk = vsock_sk(sk); + transport = vsk->transport; err = 0; lock_sock(sk); - if (sk->sk_state != TCP_ESTABLISHED) { + if (!transport || sk->sk_state != TCP_ESTABLISHED) { /* Recvmsg is supposed to return 0 if a peer performs an * orderly shutdown. Differentiate between that case and when a * peer has not connected or a local shutdown occured with the @@ -1846,6 +1980,10 @@ static const struct proto_ops vsock_stream_ops = { static int vsock_create(struct net *net, struct socket *sock, int protocol, int kern) { + struct vsock_sock *vsk; + struct sock *sk; + int ret; + if (!sock) return -EINVAL; @@ -1865,7 +2003,23 @@ static int vsock_create(struct net *net, struct socket *sock, sock->state = SS_UNCONNECTED; - return __vsock_create(net, sock, NULL, GFP_KERNEL, 0, kern) ? 0 : -ENOMEM; + sk = __vsock_create(net, sock, NULL, GFP_KERNEL, 0, kern); + if (!sk) + return -ENOMEM; + + vsk = vsock_sk(sk); + + if (sock->type == SOCK_DGRAM) { + ret = vsock_assign_transport(vsk, NULL); + if (ret < 0) { + sock_put(sk); + return ret; + } + } + + vsock_insert_unbound(vsk); + + return 0; } static const struct net_proto_family vsock_family_ops = { @@ -1878,11 +2032,20 @@ static long vsock_dev_do_ioctl(struct file *filp, unsigned int cmd, void __user *ptr) { u32 __user *p = ptr; + u32 cid = VMADDR_CID_ANY; int retval = 0; switch (cmd) { case IOCTL_VM_SOCKETS_GET_LOCAL_CID: - if (put_user(transport->get_local_cid(), p) != 0) + /* To be compatible with the VMCI behavior, we prioritize the + * guest CID instead of well-know host CID (VMADDR_CID_HOST). + */ + if (transport_g2h) + cid = transport_g2h->get_local_cid(); + else if (transport_h2g) + cid = transport_h2g->get_local_cid(); + + if (put_user(cid, p) != 0) retval = -EFAULT; break; @@ -1922,24 +2085,13 @@ static struct miscdevice vsock_device = { .fops = &vsock_device_ops, }; -int __vsock_core_init(const struct vsock_transport *t, struct module *owner) +static int __init vsock_init(void) { - int err = mutex_lock_interruptible(&vsock_register_mutex); + int err = 0; - if (err) - return err; - - if (transport) { - err = -EBUSY; - goto err_busy; - } - - /* Transport must be the owner of the protocol so that it can't - * unload while there are open sockets. - */ - vsock_proto.owner = owner; - transport = t; + vsock_init_tables(); + vsock_proto.owner = THIS_MODULE; vsock_device.minor = MISC_DYNAMIC_MINOR; err = misc_register(&vsock_device); if (err) { @@ -1960,7 +2112,6 @@ int __vsock_core_init(const struct vsock_transport *t, struct module *owner) goto err_unregister_proto; } - mutex_unlock(&vsock_register_mutex); return 0; err_unregister_proto: @@ -1968,44 +2119,86 @@ err_unregister_proto: err_deregister_misc: misc_deregister(&vsock_device); err_reset_transport: - transport = NULL; -err_busy: - mutex_unlock(&vsock_register_mutex); return err; } -EXPORT_SYMBOL_GPL(__vsock_core_init); -void vsock_core_exit(void) +static void __exit vsock_exit(void) { - mutex_lock(&vsock_register_mutex); - misc_deregister(&vsock_device); sock_unregister(AF_VSOCK); proto_unregister(&vsock_proto); - - /* We do not want the assignment below re-ordered. */ - mb(); - transport = NULL; - - mutex_unlock(&vsock_register_mutex); } -EXPORT_SYMBOL_GPL(vsock_core_exit); -const struct vsock_transport *vsock_core_get_transport(void) +const struct vsock_transport *vsock_core_get_transport(struct vsock_sock *vsk) { - /* vsock_register_mutex not taken since only the transport uses this - * function and only while registered. - */ - return transport; + return vsk->transport; } EXPORT_SYMBOL_GPL(vsock_core_get_transport); -static void __exit vsock_exit(void) +int vsock_core_register(const struct vsock_transport *t, int features) +{ + const struct vsock_transport *t_h2g, *t_g2h, *t_dgram; + int err = mutex_lock_interruptible(&vsock_register_mutex); + + if (err) + return err; + + t_h2g = transport_h2g; + t_g2h = transport_g2h; + t_dgram = transport_dgram; + + if (features & VSOCK_TRANSPORT_F_H2G) { + if (t_h2g) { + err = -EBUSY; + goto err_busy; + } + t_h2g = t; + } + + if (features & VSOCK_TRANSPORT_F_G2H) { + if (t_g2h) { + err = -EBUSY; + goto err_busy; + } + t_g2h = t; + } + + if (features & VSOCK_TRANSPORT_F_DGRAM) { + if (t_dgram) { + err = -EBUSY; + goto err_busy; + } + t_dgram = t; + } + + transport_h2g = t_h2g; + transport_g2h = t_g2h; + transport_dgram = t_dgram; + +err_busy: + mutex_unlock(&vsock_register_mutex); + return err; +} +EXPORT_SYMBOL_GPL(vsock_core_register); + +void vsock_core_unregister(const struct vsock_transport *t) { - /* Do nothing. This function makes this module removable. */ + mutex_lock(&vsock_register_mutex); + + if (transport_h2g == t) + transport_h2g = NULL; + + if (transport_g2h == t) + transport_g2h = NULL; + + if (transport_dgram == t) + transport_dgram = NULL; + + mutex_unlock(&vsock_register_mutex); } +EXPORT_SYMBOL_GPL(vsock_core_unregister); -module_init(vsock_init_tables); +module_init(vsock_init); module_exit(vsock_exit); MODULE_AUTHOR("VMware, Inc."); diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c index c443db7af8d4..3c7d07a99fc5 100644 --- a/net/vmw_vsock/hyperv_transport.c +++ b/net/vmw_vsock/hyperv_transport.c @@ -13,15 +13,16 @@ #include <linux/hyperv.h> #include <net/sock.h> #include <net/af_vsock.h> +#include <asm/hyperv-tlfs.h> /* Older (VMBUS version 'VERSION_WIN10' or before) Windows hosts have some - * stricter requirements on the hv_sock ring buffer size of six 4K pages. Newer - * hosts don't have this limitation; but, keep the defaults the same for compat. + * stricter requirements on the hv_sock ring buffer size of six 4K pages. + * hyperv-tlfs defines HV_HYP_PAGE_SIZE as 4K. Newer hosts don't have this + * limitation; but, keep the defaults the same for compat. */ -#define PAGE_SIZE_4K 4096 -#define RINGBUFFER_HVS_RCV_SIZE (PAGE_SIZE_4K * 6) -#define RINGBUFFER_HVS_SND_SIZE (PAGE_SIZE_4K * 6) -#define RINGBUFFER_HVS_MAX_SIZE (PAGE_SIZE_4K * 64) +#define RINGBUFFER_HVS_RCV_SIZE (HV_HYP_PAGE_SIZE * 6) +#define RINGBUFFER_HVS_SND_SIZE (HV_HYP_PAGE_SIZE * 6) +#define RINGBUFFER_HVS_MAX_SIZE (HV_HYP_PAGE_SIZE * 64) /* The MTU is 16KB per the host side's design */ #define HVS_MTU_SIZE (1024 * 16) @@ -54,7 +55,8 @@ struct hvs_recv_buf { * ringbuffer APIs that allow us to directly copy data from userspace buffer * to VMBus ringbuffer. */ -#define HVS_SEND_BUF_SIZE (PAGE_SIZE_4K - sizeof(struct vmpipe_proto_header)) +#define HVS_SEND_BUF_SIZE \ + (HV_HYP_PAGE_SIZE - sizeof(struct vmpipe_proto_header)) struct hvs_send_buf { /* The header before the payload data */ @@ -163,6 +165,8 @@ static const guid_t srv_id_template = GUID_INIT(0x00000000, 0xfacb, 0x11e6, 0xbd, 0x58, 0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3); +static bool hvs_check_transport(struct vsock_sock *vsk); + static bool is_valid_srv_id(const guid_t *id) { return !memcmp(&id->b[4], &srv_id_template.b[4], sizeof(guid_t) - 4); @@ -186,7 +190,8 @@ static void hvs_remote_addr_init(struct sockaddr_vm *remote, static u32 host_ephemeral_port = MIN_HOST_EPHEMERAL_PORT; struct sock *sk; - vsock_addr_init(remote, VMADDR_CID_ANY, VMADDR_PORT_ANY); + /* Remote peer is always the host */ + vsock_addr_init(remote, VMADDR_CID_HOST, VMADDR_PORT_ANY); while (1) { /* Wrap around ? */ @@ -358,13 +363,24 @@ static void hvs_open_connection(struct vmbus_channel *chan) if (sk->sk_ack_backlog >= sk->sk_max_ack_backlog) goto out; - new = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL, - sk->sk_type, 0); + new = vsock_create_connected(sk); if (!new) goto out; new->sk_state = TCP_SYN_SENT; vnew = vsock_sk(new); + + hvs_addr_init(&vnew->local_addr, if_type); + hvs_remote_addr_init(&vnew->remote_addr, &vnew->local_addr); + + ret = vsock_assign_transport(vnew, vsock_sk(sk)); + /* Transport assigned (looking at remote_addr) must be the + * same where we received the request. + */ + if (ret || !hvs_check_transport(vnew)) { + sock_put(new); + goto out; + } hvs_new = vnew->trans; hvs_new->chan = chan; } else { @@ -393,10 +409,10 @@ static void hvs_open_connection(struct vmbus_channel *chan) } else { sndbuf = max_t(int, sk->sk_sndbuf, RINGBUFFER_HVS_SND_SIZE); sndbuf = min_t(int, sndbuf, RINGBUFFER_HVS_MAX_SIZE); - sndbuf = ALIGN(sndbuf, PAGE_SIZE); + sndbuf = ALIGN(sndbuf, HV_HYP_PAGE_SIZE); rcvbuf = max_t(int, sk->sk_rcvbuf, RINGBUFFER_HVS_RCV_SIZE); rcvbuf = min_t(int, rcvbuf, RINGBUFFER_HVS_MAX_SIZE); - rcvbuf = ALIGN(rcvbuf, PAGE_SIZE); + rcvbuf = ALIGN(rcvbuf, HV_HYP_PAGE_SIZE); } ret = vmbus_open(chan, sndbuf, rcvbuf, NULL, 0, hvs_channel_cb, @@ -426,10 +442,7 @@ static void hvs_open_connection(struct vmbus_channel *chan) if (conn_from_host) { new->sk_state = TCP_ESTABLISHED; - sk->sk_ack_backlog++; - - hvs_addr_init(&vnew->local_addr, if_type); - hvs_remote_addr_init(&vnew->remote_addr, &vnew->local_addr); + sk_acceptq_added(sk); hvs_new->vm_srv_id = *if_type; hvs_new->host_srv_id = *if_instance; @@ -670,7 +683,7 @@ static ssize_t hvs_stream_enqueue(struct vsock_sock *vsk, struct msghdr *msg, ssize_t ret = 0; ssize_t bytes_written = 0; - BUILD_BUG_ON(sizeof(*send_buf) != PAGE_SIZE_4K); + BUILD_BUG_ON(sizeof(*send_buf) != HV_HYP_PAGE_SIZE); send_buf = kmalloc(sizeof(*send_buf), GFP_KERNEL); if (!send_buf) @@ -843,37 +856,9 @@ int hvs_notify_send_post_enqueue(struct vsock_sock *vsk, ssize_t written, return 0; } -static void hvs_set_buffer_size(struct vsock_sock *vsk, u64 val) -{ - /* Ignored. */ -} - -static void hvs_set_min_buffer_size(struct vsock_sock *vsk, u64 val) -{ - /* Ignored. */ -} - -static void hvs_set_max_buffer_size(struct vsock_sock *vsk, u64 val) -{ - /* Ignored. */ -} - -static u64 hvs_get_buffer_size(struct vsock_sock *vsk) -{ - return -ENOPROTOOPT; -} - -static u64 hvs_get_min_buffer_size(struct vsock_sock *vsk) -{ - return -ENOPROTOOPT; -} - -static u64 hvs_get_max_buffer_size(struct vsock_sock *vsk) -{ - return -ENOPROTOOPT; -} - static struct vsock_transport hvs_transport = { + .module = THIS_MODULE, + .get_local_cid = hvs_get_local_cid, .init = hvs_sock_init, @@ -906,14 +891,13 @@ static struct vsock_transport hvs_transport = { .notify_send_pre_enqueue = hvs_notify_send_pre_enqueue, .notify_send_post_enqueue = hvs_notify_send_post_enqueue, - .set_buffer_size = hvs_set_buffer_size, - .set_min_buffer_size = hvs_set_min_buffer_size, - .set_max_buffer_size = hvs_set_max_buffer_size, - .get_buffer_size = hvs_get_buffer_size, - .get_min_buffer_size = hvs_get_min_buffer_size, - .get_max_buffer_size = hvs_get_max_buffer_size, }; +static bool hvs_check_transport(struct vsock_sock *vsk) +{ + return vsk->transport == &hvs_transport; +} + static int hvs_probe(struct hv_device *hdev, const struct hv_vmbus_device_id *dev_id) { @@ -962,7 +946,7 @@ static int __init hvs_init(void) if (ret != 0) return ret; - ret = vsock_core_init(&hvs_transport); + ret = vsock_core_register(&hvs_transport, VSOCK_TRANSPORT_F_G2H); if (ret) { vmbus_driver_unregister(&hvs_drv); return ret; @@ -973,7 +957,7 @@ static int __init hvs_init(void) static void __exit hvs_exit(void) { - vsock_core_exit(); + vsock_core_unregister(&hvs_transport); vmbus_driver_unregister(&hvs_drv); } diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c index 082a30936690..1458c5c8b64d 100644 --- a/net/vmw_vsock/virtio_transport.c +++ b/net/vmw_vsock/virtio_transport.c @@ -86,33 +86,6 @@ out_rcu: return ret; } -static void virtio_transport_loopback_work(struct work_struct *work) -{ - struct virtio_vsock *vsock = - container_of(work, struct virtio_vsock, loopback_work); - LIST_HEAD(pkts); - - spin_lock_bh(&vsock->loopback_list_lock); - list_splice_init(&vsock->loopback_list, &pkts); - spin_unlock_bh(&vsock->loopback_list_lock); - - mutex_lock(&vsock->rx_lock); - - if (!vsock->rx_run) - goto out; - - while (!list_empty(&pkts)) { - struct virtio_vsock_pkt *pkt; - - pkt = list_first_entry(&pkts, struct virtio_vsock_pkt, list); - list_del_init(&pkt->list); - - virtio_transport_recv_pkt(pkt); - } -out: - mutex_unlock(&vsock->rx_lock); -} - static int virtio_transport_send_pkt_loopback(struct virtio_vsock *vsock, struct virtio_vsock_pkt *pkt) { @@ -370,59 +343,6 @@ static bool virtio_transport_more_replies(struct virtio_vsock *vsock) return val < virtqueue_get_vring_size(vq); } -static void virtio_transport_rx_work(struct work_struct *work) -{ - struct virtio_vsock *vsock = - container_of(work, struct virtio_vsock, rx_work); - struct virtqueue *vq; - - vq = vsock->vqs[VSOCK_VQ_RX]; - - mutex_lock(&vsock->rx_lock); - - if (!vsock->rx_run) - goto out; - - do { - virtqueue_disable_cb(vq); - for (;;) { - struct virtio_vsock_pkt *pkt; - unsigned int len; - - if (!virtio_transport_more_replies(vsock)) { - /* Stop rx until the device processes already - * pending replies. Leave rx virtqueue - * callbacks disabled. - */ - goto out; - } - - pkt = virtqueue_get_buf(vq, &len); - if (!pkt) { - break; - } - - vsock->rx_buf_nr--; - - /* Drop short/long packets */ - if (unlikely(len < sizeof(pkt->hdr) || - len > sizeof(pkt->hdr) + pkt->len)) { - virtio_transport_free_pkt(pkt); - continue; - } - - pkt->len = len - sizeof(pkt->hdr); - virtio_transport_deliver_tap_pkt(pkt); - virtio_transport_recv_pkt(pkt); - } - } while (!virtqueue_enable_cb(vq)); - -out: - if (vsock->rx_buf_nr < vsock->rx_buf_max_nr / 2) - virtio_vsock_rx_fill(vsock); - mutex_unlock(&vsock->rx_lock); -} - /* event_lock must be held */ static int virtio_vsock_event_fill_one(struct virtio_vsock *vsock, struct virtio_vsock_event *event) @@ -542,6 +462,8 @@ static void virtio_vsock_rx_done(struct virtqueue *vq) static struct virtio_transport virtio_transport = { .transport = { + .module = THIS_MODULE, + .get_local_cid = virtio_transport_get_local_cid, .init = virtio_transport_do_socket_init, @@ -574,18 +496,92 @@ static struct virtio_transport virtio_transport = { .notify_send_pre_block = virtio_transport_notify_send_pre_block, .notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue, .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue, - - .set_buffer_size = virtio_transport_set_buffer_size, - .set_min_buffer_size = virtio_transport_set_min_buffer_size, - .set_max_buffer_size = virtio_transport_set_max_buffer_size, - .get_buffer_size = virtio_transport_get_buffer_size, - .get_min_buffer_size = virtio_transport_get_min_buffer_size, - .get_max_buffer_size = virtio_transport_get_max_buffer_size, + .notify_buffer_size = virtio_transport_notify_buffer_size, }, .send_pkt = virtio_transport_send_pkt, }; +static void virtio_transport_loopback_work(struct work_struct *work) +{ + struct virtio_vsock *vsock = + container_of(work, struct virtio_vsock, loopback_work); + LIST_HEAD(pkts); + + spin_lock_bh(&vsock->loopback_list_lock); + list_splice_init(&vsock->loopback_list, &pkts); + spin_unlock_bh(&vsock->loopback_list_lock); + + mutex_lock(&vsock->rx_lock); + + if (!vsock->rx_run) + goto out; + + while (!list_empty(&pkts)) { + struct virtio_vsock_pkt *pkt; + + pkt = list_first_entry(&pkts, struct virtio_vsock_pkt, list); + list_del_init(&pkt->list); + + virtio_transport_recv_pkt(&virtio_transport, pkt); + } +out: + mutex_unlock(&vsock->rx_lock); +} + +static void virtio_transport_rx_work(struct work_struct *work) +{ + struct virtio_vsock *vsock = + container_of(work, struct virtio_vsock, rx_work); + struct virtqueue *vq; + + vq = vsock->vqs[VSOCK_VQ_RX]; + + mutex_lock(&vsock->rx_lock); + + if (!vsock->rx_run) + goto out; + + do { + virtqueue_disable_cb(vq); + for (;;) { + struct virtio_vsock_pkt *pkt; + unsigned int len; + + if (!virtio_transport_more_replies(vsock)) { + /* Stop rx until the device processes already + * pending replies. Leave rx virtqueue + * callbacks disabled. + */ + goto out; + } + + pkt = virtqueue_get_buf(vq, &len); + if (!pkt) { + break; + } + + vsock->rx_buf_nr--; + + /* Drop short/long packets */ + if (unlikely(len < sizeof(pkt->hdr) || + len > sizeof(pkt->hdr) + pkt->len)) { + virtio_transport_free_pkt(pkt); + continue; + } + + pkt->len = len - sizeof(pkt->hdr); + virtio_transport_deliver_tap_pkt(pkt); + virtio_transport_recv_pkt(&virtio_transport, pkt); + } + } while (!virtqueue_enable_cb(vq)); + +out: + if (vsock->rx_buf_nr < vsock->rx_buf_max_nr / 2) + virtio_vsock_rx_fill(vsock); + mutex_unlock(&vsock->rx_lock); +} + static int virtio_vsock_probe(struct virtio_device *vdev) { vq_callback_t *callbacks[] = { @@ -776,7 +772,8 @@ static int __init virtio_vsock_init(void) if (!virtio_vsock_workqueue) return -ENOMEM; - ret = vsock_core_init(&virtio_transport.transport); + ret = vsock_core_register(&virtio_transport.transport, + VSOCK_TRANSPORT_F_G2H); if (ret) goto out_wq; @@ -787,7 +784,7 @@ static int __init virtio_vsock_init(void) return 0; out_vci: - vsock_core_exit(); + vsock_core_unregister(&virtio_transport.transport); out_wq: destroy_workqueue(virtio_vsock_workqueue); return ret; @@ -796,7 +793,7 @@ out_wq: static void __exit virtio_vsock_exit(void) { unregister_virtio_driver(&virtio_vsock_driver); - vsock_core_exit(); + vsock_core_unregister(&virtio_transport.transport); destroy_workqueue(virtio_vsock_workqueue); } diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c index fb2060dffb0a..e5ea29c6bca7 100644 --- a/net/vmw_vsock/virtio_transport_common.c +++ b/net/vmw_vsock/virtio_transport_common.c @@ -29,9 +29,10 @@ /* Threshold for detecting small packets to copy */ #define GOOD_COPY_LEN 128 -static const struct virtio_transport *virtio_transport_get_ops(void) +static const struct virtio_transport * +virtio_transport_get_ops(struct vsock_sock *vsk) { - const struct vsock_transport *t = vsock_core_get_transport(); + const struct vsock_transport *t = vsock_core_get_transport(vsk); return container_of(t, struct virtio_transport, transport); } @@ -168,7 +169,7 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, struct virtio_vsock_pkt *pkt; u32 pkt_len = info->pkt_len; - src_cid = vm_sockets_get_local_cid(); + src_cid = virtio_transport_get_ops(vsk)->transport.get_local_cid(); src_port = vsk->local_addr.svm_port; if (!info->remote_cid) { dst_cid = vsk->remote_addr.svm_cid; @@ -201,7 +202,7 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, virtio_transport_inc_tx_pkt(vvs, pkt); - return virtio_transport_get_ops()->send_pkt(pkt); + return virtio_transport_get_ops(vsk)->send_pkt(pkt); } static bool virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs, @@ -268,6 +269,55 @@ static int virtio_transport_send_credit_update(struct vsock_sock *vsk, } static ssize_t +virtio_transport_stream_do_peek(struct vsock_sock *vsk, + struct msghdr *msg, + size_t len) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + struct virtio_vsock_pkt *pkt; + size_t bytes, total = 0, off; + int err = -EFAULT; + + spin_lock_bh(&vvs->rx_lock); + + list_for_each_entry(pkt, &vvs->rx_queue, list) { + off = pkt->off; + + if (total == len) + break; + + while (total < len && off < pkt->len) { + bytes = len - total; + if (bytes > pkt->len - off) + bytes = pkt->len - off; + + /* sk_lock is held by caller so no one else can dequeue. + * Unlock rx_lock since memcpy_to_msg() may sleep. + */ + spin_unlock_bh(&vvs->rx_lock); + + err = memcpy_to_msg(msg, pkt->buf + off, bytes); + if (err) + goto out; + + spin_lock_bh(&vvs->rx_lock); + + total += bytes; + off += bytes; + } + } + + spin_unlock_bh(&vvs->rx_lock); + + return total; + +out: + if (total) + err = total; + return err; +} + +static ssize_t virtio_transport_stream_do_dequeue(struct vsock_sock *vsk, struct msghdr *msg, size_t len) @@ -339,9 +389,9 @@ virtio_transport_stream_dequeue(struct vsock_sock *vsk, size_t len, int flags) { if (flags & MSG_PEEK) - return -EOPNOTSUPP; - - return virtio_transport_stream_do_dequeue(vsk, msg, len); + return virtio_transport_stream_do_peek(vsk, msg, len); + else + return virtio_transport_stream_do_dequeue(vsk, msg, len); } EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue); @@ -403,20 +453,16 @@ int virtio_transport_do_socket_init(struct vsock_sock *vsk, vsk->trans = vvs; vvs->vsk = vsk; - if (psk) { + if (psk && psk->trans) { struct virtio_vsock_sock *ptrans = psk->trans; - vvs->buf_size = ptrans->buf_size; - vvs->buf_size_min = ptrans->buf_size_min; - vvs->buf_size_max = ptrans->buf_size_max; vvs->peer_buf_alloc = ptrans->peer_buf_alloc; - } else { - vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE; - vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE; - vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE; } - vvs->buf_alloc = vvs->buf_size; + if (vsk->buffer_size > VIRTIO_VSOCK_MAX_BUF_SIZE) + vsk->buffer_size = VIRTIO_VSOCK_MAX_BUF_SIZE; + + vvs->buf_alloc = vsk->buffer_size; spin_lock_init(&vvs->rx_lock); spin_lock_init(&vvs->tx_lock); @@ -426,71 +472,20 @@ int virtio_transport_do_socket_init(struct vsock_sock *vsk, } EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init); -u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk) -{ - struct virtio_vsock_sock *vvs = vsk->trans; - - return vvs->buf_size; -} -EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size); - -u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk) -{ - struct virtio_vsock_sock *vvs = vsk->trans; - - return vvs->buf_size_min; -} -EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size); - -u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk) +/* sk_lock held by the caller */ +void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val) { struct virtio_vsock_sock *vvs = vsk->trans; - return vvs->buf_size_max; -} -EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size); + if (*val > VIRTIO_VSOCK_MAX_BUF_SIZE) + *val = VIRTIO_VSOCK_MAX_BUF_SIZE; -void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val) -{ - struct virtio_vsock_sock *vvs = vsk->trans; - - if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) - val = VIRTIO_VSOCK_MAX_BUF_SIZE; - if (val < vvs->buf_size_min) - vvs->buf_size_min = val; - if (val > vvs->buf_size_max) - vvs->buf_size_max = val; - vvs->buf_size = val; - vvs->buf_alloc = val; + vvs->buf_alloc = *val; virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM, NULL); } -EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size); - -void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val) -{ - struct virtio_vsock_sock *vvs = vsk->trans; - - if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) - val = VIRTIO_VSOCK_MAX_BUF_SIZE; - if (val > vvs->buf_size) - vvs->buf_size = val; - vvs->buf_size_min = val; -} -EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size); - -void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val) -{ - struct virtio_vsock_sock *vvs = vsk->trans; - - if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) - val = VIRTIO_VSOCK_MAX_BUF_SIZE; - if (val < vvs->buf_size) - vvs->buf_size = val; - vvs->buf_size_max = val; -} -EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size); +EXPORT_SYMBOL_GPL(virtio_transport_notify_buffer_size); int virtio_transport_notify_poll_in(struct vsock_sock *vsk, @@ -582,9 +577,7 @@ EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue); u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk) { - struct virtio_vsock_sock *vvs = vsk->trans; - - return vvs->buf_size; + return vsk->buffer_size; } EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat); @@ -696,9 +689,9 @@ static int virtio_transport_reset(struct vsock_sock *vsk, /* Normally packets are associated with a socket. There may be no socket if an * attempt was made to connect to a socket that does not exist. */ -static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt) +static int virtio_transport_reset_no_sock(const struct virtio_transport *t, + struct virtio_vsock_pkt *pkt) { - const struct virtio_transport *t; struct virtio_vsock_pkt *reply; struct virtio_vsock_pkt_info info = { .op = VIRTIO_VSOCK_OP_RST, @@ -718,7 +711,6 @@ static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt) if (!reply) return -ENOMEM; - t = virtio_transport_get_ops(); if (!t) { virtio_transport_free_pkt(reply); return -ENOTCONN; @@ -994,13 +986,39 @@ virtio_transport_send_response(struct vsock_sock *vsk, return virtio_transport_send_pkt_info(vsk, &info); } +static bool virtio_transport_space_update(struct sock *sk, + struct virtio_vsock_pkt *pkt) +{ + struct vsock_sock *vsk = vsock_sk(sk); + struct virtio_vsock_sock *vvs = vsk->trans; + bool space_available; + + /* Listener sockets are not associated with any transport, so we are + * not able to take the state to see if there is space available in the + * remote peer, but since they are only used to receive requests, we + * can assume that there is always space available in the other peer. + */ + if (!vvs) + return true; + + /* buf_alloc and fwd_cnt is always included in the hdr */ + spin_lock_bh(&vvs->tx_lock); + vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc); + vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt); + space_available = virtio_transport_has_space(vsk); + spin_unlock_bh(&vvs->tx_lock); + return space_available; +} + /* Handle server socket */ static int -virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt) +virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt, + struct virtio_transport *t) { struct vsock_sock *vsk = vsock_sk(sk); struct vsock_sock *vchild; struct sock *child; + int ret; if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) { virtio_transport_reset(vsk, pkt); @@ -1012,14 +1030,13 @@ virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt) return -ENOMEM; } - child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL, - sk->sk_type, 0); + child = vsock_create_connected(sk); if (!child) { virtio_transport_reset(vsk, pkt); return -ENOMEM; } - sk->sk_ack_backlog++; + sk_acceptq_added(sk); lock_sock_nested(child, SINGLE_DEPTH_NESTING); @@ -1031,6 +1048,20 @@ virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt) vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid), le32_to_cpu(pkt->hdr.src_port)); + ret = vsock_assign_transport(vchild, vsk); + /* Transport assigned (looking at remote_addr) must be the same + * where we received the request. + */ + if (ret || vchild->transport != &t->transport) { + release_sock(child); + virtio_transport_reset(vsk, pkt); + sock_put(child); + return ret; + } + + if (virtio_transport_space_update(child, pkt)) + child->sk_write_space(child); + vsock_insert_connected(vchild); vsock_enqueue_accept(sk, child); virtio_transport_send_response(vchild, pkt); @@ -1041,26 +1072,11 @@ virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt) return 0; } -static bool virtio_transport_space_update(struct sock *sk, - struct virtio_vsock_pkt *pkt) -{ - struct vsock_sock *vsk = vsock_sk(sk); - struct virtio_vsock_sock *vvs = vsk->trans; - bool space_available; - - /* buf_alloc and fwd_cnt is always included in the hdr */ - spin_lock_bh(&vvs->tx_lock); - vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc); - vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt); - space_available = virtio_transport_has_space(vsk); - spin_unlock_bh(&vvs->tx_lock); - return space_available; -} - /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex * lock. */ -void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt) +void virtio_transport_recv_pkt(struct virtio_transport *t, + struct virtio_vsock_pkt *pkt) { struct sockaddr_vm src, dst; struct vsock_sock *vsk; @@ -1082,7 +1098,7 @@ void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt) le32_to_cpu(pkt->hdr.fwd_cnt)); if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) { - (void)virtio_transport_reset_no_sock(pkt); + (void)virtio_transport_reset_no_sock(t, pkt); goto free_pkt; } @@ -1093,7 +1109,7 @@ void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt) if (!sk) { sk = vsock_find_bound_socket(&dst); if (!sk) { - (void)virtio_transport_reset_no_sock(pkt); + (void)virtio_transport_reset_no_sock(t, pkt); goto free_pkt; } } @@ -1112,7 +1128,7 @@ void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt) switch (sk->sk_state) { case TCP_LISTEN: - virtio_transport_recv_listen(sk, pkt); + virtio_transport_recv_listen(sk, pkt, t); virtio_transport_free_pkt(pkt); break; case TCP_SYN_SENT: @@ -1130,6 +1146,7 @@ void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt) virtio_transport_free_pkt(pkt); break; } + release_sock(sk); /* Release refcnt obtained when we fetched this socket out of the diff --git a/net/vmw_vsock/vmci_transport.c b/net/vmw_vsock/vmci_transport.c index 8c9c4ed90fa7..644d32e43d23 100644 --- a/net/vmw_vsock/vmci_transport.c +++ b/net/vmw_vsock/vmci_transport.c @@ -57,6 +57,7 @@ static bool vmci_transport_old_proto_override(bool *old_pkt_proto); static u16 vmci_transport_new_proto_supported_versions(void); static bool vmci_transport_proto_to_notify_struct(struct sock *sk, u16 *proto, bool old_pkt_proto); +static bool vmci_check_transport(struct vsock_sock *vsk); struct vmci_transport_recv_pkt_info { struct work_struct work; @@ -74,15 +75,6 @@ static u32 vmci_transport_qp_resumed_sub_id = VMCI_INVALID_ID; static int PROTOCOL_OVERRIDE = -1; -#define VMCI_TRANSPORT_DEFAULT_QP_SIZE_MIN 128 -#define VMCI_TRANSPORT_DEFAULT_QP_SIZE 262144 -#define VMCI_TRANSPORT_DEFAULT_QP_SIZE_MAX 262144 - -/* The default peer timeout indicates how long we will wait for a peer response - * to a control message. - */ -#define VSOCK_DEFAULT_CONNECT_TIMEOUT (2 * HZ) - /* Helper function to convert from a VMCI error code to a VSock error code. */ static s32 vmci_transport_error_to_vsock_error(s32 vmci_error) @@ -1013,8 +1005,7 @@ static int vmci_transport_recv_listen(struct sock *sk, return -ECONNREFUSED; } - pending = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL, - sk->sk_type, 0); + pending = vsock_create_connected(sk); if (!pending) { vmci_transport_send_reset(sk, pkt); return -ENOMEM; @@ -1027,14 +1018,24 @@ static int vmci_transport_recv_listen(struct sock *sk, vsock_addr_init(&vpending->remote_addr, pkt->dg.src.context, pkt->src_port); + err = vsock_assign_transport(vpending, vsock_sk(sk)); + /* Transport assigned (looking at remote_addr) must be the same + * where we received the request. + */ + if (err || !vmci_check_transport(vpending)) { + vmci_transport_send_reset(sk, pkt); + sock_put(pending); + return err; + } + /* If the proposed size fits within our min/max, accept it. Otherwise * propose our own size. */ - if (pkt->u.size >= vmci_trans(vpending)->queue_pair_min_size && - pkt->u.size <= vmci_trans(vpending)->queue_pair_max_size) { + if (pkt->u.size >= vpending->buffer_min_size && + pkt->u.size <= vpending->buffer_max_size) { qp_size = pkt->u.size; } else { - qp_size = vmci_trans(vpending)->queue_pair_size; + qp_size = vpending->buffer_size; } /* Figure out if we are using old or new requests based on the @@ -1098,12 +1099,12 @@ static int vmci_transport_recv_listen(struct sock *sk, } vsock_add_pending(sk, pending); - sk->sk_ack_backlog++; + sk_acceptq_added(sk); pending->sk_state = TCP_SYN_SENT; vmci_trans(vpending)->produce_size = vmci_trans(vpending)->consume_size = qp_size; - vmci_trans(vpending)->queue_pair_size = qp_size; + vpending->buffer_size = qp_size; vmci_trans(vpending)->notify_ops->process_request(pending); @@ -1397,8 +1398,8 @@ static int vmci_transport_recv_connecting_client_negotiate( vsk->ignore_connecting_rst = false; /* Verify that we're OK with the proposed queue pair size */ - if (pkt->u.size < vmci_trans(vsk)->queue_pair_min_size || - pkt->u.size > vmci_trans(vsk)->queue_pair_max_size) { + if (pkt->u.size < vsk->buffer_min_size || + pkt->u.size > vsk->buffer_max_size) { err = -EINVAL; goto destroy; } @@ -1503,8 +1504,7 @@ vmci_transport_recv_connecting_client_invalid(struct sock *sk, vsk->sent_request = false; vsk->ignore_connecting_rst = true; - err = vmci_transport_send_conn_request( - sk, vmci_trans(vsk)->queue_pair_size); + err = vmci_transport_send_conn_request(sk, vsk->buffer_size); if (err < 0) err = vmci_transport_error_to_vsock_error(err); else @@ -1588,21 +1588,6 @@ static int vmci_transport_socket_init(struct vsock_sock *vsk, INIT_LIST_HEAD(&vmci_trans(vsk)->elem); vmci_trans(vsk)->sk = &vsk->sk; spin_lock_init(&vmci_trans(vsk)->lock); - if (psk) { - vmci_trans(vsk)->queue_pair_size = - vmci_trans(psk)->queue_pair_size; - vmci_trans(vsk)->queue_pair_min_size = - vmci_trans(psk)->queue_pair_min_size; - vmci_trans(vsk)->queue_pair_max_size = - vmci_trans(psk)->queue_pair_max_size; - } else { - vmci_trans(vsk)->queue_pair_size = - VMCI_TRANSPORT_DEFAULT_QP_SIZE; - vmci_trans(vsk)->queue_pair_min_size = - VMCI_TRANSPORT_DEFAULT_QP_SIZE_MIN; - vmci_trans(vsk)->queue_pair_max_size = - VMCI_TRANSPORT_DEFAULT_QP_SIZE_MAX; - } return 0; } @@ -1818,8 +1803,7 @@ static int vmci_transport_connect(struct vsock_sock *vsk) if (vmci_transport_old_proto_override(&old_pkt_proto) && old_pkt_proto) { - err = vmci_transport_send_conn_request( - sk, vmci_trans(vsk)->queue_pair_size); + err = vmci_transport_send_conn_request(sk, vsk->buffer_size); if (err < 0) { sk->sk_state = TCP_CLOSE; return err; @@ -1827,8 +1811,7 @@ static int vmci_transport_connect(struct vsock_sock *vsk) } else { int supported_proto_versions = vmci_transport_new_proto_supported_versions(); - err = vmci_transport_send_conn_request2( - sk, vmci_trans(vsk)->queue_pair_size, + err = vmci_transport_send_conn_request2(sk, vsk->buffer_size, supported_proto_versions); if (err < 0) { sk->sk_state = TCP_CLOSE; @@ -1881,46 +1864,6 @@ static bool vmci_transport_stream_is_active(struct vsock_sock *vsk) return !vmci_handle_is_invalid(vmci_trans(vsk)->qp_handle); } -static u64 vmci_transport_get_buffer_size(struct vsock_sock *vsk) -{ - return vmci_trans(vsk)->queue_pair_size; -} - -static u64 vmci_transport_get_min_buffer_size(struct vsock_sock *vsk) -{ - return vmci_trans(vsk)->queue_pair_min_size; -} - -static u64 vmci_transport_get_max_buffer_size(struct vsock_sock *vsk) -{ - return vmci_trans(vsk)->queue_pair_max_size; -} - -static void vmci_transport_set_buffer_size(struct vsock_sock *vsk, u64 val) -{ - if (val < vmci_trans(vsk)->queue_pair_min_size) - vmci_trans(vsk)->queue_pair_min_size = val; - if (val > vmci_trans(vsk)->queue_pair_max_size) - vmci_trans(vsk)->queue_pair_max_size = val; - vmci_trans(vsk)->queue_pair_size = val; -} - -static void vmci_transport_set_min_buffer_size(struct vsock_sock *vsk, - u64 val) -{ - if (val > vmci_trans(vsk)->queue_pair_size) - vmci_trans(vsk)->queue_pair_size = val; - vmci_trans(vsk)->queue_pair_min_size = val; -} - -static void vmci_transport_set_max_buffer_size(struct vsock_sock *vsk, - u64 val) -{ - if (val < vmci_trans(vsk)->queue_pair_size) - vmci_trans(vsk)->queue_pair_size = val; - vmci_trans(vsk)->queue_pair_max_size = val; -} - static int vmci_transport_notify_poll_in( struct vsock_sock *vsk, size_t target, @@ -2076,7 +2019,8 @@ static u32 vmci_transport_get_local_cid(void) return vmci_get_context_id(); } -static const struct vsock_transport vmci_transport = { +static struct vsock_transport vmci_transport = { + .module = THIS_MODULE, .init = vmci_transport_socket_init, .destruct = vmci_transport_destruct, .release = vmci_transport_release, @@ -2103,15 +2047,26 @@ static const struct vsock_transport vmci_transport = { .notify_send_pre_enqueue = vmci_transport_notify_send_pre_enqueue, .notify_send_post_enqueue = vmci_transport_notify_send_post_enqueue, .shutdown = vmci_transport_shutdown, - .set_buffer_size = vmci_transport_set_buffer_size, - .set_min_buffer_size = vmci_transport_set_min_buffer_size, - .set_max_buffer_size = vmci_transport_set_max_buffer_size, - .get_buffer_size = vmci_transport_get_buffer_size, - .get_min_buffer_size = vmci_transport_get_min_buffer_size, - .get_max_buffer_size = vmci_transport_get_max_buffer_size, .get_local_cid = vmci_transport_get_local_cid, }; +static bool vmci_check_transport(struct vsock_sock *vsk) +{ + return vsk->transport == &vmci_transport; +} + +void vmci_vsock_transport_cb(bool is_host) +{ + int features; + + if (is_host) + features = VSOCK_TRANSPORT_F_H2G; + else + features = VSOCK_TRANSPORT_F_G2H; + + vsock_core_register(&vmci_transport, features); +} + static int __init vmci_transport_init(void) { int err; @@ -2128,7 +2083,6 @@ static int __init vmci_transport_init(void) pr_err("Unable to create datagram handle. (%d)\n", err); return vmci_transport_error_to_vsock_error(err); } - err = vmci_event_subscribe(VMCI_EVENT_QP_RESUMED, vmci_transport_qp_resumed_cb, NULL, &vmci_transport_qp_resumed_sub_id); @@ -2139,12 +2093,21 @@ static int __init vmci_transport_init(void) goto err_destroy_stream_handle; } - err = vsock_core_init(&vmci_transport); + /* Register only with dgram feature, other features (H2G, G2H) will be + * registered when the first host or guest becomes active. + */ + err = vsock_core_register(&vmci_transport, VSOCK_TRANSPORT_F_DGRAM); if (err < 0) goto err_unsubscribe; + err = vmci_register_vsock_callback(vmci_vsock_transport_cb); + if (err < 0) + goto err_unregister; + return 0; +err_unregister: + vsock_core_unregister(&vmci_transport); err_unsubscribe: vmci_event_unsubscribe(vmci_transport_qp_resumed_sub_id); err_destroy_stream_handle: @@ -2170,7 +2133,8 @@ static void __exit vmci_transport_exit(void) vmci_transport_qp_resumed_sub_id = VMCI_INVALID_ID; } - vsock_core_exit(); + vmci_register_vsock_callback(NULL); + vsock_core_unregister(&vmci_transport); } module_exit(vmci_transport_exit); diff --git a/net/vmw_vsock/vmci_transport.h b/net/vmw_vsock/vmci_transport.h index 1ca1e8640b31..b7b072194282 100644 --- a/net/vmw_vsock/vmci_transport.h +++ b/net/vmw_vsock/vmci_transport.h @@ -108,9 +108,6 @@ struct vmci_transport { struct vmci_qp *qpair; u64 produce_size; u64 consume_size; - u64 queue_pair_size; - u64 queue_pair_min_size; - u64 queue_pair_max_size; u32 detach_sub_id; union vmci_transport_notify notify; const struct vmci_transport_notify_ops *notify_ops; diff --git a/net/vmw_vsock/vmci_transport_notify.h b/net/vmw_vsock/vmci_transport_notify.h index 7843f08d4290..a1aa5a998c0e 100644 --- a/net/vmw_vsock/vmci_transport_notify.h +++ b/net/vmw_vsock/vmci_transport_notify.h @@ -11,7 +11,6 @@ #include <linux/types.h> #include <linux/vmw_vmci_defs.h> #include <linux/vmw_vmci_api.h> -#include <linux/vm_sockets.h> #include "vmci_transport.h" |