summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/linux/filter.h1
-rw-r--r--include/net/sock.h1
-rw-r--r--kernel/bpf/sockmap.c198
-rw-r--r--net/core/filter.c2
-rw-r--r--net/ipv4/tcp.c10
5 files changed, 207 insertions, 5 deletions
diff --git a/include/linux/filter.h b/include/linux/filter.h
index c2f167db8bd5..961cc5d53956 100644
--- a/include/linux/filter.h
+++ b/include/linux/filter.h
@@ -521,6 +521,7 @@ struct sk_msg_buff {
__u32 key;
__u32 flags;
struct bpf_map *map;
+ struct list_head list;
};
/* Compute the linear packet data range [data, data_end) which
diff --git a/include/net/sock.h b/include/net/sock.h
index 709311132d4c..b8ff435fa96e 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -1085,6 +1085,7 @@ struct proto {
#endif
bool (*stream_memory_free)(const struct sock *sk);
+ bool (*stream_memory_read)(const struct sock *sk);
/* Memory pressure */
void (*enter_memory_pressure)(struct sock *sk);
void (*leave_memory_pressure)(struct sock *sk);
diff --git a/kernel/bpf/sockmap.c b/kernel/bpf/sockmap.c
index 69c5bccabd22..402e15466e9f 100644
--- a/kernel/bpf/sockmap.c
+++ b/kernel/bpf/sockmap.c
@@ -41,6 +41,8 @@
#include <linux/mm.h>
#include <net/strparser.h>
#include <net/tcp.h>
+#include <linux/ptr_ring.h>
+#include <net/inet_common.h>
#define SOCK_CREATE_FLAG_MASK \
(BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
@@ -82,6 +84,7 @@ struct smap_psock {
int sg_size;
int eval;
struct sk_msg_buff *cork;
+ struct list_head ingress;
struct strparser strp;
struct bpf_prog *bpf_tx_msg;
@@ -103,6 +106,8 @@ struct smap_psock {
};
static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
+static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
+ int nonblock, int flags, int *addr_len);
static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
int offset, size_t size, int flags);
@@ -112,6 +117,21 @@ static inline struct smap_psock *smap_psock_sk(const struct sock *sk)
return rcu_dereference_sk_user_data(sk);
}
+static bool bpf_tcp_stream_read(const struct sock *sk)
+{
+ struct smap_psock *psock;
+ bool empty = true;
+
+ rcu_read_lock();
+ psock = smap_psock_sk(sk);
+ if (unlikely(!psock))
+ goto out;
+ empty = list_empty(&psock->ingress);
+out:
+ rcu_read_unlock();
+ return !empty;
+}
+
static struct proto tcp_bpf_proto;
static int bpf_tcp_init(struct sock *sk)
{
@@ -135,6 +155,8 @@ static int bpf_tcp_init(struct sock *sk)
if (psock->bpf_tx_msg) {
tcp_bpf_proto.sendmsg = bpf_tcp_sendmsg;
tcp_bpf_proto.sendpage = bpf_tcp_sendpage;
+ tcp_bpf_proto.recvmsg = bpf_tcp_recvmsg;
+ tcp_bpf_proto.stream_memory_read = bpf_tcp_stream_read;
}
sk->sk_prot = &tcp_bpf_proto;
@@ -170,6 +192,7 @@ static void bpf_tcp_close(struct sock *sk, long timeout)
{
void (*close_fun)(struct sock *sk, long timeout);
struct smap_psock_map_entry *e, *tmp;
+ struct sk_msg_buff *md, *mtmp;
struct smap_psock *psock;
struct sock *osk;
@@ -188,6 +211,12 @@ static void bpf_tcp_close(struct sock *sk, long timeout)
close_fun = psock->save_close;
write_lock_bh(&sk->sk_callback_lock);
+ list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
+ list_del(&md->list);
+ free_start_sg(psock->sock, md);
+ kfree(md);
+ }
+
list_for_each_entry_safe(e, tmp, &psock->maps, list) {
osk = cmpxchg(e->entry, sk, NULL);
if (osk == sk) {
@@ -468,6 +497,72 @@ verdict:
return _rc;
}
+static int bpf_tcp_ingress(struct sock *sk, int apply_bytes,
+ struct smap_psock *psock,
+ struct sk_msg_buff *md, int flags)
+{
+ bool apply = apply_bytes;
+ size_t size, copied = 0;
+ struct sk_msg_buff *r;
+ int err = 0, i;
+
+ r = kzalloc(sizeof(struct sk_msg_buff), __GFP_NOWARN | GFP_KERNEL);
+ if (unlikely(!r))
+ return -ENOMEM;
+
+ lock_sock(sk);
+ r->sg_start = md->sg_start;
+ i = md->sg_start;
+
+ do {
+ r->sg_data[i] = md->sg_data[i];
+
+ size = (apply && apply_bytes < md->sg_data[i].length) ?
+ apply_bytes : md->sg_data[i].length;
+
+ if (!sk_wmem_schedule(sk, size)) {
+ if (!copied)
+ err = -ENOMEM;
+ break;
+ }
+
+ sk_mem_charge(sk, size);
+ r->sg_data[i].length = size;
+ md->sg_data[i].length -= size;
+ md->sg_data[i].offset += size;
+ copied += size;
+
+ if (md->sg_data[i].length) {
+ get_page(sg_page(&r->sg_data[i]));
+ r->sg_end = (i + 1) == MAX_SKB_FRAGS ? 0 : i + 1;
+ } else {
+ i++;
+ if (i == MAX_SKB_FRAGS)
+ i = 0;
+ r->sg_end = i;
+ }
+
+ if (apply) {
+ apply_bytes -= size;
+ if (!apply_bytes)
+ break;
+ }
+ } while (i != md->sg_end);
+
+ md->sg_start = i;
+
+ if (!err) {
+ list_add_tail(&r->list, &psock->ingress);
+ sk->sk_data_ready(sk);
+ } else {
+ free_start_sg(sk, r);
+ kfree(r);
+ }
+
+ release_sock(sk);
+ return err;
+}
+
static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,
struct sk_msg_buff *md,
int flags)
@@ -475,6 +570,7 @@ static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,
struct smap_psock *psock;
struct scatterlist *sg;
int i, err, free = 0;
+ bool ingress = !!(md->flags & BPF_F_INGRESS);
sg = md->sg_data;
@@ -487,9 +583,14 @@ static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,
goto out_rcu;
rcu_read_unlock();
- lock_sock(sk);
- err = bpf_tcp_push(sk, send, md, flags, false);
- release_sock(sk);
+
+ if (ingress) {
+ err = bpf_tcp_ingress(sk, send, psock, md, flags);
+ } else {
+ lock_sock(sk);
+ err = bpf_tcp_push(sk, send, md, flags, false);
+ release_sock(sk);
+ }
smap_release_sock(psock, sk);
if (unlikely(err))
goto out;
@@ -623,6 +724,89 @@ out_err:
return err;
}
+static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
+ int nonblock, int flags, int *addr_len)
+{
+ struct iov_iter *iter = &msg->msg_iter;
+ struct smap_psock *psock;
+ int copied = 0;
+
+ if (unlikely(flags & MSG_ERRQUEUE))
+ return inet_recv_error(sk, msg, len, addr_len);
+
+ rcu_read_lock();
+ psock = smap_psock_sk(sk);
+ if (unlikely(!psock))
+ goto out;
+
+ if (unlikely(!refcount_inc_not_zero(&psock->refcnt)))
+ goto out;
+ rcu_read_unlock();
+
+ if (!skb_queue_empty(&sk->sk_receive_queue))
+ return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
+
+ lock_sock(sk);
+ while (copied != len) {
+ struct scatterlist *sg;
+ struct sk_msg_buff *md;
+ int i;
+
+ md = list_first_entry_or_null(&psock->ingress,
+ struct sk_msg_buff, list);
+ if (unlikely(!md))
+ break;
+ i = md->sg_start;
+ do {
+ struct page *page;
+ int n, copy;
+
+ sg = &md->sg_data[i];
+ copy = sg->length;
+ page = sg_page(sg);
+
+ if (copied + copy > len)
+ copy = len - copied;
+
+ n = copy_page_to_iter(page, sg->offset, copy, iter);
+ if (n != copy) {
+ md->sg_start = i;
+ release_sock(sk);
+ smap_release_sock(psock, sk);
+ return -EFAULT;
+ }
+
+ copied += copy;
+ sg->offset += copy;
+ sg->length -= copy;
+ sk_mem_uncharge(sk, copy);
+
+ if (!sg->length) {
+ i++;
+ if (i == MAX_SKB_FRAGS)
+ i = 0;
+ put_page(page);
+ }
+ if (copied == len)
+ break;
+ } while (i != md->sg_end);
+ md->sg_start = i;
+
+ if (!sg->length && md->sg_start == md->sg_end) {
+ list_del(&md->list);
+ kfree(md);
+ }
+ }
+
+ release_sock(sk);
+ smap_release_sock(psock, sk);
+ return copied;
+out:
+ rcu_read_unlock();
+ return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
+}
+
+
static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
{
int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS;
@@ -1107,6 +1291,7 @@ static void sock_map_remove_complete(struct bpf_stab *stab)
static void smap_gc_work(struct work_struct *w)
{
struct smap_psock_map_entry *e, *tmp;
+ struct sk_msg_buff *md, *mtmp;
struct smap_psock *psock;
psock = container_of(w, struct smap_psock, gc_work);
@@ -1131,6 +1316,12 @@ static void smap_gc_work(struct work_struct *w)
kfree(psock->cork);
}
+ list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
+ list_del(&md->list);
+ free_start_sg(psock->sock, md);
+ kfree(md);
+ }
+
list_for_each_entry_safe(e, tmp, &psock->maps, list) {
list_del(&e->list);
kfree(e);
@@ -1160,6 +1351,7 @@ static struct smap_psock *smap_init_psock(struct sock *sock,
INIT_WORK(&psock->tx_work, smap_tx_work);
INIT_WORK(&psock->gc_work, smap_gc_work);
INIT_LIST_HEAD(&psock->maps);
+ INIT_LIST_HEAD(&psock->ingress);
refcount_set(&psock->refcnt, 1);
rcu_assign_sk_user_data(sock, psock);
diff --git a/net/core/filter.c b/net/core/filter.c
index afd825534ac4..a5a995e5b380 100644
--- a/net/core/filter.c
+++ b/net/core/filter.c
@@ -1894,7 +1894,7 @@ BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg_buff *, msg,
struct bpf_map *, map, u32, key, u64, flags)
{
/* If user passes invalid input drop the packet. */
- if (unlikely(flags))
+ if (unlikely(flags & ~(BPF_F_INGRESS)))
return SK_DROP;
msg->key = key;
diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c
index 0c31be306572..bccc4c270087 100644
--- a/net/ipv4/tcp.c
+++ b/net/ipv4/tcp.c
@@ -485,6 +485,14 @@ static void tcp_tx_timestamp(struct sock *sk, u16 tsflags)
}
}
+static inline bool tcp_stream_is_readable(const struct tcp_sock *tp,
+ int target, struct sock *sk)
+{
+ return (tp->rcv_nxt - tp->copied_seq >= target) ||
+ (sk->sk_prot->stream_memory_read ?
+ sk->sk_prot->stream_memory_read(sk) : false);
+}
+
/*
* Wait for a TCP event.
*
@@ -554,7 +562,7 @@ __poll_t tcp_poll(struct file *file, struct socket *sock, poll_table *wait)
tp->urg_data)
target++;
- if (tp->rcv_nxt - tp->copied_seq >= target)
+ if (tcp_stream_is_readable(tp, target, sk))
mask |= EPOLLIN | EPOLLRDNORM;
if (!(sk->sk_shutdown & SEND_SHUTDOWN)) {