diff options
Diffstat (limited to 'kernel/bpf/sockmap.c')
| -rw-r--r-- | kernel/bpf/sockmap.c | 75 | 
1 files changed, 45 insertions, 30 deletions
diff --git a/kernel/bpf/sockmap.c b/kernel/bpf/sockmap.c index 98e621a29e8e..488ef9663c01 100644 --- a/kernel/bpf/sockmap.c +++ b/kernel/bpf/sockmap.c @@ -236,7 +236,7 @@ static int bpf_tcp_init(struct sock *sk)  }  static void smap_release_sock(struct smap_psock *psock, struct sock *sock); -static int free_start_sg(struct sock *sk, struct sk_msg_buff *md); +static int free_start_sg(struct sock *sk, struct sk_msg_buff *md, bool charge);  static void bpf_tcp_release(struct sock *sk)  { @@ -248,7 +248,7 @@ static void bpf_tcp_release(struct sock *sk)  		goto out;  	if (psock->cork) { -		free_start_sg(psock->sock, psock->cork); +		free_start_sg(psock->sock, psock->cork, true);  		kfree(psock->cork);  		psock->cork = NULL;  	} @@ -330,14 +330,14 @@ static void bpf_tcp_close(struct sock *sk, long timeout)  	close_fun = psock->save_close;  	if (psock->cork) { -		free_start_sg(psock->sock, psock->cork); +		free_start_sg(psock->sock, psock->cork, true);  		kfree(psock->cork);  		psock->cork = NULL;  	}  	list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {  		list_del(&md->list); -		free_start_sg(psock->sock, md); +		free_start_sg(psock->sock, md, true);  		kfree(md);  	} @@ -369,7 +369,7 @@ static void bpf_tcp_close(struct sock *sk, long timeout)  			/* If another thread deleted this object skip deletion.  			 * The refcnt on psock may or may not be zero.  			 */ -			if (l) { +			if (l && l == link) {  				hlist_del_rcu(&link->hash_node);  				smap_release_sock(psock, link->sk);  				free_htab_elem(htab, link); @@ -570,14 +570,16 @@ static void free_bytes_sg(struct sock *sk, int bytes,  	md->sg_start = i;  } -static int free_sg(struct sock *sk, int start, struct sk_msg_buff *md) +static int free_sg(struct sock *sk, int start, +		   struct sk_msg_buff *md, bool charge)  {  	struct scatterlist *sg = md->sg_data;  	int i = start, free = 0;  	while (sg[i].length) {  		free += sg[i].length; -		sk_mem_uncharge(sk, sg[i].length); +		if (charge) +			sk_mem_uncharge(sk, sg[i].length);  		if (!md->skb)  			put_page(sg_page(&sg[i]));  		sg[i].length = 0; @@ -594,9 +596,9 @@ static int free_sg(struct sock *sk, int start, struct sk_msg_buff *md)  	return free;  } -static int free_start_sg(struct sock *sk, struct sk_msg_buff *md) +static int free_start_sg(struct sock *sk, struct sk_msg_buff *md, bool charge)  { -	int free = free_sg(sk, md->sg_start, md); +	int free = free_sg(sk, md->sg_start, md, charge);  	md->sg_start = md->sg_end;  	return free; @@ -604,7 +606,7 @@ static int free_start_sg(struct sock *sk, struct sk_msg_buff *md)  static int free_curr_sg(struct sock *sk, struct sk_msg_buff *md)  { -	return free_sg(sk, md->sg_curr, md); +	return free_sg(sk, md->sg_curr, md, true);  }  static int bpf_map_msg_verdict(int _rc, struct sk_msg_buff *md) @@ -718,7 +720,7 @@ static int bpf_tcp_ingress(struct sock *sk, int apply_bytes,  		list_add_tail(&r->list, &psock->ingress);  		sk->sk_data_ready(sk);  	} else { -		free_start_sg(sk, r); +		free_start_sg(sk, r, true);  		kfree(r);  	} @@ -752,14 +754,10 @@ static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,  		release_sock(sk);  	}  	smap_release_sock(psock, sk); -	if (unlikely(err)) -		goto out; -	return 0; +	return err;  out_rcu:  	rcu_read_unlock(); -out: -	free_bytes_sg(NULL, send, md, false); -	return err; +	return 0;  }  static inline void bpf_md_init(struct smap_psock *psock) @@ -822,7 +820,7 @@ more_data:  	case __SK_PASS:  		err = bpf_tcp_push(sk, send, m, flags, true);  		if (unlikely(err)) { -			*copied -= free_start_sg(sk, m); +			*copied -= free_start_sg(sk, m, true);  			break;  		} @@ -845,16 +843,17 @@ more_data:  		lock_sock(sk);  		if (unlikely(err < 0)) { -			free_start_sg(sk, m); +			int free = free_start_sg(sk, m, false); +  			psock->sg_size = 0;  			if (!cork) -				*copied -= send; +				*copied -= free;  		} else {  			psock->sg_size -= send;  		}  		if (cork) { -			free_start_sg(sk, m); +			free_start_sg(sk, m, true);  			psock->sg_size = 0;  			kfree(m);  			m = NULL; @@ -912,6 +911,8 @@ static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,  	if (unlikely(flags & MSG_ERRQUEUE))  		return inet_recv_error(sk, msg, len, addr_len); +	if (!skb_queue_empty(&sk->sk_receive_queue)) +		return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);  	rcu_read_lock();  	psock = smap_psock_sk(sk); @@ -922,9 +923,6 @@ static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,  		goto out;  	rcu_read_unlock(); -	if (!skb_queue_empty(&sk->sk_receive_queue)) -		return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); -  	lock_sock(sk);  bytes_ready:  	while (copied != len) { @@ -1122,7 +1120,7 @@ wait_for_memory:  		err = sk_stream_wait_memory(sk, &timeo);  		if (err) {  			if (m && m != psock->cork) -				free_start_sg(sk, m); +				free_start_sg(sk, m, true);  			goto out_err;  		}  	} @@ -1427,12 +1425,15 @@ out:  static void smap_write_space(struct sock *sk)  {  	struct smap_psock *psock; +	void (*write_space)(struct sock *sk);  	rcu_read_lock();  	psock = smap_psock_sk(sk);  	if (likely(psock && test_bit(SMAP_TX_RUNNING, &psock->state)))  		schedule_work(&psock->tx_work); +	write_space = psock->save_write_space;  	rcu_read_unlock(); +	write_space(sk);  }  static void smap_stop_sock(struct smap_psock *psock, struct sock *sk) @@ -1461,10 +1462,16 @@ static void smap_destroy_psock(struct rcu_head *rcu)  	schedule_work(&psock->gc_work);  } +static bool psock_is_smap_sk(struct sock *sk) +{ +	return inet_csk(sk)->icsk_ulp_ops == &bpf_tcp_ulp_ops; +} +  static void smap_release_sock(struct smap_psock *psock, struct sock *sock)  {  	if (refcount_dec_and_test(&psock->refcnt)) { -		tcp_cleanup_ulp(sock); +		if (psock_is_smap_sk(sock)) +			tcp_cleanup_ulp(sock);  		write_lock_bh(&sock->sk_callback_lock);  		smap_stop_sock(psock, sock);  		write_unlock_bh(&sock->sk_callback_lock); @@ -1578,13 +1585,13 @@ static void smap_gc_work(struct work_struct *w)  		bpf_prog_put(psock->bpf_tx_msg);  	if (psock->cork) { -		free_start_sg(psock->sock, psock->cork); +		free_start_sg(psock->sock, psock->cork, true);  		kfree(psock->cork);  	}  	list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {  		list_del(&md->list); -		free_start_sg(psock->sock, md); +		free_start_sg(psock->sock, md, true);  		kfree(md);  	} @@ -1891,6 +1898,10 @@ static int __sock_map_ctx_update_elem(struct bpf_map *map,  	 * doesn't update user data.  	 */  	if (psock) { +		if (!psock_is_smap_sk(sock)) { +			err = -EBUSY; +			goto out_progs; +		}  		if (READ_ONCE(psock->bpf_parse) && parse) {  			err = -EBUSY;  			goto out_progs; @@ -2140,7 +2151,9 @@ static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)  		return ERR_PTR(-EPERM);  	/* check sanity of attributes */ -	if (attr->max_entries == 0 || attr->value_size != 4 || +	if (attr->max_entries == 0 || +	    attr->key_size == 0 || +	    attr->value_size != 4 ||  	    attr->map_flags & ~SOCK_CREATE_FLAG_MASK)  		return ERR_PTR(-EINVAL); @@ -2267,8 +2280,10 @@ static struct htab_elem *alloc_sock_hash_elem(struct bpf_htab *htab,  	}  	l_new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN,  			     htab->map.numa_node); -	if (!l_new) +	if (!l_new) { +		atomic_dec(&htab->count);  		return ERR_PTR(-ENOMEM); +	}  	memcpy(l_new->key, key, key_size);  	l_new->sk = sk;  | 
