diff options
author | Jakub Kicinski <kuba@kernel.org> | 2021-12-07 11:36:36 -0800 |
---|---|---|
committer | Jakub Kicinski <kuba@kernel.org> | 2021-12-07 11:36:37 -0800 |
commit | 59d58d93af94d7b546f12438b4a3d781b7190095 (patch) | |
tree | 23cdba78ff4fc7a076e445b95a55db44cd213efc | |
parent | c0e5e11af12b76d0dbed700c1088c6827cdcf56c (diff) | |
parent | 4f6e14bd19d6de7831f31cfb3210f2ea93eeb038 (diff) |
Merge branch 'mptcp-new-features-for-mptcp-sockets-and-netlink-pm'
Mat Martineau says:
====================
mptcp: New features for MPTCP sockets and netlink PM
This collection of patches adds MPTCP socket support for a few socket
options, ioctls, and one ancillary data type (specifics for each are
listed below). There's also a patch modifying the netlink MPTCP path
manager API to allow setting the backup flag on a configured interface
using the endpoint ID instead of the full IP address.
Patches 1 & 2: TCP_INQ cmsg and selftests.
Patches 2 & 3: SIOCINQ, OUTQ, and OUTQNSD ioctls and selftests.
Patch 5: Change backup flag using endpoint ID.
Patches 6 & 7: IP_TOS socket option and selftests.
Patches 8-10: TCP_CORK and TCP_NODELAY socket options. Includes a tcp
change to expose __tcp_sock_set_cork() and __tcp_sock_set_nodelay() for
use by MPTCP.
====================
Link: https://lore.kernel.org/r/20211203223541.69364-1-mathew.j.martineau@linux.intel.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
-rw-r--r-- | include/linux/tcp.h | 2 | ||||
-rw-r--r-- | net/ipv4/tcp.c | 4 | ||||
-rw-r--r-- | net/mptcp/pm_netlink.c | 14 | ||||
-rw-r--r-- | net/mptcp/protocol.c | 91 | ||||
-rw-r--r-- | net/mptcp/protocol.h | 4 | ||||
-rw-r--r-- | net/mptcp/sockopt.c | 132 | ||||
-rw-r--r-- | tools/testing/selftests/net/mptcp/.gitignore | 1 | ||||
-rw-r--r-- | tools/testing/selftests/net/mptcp/Makefile | 2 | ||||
-rw-r--r-- | tools/testing/selftests/net/mptcp/mptcp_connect.c | 60 | ||||
-rw-r--r-- | tools/testing/selftests/net/mptcp/mptcp_inq.c | 603 | ||||
-rw-r--r-- | tools/testing/selftests/net/mptcp/mptcp_sockopt.c | 63 | ||||
-rwxr-xr-x | tools/testing/selftests/net/mptcp/mptcp_sockopt.sh | 44 |
12 files changed, 1007 insertions, 13 deletions
diff --git a/include/linux/tcp.h b/include/linux/tcp.h index 48d8a363319e..78b91bb92f0d 100644 --- a/include/linux/tcp.h +++ b/include/linux/tcp.h @@ -512,11 +512,13 @@ static inline u16 tcp_mss_clamp(const struct tcp_sock *tp, u16 mss) int tcp_skb_shift(struct sk_buff *to, struct sk_buff *from, int pcount, int shiftlen); +void __tcp_sock_set_cork(struct sock *sk, bool on); void tcp_sock_set_cork(struct sock *sk, bool on); int tcp_sock_set_keepcnt(struct sock *sk, int val); int tcp_sock_set_keepidle_locked(struct sock *sk, int val); int tcp_sock_set_keepidle(struct sock *sk, int val); int tcp_sock_set_keepintvl(struct sock *sk, int val); +void __tcp_sock_set_nodelay(struct sock *sk, bool on); void tcp_sock_set_nodelay(struct sock *sk); void tcp_sock_set_quickack(struct sock *sk, int val); int tcp_sock_set_syncnt(struct sock *sk, int val); diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c index 6ab82e1a1d41..20054618c87e 100644 --- a/net/ipv4/tcp.c +++ b/net/ipv4/tcp.c @@ -3207,7 +3207,7 @@ static void tcp_enable_tx_delay(void) * TCP_CORK can be set together with TCP_NODELAY and it is stronger than * TCP_NODELAY. */ -static void __tcp_sock_set_cork(struct sock *sk, bool on) +void __tcp_sock_set_cork(struct sock *sk, bool on) { struct tcp_sock *tp = tcp_sk(sk); @@ -3235,7 +3235,7 @@ EXPORT_SYMBOL(tcp_sock_set_cork); * However, when TCP_NODELAY is set we make an explicit push, which overrides * even TCP_CORK for currently queued segments. */ -static void __tcp_sock_set_nodelay(struct sock *sk, bool on) +void __tcp_sock_set_nodelay(struct sock *sk, bool on) { if (on) { tcp_sk(sk)->nonagle |= TCP_NAGLE_OFF|TCP_NAGLE_PUSH; diff --git a/net/mptcp/pm_netlink.c b/net/mptcp/pm_netlink.c index 7b96be1e9f14..4ff8d55cbe82 100644 --- a/net/mptcp/pm_netlink.c +++ b/net/mptcp/pm_netlink.c @@ -1702,22 +1702,28 @@ next: static int mptcp_nl_cmd_set_flags(struct sk_buff *skb, struct genl_info *info) { + struct mptcp_pm_addr_entry addr = { .addr = { .family = AF_UNSPEC }, }, *entry; struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; struct pm_nl_pernet *pernet = genl_info_pm_nl(info); - struct mptcp_pm_addr_entry addr, *entry; struct net *net = sock_net(skb->sk); - u8 bkup = 0; + u8 bkup = 0, lookup_by_id = 0; int ret; - ret = mptcp_pm_parse_addr(attr, info, true, &addr); + ret = mptcp_pm_parse_addr(attr, info, false, &addr); if (ret < 0) return ret; if (addr.flags & MPTCP_PM_ADDR_FLAG_BACKUP) bkup = 1; + if (addr.addr.family == AF_UNSPEC) { + lookup_by_id = 1; + if (!addr.addr.id) + return -EOPNOTSUPP; + } list_for_each_entry(entry, &pernet->local_addr_list, list) { - if (addresses_equal(&entry->addr, &addr.addr, true)) { + if ((!lookup_by_id && addresses_equal(&entry->addr, &addr.addr, true)) || + (lookup_by_id && entry->addr.id == addr.addr.id)) { mptcp_nl_addr_backup(net, &entry->addr, bkup); if (bkup) diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c index b100048e43fe..f124cca125d2 100644 --- a/net/mptcp/protocol.c +++ b/net/mptcp/protocol.c @@ -22,6 +22,7 @@ #endif #include <net/mptcp.h> #include <net/xfrm.h> +#include <asm/ioctls.h> #include "protocol.h" #include "mib.h" @@ -46,6 +47,7 @@ struct mptcp_skb_cb { enum { MPTCP_CMSG_TS = BIT(0), + MPTCP_CMSG_INQ = BIT(1), }; static struct percpu_counter mptcp_sockets_allocated ____cacheline_aligned_in_smp; @@ -738,6 +740,7 @@ static bool __mptcp_ofo_queue(struct mptcp_sock *msk) MPTCP_SKB_CB(skb)->map_seq, msk->ack_seq, delta); MPTCP_SKB_CB(skb)->offset += delta; + MPTCP_SKB_CB(skb)->map_seq += delta; __skb_queue_tail(&sk->sk_receive_queue, skb); } msk->ack_seq = end_seq; @@ -1499,7 +1502,7 @@ static void mptcp_update_post_push(struct mptcp_sock *msk, msk->snd_nxt = snd_nxt_new; } -static void mptcp_check_and_set_pending(struct sock *sk) +void mptcp_check_and_set_pending(struct sock *sk) { if (mptcp_send_head(sk) && !test_bit(MPTCP_PUSH_PENDING, &mptcp_sk(sk)->flags)) @@ -1784,8 +1787,10 @@ static int __mptcp_recvmsg_mskq(struct mptcp_sock *msk, copied += count; if (count < data_len) { - if (!(flags & MSG_PEEK)) + if (!(flags & MSG_PEEK)) { MPTCP_SKB_CB(skb)->offset += count; + MPTCP_SKB_CB(skb)->map_seq += count; + } break; } @@ -1965,6 +1970,27 @@ static bool __mptcp_move_skbs(struct mptcp_sock *msk) return !skb_queue_empty(&msk->receive_queue); } +static unsigned int mptcp_inq_hint(const struct sock *sk) +{ + const struct mptcp_sock *msk = mptcp_sk(sk); + const struct sk_buff *skb; + + skb = skb_peek(&msk->receive_queue); + if (skb) { + u64 hint_val = msk->ack_seq - MPTCP_SKB_CB(skb)->map_seq; + + if (hint_val >= INT_MAX) + return INT_MAX; + + return (unsigned int)hint_val; + } + + if (sk->sk_state == TCP_CLOSE || (sk->sk_shutdown & RCV_SHUTDOWN)) + return 1; + + return 0; +} + static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int nonblock, int flags, int *addr_len) { @@ -1989,6 +2015,9 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, len = min_t(size_t, len, INT_MAX); target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); + if (unlikely(msk->recvmsg_inq)) + cmsg_flags = MPTCP_CMSG_INQ; + while (copied < len) { int bytes_read; @@ -2062,6 +2091,12 @@ out_err: if (cmsg_flags && copied >= 0) { if (cmsg_flags & MPTCP_CMSG_TS) tcp_recv_timestamp(msg, sk, &tss); + + if (cmsg_flags & MPTCP_CMSG_INQ) { + unsigned int inq = mptcp_inq_hint(sk); + + put_cmsg(msg, SOL_TCP, TCP_CM_INQ, sizeof(inq), &inq); + } } pr_debug("msk=%p rx queue empty=%d:%d copied=%d", @@ -3177,6 +3212,57 @@ static int mptcp_forward_alloc_get(const struct sock *sk) return sk->sk_forward_alloc + mptcp_sk(sk)->rmem_fwd_alloc; } +static int mptcp_ioctl_outq(const struct mptcp_sock *msk, u64 v) +{ + const struct sock *sk = (void *)msk; + u64 delta; + + if (sk->sk_state == TCP_LISTEN) + return -EINVAL; + + if ((1 << sk->sk_state) & (TCPF_SYN_SENT | TCPF_SYN_RECV)) + return 0; + + delta = msk->write_seq - v; + if (delta > INT_MAX) + delta = INT_MAX; + + return (int)delta; +} + +static int mptcp_ioctl(struct sock *sk, int cmd, unsigned long arg) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + bool slow; + int answ; + + switch (cmd) { + case SIOCINQ: + if (sk->sk_state == TCP_LISTEN) + return -EINVAL; + + lock_sock(sk); + __mptcp_move_skbs(msk); + answ = mptcp_inq_hint(sk); + release_sock(sk); + break; + case SIOCOUTQ: + slow = lock_sock_fast(sk); + answ = mptcp_ioctl_outq(msk, READ_ONCE(msk->snd_una)); + unlock_sock_fast(sk, slow); + break; + case SIOCOUTQNSD: + slow = lock_sock_fast(sk); + answ = mptcp_ioctl_outq(msk, msk->snd_nxt); + unlock_sock_fast(sk, slow); + break; + default: + return -ENOIOCTLCMD; + } + + return put_user(answ, (int __user *)arg); +} + static struct proto mptcp_prot = { .name = "MPTCP", .owner = THIS_MODULE, @@ -3189,6 +3275,7 @@ static struct proto mptcp_prot = { .shutdown = mptcp_shutdown, .destroy = mptcp_destroy, .sendmsg = mptcp_sendmsg, + .ioctl = mptcp_ioctl, .recvmsg = mptcp_recvmsg, .release_cb = mptcp_release_cb, .hash = mptcp_hash, diff --git a/net/mptcp/protocol.h b/net/mptcp/protocol.h index d87cc040352e..e1469155fb15 100644 --- a/net/mptcp/protocol.h +++ b/net/mptcp/protocol.h @@ -249,6 +249,9 @@ struct mptcp_sock { bool rcv_fastclose; bool use_64bit_ack; /* Set when we received a 64-bit DSN */ bool csum_enabled; + u8 recvmsg_inq:1, + cork:1, + nodelay:1; spinlock_t join_list_lock; struct work_struct work; struct sk_buff *ooo_last_skb; @@ -554,6 +557,7 @@ unsigned int mptcp_stale_loss_cnt(const struct net *net); void mptcp_subflow_fully_established(struct mptcp_subflow_context *subflow, struct mptcp_options_received *mp_opt); bool __mptcp_retransmit_pending_data(struct sock *sk); +void mptcp_check_and_set_pending(struct sock *sk); void __mptcp_push_pending(struct sock *sk, unsigned int flags); bool mptcp_subflow_data_available(struct sock *sk); void __init mptcp_subflow_init(void); diff --git a/net/mptcp/sockopt.c b/net/mptcp/sockopt.c index fb43e145cb57..3c3db22fd36a 100644 --- a/net/mptcp/sockopt.c +++ b/net/mptcp/sockopt.c @@ -557,6 +557,7 @@ static bool mptcp_supported_sockopt(int level, int optname) case TCP_TIMESTAMP: case TCP_NOTSENT_LOWAT: case TCP_TX_DELAY: + case TCP_INQ: return true; } @@ -568,7 +569,6 @@ static bool mptcp_supported_sockopt(int level, int optname) /* TCP_FASTOPEN_KEY, TCP_FASTOPEN TCP_FASTOPEN_CONNECT, TCP_FASTOPEN_NO_COOKIE, * are not supported fastopen is currently unsupported */ - /* TCP_INQ is currently unsupported, needs some recvmsg work */ } return false; } @@ -616,6 +616,66 @@ static int mptcp_setsockopt_sol_tcp_congestion(struct mptcp_sock *msk, sockptr_t return ret; } +static int mptcp_setsockopt_sol_tcp_cork(struct mptcp_sock *msk, sockptr_t optval, + unsigned int optlen) +{ + struct mptcp_subflow_context *subflow; + struct sock *sk = (struct sock *)msk; + int val; + + if (optlen < sizeof(int)) + return -EINVAL; + + if (copy_from_sockptr(&val, optval, sizeof(val))) + return -EFAULT; + + lock_sock(sk); + sockopt_seq_inc(msk); + msk->cork = !!val; + mptcp_for_each_subflow(msk, subflow) { + struct sock *ssk = mptcp_subflow_tcp_sock(subflow); + + lock_sock(ssk); + __tcp_sock_set_cork(ssk, !!val); + release_sock(ssk); + } + if (!val) + mptcp_check_and_set_pending(sk); + release_sock(sk); + + return 0; +} + +static int mptcp_setsockopt_sol_tcp_nodelay(struct mptcp_sock *msk, sockptr_t optval, + unsigned int optlen) +{ + struct mptcp_subflow_context *subflow; + struct sock *sk = (struct sock *)msk; + int val; + + if (optlen < sizeof(int)) + return -EINVAL; + + if (copy_from_sockptr(&val, optval, sizeof(val))) + return -EFAULT; + + lock_sock(sk); + sockopt_seq_inc(msk); + msk->nodelay = !!val; + mptcp_for_each_subflow(msk, subflow) { + struct sock *ssk = mptcp_subflow_tcp_sock(subflow); + + lock_sock(ssk); + __tcp_sock_set_nodelay(ssk, !!val); + release_sock(ssk); + } + if (val) + mptcp_check_and_set_pending(sk); + release_sock(sk); + + return 0; +} + static int mptcp_setsockopt_sol_ip_set_transparent(struct mptcp_sock *msk, int optname, sockptr_t optval, unsigned int optlen) { @@ -698,11 +758,29 @@ static int mptcp_setsockopt_v4(struct mptcp_sock *msk, int optname, static int mptcp_setsockopt_sol_tcp(struct mptcp_sock *msk, int optname, sockptr_t optval, unsigned int optlen) { + struct sock *sk = (void *)msk; + int ret, val; + switch (optname) { + case TCP_INQ: + ret = mptcp_get_int_option(msk, optval, optlen, &val); + if (ret) + return ret; + if (val < 0 || val > 1) + return -EINVAL; + + lock_sock(sk); + msk->recvmsg_inq = !!val; + release_sock(sk); + return 0; case TCP_ULP: return -EOPNOTSUPP; case TCP_CONGESTION: return mptcp_setsockopt_sol_tcp_congestion(msk, optval, optlen); + case TCP_CORK: + return mptcp_setsockopt_sol_tcp_cork(msk, optval, optlen); + case TCP_NODELAY: + return mptcp_setsockopt_sol_tcp_nodelay(msk, optval, optlen); } return -EOPNOTSUPP; @@ -1032,6 +1110,35 @@ static int mptcp_getsockopt_subflow_addrs(struct mptcp_sock *msk, char __user *o return 0; } +static int mptcp_put_int_option(struct mptcp_sock *msk, char __user *optval, + int __user *optlen, int val) +{ + int len; + + if (get_user(len, optlen)) + return -EFAULT; + if (len < 0) + return -EINVAL; + + if (len < sizeof(int) && len > 0 && val >= 0 && val <= 255) { + unsigned char ucval = (unsigned char)val; + + len = 1; + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, &ucval, 1)) + return -EFAULT; + } else { + len = min_t(unsigned int, len, sizeof(int)); + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, &val, len)) + return -EFAULT; + } + + return 0; +} + static int mptcp_getsockopt_sol_tcp(struct mptcp_sock *msk, int optname, char __user *optval, int __user *optlen) { @@ -1042,10 +1149,29 @@ static int mptcp_getsockopt_sol_tcp(struct mptcp_sock *msk, int optname, case TCP_CC_INFO: return mptcp_getsockopt_first_sf_only(msk, SOL_TCP, optname, optval, optlen); + case TCP_INQ: + return mptcp_put_int_option(msk, optval, optlen, msk->recvmsg_inq); + case TCP_CORK: + return mptcp_put_int_option(msk, optval, optlen, msk->cork); + case TCP_NODELAY: + return mptcp_put_int_option(msk, optval, optlen, msk->nodelay); } return -EOPNOTSUPP; } +static int mptcp_getsockopt_v4(struct mptcp_sock *msk, int optname, + char __user *optval, int __user *optlen) +{ + struct sock *sk = (void *)msk; + + switch (optname) { + case IP_TOS: + return mptcp_put_int_option(msk, optval, optlen, inet_sk(sk)->tos); + } + + return -EOPNOTSUPP; +} + static int mptcp_getsockopt_sol_mptcp(struct mptcp_sock *msk, int optname, char __user *optval, int __user *optlen) { @@ -1081,6 +1207,8 @@ int mptcp_getsockopt(struct sock *sk, int level, int optname, if (ssk) return tcp_getsockopt(ssk, level, optname, optval, option); + if (level == SOL_IP) + return mptcp_getsockopt_v4(msk, optname, optval, option); if (level == SOL_TCP) return mptcp_getsockopt_sol_tcp(msk, optname, optval, option); if (level == SOL_MPTCP) @@ -1129,6 +1257,8 @@ static void sync_socket_options(struct mptcp_sock *msk, struct sock *ssk) if (inet_csk(sk)->icsk_ca_ops != inet_csk(ssk)->icsk_ca_ops) tcp_set_congestion_control(ssk, msk->ca_name, false, true); + __tcp_sock_set_cork(ssk, !!msk->cork); + __tcp_sock_set_nodelay(ssk, !!msk->nodelay); inet_sk(ssk)->transparent = inet_sk(sk)->transparent; inet_sk(ssk)->freebind = inet_sk(sk)->freebind; diff --git a/tools/testing/selftests/net/mptcp/.gitignore b/tools/testing/selftests/net/mptcp/.gitignore index 7569d892967a..49daae73c41e 100644 --- a/tools/testing/selftests/net/mptcp/.gitignore +++ b/tools/testing/selftests/net/mptcp/.gitignore @@ -1,5 +1,6 @@ # SPDX-License-Identifier: GPL-2.0-only mptcp_connect +mptcp_inq mptcp_sockopt pm_nl_ctl *.pcap diff --git a/tools/testing/selftests/net/mptcp/Makefile b/tools/testing/selftests/net/mptcp/Makefile index bbf4e448bad9..0356c4501c99 100644 --- a/tools/testing/selftests/net/mptcp/Makefile +++ b/tools/testing/selftests/net/mptcp/Makefile @@ -8,7 +8,7 @@ CFLAGS = -Wall -Wl,--no-as-needed -O2 -g -I$(top_srcdir)/usr/include TEST_PROGS := mptcp_connect.sh pm_netlink.sh mptcp_join.sh diag.sh \ simult_flows.sh mptcp_sockopt.sh -TEST_GEN_FILES = mptcp_connect pm_nl_ctl mptcp_sockopt +TEST_GEN_FILES = mptcp_connect pm_nl_ctl mptcp_sockopt mptcp_inq TEST_FILES := settings diff --git a/tools/testing/selftests/net/mptcp/mptcp_connect.c b/tools/testing/selftests/net/mptcp/mptcp_connect.c index ada9b80774d4..98de28ac3ba8 100644 --- a/tools/testing/selftests/net/mptcp/mptcp_connect.c +++ b/tools/testing/selftests/net/mptcp/mptcp_connect.c @@ -73,12 +73,20 @@ static uint32_t cfg_mark; struct cfg_cmsg_types { unsigned int cmsg_enabled:1; unsigned int timestampns:1; + unsigned int tcp_inq:1; }; struct cfg_sockopt_types { unsigned int transparent:1; }; +struct tcp_inq_state { + unsigned int last; + bool expect_eof; +}; + +static struct tcp_inq_state tcp_inq; + static struct cfg_cmsg_types cfg_cmsg_types; static struct cfg_sockopt_types cfg_sockopt_types; @@ -389,7 +397,9 @@ static size_t do_write(const int fd, char *buf, const size_t len) static void process_cmsg(struct msghdr *msgh) { struct __kernel_timespec ts; + bool inq_found = false; bool ts_found = false; + unsigned int inq = 0; struct cmsghdr *cmsg; for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) { @@ -398,12 +408,27 @@ static void process_cmsg(struct msghdr *msgh) ts_found = true; continue; } + if (cmsg->cmsg_level == IPPROTO_TCP && cmsg->cmsg_type == TCP_CM_INQ) { + memcpy(&inq, CMSG_DATA(cmsg), sizeof(inq)); + inq_found = true; + continue; + } + } if (cfg_cmsg_types.timestampns) { if (!ts_found) xerror("TIMESTAMPNS not present\n"); } + + if (cfg_cmsg_types.tcp_inq) { + if (!inq_found) + xerror("TCP_INQ not present\n"); + + if (inq > 1024) + xerror("tcp_inq %u is larger than one kbyte\n", inq); + tcp_inq.last = inq; + } } static ssize_t do_recvmsg_cmsg(const int fd, char *buf, const size_t len) @@ -420,10 +445,23 @@ static ssize_t do_recvmsg_cmsg(const int fd, char *buf, const size_t len) .msg_controllen = sizeof(msg_buf), }; int flags = 0; + unsigned int last_hint = tcp_inq.last; int ret = recvmsg(fd, &msg, flags); - if (ret <= 0) + if (ret <= 0) { + if (ret == 0 && tcp_inq.expect_eof) + return ret; + + if (ret == 0 && cfg_cmsg_types.tcp_inq) + if (last_hint != 1 && last_hint != 0) + xerror("EOF but last tcp_inq hint was %u\n", last_hint); + return ret; + } + + if (tcp_inq.expect_eof) + xerror("expected EOF, last_hint %u, now %u\n", + last_hint, tcp_inq.last); if (msg.msg_controllen && !cfg_cmsg_types.cmsg_enabled) xerror("got %lu bytes of cmsg data, expected 0\n", @@ -435,6 +473,19 @@ static ssize_t do_recvmsg_cmsg(const int fd, char *buf, const size_t len) if (msg.msg_controllen) process_cmsg(&msg); + if (cfg_cmsg_types.tcp_inq) { + if ((size_t)ret < len && last_hint > (unsigned int)ret) { + if (ret + 1 != (int)last_hint) { + int next = read(fd, msg_buf, sizeof(msg_buf)); + + xerror("read %u of %u, last_hint was %u tcp_inq hint now %u next_read returned %d/%m\n", + ret, (unsigned int)len, last_hint, tcp_inq.last, next); + } else { + tcp_inq.expect_eof = true; + } + } + } + return ret; } @@ -944,6 +995,8 @@ static void apply_cmsg_types(int fd, const struct cfg_cmsg_types *cmsg) if (cmsg->timestampns) xsetsockopt(fd, SOL_SOCKET, SO_TIMESTAMPNS_NEW, &on, sizeof(on)); + if (cmsg->tcp_inq) + xsetsockopt(fd, IPPROTO_TCP, TCP_INQ, &on, sizeof(on)); } static void parse_cmsg_types(const char *type) @@ -965,6 +1018,11 @@ static void parse_cmsg_types(const char *type) return; } + if (strncmp(type, "TCPINQ", len) == 0) { + cfg_cmsg_types.tcp_inq = 1; + return; + } + fprintf(stderr, "Unrecognized cmsg option %s\n", type); exit(1); } diff --git a/tools/testing/selftests/net/mptcp/mptcp_inq.c b/tools/testing/selftests/net/mptcp/mptcp_inq.c new file mode 100644 index 000000000000..b8debd4fb5ed --- /dev/null +++ b/tools/testing/selftests/net/mptcp/mptcp_inq.c @@ -0,0 +1,603 @@ +// SPDX-License-Identifier: GPL-2.0 + +#define _GNU_SOURCE + +#include <assert.h> +#include <errno.h> +#include <fcntl.h> +#include <limits.h> +#include <string.h> +#include <stdarg.h> +#include <stdbool.h> +#include <stdint.h> +#include <inttypes.h> +#include <stdio.h> +#include <stdlib.h> +#include <strings.h> +#include <unistd.h> +#include <time.h> + +#include <sys/ioctl.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <sys/wait.h> + +#include <netdb.h> +#include <netinet/in.h> + +#include <linux/tcp.h> +#include <linux/sockios.h> + +#ifndef IPPROTO_MPTCP +#define IPPROTO_MPTCP 262 +#endif +#ifndef SOL_MPTCP +#define SOL_MPTCP 284 +#endif + +static int pf = AF_INET; +static int proto_tx = IPPROTO_MPTCP; +static int proto_rx = IPPROTO_MPTCP; + +static void die_perror(const char *msg) +{ + perror(msg); + exit(1); +} + +static void die_usage(int r) +{ + fprintf(stderr, "Usage: mptcp_inq [-6] [ -t tcp|mptcp ] [ -r tcp|mptcp]\n"); + exit(r); +} + +static void xerror(const char *fmt, ...) +{ + va_list ap; + + va_start(ap, fmt); + vfprintf(stderr, fmt, ap); + va_end(ap); + fputc('\n', stderr); + exit(1); +} + +static const char *getxinfo_strerr(int err) +{ + if (err == EAI_SYSTEM) + return strerror(errno); + + return gai_strerror(err); +} + +static void xgetaddrinfo(const char *node, const char *service, + const struct addrinfo *hints, + struct addrinfo **res) +{ + int err = getaddrinfo(node, service, hints, res); + + if (err) { + const char *errstr = getxinfo_strerr(err); + + fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n", + node ? node : "", service ? service : "", errstr); + exit(1); + } +} + +static int sock_listen_mptcp(const char * const listenaddr, + const char * const port) +{ + int sock; + struct addrinfo hints = { + .ai_protocol = IPPROTO_TCP, + .ai_socktype = SOCK_STREAM, + .ai_flags = AI_PASSIVE | AI_NUMERICHOST + }; + + hints.ai_family = pf; + + struct addrinfo *a, *addr; + int one = 1; + + xgetaddrinfo(listenaddr, port, &hints, &addr); + hints.ai_family = pf; + + for (a = addr; a; a = a->ai_next) { + sock = socket(a->ai_family, a->ai_socktype, proto_rx); + if (sock < 0) + continue; + + if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one, + sizeof(one))) + perror("setsockopt"); + + if (bind(sock, a->ai_addr, a->ai_addrlen) == 0) + break; /* success */ + + perror("bind"); + close(sock); + sock = -1; + } + + freeaddrinfo(addr); + + if (sock < 0) + xerror("could not create listen socket"); + + if (listen(sock, 20)) + die_perror("listen"); + + return sock; +} + +static int sock_connect_mptcp(const char * const remoteaddr, + const char * const port, int proto) +{ + struct addrinfo hints = { + .ai_protocol = IPPROTO_TCP, + .ai_socktype = SOCK_STREAM, + }; + struct addrinfo *a, *addr; + int sock = -1; + + hints.ai_family = pf; + + xgetaddrinfo(remoteaddr, port, &hints, &addr); + for (a = addr; a; a = a->ai_next) { + sock = socket(a->ai_family, a->ai_socktype, proto); + if (sock < 0) + continue; + + if (connect(sock, a->ai_addr, a->ai_addrlen) == 0) + break; /* success */ + + die_perror("connect"); + } + + if (sock < 0) + xerror("could not create connect socket"); + + freeaddrinfo(addr); + return sock; +} + +static int protostr_to_num(const char *s) +{ + if (strcasecmp(s, "tcp") == 0) + return IPPROTO_TCP; + if (strcasecmp(s, "mptcp") == 0) + return IPPROTO_MPTCP; + + die_usage(1); + return 0; +} + +static void parse_opts(int argc, char **argv) +{ + int c; + + while ((c = getopt(argc, argv, "h6t:r:")) != -1) { + switch (c) { + case 'h': + die_usage(0); + break; + case '6': + pf = AF_INET6; + break; + case 't': + proto_tx = protostr_to_num(optarg); + break; + case 'r': + proto_rx = protostr_to_num(optarg); + break; + default: + die_usage(1); + break; + } + } +} + +/* wait up to timeout milliseconds */ +static void wait_for_ack(int fd, int timeout, size_t total) +{ + int i; + + for (i = 0; i < timeout; i++) { + int nsd, ret, queued = -1; + struct timespec req; + + ret = ioctl(fd, TIOCOUTQ, &queued); + if (ret < 0) + die_perror("TIOCOUTQ"); + + ret = ioctl(fd, SIOCOUTQNSD, &nsd); + if (ret < 0) + die_perror("SIOCOUTQNSD"); + + if ((size_t)queued > total) + xerror("TIOCOUTQ %u, but only %zu expected\n", queued, total); + assert(nsd <= queued); + + if (queued == 0) + return; + + /* wait for peer to ack rx of all data */ + req.tv_sec = 0; + req.tv_nsec = 1 * 1000 * 1000ul; /* 1ms */ + nanosleep(&req, NULL); + } + + xerror("still tx data queued after %u ms\n", timeout); +} + +static void connect_one_server(int fd, int unixfd) +{ + size_t len, i, total, sent; + char buf[4096], buf2[4096]; + ssize_t ret; + + len = rand() % (sizeof(buf) - 1); + + if (len < 128) + len = 128; + + for (i = 0; i < len ; i++) { + buf[i] = rand() % 26; + buf[i] += 'A'; + } + + buf[i] = '\n'; + + /* un-block server */ + ret = read(unixfd, buf2, 4); + assert(ret == 4); + + assert(strncmp(buf2, "xmit", 4) == 0); + + ret = write(unixfd, &len, sizeof(len)); + assert(ret == (ssize_t)sizeof(len)); + + ret = write(fd, buf, len); + if (ret < 0) + die_perror("write"); + + if (ret != (ssize_t)len) + xerror("short write"); + + ret = read(unixfd, buf2, 4); + assert(strncmp(buf2, "huge", 4) == 0); + + total = rand() % (16 * 1024 * 1024); + total += (1 * 1024 * 1024); + sent = total; + + ret = write(unixfd, &total, sizeof(total)); + assert(ret == (ssize_t)sizeof(total)); + + wait_for_ack(fd, 5000, len); + + while (total > 0) { + if (total > sizeof(buf)) + len = sizeof(buf); + else + len = total; + + ret = write(fd, buf, len); + if (ret < 0) + die_perror("write"); + total -= ret; + + /* we don't have to care about buf content, only + * number of total bytes sent + */ + } + + ret = read(unixfd, buf2, 4); + assert(ret == 4); + assert(strncmp(buf2, "shut", 4) == 0); + + wait_for_ack(fd, 5000, sent); + + ret = write(fd, buf, 1); + assert(ret == 1); + close(fd); + ret = write(unixfd, "closed", 6); + assert(ret == 6); + + close(unixfd); +} + +static void get_tcp_inq(struct msghdr *msgh, unsigned int *inqv) +{ + struct cmsghdr *cmsg; + + for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) { + if (cmsg->cmsg_level == IPPROTO_TCP && cmsg->cmsg_type == TCP_CM_INQ) { + memcpy(inqv, CMSG_DATA(cmsg), sizeof(*inqv)); + return; + } + } + + xerror("could not find TCP_CM_INQ cmsg type"); +} + +static void process_one_client(int fd, int unixfd) +{ + unsigned int tcp_inq; + size_t expect_len; + char msg_buf[4096]; + char buf[4096]; + char tmp[16]; + struct iovec iov = { + .iov_base = buf, + .iov_len = 1, + }; + struct msghdr msg = { + .msg_iov = &iov, + .msg_iovlen = 1, + .msg_control = msg_buf, + .msg_controllen = sizeof(msg_buf), + }; + ssize_t ret, tot; + + ret = write(unixfd, "xmit", 4); + assert(ret == 4); + + ret = read(unixfd, &expect_len, sizeof(expect_len)); + assert(ret == (ssize_t)sizeof(expect_len)); + + if (expect_len > sizeof(buf)) + xerror("expect len %zu exceeds buffer size", expect_len); + + for (;;) { + struct timespec req; + unsigned int queued; + + ret = ioctl(fd, FIONREAD, &queued); + if (ret < 0) + die_perror("FIONREAD"); + if (queued > expect_len) + xerror("FIONREAD returned %u, but only %zu expected\n", + queued, expect_len); + if (queued == expect_len) + break; + + req.tv_sec = 0; + req.tv_nsec = 1000 * 1000ul; + nanosleep(&req, NULL); + } + + /* read one byte, expect cmsg to return expected - 1 */ + ret = recvmsg(fd, &msg, 0); + if (ret < 0) + die_perror("recvmsg"); + + if (msg.msg_controllen == 0) + xerror("msg_controllen is 0"); + + get_tcp_inq(&msg, &tcp_inq); + + assert((size_t)tcp_inq == (expect_len - 1)); + + iov.iov_len = sizeof(buf); + ret = recvmsg(fd, &msg, 0); + if (ret < 0) + die_perror("recvmsg"); + + /* should have gotten exact remainder of all pending data */ + assert(ret == (ssize_t)tcp_inq); + + /* should be 0, all drained */ + get_tcp_inq(&msg, &tcp_inq); + assert(tcp_inq == 0); + + /* request a large swath of data. */ + ret = write(unixfd, "huge", 4); + assert(ret == 4); + + ret = read(unixfd, &expect_len, sizeof(expect_len)); + assert(ret == (ssize_t)sizeof(expect_len)); + + /* peer should send us a few mb of data */ + if (expect_len <= sizeof(buf)) + xerror("expect len %zu too small\n", expect_len); + + tot = 0; + do { + iov.iov_len = sizeof(buf); + ret = recvmsg(fd, &msg, 0); + if (ret < 0) + die_perror("recvmsg"); + + tot += ret; + + get_tcp_inq(&msg, &tcp_inq); + + if (tcp_inq > expect_len - tot) + xerror("inq %d, remaining %d total_len %d\n", + tcp_inq, expect_len - tot, (int)expect_len); + + assert(tcp_inq <= expect_len - tot); + } while ((size_t)tot < expect_len); + + ret = write(unixfd, "shut", 4); + assert(ret == 4); + + /* wait for hangup. Should have received one more byte of data. */ + ret = read(unixfd, tmp, sizeof(tmp)); + assert(ret == 6); + assert(strncmp(tmp, "closed", 6) == 0); + + sleep(1); + + iov.iov_len = 1; + ret = recvmsg(fd, &msg, 0); + if (ret < 0) + die_perror("recvmsg"); + assert(ret == 1); + + get_tcp_inq(&msg, &tcp_inq); + + /* tcp_inq should be 1 due to received fin. */ + assert(tcp_inq == 1); + + iov.iov_len = 1; + ret = recvmsg(fd, &msg, 0); + if (ret < 0) + die_perror("recvmsg"); + + /* expect EOF */ + assert(ret == 0); + get_tcp_inq(&msg, &tcp_inq); + assert(tcp_inq == 1); + + close(fd); +} + +static int xaccept(int s) +{ + int fd = accept(s, NULL, 0); + + if (fd < 0) + die_perror("accept"); + + return fd; +} + +static int server(int unixfd) +{ + int fd = -1, r, on = 1; + + switch (pf) { + case AF_INET: + fd = sock_listen_mptcp("127.0.0.1", "15432"); + break; + case AF_INET6: + fd = sock_listen_mptcp("::1", "15432"); + break; + default: + xerror("Unknown pf %d\n", pf); + break; + } + + r = write(unixfd, "conn", 4); + assert(r == 4); + + alarm(15); + r = xaccept(fd); + + if (-1 == setsockopt(r, IPPROTO_TCP, TCP_INQ, &on, sizeof(on))) + die_perror("setsockopt"); + + process_one_client(r, unixfd); + + return 0; +} + +static int client(int unixfd) +{ + int fd = -1; + + alarm(15); + + switch (pf) { + case AF_INET: + fd = sock_connect_mptcp("127.0.0.1", "15432", proto_tx); + break; + case AF_INET6: + fd = sock_connect_mptcp("::1", "15432", proto_tx); + break; + default: + xerror("Unknown pf %d\n", pf); + } + + connect_one_server(fd, unixfd); + + return 0; +} + +static void init_rng(void) +{ + int fd = open("/dev/urandom", O_RDONLY); + unsigned int foo; + + if (fd > 0) { + int ret = read(fd, &foo, sizeof(foo)); + + if (ret < 0) + srand(fd + foo); + close(fd); + } + + srand(foo); +} + +static pid_t xfork(void) +{ + pid_t p = fork(); + + if (p < 0) + die_perror("fork"); + else if (p == 0) + init_rng(); + + return p; +} + +static int rcheck(int wstatus, const char *what) +{ + if (WIFEXITED(wstatus)) { + if (WEXITSTATUS(wstatus) == 0) + return 0; + fprintf(stderr, "%s exited, status=%d\n", what, WEXITSTATUS(wstatus)); + return WEXITSTATUS(wstatus); + } else if (WIFSIGNALED(wstatus)) { + xerror("%s killed by signal %d\n", what, WTERMSIG(wstatus)); + } else if (WIFSTOPPED(wstatus)) { + xerror("%s stopped by signal %d\n", what, WSTOPSIG(wstatus)); + } + + return 111; +} + +int main(int argc, char *argv[]) +{ + int e1, e2, wstatus; + pid_t s, c, ret; + int unixfds[2]; + + parse_opts(argc, argv); + + e1 = socketpair(AF_UNIX, SOCK_DGRAM, 0, unixfds); + if (e1 < 0) + die_perror("pipe"); + + s = xfork(); + if (s == 0) + return server(unixfds[1]); + + close(unixfds[1]); + + /* wait until server bound a socket */ + e1 = read(unixfds[0], &e1, 4); + assert(e1 == 4); + + c = xfork(); + if (c == 0) + return client(unixfds[0]); + + close(unixfds[0]); + + ret = waitpid(s, &wstatus, 0); + if (ret == -1) + die_perror("waitpid"); + e1 = rcheck(wstatus, "server"); + ret = waitpid(c, &wstatus, 0); + if (ret == -1) + die_perror("waitpid"); + e2 = rcheck(wstatus, "client"); + + return e1 ? e1 : e2; +} diff --git a/tools/testing/selftests/net/mptcp/mptcp_sockopt.c b/tools/testing/selftests/net/mptcp/mptcp_sockopt.c index 417b11cafafe..ac9a4d9c1764 100644 --- a/tools/testing/selftests/net/mptcp/mptcp_sockopt.c +++ b/tools/testing/selftests/net/mptcp/mptcp_sockopt.c @@ -4,6 +4,7 @@ #include <assert.h> #include <errno.h> +#include <fcntl.h> #include <limits.h> #include <string.h> #include <stdarg.h> @@ -13,6 +14,7 @@ #include <stdio.h> #include <stdlib.h> #include <strings.h> +#include <time.h> #include <unistd.h> #include <sys/socket.h> @@ -594,6 +596,44 @@ static int server(int pipefd) return 0; } +static void test_ip_tos_sockopt(int fd) +{ + uint8_t tos_in, tos_out; + socklen_t s; + int r; + + tos_in = rand() & 0xfc; + r = setsockopt(fd, SOL_IP, IP_TOS, &tos_in, sizeof(tos_out)); + if (r != 0) + die_perror("setsockopt IP_TOS"); + + tos_out = 0; + s = sizeof(tos_out); + r = getsockopt(fd, SOL_IP, IP_TOS, &tos_out, &s); + if (r != 0) + die_perror("getsockopt IP_TOS"); + + if (tos_in != tos_out) + xerror("tos %x != %x socklen_t %d\n", tos_in, tos_out, s); + + if (s != 1) + xerror("tos should be 1 byte"); + + s = 0; + r = getsockopt(fd, SOL_IP, IP_TOS, &tos_out, &s); + if (r != 0) + die_perror("getsockopt IP_TOS 0"); + if (s != 0) + xerror("expect socklen_t == 0"); + + s = -1; + r = getsockopt(fd, SOL_IP, IP_TOS, &tos_out, &s); + if (r != -1 && errno != EINVAL) + die_perror("getsockopt IP_TOS did not indicate -EINVAL"); + if (s != -1) + xerror("expect socklen_t == -1"); +} + static int client(int pipefd) { int fd = -1; @@ -611,6 +651,8 @@ static int client(int pipefd) xerror("Unknown pf %d\n", pf); } + test_ip_tos_sockopt(fd); + connect_one_server(fd, pipefd); return 0; @@ -642,6 +684,25 @@ static int rcheck(int wstatus, const char *what) return 111; } +static void init_rng(void) +{ + int fd = open("/dev/urandom", O_RDONLY); + + if (fd >= 0) { + unsigned int foo; + ssize_t ret; + + /* can't fail */ + ret = read(fd, &foo, sizeof(foo)); + assert(ret == sizeof(foo)); + + close(fd); + srand(foo); + } else { + srand(time(NULL)); + } +} + int main(int argc, char *argv[]) { int e1, e2, wstatus; @@ -650,6 +711,8 @@ int main(int argc, char *argv[]) parse_opts(argc, argv); + init_rng(); + e1 = pipe(pipefds); if (e1 < 0) die_perror("pipe"); diff --git a/tools/testing/selftests/net/mptcp/mptcp_sockopt.sh b/tools/testing/selftests/net/mptcp/mptcp_sockopt.sh index 41de643788b8..0879da915014 100755 --- a/tools/testing/selftests/net/mptcp/mptcp_sockopt.sh +++ b/tools/testing/selftests/net/mptcp/mptcp_sockopt.sh @@ -178,7 +178,7 @@ do_transfer() timeout ${timeout_test} \ ip netns exec ${listener_ns} \ - $mptcp_connect -t ${timeout_poll} -l -M 1 -p $port -s ${srv_proto} -c TIMESTAMPNS \ + $mptcp_connect -t ${timeout_poll} -l -M 1 -p $port -s ${srv_proto} -c TIMESTAMPNS,TCPINQ \ ${local_addr} < "$sin" > "$sout" & spid=$! @@ -186,7 +186,7 @@ do_transfer() timeout ${timeout_test} \ ip netns exec ${connector_ns} \ - $mptcp_connect -t ${timeout_poll} -M 2 -p $port -s ${cl_proto} -c TIMESTAMPNS \ + $mptcp_connect -t ${timeout_poll} -M 2 -p $port -s ${cl_proto} -c TIMESTAMPNS,TCPINQ \ $connect_addr < "$cin" > "$cout" & cpid=$! @@ -279,6 +279,45 @@ run_tests() fi } +do_tcpinq_test() +{ + ip netns exec "$ns1" ./mptcp_inq "$@" + lret=$? + if [ $lret -ne 0 ];then + ret=$lret + echo "FAIL: mptcp_inq $@" 1>&2 + return $lret + fi + + echo "PASS: TCP_INQ cmsg/ioctl $@" + return $lret +} + +do_tcpinq_tests() +{ + local lret=0 + + ip netns exec "$ns1" iptables -F + ip netns exec "$ns1" ip6tables -F + + for args in "-t tcp" "-r tcp"; do + do_tcpinq_test $args + lret=$? + if [ $lret -ne 0 ] ; then + return $lret + fi + do_tcpinq_test -6 $args + lret=$? + if [ $lret -ne 0 ] ; then + return $lret + fi + done + + do_tcpinq_test -r tcp -t tcp + + return $? +} + sin=$(mktemp) sout=$(mktemp) cin=$(mktemp) @@ -300,4 +339,5 @@ if [ $ret -eq 0 ];then echo "PASS: SOL_MPTCP getsockopt has expected information" fi +do_tcpinq_tests exit $ret |