Update dependencies
This commit is contained in:
17
vendor/github.com/tailscale/wireguard-go/LICENSE
generated
vendored
Normal file
17
vendor/github.com/tailscale/wireguard-go/LICENSE
generated
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
this software and associated documentation files (the "Software"), to deal in
|
||||
the Software without restriction, including without limitation the rights to
|
||||
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
|
||||
of the Software, and to permit persons to whom the Software is furnished to do
|
||||
so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
547
vendor/github.com/tailscale/wireguard-go/conn/bind_std.go
generated
vendored
Normal file
547
vendor/github.com/tailscale/wireguard-go/conn/bind_std.go
generated
vendored
Normal file
@@ -0,0 +1,547 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
var (
|
||||
_ Bind = (*StdNetBind)(nil)
|
||||
_ Endpoint = (*StdNetEndpoint)(nil)
|
||||
)
|
||||
|
||||
// StdNetBind implements Bind for all platforms. While Windows has its own Bind
|
||||
// (see bind_windows.go), it may fall back to StdNetBind.
|
||||
// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
|
||||
// methods for sending and receiving multiple datagrams per-syscall. See the
|
||||
// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
|
||||
type StdNetBind struct {
|
||||
mu sync.Mutex // protects all fields except as specified
|
||||
ipv4 *net.UDPConn
|
||||
ipv6 *net.UDPConn
|
||||
ipv4PC *ipv4.PacketConn // will be nil on non-Linux
|
||||
ipv6PC *ipv6.PacketConn // will be nil on non-Linux
|
||||
ipv4TxOffload bool
|
||||
ipv4RxOffload bool
|
||||
ipv6TxOffload bool
|
||||
ipv6RxOffload bool
|
||||
|
||||
// these two fields are not guarded by mu
|
||||
udpAddrPool sync.Pool
|
||||
msgsPool sync.Pool
|
||||
|
||||
blackhole4 bool
|
||||
blackhole6 bool
|
||||
}
|
||||
|
||||
func NewStdNetBind() Bind {
|
||||
return &StdNetBind{
|
||||
udpAddrPool: sync.Pool{
|
||||
New: func() any {
|
||||
return &net.UDPAddr{
|
||||
IP: make([]byte, 16),
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
msgsPool: sync.Pool{
|
||||
New: func() any {
|
||||
msgs := make([]ipv6.Message, IdealBatchSize)
|
||||
for i := range msgs {
|
||||
msgs[i].Buffers = make(net.Buffers, 1)
|
||||
msgs[i].OOB = make([]byte, controlSize)
|
||||
}
|
||||
return &msgs
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type StdNetEndpoint struct {
|
||||
// AddrPort is the endpoint destination.
|
||||
netip.AddrPort
|
||||
// src is the current sticky source address and interface index, if
|
||||
// supported. Typically this is a PKTINFO structure from/for control
|
||||
// messages, see unix.PKTINFO for an example.
|
||||
src []byte
|
||||
}
|
||||
|
||||
var (
|
||||
_ Bind = (*StdNetBind)(nil)
|
||||
_ Endpoint = &StdNetEndpoint{}
|
||||
)
|
||||
|
||||
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
|
||||
e, err := netip.ParseAddrPort(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &StdNetEndpoint{
|
||||
AddrPort: e,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) ClearSrc() {
|
||||
if e.src != nil {
|
||||
// Truncate src, no need to reallocate.
|
||||
e.src = e.src[:0]
|
||||
}
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) DstIP() netip.Addr {
|
||||
return e.AddrPort.Addr()
|
||||
}
|
||||
|
||||
// See sticky_default,linux, etc for implementations of SrcIP and SrcIfidx.
|
||||
|
||||
func (e *StdNetEndpoint) DstToBytes() []byte {
|
||||
b, _ := e.AddrPort.MarshalBinary()
|
||||
return b
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) DstToString() string {
|
||||
return e.AddrPort.String()
|
||||
}
|
||||
|
||||
func listenNet(network string, port int) (*net.UDPConn, int, error) {
|
||||
conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Retrieve port.
|
||||
laddr := conn.LocalAddr()
|
||||
uaddr, err := net.ResolveUDPAddr(
|
||||
laddr.Network(),
|
||||
laddr.String(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return conn.(*net.UDPConn), uaddr.Port, nil
|
||||
}
|
||||
|
||||
// errEADDRINUSE is syscall.EADDRINUSE, boxed into an interface once
|
||||
// in erraddrinuse.go on almost all platforms. For other platforms,
|
||||
// it's at least non-nil.
|
||||
var errEADDRINUSE error = errors.New("")
|
||||
|
||||
func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var err error
|
||||
var tries int
|
||||
|
||||
if s.ipv4 != nil || s.ipv6 != nil {
|
||||
return nil, 0, ErrBindAlreadyOpen
|
||||
}
|
||||
|
||||
// Attempt to open ipv4 and ipv6 listeners on the same port.
|
||||
// If uport is 0, we can retry on failure.
|
||||
again:
|
||||
port := int(uport)
|
||||
var v4conn, v6conn *net.UDPConn
|
||||
var v4pc *ipv4.PacketConn
|
||||
var v6pc *ipv6.PacketConn
|
||||
|
||||
v4conn, port, err = listenNet("udp4", port)
|
||||
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Listen on the same port as we're using for ipv4.
|
||||
v6conn, port, err = listenNet("udp6", port)
|
||||
if uport == 0 && errors.Is(err, errEADDRINUSE) && tries < 100 {
|
||||
v4conn.Close()
|
||||
tries++
|
||||
goto again
|
||||
}
|
||||
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
v4conn.Close()
|
||||
return nil, 0, err
|
||||
}
|
||||
var fns []ReceiveFunc
|
||||
if v4conn != nil {
|
||||
s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
|
||||
if runtime.GOOS == "linux" {
|
||||
v4pc = ipv4.NewPacketConn(v4conn)
|
||||
s.ipv4PC = v4pc
|
||||
}
|
||||
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
|
||||
s.ipv4 = v4conn
|
||||
}
|
||||
if v6conn != nil {
|
||||
s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
|
||||
if runtime.GOOS == "linux" {
|
||||
v6pc = ipv6.NewPacketConn(v6conn)
|
||||
s.ipv6PC = v6pc
|
||||
}
|
||||
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
|
||||
s.ipv6 = v6conn
|
||||
}
|
||||
if len(fns) == 0 {
|
||||
return nil, 0, syscall.EAFNOSUPPORT
|
||||
}
|
||||
|
||||
return fns, uint16(port), nil
|
||||
}
|
||||
|
||||
func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
|
||||
for i := range *msgs {
|
||||
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
|
||||
}
|
||||
s.msgsPool.Put(msgs)
|
||||
}
|
||||
|
||||
func (s *StdNetBind) getMessages() *[]ipv6.Message {
|
||||
return s.msgsPool.Get().(*[]ipv6.Message)
|
||||
}
|
||||
|
||||
var (
|
||||
// If compilation fails here these are no longer the same underlying type.
|
||||
_ ipv6.Message = ipv4.Message{}
|
||||
)
|
||||
|
||||
type batchReader interface {
|
||||
ReadBatch([]ipv6.Message, int) (int, error)
|
||||
}
|
||||
|
||||
type batchWriter interface {
|
||||
WriteBatch([]ipv6.Message, int) (int, error)
|
||||
}
|
||||
|
||||
func (s *StdNetBind) receiveIP(
|
||||
br batchReader,
|
||||
conn *net.UDPConn,
|
||||
rxOffload bool,
|
||||
bufs [][]byte,
|
||||
sizes []int,
|
||||
eps []Endpoint,
|
||||
) (n int, err error) {
|
||||
msgs := s.getMessages()
|
||||
for i := range bufs {
|
||||
(*msgs)[i].Buffers[0] = bufs[i]
|
||||
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
|
||||
}
|
||||
defer s.putMessages(msgs)
|
||||
var numMsgs int
|
||||
if runtime.GOOS == "linux" {
|
||||
if rxOffload {
|
||||
readAt := len(*msgs) - 2
|
||||
numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
numMsgs, err = br.ReadBatch(*msgs, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
msg := &(*msgs)[0]
|
||||
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
numMsgs = 1
|
||||
}
|
||||
for i := 0; i < numMsgs; i++ {
|
||||
msg := &(*msgs)[i]
|
||||
sizes[i] = msg.N
|
||||
if sizes[i] == 0 {
|
||||
continue
|
||||
}
|
||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
||||
getSrcFromControl(msg.OOB[:msg.NN], ep)
|
||||
eps[i] = ep
|
||||
}
|
||||
return numMsgs, nil
|
||||
}
|
||||
|
||||
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
|
||||
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
||||
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
|
||||
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
||||
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
|
||||
// rename the IdealBatchSize constant to BatchSize.
|
||||
func (s *StdNetBind) BatchSize() int {
|
||||
if runtime.GOOS == "linux" {
|
||||
return IdealBatchSize
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func (s *StdNetBind) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var err1, err2 error
|
||||
if s.ipv4 != nil {
|
||||
err1 = s.ipv4.Close()
|
||||
s.ipv4 = nil
|
||||
s.ipv4PC = nil
|
||||
}
|
||||
if s.ipv6 != nil {
|
||||
err2 = s.ipv6.Close()
|
||||
s.ipv6 = nil
|
||||
s.ipv6PC = nil
|
||||
}
|
||||
s.blackhole4 = false
|
||||
s.blackhole6 = false
|
||||
s.ipv4TxOffload = false
|
||||
s.ipv4RxOffload = false
|
||||
s.ipv6TxOffload = false
|
||||
s.ipv6RxOffload = false
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
return err2
|
||||
}
|
||||
|
||||
type ErrUDPGSODisabled struct {
|
||||
onLaddr string
|
||||
RetryErr error
|
||||
}
|
||||
|
||||
func (e ErrUDPGSODisabled) Error() string {
|
||||
return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr)
|
||||
}
|
||||
|
||||
func (e ErrUDPGSODisabled) Unwrap() error {
|
||||
return e.RetryErr
|
||||
}
|
||||
|
||||
func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
|
||||
s.mu.Lock()
|
||||
blackhole := s.blackhole4
|
||||
conn := s.ipv4
|
||||
offload := s.ipv4TxOffload
|
||||
br := batchWriter(s.ipv4PC)
|
||||
is6 := false
|
||||
if endpoint.DstIP().Is6() {
|
||||
blackhole = s.blackhole6
|
||||
conn = s.ipv6
|
||||
br = s.ipv6PC
|
||||
is6 = true
|
||||
offload = s.ipv6TxOffload
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if blackhole {
|
||||
return nil
|
||||
}
|
||||
if conn == nil {
|
||||
return syscall.EAFNOSUPPORT
|
||||
}
|
||||
|
||||
msgs := s.getMessages()
|
||||
defer s.putMessages(msgs)
|
||||
ua := s.udpAddrPool.Get().(*net.UDPAddr)
|
||||
defer s.udpAddrPool.Put(ua)
|
||||
if is6 {
|
||||
as16 := endpoint.DstIP().As16()
|
||||
copy(ua.IP, as16[:])
|
||||
ua.IP = ua.IP[:16]
|
||||
} else {
|
||||
as4 := endpoint.DstIP().As4()
|
||||
copy(ua.IP, as4[:])
|
||||
ua.IP = ua.IP[:4]
|
||||
}
|
||||
ua.Port = int(endpoint.(*StdNetEndpoint).Port())
|
||||
var (
|
||||
retried bool
|
||||
err error
|
||||
)
|
||||
retry:
|
||||
if offload {
|
||||
n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
|
||||
err = s.send(conn, br, (*msgs)[:n])
|
||||
if err != nil && offload && errShouldDisableUDPGSO(err) {
|
||||
offload = false
|
||||
s.mu.Lock()
|
||||
if is6 {
|
||||
s.ipv6TxOffload = false
|
||||
} else {
|
||||
s.ipv4TxOffload = false
|
||||
}
|
||||
s.mu.Unlock()
|
||||
retried = true
|
||||
goto retry
|
||||
}
|
||||
} else {
|
||||
for i := range bufs {
|
||||
(*msgs)[i].Addr = ua
|
||||
(*msgs)[i].Buffers[0] = bufs[i]
|
||||
setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
|
||||
}
|
||||
err = s.send(conn, br, (*msgs)[:len(bufs)])
|
||||
}
|
||||
if retried {
|
||||
return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
|
||||
var (
|
||||
n int
|
||||
err error
|
||||
start int
|
||||
)
|
||||
if runtime.GOOS == "linux" {
|
||||
for {
|
||||
n, err = pc.WriteBatch(msgs[start:], 0)
|
||||
if err != nil || n == len(msgs[start:]) {
|
||||
break
|
||||
}
|
||||
start += n
|
||||
}
|
||||
} else {
|
||||
for _, msg := range msgs {
|
||||
_, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
const (
|
||||
// Exceeding these values results in EMSGSIZE. They account for layer3 and
|
||||
// layer4 headers. IPv6 does not need to account for itself as the payload
|
||||
// length field is self excluding.
|
||||
maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
|
||||
maxIPv6PayloadLen = 1<<16 - 1 - 8
|
||||
|
||||
// This is a hard limit imposed by the kernel.
|
||||
udpSegmentMaxDatagrams = 64
|
||||
)
|
||||
|
||||
type setGSOFunc func(control *[]byte, gsoSize uint16)
|
||||
|
||||
func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
|
||||
var (
|
||||
base = -1 // index of msg we are currently coalescing into
|
||||
gsoSize int // segmentation size of msgs[base]
|
||||
dgramCnt int // number of dgrams coalesced into msgs[base]
|
||||
endBatch bool // tracking flag to start a new batch on next iteration of bufs
|
||||
)
|
||||
maxPayloadLen := maxIPv4PayloadLen
|
||||
if ep.DstIP().Is6() {
|
||||
maxPayloadLen = maxIPv6PayloadLen
|
||||
}
|
||||
for i, buf := range bufs {
|
||||
if i > 0 {
|
||||
msgLen := len(buf)
|
||||
baseLenBefore := len(msgs[base].Buffers[0])
|
||||
freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
|
||||
if msgLen+baseLenBefore <= maxPayloadLen &&
|
||||
msgLen <= gsoSize &&
|
||||
msgLen <= freeBaseCap &&
|
||||
dgramCnt < udpSegmentMaxDatagrams &&
|
||||
!endBatch {
|
||||
msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
|
||||
if i == len(bufs)-1 {
|
||||
setGSO(&msgs[base].OOB, uint16(gsoSize))
|
||||
}
|
||||
dgramCnt++
|
||||
if msgLen < gsoSize {
|
||||
// A smaller than gsoSize packet on the tail is legal, but
|
||||
// it must end the batch.
|
||||
endBatch = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
if dgramCnt > 1 {
|
||||
setGSO(&msgs[base].OOB, uint16(gsoSize))
|
||||
}
|
||||
// Reset prior to incrementing base since we are preparing to start a
|
||||
// new potential batch.
|
||||
endBatch = false
|
||||
base++
|
||||
gsoSize = len(buf)
|
||||
setSrcControl(&msgs[base].OOB, ep)
|
||||
msgs[base].Buffers[0] = buf
|
||||
msgs[base].Addr = addr
|
||||
dgramCnt = 1
|
||||
}
|
||||
return base + 1
|
||||
}
|
||||
|
||||
type getGSOFunc func(control []byte) (int, error)
|
||||
|
||||
func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
|
||||
for i := firstMsgAt; i < len(msgs); i++ {
|
||||
msg := &msgs[i]
|
||||
if msg.N == 0 {
|
||||
return n, err
|
||||
}
|
||||
var (
|
||||
gsoSize int
|
||||
start int
|
||||
end = msg.N
|
||||
numToSplit = 1
|
||||
)
|
||||
gsoSize, err = getGSO(msg.OOB[:msg.NN])
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
if gsoSize > 0 {
|
||||
numToSplit = (msg.N + gsoSize - 1) / gsoSize
|
||||
end = gsoSize
|
||||
}
|
||||
for j := 0; j < numToSplit; j++ {
|
||||
if n > i {
|
||||
return n, errors.New("splitting coalesced packet resulted in overflow")
|
||||
}
|
||||
copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
|
||||
msgs[n].N = copied
|
||||
msgs[n].Addr = msg.Addr
|
||||
start = end
|
||||
end += gsoSize
|
||||
if end > msg.N {
|
||||
end = msg.N
|
||||
}
|
||||
n++
|
||||
}
|
||||
if i != n-1 {
|
||||
// It is legal for bytes to move within msg.Buffers[0] as a result
|
||||
// of splitting, so we only zero the source msg len when it is not
|
||||
// the destination of the last split operation above.
|
||||
msg.N = 0
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
601
vendor/github.com/tailscale/wireguard-go/conn/bind_windows.go
generated
vendored
Normal file
601
vendor/github.com/tailscale/wireguard-go/conn/bind_windows.go
generated
vendored
Normal file
@@ -0,0 +1,601 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"github.com/tailscale/wireguard-go/conn/winrio"
|
||||
)
|
||||
|
||||
const (
|
||||
packetsPerRing = 1024
|
||||
bytesPerPacket = 2048 - 32
|
||||
receiveSpins = 15
|
||||
)
|
||||
|
||||
type ringPacket struct {
|
||||
addr WinRingEndpoint
|
||||
data [bytesPerPacket]byte
|
||||
}
|
||||
|
||||
type ringBuffer struct {
|
||||
packets uintptr
|
||||
head, tail uint32
|
||||
id winrio.BufferId
|
||||
iocp windows.Handle
|
||||
isFull bool
|
||||
cq winrio.Cq
|
||||
mu sync.Mutex
|
||||
overlapped windows.Overlapped
|
||||
}
|
||||
|
||||
func (rb *ringBuffer) Push() *ringPacket {
|
||||
for rb.isFull {
|
||||
panic("ring is full")
|
||||
}
|
||||
ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{}))))
|
||||
rb.tail += 1
|
||||
if rb.tail%packetsPerRing == rb.head%packetsPerRing {
|
||||
rb.isFull = true
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (rb *ringBuffer) Return(count uint32) {
|
||||
if rb.head%packetsPerRing == rb.tail%packetsPerRing && !rb.isFull {
|
||||
return
|
||||
}
|
||||
rb.head += count
|
||||
rb.isFull = false
|
||||
}
|
||||
|
||||
type afWinRingBind struct {
|
||||
sock windows.Handle
|
||||
rx, tx ringBuffer
|
||||
rq winrio.Rq
|
||||
mu sync.Mutex
|
||||
blackhole bool
|
||||
}
|
||||
|
||||
// WinRingBind uses Windows registered I/O for fast ring buffered networking.
|
||||
type WinRingBind struct {
|
||||
v4, v6 afWinRingBind
|
||||
mu sync.RWMutex
|
||||
isOpen atomic.Uint32 // 0, 1, or 2
|
||||
}
|
||||
|
||||
func NewDefaultBind() Bind { return NewWinRingBind() }
|
||||
|
||||
func NewWinRingBind() Bind {
|
||||
if !winrio.Initialize() {
|
||||
return NewStdNetBind()
|
||||
}
|
||||
return new(WinRingBind)
|
||||
}
|
||||
|
||||
type WinRingEndpoint struct {
|
||||
family uint16
|
||||
data [30]byte
|
||||
}
|
||||
|
||||
var (
|
||||
_ Bind = (*WinRingBind)(nil)
|
||||
_ Endpoint = (*WinRingEndpoint)(nil)
|
||||
)
|
||||
|
||||
func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
|
||||
host, port, err := net.SplitHostPort(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
host16, err := windows.UTF16PtrFromString(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
port16, err := windows.UTF16PtrFromString(port)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hints := windows.AddrinfoW{
|
||||
Flags: windows.AI_NUMERICHOST,
|
||||
Family: windows.AF_UNSPEC,
|
||||
Socktype: windows.SOCK_DGRAM,
|
||||
Protocol: windows.IPPROTO_UDP,
|
||||
}
|
||||
var addrinfo *windows.AddrinfoW
|
||||
err = windows.GetAddrInfoW(host16, port16, &hints, &addrinfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer windows.FreeAddrInfoW(addrinfo)
|
||||
if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) {
|
||||
return nil, windows.ERROR_INVALID_ADDRESS
|
||||
}
|
||||
var dst [unsafe.Sizeof(WinRingEndpoint{})]byte
|
||||
copy(dst[:], unsafe.Slice((*byte)(unsafe.Pointer(addrinfo.Addr)), addrinfo.Addrlen))
|
||||
return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil
|
||||
}
|
||||
|
||||
func (*WinRingEndpoint) ClearSrc() {}
|
||||
|
||||
func (e *WinRingEndpoint) DstIP() netip.Addr {
|
||||
switch e.family {
|
||||
case windows.AF_INET:
|
||||
return netip.AddrFrom4(*(*[4]byte)(e.data[2:6]))
|
||||
case windows.AF_INET6:
|
||||
return netip.AddrFrom16(*(*[16]byte)(e.data[6:22]))
|
||||
}
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
func (e *WinRingEndpoint) SrcIP() netip.Addr {
|
||||
return netip.Addr{} // not supported
|
||||
}
|
||||
|
||||
func (e *WinRingEndpoint) DstToBytes() []byte {
|
||||
switch e.family {
|
||||
case windows.AF_INET:
|
||||
b := make([]byte, 0, 6)
|
||||
b = append(b, e.data[2:6]...)
|
||||
b = append(b, e.data[1], e.data[0])
|
||||
return b
|
||||
case windows.AF_INET6:
|
||||
b := make([]byte, 0, 18)
|
||||
b = append(b, e.data[6:22]...)
|
||||
b = append(b, e.data[1], e.data[0])
|
||||
return b
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *WinRingEndpoint) DstToString() string {
|
||||
switch e.family {
|
||||
case windows.AF_INET:
|
||||
return netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
|
||||
case windows.AF_INET6:
|
||||
var zone string
|
||||
if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
|
||||
zone = strconv.FormatUint(uint64(scope), 10)
|
||||
}
|
||||
return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (e *WinRingEndpoint) SrcToString() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (ring *ringBuffer) CloseAndZero() {
|
||||
if ring.cq != 0 {
|
||||
winrio.CloseCompletionQueue(ring.cq)
|
||||
ring.cq = 0
|
||||
}
|
||||
if ring.iocp != 0 {
|
||||
windows.CloseHandle(ring.iocp)
|
||||
ring.iocp = 0
|
||||
}
|
||||
if ring.id != 0 {
|
||||
winrio.DeregisterBuffer(ring.id)
|
||||
ring.id = 0
|
||||
}
|
||||
if ring.packets != 0 {
|
||||
windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE)
|
||||
ring.packets = 0
|
||||
}
|
||||
ring.head = 0
|
||||
ring.tail = 0
|
||||
ring.isFull = false
|
||||
}
|
||||
|
||||
func (bind *afWinRingBind) CloseAndZero() {
|
||||
bind.rx.CloseAndZero()
|
||||
bind.tx.CloseAndZero()
|
||||
if bind.sock != 0 {
|
||||
windows.CloseHandle(bind.sock)
|
||||
bind.sock = 0
|
||||
}
|
||||
bind.blackhole = false
|
||||
}
|
||||
|
||||
func (bind *WinRingBind) closeAndZero() {
|
||||
bind.isOpen.Store(0)
|
||||
bind.v4.CloseAndZero()
|
||||
bind.v6.CloseAndZero()
|
||||
}
|
||||
|
||||
func (ring *ringBuffer) Open() error {
|
||||
var err error
|
||||
packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing
|
||||
ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sockaddr, error) {
|
||||
var err error
|
||||
bind.sock, err = winrio.Socket(family, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = bind.rx.Open()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = bind.tx.Open()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bind.rq, err = winrio.CreateRequestQueue(bind.sock, packetsPerRing, 1, packetsPerRing, 1, bind.rx.cq, bind.tx.cq, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = windows.Bind(bind.sock, sa)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sa, err = windows.Getsockname(bind.sock)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sa, nil
|
||||
}
|
||||
|
||||
func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) {
|
||||
bind.mu.Lock()
|
||||
defer bind.mu.Unlock()
|
||||
defer func() {
|
||||
if err != nil {
|
||||
bind.closeAndZero()
|
||||
}
|
||||
}()
|
||||
if bind.isOpen.Load() != 0 {
|
||||
return nil, 0, ErrBindAlreadyOpen
|
||||
}
|
||||
var sa windows.Sockaddr
|
||||
sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)})
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port})
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
selectedPort = uint16(sa.(*windows.SockaddrInet6).Port)
|
||||
for i := 0; i < packetsPerRing; i++ {
|
||||
err = bind.v4.InsertReceiveRequest()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
err = bind.v6.InsertReceiveRequest()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
bind.isOpen.Store(1)
|
||||
return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
|
||||
}
|
||||
|
||||
func (bind *WinRingBind) Close() error {
|
||||
bind.mu.RLock()
|
||||
if bind.isOpen.Load() != 1 {
|
||||
bind.mu.RUnlock()
|
||||
return nil
|
||||
}
|
||||
bind.isOpen.Store(2)
|
||||
windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
|
||||
windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
|
||||
windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
|
||||
windows.PostQueuedCompletionStatus(bind.v6.tx.iocp, 0, 0, nil)
|
||||
bind.mu.RUnlock()
|
||||
bind.mu.Lock()
|
||||
defer bind.mu.Unlock()
|
||||
bind.closeAndZero()
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
|
||||
// rename the IdealBatchSize constant to BatchSize.
|
||||
func (bind *WinRingBind) BatchSize() int {
|
||||
// TODO: implement batching in and out of the ring
|
||||
return 1
|
||||
}
|
||||
|
||||
func (bind *WinRingBind) SetMark(mark uint32) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bind *afWinRingBind) InsertReceiveRequest() error {
|
||||
packet := bind.rx.Push()
|
||||
dataBuffer := &winrio.Buffer{
|
||||
Id: bind.rx.id,
|
||||
Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.rx.packets),
|
||||
Length: uint32(len(packet.data)),
|
||||
}
|
||||
addressBuffer := &winrio.Buffer{
|
||||
Id: bind.rx.id,
|
||||
Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.rx.packets),
|
||||
Length: uint32(unsafe.Sizeof(packet.addr)),
|
||||
}
|
||||
bind.mu.Lock()
|
||||
defer bind.mu.Unlock()
|
||||
return winrio.ReceiveEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet)))
|
||||
}
|
||||
|
||||
//go:linkname procyield runtime.procyield
|
||||
func procyield(cycles uint32)
|
||||
|
||||
func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) {
|
||||
if isOpen.Load() != 1 {
|
||||
return 0, nil, net.ErrClosed
|
||||
}
|
||||
bind.rx.mu.Lock()
|
||||
defer bind.rx.mu.Unlock()
|
||||
|
||||
var err error
|
||||
var count uint32
|
||||
var results [1]winrio.Result
|
||||
retry:
|
||||
count = 0
|
||||
for tries := 0; count == 0 && tries < receiveSpins; tries++ {
|
||||
if tries > 0 {
|
||||
if isOpen.Load() != 1 {
|
||||
return 0, nil, net.ErrClosed
|
||||
}
|
||||
procyield(1)
|
||||
}
|
||||
count = winrio.DequeueCompletion(bind.rx.cq, results[:])
|
||||
}
|
||||
if count == 0 {
|
||||
err = winrio.Notify(bind.rx.cq)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
var bytes uint32
|
||||
var key uintptr
|
||||
var overlapped *windows.Overlapped
|
||||
err = windows.GetQueuedCompletionStatus(bind.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
if isOpen.Load() != 1 {
|
||||
return 0, nil, net.ErrClosed
|
||||
}
|
||||
count = winrio.DequeueCompletion(bind.rx.cq, results[:])
|
||||
if count == 0 {
|
||||
return 0, nil, io.ErrNoProgress
|
||||
}
|
||||
}
|
||||
bind.rx.Return(1)
|
||||
err = bind.InsertReceiveRequest()
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
// We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us
|
||||
// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
|
||||
// attacker bandwidth, just like the rest of the receive path.
|
||||
if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
|
||||
if isOpen.Load() != 1 {
|
||||
return 0, nil, net.ErrClosed
|
||||
}
|
||||
goto retry
|
||||
}
|
||||
if results[0].Status != 0 {
|
||||
return 0, nil, windows.Errno(results[0].Status)
|
||||
}
|
||||
packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext)))
|
||||
ep := packet.addr
|
||||
n := copy(buf, packet.data[:results[0].BytesTransferred])
|
||||
return n, &ep, nil
|
||||
}
|
||||
|
||||
func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
|
||||
bind.mu.RLock()
|
||||
defer bind.mu.RUnlock()
|
||||
n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen)
|
||||
sizes[0] = n
|
||||
eps[0] = ep
|
||||
return 1, err
|
||||
}
|
||||
|
||||
func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
|
||||
bind.mu.RLock()
|
||||
defer bind.mu.RUnlock()
|
||||
n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen)
|
||||
sizes[0] = n
|
||||
eps[0] = ep
|
||||
return 1, err
|
||||
}
|
||||
|
||||
func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
|
||||
if isOpen.Load() != 1 {
|
||||
return net.ErrClosed
|
||||
}
|
||||
if len(buf) > bytesPerPacket {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
bind.tx.mu.Lock()
|
||||
defer bind.tx.mu.Unlock()
|
||||
var results [packetsPerRing]winrio.Result
|
||||
count := winrio.DequeueCompletion(bind.tx.cq, results[:])
|
||||
if count == 0 && bind.tx.isFull {
|
||||
err := winrio.Notify(bind.tx.cq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var bytes uint32
|
||||
var key uintptr
|
||||
var overlapped *windows.Overlapped
|
||||
err = windows.GetQueuedCompletionStatus(bind.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if isOpen.Load() != 1 {
|
||||
return net.ErrClosed
|
||||
}
|
||||
count = winrio.DequeueCompletion(bind.tx.cq, results[:])
|
||||
if count == 0 {
|
||||
return io.ErrNoProgress
|
||||
}
|
||||
}
|
||||
if count > 0 {
|
||||
bind.tx.Return(count)
|
||||
}
|
||||
packet := bind.tx.Push()
|
||||
packet.addr = *nend
|
||||
copy(packet.data[:], buf)
|
||||
dataBuffer := &winrio.Buffer{
|
||||
Id: bind.tx.id,
|
||||
Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.tx.packets),
|
||||
Length: uint32(len(buf)),
|
||||
}
|
||||
addressBuffer := &winrio.Buffer{
|
||||
Id: bind.tx.id,
|
||||
Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.tx.packets),
|
||||
Length: uint32(unsafe.Sizeof(packet.addr)),
|
||||
}
|
||||
bind.mu.Lock()
|
||||
defer bind.mu.Unlock()
|
||||
return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
|
||||
}
|
||||
|
||||
func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error {
|
||||
nend, ok := endpoint.(*WinRingEndpoint)
|
||||
if !ok {
|
||||
return ErrWrongEndpointType
|
||||
}
|
||||
bind.mu.RLock()
|
||||
defer bind.mu.RUnlock()
|
||||
for _, buf := range bufs {
|
||||
switch nend.family {
|
||||
case windows.AF_INET:
|
||||
if bind.v4.blackhole {
|
||||
continue
|
||||
}
|
||||
if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil {
|
||||
return err
|
||||
}
|
||||
case windows.AF_INET6:
|
||||
if bind.v6.blackhole {
|
||||
continue
|
||||
}
|
||||
if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
sysconn, err := s.ipv4.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err2 := sysconn.Control(func(fd uintptr) {
|
||||
err = bindSocketToInterface4(windows.Handle(fd), interfaceIndex)
|
||||
})
|
||||
if err2 != nil {
|
||||
return err2
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.blackhole4 = blackhole
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
sysconn, err := s.ipv6.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err2 := sysconn.Control(func(fd uintptr) {
|
||||
err = bindSocketToInterface6(windows.Handle(fd), interfaceIndex)
|
||||
})
|
||||
if err2 != nil {
|
||||
return err2
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.blackhole6 = blackhole
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
||||
bind.mu.RLock()
|
||||
defer bind.mu.RUnlock()
|
||||
if bind.isOpen.Load() != 1 {
|
||||
return net.ErrClosed
|
||||
}
|
||||
err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
bind.v4.blackhole = blackhole
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
||||
bind.mu.RLock()
|
||||
defer bind.mu.RUnlock()
|
||||
if bind.isOpen.Load() != 1 {
|
||||
return net.ErrClosed
|
||||
}
|
||||
err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
bind.v6.blackhole = blackhole
|
||||
return nil
|
||||
}
|
||||
|
||||
func bindSocketToInterface4(handle windows.Handle, interfaceIndex uint32) error {
|
||||
const IP_UNICAST_IF = 31
|
||||
/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
|
||||
var bytes [4]byte
|
||||
binary.BigEndian.PutUint32(bytes[:], interfaceIndex)
|
||||
interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
|
||||
err := windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(interfaceIndex))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error {
|
||||
const IPV6_UNICAST_IF = 31
|
||||
return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex))
|
||||
}
|
||||
34
vendor/github.com/tailscale/wireguard-go/conn/boundif_android.go
generated
vendored
Normal file
34
vendor/github.com/tailscale/wireguard-go/conn/boundif_android.go
generated
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
func (s *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
|
||||
sysconn, err := s.ipv4.SyscallConn()
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
err = sysconn.Control(func(f uintptr) {
|
||||
fd = int(f)
|
||||
})
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
|
||||
sysconn, err := s.ipv6.SyscallConn()
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
err = sysconn.Control(func(f uintptr) {
|
||||
fd = int(f)
|
||||
})
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
return
|
||||
}
|
||||
147
vendor/github.com/tailscale/wireguard-go/conn/conn.go
generated
vendored
Normal file
147
vendor/github.com/tailscale/wireguard-go/conn/conn.go
generated
vendored
Normal file
@@ -0,0 +1,147 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
// Package conn implements WireGuard's network connections.
|
||||
package conn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
IdealBatchSize = 128 // maximum number of packets handled per read and write
|
||||
)
|
||||
|
||||
// A ReceiveFunc receives at least one packet from the network and writes them
|
||||
// into packets. On a successful read it returns the number of elements of
|
||||
// sizes, packets, and endpoints that should be evaluated. Some elements of
|
||||
// sizes may be zero, and callers should ignore them. Callers must pass a sizes
|
||||
// and eps slice with a length greater than or equal to the length of packets.
|
||||
// These lengths must not exceed the length of the associated Bind.BatchSize().
|
||||
type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error)
|
||||
|
||||
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
|
||||
//
|
||||
// A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface,
|
||||
// depending on the platform-specific implementation.
|
||||
type Bind interface {
|
||||
// Open puts the Bind into a listening state on a given port and reports the actual
|
||||
// port that it bound to. Passing zero results in a random selection.
|
||||
// fns is the set of functions that will be called to receive packets.
|
||||
Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error)
|
||||
|
||||
// Close closes the Bind listener.
|
||||
// All fns returned by Open must return net.ErrClosed after a call to Close.
|
||||
Close() error
|
||||
|
||||
// SetMark sets the mark for each packet sent through this Bind.
|
||||
// This mark is passed to the kernel as the socket option SO_MARK.
|
||||
SetMark(mark uint32) error
|
||||
|
||||
// Send writes one or more packets in bufs to address ep. The length of
|
||||
// bufs must not exceed BatchSize().
|
||||
Send(bufs [][]byte, ep Endpoint) error
|
||||
|
||||
// ParseEndpoint creates a new endpoint from a string.
|
||||
ParseEndpoint(s string) (Endpoint, error)
|
||||
|
||||
// BatchSize is the number of buffers expected to be passed to
|
||||
// the ReceiveFuncs, and the maximum expected to be passed to SendBatch.
|
||||
BatchSize() int
|
||||
}
|
||||
|
||||
// BindSocketToInterface is implemented by Bind objects that support being
|
||||
// tied to a single network interface. Used by wireguard-windows.
|
||||
type BindSocketToInterface interface {
|
||||
BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error
|
||||
BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error
|
||||
}
|
||||
|
||||
// PeekLookAtSocketFd is implemented by Bind objects that support having their
|
||||
// file descriptor peeked at. Used by wireguard-android.
|
||||
type PeekLookAtSocketFd interface {
|
||||
PeekLookAtSocketFd4() (fd int, err error)
|
||||
PeekLookAtSocketFd6() (fd int, err error)
|
||||
}
|
||||
|
||||
// An Endpoint maintains the source/destination caching for a peer.
|
||||
//
|
||||
// dst: the remote address of a peer ("endpoint" in uapi terminology)
|
||||
// src: the local address from which datagrams originate going to the peer
|
||||
type Endpoint interface {
|
||||
ClearSrc() // clears the source address
|
||||
SrcToString() string // returns the local source address (ip:port)
|
||||
DstToString() string // returns the destination address (ip:port)
|
||||
DstToBytes() []byte // used for mac2 cookie calculations
|
||||
DstIP() netip.Addr
|
||||
SrcIP() netip.Addr
|
||||
}
|
||||
|
||||
// PeerAwareEndpoint is an optional Endpoint specialization for
|
||||
// integrations that want to know about the outcome of cryptorouting
|
||||
// identification.
|
||||
//
|
||||
// If they receive a packet from a source they had not pre-identified,
|
||||
// to learn the identification WireGuard can derive from the session
|
||||
// or handshake.
|
||||
//
|
||||
// If GetPeerEndpoint returns nil, WireGuard will be unable to respond
|
||||
// to the peer until a new endpoint is written by a later packet.
|
||||
type PeerAwareEndpoint interface {
|
||||
GetPeerEndpoint(peerPublicKey [32]byte) Endpoint
|
||||
}
|
||||
|
||||
var (
|
||||
ErrBindAlreadyOpen = errors.New("bind is already open")
|
||||
ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type")
|
||||
)
|
||||
|
||||
func (fn ReceiveFunc) PrettyName() string {
|
||||
name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
|
||||
// 0. cheese/taco.beansIPv6.func12.func21218-fm
|
||||
name = strings.TrimSuffix(name, "-fm")
|
||||
// 1. cheese/taco.beansIPv6.func12.func21218
|
||||
if idx := strings.LastIndexByte(name, '/'); idx != -1 {
|
||||
name = name[idx+1:]
|
||||
// 2. taco.beansIPv6.func12.func21218
|
||||
}
|
||||
for {
|
||||
var idx int
|
||||
for idx = len(name) - 1; idx >= 0; idx-- {
|
||||
if name[idx] < '0' || name[idx] > '9' {
|
||||
break
|
||||
}
|
||||
}
|
||||
if idx == len(name)-1 {
|
||||
break
|
||||
}
|
||||
const dotFunc = ".func"
|
||||
if !strings.HasSuffix(name[:idx+1], dotFunc) {
|
||||
break
|
||||
}
|
||||
name = name[:idx+1-len(dotFunc)]
|
||||
// 3. taco.beansIPv6.func12
|
||||
// 4. taco.beansIPv6
|
||||
}
|
||||
if idx := strings.LastIndexByte(name, '.'); idx != -1 {
|
||||
name = name[idx+1:]
|
||||
// 5. beansIPv6
|
||||
}
|
||||
if name == "" {
|
||||
return fmt.Sprintf("%p", fn)
|
||||
}
|
||||
if strings.HasSuffix(name, "IPv4") {
|
||||
return "v4"
|
||||
}
|
||||
if strings.HasSuffix(name, "IPv6") {
|
||||
return "v6"
|
||||
}
|
||||
return name
|
||||
}
|
||||
51
vendor/github.com/tailscale/wireguard-go/conn/control_default.go
generated
vendored
Normal file
51
vendor/github.com/tailscale/wireguard-go/conn/control_default.go
generated
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
//go:build !(linux && !android)
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import "net/netip"
|
||||
|
||||
func (e *StdNetEndpoint) SrcIP() netip.Addr {
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) SrcIfidx() int32 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) SrcToString() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets
|
||||
// ({get,set}srcControl feature set, but use alternatively named flags and need
|
||||
// ports and require testing.
|
||||
|
||||
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
|
||||
// the source information found.
|
||||
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
|
||||
}
|
||||
|
||||
// setSrcControl parses the control for PKTINFO and if found updates ep with
|
||||
// the source information found.
|
||||
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
|
||||
}
|
||||
|
||||
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
|
||||
func getGSOSize(control []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize.
|
||||
func setGSOSize(control *[]byte, gsoSize uint16) {
|
||||
}
|
||||
|
||||
// controlSize returns the recommended buffer size for pooling sticky and UDP
|
||||
// offloading control data.
|
||||
const controlSize = 0
|
||||
|
||||
const StdNetSupportsStickySockets = false
|
||||
159
vendor/github.com/tailscale/wireguard-go/conn/control_linux.go
generated
vendored
Normal file
159
vendor/github.com/tailscale/wireguard-go/conn/control_linux.go
generated
vendored
Normal file
@@ -0,0 +1,159 @@
|
||||
//go:build linux && !android
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func (e *StdNetEndpoint) SrcIP() netip.Addr {
|
||||
switch len(e.src) {
|
||||
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
|
||||
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
|
||||
return netip.AddrFrom4(info.Spec_dst)
|
||||
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
|
||||
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
|
||||
// TODO: set zone. in order to do so we need to check if the address is
|
||||
// link local, and if it is perform a syscall to turn the ifindex into a
|
||||
// zone string because netip uses string zones.
|
||||
return netip.AddrFrom16(info.Addr)
|
||||
}
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) SrcIfidx() int32 {
|
||||
switch len(e.src) {
|
||||
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
|
||||
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
|
||||
return info.Ifindex
|
||||
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
|
||||
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
|
||||
return int32(info.Ifindex)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) SrcToString() string {
|
||||
return e.SrcIP().String()
|
||||
}
|
||||
|
||||
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
|
||||
// the source information found.
|
||||
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
|
||||
ep.ClearSrc()
|
||||
|
||||
var (
|
||||
hdr unix.Cmsghdr
|
||||
data []byte
|
||||
rem []byte = control
|
||||
err error
|
||||
)
|
||||
|
||||
for len(rem) > unix.SizeofCmsghdr {
|
||||
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if hdr.Level == unix.IPPROTO_IP &&
|
||||
hdr.Type == unix.IP_PKTINFO {
|
||||
|
||||
if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) {
|
||||
ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
|
||||
}
|
||||
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
|
||||
|
||||
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
|
||||
copy(ep.src, hdrBuf)
|
||||
copy(ep.src[unix.CmsgLen(0):], data)
|
||||
return
|
||||
}
|
||||
|
||||
if hdr.Level == unix.IPPROTO_IPV6 &&
|
||||
hdr.Type == unix.IPV6_PKTINFO {
|
||||
|
||||
if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) {
|
||||
ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
|
||||
}
|
||||
|
||||
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
|
||||
|
||||
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
|
||||
copy(ep.src, hdrBuf)
|
||||
copy(ep.src[unix.CmsgLen(0):], data)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address
|
||||
// and source ifindex found in ep. control's len will be set to 0 in the event
|
||||
// that ep is a default value.
|
||||
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
|
||||
if cap(*control) < len(ep.src) {
|
||||
return
|
||||
}
|
||||
*control = (*control)[:0]
|
||||
*control = append(*control, ep.src...)
|
||||
}
|
||||
|
||||
const (
|
||||
sizeOfGSOData = 2
|
||||
)
|
||||
|
||||
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
|
||||
func getGSOSize(control []byte) (int, error) {
|
||||
var (
|
||||
hdr unix.Cmsghdr
|
||||
data []byte
|
||||
rem = control
|
||||
err error
|
||||
)
|
||||
|
||||
for len(rem) > unix.SizeofCmsghdr {
|
||||
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error parsing socket control message: %w", err)
|
||||
}
|
||||
if hdr.Level == socketOptionLevelUDP && hdr.Type == socketOptionUDPGRO && len(data) >= sizeOfGSOData {
|
||||
var gso uint16
|
||||
copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData])
|
||||
return int(gso), nil
|
||||
}
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing
|
||||
// data in control untouched.
|
||||
func setGSOSize(control *[]byte, gsoSize uint16) {
|
||||
existingLen := len(*control)
|
||||
avail := cap(*control) - existingLen
|
||||
space := unix.CmsgSpace(sizeOfGSOData)
|
||||
if avail < space {
|
||||
return
|
||||
}
|
||||
*control = (*control)[:cap(*control)]
|
||||
gsoControl := (*control)[existingLen:]
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0]))
|
||||
hdr.Level = socketOptionLevelUDP
|
||||
hdr.Type = socketOptionUDPSegment
|
||||
hdr.SetLen(unix.CmsgLen(sizeOfGSOData))
|
||||
copy((gsoControl)[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData))
|
||||
*control = (*control)[:existingLen+space]
|
||||
}
|
||||
|
||||
// controlSize returns the recommended buffer size for pooling sticky and UDP
|
||||
// offloading control data.
|
||||
var controlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) + unix.CmsgSpace(sizeOfGSOData)
|
||||
|
||||
const StdNetSupportsStickySockets = true
|
||||
43
vendor/github.com/tailscale/wireguard-go/conn/controlfns.go
generated
vendored
Normal file
43
vendor/github.com/tailscale/wireguard-go/conn/controlfns.go
generated
vendored
Normal file
@@ -0,0 +1,43 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"net"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is
|
||||
// the max supported by a default configuration of macOS. Some platforms will
|
||||
// silently clamp the value to other maximums, such as linux clamping to
|
||||
// net.core.{r,w}mem_max (see _linux.go for additional implementation that works
|
||||
// around this limitation)
|
||||
const socketBufferSize = 7 << 20
|
||||
|
||||
// controlFn is the callback function signature from net.ListenConfig.Control.
|
||||
// It is used to apply platform specific configuration to the socket prior to
|
||||
// bind.
|
||||
type controlFn func(network, address string, c syscall.RawConn) error
|
||||
|
||||
// controlFns is a list of functions that are called from the listen config
|
||||
// that can apply socket options.
|
||||
var controlFns = []controlFn{}
|
||||
|
||||
// listenConfig returns a net.ListenConfig that applies the controlFns to the
|
||||
// socket prior to bind. This is used to apply socket buffer sizing and packet
|
||||
// information OOB configuration for sticky sockets.
|
||||
func listenConfig() *net.ListenConfig {
|
||||
return &net.ListenConfig{
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
for _, fn := range controlFns {
|
||||
if err := fn(network, address, c); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
69
vendor/github.com/tailscale/wireguard-go/conn/controlfns_linux.go
generated
vendored
Normal file
69
vendor/github.com/tailscale/wireguard-go/conn/controlfns_linux.go
generated
vendored
Normal file
@@ -0,0 +1,69 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func init() {
|
||||
controlFns = append(controlFns,
|
||||
|
||||
// Attempt to set the socket buffer size beyond net.core.{r,w}mem_max by
|
||||
// using SO_*BUFFORCE. This requires CAP_NET_ADMIN, and is allowed here to
|
||||
// fail silently - the result of failure is lower performance on very fast
|
||||
// links or high latency links.
|
||||
func(network, address string, c syscall.RawConn) error {
|
||||
return c.Control(func(fd uintptr) {
|
||||
// Set up to *mem_max
|
||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
|
||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
|
||||
// Set beyond *mem_max if CAP_NET_ADMIN
|
||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize)
|
||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize)
|
||||
})
|
||||
},
|
||||
|
||||
// Enable receiving of the packet information (IP_PKTINFO for IPv4,
|
||||
// IPV6_PKTINFO for IPv6) that is used to implement sticky socket support.
|
||||
func(network, address string, c syscall.RawConn) error {
|
||||
var err error
|
||||
switch network {
|
||||
case "udp4":
|
||||
if runtime.GOOS != "android" {
|
||||
c.Control(func(fd uintptr) {
|
||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
|
||||
})
|
||||
}
|
||||
case "udp6":
|
||||
c.Control(func(fd uintptr) {
|
||||
if runtime.GOOS != "android" {
|
||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
|
||||
})
|
||||
default:
|
||||
err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL)
|
||||
}
|
||||
return err
|
||||
},
|
||||
|
||||
// Attempt to enable UDP_GRO
|
||||
func(network, address string, c syscall.RawConn) error {
|
||||
c.Control(func(fd uintptr) {
|
||||
_ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, socketOptionUDPGRO, 1)
|
||||
})
|
||||
return nil
|
||||
},
|
||||
)
|
||||
}
|
||||
35
vendor/github.com/tailscale/wireguard-go/conn/controlfns_unix.go
generated
vendored
Normal file
35
vendor/github.com/tailscale/wireguard-go/conn/controlfns_unix.go
generated
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
//go:build !windows && !linux && !wasm && !plan9 && !tamago
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func init() {
|
||||
controlFns = append(controlFns,
|
||||
func(network, address string, c syscall.RawConn) error {
|
||||
return c.Control(func(fd uintptr) {
|
||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
|
||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
|
||||
})
|
||||
},
|
||||
|
||||
func(network, address string, c syscall.RawConn) error {
|
||||
var err error
|
||||
if network == "udp6" {
|
||||
c.Control(func(fd uintptr) {
|
||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
|
||||
})
|
||||
}
|
||||
return err
|
||||
},
|
||||
)
|
||||
}
|
||||
23
vendor/github.com/tailscale/wireguard-go/conn/controlfns_windows.go
generated
vendored
Normal file
23
vendor/github.com/tailscale/wireguard-go/conn/controlfns_windows.go
generated
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func init() {
|
||||
controlFns = append(controlFns,
|
||||
func(network, address string, c syscall.RawConn) error {
|
||||
return c.Control(func(fd uintptr) {
|
||||
_ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVBUF, socketBufferSize)
|
||||
_ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_SNDBUF, socketBufferSize)
|
||||
})
|
||||
},
|
||||
)
|
||||
}
|
||||
10
vendor/github.com/tailscale/wireguard-go/conn/default.go
generated
vendored
Normal file
10
vendor/github.com/tailscale/wireguard-go/conn/default.go
generated
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
//go:build !windows
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
func NewDefaultBind() Bind { return NewStdNetBind() }
|
||||
14
vendor/github.com/tailscale/wireguard-go/conn/erraddrinuse.go
generated
vendored
Normal file
14
vendor/github.com/tailscale/wireguard-go/conn/erraddrinuse.go
generated
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
//go:build !plan9
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import "syscall"
|
||||
|
||||
func init() {
|
||||
errEADDRINUSE = syscall.EADDRINUSE
|
||||
}
|
||||
12
vendor/github.com/tailscale/wireguard-go/conn/errors_default.go
generated
vendored
Normal file
12
vendor/github.com/tailscale/wireguard-go/conn/errors_default.go
generated
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build !linux
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
func errShouldDisableUDPGSO(err error) bool {
|
||||
return false
|
||||
}
|
||||
26
vendor/github.com/tailscale/wireguard-go/conn/errors_linux.go
generated
vendored
Normal file
26
vendor/github.com/tailscale/wireguard-go/conn/errors_linux.go
generated
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func errShouldDisableUDPGSO(err error) bool {
|
||||
var serr *os.SyscallError
|
||||
if errors.As(err, &serr) {
|
||||
// EIO is returned by udp_send_skb() if the device driver does not have
|
||||
// tx checksumming enabled, which is a hard requirement of UDP_SEGMENT.
|
||||
// See:
|
||||
// https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
|
||||
// https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
|
||||
return serr.Err == unix.EIO
|
||||
}
|
||||
return false
|
||||
}
|
||||
15
vendor/github.com/tailscale/wireguard-go/conn/features_default.go
generated
vendored
Normal file
15
vendor/github.com/tailscale/wireguard-go/conn/features_default.go
generated
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build !linux
|
||||
// +build !linux
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import "net"
|
||||
|
||||
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
|
||||
return
|
||||
}
|
||||
42
vendor/github.com/tailscale/wireguard-go/conn/features_linux.go
generated
vendored
Normal file
42
vendor/github.com/tailscale/wireguard-go/conn/features_linux.go
generated
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
// TODO: upstream to x/sys/unix
|
||||
socketOptionLevelUDP = 17
|
||||
socketOptionUDPSegment = 103
|
||||
socketOptionUDPGRO = 104
|
||||
)
|
||||
|
||||
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
|
||||
rc, err := conn.SyscallConn()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = rc.Control(func(fd uintptr) {
|
||||
_, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, socketOptionUDPSegment)
|
||||
if errSyscall != nil {
|
||||
return
|
||||
}
|
||||
txOffload = true
|
||||
opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, socketOptionUDPGRO)
|
||||
if errSyscall != nil {
|
||||
return
|
||||
}
|
||||
rxOffload = opt == 1
|
||||
})
|
||||
if err != nil {
|
||||
return false, false
|
||||
}
|
||||
return txOffload, rxOffload
|
||||
}
|
||||
12
vendor/github.com/tailscale/wireguard-go/conn/mark_default.go
generated
vendored
Normal file
12
vendor/github.com/tailscale/wireguard-go/conn/mark_default.go
generated
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build !linux && !openbsd && !freebsd
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
func (s *StdNetBind) SetMark(mark uint32) error {
|
||||
return nil
|
||||
}
|
||||
65
vendor/github.com/tailscale/wireguard-go/conn/mark_unix.go
generated
vendored
Normal file
65
vendor/github.com/tailscale/wireguard-go/conn/mark_unix.go
generated
vendored
Normal file
@@ -0,0 +1,65 @@
|
||||
//go:build linux || openbsd || freebsd
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var fwmarkIoctl int
|
||||
|
||||
func init() {
|
||||
switch runtime.GOOS {
|
||||
case "linux", "android":
|
||||
fwmarkIoctl = 36 /* unix.SO_MARK */
|
||||
case "freebsd":
|
||||
fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */
|
||||
case "openbsd":
|
||||
fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StdNetBind) SetMark(mark uint32) error {
|
||||
var operr error
|
||||
if fwmarkIoctl == 0 {
|
||||
return nil
|
||||
}
|
||||
if s.ipv4 != nil {
|
||||
fd, err := s.ipv4.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = fd.Control(func(fd uintptr) {
|
||||
operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
|
||||
})
|
||||
if err == nil {
|
||||
err = operr
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if s.ipv6 != nil {
|
||||
fd, err := s.ipv6.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = fd.Control(func(fd uintptr) {
|
||||
operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
|
||||
})
|
||||
if err == nil {
|
||||
err = operr
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
254
vendor/github.com/tailscale/wireguard-go/conn/winrio/rio_windows.go
generated
vendored
Normal file
254
vendor/github.com/tailscale/wireguard-go/conn/winrio/rio_windows.go
generated
vendored
Normal file
@@ -0,0 +1,254 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package winrio
|
||||
|
||||
import (
|
||||
"log"
|
||||
"sync"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
const (
|
||||
MsgDontNotify = 1
|
||||
MsgDefer = 2
|
||||
MsgWaitAll = 4
|
||||
MsgCommitOnly = 8
|
||||
|
||||
MaxCqSize = 0x8000000
|
||||
|
||||
invalidBufferId = 0xFFFFFFFF
|
||||
invalidCq = 0
|
||||
invalidRq = 0
|
||||
corruptCq = 0xFFFFFFFF
|
||||
)
|
||||
|
||||
var extensionFunctionTable struct {
|
||||
cbSize uint32
|
||||
rioReceive uintptr
|
||||
rioReceiveEx uintptr
|
||||
rioSend uintptr
|
||||
rioSendEx uintptr
|
||||
rioCloseCompletionQueue uintptr
|
||||
rioCreateCompletionQueue uintptr
|
||||
rioCreateRequestQueue uintptr
|
||||
rioDequeueCompletion uintptr
|
||||
rioDeregisterBuffer uintptr
|
||||
rioNotify uintptr
|
||||
rioRegisterBuffer uintptr
|
||||
rioResizeCompletionQueue uintptr
|
||||
rioResizeRequestQueue uintptr
|
||||
}
|
||||
|
||||
type Cq uintptr
|
||||
|
||||
type Rq uintptr
|
||||
|
||||
type BufferId uintptr
|
||||
|
||||
type Buffer struct {
|
||||
Id BufferId
|
||||
Offset uint32
|
||||
Length uint32
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
Status int32
|
||||
BytesTransferred uint32
|
||||
SocketContext uint64
|
||||
RequestContext uint64
|
||||
}
|
||||
|
||||
type notificationCompletionType uint32
|
||||
|
||||
const (
|
||||
eventCompletion notificationCompletionType = 1
|
||||
iocpCompletion notificationCompletionType = 2
|
||||
)
|
||||
|
||||
type eventNotificationCompletion struct {
|
||||
completionType notificationCompletionType
|
||||
event windows.Handle
|
||||
notifyReset uint32
|
||||
}
|
||||
|
||||
type iocpNotificationCompletion struct {
|
||||
completionType notificationCompletionType
|
||||
iocp windows.Handle
|
||||
key uintptr
|
||||
overlapped *windows.Overlapped
|
||||
}
|
||||
|
||||
var (
|
||||
initialized sync.Once
|
||||
available bool
|
||||
)
|
||||
|
||||
func Initialize() bool {
|
||||
initialized.Do(func() {
|
||||
var (
|
||||
err error
|
||||
socket windows.Handle
|
||||
cq Cq
|
||||
)
|
||||
defer func() {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 7 {
|
||||
return
|
||||
}
|
||||
log.Printf("Registered I/O is unavailable: %v", err)
|
||||
}()
|
||||
socket, err = Socket(windows.AF_INET, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer windows.CloseHandle(socket)
|
||||
WSAID_MULTIPLE_RIO := &windows.GUID{0x8509e081, 0x96dd, 0x4005, [8]byte{0xb1, 0x65, 0x9e, 0x2e, 0xe8, 0xc7, 0x9e, 0x3f}}
|
||||
const SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER = 0xc8000024
|
||||
ob := uint32(0)
|
||||
err = windows.WSAIoctl(socket, SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER,
|
||||
(*byte)(unsafe.Pointer(WSAID_MULTIPLE_RIO)), uint32(unsafe.Sizeof(*WSAID_MULTIPLE_RIO)),
|
||||
(*byte)(unsafe.Pointer(&extensionFunctionTable)), uint32(unsafe.Sizeof(extensionFunctionTable)),
|
||||
&ob, nil, 0)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// While we should be able to stop here, after getting the function pointers, some anti-virus actually causes
|
||||
// failures in RIOCreateRequestQueue, so keep going to be certain this is supported.
|
||||
var iocp windows.Handle
|
||||
iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer windows.CloseHandle(iocp)
|
||||
var overlapped windows.Overlapped
|
||||
cq, err = CreateIOCPCompletionQueue(2, iocp, 0, &overlapped)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer CloseCompletionQueue(cq)
|
||||
_, err = CreateRequestQueue(socket, 1, 1, 1, 1, cq, cq, 0)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
available = true
|
||||
})
|
||||
return available
|
||||
}
|
||||
|
||||
func Socket(af, typ, proto int32) (windows.Handle, error) {
|
||||
return windows.WSASocket(af, typ, proto, nil, 0, windows.WSA_FLAG_REGISTERED_IO)
|
||||
}
|
||||
|
||||
func CloseCompletionQueue(cq Cq) {
|
||||
_, _, _ = syscall.Syscall(extensionFunctionTable.rioCloseCompletionQueue, 1, uintptr(cq), 0, 0)
|
||||
}
|
||||
|
||||
func CreateEventCompletionQueue(queueSize uint32, event windows.Handle, notifyReset bool) (Cq, error) {
|
||||
notificationCompletion := &eventNotificationCompletion{
|
||||
completionType: eventCompletion,
|
||||
event: event,
|
||||
}
|
||||
if notifyReset {
|
||||
notificationCompletion.notifyReset = 1
|
||||
}
|
||||
ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
|
||||
if ret == invalidCq {
|
||||
return 0, err
|
||||
}
|
||||
return Cq(ret), nil
|
||||
}
|
||||
|
||||
func CreateIOCPCompletionQueue(queueSize uint32, iocp windows.Handle, key uintptr, overlapped *windows.Overlapped) (Cq, error) {
|
||||
notificationCompletion := &iocpNotificationCompletion{
|
||||
completionType: iocpCompletion,
|
||||
iocp: iocp,
|
||||
key: key,
|
||||
overlapped: overlapped,
|
||||
}
|
||||
ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
|
||||
if ret == invalidCq {
|
||||
return 0, err
|
||||
}
|
||||
return Cq(ret), nil
|
||||
}
|
||||
|
||||
func CreatePolledCompletionQueue(queueSize uint32) (Cq, error) {
|
||||
ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), 0, 0)
|
||||
if ret == invalidCq {
|
||||
return 0, err
|
||||
}
|
||||
return Cq(ret), nil
|
||||
}
|
||||
|
||||
func CreateRequestQueue(socket windows.Handle, maxOutstandingReceive, maxReceiveDataBuffers, maxOutstandingSend, maxSendDataBuffers uint32, receiveCq, sendCq Cq, socketContext uintptr) (Rq, error) {
|
||||
ret, _, err := syscall.Syscall9(extensionFunctionTable.rioCreateRequestQueue, 8, uintptr(socket), uintptr(maxOutstandingReceive), uintptr(maxReceiveDataBuffers), uintptr(maxOutstandingSend), uintptr(maxSendDataBuffers), uintptr(receiveCq), uintptr(sendCq), socketContext, 0)
|
||||
if ret == invalidRq {
|
||||
return 0, err
|
||||
}
|
||||
return Rq(ret), nil
|
||||
}
|
||||
|
||||
func DequeueCompletion(cq Cq, results []Result) uint32 {
|
||||
var array uintptr
|
||||
if len(results) > 0 {
|
||||
array = uintptr(unsafe.Pointer(&results[0]))
|
||||
}
|
||||
ret, _, _ := syscall.Syscall(extensionFunctionTable.rioDequeueCompletion, 3, uintptr(cq), array, uintptr(len(results)))
|
||||
if ret == corruptCq {
|
||||
panic("cq is corrupt")
|
||||
}
|
||||
return uint32(ret)
|
||||
}
|
||||
|
||||
func DeregisterBuffer(id BufferId) {
|
||||
_, _, _ = syscall.Syscall(extensionFunctionTable.rioDeregisterBuffer, 1, uintptr(id), 0, 0)
|
||||
}
|
||||
|
||||
func RegisterBuffer(buffer []byte) (BufferId, error) {
|
||||
var buf unsafe.Pointer
|
||||
if len(buffer) > 0 {
|
||||
buf = unsafe.Pointer(&buffer[0])
|
||||
}
|
||||
return RegisterPointer(buf, uint32(len(buffer)))
|
||||
}
|
||||
|
||||
func RegisterPointer(ptr unsafe.Pointer, size uint32) (BufferId, error) {
|
||||
ret, _, err := syscall.Syscall(extensionFunctionTable.rioRegisterBuffer, 2, uintptr(ptr), uintptr(size), 0)
|
||||
if ret == invalidBufferId {
|
||||
return 0, err
|
||||
}
|
||||
return BufferId(ret), nil
|
||||
}
|
||||
|
||||
func SendEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
|
||||
ret, _, err := syscall.Syscall9(extensionFunctionTable.rioSendEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
|
||||
if ret == 0 {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ReceiveEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
|
||||
ret, _, err := syscall.Syscall9(extensionFunctionTable.rioReceiveEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
|
||||
if ret == 0 {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func Notify(cq Cq) error {
|
||||
ret, _, _ := syscall.Syscall(extensionFunctionTable.rioNotify, 1, uintptr(cq), 0, 0)
|
||||
if ret != 0 {
|
||||
return windows.Errno(ret)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
294
vendor/github.com/tailscale/wireguard-go/device/allowedips.go
generated
vendored
Normal file
294
vendor/github.com/tailscale/wireguard-go/device/allowedips.go
generated
vendored
Normal file
@@ -0,0 +1,294 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"math/bits"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type parentIndirection struct {
|
||||
parentBit **trieEntry
|
||||
parentBitType uint8
|
||||
}
|
||||
|
||||
type trieEntry struct {
|
||||
peer *Peer
|
||||
child [2]*trieEntry
|
||||
parent parentIndirection
|
||||
cidr uint8
|
||||
bitAtByte uint8
|
||||
bitAtShift uint8
|
||||
bits []byte
|
||||
perPeerElem *list.Element
|
||||
}
|
||||
|
||||
func commonBits(ip1, ip2 []byte) uint8 {
|
||||
size := len(ip1)
|
||||
if size == net.IPv4len {
|
||||
a := binary.BigEndian.Uint32(ip1)
|
||||
b := binary.BigEndian.Uint32(ip2)
|
||||
x := a ^ b
|
||||
return uint8(bits.LeadingZeros32(x))
|
||||
} else if size == net.IPv6len {
|
||||
a := binary.BigEndian.Uint64(ip1)
|
||||
b := binary.BigEndian.Uint64(ip2)
|
||||
x := a ^ b
|
||||
if x != 0 {
|
||||
return uint8(bits.LeadingZeros64(x))
|
||||
}
|
||||
a = binary.BigEndian.Uint64(ip1[8:])
|
||||
b = binary.BigEndian.Uint64(ip2[8:])
|
||||
x = a ^ b
|
||||
return 64 + uint8(bits.LeadingZeros64(x))
|
||||
} else {
|
||||
panic("Wrong size bit string")
|
||||
}
|
||||
}
|
||||
|
||||
func (node *trieEntry) addToPeerEntries() {
|
||||
node.perPeerElem = node.peer.trieEntries.PushBack(node)
|
||||
}
|
||||
|
||||
func (node *trieEntry) removeFromPeerEntries() {
|
||||
if node.perPeerElem != nil {
|
||||
node.peer.trieEntries.Remove(node.perPeerElem)
|
||||
node.perPeerElem = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (node *trieEntry) choose(ip []byte) byte {
|
||||
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
|
||||
}
|
||||
|
||||
func (node *trieEntry) maskSelf() {
|
||||
mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
|
||||
for i := 0; i < len(mask); i++ {
|
||||
node.bits[i] &= mask[i]
|
||||
}
|
||||
}
|
||||
|
||||
func (node *trieEntry) zeroizePointers() {
|
||||
// Make the garbage collector's life slightly easier
|
||||
node.peer = nil
|
||||
node.child[0] = nil
|
||||
node.child[1] = nil
|
||||
node.parent.parentBit = nil
|
||||
}
|
||||
|
||||
func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
|
||||
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
|
||||
parent = node
|
||||
if parent.cidr == cidr {
|
||||
exact = true
|
||||
return
|
||||
}
|
||||
bit := node.choose(ip)
|
||||
node = node.child[bit]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
|
||||
if *trie.parentBit == nil {
|
||||
node := &trieEntry{
|
||||
peer: peer,
|
||||
parent: trie,
|
||||
bits: ip,
|
||||
cidr: cidr,
|
||||
bitAtByte: cidr / 8,
|
||||
bitAtShift: 7 - (cidr % 8),
|
||||
}
|
||||
node.maskSelf()
|
||||
node.addToPeerEntries()
|
||||
*trie.parentBit = node
|
||||
return
|
||||
}
|
||||
node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
|
||||
if exact {
|
||||
node.removeFromPeerEntries()
|
||||
node.peer = peer
|
||||
node.addToPeerEntries()
|
||||
return
|
||||
}
|
||||
|
||||
newNode := &trieEntry{
|
||||
peer: peer,
|
||||
bits: ip,
|
||||
cidr: cidr,
|
||||
bitAtByte: cidr / 8,
|
||||
bitAtShift: 7 - (cidr % 8),
|
||||
}
|
||||
newNode.maskSelf()
|
||||
newNode.addToPeerEntries()
|
||||
|
||||
var down *trieEntry
|
||||
if node == nil {
|
||||
down = *trie.parentBit
|
||||
} else {
|
||||
bit := node.choose(ip)
|
||||
down = node.child[bit]
|
||||
if down == nil {
|
||||
newNode.parent = parentIndirection{&node.child[bit], bit}
|
||||
node.child[bit] = newNode
|
||||
return
|
||||
}
|
||||
}
|
||||
common := commonBits(down.bits, ip)
|
||||
if common < cidr {
|
||||
cidr = common
|
||||
}
|
||||
parent := node
|
||||
|
||||
if newNode.cidr == cidr {
|
||||
bit := newNode.choose(down.bits)
|
||||
down.parent = parentIndirection{&newNode.child[bit], bit}
|
||||
newNode.child[bit] = down
|
||||
if parent == nil {
|
||||
newNode.parent = trie
|
||||
*trie.parentBit = newNode
|
||||
} else {
|
||||
bit := parent.choose(newNode.bits)
|
||||
newNode.parent = parentIndirection{&parent.child[bit], bit}
|
||||
parent.child[bit] = newNode
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
node = &trieEntry{
|
||||
bits: append([]byte{}, newNode.bits...),
|
||||
cidr: cidr,
|
||||
bitAtByte: cidr / 8,
|
||||
bitAtShift: 7 - (cidr % 8),
|
||||
}
|
||||
node.maskSelf()
|
||||
|
||||
bit := node.choose(down.bits)
|
||||
down.parent = parentIndirection{&node.child[bit], bit}
|
||||
node.child[bit] = down
|
||||
bit = node.choose(newNode.bits)
|
||||
newNode.parent = parentIndirection{&node.child[bit], bit}
|
||||
node.child[bit] = newNode
|
||||
if parent == nil {
|
||||
node.parent = trie
|
||||
*trie.parentBit = node
|
||||
} else {
|
||||
bit := parent.choose(node.bits)
|
||||
node.parent = parentIndirection{&parent.child[bit], bit}
|
||||
parent.child[bit] = node
|
||||
}
|
||||
}
|
||||
|
||||
func (node *trieEntry) lookup(ip []byte) *Peer {
|
||||
var found *Peer
|
||||
size := uint8(len(ip))
|
||||
for node != nil && commonBits(node.bits, ip) >= node.cidr {
|
||||
if node.peer != nil {
|
||||
found = node.peer
|
||||
}
|
||||
if node.bitAtByte == size {
|
||||
break
|
||||
}
|
||||
bit := node.choose(ip)
|
||||
node = node.child[bit]
|
||||
}
|
||||
return found
|
||||
}
|
||||
|
||||
type AllowedIPs struct {
|
||||
IPv4 *trieEntry
|
||||
IPv6 *trieEntry
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
|
||||
table.mutex.RLock()
|
||||
defer table.mutex.RUnlock()
|
||||
|
||||
for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
|
||||
node := elem.Value.(*trieEntry)
|
||||
a, _ := netip.AddrFromSlice(node.bits)
|
||||
if !cb(netip.PrefixFrom(a, int(node.cidr))) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
||||
table.mutex.Lock()
|
||||
defer table.mutex.Unlock()
|
||||
|
||||
var next *list.Element
|
||||
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
|
||||
next = elem.Next()
|
||||
node := elem.Value.(*trieEntry)
|
||||
|
||||
node.removeFromPeerEntries()
|
||||
node.peer = nil
|
||||
if node.child[0] != nil && node.child[1] != nil {
|
||||
continue
|
||||
}
|
||||
bit := 0
|
||||
if node.child[0] == nil {
|
||||
bit = 1
|
||||
}
|
||||
child := node.child[bit]
|
||||
if child != nil {
|
||||
child.parent = node.parent
|
||||
}
|
||||
*node.parent.parentBit = child
|
||||
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
|
||||
node.zeroizePointers()
|
||||
continue
|
||||
}
|
||||
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
|
||||
if parent.peer != nil {
|
||||
node.zeroizePointers()
|
||||
continue
|
||||
}
|
||||
child = parent.child[node.parent.parentBitType^1]
|
||||
if child != nil {
|
||||
child.parent = parent.parent
|
||||
}
|
||||
*parent.parent.parentBit = child
|
||||
node.zeroizePointers()
|
||||
parent.zeroizePointers()
|
||||
}
|
||||
}
|
||||
|
||||
func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
|
||||
table.mutex.Lock()
|
||||
defer table.mutex.Unlock()
|
||||
|
||||
if prefix.Addr().Is6() {
|
||||
ip := prefix.Addr().As16()
|
||||
parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
|
||||
} else if prefix.Addr().Is4() {
|
||||
ip := prefix.Addr().As4()
|
||||
parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
|
||||
} else {
|
||||
panic(errors.New("inserting unknown address type"))
|
||||
}
|
||||
}
|
||||
|
||||
func (table *AllowedIPs) Lookup(ip []byte) *Peer {
|
||||
table.mutex.RLock()
|
||||
defer table.mutex.RUnlock()
|
||||
switch len(ip) {
|
||||
case net.IPv6len:
|
||||
return table.IPv6.lookup(ip)
|
||||
case net.IPv4len:
|
||||
return table.IPv4.lookup(ip)
|
||||
default:
|
||||
panic(errors.New("looking up unknown address type"))
|
||||
}
|
||||
}
|
||||
137
vendor/github.com/tailscale/wireguard-go/device/channels.go
generated
vendored
Normal file
137
vendor/github.com/tailscale/wireguard-go/device/channels.go
generated
vendored
Normal file
@@ -0,0 +1,137 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// An outboundQueue is a channel of QueueOutboundElements awaiting encryption.
|
||||
// An outboundQueue is ref-counted using its wg field.
|
||||
// An outboundQueue created with newOutboundQueue has one reference.
|
||||
// Every additional writer must call wg.Add(1).
|
||||
// Every completed writer must call wg.Done().
|
||||
// When no further writers will be added,
|
||||
// call wg.Done to remove the initial reference.
|
||||
// When the refcount hits 0, the queue's channel is closed.
|
||||
type outboundQueue struct {
|
||||
c chan *QueueOutboundElementsContainer
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func newOutboundQueue() *outboundQueue {
|
||||
q := &outboundQueue{
|
||||
c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
|
||||
}
|
||||
q.wg.Add(1)
|
||||
go func() {
|
||||
q.wg.Wait()
|
||||
close(q.c)
|
||||
}()
|
||||
return q
|
||||
}
|
||||
|
||||
// A inboundQueue is similar to an outboundQueue; see those docs.
|
||||
type inboundQueue struct {
|
||||
c chan *QueueInboundElementsContainer
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func newInboundQueue() *inboundQueue {
|
||||
q := &inboundQueue{
|
||||
c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
|
||||
}
|
||||
q.wg.Add(1)
|
||||
go func() {
|
||||
q.wg.Wait()
|
||||
close(q.c)
|
||||
}()
|
||||
return q
|
||||
}
|
||||
|
||||
// A handshakeQueue is similar to an outboundQueue; see those docs.
|
||||
type handshakeQueue struct {
|
||||
c chan QueueHandshakeElement
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func newHandshakeQueue() *handshakeQueue {
|
||||
q := &handshakeQueue{
|
||||
c: make(chan QueueHandshakeElement, QueueHandshakeSize),
|
||||
}
|
||||
q.wg.Add(1)
|
||||
go func() {
|
||||
q.wg.Wait()
|
||||
close(q.c)
|
||||
}()
|
||||
return q
|
||||
}
|
||||
|
||||
type autodrainingInboundQueue struct {
|
||||
c chan *QueueInboundElementsContainer
|
||||
}
|
||||
|
||||
// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd.
|
||||
// It is useful in cases in which is it hard to manage the lifetime of the channel.
|
||||
// The returned channel must not be closed. Senders should signal shutdown using
|
||||
// some other means, such as sending a sentinel nil values.
|
||||
func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
|
||||
q := &autodrainingInboundQueue{
|
||||
c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
|
||||
}
|
||||
runtime.SetFinalizer(q, device.flushInboundQueue)
|
||||
return q
|
||||
}
|
||||
|
||||
func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
|
||||
for {
|
||||
select {
|
||||
case elemsContainer := <-q.c:
|
||||
elemsContainer.Lock()
|
||||
for _, elem := range elemsContainer.elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutInboundElement(elem)
|
||||
}
|
||||
device.PutInboundElementsContainer(elemsContainer)
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type autodrainingOutboundQueue struct {
|
||||
c chan *QueueOutboundElementsContainer
|
||||
}
|
||||
|
||||
// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd.
|
||||
// It is useful in cases in which is it hard to manage the lifetime of the channel.
|
||||
// The returned channel must not be closed. Senders should signal shutdown using
|
||||
// some other means, such as sending a sentinel nil values.
|
||||
// All sends to the channel must be best-effort, because there may be no receivers.
|
||||
func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
|
||||
q := &autodrainingOutboundQueue{
|
||||
c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
|
||||
}
|
||||
runtime.SetFinalizer(q, device.flushOutboundQueue)
|
||||
return q
|
||||
}
|
||||
|
||||
func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) {
|
||||
for {
|
||||
select {
|
||||
case elemsContainer := <-q.c:
|
||||
elemsContainer.Lock()
|
||||
for _, elem := range elemsContainer.elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
}
|
||||
device.PutOutboundElementsContainer(elemsContainer)
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
40
vendor/github.com/tailscale/wireguard-go/device/constants.go
generated
vendored
Normal file
40
vendor/github.com/tailscale/wireguard-go/device/constants.go
generated
vendored
Normal file
@@ -0,0 +1,40 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
/* Specification constants */
|
||||
|
||||
const (
|
||||
RekeyAfterMessages = (1 << 60)
|
||||
RejectAfterMessages = (1 << 64) - (1 << 13) - 1
|
||||
RekeyAfterTime = time.Second * 120
|
||||
RekeyAttemptTime = time.Second * 90
|
||||
RekeyTimeout = time.Second * 5
|
||||
MaxTimerHandshakes = 90 / 5 /* RekeyAttemptTime / RekeyTimeout */
|
||||
RekeyTimeoutJitterMaxMs = 334
|
||||
RejectAfterTime = time.Second * 180
|
||||
KeepaliveTimeout = time.Second * 10
|
||||
CookieRefreshTime = time.Second * 120
|
||||
HandshakeInitationRate = time.Second / 50
|
||||
PaddingMultiple = 16
|
||||
)
|
||||
|
||||
const (
|
||||
MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive)
|
||||
MaxMessageSize = MaxSegmentSize // maximum size of transport message
|
||||
MaxContentSize = MaxSegmentSize - MessageTransportSize // maximum size of transport message content
|
||||
)
|
||||
|
||||
/* Implementation constants */
|
||||
|
||||
const (
|
||||
UnderLoadAfterTime = time.Second // how long does the device remain under load after detected
|
||||
MaxPeers = 1 << 16 // maximum number of configured peers
|
||||
)
|
||||
248
vendor/github.com/tailscale/wireguard-go/device/cookie.go
generated
vendored
Normal file
248
vendor/github.com/tailscale/wireguard-go/device/cookie.go
generated
vendored
Normal file
@@ -0,0 +1,248 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/blake2s"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
type CookieChecker struct {
|
||||
sync.RWMutex
|
||||
mac1 struct {
|
||||
key [blake2s.Size]byte
|
||||
}
|
||||
mac2 struct {
|
||||
secret [blake2s.Size]byte
|
||||
secretSet time.Time
|
||||
encryptionKey [chacha20poly1305.KeySize]byte
|
||||
}
|
||||
}
|
||||
|
||||
type CookieGenerator struct {
|
||||
sync.RWMutex
|
||||
mac1 struct {
|
||||
key [blake2s.Size]byte
|
||||
}
|
||||
mac2 struct {
|
||||
cookie [blake2s.Size128]byte
|
||||
cookieSet time.Time
|
||||
hasLastMAC1 bool
|
||||
lastMAC1 [blake2s.Size128]byte
|
||||
encryptionKey [chacha20poly1305.KeySize]byte
|
||||
}
|
||||
}
|
||||
|
||||
func (st *CookieChecker) Init(pk NoisePublicKey) {
|
||||
st.Lock()
|
||||
defer st.Unlock()
|
||||
|
||||
// mac1 state
|
||||
|
||||
func() {
|
||||
hash, _ := blake2s.New256(nil)
|
||||
hash.Write([]byte(WGLabelMAC1))
|
||||
hash.Write(pk[:])
|
||||
hash.Sum(st.mac1.key[:0])
|
||||
}()
|
||||
|
||||
// mac2 state
|
||||
|
||||
func() {
|
||||
hash, _ := blake2s.New256(nil)
|
||||
hash.Write([]byte(WGLabelCookie))
|
||||
hash.Write(pk[:])
|
||||
hash.Sum(st.mac2.encryptionKey[:0])
|
||||
}()
|
||||
|
||||
st.mac2.secretSet = time.Time{}
|
||||
}
|
||||
|
||||
func (st *CookieChecker) CheckMAC1(msg []byte) bool {
|
||||
st.RLock()
|
||||
defer st.RUnlock()
|
||||
|
||||
size := len(msg)
|
||||
smac2 := size - blake2s.Size128
|
||||
smac1 := smac2 - blake2s.Size128
|
||||
|
||||
var mac1 [blake2s.Size128]byte
|
||||
|
||||
mac, _ := blake2s.New128(st.mac1.key[:])
|
||||
mac.Write(msg[:smac1])
|
||||
mac.Sum(mac1[:0])
|
||||
|
||||
return hmac.Equal(mac1[:], msg[smac1:smac2])
|
||||
}
|
||||
|
||||
func (st *CookieChecker) CheckMAC2(msg, src []byte) bool {
|
||||
st.RLock()
|
||||
defer st.RUnlock()
|
||||
|
||||
if time.Since(st.mac2.secretSet) > CookieRefreshTime {
|
||||
return false
|
||||
}
|
||||
|
||||
// derive cookie key
|
||||
|
||||
var cookie [blake2s.Size128]byte
|
||||
func() {
|
||||
mac, _ := blake2s.New128(st.mac2.secret[:])
|
||||
mac.Write(src)
|
||||
mac.Sum(cookie[:0])
|
||||
}()
|
||||
|
||||
// calculate mac of packet (including mac1)
|
||||
|
||||
smac2 := len(msg) - blake2s.Size128
|
||||
|
||||
var mac2 [blake2s.Size128]byte
|
||||
func() {
|
||||
mac, _ := blake2s.New128(cookie[:])
|
||||
mac.Write(msg[:smac2])
|
||||
mac.Sum(mac2[:0])
|
||||
}()
|
||||
|
||||
return hmac.Equal(mac2[:], msg[smac2:])
|
||||
}
|
||||
|
||||
func (st *CookieChecker) CreateReply(
|
||||
msg []byte,
|
||||
recv uint32,
|
||||
src []byte,
|
||||
) (*MessageCookieReply, error) {
|
||||
st.RLock()
|
||||
|
||||
// refresh cookie secret
|
||||
|
||||
if time.Since(st.mac2.secretSet) > CookieRefreshTime {
|
||||
st.RUnlock()
|
||||
st.Lock()
|
||||
_, err := rand.Read(st.mac2.secret[:])
|
||||
if err != nil {
|
||||
st.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
st.mac2.secretSet = time.Now()
|
||||
st.Unlock()
|
||||
st.RLock()
|
||||
}
|
||||
|
||||
// derive cookie
|
||||
|
||||
var cookie [blake2s.Size128]byte
|
||||
func() {
|
||||
mac, _ := blake2s.New128(st.mac2.secret[:])
|
||||
mac.Write(src)
|
||||
mac.Sum(cookie[:0])
|
||||
}()
|
||||
|
||||
// encrypt cookie
|
||||
|
||||
size := len(msg)
|
||||
|
||||
smac2 := size - blake2s.Size128
|
||||
smac1 := smac2 - blake2s.Size128
|
||||
|
||||
reply := new(MessageCookieReply)
|
||||
reply.Type = MessageCookieReplyType
|
||||
reply.Receiver = recv
|
||||
|
||||
_, err := rand.Read(reply.Nonce[:])
|
||||
if err != nil {
|
||||
st.RUnlock()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
|
||||
xchapoly.Seal(reply.Cookie[:0], reply.Nonce[:], cookie[:], msg[smac1:smac2])
|
||||
|
||||
st.RUnlock()
|
||||
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
func (st *CookieGenerator) Init(pk NoisePublicKey) {
|
||||
st.Lock()
|
||||
defer st.Unlock()
|
||||
|
||||
func() {
|
||||
hash, _ := blake2s.New256(nil)
|
||||
hash.Write([]byte(WGLabelMAC1))
|
||||
hash.Write(pk[:])
|
||||
hash.Sum(st.mac1.key[:0])
|
||||
}()
|
||||
|
||||
func() {
|
||||
hash, _ := blake2s.New256(nil)
|
||||
hash.Write([]byte(WGLabelCookie))
|
||||
hash.Write(pk[:])
|
||||
hash.Sum(st.mac2.encryptionKey[:0])
|
||||
}()
|
||||
|
||||
st.mac2.cookieSet = time.Time{}
|
||||
}
|
||||
|
||||
func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
|
||||
st.Lock()
|
||||
defer st.Unlock()
|
||||
|
||||
if !st.mac2.hasLastMAC1 {
|
||||
return false
|
||||
}
|
||||
|
||||
var cookie [blake2s.Size128]byte
|
||||
|
||||
xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
|
||||
_, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
st.mac2.cookieSet = time.Now()
|
||||
st.mac2.cookie = cookie
|
||||
return true
|
||||
}
|
||||
|
||||
func (st *CookieGenerator) AddMacs(msg []byte) {
|
||||
size := len(msg)
|
||||
|
||||
smac2 := size - blake2s.Size128
|
||||
smac1 := smac2 - blake2s.Size128
|
||||
|
||||
mac1 := msg[smac1:smac2]
|
||||
mac2 := msg[smac2:]
|
||||
|
||||
st.Lock()
|
||||
defer st.Unlock()
|
||||
|
||||
// set mac1
|
||||
|
||||
func() {
|
||||
mac, _ := blake2s.New128(st.mac1.key[:])
|
||||
mac.Write(msg[:smac1])
|
||||
mac.Sum(mac1[:0])
|
||||
}()
|
||||
copy(st.mac2.lastMAC1[:], mac1)
|
||||
st.mac2.hasLastMAC1 = true
|
||||
|
||||
// set mac2
|
||||
|
||||
if time.Since(st.mac2.cookieSet) > CookieRefreshTime {
|
||||
return
|
||||
}
|
||||
|
||||
func() {
|
||||
mac, _ := blake2s.New128(st.mac2.cookie[:])
|
||||
mac.Write(msg[:smac2])
|
||||
mac.Sum(mac2[:0])
|
||||
}()
|
||||
}
|
||||
536
vendor/github.com/tailscale/wireguard-go/device/device.go
generated
vendored
Normal file
536
vendor/github.com/tailscale/wireguard-go/device/device.go
generated
vendored
Normal file
@@ -0,0 +1,536 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/wireguard-go/conn"
|
||||
"github.com/tailscale/wireguard-go/ratelimiter"
|
||||
"github.com/tailscale/wireguard-go/rwcancel"
|
||||
"github.com/tailscale/wireguard-go/tun"
|
||||
)
|
||||
|
||||
type Device struct {
|
||||
state struct {
|
||||
// state holds the device's state. It is accessed atomically.
|
||||
// Use the device.deviceState method to read it.
|
||||
// device.deviceState does not acquire the mutex, so it captures only a snapshot.
|
||||
// During state transitions, the state variable is updated before the device itself.
|
||||
// The state is thus either the current state of the device or
|
||||
// the intended future state of the device.
|
||||
// For example, while executing a call to Up, state will be deviceStateUp.
|
||||
// There is no guarantee that that intended future state of the device
|
||||
// will become the actual state; Up can fail.
|
||||
// The device can also change state multiple times between time of check and time of use.
|
||||
// Unsynchronized uses of state must therefore be advisory/best-effort only.
|
||||
state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience
|
||||
// stopping blocks until all inputs to Device have been closed.
|
||||
stopping sync.WaitGroup
|
||||
// mu protects state changes.
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
net struct {
|
||||
stopping sync.WaitGroup
|
||||
sync.RWMutex
|
||||
bind conn.Bind // bind interface
|
||||
netlinkCancel *rwcancel.RWCancel
|
||||
port uint16 // listening port
|
||||
fwmark uint32 // mark value (0 = disabled)
|
||||
brokenRoaming bool
|
||||
}
|
||||
|
||||
staticIdentity struct {
|
||||
sync.RWMutex
|
||||
privateKey NoisePrivateKey
|
||||
publicKey NoisePublicKey
|
||||
}
|
||||
|
||||
peers struct {
|
||||
sync.RWMutex // protects keyMap
|
||||
keyMap map[NoisePublicKey]*Peer
|
||||
}
|
||||
|
||||
rate struct {
|
||||
underLoadUntil atomic.Int64
|
||||
limiter ratelimiter.Ratelimiter
|
||||
}
|
||||
|
||||
allowedips AllowedIPs
|
||||
indexTable IndexTable
|
||||
cookieChecker CookieChecker
|
||||
|
||||
pool struct {
|
||||
inboundElementsContainer *WaitPool
|
||||
outboundElementsContainer *WaitPool
|
||||
messageBuffers *WaitPool
|
||||
inboundElements *WaitPool
|
||||
outboundElements *WaitPool
|
||||
}
|
||||
|
||||
queue struct {
|
||||
encryption *outboundQueue
|
||||
decryption *inboundQueue
|
||||
handshake *handshakeQueue
|
||||
}
|
||||
|
||||
tun struct {
|
||||
device tun.Device
|
||||
mtu atomic.Int32
|
||||
}
|
||||
|
||||
ipcMutex sync.RWMutex
|
||||
closed chan struct{}
|
||||
log *Logger
|
||||
}
|
||||
|
||||
// deviceState represents the state of a Device.
|
||||
// There are three states: down, up, closed.
|
||||
// Transitions:
|
||||
//
|
||||
// down -----+
|
||||
// ↑↓ ↓
|
||||
// up -> closed
|
||||
type deviceState uint32
|
||||
|
||||
//go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
|
||||
const (
|
||||
deviceStateDown deviceState = iota
|
||||
deviceStateUp
|
||||
deviceStateClosed
|
||||
)
|
||||
|
||||
// deviceState returns device.state.state as a deviceState
|
||||
// See those docs for how to interpret this value.
|
||||
func (device *Device) deviceState() deviceState {
|
||||
return deviceState(device.state.state.Load())
|
||||
}
|
||||
|
||||
// isClosed reports whether the device is closed (or is closing).
|
||||
// See device.state.state comments for how to interpret this value.
|
||||
func (device *Device) isClosed() bool {
|
||||
return device.deviceState() == deviceStateClosed
|
||||
}
|
||||
|
||||
// isUp reports whether the device is up (or is attempting to come up).
|
||||
// See device.state.state comments for how to interpret this value.
|
||||
func (device *Device) isUp() bool {
|
||||
return device.deviceState() == deviceStateUp
|
||||
}
|
||||
|
||||
// Must hold device.peers.Lock()
|
||||
func removePeerLocked(device *Device, peer *Peer, key NoisePublicKey) {
|
||||
// stop routing and processing of packets
|
||||
device.allowedips.RemoveByPeer(peer)
|
||||
peer.Stop()
|
||||
|
||||
// remove from peer map
|
||||
delete(device.peers.keyMap, key)
|
||||
}
|
||||
|
||||
// changeState attempts to change the device state to match want.
|
||||
func (device *Device) changeState(want deviceState) (err error) {
|
||||
device.state.Lock()
|
||||
defer device.state.Unlock()
|
||||
old := device.deviceState()
|
||||
if old == deviceStateClosed {
|
||||
// once closed, always closed
|
||||
device.log.Verbosef("Interface closed, ignored requested state %s", want)
|
||||
return nil
|
||||
}
|
||||
switch want {
|
||||
case old:
|
||||
return nil
|
||||
case deviceStateUp:
|
||||
device.state.state.Store(uint32(deviceStateUp))
|
||||
err = device.upLocked()
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
fallthrough // up failed; bring the device all the way back down
|
||||
case deviceStateDown:
|
||||
device.state.state.Store(uint32(deviceStateDown))
|
||||
errDown := device.downLocked()
|
||||
if err == nil {
|
||||
err = errDown
|
||||
}
|
||||
}
|
||||
device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState())
|
||||
return
|
||||
}
|
||||
|
||||
// upLocked attempts to bring the device up and reports whether it succeeded.
|
||||
// The caller must hold device.state.mu and is responsible for updating device.state.state.
|
||||
func (device *Device) upLocked() error {
|
||||
if err := device.BindUpdate(); err != nil {
|
||||
device.log.Errorf("Unable to update bind: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// The IPC set operation waits for peers to be created before calling Start() on them,
|
||||
// so if there's a concurrent IPC set request happening, we should wait for it to complete.
|
||||
device.ipcMutex.Lock()
|
||||
defer device.ipcMutex.Unlock()
|
||||
|
||||
device.peers.RLock()
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.Start()
|
||||
if peer.persistentKeepaliveInterval.Load() > 0 {
|
||||
peer.SendKeepalive()
|
||||
}
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// downLocked attempts to bring the device down.
|
||||
// The caller must hold device.state.mu and is responsible for updating device.state.state.
|
||||
func (device *Device) downLocked() error {
|
||||
err := device.BindClose()
|
||||
if err != nil {
|
||||
device.log.Errorf("Bind close failed: %v", err)
|
||||
}
|
||||
|
||||
device.peers.RLock()
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.Stop()
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (device *Device) Up() error {
|
||||
return device.changeState(deviceStateUp)
|
||||
}
|
||||
|
||||
func (device *Device) Down() error {
|
||||
return device.changeState(deviceStateDown)
|
||||
}
|
||||
|
||||
func (device *Device) IsUnderLoad() bool {
|
||||
// check if currently under load
|
||||
now := time.Now()
|
||||
underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8
|
||||
if underLoad {
|
||||
device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano())
|
||||
return true
|
||||
}
|
||||
// check if recently under load
|
||||
return device.rate.underLoadUntil.Load() > now.UnixNano()
|
||||
}
|
||||
|
||||
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
||||
// lock required resources
|
||||
|
||||
device.staticIdentity.Lock()
|
||||
defer device.staticIdentity.Unlock()
|
||||
|
||||
if sk.Equals(device.staticIdentity.privateKey) {
|
||||
return nil
|
||||
}
|
||||
|
||||
device.peers.Lock()
|
||||
defer device.peers.Unlock()
|
||||
|
||||
lockedPeers := make([]*Peer, 0, len(device.peers.keyMap))
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.handshake.mutex.RLock()
|
||||
lockedPeers = append(lockedPeers, peer)
|
||||
}
|
||||
|
||||
// remove peers with matching public keys
|
||||
|
||||
publicKey := sk.publicKey()
|
||||
for key, peer := range device.peers.keyMap {
|
||||
if peer.handshake.remoteStatic.Equals(publicKey) {
|
||||
peer.handshake.mutex.RUnlock()
|
||||
removePeerLocked(device, peer, key)
|
||||
peer.handshake.mutex.RLock()
|
||||
}
|
||||
}
|
||||
|
||||
// update key material
|
||||
|
||||
device.staticIdentity.privateKey = sk
|
||||
device.staticIdentity.publicKey = publicKey
|
||||
device.cookieChecker.Init(publicKey)
|
||||
|
||||
// do static-static DH pre-computations
|
||||
|
||||
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
|
||||
for _, peer := range device.peers.keyMap {
|
||||
handshake := &peer.handshake
|
||||
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
|
||||
expiredPeers = append(expiredPeers, peer)
|
||||
}
|
||||
|
||||
for _, peer := range lockedPeers {
|
||||
peer.handshake.mutex.RUnlock()
|
||||
}
|
||||
for _, peer := range expiredPeers {
|
||||
peer.ExpireCurrentKeypairs()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
|
||||
device := new(Device)
|
||||
device.state.state.Store(uint32(deviceStateDown))
|
||||
device.closed = make(chan struct{})
|
||||
device.log = logger
|
||||
device.net.bind = bind
|
||||
device.tun.device = tunDevice
|
||||
mtu, err := device.tun.device.MTU()
|
||||
if err != nil {
|
||||
device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
|
||||
mtu = DefaultMTU
|
||||
}
|
||||
device.tun.mtu.Store(int32(mtu))
|
||||
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
|
||||
device.rate.limiter.Init()
|
||||
device.indexTable.Init()
|
||||
|
||||
device.PopulatePools()
|
||||
|
||||
// create queues
|
||||
|
||||
device.queue.handshake = newHandshakeQueue()
|
||||
device.queue.encryption = newOutboundQueue()
|
||||
device.queue.decryption = newInboundQueue()
|
||||
|
||||
// start workers
|
||||
|
||||
cpus := runtime.NumCPU()
|
||||
device.state.stopping.Wait()
|
||||
device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake
|
||||
for i := 0; i < cpus; i++ {
|
||||
go device.RoutineEncryption(i + 1)
|
||||
go device.RoutineDecryption(i + 1)
|
||||
go device.RoutineHandshake(i + 1)
|
||||
}
|
||||
|
||||
device.state.stopping.Add(1) // RoutineReadFromTUN
|
||||
device.queue.encryption.wg.Add(1) // RoutineReadFromTUN
|
||||
go device.RoutineReadFromTUN()
|
||||
go device.RoutineTUNEventReader()
|
||||
|
||||
return device
|
||||
}
|
||||
|
||||
// BatchSize returns the BatchSize for the device as a whole which is the max of
|
||||
// the bind batch size and the tun batch size. The batch size reported by device
|
||||
// is the size used to construct memory pools, and is the allowed batch size for
|
||||
// the lifetime of the device.
|
||||
func (device *Device) BatchSize() int {
|
||||
size := device.net.bind.BatchSize()
|
||||
dSize := device.tun.device.BatchSize()
|
||||
if size < dSize {
|
||||
size = dSize
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
|
||||
device.peers.RLock()
|
||||
defer device.peers.RUnlock()
|
||||
|
||||
return device.peers.keyMap[pk]
|
||||
}
|
||||
|
||||
func (device *Device) RemovePeer(key NoisePublicKey) {
|
||||
device.peers.Lock()
|
||||
defer device.peers.Unlock()
|
||||
// stop peer and remove from routing
|
||||
|
||||
peer, ok := device.peers.keyMap[key]
|
||||
if ok {
|
||||
removePeerLocked(device, peer, key)
|
||||
}
|
||||
}
|
||||
|
||||
func (device *Device) RemoveAllPeers() {
|
||||
device.peers.Lock()
|
||||
defer device.peers.Unlock()
|
||||
|
||||
for key, peer := range device.peers.keyMap {
|
||||
removePeerLocked(device, peer, key)
|
||||
}
|
||||
|
||||
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
|
||||
}
|
||||
|
||||
func (device *Device) Close() {
|
||||
device.ipcMutex.Lock()
|
||||
defer device.ipcMutex.Unlock()
|
||||
device.state.Lock()
|
||||
defer device.state.Unlock()
|
||||
if device.isClosed() {
|
||||
return
|
||||
}
|
||||
device.state.state.Store(uint32(deviceStateClosed))
|
||||
device.log.Verbosef("Device closing")
|
||||
|
||||
device.tun.device.Close()
|
||||
device.downLocked()
|
||||
|
||||
// Remove peers before closing queues,
|
||||
// because peers assume that queues are active.
|
||||
device.RemoveAllPeers()
|
||||
|
||||
// We kept a reference to the encryption and decryption queues,
|
||||
// in case we started any new peers that might write to them.
|
||||
// No new peers are coming; we are done with these queues.
|
||||
device.queue.encryption.wg.Done()
|
||||
device.queue.decryption.wg.Done()
|
||||
device.queue.handshake.wg.Done()
|
||||
device.state.stopping.Wait()
|
||||
|
||||
device.rate.limiter.Close()
|
||||
|
||||
device.log.Verbosef("Device closed")
|
||||
close(device.closed)
|
||||
}
|
||||
|
||||
func (device *Device) Wait() chan struct{} {
|
||||
return device.closed
|
||||
}
|
||||
|
||||
func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
|
||||
if !device.isUp() {
|
||||
return
|
||||
}
|
||||
|
||||
device.peers.RLock()
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.keypairs.RLock()
|
||||
sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now())
|
||||
peer.keypairs.RUnlock()
|
||||
if sendKeepalive {
|
||||
peer.SendKeepalive()
|
||||
}
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
}
|
||||
|
||||
// closeBindLocked closes the device's net.bind.
|
||||
// The caller must hold the net mutex.
|
||||
func closeBindLocked(device *Device) error {
|
||||
var err error
|
||||
netc := &device.net
|
||||
if netc.netlinkCancel != nil {
|
||||
netc.netlinkCancel.Cancel()
|
||||
}
|
||||
if netc.bind != nil {
|
||||
err = netc.bind.Close()
|
||||
}
|
||||
netc.stopping.Wait()
|
||||
return err
|
||||
}
|
||||
|
||||
func (device *Device) Bind() conn.Bind {
|
||||
device.net.Lock()
|
||||
defer device.net.Unlock()
|
||||
return device.net.bind
|
||||
}
|
||||
|
||||
func (device *Device) BindSetMark(mark uint32) error {
|
||||
device.net.Lock()
|
||||
defer device.net.Unlock()
|
||||
|
||||
// check if modified
|
||||
if device.net.fwmark == mark {
|
||||
return nil
|
||||
}
|
||||
|
||||
// update fwmark on existing bind
|
||||
device.net.fwmark = mark
|
||||
if device.isUp() && device.net.bind != nil {
|
||||
if err := device.net.bind.SetMark(mark); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// clear cached source addresses
|
||||
device.peers.RLock()
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.markEndpointSrcForClearing()
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (device *Device) BindUpdate() error {
|
||||
device.net.Lock()
|
||||
defer device.net.Unlock()
|
||||
|
||||
// close existing sockets
|
||||
if err := closeBindLocked(device); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// open new sockets
|
||||
if !device.isUp() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// bind to new port
|
||||
var err error
|
||||
var recvFns []conn.ReceiveFunc
|
||||
netc := &device.net
|
||||
|
||||
recvFns, netc.port, err = netc.bind.Open(netc.port)
|
||||
if err != nil {
|
||||
netc.port = 0
|
||||
return err
|
||||
}
|
||||
|
||||
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
|
||||
if err != nil {
|
||||
netc.bind.Close()
|
||||
netc.port = 0
|
||||
return err
|
||||
}
|
||||
|
||||
// set fwmark
|
||||
if netc.fwmark != 0 {
|
||||
err = netc.bind.SetMark(netc.fwmark)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// clear cached source addresses
|
||||
device.peers.RLock()
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.markEndpointSrcForClearing()
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
|
||||
// start receiving routines
|
||||
device.net.stopping.Add(len(recvFns))
|
||||
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
|
||||
device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
|
||||
batchSize := netc.bind.BatchSize()
|
||||
for _, fn := range recvFns {
|
||||
go device.RoutineReceiveIncoming(batchSize, fn)
|
||||
}
|
||||
|
||||
device.log.Verbosef("UDP bind has been updated")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (device *Device) BindClose() error {
|
||||
device.net.Lock()
|
||||
err := closeBindLocked(device)
|
||||
device.net.Unlock()
|
||||
return err
|
||||
}
|
||||
16
vendor/github.com/tailscale/wireguard-go/device/devicestate_string.go
generated
vendored
Normal file
16
vendor/github.com/tailscale/wireguard-go/device/devicestate_string.go
generated
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
// Code generated by "stringer -type deviceState -trimprefix=deviceState"; DO NOT EDIT.
|
||||
|
||||
package device
|
||||
|
||||
import "strconv"
|
||||
|
||||
const _deviceState_name = "DownUpClosed"
|
||||
|
||||
var _deviceState_index = [...]uint8{0, 4, 6, 12}
|
||||
|
||||
func (i deviceState) String() string {
|
||||
if i >= deviceState(len(_deviceState_index)-1) {
|
||||
return "deviceState(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
return _deviceState_name[_deviceState_index[i]:_deviceState_index[i+1]]
|
||||
}
|
||||
98
vendor/github.com/tailscale/wireguard-go/device/indextable.go
generated
vendored
Normal file
98
vendor/github.com/tailscale/wireguard-go/device/indextable.go
generated
vendored
Normal file
@@ -0,0 +1,98 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type IndexTableEntry struct {
|
||||
peer *Peer
|
||||
handshake *Handshake
|
||||
keypair *Keypair
|
||||
}
|
||||
|
||||
type IndexTable struct {
|
||||
sync.RWMutex
|
||||
table map[uint32]IndexTableEntry
|
||||
}
|
||||
|
||||
func randUint32() (uint32, error) {
|
||||
var integer [4]byte
|
||||
_, err := rand.Read(integer[:])
|
||||
// Arbitrary endianness; both are intrinsified by the Go compiler.
|
||||
return binary.LittleEndian.Uint32(integer[:]), err
|
||||
}
|
||||
|
||||
func (table *IndexTable) Init() {
|
||||
table.Lock()
|
||||
defer table.Unlock()
|
||||
table.table = make(map[uint32]IndexTableEntry)
|
||||
}
|
||||
|
||||
func (table *IndexTable) Delete(index uint32) {
|
||||
table.Lock()
|
||||
defer table.Unlock()
|
||||
delete(table.table, index)
|
||||
}
|
||||
|
||||
func (table *IndexTable) SwapIndexForKeypair(index uint32, keypair *Keypair) {
|
||||
table.Lock()
|
||||
defer table.Unlock()
|
||||
entry, ok := table.table[index]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
table.table[index] = IndexTableEntry{
|
||||
peer: entry.peer,
|
||||
keypair: keypair,
|
||||
handshake: nil,
|
||||
}
|
||||
}
|
||||
|
||||
func (table *IndexTable) NewIndexForHandshake(peer *Peer, handshake *Handshake) (uint32, error) {
|
||||
for {
|
||||
// generate random index
|
||||
|
||||
index, err := randUint32()
|
||||
if err != nil {
|
||||
return index, err
|
||||
}
|
||||
|
||||
// check if index used
|
||||
|
||||
table.RLock()
|
||||
_, ok := table.table[index]
|
||||
table.RUnlock()
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// check again while locked
|
||||
|
||||
table.Lock()
|
||||
_, found := table.table[index]
|
||||
if found {
|
||||
table.Unlock()
|
||||
continue
|
||||
}
|
||||
table.table[index] = IndexTableEntry{
|
||||
peer: peer,
|
||||
handshake: handshake,
|
||||
keypair: nil,
|
||||
}
|
||||
table.Unlock()
|
||||
return index, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (table *IndexTable) Lookup(id uint32) IndexTableEntry {
|
||||
table.RLock()
|
||||
defer table.RUnlock()
|
||||
return table.table[id]
|
||||
}
|
||||
22
vendor/github.com/tailscale/wireguard-go/device/ip.go
generated
vendored
Normal file
22
vendor/github.com/tailscale/wireguard-go/device/ip.go
generated
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
const (
|
||||
IPv4offsetTotalLength = 2
|
||||
IPv4offsetSrc = 12
|
||||
IPv4offsetDst = IPv4offsetSrc + net.IPv4len
|
||||
)
|
||||
|
||||
const (
|
||||
IPv6offsetPayloadLength = 4
|
||||
IPv6offsetSrc = 8
|
||||
IPv6offsetDst = IPv6offsetSrc + net.IPv6len
|
||||
)
|
||||
52
vendor/github.com/tailscale/wireguard-go/device/keypair.go
generated
vendored
Normal file
52
vendor/github.com/tailscale/wireguard-go/device/keypair.go
generated
vendored
Normal file
@@ -0,0 +1,52 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/wireguard-go/replay"
|
||||
)
|
||||
|
||||
/* Due to limitations in Go and /x/crypto there is currently
|
||||
* no way to ensure that key material is securely ereased in memory.
|
||||
*
|
||||
* Since this may harm the forward secrecy property,
|
||||
* we plan to resolve this issue; whenever Go allows us to do so.
|
||||
*/
|
||||
|
||||
type Keypair struct {
|
||||
sendNonce atomic.Uint64
|
||||
send cipher.AEAD
|
||||
receive cipher.AEAD
|
||||
replayFilter replay.Filter
|
||||
isInitiator bool
|
||||
created time.Time
|
||||
localIndex uint32
|
||||
remoteIndex uint32
|
||||
}
|
||||
|
||||
type Keypairs struct {
|
||||
sync.RWMutex
|
||||
current *Keypair
|
||||
previous *Keypair
|
||||
next atomic.Pointer[Keypair]
|
||||
}
|
||||
|
||||
func (kp *Keypairs) Current() *Keypair {
|
||||
kp.RLock()
|
||||
defer kp.RUnlock()
|
||||
return kp.current
|
||||
}
|
||||
|
||||
func (device *Device) DeleteKeypair(key *Keypair) {
|
||||
if key != nil {
|
||||
device.indexTable.Delete(key.localIndex)
|
||||
}
|
||||
}
|
||||
48
vendor/github.com/tailscale/wireguard-go/device/logger.go
generated
vendored
Normal file
48
vendor/github.com/tailscale/wireguard-go/device/logger.go
generated
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
||||
// A Logger provides logging for a Device.
|
||||
// The functions are Printf-style functions.
|
||||
// They must be safe for concurrent use.
|
||||
// They do not require a trailing newline in the format.
|
||||
// If nil, that level of logging will be silent.
|
||||
type Logger struct {
|
||||
Verbosef func(format string, args ...any)
|
||||
Errorf func(format string, args ...any)
|
||||
}
|
||||
|
||||
// Log levels for use with NewLogger.
|
||||
const (
|
||||
LogLevelSilent = iota
|
||||
LogLevelError
|
||||
LogLevelVerbose
|
||||
)
|
||||
|
||||
// Function for use in Logger for discarding logged lines.
|
||||
func DiscardLogf(format string, args ...any) {}
|
||||
|
||||
// NewLogger constructs a Logger that writes to stdout.
|
||||
// It logs at the specified log level and above.
|
||||
// It decorates log lines with the log level, date, time, and prepend.
|
||||
func NewLogger(level int, prepend string) *Logger {
|
||||
logger := &Logger{DiscardLogf, DiscardLogf}
|
||||
logf := func(prefix string) func(string, ...any) {
|
||||
return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf
|
||||
}
|
||||
if level >= LogLevelVerbose {
|
||||
logger.Verbosef = logf("DEBUG")
|
||||
}
|
||||
if level >= LogLevelError {
|
||||
logger.Errorf = logf("ERROR")
|
||||
}
|
||||
return logger
|
||||
}
|
||||
19
vendor/github.com/tailscale/wireguard-go/device/mobilequirks.go
generated
vendored
Normal file
19
vendor/github.com/tailscale/wireguard-go/device/mobilequirks.go
generated
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
// DisableSomeRoamingForBrokenMobileSemantics should ideally be called before peers are created,
|
||||
// though it will try to deal with it, and race maybe, if called after.
|
||||
func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
|
||||
device.net.brokenRoaming = true
|
||||
device.peers.RLock()
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.endpoint.Lock()
|
||||
peer.endpoint.disableRoaming = peer.endpoint.val != nil
|
||||
peer.endpoint.Unlock()
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
}
|
||||
108
vendor/github.com/tailscale/wireguard-go/device/noise-helpers.go
generated
vendored
Normal file
108
vendor/github.com/tailscale/wireguard-go/device/noise-helpers.go
generated
vendored
Normal file
@@ -0,0 +1,108 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
"hash"
|
||||
|
||||
"golang.org/x/crypto/blake2s"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
)
|
||||
|
||||
/* KDF related functions.
|
||||
* HMAC-based Key Derivation Function (HKDF)
|
||||
* https://tools.ietf.org/html/rfc5869
|
||||
*/
|
||||
|
||||
func HMAC1(sum *[blake2s.Size]byte, key, in0 []byte) {
|
||||
mac := hmac.New(func() hash.Hash {
|
||||
h, _ := blake2s.New256(nil)
|
||||
return h
|
||||
}, key)
|
||||
mac.Write(in0)
|
||||
mac.Sum(sum[:0])
|
||||
}
|
||||
|
||||
func HMAC2(sum *[blake2s.Size]byte, key, in0, in1 []byte) {
|
||||
mac := hmac.New(func() hash.Hash {
|
||||
h, _ := blake2s.New256(nil)
|
||||
return h
|
||||
}, key)
|
||||
mac.Write(in0)
|
||||
mac.Write(in1)
|
||||
mac.Sum(sum[:0])
|
||||
}
|
||||
|
||||
func KDF1(t0 *[blake2s.Size]byte, key, input []byte) {
|
||||
HMAC1(t0, key, input)
|
||||
HMAC1(t0, t0[:], []byte{0x1})
|
||||
}
|
||||
|
||||
func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) {
|
||||
var prk [blake2s.Size]byte
|
||||
HMAC1(&prk, key, input)
|
||||
HMAC1(t0, prk[:], []byte{0x1})
|
||||
HMAC2(t1, prk[:], t0[:], []byte{0x2})
|
||||
setZero(prk[:])
|
||||
}
|
||||
|
||||
func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) {
|
||||
var prk [blake2s.Size]byte
|
||||
HMAC1(&prk, key, input)
|
||||
HMAC1(t0, prk[:], []byte{0x1})
|
||||
HMAC2(t1, prk[:], t0[:], []byte{0x2})
|
||||
HMAC2(t2, prk[:], t1[:], []byte{0x3})
|
||||
setZero(prk[:])
|
||||
}
|
||||
|
||||
func isZero(val []byte) bool {
|
||||
acc := 1
|
||||
for _, b := range val {
|
||||
acc &= subtle.ConstantTimeByteEq(b, 0)
|
||||
}
|
||||
return acc == 1
|
||||
}
|
||||
|
||||
/* This function is not used as pervasively as it should because this is mostly impossible in Go at the moment */
|
||||
func setZero(arr []byte) {
|
||||
for i := range arr {
|
||||
arr[i] = 0
|
||||
}
|
||||
}
|
||||
|
||||
func (sk *NoisePrivateKey) clamp() {
|
||||
sk[0] &= 248
|
||||
sk[31] = (sk[31] & 127) | 64
|
||||
}
|
||||
|
||||
func newPrivateKey() (sk NoisePrivateKey, err error) {
|
||||
_, err = rand.Read(sk[:])
|
||||
sk.clamp()
|
||||
return
|
||||
}
|
||||
|
||||
func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
|
||||
apk := (*[NoisePublicKeySize]byte)(&pk)
|
||||
ask := (*[NoisePrivateKeySize]byte)(sk)
|
||||
curve25519.ScalarBaseMult(apk, ask)
|
||||
return
|
||||
}
|
||||
|
||||
var errInvalidPublicKey = errors.New("invalid public key")
|
||||
|
||||
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) {
|
||||
apk := (*[NoisePublicKeySize]byte)(&pk)
|
||||
ask := (*[NoisePrivateKeySize]byte)(sk)
|
||||
curve25519.ScalarMult(&ss, ask, apk)
|
||||
if isZero(ss[:]) {
|
||||
return ss, errInvalidPublicKey
|
||||
}
|
||||
return ss, nil
|
||||
}
|
||||
625
vendor/github.com/tailscale/wireguard-go/device/noise-protocol.go
generated
vendored
Normal file
625
vendor/github.com/tailscale/wireguard-go/device/noise-protocol.go
generated
vendored
Normal file
@@ -0,0 +1,625 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/blake2s"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/poly1305"
|
||||
|
||||
"github.com/tailscale/wireguard-go/tai64n"
|
||||
)
|
||||
|
||||
type handshakeState int
|
||||
|
||||
const (
|
||||
handshakeZeroed = handshakeState(iota)
|
||||
handshakeInitiationCreated
|
||||
handshakeInitiationConsumed
|
||||
handshakeResponseCreated
|
||||
handshakeResponseConsumed
|
||||
)
|
||||
|
||||
func (hs handshakeState) String() string {
|
||||
switch hs {
|
||||
case handshakeZeroed:
|
||||
return "handshakeZeroed"
|
||||
case handshakeInitiationCreated:
|
||||
return "handshakeInitiationCreated"
|
||||
case handshakeInitiationConsumed:
|
||||
return "handshakeInitiationConsumed"
|
||||
case handshakeResponseCreated:
|
||||
return "handshakeResponseCreated"
|
||||
case handshakeResponseConsumed:
|
||||
return "handshakeResponseConsumed"
|
||||
default:
|
||||
return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs))
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
|
||||
WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com"
|
||||
WGLabelMAC1 = "mac1----"
|
||||
WGLabelCookie = "cookie--"
|
||||
)
|
||||
|
||||
const (
|
||||
MessageInitiationType = 1
|
||||
MessageResponseType = 2
|
||||
MessageCookieReplyType = 3
|
||||
MessageTransportType = 4
|
||||
)
|
||||
|
||||
const (
|
||||
MessageInitiationSize = 148 // size of handshake initiation message
|
||||
MessageResponseSize = 92 // size of response message
|
||||
MessageCookieReplySize = 64 // size of cookie reply message
|
||||
MessageTransportHeaderSize = 16 // size of data preceding content in transport message
|
||||
MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
|
||||
MessageKeepaliveSize = MessageTransportSize // size of keepalive
|
||||
MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message
|
||||
)
|
||||
|
||||
const (
|
||||
MessageTransportOffsetReceiver = 4
|
||||
MessageTransportOffsetCounter = 8
|
||||
MessageTransportOffsetContent = 16
|
||||
)
|
||||
|
||||
/* Type is an 8-bit field, followed by 3 nul bytes,
|
||||
* by marshalling the messages in little-endian byteorder
|
||||
* we can treat these as a 32-bit unsigned int (for now)
|
||||
*
|
||||
*/
|
||||
|
||||
type MessageInitiation struct {
|
||||
Type uint32
|
||||
Sender uint32
|
||||
Ephemeral NoisePublicKey
|
||||
Static [NoisePublicKeySize + poly1305.TagSize]byte
|
||||
Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte
|
||||
MAC1 [blake2s.Size128]byte
|
||||
MAC2 [blake2s.Size128]byte
|
||||
}
|
||||
|
||||
type MessageResponse struct {
|
||||
Type uint32
|
||||
Sender uint32
|
||||
Receiver uint32
|
||||
Ephemeral NoisePublicKey
|
||||
Empty [poly1305.TagSize]byte
|
||||
MAC1 [blake2s.Size128]byte
|
||||
MAC2 [blake2s.Size128]byte
|
||||
}
|
||||
|
||||
type MessageTransport struct {
|
||||
Type uint32
|
||||
Receiver uint32
|
||||
Counter uint64
|
||||
Content []byte
|
||||
}
|
||||
|
||||
type MessageCookieReply struct {
|
||||
Type uint32
|
||||
Receiver uint32
|
||||
Nonce [chacha20poly1305.NonceSizeX]byte
|
||||
Cookie [blake2s.Size128 + poly1305.TagSize]byte
|
||||
}
|
||||
|
||||
type Handshake struct {
|
||||
state handshakeState
|
||||
mutex sync.RWMutex
|
||||
hash [blake2s.Size]byte // hash value
|
||||
chainKey [blake2s.Size]byte // chain key
|
||||
presharedKey NoisePresharedKey // psk
|
||||
localEphemeral NoisePrivateKey // ephemeral secret key
|
||||
localIndex uint32 // used to clear hash-table
|
||||
remoteIndex uint32 // index for sending
|
||||
remoteStatic NoisePublicKey // long term key, never changes, can be accessed without mutex
|
||||
remoteEphemeral NoisePublicKey // ephemeral public key
|
||||
precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret
|
||||
lastTimestamp tai64n.Timestamp
|
||||
lastInitiationConsumption time.Time
|
||||
lastSentHandshake time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
InitialChainKey [blake2s.Size]byte
|
||||
InitialHash [blake2s.Size]byte
|
||||
ZeroNonce [chacha20poly1305.NonceSize]byte
|
||||
)
|
||||
|
||||
func mixKey(dst, c *[blake2s.Size]byte, data []byte) {
|
||||
KDF1(dst, c[:], data)
|
||||
}
|
||||
|
||||
func mixHash(dst, h *[blake2s.Size]byte, data []byte) {
|
||||
hash, _ := blake2s.New256(nil)
|
||||
hash.Write(h[:])
|
||||
hash.Write(data)
|
||||
hash.Sum(dst[:0])
|
||||
hash.Reset()
|
||||
}
|
||||
|
||||
func (h *Handshake) Clear() {
|
||||
setZero(h.localEphemeral[:])
|
||||
setZero(h.remoteEphemeral[:])
|
||||
setZero(h.chainKey[:])
|
||||
setZero(h.hash[:])
|
||||
h.localIndex = 0
|
||||
h.state = handshakeZeroed
|
||||
}
|
||||
|
||||
func (h *Handshake) mixHash(data []byte) {
|
||||
mixHash(&h.hash, &h.hash, data)
|
||||
}
|
||||
|
||||
func (h *Handshake) mixKey(data []byte) {
|
||||
mixKey(&h.chainKey, &h.chainKey, data)
|
||||
}
|
||||
|
||||
/* Do basic precomputations
|
||||
*/
|
||||
func init() {
|
||||
InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction))
|
||||
mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier))
|
||||
}
|
||||
|
||||
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
|
||||
device.staticIdentity.RLock()
|
||||
defer device.staticIdentity.RUnlock()
|
||||
|
||||
handshake := &peer.handshake
|
||||
handshake.mutex.Lock()
|
||||
defer handshake.mutex.Unlock()
|
||||
|
||||
// create ephemeral key
|
||||
var err error
|
||||
handshake.hash = InitialHash
|
||||
handshake.chainKey = InitialChainKey
|
||||
handshake.localEphemeral, err = newPrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
handshake.mixHash(handshake.remoteStatic[:])
|
||||
|
||||
msg := MessageInitiation{
|
||||
Type: MessageInitiationType,
|
||||
Ephemeral: handshake.localEphemeral.publicKey(),
|
||||
}
|
||||
|
||||
handshake.mixKey(msg.Ephemeral[:])
|
||||
handshake.mixHash(msg.Ephemeral[:])
|
||||
|
||||
// encrypt static key
|
||||
ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var key [chacha20poly1305.KeySize]byte
|
||||
KDF2(
|
||||
&handshake.chainKey,
|
||||
&key,
|
||||
handshake.chainKey[:],
|
||||
ss[:],
|
||||
)
|
||||
aead, _ := chacha20poly1305.New(key[:])
|
||||
aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
|
||||
handshake.mixHash(msg.Static[:])
|
||||
|
||||
// encrypt timestamp
|
||||
if isZero(handshake.precomputedStaticStatic[:]) {
|
||||
return nil, errInvalidPublicKey
|
||||
}
|
||||
KDF2(
|
||||
&handshake.chainKey,
|
||||
&key,
|
||||
handshake.chainKey[:],
|
||||
handshake.precomputedStaticStatic[:],
|
||||
)
|
||||
timestamp := tai64n.Now()
|
||||
aead, _ = chacha20poly1305.New(key[:])
|
||||
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
|
||||
|
||||
// assign index
|
||||
device.indexTable.Delete(handshake.localIndex)
|
||||
msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
handshake.localIndex = msg.Sender
|
||||
|
||||
handshake.mixHash(msg.Timestamp[:])
|
||||
handshake.state = handshakeInitiationCreated
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
||||
var (
|
||||
hash [blake2s.Size]byte
|
||||
chainKey [blake2s.Size]byte
|
||||
)
|
||||
|
||||
if msg.Type != MessageInitiationType {
|
||||
return nil
|
||||
}
|
||||
|
||||
device.staticIdentity.RLock()
|
||||
defer device.staticIdentity.RUnlock()
|
||||
|
||||
mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:])
|
||||
mixHash(&hash, &hash, msg.Ephemeral[:])
|
||||
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
|
||||
|
||||
// decrypt static key
|
||||
var peerPK NoisePublicKey
|
||||
var key [chacha20poly1305.KeySize]byte
|
||||
ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
KDF2(&chainKey, &key, chainKey[:], ss[:])
|
||||
aead, _ := chacha20poly1305.New(key[:])
|
||||
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
mixHash(&hash, &hash, msg.Static[:])
|
||||
|
||||
// lookup peer
|
||||
|
||||
peer := device.LookupPeer(peerPK)
|
||||
if peer == nil || !peer.isRunning.Load() {
|
||||
return nil
|
||||
}
|
||||
|
||||
handshake := &peer.handshake
|
||||
|
||||
// verify identity
|
||||
|
||||
var timestamp tai64n.Timestamp
|
||||
|
||||
handshake.mutex.RLock()
|
||||
|
||||
if isZero(handshake.precomputedStaticStatic[:]) {
|
||||
handshake.mutex.RUnlock()
|
||||
return nil
|
||||
}
|
||||
KDF2(
|
||||
&chainKey,
|
||||
&key,
|
||||
chainKey[:],
|
||||
handshake.precomputedStaticStatic[:],
|
||||
)
|
||||
aead, _ = chacha20poly1305.New(key[:])
|
||||
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
|
||||
if err != nil {
|
||||
handshake.mutex.RUnlock()
|
||||
return nil
|
||||
}
|
||||
mixHash(&hash, &hash, msg.Timestamp[:])
|
||||
|
||||
// protect against replay & flood
|
||||
|
||||
replay := !timestamp.After(handshake.lastTimestamp)
|
||||
flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate
|
||||
handshake.mutex.RUnlock()
|
||||
if replay {
|
||||
device.log.Verbosef("%v - ConsumeMessageInitiation: handshake replay @ %v", peer, timestamp)
|
||||
return nil
|
||||
}
|
||||
if flood {
|
||||
device.log.Verbosef("%v - ConsumeMessageInitiation: handshake flood", peer)
|
||||
return nil
|
||||
}
|
||||
|
||||
// update handshake state
|
||||
|
||||
handshake.mutex.Lock()
|
||||
|
||||
handshake.hash = hash
|
||||
handshake.chainKey = chainKey
|
||||
handshake.remoteIndex = msg.Sender
|
||||
handshake.remoteEphemeral = msg.Ephemeral
|
||||
if timestamp.After(handshake.lastTimestamp) {
|
||||
handshake.lastTimestamp = timestamp
|
||||
}
|
||||
now := time.Now()
|
||||
if now.After(handshake.lastInitiationConsumption) {
|
||||
handshake.lastInitiationConsumption = now
|
||||
}
|
||||
handshake.state = handshakeInitiationConsumed
|
||||
|
||||
handshake.mutex.Unlock()
|
||||
|
||||
setZero(hash[:])
|
||||
setZero(chainKey[:])
|
||||
|
||||
return peer
|
||||
}
|
||||
|
||||
func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) {
|
||||
handshake := &peer.handshake
|
||||
handshake.mutex.Lock()
|
||||
defer handshake.mutex.Unlock()
|
||||
|
||||
if handshake.state != handshakeInitiationConsumed {
|
||||
return nil, errors.New("handshake initiation must be consumed first")
|
||||
}
|
||||
|
||||
// assign index
|
||||
|
||||
var err error
|
||||
device.indexTable.Delete(handshake.localIndex)
|
||||
handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var msg MessageResponse
|
||||
msg.Type = MessageResponseType
|
||||
msg.Sender = handshake.localIndex
|
||||
msg.Receiver = handshake.remoteIndex
|
||||
|
||||
// create ephemeral key
|
||||
|
||||
handshake.localEphemeral, err = newPrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msg.Ephemeral = handshake.localEphemeral.publicKey()
|
||||
handshake.mixHash(msg.Ephemeral[:])
|
||||
handshake.mixKey(msg.Ephemeral[:])
|
||||
|
||||
ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
handshake.mixKey(ss[:])
|
||||
ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
handshake.mixKey(ss[:])
|
||||
|
||||
// add preshared key
|
||||
|
||||
var tau [blake2s.Size]byte
|
||||
var key [chacha20poly1305.KeySize]byte
|
||||
|
||||
KDF3(
|
||||
&handshake.chainKey,
|
||||
&tau,
|
||||
&key,
|
||||
handshake.chainKey[:],
|
||||
handshake.presharedKey[:],
|
||||
)
|
||||
|
||||
handshake.mixHash(tau[:])
|
||||
|
||||
aead, _ := chacha20poly1305.New(key[:])
|
||||
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
|
||||
handshake.mixHash(msg.Empty[:])
|
||||
|
||||
handshake.state = handshakeResponseCreated
|
||||
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
||||
if msg.Type != MessageResponseType {
|
||||
return nil
|
||||
}
|
||||
|
||||
// lookup handshake by receiver
|
||||
|
||||
lookup := device.indexTable.Lookup(msg.Receiver)
|
||||
handshake := lookup.handshake
|
||||
if handshake == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
hash [blake2s.Size]byte
|
||||
chainKey [blake2s.Size]byte
|
||||
)
|
||||
|
||||
ok := func() bool {
|
||||
// lock handshake state
|
||||
|
||||
handshake.mutex.RLock()
|
||||
defer handshake.mutex.RUnlock()
|
||||
|
||||
if handshake.state != handshakeInitiationCreated {
|
||||
return false
|
||||
}
|
||||
|
||||
// lock private key for reading
|
||||
|
||||
device.staticIdentity.RLock()
|
||||
defer device.staticIdentity.RUnlock()
|
||||
|
||||
// finish 3-way DH
|
||||
|
||||
mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
|
||||
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
|
||||
|
||||
ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
mixKey(&chainKey, &chainKey, ss[:])
|
||||
setZero(ss[:])
|
||||
|
||||
ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
mixKey(&chainKey, &chainKey, ss[:])
|
||||
setZero(ss[:])
|
||||
|
||||
// add preshared key (psk)
|
||||
|
||||
var tau [blake2s.Size]byte
|
||||
var key [chacha20poly1305.KeySize]byte
|
||||
KDF3(
|
||||
&chainKey,
|
||||
&tau,
|
||||
&key,
|
||||
chainKey[:],
|
||||
handshake.presharedKey[:],
|
||||
)
|
||||
mixHash(&hash, &hash, tau[:])
|
||||
|
||||
// authenticate transcript
|
||||
|
||||
aead, _ := chacha20poly1305.New(key[:])
|
||||
_, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
mixHash(&hash, &hash, msg.Empty[:])
|
||||
return true
|
||||
}()
|
||||
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// update handshake state
|
||||
|
||||
handshake.mutex.Lock()
|
||||
|
||||
handshake.hash = hash
|
||||
handshake.chainKey = chainKey
|
||||
handshake.remoteIndex = msg.Sender
|
||||
handshake.state = handshakeResponseConsumed
|
||||
|
||||
handshake.mutex.Unlock()
|
||||
|
||||
setZero(hash[:])
|
||||
setZero(chainKey[:])
|
||||
|
||||
return lookup.peer
|
||||
}
|
||||
|
||||
/* Derives a new keypair from the current handshake state
|
||||
*
|
||||
*/
|
||||
func (peer *Peer) BeginSymmetricSession() error {
|
||||
device := peer.device
|
||||
handshake := &peer.handshake
|
||||
handshake.mutex.Lock()
|
||||
defer handshake.mutex.Unlock()
|
||||
|
||||
// derive keys
|
||||
|
||||
var isInitiator bool
|
||||
var sendKey [chacha20poly1305.KeySize]byte
|
||||
var recvKey [chacha20poly1305.KeySize]byte
|
||||
|
||||
if handshake.state == handshakeResponseConsumed {
|
||||
KDF2(
|
||||
&sendKey,
|
||||
&recvKey,
|
||||
handshake.chainKey[:],
|
||||
nil,
|
||||
)
|
||||
isInitiator = true
|
||||
} else if handshake.state == handshakeResponseCreated {
|
||||
KDF2(
|
||||
&recvKey,
|
||||
&sendKey,
|
||||
handshake.chainKey[:],
|
||||
nil,
|
||||
)
|
||||
isInitiator = false
|
||||
} else {
|
||||
return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state)
|
||||
}
|
||||
|
||||
// zero handshake
|
||||
|
||||
setZero(handshake.chainKey[:])
|
||||
setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
|
||||
setZero(handshake.localEphemeral[:])
|
||||
peer.handshake.state = handshakeZeroed
|
||||
|
||||
// create AEAD instances
|
||||
|
||||
keypair := new(Keypair)
|
||||
keypair.send, _ = chacha20poly1305.New(sendKey[:])
|
||||
keypair.receive, _ = chacha20poly1305.New(recvKey[:])
|
||||
|
||||
setZero(sendKey[:])
|
||||
setZero(recvKey[:])
|
||||
|
||||
keypair.created = time.Now()
|
||||
keypair.replayFilter.Reset()
|
||||
keypair.isInitiator = isInitiator
|
||||
keypair.localIndex = peer.handshake.localIndex
|
||||
keypair.remoteIndex = peer.handshake.remoteIndex
|
||||
|
||||
// remap index
|
||||
|
||||
device.indexTable.SwapIndexForKeypair(handshake.localIndex, keypair)
|
||||
handshake.localIndex = 0
|
||||
|
||||
// rotate key pairs
|
||||
|
||||
keypairs := &peer.keypairs
|
||||
keypairs.Lock()
|
||||
defer keypairs.Unlock()
|
||||
|
||||
previous := keypairs.previous
|
||||
next := keypairs.next.Load()
|
||||
current := keypairs.current
|
||||
|
||||
if isInitiator {
|
||||
if next != nil {
|
||||
keypairs.next.Store(nil)
|
||||
keypairs.previous = next
|
||||
device.DeleteKeypair(current)
|
||||
} else {
|
||||
keypairs.previous = current
|
||||
}
|
||||
device.DeleteKeypair(previous)
|
||||
keypairs.current = keypair
|
||||
} else {
|
||||
keypairs.next.Store(keypair)
|
||||
device.DeleteKeypair(next)
|
||||
keypairs.previous = nil
|
||||
device.DeleteKeypair(previous)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
|
||||
keypairs := &peer.keypairs
|
||||
|
||||
if keypairs.next.Load() != receivedKeypair {
|
||||
return false
|
||||
}
|
||||
keypairs.Lock()
|
||||
defer keypairs.Unlock()
|
||||
if keypairs.next.Load() != receivedKeypair {
|
||||
return false
|
||||
}
|
||||
old := keypairs.previous
|
||||
keypairs.previous = keypairs.current
|
||||
peer.device.DeleteKeypair(old)
|
||||
keypairs.current = keypairs.next.Load()
|
||||
keypairs.next.Store(nil)
|
||||
return true
|
||||
}
|
||||
78
vendor/github.com/tailscale/wireguard-go/device/noise-types.go
generated
vendored
Normal file
78
vendor/github.com/tailscale/wireguard-go/device/noise-types.go
generated
vendored
Normal file
@@ -0,0 +1,78 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
)
|
||||
|
||||
const (
|
||||
NoisePublicKeySize = 32
|
||||
NoisePrivateKeySize = 32
|
||||
NoisePresharedKeySize = 32
|
||||
)
|
||||
|
||||
type (
|
||||
NoisePublicKey [NoisePublicKeySize]byte
|
||||
NoisePrivateKey [NoisePrivateKeySize]byte
|
||||
NoisePresharedKey [NoisePresharedKeySize]byte
|
||||
NoiseNonce uint64 // padded to 12-bytes
|
||||
)
|
||||
|
||||
func loadExactHex(dst []byte, src string) error {
|
||||
slice, err := hex.DecodeString(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(slice) != len(dst) {
|
||||
return errors.New("hex string does not fit the slice")
|
||||
}
|
||||
copy(dst, slice)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (key NoisePrivateKey) IsZero() bool {
|
||||
var zero NoisePrivateKey
|
||||
return key.Equals(zero)
|
||||
}
|
||||
|
||||
func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool {
|
||||
return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
|
||||
}
|
||||
|
||||
func (key *NoisePrivateKey) FromHex(src string) (err error) {
|
||||
err = loadExactHex(key[:], src)
|
||||
key.clamp()
|
||||
return
|
||||
}
|
||||
|
||||
func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) {
|
||||
err = loadExactHex(key[:], src)
|
||||
if key.IsZero() {
|
||||
return
|
||||
}
|
||||
key.clamp()
|
||||
return
|
||||
}
|
||||
|
||||
func (key *NoisePublicKey) FromHex(src string) error {
|
||||
return loadExactHex(key[:], src)
|
||||
}
|
||||
|
||||
func (key NoisePublicKey) IsZero() bool {
|
||||
var zero NoisePublicKey
|
||||
return key.Equals(zero)
|
||||
}
|
||||
|
||||
func (key NoisePublicKey) Equals(tar NoisePublicKey) bool {
|
||||
return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
|
||||
}
|
||||
|
||||
func (key *NoisePresharedKey) FromHex(src string) error {
|
||||
return loadExactHex(key[:], src)
|
||||
}
|
||||
299
vendor/github.com/tailscale/wireguard-go/device/peer.go
generated
vendored
Normal file
299
vendor/github.com/tailscale/wireguard-go/device/peer.go
generated
vendored
Normal file
@@ -0,0 +1,299 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/wireguard-go/conn"
|
||||
)
|
||||
|
||||
type Peer struct {
|
||||
isRunning atomic.Bool
|
||||
keypairs Keypairs
|
||||
handshake Handshake
|
||||
device *Device
|
||||
stopping sync.WaitGroup // routines pending stop
|
||||
txBytes atomic.Uint64 // bytes send to peer (endpoint)
|
||||
rxBytes atomic.Uint64 // bytes received from peer
|
||||
lastHandshakeNano atomic.Int64 // nano seconds since epoch
|
||||
|
||||
endpoint struct {
|
||||
sync.Mutex
|
||||
val conn.Endpoint
|
||||
clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission
|
||||
disableRoaming bool
|
||||
}
|
||||
|
||||
timers struct {
|
||||
retransmitHandshake *Timer
|
||||
sendKeepalive *Timer
|
||||
newHandshake *Timer
|
||||
zeroKeyMaterial *Timer
|
||||
persistentKeepalive *Timer
|
||||
handshakeAttempts atomic.Uint32
|
||||
needAnotherKeepalive atomic.Bool
|
||||
sentLastMinuteHandshake atomic.Bool
|
||||
}
|
||||
|
||||
state struct {
|
||||
sync.Mutex // protects against concurrent Start/Stop
|
||||
}
|
||||
|
||||
queue struct {
|
||||
staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available
|
||||
outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
|
||||
inbound *autodrainingInboundQueue // sequential ordering of tun writing
|
||||
}
|
||||
|
||||
cookieGenerator CookieGenerator
|
||||
trieEntries list.List
|
||||
persistentKeepaliveInterval atomic.Uint32
|
||||
}
|
||||
|
||||
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
||||
if device.isClosed() {
|
||||
return nil, errors.New("device closed")
|
||||
}
|
||||
|
||||
// lock resources
|
||||
device.staticIdentity.RLock()
|
||||
defer device.staticIdentity.RUnlock()
|
||||
|
||||
device.peers.Lock()
|
||||
defer device.peers.Unlock()
|
||||
|
||||
// check if over limit
|
||||
if len(device.peers.keyMap) >= MaxPeers {
|
||||
return nil, errors.New("too many peers")
|
||||
}
|
||||
|
||||
// create peer
|
||||
peer := new(Peer)
|
||||
|
||||
peer.cookieGenerator.Init(pk)
|
||||
peer.device = device
|
||||
peer.queue.outbound = newAutodrainingOutboundQueue(device)
|
||||
peer.queue.inbound = newAutodrainingInboundQueue(device)
|
||||
peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize)
|
||||
|
||||
// map public key
|
||||
_, ok := device.peers.keyMap[pk]
|
||||
if ok {
|
||||
return nil, errors.New("adding existing peer")
|
||||
}
|
||||
|
||||
// pre-compute DH
|
||||
handshake := &peer.handshake
|
||||
handshake.mutex.Lock()
|
||||
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk)
|
||||
handshake.remoteStatic = pk
|
||||
handshake.mutex.Unlock()
|
||||
|
||||
// reset endpoint
|
||||
peer.endpoint.Lock()
|
||||
peer.endpoint.val = nil
|
||||
peer.endpoint.disableRoaming = false
|
||||
peer.endpoint.clearSrcOnTx = false
|
||||
peer.endpoint.Unlock()
|
||||
|
||||
// init timers
|
||||
peer.timersInit()
|
||||
|
||||
// add
|
||||
device.peers.keyMap[pk] = peer
|
||||
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
func (peer *Peer) SendBuffers(buffers [][]byte) error {
|
||||
peer.device.net.RLock()
|
||||
defer peer.device.net.RUnlock()
|
||||
|
||||
if peer.device.isClosed() {
|
||||
return nil
|
||||
}
|
||||
|
||||
peer.endpoint.Lock()
|
||||
endpoint := peer.endpoint.val
|
||||
if endpoint == nil {
|
||||
peer.endpoint.Unlock()
|
||||
return errors.New("no known endpoint for peer")
|
||||
}
|
||||
if peer.endpoint.clearSrcOnTx {
|
||||
endpoint.ClearSrc()
|
||||
peer.endpoint.clearSrcOnTx = false
|
||||
}
|
||||
peer.endpoint.Unlock()
|
||||
|
||||
err := peer.device.net.bind.Send(buffers, endpoint)
|
||||
if err == nil {
|
||||
var totalLen uint64
|
||||
for _, b := range buffers {
|
||||
totalLen += uint64(len(b))
|
||||
}
|
||||
peer.txBytes.Add(totalLen)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (peer *Peer) String() string {
|
||||
// The awful goo that follows is identical to:
|
||||
//
|
||||
// base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
|
||||
// abbreviatedKey := base64Key[0:4] + "…" + base64Key[39:43]
|
||||
// return fmt.Sprintf("peer(%s)", abbreviatedKey)
|
||||
//
|
||||
// except that it is considerably more efficient.
|
||||
src := peer.handshake.remoteStatic
|
||||
b64 := func(input byte) byte {
|
||||
return input + 'A' + byte(((25-int(input))>>8)&6) - byte(((51-int(input))>>8)&75) - byte(((61-int(input))>>8)&15) + byte(((62-int(input))>>8)&3)
|
||||
}
|
||||
b := []byte("peer(____…____)")
|
||||
const first = len("peer(")
|
||||
const second = len("peer(____…")
|
||||
b[first+0] = b64((src[0] >> 2) & 63)
|
||||
b[first+1] = b64(((src[0] << 4) | (src[1] >> 4)) & 63)
|
||||
b[first+2] = b64(((src[1] << 2) | (src[2] >> 6)) & 63)
|
||||
b[first+3] = b64(src[2] & 63)
|
||||
b[second+0] = b64(src[29] & 63)
|
||||
b[second+1] = b64((src[30] >> 2) & 63)
|
||||
b[second+2] = b64(((src[30] << 4) | (src[31] >> 4)) & 63)
|
||||
b[second+3] = b64((src[31] << 2) & 63)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func (peer *Peer) Start() {
|
||||
// should never start a peer on a closed device
|
||||
if peer.device.isClosed() {
|
||||
return
|
||||
}
|
||||
|
||||
// prevent simultaneous start/stop operations
|
||||
peer.state.Lock()
|
||||
defer peer.state.Unlock()
|
||||
|
||||
if peer.isRunning.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
device := peer.device
|
||||
device.log.Verbosef("%v - Starting", peer)
|
||||
|
||||
// reset routine state
|
||||
peer.stopping.Wait()
|
||||
peer.stopping.Add(2)
|
||||
|
||||
peer.handshake.mutex.Lock()
|
||||
peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
|
||||
peer.handshake.mutex.Unlock()
|
||||
|
||||
peer.device.queue.encryption.wg.Add(1) // keep encryption queue open for our writes
|
||||
|
||||
peer.timersStart()
|
||||
|
||||
device.flushInboundQueue(peer.queue.inbound)
|
||||
device.flushOutboundQueue(peer.queue.outbound)
|
||||
|
||||
// Use the device batch size, not the bind batch size, as the device size is
|
||||
// the size of the batch pools.
|
||||
batchSize := peer.device.BatchSize()
|
||||
go peer.RoutineSequentialSender(batchSize)
|
||||
go peer.RoutineSequentialReceiver(batchSize)
|
||||
|
||||
peer.isRunning.Store(true)
|
||||
}
|
||||
|
||||
func (peer *Peer) ZeroAndFlushAll() {
|
||||
device := peer.device
|
||||
|
||||
// clear key pairs
|
||||
|
||||
keypairs := &peer.keypairs
|
||||
keypairs.Lock()
|
||||
device.DeleteKeypair(keypairs.previous)
|
||||
device.DeleteKeypair(keypairs.current)
|
||||
device.DeleteKeypair(keypairs.next.Load())
|
||||
keypairs.previous = nil
|
||||
keypairs.current = nil
|
||||
keypairs.next.Store(nil)
|
||||
keypairs.Unlock()
|
||||
|
||||
// clear handshake state
|
||||
|
||||
handshake := &peer.handshake
|
||||
handshake.mutex.Lock()
|
||||
device.indexTable.Delete(handshake.localIndex)
|
||||
handshake.Clear()
|
||||
handshake.mutex.Unlock()
|
||||
|
||||
peer.FlushStagedPackets()
|
||||
}
|
||||
|
||||
func (peer *Peer) ExpireCurrentKeypairs() {
|
||||
handshake := &peer.handshake
|
||||
handshake.mutex.Lock()
|
||||
peer.device.indexTable.Delete(handshake.localIndex)
|
||||
handshake.Clear()
|
||||
peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
|
||||
handshake.mutex.Unlock()
|
||||
|
||||
keypairs := &peer.keypairs
|
||||
keypairs.Lock()
|
||||
if keypairs.current != nil {
|
||||
keypairs.current.sendNonce.Store(RejectAfterMessages)
|
||||
}
|
||||
if next := keypairs.next.Load(); next != nil {
|
||||
next.sendNonce.Store(RejectAfterMessages)
|
||||
}
|
||||
keypairs.Unlock()
|
||||
}
|
||||
|
||||
func (peer *Peer) Stop() {
|
||||
peer.state.Lock()
|
||||
defer peer.state.Unlock()
|
||||
|
||||
if !peer.isRunning.Swap(false) {
|
||||
return
|
||||
}
|
||||
|
||||
peer.device.log.Verbosef("%v - Stopping", peer)
|
||||
|
||||
peer.timersStop()
|
||||
// Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit.
|
||||
peer.queue.inbound.c <- nil
|
||||
peer.queue.outbound.c <- nil
|
||||
peer.stopping.Wait()
|
||||
peer.device.queue.encryption.wg.Done() // no more writes to encryption queue from us
|
||||
|
||||
peer.ZeroAndFlushAll()
|
||||
}
|
||||
|
||||
func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
|
||||
peer.endpoint.Lock()
|
||||
defer peer.endpoint.Unlock()
|
||||
if peer.endpoint.disableRoaming {
|
||||
return
|
||||
}
|
||||
peer.endpoint.clearSrcOnTx = false
|
||||
if ep, ok := endpoint.(conn.PeerAwareEndpoint); ok {
|
||||
endpoint = ep.GetPeerEndpoint(peer.handshake.remoteStatic)
|
||||
}
|
||||
peer.endpoint.val = endpoint
|
||||
}
|
||||
|
||||
func (peer *Peer) markEndpointSrcForClearing() {
|
||||
peer.endpoint.Lock()
|
||||
defer peer.endpoint.Unlock()
|
||||
if peer.endpoint.val == nil {
|
||||
return
|
||||
}
|
||||
peer.endpoint.clearSrcOnTx = true
|
||||
}
|
||||
121
vendor/github.com/tailscale/wireguard-go/device/pools.go
generated
vendored
Normal file
121
vendor/github.com/tailscale/wireguard-go/device/pools.go
generated
vendored
Normal file
@@ -0,0 +1,121 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type WaitPool struct {
|
||||
pool sync.Pool
|
||||
cond sync.Cond
|
||||
lock sync.Mutex
|
||||
count uint32 // Get calls not yet Put back
|
||||
max uint32
|
||||
}
|
||||
|
||||
func NewWaitPool(max uint32, new func() any) *WaitPool {
|
||||
p := &WaitPool{pool: sync.Pool{New: new}, max: max}
|
||||
p.cond = sync.Cond{L: &p.lock}
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *WaitPool) Get() any {
|
||||
if p.max != 0 {
|
||||
p.lock.Lock()
|
||||
for p.count >= p.max {
|
||||
p.cond.Wait()
|
||||
}
|
||||
p.count++
|
||||
p.lock.Unlock()
|
||||
}
|
||||
return p.pool.Get()
|
||||
}
|
||||
|
||||
func (p *WaitPool) Put(x any) {
|
||||
p.pool.Put(x)
|
||||
if p.max == 0 {
|
||||
return
|
||||
}
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
p.count--
|
||||
p.cond.Signal()
|
||||
}
|
||||
|
||||
func (device *Device) PopulatePools() {
|
||||
device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||
s := make([]*QueueInboundElement, 0, device.BatchSize())
|
||||
return &QueueInboundElementsContainer{elems: s}
|
||||
})
|
||||
device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||
s := make([]*QueueOutboundElement, 0, device.BatchSize())
|
||||
return &QueueOutboundElementsContainer{elems: s}
|
||||
})
|
||||
device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||
return new([MaxMessageSize]byte)
|
||||
})
|
||||
device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||
return new(QueueInboundElement)
|
||||
})
|
||||
device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||
return new(QueueOutboundElement)
|
||||
})
|
||||
}
|
||||
|
||||
func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer {
|
||||
c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer)
|
||||
c.Mutex = sync.Mutex{}
|
||||
return c
|
||||
}
|
||||
|
||||
func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) {
|
||||
for i := range c.elems {
|
||||
c.elems[i] = nil
|
||||
}
|
||||
c.elems = c.elems[:0]
|
||||
device.pool.inboundElementsContainer.Put(c)
|
||||
}
|
||||
|
||||
func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer {
|
||||
c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer)
|
||||
c.Mutex = sync.Mutex{}
|
||||
return c
|
||||
}
|
||||
|
||||
func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) {
|
||||
for i := range c.elems {
|
||||
c.elems[i] = nil
|
||||
}
|
||||
c.elems = c.elems[:0]
|
||||
device.pool.outboundElementsContainer.Put(c)
|
||||
}
|
||||
|
||||
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
|
||||
return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
|
||||
}
|
||||
|
||||
func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
|
||||
device.pool.messageBuffers.Put(msg)
|
||||
}
|
||||
|
||||
func (device *Device) GetInboundElement() *QueueInboundElement {
|
||||
return device.pool.inboundElements.Get().(*QueueInboundElement)
|
||||
}
|
||||
|
||||
func (device *Device) PutInboundElement(elem *QueueInboundElement) {
|
||||
elem.clearPointers()
|
||||
device.pool.inboundElements.Put(elem)
|
||||
}
|
||||
|
||||
func (device *Device) GetOutboundElement() *QueueOutboundElement {
|
||||
return device.pool.outboundElements.Get().(*QueueOutboundElement)
|
||||
}
|
||||
|
||||
func (device *Device) PutOutboundElement(elem *QueueOutboundElement) {
|
||||
elem.clearPointers()
|
||||
device.pool.outboundElements.Put(elem)
|
||||
}
|
||||
19
vendor/github.com/tailscale/wireguard-go/device/queueconstants_android.go
generated
vendored
Normal file
19
vendor/github.com/tailscale/wireguard-go/device/queueconstants_android.go
generated
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import "github.com/tailscale/wireguard-go/conn"
|
||||
|
||||
/* Reduce memory consumption for Android */
|
||||
|
||||
const (
|
||||
QueueStagedSize = conn.IdealBatchSize
|
||||
QueueOutboundSize = 1024
|
||||
QueueInboundSize = 1024
|
||||
QueueHandshakeSize = 1024
|
||||
MaxSegmentSize = 2200
|
||||
PreallocatedBuffersPerPool = 4096
|
||||
)
|
||||
19
vendor/github.com/tailscale/wireguard-go/device/queueconstants_default.go
generated
vendored
Normal file
19
vendor/github.com/tailscale/wireguard-go/device/queueconstants_default.go
generated
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
//go:build !android && !ios && !windows
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import "github.com/tailscale/wireguard-go/conn"
|
||||
|
||||
const (
|
||||
QueueStagedSize = conn.IdealBatchSize
|
||||
QueueOutboundSize = 1024
|
||||
QueueInboundSize = 1024
|
||||
QueueHandshakeSize = 1024
|
||||
MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram
|
||||
PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth
|
||||
)
|
||||
21
vendor/github.com/tailscale/wireguard-go/device/queueconstants_ios.go
generated
vendored
Normal file
21
vendor/github.com/tailscale/wireguard-go/device/queueconstants_ios.go
generated
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
//go:build ios
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
// Fit within memory limits for iOS's Network Extension API, which has stricter requirements.
|
||||
// These are vars instead of consts, because heavier network extensions might want to reduce
|
||||
// them further.
|
||||
var (
|
||||
QueueStagedSize = 128
|
||||
QueueOutboundSize = 1024
|
||||
QueueInboundSize = 1024
|
||||
QueueHandshakeSize = 1024
|
||||
PreallocatedBuffersPerPool uint32 = 1024
|
||||
)
|
||||
|
||||
const MaxSegmentSize = 1700
|
||||
15
vendor/github.com/tailscale/wireguard-go/device/queueconstants_windows.go
generated
vendored
Normal file
15
vendor/github.com/tailscale/wireguard-go/device/queueconstants_windows.go
generated
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
const (
|
||||
QueueStagedSize = 128
|
||||
QueueOutboundSize = 1024
|
||||
QueueInboundSize = 1024
|
||||
QueueHandshakeSize = 1024
|
||||
MaxSegmentSize = 2048 - 32 // largest possible UDP datagram
|
||||
PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth
|
||||
)
|
||||
537
vendor/github.com/tailscale/wireguard-go/device/receive.go
generated
vendored
Normal file
537
vendor/github.com/tailscale/wireguard-go/device/receive.go
generated
vendored
Normal file
@@ -0,0 +1,537 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/wireguard-go/conn"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
type QueueHandshakeElement struct {
|
||||
msgType uint32
|
||||
packet []byte
|
||||
endpoint conn.Endpoint
|
||||
buffer *[MaxMessageSize]byte
|
||||
}
|
||||
|
||||
type QueueInboundElement struct {
|
||||
buffer *[MaxMessageSize]byte
|
||||
packet []byte
|
||||
counter uint64
|
||||
keypair *Keypair
|
||||
endpoint conn.Endpoint
|
||||
}
|
||||
|
||||
type QueueInboundElementsContainer struct {
|
||||
sync.Mutex
|
||||
elems []*QueueInboundElement
|
||||
}
|
||||
|
||||
// clearPointers clears elem fields that contain pointers.
|
||||
// This makes the garbage collector's life easier and
|
||||
// avoids accidentally keeping other objects around unnecessarily.
|
||||
// It also reduces the possible collateral damage from use-after-free bugs.
|
||||
func (elem *QueueInboundElement) clearPointers() {
|
||||
elem.buffer = nil
|
||||
elem.packet = nil
|
||||
elem.keypair = nil
|
||||
elem.endpoint = nil
|
||||
}
|
||||
|
||||
/* Called when a new authenticated message has been received
|
||||
*
|
||||
* NOTE: Not thread safe, but called by sequential receiver!
|
||||
*/
|
||||
func (peer *Peer) keepKeyFreshReceiving() {
|
||||
if peer.timers.sentLastMinuteHandshake.Load() {
|
||||
return
|
||||
}
|
||||
keypair := peer.keypairs.Current()
|
||||
if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
|
||||
peer.timers.sentLastMinuteHandshake.Store(true)
|
||||
peer.SendHandshakeInitiation(false)
|
||||
}
|
||||
}
|
||||
|
||||
/* Receives incoming datagrams for the device
|
||||
*
|
||||
* Every time the bind is updated a new routine is started for
|
||||
* IPv4 and IPv6 (separately)
|
||||
*/
|
||||
func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) {
|
||||
recvName := recv.PrettyName()
|
||||
defer func() {
|
||||
device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
|
||||
device.queue.decryption.wg.Done()
|
||||
device.queue.handshake.wg.Done()
|
||||
device.net.stopping.Done()
|
||||
}()
|
||||
|
||||
device.log.Verbosef("Routine: receive incoming %s - started", recvName)
|
||||
|
||||
// receive datagrams until conn is closed
|
||||
|
||||
var (
|
||||
bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize)
|
||||
bufs = make([][]byte, maxBatchSize)
|
||||
err error
|
||||
sizes = make([]int, maxBatchSize)
|
||||
count int
|
||||
endpoints = make([]conn.Endpoint, maxBatchSize)
|
||||
deathSpiral int
|
||||
elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize)
|
||||
)
|
||||
|
||||
for i := range bufsArrs {
|
||||
bufsArrs[i] = device.GetMessageBuffer()
|
||||
bufs[i] = bufsArrs[i][:]
|
||||
}
|
||||
|
||||
defer func() {
|
||||
for i := 0; i < maxBatchSize; i++ {
|
||||
if bufsArrs[i] != nil {
|
||||
device.PutMessageBuffer(bufsArrs[i])
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
count, err = recv(bufs, sizes, endpoints)
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
device.log.Verbosef("Failed to receive %s packet: %v", recvName, err)
|
||||
if neterr, ok := err.(net.Error); ok && !neterr.Temporary() {
|
||||
return
|
||||
}
|
||||
if deathSpiral < 10 {
|
||||
deathSpiral++
|
||||
time.Sleep(time.Second / 3)
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
deathSpiral = 0
|
||||
|
||||
// handle each packet in the batch
|
||||
for i, size := range sizes[:count] {
|
||||
if size < MinMessageSize {
|
||||
continue
|
||||
}
|
||||
|
||||
// check size of packet
|
||||
|
||||
packet := bufsArrs[i][:size]
|
||||
msgType := binary.LittleEndian.Uint32(packet[:4])
|
||||
|
||||
switch msgType {
|
||||
|
||||
// check if transport
|
||||
|
||||
case MessageTransportType:
|
||||
|
||||
// check size
|
||||
|
||||
if len(packet) < MessageTransportSize {
|
||||
continue
|
||||
}
|
||||
|
||||
// lookup key pair
|
||||
|
||||
receiver := binary.LittleEndian.Uint32(
|
||||
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
|
||||
)
|
||||
value := device.indexTable.Lookup(receiver)
|
||||
keypair := value.keypair
|
||||
if keypair == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// check keypair expiry
|
||||
|
||||
if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
|
||||
continue
|
||||
}
|
||||
|
||||
// create work element
|
||||
peer := value.peer
|
||||
elem := device.GetInboundElement()
|
||||
elem.packet = packet
|
||||
elem.buffer = bufsArrs[i]
|
||||
elem.keypair = keypair
|
||||
elem.endpoint = endpoints[i]
|
||||
elem.counter = 0
|
||||
|
||||
elemsForPeer, ok := elemsByPeer[peer]
|
||||
if !ok {
|
||||
elemsForPeer = device.GetInboundElementsContainer()
|
||||
elemsForPeer.Lock()
|
||||
elemsByPeer[peer] = elemsForPeer
|
||||
}
|
||||
elemsForPeer.elems = append(elemsForPeer.elems, elem)
|
||||
bufsArrs[i] = device.GetMessageBuffer()
|
||||
bufs[i] = bufsArrs[i][:]
|
||||
continue
|
||||
|
||||
// otherwise it is a fixed size & handshake related packet
|
||||
|
||||
case MessageInitiationType:
|
||||
if len(packet) != MessageInitiationSize {
|
||||
continue
|
||||
}
|
||||
|
||||
case MessageResponseType:
|
||||
if len(packet) != MessageResponseSize {
|
||||
continue
|
||||
}
|
||||
|
||||
case MessageCookieReplyType:
|
||||
if len(packet) != MessageCookieReplySize {
|
||||
continue
|
||||
}
|
||||
|
||||
default:
|
||||
device.log.Verbosef("Received message with unknown type")
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case device.queue.handshake.c <- QueueHandshakeElement{
|
||||
msgType: msgType,
|
||||
buffer: bufsArrs[i],
|
||||
packet: packet,
|
||||
endpoint: endpoints[i],
|
||||
}:
|
||||
bufsArrs[i] = device.GetMessageBuffer()
|
||||
bufs[i] = bufsArrs[i][:]
|
||||
default:
|
||||
}
|
||||
}
|
||||
for peer, elemsContainer := range elemsByPeer {
|
||||
if peer.isRunning.Load() {
|
||||
peer.queue.inbound.c <- elemsContainer
|
||||
device.queue.decryption.c <- elemsContainer
|
||||
} else {
|
||||
for _, elem := range elemsContainer.elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutInboundElement(elem)
|
||||
}
|
||||
device.PutInboundElementsContainer(elemsContainer)
|
||||
}
|
||||
delete(elemsByPeer, peer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (device *Device) RoutineDecryption(id int) {
|
||||
var nonce [chacha20poly1305.NonceSize]byte
|
||||
|
||||
defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
|
||||
device.log.Verbosef("Routine: decryption worker %d - started", id)
|
||||
|
||||
for elemsContainer := range device.queue.decryption.c {
|
||||
for _, elem := range elemsContainer.elems {
|
||||
// split message into fields
|
||||
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
|
||||
content := elem.packet[MessageTransportOffsetContent:]
|
||||
|
||||
// decrypt and release to consumer
|
||||
var err error
|
||||
elem.counter = binary.LittleEndian.Uint64(counter)
|
||||
// copy counter to nonce
|
||||
binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
|
||||
elem.packet, err = elem.keypair.receive.Open(
|
||||
content[:0],
|
||||
nonce[:],
|
||||
content,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
elem.packet = nil
|
||||
}
|
||||
}
|
||||
elemsContainer.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
/* Handles incoming packets related to handshake
|
||||
*/
|
||||
func (device *Device) RoutineHandshake(id int) {
|
||||
defer func() {
|
||||
device.log.Verbosef("Routine: handshake worker %d - stopped", id)
|
||||
device.queue.encryption.wg.Done()
|
||||
}()
|
||||
device.log.Verbosef("Routine: handshake worker %d - started", id)
|
||||
|
||||
for elem := range device.queue.handshake.c {
|
||||
|
||||
// handle cookie fields and ratelimiting
|
||||
|
||||
switch elem.msgType {
|
||||
|
||||
case MessageCookieReplyType:
|
||||
|
||||
// unmarshal packet
|
||||
|
||||
var reply MessageCookieReply
|
||||
reader := bytes.NewReader(elem.packet)
|
||||
err := binary.Read(reader, binary.LittleEndian, &reply)
|
||||
if err != nil {
|
||||
device.log.Verbosef("Failed to decode cookie reply")
|
||||
goto skip
|
||||
}
|
||||
|
||||
// lookup peer from index
|
||||
|
||||
entry := device.indexTable.Lookup(reply.Receiver)
|
||||
|
||||
if entry.peer == nil {
|
||||
goto skip
|
||||
}
|
||||
|
||||
// consume reply
|
||||
|
||||
if peer := entry.peer; peer.isRunning.Load() {
|
||||
device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString())
|
||||
if !peer.cookieGenerator.ConsumeReply(&reply) {
|
||||
device.log.Verbosef("Could not decrypt invalid cookie response")
|
||||
}
|
||||
}
|
||||
|
||||
goto skip
|
||||
|
||||
case MessageInitiationType, MessageResponseType:
|
||||
|
||||
// check mac fields and maybe ratelimit
|
||||
|
||||
if !device.cookieChecker.CheckMAC1(elem.packet) {
|
||||
device.log.Verbosef("Received packet with invalid mac1")
|
||||
goto skip
|
||||
}
|
||||
|
||||
// endpoints destination address is the source of the datagram
|
||||
|
||||
if device.IsUnderLoad() {
|
||||
|
||||
// verify MAC2 field
|
||||
|
||||
if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
|
||||
device.SendHandshakeCookie(&elem)
|
||||
goto skip
|
||||
}
|
||||
|
||||
// check ratelimiter
|
||||
|
||||
if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
|
||||
goto skip
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
device.log.Errorf("Invalid packet ended up in the handshake queue")
|
||||
goto skip
|
||||
}
|
||||
|
||||
// handle handshake initiation/response content
|
||||
|
||||
switch elem.msgType {
|
||||
case MessageInitiationType:
|
||||
|
||||
// unmarshal
|
||||
|
||||
var msg MessageInitiation
|
||||
reader := bytes.NewReader(elem.packet)
|
||||
err := binary.Read(reader, binary.LittleEndian, &msg)
|
||||
if err != nil {
|
||||
device.log.Errorf("Failed to decode initiation message")
|
||||
goto skip
|
||||
}
|
||||
|
||||
// consume initiation
|
||||
|
||||
peer := device.ConsumeMessageInitiation(&msg)
|
||||
if peer == nil {
|
||||
device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
|
||||
goto skip
|
||||
}
|
||||
|
||||
// update timers
|
||||
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketReceived()
|
||||
|
||||
// update endpoint
|
||||
peer.SetEndpointFromPacket(elem.endpoint)
|
||||
|
||||
device.log.Verbosef("%v - Received handshake initiation", peer)
|
||||
peer.rxBytes.Add(uint64(len(elem.packet)))
|
||||
|
||||
peer.SendHandshakeResponse()
|
||||
|
||||
case MessageResponseType:
|
||||
|
||||
// unmarshal
|
||||
|
||||
var msg MessageResponse
|
||||
reader := bytes.NewReader(elem.packet)
|
||||
err := binary.Read(reader, binary.LittleEndian, &msg)
|
||||
if err != nil {
|
||||
device.log.Errorf("Failed to decode response message")
|
||||
goto skip
|
||||
}
|
||||
|
||||
// consume response
|
||||
|
||||
peer := device.ConsumeMessageResponse(&msg)
|
||||
if peer == nil {
|
||||
device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString())
|
||||
goto skip
|
||||
}
|
||||
|
||||
// update endpoint
|
||||
peer.SetEndpointFromPacket(elem.endpoint)
|
||||
|
||||
device.log.Verbosef("%v - Received handshake response", peer)
|
||||
peer.rxBytes.Add(uint64(len(elem.packet)))
|
||||
|
||||
// update timers
|
||||
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketReceived()
|
||||
|
||||
// derive keypair
|
||||
|
||||
err = peer.BeginSymmetricSession()
|
||||
|
||||
if err != nil {
|
||||
device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
|
||||
goto skip
|
||||
}
|
||||
|
||||
peer.timersSessionDerived()
|
||||
peer.timersHandshakeComplete()
|
||||
peer.SendKeepalive()
|
||||
}
|
||||
skip:
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
|
||||
device := peer.device
|
||||
defer func() {
|
||||
device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
|
||||
peer.stopping.Done()
|
||||
}()
|
||||
device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
|
||||
|
||||
bufs := make([][]byte, 0, maxBatchSize)
|
||||
|
||||
for elemsContainer := range peer.queue.inbound.c {
|
||||
if elemsContainer == nil {
|
||||
return
|
||||
}
|
||||
elemsContainer.Lock()
|
||||
validTailPacket := -1
|
||||
dataPacketReceived := false
|
||||
for i, elem := range elemsContainer.elems {
|
||||
if elem.packet == nil {
|
||||
// decryption failed
|
||||
continue
|
||||
}
|
||||
|
||||
if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
|
||||
continue
|
||||
}
|
||||
|
||||
validTailPacket = i
|
||||
if peer.ReceivedWithKeypair(elem.keypair) {
|
||||
peer.SetEndpointFromPacket(elem.endpoint)
|
||||
peer.timersHandshakeComplete()
|
||||
peer.SendStagedPackets()
|
||||
}
|
||||
peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize))
|
||||
|
||||
if len(elem.packet) == 0 {
|
||||
device.log.Verbosef("%v - Receiving keepalive packet", peer)
|
||||
continue
|
||||
}
|
||||
dataPacketReceived = true
|
||||
|
||||
switch elem.packet[0] >> 4 {
|
||||
case 4:
|
||||
if len(elem.packet) < ipv4.HeaderLen {
|
||||
continue
|
||||
}
|
||||
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
|
||||
length := binary.BigEndian.Uint16(field)
|
||||
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
|
||||
continue
|
||||
}
|
||||
elem.packet = elem.packet[:length]
|
||||
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
|
||||
if device.allowedips.Lookup(src) != peer {
|
||||
device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
|
||||
continue
|
||||
}
|
||||
|
||||
case 6:
|
||||
if len(elem.packet) < ipv6.HeaderLen {
|
||||
continue
|
||||
}
|
||||
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
|
||||
length := binary.BigEndian.Uint16(field)
|
||||
length += ipv6.HeaderLen
|
||||
if int(length) > len(elem.packet) {
|
||||
continue
|
||||
}
|
||||
elem.packet = elem.packet[:length]
|
||||
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
|
||||
if device.allowedips.Lookup(src) != peer {
|
||||
device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
|
||||
continue
|
||||
}
|
||||
|
||||
default:
|
||||
device.log.Verbosef("Packet with invalid IP version from %v", peer)
|
||||
continue
|
||||
}
|
||||
|
||||
bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)])
|
||||
}
|
||||
if validTailPacket >= 0 {
|
||||
peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint)
|
||||
peer.keepKeyFreshReceiving()
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketReceived()
|
||||
}
|
||||
if dataPacketReceived {
|
||||
peer.timersDataReceived()
|
||||
}
|
||||
if len(bufs) > 0 {
|
||||
_, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
|
||||
if err != nil && !device.isClosed() {
|
||||
device.log.Errorf("Failed to write packets to TUN device: %v", err)
|
||||
}
|
||||
}
|
||||
for _, elem := range elemsContainer.elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutInboundElement(elem)
|
||||
}
|
||||
bufs = bufs[:0]
|
||||
device.PutInboundElementsContainer(elemsContainer)
|
||||
}
|
||||
}
|
||||
547
vendor/github.com/tailscale/wireguard-go/device/send.go
generated
vendored
Normal file
547
vendor/github.com/tailscale/wireguard-go/device/send.go
generated
vendored
Normal file
@@ -0,0 +1,547 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/wireguard-go/conn"
|
||||
"github.com/tailscale/wireguard-go/tun"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
/* Outbound flow
|
||||
*
|
||||
* 1. TUN queue
|
||||
* 2. Routing (sequential)
|
||||
* 3. Nonce assignment (sequential)
|
||||
* 4. Encryption (parallel)
|
||||
* 5. Transmission (sequential)
|
||||
*
|
||||
* The functions in this file occur (roughly) in the order in
|
||||
* which the packets are processed.
|
||||
*
|
||||
* Locking, Producers and Consumers
|
||||
*
|
||||
* The order of packets (per peer) must be maintained,
|
||||
* but encryption of packets happen out-of-order:
|
||||
*
|
||||
* The sequential consumers will attempt to take the lock,
|
||||
* workers release lock when they have completed work (encryption) on the packet.
|
||||
*
|
||||
* If the element is inserted into the "encryption queue",
|
||||
* the content is preceded by enough "junk" to contain the transport header
|
||||
* (to allow the construction of transport messages in-place)
|
||||
*/
|
||||
|
||||
type QueueOutboundElement struct {
|
||||
buffer *[MaxMessageSize]byte // slice holding the packet data
|
||||
packet []byte // slice of "buffer" (always!)
|
||||
nonce uint64 // nonce for encryption
|
||||
keypair *Keypair // keypair for encryption
|
||||
peer *Peer // related peer
|
||||
}
|
||||
|
||||
type QueueOutboundElementsContainer struct {
|
||||
sync.Mutex
|
||||
elems []*QueueOutboundElement
|
||||
}
|
||||
|
||||
func (device *Device) NewOutboundElement() *QueueOutboundElement {
|
||||
elem := device.GetOutboundElement()
|
||||
elem.buffer = device.GetMessageBuffer()
|
||||
elem.nonce = 0
|
||||
// keypair and peer were cleared (if necessary) by clearPointers.
|
||||
return elem
|
||||
}
|
||||
|
||||
// clearPointers clears elem fields that contain pointers.
|
||||
// This makes the garbage collector's life easier and
|
||||
// avoids accidentally keeping other objects around unnecessarily.
|
||||
// It also reduces the possible collateral damage from use-after-free bugs.
|
||||
func (elem *QueueOutboundElement) clearPointers() {
|
||||
elem.buffer = nil
|
||||
elem.packet = nil
|
||||
elem.keypair = nil
|
||||
elem.peer = nil
|
||||
}
|
||||
|
||||
/* Queues a keepalive if no packets are queued for peer
|
||||
*/
|
||||
func (peer *Peer) SendKeepalive() {
|
||||
if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
|
||||
elem := peer.device.NewOutboundElement()
|
||||
elemsContainer := peer.device.GetOutboundElementsContainer()
|
||||
elemsContainer.elems = append(elemsContainer.elems, elem)
|
||||
select {
|
||||
case peer.queue.staged <- elemsContainer:
|
||||
peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
|
||||
default:
|
||||
peer.device.PutMessageBuffer(elem.buffer)
|
||||
peer.device.PutOutboundElement(elem)
|
||||
peer.device.PutOutboundElementsContainer(elemsContainer)
|
||||
}
|
||||
}
|
||||
peer.SendStagedPackets()
|
||||
}
|
||||
|
||||
func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
||||
if !isRetry {
|
||||
peer.timers.handshakeAttempts.Store(0)
|
||||
}
|
||||
|
||||
peer.handshake.mutex.RLock()
|
||||
if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
|
||||
peer.handshake.mutex.RUnlock()
|
||||
return nil
|
||||
}
|
||||
peer.handshake.mutex.RUnlock()
|
||||
|
||||
peer.handshake.mutex.Lock()
|
||||
if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
|
||||
peer.handshake.mutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
peer.handshake.lastSentHandshake = time.Now()
|
||||
peer.handshake.mutex.Unlock()
|
||||
|
||||
peer.device.log.Verbosef("%v - Sending handshake initiation", peer)
|
||||
|
||||
msg, err := peer.device.CreateMessageInitiation(peer)
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err)
|
||||
return err
|
||||
}
|
||||
|
||||
var buf [MessageInitiationSize]byte
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
binary.Write(writer, binary.LittleEndian, msg)
|
||||
packet := writer.Bytes()
|
||||
peer.cookieGenerator.AddMacs(packet)
|
||||
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketSent()
|
||||
|
||||
err = peer.SendBuffers([][]byte{packet})
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
|
||||
}
|
||||
peer.timersHandshakeInitiated()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (peer *Peer) SendHandshakeResponse() error {
|
||||
peer.handshake.mutex.Lock()
|
||||
peer.handshake.lastSentHandshake = time.Now()
|
||||
peer.handshake.mutex.Unlock()
|
||||
|
||||
peer.device.log.Verbosef("%v - Sending handshake response", peer)
|
||||
|
||||
response, err := peer.device.CreateMessageResponse(peer)
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err)
|
||||
return err
|
||||
}
|
||||
|
||||
var buf [MessageResponseSize]byte
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
binary.Write(writer, binary.LittleEndian, response)
|
||||
packet := writer.Bytes()
|
||||
peer.cookieGenerator.AddMacs(packet)
|
||||
|
||||
err = peer.BeginSymmetricSession()
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
|
||||
return err
|
||||
}
|
||||
|
||||
peer.timersSessionDerived()
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketSent()
|
||||
|
||||
// TODO: allocation could be avoided
|
||||
err = peer.SendBuffers([][]byte{packet})
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error {
|
||||
device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
|
||||
|
||||
sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
|
||||
reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes())
|
||||
if err != nil {
|
||||
device.log.Errorf("Failed to create cookie reply: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
var buf [MessageCookieReplySize]byte
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
binary.Write(writer, binary.LittleEndian, reply)
|
||||
// TODO: allocation could be avoided
|
||||
device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (peer *Peer) keepKeyFreshSending() {
|
||||
keypair := peer.keypairs.Current()
|
||||
if keypair == nil {
|
||||
return
|
||||
}
|
||||
nonce := keypair.sendNonce.Load()
|
||||
if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
|
||||
peer.SendHandshakeInitiation(false)
|
||||
}
|
||||
}
|
||||
|
||||
func (device *Device) RoutineReadFromTUN() {
|
||||
defer func() {
|
||||
device.log.Verbosef("Routine: TUN reader - stopped")
|
||||
device.state.stopping.Done()
|
||||
device.queue.encryption.wg.Done()
|
||||
}()
|
||||
|
||||
device.log.Verbosef("Routine: TUN reader - started")
|
||||
|
||||
var (
|
||||
batchSize = device.BatchSize()
|
||||
readErr error
|
||||
elems = make([]*QueueOutboundElement, batchSize)
|
||||
bufs = make([][]byte, batchSize)
|
||||
elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize)
|
||||
count = 0
|
||||
sizes = make([]int, batchSize)
|
||||
offset = MessageTransportHeaderSize
|
||||
)
|
||||
|
||||
for i := range elems {
|
||||
elems[i] = device.NewOutboundElement()
|
||||
bufs[i] = elems[i].buffer[:]
|
||||
}
|
||||
|
||||
defer func() {
|
||||
for _, elem := range elems {
|
||||
if elem != nil {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
// read packets
|
||||
count, readErr = device.tun.device.Read(bufs, sizes, offset)
|
||||
for i := 0; i < count; i++ {
|
||||
if sizes[i] < 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
elem := elems[i]
|
||||
elem.packet = bufs[i][offset : offset+sizes[i]]
|
||||
|
||||
// lookup peer
|
||||
var peer *Peer
|
||||
switch elem.packet[0] >> 4 {
|
||||
case 4:
|
||||
if len(elem.packet) < ipv4.HeaderLen {
|
||||
continue
|
||||
}
|
||||
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
|
||||
peer = device.allowedips.Lookup(dst)
|
||||
|
||||
case 6:
|
||||
if len(elem.packet) < ipv6.HeaderLen {
|
||||
continue
|
||||
}
|
||||
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
|
||||
peer = device.allowedips.Lookup(dst)
|
||||
|
||||
default:
|
||||
device.log.Verbosef("Received packet with unknown IP version")
|
||||
}
|
||||
|
||||
if peer == nil {
|
||||
continue
|
||||
}
|
||||
elemsForPeer, ok := elemsByPeer[peer]
|
||||
if !ok {
|
||||
elemsForPeer = device.GetOutboundElementsContainer()
|
||||
elemsByPeer[peer] = elemsForPeer
|
||||
}
|
||||
elemsForPeer.elems = append(elemsForPeer.elems, elem)
|
||||
elems[i] = device.NewOutboundElement()
|
||||
bufs[i] = elems[i].buffer[:]
|
||||
}
|
||||
|
||||
for peer, elemsForPeer := range elemsByPeer {
|
||||
if peer.isRunning.Load() {
|
||||
peer.StagePackets(elemsForPeer)
|
||||
peer.SendStagedPackets()
|
||||
} else {
|
||||
for _, elem := range elemsForPeer.elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
}
|
||||
device.PutOutboundElementsContainer(elemsForPeer)
|
||||
}
|
||||
delete(elemsByPeer, peer)
|
||||
}
|
||||
|
||||
if readErr != nil {
|
||||
if errors.Is(readErr, tun.ErrTooManySegments) {
|
||||
// TODO: record stat for this
|
||||
// This will happen if MSS is surprisingly small (< 576)
|
||||
// coincident with reasonably high throughput.
|
||||
device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
|
||||
continue
|
||||
}
|
||||
if !device.isClosed() {
|
||||
if !errors.Is(readErr, os.ErrClosed) {
|
||||
device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
|
||||
}
|
||||
go device.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
|
||||
for {
|
||||
select {
|
||||
case peer.queue.staged <- elems:
|
||||
return
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case tooOld := <-peer.queue.staged:
|
||||
for _, elem := range tooOld.elems {
|
||||
peer.device.PutMessageBuffer(elem.buffer)
|
||||
peer.device.PutOutboundElement(elem)
|
||||
}
|
||||
peer.device.PutOutboundElementsContainer(tooOld)
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *Peer) SendStagedPackets() {
|
||||
top:
|
||||
if len(peer.queue.staged) == 0 || !peer.device.isUp() {
|
||||
return
|
||||
}
|
||||
|
||||
keypair := peer.keypairs.Current()
|
||||
if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
|
||||
peer.SendHandshakeInitiation(false)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
var elemsContainerOOO *QueueOutboundElementsContainer
|
||||
select {
|
||||
case elemsContainer := <-peer.queue.staged:
|
||||
i := 0
|
||||
for _, elem := range elemsContainer.elems {
|
||||
elem.peer = peer
|
||||
elem.nonce = keypair.sendNonce.Add(1) - 1
|
||||
if elem.nonce >= RejectAfterMessages {
|
||||
keypair.sendNonce.Store(RejectAfterMessages)
|
||||
if elemsContainerOOO == nil {
|
||||
elemsContainerOOO = peer.device.GetOutboundElementsContainer()
|
||||
}
|
||||
elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
|
||||
continue
|
||||
} else {
|
||||
elemsContainer.elems[i] = elem
|
||||
i++
|
||||
}
|
||||
|
||||
elem.keypair = keypair
|
||||
}
|
||||
elemsContainer.Lock()
|
||||
elemsContainer.elems = elemsContainer.elems[:i]
|
||||
|
||||
if elemsContainerOOO != nil {
|
||||
peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
|
||||
}
|
||||
|
||||
if len(elemsContainer.elems) == 0 {
|
||||
peer.device.PutOutboundElementsContainer(elemsContainer)
|
||||
goto top
|
||||
}
|
||||
|
||||
// add to parallel and sequential queue
|
||||
if peer.isRunning.Load() {
|
||||
peer.queue.outbound.c <- elemsContainer
|
||||
peer.device.queue.encryption.c <- elemsContainer
|
||||
} else {
|
||||
for _, elem := range elemsContainer.elems {
|
||||
peer.device.PutMessageBuffer(elem.buffer)
|
||||
peer.device.PutOutboundElement(elem)
|
||||
}
|
||||
peer.device.PutOutboundElementsContainer(elemsContainer)
|
||||
}
|
||||
|
||||
if elemsContainerOOO != nil {
|
||||
goto top
|
||||
}
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *Peer) FlushStagedPackets() {
|
||||
for {
|
||||
select {
|
||||
case elemsContainer := <-peer.queue.staged:
|
||||
for _, elem := range elemsContainer.elems {
|
||||
peer.device.PutMessageBuffer(elem.buffer)
|
||||
peer.device.PutOutboundElement(elem)
|
||||
}
|
||||
peer.device.PutOutboundElementsContainer(elemsContainer)
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func calculatePaddingSize(packetSize, mtu int) int {
|
||||
lastUnit := packetSize
|
||||
if mtu == 0 {
|
||||
return ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) - lastUnit
|
||||
}
|
||||
if lastUnit > mtu {
|
||||
lastUnit %= mtu
|
||||
}
|
||||
paddedSize := ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1))
|
||||
if paddedSize > mtu {
|
||||
paddedSize = mtu
|
||||
}
|
||||
return paddedSize - lastUnit
|
||||
}
|
||||
|
||||
/* Encrypts the elements in the queue
|
||||
* and marks them for sequential consumption (by releasing the mutex)
|
||||
*
|
||||
* Obs. One instance per core
|
||||
*/
|
||||
func (device *Device) RoutineEncryption(id int) {
|
||||
var paddingZeros [PaddingMultiple]byte
|
||||
var nonce [chacha20poly1305.NonceSize]byte
|
||||
|
||||
defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
|
||||
device.log.Verbosef("Routine: encryption worker %d - started", id)
|
||||
|
||||
for elemsContainer := range device.queue.encryption.c {
|
||||
for _, elem := range elemsContainer.elems {
|
||||
// populate header fields
|
||||
header := elem.buffer[:MessageTransportHeaderSize]
|
||||
|
||||
fieldType := header[0:4]
|
||||
fieldReceiver := header[4:8]
|
||||
fieldNonce := header[8:16]
|
||||
|
||||
binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
|
||||
binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
|
||||
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
|
||||
|
||||
// pad content to multiple of 16
|
||||
paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
|
||||
elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
|
||||
|
||||
// encrypt content and release to consumer
|
||||
|
||||
binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
|
||||
elem.packet = elem.keypair.send.Seal(
|
||||
header,
|
||||
nonce[:],
|
||||
elem.packet,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
elemsContainer.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
|
||||
device := peer.device
|
||||
defer func() {
|
||||
defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
|
||||
peer.stopping.Done()
|
||||
}()
|
||||
device.log.Verbosef("%v - Routine: sequential sender - started", peer)
|
||||
|
||||
bufs := make([][]byte, 0, maxBatchSize)
|
||||
|
||||
for elemsContainer := range peer.queue.outbound.c {
|
||||
bufs = bufs[:0]
|
||||
if elemsContainer == nil {
|
||||
return
|
||||
}
|
||||
if !peer.isRunning.Load() {
|
||||
// peer has been stopped; return re-usable elems to the shared pool.
|
||||
// This is an optimization only. It is possible for the peer to be stopped
|
||||
// immediately after this check, in which case, elem will get processed.
|
||||
// The timers and SendBuffers code are resilient to a few stragglers.
|
||||
// TODO: rework peer shutdown order to ensure
|
||||
// that we never accidentally keep timers alive longer than necessary.
|
||||
elemsContainer.Lock()
|
||||
for _, elem := range elemsContainer.elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
}
|
||||
device.PutOutboundElementsContainer(elemsContainer)
|
||||
continue
|
||||
}
|
||||
dataSent := false
|
||||
elemsContainer.Lock()
|
||||
for _, elem := range elemsContainer.elems {
|
||||
if len(elem.packet) != MessageKeepaliveSize {
|
||||
dataSent = true
|
||||
}
|
||||
bufs = append(bufs, elem.packet)
|
||||
}
|
||||
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketSent()
|
||||
|
||||
err := peer.SendBuffers(bufs)
|
||||
if dataSent {
|
||||
peer.timersDataSent()
|
||||
}
|
||||
for _, elem := range elemsContainer.elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
}
|
||||
device.PutOutboundElementsContainer(elemsContainer)
|
||||
if err != nil {
|
||||
var errGSO conn.ErrUDPGSODisabled
|
||||
if errors.As(err, &errGSO) {
|
||||
device.log.Verbosef(err.Error())
|
||||
err = errGSO.RetryErr
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
|
||||
continue
|
||||
}
|
||||
|
||||
peer.keepKeyFreshSending()
|
||||
}
|
||||
}
|
||||
12
vendor/github.com/tailscale/wireguard-go/device/sticky_default.go
generated
vendored
Normal file
12
vendor/github.com/tailscale/wireguard-go/device/sticky_default.go
generated
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build !linux
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"github.com/tailscale/wireguard-go/conn"
|
||||
"github.com/tailscale/wireguard-go/rwcancel"
|
||||
)
|
||||
|
||||
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
||||
return nil, nil
|
||||
}
|
||||
224
vendor/github.com/tailscale/wireguard-go/device/sticky_linux.go
generated
vendored
Normal file
224
vendor/github.com/tailscale/wireguard-go/device/sticky_linux.go
generated
vendored
Normal file
@@ -0,0 +1,224 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* This implements userspace semantics of "sticky sockets", modeled after
|
||||
* WireGuard's kernelspace implementation. This is more or less a straight port
|
||||
* of the sticky-sockets.c example code:
|
||||
* https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
|
||||
*
|
||||
* Currently there is no way to achieve this within the net package:
|
||||
* See e.g. https://github.com/golang/go/issues/17930
|
||||
* So this code is remains platform dependent.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/tailscale/wireguard-go/conn"
|
||||
"github.com/tailscale/wireguard-go/rwcancel"
|
||||
)
|
||||
|
||||
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
||||
if !conn.StdNetSupportsStickySockets {
|
||||
return nil, nil
|
||||
}
|
||||
if _, ok := bind.(*conn.StdNetBind); !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
netlinkSock, err := createNetlinkRouteSocket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock)
|
||||
if err != nil {
|
||||
unix.Close(netlinkSock)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go device.routineRouteListener(bind, netlinkSock, netlinkCancel)
|
||||
|
||||
return netlinkCancel, nil
|
||||
}
|
||||
|
||||
func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
|
||||
type peerEndpointPtr struct {
|
||||
peer *Peer
|
||||
endpoint *conn.Endpoint
|
||||
}
|
||||
var reqPeer map[uint32]peerEndpointPtr
|
||||
var reqPeerLock sync.Mutex
|
||||
|
||||
defer netlinkCancel.Close()
|
||||
defer unix.Close(netlinkSock)
|
||||
|
||||
for msg := make([]byte, 1<<16); ; {
|
||||
var err error
|
||||
var msgn int
|
||||
for {
|
||||
msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0)
|
||||
if err == nil || !rwcancel.RetryAfterError(err) {
|
||||
break
|
||||
}
|
||||
if !netlinkCancel.ReadyRead() {
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
|
||||
|
||||
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
|
||||
|
||||
if uint(hdr.Len) > uint(len(remain)) {
|
||||
break
|
||||
}
|
||||
|
||||
switch hdr.Type {
|
||||
case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
|
||||
if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
|
||||
if uint(len(remain)) < uint(hdr.Len) {
|
||||
break
|
||||
}
|
||||
if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
|
||||
attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
|
||||
for {
|
||||
if uint(len(attr)) < uint(unix.SizeofRtAttr) {
|
||||
break
|
||||
}
|
||||
attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
|
||||
if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
|
||||
break
|
||||
}
|
||||
if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
|
||||
ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
|
||||
reqPeerLock.Lock()
|
||||
if reqPeer == nil {
|
||||
reqPeerLock.Unlock()
|
||||
break
|
||||
}
|
||||
pePtr, ok := reqPeer[hdr.Seq]
|
||||
reqPeerLock.Unlock()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
pePtr.peer.endpoint.Lock()
|
||||
if &pePtr.peer.endpoint.val != pePtr.endpoint {
|
||||
pePtr.peer.endpoint.Unlock()
|
||||
break
|
||||
}
|
||||
if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
|
||||
pePtr.peer.endpoint.Unlock()
|
||||
break
|
||||
}
|
||||
pePtr.peer.endpoint.clearSrcOnTx = true
|
||||
pePtr.peer.endpoint.Unlock()
|
||||
}
|
||||
attr = attr[attrhdr.Len:]
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
reqPeerLock.Lock()
|
||||
reqPeer = make(map[uint32]peerEndpointPtr)
|
||||
reqPeerLock.Unlock()
|
||||
go func() {
|
||||
device.peers.RLock()
|
||||
i := uint32(1)
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.endpoint.Lock()
|
||||
if peer.endpoint.val == nil {
|
||||
peer.endpoint.Unlock()
|
||||
continue
|
||||
}
|
||||
nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint)
|
||||
if nativeEP == nil {
|
||||
peer.endpoint.Unlock()
|
||||
continue
|
||||
}
|
||||
if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
|
||||
peer.endpoint.Unlock()
|
||||
break
|
||||
}
|
||||
nlmsg := struct {
|
||||
hdr unix.NlMsghdr
|
||||
msg unix.RtMsg
|
||||
dsthdr unix.RtAttr
|
||||
dst [4]byte
|
||||
srchdr unix.RtAttr
|
||||
src [4]byte
|
||||
markhdr unix.RtAttr
|
||||
mark uint32
|
||||
}{
|
||||
unix.NlMsghdr{
|
||||
Type: uint16(unix.RTM_GETROUTE),
|
||||
Flags: unix.NLM_F_REQUEST,
|
||||
Seq: i,
|
||||
},
|
||||
unix.RtMsg{
|
||||
Family: unix.AF_INET,
|
||||
Dst_len: 32,
|
||||
Src_len: 32,
|
||||
},
|
||||
unix.RtAttr{
|
||||
Len: 8,
|
||||
Type: unix.RTA_DST,
|
||||
},
|
||||
nativeEP.DstIP().As4(),
|
||||
unix.RtAttr{
|
||||
Len: 8,
|
||||
Type: unix.RTA_SRC,
|
||||
},
|
||||
nativeEP.SrcIP().As4(),
|
||||
unix.RtAttr{
|
||||
Len: 8,
|
||||
Type: unix.RTA_MARK,
|
||||
},
|
||||
device.net.fwmark,
|
||||
}
|
||||
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
|
||||
reqPeerLock.Lock()
|
||||
reqPeer[i] = peerEndpointPtr{
|
||||
peer: peer,
|
||||
endpoint: &peer.endpoint.val,
|
||||
}
|
||||
reqPeerLock.Unlock()
|
||||
peer.endpoint.Unlock()
|
||||
i++
|
||||
_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
}()
|
||||
}
|
||||
remain = remain[hdr.Len:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func createNetlinkRouteSocket() (int, error) {
|
||||
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
saddr := &unix.SockaddrNetlink{
|
||||
Family: unix.AF_NETLINK,
|
||||
Groups: unix.RTMGRP_IPV4_ROUTE,
|
||||
}
|
||||
err = unix.Bind(sock, saddr)
|
||||
if err != nil {
|
||||
unix.Close(sock)
|
||||
return -1, err
|
||||
}
|
||||
return sock, nil
|
||||
}
|
||||
221
vendor/github.com/tailscale/wireguard-go/device/timers.go
generated
vendored
Normal file
221
vendor/github.com/tailscale/wireguard-go/device/timers.go
generated
vendored
Normal file
@@ -0,0 +1,221 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* This is based heavily on timers.c from the kernel implementation.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
_ "unsafe"
|
||||
)
|
||||
|
||||
//go:linkname fastrandn runtime.fastrandn
|
||||
func fastrandn(n uint32) uint32
|
||||
|
||||
// A Timer manages time-based aspects of the WireGuard protocol.
|
||||
// Timer roughly copies the interface of the Linux kernel's struct timer_list.
|
||||
type Timer struct {
|
||||
*time.Timer
|
||||
modifyingLock sync.RWMutex
|
||||
runningLock sync.Mutex
|
||||
isPending bool
|
||||
}
|
||||
|
||||
func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer {
|
||||
timer := &Timer{}
|
||||
timer.Timer = time.AfterFunc(time.Hour, func() {
|
||||
timer.runningLock.Lock()
|
||||
defer timer.runningLock.Unlock()
|
||||
|
||||
timer.modifyingLock.Lock()
|
||||
if !timer.isPending {
|
||||
timer.modifyingLock.Unlock()
|
||||
return
|
||||
}
|
||||
timer.isPending = false
|
||||
timer.modifyingLock.Unlock()
|
||||
|
||||
expirationFunction(peer)
|
||||
})
|
||||
timer.Stop()
|
||||
return timer
|
||||
}
|
||||
|
||||
func (timer *Timer) Mod(d time.Duration) {
|
||||
timer.modifyingLock.Lock()
|
||||
timer.isPending = true
|
||||
timer.Reset(d)
|
||||
timer.modifyingLock.Unlock()
|
||||
}
|
||||
|
||||
func (timer *Timer) Del() {
|
||||
timer.modifyingLock.Lock()
|
||||
timer.isPending = false
|
||||
timer.Stop()
|
||||
timer.modifyingLock.Unlock()
|
||||
}
|
||||
|
||||
func (timer *Timer) DelSync() {
|
||||
timer.Del()
|
||||
timer.runningLock.Lock()
|
||||
timer.Del()
|
||||
timer.runningLock.Unlock()
|
||||
}
|
||||
|
||||
func (timer *Timer) IsPending() bool {
|
||||
timer.modifyingLock.RLock()
|
||||
defer timer.modifyingLock.RUnlock()
|
||||
return timer.isPending
|
||||
}
|
||||
|
||||
func (peer *Peer) timersActive() bool {
|
||||
return peer.isRunning.Load() && peer.device != nil && peer.device.isUp()
|
||||
}
|
||||
|
||||
func expiredRetransmitHandshake(peer *Peer) {
|
||||
if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes {
|
||||
peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2)
|
||||
|
||||
if peer.timersActive() {
|
||||
peer.timers.sendKeepalive.Del()
|
||||
}
|
||||
|
||||
/* We drop all packets without a keypair and don't try again,
|
||||
* if we try unsuccessfully for too long to make a handshake.
|
||||
*/
|
||||
peer.FlushStagedPackets()
|
||||
|
||||
/* We set a timer for destroying any residue that might be left
|
||||
* of a partial exchange.
|
||||
*/
|
||||
if peer.timersActive() && !peer.timers.zeroKeyMaterial.IsPending() {
|
||||
peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
|
||||
}
|
||||
} else {
|
||||
peer.timers.handshakeAttempts.Add(1)
|
||||
peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
|
||||
|
||||
/* We clear the endpoint address src address, in case this is the cause of trouble. */
|
||||
peer.markEndpointSrcForClearing()
|
||||
|
||||
peer.SendHandshakeInitiation(true)
|
||||
}
|
||||
}
|
||||
|
||||
func expiredSendKeepalive(peer *Peer) {
|
||||
peer.SendKeepalive()
|
||||
if peer.timers.needAnotherKeepalive.Load() {
|
||||
peer.timers.needAnotherKeepalive.Store(false)
|
||||
if peer.timersActive() {
|
||||
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func expiredNewHandshake(peer *Peer) {
|
||||
peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
|
||||
/* We clear the endpoint address src address, in case this is the cause of trouble. */
|
||||
peer.markEndpointSrcForClearing()
|
||||
peer.SendHandshakeInitiation(false)
|
||||
}
|
||||
|
||||
func expiredZeroKeyMaterial(peer *Peer) {
|
||||
peer.device.log.Verbosef("%s - Removing all keys, since we haven't received a new one in %d seconds", peer, int((RejectAfterTime * 3).Seconds()))
|
||||
peer.ZeroAndFlushAll()
|
||||
}
|
||||
|
||||
func expiredPersistentKeepalive(peer *Peer) {
|
||||
if peer.persistentKeepaliveInterval.Load() > 0 {
|
||||
peer.SendKeepalive()
|
||||
}
|
||||
}
|
||||
|
||||
/* Should be called after an authenticated data packet is sent. */
|
||||
func (peer *Peer) timersDataSent() {
|
||||
if peer.timersActive() && !peer.timers.newHandshake.IsPending() {
|
||||
peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs)))
|
||||
}
|
||||
}
|
||||
|
||||
/* Should be called after an authenticated data packet is received. */
|
||||
func (peer *Peer) timersDataReceived() {
|
||||
if peer.timersActive() {
|
||||
if !peer.timers.sendKeepalive.IsPending() {
|
||||
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
|
||||
} else {
|
||||
peer.timers.needAnotherKeepalive.Store(true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Should be called after any type of authenticated packet is sent -- keepalive, data, or handshake. */
|
||||
func (peer *Peer) timersAnyAuthenticatedPacketSent() {
|
||||
if peer.timersActive() {
|
||||
peer.timers.sendKeepalive.Del()
|
||||
}
|
||||
}
|
||||
|
||||
/* Should be called after any type of authenticated packet is received -- keepalive, data, or handshake. */
|
||||
func (peer *Peer) timersAnyAuthenticatedPacketReceived() {
|
||||
if peer.timersActive() {
|
||||
peer.timers.newHandshake.Del()
|
||||
}
|
||||
}
|
||||
|
||||
/* Should be called after a handshake initiation message is sent. */
|
||||
func (peer *Peer) timersHandshakeInitiated() {
|
||||
if peer.timersActive() {
|
||||
peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs)))
|
||||
}
|
||||
}
|
||||
|
||||
/* Should be called after a handshake response message is received and processed or when getting key confirmation via the first data message. */
|
||||
func (peer *Peer) timersHandshakeComplete() {
|
||||
if peer.timersActive() {
|
||||
peer.timers.retransmitHandshake.Del()
|
||||
}
|
||||
peer.timers.handshakeAttempts.Store(0)
|
||||
peer.timers.sentLastMinuteHandshake.Store(false)
|
||||
peer.lastHandshakeNano.Store(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
|
||||
func (peer *Peer) timersSessionDerived() {
|
||||
if peer.timersActive() {
|
||||
peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
|
||||
}
|
||||
}
|
||||
|
||||
/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
|
||||
func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
|
||||
keepalive := peer.persistentKeepaliveInterval.Load()
|
||||
if keepalive > 0 && peer.timersActive() {
|
||||
peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *Peer) timersInit() {
|
||||
peer.timers.retransmitHandshake = peer.NewTimer(expiredRetransmitHandshake)
|
||||
peer.timers.sendKeepalive = peer.NewTimer(expiredSendKeepalive)
|
||||
peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake)
|
||||
peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial)
|
||||
peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive)
|
||||
}
|
||||
|
||||
func (peer *Peer) timersStart() {
|
||||
peer.timers.handshakeAttempts.Store(0)
|
||||
peer.timers.sentLastMinuteHandshake.Store(false)
|
||||
peer.timers.needAnotherKeepalive.Store(false)
|
||||
}
|
||||
|
||||
func (peer *Peer) timersStop() {
|
||||
peer.timers.retransmitHandshake.DelSync()
|
||||
peer.timers.sendKeepalive.DelSync()
|
||||
peer.timers.newHandshake.DelSync()
|
||||
peer.timers.zeroKeyMaterial.DelSync()
|
||||
peer.timers.persistentKeepalive.DelSync()
|
||||
}
|
||||
53
vendor/github.com/tailscale/wireguard-go/device/tun.go
generated
vendored
Normal file
53
vendor/github.com/tailscale/wireguard-go/device/tun.go
generated
vendored
Normal file
@@ -0,0 +1,53 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/tailscale/wireguard-go/tun"
|
||||
)
|
||||
|
||||
const DefaultMTU = 1420
|
||||
|
||||
func (device *Device) RoutineTUNEventReader() {
|
||||
device.log.Verbosef("Routine: event worker - started")
|
||||
|
||||
for event := range device.tun.device.Events() {
|
||||
if event&tun.EventMTUUpdate != 0 {
|
||||
mtu, err := device.tun.device.MTU()
|
||||
if err != nil {
|
||||
device.log.Errorf("Failed to load updated MTU of device: %v", err)
|
||||
continue
|
||||
}
|
||||
if mtu < 0 {
|
||||
device.log.Errorf("MTU not updated to negative value: %v", mtu)
|
||||
continue
|
||||
}
|
||||
var tooLarge string
|
||||
if mtu > MaxContentSize {
|
||||
tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize)
|
||||
mtu = MaxContentSize
|
||||
}
|
||||
old := device.tun.mtu.Swap(int32(mtu))
|
||||
if int(old) != mtu {
|
||||
device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge)
|
||||
}
|
||||
}
|
||||
|
||||
if event&tun.EventUp != 0 {
|
||||
device.log.Verbosef("Interface up requested")
|
||||
device.Up()
|
||||
}
|
||||
|
||||
if event&tun.EventDown != 0 {
|
||||
device.log.Verbosef("Interface down requested")
|
||||
device.Down()
|
||||
}
|
||||
}
|
||||
|
||||
device.log.Verbosef("Routine: event worker - stopped")
|
||||
}
|
||||
457
vendor/github.com/tailscale/wireguard-go/device/uapi.go
generated
vendored
Normal file
457
vendor/github.com/tailscale/wireguard-go/device/uapi.go
generated
vendored
Normal file
@@ -0,0 +1,457 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/wireguard-go/ipc"
|
||||
)
|
||||
|
||||
type IPCError struct {
|
||||
code int64 // error code
|
||||
err error // underlying/wrapped error
|
||||
}
|
||||
|
||||
func (s IPCError) Error() string {
|
||||
return fmt.Sprintf("IPC error %d: %v", s.code, s.err)
|
||||
}
|
||||
|
||||
func (s IPCError) Unwrap() error {
|
||||
return s.err
|
||||
}
|
||||
|
||||
func (s IPCError) ErrorCode() int64 {
|
||||
return s.code
|
||||
}
|
||||
|
||||
func ipcErrorf(code int64, msg string, args ...any) *IPCError {
|
||||
return &IPCError{code: code, err: fmt.Errorf(msg, args...)}
|
||||
}
|
||||
|
||||
var byteBufferPool = &sync.Pool{
|
||||
New: func() any { return new(bytes.Buffer) },
|
||||
}
|
||||
|
||||
// IpcGetOperation implements the WireGuard configuration protocol "get" operation.
|
||||
// See https://www.wireguard.com/xplatform/#configuration-protocol for details.
|
||||
func (device *Device) IpcGetOperation(w io.Writer) error {
|
||||
device.ipcMutex.RLock()
|
||||
defer device.ipcMutex.RUnlock()
|
||||
|
||||
buf := byteBufferPool.Get().(*bytes.Buffer)
|
||||
buf.Reset()
|
||||
defer byteBufferPool.Put(buf)
|
||||
sendf := func(format string, args ...any) {
|
||||
fmt.Fprintf(buf, format, args...)
|
||||
buf.WriteByte('\n')
|
||||
}
|
||||
keyf := func(prefix string, key *[32]byte) {
|
||||
buf.Grow(len(key)*2 + 2 + len(prefix))
|
||||
buf.WriteString(prefix)
|
||||
buf.WriteByte('=')
|
||||
const hex = "0123456789abcdef"
|
||||
for i := 0; i < len(key); i++ {
|
||||
buf.WriteByte(hex[key[i]>>4])
|
||||
buf.WriteByte(hex[key[i]&0xf])
|
||||
}
|
||||
buf.WriteByte('\n')
|
||||
}
|
||||
|
||||
func() {
|
||||
// lock required resources
|
||||
|
||||
device.net.RLock()
|
||||
defer device.net.RUnlock()
|
||||
|
||||
device.staticIdentity.RLock()
|
||||
defer device.staticIdentity.RUnlock()
|
||||
|
||||
device.peers.RLock()
|
||||
defer device.peers.RUnlock()
|
||||
|
||||
// serialize device related values
|
||||
|
||||
if !device.staticIdentity.privateKey.IsZero() {
|
||||
keyf("private_key", (*[32]byte)(&device.staticIdentity.privateKey))
|
||||
}
|
||||
|
||||
if device.net.port != 0 {
|
||||
sendf("listen_port=%d", device.net.port)
|
||||
}
|
||||
|
||||
if device.net.fwmark != 0 {
|
||||
sendf("fwmark=%d", device.net.fwmark)
|
||||
}
|
||||
|
||||
for _, peer := range device.peers.keyMap {
|
||||
// Serialize peer state.
|
||||
peer.handshake.mutex.RLock()
|
||||
keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
|
||||
keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
|
||||
peer.handshake.mutex.RUnlock()
|
||||
sendf("protocol_version=1")
|
||||
peer.endpoint.Lock()
|
||||
if peer.endpoint.val != nil {
|
||||
sendf("endpoint=%s", peer.endpoint.val.DstToString())
|
||||
}
|
||||
peer.endpoint.Unlock()
|
||||
|
||||
nano := peer.lastHandshakeNano.Load()
|
||||
secs := nano / time.Second.Nanoseconds()
|
||||
nano %= time.Second.Nanoseconds()
|
||||
|
||||
sendf("last_handshake_time_sec=%d", secs)
|
||||
sendf("last_handshake_time_nsec=%d", nano)
|
||||
sendf("tx_bytes=%d", peer.txBytes.Load())
|
||||
sendf("rx_bytes=%d", peer.rxBytes.Load())
|
||||
sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
|
||||
|
||||
device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
|
||||
sendf("allowed_ip=%s", prefix.String())
|
||||
return true
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
// send lines (does not require resource locks)
|
||||
if _, err := w.Write(buf.Bytes()); err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorIO, "failed to write output: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IpcSetOperation implements the WireGuard configuration protocol "set" operation.
|
||||
// See https://www.wireguard.com/xplatform/#configuration-protocol for details.
|
||||
func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
||||
device.ipcMutex.Lock()
|
||||
defer device.ipcMutex.Unlock()
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
device.log.Errorf("%v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
peer := new(ipcSetPeer)
|
||||
deviceConfig := true
|
||||
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line == "" {
|
||||
// Blank line means terminate operation.
|
||||
peer.handlePostConfig()
|
||||
return nil
|
||||
}
|
||||
key, value, ok := strings.Cut(line, "=")
|
||||
if !ok {
|
||||
return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q", line)
|
||||
}
|
||||
|
||||
if key == "public_key" {
|
||||
if deviceConfig {
|
||||
deviceConfig = false
|
||||
}
|
||||
peer.handlePostConfig()
|
||||
// Load/create the peer we are now configuring.
|
||||
err := device.handlePublicKeyLine(peer, value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
var err error
|
||||
if deviceConfig {
|
||||
err = device.handleDeviceLine(key, value)
|
||||
} else {
|
||||
err = device.handlePeerLine(peer, key, value)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
peer.handlePostConfig()
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorIO, "failed to read input: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (device *Device) handleDeviceLine(key, value string) error {
|
||||
switch key {
|
||||
case "private_key":
|
||||
var sk NoisePrivateKey
|
||||
err := sk.FromMaybeZeroHex(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating private key")
|
||||
device.SetPrivateKey(sk)
|
||||
|
||||
case "listen_port":
|
||||
port, err := strconv.ParseUint(value, 10, 16)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err)
|
||||
}
|
||||
|
||||
// update port and rebind
|
||||
device.log.Verbosef("UAPI: Updating listen port")
|
||||
|
||||
device.net.Lock()
|
||||
device.net.port = uint16(port)
|
||||
device.net.Unlock()
|
||||
|
||||
if err := device.BindUpdate(); err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err)
|
||||
}
|
||||
|
||||
case "fwmark":
|
||||
mark, err := strconv.ParseUint(value, 10, 32)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err)
|
||||
}
|
||||
|
||||
device.log.Verbosef("UAPI: Updating fwmark")
|
||||
if err := device.BindSetMark(uint32(mark)); err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err)
|
||||
}
|
||||
|
||||
case "replace_peers":
|
||||
if value != "true" {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Removing all peers")
|
||||
device.RemoveAllPeers()
|
||||
|
||||
default:
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// An ipcSetPeer is the current state of an IPC set operation on a peer.
|
||||
type ipcSetPeer struct {
|
||||
*Peer // Peer is the current peer being operated on
|
||||
dummy bool // dummy reports whether this peer is a temporary, placeholder peer
|
||||
created bool // new reports whether this is a newly created peer
|
||||
pkaOn bool // pkaOn reports whether the peer had the persistent keepalive turn on
|
||||
}
|
||||
|
||||
func (peer *ipcSetPeer) handlePostConfig() {
|
||||
if peer.Peer == nil || peer.dummy {
|
||||
return
|
||||
}
|
||||
if peer.created {
|
||||
peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil
|
||||
}
|
||||
if peer.device.isUp() {
|
||||
peer.Start()
|
||||
if peer.pkaOn {
|
||||
peer.SendKeepalive()
|
||||
}
|
||||
peer.SendStagedPackets()
|
||||
}
|
||||
}
|
||||
|
||||
func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error {
|
||||
// Load/create the peer we are configuring.
|
||||
var publicKey NoisePublicKey
|
||||
err := publicKey.FromHex(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err)
|
||||
}
|
||||
|
||||
// Ignore peer with the same public key as this device.
|
||||
device.staticIdentity.RLock()
|
||||
peer.dummy = device.staticIdentity.publicKey.Equals(publicKey)
|
||||
device.staticIdentity.RUnlock()
|
||||
|
||||
if peer.dummy {
|
||||
peer.Peer = &Peer{}
|
||||
} else {
|
||||
peer.Peer = device.LookupPeer(publicKey)
|
||||
}
|
||||
|
||||
peer.created = peer.Peer == nil
|
||||
if peer.created {
|
||||
peer.Peer, err = device.NewPeer(publicKey)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err)
|
||||
}
|
||||
device.log.Verbosef("%v - UAPI: Created", peer.Peer)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error {
|
||||
switch key {
|
||||
case "update_only":
|
||||
// allow disabling of creation
|
||||
if value != "true" {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
|
||||
}
|
||||
if peer.created && !peer.dummy {
|
||||
device.RemovePeer(peer.handshake.remoteStatic)
|
||||
peer.Peer = &Peer{}
|
||||
peer.dummy = true
|
||||
}
|
||||
|
||||
case "remove":
|
||||
// remove currently selected peer from device
|
||||
if value != "true" {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value)
|
||||
}
|
||||
if !peer.dummy {
|
||||
device.log.Verbosef("%v - UAPI: Removing", peer.Peer)
|
||||
device.RemovePeer(peer.handshake.remoteStatic)
|
||||
}
|
||||
peer.Peer = &Peer{}
|
||||
peer.dummy = true
|
||||
|
||||
case "preshared_key":
|
||||
device.log.Verbosef("%v - UAPI: Updating preshared key", peer.Peer)
|
||||
|
||||
peer.handshake.mutex.Lock()
|
||||
err := peer.handshake.presharedKey.FromHex(value)
|
||||
peer.handshake.mutex.Unlock()
|
||||
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err)
|
||||
}
|
||||
|
||||
case "endpoint":
|
||||
device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer)
|
||||
endpoint, err := device.net.bind.ParseEndpoint(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
|
||||
}
|
||||
peer.endpoint.Lock()
|
||||
defer peer.endpoint.Unlock()
|
||||
peer.endpoint.val = endpoint
|
||||
|
||||
case "persistent_keepalive_interval":
|
||||
device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
|
||||
|
||||
secs, err := strconv.ParseUint(value, 10, 16)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
|
||||
}
|
||||
|
||||
old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
|
||||
|
||||
// Send immediate keepalive if we're turning it on and before it wasn't on.
|
||||
peer.pkaOn = old == 0 && secs != 0
|
||||
|
||||
case "replace_allowed_ips":
|
||||
device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
|
||||
if value != "true" {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
|
||||
}
|
||||
if peer.dummy {
|
||||
return nil
|
||||
}
|
||||
device.allowedips.RemoveByPeer(peer.Peer)
|
||||
|
||||
case "allowed_ip":
|
||||
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
|
||||
prefix, err := netip.ParsePrefix(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
|
||||
}
|
||||
if peer.dummy {
|
||||
return nil
|
||||
}
|
||||
device.allowedips.Insert(prefix, peer.Peer)
|
||||
|
||||
case "protocol_version":
|
||||
if value != "1" {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value)
|
||||
}
|
||||
|
||||
default:
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (device *Device) IpcGet() (string, error) {
|
||||
buf := new(strings.Builder)
|
||||
if err := device.IpcGetOperation(buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func (device *Device) IpcSet(uapiConf string) error {
|
||||
return device.IpcSetOperation(strings.NewReader(uapiConf))
|
||||
}
|
||||
|
||||
func (device *Device) IpcHandle(socket net.Conn) {
|
||||
defer socket.Close()
|
||||
|
||||
buffered := func(s io.ReadWriter) *bufio.ReadWriter {
|
||||
reader := bufio.NewReader(s)
|
||||
writer := bufio.NewWriter(s)
|
||||
return bufio.NewReadWriter(reader, writer)
|
||||
}(socket)
|
||||
|
||||
for {
|
||||
op, err := buffered.ReadString('\n')
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// handle operation
|
||||
switch op {
|
||||
case "set=1\n":
|
||||
err = device.IpcSetOperation(buffered.Reader)
|
||||
case "get=1\n":
|
||||
var nextByte byte
|
||||
nextByte, err = buffered.ReadByte()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if nextByte != '\n' {
|
||||
err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte)
|
||||
break
|
||||
}
|
||||
err = device.IpcGetOperation(buffered.Writer)
|
||||
default:
|
||||
device.log.Errorf("invalid UAPI operation: %v", op)
|
||||
return
|
||||
}
|
||||
|
||||
// write status
|
||||
var status *IPCError
|
||||
if err != nil && !errors.As(err, &status) {
|
||||
// shouldn't happen
|
||||
status = ipcErrorf(ipc.IpcErrorUnknown, "other UAPI error: %w", err)
|
||||
}
|
||||
if status != nil {
|
||||
device.log.Errorf("%v", status)
|
||||
fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode())
|
||||
} else {
|
||||
fmt.Fprintf(buffered, "errno=0\n\n")
|
||||
}
|
||||
buffered.Flush()
|
||||
}
|
||||
}
|
||||
287
vendor/github.com/tailscale/wireguard-go/ipc/namedpipe/file.go
generated
vendored
Normal file
287
vendor/github.com/tailscale/wireguard-go/ipc/namedpipe/file.go
generated
vendored
Normal file
@@ -0,0 +1,287 @@
|
||||
// Copyright 2021 The Go Authors. All rights reserved.
|
||||
// Copyright 2015 Microsoft
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build windows
|
||||
|
||||
package namedpipe
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
type timeoutChan chan struct{}
|
||||
|
||||
var (
|
||||
ioInitOnce sync.Once
|
||||
ioCompletionPort windows.Handle
|
||||
)
|
||||
|
||||
// ioResult contains the result of an asynchronous IO operation
|
||||
type ioResult struct {
|
||||
bytes uint32
|
||||
err error
|
||||
}
|
||||
|
||||
// ioOperation represents an outstanding asynchronous Win32 IO
|
||||
type ioOperation struct {
|
||||
o windows.Overlapped
|
||||
ch chan ioResult
|
||||
}
|
||||
|
||||
func initIo() {
|
||||
h, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ioCompletionPort = h
|
||||
go ioCompletionProcessor(h)
|
||||
}
|
||||
|
||||
// file implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall.
|
||||
// It takes ownership of this handle and will close it if it is garbage collected.
|
||||
type file struct {
|
||||
handle windows.Handle
|
||||
wg sync.WaitGroup
|
||||
wgLock sync.RWMutex
|
||||
closing atomic.Bool
|
||||
socket bool
|
||||
readDeadline deadlineHandler
|
||||
writeDeadline deadlineHandler
|
||||
}
|
||||
|
||||
type deadlineHandler struct {
|
||||
setLock sync.Mutex
|
||||
channel timeoutChan
|
||||
channelLock sync.RWMutex
|
||||
timer *time.Timer
|
||||
timedout atomic.Bool
|
||||
}
|
||||
|
||||
// makeFile makes a new file from an existing file handle
|
||||
func makeFile(h windows.Handle) (*file, error) {
|
||||
f := &file{handle: h}
|
||||
ioInitOnce.Do(initIo)
|
||||
_, err := windows.CreateIoCompletionPort(h, ioCompletionPort, 0, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = windows.SetFileCompletionNotificationModes(h, windows.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS|windows.FILE_SKIP_SET_EVENT_ON_HANDLE)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
f.readDeadline.channel = make(timeoutChan)
|
||||
f.writeDeadline.channel = make(timeoutChan)
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// closeHandle closes the resources associated with a Win32 handle
|
||||
func (f *file) closeHandle() {
|
||||
f.wgLock.Lock()
|
||||
// Atomically set that we are closing, releasing the resources only once.
|
||||
if f.closing.Swap(true) == false {
|
||||
f.wgLock.Unlock()
|
||||
// cancel all IO and wait for it to complete
|
||||
windows.CancelIoEx(f.handle, nil)
|
||||
f.wg.Wait()
|
||||
// at this point, no new IO can start
|
||||
windows.Close(f.handle)
|
||||
f.handle = 0
|
||||
} else {
|
||||
f.wgLock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes a file.
|
||||
func (f *file) Close() error {
|
||||
f.closeHandle()
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepareIo prepares for a new IO operation.
|
||||
// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
|
||||
func (f *file) prepareIo() (*ioOperation, error) {
|
||||
f.wgLock.RLock()
|
||||
if f.closing.Load() {
|
||||
f.wgLock.RUnlock()
|
||||
return nil, os.ErrClosed
|
||||
}
|
||||
f.wg.Add(1)
|
||||
f.wgLock.RUnlock()
|
||||
c := &ioOperation{}
|
||||
c.ch = make(chan ioResult)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// ioCompletionProcessor processes completed async IOs forever
|
||||
func ioCompletionProcessor(h windows.Handle) {
|
||||
for {
|
||||
var bytes uint32
|
||||
var key uintptr
|
||||
var op *ioOperation
|
||||
err := windows.GetQueuedCompletionStatus(h, &bytes, &key, (**windows.Overlapped)(unsafe.Pointer(&op)), windows.INFINITE)
|
||||
if op == nil {
|
||||
panic(err)
|
||||
}
|
||||
op.ch <- ioResult{bytes, err}
|
||||
}
|
||||
}
|
||||
|
||||
// asyncIo processes the return value from ReadFile or WriteFile, blocking until
|
||||
// the operation has actually completed.
|
||||
func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
|
||||
if err != windows.ERROR_IO_PENDING {
|
||||
return int(bytes), err
|
||||
}
|
||||
|
||||
if f.closing.Load() {
|
||||
windows.CancelIoEx(f.handle, &c.o)
|
||||
}
|
||||
|
||||
var timeout timeoutChan
|
||||
if d != nil {
|
||||
d.channelLock.Lock()
|
||||
timeout = d.channel
|
||||
d.channelLock.Unlock()
|
||||
}
|
||||
|
||||
var r ioResult
|
||||
select {
|
||||
case r = <-c.ch:
|
||||
err = r.err
|
||||
if err == windows.ERROR_OPERATION_ABORTED {
|
||||
if f.closing.Load() {
|
||||
err = os.ErrClosed
|
||||
}
|
||||
} else if err != nil && f.socket {
|
||||
// err is from Win32. Query the overlapped structure to get the winsock error.
|
||||
var bytes, flags uint32
|
||||
err = windows.WSAGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags)
|
||||
}
|
||||
case <-timeout:
|
||||
windows.CancelIoEx(f.handle, &c.o)
|
||||
r = <-c.ch
|
||||
err = r.err
|
||||
if err == windows.ERROR_OPERATION_ABORTED {
|
||||
err = os.ErrDeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
// runtime.KeepAlive is needed, as c is passed via native
|
||||
// code to ioCompletionProcessor, c must remain alive
|
||||
// until the channel read is complete.
|
||||
runtime.KeepAlive(c)
|
||||
return int(r.bytes), err
|
||||
}
|
||||
|
||||
// Read reads from a file handle.
|
||||
func (f *file) Read(b []byte) (int, error) {
|
||||
c, err := f.prepareIo()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer f.wg.Done()
|
||||
|
||||
if f.readDeadline.timedout.Load() {
|
||||
return 0, os.ErrDeadlineExceeded
|
||||
}
|
||||
|
||||
var bytes uint32
|
||||
err = windows.ReadFile(f.handle, b, &bytes, &c.o)
|
||||
n, err := f.asyncIo(c, &f.readDeadline, bytes, err)
|
||||
runtime.KeepAlive(b)
|
||||
|
||||
// Handle EOF conditions.
|
||||
if err == nil && n == 0 && len(b) != 0 {
|
||||
return 0, io.EOF
|
||||
} else if err == windows.ERROR_BROKEN_PIPE {
|
||||
return 0, io.EOF
|
||||
} else {
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
||||
// Write writes to a file handle.
|
||||
func (f *file) Write(b []byte) (int, error) {
|
||||
c, err := f.prepareIo()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer f.wg.Done()
|
||||
|
||||
if f.writeDeadline.timedout.Load() {
|
||||
return 0, os.ErrDeadlineExceeded
|
||||
}
|
||||
|
||||
var bytes uint32
|
||||
err = windows.WriteFile(f.handle, b, &bytes, &c.o)
|
||||
n, err := f.asyncIo(c, &f.writeDeadline, bytes, err)
|
||||
runtime.KeepAlive(b)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (f *file) SetReadDeadline(deadline time.Time) error {
|
||||
return f.readDeadline.set(deadline)
|
||||
}
|
||||
|
||||
func (f *file) SetWriteDeadline(deadline time.Time) error {
|
||||
return f.writeDeadline.set(deadline)
|
||||
}
|
||||
|
||||
func (f *file) Flush() error {
|
||||
return windows.FlushFileBuffers(f.handle)
|
||||
}
|
||||
|
||||
func (f *file) Fd() uintptr {
|
||||
return uintptr(f.handle)
|
||||
}
|
||||
|
||||
func (d *deadlineHandler) set(deadline time.Time) error {
|
||||
d.setLock.Lock()
|
||||
defer d.setLock.Unlock()
|
||||
|
||||
if d.timer != nil {
|
||||
if !d.timer.Stop() {
|
||||
<-d.channel
|
||||
}
|
||||
d.timer = nil
|
||||
}
|
||||
d.timedout.Store(false)
|
||||
|
||||
select {
|
||||
case <-d.channel:
|
||||
d.channelLock.Lock()
|
||||
d.channel = make(chan struct{})
|
||||
d.channelLock.Unlock()
|
||||
default:
|
||||
}
|
||||
|
||||
if deadline.IsZero() {
|
||||
return nil
|
||||
}
|
||||
|
||||
timeoutIO := func() {
|
||||
d.timedout.Store(true)
|
||||
close(d.channel)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
duration := deadline.Sub(now)
|
||||
if deadline.After(now) {
|
||||
// Deadline is in the future, set a timer to wait
|
||||
d.timer = time.AfterFunc(duration, timeoutIO)
|
||||
} else {
|
||||
// Deadline is in the past. Cancel all pending IO now.
|
||||
timeoutIO()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
485
vendor/github.com/tailscale/wireguard-go/ipc/namedpipe/namedpipe.go
generated
vendored
Normal file
485
vendor/github.com/tailscale/wireguard-go/ipc/namedpipe/namedpipe.go
generated
vendored
Normal file
@@ -0,0 +1,485 @@
|
||||
// Copyright 2021 The Go Authors. All rights reserved.
|
||||
// Copyright 2015 Microsoft
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build windows
|
||||
|
||||
// Package namedpipe implements a net.Conn and net.Listener around Windows named pipes.
|
||||
package namedpipe
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
type pipe struct {
|
||||
*file
|
||||
path string
|
||||
}
|
||||
|
||||
type messageBytePipe struct {
|
||||
pipe
|
||||
writeClosed atomic.Bool
|
||||
readEOF bool
|
||||
}
|
||||
|
||||
type pipeAddress string
|
||||
|
||||
func (f *pipe) LocalAddr() net.Addr {
|
||||
return pipeAddress(f.path)
|
||||
}
|
||||
|
||||
func (f *pipe) RemoteAddr() net.Addr {
|
||||
return pipeAddress(f.path)
|
||||
}
|
||||
|
||||
func (f *pipe) SetDeadline(t time.Time) error {
|
||||
f.SetReadDeadline(t)
|
||||
f.SetWriteDeadline(t)
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseWrite closes the write side of a message pipe in byte mode.
|
||||
func (f *messageBytePipe) CloseWrite() error {
|
||||
if !f.writeClosed.CompareAndSwap(false, true) {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
err := f.file.Flush()
|
||||
if err != nil {
|
||||
f.writeClosed.Store(false)
|
||||
return err
|
||||
}
|
||||
_, err = f.file.Write(nil)
|
||||
if err != nil {
|
||||
f.writeClosed.Store(false)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
|
||||
// they are used to implement CloseWrite.
|
||||
func (f *messageBytePipe) Write(b []byte) (int, error) {
|
||||
if f.writeClosed.Load() {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
if len(b) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
return f.file.Write(b)
|
||||
}
|
||||
|
||||
// Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message
|
||||
// mode pipe will return io.EOF, as will all subsequent reads.
|
||||
func (f *messageBytePipe) Read(b []byte) (int, error) {
|
||||
if f.readEOF {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n, err := f.file.Read(b)
|
||||
if err == io.EOF {
|
||||
// If this was the result of a zero-byte read, then
|
||||
// it is possible that the read was due to a zero-size
|
||||
// message. Since we are simulating CloseWrite with a
|
||||
// zero-byte message, ensure that all future Read calls
|
||||
// also return EOF.
|
||||
f.readEOF = true
|
||||
} else if err == windows.ERROR_MORE_DATA {
|
||||
// ERROR_MORE_DATA indicates that the pipe's read mode is message mode
|
||||
// and the message still has more bytes. Treat this as a success, since
|
||||
// this package presents all named pipes as byte streams.
|
||||
err = nil
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (f *pipe) Handle() windows.Handle {
|
||||
return f.handle
|
||||
}
|
||||
|
||||
func (s pipeAddress) Network() string {
|
||||
return "pipe"
|
||||
}
|
||||
|
||||
func (s pipeAddress) String() string {
|
||||
return string(s)
|
||||
}
|
||||
|
||||
// tryDialPipe attempts to dial the specified pipe until cancellation or timeout.
|
||||
func tryDialPipe(ctx context.Context, path *string) (windows.Handle, error) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return 0, ctx.Err()
|
||||
default:
|
||||
path16, err := windows.UTF16PtrFromString(*path)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
h, err := windows.CreateFile(path16, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_OVERLAPPED|windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0)
|
||||
if err == nil {
|
||||
return h, nil
|
||||
}
|
||||
if err != windows.ERROR_PIPE_BUSY {
|
||||
return h, &os.PathError{Err: err, Op: "open", Path: *path}
|
||||
}
|
||||
// Wait 10 msec and try again. This is a rather simplistic
|
||||
// view, as we always try each 10 milliseconds.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DialConfig exposes various options for use in Dial and DialContext.
|
||||
type DialConfig struct {
|
||||
ExpectedOwner *windows.SID // If non-nil, the pipe is verified to be owned by this SID.
|
||||
}
|
||||
|
||||
// DialTimeout connects to the specified named pipe by path, timing out if the
|
||||
// connection takes longer than the specified duration. If timeout is zero, then
|
||||
// we use a default timeout of 2 seconds.
|
||||
func (config *DialConfig) DialTimeout(path string, timeout time.Duration) (net.Conn, error) {
|
||||
if timeout == 0 {
|
||||
timeout = time.Second * 2
|
||||
}
|
||||
absTimeout := time.Now().Add(timeout)
|
||||
ctx, _ := context.WithDeadline(context.Background(), absTimeout)
|
||||
conn, err := config.DialContext(ctx, path)
|
||||
if err == context.DeadlineExceeded {
|
||||
return nil, os.ErrDeadlineExceeded
|
||||
}
|
||||
return conn, err
|
||||
}
|
||||
|
||||
// DialContext attempts to connect to the specified named pipe by path.
|
||||
func (config *DialConfig) DialContext(ctx context.Context, path string) (net.Conn, error) {
|
||||
var err error
|
||||
var h windows.Handle
|
||||
h, err = tryDialPipe(ctx, &path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config.ExpectedOwner != nil {
|
||||
sd, err := windows.GetSecurityInfo(h, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION)
|
||||
if err != nil {
|
||||
windows.Close(h)
|
||||
return nil, err
|
||||
}
|
||||
realOwner, _, err := sd.Owner()
|
||||
if err != nil {
|
||||
windows.Close(h)
|
||||
return nil, err
|
||||
}
|
||||
if !realOwner.Equals(config.ExpectedOwner) {
|
||||
windows.Close(h)
|
||||
return nil, windows.ERROR_ACCESS_DENIED
|
||||
}
|
||||
}
|
||||
|
||||
var flags uint32
|
||||
err = windows.GetNamedPipeInfo(h, &flags, nil, nil, nil)
|
||||
if err != nil {
|
||||
windows.Close(h)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := makeFile(h)
|
||||
if err != nil {
|
||||
windows.Close(h)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If the pipe is in message mode, return a message byte pipe, which
|
||||
// supports CloseWrite.
|
||||
if flags&windows.PIPE_TYPE_MESSAGE != 0 {
|
||||
return &messageBytePipe{
|
||||
pipe: pipe{file: f, path: path},
|
||||
}, nil
|
||||
}
|
||||
return &pipe{file: f, path: path}, nil
|
||||
}
|
||||
|
||||
var defaultDialer DialConfig
|
||||
|
||||
// DialTimeout calls DialConfig.DialTimeout using an empty configuration.
|
||||
func DialTimeout(path string, timeout time.Duration) (net.Conn, error) {
|
||||
return defaultDialer.DialTimeout(path, timeout)
|
||||
}
|
||||
|
||||
// DialContext calls DialConfig.DialContext using an empty configuration.
|
||||
func DialContext(ctx context.Context, path string) (net.Conn, error) {
|
||||
return defaultDialer.DialContext(ctx, path)
|
||||
}
|
||||
|
||||
type acceptResponse struct {
|
||||
f *file
|
||||
err error
|
||||
}
|
||||
|
||||
type pipeListener struct {
|
||||
firstHandle windows.Handle
|
||||
path string
|
||||
config ListenConfig
|
||||
acceptCh chan chan acceptResponse
|
||||
closeCh chan int
|
||||
doneCh chan int
|
||||
}
|
||||
|
||||
func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, isFirstPipe bool) (windows.Handle, error) {
|
||||
path16, err := windows.UTF16PtrFromString(path)
|
||||
if err != nil {
|
||||
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||
}
|
||||
|
||||
var oa windows.OBJECT_ATTRIBUTES
|
||||
oa.Length = uint32(unsafe.Sizeof(oa))
|
||||
|
||||
var ntPath windows.NTUnicodeString
|
||||
if err := windows.RtlDosPathNameToNtPathName(path16, &ntPath, nil, nil); err != nil {
|
||||
if ntstatus, ok := err.(windows.NTStatus); ok {
|
||||
err = ntstatus.Errno()
|
||||
}
|
||||
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||
}
|
||||
defer windows.LocalFree(windows.Handle(unsafe.Pointer(ntPath.Buffer)))
|
||||
oa.ObjectName = &ntPath
|
||||
|
||||
// The security descriptor is only needed for the first pipe.
|
||||
if isFirstPipe {
|
||||
if sd != nil {
|
||||
oa.SecurityDescriptor = sd
|
||||
} else {
|
||||
// Construct the default named pipe security descriptor.
|
||||
var acl *windows.ACL
|
||||
if err := windows.RtlDefaultNpAcl(&acl); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer windows.LocalFree(windows.Handle(unsafe.Pointer(acl)))
|
||||
sd, err = windows.NewSecurityDescriptor()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err = sd.SetDACL(acl, true, false); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
oa.SecurityDescriptor = sd
|
||||
}
|
||||
}
|
||||
|
||||
typ := uint32(windows.FILE_PIPE_REJECT_REMOTE_CLIENTS)
|
||||
if c.MessageMode {
|
||||
typ |= windows.FILE_PIPE_MESSAGE_TYPE
|
||||
}
|
||||
|
||||
disposition := uint32(windows.FILE_OPEN)
|
||||
access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE)
|
||||
if isFirstPipe {
|
||||
disposition = windows.FILE_CREATE
|
||||
// By not asking for read or write access, the named pipe file system
|
||||
// will put this pipe into an initially disconnected state, blocking
|
||||
// client connections until the next call with isFirstPipe == false.
|
||||
access = windows.SYNCHRONIZE
|
||||
}
|
||||
|
||||
timeout := int64(-50 * 10000) // 50ms
|
||||
|
||||
var (
|
||||
h windows.Handle
|
||||
iosb windows.IO_STATUS_BLOCK
|
||||
)
|
||||
err = windows.NtCreateNamedPipeFile(&h, access, &oa, &iosb, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout)
|
||||
if err != nil {
|
||||
if ntstatus, ok := err.(windows.NTStatus); ok {
|
||||
err = ntstatus.Errno()
|
||||
}
|
||||
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||
}
|
||||
|
||||
runtime.KeepAlive(ntPath)
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func (l *pipeListener) makeServerPipe() (*file, error) {
|
||||
h, err := makeServerPipeHandle(l.path, nil, &l.config, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
f, err := makeFile(h)
|
||||
if err != nil {
|
||||
windows.Close(h)
|
||||
return nil, err
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (l *pipeListener) makeConnectedServerPipe() (*file, error) {
|
||||
p, err := l.makeServerPipe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Wait for the client to connect.
|
||||
ch := make(chan error)
|
||||
go func(p *file) {
|
||||
ch <- connectPipe(p)
|
||||
}(p)
|
||||
|
||||
select {
|
||||
case err = <-ch:
|
||||
if err != nil {
|
||||
p.Close()
|
||||
p = nil
|
||||
}
|
||||
case <-l.closeCh:
|
||||
// Abort the connect request by closing the handle.
|
||||
p.Close()
|
||||
p = nil
|
||||
err = <-ch
|
||||
if err == nil || err == os.ErrClosed {
|
||||
err = net.ErrClosed
|
||||
}
|
||||
}
|
||||
return p, err
|
||||
}
|
||||
|
||||
func (l *pipeListener) listenerRoutine() {
|
||||
closed := false
|
||||
for !closed {
|
||||
select {
|
||||
case <-l.closeCh:
|
||||
closed = true
|
||||
case responseCh := <-l.acceptCh:
|
||||
var (
|
||||
p *file
|
||||
err error
|
||||
)
|
||||
for {
|
||||
p, err = l.makeConnectedServerPipe()
|
||||
// If the connection was immediately closed by the client, try
|
||||
// again.
|
||||
if err != windows.ERROR_NO_DATA {
|
||||
break
|
||||
}
|
||||
}
|
||||
responseCh <- acceptResponse{p, err}
|
||||
closed = err == net.ErrClosed
|
||||
}
|
||||
}
|
||||
windows.Close(l.firstHandle)
|
||||
l.firstHandle = 0
|
||||
// Notify Close and Accept callers that the handle has been closed.
|
||||
close(l.doneCh)
|
||||
}
|
||||
|
||||
// ListenConfig contains configuration for the pipe listener.
|
||||
type ListenConfig struct {
|
||||
// SecurityDescriptor contains a Windows security descriptor. If nil, the default from RtlDefaultNpAcl is used.
|
||||
SecurityDescriptor *windows.SECURITY_DESCRIPTOR
|
||||
|
||||
// MessageMode determines whether the pipe is in byte or message mode. In either
|
||||
// case the pipe is read in byte mode by default. The only practical difference in
|
||||
// this implementation is that CloseWrite is only supported for message mode pipes;
|
||||
// CloseWrite is implemented as a zero-byte write, but zero-byte writes are only
|
||||
// transferred to the reader (and returned as io.EOF in this implementation)
|
||||
// when the pipe is in message mode.
|
||||
MessageMode bool
|
||||
|
||||
// InputBufferSize specifies the initial size of the input buffer, in bytes, which the OS will grow as needed.
|
||||
InputBufferSize int32
|
||||
|
||||
// OutputBufferSize specifies the initial size of the output buffer, in bytes, which the OS will grow as needed.
|
||||
OutputBufferSize int32
|
||||
}
|
||||
|
||||
// Listen creates a listener on a Windows named pipe path,such as \\.\pipe\mypipe.
|
||||
// The pipe must not already exist.
|
||||
func (c *ListenConfig) Listen(path string) (net.Listener, error) {
|
||||
h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l := &pipeListener{
|
||||
firstHandle: h,
|
||||
path: path,
|
||||
config: *c,
|
||||
acceptCh: make(chan chan acceptResponse),
|
||||
closeCh: make(chan int),
|
||||
doneCh: make(chan int),
|
||||
}
|
||||
// The first connection is swallowed on Windows 7 & 8, so synthesize it.
|
||||
if maj, min, _ := windows.RtlGetNtVersionNumbers(); maj < 6 || (maj == 6 && min < 4) {
|
||||
path16, err := windows.UTF16PtrFromString(path)
|
||||
if err == nil {
|
||||
h, err = windows.CreateFile(path16, 0, 0, nil, windows.OPEN_EXISTING, windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0)
|
||||
if err == nil {
|
||||
windows.CloseHandle(h)
|
||||
}
|
||||
}
|
||||
}
|
||||
go l.listenerRoutine()
|
||||
return l, nil
|
||||
}
|
||||
|
||||
var defaultListener ListenConfig
|
||||
|
||||
// Listen calls ListenConfig.Listen using an empty configuration.
|
||||
func Listen(path string) (net.Listener, error) {
|
||||
return defaultListener.Listen(path)
|
||||
}
|
||||
|
||||
func connectPipe(p *file) error {
|
||||
c, err := p.prepareIo()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer p.wg.Done()
|
||||
|
||||
err = windows.ConnectNamedPipe(p.handle, &c.o)
|
||||
_, err = p.asyncIo(c, nil, 0, err)
|
||||
if err != nil && err != windows.ERROR_PIPE_CONNECTED {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *pipeListener) Accept() (net.Conn, error) {
|
||||
ch := make(chan acceptResponse)
|
||||
select {
|
||||
case l.acceptCh <- ch:
|
||||
response := <-ch
|
||||
err := response.err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if l.config.MessageMode {
|
||||
return &messageBytePipe{
|
||||
pipe: pipe{file: response.f, path: l.path},
|
||||
}, nil
|
||||
}
|
||||
return &pipe{file: response.f, path: l.path}, nil
|
||||
case <-l.doneCh:
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
}
|
||||
|
||||
func (l *pipeListener) Close() error {
|
||||
select {
|
||||
case l.closeCh <- 1:
|
||||
<-l.doneCh
|
||||
case <-l.doneCh:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *pipeListener) Addr() net.Addr {
|
||||
return pipeAddress(l.path)
|
||||
}
|
||||
132
vendor/github.com/tailscale/wireguard-go/ipc/uapi_bsd.go
generated
vendored
Normal file
132
vendor/github.com/tailscale/wireguard-go/ipc/uapi_bsd.go
generated
vendored
Normal file
@@ -0,0 +1,132 @@
|
||||
//go:build darwin || freebsd || openbsd
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type UAPIListener struct {
|
||||
listener net.Listener // unix socket listener
|
||||
connNew chan net.Conn
|
||||
connErr chan error
|
||||
kqueueFd int
|
||||
keventFd int
|
||||
}
|
||||
|
||||
func (l *UAPIListener) Accept() (net.Conn, error) {
|
||||
for {
|
||||
select {
|
||||
case conn := <-l.connNew:
|
||||
return conn, nil
|
||||
|
||||
case err := <-l.connErr:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *UAPIListener) Close() error {
|
||||
err1 := unix.Close(l.kqueueFd)
|
||||
err2 := unix.Close(l.keventFd)
|
||||
err3 := l.listener.Close()
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
if err2 != nil {
|
||||
return err2
|
||||
}
|
||||
return err3
|
||||
}
|
||||
|
||||
func (l *UAPIListener) Addr() net.Addr {
|
||||
return l.listener.Addr()
|
||||
}
|
||||
|
||||
func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
||||
// wrap file in listener
|
||||
|
||||
listener, err := net.FileListener(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
uapi := &UAPIListener{
|
||||
listener: listener,
|
||||
connNew: make(chan net.Conn, 1),
|
||||
connErr: make(chan error, 1),
|
||||
}
|
||||
|
||||
if unixListener, ok := listener.(*net.UnixListener); ok {
|
||||
unixListener.SetUnlinkOnClose(true)
|
||||
}
|
||||
|
||||
socketPath := sockPath(name)
|
||||
|
||||
// watch for deletion of socket
|
||||
|
||||
uapi.kqueueFd, err = unix.Kqueue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
uapi.keventFd, err = unix.Open(socketDirectory, unix.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
unix.Close(uapi.kqueueFd)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go func(l *UAPIListener) {
|
||||
event := unix.Kevent_t{
|
||||
Filter: unix.EVFILT_VNODE,
|
||||
Flags: unix.EV_ADD | unix.EV_ENABLE | unix.EV_ONESHOT,
|
||||
Fflags: unix.NOTE_WRITE,
|
||||
}
|
||||
// Allow this assignment to work with both the 32-bit and 64-bit version
|
||||
// of the above struct. If you know another way, please submit a patch.
|
||||
*(*uintptr)(unsafe.Pointer(&event.Ident)) = uintptr(uapi.keventFd)
|
||||
events := make([]unix.Kevent_t, 1)
|
||||
n := 1
|
||||
var kerr error
|
||||
for {
|
||||
// start with lstat to avoid race condition
|
||||
if _, err := os.Lstat(socketPath); os.IsNotExist(err) {
|
||||
l.connErr <- err
|
||||
return
|
||||
}
|
||||
if (kerr != nil || n != 1) && kerr != unix.EINTR {
|
||||
if kerr != nil {
|
||||
l.connErr <- kerr
|
||||
} else {
|
||||
l.connErr <- errors.New("kqueue returned empty")
|
||||
}
|
||||
return
|
||||
}
|
||||
n, kerr = unix.Kevent(uapi.kqueueFd, []unix.Kevent_t{event}, events, nil)
|
||||
}
|
||||
}(uapi)
|
||||
|
||||
// watch for new connections
|
||||
|
||||
go func(l *UAPIListener) {
|
||||
for {
|
||||
conn, err := l.listener.Accept()
|
||||
if err != nil {
|
||||
l.connErr <- err
|
||||
break
|
||||
}
|
||||
l.connNew <- conn
|
||||
}
|
||||
}(uapi)
|
||||
|
||||
return uapi, nil
|
||||
}
|
||||
17
vendor/github.com/tailscale/wireguard-go/ipc/uapi_fake.go
generated
vendored
Normal file
17
vendor/github.com/tailscale/wireguard-go/ipc/uapi_fake.go
generated
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
//go:build wasm || plan9 || aix
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
|
||||
// Made up sentinel error codes for {js,wasip1}/wasm, and plan9.
|
||||
const (
|
||||
IpcErrorIO = 1
|
||||
IpcErrorInvalid = 2
|
||||
IpcErrorPortInUse = 3
|
||||
IpcErrorUnknown = 4
|
||||
IpcErrorProtocol = 5
|
||||
)
|
||||
129
vendor/github.com/tailscale/wireguard-go/ipc/uapi_linux.go
generated
vendored
Normal file
129
vendor/github.com/tailscale/wireguard-go/ipc/uapi_linux.go
generated
vendored
Normal file
@@ -0,0 +1,129 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/tailscale/wireguard-go/rwcancel"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type UAPIListener struct {
|
||||
listener net.Listener // unix socket listener
|
||||
connNew chan net.Conn
|
||||
connErr chan error
|
||||
inotifyFd int
|
||||
inotifyRWCancel *rwcancel.RWCancel
|
||||
}
|
||||
|
||||
func (l *UAPIListener) Accept() (net.Conn, error) {
|
||||
for {
|
||||
select {
|
||||
case conn := <-l.connNew:
|
||||
return conn, nil
|
||||
|
||||
case err := <-l.connErr:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *UAPIListener) Close() error {
|
||||
err1 := unix.Close(l.inotifyFd)
|
||||
err2 := l.inotifyRWCancel.Cancel()
|
||||
err3 := l.listener.Close()
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
if err2 != nil {
|
||||
return err2
|
||||
}
|
||||
return err3
|
||||
}
|
||||
|
||||
func (l *UAPIListener) Addr() net.Addr {
|
||||
return l.listener.Addr()
|
||||
}
|
||||
|
||||
func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
||||
// wrap file in listener
|
||||
|
||||
listener, err := net.FileListener(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if unixListener, ok := listener.(*net.UnixListener); ok {
|
||||
unixListener.SetUnlinkOnClose(true)
|
||||
}
|
||||
|
||||
uapi := &UAPIListener{
|
||||
listener: listener,
|
||||
connNew: make(chan net.Conn, 1),
|
||||
connErr: make(chan error, 1),
|
||||
}
|
||||
|
||||
// watch for deletion of socket
|
||||
|
||||
socketPath := sockPath(name)
|
||||
|
||||
uapi.inotifyFd, err = unix.InotifyInit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = unix.InotifyAddWatch(
|
||||
uapi.inotifyFd,
|
||||
socketPath,
|
||||
unix.IN_ATTRIB|
|
||||
unix.IN_DELETE|
|
||||
unix.IN_DELETE_SELF,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
uapi.inotifyRWCancel, err = rwcancel.NewRWCancel(uapi.inotifyFd)
|
||||
if err != nil {
|
||||
unix.Close(uapi.inotifyFd)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go func(l *UAPIListener) {
|
||||
var buf [0]byte
|
||||
for {
|
||||
defer uapi.inotifyRWCancel.Close()
|
||||
// start with lstat to avoid race condition
|
||||
if _, err := os.Lstat(socketPath); os.IsNotExist(err) {
|
||||
l.connErr <- err
|
||||
return
|
||||
}
|
||||
_, err := uapi.inotifyRWCancel.Read(buf[:])
|
||||
if err != nil {
|
||||
l.connErr <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}(uapi)
|
||||
|
||||
// watch for new connections
|
||||
|
||||
go func(l *UAPIListener) {
|
||||
for {
|
||||
conn, err := l.listener.Accept()
|
||||
if err != nil {
|
||||
l.connErr <- err
|
||||
break
|
||||
}
|
||||
l.connNew <- conn
|
||||
}
|
||||
}(uapi)
|
||||
|
||||
return uapi, nil
|
||||
}
|
||||
17
vendor/github.com/tailscale/wireguard-go/ipc/uapi_tamago.go
generated
vendored
Normal file
17
vendor/github.com/tailscale/wireguard-go/ipc/uapi_tamago.go
generated
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
//go:build tamago
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
|
||||
// Made up sentinel error codes for tamago platform.
|
||||
const (
|
||||
IpcErrorIO = 1
|
||||
IpcErrorInvalid = 2
|
||||
IpcErrorPortInUse = 3
|
||||
IpcErrorUnknown = 4
|
||||
IpcErrorProtocol = 5
|
||||
)
|
||||
66
vendor/github.com/tailscale/wireguard-go/ipc/uapi_unix.go
generated
vendored
Normal file
66
vendor/github.com/tailscale/wireguard-go/ipc/uapi_unix.go
generated
vendored
Normal file
@@ -0,0 +1,66 @@
|
||||
//go:build linux || darwin || freebsd || openbsd
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
IpcErrorIO = -int64(unix.EIO)
|
||||
IpcErrorProtocol = -int64(unix.EPROTO)
|
||||
IpcErrorInvalid = -int64(unix.EINVAL)
|
||||
IpcErrorPortInUse = -int64(unix.EADDRINUSE)
|
||||
IpcErrorUnknown = -55 // ENOANO
|
||||
)
|
||||
|
||||
// socketDirectory is variable because it is modified by a linker
|
||||
// flag in wireguard-android.
|
||||
var socketDirectory = "/var/run/wireguard"
|
||||
|
||||
func sockPath(iface string) string {
|
||||
return fmt.Sprintf("%s/%s.sock", socketDirectory, iface)
|
||||
}
|
||||
|
||||
func UAPIOpen(name string) (*os.File, error) {
|
||||
if err := os.MkdirAll(socketDirectory, 0o755); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
socketPath := sockPath(name)
|
||||
addr, err := net.ResolveUnixAddr("unix", socketPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oldUmask := unix.Umask(0o077)
|
||||
defer unix.Umask(oldUmask)
|
||||
|
||||
listener, err := net.ListenUnix("unix", addr)
|
||||
if err == nil {
|
||||
return listener.File()
|
||||
}
|
||||
|
||||
// Test socket, if not in use cleanup and try again.
|
||||
if _, err := net.Dial("unix", socketPath); err == nil {
|
||||
return nil, errors.New("unix socket in use")
|
||||
}
|
||||
if err := os.Remove(socketPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
listener, err = net.ListenUnix("unix", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return listener.File()
|
||||
}
|
||||
88
vendor/github.com/tailscale/wireguard-go/ipc/uapi_windows.go
generated
vendored
Normal file
88
vendor/github.com/tailscale/wireguard-go/ipc/uapi_windows.go
generated
vendored
Normal file
@@ -0,0 +1,88 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/tailscale/wireguard-go/ipc/namedpipe"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// TODO: replace these with actual standard windows error numbers from the win package
|
||||
const (
|
||||
IpcErrorIO = -int64(5)
|
||||
IpcErrorProtocol = -int64(71)
|
||||
IpcErrorInvalid = -int64(22)
|
||||
IpcErrorPortInUse = -int64(98)
|
||||
IpcErrorUnknown = -int64(55)
|
||||
)
|
||||
|
||||
type UAPIListener struct {
|
||||
listener net.Listener // unix socket listener
|
||||
connNew chan net.Conn
|
||||
connErr chan error
|
||||
kqueueFd int
|
||||
keventFd int
|
||||
}
|
||||
|
||||
func (l *UAPIListener) Accept() (net.Conn, error) {
|
||||
for {
|
||||
select {
|
||||
case conn := <-l.connNew:
|
||||
return conn, nil
|
||||
|
||||
case err := <-l.connErr:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *UAPIListener) Close() error {
|
||||
return l.listener.Close()
|
||||
}
|
||||
|
||||
func (l *UAPIListener) Addr() net.Addr {
|
||||
return l.listener.Addr()
|
||||
}
|
||||
|
||||
var UAPISecurityDescriptor *windows.SECURITY_DESCRIPTOR
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
UAPISecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)(A;;GA;;;BA)S:(ML;;NWNRNX;;;HI)")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func UAPIListen(name string) (net.Listener, error) {
|
||||
listener, err := (&namedpipe.ListenConfig{
|
||||
SecurityDescriptor: UAPISecurityDescriptor,
|
||||
}).Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\` + name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
uapi := &UAPIListener{
|
||||
listener: listener,
|
||||
connNew: make(chan net.Conn, 1),
|
||||
connErr: make(chan error, 1),
|
||||
}
|
||||
|
||||
go func(l *UAPIListener) {
|
||||
for {
|
||||
conn, err := l.listener.Accept()
|
||||
if err != nil {
|
||||
l.connErr <- err
|
||||
break
|
||||
}
|
||||
l.connNew <- conn
|
||||
}
|
||||
}(uapi)
|
||||
|
||||
return uapi, nil
|
||||
}
|
||||
137
vendor/github.com/tailscale/wireguard-go/ratelimiter/ratelimiter.go
generated
vendored
Normal file
137
vendor/github.com/tailscale/wireguard-go/ratelimiter/ratelimiter.go
generated
vendored
Normal file
@@ -0,0 +1,137 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ratelimiter
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
packetsPerSecond = 20
|
||||
packetsBurstable = 5
|
||||
garbageCollectTime = time.Second
|
||||
packetCost = 1000000000 / packetsPerSecond
|
||||
maxTokens = packetCost * packetsBurstable
|
||||
)
|
||||
|
||||
type RatelimiterEntry struct {
|
||||
mu sync.Mutex
|
||||
lastTime time.Time
|
||||
tokens int64
|
||||
}
|
||||
|
||||
type Ratelimiter struct {
|
||||
mu sync.RWMutex
|
||||
timeNow func() time.Time
|
||||
|
||||
stopReset chan struct{} // send to reset, close to stop
|
||||
table map[netip.Addr]*RatelimiterEntry
|
||||
}
|
||||
|
||||
func (rate *Ratelimiter) Close() {
|
||||
rate.mu.Lock()
|
||||
defer rate.mu.Unlock()
|
||||
|
||||
if rate.stopReset != nil {
|
||||
close(rate.stopReset)
|
||||
}
|
||||
}
|
||||
|
||||
func (rate *Ratelimiter) Init() {
|
||||
rate.mu.Lock()
|
||||
defer rate.mu.Unlock()
|
||||
|
||||
if rate.timeNow == nil {
|
||||
rate.timeNow = time.Now
|
||||
}
|
||||
|
||||
// stop any ongoing garbage collection routine
|
||||
if rate.stopReset != nil {
|
||||
close(rate.stopReset)
|
||||
}
|
||||
|
||||
rate.stopReset = make(chan struct{})
|
||||
rate.table = make(map[netip.Addr]*RatelimiterEntry)
|
||||
|
||||
stopReset := rate.stopReset // store in case Init is called again.
|
||||
|
||||
// Start garbage collection routine.
|
||||
go func() {
|
||||
ticker := time.NewTicker(time.Second)
|
||||
ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case _, ok := <-stopReset:
|
||||
ticker.Stop()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
ticker = time.NewTicker(time.Second)
|
||||
case <-ticker.C:
|
||||
if rate.cleanup() {
|
||||
ticker.Stop()
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (rate *Ratelimiter) cleanup() (empty bool) {
|
||||
rate.mu.Lock()
|
||||
defer rate.mu.Unlock()
|
||||
|
||||
for key, entry := range rate.table {
|
||||
entry.mu.Lock()
|
||||
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
|
||||
delete(rate.table, key)
|
||||
}
|
||||
entry.mu.Unlock()
|
||||
}
|
||||
|
||||
return len(rate.table) == 0
|
||||
}
|
||||
|
||||
func (rate *Ratelimiter) Allow(ip netip.Addr) bool {
|
||||
var entry *RatelimiterEntry
|
||||
// lookup entry
|
||||
rate.mu.RLock()
|
||||
entry = rate.table[ip]
|
||||
rate.mu.RUnlock()
|
||||
|
||||
// make new entry if not found
|
||||
if entry == nil {
|
||||
entry = new(RatelimiterEntry)
|
||||
entry.tokens = maxTokens - packetCost
|
||||
entry.lastTime = rate.timeNow()
|
||||
rate.mu.Lock()
|
||||
rate.table[ip] = entry
|
||||
if len(rate.table) == 1 {
|
||||
rate.stopReset <- struct{}{}
|
||||
}
|
||||
rate.mu.Unlock()
|
||||
return true
|
||||
}
|
||||
|
||||
// add tokens to entry
|
||||
entry.mu.Lock()
|
||||
now := rate.timeNow()
|
||||
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
|
||||
entry.lastTime = now
|
||||
if entry.tokens > maxTokens {
|
||||
entry.tokens = maxTokens
|
||||
}
|
||||
|
||||
// subtract cost of packet
|
||||
if entry.tokens > packetCost {
|
||||
entry.tokens -= packetCost
|
||||
entry.mu.Unlock()
|
||||
return true
|
||||
}
|
||||
entry.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
62
vendor/github.com/tailscale/wireguard-go/replay/replay.go
generated
vendored
Normal file
62
vendor/github.com/tailscale/wireguard-go/replay/replay.go
generated
vendored
Normal file
@@ -0,0 +1,62 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
// Package replay implements an efficient anti-replay algorithm as specified in RFC 6479.
|
||||
package replay
|
||||
|
||||
type block uint64
|
||||
|
||||
const (
|
||||
blockBitLog = 6 // 1<<6 == 64 bits
|
||||
blockBits = 1 << blockBitLog // must be power of 2
|
||||
ringBlocks = 1 << 7 // must be power of 2
|
||||
windowSize = (ringBlocks - 1) * blockBits
|
||||
blockMask = ringBlocks - 1
|
||||
bitMask = blockBits - 1
|
||||
)
|
||||
|
||||
// A Filter rejects replayed messages by checking if message counter value is
|
||||
// within a sliding window of previously received messages.
|
||||
// The zero value for Filter is an empty filter ready to use.
|
||||
// Filters are unsafe for concurrent use.
|
||||
type Filter struct {
|
||||
last uint64
|
||||
ring [ringBlocks]block
|
||||
}
|
||||
|
||||
// Reset resets the filter to empty state.
|
||||
func (f *Filter) Reset() {
|
||||
f.last = 0
|
||||
f.ring[0] = 0
|
||||
}
|
||||
|
||||
// ValidateCounter checks if the counter should be accepted.
|
||||
// Overlimit counters (>= limit) are always rejected.
|
||||
func (f *Filter) ValidateCounter(counter, limit uint64) bool {
|
||||
if counter >= limit {
|
||||
return false
|
||||
}
|
||||
indexBlock := counter >> blockBitLog
|
||||
if counter > f.last { // move window forward
|
||||
current := f.last >> blockBitLog
|
||||
diff := indexBlock - current
|
||||
if diff > ringBlocks {
|
||||
diff = ringBlocks // cap diff to clear the whole ring
|
||||
}
|
||||
for i := current + 1; i <= current+diff; i++ {
|
||||
f.ring[i&blockMask] = 0
|
||||
}
|
||||
f.last = counter
|
||||
} else if f.last-counter > windowSize { // behind current window
|
||||
return false
|
||||
}
|
||||
// check and set bit
|
||||
indexBlock &= blockMask
|
||||
indexBit := counter & bitMask
|
||||
old := f.ring[indexBlock]
|
||||
new := old | 1<<indexBit
|
||||
f.ring[indexBlock] = new
|
||||
return old != new
|
||||
}
|
||||
117
vendor/github.com/tailscale/wireguard-go/rwcancel/rwcancel.go
generated
vendored
Normal file
117
vendor/github.com/tailscale/wireguard-go/rwcancel/rwcancel.go
generated
vendored
Normal file
@@ -0,0 +1,117 @@
|
||||
//go:build !windows && !wasm && !plan9 && !tamago
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
// Package rwcancel implements cancelable read/write operations on
|
||||
// a file descriptor.
|
||||
package rwcancel
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type RWCancel struct {
|
||||
fd int
|
||||
closingReader *os.File
|
||||
closingWriter *os.File
|
||||
}
|
||||
|
||||
func NewRWCancel(fd int) (*RWCancel, error) {
|
||||
err := unix.SetNonblock(fd, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rwcancel := RWCancel{fd: fd}
|
||||
|
||||
rwcancel.closingReader, rwcancel.closingWriter, err = os.Pipe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &rwcancel, nil
|
||||
}
|
||||
|
||||
func RetryAfterError(err error) bool {
|
||||
return errors.Is(err, syscall.EAGAIN) || errors.Is(err, syscall.EINTR)
|
||||
}
|
||||
|
||||
func (rw *RWCancel) ReadyRead() bool {
|
||||
closeFd := int32(rw.closingReader.Fd())
|
||||
|
||||
pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLIN}, {Fd: closeFd, Events: unix.POLLIN}}
|
||||
var err error
|
||||
for {
|
||||
_, err = unix.Poll(pollFds, -1)
|
||||
if err == nil || !RetryAfterError(err) {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if pollFds[1].Revents != 0 {
|
||||
return false
|
||||
}
|
||||
return pollFds[0].Revents != 0
|
||||
}
|
||||
|
||||
func (rw *RWCancel) ReadyWrite() bool {
|
||||
closeFd := int32(rw.closingReader.Fd())
|
||||
pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLOUT}, {Fd: closeFd, Events: unix.POLLOUT}}
|
||||
var err error
|
||||
for {
|
||||
_, err = unix.Poll(pollFds, -1)
|
||||
if err == nil || !RetryAfterError(err) {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if pollFds[1].Revents != 0 {
|
||||
return false
|
||||
}
|
||||
return pollFds[0].Revents != 0
|
||||
}
|
||||
|
||||
func (rw *RWCancel) Read(p []byte) (n int, err error) {
|
||||
for {
|
||||
n, err := unix.Read(rw.fd, p)
|
||||
if err == nil || !RetryAfterError(err) {
|
||||
return n, err
|
||||
}
|
||||
if !rw.ReadyRead() {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rw *RWCancel) Write(p []byte) (n int, err error) {
|
||||
for {
|
||||
n, err := unix.Write(rw.fd, p)
|
||||
if err == nil || !RetryAfterError(err) {
|
||||
return n, err
|
||||
}
|
||||
if !rw.ReadyWrite() {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rw *RWCancel) Cancel() (err error) {
|
||||
_, err = rw.closingWriter.Write([]byte{0})
|
||||
return
|
||||
}
|
||||
|
||||
func (rw *RWCancel) Close() {
|
||||
rw.closingReader.Close()
|
||||
rw.closingWriter.Close()
|
||||
}
|
||||
9
vendor/github.com/tailscale/wireguard-go/rwcancel/rwcancel_stub.go
generated
vendored
Normal file
9
vendor/github.com/tailscale/wireguard-go/rwcancel/rwcancel_stub.go
generated
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
//go:build windows || wasm || plan9 || tamago
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package rwcancel
|
||||
|
||||
type RWCancel struct{}
|
||||
|
||||
func (*RWCancel) Cancel() {}
|
||||
41
vendor/github.com/tailscale/wireguard-go/tai64n/tai64n.go
generated
vendored
Normal file
41
vendor/github.com/tailscale/wireguard-go/tai64n/tai64n.go
generated
vendored
Normal file
@@ -0,0 +1,41 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tai64n
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
TimestampSize = 12
|
||||
base = uint64(0x400000000000000a)
|
||||
whitenerMask = uint32(0x1000000 - 1)
|
||||
)
|
||||
|
||||
type Timestamp [TimestampSize]byte
|
||||
|
||||
func stamp(t time.Time) Timestamp {
|
||||
var tai64n Timestamp
|
||||
secs := base + uint64(t.Unix())
|
||||
nano := uint32(t.Nanosecond()) &^ whitenerMask
|
||||
binary.BigEndian.PutUint64(tai64n[:], secs)
|
||||
binary.BigEndian.PutUint32(tai64n[8:], nano)
|
||||
return tai64n
|
||||
}
|
||||
|
||||
func Now() Timestamp {
|
||||
return stamp(time.Now())
|
||||
}
|
||||
|
||||
func (t1 Timestamp) After(t2 Timestamp) bool {
|
||||
return bytes.Compare(t1[:], t2[:]) > 0
|
||||
}
|
||||
|
||||
func (t Timestamp) String() string {
|
||||
return time.Unix(int64(binary.BigEndian.Uint64(t[:8])-base), int64(binary.BigEndian.Uint32(t[8:12]))).String()
|
||||
}
|
||||
712
vendor/github.com/tailscale/wireguard-go/tun/checksum.go
generated
vendored
Normal file
712
vendor/github.com/tailscale/wireguard-go/tun/checksum.go
generated
vendored
Normal file
@@ -0,0 +1,712 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math/bits"
|
||||
"strconv"
|
||||
|
||||
"golang.org/x/sys/cpu"
|
||||
)
|
||||
|
||||
// checksumGeneric64 is a reference implementation of checksum using 64 bit
|
||||
// arithmetic for use in testing or when an architecture-specific implementation
|
||||
// is not available.
|
||||
func checksumGeneric64(b []byte, initial uint16) uint16 {
|
||||
var ac uint64
|
||||
var carry uint64
|
||||
|
||||
if cpu.IsBigEndian {
|
||||
ac = uint64(initial)
|
||||
} else {
|
||||
ac = uint64(bits.ReverseBytes16(initial))
|
||||
}
|
||||
|
||||
for len(b) >= 128 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[32:40]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[40:48]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[48:56]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[56:64]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[64:72]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[72:80]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[80:88]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[88:96]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[96:104]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[104:112]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[112:120]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[120:128]), carry)
|
||||
} else {
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[32:40]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[40:48]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[48:56]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[56:64]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[64:72]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[72:80]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[80:88]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[88:96]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[96:104]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[104:112]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[112:120]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[120:128]), carry)
|
||||
}
|
||||
b = b[128:]
|
||||
}
|
||||
if len(b) >= 64 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[32:40]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[40:48]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[48:56]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[56:64]), carry)
|
||||
} else {
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[32:40]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[40:48]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[48:56]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[56:64]), carry)
|
||||
}
|
||||
b = b[64:]
|
||||
}
|
||||
if len(b) >= 32 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry)
|
||||
} else {
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry)
|
||||
}
|
||||
b = b[32:]
|
||||
}
|
||||
if len(b) >= 16 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry)
|
||||
} else {
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry)
|
||||
}
|
||||
b = b[16:]
|
||||
}
|
||||
if len(b) >= 8 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b), carry)
|
||||
} else {
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b), carry)
|
||||
}
|
||||
b = b[8:]
|
||||
}
|
||||
if len(b) >= 4 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add64(ac, uint64(binary.BigEndian.Uint32(b)), carry)
|
||||
} else {
|
||||
ac, carry = bits.Add64(ac, uint64(binary.LittleEndian.Uint32(b)), carry)
|
||||
}
|
||||
b = b[4:]
|
||||
}
|
||||
if len(b) >= 2 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add64(ac, uint64(binary.BigEndian.Uint16(b)), carry)
|
||||
} else {
|
||||
ac, carry = bits.Add64(ac, uint64(binary.LittleEndian.Uint16(b)), carry)
|
||||
}
|
||||
b = b[2:]
|
||||
}
|
||||
if len(b) >= 1 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add64(ac, uint64(b[0])<<8, carry)
|
||||
} else {
|
||||
ac, carry = bits.Add64(ac, uint64(b[0]), carry)
|
||||
}
|
||||
}
|
||||
|
||||
folded := ipChecksumFold64(ac, carry)
|
||||
if !cpu.IsBigEndian {
|
||||
folded = bits.ReverseBytes16(folded)
|
||||
}
|
||||
return folded
|
||||
}
|
||||
|
||||
// checksumGeneric32 is a reference implementation of checksum using 32 bit
|
||||
// arithmetic for use in testing or when an architecture-specific implementation
|
||||
// is not available.
|
||||
func checksumGeneric32(b []byte, initial uint16) uint16 {
|
||||
var ac uint32
|
||||
var carry uint32
|
||||
|
||||
if cpu.IsBigEndian {
|
||||
ac = uint32(initial)
|
||||
} else {
|
||||
ac = uint32(bits.ReverseBytes16(initial))
|
||||
}
|
||||
|
||||
for len(b) >= 64 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:8]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[16:20]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[20:24]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[24:28]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[28:32]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[32:36]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[36:40]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[40:44]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[44:48]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[48:52]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[52:56]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[56:60]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[60:64]), carry)
|
||||
} else {
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:8]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[16:20]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[20:24]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[24:28]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[28:32]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[32:36]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[36:40]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[40:44]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[44:48]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[48:52]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[52:56]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[56:60]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[60:64]), carry)
|
||||
}
|
||||
b = b[64:]
|
||||
}
|
||||
if len(b) >= 32 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[16:20]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[20:24]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[24:28]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[28:32]), carry)
|
||||
} else {
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[16:20]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[20:24]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[24:28]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[28:32]), carry)
|
||||
}
|
||||
b = b[32:]
|
||||
}
|
||||
if len(b) >= 16 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry)
|
||||
} else {
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry)
|
||||
}
|
||||
b = b[16:]
|
||||
}
|
||||
if len(b) >= 8 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry)
|
||||
} else {
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry)
|
||||
}
|
||||
b = b[8:]
|
||||
}
|
||||
if len(b) >= 4 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b), carry)
|
||||
} else {
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b), carry)
|
||||
}
|
||||
b = b[4:]
|
||||
}
|
||||
if len(b) >= 2 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add32(ac, uint32(binary.BigEndian.Uint16(b)), carry)
|
||||
} else {
|
||||
ac, carry = bits.Add32(ac, uint32(binary.LittleEndian.Uint16(b)), carry)
|
||||
}
|
||||
b = b[2:]
|
||||
}
|
||||
if len(b) >= 1 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add32(ac, uint32(b[0])<<8, carry)
|
||||
} else {
|
||||
ac, carry = bits.Add32(ac, uint32(b[0]), carry)
|
||||
}
|
||||
}
|
||||
|
||||
folded := ipChecksumFold32(ac, carry)
|
||||
if !cpu.IsBigEndian {
|
||||
folded = bits.ReverseBytes16(folded)
|
||||
}
|
||||
return folded
|
||||
}
|
||||
|
||||
// checksumGeneric32Alternate is an alternate reference implementation of
|
||||
// checksum using 32 bit arithmetic for use in testing or when an
|
||||
// architecture-specific implementation is not available.
|
||||
func checksumGeneric32Alternate(b []byte, initial uint16) uint16 {
|
||||
var ac uint32
|
||||
|
||||
if cpu.IsBigEndian {
|
||||
ac = uint32(initial)
|
||||
} else {
|
||||
ac = uint32(bits.ReverseBytes16(initial))
|
||||
}
|
||||
|
||||
for len(b) >= 64 {
|
||||
if cpu.IsBigEndian {
|
||||
ac += uint32(binary.BigEndian.Uint16(b[:2]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[4:6]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[6:8]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[8:10]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[10:12]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[12:14]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[14:16]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[16:18]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[18:20]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[20:22]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[22:24]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[24:26]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[26:28]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[28:30]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[30:32]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[32:34]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[34:36]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[36:38]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[38:40]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[40:42]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[42:44]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[44:46]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[46:48]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[48:50]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[50:52]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[52:54]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[54:56]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[56:58]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[58:60]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[60:62]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[62:64]))
|
||||
} else {
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[4:6]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[6:8]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[8:10]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[10:12]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[12:14]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[14:16]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[16:18]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[18:20]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[20:22]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[22:24]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[24:26]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[26:28]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[28:30]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[30:32]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[32:34]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[34:36]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[36:38]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[38:40]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[40:42]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[42:44]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[44:46]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[46:48]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[48:50]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[50:52]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[52:54]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[54:56]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[56:58]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[58:60]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[60:62]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[62:64]))
|
||||
}
|
||||
b = b[64:]
|
||||
}
|
||||
if len(b) >= 32 {
|
||||
if cpu.IsBigEndian {
|
||||
ac += uint32(binary.BigEndian.Uint16(b[:2]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[4:6]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[6:8]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[8:10]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[10:12]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[12:14]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[14:16]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[16:18]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[18:20]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[20:22]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[22:24]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[24:26]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[26:28]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[28:30]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[30:32]))
|
||||
} else {
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[4:6]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[6:8]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[8:10]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[10:12]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[12:14]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[14:16]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[16:18]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[18:20]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[20:22]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[22:24]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[24:26]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[26:28]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[28:30]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[30:32]))
|
||||
}
|
||||
b = b[32:]
|
||||
}
|
||||
if len(b) >= 16 {
|
||||
if cpu.IsBigEndian {
|
||||
ac += uint32(binary.BigEndian.Uint16(b[:2]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[4:6]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[6:8]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[8:10]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[10:12]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[12:14]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[14:16]))
|
||||
} else {
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[4:6]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[6:8]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[8:10]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[10:12]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[12:14]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[14:16]))
|
||||
}
|
||||
b = b[16:]
|
||||
}
|
||||
if len(b) >= 8 {
|
||||
if cpu.IsBigEndian {
|
||||
ac += uint32(binary.BigEndian.Uint16(b[:2]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[4:6]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[6:8]))
|
||||
} else {
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[4:6]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[6:8]))
|
||||
}
|
||||
b = b[8:]
|
||||
}
|
||||
if len(b) >= 4 {
|
||||
if cpu.IsBigEndian {
|
||||
ac += uint32(binary.BigEndian.Uint16(b[:2]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
|
||||
} else {
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
|
||||
}
|
||||
b = b[4:]
|
||||
}
|
||||
if len(b) >= 2 {
|
||||
if cpu.IsBigEndian {
|
||||
ac += uint32(binary.BigEndian.Uint16(b))
|
||||
} else {
|
||||
ac += uint32(binary.LittleEndian.Uint16(b))
|
||||
}
|
||||
b = b[2:]
|
||||
}
|
||||
if len(b) >= 1 {
|
||||
if cpu.IsBigEndian {
|
||||
ac += uint32(b[0]) << 8
|
||||
} else {
|
||||
ac += uint32(b[0])
|
||||
}
|
||||
}
|
||||
|
||||
folded := ipChecksumFold32(ac, 0)
|
||||
if !cpu.IsBigEndian {
|
||||
folded = bits.ReverseBytes16(folded)
|
||||
}
|
||||
return folded
|
||||
}
|
||||
|
||||
// checksumGeneric64Alternate is an alternate reference implementation of
|
||||
// checksum using 64 bit arithmetic for use in testing or when an
|
||||
// architecture-specific implementation is not available.
|
||||
func checksumGeneric64Alternate(b []byte, initial uint16) uint16 {
|
||||
var ac uint64
|
||||
|
||||
if cpu.IsBigEndian {
|
||||
ac = uint64(initial)
|
||||
} else {
|
||||
ac = uint64(bits.ReverseBytes16(initial))
|
||||
}
|
||||
|
||||
for len(b) >= 64 {
|
||||
if cpu.IsBigEndian {
|
||||
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[32:36]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[36:40]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[40:44]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[44:48]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[48:52]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[52:56]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[56:60]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[60:64]))
|
||||
} else {
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[:4]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[4:8]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[8:12]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[12:16]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[16:20]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[20:24]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[24:28]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[28:32]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[32:36]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[36:40]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[40:44]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[44:48]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[48:52]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[52:56]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[56:60]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[60:64]))
|
||||
}
|
||||
b = b[64:]
|
||||
}
|
||||
if len(b) >= 32 {
|
||||
if cpu.IsBigEndian {
|
||||
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
|
||||
} else {
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[:4]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[4:8]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[8:12]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[12:16]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[16:20]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[20:24]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[24:28]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[28:32]))
|
||||
}
|
||||
b = b[32:]
|
||||
}
|
||||
if len(b) >= 16 {
|
||||
if cpu.IsBigEndian {
|
||||
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
|
||||
} else {
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[:4]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[4:8]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[8:12]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[12:16]))
|
||||
}
|
||||
b = b[16:]
|
||||
}
|
||||
if len(b) >= 8 {
|
||||
if cpu.IsBigEndian {
|
||||
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||
} else {
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[:4]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[4:8]))
|
||||
}
|
||||
b = b[8:]
|
||||
}
|
||||
if len(b) >= 4 {
|
||||
if cpu.IsBigEndian {
|
||||
ac += uint64(binary.BigEndian.Uint32(b))
|
||||
} else {
|
||||
ac += uint64(binary.LittleEndian.Uint32(b))
|
||||
}
|
||||
b = b[4:]
|
||||
}
|
||||
if len(b) >= 2 {
|
||||
if cpu.IsBigEndian {
|
||||
ac += uint64(binary.BigEndian.Uint16(b))
|
||||
} else {
|
||||
ac += uint64(binary.LittleEndian.Uint16(b))
|
||||
}
|
||||
b = b[2:]
|
||||
}
|
||||
if len(b) >= 1 {
|
||||
if cpu.IsBigEndian {
|
||||
ac += uint64(b[0]) << 8
|
||||
} else {
|
||||
ac += uint64(b[0])
|
||||
}
|
||||
}
|
||||
|
||||
folded := ipChecksumFold64(ac, 0)
|
||||
if !cpu.IsBigEndian {
|
||||
folded = bits.ReverseBytes16(folded)
|
||||
}
|
||||
return folded
|
||||
}
|
||||
|
||||
func ipChecksumFold64(unfolded uint64, initialCarry uint64) uint16 {
|
||||
sum, carry := bits.Add32(uint32(unfolded>>32), uint32(unfolded&0xffff_ffff), uint32(initialCarry))
|
||||
// if carry != 0, sum <= 0xffff_fffe, otherwise sum <= 0xffff_ffff
|
||||
// therefore (sum >> 16) + (sum & 0xffff) + carry <= 0x1_fffe; so there is
|
||||
// no need to save the carry flag
|
||||
sum = (sum >> 16) + (sum & 0xffff) + carry
|
||||
// sum <= 0x1_fffe therefore this is the last fold needed:
|
||||
// if (sum >> 16) > 0 then
|
||||
// (sum >> 16) == 1 && (sum & 0xffff) <= 0xfffe and therefore
|
||||
// the addition will not overflow
|
||||
// otherwise (sum >> 16) == 0 and sum will be unchanged
|
||||
sum = (sum >> 16) + (sum & 0xffff)
|
||||
return uint16(sum)
|
||||
}
|
||||
|
||||
func ipChecksumFold32(unfolded uint32, initialCarry uint32) uint16 {
|
||||
sum := (unfolded >> 16) + (unfolded & 0xffff) + initialCarry
|
||||
// sum <= 0x1_ffff:
|
||||
// 0xffff + 0xffff = 0x1_fffe
|
||||
// initialCarry is 0 or 1, for a combined maximum of 0x1_ffff
|
||||
sum = (sum >> 16) + (sum & 0xffff)
|
||||
// sum <= 0x1_0000 therefore this is the last fold needed:
|
||||
// if (sum >> 16) > 0 then
|
||||
// (sum >> 16) == 1 && (sum & 0xffff) == 0 and therefore
|
||||
// the addition will not overflow
|
||||
// otherwise (sum >> 16) == 0 and sum will be unchanged
|
||||
sum = (sum >> 16) + (sum & 0xffff)
|
||||
return uint16(sum)
|
||||
}
|
||||
|
||||
func addrPartialChecksum64(addr []byte, initial, carryIn uint64) (sum, carry uint64) {
|
||||
sum, carry = initial, carryIn
|
||||
switch len(addr) {
|
||||
case 4: // IPv4
|
||||
if cpu.IsBigEndian {
|
||||
sum, carry = bits.Add64(sum, uint64(binary.BigEndian.Uint32(addr)), carry)
|
||||
} else {
|
||||
sum, carry = bits.Add64(sum, uint64(binary.LittleEndian.Uint32(addr)), carry)
|
||||
}
|
||||
case 16: // IPv6
|
||||
if cpu.IsBigEndian {
|
||||
sum, carry = bits.Add64(sum, binary.BigEndian.Uint64(addr), carry)
|
||||
sum, carry = bits.Add64(sum, binary.BigEndian.Uint64(addr[8:]), carry)
|
||||
} else {
|
||||
sum, carry = bits.Add64(sum, binary.LittleEndian.Uint64(addr), carry)
|
||||
sum, carry = bits.Add64(sum, binary.LittleEndian.Uint64(addr[8:]), carry)
|
||||
}
|
||||
default:
|
||||
panic("bad addr length")
|
||||
}
|
||||
return sum, carry
|
||||
}
|
||||
|
||||
func addrPartialChecksum32(addr []byte, initial, carryIn uint32) (sum, carry uint32) {
|
||||
sum, carry = initial, carryIn
|
||||
switch len(addr) {
|
||||
case 4: // IPv4
|
||||
if cpu.IsBigEndian {
|
||||
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr), carry)
|
||||
} else {
|
||||
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr), carry)
|
||||
}
|
||||
case 16: // IPv6
|
||||
if cpu.IsBigEndian {
|
||||
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr), carry)
|
||||
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[4:8]), carry)
|
||||
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[8:12]), carry)
|
||||
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[12:16]), carry)
|
||||
} else {
|
||||
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr), carry)
|
||||
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[4:8]), carry)
|
||||
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[8:12]), carry)
|
||||
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[12:16]), carry)
|
||||
}
|
||||
default:
|
||||
panic("bad addr length")
|
||||
}
|
||||
return sum, carry
|
||||
}
|
||||
|
||||
func pseudoHeaderChecksum64(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 {
|
||||
var sum uint64
|
||||
if cpu.IsBigEndian {
|
||||
sum = uint64(totalLen) + uint64(protocol)
|
||||
} else {
|
||||
sum = uint64(bits.ReverseBytes16(totalLen)) + uint64(protocol)<<8
|
||||
}
|
||||
sum, carry := addrPartialChecksum64(srcAddr, sum, 0)
|
||||
sum, carry = addrPartialChecksum64(dstAddr, sum, carry)
|
||||
|
||||
foldedSum := ipChecksumFold64(sum, carry)
|
||||
if !cpu.IsBigEndian {
|
||||
foldedSum = bits.ReverseBytes16(foldedSum)
|
||||
}
|
||||
return foldedSum
|
||||
}
|
||||
|
||||
func pseudoHeaderChecksum32(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 {
|
||||
var sum uint32
|
||||
if cpu.IsBigEndian {
|
||||
sum = uint32(totalLen) + uint32(protocol)
|
||||
} else {
|
||||
sum = uint32(bits.ReverseBytes16(totalLen)) + uint32(protocol)<<8
|
||||
}
|
||||
sum, carry := addrPartialChecksum32(srcAddr, sum, 0)
|
||||
sum, carry = addrPartialChecksum32(dstAddr, sum, carry)
|
||||
|
||||
foldedSum := ipChecksumFold32(sum, carry)
|
||||
if !cpu.IsBigEndian {
|
||||
foldedSum = bits.ReverseBytes16(foldedSum)
|
||||
}
|
||||
return foldedSum
|
||||
}
|
||||
|
||||
// PseudoHeaderChecksum computes an IP pseudo-header checksum. srcAddr and
|
||||
// dstAddr must be 4 or 16 bytes in length.
|
||||
func PseudoHeaderChecksum(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 {
|
||||
if strconv.IntSize < 64 {
|
||||
return pseudoHeaderChecksum32(protocol, srcAddr, dstAddr, totalLen)
|
||||
}
|
||||
return pseudoHeaderChecksum64(protocol, srcAddr, dstAddr, totalLen)
|
||||
}
|
||||
23
vendor/github.com/tailscale/wireguard-go/tun/checksum_amd64.go
generated
vendored
Normal file
23
vendor/github.com/tailscale/wireguard-go/tun/checksum_amd64.go
generated
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
package tun
|
||||
|
||||
import "golang.org/x/sys/cpu"
|
||||
|
||||
var checksum = checksumAMD64
|
||||
|
||||
// Checksum computes an IP checksum starting with the provided initial value.
|
||||
// The length of data should be at least 128 bytes for best performance. Smaller
|
||||
// buffers will still compute a correct result.
|
||||
func Checksum(data []byte, initial uint16) uint16 {
|
||||
return checksum(data, initial)
|
||||
}
|
||||
|
||||
func init() {
|
||||
if cpu.X86.HasAVX && cpu.X86.HasAVX2 && cpu.X86.HasBMI2 {
|
||||
checksum = checksumAVX2
|
||||
return
|
||||
}
|
||||
if cpu.X86.HasSSE2 {
|
||||
checksum = checksumSSE2
|
||||
return
|
||||
}
|
||||
}
|
||||
18
vendor/github.com/tailscale/wireguard-go/tun/checksum_generated_amd64.go
generated
vendored
Normal file
18
vendor/github.com/tailscale/wireguard-go/tun/checksum_generated_amd64.go
generated
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
// Code generated by command: go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go. DO NOT EDIT.
|
||||
|
||||
package tun
|
||||
|
||||
// checksumAVX2 computes an IP checksum using amd64 v3 instructions (AVX2, BMI2)
|
||||
//
|
||||
//go:noescape
|
||||
func checksumAVX2(b []byte, initial uint16) uint16
|
||||
|
||||
// checksumSSE2 computes an IP checksum using amd64 baseline instructions (SSE2)
|
||||
//
|
||||
//go:noescape
|
||||
func checksumSSE2(b []byte, initial uint16) uint16
|
||||
|
||||
// checksumAMD64 computes an IP checksum using amd64 baseline instructions
|
||||
//
|
||||
//go:noescape
|
||||
func checksumAMD64(b []byte, initial uint16) uint16
|
||||
851
vendor/github.com/tailscale/wireguard-go/tun/checksum_generated_amd64.s
generated
vendored
Normal file
851
vendor/github.com/tailscale/wireguard-go/tun/checksum_generated_amd64.s
generated
vendored
Normal file
@@ -0,0 +1,851 @@
|
||||
// Code generated by command: go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go. DO NOT EDIT.
|
||||
|
||||
#include "textflag.h"
|
||||
|
||||
DATA xmmLoadMasks<>+0(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff"
|
||||
DATA xmmLoadMasks<>+16(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff"
|
||||
DATA xmmLoadMasks<>+32(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff"
|
||||
DATA xmmLoadMasks<>+48(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff"
|
||||
DATA xmmLoadMasks<>+64(SB)/16, $"\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
|
||||
DATA xmmLoadMasks<>+80(SB)/16, $"\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
|
||||
DATA xmmLoadMasks<>+96(SB)/16, $"\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
|
||||
GLOBL xmmLoadMasks<>(SB), RODATA|NOPTR, $112
|
||||
|
||||
// func checksumAVX2(b []byte, initial uint16) uint16
|
||||
// Requires: AVX, AVX2, BMI2
|
||||
TEXT ·checksumAVX2(SB), NOSPLIT|NOFRAME, $0-34
|
||||
MOVWQZX initial+24(FP), AX
|
||||
XCHGB AH, AL
|
||||
MOVQ b_base+0(FP), DX
|
||||
MOVQ b_len+8(FP), BX
|
||||
|
||||
// handle odd length buffers; they are difficult to handle in general
|
||||
TESTQ $0x00000001, BX
|
||||
JZ lengthIsEven
|
||||
MOVBQZX -1(DX)(BX*1), CX
|
||||
DECQ BX
|
||||
ADDQ CX, AX
|
||||
|
||||
lengthIsEven:
|
||||
// handle tiny buffers (<=31 bytes) specially
|
||||
CMPQ BX, $0x1f
|
||||
JGT bufferIsNotTiny
|
||||
XORQ CX, CX
|
||||
XORQ SI, SI
|
||||
XORQ DI, DI
|
||||
|
||||
// shift twice to start because length is guaranteed to be even
|
||||
// n = n >> 2; CF = originalN & 2
|
||||
SHRQ $0x02, BX
|
||||
JNC handleTiny4
|
||||
|
||||
// tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]
|
||||
MOVWQZX (DX), CX
|
||||
ADDQ $0x02, DX
|
||||
|
||||
handleTiny4:
|
||||
// n = n >> 1; CF = originalN & 4
|
||||
SHRQ $0x01, BX
|
||||
JNC handleTiny8
|
||||
|
||||
// tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]
|
||||
MOVLQZX (DX), SI
|
||||
ADDQ $0x04, DX
|
||||
|
||||
handleTiny8:
|
||||
// n = n >> 1; CF = originalN & 8
|
||||
SHRQ $0x01, BX
|
||||
JNC handleTiny16
|
||||
|
||||
// tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]
|
||||
MOVQ (DX), DI
|
||||
ADDQ $0x08, DX
|
||||
|
||||
handleTiny16:
|
||||
// n = n >> 1; CF = originalN & 16
|
||||
// n == 0 now, otherwise we would have branched after comparing with tinyBufferSize
|
||||
SHRQ $0x01, BX
|
||||
JNC handleTinyFinish
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
|
||||
handleTinyFinish:
|
||||
// CF should be included from the previous add, so we use ADCQ.
|
||||
// If we arrived via the JNC above, then CF=0 due to the branch condition,
|
||||
// so ADCQ will still produce the correct result.
|
||||
ADCQ CX, AX
|
||||
ADCQ SI, AX
|
||||
ADCQ DI, AX
|
||||
JMP foldAndReturn
|
||||
|
||||
bufferIsNotTiny:
|
||||
// skip all SIMD for small buffers
|
||||
CMPQ BX, $0x00000100
|
||||
JGE startSIMD
|
||||
|
||||
// Accumulate carries in this register. It is never expected to overflow.
|
||||
XORQ SI, SI
|
||||
|
||||
// We will perform an overlapped read for buffers with length not a multiple of 8.
|
||||
// Overlapped in this context means some memory will be read twice, but a shift will
|
||||
// eliminate the duplicated data. This extra read is performed at the end of the buffer to
|
||||
// preserve any alignment that may exist for the start of the buffer.
|
||||
MOVQ BX, CX
|
||||
SHRQ $0x03, BX
|
||||
ANDQ $0x07, CX
|
||||
JZ handleRemaining8
|
||||
LEAQ (DX)(BX*8), DI
|
||||
MOVQ -8(DI)(CX*1), DI
|
||||
|
||||
// Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)
|
||||
SHLQ $0x03, CX
|
||||
NEGQ CX
|
||||
ADDQ $0x40, CX
|
||||
SHRQ CL, DI
|
||||
ADDQ DI, AX
|
||||
ADCQ $0x00, SI
|
||||
|
||||
handleRemaining8:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemaining16
|
||||
ADDQ (DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x08, DX
|
||||
|
||||
handleRemaining16:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemaining32
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x10, DX
|
||||
|
||||
handleRemaining32:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemaining64
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
ADCQ 16(DX), AX
|
||||
ADCQ 24(DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x20, DX
|
||||
|
||||
handleRemaining64:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemaining128
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
ADCQ 16(DX), AX
|
||||
ADCQ 24(DX), AX
|
||||
ADCQ 32(DX), AX
|
||||
ADCQ 40(DX), AX
|
||||
ADCQ 48(DX), AX
|
||||
ADCQ 56(DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x40, DX
|
||||
|
||||
handleRemaining128:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemainingComplete
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
ADCQ 16(DX), AX
|
||||
ADCQ 24(DX), AX
|
||||
ADCQ 32(DX), AX
|
||||
ADCQ 40(DX), AX
|
||||
ADCQ 48(DX), AX
|
||||
ADCQ 56(DX), AX
|
||||
ADCQ 64(DX), AX
|
||||
ADCQ 72(DX), AX
|
||||
ADCQ 80(DX), AX
|
||||
ADCQ 88(DX), AX
|
||||
ADCQ 96(DX), AX
|
||||
ADCQ 104(DX), AX
|
||||
ADCQ 112(DX), AX
|
||||
ADCQ 120(DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x80, DX
|
||||
|
||||
handleRemainingComplete:
|
||||
ADDQ SI, AX
|
||||
JMP foldAndReturn
|
||||
|
||||
startSIMD:
|
||||
VPXOR Y0, Y0, Y0
|
||||
VPXOR Y1, Y1, Y1
|
||||
VPXOR Y2, Y2, Y2
|
||||
VPXOR Y3, Y3, Y3
|
||||
MOVQ BX, CX
|
||||
|
||||
// Update number of bytes remaining after the loop completes
|
||||
ANDQ $0xff, BX
|
||||
|
||||
// Number of 256 byte iterations
|
||||
SHRQ $0x08, CX
|
||||
JZ smallLoop
|
||||
|
||||
bigLoop:
|
||||
VPMOVZXWD (DX), Y4
|
||||
VPADDD Y4, Y0, Y0
|
||||
VPMOVZXWD 16(DX), Y4
|
||||
VPADDD Y4, Y1, Y1
|
||||
VPMOVZXWD 32(DX), Y4
|
||||
VPADDD Y4, Y2, Y2
|
||||
VPMOVZXWD 48(DX), Y4
|
||||
VPADDD Y4, Y3, Y3
|
||||
VPMOVZXWD 64(DX), Y4
|
||||
VPADDD Y4, Y0, Y0
|
||||
VPMOVZXWD 80(DX), Y4
|
||||
VPADDD Y4, Y1, Y1
|
||||
VPMOVZXWD 96(DX), Y4
|
||||
VPADDD Y4, Y2, Y2
|
||||
VPMOVZXWD 112(DX), Y4
|
||||
VPADDD Y4, Y3, Y3
|
||||
VPMOVZXWD 128(DX), Y4
|
||||
VPADDD Y4, Y0, Y0
|
||||
VPMOVZXWD 144(DX), Y4
|
||||
VPADDD Y4, Y1, Y1
|
||||
VPMOVZXWD 160(DX), Y4
|
||||
VPADDD Y4, Y2, Y2
|
||||
VPMOVZXWD 176(DX), Y4
|
||||
VPADDD Y4, Y3, Y3
|
||||
VPMOVZXWD 192(DX), Y4
|
||||
VPADDD Y4, Y0, Y0
|
||||
VPMOVZXWD 208(DX), Y4
|
||||
VPADDD Y4, Y1, Y1
|
||||
VPMOVZXWD 224(DX), Y4
|
||||
VPADDD Y4, Y2, Y2
|
||||
VPMOVZXWD 240(DX), Y4
|
||||
VPADDD Y4, Y3, Y3
|
||||
ADDQ $0x00000100, DX
|
||||
DECQ CX
|
||||
JNZ bigLoop
|
||||
CMPQ BX, $0x10
|
||||
JLT doneSmallLoop
|
||||
|
||||
// now read a single 16 byte unit of data at a time
|
||||
smallLoop:
|
||||
VPMOVZXWD (DX), Y4
|
||||
VPADDD Y4, Y0, Y0
|
||||
ADDQ $0x10, DX
|
||||
SUBQ $0x10, BX
|
||||
CMPQ BX, $0x10
|
||||
JGE smallLoop
|
||||
|
||||
doneSmallLoop:
|
||||
CMPQ BX, $0x00
|
||||
JE doneSIMD
|
||||
|
||||
// There are between 1 and 15 bytes remaining. Perform an overlapped read.
|
||||
LEAQ xmmLoadMasks<>+0(SB), CX
|
||||
VMOVDQU -16(DX)(BX*1), X4
|
||||
VPAND -16(CX)(BX*8), X4, X4
|
||||
VPMOVZXWD X4, Y4
|
||||
VPADDD Y4, Y0, Y0
|
||||
|
||||
doneSIMD:
|
||||
// Multi-chain loop is done, combine the accumulators
|
||||
VPADDD Y1, Y0, Y0
|
||||
VPADDD Y2, Y0, Y0
|
||||
VPADDD Y3, Y0, Y0
|
||||
|
||||
// extract the YMM into a pair of XMM and sum them
|
||||
VEXTRACTI128 $0x01, Y0, X1
|
||||
VPADDD X0, X1, X0
|
||||
|
||||
// extract the XMM into GP64
|
||||
VPEXTRQ $0x00, X0, CX
|
||||
VPEXTRQ $0x01, X0, DX
|
||||
|
||||
// no more AVX code, clear upper registers to avoid SSE slowdowns
|
||||
VZEROUPPER
|
||||
ADDQ CX, AX
|
||||
ADCQ DX, AX
|
||||
|
||||
foldAndReturn:
|
||||
// add CF and fold
|
||||
RORXQ $0x20, AX, CX
|
||||
ADCL CX, AX
|
||||
RORXL $0x10, AX, CX
|
||||
ADCW CX, AX
|
||||
ADCW $0x00, AX
|
||||
XCHGB AH, AL
|
||||
MOVW AX, ret+32(FP)
|
||||
RET
|
||||
|
||||
// func checksumSSE2(b []byte, initial uint16) uint16
|
||||
// Requires: SSE2
|
||||
TEXT ·checksumSSE2(SB), NOSPLIT|NOFRAME, $0-34
|
||||
MOVWQZX initial+24(FP), AX
|
||||
XCHGB AH, AL
|
||||
MOVQ b_base+0(FP), DX
|
||||
MOVQ b_len+8(FP), BX
|
||||
|
||||
// handle odd length buffers; they are difficult to handle in general
|
||||
TESTQ $0x00000001, BX
|
||||
JZ lengthIsEven
|
||||
MOVBQZX -1(DX)(BX*1), CX
|
||||
DECQ BX
|
||||
ADDQ CX, AX
|
||||
|
||||
lengthIsEven:
|
||||
// handle tiny buffers (<=31 bytes) specially
|
||||
CMPQ BX, $0x1f
|
||||
JGT bufferIsNotTiny
|
||||
XORQ CX, CX
|
||||
XORQ SI, SI
|
||||
XORQ DI, DI
|
||||
|
||||
// shift twice to start because length is guaranteed to be even
|
||||
// n = n >> 2; CF = originalN & 2
|
||||
SHRQ $0x02, BX
|
||||
JNC handleTiny4
|
||||
|
||||
// tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]
|
||||
MOVWQZX (DX), CX
|
||||
ADDQ $0x02, DX
|
||||
|
||||
handleTiny4:
|
||||
// n = n >> 1; CF = originalN & 4
|
||||
SHRQ $0x01, BX
|
||||
JNC handleTiny8
|
||||
|
||||
// tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]
|
||||
MOVLQZX (DX), SI
|
||||
ADDQ $0x04, DX
|
||||
|
||||
handleTiny8:
|
||||
// n = n >> 1; CF = originalN & 8
|
||||
SHRQ $0x01, BX
|
||||
JNC handleTiny16
|
||||
|
||||
// tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]
|
||||
MOVQ (DX), DI
|
||||
ADDQ $0x08, DX
|
||||
|
||||
handleTiny16:
|
||||
// n = n >> 1; CF = originalN & 16
|
||||
// n == 0 now, otherwise we would have branched after comparing with tinyBufferSize
|
||||
SHRQ $0x01, BX
|
||||
JNC handleTinyFinish
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
|
||||
handleTinyFinish:
|
||||
// CF should be included from the previous add, so we use ADCQ.
|
||||
// If we arrived via the JNC above, then CF=0 due to the branch condition,
|
||||
// so ADCQ will still produce the correct result.
|
||||
ADCQ CX, AX
|
||||
ADCQ SI, AX
|
||||
ADCQ DI, AX
|
||||
JMP foldAndReturn
|
||||
|
||||
bufferIsNotTiny:
|
||||
// skip all SIMD for small buffers
|
||||
CMPQ BX, $0x00000100
|
||||
JGE startSIMD
|
||||
|
||||
// Accumulate carries in this register. It is never expected to overflow.
|
||||
XORQ SI, SI
|
||||
|
||||
// We will perform an overlapped read for buffers with length not a multiple of 8.
|
||||
// Overlapped in this context means some memory will be read twice, but a shift will
|
||||
// eliminate the duplicated data. This extra read is performed at the end of the buffer to
|
||||
// preserve any alignment that may exist for the start of the buffer.
|
||||
MOVQ BX, CX
|
||||
SHRQ $0x03, BX
|
||||
ANDQ $0x07, CX
|
||||
JZ handleRemaining8
|
||||
LEAQ (DX)(BX*8), DI
|
||||
MOVQ -8(DI)(CX*1), DI
|
||||
|
||||
// Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)
|
||||
SHLQ $0x03, CX
|
||||
NEGQ CX
|
||||
ADDQ $0x40, CX
|
||||
SHRQ CL, DI
|
||||
ADDQ DI, AX
|
||||
ADCQ $0x00, SI
|
||||
|
||||
handleRemaining8:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemaining16
|
||||
ADDQ (DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x08, DX
|
||||
|
||||
handleRemaining16:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemaining32
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x10, DX
|
||||
|
||||
handleRemaining32:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemaining64
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
ADCQ 16(DX), AX
|
||||
ADCQ 24(DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x20, DX
|
||||
|
||||
handleRemaining64:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemaining128
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
ADCQ 16(DX), AX
|
||||
ADCQ 24(DX), AX
|
||||
ADCQ 32(DX), AX
|
||||
ADCQ 40(DX), AX
|
||||
ADCQ 48(DX), AX
|
||||
ADCQ 56(DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x40, DX
|
||||
|
||||
handleRemaining128:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemainingComplete
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
ADCQ 16(DX), AX
|
||||
ADCQ 24(DX), AX
|
||||
ADCQ 32(DX), AX
|
||||
ADCQ 40(DX), AX
|
||||
ADCQ 48(DX), AX
|
||||
ADCQ 56(DX), AX
|
||||
ADCQ 64(DX), AX
|
||||
ADCQ 72(DX), AX
|
||||
ADCQ 80(DX), AX
|
||||
ADCQ 88(DX), AX
|
||||
ADCQ 96(DX), AX
|
||||
ADCQ 104(DX), AX
|
||||
ADCQ 112(DX), AX
|
||||
ADCQ 120(DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x80, DX
|
||||
|
||||
handleRemainingComplete:
|
||||
ADDQ SI, AX
|
||||
JMP foldAndReturn
|
||||
|
||||
startSIMD:
|
||||
PXOR X0, X0
|
||||
PXOR X1, X1
|
||||
PXOR X2, X2
|
||||
PXOR X3, X3
|
||||
PXOR X4, X4
|
||||
MOVQ BX, CX
|
||||
|
||||
// Update number of bytes remaining after the loop completes
|
||||
ANDQ $0xff, BX
|
||||
|
||||
// Number of 256 byte iterations
|
||||
SHRQ $0x08, CX
|
||||
JZ smallLoop
|
||||
|
||||
bigLoop:
|
||||
MOVOU (DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X0
|
||||
PADDD X6, X2
|
||||
MOVOU 16(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X1
|
||||
PADDD X6, X3
|
||||
MOVOU 32(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X2
|
||||
PADDD X6, X0
|
||||
MOVOU 48(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X3
|
||||
PADDD X6, X1
|
||||
MOVOU 64(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X0
|
||||
PADDD X6, X2
|
||||
MOVOU 80(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X1
|
||||
PADDD X6, X3
|
||||
MOVOU 96(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X2
|
||||
PADDD X6, X0
|
||||
MOVOU 112(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X3
|
||||
PADDD X6, X1
|
||||
MOVOU 128(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X0
|
||||
PADDD X6, X2
|
||||
MOVOU 144(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X1
|
||||
PADDD X6, X3
|
||||
MOVOU 160(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X2
|
||||
PADDD X6, X0
|
||||
MOVOU 176(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X3
|
||||
PADDD X6, X1
|
||||
MOVOU 192(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X0
|
||||
PADDD X6, X2
|
||||
MOVOU 208(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X1
|
||||
PADDD X6, X3
|
||||
MOVOU 224(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X2
|
||||
PADDD X6, X0
|
||||
MOVOU 240(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X3
|
||||
PADDD X6, X1
|
||||
ADDQ $0x00000100, DX
|
||||
DECQ CX
|
||||
JNZ bigLoop
|
||||
CMPQ BX, $0x10
|
||||
JLT doneSmallLoop
|
||||
|
||||
// now read a single 16 byte unit of data at a time
|
||||
smallLoop:
|
||||
MOVOU (DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X0
|
||||
PADDD X6, X1
|
||||
ADDQ $0x10, DX
|
||||
SUBQ $0x10, BX
|
||||
CMPQ BX, $0x10
|
||||
JGE smallLoop
|
||||
|
||||
doneSmallLoop:
|
||||
CMPQ BX, $0x00
|
||||
JE doneSIMD
|
||||
|
||||
// There are between 1 and 15 bytes remaining. Perform an overlapped read.
|
||||
LEAQ xmmLoadMasks<>+0(SB), CX
|
||||
MOVOU -16(DX)(BX*1), X5
|
||||
PAND -16(CX)(BX*8), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X0
|
||||
PADDD X6, X1
|
||||
|
||||
doneSIMD:
|
||||
// Multi-chain loop is done, combine the accumulators
|
||||
PADDD X1, X0
|
||||
PADDD X2, X0
|
||||
PADDD X3, X0
|
||||
|
||||
// extract the XMM into GP64
|
||||
MOVQ X0, CX
|
||||
PSRLDQ $0x08, X0
|
||||
MOVQ X0, DX
|
||||
ADDQ CX, AX
|
||||
ADCQ DX, AX
|
||||
|
||||
foldAndReturn:
|
||||
// add CF and fold
|
||||
MOVL AX, CX
|
||||
ADCQ $0x00, CX
|
||||
SHRQ $0x20, AX
|
||||
ADDQ CX, AX
|
||||
MOVWQZX AX, CX
|
||||
SHRQ $0x10, AX
|
||||
ADDQ CX, AX
|
||||
MOVW AX, CX
|
||||
SHRQ $0x10, AX
|
||||
ADDW CX, AX
|
||||
ADCW $0x00, AX
|
||||
XCHGB AH, AL
|
||||
MOVW AX, ret+32(FP)
|
||||
RET
|
||||
|
||||
// func checksumAMD64(b []byte, initial uint16) uint16
|
||||
TEXT ·checksumAMD64(SB), NOSPLIT|NOFRAME, $0-34
|
||||
MOVWQZX initial+24(FP), AX
|
||||
XCHGB AH, AL
|
||||
MOVQ b_base+0(FP), DX
|
||||
MOVQ b_len+8(FP), BX
|
||||
|
||||
// handle odd length buffers; they are difficult to handle in general
|
||||
TESTQ $0x00000001, BX
|
||||
JZ lengthIsEven
|
||||
MOVBQZX -1(DX)(BX*1), CX
|
||||
DECQ BX
|
||||
ADDQ CX, AX
|
||||
|
||||
lengthIsEven:
|
||||
// handle tiny buffers (<=31 bytes) specially
|
||||
CMPQ BX, $0x1f
|
||||
JGT bufferIsNotTiny
|
||||
XORQ CX, CX
|
||||
XORQ SI, SI
|
||||
XORQ DI, DI
|
||||
|
||||
// shift twice to start because length is guaranteed to be even
|
||||
// n = n >> 2; CF = originalN & 2
|
||||
SHRQ $0x02, BX
|
||||
JNC handleTiny4
|
||||
|
||||
// tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]
|
||||
MOVWQZX (DX), CX
|
||||
ADDQ $0x02, DX
|
||||
|
||||
handleTiny4:
|
||||
// n = n >> 1; CF = originalN & 4
|
||||
SHRQ $0x01, BX
|
||||
JNC handleTiny8
|
||||
|
||||
// tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]
|
||||
MOVLQZX (DX), SI
|
||||
ADDQ $0x04, DX
|
||||
|
||||
handleTiny8:
|
||||
// n = n >> 1; CF = originalN & 8
|
||||
SHRQ $0x01, BX
|
||||
JNC handleTiny16
|
||||
|
||||
// tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]
|
||||
MOVQ (DX), DI
|
||||
ADDQ $0x08, DX
|
||||
|
||||
handleTiny16:
|
||||
// n = n >> 1; CF = originalN & 16
|
||||
// n == 0 now, otherwise we would have branched after comparing with tinyBufferSize
|
||||
SHRQ $0x01, BX
|
||||
JNC handleTinyFinish
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
|
||||
handleTinyFinish:
|
||||
// CF should be included from the previous add, so we use ADCQ.
|
||||
// If we arrived via the JNC above, then CF=0 due to the branch condition,
|
||||
// so ADCQ will still produce the correct result.
|
||||
ADCQ CX, AX
|
||||
ADCQ SI, AX
|
||||
ADCQ DI, AX
|
||||
JMP foldAndReturn
|
||||
|
||||
bufferIsNotTiny:
|
||||
// Number of 256 byte iterations into loop counter
|
||||
MOVQ BX, CX
|
||||
|
||||
// Update number of bytes remaining after the loop completes
|
||||
ANDQ $0xff, BX
|
||||
SHRQ $0x08, CX
|
||||
JZ startCleanup
|
||||
CLC
|
||||
XORQ SI, SI
|
||||
XORQ DI, DI
|
||||
XORQ R8, R8
|
||||
XORQ R9, R9
|
||||
XORQ R10, R10
|
||||
XORQ R11, R11
|
||||
XORQ R12, R12
|
||||
|
||||
bigLoop:
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
ADCQ 16(DX), AX
|
||||
ADCQ 24(DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ 32(DX), DI
|
||||
ADCQ 40(DX), DI
|
||||
ADCQ 48(DX), DI
|
||||
ADCQ 56(DX), DI
|
||||
ADCQ $0x00, R8
|
||||
ADDQ 64(DX), R9
|
||||
ADCQ 72(DX), R9
|
||||
ADCQ 80(DX), R9
|
||||
ADCQ 88(DX), R9
|
||||
ADCQ $0x00, R10
|
||||
ADDQ 96(DX), R11
|
||||
ADCQ 104(DX), R11
|
||||
ADCQ 112(DX), R11
|
||||
ADCQ 120(DX), R11
|
||||
ADCQ $0x00, R12
|
||||
ADDQ 128(DX), AX
|
||||
ADCQ 136(DX), AX
|
||||
ADCQ 144(DX), AX
|
||||
ADCQ 152(DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ 160(DX), DI
|
||||
ADCQ 168(DX), DI
|
||||
ADCQ 176(DX), DI
|
||||
ADCQ 184(DX), DI
|
||||
ADCQ $0x00, R8
|
||||
ADDQ 192(DX), R9
|
||||
ADCQ 200(DX), R9
|
||||
ADCQ 208(DX), R9
|
||||
ADCQ 216(DX), R9
|
||||
ADCQ $0x00, R10
|
||||
ADDQ 224(DX), R11
|
||||
ADCQ 232(DX), R11
|
||||
ADCQ 240(DX), R11
|
||||
ADCQ 248(DX), R11
|
||||
ADCQ $0x00, R12
|
||||
ADDQ $0x00000100, DX
|
||||
SUBQ $0x01, CX
|
||||
JNZ bigLoop
|
||||
ADDQ SI, AX
|
||||
ADCQ DI, AX
|
||||
ADCQ R8, AX
|
||||
ADCQ R9, AX
|
||||
ADCQ R10, AX
|
||||
ADCQ R11, AX
|
||||
ADCQ R12, AX
|
||||
|
||||
// accumulate CF (twice, in case the first time overflows)
|
||||
ADCQ $0x00, AX
|
||||
ADCQ $0x00, AX
|
||||
|
||||
startCleanup:
|
||||
// Accumulate carries in this register. It is never expected to overflow.
|
||||
XORQ SI, SI
|
||||
|
||||
// We will perform an overlapped read for buffers with length not a multiple of 8.
|
||||
// Overlapped in this context means some memory will be read twice, but a shift will
|
||||
// eliminate the duplicated data. This extra read is performed at the end of the buffer to
|
||||
// preserve any alignment that may exist for the start of the buffer.
|
||||
MOVQ BX, CX
|
||||
SHRQ $0x03, BX
|
||||
ANDQ $0x07, CX
|
||||
JZ handleRemaining8
|
||||
LEAQ (DX)(BX*8), DI
|
||||
MOVQ -8(DI)(CX*1), DI
|
||||
|
||||
// Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)
|
||||
SHLQ $0x03, CX
|
||||
NEGQ CX
|
||||
ADDQ $0x40, CX
|
||||
SHRQ CL, DI
|
||||
ADDQ DI, AX
|
||||
ADCQ $0x00, SI
|
||||
|
||||
handleRemaining8:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemaining16
|
||||
ADDQ (DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x08, DX
|
||||
|
||||
handleRemaining16:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemaining32
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x10, DX
|
||||
|
||||
handleRemaining32:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemaining64
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
ADCQ 16(DX), AX
|
||||
ADCQ 24(DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x20, DX
|
||||
|
||||
handleRemaining64:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemaining128
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
ADCQ 16(DX), AX
|
||||
ADCQ 24(DX), AX
|
||||
ADCQ 32(DX), AX
|
||||
ADCQ 40(DX), AX
|
||||
ADCQ 48(DX), AX
|
||||
ADCQ 56(DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x40, DX
|
||||
|
||||
handleRemaining128:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemainingComplete
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
ADCQ 16(DX), AX
|
||||
ADCQ 24(DX), AX
|
||||
ADCQ 32(DX), AX
|
||||
ADCQ 40(DX), AX
|
||||
ADCQ 48(DX), AX
|
||||
ADCQ 56(DX), AX
|
||||
ADCQ 64(DX), AX
|
||||
ADCQ 72(DX), AX
|
||||
ADCQ 80(DX), AX
|
||||
ADCQ 88(DX), AX
|
||||
ADCQ 96(DX), AX
|
||||
ADCQ 104(DX), AX
|
||||
ADCQ 112(DX), AX
|
||||
ADCQ 120(DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x80, DX
|
||||
|
||||
handleRemainingComplete:
|
||||
ADDQ SI, AX
|
||||
|
||||
foldAndReturn:
|
||||
// add CF and fold
|
||||
MOVL AX, CX
|
||||
ADCQ $0x00, CX
|
||||
SHRQ $0x20, AX
|
||||
ADDQ CX, AX
|
||||
MOVWQZX AX, CX
|
||||
SHRQ $0x10, AX
|
||||
ADDQ CX, AX
|
||||
MOVW AX, CX
|
||||
SHRQ $0x10, AX
|
||||
ADDW CX, AX
|
||||
ADCW $0x00, AX
|
||||
XCHGB AH, AL
|
||||
MOVW AX, ret+32(FP)
|
||||
RET
|
||||
15
vendor/github.com/tailscale/wireguard-go/tun/checksum_generic.go
generated
vendored
Normal file
15
vendor/github.com/tailscale/wireguard-go/tun/checksum_generic.go
generated
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
// This file contains IP checksum algorithms that are not specific to any
|
||||
// architecture and don't use hardware acceleration.
|
||||
|
||||
//go:build !amd64
|
||||
|
||||
package tun
|
||||
|
||||
import "strconv"
|
||||
|
||||
func Checksum(data []byte, initial uint16) uint16 {
|
||||
if strconv.IntSize < 64 {
|
||||
return checksumGeneric32(data, initial)
|
||||
}
|
||||
return checksumGeneric64(data, initial)
|
||||
}
|
||||
12
vendor/github.com/tailscale/wireguard-go/tun/errors.go
generated
vendored
Normal file
12
vendor/github.com/tailscale/wireguard-go/tun/errors.go
generated
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrTooManySegments is returned by Device.Read() when segmentation
|
||||
// overflows the length of supplied buffers. This error should not cause
|
||||
// reads to cease.
|
||||
ErrTooManySegments = errors.New("too many segments")
|
||||
)
|
||||
220
vendor/github.com/tailscale/wireguard-go/tun/offload.go
generated
vendored
Normal file
220
vendor/github.com/tailscale/wireguard-go/tun/offload.go
generated
vendored
Normal file
@@ -0,0 +1,220 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// GSOType represents the type of segmentation offload.
|
||||
type GSOType int
|
||||
|
||||
const (
|
||||
GSONone GSOType = iota
|
||||
GSOTCPv4
|
||||
GSOTCPv6
|
||||
GSOUDPL4
|
||||
)
|
||||
|
||||
func (g GSOType) String() string {
|
||||
switch g {
|
||||
case GSONone:
|
||||
return "GSONone"
|
||||
case GSOTCPv4:
|
||||
return "GSOTCPv4"
|
||||
case GSOTCPv6:
|
||||
return "GSOTCPv6"
|
||||
case GSOUDPL4:
|
||||
return "GSOUDPL4"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// GSOOptions is loosely modeled after struct virtio_net_hdr from the VIRTIO
|
||||
// specification. It is a common representation of GSO metadata that can be
|
||||
// applied to support packet GSO across tun.Device implementations.
|
||||
type GSOOptions struct {
|
||||
// GSOType represents the type of segmentation offload.
|
||||
GSOType GSOType
|
||||
// HdrLen is the sum of the layer 3 and 4 header lengths. This field may be
|
||||
// zero when GSOType == GSONone.
|
||||
HdrLen uint16
|
||||
// CsumStart is the head byte index of the packet data to be checksummed,
|
||||
// i.e. the start of the TCP or UDP header.
|
||||
CsumStart uint16
|
||||
// CsumOffset is the offset from CsumStart where the 2-byte checksum value
|
||||
// should be placed.
|
||||
CsumOffset uint16
|
||||
// GSOSize is the size of each segment exclusive of HdrLen. The tail segment
|
||||
// may be smaller than this value.
|
||||
GSOSize uint16
|
||||
// NeedsCsum may be set where GSOType == GSONone. When set, the checksum
|
||||
// at CsumStart + CsumOffset must be a partial checksum, i.e. the
|
||||
// pseudo-header sum.
|
||||
NeedsCsum bool
|
||||
}
|
||||
|
||||
const (
|
||||
ipv4SrcAddrOffset = 12
|
||||
ipv6SrcAddrOffset = 8
|
||||
)
|
||||
|
||||
const tcpFlagsOffset = 13
|
||||
|
||||
const (
|
||||
tcpFlagFIN uint8 = 0x01
|
||||
tcpFlagPSH uint8 = 0x08
|
||||
tcpFlagACK uint8 = 0x10
|
||||
)
|
||||
|
||||
const (
|
||||
// defined here in order to avoid importation of any platform-specific pkgs
|
||||
ipProtoTCP = 6
|
||||
ipProtoUDP = 17
|
||||
)
|
||||
|
||||
// GSOSplit splits packets from 'in' into outBufs[<index>][outOffset:], writing
|
||||
// the size of each element into sizes. It returns the number of buffers
|
||||
// populated, and/or an error. Callers may pass an 'in' slice that overlaps with
|
||||
// the first element of outBuffers, i.e. &in[0] may be equal to
|
||||
// &outBufs[0][outOffset]. GSONone is a valid options.GSOType regardless of the
|
||||
// value of options.NeedsCsum. Length of each outBufs element must be greater
|
||||
// than or equal to the length of 'in', otherwise output may be silently
|
||||
// truncated.
|
||||
func GSOSplit(in []byte, options GSOOptions, outBufs [][]byte, sizes []int, outOffset int) (int, error) {
|
||||
cSumAt := int(options.CsumStart) + int(options.CsumOffset)
|
||||
if cSumAt+1 >= len(in) {
|
||||
return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
|
||||
}
|
||||
|
||||
if len(in) < int(options.HdrLen) {
|
||||
return 0, fmt.Errorf("length of packet (%d) < GSO HdrLen (%d)", len(in), options.HdrLen)
|
||||
}
|
||||
|
||||
// Handle the conditions where we are copying a single element to outBuffs.
|
||||
payloadLen := len(in) - int(options.HdrLen)
|
||||
if options.GSOType == GSONone || payloadLen < int(options.GSOSize) {
|
||||
if len(in) > len(outBufs[0][outOffset:]) {
|
||||
return 0, fmt.Errorf("length of packet (%d) exceeds output element length (%d)", len(in), len(outBufs[0][outOffset:]))
|
||||
}
|
||||
if options.NeedsCsum {
|
||||
// The initial value at the checksum offset should be summed with
|
||||
// the checksum we compute. This is typically the pseudo-header sum.
|
||||
initial := binary.BigEndian.Uint16(in[cSumAt:])
|
||||
in[cSumAt], in[cSumAt+1] = 0, 0
|
||||
binary.BigEndian.PutUint16(in[cSumAt:], ^Checksum(in[options.CsumStart:], initial))
|
||||
}
|
||||
sizes[0] = copy(outBufs[0][outOffset:], in)
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
if options.HdrLen < options.CsumStart {
|
||||
return 0, fmt.Errorf("GSO HdrLen (%d) < GSO CsumStart (%d)", options.HdrLen, options.CsumStart)
|
||||
}
|
||||
|
||||
ipVersion := in[0] >> 4
|
||||
switch ipVersion {
|
||||
case 4:
|
||||
if options.GSOType != GSOTCPv4 && options.GSOType != GSOUDPL4 {
|
||||
return 0, fmt.Errorf("ip header version: %d, GSO type: %s", ipVersion, options.GSOType)
|
||||
}
|
||||
if len(in) < 20 {
|
||||
return 0, fmt.Errorf("length of packet (%d) < minimum ipv4 header size (%d)", len(in), 20)
|
||||
}
|
||||
case 6:
|
||||
if options.GSOType != GSOTCPv6 && options.GSOType != GSOUDPL4 {
|
||||
return 0, fmt.Errorf("ip header version: %d, GSO type: %s", ipVersion, options.GSOType)
|
||||
}
|
||||
if len(in) < 40 {
|
||||
return 0, fmt.Errorf("length of packet (%d) < minimum ipv6 header size (%d)", len(in), 40)
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
|
||||
}
|
||||
|
||||
iphLen := int(options.CsumStart)
|
||||
srcAddrOffset := ipv6SrcAddrOffset
|
||||
addrLen := 16
|
||||
if ipVersion == 4 {
|
||||
srcAddrOffset = ipv4SrcAddrOffset
|
||||
addrLen = 4
|
||||
}
|
||||
transportCsumAt := int(options.CsumStart + options.CsumOffset)
|
||||
var firstTCPSeqNum uint32
|
||||
var protocol uint8
|
||||
if options.GSOType == GSOTCPv4 || options.GSOType == GSOTCPv6 {
|
||||
protocol = ipProtoTCP
|
||||
if len(in) < int(options.CsumStart)+20 {
|
||||
return 0, fmt.Errorf("length of packet (%d) < GSO CsumStart (%d) + minimum TCP header size (%d)",
|
||||
len(in), options.CsumStart, 20)
|
||||
}
|
||||
firstTCPSeqNum = binary.BigEndian.Uint32(in[options.CsumStart+4:])
|
||||
} else {
|
||||
protocol = ipProtoUDP
|
||||
}
|
||||
nextSegmentDataAt := int(options.HdrLen)
|
||||
i := 0
|
||||
for ; nextSegmentDataAt < len(in); i++ {
|
||||
if i == len(outBufs) {
|
||||
return i - 1, ErrTooManySegments
|
||||
}
|
||||
nextSegmentEnd := nextSegmentDataAt + int(options.GSOSize)
|
||||
if nextSegmentEnd > len(in) {
|
||||
nextSegmentEnd = len(in)
|
||||
}
|
||||
segmentDataLen := nextSegmentEnd - nextSegmentDataAt
|
||||
totalLen := int(options.HdrLen) + segmentDataLen
|
||||
sizes[i] = totalLen
|
||||
out := outBufs[i][outOffset:]
|
||||
|
||||
copy(out, in[:iphLen])
|
||||
if ipVersion == 4 {
|
||||
// For IPv4 we are responsible for incrementing the ID field,
|
||||
// updating the total len field, and recalculating the header
|
||||
// checksum.
|
||||
if i > 0 {
|
||||
id := binary.BigEndian.Uint16(out[4:])
|
||||
id += uint16(i)
|
||||
binary.BigEndian.PutUint16(out[4:], id)
|
||||
}
|
||||
out[10], out[11] = 0, 0 // clear ipv4 header checksum
|
||||
binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
|
||||
ipv4CSum := ^Checksum(out[:iphLen], 0)
|
||||
binary.BigEndian.PutUint16(out[10:], ipv4CSum)
|
||||
} else {
|
||||
// For IPv6 we are responsible for updating the payload length field.
|
||||
binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
|
||||
}
|
||||
|
||||
// copy transport header
|
||||
copy(out[options.CsumStart:options.HdrLen], in[options.CsumStart:options.HdrLen])
|
||||
|
||||
if protocol == ipProtoTCP {
|
||||
// set TCP seq and adjust TCP flags
|
||||
tcpSeq := firstTCPSeqNum + uint32(options.GSOSize*uint16(i))
|
||||
binary.BigEndian.PutUint32(out[options.CsumStart+4:], tcpSeq)
|
||||
if nextSegmentEnd != len(in) {
|
||||
// FIN and PSH should only be set on last segment
|
||||
clearFlags := tcpFlagFIN | tcpFlagPSH
|
||||
out[options.CsumStart+tcpFlagsOffset] &^= clearFlags
|
||||
}
|
||||
} else {
|
||||
// set UDP header len
|
||||
binary.BigEndian.PutUint16(out[options.CsumStart+4:], uint16(segmentDataLen)+(options.HdrLen-options.CsumStart))
|
||||
}
|
||||
|
||||
// payload
|
||||
copy(out[options.HdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
|
||||
|
||||
// transport checksum
|
||||
out[transportCsumAt], out[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum
|
||||
transportHeaderLen := int(options.HdrLen - options.CsumStart)
|
||||
lenForPseudo := uint16(transportHeaderLen + segmentDataLen)
|
||||
transportCSum := PseudoHeaderChecksum(protocol, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo)
|
||||
transportCSum = ^Checksum(out[options.CsumStart:totalLen], transportCSum)
|
||||
binary.BigEndian.PutUint16(out[options.CsumStart+options.CsumOffset:], transportCSum)
|
||||
|
||||
nextSegmentDataAt += int(options.GSOSize)
|
||||
}
|
||||
return i, nil
|
||||
}
|
||||
911
vendor/github.com/tailscale/wireguard-go/tun/offload_linux.go
generated
vendored
Normal file
911
vendor/github.com/tailscale/wireguard-go/tun/offload_linux.go
generated
vendored
Normal file
@@ -0,0 +1,911 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"unsafe"
|
||||
|
||||
"github.com/tailscale/wireguard-go/conn"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
|
||||
// kernel symbol is virtio_net_hdr.
|
||||
type virtioNetHdr struct {
|
||||
flags uint8
|
||||
gsoType uint8
|
||||
hdrLen uint16
|
||||
gsoSize uint16
|
||||
csumStart uint16
|
||||
csumOffset uint16
|
||||
}
|
||||
|
||||
func (v *virtioNetHdr) toGSOOptions() (GSOOptions, error) {
|
||||
var gsoType GSOType
|
||||
switch v.gsoType {
|
||||
case unix.VIRTIO_NET_HDR_GSO_NONE:
|
||||
gsoType = GSONone
|
||||
case unix.VIRTIO_NET_HDR_GSO_TCPV4:
|
||||
gsoType = GSOTCPv4
|
||||
case unix.VIRTIO_NET_HDR_GSO_TCPV6:
|
||||
gsoType = GSOTCPv6
|
||||
case unix.VIRTIO_NET_HDR_GSO_UDP_L4:
|
||||
gsoType = GSOUDPL4
|
||||
default:
|
||||
return GSOOptions{}, fmt.Errorf("unsupported virtio gsoType: %d", v.gsoType)
|
||||
}
|
||||
return GSOOptions{
|
||||
GSOType: gsoType,
|
||||
HdrLen: v.hdrLen,
|
||||
CsumStart: v.csumStart,
|
||||
CsumOffset: v.csumOffset,
|
||||
GSOSize: v.gsoSize,
|
||||
NeedsCsum: v.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (v *virtioNetHdr) decode(b []byte) error {
|
||||
if len(b) < virtioNetHdrLen {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen])
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *virtioNetHdr) encode(b []byte) error {
|
||||
if len(b) < virtioNetHdrLen {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen))
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
// virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the
|
||||
// shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr).
|
||||
virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
|
||||
)
|
||||
|
||||
// tcpFlowKey represents the key for a TCP flow.
|
||||
type tcpFlowKey struct {
|
||||
srcAddr, dstAddr [16]byte
|
||||
srcPort, dstPort uint16
|
||||
rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows.
|
||||
isV6 bool
|
||||
}
|
||||
|
||||
// tcpGROTable holds flow and coalescing information for the purposes of TCP GRO.
|
||||
type tcpGROTable struct {
|
||||
itemsByFlow map[tcpFlowKey][]tcpGROItem
|
||||
itemsPool [][]tcpGROItem
|
||||
}
|
||||
|
||||
func newTCPGROTable() *tcpGROTable {
|
||||
t := &tcpGROTable{
|
||||
itemsByFlow: make(map[tcpFlowKey][]tcpGROItem, conn.IdealBatchSize),
|
||||
itemsPool: make([][]tcpGROItem, conn.IdealBatchSize),
|
||||
}
|
||||
for i := range t.itemsPool {
|
||||
t.itemsPool[i] = make([]tcpGROItem, 0, conn.IdealBatchSize)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func newTCPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset int) tcpFlowKey {
|
||||
key := tcpFlowKey{}
|
||||
addrSize := dstAddrOffset - srcAddrOffset
|
||||
copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset])
|
||||
copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize])
|
||||
key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
|
||||
key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
|
||||
key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
|
||||
key.isV6 = addrSize == 16
|
||||
return key
|
||||
}
|
||||
|
||||
// lookupOrInsert looks up a flow for the provided packet and metadata,
|
||||
// returning the packets found for the flow, or inserting a new one if none
|
||||
// is found.
|
||||
func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) {
|
||||
key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
|
||||
items, ok := t.itemsByFlow[key]
|
||||
if ok {
|
||||
return items, ok
|
||||
}
|
||||
// TODO: insert() performs another map lookup. This could be rearranged to avoid.
|
||||
t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// insert an item in the table for the provided packet and packet metadata.
|
||||
func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) {
|
||||
key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
|
||||
item := tcpGROItem{
|
||||
key: key,
|
||||
bufsIndex: uint16(bufsIndex),
|
||||
gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])),
|
||||
iphLen: uint8(tcphOffset),
|
||||
tcphLen: uint8(tcphLen),
|
||||
sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]),
|
||||
pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0,
|
||||
}
|
||||
items, ok := t.itemsByFlow[key]
|
||||
if !ok {
|
||||
items = t.newItems()
|
||||
}
|
||||
items = append(items, item)
|
||||
t.itemsByFlow[key] = items
|
||||
}
|
||||
|
||||
func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
|
||||
items, _ := t.itemsByFlow[item.key]
|
||||
items[i] = item
|
||||
}
|
||||
|
||||
func (t *tcpGROTable) deleteAt(key tcpFlowKey, i int) {
|
||||
items, _ := t.itemsByFlow[key]
|
||||
items = append(items[:i], items[i+1:]...)
|
||||
t.itemsByFlow[key] = items
|
||||
}
|
||||
|
||||
// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
|
||||
// of a GRO evaluation across a vector of packets.
|
||||
type tcpGROItem struct {
|
||||
key tcpFlowKey
|
||||
sentSeq uint32 // the sequence number
|
||||
bufsIndex uint16 // the index into the original bufs slice
|
||||
numMerged uint16 // the number of packets merged into this item
|
||||
gsoSize uint16 // payload size
|
||||
iphLen uint8 // ip header len
|
||||
tcphLen uint8 // tcp header len
|
||||
pshSet bool // psh flag is set
|
||||
}
|
||||
|
||||
func (t *tcpGROTable) newItems() []tcpGROItem {
|
||||
var items []tcpGROItem
|
||||
items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1]
|
||||
return items
|
||||
}
|
||||
|
||||
func (t *tcpGROTable) reset() {
|
||||
for k, items := range t.itemsByFlow {
|
||||
items = items[:0]
|
||||
t.itemsPool = append(t.itemsPool, items)
|
||||
delete(t.itemsByFlow, k)
|
||||
}
|
||||
}
|
||||
|
||||
// udpFlowKey represents the key for a UDP flow.
|
||||
type udpFlowKey struct {
|
||||
srcAddr, dstAddr [16]byte
|
||||
srcPort, dstPort uint16
|
||||
isV6 bool
|
||||
}
|
||||
|
||||
// udpGROTable holds flow and coalescing information for the purposes of UDP GRO.
|
||||
type udpGROTable struct {
|
||||
itemsByFlow map[udpFlowKey][]udpGROItem
|
||||
itemsPool [][]udpGROItem
|
||||
}
|
||||
|
||||
func newUDPGROTable() *udpGROTable {
|
||||
u := &udpGROTable{
|
||||
itemsByFlow: make(map[udpFlowKey][]udpGROItem, conn.IdealBatchSize),
|
||||
itemsPool: make([][]udpGROItem, conn.IdealBatchSize),
|
||||
}
|
||||
for i := range u.itemsPool {
|
||||
u.itemsPool[i] = make([]udpGROItem, 0, conn.IdealBatchSize)
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
func newUDPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int) udpFlowKey {
|
||||
key := udpFlowKey{}
|
||||
addrSize := dstAddrOffset - srcAddrOffset
|
||||
copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset])
|
||||
copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize])
|
||||
key.srcPort = binary.BigEndian.Uint16(pkt[udphOffset:])
|
||||
key.dstPort = binary.BigEndian.Uint16(pkt[udphOffset+2:])
|
||||
key.isV6 = addrSize == 16
|
||||
return key
|
||||
}
|
||||
|
||||
// lookupOrInsert looks up a flow for the provided packet and metadata,
|
||||
// returning the packets found for the flow, or inserting a new one if none
|
||||
// is found.
|
||||
func (u *udpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int) ([]udpGROItem, bool) {
|
||||
key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset)
|
||||
items, ok := u.itemsByFlow[key]
|
||||
if ok {
|
||||
return items, ok
|
||||
}
|
||||
// TODO: insert() performs another map lookup. This could be rearranged to avoid.
|
||||
u.insert(pkt, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex, false)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// insert an item in the table for the provided packet and packet metadata.
|
||||
func (u *udpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int, cSumKnownInvalid bool) {
|
||||
key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset)
|
||||
item := udpGROItem{
|
||||
key: key,
|
||||
bufsIndex: uint16(bufsIndex),
|
||||
gsoSize: uint16(len(pkt[udphOffset+udphLen:])),
|
||||
iphLen: uint8(udphOffset),
|
||||
cSumKnownInvalid: cSumKnownInvalid,
|
||||
}
|
||||
items, ok := u.itemsByFlow[key]
|
||||
if !ok {
|
||||
items = u.newItems()
|
||||
}
|
||||
items = append(items, item)
|
||||
u.itemsByFlow[key] = items
|
||||
}
|
||||
|
||||
func (u *udpGROTable) updateAt(item udpGROItem, i int) {
|
||||
items, _ := u.itemsByFlow[item.key]
|
||||
items[i] = item
|
||||
}
|
||||
|
||||
// udpGROItem represents bookkeeping data for a UDP packet during the lifetime
|
||||
// of a GRO evaluation across a vector of packets.
|
||||
type udpGROItem struct {
|
||||
key udpFlowKey
|
||||
bufsIndex uint16 // the index into the original bufs slice
|
||||
numMerged uint16 // the number of packets merged into this item
|
||||
gsoSize uint16 // payload size
|
||||
iphLen uint8 // ip header len
|
||||
cSumKnownInvalid bool // UDP header checksum validity; a false value DOES NOT imply valid, just unknown.
|
||||
}
|
||||
|
||||
func (u *udpGROTable) newItems() []udpGROItem {
|
||||
var items []udpGROItem
|
||||
items, u.itemsPool = u.itemsPool[len(u.itemsPool)-1], u.itemsPool[:len(u.itemsPool)-1]
|
||||
return items
|
||||
}
|
||||
|
||||
func (u *udpGROTable) reset() {
|
||||
for k, items := range u.itemsByFlow {
|
||||
items = items[:0]
|
||||
u.itemsPool = append(u.itemsPool, items)
|
||||
delete(u.itemsByFlow, k)
|
||||
}
|
||||
}
|
||||
|
||||
// canCoalesce represents the outcome of checking if two TCP packets are
|
||||
// candidates for coalescing.
|
||||
type canCoalesce int
|
||||
|
||||
const (
|
||||
coalescePrepend canCoalesce = -1
|
||||
coalesceUnavailable canCoalesce = 0
|
||||
coalesceAppend canCoalesce = 1
|
||||
)
|
||||
|
||||
// ipHeadersCanCoalesce returns true if the IP headers found in pktA and pktB
|
||||
// meet all requirements to be merged as part of a GRO operation, otherwise it
|
||||
// returns false.
|
||||
func ipHeadersCanCoalesce(pktA, pktB []byte) bool {
|
||||
if len(pktA) < 9 || len(pktB) < 9 {
|
||||
return false
|
||||
}
|
||||
if pktA[0]>>4 == 6 {
|
||||
if pktA[0] != pktB[0] || pktA[1]>>4 != pktB[1]>>4 {
|
||||
// cannot coalesce with unequal Traffic class values
|
||||
return false
|
||||
}
|
||||
if pktA[7] != pktB[7] {
|
||||
// cannot coalesce with unequal Hop limit values
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
if pktA[1] != pktB[1] {
|
||||
// cannot coalesce with unequal ToS values
|
||||
return false
|
||||
}
|
||||
if pktA[6]>>5 != pktB[6]>>5 {
|
||||
// cannot coalesce with unequal DF or reserved bits. MF is checked
|
||||
// further up the stack.
|
||||
return false
|
||||
}
|
||||
if pktA[8] != pktB[8] {
|
||||
// cannot coalesce with unequal TTL values
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// udpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
|
||||
// described by item. iphLen and gsoSize describe pkt. bufs is the vector of
|
||||
// packets involved in the current GRO evaluation. bufsOffset is the offset at
|
||||
// which packet data begins within bufs.
|
||||
func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
|
||||
pktTarget := bufs[item.bufsIndex][bufsOffset:]
|
||||
if !ipHeadersCanCoalesce(pkt, pktTarget) {
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if len(pktTarget[iphLen+udphLen:])%int(item.gsoSize) != 0 {
|
||||
// A smaller than gsoSize packet has been appended previously.
|
||||
// Nothing can come after a smaller packet on the end.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if gsoSize > item.gsoSize {
|
||||
// We cannot have a larger packet following a smaller one.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
return coalesceAppend
|
||||
}
|
||||
|
||||
// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
|
||||
// described by item. This function makes considerations that match the kernel's
|
||||
// GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
|
||||
func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
|
||||
pktTarget := bufs[item.bufsIndex][bufsOffset:]
|
||||
if tcphLen != item.tcphLen {
|
||||
// cannot coalesce with unequal tcp options len
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if tcphLen > 20 {
|
||||
if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) {
|
||||
// cannot coalesce with unequal tcp options
|
||||
return coalesceUnavailable
|
||||
}
|
||||
}
|
||||
if !ipHeadersCanCoalesce(pkt, pktTarget) {
|
||||
return coalesceUnavailable
|
||||
}
|
||||
// seq adjacency
|
||||
lhsLen := item.gsoSize
|
||||
lhsLen += item.numMerged * item.gsoSize
|
||||
if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective
|
||||
if item.pshSet {
|
||||
// We cannot append to a segment that has the PSH flag set, PSH
|
||||
// can only be set on the final segment in a reassembled group.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 {
|
||||
// A smaller than gsoSize packet has been appended previously.
|
||||
// Nothing can come after a smaller packet on the end.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if gsoSize > item.gsoSize {
|
||||
// We cannot have a larger packet following a smaller one.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
return coalesceAppend
|
||||
} else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective
|
||||
if pshSet {
|
||||
// We cannot prepend with a segment that has the PSH flag set, PSH
|
||||
// can only be set on the final segment in a reassembled group.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if gsoSize < item.gsoSize {
|
||||
// We cannot have a larger packet following a smaller one.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if gsoSize > item.gsoSize && item.numMerged > 0 {
|
||||
// There's at least one previous merge, and we're larger than all
|
||||
// previous. This would put multiple smaller packets on the end.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
return coalescePrepend
|
||||
}
|
||||
return coalesceUnavailable
|
||||
}
|
||||
|
||||
func checksumValid(pkt []byte, iphLen, proto uint8, isV6 bool) bool {
|
||||
srcAddrAt := ipv4SrcAddrOffset
|
||||
addrSize := 4
|
||||
if isV6 {
|
||||
srcAddrAt = ipv6SrcAddrOffset
|
||||
addrSize = 16
|
||||
}
|
||||
lenForPseudo := uint16(len(pkt) - int(iphLen))
|
||||
cSum := PseudoHeaderChecksum(proto, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], lenForPseudo)
|
||||
return ^Checksum(pkt[iphLen:], cSum) == 0
|
||||
}
|
||||
|
||||
// coalesceResult represents the result of attempting to coalesce two TCP
|
||||
// packets.
|
||||
type coalesceResult int
|
||||
|
||||
const (
|
||||
coalesceInsufficientCap coalesceResult = iota
|
||||
coalescePSHEnding
|
||||
coalesceItemInvalidCSum
|
||||
coalescePktInvalidCSum
|
||||
coalesceSuccess
|
||||
)
|
||||
|
||||
// coalesceUDPPackets attempts to coalesce pkt with the packet described by
|
||||
// item, and returns the outcome.
|
||||
func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
|
||||
pktHead := bufs[item.bufsIndex][bufsOffset:] // the packet that will end up at the front
|
||||
headersLen := item.iphLen + udphLen
|
||||
coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
|
||||
|
||||
if cap(pktHead)-bufsOffset < coalescedLen {
|
||||
// We don't want to allocate a new underlying array if capacity is
|
||||
// too small.
|
||||
return coalesceInsufficientCap
|
||||
}
|
||||
if item.numMerged == 0 {
|
||||
if item.cSumKnownInvalid || !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_UDP, isV6) {
|
||||
return coalesceItemInvalidCSum
|
||||
}
|
||||
}
|
||||
if !checksumValid(pkt, item.iphLen, unix.IPPROTO_UDP, isV6) {
|
||||
return coalescePktInvalidCSum
|
||||
}
|
||||
extendBy := len(pkt) - int(headersLen)
|
||||
bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
|
||||
copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
|
||||
|
||||
item.numMerged++
|
||||
return coalesceSuccess
|
||||
}
|
||||
|
||||
// coalesceTCPPackets attempts to coalesce pkt with the packet described by
|
||||
// item, and returns the outcome. This function may swap bufs elements in the
|
||||
// event of a prepend as item's bufs index is already being tracked for writing
|
||||
// to a Device.
|
||||
func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
|
||||
var pktHead []byte // the packet that will end up at the front
|
||||
headersLen := item.iphLen + item.tcphLen
|
||||
coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
|
||||
|
||||
// Copy data
|
||||
if mode == coalescePrepend {
|
||||
pktHead = pkt
|
||||
if cap(pkt)-bufsOffset < coalescedLen {
|
||||
// We don't want to allocate a new underlying array if capacity is
|
||||
// too small.
|
||||
return coalesceInsufficientCap
|
||||
}
|
||||
if pshSet {
|
||||
return coalescePSHEnding
|
||||
}
|
||||
if item.numMerged == 0 {
|
||||
if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) {
|
||||
return coalesceItemInvalidCSum
|
||||
}
|
||||
}
|
||||
if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) {
|
||||
return coalescePktInvalidCSum
|
||||
}
|
||||
item.sentSeq = seq
|
||||
extendBy := coalescedLen - len(pktHead)
|
||||
bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...)
|
||||
copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):])
|
||||
// Flip the slice headers in bufs as part of prepend. The index of item
|
||||
// is already being tracked for writing.
|
||||
bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex]
|
||||
} else {
|
||||
pktHead = bufs[item.bufsIndex][bufsOffset:]
|
||||
if cap(pktHead)-bufsOffset < coalescedLen {
|
||||
// We don't want to allocate a new underlying array if capacity is
|
||||
// too small.
|
||||
return coalesceInsufficientCap
|
||||
}
|
||||
if item.numMerged == 0 {
|
||||
if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) {
|
||||
return coalesceItemInvalidCSum
|
||||
}
|
||||
}
|
||||
if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) {
|
||||
return coalescePktInvalidCSum
|
||||
}
|
||||
if pshSet {
|
||||
// We are appending a segment with PSH set.
|
||||
item.pshSet = pshSet
|
||||
pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH
|
||||
}
|
||||
extendBy := len(pkt) - int(headersLen)
|
||||
bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
|
||||
copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
|
||||
}
|
||||
|
||||
if gsoSize > item.gsoSize {
|
||||
item.gsoSize = gsoSize
|
||||
}
|
||||
|
||||
item.numMerged++
|
||||
return coalesceSuccess
|
||||
}
|
||||
|
||||
const (
|
||||
ipv4FlagMoreFragments uint8 = 0x20
|
||||
)
|
||||
|
||||
const (
|
||||
maxUint16 = 1<<16 - 1
|
||||
)
|
||||
|
||||
type groResult int
|
||||
|
||||
const (
|
||||
groResultNoop groResult = iota
|
||||
groResultTableInsert
|
||||
groResultCoalesced
|
||||
)
|
||||
|
||||
// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
|
||||
// existing packets tracked in table. It returns a groResultNoop when no
|
||||
// action was taken, groResultTableInsert when the evaluated packet was
|
||||
// inserted into table, and groResultCoalesced when the evaluated packet was
|
||||
// coalesced with another packet in table.
|
||||
func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) groResult {
|
||||
pkt := bufs[pktI][offset:]
|
||||
if len(pkt) > maxUint16 {
|
||||
// A valid IPv4 or IPv6 packet will never exceed this.
|
||||
return groResultNoop
|
||||
}
|
||||
iphLen := int((pkt[0] & 0x0F) * 4)
|
||||
if isV6 {
|
||||
iphLen = 40
|
||||
ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
|
||||
if ipv6HPayloadLen != len(pkt)-iphLen {
|
||||
return groResultNoop
|
||||
}
|
||||
} else {
|
||||
totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
|
||||
if totalLen != len(pkt) {
|
||||
return groResultNoop
|
||||
}
|
||||
}
|
||||
if len(pkt) < iphLen {
|
||||
return groResultNoop
|
||||
}
|
||||
tcphLen := int((pkt[iphLen+12] >> 4) * 4)
|
||||
if tcphLen < 20 || tcphLen > 60 {
|
||||
return groResultNoop
|
||||
}
|
||||
if len(pkt) < iphLen+tcphLen {
|
||||
return groResultNoop
|
||||
}
|
||||
if !isV6 {
|
||||
if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
|
||||
// no GRO support for fragmented segments for now
|
||||
return groResultNoop
|
||||
}
|
||||
}
|
||||
tcpFlags := pkt[iphLen+tcpFlagsOffset]
|
||||
var pshSet bool
|
||||
// not a candidate if any non-ACK flags (except PSH+ACK) are set
|
||||
if tcpFlags != tcpFlagACK {
|
||||
if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
|
||||
return groResultNoop
|
||||
}
|
||||
pshSet = true
|
||||
}
|
||||
gsoSize := uint16(len(pkt) - tcphLen - iphLen)
|
||||
// not a candidate if payload len is 0
|
||||
if gsoSize < 1 {
|
||||
return groResultNoop
|
||||
}
|
||||
seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
|
||||
srcAddrOffset := ipv4SrcAddrOffset
|
||||
addrLen := 4
|
||||
if isV6 {
|
||||
srcAddrOffset = ipv6SrcAddrOffset
|
||||
addrLen = 16
|
||||
}
|
||||
items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
|
||||
if !existing {
|
||||
return groResultTableInsert
|
||||
}
|
||||
for i := len(items) - 1; i >= 0; i-- {
|
||||
// In the best case of packets arriving in order iterating in reverse is
|
||||
// more efficient if there are multiple items for a given flow. This
|
||||
// also enables a natural table.deleteAt() in the
|
||||
// coalesceItemInvalidCSum case without the need for index tracking.
|
||||
// This algorithm makes a best effort to coalesce in the event of
|
||||
// unordered packets, where pkt may land anywhere in items from a
|
||||
// sequence number perspective, however once an item is inserted into
|
||||
// the table it is never compared across other items later.
|
||||
item := items[i]
|
||||
can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset)
|
||||
if can != coalesceUnavailable {
|
||||
result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6)
|
||||
switch result {
|
||||
case coalesceSuccess:
|
||||
table.updateAt(item, i)
|
||||
return groResultCoalesced
|
||||
case coalesceItemInvalidCSum:
|
||||
// delete the item with an invalid csum
|
||||
table.deleteAt(item.key, i)
|
||||
case coalescePktInvalidCSum:
|
||||
// no point in inserting an item that we can't coalesce
|
||||
return groResultNoop
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
// failed to coalesce with any other packets; store the item in the flow
|
||||
table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
|
||||
return groResultTableInsert
|
||||
}
|
||||
|
||||
// applyTCPCoalesceAccounting updates bufs to account for coalescing based on the
|
||||
// metadata found in table.
|
||||
func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) error {
|
||||
for _, items := range table.itemsByFlow {
|
||||
for _, item := range items {
|
||||
if item.numMerged > 0 {
|
||||
hdr := virtioNetHdr{
|
||||
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
|
||||
hdrLen: uint16(item.iphLen + item.tcphLen),
|
||||
gsoSize: item.gsoSize,
|
||||
csumStart: uint16(item.iphLen),
|
||||
csumOffset: 16,
|
||||
}
|
||||
pkt := bufs[item.bufsIndex][offset:]
|
||||
|
||||
// Recalculate the total len (IPv4) or payload len (IPv6).
|
||||
// Recalculate the (IPv4) header checksum.
|
||||
if item.key.isV6 {
|
||||
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
|
||||
binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
|
||||
} else {
|
||||
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
|
||||
pkt[10], pkt[11] = 0, 0
|
||||
binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
|
||||
iphCSum := ^Checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum
|
||||
binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field
|
||||
}
|
||||
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Calculate the pseudo header checksum and place it at the TCP
|
||||
// checksum offset. Downstream checksum offloading will combine
|
||||
// this with computation of the tcp header and payload checksum.
|
||||
addrLen := 4
|
||||
addrOffset := ipv4SrcAddrOffset
|
||||
if item.key.isV6 {
|
||||
addrLen = 16
|
||||
addrOffset = ipv6SrcAddrOffset
|
||||
}
|
||||
srcAddrAt := offset + addrOffset
|
||||
srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
|
||||
dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
|
||||
psum := PseudoHeaderChecksum(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
|
||||
binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], Checksum([]byte{}, psum))
|
||||
} else {
|
||||
hdr := virtioNetHdr{}
|
||||
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyUDPCoalesceAccounting updates bufs to account for coalescing based on the
|
||||
// metadata found in table.
|
||||
func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) error {
|
||||
for _, items := range table.itemsByFlow {
|
||||
for _, item := range items {
|
||||
if item.numMerged > 0 {
|
||||
hdr := virtioNetHdr{
|
||||
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
|
||||
hdrLen: uint16(item.iphLen + udphLen),
|
||||
gsoSize: item.gsoSize,
|
||||
csumStart: uint16(item.iphLen),
|
||||
csumOffset: 6,
|
||||
}
|
||||
pkt := bufs[item.bufsIndex][offset:]
|
||||
|
||||
// Recalculate the total len (IPv4) or payload len (IPv6).
|
||||
// Recalculate the (IPv4) header checksum.
|
||||
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_UDP_L4
|
||||
if item.key.isV6 {
|
||||
binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
|
||||
} else {
|
||||
pkt[10], pkt[11] = 0, 0
|
||||
binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
|
||||
iphCSum := ^Checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum
|
||||
binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field
|
||||
}
|
||||
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Recalculate the UDP len field value
|
||||
binary.BigEndian.PutUint16(pkt[item.iphLen+4:], uint16(len(pkt[item.iphLen:])))
|
||||
|
||||
// Calculate the pseudo header checksum and place it at the UDP
|
||||
// checksum offset. Downstream checksum offloading will combine
|
||||
// this with computation of the udp header and payload checksum.
|
||||
addrLen := 4
|
||||
addrOffset := ipv4SrcAddrOffset
|
||||
if item.key.isV6 {
|
||||
addrLen = 16
|
||||
addrOffset = ipv6SrcAddrOffset
|
||||
}
|
||||
srcAddrAt := offset + addrOffset
|
||||
srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
|
||||
dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
|
||||
psum := PseudoHeaderChecksum(unix.IPPROTO_UDP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
|
||||
binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], Checksum([]byte{}, psum))
|
||||
} else {
|
||||
hdr := virtioNetHdr{}
|
||||
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type groCandidateType uint8
|
||||
|
||||
const (
|
||||
notGROCandidate groCandidateType = iota
|
||||
tcp4GROCandidate
|
||||
tcp6GROCandidate
|
||||
udp4GROCandidate
|
||||
udp6GROCandidate
|
||||
)
|
||||
|
||||
func packetIsGROCandidate(b []byte, gro groDisablementFlags) groCandidateType {
|
||||
if len(b) < 28 {
|
||||
return notGROCandidate
|
||||
}
|
||||
if b[0]>>4 == 4 {
|
||||
if b[0]&0x0F != 5 {
|
||||
// IPv4 packets w/IP options do not coalesce
|
||||
return notGROCandidate
|
||||
}
|
||||
if b[9] == unix.IPPROTO_TCP && len(b) >= 40 && gro.canTCPGRO() {
|
||||
return tcp4GROCandidate
|
||||
}
|
||||
if b[9] == unix.IPPROTO_UDP && gro.canUDPGRO() {
|
||||
return udp4GROCandidate
|
||||
}
|
||||
} else if b[0]>>4 == 6 {
|
||||
if b[6] == unix.IPPROTO_TCP && len(b) >= 60 && gro.canTCPGRO() {
|
||||
return tcp6GROCandidate
|
||||
}
|
||||
if b[6] == unix.IPPROTO_UDP && len(b) >= 48 && gro.canUDPGRO() {
|
||||
return udp6GROCandidate
|
||||
}
|
||||
}
|
||||
return notGROCandidate
|
||||
}
|
||||
|
||||
const (
|
||||
udphLen = 8
|
||||
)
|
||||
|
||||
// udpGRO evaluates the UDP packet at pktI in bufs for coalescing with
|
||||
// existing packets tracked in table. It returns a groResultNoop when no
|
||||
// action was taken, groResultTableInsert when the evaluated packet was
|
||||
// inserted into table, and groResultCoalesced when the evaluated packet was
|
||||
// coalesced with another packet in table.
|
||||
func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) groResult {
|
||||
pkt := bufs[pktI][offset:]
|
||||
if len(pkt) > maxUint16 {
|
||||
// A valid IPv4 or IPv6 packet will never exceed this.
|
||||
return groResultNoop
|
||||
}
|
||||
iphLen := int((pkt[0] & 0x0F) * 4)
|
||||
if isV6 {
|
||||
iphLen = 40
|
||||
ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
|
||||
if ipv6HPayloadLen != len(pkt)-iphLen {
|
||||
return groResultNoop
|
||||
}
|
||||
} else {
|
||||
totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
|
||||
if totalLen != len(pkt) {
|
||||
return groResultNoop
|
||||
}
|
||||
}
|
||||
if len(pkt) < iphLen {
|
||||
return groResultNoop
|
||||
}
|
||||
if len(pkt) < iphLen+udphLen {
|
||||
return groResultNoop
|
||||
}
|
||||
if !isV6 {
|
||||
if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
|
||||
// no GRO support for fragmented segments for now
|
||||
return groResultNoop
|
||||
}
|
||||
}
|
||||
gsoSize := uint16(len(pkt) - udphLen - iphLen)
|
||||
// not a candidate if payload len is 0
|
||||
if gsoSize < 1 {
|
||||
return groResultNoop
|
||||
}
|
||||
srcAddrOffset := ipv4SrcAddrOffset
|
||||
addrLen := 4
|
||||
if isV6 {
|
||||
srcAddrOffset = ipv6SrcAddrOffset
|
||||
addrLen = 16
|
||||
}
|
||||
items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI)
|
||||
if !existing {
|
||||
return groResultTableInsert
|
||||
}
|
||||
// With UDP we only check the last item, otherwise we could reorder packets
|
||||
// for a given flow. We must also always insert a new item, or successfully
|
||||
// coalesce with an existing item, for the same reason.
|
||||
item := items[len(items)-1]
|
||||
can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, bufs, offset)
|
||||
var pktCSumKnownInvalid bool
|
||||
if can == coalesceAppend {
|
||||
result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6)
|
||||
switch result {
|
||||
case coalesceSuccess:
|
||||
table.updateAt(item, len(items)-1)
|
||||
return groResultCoalesced
|
||||
case coalesceItemInvalidCSum:
|
||||
// If the existing item has an invalid csum we take no action. A new
|
||||
// item will be stored after it, and the existing item will never be
|
||||
// revisited as part of future coalescing candidacy checks.
|
||||
case coalescePktInvalidCSum:
|
||||
// We must insert a new item, but we also mark it as invalid csum
|
||||
// to prevent a repeat checksum validation.
|
||||
pktCSumKnownInvalid = true
|
||||
default:
|
||||
}
|
||||
}
|
||||
// failed to coalesce with any other packets; store the item in the flow
|
||||
table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI, pktCSumKnownInvalid)
|
||||
return groResultTableInsert
|
||||
}
|
||||
|
||||
// handleGRO evaluates bufs for GRO, and writes the indices of the resulting
|
||||
// packets into toWrite. toWrite, tcpTable, and udpTable should initially be
|
||||
// empty (but non-nil), and are passed in to save allocs as the caller may reset
|
||||
// and recycle them across vectors of packets. gro indicates if TCP and UDP GRO
|
||||
// are supported/enabled.
|
||||
func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, gro groDisablementFlags, toWrite *[]int) error {
|
||||
for i := range bufs {
|
||||
if offset < virtioNetHdrLen || offset > len(bufs[i])-1 {
|
||||
return errors.New("invalid offset")
|
||||
}
|
||||
var result groResult
|
||||
switch packetIsGROCandidate(bufs[i][offset:], gro) {
|
||||
case tcp4GROCandidate:
|
||||
result = tcpGRO(bufs, offset, i, tcpTable, false)
|
||||
case tcp6GROCandidate:
|
||||
result = tcpGRO(bufs, offset, i, tcpTable, true)
|
||||
case udp4GROCandidate:
|
||||
result = udpGRO(bufs, offset, i, udpTable, false)
|
||||
case udp6GROCandidate:
|
||||
result = udpGRO(bufs, offset, i, udpTable, true)
|
||||
}
|
||||
switch result {
|
||||
case groResultNoop:
|
||||
hdr := virtioNetHdr{}
|
||||
err := hdr.encode(bufs[i][offset-virtioNetHdrLen:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fallthrough
|
||||
case groResultTableInsert:
|
||||
*toWrite = append(*toWrite, i)
|
||||
}
|
||||
}
|
||||
errTCP := applyTCPCoalesceAccounting(bufs, offset, tcpTable)
|
||||
errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable)
|
||||
return errors.Join(errTCP, errUDP)
|
||||
}
|
||||
24
vendor/github.com/tailscale/wireguard-go/tun/operateonfd.go
generated
vendored
Normal file
24
vendor/github.com/tailscale/wireguard-go/tun/operateonfd.go
generated
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
//go:build darwin || freebsd
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func (tun *NativeTun) operateOnFd(fn func(fd uintptr)) {
|
||||
sysconn, err := tun.tunFile.SyscallConn()
|
||||
if err != nil {
|
||||
tun.errors <- fmt.Errorf("unable to find sysconn for tunfile: %s", err.Error())
|
||||
return
|
||||
}
|
||||
err = sysconn.Control(fn)
|
||||
if err != nil {
|
||||
tun.errors <- fmt.Errorf("unable to control sysconn for tunfile: %s", err.Error())
|
||||
}
|
||||
}
|
||||
76
vendor/github.com/tailscale/wireguard-go/tun/tun.go
generated
vendored
Normal file
76
vendor/github.com/tailscale/wireguard-go/tun/tun.go
generated
vendored
Normal file
@@ -0,0 +1,76 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"os"
|
||||
)
|
||||
|
||||
type Event int
|
||||
|
||||
const (
|
||||
EventUp = 1 << iota
|
||||
EventDown
|
||||
EventMTUUpdate
|
||||
)
|
||||
|
||||
type Device interface {
|
||||
// File returns the file descriptor of the device.
|
||||
File() *os.File
|
||||
|
||||
// Read one or more packets from the Device (without any additional headers).
|
||||
// On a successful read it returns the number of packets read, and sets
|
||||
// packet lengths within the sizes slice. len(sizes) must be >= len(bufs).
|
||||
// A nonzero offset can be used to instruct the Device on where to begin
|
||||
// reading into each element of the bufs slice.
|
||||
Read(bufs [][]byte, sizes []int, offset int) (n int, err error)
|
||||
|
||||
// Write one or more packets to the device (without any additional headers).
|
||||
// On a successful write it returns the number of packets written. A nonzero
|
||||
// offset can be used to instruct the Device on where to begin writing from
|
||||
// each packet contained within the bufs slice.
|
||||
Write(bufs [][]byte, offset int) (int, error)
|
||||
|
||||
// MTU returns the MTU of the Device.
|
||||
MTU() (int, error)
|
||||
|
||||
// Name returns the current name of the Device.
|
||||
Name() (string, error)
|
||||
|
||||
// Events returns a channel of type Event, which is fed Device events.
|
||||
Events() <-chan Event
|
||||
|
||||
// Close stops the Device and closes the Event channel.
|
||||
Close() error
|
||||
|
||||
// BatchSize returns the preferred/max number of packets that can be read or
|
||||
// written in a single read/write call. BatchSize must not change over the
|
||||
// lifetime of a Device.
|
||||
BatchSize() int
|
||||
}
|
||||
|
||||
// GRODevice is a Device extended with methods for disabling GRO. Certain OS
|
||||
// versions may have offload bugs. Where these bugs negatively impact throughput
|
||||
// or break connectivity entirely we can use these methods to disable the
|
||||
// related offload.
|
||||
//
|
||||
// Linux has the following known, GRO bugs.
|
||||
//
|
||||
// torvalds/linux@e269d79c7d35aa3808b1f3c1737d63dab504ddc8 broke virtio_net
|
||||
// TCP & UDP GRO causing GRO writes to return EINVAL. The bug was then
|
||||
// resolved later in
|
||||
// torvalds/linux@89add40066f9ed9abe5f7f886fe5789ff7e0c50e. The offending
|
||||
// commit was pulled into various LTS releases.
|
||||
//
|
||||
// UDP GRO writes end up blackholing/dropping packets destined for a
|
||||
// vxlan/geneve interface on kernel versions prior to 6.8.5.
|
||||
type GRODevice interface {
|
||||
Device
|
||||
// DisableUDPGRO disables UDP GRO if it is enabled.
|
||||
DisableUDPGRO()
|
||||
// DisableTCPGRO disables TCP GRO if it is enabled.
|
||||
DisableTCPGRO()
|
||||
}
|
||||
336
vendor/github.com/tailscale/wireguard-go/tun/tun_darwin.go
generated
vendored
Normal file
336
vendor/github.com/tailscale/wireguard-go/tun/tun_darwin.go
generated
vendored
Normal file
@@ -0,0 +1,336 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const utunControlName = "com.apple.net.utun_control"
|
||||
|
||||
type NativeTun struct {
|
||||
name string
|
||||
tunFile *os.File
|
||||
events chan Event
|
||||
errors chan error
|
||||
routeSocket int
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func retryInterfaceByIndex(index int) (iface *net.Interface, err error) {
|
||||
for i := 0; i < 20; i++ {
|
||||
iface, err = net.InterfaceByIndex(index)
|
||||
if err != nil && errors.Is(err, unix.ENOMEM) {
|
||||
time.Sleep(time.Duration(i) * time.Second / 3)
|
||||
continue
|
||||
}
|
||||
return iface, err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (tun *NativeTun) routineRouteListener(tunIfindex int) {
|
||||
var (
|
||||
statusUp bool
|
||||
statusMTU int
|
||||
)
|
||||
|
||||
defer close(tun.events)
|
||||
|
||||
data := make([]byte, os.Getpagesize())
|
||||
for {
|
||||
retry:
|
||||
n, err := unix.Read(tun.routeSocket, data)
|
||||
if err != nil {
|
||||
if errno, ok := err.(unix.Errno); ok && errno == unix.EINTR {
|
||||
goto retry
|
||||
}
|
||||
tun.errors <- err
|
||||
return
|
||||
}
|
||||
|
||||
if n < 14 {
|
||||
continue
|
||||
}
|
||||
|
||||
if data[3 /* type */] != unix.RTM_IFINFO {
|
||||
continue
|
||||
}
|
||||
ifindex := int(*(*uint16)(unsafe.Pointer(&data[12 /* ifindex */])))
|
||||
if ifindex != tunIfindex {
|
||||
continue
|
||||
}
|
||||
|
||||
iface, err := retryInterfaceByIndex(ifindex)
|
||||
if err != nil {
|
||||
tun.errors <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Up / Down event
|
||||
up := (iface.Flags & net.FlagUp) != 0
|
||||
if up != statusUp && up {
|
||||
tun.events <- EventUp
|
||||
}
|
||||
if up != statusUp && !up {
|
||||
tun.events <- EventDown
|
||||
}
|
||||
statusUp = up
|
||||
|
||||
// MTU changes
|
||||
if iface.MTU != statusMTU {
|
||||
tun.events <- EventMTUUpdate
|
||||
}
|
||||
statusMTU = iface.MTU
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTUN(name string, mtu int) (Device, error) {
|
||||
ifIndex := -1
|
||||
if name != "utun" {
|
||||
_, err := fmt.Sscanf(name, "utun%d", &ifIndex)
|
||||
if err != nil || ifIndex < 0 {
|
||||
return nil, fmt.Errorf("Interface name must be utun[0-9]*")
|
||||
}
|
||||
}
|
||||
|
||||
fd, err := socketCloexec(unix.AF_SYSTEM, unix.SOCK_DGRAM, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctlInfo := &unix.CtlInfo{}
|
||||
copy(ctlInfo.Name[:], []byte(utunControlName))
|
||||
err = unix.IoctlCtlInfo(fd, ctlInfo)
|
||||
if err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("IoctlGetCtlInfo: %w", err)
|
||||
}
|
||||
|
||||
sc := &unix.SockaddrCtl{
|
||||
ID: ctlInfo.Id,
|
||||
Unit: uint32(ifIndex) + 1,
|
||||
}
|
||||
|
||||
err = unix.Connect(fd, sc)
|
||||
if err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = unix.SetNonblock(fd, true)
|
||||
if err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
tun, err := CreateTUNFromFile(os.NewFile(uintptr(fd), ""), mtu)
|
||||
|
||||
if err == nil && name == "utun" {
|
||||
fname := os.Getenv("WG_TUN_NAME_FILE")
|
||||
if fname != "" {
|
||||
os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0o400)
|
||||
}
|
||||
}
|
||||
|
||||
return tun, err
|
||||
}
|
||||
|
||||
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
|
||||
tun := &NativeTun{
|
||||
tunFile: file,
|
||||
events: make(chan Event, 10),
|
||||
errors: make(chan error, 5),
|
||||
}
|
||||
|
||||
name, err := tun.Name()
|
||||
if err != nil {
|
||||
tun.tunFile.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tunIfindex, err := func() (int, error) {
|
||||
iface, err := net.InterfaceByName(name)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
return iface.Index, nil
|
||||
}()
|
||||
if err != nil {
|
||||
tun.tunFile.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tun.routeSocket, err = socketCloexec(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
tun.tunFile.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go tun.routineRouteListener(tunIfindex)
|
||||
|
||||
if mtu > 0 {
|
||||
err = tun.setMTU(mtu)
|
||||
if err != nil {
|
||||
tun.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return tun, nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Name() (string, error) {
|
||||
var err error
|
||||
tun.operateOnFd(func(fd uintptr) {
|
||||
tun.name, err = unix.GetsockoptString(
|
||||
int(fd),
|
||||
2, /* #define SYSPROTO_CONTROL 2 */
|
||||
2, /* #define UTUN_OPT_IFNAME 2 */
|
||||
)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("GetSockoptString: %w", err)
|
||||
}
|
||||
|
||||
return tun.name, nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) File() *os.File {
|
||||
return tun.tunFile
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Events() <-chan Event {
|
||||
return tun.events
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
|
||||
// TODO: the BSDs look very similar in Read() and Write(). They should be
|
||||
// collapsed, with platform-specific files containing the varying parts of
|
||||
// their implementations.
|
||||
select {
|
||||
case err := <-tun.errors:
|
||||
return 0, err
|
||||
default:
|
||||
buf := bufs[0][offset-4:]
|
||||
n, err := tun.tunFile.Read(buf[:])
|
||||
if n < 4 {
|
||||
return 0, err
|
||||
}
|
||||
sizes[0] = n - 4
|
||||
return 1, err
|
||||
}
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
|
||||
if offset < 4 {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
for i, buf := range bufs {
|
||||
buf = buf[offset-4:]
|
||||
buf[0] = 0x00
|
||||
buf[1] = 0x00
|
||||
buf[2] = 0x00
|
||||
switch buf[4] >> 4 {
|
||||
case 4:
|
||||
buf[3] = unix.AF_INET
|
||||
case 6:
|
||||
buf[3] = unix.AF_INET6
|
||||
default:
|
||||
return i, unix.EAFNOSUPPORT
|
||||
}
|
||||
if _, err := tun.tunFile.Write(buf); err != nil {
|
||||
return i, err
|
||||
}
|
||||
}
|
||||
return len(bufs), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Close() error {
|
||||
var err1, err2 error
|
||||
tun.closeOnce.Do(func() {
|
||||
err1 = tun.tunFile.Close()
|
||||
if tun.routeSocket != -1 {
|
||||
unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
|
||||
err2 = unix.Close(tun.routeSocket)
|
||||
} else if tun.events != nil {
|
||||
close(tun.events)
|
||||
}
|
||||
})
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
return err2
|
||||
}
|
||||
|
||||
func (tun *NativeTun) setMTU(n int) error {
|
||||
fd, err := socketCloexec(
|
||||
unix.AF_INET,
|
||||
unix.SOCK_DGRAM,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer unix.Close(fd)
|
||||
|
||||
var ifr unix.IfreqMTU
|
||||
copy(ifr.Name[:], tun.name)
|
||||
ifr.MTU = int32(n)
|
||||
err = unix.IoctlSetIfreqMTU(fd, &ifr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set MTU on %s: %w", tun.name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) MTU() (int, error) {
|
||||
fd, err := socketCloexec(
|
||||
unix.AF_INET,
|
||||
unix.SOCK_DGRAM,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
defer unix.Close(fd)
|
||||
|
||||
ifr, err := unix.IoctlGetIfreqMTU(fd, tun.name)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get MTU on %s: %w", tun.name, err)
|
||||
}
|
||||
|
||||
return int(ifr.MTU), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func socketCloexec(family, sotype, proto int) (fd int, err error) {
|
||||
// See go/src/net/sys_cloexec.go for background.
|
||||
syscall.ForkLock.RLock()
|
||||
defer syscall.ForkLock.RUnlock()
|
||||
|
||||
fd, err = unix.Socket(family, sotype, proto)
|
||||
if err == nil {
|
||||
unix.CloseOnExec(fd)
|
||||
}
|
||||
return
|
||||
}
|
||||
435
vendor/github.com/tailscale/wireguard-go/tun/tun_freebsd.go
generated
vendored
Normal file
435
vendor/github.com/tailscale/wireguard-go/tun/tun_freebsd.go
generated
vendored
Normal file
@@ -0,0 +1,435 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
_TUNSIFHEAD = 0x80047460
|
||||
_TUNSIFMODE = 0x8004745e
|
||||
_TUNGIFNAME = 0x4020745d
|
||||
_TUNSIFPID = 0x2000745f
|
||||
|
||||
_SIOCGIFINFO_IN6 = 0xc048696c
|
||||
_SIOCSIFINFO_IN6 = 0xc048696d
|
||||
_ND6_IFF_AUTO_LINKLOCAL = 0x20
|
||||
_ND6_IFF_NO_DAD = 0x100
|
||||
)
|
||||
|
||||
// Iface requests with just the name
|
||||
type ifreqName struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
_ [16]byte
|
||||
}
|
||||
|
||||
// Iface requests with a pointer
|
||||
type ifreqPtr struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Data uintptr
|
||||
_ [16 - unsafe.Sizeof(uintptr(0))]byte
|
||||
}
|
||||
|
||||
// Iface requests with MTU
|
||||
type ifreqMtu struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
MTU uint32
|
||||
_ [12]byte
|
||||
}
|
||||
|
||||
// ND6 flag manipulation
|
||||
type nd6Req struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Linkmtu uint32
|
||||
Maxmtu uint32
|
||||
Basereachable uint32
|
||||
Reachable uint32
|
||||
Retrans uint32
|
||||
Flags uint32
|
||||
Recalctm int
|
||||
Chlim uint8
|
||||
Initialized uint8
|
||||
Randomseed0 [8]byte
|
||||
Randomseed1 [8]byte
|
||||
Randomid [8]byte
|
||||
}
|
||||
|
||||
type NativeTun struct {
|
||||
name string
|
||||
tunFile *os.File
|
||||
events chan Event
|
||||
errors chan error
|
||||
routeSocket int
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func (tun *NativeTun) routineRouteListener(tunIfindex int) {
|
||||
var (
|
||||
statusUp bool
|
||||
statusMTU int
|
||||
)
|
||||
|
||||
defer close(tun.events)
|
||||
|
||||
data := make([]byte, os.Getpagesize())
|
||||
for {
|
||||
retry:
|
||||
n, err := unix.Read(tun.routeSocket, data)
|
||||
if err != nil {
|
||||
if errors.Is(err, syscall.EINTR) {
|
||||
goto retry
|
||||
}
|
||||
tun.errors <- err
|
||||
return
|
||||
}
|
||||
|
||||
if n < 14 {
|
||||
continue
|
||||
}
|
||||
|
||||
if data[3 /* type */] != unix.RTM_IFINFO {
|
||||
continue
|
||||
}
|
||||
ifindex := int(*(*uint16)(unsafe.Pointer(&data[12 /* ifindex */])))
|
||||
if ifindex != tunIfindex {
|
||||
continue
|
||||
}
|
||||
|
||||
iface, err := net.InterfaceByIndex(ifindex)
|
||||
if err != nil {
|
||||
tun.errors <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Up / Down event
|
||||
up := (iface.Flags & net.FlagUp) != 0
|
||||
if up != statusUp && up {
|
||||
tun.events <- EventUp
|
||||
}
|
||||
if up != statusUp && !up {
|
||||
tun.events <- EventDown
|
||||
}
|
||||
statusUp = up
|
||||
|
||||
// MTU changes
|
||||
if iface.MTU != statusMTU {
|
||||
tun.events <- EventMTUUpdate
|
||||
}
|
||||
statusMTU = iface.MTU
|
||||
}
|
||||
}
|
||||
|
||||
func tunName(fd uintptr) (string, error) {
|
||||
var ifreq ifreqName
|
||||
_, _, err := unix.Syscall(unix.SYS_IOCTL, fd, _TUNGIFNAME, uintptr(unsafe.Pointer(&ifreq)))
|
||||
if err != 0 {
|
||||
return "", err
|
||||
}
|
||||
return unix.ByteSliceToString(ifreq.Name[:]), nil
|
||||
}
|
||||
|
||||
// Destroy a named system interface
|
||||
func tunDestroy(name string) error {
|
||||
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer unix.Close(fd)
|
||||
|
||||
var ifr [32]byte
|
||||
copy(ifr[:], name)
|
||||
_, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCIFDESTROY), uintptr(unsafe.Pointer(&ifr[0])))
|
||||
if errno != 0 {
|
||||
return fmt.Errorf("failed to destroy interface %s: %w", name, errno)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func CreateTUN(name string, mtu int) (Device, error) {
|
||||
if len(name) > unix.IFNAMSIZ-1 {
|
||||
return nil, errors.New("interface name too long")
|
||||
}
|
||||
|
||||
// See if interface already exists
|
||||
iface, _ := net.InterfaceByName(name)
|
||||
if iface != nil {
|
||||
return nil, fmt.Errorf("interface %s already exists", name)
|
||||
}
|
||||
|
||||
tunFile, err := os.OpenFile("/dev/tun", unix.O_RDWR|unix.O_CLOEXEC, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tun := NativeTun{tunFile: tunFile}
|
||||
var assignedName string
|
||||
tun.operateOnFd(func(fd uintptr) {
|
||||
assignedName, err = tunName(fd)
|
||||
})
|
||||
if err != nil {
|
||||
tunFile.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Enable ifhead mode, otherwise tun will complain if it gets a non-AF_INET packet
|
||||
ifheadmode := 1
|
||||
var errno syscall.Errno
|
||||
tun.operateOnFd(func(fd uintptr) {
|
||||
_, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, _TUNSIFHEAD, uintptr(unsafe.Pointer(&ifheadmode)))
|
||||
})
|
||||
|
||||
if errno != 0 {
|
||||
tunFile.Close()
|
||||
tunDestroy(assignedName)
|
||||
return nil, fmt.Errorf("unable to put into IFHEAD mode: %w", errno)
|
||||
}
|
||||
|
||||
// Get out of PTP mode.
|
||||
ifflags := syscall.IFF_BROADCAST | syscall.IFF_MULTICAST
|
||||
tun.operateOnFd(func(fd uintptr) {
|
||||
_, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, uintptr(_TUNSIFMODE), uintptr(unsafe.Pointer(&ifflags)))
|
||||
})
|
||||
|
||||
if errno != 0 {
|
||||
tunFile.Close()
|
||||
tunDestroy(assignedName)
|
||||
return nil, fmt.Errorf("unable to put into IFF_BROADCAST mode: %w", errno)
|
||||
}
|
||||
|
||||
// Disable link-local v6, not just because WireGuard doesn't do that anyway, but
|
||||
// also because there are serious races with attaching and detaching LLv6 addresses
|
||||
// in relation to interface lifetime within the FreeBSD kernel.
|
||||
confd6, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
|
||||
if err != nil {
|
||||
tunFile.Close()
|
||||
tunDestroy(assignedName)
|
||||
return nil, err
|
||||
}
|
||||
defer unix.Close(confd6)
|
||||
var ndireq nd6Req
|
||||
copy(ndireq.Name[:], assignedName)
|
||||
_, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd6), uintptr(_SIOCGIFINFO_IN6), uintptr(unsafe.Pointer(&ndireq)))
|
||||
if errno != 0 {
|
||||
tunFile.Close()
|
||||
tunDestroy(assignedName)
|
||||
return nil, fmt.Errorf("unable to get nd6 flags for %s: %w", assignedName, errno)
|
||||
}
|
||||
ndireq.Flags = ndireq.Flags &^ _ND6_IFF_AUTO_LINKLOCAL
|
||||
ndireq.Flags = ndireq.Flags | _ND6_IFF_NO_DAD
|
||||
_, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd6), uintptr(_SIOCSIFINFO_IN6), uintptr(unsafe.Pointer(&ndireq)))
|
||||
if errno != 0 {
|
||||
tunFile.Close()
|
||||
tunDestroy(assignedName)
|
||||
return nil, fmt.Errorf("unable to set nd6 flags for %s: %w", assignedName, errno)
|
||||
}
|
||||
|
||||
if name != "" {
|
||||
confd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
|
||||
if err != nil {
|
||||
tunFile.Close()
|
||||
tunDestroy(assignedName)
|
||||
return nil, err
|
||||
}
|
||||
defer unix.Close(confd)
|
||||
var newnp [unix.IFNAMSIZ]byte
|
||||
copy(newnp[:], name)
|
||||
var ifr ifreqPtr
|
||||
copy(ifr.Name[:], assignedName)
|
||||
ifr.Data = uintptr(unsafe.Pointer(&newnp[0]))
|
||||
_, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd), uintptr(unix.SIOCSIFNAME), uintptr(unsafe.Pointer(&ifr)))
|
||||
if errno != 0 {
|
||||
tunFile.Close()
|
||||
tunDestroy(assignedName)
|
||||
return nil, fmt.Errorf("Failed to rename %s to %s: %w", assignedName, name, errno)
|
||||
}
|
||||
}
|
||||
|
||||
return CreateTUNFromFile(tunFile, mtu)
|
||||
}
|
||||
|
||||
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
|
||||
tun := &NativeTun{
|
||||
tunFile: file,
|
||||
events: make(chan Event, 10),
|
||||
errors: make(chan error, 1),
|
||||
}
|
||||
|
||||
var errno syscall.Errno
|
||||
tun.operateOnFd(func(fd uintptr) {
|
||||
_, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, _TUNSIFPID, uintptr(0))
|
||||
})
|
||||
if errno != 0 {
|
||||
tun.tunFile.Close()
|
||||
return nil, fmt.Errorf("unable to become controlling TUN process: %w", errno)
|
||||
}
|
||||
|
||||
name, err := tun.Name()
|
||||
if err != nil {
|
||||
tun.tunFile.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tunIfindex, err := func() (int, error) {
|
||||
iface, err := net.InterfaceByName(name)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
return iface.Index, nil
|
||||
}()
|
||||
if err != nil {
|
||||
tun.tunFile.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
tun.tunFile.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go tun.routineRouteListener(tunIfindex)
|
||||
|
||||
err = tun.setMTU(mtu)
|
||||
if err != nil {
|
||||
tun.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tun, nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Name() (string, error) {
|
||||
var name string
|
||||
var err error
|
||||
tun.operateOnFd(func(fd uintptr) {
|
||||
name, err = tunName(fd)
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
tun.name = name
|
||||
return name, nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) File() *os.File {
|
||||
return tun.tunFile
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Events() <-chan Event {
|
||||
return tun.events
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
|
||||
select {
|
||||
case err := <-tun.errors:
|
||||
return 0, err
|
||||
default:
|
||||
buf := bufs[0][offset-4:]
|
||||
n, err := tun.tunFile.Read(buf[:])
|
||||
if n < 4 {
|
||||
return 0, err
|
||||
}
|
||||
sizes[0] = n - 4
|
||||
return 1, err
|
||||
}
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
|
||||
if offset < 4 {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
for i, buf := range bufs {
|
||||
buf = buf[offset-4:]
|
||||
if len(buf) < 5 {
|
||||
return i, io.ErrShortBuffer
|
||||
}
|
||||
buf[0] = 0x00
|
||||
buf[1] = 0x00
|
||||
buf[2] = 0x00
|
||||
switch buf[4] >> 4 {
|
||||
case 4:
|
||||
buf[3] = unix.AF_INET
|
||||
case 6:
|
||||
buf[3] = unix.AF_INET6
|
||||
default:
|
||||
return i, unix.EAFNOSUPPORT
|
||||
}
|
||||
if _, err := tun.tunFile.Write(buf); err != nil {
|
||||
return i, err
|
||||
}
|
||||
}
|
||||
return len(bufs), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Close() error {
|
||||
var err1, err2, err3 error
|
||||
tun.closeOnce.Do(func() {
|
||||
err1 = tun.tunFile.Close()
|
||||
err2 = tunDestroy(tun.name)
|
||||
if tun.routeSocket != -1 {
|
||||
unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
|
||||
err3 = unix.Close(tun.routeSocket)
|
||||
tun.routeSocket = -1
|
||||
} else if tun.events != nil {
|
||||
close(tun.events)
|
||||
}
|
||||
})
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
if err2 != nil {
|
||||
return err2
|
||||
}
|
||||
return err3
|
||||
}
|
||||
|
||||
func (tun *NativeTun) setMTU(n int) error {
|
||||
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer unix.Close(fd)
|
||||
|
||||
var ifr ifreqMtu
|
||||
copy(ifr.Name[:], tun.name)
|
||||
ifr.MTU = uint32(n)
|
||||
_, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCSIFMTU), uintptr(unsafe.Pointer(&ifr)))
|
||||
if errno != 0 {
|
||||
return fmt.Errorf("failed to set MTU on %s: %w", tun.name, errno)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) MTU() (int, error) {
|
||||
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer unix.Close(fd)
|
||||
|
||||
var ifr ifreqMtu
|
||||
copy(ifr.Name[:], tun.name)
|
||||
_, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCGIFMTU), uintptr(unsafe.Pointer(&ifr)))
|
||||
if errno != 0 {
|
||||
return 0, fmt.Errorf("failed to get MTU on %s: %w", tun.name, errno)
|
||||
}
|
||||
return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
660
vendor/github.com/tailscale/wireguard-go/tun/tun_linux.go
generated
vendored
Normal file
660
vendor/github.com/tailscale/wireguard-go/tun/tun_linux.go
generated
vendored
Normal file
@@ -0,0 +1,660 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tun
|
||||
|
||||
/* Implementation of the TUN device interface for linux
|
||||
*/
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/tailscale/wireguard-go/conn"
|
||||
"github.com/tailscale/wireguard-go/rwcancel"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
cloneDevicePath = "/dev/net/tun"
|
||||
ifReqSize = unix.IFNAMSIZ + 64
|
||||
)
|
||||
|
||||
type NativeTun struct {
|
||||
tunFile *os.File
|
||||
index int32 // if index
|
||||
errors chan error // async error handling
|
||||
events chan Event // device related events
|
||||
netlinkSock int
|
||||
netlinkCancel *rwcancel.RWCancel
|
||||
hackListenerClosed sync.Mutex
|
||||
statusListenersShutdown chan struct{}
|
||||
batchSize int
|
||||
vnetHdr bool
|
||||
|
||||
closeOnce sync.Once
|
||||
|
||||
nameOnce sync.Once // guards calling initNameCache, which sets following fields
|
||||
nameCache string // name of interface
|
||||
nameErr error
|
||||
|
||||
readOpMu sync.Mutex // readOpMu guards readBuff
|
||||
readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr
|
||||
|
||||
writeOpMu sync.Mutex // writeOpMu guards the following fields
|
||||
toWrite []int
|
||||
tcpGROTable *tcpGROTable
|
||||
udpGROTable *udpGROTable
|
||||
gro groDisablementFlags
|
||||
}
|
||||
|
||||
type groDisablementFlags int
|
||||
|
||||
const (
|
||||
tcpGRODisabled groDisablementFlags = 1 << iota
|
||||
udpGRODisabled
|
||||
)
|
||||
|
||||
func (g *groDisablementFlags) disableTCPGRO() {
|
||||
*g |= tcpGRODisabled
|
||||
}
|
||||
|
||||
func (g *groDisablementFlags) canTCPGRO() bool {
|
||||
return (*g)&tcpGRODisabled == 0
|
||||
}
|
||||
|
||||
func (g *groDisablementFlags) disableUDPGRO() {
|
||||
*g |= udpGRODisabled
|
||||
}
|
||||
|
||||
func (g *groDisablementFlags) canUDPGRO() bool {
|
||||
return (*g)&udpGRODisabled == 0
|
||||
}
|
||||
|
||||
func (tun *NativeTun) File() *os.File {
|
||||
return tun.tunFile
|
||||
}
|
||||
|
||||
func (tun *NativeTun) routineHackListener() {
|
||||
defer tun.hackListenerClosed.Unlock()
|
||||
/* This is needed for the detection to work across network namespaces
|
||||
* If you are reading this and know a better method, please get in touch.
|
||||
*/
|
||||
last := 0
|
||||
const (
|
||||
up = 1
|
||||
down = 2
|
||||
)
|
||||
for {
|
||||
sysconn, err := tun.tunFile.SyscallConn()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err2 := sysconn.Control(func(fd uintptr) {
|
||||
_, err = unix.Write(int(fd), nil)
|
||||
})
|
||||
if err2 != nil {
|
||||
return
|
||||
}
|
||||
switch err {
|
||||
case unix.EINVAL:
|
||||
if last != up {
|
||||
// If the tunnel is up, it reports that write() is
|
||||
// allowed but we provided invalid data.
|
||||
tun.events <- EventUp
|
||||
last = up
|
||||
}
|
||||
case unix.EIO:
|
||||
if last != down {
|
||||
// If the tunnel is down, it reports that no I/O
|
||||
// is possible, without checking our provided data.
|
||||
tun.events <- EventDown
|
||||
last = down
|
||||
}
|
||||
default:
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-time.After(time.Second):
|
||||
// nothing
|
||||
case <-tun.statusListenersShutdown:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func createNetlinkSocket() (int, error) {
|
||||
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
saddr := &unix.SockaddrNetlink{
|
||||
Family: unix.AF_NETLINK,
|
||||
Groups: unix.RTMGRP_LINK | unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR,
|
||||
}
|
||||
err = unix.Bind(sock, saddr)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
return sock, nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) routineNetlinkListener() {
|
||||
defer func() {
|
||||
unix.Close(tun.netlinkSock)
|
||||
tun.hackListenerClosed.Lock()
|
||||
close(tun.events)
|
||||
tun.netlinkCancel.Close()
|
||||
}()
|
||||
|
||||
for msg := make([]byte, 1<<16); ; {
|
||||
var err error
|
||||
var msgn int
|
||||
for {
|
||||
msgn, _, _, _, err = unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0)
|
||||
if err == nil || !rwcancel.RetryAfterError(err) {
|
||||
break
|
||||
}
|
||||
if !tun.netlinkCancel.ReadyRead() {
|
||||
tun.errors <- fmt.Errorf("netlink socket closed: %w", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
tun.errors <- fmt.Errorf("failed to receive netlink message: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-tun.statusListenersShutdown:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
wasEverUp := false
|
||||
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
|
||||
|
||||
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
|
||||
|
||||
if int(hdr.Len) > len(remain) {
|
||||
break
|
||||
}
|
||||
|
||||
switch hdr.Type {
|
||||
case unix.NLMSG_DONE:
|
||||
remain = []byte{}
|
||||
|
||||
case unix.RTM_NEWLINK:
|
||||
info := *(*unix.IfInfomsg)(unsafe.Pointer(&remain[unix.SizeofNlMsghdr]))
|
||||
remain = remain[hdr.Len:]
|
||||
|
||||
if info.Index != tun.index {
|
||||
// not our interface
|
||||
continue
|
||||
}
|
||||
|
||||
if info.Flags&unix.IFF_RUNNING != 0 {
|
||||
tun.events <- EventUp
|
||||
wasEverUp = true
|
||||
}
|
||||
|
||||
if info.Flags&unix.IFF_RUNNING == 0 {
|
||||
// Don't emit EventDown before we've ever emitted EventUp.
|
||||
// This avoids a startup race with HackListener, which
|
||||
// might detect Up before we have finished reporting Down.
|
||||
if wasEverUp {
|
||||
tun.events <- EventDown
|
||||
}
|
||||
}
|
||||
|
||||
tun.events <- EventMTUUpdate
|
||||
|
||||
default:
|
||||
remain = remain[hdr.Len:]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getIFIndex(name string) (int32, error) {
|
||||
fd, err := unix.Socket(
|
||||
unix.AF_INET,
|
||||
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
defer unix.Close(fd)
|
||||
|
||||
var ifr [ifReqSize]byte
|
||||
copy(ifr[:], name)
|
||||
_, _, errno := unix.Syscall(
|
||||
unix.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
uintptr(unix.SIOCGIFINDEX),
|
||||
uintptr(unsafe.Pointer(&ifr[0])),
|
||||
)
|
||||
|
||||
if errno != 0 {
|
||||
return 0, errno
|
||||
}
|
||||
|
||||
return *(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) setMTU(n int) error {
|
||||
name, err := tun.Name()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// open datagram socket
|
||||
fd, err := unix.Socket(
|
||||
unix.AF_INET,
|
||||
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer unix.Close(fd)
|
||||
|
||||
// do ioctl call
|
||||
var ifr [ifReqSize]byte
|
||||
copy(ifr[:], name)
|
||||
*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n)
|
||||
_, _, errno := unix.Syscall(
|
||||
unix.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
uintptr(unix.SIOCSIFMTU),
|
||||
uintptr(unsafe.Pointer(&ifr[0])),
|
||||
)
|
||||
|
||||
if errno != 0 {
|
||||
return fmt.Errorf("failed to set MTU of TUN device: %w", errno)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) MTU() (int, error) {
|
||||
name, err := tun.Name()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// open datagram socket
|
||||
fd, err := unix.Socket(
|
||||
unix.AF_INET,
|
||||
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
defer unix.Close(fd)
|
||||
|
||||
// do ioctl call
|
||||
|
||||
var ifr [ifReqSize]byte
|
||||
copy(ifr[:], name)
|
||||
_, _, errno := unix.Syscall(
|
||||
unix.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
uintptr(unix.SIOCGIFMTU),
|
||||
uintptr(unsafe.Pointer(&ifr[0])),
|
||||
)
|
||||
if errno != 0 {
|
||||
return 0, fmt.Errorf("failed to get MTU of TUN device: %w", errno)
|
||||
}
|
||||
|
||||
return int(*(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ]))), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Name() (string, error) {
|
||||
tun.nameOnce.Do(tun.initNameCache)
|
||||
return tun.nameCache, tun.nameErr
|
||||
}
|
||||
|
||||
func (tun *NativeTun) initNameCache() {
|
||||
tun.nameCache, tun.nameErr = tun.nameSlow()
|
||||
}
|
||||
|
||||
func (tun *NativeTun) nameSlow() (string, error) {
|
||||
sysconn, err := tun.tunFile.SyscallConn()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var ifr [ifReqSize]byte
|
||||
var errno syscall.Errno
|
||||
err = sysconn.Control(func(fd uintptr) {
|
||||
_, _, errno = unix.Syscall(
|
||||
unix.SYS_IOCTL,
|
||||
fd,
|
||||
uintptr(unix.TUNGETIFF),
|
||||
uintptr(unsafe.Pointer(&ifr[0])),
|
||||
)
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get name of TUN device: %w", err)
|
||||
}
|
||||
if errno != 0 {
|
||||
return "", fmt.Errorf("failed to get name of TUN device: %w", errno)
|
||||
}
|
||||
return unix.ByteSliceToString(ifr[:]), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
|
||||
tun.writeOpMu.Lock()
|
||||
defer func() {
|
||||
tun.tcpGROTable.reset()
|
||||
tun.udpGROTable.reset()
|
||||
tun.writeOpMu.Unlock()
|
||||
}()
|
||||
var (
|
||||
errs error
|
||||
total int
|
||||
)
|
||||
tun.toWrite = tun.toWrite[:0]
|
||||
if tun.vnetHdr {
|
||||
err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.gro, &tun.toWrite)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
offset -= virtioNetHdrLen
|
||||
} else {
|
||||
for i := range bufs {
|
||||
tun.toWrite = append(tun.toWrite, i)
|
||||
}
|
||||
}
|
||||
for _, bufsI := range tun.toWrite {
|
||||
n, err := tun.tunFile.Write(bufs[bufsI][offset:])
|
||||
if errors.Is(err, syscall.EBADFD) {
|
||||
return total, os.ErrClosed
|
||||
}
|
||||
if err != nil {
|
||||
errs = errors.Join(errs, err)
|
||||
} else {
|
||||
total += n
|
||||
}
|
||||
}
|
||||
return total, errs
|
||||
}
|
||||
|
||||
// handleVirtioRead splits in into bufs, leaving offset bytes at the front of
|
||||
// each buffer. It mutates sizes to reflect the size of each element of bufs,
|
||||
// and returns the number of packets read.
|
||||
func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) {
|
||||
var hdr virtioNetHdr
|
||||
err := hdr.decode(in)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
in = in[virtioNetHdrLen:]
|
||||
|
||||
options, err := hdr.toGSOOptions()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Don't trust HdrLen from the kernel as it can be equal to the length
|
||||
// of the entire first packet when the kernel is handling it as part of a
|
||||
// FORWARD path. Instead, parse the transport header length and add it onto
|
||||
// CsumStart, which is synonymous for IP header length.
|
||||
if options.GSOType == GSOUDPL4 {
|
||||
options.HdrLen = options.CsumStart + 8
|
||||
} else if options.GSOType != GSONone {
|
||||
if len(in) <= int(options.CsumStart+12) {
|
||||
return 0, errors.New("packet is too short")
|
||||
}
|
||||
|
||||
tcpHLen := uint16(in[options.CsumStart+12] >> 4 * 4)
|
||||
if tcpHLen < 20 || tcpHLen > 60 {
|
||||
// A TCP header must be between 20 and 60 bytes in length.
|
||||
return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
|
||||
}
|
||||
options.HdrLen = options.CsumStart + tcpHLen
|
||||
}
|
||||
|
||||
return GSOSplit(in, options, bufs, sizes, offset)
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
|
||||
tun.readOpMu.Lock()
|
||||
defer tun.readOpMu.Unlock()
|
||||
select {
|
||||
case err := <-tun.errors:
|
||||
return 0, err
|
||||
default:
|
||||
readInto := bufs[0][offset:]
|
||||
if tun.vnetHdr {
|
||||
readInto = tun.readBuff[:]
|
||||
}
|
||||
n, err := tun.tunFile.Read(readInto)
|
||||
if errors.Is(err, syscall.EBADFD) {
|
||||
err = os.ErrClosed
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if tun.vnetHdr {
|
||||
return handleVirtioRead(readInto[:n], bufs, sizes, offset)
|
||||
} else {
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Events() <-chan Event {
|
||||
return tun.events
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Close() error {
|
||||
var err1, err2 error
|
||||
tun.closeOnce.Do(func() {
|
||||
if tun.statusListenersShutdown != nil {
|
||||
close(tun.statusListenersShutdown)
|
||||
if tun.netlinkCancel != nil {
|
||||
err1 = tun.netlinkCancel.Cancel()
|
||||
}
|
||||
} else if tun.events != nil {
|
||||
close(tun.events)
|
||||
}
|
||||
err2 = tun.tunFile.Close()
|
||||
})
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
return err2
|
||||
}
|
||||
|
||||
func (tun *NativeTun) BatchSize() int {
|
||||
return tun.batchSize
|
||||
}
|
||||
|
||||
// DisableUDPGRO disables UDP GRO if it is enabled. See the GRODevice interface
|
||||
// for cases where it should be called.
|
||||
func (tun *NativeTun) DisableUDPGRO() {
|
||||
tun.writeOpMu.Lock()
|
||||
tun.gro.disableUDPGRO()
|
||||
tun.writeOpMu.Unlock()
|
||||
}
|
||||
|
||||
// DisableTCPGRO disables TCP GRO if it is enabled. See the GRODevice interface
|
||||
// for cases where it should be called.
|
||||
func (tun *NativeTun) DisableTCPGRO() {
|
||||
tun.writeOpMu.Lock()
|
||||
tun.gro.disableTCPGRO()
|
||||
tun.writeOpMu.Unlock()
|
||||
}
|
||||
|
||||
const (
|
||||
// TODO: support TSO with ECN bits
|
||||
tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
|
||||
tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6
|
||||
)
|
||||
|
||||
func (tun *NativeTun) initFromFlags(name string) error {
|
||||
sc, err := tun.tunFile.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if e := sc.Control(func(fd uintptr) {
|
||||
var (
|
||||
ifr *unix.Ifreq
|
||||
)
|
||||
ifr, err = unix.NewIfreq(name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
got := ifr.Uint16()
|
||||
if got&unix.IFF_VNET_HDR != 0 {
|
||||
// tunTCPOffloads were added in Linux v2.6. We require their support
|
||||
// if IFF_VNET_HDR is set.
|
||||
err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
tun.vnetHdr = true
|
||||
tun.batchSize = conn.IdealBatchSize
|
||||
// tunUDPOffloads were added in Linux v6.2. We do not return an
|
||||
// error if they are unsupported at runtime.
|
||||
if unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads) != nil {
|
||||
tun.gro.disableUDPGRO()
|
||||
}
|
||||
} else {
|
||||
tun.batchSize = 1
|
||||
}
|
||||
}); e != nil {
|
||||
return e
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// CreateTUN creates a Device with the provided name and MTU.
|
||||
func CreateTUN(name string, mtu int) (Device, error) {
|
||||
nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("CreateTUN(%q) failed; %s does not exist", name, cloneDevicePath)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ifr, err := unix.NewIfreq(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// IFF_VNET_HDR enables the "tun status hack" via routineHackListener()
|
||||
// where a null write will return EINVAL indicating the TUN is up.
|
||||
ifr.SetUint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR)
|
||||
err = unix.IoctlIfreq(nfd, unix.TUNSETIFF, ifr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = unix.SetNonblock(nfd, true)
|
||||
if err != nil {
|
||||
unix.Close(nfd)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Note that the above -- open,ioctl,nonblock -- must happen prior to handing it to netpoll as below this line.
|
||||
|
||||
fd := os.NewFile(uintptr(nfd), cloneDevicePath)
|
||||
return CreateTUNFromFile(fd, mtu)
|
||||
}
|
||||
|
||||
// CreateTUNFromFile creates a Device from an os.File with the provided MTU.
|
||||
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
|
||||
tun := &NativeTun{
|
||||
tunFile: file,
|
||||
events: make(chan Event, 5),
|
||||
errors: make(chan error, 5),
|
||||
statusListenersShutdown: make(chan struct{}),
|
||||
tcpGROTable: newTCPGROTable(),
|
||||
udpGROTable: newUDPGROTable(),
|
||||
toWrite: make([]int, 0, conn.IdealBatchSize),
|
||||
}
|
||||
|
||||
name, err := tun.Name()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = tun.initFromFlags(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// start event listener
|
||||
tun.index, err = getIFIndex(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tun.netlinkSock, err = createNetlinkSocket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tun.netlinkCancel, err = rwcancel.NewRWCancel(tun.netlinkSock)
|
||||
if err != nil {
|
||||
unix.Close(tun.netlinkSock)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tun.hackListenerClosed.Lock()
|
||||
go tun.routineNetlinkListener()
|
||||
go tun.routineHackListener() // cross namespace
|
||||
|
||||
err = tun.setMTU(mtu)
|
||||
if err != nil {
|
||||
unix.Close(tun.netlinkSock)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tun, nil
|
||||
}
|
||||
|
||||
// CreateUnmonitoredTUNFromFD creates a Device from the provided file
|
||||
// descriptor.
|
||||
func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
|
||||
err := unix.SetNonblock(fd, true)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
file := os.NewFile(uintptr(fd), "/dev/tun")
|
||||
tun := &NativeTun{
|
||||
tunFile: file,
|
||||
events: make(chan Event, 5),
|
||||
errors: make(chan error, 5),
|
||||
tcpGROTable: newTCPGROTable(),
|
||||
udpGROTable: newUDPGROTable(),
|
||||
toWrite: make([]int, 0, conn.IdealBatchSize),
|
||||
}
|
||||
name, err := tun.Name()
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
err = tun.initFromFlags(name)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return tun, name, err
|
||||
}
|
||||
333
vendor/github.com/tailscale/wireguard-go/tun/tun_openbsd.go
generated
vendored
Normal file
333
vendor/github.com/tailscale/wireguard-go/tun/tun_openbsd.go
generated
vendored
Normal file
@@ -0,0 +1,333 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// Structure for iface mtu get/set ioctls
|
||||
type ifreq_mtu struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
MTU uint32
|
||||
Pad0 [12]byte
|
||||
}
|
||||
|
||||
const _TUNSIFMODE = 0x8004745d
|
||||
|
||||
type NativeTun struct {
|
||||
name string
|
||||
tunFile *os.File
|
||||
events chan Event
|
||||
errors chan error
|
||||
routeSocket int
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func (tun *NativeTun) routineRouteListener(tunIfindex int) {
|
||||
var (
|
||||
statusUp bool
|
||||
statusMTU int
|
||||
)
|
||||
|
||||
defer close(tun.events)
|
||||
|
||||
check := func() bool {
|
||||
iface, err := net.InterfaceByIndex(tunIfindex)
|
||||
if err != nil {
|
||||
tun.errors <- err
|
||||
return true
|
||||
}
|
||||
|
||||
// Up / Down event
|
||||
up := (iface.Flags & net.FlagUp) != 0
|
||||
if up != statusUp && up {
|
||||
tun.events <- EventUp
|
||||
}
|
||||
if up != statusUp && !up {
|
||||
tun.events <- EventDown
|
||||
}
|
||||
statusUp = up
|
||||
|
||||
// MTU changes
|
||||
if iface.MTU != statusMTU {
|
||||
tun.events <- EventMTUUpdate
|
||||
}
|
||||
statusMTU = iface.MTU
|
||||
return false
|
||||
}
|
||||
|
||||
if check() {
|
||||
return
|
||||
}
|
||||
|
||||
data := make([]byte, os.Getpagesize())
|
||||
for {
|
||||
n, err := unix.Read(tun.routeSocket, data)
|
||||
if err != nil {
|
||||
if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR {
|
||||
continue
|
||||
}
|
||||
tun.errors <- err
|
||||
return
|
||||
}
|
||||
|
||||
if n < 8 {
|
||||
continue
|
||||
}
|
||||
|
||||
if data[3 /* type */] != unix.RTM_IFINFO {
|
||||
continue
|
||||
}
|
||||
ifindex := int(*(*uint16)(unsafe.Pointer(&data[6 /* ifindex */])))
|
||||
if ifindex != tunIfindex {
|
||||
continue
|
||||
}
|
||||
if check() {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTUN(name string, mtu int) (Device, error) {
|
||||
ifIndex := -1
|
||||
if name != "tun" {
|
||||
_, err := fmt.Sscanf(name, "tun%d", &ifIndex)
|
||||
if err != nil || ifIndex < 0 {
|
||||
return nil, fmt.Errorf("Interface name must be tun[0-9]*")
|
||||
}
|
||||
}
|
||||
|
||||
var tunfile *os.File
|
||||
var err error
|
||||
|
||||
if ifIndex != -1 {
|
||||
tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR|unix.O_CLOEXEC, 0)
|
||||
} else {
|
||||
for ifIndex = 0; ifIndex < 256; ifIndex++ {
|
||||
tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR|unix.O_CLOEXEC, 0)
|
||||
if err == nil || !errors.Is(err, syscall.EBUSY) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tun, err := CreateTUNFromFile(tunfile, mtu)
|
||||
|
||||
if err == nil && name == "tun" {
|
||||
fname := os.Getenv("WG_TUN_NAME_FILE")
|
||||
if fname != "" {
|
||||
os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0o400)
|
||||
}
|
||||
}
|
||||
|
||||
return tun, err
|
||||
}
|
||||
|
||||
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
|
||||
tun := &NativeTun{
|
||||
tunFile: file,
|
||||
events: make(chan Event, 10),
|
||||
errors: make(chan error, 1),
|
||||
}
|
||||
|
||||
name, err := tun.Name()
|
||||
if err != nil {
|
||||
tun.tunFile.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tunIfindex, err := func() (int, error) {
|
||||
iface, err := net.InterfaceByName(name)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
return iface.Index, nil
|
||||
}()
|
||||
if err != nil {
|
||||
tun.tunFile.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
tun.tunFile.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go tun.routineRouteListener(tunIfindex)
|
||||
|
||||
currentMTU, err := tun.MTU()
|
||||
if err != nil || currentMTU != mtu {
|
||||
err = tun.setMTU(mtu)
|
||||
if err != nil {
|
||||
tun.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return tun, nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Name() (string, error) {
|
||||
gostat, err := tun.tunFile.Stat()
|
||||
if err != nil {
|
||||
tun.name = ""
|
||||
return "", err
|
||||
}
|
||||
stat := gostat.Sys().(*syscall.Stat_t)
|
||||
tun.name = fmt.Sprintf("tun%d", stat.Rdev%256)
|
||||
return tun.name, nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) File() *os.File {
|
||||
return tun.tunFile
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Events() <-chan Event {
|
||||
return tun.events
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
|
||||
select {
|
||||
case err := <-tun.errors:
|
||||
return 0, err
|
||||
default:
|
||||
buf := bufs[0][offset-4:]
|
||||
n, err := tun.tunFile.Read(buf[:])
|
||||
if n < 4 {
|
||||
return 0, err
|
||||
}
|
||||
sizes[0] = n - 4
|
||||
return 1, err
|
||||
}
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
|
||||
if offset < 4 {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
for i, buf := range bufs {
|
||||
buf = buf[offset-4:]
|
||||
buf[0] = 0x00
|
||||
buf[1] = 0x00
|
||||
buf[2] = 0x00
|
||||
switch buf[4] >> 4 {
|
||||
case 4:
|
||||
buf[3] = unix.AF_INET
|
||||
case 6:
|
||||
buf[3] = unix.AF_INET6
|
||||
default:
|
||||
return i, unix.EAFNOSUPPORT
|
||||
}
|
||||
if _, err := tun.tunFile.Write(buf); err != nil {
|
||||
return i, err
|
||||
}
|
||||
}
|
||||
return len(bufs), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Close() error {
|
||||
var err1, err2 error
|
||||
tun.closeOnce.Do(func() {
|
||||
err1 = tun.tunFile.Close()
|
||||
if tun.routeSocket != -1 {
|
||||
unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
|
||||
err2 = unix.Close(tun.routeSocket)
|
||||
tun.routeSocket = -1
|
||||
} else if tun.events != nil {
|
||||
close(tun.events)
|
||||
}
|
||||
})
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
return err2
|
||||
}
|
||||
|
||||
func (tun *NativeTun) setMTU(n int) error {
|
||||
// open datagram socket
|
||||
|
||||
var fd int
|
||||
|
||||
fd, err := unix.Socket(
|
||||
unix.AF_INET,
|
||||
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer unix.Close(fd)
|
||||
|
||||
// do ioctl call
|
||||
|
||||
var ifr ifreq_mtu
|
||||
copy(ifr.Name[:], tun.name)
|
||||
ifr.MTU = uint32(n)
|
||||
|
||||
_, _, errno := unix.Syscall(
|
||||
unix.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
uintptr(unix.SIOCSIFMTU),
|
||||
uintptr(unsafe.Pointer(&ifr)),
|
||||
)
|
||||
|
||||
if errno != 0 {
|
||||
return fmt.Errorf("failed to set MTU on %s", tun.name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) MTU() (int, error) {
|
||||
// open datagram socket
|
||||
|
||||
fd, err := unix.Socket(
|
||||
unix.AF_INET,
|
||||
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
defer unix.Close(fd)
|
||||
|
||||
// do ioctl call
|
||||
var ifr ifreq_mtu
|
||||
copy(ifr.Name[:], tun.name)
|
||||
|
||||
_, _, errno := unix.Syscall(
|
||||
unix.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
uintptr(unix.SIOCGIFMTU),
|
||||
uintptr(unsafe.Pointer(&ifr)),
|
||||
)
|
||||
if errno != 0 {
|
||||
return 0, fmt.Errorf("failed to get MTU on %s", tun.name)
|
||||
}
|
||||
|
||||
return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
242
vendor/github.com/tailscale/wireguard-go/tun/tun_windows.go
generated
vendored
Normal file
242
vendor/github.com/tailscale/wireguard-go/tun/tun_windows.go
generated
vendored
Normal file
@@ -0,0 +1,242 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
_ "unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.zx2c4.com/wintun"
|
||||
)
|
||||
|
||||
const (
|
||||
rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond)
|
||||
spinloopRateThreshold = 800000000 / 8 // 800mbps
|
||||
spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s
|
||||
)
|
||||
|
||||
type rateJuggler struct {
|
||||
current atomic.Uint64
|
||||
nextByteCount atomic.Uint64
|
||||
nextStartTime atomic.Int64
|
||||
changing atomic.Bool
|
||||
}
|
||||
|
||||
type NativeTun struct {
|
||||
wt *wintun.Adapter
|
||||
name string
|
||||
handle windows.Handle
|
||||
rate rateJuggler
|
||||
session wintun.Session
|
||||
readWait windows.Handle
|
||||
events chan Event
|
||||
running sync.WaitGroup
|
||||
closeOnce sync.Once
|
||||
close atomic.Bool
|
||||
forcedMTU int
|
||||
outSizes []int
|
||||
}
|
||||
|
||||
var (
|
||||
WintunTunnelType = "WireGuard"
|
||||
WintunStaticRequestedGUID *windows.GUID
|
||||
)
|
||||
|
||||
//go:linkname procyield runtime.procyield
|
||||
func procyield(cycles uint32)
|
||||
|
||||
//go:linkname nanotime runtime.nanotime
|
||||
func nanotime() int64
|
||||
|
||||
// CreateTUN creates a Wintun interface with the given name. Should a Wintun
|
||||
// interface with the same name exist, it is reused.
|
||||
func CreateTUN(ifname string, mtu int) (Device, error) {
|
||||
return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu)
|
||||
}
|
||||
|
||||
// CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
|
||||
// a requested GUID. Should a Wintun interface with the same name exist, it is reused.
|
||||
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
|
||||
wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error creating interface: %w", err)
|
||||
}
|
||||
|
||||
forcedMTU := 1420
|
||||
if mtu > 0 {
|
||||
forcedMTU = mtu
|
||||
}
|
||||
|
||||
tun := &NativeTun{
|
||||
wt: wt,
|
||||
name: ifname,
|
||||
handle: windows.InvalidHandle,
|
||||
events: make(chan Event, 10),
|
||||
forcedMTU: forcedMTU,
|
||||
}
|
||||
|
||||
tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB
|
||||
if err != nil {
|
||||
tun.wt.Close()
|
||||
close(tun.events)
|
||||
return nil, fmt.Errorf("Error starting session: %w", err)
|
||||
}
|
||||
tun.readWait = tun.session.ReadWaitEvent()
|
||||
return tun, nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Name() (string, error) {
|
||||
return tun.name, nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) File() *os.File {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Events() <-chan Event {
|
||||
return tun.events
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Close() error {
|
||||
var err error
|
||||
tun.closeOnce.Do(func() {
|
||||
tun.close.Store(true)
|
||||
windows.SetEvent(tun.readWait)
|
||||
tun.running.Wait()
|
||||
tun.session.End()
|
||||
if tun.wt != nil {
|
||||
tun.wt.Close()
|
||||
}
|
||||
close(tun.events)
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (tun *NativeTun) MTU() (int, error) {
|
||||
return tun.forcedMTU, nil
|
||||
}
|
||||
|
||||
// TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes.
|
||||
func (tun *NativeTun) ForceMTU(mtu int) {
|
||||
if tun.close.Load() {
|
||||
return
|
||||
}
|
||||
update := tun.forcedMTU != mtu
|
||||
tun.forcedMTU = mtu
|
||||
if update {
|
||||
tun.events <- EventMTUUpdate
|
||||
}
|
||||
}
|
||||
|
||||
func (tun *NativeTun) BatchSize() int {
|
||||
// TODO: implement batching with wintun
|
||||
return 1
|
||||
}
|
||||
|
||||
// Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
|
||||
|
||||
func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
|
||||
tun.running.Add(1)
|
||||
defer tun.running.Done()
|
||||
retry:
|
||||
if tun.close.Load() {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
start := nanotime()
|
||||
shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2
|
||||
for {
|
||||
if tun.close.Load() {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
packet, err := tun.session.ReceivePacket()
|
||||
switch err {
|
||||
case nil:
|
||||
packetSize := len(packet)
|
||||
copy(bufs[0][offset:], packet)
|
||||
sizes[0] = packetSize
|
||||
tun.session.ReleaseReceivePacket(packet)
|
||||
tun.rate.update(uint64(packetSize))
|
||||
return 1, nil
|
||||
case windows.ERROR_NO_MORE_ITEMS:
|
||||
if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
|
||||
windows.WaitForSingleObject(tun.readWait, windows.INFINITE)
|
||||
goto retry
|
||||
}
|
||||
procyield(1)
|
||||
continue
|
||||
case windows.ERROR_HANDLE_EOF:
|
||||
return 0, os.ErrClosed
|
||||
case windows.ERROR_INVALID_DATA:
|
||||
return 0, errors.New("Send ring corrupt")
|
||||
}
|
||||
return 0, fmt.Errorf("Read failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
|
||||
tun.running.Add(1)
|
||||
defer tun.running.Done()
|
||||
if tun.close.Load() {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
|
||||
for i, buf := range bufs {
|
||||
packetSize := len(buf) - offset
|
||||
tun.rate.update(uint64(packetSize))
|
||||
|
||||
packet, err := tun.session.AllocateSendPacket(packetSize)
|
||||
switch err {
|
||||
case nil:
|
||||
// TODO: Explore options to eliminate this copy.
|
||||
copy(packet, buf[offset:])
|
||||
tun.session.SendPacket(packet)
|
||||
continue
|
||||
case windows.ERROR_HANDLE_EOF:
|
||||
return i, os.ErrClosed
|
||||
case windows.ERROR_BUFFER_OVERFLOW:
|
||||
continue // Dropping when ring is full.
|
||||
default:
|
||||
return i, fmt.Errorf("Write failed: %w", err)
|
||||
}
|
||||
}
|
||||
return len(bufs), nil
|
||||
}
|
||||
|
||||
// LUID returns Windows interface instance ID.
|
||||
func (tun *NativeTun) LUID() uint64 {
|
||||
tun.running.Add(1)
|
||||
defer tun.running.Done()
|
||||
if tun.close.Load() {
|
||||
return 0
|
||||
}
|
||||
return tun.wt.LUID()
|
||||
}
|
||||
|
||||
// RunningVersion returns the running version of the Wintun driver.
|
||||
func (tun *NativeTun) RunningVersion() (version uint32, err error) {
|
||||
return wintun.RunningVersion()
|
||||
}
|
||||
|
||||
func (rate *rateJuggler) update(packetLen uint64) {
|
||||
now := nanotime()
|
||||
total := rate.nextByteCount.Add(packetLen)
|
||||
period := uint64(now - rate.nextStartTime.Load())
|
||||
if period >= rateMeasurementGranularity {
|
||||
if !rate.changing.CompareAndSwap(false, true) {
|
||||
return
|
||||
}
|
||||
rate.nextStartTime.Store(now)
|
||||
rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period)
|
||||
rate.nextByteCount.Store(0)
|
||||
rate.changing.Store(false)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user