summaryrefslogtreecommitdiff
path: root/net/sctp
diff options
context:
space:
mode:
Diffstat (limited to 'net/sctp')
-rw-r--r--net/sctp/input.c19
-rw-r--r--net/sctp/ipv6.c2
-rw-r--r--net/sctp/protocol.c4
-rw-r--r--net/sctp/sm_statefuns.c23
-rw-r--r--net/sctp/socket.c31
-rw-r--r--net/sctp/sysctl.c40
6 files changed, 68 insertions, 51 deletions
diff --git a/net/sctp/input.c b/net/sctp/input.c
index 17fcaa9b0df9..a8a254a5008e 100644
--- a/net/sctp/input.c
+++ b/net/sctp/input.c
@@ -735,15 +735,19 @@ static int __sctp_hash_endpoint(struct sctp_endpoint *ep)
struct sock *sk = ep->base.sk;
struct net *net = sock_net(sk);
struct sctp_hashbucket *head;
+ int err = 0;
ep->hashent = sctp_ep_hashfn(net, ep->base.bind_addr.port);
head = &sctp_ep_hashtable[ep->hashent];
+ write_lock(&head->lock);
if (sk->sk_reuseport) {
bool any = sctp_is_ep_boundall(sk);
struct sctp_endpoint *ep2;
struct list_head *list;
- int cnt = 0, err = 1;
+ int cnt = 0;
+
+ err = 1;
list_for_each(list, &ep->base.bind_addr.address_list)
cnt++;
@@ -761,24 +765,24 @@ static int __sctp_hash_endpoint(struct sctp_endpoint *ep)
if (!err) {
err = reuseport_add_sock(sk, sk2, any);
if (err)
- return err;
+ goto out;
break;
} else if (err < 0) {
- return err;
+ goto out;
}
}
if (err) {
err = reuseport_alloc(sk, any);
if (err)
- return err;
+ goto out;
}
}
- write_lock(&head->lock);
hlist_add_head(&ep->node, &head->chain);
+out:
write_unlock(&head->lock);
- return 0;
+ return err;
}
/* Add an endpoint to the hash. Local BH-safe. */
@@ -803,10 +807,9 @@ static void __sctp_unhash_endpoint(struct sctp_endpoint *ep)
head = &sctp_ep_hashtable[ep->hashent];
+ write_lock(&head->lock);
if (rcu_access_pointer(sk->sk_reuseport_cb))
reuseport_detach_sock(sk);
-
- write_lock(&head->lock);
hlist_del_init(&ep->node);
write_unlock(&head->lock);
}
diff --git a/net/sctp/ipv6.c b/net/sctp/ipv6.c
index 24368f755ab1..f7b809c0d142 100644
--- a/net/sctp/ipv6.c
+++ b/net/sctp/ipv6.c
@@ -415,7 +415,7 @@ out:
if (!IS_ERR_OR_NULL(dst)) {
struct rt6_info *rt;
- rt = (struct rt6_info *)dst;
+ rt = dst_rt6_info(dst);
t->dst_cookie = rt6_get_cookie(rt);
pr_debug("rt6_dst:%pI6/%d rt6_src:%pI6\n",
&rt->rt6i_dst.addr, rt->rt6i_dst.plen,
diff --git a/net/sctp/protocol.c b/net/sctp/protocol.c
index e849f368ed91..5a7436a13b74 100644
--- a/net/sctp/protocol.c
+++ b/net/sctp/protocol.c
@@ -552,7 +552,7 @@ static void sctp_v4_get_saddr(struct sctp_sock *sk,
struct flowi *fl)
{
union sctp_addr *saddr = &t->saddr;
- struct rtable *rt = (struct rtable *)t->dst;
+ struct rtable *rt = dst_rtable(t->dst);
if (rt) {
saddr->v4.sin_family = AF_INET;
@@ -1085,7 +1085,7 @@ static inline int sctp_v4_xmit(struct sk_buff *skb, struct sctp_transport *t)
skb_reset_inner_mac_header(skb);
skb_reset_inner_transport_header(skb);
skb_set_inner_ipproto(skb, IPPROTO_SCTP);
- udp_tunnel_xmit_skb((struct rtable *)dst, sk, skb, fl4->saddr,
+ udp_tunnel_xmit_skb(dst_rtable(dst), sk, skb, fl4->saddr,
fl4->daddr, dscp, ip4_dst_hoplimit(dst), df,
sctp_sk(sk)->udp_port, t->encap_port, false, false);
return 0;
diff --git a/net/sctp/sm_statefuns.c b/net/sctp/sm_statefuns.c
index 08fdf1251f46..7d315a18612b 100644
--- a/net/sctp/sm_statefuns.c
+++ b/net/sctp/sm_statefuns.c
@@ -38,6 +38,7 @@
#include <linux/inet.h>
#include <linux/slab.h>
#include <net/sock.h>
+#include <net/proto_memory.h>
#include <net/inet_ecn.h>
#include <linux/skbuff.h>
#include <net/sctp/sctp.h>
@@ -2259,12 +2260,6 @@ enum sctp_disposition sctp_sf_do_5_2_4_dupcook(
}
}
- /* Update socket peer label if first association. */
- if (security_sctp_assoc_request(new_asoc, chunk->head_skb ?: chunk->skb)) {
- sctp_association_free(new_asoc);
- return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands);
- }
-
/* Set temp so that it won't be added into hashtable */
new_asoc->temp = 1;
@@ -2273,6 +2268,22 @@ enum sctp_disposition sctp_sf_do_5_2_4_dupcook(
*/
action = sctp_tietags_compare(new_asoc, asoc);
+ /* In cases C and E the association doesn't enter the ESTABLISHED
+ * state, so there is no need to call security_sctp_assoc_request().
+ */
+ switch (action) {
+ case 'A': /* Association restart. */
+ case 'B': /* Collision case B. */
+ case 'D': /* Collision case D. */
+ /* Update socket peer label if first association. */
+ if (security_sctp_assoc_request((struct sctp_association *)asoc,
+ chunk->head_skb ?: chunk->skb)) {
+ sctp_association_free(new_asoc);
+ return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands);
+ }
+ break;
+ }
+
switch (action) {
case 'A': /* Association restart. */
retval = sctp_sf_do_dupcook_a(net, ep, asoc, chunk, commands,
diff --git a/net/sctp/socket.c b/net/sctp/socket.c
index c67679a41044..32f76f1298da 100644
--- a/net/sctp/socket.c
+++ b/net/sctp/socket.c
@@ -4834,10 +4834,14 @@ int sctp_inet_connect(struct socket *sock, struct sockaddr *uaddr,
return sctp_connect(sock->sk, uaddr, addr_len, flags);
}
-/* FIXME: Write comments. */
+/* Only called when shutdown a listening SCTP socket. */
static int sctp_disconnect(struct sock *sk, int flags)
{
- return -EOPNOTSUPP; /* STUB */
+ if (!sctp_style(sk, TCP))
+ return -EOPNOTSUPP;
+
+ sk->sk_shutdown |= RCV_SHUTDOWN;
+ return 0;
}
/* 4.1.4 accept() - TCP Style Syntax
@@ -4847,7 +4851,7 @@ static int sctp_disconnect(struct sock *sk, int flags)
* descriptor will be returned from accept() to represent the newly
* formed association.
*/
-static struct sock *sctp_accept(struct sock *sk, int flags, int *err, bool kern)
+static struct sock *sctp_accept(struct sock *sk, struct proto_accept_arg *arg)
{
struct sctp_sock *sp;
struct sctp_endpoint *ep;
@@ -4866,12 +4870,13 @@ static struct sock *sctp_accept(struct sock *sk, int flags, int *err, bool kern)
goto out;
}
- if (!sctp_sstate(sk, LISTENING)) {
+ if (!sctp_sstate(sk, LISTENING) ||
+ (sk->sk_shutdown & RCV_SHUTDOWN)) {
error = -EINVAL;
goto out;
}
- timeo = sock_rcvtimeo(sk, flags & O_NONBLOCK);
+ timeo = sock_rcvtimeo(sk, arg->flags & O_NONBLOCK);
error = sctp_wait_for_accept(sk, timeo);
if (error)
@@ -4882,7 +4887,7 @@ static struct sock *sctp_accept(struct sock *sk, int flags, int *err, bool kern)
*/
asoc = list_entry(ep->asocs.next, struct sctp_association, asocs);
- newsk = sp->pf->create_accept_sk(sk, asoc, kern);
+ newsk = sp->pf->create_accept_sk(sk, asoc, arg->kern);
if (!newsk) {
error = -ENOMEM;
goto out;
@@ -4899,7 +4904,7 @@ static struct sock *sctp_accept(struct sock *sk, int flags, int *err, bool kern)
out:
release_sock(sk);
- *err = error;
+ arg->err = error;
return newsk;
}
@@ -7119,6 +7124,7 @@ static int sctp_getsockopt_assoc_ids(struct sock *sk, int len,
struct sctp_sock *sp = sctp_sk(sk);
struct sctp_association *asoc;
struct sctp_assoc_ids *ids;
+ size_t ids_size;
u32 num = 0;
if (sctp_style(sk, TCP))
@@ -7131,11 +7137,11 @@ static int sctp_getsockopt_assoc_ids(struct sock *sk, int len,
num++;
}
- if (len < sizeof(struct sctp_assoc_ids) + sizeof(sctp_assoc_t) * num)
+ ids_size = struct_size(ids, gaids_assoc_id, num);
+ if (len < ids_size)
return -EINVAL;
- len = sizeof(struct sctp_assoc_ids) + sizeof(sctp_assoc_t) * num;
-
+ len = ids_size;
ids = kmalloc(len, GFP_USER | __GFP_NOWARN);
if (unlikely(!ids))
return -ENOMEM;
@@ -9276,7 +9282,7 @@ void sctp_data_ready(struct sock *sk)
if (skwq_has_sleeper(wq))
wake_up_interruptible_sync_poll(&wq->wait, EPOLLIN |
EPOLLRDNORM | EPOLLRDBAND);
- sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_IN);
+ sk_wake_async_rcu(sk, SOCK_WAKE_WAITD, POLL_IN);
rcu_read_unlock();
}
@@ -9392,7 +9398,8 @@ static int sctp_wait_for_accept(struct sock *sk, long timeo)
}
err = -EINVAL;
- if (!sctp_sstate(sk, LISTENING))
+ if (!sctp_sstate(sk, LISTENING) ||
+ (sk->sk_shutdown & RCV_SHUTDOWN))
break;
err = 0;
diff --git a/net/sctp/sysctl.c b/net/sctp/sysctl.c
index f65d6f92afcb..e5a5af343c4c 100644
--- a/net/sctp/sysctl.c
+++ b/net/sctp/sysctl.c
@@ -43,19 +43,19 @@ static unsigned long max_autoclose_max =
(MAX_SCHEDULE_TIMEOUT / HZ > UINT_MAX)
? UINT_MAX : MAX_SCHEDULE_TIMEOUT / HZ;
-static int proc_sctp_do_hmac_alg(struct ctl_table *ctl, int write,
+static int proc_sctp_do_hmac_alg(const struct ctl_table *ctl, int write,
void *buffer, size_t *lenp, loff_t *ppos);
-static int proc_sctp_do_rto_min(struct ctl_table *ctl, int write,
+static int proc_sctp_do_rto_min(const struct ctl_table *ctl, int write,
void *buffer, size_t *lenp, loff_t *ppos);
-static int proc_sctp_do_rto_max(struct ctl_table *ctl, int write, void *buffer,
+static int proc_sctp_do_rto_max(const struct ctl_table *ctl, int write, void *buffer,
size_t *lenp, loff_t *ppos);
-static int proc_sctp_do_udp_port(struct ctl_table *ctl, int write, void *buffer,
+static int proc_sctp_do_udp_port(const struct ctl_table *ctl, int write, void *buffer,
size_t *lenp, loff_t *ppos);
-static int proc_sctp_do_alpha_beta(struct ctl_table *ctl, int write,
+static int proc_sctp_do_alpha_beta(const struct ctl_table *ctl, int write,
void *buffer, size_t *lenp, loff_t *ppos);
-static int proc_sctp_do_auth(struct ctl_table *ctl, int write,
+static int proc_sctp_do_auth(const struct ctl_table *ctl, int write,
void *buffer, size_t *lenp, loff_t *ppos);
-static int proc_sctp_do_probe_interval(struct ctl_table *ctl, int write,
+static int proc_sctp_do_probe_interval(const struct ctl_table *ctl, int write,
void *buffer, size_t *lenp, loff_t *ppos);
static struct ctl_table sctp_table[] = {
@@ -80,8 +80,6 @@ static struct ctl_table sctp_table[] = {
.mode = 0644,
.proc_handler = proc_dointvec,
},
-
- { /* sentinel */ }
};
/* The following index defines are used in sctp_sysctl_net_register().
@@ -384,11 +382,9 @@ static struct ctl_table sctp_net_table[] = {
.extra1 = SYSCTL_ZERO,
.extra2 = &pf_expose_max,
},
-
- { /* sentinel */ }
};
-static int proc_sctp_do_hmac_alg(struct ctl_table *ctl, int write,
+static int proc_sctp_do_hmac_alg(const struct ctl_table *ctl, int write,
void *buffer, size_t *lenp, loff_t *ppos)
{
struct net *net = current->nsproxy->net_ns;
@@ -433,7 +429,7 @@ static int proc_sctp_do_hmac_alg(struct ctl_table *ctl, int write,
return ret;
}
-static int proc_sctp_do_rto_min(struct ctl_table *ctl, int write,
+static int proc_sctp_do_rto_min(const struct ctl_table *ctl, int write,
void *buffer, size_t *lenp, loff_t *ppos)
{
struct net *net = current->nsproxy->net_ns;
@@ -461,7 +457,7 @@ static int proc_sctp_do_rto_min(struct ctl_table *ctl, int write,
return ret;
}
-static int proc_sctp_do_rto_max(struct ctl_table *ctl, int write,
+static int proc_sctp_do_rto_max(const struct ctl_table *ctl, int write,
void *buffer, size_t *lenp, loff_t *ppos)
{
struct net *net = current->nsproxy->net_ns;
@@ -489,7 +485,7 @@ static int proc_sctp_do_rto_max(struct ctl_table *ctl, int write,
return ret;
}
-static int proc_sctp_do_alpha_beta(struct ctl_table *ctl, int write,
+static int proc_sctp_do_alpha_beta(const struct ctl_table *ctl, int write,
void *buffer, size_t *lenp, loff_t *ppos)
{
if (write)
@@ -499,7 +495,7 @@ static int proc_sctp_do_alpha_beta(struct ctl_table *ctl, int write,
return proc_dointvec_minmax(ctl, write, buffer, lenp, ppos);
}
-static int proc_sctp_do_auth(struct ctl_table *ctl, int write,
+static int proc_sctp_do_auth(const struct ctl_table *ctl, int write,
void *buffer, size_t *lenp, loff_t *ppos)
{
struct net *net = current->nsproxy->net_ns;
@@ -528,7 +524,7 @@ static int proc_sctp_do_auth(struct ctl_table *ctl, int write,
return ret;
}
-static int proc_sctp_do_udp_port(struct ctl_table *ctl, int write,
+static int proc_sctp_do_udp_port(const struct ctl_table *ctl, int write,
void *buffer, size_t *lenp, loff_t *ppos)
{
struct net *net = current->nsproxy->net_ns;
@@ -569,7 +565,7 @@ static int proc_sctp_do_udp_port(struct ctl_table *ctl, int write,
return ret;
}
-static int proc_sctp_do_probe_interval(struct ctl_table *ctl, int write,
+static int proc_sctp_do_probe_interval(const struct ctl_table *ctl, int write,
void *buffer, size_t *lenp, loff_t *ppos)
{
struct net *net = current->nsproxy->net_ns;
@@ -597,6 +593,7 @@ static int proc_sctp_do_probe_interval(struct ctl_table *ctl, int write,
int sctp_sysctl_net_register(struct net *net)
{
+ size_t table_size = ARRAY_SIZE(sctp_net_table);
struct ctl_table *table;
int i;
@@ -604,7 +601,7 @@ int sctp_sysctl_net_register(struct net *net)
if (!table)
return -ENOMEM;
- for (i = 0; table[i].data; i++)
+ for (i = 0; i < table_size; i++)
table[i].data += (char *)(&net->sctp) - (char *)&init_net.sctp;
table[SCTP_RTO_MIN_IDX].extra2 = &net->sctp.rto_max;
@@ -613,8 +610,7 @@ int sctp_sysctl_net_register(struct net *net)
table[SCTP_PS_RETRANS_IDX].extra1 = &net->sctp.pf_retrans;
net->sctp.sysctl_header = register_net_sysctl_sz(net, "net/sctp",
- table,
- ARRAY_SIZE(sctp_net_table));
+ table, table_size);
if (net->sctp.sysctl_header == NULL) {
kfree(table);
return -ENOMEM;
@@ -624,7 +620,7 @@ int sctp_sysctl_net_register(struct net *net)
void sctp_sysctl_net_unregister(struct net *net)
{
- struct ctl_table *table;
+ const struct ctl_table *table;
table = net->sctp.sysctl_header->ctl_table_arg;
unregister_net_sysctl_table(net->sctp.sysctl_header);