diff options
Diffstat (limited to 'net/mptcp/pm_netlink.c')
-rw-r--r-- | net/mptcp/pm_netlink.c | 291 |
1 files changed, 242 insertions, 49 deletions
diff --git a/net/mptcp/pm_netlink.c b/net/mptcp/pm_netlink.c index 83976b9ee99b..e7b1abb4f0c2 100644 --- a/net/mptcp/pm_netlink.c +++ b/net/mptcp/pm_netlink.c @@ -26,6 +26,7 @@ struct mptcp_pm_addr_entry { struct list_head list; struct mptcp_addr_info addr; struct rcu_head rcu; + struct socket *lsk; }; struct mptcp_pm_add_entry { @@ -90,14 +91,14 @@ static bool address_zero(const struct mptcp_addr_info *addr) memset(&zero, 0, sizeof(zero)); zero.family = addr->family; - return addresses_equal(addr, &zero, false); + return addresses_equal(addr, &zero, true); } static void local_address(const struct sock_common *skc, struct mptcp_addr_info *addr) { - addr->port = 0; addr->family = skc->skc_family; + addr->port = htons(skc->skc_num); if (addr->family == AF_INET) addr->addr.s_addr = skc->skc_rcv_saddr; #if IS_ENABLED(CONFIG_MPTCP_IPV6) @@ -130,7 +131,7 @@ static bool lookup_subflow_by_saddr(const struct list_head *list, skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow); local_address(skc, &cur); - if (addresses_equal(&cur, saddr, false)) + if (addresses_equal(&cur, saddr, saddr->port)) return true; } @@ -196,11 +197,46 @@ select_signal_address(struct pm_nl_pernet *pernet, unsigned int pos) return ret; } +unsigned int mptcp_pm_get_add_addr_signal_max(struct mptcp_sock *msk) +{ + struct pm_nl_pernet *pernet; + + pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); + return READ_ONCE(pernet->add_addr_signal_max); +} +EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_signal_max); + +unsigned int mptcp_pm_get_add_addr_accept_max(struct mptcp_sock *msk) +{ + struct pm_nl_pernet *pernet; + + pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); + return READ_ONCE(pernet->add_addr_accept_max); +} +EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_accept_max); + +unsigned int mptcp_pm_get_subflows_max(struct mptcp_sock *msk) +{ + struct pm_nl_pernet *pernet; + + pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); + return READ_ONCE(pernet->subflows_max); +} +EXPORT_SYMBOL_GPL(mptcp_pm_get_subflows_max); + +static unsigned int mptcp_pm_get_local_addr_max(struct mptcp_sock *msk) +{ + struct pm_nl_pernet *pernet; + + pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); + return READ_ONCE(pernet->local_addr_max); +} + static void check_work_pending(struct mptcp_sock *msk) { - if (msk->pm.add_addr_signaled == msk->pm.add_addr_signal_max && - (msk->pm.local_addr_used == msk->pm.local_addr_max || - msk->pm.subflows == msk->pm.subflows_max)) + if (msk->pm.add_addr_signaled == mptcp_pm_get_add_addr_signal_max(msk) && + (msk->pm.local_addr_used == mptcp_pm_get_local_addr_max(msk) || + msk->pm.subflows == mptcp_pm_get_subflows_max(msk))) WRITE_ONCE(msk->pm.work_pending, false); } @@ -211,13 +247,34 @@ lookup_anno_list_by_saddr(struct mptcp_sock *msk, struct mptcp_pm_add_entry *entry; list_for_each_entry(entry, &msk->pm.anno_list, list) { - if (addresses_equal(&entry->addr, addr, false)) + if (addresses_equal(&entry->addr, addr, true)) return entry; } return NULL; } +bool mptcp_pm_sport_in_anno_list(struct mptcp_sock *msk, const struct sock *sk) +{ + struct mptcp_pm_add_entry *entry; + struct mptcp_addr_info saddr; + bool ret = false; + + local_address((struct sock_common *)sk, &saddr); + + spin_lock_bh(&msk->pm.lock); + list_for_each_entry(entry, &msk->pm.anno_list, list) { + if (addresses_equal(&entry->addr, &saddr, true)) { + ret = true; + goto out; + } + } + +out: + spin_unlock_bh(&msk->pm.lock); + return ret; +} + static void mptcp_pm_add_timer(struct timer_list *timer) { struct mptcp_pm_add_entry *entry = from_timer(entry, timer, add_timer); @@ -327,17 +384,24 @@ static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk) { struct sock *sk = (struct sock *)msk; struct mptcp_pm_addr_entry *local; + unsigned int add_addr_signal_max; + unsigned int local_addr_max; struct pm_nl_pernet *pernet; + unsigned int subflows_max; pernet = net_generic(sock_net(sk), pm_nl_pernet_id); + add_addr_signal_max = mptcp_pm_get_add_addr_signal_max(msk); + local_addr_max = mptcp_pm_get_local_addr_max(msk); + subflows_max = mptcp_pm_get_subflows_max(msk); + pr_debug("local %d:%d signal %d:%d subflows %d:%d\n", - msk->pm.local_addr_used, msk->pm.local_addr_max, - msk->pm.add_addr_signaled, msk->pm.add_addr_signal_max, - msk->pm.subflows, msk->pm.subflows_max); + msk->pm.local_addr_used, local_addr_max, + msk->pm.add_addr_signaled, add_addr_signal_max, + msk->pm.subflows, subflows_max); /* check first for announce */ - if (msk->pm.add_addr_signaled < msk->pm.add_addr_signal_max) { + if (msk->pm.add_addr_signaled < add_addr_signal_max) { local = select_signal_address(pernet, msk->pm.add_addr_signaled); @@ -349,15 +413,15 @@ static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk) } } else { /* pick failed, avoid fourther attempts later */ - msk->pm.local_addr_used = msk->pm.add_addr_signal_max; + msk->pm.local_addr_used = add_addr_signal_max; } check_work_pending(msk); } /* check if should create a new subflow */ - if (msk->pm.local_addr_used < msk->pm.local_addr_max && - msk->pm.subflows < msk->pm.subflows_max) { + if (msk->pm.local_addr_used < local_addr_max && + msk->pm.subflows < subflows_max) { local = select_local_address(pernet, msk); if (local) { struct mptcp_addr_info remote = { 0 }; @@ -373,7 +437,7 @@ static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk) } /* lookup failed, avoid fourther attempts later */ - msk->pm.local_addr_used = msk->pm.local_addr_max; + msk->pm.local_addr_used = local_addr_max; check_work_pending(msk); } } @@ -391,17 +455,22 @@ void mptcp_pm_nl_subflow_established(struct mptcp_sock *msk) void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk) { struct sock *sk = (struct sock *)msk; + unsigned int add_addr_accept_max; struct mptcp_addr_info remote; struct mptcp_addr_info local; + unsigned int subflows_max; bool use_port = false; + add_addr_accept_max = mptcp_pm_get_add_addr_accept_max(msk); + subflows_max = mptcp_pm_get_subflows_max(msk); + pr_debug("accepted %d:%d remote family %d", - msk->pm.add_addr_accepted, msk->pm.add_addr_accept_max, + msk->pm.add_addr_accepted, add_addr_accept_max, msk->pm.remote.family); msk->pm.add_addr_accepted++; msk->pm.subflows++; - if (msk->pm.add_addr_accepted >= msk->pm.add_addr_accept_max || - msk->pm.subflows >= msk->pm.subflows_max) + if (msk->pm.add_addr_accepted >= add_addr_accept_max || + msk->pm.subflows >= subflows_max) WRITE_ONCE(msk->pm.accept_addr, false); /* connect to the specified remote address, using whatever @@ -427,8 +496,7 @@ void mptcp_pm_nl_add_addr_send_ack(struct mptcp_sock *msk) { struct mptcp_subflow_context *subflow; - if (!mptcp_pm_should_add_signal_ipv6(msk) && - !mptcp_pm_should_add_signal_port(msk)) + if (!mptcp_pm_should_add_signal(msk)) return; __mptcp_flush_join_list(msk); @@ -438,10 +506,9 @@ void mptcp_pm_nl_add_addr_send_ack(struct mptcp_sock *msk) u8 add_addr; spin_unlock_bh(&msk->pm.lock); - if (mptcp_pm_should_add_signal_ipv6(msk)) - pr_debug("send ack for add_addr6"); - if (mptcp_pm_should_add_signal_port(msk)) - pr_debug("send ack for add_addr_port"); + pr_debug("send ack for add_addr%s%s", + mptcp_pm_should_add_signal_ipv6(msk) ? " [ipv6]" : "", + mptcp_pm_should_add_signal_port(msk) ? " [port]" : ""); lock_sock(ssk); tcp_send_ack(ssk); @@ -572,6 +639,7 @@ static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet, struct mptcp_pm_addr_entry *entry) { struct mptcp_pm_addr_entry *cur; + unsigned int addr_max; int ret = -EINVAL; spin_lock_bh(&pernet->lock); @@ -614,10 +682,14 @@ find_next: if (entry->addr.id > pernet->next_id) pernet->next_id = entry->addr.id; - if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL) - pernet->add_addr_signal_max++; - if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) - pernet->local_addr_max++; + if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL) { + addr_max = pernet->add_addr_signal_max; + WRITE_ONCE(pernet->add_addr_signal_max, addr_max + 1); + } + if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) { + addr_max = pernet->local_addr_max; + WRITE_ONCE(pernet->local_addr_max, addr_max + 1); + } pernet->addrs++; list_add_tail_rcu(&entry->list, &pernet->local_addr_list); @@ -628,6 +700,53 @@ out: return ret; } +static int mptcp_pm_nl_create_listen_socket(struct sock *sk, + struct mptcp_pm_addr_entry *entry) +{ + struct sockaddr_storage addr; + struct mptcp_sock *msk; + struct socket *ssock; + int backlog = 1024; + int err; + + err = sock_create_kern(sock_net(sk), entry->addr.family, + SOCK_STREAM, IPPROTO_MPTCP, &entry->lsk); + if (err) + return err; + + msk = mptcp_sk(entry->lsk->sk); + if (!msk) { + err = -EINVAL; + goto out; + } + + ssock = __mptcp_nmpc_socket(msk); + if (!ssock) { + err = -EINVAL; + goto out; + } + + mptcp_info2sockaddr(&entry->addr, &addr, entry->addr.family); + err = kernel_bind(ssock, (struct sockaddr *)&addr, + sizeof(struct sockaddr_in)); + if (err) { + pr_warn("kernel_bind error, err=%d", err); + goto out; + } + + err = kernel_listen(ssock, backlog); + if (err) { + pr_warn("kernel_listen error, err=%d", err); + goto out; + } + + return 0; + +out: + sock_release(entry->lsk); + return err; +} + int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc) { struct mptcp_pm_addr_entry *entry; @@ -654,7 +773,7 @@ int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc) rcu_read_lock(); list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) { - if (addresses_equal(&entry->addr, &skc_local, false)) { + if (addresses_equal(&entry->addr, &skc_local, entry->addr.port)) { ret = entry->addr.id; break; } @@ -672,6 +791,8 @@ int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc) entry->addr.ifindex = 0; entry->addr.flags = 0; entry->addr.id = 0; + entry->addr.port = 0; + entry->lsk = NULL; ret = mptcp_pm_nl_append_new_local_addr(pernet, entry); if (ret < 0) kfree(entry); @@ -682,19 +803,12 @@ int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc) void mptcp_pm_nl_data_init(struct mptcp_sock *msk) { struct mptcp_pm_data *pm = &msk->pm; - struct pm_nl_pernet *pernet; bool subflows; - pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); - - pm->add_addr_signal_max = READ_ONCE(pernet->add_addr_signal_max); - pm->add_addr_accept_max = READ_ONCE(pernet->add_addr_accept_max); - pm->local_addr_max = READ_ONCE(pernet->local_addr_max); - pm->subflows_max = READ_ONCE(pernet->subflows_max); - subflows = !!pm->subflows_max; - WRITE_ONCE(pm->work_pending, (!!pm->local_addr_max && subflows) || - !!pm->add_addr_signal_max); - WRITE_ONCE(pm->accept_addr, !!pm->add_addr_accept_max && subflows); + subflows = !!mptcp_pm_get_subflows_max(msk); + WRITE_ONCE(pm->work_pending, (!!mptcp_pm_get_local_addr_max(msk) && subflows) || + !!mptcp_pm_get_add_addr_signal_max(msk)); + WRITE_ONCE(pm->accept_addr, !!mptcp_pm_get_add_addr_accept_max(msk) && subflows); WRITE_ONCE(pm->accept_subflow, subflows); } @@ -797,6 +911,9 @@ skip_family: if (tb[MPTCP_PM_ADDR_ATTR_FLAGS]) entry->addr.flags = nla_get_u32(tb[MPTCP_PM_ADDR_ATTR_FLAGS]); + if (tb[MPTCP_PM_ADDR_ATTR_PORT]) + entry->addr.port = htons(nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_PORT])); + return 0; } @@ -805,6 +922,31 @@ static struct pm_nl_pernet *genl_info_pm_nl(struct genl_info *info) return net_generic(genl_info_net(info), pm_nl_pernet_id); } +static int mptcp_nl_add_subflow_or_signal_addr(struct net *net) +{ + struct mptcp_sock *msk; + long s_slot = 0, s_num = 0; + + while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) { + struct sock *sk = (struct sock *)msk; + + if (!READ_ONCE(msk->fully_established)) + goto next; + + lock_sock(sk); + spin_lock_bh(&msk->pm.lock); + mptcp_pm_create_subflow_or_signal_addr(msk); + spin_unlock_bh(&msk->pm.lock); + release_sock(sk); + +next: + sock_put(sk); + cond_resched(); + } + + return 0; +} + static int mptcp_nl_cmd_add_addr(struct sk_buff *skb, struct genl_info *info) { struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; @@ -823,13 +965,25 @@ static int mptcp_nl_cmd_add_addr(struct sk_buff *skb, struct genl_info *info) } *entry = addr; + if (entry->addr.port) { + ret = mptcp_pm_nl_create_listen_socket(skb->sk, entry); + if (ret) { + GENL_SET_ERR_MSG(info, "create listen socket error"); + kfree(entry); + return ret; + } + } ret = mptcp_pm_nl_append_new_local_addr(pernet, entry); if (ret < 0) { GENL_SET_ERR_MSG(info, "too many addresses or duplicate one"); + if (entry->lsk) + sock_release(entry->lsk); kfree(entry); return ret; } + mptcp_nl_add_subflow_or_signal_addr(sock_net(skb->sk)); + return 0; } @@ -907,11 +1061,44 @@ next: return 0; } +struct addr_entry_release_work { + struct rcu_work rwork; + struct mptcp_pm_addr_entry *entry; +}; + +static void mptcp_pm_release_addr_entry(struct work_struct *work) +{ + struct addr_entry_release_work *w; + struct mptcp_pm_addr_entry *entry; + + w = container_of(to_rcu_work(work), struct addr_entry_release_work, rwork); + entry = w->entry; + if (entry) { + if (entry->lsk) + sock_release(entry->lsk); + kfree(entry); + } + kfree(w); +} + +static void mptcp_pm_free_addr_entry(struct mptcp_pm_addr_entry *entry) +{ + struct addr_entry_release_work *w; + + w = kmalloc(sizeof(*w), GFP_ATOMIC); + if (w) { + INIT_RCU_WORK(&w->rwork, mptcp_pm_release_addr_entry); + w->entry = entry; + queue_rcu_work(system_wq, &w->rwork); + } +} + static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info) { struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; struct pm_nl_pernet *pernet = genl_info_pm_nl(info); struct mptcp_pm_addr_entry addr, *entry; + unsigned int addr_max; int ret; ret = mptcp_pm_parse_addr(attr, info, false, &addr); @@ -925,10 +1112,14 @@ static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info) spin_unlock_bh(&pernet->lock); return -EINVAL; } - if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL) - pernet->add_addr_signal_max--; - if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) - pernet->local_addr_max--; + if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL) { + addr_max = pernet->add_addr_signal_max; + WRITE_ONCE(pernet->add_addr_signal_max, addr_max - 1); + } + if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) { + addr_max = pernet->local_addr_max; + WRITE_ONCE(pernet->local_addr_max, addr_max - 1); + } pernet->addrs--; list_del_rcu(&entry->list); @@ -936,7 +1127,7 @@ static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info) spin_unlock_bh(&pernet->lock); mptcp_nl_remove_subflow_and_signal_addr(sock_net(skb->sk), &entry->addr); - kfree_rcu(entry, rcu); + mptcp_pm_free_addr_entry(entry); return ret; } @@ -950,15 +1141,15 @@ static void __flush_addrs(struct net *net, struct list_head *list) struct mptcp_pm_addr_entry, list); mptcp_nl_remove_subflow_and_signal_addr(net, &cur->addr); list_del_rcu(&cur->list); - kfree_rcu(cur, rcu); + mptcp_pm_free_addr_entry(cur); } } static void __reset_counters(struct pm_nl_pernet *pernet) { - pernet->add_addr_signal_max = 0; - pernet->add_addr_accept_max = 0; - pernet->local_addr_max = 0; + WRITE_ONCE(pernet->add_addr_signal_max, 0); + WRITE_ONCE(pernet->add_addr_accept_max, 0); + WRITE_ONCE(pernet->local_addr_max, 0); pernet->addrs = 0; } @@ -989,6 +1180,8 @@ static int mptcp_nl_fill_addr(struct sk_buff *skb, if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_FAMILY, addr->family)) goto nla_put_failure; + if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_PORT, ntohs(addr->port))) + goto nla_put_failure; if (nla_put_u8(skb, MPTCP_PM_ADDR_ATTR_ID, addr->id)) goto nla_put_failure; if (nla_put_u32(skb, MPTCP_PM_ADDR_ATTR_FLAGS, entry->addr.flags)) |