diff options
Diffstat (limited to 'net/can/raw.c')
| -rw-r--r-- | net/can/raw.c | 79 | 
1 files changed, 44 insertions, 35 deletions
diff --git a/net/can/raw.c b/net/can/raw.c index 15c79b079184..d50c3f3d892f 100644 --- a/net/can/raw.c +++ b/net/can/raw.c @@ -84,6 +84,8 @@ struct raw_sock {  	struct sock sk;  	int bound;  	int ifindex; +	struct net_device *dev; +	netdevice_tracker dev_tracker;  	struct list_head notifier;  	int loopback;  	int recv_own_msgs; @@ -277,21 +279,24 @@ static void raw_notify(struct raw_sock *ro, unsigned long msg,  	if (!net_eq(dev_net(dev), sock_net(sk)))  		return; -	if (ro->ifindex != dev->ifindex) +	if (ro->dev != dev)  		return;  	switch (msg) {  	case NETDEV_UNREGISTER:  		lock_sock(sk);  		/* remove current filters & unregister */ -		if (ro->bound) +		if (ro->bound) {  			raw_disable_allfilters(dev_net(dev), dev, sk); +			netdev_put(dev, &ro->dev_tracker); +		}  		if (ro->count > 1)  			kfree(ro->filter);  		ro->ifindex = 0;  		ro->bound = 0; +		ro->dev = NULL;  		ro->count = 0;  		release_sock(sk); @@ -337,6 +342,7 @@ static int raw_init(struct sock *sk)  	ro->bound            = 0;  	ro->ifindex          = 0; +	ro->dev              = NULL;  	/* set default filter to single entry dfilter */  	ro->dfilter.can_id   = 0; @@ -383,18 +389,14 @@ static int raw_release(struct socket *sock)  	list_del(&ro->notifier);  	spin_unlock(&raw_notifier_lock); +	rtnl_lock();  	lock_sock(sk);  	/* remove current filters & unregister */  	if (ro->bound) { -		if (ro->ifindex) { -			struct net_device *dev; - -			dev = dev_get_by_index(sock_net(sk), ro->ifindex); -			if (dev) { -				raw_disable_allfilters(dev_net(dev), dev, sk); -				dev_put(dev); -			} +		if (ro->dev) { +			raw_disable_allfilters(dev_net(ro->dev), ro->dev, sk); +			netdev_put(ro->dev, &ro->dev_tracker);  		} else {  			raw_disable_allfilters(sock_net(sk), NULL, sk);  		} @@ -405,6 +407,7 @@ static int raw_release(struct socket *sock)  	ro->ifindex = 0;  	ro->bound = 0; +	ro->dev = NULL;  	ro->count = 0;  	free_percpu(ro->uniq); @@ -412,6 +415,8 @@ static int raw_release(struct socket *sock)  	sock->sk = NULL;  	release_sock(sk); +	rtnl_unlock(); +  	sock_put(sk);  	return 0; @@ -422,6 +427,7 @@ static int raw_bind(struct socket *sock, struct sockaddr *uaddr, int len)  	struct sockaddr_can *addr = (struct sockaddr_can *)uaddr;  	struct sock *sk = sock->sk;  	struct raw_sock *ro = raw_sk(sk); +	struct net_device *dev = NULL;  	int ifindex;  	int err = 0;  	int notify_enetdown = 0; @@ -431,24 +437,23 @@ static int raw_bind(struct socket *sock, struct sockaddr *uaddr, int len)  	if (addr->can_family != AF_CAN)  		return -EINVAL; +	rtnl_lock();  	lock_sock(sk);  	if (ro->bound && addr->can_ifindex == ro->ifindex)  		goto out;  	if (addr->can_ifindex) { -		struct net_device *dev; -  		dev = dev_get_by_index(sock_net(sk), addr->can_ifindex);  		if (!dev) {  			err = -ENODEV;  			goto out;  		}  		if (dev->type != ARPHRD_CAN) { -			dev_put(dev);  			err = -ENODEV; -			goto out; +			goto out_put_dev;  		} +  		if (!(dev->flags & IFF_UP))  			notify_enetdown = 1; @@ -456,7 +461,9 @@ static int raw_bind(struct socket *sock, struct sockaddr *uaddr, int len)  		/* filters set by default/setsockopt */  		err = raw_enable_allfilters(sock_net(sk), dev, sk); -		dev_put(dev); +		if (err) +			goto out_put_dev; +  	} else {  		ifindex = 0; @@ -467,26 +474,30 @@ static int raw_bind(struct socket *sock, struct sockaddr *uaddr, int len)  	if (!err) {  		if (ro->bound) {  			/* unregister old filters */ -			if (ro->ifindex) { -				struct net_device *dev; - -				dev = dev_get_by_index(sock_net(sk), -						       ro->ifindex); -				if (dev) { -					raw_disable_allfilters(dev_net(dev), -							       dev, sk); -					dev_put(dev); -				} +			if (ro->dev) { +				raw_disable_allfilters(dev_net(ro->dev), +						       ro->dev, sk); +				/* drop reference to old ro->dev */ +				netdev_put(ro->dev, &ro->dev_tracker);  			} else {  				raw_disable_allfilters(sock_net(sk), NULL, sk);  			}  		}  		ro->ifindex = ifindex;  		ro->bound = 1; +		/* bind() ok -> hold a reference for new ro->dev */ +		ro->dev = dev; +		if (ro->dev) +			netdev_hold(ro->dev, &ro->dev_tracker, GFP_KERNEL);  	} - out: +out_put_dev: +	/* remove potential reference from dev_get_by_index() */ +	if (dev) +		dev_put(dev); +out:  	release_sock(sk); +	rtnl_unlock();  	if (notify_enetdown) {  		sk->sk_err = ENETDOWN; @@ -553,9 +564,9 @@ static int raw_setsockopt(struct socket *sock, int level, int optname,  		rtnl_lock();  		lock_sock(sk); -		if (ro->bound && ro->ifindex) { -			dev = dev_get_by_index(sock_net(sk), ro->ifindex); -			if (!dev) { +		dev = ro->dev; +		if (ro->bound && dev) { +			if (dev->reg_state != NETREG_REGISTERED) {  				if (count > 1)  					kfree(filter);  				err = -ENODEV; @@ -596,7 +607,6 @@ static int raw_setsockopt(struct socket *sock, int level, int optname,  		ro->count  = count;   out_fil: -		dev_put(dev);  		release_sock(sk);  		rtnl_unlock(); @@ -614,9 +624,9 @@ static int raw_setsockopt(struct socket *sock, int level, int optname,  		rtnl_lock();  		lock_sock(sk); -		if (ro->bound && ro->ifindex) { -			dev = dev_get_by_index(sock_net(sk), ro->ifindex); -			if (!dev) { +		dev = ro->dev; +		if (ro->bound && dev) { +			if (dev->reg_state != NETREG_REGISTERED) {  				err = -ENODEV;  				goto out_err;  			} @@ -640,7 +650,6 @@ static int raw_setsockopt(struct socket *sock, int level, int optname,  		ro->err_mask = err_mask;   out_err: -		dev_put(dev);  		release_sock(sk);  		rtnl_unlock(); @@ -873,7 +882,7 @@ static int raw_sendmsg(struct socket *sock, struct msghdr *msg, size_t size)  	skb->dev = dev;  	skb->priority = sk->sk_priority; -	skb->mark = sk->sk_mark; +	skb->mark = READ_ONCE(sk->sk_mark);  	skb->tstamp = sockc.transmit_time;  	skb_setup_tx_timestamp(skb, sockc.tsflags);  | 
