Update dependencies
This commit is contained in:
294
vendor/github.com/tailscale/wireguard-go/device/allowedips.go
generated
vendored
Normal file
294
vendor/github.com/tailscale/wireguard-go/device/allowedips.go
generated
vendored
Normal file
@@ -0,0 +1,294 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"math/bits"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type parentIndirection struct {
|
||||
parentBit **trieEntry
|
||||
parentBitType uint8
|
||||
}
|
||||
|
||||
type trieEntry struct {
|
||||
peer *Peer
|
||||
child [2]*trieEntry
|
||||
parent parentIndirection
|
||||
cidr uint8
|
||||
bitAtByte uint8
|
||||
bitAtShift uint8
|
||||
bits []byte
|
||||
perPeerElem *list.Element
|
||||
}
|
||||
|
||||
func commonBits(ip1, ip2 []byte) uint8 {
|
||||
size := len(ip1)
|
||||
if size == net.IPv4len {
|
||||
a := binary.BigEndian.Uint32(ip1)
|
||||
b := binary.BigEndian.Uint32(ip2)
|
||||
x := a ^ b
|
||||
return uint8(bits.LeadingZeros32(x))
|
||||
} else if size == net.IPv6len {
|
||||
a := binary.BigEndian.Uint64(ip1)
|
||||
b := binary.BigEndian.Uint64(ip2)
|
||||
x := a ^ b
|
||||
if x != 0 {
|
||||
return uint8(bits.LeadingZeros64(x))
|
||||
}
|
||||
a = binary.BigEndian.Uint64(ip1[8:])
|
||||
b = binary.BigEndian.Uint64(ip2[8:])
|
||||
x = a ^ b
|
||||
return 64 + uint8(bits.LeadingZeros64(x))
|
||||
} else {
|
||||
panic("Wrong size bit string")
|
||||
}
|
||||
}
|
||||
|
||||
func (node *trieEntry) addToPeerEntries() {
|
||||
node.perPeerElem = node.peer.trieEntries.PushBack(node)
|
||||
}
|
||||
|
||||
func (node *trieEntry) removeFromPeerEntries() {
|
||||
if node.perPeerElem != nil {
|
||||
node.peer.trieEntries.Remove(node.perPeerElem)
|
||||
node.perPeerElem = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (node *trieEntry) choose(ip []byte) byte {
|
||||
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
|
||||
}
|
||||
|
||||
func (node *trieEntry) maskSelf() {
|
||||
mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
|
||||
for i := 0; i < len(mask); i++ {
|
||||
node.bits[i] &= mask[i]
|
||||
}
|
||||
}
|
||||
|
||||
func (node *trieEntry) zeroizePointers() {
|
||||
// Make the garbage collector's life slightly easier
|
||||
node.peer = nil
|
||||
node.child[0] = nil
|
||||
node.child[1] = nil
|
||||
node.parent.parentBit = nil
|
||||
}
|
||||
|
||||
func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
|
||||
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
|
||||
parent = node
|
||||
if parent.cidr == cidr {
|
||||
exact = true
|
||||
return
|
||||
}
|
||||
bit := node.choose(ip)
|
||||
node = node.child[bit]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
|
||||
if *trie.parentBit == nil {
|
||||
node := &trieEntry{
|
||||
peer: peer,
|
||||
parent: trie,
|
||||
bits: ip,
|
||||
cidr: cidr,
|
||||
bitAtByte: cidr / 8,
|
||||
bitAtShift: 7 - (cidr % 8),
|
||||
}
|
||||
node.maskSelf()
|
||||
node.addToPeerEntries()
|
||||
*trie.parentBit = node
|
||||
return
|
||||
}
|
||||
node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
|
||||
if exact {
|
||||
node.removeFromPeerEntries()
|
||||
node.peer = peer
|
||||
node.addToPeerEntries()
|
||||
return
|
||||
}
|
||||
|
||||
newNode := &trieEntry{
|
||||
peer: peer,
|
||||
bits: ip,
|
||||
cidr: cidr,
|
||||
bitAtByte: cidr / 8,
|
||||
bitAtShift: 7 - (cidr % 8),
|
||||
}
|
||||
newNode.maskSelf()
|
||||
newNode.addToPeerEntries()
|
||||
|
||||
var down *trieEntry
|
||||
if node == nil {
|
||||
down = *trie.parentBit
|
||||
} else {
|
||||
bit := node.choose(ip)
|
||||
down = node.child[bit]
|
||||
if down == nil {
|
||||
newNode.parent = parentIndirection{&node.child[bit], bit}
|
||||
node.child[bit] = newNode
|
||||
return
|
||||
}
|
||||
}
|
||||
common := commonBits(down.bits, ip)
|
||||
if common < cidr {
|
||||
cidr = common
|
||||
}
|
||||
parent := node
|
||||
|
||||
if newNode.cidr == cidr {
|
||||
bit := newNode.choose(down.bits)
|
||||
down.parent = parentIndirection{&newNode.child[bit], bit}
|
||||
newNode.child[bit] = down
|
||||
if parent == nil {
|
||||
newNode.parent = trie
|
||||
*trie.parentBit = newNode
|
||||
} else {
|
||||
bit := parent.choose(newNode.bits)
|
||||
newNode.parent = parentIndirection{&parent.child[bit], bit}
|
||||
parent.child[bit] = newNode
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
node = &trieEntry{
|
||||
bits: append([]byte{}, newNode.bits...),
|
||||
cidr: cidr,
|
||||
bitAtByte: cidr / 8,
|
||||
bitAtShift: 7 - (cidr % 8),
|
||||
}
|
||||
node.maskSelf()
|
||||
|
||||
bit := node.choose(down.bits)
|
||||
down.parent = parentIndirection{&node.child[bit], bit}
|
||||
node.child[bit] = down
|
||||
bit = node.choose(newNode.bits)
|
||||
newNode.parent = parentIndirection{&node.child[bit], bit}
|
||||
node.child[bit] = newNode
|
||||
if parent == nil {
|
||||
node.parent = trie
|
||||
*trie.parentBit = node
|
||||
} else {
|
||||
bit := parent.choose(node.bits)
|
||||
node.parent = parentIndirection{&parent.child[bit], bit}
|
||||
parent.child[bit] = node
|
||||
}
|
||||
}
|
||||
|
||||
func (node *trieEntry) lookup(ip []byte) *Peer {
|
||||
var found *Peer
|
||||
size := uint8(len(ip))
|
||||
for node != nil && commonBits(node.bits, ip) >= node.cidr {
|
||||
if node.peer != nil {
|
||||
found = node.peer
|
||||
}
|
||||
if node.bitAtByte == size {
|
||||
break
|
||||
}
|
||||
bit := node.choose(ip)
|
||||
node = node.child[bit]
|
||||
}
|
||||
return found
|
||||
}
|
||||
|
||||
type AllowedIPs struct {
|
||||
IPv4 *trieEntry
|
||||
IPv6 *trieEntry
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
|
||||
table.mutex.RLock()
|
||||
defer table.mutex.RUnlock()
|
||||
|
||||
for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
|
||||
node := elem.Value.(*trieEntry)
|
||||
a, _ := netip.AddrFromSlice(node.bits)
|
||||
if !cb(netip.PrefixFrom(a, int(node.cidr))) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
||||
table.mutex.Lock()
|
||||
defer table.mutex.Unlock()
|
||||
|
||||
var next *list.Element
|
||||
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
|
||||
next = elem.Next()
|
||||
node := elem.Value.(*trieEntry)
|
||||
|
||||
node.removeFromPeerEntries()
|
||||
node.peer = nil
|
||||
if node.child[0] != nil && node.child[1] != nil {
|
||||
continue
|
||||
}
|
||||
bit := 0
|
||||
if node.child[0] == nil {
|
||||
bit = 1
|
||||
}
|
||||
child := node.child[bit]
|
||||
if child != nil {
|
||||
child.parent = node.parent
|
||||
}
|
||||
*node.parent.parentBit = child
|
||||
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
|
||||
node.zeroizePointers()
|
||||
continue
|
||||
}
|
||||
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
|
||||
if parent.peer != nil {
|
||||
node.zeroizePointers()
|
||||
continue
|
||||
}
|
||||
child = parent.child[node.parent.parentBitType^1]
|
||||
if child != nil {
|
||||
child.parent = parent.parent
|
||||
}
|
||||
*parent.parent.parentBit = child
|
||||
node.zeroizePointers()
|
||||
parent.zeroizePointers()
|
||||
}
|
||||
}
|
||||
|
||||
func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
|
||||
table.mutex.Lock()
|
||||
defer table.mutex.Unlock()
|
||||
|
||||
if prefix.Addr().Is6() {
|
||||
ip := prefix.Addr().As16()
|
||||
parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
|
||||
} else if prefix.Addr().Is4() {
|
||||
ip := prefix.Addr().As4()
|
||||
parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
|
||||
} else {
|
||||
panic(errors.New("inserting unknown address type"))
|
||||
}
|
||||
}
|
||||
|
||||
func (table *AllowedIPs) Lookup(ip []byte) *Peer {
|
||||
table.mutex.RLock()
|
||||
defer table.mutex.RUnlock()
|
||||
switch len(ip) {
|
||||
case net.IPv6len:
|
||||
return table.IPv6.lookup(ip)
|
||||
case net.IPv4len:
|
||||
return table.IPv4.lookup(ip)
|
||||
default:
|
||||
panic(errors.New("looking up unknown address type"))
|
||||
}
|
||||
}
|
||||
137
vendor/github.com/tailscale/wireguard-go/device/channels.go
generated
vendored
Normal file
137
vendor/github.com/tailscale/wireguard-go/device/channels.go
generated
vendored
Normal file
@@ -0,0 +1,137 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// An outboundQueue is a channel of QueueOutboundElements awaiting encryption.
|
||||
// An outboundQueue is ref-counted using its wg field.
|
||||
// An outboundQueue created with newOutboundQueue has one reference.
|
||||
// Every additional writer must call wg.Add(1).
|
||||
// Every completed writer must call wg.Done().
|
||||
// When no further writers will be added,
|
||||
// call wg.Done to remove the initial reference.
|
||||
// When the refcount hits 0, the queue's channel is closed.
|
||||
type outboundQueue struct {
|
||||
c chan *QueueOutboundElementsContainer
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func newOutboundQueue() *outboundQueue {
|
||||
q := &outboundQueue{
|
||||
c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
|
||||
}
|
||||
q.wg.Add(1)
|
||||
go func() {
|
||||
q.wg.Wait()
|
||||
close(q.c)
|
||||
}()
|
||||
return q
|
||||
}
|
||||
|
||||
// A inboundQueue is similar to an outboundQueue; see those docs.
|
||||
type inboundQueue struct {
|
||||
c chan *QueueInboundElementsContainer
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func newInboundQueue() *inboundQueue {
|
||||
q := &inboundQueue{
|
||||
c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
|
||||
}
|
||||
q.wg.Add(1)
|
||||
go func() {
|
||||
q.wg.Wait()
|
||||
close(q.c)
|
||||
}()
|
||||
return q
|
||||
}
|
||||
|
||||
// A handshakeQueue is similar to an outboundQueue; see those docs.
|
||||
type handshakeQueue struct {
|
||||
c chan QueueHandshakeElement
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func newHandshakeQueue() *handshakeQueue {
|
||||
q := &handshakeQueue{
|
||||
c: make(chan QueueHandshakeElement, QueueHandshakeSize),
|
||||
}
|
||||
q.wg.Add(1)
|
||||
go func() {
|
||||
q.wg.Wait()
|
||||
close(q.c)
|
||||
}()
|
||||
return q
|
||||
}
|
||||
|
||||
type autodrainingInboundQueue struct {
|
||||
c chan *QueueInboundElementsContainer
|
||||
}
|
||||
|
||||
// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd.
|
||||
// It is useful in cases in which is it hard to manage the lifetime of the channel.
|
||||
// The returned channel must not be closed. Senders should signal shutdown using
|
||||
// some other means, such as sending a sentinel nil values.
|
||||
func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
|
||||
q := &autodrainingInboundQueue{
|
||||
c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
|
||||
}
|
||||
runtime.SetFinalizer(q, device.flushInboundQueue)
|
||||
return q
|
||||
}
|
||||
|
||||
func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
|
||||
for {
|
||||
select {
|
||||
case elemsContainer := <-q.c:
|
||||
elemsContainer.Lock()
|
||||
for _, elem := range elemsContainer.elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutInboundElement(elem)
|
||||
}
|
||||
device.PutInboundElementsContainer(elemsContainer)
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type autodrainingOutboundQueue struct {
|
||||
c chan *QueueOutboundElementsContainer
|
||||
}
|
||||
|
||||
// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd.
|
||||
// It is useful in cases in which is it hard to manage the lifetime of the channel.
|
||||
// The returned channel must not be closed. Senders should signal shutdown using
|
||||
// some other means, such as sending a sentinel nil values.
|
||||
// All sends to the channel must be best-effort, because there may be no receivers.
|
||||
func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
|
||||
q := &autodrainingOutboundQueue{
|
||||
c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
|
||||
}
|
||||
runtime.SetFinalizer(q, device.flushOutboundQueue)
|
||||
return q
|
||||
}
|
||||
|
||||
func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) {
|
||||
for {
|
||||
select {
|
||||
case elemsContainer := <-q.c:
|
||||
elemsContainer.Lock()
|
||||
for _, elem := range elemsContainer.elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
}
|
||||
device.PutOutboundElementsContainer(elemsContainer)
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
40
vendor/github.com/tailscale/wireguard-go/device/constants.go
generated
vendored
Normal file
40
vendor/github.com/tailscale/wireguard-go/device/constants.go
generated
vendored
Normal file
@@ -0,0 +1,40 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
/* Specification constants */
|
||||
|
||||
const (
|
||||
RekeyAfterMessages = (1 << 60)
|
||||
RejectAfterMessages = (1 << 64) - (1 << 13) - 1
|
||||
RekeyAfterTime = time.Second * 120
|
||||
RekeyAttemptTime = time.Second * 90
|
||||
RekeyTimeout = time.Second * 5
|
||||
MaxTimerHandshakes = 90 / 5 /* RekeyAttemptTime / RekeyTimeout */
|
||||
RekeyTimeoutJitterMaxMs = 334
|
||||
RejectAfterTime = time.Second * 180
|
||||
KeepaliveTimeout = time.Second * 10
|
||||
CookieRefreshTime = time.Second * 120
|
||||
HandshakeInitationRate = time.Second / 50
|
||||
PaddingMultiple = 16
|
||||
)
|
||||
|
||||
const (
|
||||
MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive)
|
||||
MaxMessageSize = MaxSegmentSize // maximum size of transport message
|
||||
MaxContentSize = MaxSegmentSize - MessageTransportSize // maximum size of transport message content
|
||||
)
|
||||
|
||||
/* Implementation constants */
|
||||
|
||||
const (
|
||||
UnderLoadAfterTime = time.Second // how long does the device remain under load after detected
|
||||
MaxPeers = 1 << 16 // maximum number of configured peers
|
||||
)
|
||||
248
vendor/github.com/tailscale/wireguard-go/device/cookie.go
generated
vendored
Normal file
248
vendor/github.com/tailscale/wireguard-go/device/cookie.go
generated
vendored
Normal file
@@ -0,0 +1,248 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/blake2s"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
type CookieChecker struct {
|
||||
sync.RWMutex
|
||||
mac1 struct {
|
||||
key [blake2s.Size]byte
|
||||
}
|
||||
mac2 struct {
|
||||
secret [blake2s.Size]byte
|
||||
secretSet time.Time
|
||||
encryptionKey [chacha20poly1305.KeySize]byte
|
||||
}
|
||||
}
|
||||
|
||||
type CookieGenerator struct {
|
||||
sync.RWMutex
|
||||
mac1 struct {
|
||||
key [blake2s.Size]byte
|
||||
}
|
||||
mac2 struct {
|
||||
cookie [blake2s.Size128]byte
|
||||
cookieSet time.Time
|
||||
hasLastMAC1 bool
|
||||
lastMAC1 [blake2s.Size128]byte
|
||||
encryptionKey [chacha20poly1305.KeySize]byte
|
||||
}
|
||||
}
|
||||
|
||||
func (st *CookieChecker) Init(pk NoisePublicKey) {
|
||||
st.Lock()
|
||||
defer st.Unlock()
|
||||
|
||||
// mac1 state
|
||||
|
||||
func() {
|
||||
hash, _ := blake2s.New256(nil)
|
||||
hash.Write([]byte(WGLabelMAC1))
|
||||
hash.Write(pk[:])
|
||||
hash.Sum(st.mac1.key[:0])
|
||||
}()
|
||||
|
||||
// mac2 state
|
||||
|
||||
func() {
|
||||
hash, _ := blake2s.New256(nil)
|
||||
hash.Write([]byte(WGLabelCookie))
|
||||
hash.Write(pk[:])
|
||||
hash.Sum(st.mac2.encryptionKey[:0])
|
||||
}()
|
||||
|
||||
st.mac2.secretSet = time.Time{}
|
||||
}
|
||||
|
||||
func (st *CookieChecker) CheckMAC1(msg []byte) bool {
|
||||
st.RLock()
|
||||
defer st.RUnlock()
|
||||
|
||||
size := len(msg)
|
||||
smac2 := size - blake2s.Size128
|
||||
smac1 := smac2 - blake2s.Size128
|
||||
|
||||
var mac1 [blake2s.Size128]byte
|
||||
|
||||
mac, _ := blake2s.New128(st.mac1.key[:])
|
||||
mac.Write(msg[:smac1])
|
||||
mac.Sum(mac1[:0])
|
||||
|
||||
return hmac.Equal(mac1[:], msg[smac1:smac2])
|
||||
}
|
||||
|
||||
func (st *CookieChecker) CheckMAC2(msg, src []byte) bool {
|
||||
st.RLock()
|
||||
defer st.RUnlock()
|
||||
|
||||
if time.Since(st.mac2.secretSet) > CookieRefreshTime {
|
||||
return false
|
||||
}
|
||||
|
||||
// derive cookie key
|
||||
|
||||
var cookie [blake2s.Size128]byte
|
||||
func() {
|
||||
mac, _ := blake2s.New128(st.mac2.secret[:])
|
||||
mac.Write(src)
|
||||
mac.Sum(cookie[:0])
|
||||
}()
|
||||
|
||||
// calculate mac of packet (including mac1)
|
||||
|
||||
smac2 := len(msg) - blake2s.Size128
|
||||
|
||||
var mac2 [blake2s.Size128]byte
|
||||
func() {
|
||||
mac, _ := blake2s.New128(cookie[:])
|
||||
mac.Write(msg[:smac2])
|
||||
mac.Sum(mac2[:0])
|
||||
}()
|
||||
|
||||
return hmac.Equal(mac2[:], msg[smac2:])
|
||||
}
|
||||
|
||||
func (st *CookieChecker) CreateReply(
|
||||
msg []byte,
|
||||
recv uint32,
|
||||
src []byte,
|
||||
) (*MessageCookieReply, error) {
|
||||
st.RLock()
|
||||
|
||||
// refresh cookie secret
|
||||
|
||||
if time.Since(st.mac2.secretSet) > CookieRefreshTime {
|
||||
st.RUnlock()
|
||||
st.Lock()
|
||||
_, err := rand.Read(st.mac2.secret[:])
|
||||
if err != nil {
|
||||
st.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
st.mac2.secretSet = time.Now()
|
||||
st.Unlock()
|
||||
st.RLock()
|
||||
}
|
||||
|
||||
// derive cookie
|
||||
|
||||
var cookie [blake2s.Size128]byte
|
||||
func() {
|
||||
mac, _ := blake2s.New128(st.mac2.secret[:])
|
||||
mac.Write(src)
|
||||
mac.Sum(cookie[:0])
|
||||
}()
|
||||
|
||||
// encrypt cookie
|
||||
|
||||
size := len(msg)
|
||||
|
||||
smac2 := size - blake2s.Size128
|
||||
smac1 := smac2 - blake2s.Size128
|
||||
|
||||
reply := new(MessageCookieReply)
|
||||
reply.Type = MessageCookieReplyType
|
||||
reply.Receiver = recv
|
||||
|
||||
_, err := rand.Read(reply.Nonce[:])
|
||||
if err != nil {
|
||||
st.RUnlock()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
|
||||
xchapoly.Seal(reply.Cookie[:0], reply.Nonce[:], cookie[:], msg[smac1:smac2])
|
||||
|
||||
st.RUnlock()
|
||||
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
func (st *CookieGenerator) Init(pk NoisePublicKey) {
|
||||
st.Lock()
|
||||
defer st.Unlock()
|
||||
|
||||
func() {
|
||||
hash, _ := blake2s.New256(nil)
|
||||
hash.Write([]byte(WGLabelMAC1))
|
||||
hash.Write(pk[:])
|
||||
hash.Sum(st.mac1.key[:0])
|
||||
}()
|
||||
|
||||
func() {
|
||||
hash, _ := blake2s.New256(nil)
|
||||
hash.Write([]byte(WGLabelCookie))
|
||||
hash.Write(pk[:])
|
||||
hash.Sum(st.mac2.encryptionKey[:0])
|
||||
}()
|
||||
|
||||
st.mac2.cookieSet = time.Time{}
|
||||
}
|
||||
|
||||
func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
|
||||
st.Lock()
|
||||
defer st.Unlock()
|
||||
|
||||
if !st.mac2.hasLastMAC1 {
|
||||
return false
|
||||
}
|
||||
|
||||
var cookie [blake2s.Size128]byte
|
||||
|
||||
xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
|
||||
_, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
st.mac2.cookieSet = time.Now()
|
||||
st.mac2.cookie = cookie
|
||||
return true
|
||||
}
|
||||
|
||||
func (st *CookieGenerator) AddMacs(msg []byte) {
|
||||
size := len(msg)
|
||||
|
||||
smac2 := size - blake2s.Size128
|
||||
smac1 := smac2 - blake2s.Size128
|
||||
|
||||
mac1 := msg[smac1:smac2]
|
||||
mac2 := msg[smac2:]
|
||||
|
||||
st.Lock()
|
||||
defer st.Unlock()
|
||||
|
||||
// set mac1
|
||||
|
||||
func() {
|
||||
mac, _ := blake2s.New128(st.mac1.key[:])
|
||||
mac.Write(msg[:smac1])
|
||||
mac.Sum(mac1[:0])
|
||||
}()
|
||||
copy(st.mac2.lastMAC1[:], mac1)
|
||||
st.mac2.hasLastMAC1 = true
|
||||
|
||||
// set mac2
|
||||
|
||||
if time.Since(st.mac2.cookieSet) > CookieRefreshTime {
|
||||
return
|
||||
}
|
||||
|
||||
func() {
|
||||
mac, _ := blake2s.New128(st.mac2.cookie[:])
|
||||
mac.Write(msg[:smac2])
|
||||
mac.Sum(mac2[:0])
|
||||
}()
|
||||
}
|
||||
536
vendor/github.com/tailscale/wireguard-go/device/device.go
generated
vendored
Normal file
536
vendor/github.com/tailscale/wireguard-go/device/device.go
generated
vendored
Normal file
@@ -0,0 +1,536 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/wireguard-go/conn"
|
||||
"github.com/tailscale/wireguard-go/ratelimiter"
|
||||
"github.com/tailscale/wireguard-go/rwcancel"
|
||||
"github.com/tailscale/wireguard-go/tun"
|
||||
)
|
||||
|
||||
type Device struct {
|
||||
state struct {
|
||||
// state holds the device's state. It is accessed atomically.
|
||||
// Use the device.deviceState method to read it.
|
||||
// device.deviceState does not acquire the mutex, so it captures only a snapshot.
|
||||
// During state transitions, the state variable is updated before the device itself.
|
||||
// The state is thus either the current state of the device or
|
||||
// the intended future state of the device.
|
||||
// For example, while executing a call to Up, state will be deviceStateUp.
|
||||
// There is no guarantee that that intended future state of the device
|
||||
// will become the actual state; Up can fail.
|
||||
// The device can also change state multiple times between time of check and time of use.
|
||||
// Unsynchronized uses of state must therefore be advisory/best-effort only.
|
||||
state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience
|
||||
// stopping blocks until all inputs to Device have been closed.
|
||||
stopping sync.WaitGroup
|
||||
// mu protects state changes.
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
net struct {
|
||||
stopping sync.WaitGroup
|
||||
sync.RWMutex
|
||||
bind conn.Bind // bind interface
|
||||
netlinkCancel *rwcancel.RWCancel
|
||||
port uint16 // listening port
|
||||
fwmark uint32 // mark value (0 = disabled)
|
||||
brokenRoaming bool
|
||||
}
|
||||
|
||||
staticIdentity struct {
|
||||
sync.RWMutex
|
||||
privateKey NoisePrivateKey
|
||||
publicKey NoisePublicKey
|
||||
}
|
||||
|
||||
peers struct {
|
||||
sync.RWMutex // protects keyMap
|
||||
keyMap map[NoisePublicKey]*Peer
|
||||
}
|
||||
|
||||
rate struct {
|
||||
underLoadUntil atomic.Int64
|
||||
limiter ratelimiter.Ratelimiter
|
||||
}
|
||||
|
||||
allowedips AllowedIPs
|
||||
indexTable IndexTable
|
||||
cookieChecker CookieChecker
|
||||
|
||||
pool struct {
|
||||
inboundElementsContainer *WaitPool
|
||||
outboundElementsContainer *WaitPool
|
||||
messageBuffers *WaitPool
|
||||
inboundElements *WaitPool
|
||||
outboundElements *WaitPool
|
||||
}
|
||||
|
||||
queue struct {
|
||||
encryption *outboundQueue
|
||||
decryption *inboundQueue
|
||||
handshake *handshakeQueue
|
||||
}
|
||||
|
||||
tun struct {
|
||||
device tun.Device
|
||||
mtu atomic.Int32
|
||||
}
|
||||
|
||||
ipcMutex sync.RWMutex
|
||||
closed chan struct{}
|
||||
log *Logger
|
||||
}
|
||||
|
||||
// deviceState represents the state of a Device.
|
||||
// There are three states: down, up, closed.
|
||||
// Transitions:
|
||||
//
|
||||
// down -----+
|
||||
// ↑↓ ↓
|
||||
// up -> closed
|
||||
type deviceState uint32
|
||||
|
||||
//go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
|
||||
const (
|
||||
deviceStateDown deviceState = iota
|
||||
deviceStateUp
|
||||
deviceStateClosed
|
||||
)
|
||||
|
||||
// deviceState returns device.state.state as a deviceState
|
||||
// See those docs for how to interpret this value.
|
||||
func (device *Device) deviceState() deviceState {
|
||||
return deviceState(device.state.state.Load())
|
||||
}
|
||||
|
||||
// isClosed reports whether the device is closed (or is closing).
|
||||
// See device.state.state comments for how to interpret this value.
|
||||
func (device *Device) isClosed() bool {
|
||||
return device.deviceState() == deviceStateClosed
|
||||
}
|
||||
|
||||
// isUp reports whether the device is up (or is attempting to come up).
|
||||
// See device.state.state comments for how to interpret this value.
|
||||
func (device *Device) isUp() bool {
|
||||
return device.deviceState() == deviceStateUp
|
||||
}
|
||||
|
||||
// Must hold device.peers.Lock()
|
||||
func removePeerLocked(device *Device, peer *Peer, key NoisePublicKey) {
|
||||
// stop routing and processing of packets
|
||||
device.allowedips.RemoveByPeer(peer)
|
||||
peer.Stop()
|
||||
|
||||
// remove from peer map
|
||||
delete(device.peers.keyMap, key)
|
||||
}
|
||||
|
||||
// changeState attempts to change the device state to match want.
|
||||
func (device *Device) changeState(want deviceState) (err error) {
|
||||
device.state.Lock()
|
||||
defer device.state.Unlock()
|
||||
old := device.deviceState()
|
||||
if old == deviceStateClosed {
|
||||
// once closed, always closed
|
||||
device.log.Verbosef("Interface closed, ignored requested state %s", want)
|
||||
return nil
|
||||
}
|
||||
switch want {
|
||||
case old:
|
||||
return nil
|
||||
case deviceStateUp:
|
||||
device.state.state.Store(uint32(deviceStateUp))
|
||||
err = device.upLocked()
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
fallthrough // up failed; bring the device all the way back down
|
||||
case deviceStateDown:
|
||||
device.state.state.Store(uint32(deviceStateDown))
|
||||
errDown := device.downLocked()
|
||||
if err == nil {
|
||||
err = errDown
|
||||
}
|
||||
}
|
||||
device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState())
|
||||
return
|
||||
}
|
||||
|
||||
// upLocked attempts to bring the device up and reports whether it succeeded.
|
||||
// The caller must hold device.state.mu and is responsible for updating device.state.state.
|
||||
func (device *Device) upLocked() error {
|
||||
if err := device.BindUpdate(); err != nil {
|
||||
device.log.Errorf("Unable to update bind: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// The IPC set operation waits for peers to be created before calling Start() on them,
|
||||
// so if there's a concurrent IPC set request happening, we should wait for it to complete.
|
||||
device.ipcMutex.Lock()
|
||||
defer device.ipcMutex.Unlock()
|
||||
|
||||
device.peers.RLock()
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.Start()
|
||||
if peer.persistentKeepaliveInterval.Load() > 0 {
|
||||
peer.SendKeepalive()
|
||||
}
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// downLocked attempts to bring the device down.
|
||||
// The caller must hold device.state.mu and is responsible for updating device.state.state.
|
||||
func (device *Device) downLocked() error {
|
||||
err := device.BindClose()
|
||||
if err != nil {
|
||||
device.log.Errorf("Bind close failed: %v", err)
|
||||
}
|
||||
|
||||
device.peers.RLock()
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.Stop()
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (device *Device) Up() error {
|
||||
return device.changeState(deviceStateUp)
|
||||
}
|
||||
|
||||
func (device *Device) Down() error {
|
||||
return device.changeState(deviceStateDown)
|
||||
}
|
||||
|
||||
func (device *Device) IsUnderLoad() bool {
|
||||
// check if currently under load
|
||||
now := time.Now()
|
||||
underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8
|
||||
if underLoad {
|
||||
device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano())
|
||||
return true
|
||||
}
|
||||
// check if recently under load
|
||||
return device.rate.underLoadUntil.Load() > now.UnixNano()
|
||||
}
|
||||
|
||||
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
||||
// lock required resources
|
||||
|
||||
device.staticIdentity.Lock()
|
||||
defer device.staticIdentity.Unlock()
|
||||
|
||||
if sk.Equals(device.staticIdentity.privateKey) {
|
||||
return nil
|
||||
}
|
||||
|
||||
device.peers.Lock()
|
||||
defer device.peers.Unlock()
|
||||
|
||||
lockedPeers := make([]*Peer, 0, len(device.peers.keyMap))
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.handshake.mutex.RLock()
|
||||
lockedPeers = append(lockedPeers, peer)
|
||||
}
|
||||
|
||||
// remove peers with matching public keys
|
||||
|
||||
publicKey := sk.publicKey()
|
||||
for key, peer := range device.peers.keyMap {
|
||||
if peer.handshake.remoteStatic.Equals(publicKey) {
|
||||
peer.handshake.mutex.RUnlock()
|
||||
removePeerLocked(device, peer, key)
|
||||
peer.handshake.mutex.RLock()
|
||||
}
|
||||
}
|
||||
|
||||
// update key material
|
||||
|
||||
device.staticIdentity.privateKey = sk
|
||||
device.staticIdentity.publicKey = publicKey
|
||||
device.cookieChecker.Init(publicKey)
|
||||
|
||||
// do static-static DH pre-computations
|
||||
|
||||
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
|
||||
for _, peer := range device.peers.keyMap {
|
||||
handshake := &peer.handshake
|
||||
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
|
||||
expiredPeers = append(expiredPeers, peer)
|
||||
}
|
||||
|
||||
for _, peer := range lockedPeers {
|
||||
peer.handshake.mutex.RUnlock()
|
||||
}
|
||||
for _, peer := range expiredPeers {
|
||||
peer.ExpireCurrentKeypairs()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
|
||||
device := new(Device)
|
||||
device.state.state.Store(uint32(deviceStateDown))
|
||||
device.closed = make(chan struct{})
|
||||
device.log = logger
|
||||
device.net.bind = bind
|
||||
device.tun.device = tunDevice
|
||||
mtu, err := device.tun.device.MTU()
|
||||
if err != nil {
|
||||
device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
|
||||
mtu = DefaultMTU
|
||||
}
|
||||
device.tun.mtu.Store(int32(mtu))
|
||||
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
|
||||
device.rate.limiter.Init()
|
||||
device.indexTable.Init()
|
||||
|
||||
device.PopulatePools()
|
||||
|
||||
// create queues
|
||||
|
||||
device.queue.handshake = newHandshakeQueue()
|
||||
device.queue.encryption = newOutboundQueue()
|
||||
device.queue.decryption = newInboundQueue()
|
||||
|
||||
// start workers
|
||||
|
||||
cpus := runtime.NumCPU()
|
||||
device.state.stopping.Wait()
|
||||
device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake
|
||||
for i := 0; i < cpus; i++ {
|
||||
go device.RoutineEncryption(i + 1)
|
||||
go device.RoutineDecryption(i + 1)
|
||||
go device.RoutineHandshake(i + 1)
|
||||
}
|
||||
|
||||
device.state.stopping.Add(1) // RoutineReadFromTUN
|
||||
device.queue.encryption.wg.Add(1) // RoutineReadFromTUN
|
||||
go device.RoutineReadFromTUN()
|
||||
go device.RoutineTUNEventReader()
|
||||
|
||||
return device
|
||||
}
|
||||
|
||||
// BatchSize returns the BatchSize for the device as a whole which is the max of
|
||||
// the bind batch size and the tun batch size. The batch size reported by device
|
||||
// is the size used to construct memory pools, and is the allowed batch size for
|
||||
// the lifetime of the device.
|
||||
func (device *Device) BatchSize() int {
|
||||
size := device.net.bind.BatchSize()
|
||||
dSize := device.tun.device.BatchSize()
|
||||
if size < dSize {
|
||||
size = dSize
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
|
||||
device.peers.RLock()
|
||||
defer device.peers.RUnlock()
|
||||
|
||||
return device.peers.keyMap[pk]
|
||||
}
|
||||
|
||||
func (device *Device) RemovePeer(key NoisePublicKey) {
|
||||
device.peers.Lock()
|
||||
defer device.peers.Unlock()
|
||||
// stop peer and remove from routing
|
||||
|
||||
peer, ok := device.peers.keyMap[key]
|
||||
if ok {
|
||||
removePeerLocked(device, peer, key)
|
||||
}
|
||||
}
|
||||
|
||||
func (device *Device) RemoveAllPeers() {
|
||||
device.peers.Lock()
|
||||
defer device.peers.Unlock()
|
||||
|
||||
for key, peer := range device.peers.keyMap {
|
||||
removePeerLocked(device, peer, key)
|
||||
}
|
||||
|
||||
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
|
||||
}
|
||||
|
||||
func (device *Device) Close() {
|
||||
device.ipcMutex.Lock()
|
||||
defer device.ipcMutex.Unlock()
|
||||
device.state.Lock()
|
||||
defer device.state.Unlock()
|
||||
if device.isClosed() {
|
||||
return
|
||||
}
|
||||
device.state.state.Store(uint32(deviceStateClosed))
|
||||
device.log.Verbosef("Device closing")
|
||||
|
||||
device.tun.device.Close()
|
||||
device.downLocked()
|
||||
|
||||
// Remove peers before closing queues,
|
||||
// because peers assume that queues are active.
|
||||
device.RemoveAllPeers()
|
||||
|
||||
// We kept a reference to the encryption and decryption queues,
|
||||
// in case we started any new peers that might write to them.
|
||||
// No new peers are coming; we are done with these queues.
|
||||
device.queue.encryption.wg.Done()
|
||||
device.queue.decryption.wg.Done()
|
||||
device.queue.handshake.wg.Done()
|
||||
device.state.stopping.Wait()
|
||||
|
||||
device.rate.limiter.Close()
|
||||
|
||||
device.log.Verbosef("Device closed")
|
||||
close(device.closed)
|
||||
}
|
||||
|
||||
func (device *Device) Wait() chan struct{} {
|
||||
return device.closed
|
||||
}
|
||||
|
||||
func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
|
||||
if !device.isUp() {
|
||||
return
|
||||
}
|
||||
|
||||
device.peers.RLock()
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.keypairs.RLock()
|
||||
sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now())
|
||||
peer.keypairs.RUnlock()
|
||||
if sendKeepalive {
|
||||
peer.SendKeepalive()
|
||||
}
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
}
|
||||
|
||||
// closeBindLocked closes the device's net.bind.
|
||||
// The caller must hold the net mutex.
|
||||
func closeBindLocked(device *Device) error {
|
||||
var err error
|
||||
netc := &device.net
|
||||
if netc.netlinkCancel != nil {
|
||||
netc.netlinkCancel.Cancel()
|
||||
}
|
||||
if netc.bind != nil {
|
||||
err = netc.bind.Close()
|
||||
}
|
||||
netc.stopping.Wait()
|
||||
return err
|
||||
}
|
||||
|
||||
func (device *Device) Bind() conn.Bind {
|
||||
device.net.Lock()
|
||||
defer device.net.Unlock()
|
||||
return device.net.bind
|
||||
}
|
||||
|
||||
func (device *Device) BindSetMark(mark uint32) error {
|
||||
device.net.Lock()
|
||||
defer device.net.Unlock()
|
||||
|
||||
// check if modified
|
||||
if device.net.fwmark == mark {
|
||||
return nil
|
||||
}
|
||||
|
||||
// update fwmark on existing bind
|
||||
device.net.fwmark = mark
|
||||
if device.isUp() && device.net.bind != nil {
|
||||
if err := device.net.bind.SetMark(mark); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// clear cached source addresses
|
||||
device.peers.RLock()
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.markEndpointSrcForClearing()
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (device *Device) BindUpdate() error {
|
||||
device.net.Lock()
|
||||
defer device.net.Unlock()
|
||||
|
||||
// close existing sockets
|
||||
if err := closeBindLocked(device); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// open new sockets
|
||||
if !device.isUp() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// bind to new port
|
||||
var err error
|
||||
var recvFns []conn.ReceiveFunc
|
||||
netc := &device.net
|
||||
|
||||
recvFns, netc.port, err = netc.bind.Open(netc.port)
|
||||
if err != nil {
|
||||
netc.port = 0
|
||||
return err
|
||||
}
|
||||
|
||||
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
|
||||
if err != nil {
|
||||
netc.bind.Close()
|
||||
netc.port = 0
|
||||
return err
|
||||
}
|
||||
|
||||
// set fwmark
|
||||
if netc.fwmark != 0 {
|
||||
err = netc.bind.SetMark(netc.fwmark)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// clear cached source addresses
|
||||
device.peers.RLock()
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.markEndpointSrcForClearing()
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
|
||||
// start receiving routines
|
||||
device.net.stopping.Add(len(recvFns))
|
||||
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
|
||||
device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
|
||||
batchSize := netc.bind.BatchSize()
|
||||
for _, fn := range recvFns {
|
||||
go device.RoutineReceiveIncoming(batchSize, fn)
|
||||
}
|
||||
|
||||
device.log.Verbosef("UDP bind has been updated")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (device *Device) BindClose() error {
|
||||
device.net.Lock()
|
||||
err := closeBindLocked(device)
|
||||
device.net.Unlock()
|
||||
return err
|
||||
}
|
||||
16
vendor/github.com/tailscale/wireguard-go/device/devicestate_string.go
generated
vendored
Normal file
16
vendor/github.com/tailscale/wireguard-go/device/devicestate_string.go
generated
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
// Code generated by "stringer -type deviceState -trimprefix=deviceState"; DO NOT EDIT.
|
||||
|
||||
package device
|
||||
|
||||
import "strconv"
|
||||
|
||||
const _deviceState_name = "DownUpClosed"
|
||||
|
||||
var _deviceState_index = [...]uint8{0, 4, 6, 12}
|
||||
|
||||
func (i deviceState) String() string {
|
||||
if i >= deviceState(len(_deviceState_index)-1) {
|
||||
return "deviceState(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
return _deviceState_name[_deviceState_index[i]:_deviceState_index[i+1]]
|
||||
}
|
||||
98
vendor/github.com/tailscale/wireguard-go/device/indextable.go
generated
vendored
Normal file
98
vendor/github.com/tailscale/wireguard-go/device/indextable.go
generated
vendored
Normal file
@@ -0,0 +1,98 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type IndexTableEntry struct {
|
||||
peer *Peer
|
||||
handshake *Handshake
|
||||
keypair *Keypair
|
||||
}
|
||||
|
||||
type IndexTable struct {
|
||||
sync.RWMutex
|
||||
table map[uint32]IndexTableEntry
|
||||
}
|
||||
|
||||
func randUint32() (uint32, error) {
|
||||
var integer [4]byte
|
||||
_, err := rand.Read(integer[:])
|
||||
// Arbitrary endianness; both are intrinsified by the Go compiler.
|
||||
return binary.LittleEndian.Uint32(integer[:]), err
|
||||
}
|
||||
|
||||
func (table *IndexTable) Init() {
|
||||
table.Lock()
|
||||
defer table.Unlock()
|
||||
table.table = make(map[uint32]IndexTableEntry)
|
||||
}
|
||||
|
||||
func (table *IndexTable) Delete(index uint32) {
|
||||
table.Lock()
|
||||
defer table.Unlock()
|
||||
delete(table.table, index)
|
||||
}
|
||||
|
||||
func (table *IndexTable) SwapIndexForKeypair(index uint32, keypair *Keypair) {
|
||||
table.Lock()
|
||||
defer table.Unlock()
|
||||
entry, ok := table.table[index]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
table.table[index] = IndexTableEntry{
|
||||
peer: entry.peer,
|
||||
keypair: keypair,
|
||||
handshake: nil,
|
||||
}
|
||||
}
|
||||
|
||||
func (table *IndexTable) NewIndexForHandshake(peer *Peer, handshake *Handshake) (uint32, error) {
|
||||
for {
|
||||
// generate random index
|
||||
|
||||
index, err := randUint32()
|
||||
if err != nil {
|
||||
return index, err
|
||||
}
|
||||
|
||||
// check if index used
|
||||
|
||||
table.RLock()
|
||||
_, ok := table.table[index]
|
||||
table.RUnlock()
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// check again while locked
|
||||
|
||||
table.Lock()
|
||||
_, found := table.table[index]
|
||||
if found {
|
||||
table.Unlock()
|
||||
continue
|
||||
}
|
||||
table.table[index] = IndexTableEntry{
|
||||
peer: peer,
|
||||
handshake: handshake,
|
||||
keypair: nil,
|
||||
}
|
||||
table.Unlock()
|
||||
return index, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (table *IndexTable) Lookup(id uint32) IndexTableEntry {
|
||||
table.RLock()
|
||||
defer table.RUnlock()
|
||||
return table.table[id]
|
||||
}
|
||||
22
vendor/github.com/tailscale/wireguard-go/device/ip.go
generated
vendored
Normal file
22
vendor/github.com/tailscale/wireguard-go/device/ip.go
generated
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
const (
|
||||
IPv4offsetTotalLength = 2
|
||||
IPv4offsetSrc = 12
|
||||
IPv4offsetDst = IPv4offsetSrc + net.IPv4len
|
||||
)
|
||||
|
||||
const (
|
||||
IPv6offsetPayloadLength = 4
|
||||
IPv6offsetSrc = 8
|
||||
IPv6offsetDst = IPv6offsetSrc + net.IPv6len
|
||||
)
|
||||
52
vendor/github.com/tailscale/wireguard-go/device/keypair.go
generated
vendored
Normal file
52
vendor/github.com/tailscale/wireguard-go/device/keypair.go
generated
vendored
Normal file
@@ -0,0 +1,52 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/wireguard-go/replay"
|
||||
)
|
||||
|
||||
/* Due to limitations in Go and /x/crypto there is currently
|
||||
* no way to ensure that key material is securely ereased in memory.
|
||||
*
|
||||
* Since this may harm the forward secrecy property,
|
||||
* we plan to resolve this issue; whenever Go allows us to do so.
|
||||
*/
|
||||
|
||||
type Keypair struct {
|
||||
sendNonce atomic.Uint64
|
||||
send cipher.AEAD
|
||||
receive cipher.AEAD
|
||||
replayFilter replay.Filter
|
||||
isInitiator bool
|
||||
created time.Time
|
||||
localIndex uint32
|
||||
remoteIndex uint32
|
||||
}
|
||||
|
||||
type Keypairs struct {
|
||||
sync.RWMutex
|
||||
current *Keypair
|
||||
previous *Keypair
|
||||
next atomic.Pointer[Keypair]
|
||||
}
|
||||
|
||||
func (kp *Keypairs) Current() *Keypair {
|
||||
kp.RLock()
|
||||
defer kp.RUnlock()
|
||||
return kp.current
|
||||
}
|
||||
|
||||
func (device *Device) DeleteKeypair(key *Keypair) {
|
||||
if key != nil {
|
||||
device.indexTable.Delete(key.localIndex)
|
||||
}
|
||||
}
|
||||
48
vendor/github.com/tailscale/wireguard-go/device/logger.go
generated
vendored
Normal file
48
vendor/github.com/tailscale/wireguard-go/device/logger.go
generated
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
||||
// A Logger provides logging for a Device.
|
||||
// The functions are Printf-style functions.
|
||||
// They must be safe for concurrent use.
|
||||
// They do not require a trailing newline in the format.
|
||||
// If nil, that level of logging will be silent.
|
||||
type Logger struct {
|
||||
Verbosef func(format string, args ...any)
|
||||
Errorf func(format string, args ...any)
|
||||
}
|
||||
|
||||
// Log levels for use with NewLogger.
|
||||
const (
|
||||
LogLevelSilent = iota
|
||||
LogLevelError
|
||||
LogLevelVerbose
|
||||
)
|
||||
|
||||
// Function for use in Logger for discarding logged lines.
|
||||
func DiscardLogf(format string, args ...any) {}
|
||||
|
||||
// NewLogger constructs a Logger that writes to stdout.
|
||||
// It logs at the specified log level and above.
|
||||
// It decorates log lines with the log level, date, time, and prepend.
|
||||
func NewLogger(level int, prepend string) *Logger {
|
||||
logger := &Logger{DiscardLogf, DiscardLogf}
|
||||
logf := func(prefix string) func(string, ...any) {
|
||||
return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf
|
||||
}
|
||||
if level >= LogLevelVerbose {
|
||||
logger.Verbosef = logf("DEBUG")
|
||||
}
|
||||
if level >= LogLevelError {
|
||||
logger.Errorf = logf("ERROR")
|
||||
}
|
||||
return logger
|
||||
}
|
||||
19
vendor/github.com/tailscale/wireguard-go/device/mobilequirks.go
generated
vendored
Normal file
19
vendor/github.com/tailscale/wireguard-go/device/mobilequirks.go
generated
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
// DisableSomeRoamingForBrokenMobileSemantics should ideally be called before peers are created,
|
||||
// though it will try to deal with it, and race maybe, if called after.
|
||||
func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
|
||||
device.net.brokenRoaming = true
|
||||
device.peers.RLock()
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.endpoint.Lock()
|
||||
peer.endpoint.disableRoaming = peer.endpoint.val != nil
|
||||
peer.endpoint.Unlock()
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
}
|
||||
108
vendor/github.com/tailscale/wireguard-go/device/noise-helpers.go
generated
vendored
Normal file
108
vendor/github.com/tailscale/wireguard-go/device/noise-helpers.go
generated
vendored
Normal file
@@ -0,0 +1,108 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
"hash"
|
||||
|
||||
"golang.org/x/crypto/blake2s"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
)
|
||||
|
||||
/* KDF related functions.
|
||||
* HMAC-based Key Derivation Function (HKDF)
|
||||
* https://tools.ietf.org/html/rfc5869
|
||||
*/
|
||||
|
||||
func HMAC1(sum *[blake2s.Size]byte, key, in0 []byte) {
|
||||
mac := hmac.New(func() hash.Hash {
|
||||
h, _ := blake2s.New256(nil)
|
||||
return h
|
||||
}, key)
|
||||
mac.Write(in0)
|
||||
mac.Sum(sum[:0])
|
||||
}
|
||||
|
||||
func HMAC2(sum *[blake2s.Size]byte, key, in0, in1 []byte) {
|
||||
mac := hmac.New(func() hash.Hash {
|
||||
h, _ := blake2s.New256(nil)
|
||||
return h
|
||||
}, key)
|
||||
mac.Write(in0)
|
||||
mac.Write(in1)
|
||||
mac.Sum(sum[:0])
|
||||
}
|
||||
|
||||
func KDF1(t0 *[blake2s.Size]byte, key, input []byte) {
|
||||
HMAC1(t0, key, input)
|
||||
HMAC1(t0, t0[:], []byte{0x1})
|
||||
}
|
||||
|
||||
func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) {
|
||||
var prk [blake2s.Size]byte
|
||||
HMAC1(&prk, key, input)
|
||||
HMAC1(t0, prk[:], []byte{0x1})
|
||||
HMAC2(t1, prk[:], t0[:], []byte{0x2})
|
||||
setZero(prk[:])
|
||||
}
|
||||
|
||||
func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) {
|
||||
var prk [blake2s.Size]byte
|
||||
HMAC1(&prk, key, input)
|
||||
HMAC1(t0, prk[:], []byte{0x1})
|
||||
HMAC2(t1, prk[:], t0[:], []byte{0x2})
|
||||
HMAC2(t2, prk[:], t1[:], []byte{0x3})
|
||||
setZero(prk[:])
|
||||
}
|
||||
|
||||
func isZero(val []byte) bool {
|
||||
acc := 1
|
||||
for _, b := range val {
|
||||
acc &= subtle.ConstantTimeByteEq(b, 0)
|
||||
}
|
||||
return acc == 1
|
||||
}
|
||||
|
||||
/* This function is not used as pervasively as it should because this is mostly impossible in Go at the moment */
|
||||
func setZero(arr []byte) {
|
||||
for i := range arr {
|
||||
arr[i] = 0
|
||||
}
|
||||
}
|
||||
|
||||
func (sk *NoisePrivateKey) clamp() {
|
||||
sk[0] &= 248
|
||||
sk[31] = (sk[31] & 127) | 64
|
||||
}
|
||||
|
||||
func newPrivateKey() (sk NoisePrivateKey, err error) {
|
||||
_, err = rand.Read(sk[:])
|
||||
sk.clamp()
|
||||
return
|
||||
}
|
||||
|
||||
func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
|
||||
apk := (*[NoisePublicKeySize]byte)(&pk)
|
||||
ask := (*[NoisePrivateKeySize]byte)(sk)
|
||||
curve25519.ScalarBaseMult(apk, ask)
|
||||
return
|
||||
}
|
||||
|
||||
var errInvalidPublicKey = errors.New("invalid public key")
|
||||
|
||||
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) {
|
||||
apk := (*[NoisePublicKeySize]byte)(&pk)
|
||||
ask := (*[NoisePrivateKeySize]byte)(sk)
|
||||
curve25519.ScalarMult(&ss, ask, apk)
|
||||
if isZero(ss[:]) {
|
||||
return ss, errInvalidPublicKey
|
||||
}
|
||||
return ss, nil
|
||||
}
|
||||
625
vendor/github.com/tailscale/wireguard-go/device/noise-protocol.go
generated
vendored
Normal file
625
vendor/github.com/tailscale/wireguard-go/device/noise-protocol.go
generated
vendored
Normal file
@@ -0,0 +1,625 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/blake2s"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/poly1305"
|
||||
|
||||
"github.com/tailscale/wireguard-go/tai64n"
|
||||
)
|
||||
|
||||
type handshakeState int
|
||||
|
||||
const (
|
||||
handshakeZeroed = handshakeState(iota)
|
||||
handshakeInitiationCreated
|
||||
handshakeInitiationConsumed
|
||||
handshakeResponseCreated
|
||||
handshakeResponseConsumed
|
||||
)
|
||||
|
||||
func (hs handshakeState) String() string {
|
||||
switch hs {
|
||||
case handshakeZeroed:
|
||||
return "handshakeZeroed"
|
||||
case handshakeInitiationCreated:
|
||||
return "handshakeInitiationCreated"
|
||||
case handshakeInitiationConsumed:
|
||||
return "handshakeInitiationConsumed"
|
||||
case handshakeResponseCreated:
|
||||
return "handshakeResponseCreated"
|
||||
case handshakeResponseConsumed:
|
||||
return "handshakeResponseConsumed"
|
||||
default:
|
||||
return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs))
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
|
||||
WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com"
|
||||
WGLabelMAC1 = "mac1----"
|
||||
WGLabelCookie = "cookie--"
|
||||
)
|
||||
|
||||
const (
|
||||
MessageInitiationType = 1
|
||||
MessageResponseType = 2
|
||||
MessageCookieReplyType = 3
|
||||
MessageTransportType = 4
|
||||
)
|
||||
|
||||
const (
|
||||
MessageInitiationSize = 148 // size of handshake initiation message
|
||||
MessageResponseSize = 92 // size of response message
|
||||
MessageCookieReplySize = 64 // size of cookie reply message
|
||||
MessageTransportHeaderSize = 16 // size of data preceding content in transport message
|
||||
MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
|
||||
MessageKeepaliveSize = MessageTransportSize // size of keepalive
|
||||
MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message
|
||||
)
|
||||
|
||||
const (
|
||||
MessageTransportOffsetReceiver = 4
|
||||
MessageTransportOffsetCounter = 8
|
||||
MessageTransportOffsetContent = 16
|
||||
)
|
||||
|
||||
/* Type is an 8-bit field, followed by 3 nul bytes,
|
||||
* by marshalling the messages in little-endian byteorder
|
||||
* we can treat these as a 32-bit unsigned int (for now)
|
||||
*
|
||||
*/
|
||||
|
||||
type MessageInitiation struct {
|
||||
Type uint32
|
||||
Sender uint32
|
||||
Ephemeral NoisePublicKey
|
||||
Static [NoisePublicKeySize + poly1305.TagSize]byte
|
||||
Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte
|
||||
MAC1 [blake2s.Size128]byte
|
||||
MAC2 [blake2s.Size128]byte
|
||||
}
|
||||
|
||||
type MessageResponse struct {
|
||||
Type uint32
|
||||
Sender uint32
|
||||
Receiver uint32
|
||||
Ephemeral NoisePublicKey
|
||||
Empty [poly1305.TagSize]byte
|
||||
MAC1 [blake2s.Size128]byte
|
||||
MAC2 [blake2s.Size128]byte
|
||||
}
|
||||
|
||||
type MessageTransport struct {
|
||||
Type uint32
|
||||
Receiver uint32
|
||||
Counter uint64
|
||||
Content []byte
|
||||
}
|
||||
|
||||
type MessageCookieReply struct {
|
||||
Type uint32
|
||||
Receiver uint32
|
||||
Nonce [chacha20poly1305.NonceSizeX]byte
|
||||
Cookie [blake2s.Size128 + poly1305.TagSize]byte
|
||||
}
|
||||
|
||||
type Handshake struct {
|
||||
state handshakeState
|
||||
mutex sync.RWMutex
|
||||
hash [blake2s.Size]byte // hash value
|
||||
chainKey [blake2s.Size]byte // chain key
|
||||
presharedKey NoisePresharedKey // psk
|
||||
localEphemeral NoisePrivateKey // ephemeral secret key
|
||||
localIndex uint32 // used to clear hash-table
|
||||
remoteIndex uint32 // index for sending
|
||||
remoteStatic NoisePublicKey // long term key, never changes, can be accessed without mutex
|
||||
remoteEphemeral NoisePublicKey // ephemeral public key
|
||||
precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret
|
||||
lastTimestamp tai64n.Timestamp
|
||||
lastInitiationConsumption time.Time
|
||||
lastSentHandshake time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
InitialChainKey [blake2s.Size]byte
|
||||
InitialHash [blake2s.Size]byte
|
||||
ZeroNonce [chacha20poly1305.NonceSize]byte
|
||||
)
|
||||
|
||||
func mixKey(dst, c *[blake2s.Size]byte, data []byte) {
|
||||
KDF1(dst, c[:], data)
|
||||
}
|
||||
|
||||
func mixHash(dst, h *[blake2s.Size]byte, data []byte) {
|
||||
hash, _ := blake2s.New256(nil)
|
||||
hash.Write(h[:])
|
||||
hash.Write(data)
|
||||
hash.Sum(dst[:0])
|
||||
hash.Reset()
|
||||
}
|
||||
|
||||
func (h *Handshake) Clear() {
|
||||
setZero(h.localEphemeral[:])
|
||||
setZero(h.remoteEphemeral[:])
|
||||
setZero(h.chainKey[:])
|
||||
setZero(h.hash[:])
|
||||
h.localIndex = 0
|
||||
h.state = handshakeZeroed
|
||||
}
|
||||
|
||||
func (h *Handshake) mixHash(data []byte) {
|
||||
mixHash(&h.hash, &h.hash, data)
|
||||
}
|
||||
|
||||
func (h *Handshake) mixKey(data []byte) {
|
||||
mixKey(&h.chainKey, &h.chainKey, data)
|
||||
}
|
||||
|
||||
/* Do basic precomputations
|
||||
*/
|
||||
func init() {
|
||||
InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction))
|
||||
mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier))
|
||||
}
|
||||
|
||||
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
|
||||
device.staticIdentity.RLock()
|
||||
defer device.staticIdentity.RUnlock()
|
||||
|
||||
handshake := &peer.handshake
|
||||
handshake.mutex.Lock()
|
||||
defer handshake.mutex.Unlock()
|
||||
|
||||
// create ephemeral key
|
||||
var err error
|
||||
handshake.hash = InitialHash
|
||||
handshake.chainKey = InitialChainKey
|
||||
handshake.localEphemeral, err = newPrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
handshake.mixHash(handshake.remoteStatic[:])
|
||||
|
||||
msg := MessageInitiation{
|
||||
Type: MessageInitiationType,
|
||||
Ephemeral: handshake.localEphemeral.publicKey(),
|
||||
}
|
||||
|
||||
handshake.mixKey(msg.Ephemeral[:])
|
||||
handshake.mixHash(msg.Ephemeral[:])
|
||||
|
||||
// encrypt static key
|
||||
ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var key [chacha20poly1305.KeySize]byte
|
||||
KDF2(
|
||||
&handshake.chainKey,
|
||||
&key,
|
||||
handshake.chainKey[:],
|
||||
ss[:],
|
||||
)
|
||||
aead, _ := chacha20poly1305.New(key[:])
|
||||
aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
|
||||
handshake.mixHash(msg.Static[:])
|
||||
|
||||
// encrypt timestamp
|
||||
if isZero(handshake.precomputedStaticStatic[:]) {
|
||||
return nil, errInvalidPublicKey
|
||||
}
|
||||
KDF2(
|
||||
&handshake.chainKey,
|
||||
&key,
|
||||
handshake.chainKey[:],
|
||||
handshake.precomputedStaticStatic[:],
|
||||
)
|
||||
timestamp := tai64n.Now()
|
||||
aead, _ = chacha20poly1305.New(key[:])
|
||||
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
|
||||
|
||||
// assign index
|
||||
device.indexTable.Delete(handshake.localIndex)
|
||||
msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
handshake.localIndex = msg.Sender
|
||||
|
||||
handshake.mixHash(msg.Timestamp[:])
|
||||
handshake.state = handshakeInitiationCreated
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
||||
var (
|
||||
hash [blake2s.Size]byte
|
||||
chainKey [blake2s.Size]byte
|
||||
)
|
||||
|
||||
if msg.Type != MessageInitiationType {
|
||||
return nil
|
||||
}
|
||||
|
||||
device.staticIdentity.RLock()
|
||||
defer device.staticIdentity.RUnlock()
|
||||
|
||||
mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:])
|
||||
mixHash(&hash, &hash, msg.Ephemeral[:])
|
||||
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
|
||||
|
||||
// decrypt static key
|
||||
var peerPK NoisePublicKey
|
||||
var key [chacha20poly1305.KeySize]byte
|
||||
ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
KDF2(&chainKey, &key, chainKey[:], ss[:])
|
||||
aead, _ := chacha20poly1305.New(key[:])
|
||||
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
mixHash(&hash, &hash, msg.Static[:])
|
||||
|
||||
// lookup peer
|
||||
|
||||
peer := device.LookupPeer(peerPK)
|
||||
if peer == nil || !peer.isRunning.Load() {
|
||||
return nil
|
||||
}
|
||||
|
||||
handshake := &peer.handshake
|
||||
|
||||
// verify identity
|
||||
|
||||
var timestamp tai64n.Timestamp
|
||||
|
||||
handshake.mutex.RLock()
|
||||
|
||||
if isZero(handshake.precomputedStaticStatic[:]) {
|
||||
handshake.mutex.RUnlock()
|
||||
return nil
|
||||
}
|
||||
KDF2(
|
||||
&chainKey,
|
||||
&key,
|
||||
chainKey[:],
|
||||
handshake.precomputedStaticStatic[:],
|
||||
)
|
||||
aead, _ = chacha20poly1305.New(key[:])
|
||||
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
|
||||
if err != nil {
|
||||
handshake.mutex.RUnlock()
|
||||
return nil
|
||||
}
|
||||
mixHash(&hash, &hash, msg.Timestamp[:])
|
||||
|
||||
// protect against replay & flood
|
||||
|
||||
replay := !timestamp.After(handshake.lastTimestamp)
|
||||
flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate
|
||||
handshake.mutex.RUnlock()
|
||||
if replay {
|
||||
device.log.Verbosef("%v - ConsumeMessageInitiation: handshake replay @ %v", peer, timestamp)
|
||||
return nil
|
||||
}
|
||||
if flood {
|
||||
device.log.Verbosef("%v - ConsumeMessageInitiation: handshake flood", peer)
|
||||
return nil
|
||||
}
|
||||
|
||||
// update handshake state
|
||||
|
||||
handshake.mutex.Lock()
|
||||
|
||||
handshake.hash = hash
|
||||
handshake.chainKey = chainKey
|
||||
handshake.remoteIndex = msg.Sender
|
||||
handshake.remoteEphemeral = msg.Ephemeral
|
||||
if timestamp.After(handshake.lastTimestamp) {
|
||||
handshake.lastTimestamp = timestamp
|
||||
}
|
||||
now := time.Now()
|
||||
if now.After(handshake.lastInitiationConsumption) {
|
||||
handshake.lastInitiationConsumption = now
|
||||
}
|
||||
handshake.state = handshakeInitiationConsumed
|
||||
|
||||
handshake.mutex.Unlock()
|
||||
|
||||
setZero(hash[:])
|
||||
setZero(chainKey[:])
|
||||
|
||||
return peer
|
||||
}
|
||||
|
||||
func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) {
|
||||
handshake := &peer.handshake
|
||||
handshake.mutex.Lock()
|
||||
defer handshake.mutex.Unlock()
|
||||
|
||||
if handshake.state != handshakeInitiationConsumed {
|
||||
return nil, errors.New("handshake initiation must be consumed first")
|
||||
}
|
||||
|
||||
// assign index
|
||||
|
||||
var err error
|
||||
device.indexTable.Delete(handshake.localIndex)
|
||||
handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var msg MessageResponse
|
||||
msg.Type = MessageResponseType
|
||||
msg.Sender = handshake.localIndex
|
||||
msg.Receiver = handshake.remoteIndex
|
||||
|
||||
// create ephemeral key
|
||||
|
||||
handshake.localEphemeral, err = newPrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msg.Ephemeral = handshake.localEphemeral.publicKey()
|
||||
handshake.mixHash(msg.Ephemeral[:])
|
||||
handshake.mixKey(msg.Ephemeral[:])
|
||||
|
||||
ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
handshake.mixKey(ss[:])
|
||||
ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
handshake.mixKey(ss[:])
|
||||
|
||||
// add preshared key
|
||||
|
||||
var tau [blake2s.Size]byte
|
||||
var key [chacha20poly1305.KeySize]byte
|
||||
|
||||
KDF3(
|
||||
&handshake.chainKey,
|
||||
&tau,
|
||||
&key,
|
||||
handshake.chainKey[:],
|
||||
handshake.presharedKey[:],
|
||||
)
|
||||
|
||||
handshake.mixHash(tau[:])
|
||||
|
||||
aead, _ := chacha20poly1305.New(key[:])
|
||||
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
|
||||
handshake.mixHash(msg.Empty[:])
|
||||
|
||||
handshake.state = handshakeResponseCreated
|
||||
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
||||
if msg.Type != MessageResponseType {
|
||||
return nil
|
||||
}
|
||||
|
||||
// lookup handshake by receiver
|
||||
|
||||
lookup := device.indexTable.Lookup(msg.Receiver)
|
||||
handshake := lookup.handshake
|
||||
if handshake == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
hash [blake2s.Size]byte
|
||||
chainKey [blake2s.Size]byte
|
||||
)
|
||||
|
||||
ok := func() bool {
|
||||
// lock handshake state
|
||||
|
||||
handshake.mutex.RLock()
|
||||
defer handshake.mutex.RUnlock()
|
||||
|
||||
if handshake.state != handshakeInitiationCreated {
|
||||
return false
|
||||
}
|
||||
|
||||
// lock private key for reading
|
||||
|
||||
device.staticIdentity.RLock()
|
||||
defer device.staticIdentity.RUnlock()
|
||||
|
||||
// finish 3-way DH
|
||||
|
||||
mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
|
||||
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
|
||||
|
||||
ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
mixKey(&chainKey, &chainKey, ss[:])
|
||||
setZero(ss[:])
|
||||
|
||||
ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
mixKey(&chainKey, &chainKey, ss[:])
|
||||
setZero(ss[:])
|
||||
|
||||
// add preshared key (psk)
|
||||
|
||||
var tau [blake2s.Size]byte
|
||||
var key [chacha20poly1305.KeySize]byte
|
||||
KDF3(
|
||||
&chainKey,
|
||||
&tau,
|
||||
&key,
|
||||
chainKey[:],
|
||||
handshake.presharedKey[:],
|
||||
)
|
||||
mixHash(&hash, &hash, tau[:])
|
||||
|
||||
// authenticate transcript
|
||||
|
||||
aead, _ := chacha20poly1305.New(key[:])
|
||||
_, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
mixHash(&hash, &hash, msg.Empty[:])
|
||||
return true
|
||||
}()
|
||||
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// update handshake state
|
||||
|
||||
handshake.mutex.Lock()
|
||||
|
||||
handshake.hash = hash
|
||||
handshake.chainKey = chainKey
|
||||
handshake.remoteIndex = msg.Sender
|
||||
handshake.state = handshakeResponseConsumed
|
||||
|
||||
handshake.mutex.Unlock()
|
||||
|
||||
setZero(hash[:])
|
||||
setZero(chainKey[:])
|
||||
|
||||
return lookup.peer
|
||||
}
|
||||
|
||||
/* Derives a new keypair from the current handshake state
|
||||
*
|
||||
*/
|
||||
func (peer *Peer) BeginSymmetricSession() error {
|
||||
device := peer.device
|
||||
handshake := &peer.handshake
|
||||
handshake.mutex.Lock()
|
||||
defer handshake.mutex.Unlock()
|
||||
|
||||
// derive keys
|
||||
|
||||
var isInitiator bool
|
||||
var sendKey [chacha20poly1305.KeySize]byte
|
||||
var recvKey [chacha20poly1305.KeySize]byte
|
||||
|
||||
if handshake.state == handshakeResponseConsumed {
|
||||
KDF2(
|
||||
&sendKey,
|
||||
&recvKey,
|
||||
handshake.chainKey[:],
|
||||
nil,
|
||||
)
|
||||
isInitiator = true
|
||||
} else if handshake.state == handshakeResponseCreated {
|
||||
KDF2(
|
||||
&recvKey,
|
||||
&sendKey,
|
||||
handshake.chainKey[:],
|
||||
nil,
|
||||
)
|
||||
isInitiator = false
|
||||
} else {
|
||||
return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state)
|
||||
}
|
||||
|
||||
// zero handshake
|
||||
|
||||
setZero(handshake.chainKey[:])
|
||||
setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
|
||||
setZero(handshake.localEphemeral[:])
|
||||
peer.handshake.state = handshakeZeroed
|
||||
|
||||
// create AEAD instances
|
||||
|
||||
keypair := new(Keypair)
|
||||
keypair.send, _ = chacha20poly1305.New(sendKey[:])
|
||||
keypair.receive, _ = chacha20poly1305.New(recvKey[:])
|
||||
|
||||
setZero(sendKey[:])
|
||||
setZero(recvKey[:])
|
||||
|
||||
keypair.created = time.Now()
|
||||
keypair.replayFilter.Reset()
|
||||
keypair.isInitiator = isInitiator
|
||||
keypair.localIndex = peer.handshake.localIndex
|
||||
keypair.remoteIndex = peer.handshake.remoteIndex
|
||||
|
||||
// remap index
|
||||
|
||||
device.indexTable.SwapIndexForKeypair(handshake.localIndex, keypair)
|
||||
handshake.localIndex = 0
|
||||
|
||||
// rotate key pairs
|
||||
|
||||
keypairs := &peer.keypairs
|
||||
keypairs.Lock()
|
||||
defer keypairs.Unlock()
|
||||
|
||||
previous := keypairs.previous
|
||||
next := keypairs.next.Load()
|
||||
current := keypairs.current
|
||||
|
||||
if isInitiator {
|
||||
if next != nil {
|
||||
keypairs.next.Store(nil)
|
||||
keypairs.previous = next
|
||||
device.DeleteKeypair(current)
|
||||
} else {
|
||||
keypairs.previous = current
|
||||
}
|
||||
device.DeleteKeypair(previous)
|
||||
keypairs.current = keypair
|
||||
} else {
|
||||
keypairs.next.Store(keypair)
|
||||
device.DeleteKeypair(next)
|
||||
keypairs.previous = nil
|
||||
device.DeleteKeypair(previous)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
|
||||
keypairs := &peer.keypairs
|
||||
|
||||
if keypairs.next.Load() != receivedKeypair {
|
||||
return false
|
||||
}
|
||||
keypairs.Lock()
|
||||
defer keypairs.Unlock()
|
||||
if keypairs.next.Load() != receivedKeypair {
|
||||
return false
|
||||
}
|
||||
old := keypairs.previous
|
||||
keypairs.previous = keypairs.current
|
||||
peer.device.DeleteKeypair(old)
|
||||
keypairs.current = keypairs.next.Load()
|
||||
keypairs.next.Store(nil)
|
||||
return true
|
||||
}
|
||||
78
vendor/github.com/tailscale/wireguard-go/device/noise-types.go
generated
vendored
Normal file
78
vendor/github.com/tailscale/wireguard-go/device/noise-types.go
generated
vendored
Normal file
@@ -0,0 +1,78 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
)
|
||||
|
||||
const (
|
||||
NoisePublicKeySize = 32
|
||||
NoisePrivateKeySize = 32
|
||||
NoisePresharedKeySize = 32
|
||||
)
|
||||
|
||||
type (
|
||||
NoisePublicKey [NoisePublicKeySize]byte
|
||||
NoisePrivateKey [NoisePrivateKeySize]byte
|
||||
NoisePresharedKey [NoisePresharedKeySize]byte
|
||||
NoiseNonce uint64 // padded to 12-bytes
|
||||
)
|
||||
|
||||
func loadExactHex(dst []byte, src string) error {
|
||||
slice, err := hex.DecodeString(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(slice) != len(dst) {
|
||||
return errors.New("hex string does not fit the slice")
|
||||
}
|
||||
copy(dst, slice)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (key NoisePrivateKey) IsZero() bool {
|
||||
var zero NoisePrivateKey
|
||||
return key.Equals(zero)
|
||||
}
|
||||
|
||||
func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool {
|
||||
return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
|
||||
}
|
||||
|
||||
func (key *NoisePrivateKey) FromHex(src string) (err error) {
|
||||
err = loadExactHex(key[:], src)
|
||||
key.clamp()
|
||||
return
|
||||
}
|
||||
|
||||
func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) {
|
||||
err = loadExactHex(key[:], src)
|
||||
if key.IsZero() {
|
||||
return
|
||||
}
|
||||
key.clamp()
|
||||
return
|
||||
}
|
||||
|
||||
func (key *NoisePublicKey) FromHex(src string) error {
|
||||
return loadExactHex(key[:], src)
|
||||
}
|
||||
|
||||
func (key NoisePublicKey) IsZero() bool {
|
||||
var zero NoisePublicKey
|
||||
return key.Equals(zero)
|
||||
}
|
||||
|
||||
func (key NoisePublicKey) Equals(tar NoisePublicKey) bool {
|
||||
return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
|
||||
}
|
||||
|
||||
func (key *NoisePresharedKey) FromHex(src string) error {
|
||||
return loadExactHex(key[:], src)
|
||||
}
|
||||
299
vendor/github.com/tailscale/wireguard-go/device/peer.go
generated
vendored
Normal file
299
vendor/github.com/tailscale/wireguard-go/device/peer.go
generated
vendored
Normal file
@@ -0,0 +1,299 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/wireguard-go/conn"
|
||||
)
|
||||
|
||||
type Peer struct {
|
||||
isRunning atomic.Bool
|
||||
keypairs Keypairs
|
||||
handshake Handshake
|
||||
device *Device
|
||||
stopping sync.WaitGroup // routines pending stop
|
||||
txBytes atomic.Uint64 // bytes send to peer (endpoint)
|
||||
rxBytes atomic.Uint64 // bytes received from peer
|
||||
lastHandshakeNano atomic.Int64 // nano seconds since epoch
|
||||
|
||||
endpoint struct {
|
||||
sync.Mutex
|
||||
val conn.Endpoint
|
||||
clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission
|
||||
disableRoaming bool
|
||||
}
|
||||
|
||||
timers struct {
|
||||
retransmitHandshake *Timer
|
||||
sendKeepalive *Timer
|
||||
newHandshake *Timer
|
||||
zeroKeyMaterial *Timer
|
||||
persistentKeepalive *Timer
|
||||
handshakeAttempts atomic.Uint32
|
||||
needAnotherKeepalive atomic.Bool
|
||||
sentLastMinuteHandshake atomic.Bool
|
||||
}
|
||||
|
||||
state struct {
|
||||
sync.Mutex // protects against concurrent Start/Stop
|
||||
}
|
||||
|
||||
queue struct {
|
||||
staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available
|
||||
outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
|
||||
inbound *autodrainingInboundQueue // sequential ordering of tun writing
|
||||
}
|
||||
|
||||
cookieGenerator CookieGenerator
|
||||
trieEntries list.List
|
||||
persistentKeepaliveInterval atomic.Uint32
|
||||
}
|
||||
|
||||
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
||||
if device.isClosed() {
|
||||
return nil, errors.New("device closed")
|
||||
}
|
||||
|
||||
// lock resources
|
||||
device.staticIdentity.RLock()
|
||||
defer device.staticIdentity.RUnlock()
|
||||
|
||||
device.peers.Lock()
|
||||
defer device.peers.Unlock()
|
||||
|
||||
// check if over limit
|
||||
if len(device.peers.keyMap) >= MaxPeers {
|
||||
return nil, errors.New("too many peers")
|
||||
}
|
||||
|
||||
// create peer
|
||||
peer := new(Peer)
|
||||
|
||||
peer.cookieGenerator.Init(pk)
|
||||
peer.device = device
|
||||
peer.queue.outbound = newAutodrainingOutboundQueue(device)
|
||||
peer.queue.inbound = newAutodrainingInboundQueue(device)
|
||||
peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize)
|
||||
|
||||
// map public key
|
||||
_, ok := device.peers.keyMap[pk]
|
||||
if ok {
|
||||
return nil, errors.New("adding existing peer")
|
||||
}
|
||||
|
||||
// pre-compute DH
|
||||
handshake := &peer.handshake
|
||||
handshake.mutex.Lock()
|
||||
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk)
|
||||
handshake.remoteStatic = pk
|
||||
handshake.mutex.Unlock()
|
||||
|
||||
// reset endpoint
|
||||
peer.endpoint.Lock()
|
||||
peer.endpoint.val = nil
|
||||
peer.endpoint.disableRoaming = false
|
||||
peer.endpoint.clearSrcOnTx = false
|
||||
peer.endpoint.Unlock()
|
||||
|
||||
// init timers
|
||||
peer.timersInit()
|
||||
|
||||
// add
|
||||
device.peers.keyMap[pk] = peer
|
||||
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
func (peer *Peer) SendBuffers(buffers [][]byte) error {
|
||||
peer.device.net.RLock()
|
||||
defer peer.device.net.RUnlock()
|
||||
|
||||
if peer.device.isClosed() {
|
||||
return nil
|
||||
}
|
||||
|
||||
peer.endpoint.Lock()
|
||||
endpoint := peer.endpoint.val
|
||||
if endpoint == nil {
|
||||
peer.endpoint.Unlock()
|
||||
return errors.New("no known endpoint for peer")
|
||||
}
|
||||
if peer.endpoint.clearSrcOnTx {
|
||||
endpoint.ClearSrc()
|
||||
peer.endpoint.clearSrcOnTx = false
|
||||
}
|
||||
peer.endpoint.Unlock()
|
||||
|
||||
err := peer.device.net.bind.Send(buffers, endpoint)
|
||||
if err == nil {
|
||||
var totalLen uint64
|
||||
for _, b := range buffers {
|
||||
totalLen += uint64(len(b))
|
||||
}
|
||||
peer.txBytes.Add(totalLen)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (peer *Peer) String() string {
|
||||
// The awful goo that follows is identical to:
|
||||
//
|
||||
// base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
|
||||
// abbreviatedKey := base64Key[0:4] + "…" + base64Key[39:43]
|
||||
// return fmt.Sprintf("peer(%s)", abbreviatedKey)
|
||||
//
|
||||
// except that it is considerably more efficient.
|
||||
src := peer.handshake.remoteStatic
|
||||
b64 := func(input byte) byte {
|
||||
return input + 'A' + byte(((25-int(input))>>8)&6) - byte(((51-int(input))>>8)&75) - byte(((61-int(input))>>8)&15) + byte(((62-int(input))>>8)&3)
|
||||
}
|
||||
b := []byte("peer(____…____)")
|
||||
const first = len("peer(")
|
||||
const second = len("peer(____…")
|
||||
b[first+0] = b64((src[0] >> 2) & 63)
|
||||
b[first+1] = b64(((src[0] << 4) | (src[1] >> 4)) & 63)
|
||||
b[first+2] = b64(((src[1] << 2) | (src[2] >> 6)) & 63)
|
||||
b[first+3] = b64(src[2] & 63)
|
||||
b[second+0] = b64(src[29] & 63)
|
||||
b[second+1] = b64((src[30] >> 2) & 63)
|
||||
b[second+2] = b64(((src[30] << 4) | (src[31] >> 4)) & 63)
|
||||
b[second+3] = b64((src[31] << 2) & 63)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func (peer *Peer) Start() {
|
||||
// should never start a peer on a closed device
|
||||
if peer.device.isClosed() {
|
||||
return
|
||||
}
|
||||
|
||||
// prevent simultaneous start/stop operations
|
||||
peer.state.Lock()
|
||||
defer peer.state.Unlock()
|
||||
|
||||
if peer.isRunning.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
device := peer.device
|
||||
device.log.Verbosef("%v - Starting", peer)
|
||||
|
||||
// reset routine state
|
||||
peer.stopping.Wait()
|
||||
peer.stopping.Add(2)
|
||||
|
||||
peer.handshake.mutex.Lock()
|
||||
peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
|
||||
peer.handshake.mutex.Unlock()
|
||||
|
||||
peer.device.queue.encryption.wg.Add(1) // keep encryption queue open for our writes
|
||||
|
||||
peer.timersStart()
|
||||
|
||||
device.flushInboundQueue(peer.queue.inbound)
|
||||
device.flushOutboundQueue(peer.queue.outbound)
|
||||
|
||||
// Use the device batch size, not the bind batch size, as the device size is
|
||||
// the size of the batch pools.
|
||||
batchSize := peer.device.BatchSize()
|
||||
go peer.RoutineSequentialSender(batchSize)
|
||||
go peer.RoutineSequentialReceiver(batchSize)
|
||||
|
||||
peer.isRunning.Store(true)
|
||||
}
|
||||
|
||||
func (peer *Peer) ZeroAndFlushAll() {
|
||||
device := peer.device
|
||||
|
||||
// clear key pairs
|
||||
|
||||
keypairs := &peer.keypairs
|
||||
keypairs.Lock()
|
||||
device.DeleteKeypair(keypairs.previous)
|
||||
device.DeleteKeypair(keypairs.current)
|
||||
device.DeleteKeypair(keypairs.next.Load())
|
||||
keypairs.previous = nil
|
||||
keypairs.current = nil
|
||||
keypairs.next.Store(nil)
|
||||
keypairs.Unlock()
|
||||
|
||||
// clear handshake state
|
||||
|
||||
handshake := &peer.handshake
|
||||
handshake.mutex.Lock()
|
||||
device.indexTable.Delete(handshake.localIndex)
|
||||
handshake.Clear()
|
||||
handshake.mutex.Unlock()
|
||||
|
||||
peer.FlushStagedPackets()
|
||||
}
|
||||
|
||||
func (peer *Peer) ExpireCurrentKeypairs() {
|
||||
handshake := &peer.handshake
|
||||
handshake.mutex.Lock()
|
||||
peer.device.indexTable.Delete(handshake.localIndex)
|
||||
handshake.Clear()
|
||||
peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
|
||||
handshake.mutex.Unlock()
|
||||
|
||||
keypairs := &peer.keypairs
|
||||
keypairs.Lock()
|
||||
if keypairs.current != nil {
|
||||
keypairs.current.sendNonce.Store(RejectAfterMessages)
|
||||
}
|
||||
if next := keypairs.next.Load(); next != nil {
|
||||
next.sendNonce.Store(RejectAfterMessages)
|
||||
}
|
||||
keypairs.Unlock()
|
||||
}
|
||||
|
||||
func (peer *Peer) Stop() {
|
||||
peer.state.Lock()
|
||||
defer peer.state.Unlock()
|
||||
|
||||
if !peer.isRunning.Swap(false) {
|
||||
return
|
||||
}
|
||||
|
||||
peer.device.log.Verbosef("%v - Stopping", peer)
|
||||
|
||||
peer.timersStop()
|
||||
// Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit.
|
||||
peer.queue.inbound.c <- nil
|
||||
peer.queue.outbound.c <- nil
|
||||
peer.stopping.Wait()
|
||||
peer.device.queue.encryption.wg.Done() // no more writes to encryption queue from us
|
||||
|
||||
peer.ZeroAndFlushAll()
|
||||
}
|
||||
|
||||
func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
|
||||
peer.endpoint.Lock()
|
||||
defer peer.endpoint.Unlock()
|
||||
if peer.endpoint.disableRoaming {
|
||||
return
|
||||
}
|
||||
peer.endpoint.clearSrcOnTx = false
|
||||
if ep, ok := endpoint.(conn.PeerAwareEndpoint); ok {
|
||||
endpoint = ep.GetPeerEndpoint(peer.handshake.remoteStatic)
|
||||
}
|
||||
peer.endpoint.val = endpoint
|
||||
}
|
||||
|
||||
func (peer *Peer) markEndpointSrcForClearing() {
|
||||
peer.endpoint.Lock()
|
||||
defer peer.endpoint.Unlock()
|
||||
if peer.endpoint.val == nil {
|
||||
return
|
||||
}
|
||||
peer.endpoint.clearSrcOnTx = true
|
||||
}
|
||||
121
vendor/github.com/tailscale/wireguard-go/device/pools.go
generated
vendored
Normal file
121
vendor/github.com/tailscale/wireguard-go/device/pools.go
generated
vendored
Normal file
@@ -0,0 +1,121 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type WaitPool struct {
|
||||
pool sync.Pool
|
||||
cond sync.Cond
|
||||
lock sync.Mutex
|
||||
count uint32 // Get calls not yet Put back
|
||||
max uint32
|
||||
}
|
||||
|
||||
func NewWaitPool(max uint32, new func() any) *WaitPool {
|
||||
p := &WaitPool{pool: sync.Pool{New: new}, max: max}
|
||||
p.cond = sync.Cond{L: &p.lock}
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *WaitPool) Get() any {
|
||||
if p.max != 0 {
|
||||
p.lock.Lock()
|
||||
for p.count >= p.max {
|
||||
p.cond.Wait()
|
||||
}
|
||||
p.count++
|
||||
p.lock.Unlock()
|
||||
}
|
||||
return p.pool.Get()
|
||||
}
|
||||
|
||||
func (p *WaitPool) Put(x any) {
|
||||
p.pool.Put(x)
|
||||
if p.max == 0 {
|
||||
return
|
||||
}
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
p.count--
|
||||
p.cond.Signal()
|
||||
}
|
||||
|
||||
func (device *Device) PopulatePools() {
|
||||
device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||
s := make([]*QueueInboundElement, 0, device.BatchSize())
|
||||
return &QueueInboundElementsContainer{elems: s}
|
||||
})
|
||||
device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||
s := make([]*QueueOutboundElement, 0, device.BatchSize())
|
||||
return &QueueOutboundElementsContainer{elems: s}
|
||||
})
|
||||
device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||
return new([MaxMessageSize]byte)
|
||||
})
|
||||
device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||
return new(QueueInboundElement)
|
||||
})
|
||||
device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||
return new(QueueOutboundElement)
|
||||
})
|
||||
}
|
||||
|
||||
func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer {
|
||||
c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer)
|
||||
c.Mutex = sync.Mutex{}
|
||||
return c
|
||||
}
|
||||
|
||||
func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) {
|
||||
for i := range c.elems {
|
||||
c.elems[i] = nil
|
||||
}
|
||||
c.elems = c.elems[:0]
|
||||
device.pool.inboundElementsContainer.Put(c)
|
||||
}
|
||||
|
||||
func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer {
|
||||
c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer)
|
||||
c.Mutex = sync.Mutex{}
|
||||
return c
|
||||
}
|
||||
|
||||
func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) {
|
||||
for i := range c.elems {
|
||||
c.elems[i] = nil
|
||||
}
|
||||
c.elems = c.elems[:0]
|
||||
device.pool.outboundElementsContainer.Put(c)
|
||||
}
|
||||
|
||||
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
|
||||
return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
|
||||
}
|
||||
|
||||
func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
|
||||
device.pool.messageBuffers.Put(msg)
|
||||
}
|
||||
|
||||
func (device *Device) GetInboundElement() *QueueInboundElement {
|
||||
return device.pool.inboundElements.Get().(*QueueInboundElement)
|
||||
}
|
||||
|
||||
func (device *Device) PutInboundElement(elem *QueueInboundElement) {
|
||||
elem.clearPointers()
|
||||
device.pool.inboundElements.Put(elem)
|
||||
}
|
||||
|
||||
func (device *Device) GetOutboundElement() *QueueOutboundElement {
|
||||
return device.pool.outboundElements.Get().(*QueueOutboundElement)
|
||||
}
|
||||
|
||||
func (device *Device) PutOutboundElement(elem *QueueOutboundElement) {
|
||||
elem.clearPointers()
|
||||
device.pool.outboundElements.Put(elem)
|
||||
}
|
||||
19
vendor/github.com/tailscale/wireguard-go/device/queueconstants_android.go
generated
vendored
Normal file
19
vendor/github.com/tailscale/wireguard-go/device/queueconstants_android.go
generated
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import "github.com/tailscale/wireguard-go/conn"
|
||||
|
||||
/* Reduce memory consumption for Android */
|
||||
|
||||
const (
|
||||
QueueStagedSize = conn.IdealBatchSize
|
||||
QueueOutboundSize = 1024
|
||||
QueueInboundSize = 1024
|
||||
QueueHandshakeSize = 1024
|
||||
MaxSegmentSize = 2200
|
||||
PreallocatedBuffersPerPool = 4096
|
||||
)
|
||||
19
vendor/github.com/tailscale/wireguard-go/device/queueconstants_default.go
generated
vendored
Normal file
19
vendor/github.com/tailscale/wireguard-go/device/queueconstants_default.go
generated
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
//go:build !android && !ios && !windows
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import "github.com/tailscale/wireguard-go/conn"
|
||||
|
||||
const (
|
||||
QueueStagedSize = conn.IdealBatchSize
|
||||
QueueOutboundSize = 1024
|
||||
QueueInboundSize = 1024
|
||||
QueueHandshakeSize = 1024
|
||||
MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram
|
||||
PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth
|
||||
)
|
||||
21
vendor/github.com/tailscale/wireguard-go/device/queueconstants_ios.go
generated
vendored
Normal file
21
vendor/github.com/tailscale/wireguard-go/device/queueconstants_ios.go
generated
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
//go:build ios
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
// Fit within memory limits for iOS's Network Extension API, which has stricter requirements.
|
||||
// These are vars instead of consts, because heavier network extensions might want to reduce
|
||||
// them further.
|
||||
var (
|
||||
QueueStagedSize = 128
|
||||
QueueOutboundSize = 1024
|
||||
QueueInboundSize = 1024
|
||||
QueueHandshakeSize = 1024
|
||||
PreallocatedBuffersPerPool uint32 = 1024
|
||||
)
|
||||
|
||||
const MaxSegmentSize = 1700
|
||||
15
vendor/github.com/tailscale/wireguard-go/device/queueconstants_windows.go
generated
vendored
Normal file
15
vendor/github.com/tailscale/wireguard-go/device/queueconstants_windows.go
generated
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
const (
|
||||
QueueStagedSize = 128
|
||||
QueueOutboundSize = 1024
|
||||
QueueInboundSize = 1024
|
||||
QueueHandshakeSize = 1024
|
||||
MaxSegmentSize = 2048 - 32 // largest possible UDP datagram
|
||||
PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth
|
||||
)
|
||||
537
vendor/github.com/tailscale/wireguard-go/device/receive.go
generated
vendored
Normal file
537
vendor/github.com/tailscale/wireguard-go/device/receive.go
generated
vendored
Normal file
@@ -0,0 +1,537 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/wireguard-go/conn"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
type QueueHandshakeElement struct {
|
||||
msgType uint32
|
||||
packet []byte
|
||||
endpoint conn.Endpoint
|
||||
buffer *[MaxMessageSize]byte
|
||||
}
|
||||
|
||||
type QueueInboundElement struct {
|
||||
buffer *[MaxMessageSize]byte
|
||||
packet []byte
|
||||
counter uint64
|
||||
keypair *Keypair
|
||||
endpoint conn.Endpoint
|
||||
}
|
||||
|
||||
type QueueInboundElementsContainer struct {
|
||||
sync.Mutex
|
||||
elems []*QueueInboundElement
|
||||
}
|
||||
|
||||
// clearPointers clears elem fields that contain pointers.
|
||||
// This makes the garbage collector's life easier and
|
||||
// avoids accidentally keeping other objects around unnecessarily.
|
||||
// It also reduces the possible collateral damage from use-after-free bugs.
|
||||
func (elem *QueueInboundElement) clearPointers() {
|
||||
elem.buffer = nil
|
||||
elem.packet = nil
|
||||
elem.keypair = nil
|
||||
elem.endpoint = nil
|
||||
}
|
||||
|
||||
/* Called when a new authenticated message has been received
|
||||
*
|
||||
* NOTE: Not thread safe, but called by sequential receiver!
|
||||
*/
|
||||
func (peer *Peer) keepKeyFreshReceiving() {
|
||||
if peer.timers.sentLastMinuteHandshake.Load() {
|
||||
return
|
||||
}
|
||||
keypair := peer.keypairs.Current()
|
||||
if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
|
||||
peer.timers.sentLastMinuteHandshake.Store(true)
|
||||
peer.SendHandshakeInitiation(false)
|
||||
}
|
||||
}
|
||||
|
||||
/* Receives incoming datagrams for the device
|
||||
*
|
||||
* Every time the bind is updated a new routine is started for
|
||||
* IPv4 and IPv6 (separately)
|
||||
*/
|
||||
func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) {
|
||||
recvName := recv.PrettyName()
|
||||
defer func() {
|
||||
device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
|
||||
device.queue.decryption.wg.Done()
|
||||
device.queue.handshake.wg.Done()
|
||||
device.net.stopping.Done()
|
||||
}()
|
||||
|
||||
device.log.Verbosef("Routine: receive incoming %s - started", recvName)
|
||||
|
||||
// receive datagrams until conn is closed
|
||||
|
||||
var (
|
||||
bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize)
|
||||
bufs = make([][]byte, maxBatchSize)
|
||||
err error
|
||||
sizes = make([]int, maxBatchSize)
|
||||
count int
|
||||
endpoints = make([]conn.Endpoint, maxBatchSize)
|
||||
deathSpiral int
|
||||
elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize)
|
||||
)
|
||||
|
||||
for i := range bufsArrs {
|
||||
bufsArrs[i] = device.GetMessageBuffer()
|
||||
bufs[i] = bufsArrs[i][:]
|
||||
}
|
||||
|
||||
defer func() {
|
||||
for i := 0; i < maxBatchSize; i++ {
|
||||
if bufsArrs[i] != nil {
|
||||
device.PutMessageBuffer(bufsArrs[i])
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
count, err = recv(bufs, sizes, endpoints)
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
device.log.Verbosef("Failed to receive %s packet: %v", recvName, err)
|
||||
if neterr, ok := err.(net.Error); ok && !neterr.Temporary() {
|
||||
return
|
||||
}
|
||||
if deathSpiral < 10 {
|
||||
deathSpiral++
|
||||
time.Sleep(time.Second / 3)
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
deathSpiral = 0
|
||||
|
||||
// handle each packet in the batch
|
||||
for i, size := range sizes[:count] {
|
||||
if size < MinMessageSize {
|
||||
continue
|
||||
}
|
||||
|
||||
// check size of packet
|
||||
|
||||
packet := bufsArrs[i][:size]
|
||||
msgType := binary.LittleEndian.Uint32(packet[:4])
|
||||
|
||||
switch msgType {
|
||||
|
||||
// check if transport
|
||||
|
||||
case MessageTransportType:
|
||||
|
||||
// check size
|
||||
|
||||
if len(packet) < MessageTransportSize {
|
||||
continue
|
||||
}
|
||||
|
||||
// lookup key pair
|
||||
|
||||
receiver := binary.LittleEndian.Uint32(
|
||||
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
|
||||
)
|
||||
value := device.indexTable.Lookup(receiver)
|
||||
keypair := value.keypair
|
||||
if keypair == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// check keypair expiry
|
||||
|
||||
if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
|
||||
continue
|
||||
}
|
||||
|
||||
// create work element
|
||||
peer := value.peer
|
||||
elem := device.GetInboundElement()
|
||||
elem.packet = packet
|
||||
elem.buffer = bufsArrs[i]
|
||||
elem.keypair = keypair
|
||||
elem.endpoint = endpoints[i]
|
||||
elem.counter = 0
|
||||
|
||||
elemsForPeer, ok := elemsByPeer[peer]
|
||||
if !ok {
|
||||
elemsForPeer = device.GetInboundElementsContainer()
|
||||
elemsForPeer.Lock()
|
||||
elemsByPeer[peer] = elemsForPeer
|
||||
}
|
||||
elemsForPeer.elems = append(elemsForPeer.elems, elem)
|
||||
bufsArrs[i] = device.GetMessageBuffer()
|
||||
bufs[i] = bufsArrs[i][:]
|
||||
continue
|
||||
|
||||
// otherwise it is a fixed size & handshake related packet
|
||||
|
||||
case MessageInitiationType:
|
||||
if len(packet) != MessageInitiationSize {
|
||||
continue
|
||||
}
|
||||
|
||||
case MessageResponseType:
|
||||
if len(packet) != MessageResponseSize {
|
||||
continue
|
||||
}
|
||||
|
||||
case MessageCookieReplyType:
|
||||
if len(packet) != MessageCookieReplySize {
|
||||
continue
|
||||
}
|
||||
|
||||
default:
|
||||
device.log.Verbosef("Received message with unknown type")
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case device.queue.handshake.c <- QueueHandshakeElement{
|
||||
msgType: msgType,
|
||||
buffer: bufsArrs[i],
|
||||
packet: packet,
|
||||
endpoint: endpoints[i],
|
||||
}:
|
||||
bufsArrs[i] = device.GetMessageBuffer()
|
||||
bufs[i] = bufsArrs[i][:]
|
||||
default:
|
||||
}
|
||||
}
|
||||
for peer, elemsContainer := range elemsByPeer {
|
||||
if peer.isRunning.Load() {
|
||||
peer.queue.inbound.c <- elemsContainer
|
||||
device.queue.decryption.c <- elemsContainer
|
||||
} else {
|
||||
for _, elem := range elemsContainer.elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutInboundElement(elem)
|
||||
}
|
||||
device.PutInboundElementsContainer(elemsContainer)
|
||||
}
|
||||
delete(elemsByPeer, peer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (device *Device) RoutineDecryption(id int) {
|
||||
var nonce [chacha20poly1305.NonceSize]byte
|
||||
|
||||
defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
|
||||
device.log.Verbosef("Routine: decryption worker %d - started", id)
|
||||
|
||||
for elemsContainer := range device.queue.decryption.c {
|
||||
for _, elem := range elemsContainer.elems {
|
||||
// split message into fields
|
||||
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
|
||||
content := elem.packet[MessageTransportOffsetContent:]
|
||||
|
||||
// decrypt and release to consumer
|
||||
var err error
|
||||
elem.counter = binary.LittleEndian.Uint64(counter)
|
||||
// copy counter to nonce
|
||||
binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
|
||||
elem.packet, err = elem.keypair.receive.Open(
|
||||
content[:0],
|
||||
nonce[:],
|
||||
content,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
elem.packet = nil
|
||||
}
|
||||
}
|
||||
elemsContainer.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
/* Handles incoming packets related to handshake
|
||||
*/
|
||||
func (device *Device) RoutineHandshake(id int) {
|
||||
defer func() {
|
||||
device.log.Verbosef("Routine: handshake worker %d - stopped", id)
|
||||
device.queue.encryption.wg.Done()
|
||||
}()
|
||||
device.log.Verbosef("Routine: handshake worker %d - started", id)
|
||||
|
||||
for elem := range device.queue.handshake.c {
|
||||
|
||||
// handle cookie fields and ratelimiting
|
||||
|
||||
switch elem.msgType {
|
||||
|
||||
case MessageCookieReplyType:
|
||||
|
||||
// unmarshal packet
|
||||
|
||||
var reply MessageCookieReply
|
||||
reader := bytes.NewReader(elem.packet)
|
||||
err := binary.Read(reader, binary.LittleEndian, &reply)
|
||||
if err != nil {
|
||||
device.log.Verbosef("Failed to decode cookie reply")
|
||||
goto skip
|
||||
}
|
||||
|
||||
// lookup peer from index
|
||||
|
||||
entry := device.indexTable.Lookup(reply.Receiver)
|
||||
|
||||
if entry.peer == nil {
|
||||
goto skip
|
||||
}
|
||||
|
||||
// consume reply
|
||||
|
||||
if peer := entry.peer; peer.isRunning.Load() {
|
||||
device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString())
|
||||
if !peer.cookieGenerator.ConsumeReply(&reply) {
|
||||
device.log.Verbosef("Could not decrypt invalid cookie response")
|
||||
}
|
||||
}
|
||||
|
||||
goto skip
|
||||
|
||||
case MessageInitiationType, MessageResponseType:
|
||||
|
||||
// check mac fields and maybe ratelimit
|
||||
|
||||
if !device.cookieChecker.CheckMAC1(elem.packet) {
|
||||
device.log.Verbosef("Received packet with invalid mac1")
|
||||
goto skip
|
||||
}
|
||||
|
||||
// endpoints destination address is the source of the datagram
|
||||
|
||||
if device.IsUnderLoad() {
|
||||
|
||||
// verify MAC2 field
|
||||
|
||||
if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
|
||||
device.SendHandshakeCookie(&elem)
|
||||
goto skip
|
||||
}
|
||||
|
||||
// check ratelimiter
|
||||
|
||||
if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
|
||||
goto skip
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
device.log.Errorf("Invalid packet ended up in the handshake queue")
|
||||
goto skip
|
||||
}
|
||||
|
||||
// handle handshake initiation/response content
|
||||
|
||||
switch elem.msgType {
|
||||
case MessageInitiationType:
|
||||
|
||||
// unmarshal
|
||||
|
||||
var msg MessageInitiation
|
||||
reader := bytes.NewReader(elem.packet)
|
||||
err := binary.Read(reader, binary.LittleEndian, &msg)
|
||||
if err != nil {
|
||||
device.log.Errorf("Failed to decode initiation message")
|
||||
goto skip
|
||||
}
|
||||
|
||||
// consume initiation
|
||||
|
||||
peer := device.ConsumeMessageInitiation(&msg)
|
||||
if peer == nil {
|
||||
device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
|
||||
goto skip
|
||||
}
|
||||
|
||||
// update timers
|
||||
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketReceived()
|
||||
|
||||
// update endpoint
|
||||
peer.SetEndpointFromPacket(elem.endpoint)
|
||||
|
||||
device.log.Verbosef("%v - Received handshake initiation", peer)
|
||||
peer.rxBytes.Add(uint64(len(elem.packet)))
|
||||
|
||||
peer.SendHandshakeResponse()
|
||||
|
||||
case MessageResponseType:
|
||||
|
||||
// unmarshal
|
||||
|
||||
var msg MessageResponse
|
||||
reader := bytes.NewReader(elem.packet)
|
||||
err := binary.Read(reader, binary.LittleEndian, &msg)
|
||||
if err != nil {
|
||||
device.log.Errorf("Failed to decode response message")
|
||||
goto skip
|
||||
}
|
||||
|
||||
// consume response
|
||||
|
||||
peer := device.ConsumeMessageResponse(&msg)
|
||||
if peer == nil {
|
||||
device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString())
|
||||
goto skip
|
||||
}
|
||||
|
||||
// update endpoint
|
||||
peer.SetEndpointFromPacket(elem.endpoint)
|
||||
|
||||
device.log.Verbosef("%v - Received handshake response", peer)
|
||||
peer.rxBytes.Add(uint64(len(elem.packet)))
|
||||
|
||||
// update timers
|
||||
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketReceived()
|
||||
|
||||
// derive keypair
|
||||
|
||||
err = peer.BeginSymmetricSession()
|
||||
|
||||
if err != nil {
|
||||
device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
|
||||
goto skip
|
||||
}
|
||||
|
||||
peer.timersSessionDerived()
|
||||
peer.timersHandshakeComplete()
|
||||
peer.SendKeepalive()
|
||||
}
|
||||
skip:
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
|
||||
device := peer.device
|
||||
defer func() {
|
||||
device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
|
||||
peer.stopping.Done()
|
||||
}()
|
||||
device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
|
||||
|
||||
bufs := make([][]byte, 0, maxBatchSize)
|
||||
|
||||
for elemsContainer := range peer.queue.inbound.c {
|
||||
if elemsContainer == nil {
|
||||
return
|
||||
}
|
||||
elemsContainer.Lock()
|
||||
validTailPacket := -1
|
||||
dataPacketReceived := false
|
||||
for i, elem := range elemsContainer.elems {
|
||||
if elem.packet == nil {
|
||||
// decryption failed
|
||||
continue
|
||||
}
|
||||
|
||||
if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
|
||||
continue
|
||||
}
|
||||
|
||||
validTailPacket = i
|
||||
if peer.ReceivedWithKeypair(elem.keypair) {
|
||||
peer.SetEndpointFromPacket(elem.endpoint)
|
||||
peer.timersHandshakeComplete()
|
||||
peer.SendStagedPackets()
|
||||
}
|
||||
peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize))
|
||||
|
||||
if len(elem.packet) == 0 {
|
||||
device.log.Verbosef("%v - Receiving keepalive packet", peer)
|
||||
continue
|
||||
}
|
||||
dataPacketReceived = true
|
||||
|
||||
switch elem.packet[0] >> 4 {
|
||||
case 4:
|
||||
if len(elem.packet) < ipv4.HeaderLen {
|
||||
continue
|
||||
}
|
||||
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
|
||||
length := binary.BigEndian.Uint16(field)
|
||||
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
|
||||
continue
|
||||
}
|
||||
elem.packet = elem.packet[:length]
|
||||
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
|
||||
if device.allowedips.Lookup(src) != peer {
|
||||
device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
|
||||
continue
|
||||
}
|
||||
|
||||
case 6:
|
||||
if len(elem.packet) < ipv6.HeaderLen {
|
||||
continue
|
||||
}
|
||||
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
|
||||
length := binary.BigEndian.Uint16(field)
|
||||
length += ipv6.HeaderLen
|
||||
if int(length) > len(elem.packet) {
|
||||
continue
|
||||
}
|
||||
elem.packet = elem.packet[:length]
|
||||
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
|
||||
if device.allowedips.Lookup(src) != peer {
|
||||
device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
|
||||
continue
|
||||
}
|
||||
|
||||
default:
|
||||
device.log.Verbosef("Packet with invalid IP version from %v", peer)
|
||||
continue
|
||||
}
|
||||
|
||||
bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)])
|
||||
}
|
||||
if validTailPacket >= 0 {
|
||||
peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint)
|
||||
peer.keepKeyFreshReceiving()
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketReceived()
|
||||
}
|
||||
if dataPacketReceived {
|
||||
peer.timersDataReceived()
|
||||
}
|
||||
if len(bufs) > 0 {
|
||||
_, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
|
||||
if err != nil && !device.isClosed() {
|
||||
device.log.Errorf("Failed to write packets to TUN device: %v", err)
|
||||
}
|
||||
}
|
||||
for _, elem := range elemsContainer.elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutInboundElement(elem)
|
||||
}
|
||||
bufs = bufs[:0]
|
||||
device.PutInboundElementsContainer(elemsContainer)
|
||||
}
|
||||
}
|
||||
547
vendor/github.com/tailscale/wireguard-go/device/send.go
generated
vendored
Normal file
547
vendor/github.com/tailscale/wireguard-go/device/send.go
generated
vendored
Normal file
@@ -0,0 +1,547 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/wireguard-go/conn"
|
||||
"github.com/tailscale/wireguard-go/tun"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
/* Outbound flow
|
||||
*
|
||||
* 1. TUN queue
|
||||
* 2. Routing (sequential)
|
||||
* 3. Nonce assignment (sequential)
|
||||
* 4. Encryption (parallel)
|
||||
* 5. Transmission (sequential)
|
||||
*
|
||||
* The functions in this file occur (roughly) in the order in
|
||||
* which the packets are processed.
|
||||
*
|
||||
* Locking, Producers and Consumers
|
||||
*
|
||||
* The order of packets (per peer) must be maintained,
|
||||
* but encryption of packets happen out-of-order:
|
||||
*
|
||||
* The sequential consumers will attempt to take the lock,
|
||||
* workers release lock when they have completed work (encryption) on the packet.
|
||||
*
|
||||
* If the element is inserted into the "encryption queue",
|
||||
* the content is preceded by enough "junk" to contain the transport header
|
||||
* (to allow the construction of transport messages in-place)
|
||||
*/
|
||||
|
||||
type QueueOutboundElement struct {
|
||||
buffer *[MaxMessageSize]byte // slice holding the packet data
|
||||
packet []byte // slice of "buffer" (always!)
|
||||
nonce uint64 // nonce for encryption
|
||||
keypair *Keypair // keypair for encryption
|
||||
peer *Peer // related peer
|
||||
}
|
||||
|
||||
type QueueOutboundElementsContainer struct {
|
||||
sync.Mutex
|
||||
elems []*QueueOutboundElement
|
||||
}
|
||||
|
||||
func (device *Device) NewOutboundElement() *QueueOutboundElement {
|
||||
elem := device.GetOutboundElement()
|
||||
elem.buffer = device.GetMessageBuffer()
|
||||
elem.nonce = 0
|
||||
// keypair and peer were cleared (if necessary) by clearPointers.
|
||||
return elem
|
||||
}
|
||||
|
||||
// clearPointers clears elem fields that contain pointers.
|
||||
// This makes the garbage collector's life easier and
|
||||
// avoids accidentally keeping other objects around unnecessarily.
|
||||
// It also reduces the possible collateral damage from use-after-free bugs.
|
||||
func (elem *QueueOutboundElement) clearPointers() {
|
||||
elem.buffer = nil
|
||||
elem.packet = nil
|
||||
elem.keypair = nil
|
||||
elem.peer = nil
|
||||
}
|
||||
|
||||
/* Queues a keepalive if no packets are queued for peer
|
||||
*/
|
||||
func (peer *Peer) SendKeepalive() {
|
||||
if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
|
||||
elem := peer.device.NewOutboundElement()
|
||||
elemsContainer := peer.device.GetOutboundElementsContainer()
|
||||
elemsContainer.elems = append(elemsContainer.elems, elem)
|
||||
select {
|
||||
case peer.queue.staged <- elemsContainer:
|
||||
peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
|
||||
default:
|
||||
peer.device.PutMessageBuffer(elem.buffer)
|
||||
peer.device.PutOutboundElement(elem)
|
||||
peer.device.PutOutboundElementsContainer(elemsContainer)
|
||||
}
|
||||
}
|
||||
peer.SendStagedPackets()
|
||||
}
|
||||
|
||||
func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
||||
if !isRetry {
|
||||
peer.timers.handshakeAttempts.Store(0)
|
||||
}
|
||||
|
||||
peer.handshake.mutex.RLock()
|
||||
if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
|
||||
peer.handshake.mutex.RUnlock()
|
||||
return nil
|
||||
}
|
||||
peer.handshake.mutex.RUnlock()
|
||||
|
||||
peer.handshake.mutex.Lock()
|
||||
if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
|
||||
peer.handshake.mutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
peer.handshake.lastSentHandshake = time.Now()
|
||||
peer.handshake.mutex.Unlock()
|
||||
|
||||
peer.device.log.Verbosef("%v - Sending handshake initiation", peer)
|
||||
|
||||
msg, err := peer.device.CreateMessageInitiation(peer)
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err)
|
||||
return err
|
||||
}
|
||||
|
||||
var buf [MessageInitiationSize]byte
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
binary.Write(writer, binary.LittleEndian, msg)
|
||||
packet := writer.Bytes()
|
||||
peer.cookieGenerator.AddMacs(packet)
|
||||
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketSent()
|
||||
|
||||
err = peer.SendBuffers([][]byte{packet})
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
|
||||
}
|
||||
peer.timersHandshakeInitiated()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (peer *Peer) SendHandshakeResponse() error {
|
||||
peer.handshake.mutex.Lock()
|
||||
peer.handshake.lastSentHandshake = time.Now()
|
||||
peer.handshake.mutex.Unlock()
|
||||
|
||||
peer.device.log.Verbosef("%v - Sending handshake response", peer)
|
||||
|
||||
response, err := peer.device.CreateMessageResponse(peer)
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err)
|
||||
return err
|
||||
}
|
||||
|
||||
var buf [MessageResponseSize]byte
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
binary.Write(writer, binary.LittleEndian, response)
|
||||
packet := writer.Bytes()
|
||||
peer.cookieGenerator.AddMacs(packet)
|
||||
|
||||
err = peer.BeginSymmetricSession()
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
|
||||
return err
|
||||
}
|
||||
|
||||
peer.timersSessionDerived()
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketSent()
|
||||
|
||||
// TODO: allocation could be avoided
|
||||
err = peer.SendBuffers([][]byte{packet})
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error {
|
||||
device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
|
||||
|
||||
sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
|
||||
reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes())
|
||||
if err != nil {
|
||||
device.log.Errorf("Failed to create cookie reply: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
var buf [MessageCookieReplySize]byte
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
binary.Write(writer, binary.LittleEndian, reply)
|
||||
// TODO: allocation could be avoided
|
||||
device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (peer *Peer) keepKeyFreshSending() {
|
||||
keypair := peer.keypairs.Current()
|
||||
if keypair == nil {
|
||||
return
|
||||
}
|
||||
nonce := keypair.sendNonce.Load()
|
||||
if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
|
||||
peer.SendHandshakeInitiation(false)
|
||||
}
|
||||
}
|
||||
|
||||
func (device *Device) RoutineReadFromTUN() {
|
||||
defer func() {
|
||||
device.log.Verbosef("Routine: TUN reader - stopped")
|
||||
device.state.stopping.Done()
|
||||
device.queue.encryption.wg.Done()
|
||||
}()
|
||||
|
||||
device.log.Verbosef("Routine: TUN reader - started")
|
||||
|
||||
var (
|
||||
batchSize = device.BatchSize()
|
||||
readErr error
|
||||
elems = make([]*QueueOutboundElement, batchSize)
|
||||
bufs = make([][]byte, batchSize)
|
||||
elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize)
|
||||
count = 0
|
||||
sizes = make([]int, batchSize)
|
||||
offset = MessageTransportHeaderSize
|
||||
)
|
||||
|
||||
for i := range elems {
|
||||
elems[i] = device.NewOutboundElement()
|
||||
bufs[i] = elems[i].buffer[:]
|
||||
}
|
||||
|
||||
defer func() {
|
||||
for _, elem := range elems {
|
||||
if elem != nil {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
// read packets
|
||||
count, readErr = device.tun.device.Read(bufs, sizes, offset)
|
||||
for i := 0; i < count; i++ {
|
||||
if sizes[i] < 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
elem := elems[i]
|
||||
elem.packet = bufs[i][offset : offset+sizes[i]]
|
||||
|
||||
// lookup peer
|
||||
var peer *Peer
|
||||
switch elem.packet[0] >> 4 {
|
||||
case 4:
|
||||
if len(elem.packet) < ipv4.HeaderLen {
|
||||
continue
|
||||
}
|
||||
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
|
||||
peer = device.allowedips.Lookup(dst)
|
||||
|
||||
case 6:
|
||||
if len(elem.packet) < ipv6.HeaderLen {
|
||||
continue
|
||||
}
|
||||
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
|
||||
peer = device.allowedips.Lookup(dst)
|
||||
|
||||
default:
|
||||
device.log.Verbosef("Received packet with unknown IP version")
|
||||
}
|
||||
|
||||
if peer == nil {
|
||||
continue
|
||||
}
|
||||
elemsForPeer, ok := elemsByPeer[peer]
|
||||
if !ok {
|
||||
elemsForPeer = device.GetOutboundElementsContainer()
|
||||
elemsByPeer[peer] = elemsForPeer
|
||||
}
|
||||
elemsForPeer.elems = append(elemsForPeer.elems, elem)
|
||||
elems[i] = device.NewOutboundElement()
|
||||
bufs[i] = elems[i].buffer[:]
|
||||
}
|
||||
|
||||
for peer, elemsForPeer := range elemsByPeer {
|
||||
if peer.isRunning.Load() {
|
||||
peer.StagePackets(elemsForPeer)
|
||||
peer.SendStagedPackets()
|
||||
} else {
|
||||
for _, elem := range elemsForPeer.elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
}
|
||||
device.PutOutboundElementsContainer(elemsForPeer)
|
||||
}
|
||||
delete(elemsByPeer, peer)
|
||||
}
|
||||
|
||||
if readErr != nil {
|
||||
if errors.Is(readErr, tun.ErrTooManySegments) {
|
||||
// TODO: record stat for this
|
||||
// This will happen if MSS is surprisingly small (< 576)
|
||||
// coincident with reasonably high throughput.
|
||||
device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
|
||||
continue
|
||||
}
|
||||
if !device.isClosed() {
|
||||
if !errors.Is(readErr, os.ErrClosed) {
|
||||
device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
|
||||
}
|
||||
go device.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
|
||||
for {
|
||||
select {
|
||||
case peer.queue.staged <- elems:
|
||||
return
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case tooOld := <-peer.queue.staged:
|
||||
for _, elem := range tooOld.elems {
|
||||
peer.device.PutMessageBuffer(elem.buffer)
|
||||
peer.device.PutOutboundElement(elem)
|
||||
}
|
||||
peer.device.PutOutboundElementsContainer(tooOld)
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *Peer) SendStagedPackets() {
|
||||
top:
|
||||
if len(peer.queue.staged) == 0 || !peer.device.isUp() {
|
||||
return
|
||||
}
|
||||
|
||||
keypair := peer.keypairs.Current()
|
||||
if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
|
||||
peer.SendHandshakeInitiation(false)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
var elemsContainerOOO *QueueOutboundElementsContainer
|
||||
select {
|
||||
case elemsContainer := <-peer.queue.staged:
|
||||
i := 0
|
||||
for _, elem := range elemsContainer.elems {
|
||||
elem.peer = peer
|
||||
elem.nonce = keypair.sendNonce.Add(1) - 1
|
||||
if elem.nonce >= RejectAfterMessages {
|
||||
keypair.sendNonce.Store(RejectAfterMessages)
|
||||
if elemsContainerOOO == nil {
|
||||
elemsContainerOOO = peer.device.GetOutboundElementsContainer()
|
||||
}
|
||||
elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
|
||||
continue
|
||||
} else {
|
||||
elemsContainer.elems[i] = elem
|
||||
i++
|
||||
}
|
||||
|
||||
elem.keypair = keypair
|
||||
}
|
||||
elemsContainer.Lock()
|
||||
elemsContainer.elems = elemsContainer.elems[:i]
|
||||
|
||||
if elemsContainerOOO != nil {
|
||||
peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
|
||||
}
|
||||
|
||||
if len(elemsContainer.elems) == 0 {
|
||||
peer.device.PutOutboundElementsContainer(elemsContainer)
|
||||
goto top
|
||||
}
|
||||
|
||||
// add to parallel and sequential queue
|
||||
if peer.isRunning.Load() {
|
||||
peer.queue.outbound.c <- elemsContainer
|
||||
peer.device.queue.encryption.c <- elemsContainer
|
||||
} else {
|
||||
for _, elem := range elemsContainer.elems {
|
||||
peer.device.PutMessageBuffer(elem.buffer)
|
||||
peer.device.PutOutboundElement(elem)
|
||||
}
|
||||
peer.device.PutOutboundElementsContainer(elemsContainer)
|
||||
}
|
||||
|
||||
if elemsContainerOOO != nil {
|
||||
goto top
|
||||
}
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *Peer) FlushStagedPackets() {
|
||||
for {
|
||||
select {
|
||||
case elemsContainer := <-peer.queue.staged:
|
||||
for _, elem := range elemsContainer.elems {
|
||||
peer.device.PutMessageBuffer(elem.buffer)
|
||||
peer.device.PutOutboundElement(elem)
|
||||
}
|
||||
peer.device.PutOutboundElementsContainer(elemsContainer)
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func calculatePaddingSize(packetSize, mtu int) int {
|
||||
lastUnit := packetSize
|
||||
if mtu == 0 {
|
||||
return ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) - lastUnit
|
||||
}
|
||||
if lastUnit > mtu {
|
||||
lastUnit %= mtu
|
||||
}
|
||||
paddedSize := ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1))
|
||||
if paddedSize > mtu {
|
||||
paddedSize = mtu
|
||||
}
|
||||
return paddedSize - lastUnit
|
||||
}
|
||||
|
||||
/* Encrypts the elements in the queue
|
||||
* and marks them for sequential consumption (by releasing the mutex)
|
||||
*
|
||||
* Obs. One instance per core
|
||||
*/
|
||||
func (device *Device) RoutineEncryption(id int) {
|
||||
var paddingZeros [PaddingMultiple]byte
|
||||
var nonce [chacha20poly1305.NonceSize]byte
|
||||
|
||||
defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
|
||||
device.log.Verbosef("Routine: encryption worker %d - started", id)
|
||||
|
||||
for elemsContainer := range device.queue.encryption.c {
|
||||
for _, elem := range elemsContainer.elems {
|
||||
// populate header fields
|
||||
header := elem.buffer[:MessageTransportHeaderSize]
|
||||
|
||||
fieldType := header[0:4]
|
||||
fieldReceiver := header[4:8]
|
||||
fieldNonce := header[8:16]
|
||||
|
||||
binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
|
||||
binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
|
||||
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
|
||||
|
||||
// pad content to multiple of 16
|
||||
paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
|
||||
elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
|
||||
|
||||
// encrypt content and release to consumer
|
||||
|
||||
binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
|
||||
elem.packet = elem.keypair.send.Seal(
|
||||
header,
|
||||
nonce[:],
|
||||
elem.packet,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
elemsContainer.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
|
||||
device := peer.device
|
||||
defer func() {
|
||||
defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
|
||||
peer.stopping.Done()
|
||||
}()
|
||||
device.log.Verbosef("%v - Routine: sequential sender - started", peer)
|
||||
|
||||
bufs := make([][]byte, 0, maxBatchSize)
|
||||
|
||||
for elemsContainer := range peer.queue.outbound.c {
|
||||
bufs = bufs[:0]
|
||||
if elemsContainer == nil {
|
||||
return
|
||||
}
|
||||
if !peer.isRunning.Load() {
|
||||
// peer has been stopped; return re-usable elems to the shared pool.
|
||||
// This is an optimization only. It is possible for the peer to be stopped
|
||||
// immediately after this check, in which case, elem will get processed.
|
||||
// The timers and SendBuffers code are resilient to a few stragglers.
|
||||
// TODO: rework peer shutdown order to ensure
|
||||
// that we never accidentally keep timers alive longer than necessary.
|
||||
elemsContainer.Lock()
|
||||
for _, elem := range elemsContainer.elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
}
|
||||
device.PutOutboundElementsContainer(elemsContainer)
|
||||
continue
|
||||
}
|
||||
dataSent := false
|
||||
elemsContainer.Lock()
|
||||
for _, elem := range elemsContainer.elems {
|
||||
if len(elem.packet) != MessageKeepaliveSize {
|
||||
dataSent = true
|
||||
}
|
||||
bufs = append(bufs, elem.packet)
|
||||
}
|
||||
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketSent()
|
||||
|
||||
err := peer.SendBuffers(bufs)
|
||||
if dataSent {
|
||||
peer.timersDataSent()
|
||||
}
|
||||
for _, elem := range elemsContainer.elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
}
|
||||
device.PutOutboundElementsContainer(elemsContainer)
|
||||
if err != nil {
|
||||
var errGSO conn.ErrUDPGSODisabled
|
||||
if errors.As(err, &errGSO) {
|
||||
device.log.Verbosef(err.Error())
|
||||
err = errGSO.RetryErr
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
|
||||
continue
|
||||
}
|
||||
|
||||
peer.keepKeyFreshSending()
|
||||
}
|
||||
}
|
||||
12
vendor/github.com/tailscale/wireguard-go/device/sticky_default.go
generated
vendored
Normal file
12
vendor/github.com/tailscale/wireguard-go/device/sticky_default.go
generated
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build !linux
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"github.com/tailscale/wireguard-go/conn"
|
||||
"github.com/tailscale/wireguard-go/rwcancel"
|
||||
)
|
||||
|
||||
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
||||
return nil, nil
|
||||
}
|
||||
224
vendor/github.com/tailscale/wireguard-go/device/sticky_linux.go
generated
vendored
Normal file
224
vendor/github.com/tailscale/wireguard-go/device/sticky_linux.go
generated
vendored
Normal file
@@ -0,0 +1,224 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* This implements userspace semantics of "sticky sockets", modeled after
|
||||
* WireGuard's kernelspace implementation. This is more or less a straight port
|
||||
* of the sticky-sockets.c example code:
|
||||
* https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
|
||||
*
|
||||
* Currently there is no way to achieve this within the net package:
|
||||
* See e.g. https://github.com/golang/go/issues/17930
|
||||
* So this code is remains platform dependent.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/tailscale/wireguard-go/conn"
|
||||
"github.com/tailscale/wireguard-go/rwcancel"
|
||||
)
|
||||
|
||||
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
||||
if !conn.StdNetSupportsStickySockets {
|
||||
return nil, nil
|
||||
}
|
||||
if _, ok := bind.(*conn.StdNetBind); !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
netlinkSock, err := createNetlinkRouteSocket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock)
|
||||
if err != nil {
|
||||
unix.Close(netlinkSock)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go device.routineRouteListener(bind, netlinkSock, netlinkCancel)
|
||||
|
||||
return netlinkCancel, nil
|
||||
}
|
||||
|
||||
func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
|
||||
type peerEndpointPtr struct {
|
||||
peer *Peer
|
||||
endpoint *conn.Endpoint
|
||||
}
|
||||
var reqPeer map[uint32]peerEndpointPtr
|
||||
var reqPeerLock sync.Mutex
|
||||
|
||||
defer netlinkCancel.Close()
|
||||
defer unix.Close(netlinkSock)
|
||||
|
||||
for msg := make([]byte, 1<<16); ; {
|
||||
var err error
|
||||
var msgn int
|
||||
for {
|
||||
msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0)
|
||||
if err == nil || !rwcancel.RetryAfterError(err) {
|
||||
break
|
||||
}
|
||||
if !netlinkCancel.ReadyRead() {
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
|
||||
|
||||
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
|
||||
|
||||
if uint(hdr.Len) > uint(len(remain)) {
|
||||
break
|
||||
}
|
||||
|
||||
switch hdr.Type {
|
||||
case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
|
||||
if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
|
||||
if uint(len(remain)) < uint(hdr.Len) {
|
||||
break
|
||||
}
|
||||
if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
|
||||
attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
|
||||
for {
|
||||
if uint(len(attr)) < uint(unix.SizeofRtAttr) {
|
||||
break
|
||||
}
|
||||
attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
|
||||
if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
|
||||
break
|
||||
}
|
||||
if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
|
||||
ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
|
||||
reqPeerLock.Lock()
|
||||
if reqPeer == nil {
|
||||
reqPeerLock.Unlock()
|
||||
break
|
||||
}
|
||||
pePtr, ok := reqPeer[hdr.Seq]
|
||||
reqPeerLock.Unlock()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
pePtr.peer.endpoint.Lock()
|
||||
if &pePtr.peer.endpoint.val != pePtr.endpoint {
|
||||
pePtr.peer.endpoint.Unlock()
|
||||
break
|
||||
}
|
||||
if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
|
||||
pePtr.peer.endpoint.Unlock()
|
||||
break
|
||||
}
|
||||
pePtr.peer.endpoint.clearSrcOnTx = true
|
||||
pePtr.peer.endpoint.Unlock()
|
||||
}
|
||||
attr = attr[attrhdr.Len:]
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
reqPeerLock.Lock()
|
||||
reqPeer = make(map[uint32]peerEndpointPtr)
|
||||
reqPeerLock.Unlock()
|
||||
go func() {
|
||||
device.peers.RLock()
|
||||
i := uint32(1)
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.endpoint.Lock()
|
||||
if peer.endpoint.val == nil {
|
||||
peer.endpoint.Unlock()
|
||||
continue
|
||||
}
|
||||
nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint)
|
||||
if nativeEP == nil {
|
||||
peer.endpoint.Unlock()
|
||||
continue
|
||||
}
|
||||
if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
|
||||
peer.endpoint.Unlock()
|
||||
break
|
||||
}
|
||||
nlmsg := struct {
|
||||
hdr unix.NlMsghdr
|
||||
msg unix.RtMsg
|
||||
dsthdr unix.RtAttr
|
||||
dst [4]byte
|
||||
srchdr unix.RtAttr
|
||||
src [4]byte
|
||||
markhdr unix.RtAttr
|
||||
mark uint32
|
||||
}{
|
||||
unix.NlMsghdr{
|
||||
Type: uint16(unix.RTM_GETROUTE),
|
||||
Flags: unix.NLM_F_REQUEST,
|
||||
Seq: i,
|
||||
},
|
||||
unix.RtMsg{
|
||||
Family: unix.AF_INET,
|
||||
Dst_len: 32,
|
||||
Src_len: 32,
|
||||
},
|
||||
unix.RtAttr{
|
||||
Len: 8,
|
||||
Type: unix.RTA_DST,
|
||||
},
|
||||
nativeEP.DstIP().As4(),
|
||||
unix.RtAttr{
|
||||
Len: 8,
|
||||
Type: unix.RTA_SRC,
|
||||
},
|
||||
nativeEP.SrcIP().As4(),
|
||||
unix.RtAttr{
|
||||
Len: 8,
|
||||
Type: unix.RTA_MARK,
|
||||
},
|
||||
device.net.fwmark,
|
||||
}
|
||||
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
|
||||
reqPeerLock.Lock()
|
||||
reqPeer[i] = peerEndpointPtr{
|
||||
peer: peer,
|
||||
endpoint: &peer.endpoint.val,
|
||||
}
|
||||
reqPeerLock.Unlock()
|
||||
peer.endpoint.Unlock()
|
||||
i++
|
||||
_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
}()
|
||||
}
|
||||
remain = remain[hdr.Len:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func createNetlinkRouteSocket() (int, error) {
|
||||
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
saddr := &unix.SockaddrNetlink{
|
||||
Family: unix.AF_NETLINK,
|
||||
Groups: unix.RTMGRP_IPV4_ROUTE,
|
||||
}
|
||||
err = unix.Bind(sock, saddr)
|
||||
if err != nil {
|
||||
unix.Close(sock)
|
||||
return -1, err
|
||||
}
|
||||
return sock, nil
|
||||
}
|
||||
221
vendor/github.com/tailscale/wireguard-go/device/timers.go
generated
vendored
Normal file
221
vendor/github.com/tailscale/wireguard-go/device/timers.go
generated
vendored
Normal file
@@ -0,0 +1,221 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* This is based heavily on timers.c from the kernel implementation.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
_ "unsafe"
|
||||
)
|
||||
|
||||
//go:linkname fastrandn runtime.fastrandn
|
||||
func fastrandn(n uint32) uint32
|
||||
|
||||
// A Timer manages time-based aspects of the WireGuard protocol.
|
||||
// Timer roughly copies the interface of the Linux kernel's struct timer_list.
|
||||
type Timer struct {
|
||||
*time.Timer
|
||||
modifyingLock sync.RWMutex
|
||||
runningLock sync.Mutex
|
||||
isPending bool
|
||||
}
|
||||
|
||||
func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer {
|
||||
timer := &Timer{}
|
||||
timer.Timer = time.AfterFunc(time.Hour, func() {
|
||||
timer.runningLock.Lock()
|
||||
defer timer.runningLock.Unlock()
|
||||
|
||||
timer.modifyingLock.Lock()
|
||||
if !timer.isPending {
|
||||
timer.modifyingLock.Unlock()
|
||||
return
|
||||
}
|
||||
timer.isPending = false
|
||||
timer.modifyingLock.Unlock()
|
||||
|
||||
expirationFunction(peer)
|
||||
})
|
||||
timer.Stop()
|
||||
return timer
|
||||
}
|
||||
|
||||
func (timer *Timer) Mod(d time.Duration) {
|
||||
timer.modifyingLock.Lock()
|
||||
timer.isPending = true
|
||||
timer.Reset(d)
|
||||
timer.modifyingLock.Unlock()
|
||||
}
|
||||
|
||||
func (timer *Timer) Del() {
|
||||
timer.modifyingLock.Lock()
|
||||
timer.isPending = false
|
||||
timer.Stop()
|
||||
timer.modifyingLock.Unlock()
|
||||
}
|
||||
|
||||
func (timer *Timer) DelSync() {
|
||||
timer.Del()
|
||||
timer.runningLock.Lock()
|
||||
timer.Del()
|
||||
timer.runningLock.Unlock()
|
||||
}
|
||||
|
||||
func (timer *Timer) IsPending() bool {
|
||||
timer.modifyingLock.RLock()
|
||||
defer timer.modifyingLock.RUnlock()
|
||||
return timer.isPending
|
||||
}
|
||||
|
||||
func (peer *Peer) timersActive() bool {
|
||||
return peer.isRunning.Load() && peer.device != nil && peer.device.isUp()
|
||||
}
|
||||
|
||||
func expiredRetransmitHandshake(peer *Peer) {
|
||||
if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes {
|
||||
peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2)
|
||||
|
||||
if peer.timersActive() {
|
||||
peer.timers.sendKeepalive.Del()
|
||||
}
|
||||
|
||||
/* We drop all packets without a keypair and don't try again,
|
||||
* if we try unsuccessfully for too long to make a handshake.
|
||||
*/
|
||||
peer.FlushStagedPackets()
|
||||
|
||||
/* We set a timer for destroying any residue that might be left
|
||||
* of a partial exchange.
|
||||
*/
|
||||
if peer.timersActive() && !peer.timers.zeroKeyMaterial.IsPending() {
|
||||
peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
|
||||
}
|
||||
} else {
|
||||
peer.timers.handshakeAttempts.Add(1)
|
||||
peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
|
||||
|
||||
/* We clear the endpoint address src address, in case this is the cause of trouble. */
|
||||
peer.markEndpointSrcForClearing()
|
||||
|
||||
peer.SendHandshakeInitiation(true)
|
||||
}
|
||||
}
|
||||
|
||||
func expiredSendKeepalive(peer *Peer) {
|
||||
peer.SendKeepalive()
|
||||
if peer.timers.needAnotherKeepalive.Load() {
|
||||
peer.timers.needAnotherKeepalive.Store(false)
|
||||
if peer.timersActive() {
|
||||
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func expiredNewHandshake(peer *Peer) {
|
||||
peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
|
||||
/* We clear the endpoint address src address, in case this is the cause of trouble. */
|
||||
peer.markEndpointSrcForClearing()
|
||||
peer.SendHandshakeInitiation(false)
|
||||
}
|
||||
|
||||
func expiredZeroKeyMaterial(peer *Peer) {
|
||||
peer.device.log.Verbosef("%s - Removing all keys, since we haven't received a new one in %d seconds", peer, int((RejectAfterTime * 3).Seconds()))
|
||||
peer.ZeroAndFlushAll()
|
||||
}
|
||||
|
||||
func expiredPersistentKeepalive(peer *Peer) {
|
||||
if peer.persistentKeepaliveInterval.Load() > 0 {
|
||||
peer.SendKeepalive()
|
||||
}
|
||||
}
|
||||
|
||||
/* Should be called after an authenticated data packet is sent. */
|
||||
func (peer *Peer) timersDataSent() {
|
||||
if peer.timersActive() && !peer.timers.newHandshake.IsPending() {
|
||||
peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs)))
|
||||
}
|
||||
}
|
||||
|
||||
/* Should be called after an authenticated data packet is received. */
|
||||
func (peer *Peer) timersDataReceived() {
|
||||
if peer.timersActive() {
|
||||
if !peer.timers.sendKeepalive.IsPending() {
|
||||
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
|
||||
} else {
|
||||
peer.timers.needAnotherKeepalive.Store(true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Should be called after any type of authenticated packet is sent -- keepalive, data, or handshake. */
|
||||
func (peer *Peer) timersAnyAuthenticatedPacketSent() {
|
||||
if peer.timersActive() {
|
||||
peer.timers.sendKeepalive.Del()
|
||||
}
|
||||
}
|
||||
|
||||
/* Should be called after any type of authenticated packet is received -- keepalive, data, or handshake. */
|
||||
func (peer *Peer) timersAnyAuthenticatedPacketReceived() {
|
||||
if peer.timersActive() {
|
||||
peer.timers.newHandshake.Del()
|
||||
}
|
||||
}
|
||||
|
||||
/* Should be called after a handshake initiation message is sent. */
|
||||
func (peer *Peer) timersHandshakeInitiated() {
|
||||
if peer.timersActive() {
|
||||
peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs)))
|
||||
}
|
||||
}
|
||||
|
||||
/* Should be called after a handshake response message is received and processed or when getting key confirmation via the first data message. */
|
||||
func (peer *Peer) timersHandshakeComplete() {
|
||||
if peer.timersActive() {
|
||||
peer.timers.retransmitHandshake.Del()
|
||||
}
|
||||
peer.timers.handshakeAttempts.Store(0)
|
||||
peer.timers.sentLastMinuteHandshake.Store(false)
|
||||
peer.lastHandshakeNano.Store(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
|
||||
func (peer *Peer) timersSessionDerived() {
|
||||
if peer.timersActive() {
|
||||
peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
|
||||
}
|
||||
}
|
||||
|
||||
/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
|
||||
func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
|
||||
keepalive := peer.persistentKeepaliveInterval.Load()
|
||||
if keepalive > 0 && peer.timersActive() {
|
||||
peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *Peer) timersInit() {
|
||||
peer.timers.retransmitHandshake = peer.NewTimer(expiredRetransmitHandshake)
|
||||
peer.timers.sendKeepalive = peer.NewTimer(expiredSendKeepalive)
|
||||
peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake)
|
||||
peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial)
|
||||
peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive)
|
||||
}
|
||||
|
||||
func (peer *Peer) timersStart() {
|
||||
peer.timers.handshakeAttempts.Store(0)
|
||||
peer.timers.sentLastMinuteHandshake.Store(false)
|
||||
peer.timers.needAnotherKeepalive.Store(false)
|
||||
}
|
||||
|
||||
func (peer *Peer) timersStop() {
|
||||
peer.timers.retransmitHandshake.DelSync()
|
||||
peer.timers.sendKeepalive.DelSync()
|
||||
peer.timers.newHandshake.DelSync()
|
||||
peer.timers.zeroKeyMaterial.DelSync()
|
||||
peer.timers.persistentKeepalive.DelSync()
|
||||
}
|
||||
53
vendor/github.com/tailscale/wireguard-go/device/tun.go
generated
vendored
Normal file
53
vendor/github.com/tailscale/wireguard-go/device/tun.go
generated
vendored
Normal file
@@ -0,0 +1,53 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/tailscale/wireguard-go/tun"
|
||||
)
|
||||
|
||||
const DefaultMTU = 1420
|
||||
|
||||
func (device *Device) RoutineTUNEventReader() {
|
||||
device.log.Verbosef("Routine: event worker - started")
|
||||
|
||||
for event := range device.tun.device.Events() {
|
||||
if event&tun.EventMTUUpdate != 0 {
|
||||
mtu, err := device.tun.device.MTU()
|
||||
if err != nil {
|
||||
device.log.Errorf("Failed to load updated MTU of device: %v", err)
|
||||
continue
|
||||
}
|
||||
if mtu < 0 {
|
||||
device.log.Errorf("MTU not updated to negative value: %v", mtu)
|
||||
continue
|
||||
}
|
||||
var tooLarge string
|
||||
if mtu > MaxContentSize {
|
||||
tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize)
|
||||
mtu = MaxContentSize
|
||||
}
|
||||
old := device.tun.mtu.Swap(int32(mtu))
|
||||
if int(old) != mtu {
|
||||
device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge)
|
||||
}
|
||||
}
|
||||
|
||||
if event&tun.EventUp != 0 {
|
||||
device.log.Verbosef("Interface up requested")
|
||||
device.Up()
|
||||
}
|
||||
|
||||
if event&tun.EventDown != 0 {
|
||||
device.log.Verbosef("Interface down requested")
|
||||
device.Down()
|
||||
}
|
||||
}
|
||||
|
||||
device.log.Verbosef("Routine: event worker - stopped")
|
||||
}
|
||||
457
vendor/github.com/tailscale/wireguard-go/device/uapi.go
generated
vendored
Normal file
457
vendor/github.com/tailscale/wireguard-go/device/uapi.go
generated
vendored
Normal file
@@ -0,0 +1,457 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/wireguard-go/ipc"
|
||||
)
|
||||
|
||||
type IPCError struct {
|
||||
code int64 // error code
|
||||
err error // underlying/wrapped error
|
||||
}
|
||||
|
||||
func (s IPCError) Error() string {
|
||||
return fmt.Sprintf("IPC error %d: %v", s.code, s.err)
|
||||
}
|
||||
|
||||
func (s IPCError) Unwrap() error {
|
||||
return s.err
|
||||
}
|
||||
|
||||
func (s IPCError) ErrorCode() int64 {
|
||||
return s.code
|
||||
}
|
||||
|
||||
func ipcErrorf(code int64, msg string, args ...any) *IPCError {
|
||||
return &IPCError{code: code, err: fmt.Errorf(msg, args...)}
|
||||
}
|
||||
|
||||
var byteBufferPool = &sync.Pool{
|
||||
New: func() any { return new(bytes.Buffer) },
|
||||
}
|
||||
|
||||
// IpcGetOperation implements the WireGuard configuration protocol "get" operation.
|
||||
// See https://www.wireguard.com/xplatform/#configuration-protocol for details.
|
||||
func (device *Device) IpcGetOperation(w io.Writer) error {
|
||||
device.ipcMutex.RLock()
|
||||
defer device.ipcMutex.RUnlock()
|
||||
|
||||
buf := byteBufferPool.Get().(*bytes.Buffer)
|
||||
buf.Reset()
|
||||
defer byteBufferPool.Put(buf)
|
||||
sendf := func(format string, args ...any) {
|
||||
fmt.Fprintf(buf, format, args...)
|
||||
buf.WriteByte('\n')
|
||||
}
|
||||
keyf := func(prefix string, key *[32]byte) {
|
||||
buf.Grow(len(key)*2 + 2 + len(prefix))
|
||||
buf.WriteString(prefix)
|
||||
buf.WriteByte('=')
|
||||
const hex = "0123456789abcdef"
|
||||
for i := 0; i < len(key); i++ {
|
||||
buf.WriteByte(hex[key[i]>>4])
|
||||
buf.WriteByte(hex[key[i]&0xf])
|
||||
}
|
||||
buf.WriteByte('\n')
|
||||
}
|
||||
|
||||
func() {
|
||||
// lock required resources
|
||||
|
||||
device.net.RLock()
|
||||
defer device.net.RUnlock()
|
||||
|
||||
device.staticIdentity.RLock()
|
||||
defer device.staticIdentity.RUnlock()
|
||||
|
||||
device.peers.RLock()
|
||||
defer device.peers.RUnlock()
|
||||
|
||||
// serialize device related values
|
||||
|
||||
if !device.staticIdentity.privateKey.IsZero() {
|
||||
keyf("private_key", (*[32]byte)(&device.staticIdentity.privateKey))
|
||||
}
|
||||
|
||||
if device.net.port != 0 {
|
||||
sendf("listen_port=%d", device.net.port)
|
||||
}
|
||||
|
||||
if device.net.fwmark != 0 {
|
||||
sendf("fwmark=%d", device.net.fwmark)
|
||||
}
|
||||
|
||||
for _, peer := range device.peers.keyMap {
|
||||
// Serialize peer state.
|
||||
peer.handshake.mutex.RLock()
|
||||
keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
|
||||
keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
|
||||
peer.handshake.mutex.RUnlock()
|
||||
sendf("protocol_version=1")
|
||||
peer.endpoint.Lock()
|
||||
if peer.endpoint.val != nil {
|
||||
sendf("endpoint=%s", peer.endpoint.val.DstToString())
|
||||
}
|
||||
peer.endpoint.Unlock()
|
||||
|
||||
nano := peer.lastHandshakeNano.Load()
|
||||
secs := nano / time.Second.Nanoseconds()
|
||||
nano %= time.Second.Nanoseconds()
|
||||
|
||||
sendf("last_handshake_time_sec=%d", secs)
|
||||
sendf("last_handshake_time_nsec=%d", nano)
|
||||
sendf("tx_bytes=%d", peer.txBytes.Load())
|
||||
sendf("rx_bytes=%d", peer.rxBytes.Load())
|
||||
sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
|
||||
|
||||
device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
|
||||
sendf("allowed_ip=%s", prefix.String())
|
||||
return true
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
// send lines (does not require resource locks)
|
||||
if _, err := w.Write(buf.Bytes()); err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorIO, "failed to write output: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IpcSetOperation implements the WireGuard configuration protocol "set" operation.
|
||||
// See https://www.wireguard.com/xplatform/#configuration-protocol for details.
|
||||
func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
||||
device.ipcMutex.Lock()
|
||||
defer device.ipcMutex.Unlock()
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
device.log.Errorf("%v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
peer := new(ipcSetPeer)
|
||||
deviceConfig := true
|
||||
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line == "" {
|
||||
// Blank line means terminate operation.
|
||||
peer.handlePostConfig()
|
||||
return nil
|
||||
}
|
||||
key, value, ok := strings.Cut(line, "=")
|
||||
if !ok {
|
||||
return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q", line)
|
||||
}
|
||||
|
||||
if key == "public_key" {
|
||||
if deviceConfig {
|
||||
deviceConfig = false
|
||||
}
|
||||
peer.handlePostConfig()
|
||||
// Load/create the peer we are now configuring.
|
||||
err := device.handlePublicKeyLine(peer, value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
var err error
|
||||
if deviceConfig {
|
||||
err = device.handleDeviceLine(key, value)
|
||||
} else {
|
||||
err = device.handlePeerLine(peer, key, value)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
peer.handlePostConfig()
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorIO, "failed to read input: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (device *Device) handleDeviceLine(key, value string) error {
|
||||
switch key {
|
||||
case "private_key":
|
||||
var sk NoisePrivateKey
|
||||
err := sk.FromMaybeZeroHex(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating private key")
|
||||
device.SetPrivateKey(sk)
|
||||
|
||||
case "listen_port":
|
||||
port, err := strconv.ParseUint(value, 10, 16)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err)
|
||||
}
|
||||
|
||||
// update port and rebind
|
||||
device.log.Verbosef("UAPI: Updating listen port")
|
||||
|
||||
device.net.Lock()
|
||||
device.net.port = uint16(port)
|
||||
device.net.Unlock()
|
||||
|
||||
if err := device.BindUpdate(); err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err)
|
||||
}
|
||||
|
||||
case "fwmark":
|
||||
mark, err := strconv.ParseUint(value, 10, 32)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err)
|
||||
}
|
||||
|
||||
device.log.Verbosef("UAPI: Updating fwmark")
|
||||
if err := device.BindSetMark(uint32(mark)); err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err)
|
||||
}
|
||||
|
||||
case "replace_peers":
|
||||
if value != "true" {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Removing all peers")
|
||||
device.RemoveAllPeers()
|
||||
|
||||
default:
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// An ipcSetPeer is the current state of an IPC set operation on a peer.
|
||||
type ipcSetPeer struct {
|
||||
*Peer // Peer is the current peer being operated on
|
||||
dummy bool // dummy reports whether this peer is a temporary, placeholder peer
|
||||
created bool // new reports whether this is a newly created peer
|
||||
pkaOn bool // pkaOn reports whether the peer had the persistent keepalive turn on
|
||||
}
|
||||
|
||||
func (peer *ipcSetPeer) handlePostConfig() {
|
||||
if peer.Peer == nil || peer.dummy {
|
||||
return
|
||||
}
|
||||
if peer.created {
|
||||
peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil
|
||||
}
|
||||
if peer.device.isUp() {
|
||||
peer.Start()
|
||||
if peer.pkaOn {
|
||||
peer.SendKeepalive()
|
||||
}
|
||||
peer.SendStagedPackets()
|
||||
}
|
||||
}
|
||||
|
||||
func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error {
|
||||
// Load/create the peer we are configuring.
|
||||
var publicKey NoisePublicKey
|
||||
err := publicKey.FromHex(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err)
|
||||
}
|
||||
|
||||
// Ignore peer with the same public key as this device.
|
||||
device.staticIdentity.RLock()
|
||||
peer.dummy = device.staticIdentity.publicKey.Equals(publicKey)
|
||||
device.staticIdentity.RUnlock()
|
||||
|
||||
if peer.dummy {
|
||||
peer.Peer = &Peer{}
|
||||
} else {
|
||||
peer.Peer = device.LookupPeer(publicKey)
|
||||
}
|
||||
|
||||
peer.created = peer.Peer == nil
|
||||
if peer.created {
|
||||
peer.Peer, err = device.NewPeer(publicKey)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err)
|
||||
}
|
||||
device.log.Verbosef("%v - UAPI: Created", peer.Peer)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error {
|
||||
switch key {
|
||||
case "update_only":
|
||||
// allow disabling of creation
|
||||
if value != "true" {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
|
||||
}
|
||||
if peer.created && !peer.dummy {
|
||||
device.RemovePeer(peer.handshake.remoteStatic)
|
||||
peer.Peer = &Peer{}
|
||||
peer.dummy = true
|
||||
}
|
||||
|
||||
case "remove":
|
||||
// remove currently selected peer from device
|
||||
if value != "true" {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value)
|
||||
}
|
||||
if !peer.dummy {
|
||||
device.log.Verbosef("%v - UAPI: Removing", peer.Peer)
|
||||
device.RemovePeer(peer.handshake.remoteStatic)
|
||||
}
|
||||
peer.Peer = &Peer{}
|
||||
peer.dummy = true
|
||||
|
||||
case "preshared_key":
|
||||
device.log.Verbosef("%v - UAPI: Updating preshared key", peer.Peer)
|
||||
|
||||
peer.handshake.mutex.Lock()
|
||||
err := peer.handshake.presharedKey.FromHex(value)
|
||||
peer.handshake.mutex.Unlock()
|
||||
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err)
|
||||
}
|
||||
|
||||
case "endpoint":
|
||||
device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer)
|
||||
endpoint, err := device.net.bind.ParseEndpoint(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
|
||||
}
|
||||
peer.endpoint.Lock()
|
||||
defer peer.endpoint.Unlock()
|
||||
peer.endpoint.val = endpoint
|
||||
|
||||
case "persistent_keepalive_interval":
|
||||
device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
|
||||
|
||||
secs, err := strconv.ParseUint(value, 10, 16)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
|
||||
}
|
||||
|
||||
old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
|
||||
|
||||
// Send immediate keepalive if we're turning it on and before it wasn't on.
|
||||
peer.pkaOn = old == 0 && secs != 0
|
||||
|
||||
case "replace_allowed_ips":
|
||||
device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
|
||||
if value != "true" {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
|
||||
}
|
||||
if peer.dummy {
|
||||
return nil
|
||||
}
|
||||
device.allowedips.RemoveByPeer(peer.Peer)
|
||||
|
||||
case "allowed_ip":
|
||||
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
|
||||
prefix, err := netip.ParsePrefix(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
|
||||
}
|
||||
if peer.dummy {
|
||||
return nil
|
||||
}
|
||||
device.allowedips.Insert(prefix, peer.Peer)
|
||||
|
||||
case "protocol_version":
|
||||
if value != "1" {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value)
|
||||
}
|
||||
|
||||
default:
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (device *Device) IpcGet() (string, error) {
|
||||
buf := new(strings.Builder)
|
||||
if err := device.IpcGetOperation(buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func (device *Device) IpcSet(uapiConf string) error {
|
||||
return device.IpcSetOperation(strings.NewReader(uapiConf))
|
||||
}
|
||||
|
||||
func (device *Device) IpcHandle(socket net.Conn) {
|
||||
defer socket.Close()
|
||||
|
||||
buffered := func(s io.ReadWriter) *bufio.ReadWriter {
|
||||
reader := bufio.NewReader(s)
|
||||
writer := bufio.NewWriter(s)
|
||||
return bufio.NewReadWriter(reader, writer)
|
||||
}(socket)
|
||||
|
||||
for {
|
||||
op, err := buffered.ReadString('\n')
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// handle operation
|
||||
switch op {
|
||||
case "set=1\n":
|
||||
err = device.IpcSetOperation(buffered.Reader)
|
||||
case "get=1\n":
|
||||
var nextByte byte
|
||||
nextByte, err = buffered.ReadByte()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if nextByte != '\n' {
|
||||
err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte)
|
||||
break
|
||||
}
|
||||
err = device.IpcGetOperation(buffered.Writer)
|
||||
default:
|
||||
device.log.Errorf("invalid UAPI operation: %v", op)
|
||||
return
|
||||
}
|
||||
|
||||
// write status
|
||||
var status *IPCError
|
||||
if err != nil && !errors.As(err, &status) {
|
||||
// shouldn't happen
|
||||
status = ipcErrorf(ipc.IpcErrorUnknown, "other UAPI error: %w", err)
|
||||
}
|
||||
if status != nil {
|
||||
device.log.Errorf("%v", status)
|
||||
fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode())
|
||||
} else {
|
||||
fmt.Fprintf(buffered, "errno=0\n\n")
|
||||
}
|
||||
buffered.Flush()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user