summaryrefslogtreecommitdiff
path: root/net/vmw_vsock/vmci_transport.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/vmw_vsock/vmci_transport.c')
-rw-r--r--net/vmw_vsock/vmci_transport.c142
1 files changed, 53 insertions, 89 deletions
diff --git a/net/vmw_vsock/vmci_transport.c b/net/vmw_vsock/vmci_transport.c
index 8c9c4ed90fa7..644d32e43d23 100644
--- a/net/vmw_vsock/vmci_transport.c
+++ b/net/vmw_vsock/vmci_transport.c
@@ -57,6 +57,7 @@ static bool vmci_transport_old_proto_override(bool *old_pkt_proto);
static u16 vmci_transport_new_proto_supported_versions(void);
static bool vmci_transport_proto_to_notify_struct(struct sock *sk, u16 *proto,
bool old_pkt_proto);
+static bool vmci_check_transport(struct vsock_sock *vsk);
struct vmci_transport_recv_pkt_info {
struct work_struct work;
@@ -74,15 +75,6 @@ static u32 vmci_transport_qp_resumed_sub_id = VMCI_INVALID_ID;
static int PROTOCOL_OVERRIDE = -1;
-#define VMCI_TRANSPORT_DEFAULT_QP_SIZE_MIN 128
-#define VMCI_TRANSPORT_DEFAULT_QP_SIZE 262144
-#define VMCI_TRANSPORT_DEFAULT_QP_SIZE_MAX 262144
-
-/* The default peer timeout indicates how long we will wait for a peer response
- * to a control message.
- */
-#define VSOCK_DEFAULT_CONNECT_TIMEOUT (2 * HZ)
-
/* Helper function to convert from a VMCI error code to a VSock error code. */
static s32 vmci_transport_error_to_vsock_error(s32 vmci_error)
@@ -1013,8 +1005,7 @@ static int vmci_transport_recv_listen(struct sock *sk,
return -ECONNREFUSED;
}
- pending = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
- sk->sk_type, 0);
+ pending = vsock_create_connected(sk);
if (!pending) {
vmci_transport_send_reset(sk, pkt);
return -ENOMEM;
@@ -1027,14 +1018,24 @@ static int vmci_transport_recv_listen(struct sock *sk,
vsock_addr_init(&vpending->remote_addr, pkt->dg.src.context,
pkt->src_port);
+ err = vsock_assign_transport(vpending, vsock_sk(sk));
+ /* Transport assigned (looking at remote_addr) must be the same
+ * where we received the request.
+ */
+ if (err || !vmci_check_transport(vpending)) {
+ vmci_transport_send_reset(sk, pkt);
+ sock_put(pending);
+ return err;
+ }
+
/* If the proposed size fits within our min/max, accept it. Otherwise
* propose our own size.
*/
- if (pkt->u.size >= vmci_trans(vpending)->queue_pair_min_size &&
- pkt->u.size <= vmci_trans(vpending)->queue_pair_max_size) {
+ if (pkt->u.size >= vpending->buffer_min_size &&
+ pkt->u.size <= vpending->buffer_max_size) {
qp_size = pkt->u.size;
} else {
- qp_size = vmci_trans(vpending)->queue_pair_size;
+ qp_size = vpending->buffer_size;
}
/* Figure out if we are using old or new requests based on the
@@ -1098,12 +1099,12 @@ static int vmci_transport_recv_listen(struct sock *sk,
}
vsock_add_pending(sk, pending);
- sk->sk_ack_backlog++;
+ sk_acceptq_added(sk);
pending->sk_state = TCP_SYN_SENT;
vmci_trans(vpending)->produce_size =
vmci_trans(vpending)->consume_size = qp_size;
- vmci_trans(vpending)->queue_pair_size = qp_size;
+ vpending->buffer_size = qp_size;
vmci_trans(vpending)->notify_ops->process_request(pending);
@@ -1397,8 +1398,8 @@ static int vmci_transport_recv_connecting_client_negotiate(
vsk->ignore_connecting_rst = false;
/* Verify that we're OK with the proposed queue pair size */
- if (pkt->u.size < vmci_trans(vsk)->queue_pair_min_size ||
- pkt->u.size > vmci_trans(vsk)->queue_pair_max_size) {
+ if (pkt->u.size < vsk->buffer_min_size ||
+ pkt->u.size > vsk->buffer_max_size) {
err = -EINVAL;
goto destroy;
}
@@ -1503,8 +1504,7 @@ vmci_transport_recv_connecting_client_invalid(struct sock *sk,
vsk->sent_request = false;
vsk->ignore_connecting_rst = true;
- err = vmci_transport_send_conn_request(
- sk, vmci_trans(vsk)->queue_pair_size);
+ err = vmci_transport_send_conn_request(sk, vsk->buffer_size);
if (err < 0)
err = vmci_transport_error_to_vsock_error(err);
else
@@ -1588,21 +1588,6 @@ static int vmci_transport_socket_init(struct vsock_sock *vsk,
INIT_LIST_HEAD(&vmci_trans(vsk)->elem);
vmci_trans(vsk)->sk = &vsk->sk;
spin_lock_init(&vmci_trans(vsk)->lock);
- if (psk) {
- vmci_trans(vsk)->queue_pair_size =
- vmci_trans(psk)->queue_pair_size;
- vmci_trans(vsk)->queue_pair_min_size =
- vmci_trans(psk)->queue_pair_min_size;
- vmci_trans(vsk)->queue_pair_max_size =
- vmci_trans(psk)->queue_pair_max_size;
- } else {
- vmci_trans(vsk)->queue_pair_size =
- VMCI_TRANSPORT_DEFAULT_QP_SIZE;
- vmci_trans(vsk)->queue_pair_min_size =
- VMCI_TRANSPORT_DEFAULT_QP_SIZE_MIN;
- vmci_trans(vsk)->queue_pair_max_size =
- VMCI_TRANSPORT_DEFAULT_QP_SIZE_MAX;
- }
return 0;
}
@@ -1818,8 +1803,7 @@ static int vmci_transport_connect(struct vsock_sock *vsk)
if (vmci_transport_old_proto_override(&old_pkt_proto) &&
old_pkt_proto) {
- err = vmci_transport_send_conn_request(
- sk, vmci_trans(vsk)->queue_pair_size);
+ err = vmci_transport_send_conn_request(sk, vsk->buffer_size);
if (err < 0) {
sk->sk_state = TCP_CLOSE;
return err;
@@ -1827,8 +1811,7 @@ static int vmci_transport_connect(struct vsock_sock *vsk)
} else {
int supported_proto_versions =
vmci_transport_new_proto_supported_versions();
- err = vmci_transport_send_conn_request2(
- sk, vmci_trans(vsk)->queue_pair_size,
+ err = vmci_transport_send_conn_request2(sk, vsk->buffer_size,
supported_proto_versions);
if (err < 0) {
sk->sk_state = TCP_CLOSE;
@@ -1881,46 +1864,6 @@ static bool vmci_transport_stream_is_active(struct vsock_sock *vsk)
return !vmci_handle_is_invalid(vmci_trans(vsk)->qp_handle);
}
-static u64 vmci_transport_get_buffer_size(struct vsock_sock *vsk)
-{
- return vmci_trans(vsk)->queue_pair_size;
-}
-
-static u64 vmci_transport_get_min_buffer_size(struct vsock_sock *vsk)
-{
- return vmci_trans(vsk)->queue_pair_min_size;
-}
-
-static u64 vmci_transport_get_max_buffer_size(struct vsock_sock *vsk)
-{
- return vmci_trans(vsk)->queue_pair_max_size;
-}
-
-static void vmci_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
-{
- if (val < vmci_trans(vsk)->queue_pair_min_size)
- vmci_trans(vsk)->queue_pair_min_size = val;
- if (val > vmci_trans(vsk)->queue_pair_max_size)
- vmci_trans(vsk)->queue_pair_max_size = val;
- vmci_trans(vsk)->queue_pair_size = val;
-}
-
-static void vmci_transport_set_min_buffer_size(struct vsock_sock *vsk,
- u64 val)
-{
- if (val > vmci_trans(vsk)->queue_pair_size)
- vmci_trans(vsk)->queue_pair_size = val;
- vmci_trans(vsk)->queue_pair_min_size = val;
-}
-
-static void vmci_transport_set_max_buffer_size(struct vsock_sock *vsk,
- u64 val)
-{
- if (val < vmci_trans(vsk)->queue_pair_size)
- vmci_trans(vsk)->queue_pair_size = val;
- vmci_trans(vsk)->queue_pair_max_size = val;
-}
-
static int vmci_transport_notify_poll_in(
struct vsock_sock *vsk,
size_t target,
@@ -2076,7 +2019,8 @@ static u32 vmci_transport_get_local_cid(void)
return vmci_get_context_id();
}
-static const struct vsock_transport vmci_transport = {
+static struct vsock_transport vmci_transport = {
+ .module = THIS_MODULE,
.init = vmci_transport_socket_init,
.destruct = vmci_transport_destruct,
.release = vmci_transport_release,
@@ -2103,15 +2047,26 @@ static const struct vsock_transport vmci_transport = {
.notify_send_pre_enqueue = vmci_transport_notify_send_pre_enqueue,
.notify_send_post_enqueue = vmci_transport_notify_send_post_enqueue,
.shutdown = vmci_transport_shutdown,
- .set_buffer_size = vmci_transport_set_buffer_size,
- .set_min_buffer_size = vmci_transport_set_min_buffer_size,
- .set_max_buffer_size = vmci_transport_set_max_buffer_size,
- .get_buffer_size = vmci_transport_get_buffer_size,
- .get_min_buffer_size = vmci_transport_get_min_buffer_size,
- .get_max_buffer_size = vmci_transport_get_max_buffer_size,
.get_local_cid = vmci_transport_get_local_cid,
};
+static bool vmci_check_transport(struct vsock_sock *vsk)
+{
+ return vsk->transport == &vmci_transport;
+}
+
+void vmci_vsock_transport_cb(bool is_host)
+{
+ int features;
+
+ if (is_host)
+ features = VSOCK_TRANSPORT_F_H2G;
+ else
+ features = VSOCK_TRANSPORT_F_G2H;
+
+ vsock_core_register(&vmci_transport, features);
+}
+
static int __init vmci_transport_init(void)
{
int err;
@@ -2128,7 +2083,6 @@ static int __init vmci_transport_init(void)
pr_err("Unable to create datagram handle. (%d)\n", err);
return vmci_transport_error_to_vsock_error(err);
}
-
err = vmci_event_subscribe(VMCI_EVENT_QP_RESUMED,
vmci_transport_qp_resumed_cb,
NULL, &vmci_transport_qp_resumed_sub_id);
@@ -2139,12 +2093,21 @@ static int __init vmci_transport_init(void)
goto err_destroy_stream_handle;
}
- err = vsock_core_init(&vmci_transport);
+ /* Register only with dgram feature, other features (H2G, G2H) will be
+ * registered when the first host or guest becomes active.
+ */
+ err = vsock_core_register(&vmci_transport, VSOCK_TRANSPORT_F_DGRAM);
if (err < 0)
goto err_unsubscribe;
+ err = vmci_register_vsock_callback(vmci_vsock_transport_cb);
+ if (err < 0)
+ goto err_unregister;
+
return 0;
+err_unregister:
+ vsock_core_unregister(&vmci_transport);
err_unsubscribe:
vmci_event_unsubscribe(vmci_transport_qp_resumed_sub_id);
err_destroy_stream_handle:
@@ -2170,7 +2133,8 @@ static void __exit vmci_transport_exit(void)
vmci_transport_qp_resumed_sub_id = VMCI_INVALID_ID;
}
- vsock_core_exit();
+ vmci_register_vsock_callback(NULL);
+ vsock_core_unregister(&vmci_transport);
}
module_exit(vmci_transport_exit);