diff options
Diffstat (limited to 'drivers/net/wireguard')
-rw-r--r-- | drivers/net/wireguard/messages.h | 2 | ||||
-rw-r--r-- | drivers/net/wireguard/noise.c | 22 | ||||
-rw-r--r-- | drivers/net/wireguard/noise.h | 14 | ||||
-rw-r--r-- | drivers/net/wireguard/queueing.c | 4 | ||||
-rw-r--r-- | drivers/net/wireguard/queueing.h | 10 | ||||
-rw-r--r-- | drivers/net/wireguard/receive.c | 65 | ||||
-rw-r--r-- | drivers/net/wireguard/selftest/counter.c | 17 | ||||
-rw-r--r-- | drivers/net/wireguard/selftest/ratelimiter.c | 4 | ||||
-rw-r--r-- | drivers/net/wireguard/send.c | 37 | ||||
-rw-r--r-- | drivers/net/wireguard/socket.c | 12 |
10 files changed, 94 insertions, 93 deletions
diff --git a/drivers/net/wireguard/messages.h b/drivers/net/wireguard/messages.h index b8a7b9ce32ba..208da72673fc 100644 --- a/drivers/net/wireguard/messages.h +++ b/drivers/net/wireguard/messages.h @@ -32,7 +32,7 @@ enum cookie_values { }; enum counter_values { - COUNTER_BITS_TOTAL = 2048, + COUNTER_BITS_TOTAL = 8192, COUNTER_REDUNDANT_BITS = BITS_PER_LONG, COUNTER_WINDOW_SIZE = COUNTER_BITS_TOTAL - COUNTER_REDUNDANT_BITS }; diff --git a/drivers/net/wireguard/noise.c b/drivers/net/wireguard/noise.c index 708dc61c974f..626433690abb 100644 --- a/drivers/net/wireguard/noise.c +++ b/drivers/net/wireguard/noise.c @@ -104,6 +104,7 @@ static struct noise_keypair *keypair_create(struct wg_peer *peer) if (unlikely(!keypair)) return NULL; + spin_lock_init(&keypair->receiving_counter.lock); keypair->internal_id = atomic64_inc_return(&keypair_counter); keypair->entry.type = INDEX_HASHTABLE_KEYPAIR; keypair->entry.peer = peer; @@ -358,25 +359,16 @@ out: memzero_explicit(output, BLAKE2S_HASH_SIZE + 1); } -static void symmetric_key_init(struct noise_symmetric_key *key) -{ - spin_lock_init(&key->counter.receive.lock); - atomic64_set(&key->counter.counter, 0); - memset(key->counter.receive.backtrack, 0, - sizeof(key->counter.receive.backtrack)); - key->birthdate = ktime_get_coarse_boottime_ns(); - key->is_valid = true; -} - static void derive_keys(struct noise_symmetric_key *first_dst, struct noise_symmetric_key *second_dst, const u8 chaining_key[NOISE_HASH_LEN]) { + u64 birthdate = ktime_get_coarse_boottime_ns(); kdf(first_dst->key, second_dst->key, NULL, NULL, NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0, chaining_key); - symmetric_key_init(first_dst); - symmetric_key_init(second_dst); + first_dst->birthdate = second_dst->birthdate = birthdate; + first_dst->is_valid = second_dst->is_valid = true; } static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN], @@ -715,6 +707,7 @@ wg_noise_handshake_consume_response(struct message_handshake_response *src, u8 e[NOISE_PUBLIC_KEY_LEN]; u8 ephemeral_private[NOISE_PUBLIC_KEY_LEN]; u8 static_private[NOISE_PUBLIC_KEY_LEN]; + u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN]; down_read(&wg->static_identity.lock); @@ -733,6 +726,8 @@ wg_noise_handshake_consume_response(struct message_handshake_response *src, memcpy(chaining_key, handshake->chaining_key, NOISE_HASH_LEN); memcpy(ephemeral_private, handshake->ephemeral_private, NOISE_PUBLIC_KEY_LEN); + memcpy(preshared_key, handshake->preshared_key, + NOISE_SYMMETRIC_KEY_LEN); up_read(&handshake->lock); if (state != HANDSHAKE_CREATED_INITIATION) @@ -750,7 +745,7 @@ wg_noise_handshake_consume_response(struct message_handshake_response *src, goto fail; /* psk */ - mix_psk(chaining_key, hash, key, handshake->preshared_key); + mix_psk(chaining_key, hash, key, preshared_key); /* {} */ if (!message_decrypt(NULL, src->encrypted_nothing, @@ -783,6 +778,7 @@ out: memzero_explicit(chaining_key, NOISE_HASH_LEN); memzero_explicit(ephemeral_private, NOISE_PUBLIC_KEY_LEN); memzero_explicit(static_private, NOISE_PUBLIC_KEY_LEN); + memzero_explicit(preshared_key, NOISE_SYMMETRIC_KEY_LEN); up_read(&wg->static_identity.lock); return ret_peer; } diff --git a/drivers/net/wireguard/noise.h b/drivers/net/wireguard/noise.h index f532d59d3f19..c527253dba80 100644 --- a/drivers/net/wireguard/noise.h +++ b/drivers/net/wireguard/noise.h @@ -15,18 +15,14 @@ #include <linux/mutex.h> #include <linux/kref.h> -union noise_counter { - struct { - u64 counter; - unsigned long backtrack[COUNTER_BITS_TOTAL / BITS_PER_LONG]; - spinlock_t lock; - } receive; - atomic64_t counter; +struct noise_replay_counter { + u64 counter; + spinlock_t lock; + unsigned long backtrack[COUNTER_BITS_TOTAL / BITS_PER_LONG]; }; struct noise_symmetric_key { u8 key[NOISE_SYMMETRIC_KEY_LEN]; - union noise_counter counter; u64 birthdate; bool is_valid; }; @@ -34,7 +30,9 @@ struct noise_symmetric_key { struct noise_keypair { struct index_hashtable_entry entry; struct noise_symmetric_key sending; + atomic64_t sending_counter; struct noise_symmetric_key receiving; + struct noise_replay_counter receiving_counter; __le32 remote_index; bool i_am_the_initiator; struct kref refcount; diff --git a/drivers/net/wireguard/queueing.c b/drivers/net/wireguard/queueing.c index 5c964fcb994e..71b8e80b58e1 100644 --- a/drivers/net/wireguard/queueing.c +++ b/drivers/net/wireguard/queueing.c @@ -35,8 +35,10 @@ int wg_packet_queue_init(struct crypt_queue *queue, work_func_t function, if (multicore) { queue->worker = wg_packet_percpu_multicore_worker_alloc( function, queue); - if (!queue->worker) + if (!queue->worker) { + ptr_ring_cleanup(&queue->ring, NULL); return -ENOMEM; + } } else { INIT_WORK(&queue->work, function); } diff --git a/drivers/net/wireguard/queueing.h b/drivers/net/wireguard/queueing.h index 3432232afe06..c58df439dbbe 100644 --- a/drivers/net/wireguard/queueing.h +++ b/drivers/net/wireguard/queueing.h @@ -87,12 +87,20 @@ static inline bool wg_check_packet_protocol(struct sk_buff *skb) return real_protocol && skb->protocol == real_protocol; } -static inline void wg_reset_packet(struct sk_buff *skb) +static inline void wg_reset_packet(struct sk_buff *skb, bool encapsulating) { + u8 l4_hash = skb->l4_hash; + u8 sw_hash = skb->sw_hash; + u32 hash = skb->hash; skb_scrub_packet(skb, true); memset(&skb->headers_start, 0, offsetof(struct sk_buff, headers_end) - offsetof(struct sk_buff, headers_start)); + if (encapsulating) { + skb->l4_hash = l4_hash; + skb->sw_hash = sw_hash; + skb->hash = hash; + } skb->queue_mapping = 0; skb->nohdr = 0; skb->peeked = 0; diff --git a/drivers/net/wireguard/receive.c b/drivers/net/wireguard/receive.c index da3b782ab7d3..91438144e4f7 100644 --- a/drivers/net/wireguard/receive.c +++ b/drivers/net/wireguard/receive.c @@ -226,40 +226,39 @@ void wg_packet_handshake_receive_worker(struct work_struct *work) static void keep_key_fresh(struct wg_peer *peer) { struct noise_keypair *keypair; - bool send = false; + bool send; if (peer->sent_lastminute_handshake) return; rcu_read_lock_bh(); keypair = rcu_dereference_bh(peer->keypairs.current_keypair); - if (likely(keypair && READ_ONCE(keypair->sending.is_valid)) && - keypair->i_am_the_initiator && - unlikely(wg_birthdate_has_expired(keypair->sending.birthdate, - REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT))) - send = true; + send = keypair && READ_ONCE(keypair->sending.is_valid) && + keypair->i_am_the_initiator && + wg_birthdate_has_expired(keypair->sending.birthdate, + REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT); rcu_read_unlock_bh(); - if (send) { + if (unlikely(send)) { peer->sent_lastminute_handshake = true; wg_packet_send_queued_handshake_initiation(peer, false); } } -static bool decrypt_packet(struct sk_buff *skb, struct noise_symmetric_key *key) +static bool decrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair) { struct scatterlist sg[MAX_SKB_FRAGS + 8]; struct sk_buff *trailer; unsigned int offset; int num_frags; - if (unlikely(!key)) + if (unlikely(!keypair)) return false; - if (unlikely(!READ_ONCE(key->is_valid) || - wg_birthdate_has_expired(key->birthdate, REJECT_AFTER_TIME) || - key->counter.receive.counter >= REJECT_AFTER_MESSAGES)) { - WRITE_ONCE(key->is_valid, false); + if (unlikely(!READ_ONCE(keypair->receiving.is_valid) || + wg_birthdate_has_expired(keypair->receiving.birthdate, REJECT_AFTER_TIME) || + keypair->receiving_counter.counter >= REJECT_AFTER_MESSAGES)) { + WRITE_ONCE(keypair->receiving.is_valid, false); return false; } @@ -284,7 +283,7 @@ static bool decrypt_packet(struct sk_buff *skb, struct noise_symmetric_key *key) if (!chacha20poly1305_decrypt_sg_inplace(sg, skb->len, NULL, 0, PACKET_CB(skb)->nonce, - key->key)) + keypair->receiving.key)) return false; /* Another ugly situation of pushing and pulling the header so as to @@ -299,41 +298,41 @@ static bool decrypt_packet(struct sk_buff *skb, struct noise_symmetric_key *key) } /* This is RFC6479, a replay detection bitmap algorithm that avoids bitshifts */ -static bool counter_validate(union noise_counter *counter, u64 their_counter) +static bool counter_validate(struct noise_replay_counter *counter, u64 their_counter) { unsigned long index, index_current, top, i; bool ret = false; - spin_lock_bh(&counter->receive.lock); + spin_lock_bh(&counter->lock); - if (unlikely(counter->receive.counter >= REJECT_AFTER_MESSAGES + 1 || + if (unlikely(counter->counter >= REJECT_AFTER_MESSAGES + 1 || their_counter >= REJECT_AFTER_MESSAGES)) goto out; ++their_counter; if (unlikely((COUNTER_WINDOW_SIZE + their_counter) < - counter->receive.counter)) + counter->counter)) goto out; index = their_counter >> ilog2(BITS_PER_LONG); - if (likely(their_counter > counter->receive.counter)) { - index_current = counter->receive.counter >> ilog2(BITS_PER_LONG); + if (likely(their_counter > counter->counter)) { + index_current = counter->counter >> ilog2(BITS_PER_LONG); top = min_t(unsigned long, index - index_current, COUNTER_BITS_TOTAL / BITS_PER_LONG); for (i = 1; i <= top; ++i) - counter->receive.backtrack[(i + index_current) & + counter->backtrack[(i + index_current) & ((COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1)] = 0; - counter->receive.counter = their_counter; + counter->counter = their_counter; } index &= (COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1; ret = !test_and_set_bit(their_counter & (BITS_PER_LONG - 1), - &counter->receive.backtrack[index]); + &counter->backtrack[index]); out: - spin_unlock_bh(&counter->receive.lock); + spin_unlock_bh(&counter->lock); return ret; } @@ -393,13 +392,11 @@ static void wg_packet_consume_data_done(struct wg_peer *peer, len = ntohs(ip_hdr(skb)->tot_len); if (unlikely(len < sizeof(struct iphdr))) goto dishonest_packet_size; - if (INET_ECN_is_ce(PACKET_CB(skb)->ds)) - IP_ECN_set_ce(ip_hdr(skb)); + INET_ECN_decapsulate(skb, PACKET_CB(skb)->ds, ip_hdr(skb)->tos); } else if (skb->protocol == htons(ETH_P_IPV6)) { len = ntohs(ipv6_hdr(skb)->payload_len) + sizeof(struct ipv6hdr); - if (INET_ECN_is_ce(PACKET_CB(skb)->ds)) - IP6_ECN_set_ce(skb, ipv6_hdr(skb)); + INET_ECN_decapsulate(skb, PACKET_CB(skb)->ds, ipv6_get_dsfield(ipv6_hdr(skb))); } else { goto dishonest_packet_type; } @@ -475,19 +472,19 @@ int wg_packet_rx_poll(struct napi_struct *napi, int budget) if (unlikely(state != PACKET_STATE_CRYPTED)) goto next; - if (unlikely(!counter_validate(&keypair->receiving.counter, + if (unlikely(!counter_validate(&keypair->receiving_counter, PACKET_CB(skb)->nonce))) { net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n", peer->device->dev->name, PACKET_CB(skb)->nonce, - keypair->receiving.counter.receive.counter); + keypair->receiving_counter.counter); goto next; } if (unlikely(wg_socket_endpoint_from_skb(&endpoint, skb))) goto next; - wg_reset_packet(skb); + wg_reset_packet(skb, false); wg_packet_consume_data_done(peer, skb, &endpoint); free = false; @@ -514,10 +511,12 @@ void wg_packet_decrypt_worker(struct work_struct *work) struct sk_buff *skb; while ((skb = ptr_ring_consume_bh(&queue->ring)) != NULL) { - enum packet_state state = likely(decrypt_packet(skb, - &PACKET_CB(skb)->keypair->receiving)) ? + enum packet_state state = + likely(decrypt_packet(skb, PACKET_CB(skb)->keypair)) ? PACKET_STATE_CRYPTED : PACKET_STATE_DEAD; wg_queue_enqueue_per_peer_napi(skb, state); + if (need_resched()) + cond_resched(); } } diff --git a/drivers/net/wireguard/selftest/counter.c b/drivers/net/wireguard/selftest/counter.c index f4fbb9072ed7..ec3c156bf91b 100644 --- a/drivers/net/wireguard/selftest/counter.c +++ b/drivers/net/wireguard/selftest/counter.c @@ -6,18 +6,24 @@ #ifdef DEBUG bool __init wg_packet_counter_selftest(void) { + struct noise_replay_counter *counter; unsigned int test_num = 0, i; - union noise_counter counter; bool success = true; -#define T_INIT do { \ - memset(&counter, 0, sizeof(union noise_counter)); \ - spin_lock_init(&counter.receive.lock); \ + counter = kmalloc(sizeof(*counter), GFP_KERNEL); + if (unlikely(!counter)) { + pr_err("nonce counter self-test malloc: FAIL\n"); + return false; + } + +#define T_INIT do { \ + memset(counter, 0, sizeof(*counter)); \ + spin_lock_init(&counter->lock); \ } while (0) #define T_LIM (COUNTER_WINDOW_SIZE + 1) #define T(n, v) do { \ ++test_num; \ - if (counter_validate(&counter, n) != (v)) { \ + if (counter_validate(counter, n) != (v)) { \ pr_err("nonce counter self-test %u: FAIL\n", \ test_num); \ success = false; \ @@ -99,6 +105,7 @@ bool __init wg_packet_counter_selftest(void) if (success) pr_info("nonce counter self-tests: pass\n"); + kfree(counter); return success; } #endif diff --git a/drivers/net/wireguard/selftest/ratelimiter.c b/drivers/net/wireguard/selftest/ratelimiter.c index bcd6462e4540..007cd4457c5f 100644 --- a/drivers/net/wireguard/selftest/ratelimiter.c +++ b/drivers/net/wireguard/selftest/ratelimiter.c @@ -120,9 +120,9 @@ bool __init wg_ratelimiter_selftest(void) enum { TRIALS_BEFORE_GIVING_UP = 5000 }; bool success = false; int test = 0, trials; - struct sk_buff *skb4, *skb6; + struct sk_buff *skb4, *skb6 = NULL; struct iphdr *hdr4; - struct ipv6hdr *hdr6; + struct ipv6hdr *hdr6 = NULL; if (IS_ENABLED(CONFIG_KASAN) || IS_ENABLED(CONFIG_UBSAN)) return true; diff --git a/drivers/net/wireguard/send.c b/drivers/net/wireguard/send.c index 7348c10cbae3..f74b9341ab0f 100644 --- a/drivers/net/wireguard/send.c +++ b/drivers/net/wireguard/send.c @@ -124,20 +124,17 @@ void wg_packet_send_handshake_cookie(struct wg_device *wg, static void keep_key_fresh(struct wg_peer *peer) { struct noise_keypair *keypair; - bool send = false; + bool send; rcu_read_lock_bh(); keypair = rcu_dereference_bh(peer->keypairs.current_keypair); - if (likely(keypair && READ_ONCE(keypair->sending.is_valid)) && - (unlikely(atomic64_read(&keypair->sending.counter.counter) > - REKEY_AFTER_MESSAGES) || - (keypair->i_am_the_initiator && - unlikely(wg_birthdate_has_expired(keypair->sending.birthdate, - REKEY_AFTER_TIME))))) - send = true; + send = keypair && READ_ONCE(keypair->sending.is_valid) && + (atomic64_read(&keypair->sending_counter) > REKEY_AFTER_MESSAGES || + (keypair->i_am_the_initiator && + wg_birthdate_has_expired(keypair->sending.birthdate, REKEY_AFTER_TIME))); rcu_read_unlock_bh(); - if (send) + if (unlikely(send)) wg_packet_send_queued_handshake_initiation(peer, false); } @@ -170,6 +167,11 @@ static bool encrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair) struct sk_buff *trailer; int num_frags; + /* Force hash calculation before encryption so that flow analysis is + * consistent over the inner packet. + */ + skb_get_hash(skb); + /* Calculate lengths. */ padding_len = calculate_skb_padding(skb); trailer_len = padding_len + noise_encrypted_len(0); @@ -281,6 +283,8 @@ void wg_packet_tx_worker(struct work_struct *work) wg_noise_keypair_put(keypair, false); wg_peer_put(peer); + if (need_resched()) + cond_resched(); } } @@ -296,7 +300,7 @@ void wg_packet_encrypt_worker(struct work_struct *work) skb_list_walk_safe(first, skb, next) { if (likely(encrypt_packet(skb, PACKET_CB(first)->keypair))) { - wg_reset_packet(skb); + wg_reset_packet(skb, true); } else { state = PACKET_STATE_DEAD; break; @@ -304,7 +308,8 @@ void wg_packet_encrypt_worker(struct work_struct *work) } wg_queue_enqueue_per_peer(&PACKET_PEER(first)->tx_queue, first, state); - + if (need_resched()) + cond_resched(); } } @@ -344,7 +349,6 @@ void wg_packet_purge_staged_packets(struct wg_peer *peer) void wg_packet_send_staged_packets(struct wg_peer *peer) { - struct noise_symmetric_key *key; struct noise_keypair *keypair; struct sk_buff_head packets; struct sk_buff *skb; @@ -364,10 +368,9 @@ void wg_packet_send_staged_packets(struct wg_peer *peer) rcu_read_unlock_bh(); if (unlikely(!keypair)) goto out_nokey; - key = &keypair->sending; - if (unlikely(!READ_ONCE(key->is_valid))) + if (unlikely(!READ_ONCE(keypair->sending.is_valid))) goto out_nokey; - if (unlikely(wg_birthdate_has_expired(key->birthdate, + if (unlikely(wg_birthdate_has_expired(keypair->sending.birthdate, REJECT_AFTER_TIME))) goto out_invalid; @@ -382,7 +385,7 @@ void wg_packet_send_staged_packets(struct wg_peer *peer) */ PACKET_CB(skb)->ds = ip_tunnel_ecn_encap(0, ip_hdr(skb), skb); PACKET_CB(skb)->nonce = - atomic64_inc_return(&key->counter.counter) - 1; + atomic64_inc_return(&keypair->sending_counter) - 1; if (unlikely(PACKET_CB(skb)->nonce >= REJECT_AFTER_MESSAGES)) goto out_invalid; } @@ -394,7 +397,7 @@ void wg_packet_send_staged_packets(struct wg_peer *peer) return; out_invalid: - WRITE_ONCE(key->is_valid, false); + WRITE_ONCE(keypair->sending.is_valid, false); out_nokey: wg_noise_keypair_put(keypair, false); diff --git a/drivers/net/wireguard/socket.c b/drivers/net/wireguard/socket.c index b0d6541582d3..f9018027fc13 100644 --- a/drivers/net/wireguard/socket.c +++ b/drivers/net/wireguard/socket.c @@ -76,12 +76,6 @@ static int send4(struct wg_device *wg, struct sk_buff *skb, net_dbg_ratelimited("%s: No route to %pISpfsc, error %d\n", wg->dev->name, &endpoint->addr, ret); goto err; - } else if (unlikely(rt->dst.dev == skb->dev)) { - ip_rt_put(rt); - ret = -ELOOP; - net_dbg_ratelimited("%s: Avoiding routing loop to %pISpfsc\n", - wg->dev->name, &endpoint->addr); - goto err; } if (cache) dst_cache_set_ip4(cache, &rt->dst, fl.saddr); @@ -149,12 +143,6 @@ static int send6(struct wg_device *wg, struct sk_buff *skb, net_dbg_ratelimited("%s: No route to %pISpfsc, error %d\n", wg->dev->name, &endpoint->addr, ret); goto err; - } else if (unlikely(dst->dev == skb->dev)) { - dst_release(dst); - ret = -ELOOP; - net_dbg_ratelimited("%s: Avoiding routing loop to %pISpfsc\n", - wg->dev->name, &endpoint->addr); - goto err; } if (cache) dst_cache_set_ip6(cache, dst, &fl.saddr); |