Update dependencies

This commit is contained in:
bluepython508
2024-11-01 17:33:34 +00:00
parent 033ac0b400
commit 5cdfab398d
3596 changed files with 1033483 additions and 259 deletions

View File

@@ -0,0 +1,712 @@
package tun
import (
"encoding/binary"
"math/bits"
"strconv"
"golang.org/x/sys/cpu"
)
// checksumGeneric64 is a reference implementation of checksum using 64 bit
// arithmetic for use in testing or when an architecture-specific implementation
// is not available.
func checksumGeneric64(b []byte, initial uint16) uint16 {
var ac uint64
var carry uint64
if cpu.IsBigEndian {
ac = uint64(initial)
} else {
ac = uint64(bits.ReverseBytes16(initial))
}
for len(b) >= 128 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[32:40]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[40:48]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[48:56]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[56:64]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[64:72]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[72:80]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[80:88]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[88:96]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[96:104]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[104:112]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[112:120]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[120:128]), carry)
} else {
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[32:40]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[40:48]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[48:56]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[56:64]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[64:72]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[72:80]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[80:88]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[88:96]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[96:104]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[104:112]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[112:120]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[120:128]), carry)
}
b = b[128:]
}
if len(b) >= 64 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[32:40]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[40:48]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[48:56]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[56:64]), carry)
} else {
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[32:40]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[40:48]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[48:56]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[56:64]), carry)
}
b = b[64:]
}
if len(b) >= 32 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry)
} else {
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry)
}
b = b[32:]
}
if len(b) >= 16 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry)
} else {
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry)
}
b = b[16:]
}
if len(b) >= 8 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b), carry)
} else {
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b), carry)
}
b = b[8:]
}
if len(b) >= 4 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, uint64(binary.BigEndian.Uint32(b)), carry)
} else {
ac, carry = bits.Add64(ac, uint64(binary.LittleEndian.Uint32(b)), carry)
}
b = b[4:]
}
if len(b) >= 2 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, uint64(binary.BigEndian.Uint16(b)), carry)
} else {
ac, carry = bits.Add64(ac, uint64(binary.LittleEndian.Uint16(b)), carry)
}
b = b[2:]
}
if len(b) >= 1 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, uint64(b[0])<<8, carry)
} else {
ac, carry = bits.Add64(ac, uint64(b[0]), carry)
}
}
folded := ipChecksumFold64(ac, carry)
if !cpu.IsBigEndian {
folded = bits.ReverseBytes16(folded)
}
return folded
}
// checksumGeneric32 is a reference implementation of checksum using 32 bit
// arithmetic for use in testing or when an architecture-specific implementation
// is not available.
func checksumGeneric32(b []byte, initial uint16) uint16 {
var ac uint32
var carry uint32
if cpu.IsBigEndian {
ac = uint32(initial)
} else {
ac = uint32(bits.ReverseBytes16(initial))
}
for len(b) >= 64 {
if cpu.IsBigEndian {
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:8]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[16:20]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[20:24]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[24:28]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[28:32]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[32:36]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[36:40]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[40:44]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[44:48]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[48:52]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[52:56]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[56:60]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[60:64]), carry)
} else {
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:8]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[16:20]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[20:24]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[24:28]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[28:32]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[32:36]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[36:40]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[40:44]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[44:48]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[48:52]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[52:56]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[56:60]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[60:64]), carry)
}
b = b[64:]
}
if len(b) >= 32 {
if cpu.IsBigEndian {
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[16:20]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[20:24]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[24:28]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[28:32]), carry)
} else {
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[16:20]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[20:24]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[24:28]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[28:32]), carry)
}
b = b[32:]
}
if len(b) >= 16 {
if cpu.IsBigEndian {
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry)
} else {
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry)
}
b = b[16:]
}
if len(b) >= 8 {
if cpu.IsBigEndian {
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry)
} else {
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry)
}
b = b[8:]
}
if len(b) >= 4 {
if cpu.IsBigEndian {
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b), carry)
} else {
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b), carry)
}
b = b[4:]
}
if len(b) >= 2 {
if cpu.IsBigEndian {
ac, carry = bits.Add32(ac, uint32(binary.BigEndian.Uint16(b)), carry)
} else {
ac, carry = bits.Add32(ac, uint32(binary.LittleEndian.Uint16(b)), carry)
}
b = b[2:]
}
if len(b) >= 1 {
if cpu.IsBigEndian {
ac, carry = bits.Add32(ac, uint32(b[0])<<8, carry)
} else {
ac, carry = bits.Add32(ac, uint32(b[0]), carry)
}
}
folded := ipChecksumFold32(ac, carry)
if !cpu.IsBigEndian {
folded = bits.ReverseBytes16(folded)
}
return folded
}
// checksumGeneric32Alternate is an alternate reference implementation of
// checksum using 32 bit arithmetic for use in testing or when an
// architecture-specific implementation is not available.
func checksumGeneric32Alternate(b []byte, initial uint16) uint16 {
var ac uint32
if cpu.IsBigEndian {
ac = uint32(initial)
} else {
ac = uint32(bits.ReverseBytes16(initial))
}
for len(b) >= 64 {
if cpu.IsBigEndian {
ac += uint32(binary.BigEndian.Uint16(b[:2]))
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
ac += uint32(binary.BigEndian.Uint16(b[4:6]))
ac += uint32(binary.BigEndian.Uint16(b[6:8]))
ac += uint32(binary.BigEndian.Uint16(b[8:10]))
ac += uint32(binary.BigEndian.Uint16(b[10:12]))
ac += uint32(binary.BigEndian.Uint16(b[12:14]))
ac += uint32(binary.BigEndian.Uint16(b[14:16]))
ac += uint32(binary.BigEndian.Uint16(b[16:18]))
ac += uint32(binary.BigEndian.Uint16(b[18:20]))
ac += uint32(binary.BigEndian.Uint16(b[20:22]))
ac += uint32(binary.BigEndian.Uint16(b[22:24]))
ac += uint32(binary.BigEndian.Uint16(b[24:26]))
ac += uint32(binary.BigEndian.Uint16(b[26:28]))
ac += uint32(binary.BigEndian.Uint16(b[28:30]))
ac += uint32(binary.BigEndian.Uint16(b[30:32]))
ac += uint32(binary.BigEndian.Uint16(b[32:34]))
ac += uint32(binary.BigEndian.Uint16(b[34:36]))
ac += uint32(binary.BigEndian.Uint16(b[36:38]))
ac += uint32(binary.BigEndian.Uint16(b[38:40]))
ac += uint32(binary.BigEndian.Uint16(b[40:42]))
ac += uint32(binary.BigEndian.Uint16(b[42:44]))
ac += uint32(binary.BigEndian.Uint16(b[44:46]))
ac += uint32(binary.BigEndian.Uint16(b[46:48]))
ac += uint32(binary.BigEndian.Uint16(b[48:50]))
ac += uint32(binary.BigEndian.Uint16(b[50:52]))
ac += uint32(binary.BigEndian.Uint16(b[52:54]))
ac += uint32(binary.BigEndian.Uint16(b[54:56]))
ac += uint32(binary.BigEndian.Uint16(b[56:58]))
ac += uint32(binary.BigEndian.Uint16(b[58:60]))
ac += uint32(binary.BigEndian.Uint16(b[60:62]))
ac += uint32(binary.BigEndian.Uint16(b[62:64]))
} else {
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
ac += uint32(binary.LittleEndian.Uint16(b[4:6]))
ac += uint32(binary.LittleEndian.Uint16(b[6:8]))
ac += uint32(binary.LittleEndian.Uint16(b[8:10]))
ac += uint32(binary.LittleEndian.Uint16(b[10:12]))
ac += uint32(binary.LittleEndian.Uint16(b[12:14]))
ac += uint32(binary.LittleEndian.Uint16(b[14:16]))
ac += uint32(binary.LittleEndian.Uint16(b[16:18]))
ac += uint32(binary.LittleEndian.Uint16(b[18:20]))
ac += uint32(binary.LittleEndian.Uint16(b[20:22]))
ac += uint32(binary.LittleEndian.Uint16(b[22:24]))
ac += uint32(binary.LittleEndian.Uint16(b[24:26]))
ac += uint32(binary.LittleEndian.Uint16(b[26:28]))
ac += uint32(binary.LittleEndian.Uint16(b[28:30]))
ac += uint32(binary.LittleEndian.Uint16(b[30:32]))
ac += uint32(binary.LittleEndian.Uint16(b[32:34]))
ac += uint32(binary.LittleEndian.Uint16(b[34:36]))
ac += uint32(binary.LittleEndian.Uint16(b[36:38]))
ac += uint32(binary.LittleEndian.Uint16(b[38:40]))
ac += uint32(binary.LittleEndian.Uint16(b[40:42]))
ac += uint32(binary.LittleEndian.Uint16(b[42:44]))
ac += uint32(binary.LittleEndian.Uint16(b[44:46]))
ac += uint32(binary.LittleEndian.Uint16(b[46:48]))
ac += uint32(binary.LittleEndian.Uint16(b[48:50]))
ac += uint32(binary.LittleEndian.Uint16(b[50:52]))
ac += uint32(binary.LittleEndian.Uint16(b[52:54]))
ac += uint32(binary.LittleEndian.Uint16(b[54:56]))
ac += uint32(binary.LittleEndian.Uint16(b[56:58]))
ac += uint32(binary.LittleEndian.Uint16(b[58:60]))
ac += uint32(binary.LittleEndian.Uint16(b[60:62]))
ac += uint32(binary.LittleEndian.Uint16(b[62:64]))
}
b = b[64:]
}
if len(b) >= 32 {
if cpu.IsBigEndian {
ac += uint32(binary.BigEndian.Uint16(b[:2]))
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
ac += uint32(binary.BigEndian.Uint16(b[4:6]))
ac += uint32(binary.BigEndian.Uint16(b[6:8]))
ac += uint32(binary.BigEndian.Uint16(b[8:10]))
ac += uint32(binary.BigEndian.Uint16(b[10:12]))
ac += uint32(binary.BigEndian.Uint16(b[12:14]))
ac += uint32(binary.BigEndian.Uint16(b[14:16]))
ac += uint32(binary.BigEndian.Uint16(b[16:18]))
ac += uint32(binary.BigEndian.Uint16(b[18:20]))
ac += uint32(binary.BigEndian.Uint16(b[20:22]))
ac += uint32(binary.BigEndian.Uint16(b[22:24]))
ac += uint32(binary.BigEndian.Uint16(b[24:26]))
ac += uint32(binary.BigEndian.Uint16(b[26:28]))
ac += uint32(binary.BigEndian.Uint16(b[28:30]))
ac += uint32(binary.BigEndian.Uint16(b[30:32]))
} else {
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
ac += uint32(binary.LittleEndian.Uint16(b[4:6]))
ac += uint32(binary.LittleEndian.Uint16(b[6:8]))
ac += uint32(binary.LittleEndian.Uint16(b[8:10]))
ac += uint32(binary.LittleEndian.Uint16(b[10:12]))
ac += uint32(binary.LittleEndian.Uint16(b[12:14]))
ac += uint32(binary.LittleEndian.Uint16(b[14:16]))
ac += uint32(binary.LittleEndian.Uint16(b[16:18]))
ac += uint32(binary.LittleEndian.Uint16(b[18:20]))
ac += uint32(binary.LittleEndian.Uint16(b[20:22]))
ac += uint32(binary.LittleEndian.Uint16(b[22:24]))
ac += uint32(binary.LittleEndian.Uint16(b[24:26]))
ac += uint32(binary.LittleEndian.Uint16(b[26:28]))
ac += uint32(binary.LittleEndian.Uint16(b[28:30]))
ac += uint32(binary.LittleEndian.Uint16(b[30:32]))
}
b = b[32:]
}
if len(b) >= 16 {
if cpu.IsBigEndian {
ac += uint32(binary.BigEndian.Uint16(b[:2]))
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
ac += uint32(binary.BigEndian.Uint16(b[4:6]))
ac += uint32(binary.BigEndian.Uint16(b[6:8]))
ac += uint32(binary.BigEndian.Uint16(b[8:10]))
ac += uint32(binary.BigEndian.Uint16(b[10:12]))
ac += uint32(binary.BigEndian.Uint16(b[12:14]))
ac += uint32(binary.BigEndian.Uint16(b[14:16]))
} else {
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
ac += uint32(binary.LittleEndian.Uint16(b[4:6]))
ac += uint32(binary.LittleEndian.Uint16(b[6:8]))
ac += uint32(binary.LittleEndian.Uint16(b[8:10]))
ac += uint32(binary.LittleEndian.Uint16(b[10:12]))
ac += uint32(binary.LittleEndian.Uint16(b[12:14]))
ac += uint32(binary.LittleEndian.Uint16(b[14:16]))
}
b = b[16:]
}
if len(b) >= 8 {
if cpu.IsBigEndian {
ac += uint32(binary.BigEndian.Uint16(b[:2]))
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
ac += uint32(binary.BigEndian.Uint16(b[4:6]))
ac += uint32(binary.BigEndian.Uint16(b[6:8]))
} else {
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
ac += uint32(binary.LittleEndian.Uint16(b[4:6]))
ac += uint32(binary.LittleEndian.Uint16(b[6:8]))
}
b = b[8:]
}
if len(b) >= 4 {
if cpu.IsBigEndian {
ac += uint32(binary.BigEndian.Uint16(b[:2]))
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
} else {
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
}
b = b[4:]
}
if len(b) >= 2 {
if cpu.IsBigEndian {
ac += uint32(binary.BigEndian.Uint16(b))
} else {
ac += uint32(binary.LittleEndian.Uint16(b))
}
b = b[2:]
}
if len(b) >= 1 {
if cpu.IsBigEndian {
ac += uint32(b[0]) << 8
} else {
ac += uint32(b[0])
}
}
folded := ipChecksumFold32(ac, 0)
if !cpu.IsBigEndian {
folded = bits.ReverseBytes16(folded)
}
return folded
}
// checksumGeneric64Alternate is an alternate reference implementation of
// checksum using 64 bit arithmetic for use in testing or when an
// architecture-specific implementation is not available.
func checksumGeneric64Alternate(b []byte, initial uint16) uint16 {
var ac uint64
if cpu.IsBigEndian {
ac = uint64(initial)
} else {
ac = uint64(bits.ReverseBytes16(initial))
}
for len(b) >= 64 {
if cpu.IsBigEndian {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
ac += uint64(binary.BigEndian.Uint32(b[32:36]))
ac += uint64(binary.BigEndian.Uint32(b[36:40]))
ac += uint64(binary.BigEndian.Uint32(b[40:44]))
ac += uint64(binary.BigEndian.Uint32(b[44:48]))
ac += uint64(binary.BigEndian.Uint32(b[48:52]))
ac += uint64(binary.BigEndian.Uint32(b[52:56]))
ac += uint64(binary.BigEndian.Uint32(b[56:60]))
ac += uint64(binary.BigEndian.Uint32(b[60:64]))
} else {
ac += uint64(binary.LittleEndian.Uint32(b[:4]))
ac += uint64(binary.LittleEndian.Uint32(b[4:8]))
ac += uint64(binary.LittleEndian.Uint32(b[8:12]))
ac += uint64(binary.LittleEndian.Uint32(b[12:16]))
ac += uint64(binary.LittleEndian.Uint32(b[16:20]))
ac += uint64(binary.LittleEndian.Uint32(b[20:24]))
ac += uint64(binary.LittleEndian.Uint32(b[24:28]))
ac += uint64(binary.LittleEndian.Uint32(b[28:32]))
ac += uint64(binary.LittleEndian.Uint32(b[32:36]))
ac += uint64(binary.LittleEndian.Uint32(b[36:40]))
ac += uint64(binary.LittleEndian.Uint32(b[40:44]))
ac += uint64(binary.LittleEndian.Uint32(b[44:48]))
ac += uint64(binary.LittleEndian.Uint32(b[48:52]))
ac += uint64(binary.LittleEndian.Uint32(b[52:56]))
ac += uint64(binary.LittleEndian.Uint32(b[56:60]))
ac += uint64(binary.LittleEndian.Uint32(b[60:64]))
}
b = b[64:]
}
if len(b) >= 32 {
if cpu.IsBigEndian {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
} else {
ac += uint64(binary.LittleEndian.Uint32(b[:4]))
ac += uint64(binary.LittleEndian.Uint32(b[4:8]))
ac += uint64(binary.LittleEndian.Uint32(b[8:12]))
ac += uint64(binary.LittleEndian.Uint32(b[12:16]))
ac += uint64(binary.LittleEndian.Uint32(b[16:20]))
ac += uint64(binary.LittleEndian.Uint32(b[20:24]))
ac += uint64(binary.LittleEndian.Uint32(b[24:28]))
ac += uint64(binary.LittleEndian.Uint32(b[28:32]))
}
b = b[32:]
}
if len(b) >= 16 {
if cpu.IsBigEndian {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
} else {
ac += uint64(binary.LittleEndian.Uint32(b[:4]))
ac += uint64(binary.LittleEndian.Uint32(b[4:8]))
ac += uint64(binary.LittleEndian.Uint32(b[8:12]))
ac += uint64(binary.LittleEndian.Uint32(b[12:16]))
}
b = b[16:]
}
if len(b) >= 8 {
if cpu.IsBigEndian {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
} else {
ac += uint64(binary.LittleEndian.Uint32(b[:4]))
ac += uint64(binary.LittleEndian.Uint32(b[4:8]))
}
b = b[8:]
}
if len(b) >= 4 {
if cpu.IsBigEndian {
ac += uint64(binary.BigEndian.Uint32(b))
} else {
ac += uint64(binary.LittleEndian.Uint32(b))
}
b = b[4:]
}
if len(b) >= 2 {
if cpu.IsBigEndian {
ac += uint64(binary.BigEndian.Uint16(b))
} else {
ac += uint64(binary.LittleEndian.Uint16(b))
}
b = b[2:]
}
if len(b) >= 1 {
if cpu.IsBigEndian {
ac += uint64(b[0]) << 8
} else {
ac += uint64(b[0])
}
}
folded := ipChecksumFold64(ac, 0)
if !cpu.IsBigEndian {
folded = bits.ReverseBytes16(folded)
}
return folded
}
func ipChecksumFold64(unfolded uint64, initialCarry uint64) uint16 {
sum, carry := bits.Add32(uint32(unfolded>>32), uint32(unfolded&0xffff_ffff), uint32(initialCarry))
// if carry != 0, sum <= 0xffff_fffe, otherwise sum <= 0xffff_ffff
// therefore (sum >> 16) + (sum & 0xffff) + carry <= 0x1_fffe; so there is
// no need to save the carry flag
sum = (sum >> 16) + (sum & 0xffff) + carry
// sum <= 0x1_fffe therefore this is the last fold needed:
// if (sum >> 16) > 0 then
// (sum >> 16) == 1 && (sum & 0xffff) <= 0xfffe and therefore
// the addition will not overflow
// otherwise (sum >> 16) == 0 and sum will be unchanged
sum = (sum >> 16) + (sum & 0xffff)
return uint16(sum)
}
func ipChecksumFold32(unfolded uint32, initialCarry uint32) uint16 {
sum := (unfolded >> 16) + (unfolded & 0xffff) + initialCarry
// sum <= 0x1_ffff:
// 0xffff + 0xffff = 0x1_fffe
// initialCarry is 0 or 1, for a combined maximum of 0x1_ffff
sum = (sum >> 16) + (sum & 0xffff)
// sum <= 0x1_0000 therefore this is the last fold needed:
// if (sum >> 16) > 0 then
// (sum >> 16) == 1 && (sum & 0xffff) == 0 and therefore
// the addition will not overflow
// otherwise (sum >> 16) == 0 and sum will be unchanged
sum = (sum >> 16) + (sum & 0xffff)
return uint16(sum)
}
func addrPartialChecksum64(addr []byte, initial, carryIn uint64) (sum, carry uint64) {
sum, carry = initial, carryIn
switch len(addr) {
case 4: // IPv4
if cpu.IsBigEndian {
sum, carry = bits.Add64(sum, uint64(binary.BigEndian.Uint32(addr)), carry)
} else {
sum, carry = bits.Add64(sum, uint64(binary.LittleEndian.Uint32(addr)), carry)
}
case 16: // IPv6
if cpu.IsBigEndian {
sum, carry = bits.Add64(sum, binary.BigEndian.Uint64(addr), carry)
sum, carry = bits.Add64(sum, binary.BigEndian.Uint64(addr[8:]), carry)
} else {
sum, carry = bits.Add64(sum, binary.LittleEndian.Uint64(addr), carry)
sum, carry = bits.Add64(sum, binary.LittleEndian.Uint64(addr[8:]), carry)
}
default:
panic("bad addr length")
}
return sum, carry
}
func addrPartialChecksum32(addr []byte, initial, carryIn uint32) (sum, carry uint32) {
sum, carry = initial, carryIn
switch len(addr) {
case 4: // IPv4
if cpu.IsBigEndian {
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr), carry)
} else {
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr), carry)
}
case 16: // IPv6
if cpu.IsBigEndian {
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr), carry)
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[4:8]), carry)
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[8:12]), carry)
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[12:16]), carry)
} else {
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr), carry)
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[4:8]), carry)
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[8:12]), carry)
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[12:16]), carry)
}
default:
panic("bad addr length")
}
return sum, carry
}
func pseudoHeaderChecksum64(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 {
var sum uint64
if cpu.IsBigEndian {
sum = uint64(totalLen) + uint64(protocol)
} else {
sum = uint64(bits.ReverseBytes16(totalLen)) + uint64(protocol)<<8
}
sum, carry := addrPartialChecksum64(srcAddr, sum, 0)
sum, carry = addrPartialChecksum64(dstAddr, sum, carry)
foldedSum := ipChecksumFold64(sum, carry)
if !cpu.IsBigEndian {
foldedSum = bits.ReverseBytes16(foldedSum)
}
return foldedSum
}
func pseudoHeaderChecksum32(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 {
var sum uint32
if cpu.IsBigEndian {
sum = uint32(totalLen) + uint32(protocol)
} else {
sum = uint32(bits.ReverseBytes16(totalLen)) + uint32(protocol)<<8
}
sum, carry := addrPartialChecksum32(srcAddr, sum, 0)
sum, carry = addrPartialChecksum32(dstAddr, sum, carry)
foldedSum := ipChecksumFold32(sum, carry)
if !cpu.IsBigEndian {
foldedSum = bits.ReverseBytes16(foldedSum)
}
return foldedSum
}
// PseudoHeaderChecksum computes an IP pseudo-header checksum. srcAddr and
// dstAddr must be 4 or 16 bytes in length.
func PseudoHeaderChecksum(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 {
if strconv.IntSize < 64 {
return pseudoHeaderChecksum32(protocol, srcAddr, dstAddr, totalLen)
}
return pseudoHeaderChecksum64(protocol, srcAddr, dstAddr, totalLen)
}

View File

@@ -0,0 +1,23 @@
package tun
import "golang.org/x/sys/cpu"
var checksum = checksumAMD64
// Checksum computes an IP checksum starting with the provided initial value.
// The length of data should be at least 128 bytes for best performance. Smaller
// buffers will still compute a correct result.
func Checksum(data []byte, initial uint16) uint16 {
return checksum(data, initial)
}
func init() {
if cpu.X86.HasAVX && cpu.X86.HasAVX2 && cpu.X86.HasBMI2 {
checksum = checksumAVX2
return
}
if cpu.X86.HasSSE2 {
checksum = checksumSSE2
return
}
}

View File

@@ -0,0 +1,18 @@
// Code generated by command: go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go. DO NOT EDIT.
package tun
// checksumAVX2 computes an IP checksum using amd64 v3 instructions (AVX2, BMI2)
//
//go:noescape
func checksumAVX2(b []byte, initial uint16) uint16
// checksumSSE2 computes an IP checksum using amd64 baseline instructions (SSE2)
//
//go:noescape
func checksumSSE2(b []byte, initial uint16) uint16
// checksumAMD64 computes an IP checksum using amd64 baseline instructions
//
//go:noescape
func checksumAMD64(b []byte, initial uint16) uint16

View File

@@ -0,0 +1,851 @@
// Code generated by command: go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go. DO NOT EDIT.
#include "textflag.h"
DATA xmmLoadMasks<>+0(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff"
DATA xmmLoadMasks<>+16(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff"
DATA xmmLoadMasks<>+32(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff"
DATA xmmLoadMasks<>+48(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff"
DATA xmmLoadMasks<>+64(SB)/16, $"\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
DATA xmmLoadMasks<>+80(SB)/16, $"\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
DATA xmmLoadMasks<>+96(SB)/16, $"\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
GLOBL xmmLoadMasks<>(SB), RODATA|NOPTR, $112
// func checksumAVX2(b []byte, initial uint16) uint16
// Requires: AVX, AVX2, BMI2
TEXT ·checksumAVX2(SB), NOSPLIT|NOFRAME, $0-34
MOVWQZX initial+24(FP), AX
XCHGB AH, AL
MOVQ b_base+0(FP), DX
MOVQ b_len+8(FP), BX
// handle odd length buffers; they are difficult to handle in general
TESTQ $0x00000001, BX
JZ lengthIsEven
MOVBQZX -1(DX)(BX*1), CX
DECQ BX
ADDQ CX, AX
lengthIsEven:
// handle tiny buffers (<=31 bytes) specially
CMPQ BX, $0x1f
JGT bufferIsNotTiny
XORQ CX, CX
XORQ SI, SI
XORQ DI, DI
// shift twice to start because length is guaranteed to be even
// n = n >> 2; CF = originalN & 2
SHRQ $0x02, BX
JNC handleTiny4
// tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]
MOVWQZX (DX), CX
ADDQ $0x02, DX
handleTiny4:
// n = n >> 1; CF = originalN & 4
SHRQ $0x01, BX
JNC handleTiny8
// tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]
MOVLQZX (DX), SI
ADDQ $0x04, DX
handleTiny8:
// n = n >> 1; CF = originalN & 8
SHRQ $0x01, BX
JNC handleTiny16
// tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]
MOVQ (DX), DI
ADDQ $0x08, DX
handleTiny16:
// n = n >> 1; CF = originalN & 16
// n == 0 now, otherwise we would have branched after comparing with tinyBufferSize
SHRQ $0x01, BX
JNC handleTinyFinish
ADDQ (DX), AX
ADCQ 8(DX), AX
handleTinyFinish:
// CF should be included from the previous add, so we use ADCQ.
// If we arrived via the JNC above, then CF=0 due to the branch condition,
// so ADCQ will still produce the correct result.
ADCQ CX, AX
ADCQ SI, AX
ADCQ DI, AX
JMP foldAndReturn
bufferIsNotTiny:
// skip all SIMD for small buffers
CMPQ BX, $0x00000100
JGE startSIMD
// Accumulate carries in this register. It is never expected to overflow.
XORQ SI, SI
// We will perform an overlapped read for buffers with length not a multiple of 8.
// Overlapped in this context means some memory will be read twice, but a shift will
// eliminate the duplicated data. This extra read is performed at the end of the buffer to
// preserve any alignment that may exist for the start of the buffer.
MOVQ BX, CX
SHRQ $0x03, BX
ANDQ $0x07, CX
JZ handleRemaining8
LEAQ (DX)(BX*8), DI
MOVQ -8(DI)(CX*1), DI
// Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)
SHLQ $0x03, CX
NEGQ CX
ADDQ $0x40, CX
SHRQ CL, DI
ADDQ DI, AX
ADCQ $0x00, SI
handleRemaining8:
SHRQ $0x01, BX
JNC handleRemaining16
ADDQ (DX), AX
ADCQ $0x00, SI
ADDQ $0x08, DX
handleRemaining16:
SHRQ $0x01, BX
JNC handleRemaining32
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ $0x00, SI
ADDQ $0x10, DX
handleRemaining32:
SHRQ $0x01, BX
JNC handleRemaining64
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ $0x00, SI
ADDQ $0x20, DX
handleRemaining64:
SHRQ $0x01, BX
JNC handleRemaining128
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ 32(DX), AX
ADCQ 40(DX), AX
ADCQ 48(DX), AX
ADCQ 56(DX), AX
ADCQ $0x00, SI
ADDQ $0x40, DX
handleRemaining128:
SHRQ $0x01, BX
JNC handleRemainingComplete
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ 32(DX), AX
ADCQ 40(DX), AX
ADCQ 48(DX), AX
ADCQ 56(DX), AX
ADCQ 64(DX), AX
ADCQ 72(DX), AX
ADCQ 80(DX), AX
ADCQ 88(DX), AX
ADCQ 96(DX), AX
ADCQ 104(DX), AX
ADCQ 112(DX), AX
ADCQ 120(DX), AX
ADCQ $0x00, SI
ADDQ $0x80, DX
handleRemainingComplete:
ADDQ SI, AX
JMP foldAndReturn
startSIMD:
VPXOR Y0, Y0, Y0
VPXOR Y1, Y1, Y1
VPXOR Y2, Y2, Y2
VPXOR Y3, Y3, Y3
MOVQ BX, CX
// Update number of bytes remaining after the loop completes
ANDQ $0xff, BX
// Number of 256 byte iterations
SHRQ $0x08, CX
JZ smallLoop
bigLoop:
VPMOVZXWD (DX), Y4
VPADDD Y4, Y0, Y0
VPMOVZXWD 16(DX), Y4
VPADDD Y4, Y1, Y1
VPMOVZXWD 32(DX), Y4
VPADDD Y4, Y2, Y2
VPMOVZXWD 48(DX), Y4
VPADDD Y4, Y3, Y3
VPMOVZXWD 64(DX), Y4
VPADDD Y4, Y0, Y0
VPMOVZXWD 80(DX), Y4
VPADDD Y4, Y1, Y1
VPMOVZXWD 96(DX), Y4
VPADDD Y4, Y2, Y2
VPMOVZXWD 112(DX), Y4
VPADDD Y4, Y3, Y3
VPMOVZXWD 128(DX), Y4
VPADDD Y4, Y0, Y0
VPMOVZXWD 144(DX), Y4
VPADDD Y4, Y1, Y1
VPMOVZXWD 160(DX), Y4
VPADDD Y4, Y2, Y2
VPMOVZXWD 176(DX), Y4
VPADDD Y4, Y3, Y3
VPMOVZXWD 192(DX), Y4
VPADDD Y4, Y0, Y0
VPMOVZXWD 208(DX), Y4
VPADDD Y4, Y1, Y1
VPMOVZXWD 224(DX), Y4
VPADDD Y4, Y2, Y2
VPMOVZXWD 240(DX), Y4
VPADDD Y4, Y3, Y3
ADDQ $0x00000100, DX
DECQ CX
JNZ bigLoop
CMPQ BX, $0x10
JLT doneSmallLoop
// now read a single 16 byte unit of data at a time
smallLoop:
VPMOVZXWD (DX), Y4
VPADDD Y4, Y0, Y0
ADDQ $0x10, DX
SUBQ $0x10, BX
CMPQ BX, $0x10
JGE smallLoop
doneSmallLoop:
CMPQ BX, $0x00
JE doneSIMD
// There are between 1 and 15 bytes remaining. Perform an overlapped read.
LEAQ xmmLoadMasks<>+0(SB), CX
VMOVDQU -16(DX)(BX*1), X4
VPAND -16(CX)(BX*8), X4, X4
VPMOVZXWD X4, Y4
VPADDD Y4, Y0, Y0
doneSIMD:
// Multi-chain loop is done, combine the accumulators
VPADDD Y1, Y0, Y0
VPADDD Y2, Y0, Y0
VPADDD Y3, Y0, Y0
// extract the YMM into a pair of XMM and sum them
VEXTRACTI128 $0x01, Y0, X1
VPADDD X0, X1, X0
// extract the XMM into GP64
VPEXTRQ $0x00, X0, CX
VPEXTRQ $0x01, X0, DX
// no more AVX code, clear upper registers to avoid SSE slowdowns
VZEROUPPER
ADDQ CX, AX
ADCQ DX, AX
foldAndReturn:
// add CF and fold
RORXQ $0x20, AX, CX
ADCL CX, AX
RORXL $0x10, AX, CX
ADCW CX, AX
ADCW $0x00, AX
XCHGB AH, AL
MOVW AX, ret+32(FP)
RET
// func checksumSSE2(b []byte, initial uint16) uint16
// Requires: SSE2
TEXT ·checksumSSE2(SB), NOSPLIT|NOFRAME, $0-34
MOVWQZX initial+24(FP), AX
XCHGB AH, AL
MOVQ b_base+0(FP), DX
MOVQ b_len+8(FP), BX
// handle odd length buffers; they are difficult to handle in general
TESTQ $0x00000001, BX
JZ lengthIsEven
MOVBQZX -1(DX)(BX*1), CX
DECQ BX
ADDQ CX, AX
lengthIsEven:
// handle tiny buffers (<=31 bytes) specially
CMPQ BX, $0x1f
JGT bufferIsNotTiny
XORQ CX, CX
XORQ SI, SI
XORQ DI, DI
// shift twice to start because length is guaranteed to be even
// n = n >> 2; CF = originalN & 2
SHRQ $0x02, BX
JNC handleTiny4
// tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]
MOVWQZX (DX), CX
ADDQ $0x02, DX
handleTiny4:
// n = n >> 1; CF = originalN & 4
SHRQ $0x01, BX
JNC handleTiny8
// tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]
MOVLQZX (DX), SI
ADDQ $0x04, DX
handleTiny8:
// n = n >> 1; CF = originalN & 8
SHRQ $0x01, BX
JNC handleTiny16
// tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]
MOVQ (DX), DI
ADDQ $0x08, DX
handleTiny16:
// n = n >> 1; CF = originalN & 16
// n == 0 now, otherwise we would have branched after comparing with tinyBufferSize
SHRQ $0x01, BX
JNC handleTinyFinish
ADDQ (DX), AX
ADCQ 8(DX), AX
handleTinyFinish:
// CF should be included from the previous add, so we use ADCQ.
// If we arrived via the JNC above, then CF=0 due to the branch condition,
// so ADCQ will still produce the correct result.
ADCQ CX, AX
ADCQ SI, AX
ADCQ DI, AX
JMP foldAndReturn
bufferIsNotTiny:
// skip all SIMD for small buffers
CMPQ BX, $0x00000100
JGE startSIMD
// Accumulate carries in this register. It is never expected to overflow.
XORQ SI, SI
// We will perform an overlapped read for buffers with length not a multiple of 8.
// Overlapped in this context means some memory will be read twice, but a shift will
// eliminate the duplicated data. This extra read is performed at the end of the buffer to
// preserve any alignment that may exist for the start of the buffer.
MOVQ BX, CX
SHRQ $0x03, BX
ANDQ $0x07, CX
JZ handleRemaining8
LEAQ (DX)(BX*8), DI
MOVQ -8(DI)(CX*1), DI
// Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)
SHLQ $0x03, CX
NEGQ CX
ADDQ $0x40, CX
SHRQ CL, DI
ADDQ DI, AX
ADCQ $0x00, SI
handleRemaining8:
SHRQ $0x01, BX
JNC handleRemaining16
ADDQ (DX), AX
ADCQ $0x00, SI
ADDQ $0x08, DX
handleRemaining16:
SHRQ $0x01, BX
JNC handleRemaining32
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ $0x00, SI
ADDQ $0x10, DX
handleRemaining32:
SHRQ $0x01, BX
JNC handleRemaining64
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ $0x00, SI
ADDQ $0x20, DX
handleRemaining64:
SHRQ $0x01, BX
JNC handleRemaining128
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ 32(DX), AX
ADCQ 40(DX), AX
ADCQ 48(DX), AX
ADCQ 56(DX), AX
ADCQ $0x00, SI
ADDQ $0x40, DX
handleRemaining128:
SHRQ $0x01, BX
JNC handleRemainingComplete
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ 32(DX), AX
ADCQ 40(DX), AX
ADCQ 48(DX), AX
ADCQ 56(DX), AX
ADCQ 64(DX), AX
ADCQ 72(DX), AX
ADCQ 80(DX), AX
ADCQ 88(DX), AX
ADCQ 96(DX), AX
ADCQ 104(DX), AX
ADCQ 112(DX), AX
ADCQ 120(DX), AX
ADCQ $0x00, SI
ADDQ $0x80, DX
handleRemainingComplete:
ADDQ SI, AX
JMP foldAndReturn
startSIMD:
PXOR X0, X0
PXOR X1, X1
PXOR X2, X2
PXOR X3, X3
PXOR X4, X4
MOVQ BX, CX
// Update number of bytes remaining after the loop completes
ANDQ $0xff, BX
// Number of 256 byte iterations
SHRQ $0x08, CX
JZ smallLoop
bigLoop:
MOVOU (DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X0
PADDD X6, X2
MOVOU 16(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X1
PADDD X6, X3
MOVOU 32(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X2
PADDD X6, X0
MOVOU 48(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X3
PADDD X6, X1
MOVOU 64(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X0
PADDD X6, X2
MOVOU 80(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X1
PADDD X6, X3
MOVOU 96(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X2
PADDD X6, X0
MOVOU 112(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X3
PADDD X6, X1
MOVOU 128(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X0
PADDD X6, X2
MOVOU 144(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X1
PADDD X6, X3
MOVOU 160(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X2
PADDD X6, X0
MOVOU 176(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X3
PADDD X6, X1
MOVOU 192(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X0
PADDD X6, X2
MOVOU 208(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X1
PADDD X6, X3
MOVOU 224(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X2
PADDD X6, X0
MOVOU 240(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X3
PADDD X6, X1
ADDQ $0x00000100, DX
DECQ CX
JNZ bigLoop
CMPQ BX, $0x10
JLT doneSmallLoop
// now read a single 16 byte unit of data at a time
smallLoop:
MOVOU (DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X0
PADDD X6, X1
ADDQ $0x10, DX
SUBQ $0x10, BX
CMPQ BX, $0x10
JGE smallLoop
doneSmallLoop:
CMPQ BX, $0x00
JE doneSIMD
// There are between 1 and 15 bytes remaining. Perform an overlapped read.
LEAQ xmmLoadMasks<>+0(SB), CX
MOVOU -16(DX)(BX*1), X5
PAND -16(CX)(BX*8), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X0
PADDD X6, X1
doneSIMD:
// Multi-chain loop is done, combine the accumulators
PADDD X1, X0
PADDD X2, X0
PADDD X3, X0
// extract the XMM into GP64
MOVQ X0, CX
PSRLDQ $0x08, X0
MOVQ X0, DX
ADDQ CX, AX
ADCQ DX, AX
foldAndReturn:
// add CF and fold
MOVL AX, CX
ADCQ $0x00, CX
SHRQ $0x20, AX
ADDQ CX, AX
MOVWQZX AX, CX
SHRQ $0x10, AX
ADDQ CX, AX
MOVW AX, CX
SHRQ $0x10, AX
ADDW CX, AX
ADCW $0x00, AX
XCHGB AH, AL
MOVW AX, ret+32(FP)
RET
// func checksumAMD64(b []byte, initial uint16) uint16
TEXT ·checksumAMD64(SB), NOSPLIT|NOFRAME, $0-34
MOVWQZX initial+24(FP), AX
XCHGB AH, AL
MOVQ b_base+0(FP), DX
MOVQ b_len+8(FP), BX
// handle odd length buffers; they are difficult to handle in general
TESTQ $0x00000001, BX
JZ lengthIsEven
MOVBQZX -1(DX)(BX*1), CX
DECQ BX
ADDQ CX, AX
lengthIsEven:
// handle tiny buffers (<=31 bytes) specially
CMPQ BX, $0x1f
JGT bufferIsNotTiny
XORQ CX, CX
XORQ SI, SI
XORQ DI, DI
// shift twice to start because length is guaranteed to be even
// n = n >> 2; CF = originalN & 2
SHRQ $0x02, BX
JNC handleTiny4
// tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]
MOVWQZX (DX), CX
ADDQ $0x02, DX
handleTiny4:
// n = n >> 1; CF = originalN & 4
SHRQ $0x01, BX
JNC handleTiny8
// tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]
MOVLQZX (DX), SI
ADDQ $0x04, DX
handleTiny8:
// n = n >> 1; CF = originalN & 8
SHRQ $0x01, BX
JNC handleTiny16
// tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]
MOVQ (DX), DI
ADDQ $0x08, DX
handleTiny16:
// n = n >> 1; CF = originalN & 16
// n == 0 now, otherwise we would have branched after comparing with tinyBufferSize
SHRQ $0x01, BX
JNC handleTinyFinish
ADDQ (DX), AX
ADCQ 8(DX), AX
handleTinyFinish:
// CF should be included from the previous add, so we use ADCQ.
// If we arrived via the JNC above, then CF=0 due to the branch condition,
// so ADCQ will still produce the correct result.
ADCQ CX, AX
ADCQ SI, AX
ADCQ DI, AX
JMP foldAndReturn
bufferIsNotTiny:
// Number of 256 byte iterations into loop counter
MOVQ BX, CX
// Update number of bytes remaining after the loop completes
ANDQ $0xff, BX
SHRQ $0x08, CX
JZ startCleanup
CLC
XORQ SI, SI
XORQ DI, DI
XORQ R8, R8
XORQ R9, R9
XORQ R10, R10
XORQ R11, R11
XORQ R12, R12
bigLoop:
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ $0x00, SI
ADDQ 32(DX), DI
ADCQ 40(DX), DI
ADCQ 48(DX), DI
ADCQ 56(DX), DI
ADCQ $0x00, R8
ADDQ 64(DX), R9
ADCQ 72(DX), R9
ADCQ 80(DX), R9
ADCQ 88(DX), R9
ADCQ $0x00, R10
ADDQ 96(DX), R11
ADCQ 104(DX), R11
ADCQ 112(DX), R11
ADCQ 120(DX), R11
ADCQ $0x00, R12
ADDQ 128(DX), AX
ADCQ 136(DX), AX
ADCQ 144(DX), AX
ADCQ 152(DX), AX
ADCQ $0x00, SI
ADDQ 160(DX), DI
ADCQ 168(DX), DI
ADCQ 176(DX), DI
ADCQ 184(DX), DI
ADCQ $0x00, R8
ADDQ 192(DX), R9
ADCQ 200(DX), R9
ADCQ 208(DX), R9
ADCQ 216(DX), R9
ADCQ $0x00, R10
ADDQ 224(DX), R11
ADCQ 232(DX), R11
ADCQ 240(DX), R11
ADCQ 248(DX), R11
ADCQ $0x00, R12
ADDQ $0x00000100, DX
SUBQ $0x01, CX
JNZ bigLoop
ADDQ SI, AX
ADCQ DI, AX
ADCQ R8, AX
ADCQ R9, AX
ADCQ R10, AX
ADCQ R11, AX
ADCQ R12, AX
// accumulate CF (twice, in case the first time overflows)
ADCQ $0x00, AX
ADCQ $0x00, AX
startCleanup:
// Accumulate carries in this register. It is never expected to overflow.
XORQ SI, SI
// We will perform an overlapped read for buffers with length not a multiple of 8.
// Overlapped in this context means some memory will be read twice, but a shift will
// eliminate the duplicated data. This extra read is performed at the end of the buffer to
// preserve any alignment that may exist for the start of the buffer.
MOVQ BX, CX
SHRQ $0x03, BX
ANDQ $0x07, CX
JZ handleRemaining8
LEAQ (DX)(BX*8), DI
MOVQ -8(DI)(CX*1), DI
// Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)
SHLQ $0x03, CX
NEGQ CX
ADDQ $0x40, CX
SHRQ CL, DI
ADDQ DI, AX
ADCQ $0x00, SI
handleRemaining8:
SHRQ $0x01, BX
JNC handleRemaining16
ADDQ (DX), AX
ADCQ $0x00, SI
ADDQ $0x08, DX
handleRemaining16:
SHRQ $0x01, BX
JNC handleRemaining32
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ $0x00, SI
ADDQ $0x10, DX
handleRemaining32:
SHRQ $0x01, BX
JNC handleRemaining64
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ $0x00, SI
ADDQ $0x20, DX
handleRemaining64:
SHRQ $0x01, BX
JNC handleRemaining128
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ 32(DX), AX
ADCQ 40(DX), AX
ADCQ 48(DX), AX
ADCQ 56(DX), AX
ADCQ $0x00, SI
ADDQ $0x40, DX
handleRemaining128:
SHRQ $0x01, BX
JNC handleRemainingComplete
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ 32(DX), AX
ADCQ 40(DX), AX
ADCQ 48(DX), AX
ADCQ 56(DX), AX
ADCQ 64(DX), AX
ADCQ 72(DX), AX
ADCQ 80(DX), AX
ADCQ 88(DX), AX
ADCQ 96(DX), AX
ADCQ 104(DX), AX
ADCQ 112(DX), AX
ADCQ 120(DX), AX
ADCQ $0x00, SI
ADDQ $0x80, DX
handleRemainingComplete:
ADDQ SI, AX
foldAndReturn:
// add CF and fold
MOVL AX, CX
ADCQ $0x00, CX
SHRQ $0x20, AX
ADDQ CX, AX
MOVWQZX AX, CX
SHRQ $0x10, AX
ADDQ CX, AX
MOVW AX, CX
SHRQ $0x10, AX
ADDW CX, AX
ADCW $0x00, AX
XCHGB AH, AL
MOVW AX, ret+32(FP)
RET

View File

@@ -0,0 +1,15 @@
// This file contains IP checksum algorithms that are not specific to any
// architecture and don't use hardware acceleration.
//go:build !amd64
package tun
import "strconv"
func Checksum(data []byte, initial uint16) uint16 {
if strconv.IntSize < 64 {
return checksumGeneric32(data, initial)
}
return checksumGeneric64(data, initial)
}

12
vendor/github.com/tailscale/wireguard-go/tun/errors.go generated vendored Normal file
View File

@@ -0,0 +1,12 @@
package tun
import (
"errors"
)
var (
// ErrTooManySegments is returned by Device.Read() when segmentation
// overflows the length of supplied buffers. This error should not cause
// reads to cease.
ErrTooManySegments = errors.New("too many segments")
)

220
vendor/github.com/tailscale/wireguard-go/tun/offload.go generated vendored Normal file
View File

@@ -0,0 +1,220 @@
package tun
import (
"encoding/binary"
"fmt"
)
// GSOType represents the type of segmentation offload.
type GSOType int
const (
GSONone GSOType = iota
GSOTCPv4
GSOTCPv6
GSOUDPL4
)
func (g GSOType) String() string {
switch g {
case GSONone:
return "GSONone"
case GSOTCPv4:
return "GSOTCPv4"
case GSOTCPv6:
return "GSOTCPv6"
case GSOUDPL4:
return "GSOUDPL4"
default:
return "unknown"
}
}
// GSOOptions is loosely modeled after struct virtio_net_hdr from the VIRTIO
// specification. It is a common representation of GSO metadata that can be
// applied to support packet GSO across tun.Device implementations.
type GSOOptions struct {
// GSOType represents the type of segmentation offload.
GSOType GSOType
// HdrLen is the sum of the layer 3 and 4 header lengths. This field may be
// zero when GSOType == GSONone.
HdrLen uint16
// CsumStart is the head byte index of the packet data to be checksummed,
// i.e. the start of the TCP or UDP header.
CsumStart uint16
// CsumOffset is the offset from CsumStart where the 2-byte checksum value
// should be placed.
CsumOffset uint16
// GSOSize is the size of each segment exclusive of HdrLen. The tail segment
// may be smaller than this value.
GSOSize uint16
// NeedsCsum may be set where GSOType == GSONone. When set, the checksum
// at CsumStart + CsumOffset must be a partial checksum, i.e. the
// pseudo-header sum.
NeedsCsum bool
}
const (
ipv4SrcAddrOffset = 12
ipv6SrcAddrOffset = 8
)
const tcpFlagsOffset = 13
const (
tcpFlagFIN uint8 = 0x01
tcpFlagPSH uint8 = 0x08
tcpFlagACK uint8 = 0x10
)
const (
// defined here in order to avoid importation of any platform-specific pkgs
ipProtoTCP = 6
ipProtoUDP = 17
)
// GSOSplit splits packets from 'in' into outBufs[<index>][outOffset:], writing
// the size of each element into sizes. It returns the number of buffers
// populated, and/or an error. Callers may pass an 'in' slice that overlaps with
// the first element of outBuffers, i.e. &in[0] may be equal to
// &outBufs[0][outOffset]. GSONone is a valid options.GSOType regardless of the
// value of options.NeedsCsum. Length of each outBufs element must be greater
// than or equal to the length of 'in', otherwise output may be silently
// truncated.
func GSOSplit(in []byte, options GSOOptions, outBufs [][]byte, sizes []int, outOffset int) (int, error) {
cSumAt := int(options.CsumStart) + int(options.CsumOffset)
if cSumAt+1 >= len(in) {
return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
}
if len(in) < int(options.HdrLen) {
return 0, fmt.Errorf("length of packet (%d) < GSO HdrLen (%d)", len(in), options.HdrLen)
}
// Handle the conditions where we are copying a single element to outBuffs.
payloadLen := len(in) - int(options.HdrLen)
if options.GSOType == GSONone || payloadLen < int(options.GSOSize) {
if len(in) > len(outBufs[0][outOffset:]) {
return 0, fmt.Errorf("length of packet (%d) exceeds output element length (%d)", len(in), len(outBufs[0][outOffset:]))
}
if options.NeedsCsum {
// The initial value at the checksum offset should be summed with
// the checksum we compute. This is typically the pseudo-header sum.
initial := binary.BigEndian.Uint16(in[cSumAt:])
in[cSumAt], in[cSumAt+1] = 0, 0
binary.BigEndian.PutUint16(in[cSumAt:], ^Checksum(in[options.CsumStart:], initial))
}
sizes[0] = copy(outBufs[0][outOffset:], in)
return 1, nil
}
if options.HdrLen < options.CsumStart {
return 0, fmt.Errorf("GSO HdrLen (%d) < GSO CsumStart (%d)", options.HdrLen, options.CsumStart)
}
ipVersion := in[0] >> 4
switch ipVersion {
case 4:
if options.GSOType != GSOTCPv4 && options.GSOType != GSOUDPL4 {
return 0, fmt.Errorf("ip header version: %d, GSO type: %s", ipVersion, options.GSOType)
}
if len(in) < 20 {
return 0, fmt.Errorf("length of packet (%d) < minimum ipv4 header size (%d)", len(in), 20)
}
case 6:
if options.GSOType != GSOTCPv6 && options.GSOType != GSOUDPL4 {
return 0, fmt.Errorf("ip header version: %d, GSO type: %s", ipVersion, options.GSOType)
}
if len(in) < 40 {
return 0, fmt.Errorf("length of packet (%d) < minimum ipv6 header size (%d)", len(in), 40)
}
default:
return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
}
iphLen := int(options.CsumStart)
srcAddrOffset := ipv6SrcAddrOffset
addrLen := 16
if ipVersion == 4 {
srcAddrOffset = ipv4SrcAddrOffset
addrLen = 4
}
transportCsumAt := int(options.CsumStart + options.CsumOffset)
var firstTCPSeqNum uint32
var protocol uint8
if options.GSOType == GSOTCPv4 || options.GSOType == GSOTCPv6 {
protocol = ipProtoTCP
if len(in) < int(options.CsumStart)+20 {
return 0, fmt.Errorf("length of packet (%d) < GSO CsumStart (%d) + minimum TCP header size (%d)",
len(in), options.CsumStart, 20)
}
firstTCPSeqNum = binary.BigEndian.Uint32(in[options.CsumStart+4:])
} else {
protocol = ipProtoUDP
}
nextSegmentDataAt := int(options.HdrLen)
i := 0
for ; nextSegmentDataAt < len(in); i++ {
if i == len(outBufs) {
return i - 1, ErrTooManySegments
}
nextSegmentEnd := nextSegmentDataAt + int(options.GSOSize)
if nextSegmentEnd > len(in) {
nextSegmentEnd = len(in)
}
segmentDataLen := nextSegmentEnd - nextSegmentDataAt
totalLen := int(options.HdrLen) + segmentDataLen
sizes[i] = totalLen
out := outBufs[i][outOffset:]
copy(out, in[:iphLen])
if ipVersion == 4 {
// For IPv4 we are responsible for incrementing the ID field,
// updating the total len field, and recalculating the header
// checksum.
if i > 0 {
id := binary.BigEndian.Uint16(out[4:])
id += uint16(i)
binary.BigEndian.PutUint16(out[4:], id)
}
out[10], out[11] = 0, 0 // clear ipv4 header checksum
binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
ipv4CSum := ^Checksum(out[:iphLen], 0)
binary.BigEndian.PutUint16(out[10:], ipv4CSum)
} else {
// For IPv6 we are responsible for updating the payload length field.
binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
}
// copy transport header
copy(out[options.CsumStart:options.HdrLen], in[options.CsumStart:options.HdrLen])
if protocol == ipProtoTCP {
// set TCP seq and adjust TCP flags
tcpSeq := firstTCPSeqNum + uint32(options.GSOSize*uint16(i))
binary.BigEndian.PutUint32(out[options.CsumStart+4:], tcpSeq)
if nextSegmentEnd != len(in) {
// FIN and PSH should only be set on last segment
clearFlags := tcpFlagFIN | tcpFlagPSH
out[options.CsumStart+tcpFlagsOffset] &^= clearFlags
}
} else {
// set UDP header len
binary.BigEndian.PutUint16(out[options.CsumStart+4:], uint16(segmentDataLen)+(options.HdrLen-options.CsumStart))
}
// payload
copy(out[options.HdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
// transport checksum
out[transportCsumAt], out[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum
transportHeaderLen := int(options.HdrLen - options.CsumStart)
lenForPseudo := uint16(transportHeaderLen + segmentDataLen)
transportCSum := PseudoHeaderChecksum(protocol, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo)
transportCSum = ^Checksum(out[options.CsumStart:totalLen], transportCSum)
binary.BigEndian.PutUint16(out[options.CsumStart+options.CsumOffset:], transportCSum)
nextSegmentDataAt += int(options.GSOSize)
}
return i, nil
}

View File

@@ -0,0 +1,911 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"unsafe"
"github.com/tailscale/wireguard-go/conn"
"golang.org/x/sys/unix"
)
// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
// kernel symbol is virtio_net_hdr.
type virtioNetHdr struct {
flags uint8
gsoType uint8
hdrLen uint16
gsoSize uint16
csumStart uint16
csumOffset uint16
}
func (v *virtioNetHdr) toGSOOptions() (GSOOptions, error) {
var gsoType GSOType
switch v.gsoType {
case unix.VIRTIO_NET_HDR_GSO_NONE:
gsoType = GSONone
case unix.VIRTIO_NET_HDR_GSO_TCPV4:
gsoType = GSOTCPv4
case unix.VIRTIO_NET_HDR_GSO_TCPV6:
gsoType = GSOTCPv6
case unix.VIRTIO_NET_HDR_GSO_UDP_L4:
gsoType = GSOUDPL4
default:
return GSOOptions{}, fmt.Errorf("unsupported virtio gsoType: %d", v.gsoType)
}
return GSOOptions{
GSOType: gsoType,
HdrLen: v.hdrLen,
CsumStart: v.csumStart,
CsumOffset: v.csumOffset,
GSOSize: v.gsoSize,
NeedsCsum: v.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0,
}, nil
}
func (v *virtioNetHdr) decode(b []byte) error {
if len(b) < virtioNetHdrLen {
return io.ErrShortBuffer
}
copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen])
return nil
}
func (v *virtioNetHdr) encode(b []byte) error {
if len(b) < virtioNetHdrLen {
return io.ErrShortBuffer
}
copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen))
return nil
}
const (
// virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the
// shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr).
virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
)
// tcpFlowKey represents the key for a TCP flow.
type tcpFlowKey struct {
srcAddr, dstAddr [16]byte
srcPort, dstPort uint16
rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows.
isV6 bool
}
// tcpGROTable holds flow and coalescing information for the purposes of TCP GRO.
type tcpGROTable struct {
itemsByFlow map[tcpFlowKey][]tcpGROItem
itemsPool [][]tcpGROItem
}
func newTCPGROTable() *tcpGROTable {
t := &tcpGROTable{
itemsByFlow: make(map[tcpFlowKey][]tcpGROItem, conn.IdealBatchSize),
itemsPool: make([][]tcpGROItem, conn.IdealBatchSize),
}
for i := range t.itemsPool {
t.itemsPool[i] = make([]tcpGROItem, 0, conn.IdealBatchSize)
}
return t
}
func newTCPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset int) tcpFlowKey {
key := tcpFlowKey{}
addrSize := dstAddrOffset - srcAddrOffset
copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset])
copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize])
key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
key.isV6 = addrSize == 16
return key
}
// lookupOrInsert looks up a flow for the provided packet and metadata,
// returning the packets found for the flow, or inserting a new one if none
// is found.
func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) {
key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
items, ok := t.itemsByFlow[key]
if ok {
return items, ok
}
// TODO: insert() performs another map lookup. This could be rearranged to avoid.
t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex)
return nil, false
}
// insert an item in the table for the provided packet and packet metadata.
func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) {
key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
item := tcpGROItem{
key: key,
bufsIndex: uint16(bufsIndex),
gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])),
iphLen: uint8(tcphOffset),
tcphLen: uint8(tcphLen),
sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]),
pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0,
}
items, ok := t.itemsByFlow[key]
if !ok {
items = t.newItems()
}
items = append(items, item)
t.itemsByFlow[key] = items
}
func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
items, _ := t.itemsByFlow[item.key]
items[i] = item
}
func (t *tcpGROTable) deleteAt(key tcpFlowKey, i int) {
items, _ := t.itemsByFlow[key]
items = append(items[:i], items[i+1:]...)
t.itemsByFlow[key] = items
}
// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
// of a GRO evaluation across a vector of packets.
type tcpGROItem struct {
key tcpFlowKey
sentSeq uint32 // the sequence number
bufsIndex uint16 // the index into the original bufs slice
numMerged uint16 // the number of packets merged into this item
gsoSize uint16 // payload size
iphLen uint8 // ip header len
tcphLen uint8 // tcp header len
pshSet bool // psh flag is set
}
func (t *tcpGROTable) newItems() []tcpGROItem {
var items []tcpGROItem
items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1]
return items
}
func (t *tcpGROTable) reset() {
for k, items := range t.itemsByFlow {
items = items[:0]
t.itemsPool = append(t.itemsPool, items)
delete(t.itemsByFlow, k)
}
}
// udpFlowKey represents the key for a UDP flow.
type udpFlowKey struct {
srcAddr, dstAddr [16]byte
srcPort, dstPort uint16
isV6 bool
}
// udpGROTable holds flow and coalescing information for the purposes of UDP GRO.
type udpGROTable struct {
itemsByFlow map[udpFlowKey][]udpGROItem
itemsPool [][]udpGROItem
}
func newUDPGROTable() *udpGROTable {
u := &udpGROTable{
itemsByFlow: make(map[udpFlowKey][]udpGROItem, conn.IdealBatchSize),
itemsPool: make([][]udpGROItem, conn.IdealBatchSize),
}
for i := range u.itemsPool {
u.itemsPool[i] = make([]udpGROItem, 0, conn.IdealBatchSize)
}
return u
}
func newUDPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int) udpFlowKey {
key := udpFlowKey{}
addrSize := dstAddrOffset - srcAddrOffset
copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset])
copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize])
key.srcPort = binary.BigEndian.Uint16(pkt[udphOffset:])
key.dstPort = binary.BigEndian.Uint16(pkt[udphOffset+2:])
key.isV6 = addrSize == 16
return key
}
// lookupOrInsert looks up a flow for the provided packet and metadata,
// returning the packets found for the flow, or inserting a new one if none
// is found.
func (u *udpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int) ([]udpGROItem, bool) {
key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset)
items, ok := u.itemsByFlow[key]
if ok {
return items, ok
}
// TODO: insert() performs another map lookup. This could be rearranged to avoid.
u.insert(pkt, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex, false)
return nil, false
}
// insert an item in the table for the provided packet and packet metadata.
func (u *udpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int, cSumKnownInvalid bool) {
key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset)
item := udpGROItem{
key: key,
bufsIndex: uint16(bufsIndex),
gsoSize: uint16(len(pkt[udphOffset+udphLen:])),
iphLen: uint8(udphOffset),
cSumKnownInvalid: cSumKnownInvalid,
}
items, ok := u.itemsByFlow[key]
if !ok {
items = u.newItems()
}
items = append(items, item)
u.itemsByFlow[key] = items
}
func (u *udpGROTable) updateAt(item udpGROItem, i int) {
items, _ := u.itemsByFlow[item.key]
items[i] = item
}
// udpGROItem represents bookkeeping data for a UDP packet during the lifetime
// of a GRO evaluation across a vector of packets.
type udpGROItem struct {
key udpFlowKey
bufsIndex uint16 // the index into the original bufs slice
numMerged uint16 // the number of packets merged into this item
gsoSize uint16 // payload size
iphLen uint8 // ip header len
cSumKnownInvalid bool // UDP header checksum validity; a false value DOES NOT imply valid, just unknown.
}
func (u *udpGROTable) newItems() []udpGROItem {
var items []udpGROItem
items, u.itemsPool = u.itemsPool[len(u.itemsPool)-1], u.itemsPool[:len(u.itemsPool)-1]
return items
}
func (u *udpGROTable) reset() {
for k, items := range u.itemsByFlow {
items = items[:0]
u.itemsPool = append(u.itemsPool, items)
delete(u.itemsByFlow, k)
}
}
// canCoalesce represents the outcome of checking if two TCP packets are
// candidates for coalescing.
type canCoalesce int
const (
coalescePrepend canCoalesce = -1
coalesceUnavailable canCoalesce = 0
coalesceAppend canCoalesce = 1
)
// ipHeadersCanCoalesce returns true if the IP headers found in pktA and pktB
// meet all requirements to be merged as part of a GRO operation, otherwise it
// returns false.
func ipHeadersCanCoalesce(pktA, pktB []byte) bool {
if len(pktA) < 9 || len(pktB) < 9 {
return false
}
if pktA[0]>>4 == 6 {
if pktA[0] != pktB[0] || pktA[1]>>4 != pktB[1]>>4 {
// cannot coalesce with unequal Traffic class values
return false
}
if pktA[7] != pktB[7] {
// cannot coalesce with unequal Hop limit values
return false
}
} else {
if pktA[1] != pktB[1] {
// cannot coalesce with unequal ToS values
return false
}
if pktA[6]>>5 != pktB[6]>>5 {
// cannot coalesce with unequal DF or reserved bits. MF is checked
// further up the stack.
return false
}
if pktA[8] != pktB[8] {
// cannot coalesce with unequal TTL values
return false
}
}
return true
}
// udpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
// described by item. iphLen and gsoSize describe pkt. bufs is the vector of
// packets involved in the current GRO evaluation. bufsOffset is the offset at
// which packet data begins within bufs.
func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
pktTarget := bufs[item.bufsIndex][bufsOffset:]
if !ipHeadersCanCoalesce(pkt, pktTarget) {
return coalesceUnavailable
}
if len(pktTarget[iphLen+udphLen:])%int(item.gsoSize) != 0 {
// A smaller than gsoSize packet has been appended previously.
// Nothing can come after a smaller packet on the end.
return coalesceUnavailable
}
if gsoSize > item.gsoSize {
// We cannot have a larger packet following a smaller one.
return coalesceUnavailable
}
return coalesceAppend
}
// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
// described by item. This function makes considerations that match the kernel's
// GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
pktTarget := bufs[item.bufsIndex][bufsOffset:]
if tcphLen != item.tcphLen {
// cannot coalesce with unequal tcp options len
return coalesceUnavailable
}
if tcphLen > 20 {
if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) {
// cannot coalesce with unequal tcp options
return coalesceUnavailable
}
}
if !ipHeadersCanCoalesce(pkt, pktTarget) {
return coalesceUnavailable
}
// seq adjacency
lhsLen := item.gsoSize
lhsLen += item.numMerged * item.gsoSize
if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective
if item.pshSet {
// We cannot append to a segment that has the PSH flag set, PSH
// can only be set on the final segment in a reassembled group.
return coalesceUnavailable
}
if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 {
// A smaller than gsoSize packet has been appended previously.
// Nothing can come after a smaller packet on the end.
return coalesceUnavailable
}
if gsoSize > item.gsoSize {
// We cannot have a larger packet following a smaller one.
return coalesceUnavailable
}
return coalesceAppend
} else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective
if pshSet {
// We cannot prepend with a segment that has the PSH flag set, PSH
// can only be set on the final segment in a reassembled group.
return coalesceUnavailable
}
if gsoSize < item.gsoSize {
// We cannot have a larger packet following a smaller one.
return coalesceUnavailable
}
if gsoSize > item.gsoSize && item.numMerged > 0 {
// There's at least one previous merge, and we're larger than all
// previous. This would put multiple smaller packets on the end.
return coalesceUnavailable
}
return coalescePrepend
}
return coalesceUnavailable
}
func checksumValid(pkt []byte, iphLen, proto uint8, isV6 bool) bool {
srcAddrAt := ipv4SrcAddrOffset
addrSize := 4
if isV6 {
srcAddrAt = ipv6SrcAddrOffset
addrSize = 16
}
lenForPseudo := uint16(len(pkt) - int(iphLen))
cSum := PseudoHeaderChecksum(proto, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], lenForPseudo)
return ^Checksum(pkt[iphLen:], cSum) == 0
}
// coalesceResult represents the result of attempting to coalesce two TCP
// packets.
type coalesceResult int
const (
coalesceInsufficientCap coalesceResult = iota
coalescePSHEnding
coalesceItemInvalidCSum
coalescePktInvalidCSum
coalesceSuccess
)
// coalesceUDPPackets attempts to coalesce pkt with the packet described by
// item, and returns the outcome.
func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
pktHead := bufs[item.bufsIndex][bufsOffset:] // the packet that will end up at the front
headersLen := item.iphLen + udphLen
coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
if cap(pktHead)-bufsOffset < coalescedLen {
// We don't want to allocate a new underlying array if capacity is
// too small.
return coalesceInsufficientCap
}
if item.numMerged == 0 {
if item.cSumKnownInvalid || !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_UDP, isV6) {
return coalesceItemInvalidCSum
}
}
if !checksumValid(pkt, item.iphLen, unix.IPPROTO_UDP, isV6) {
return coalescePktInvalidCSum
}
extendBy := len(pkt) - int(headersLen)
bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
item.numMerged++
return coalesceSuccess
}
// coalesceTCPPackets attempts to coalesce pkt with the packet described by
// item, and returns the outcome. This function may swap bufs elements in the
// event of a prepend as item's bufs index is already being tracked for writing
// to a Device.
func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
var pktHead []byte // the packet that will end up at the front
headersLen := item.iphLen + item.tcphLen
coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
// Copy data
if mode == coalescePrepend {
pktHead = pkt
if cap(pkt)-bufsOffset < coalescedLen {
// We don't want to allocate a new underlying array if capacity is
// too small.
return coalesceInsufficientCap
}
if pshSet {
return coalescePSHEnding
}
if item.numMerged == 0 {
if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) {
return coalesceItemInvalidCSum
}
}
if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) {
return coalescePktInvalidCSum
}
item.sentSeq = seq
extendBy := coalescedLen - len(pktHead)
bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...)
copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):])
// Flip the slice headers in bufs as part of prepend. The index of item
// is already being tracked for writing.
bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex]
} else {
pktHead = bufs[item.bufsIndex][bufsOffset:]
if cap(pktHead)-bufsOffset < coalescedLen {
// We don't want to allocate a new underlying array if capacity is
// too small.
return coalesceInsufficientCap
}
if item.numMerged == 0 {
if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) {
return coalesceItemInvalidCSum
}
}
if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) {
return coalescePktInvalidCSum
}
if pshSet {
// We are appending a segment with PSH set.
item.pshSet = pshSet
pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH
}
extendBy := len(pkt) - int(headersLen)
bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
}
if gsoSize > item.gsoSize {
item.gsoSize = gsoSize
}
item.numMerged++
return coalesceSuccess
}
const (
ipv4FlagMoreFragments uint8 = 0x20
)
const (
maxUint16 = 1<<16 - 1
)
type groResult int
const (
groResultNoop groResult = iota
groResultTableInsert
groResultCoalesced
)
// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
// existing packets tracked in table. It returns a groResultNoop when no
// action was taken, groResultTableInsert when the evaluated packet was
// inserted into table, and groResultCoalesced when the evaluated packet was
// coalesced with another packet in table.
func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) groResult {
pkt := bufs[pktI][offset:]
if len(pkt) > maxUint16 {
// A valid IPv4 or IPv6 packet will never exceed this.
return groResultNoop
}
iphLen := int((pkt[0] & 0x0F) * 4)
if isV6 {
iphLen = 40
ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
if ipv6HPayloadLen != len(pkt)-iphLen {
return groResultNoop
}
} else {
totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
if totalLen != len(pkt) {
return groResultNoop
}
}
if len(pkt) < iphLen {
return groResultNoop
}
tcphLen := int((pkt[iphLen+12] >> 4) * 4)
if tcphLen < 20 || tcphLen > 60 {
return groResultNoop
}
if len(pkt) < iphLen+tcphLen {
return groResultNoop
}
if !isV6 {
if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
// no GRO support for fragmented segments for now
return groResultNoop
}
}
tcpFlags := pkt[iphLen+tcpFlagsOffset]
var pshSet bool
// not a candidate if any non-ACK flags (except PSH+ACK) are set
if tcpFlags != tcpFlagACK {
if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
return groResultNoop
}
pshSet = true
}
gsoSize := uint16(len(pkt) - tcphLen - iphLen)
// not a candidate if payload len is 0
if gsoSize < 1 {
return groResultNoop
}
seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
srcAddrOffset := ipv4SrcAddrOffset
addrLen := 4
if isV6 {
srcAddrOffset = ipv6SrcAddrOffset
addrLen = 16
}
items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
if !existing {
return groResultTableInsert
}
for i := len(items) - 1; i >= 0; i-- {
// In the best case of packets arriving in order iterating in reverse is
// more efficient if there are multiple items for a given flow. This
// also enables a natural table.deleteAt() in the
// coalesceItemInvalidCSum case without the need for index tracking.
// This algorithm makes a best effort to coalesce in the event of
// unordered packets, where pkt may land anywhere in items from a
// sequence number perspective, however once an item is inserted into
// the table it is never compared across other items later.
item := items[i]
can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset)
if can != coalesceUnavailable {
result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6)
switch result {
case coalesceSuccess:
table.updateAt(item, i)
return groResultCoalesced
case coalesceItemInvalidCSum:
// delete the item with an invalid csum
table.deleteAt(item.key, i)
case coalescePktInvalidCSum:
// no point in inserting an item that we can't coalesce
return groResultNoop
default:
}
}
}
// failed to coalesce with any other packets; store the item in the flow
table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
return groResultTableInsert
}
// applyTCPCoalesceAccounting updates bufs to account for coalescing based on the
// metadata found in table.
func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) error {
for _, items := range table.itemsByFlow {
for _, item := range items {
if item.numMerged > 0 {
hdr := virtioNetHdr{
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
hdrLen: uint16(item.iphLen + item.tcphLen),
gsoSize: item.gsoSize,
csumStart: uint16(item.iphLen),
csumOffset: 16,
}
pkt := bufs[item.bufsIndex][offset:]
// Recalculate the total len (IPv4) or payload len (IPv6).
// Recalculate the (IPv4) header checksum.
if item.key.isV6 {
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
} else {
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
pkt[10], pkt[11] = 0, 0
binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
iphCSum := ^Checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum
binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field
}
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
if err != nil {
return err
}
// Calculate the pseudo header checksum and place it at the TCP
// checksum offset. Downstream checksum offloading will combine
// this with computation of the tcp header and payload checksum.
addrLen := 4
addrOffset := ipv4SrcAddrOffset
if item.key.isV6 {
addrLen = 16
addrOffset = ipv6SrcAddrOffset
}
srcAddrAt := offset + addrOffset
srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
psum := PseudoHeaderChecksum(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], Checksum([]byte{}, psum))
} else {
hdr := virtioNetHdr{}
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
if err != nil {
return err
}
}
}
}
return nil
}
// applyUDPCoalesceAccounting updates bufs to account for coalescing based on the
// metadata found in table.
func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) error {
for _, items := range table.itemsByFlow {
for _, item := range items {
if item.numMerged > 0 {
hdr := virtioNetHdr{
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
hdrLen: uint16(item.iphLen + udphLen),
gsoSize: item.gsoSize,
csumStart: uint16(item.iphLen),
csumOffset: 6,
}
pkt := bufs[item.bufsIndex][offset:]
// Recalculate the total len (IPv4) or payload len (IPv6).
// Recalculate the (IPv4) header checksum.
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_UDP_L4
if item.key.isV6 {
binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
} else {
pkt[10], pkt[11] = 0, 0
binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
iphCSum := ^Checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum
binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field
}
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
if err != nil {
return err
}
// Recalculate the UDP len field value
binary.BigEndian.PutUint16(pkt[item.iphLen+4:], uint16(len(pkt[item.iphLen:])))
// Calculate the pseudo header checksum and place it at the UDP
// checksum offset. Downstream checksum offloading will combine
// this with computation of the udp header and payload checksum.
addrLen := 4
addrOffset := ipv4SrcAddrOffset
if item.key.isV6 {
addrLen = 16
addrOffset = ipv6SrcAddrOffset
}
srcAddrAt := offset + addrOffset
srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
psum := PseudoHeaderChecksum(unix.IPPROTO_UDP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], Checksum([]byte{}, psum))
} else {
hdr := virtioNetHdr{}
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
if err != nil {
return err
}
}
}
}
return nil
}
type groCandidateType uint8
const (
notGROCandidate groCandidateType = iota
tcp4GROCandidate
tcp6GROCandidate
udp4GROCandidate
udp6GROCandidate
)
func packetIsGROCandidate(b []byte, gro groDisablementFlags) groCandidateType {
if len(b) < 28 {
return notGROCandidate
}
if b[0]>>4 == 4 {
if b[0]&0x0F != 5 {
// IPv4 packets w/IP options do not coalesce
return notGROCandidate
}
if b[9] == unix.IPPROTO_TCP && len(b) >= 40 && gro.canTCPGRO() {
return tcp4GROCandidate
}
if b[9] == unix.IPPROTO_UDP && gro.canUDPGRO() {
return udp4GROCandidate
}
} else if b[0]>>4 == 6 {
if b[6] == unix.IPPROTO_TCP && len(b) >= 60 && gro.canTCPGRO() {
return tcp6GROCandidate
}
if b[6] == unix.IPPROTO_UDP && len(b) >= 48 && gro.canUDPGRO() {
return udp6GROCandidate
}
}
return notGROCandidate
}
const (
udphLen = 8
)
// udpGRO evaluates the UDP packet at pktI in bufs for coalescing with
// existing packets tracked in table. It returns a groResultNoop when no
// action was taken, groResultTableInsert when the evaluated packet was
// inserted into table, and groResultCoalesced when the evaluated packet was
// coalesced with another packet in table.
func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) groResult {
pkt := bufs[pktI][offset:]
if len(pkt) > maxUint16 {
// A valid IPv4 or IPv6 packet will never exceed this.
return groResultNoop
}
iphLen := int((pkt[0] & 0x0F) * 4)
if isV6 {
iphLen = 40
ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
if ipv6HPayloadLen != len(pkt)-iphLen {
return groResultNoop
}
} else {
totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
if totalLen != len(pkt) {
return groResultNoop
}
}
if len(pkt) < iphLen {
return groResultNoop
}
if len(pkt) < iphLen+udphLen {
return groResultNoop
}
if !isV6 {
if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
// no GRO support for fragmented segments for now
return groResultNoop
}
}
gsoSize := uint16(len(pkt) - udphLen - iphLen)
// not a candidate if payload len is 0
if gsoSize < 1 {
return groResultNoop
}
srcAddrOffset := ipv4SrcAddrOffset
addrLen := 4
if isV6 {
srcAddrOffset = ipv6SrcAddrOffset
addrLen = 16
}
items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI)
if !existing {
return groResultTableInsert
}
// With UDP we only check the last item, otherwise we could reorder packets
// for a given flow. We must also always insert a new item, or successfully
// coalesce with an existing item, for the same reason.
item := items[len(items)-1]
can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, bufs, offset)
var pktCSumKnownInvalid bool
if can == coalesceAppend {
result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6)
switch result {
case coalesceSuccess:
table.updateAt(item, len(items)-1)
return groResultCoalesced
case coalesceItemInvalidCSum:
// If the existing item has an invalid csum we take no action. A new
// item will be stored after it, and the existing item will never be
// revisited as part of future coalescing candidacy checks.
case coalescePktInvalidCSum:
// We must insert a new item, but we also mark it as invalid csum
// to prevent a repeat checksum validation.
pktCSumKnownInvalid = true
default:
}
}
// failed to coalesce with any other packets; store the item in the flow
table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI, pktCSumKnownInvalid)
return groResultTableInsert
}
// handleGRO evaluates bufs for GRO, and writes the indices of the resulting
// packets into toWrite. toWrite, tcpTable, and udpTable should initially be
// empty (but non-nil), and are passed in to save allocs as the caller may reset
// and recycle them across vectors of packets. gro indicates if TCP and UDP GRO
// are supported/enabled.
func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, gro groDisablementFlags, toWrite *[]int) error {
for i := range bufs {
if offset < virtioNetHdrLen || offset > len(bufs[i])-1 {
return errors.New("invalid offset")
}
var result groResult
switch packetIsGROCandidate(bufs[i][offset:], gro) {
case tcp4GROCandidate:
result = tcpGRO(bufs, offset, i, tcpTable, false)
case tcp6GROCandidate:
result = tcpGRO(bufs, offset, i, tcpTable, true)
case udp4GROCandidate:
result = udpGRO(bufs, offset, i, udpTable, false)
case udp6GROCandidate:
result = udpGRO(bufs, offset, i, udpTable, true)
}
switch result {
case groResultNoop:
hdr := virtioNetHdr{}
err := hdr.encode(bufs[i][offset-virtioNetHdrLen:])
if err != nil {
return err
}
fallthrough
case groResultTableInsert:
*toWrite = append(*toWrite, i)
}
}
errTCP := applyTCPCoalesceAccounting(bufs, offset, tcpTable)
errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable)
return errors.Join(errTCP, errUDP)
}

View File

@@ -0,0 +1,24 @@
//go:build darwin || freebsd
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
import (
"fmt"
)
func (tun *NativeTun) operateOnFd(fn func(fd uintptr)) {
sysconn, err := tun.tunFile.SyscallConn()
if err != nil {
tun.errors <- fmt.Errorf("unable to find sysconn for tunfile: %s", err.Error())
return
}
err = sysconn.Control(fn)
if err != nil {
tun.errors <- fmt.Errorf("unable to control sysconn for tunfile: %s", err.Error())
}
}

76
vendor/github.com/tailscale/wireguard-go/tun/tun.go generated vendored Normal file
View File

@@ -0,0 +1,76 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
import (
"os"
)
type Event int
const (
EventUp = 1 << iota
EventDown
EventMTUUpdate
)
type Device interface {
// File returns the file descriptor of the device.
File() *os.File
// Read one or more packets from the Device (without any additional headers).
// On a successful read it returns the number of packets read, and sets
// packet lengths within the sizes slice. len(sizes) must be >= len(bufs).
// A nonzero offset can be used to instruct the Device on where to begin
// reading into each element of the bufs slice.
Read(bufs [][]byte, sizes []int, offset int) (n int, err error)
// Write one or more packets to the device (without any additional headers).
// On a successful write it returns the number of packets written. A nonzero
// offset can be used to instruct the Device on where to begin writing from
// each packet contained within the bufs slice.
Write(bufs [][]byte, offset int) (int, error)
// MTU returns the MTU of the Device.
MTU() (int, error)
// Name returns the current name of the Device.
Name() (string, error)
// Events returns a channel of type Event, which is fed Device events.
Events() <-chan Event
// Close stops the Device and closes the Event channel.
Close() error
// BatchSize returns the preferred/max number of packets that can be read or
// written in a single read/write call. BatchSize must not change over the
// lifetime of a Device.
BatchSize() int
}
// GRODevice is a Device extended with methods for disabling GRO. Certain OS
// versions may have offload bugs. Where these bugs negatively impact throughput
// or break connectivity entirely we can use these methods to disable the
// related offload.
//
// Linux has the following known, GRO bugs.
//
// torvalds/linux@e269d79c7d35aa3808b1f3c1737d63dab504ddc8 broke virtio_net
// TCP & UDP GRO causing GRO writes to return EINVAL. The bug was then
// resolved later in
// torvalds/linux@89add40066f9ed9abe5f7f886fe5789ff7e0c50e. The offending
// commit was pulled into various LTS releases.
//
// UDP GRO writes end up blackholing/dropping packets destined for a
// vxlan/geneve interface on kernel versions prior to 6.8.5.
type GRODevice interface {
Device
// DisableUDPGRO disables UDP GRO if it is enabled.
DisableUDPGRO()
// DisableTCPGRO disables TCP GRO if it is enabled.
DisableTCPGRO()
}

View File

@@ -0,0 +1,336 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
import (
"errors"
"fmt"
"io"
"net"
"os"
"sync"
"syscall"
"time"
"unsafe"
"golang.org/x/sys/unix"
)
const utunControlName = "com.apple.net.utun_control"
type NativeTun struct {
name string
tunFile *os.File
events chan Event
errors chan error
routeSocket int
closeOnce sync.Once
}
func retryInterfaceByIndex(index int) (iface *net.Interface, err error) {
for i := 0; i < 20; i++ {
iface, err = net.InterfaceByIndex(index)
if err != nil && errors.Is(err, unix.ENOMEM) {
time.Sleep(time.Duration(i) * time.Second / 3)
continue
}
return iface, err
}
return nil, err
}
func (tun *NativeTun) routineRouteListener(tunIfindex int) {
var (
statusUp bool
statusMTU int
)
defer close(tun.events)
data := make([]byte, os.Getpagesize())
for {
retry:
n, err := unix.Read(tun.routeSocket, data)
if err != nil {
if errno, ok := err.(unix.Errno); ok && errno == unix.EINTR {
goto retry
}
tun.errors <- err
return
}
if n < 14 {
continue
}
if data[3 /* type */] != unix.RTM_IFINFO {
continue
}
ifindex := int(*(*uint16)(unsafe.Pointer(&data[12 /* ifindex */])))
if ifindex != tunIfindex {
continue
}
iface, err := retryInterfaceByIndex(ifindex)
if err != nil {
tun.errors <- err
return
}
// Up / Down event
up := (iface.Flags & net.FlagUp) != 0
if up != statusUp && up {
tun.events <- EventUp
}
if up != statusUp && !up {
tun.events <- EventDown
}
statusUp = up
// MTU changes
if iface.MTU != statusMTU {
tun.events <- EventMTUUpdate
}
statusMTU = iface.MTU
}
}
func CreateTUN(name string, mtu int) (Device, error) {
ifIndex := -1
if name != "utun" {
_, err := fmt.Sscanf(name, "utun%d", &ifIndex)
if err != nil || ifIndex < 0 {
return nil, fmt.Errorf("Interface name must be utun[0-9]*")
}
}
fd, err := socketCloexec(unix.AF_SYSTEM, unix.SOCK_DGRAM, 2)
if err != nil {
return nil, err
}
ctlInfo := &unix.CtlInfo{}
copy(ctlInfo.Name[:], []byte(utunControlName))
err = unix.IoctlCtlInfo(fd, ctlInfo)
if err != nil {
unix.Close(fd)
return nil, fmt.Errorf("IoctlGetCtlInfo: %w", err)
}
sc := &unix.SockaddrCtl{
ID: ctlInfo.Id,
Unit: uint32(ifIndex) + 1,
}
err = unix.Connect(fd, sc)
if err != nil {
unix.Close(fd)
return nil, err
}
err = unix.SetNonblock(fd, true)
if err != nil {
unix.Close(fd)
return nil, err
}
tun, err := CreateTUNFromFile(os.NewFile(uintptr(fd), ""), mtu)
if err == nil && name == "utun" {
fname := os.Getenv("WG_TUN_NAME_FILE")
if fname != "" {
os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0o400)
}
}
return tun, err
}
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
tun := &NativeTun{
tunFile: file,
events: make(chan Event, 10),
errors: make(chan error, 5),
}
name, err := tun.Name()
if err != nil {
tun.tunFile.Close()
return nil, err
}
tunIfindex, err := func() (int, error) {
iface, err := net.InterfaceByName(name)
if err != nil {
return -1, err
}
return iface.Index, nil
}()
if err != nil {
tun.tunFile.Close()
return nil, err
}
tun.routeSocket, err = socketCloexec(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
tun.tunFile.Close()
return nil, err
}
go tun.routineRouteListener(tunIfindex)
if mtu > 0 {
err = tun.setMTU(mtu)
if err != nil {
tun.Close()
return nil, err
}
}
return tun, nil
}
func (tun *NativeTun) Name() (string, error) {
var err error
tun.operateOnFd(func(fd uintptr) {
tun.name, err = unix.GetsockoptString(
int(fd),
2, /* #define SYSPROTO_CONTROL 2 */
2, /* #define UTUN_OPT_IFNAME 2 */
)
})
if err != nil {
return "", fmt.Errorf("GetSockoptString: %w", err)
}
return tun.name, nil
}
func (tun *NativeTun) File() *os.File {
return tun.tunFile
}
func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
// TODO: the BSDs look very similar in Read() and Write(). They should be
// collapsed, with platform-specific files containing the varying parts of
// their implementations.
select {
case err := <-tun.errors:
return 0, err
default:
buf := bufs[0][offset-4:]
n, err := tun.tunFile.Read(buf[:])
if n < 4 {
return 0, err
}
sizes[0] = n - 4
return 1, err
}
}
func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
if offset < 4 {
return 0, io.ErrShortBuffer
}
for i, buf := range bufs {
buf = buf[offset-4:]
buf[0] = 0x00
buf[1] = 0x00
buf[2] = 0x00
switch buf[4] >> 4 {
case 4:
buf[3] = unix.AF_INET
case 6:
buf[3] = unix.AF_INET6
default:
return i, unix.EAFNOSUPPORT
}
if _, err := tun.tunFile.Write(buf); err != nil {
return i, err
}
}
return len(bufs), nil
}
func (tun *NativeTun) Close() error {
var err1, err2 error
tun.closeOnce.Do(func() {
err1 = tun.tunFile.Close()
if tun.routeSocket != -1 {
unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
err2 = unix.Close(tun.routeSocket)
} else if tun.events != nil {
close(tun.events)
}
})
if err1 != nil {
return err1
}
return err2
}
func (tun *NativeTun) setMTU(n int) error {
fd, err := socketCloexec(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return err
}
defer unix.Close(fd)
var ifr unix.IfreqMTU
copy(ifr.Name[:], tun.name)
ifr.MTU = int32(n)
err = unix.IoctlSetIfreqMTU(fd, &ifr)
if err != nil {
return fmt.Errorf("failed to set MTU on %s: %w", tun.name, err)
}
return nil
}
func (tun *NativeTun) MTU() (int, error) {
fd, err := socketCloexec(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return 0, err
}
defer unix.Close(fd)
ifr, err := unix.IoctlGetIfreqMTU(fd, tun.name)
if err != nil {
return 0, fmt.Errorf("failed to get MTU on %s: %w", tun.name, err)
}
return int(ifr.MTU), nil
}
func (tun *NativeTun) BatchSize() int {
return 1
}
func socketCloexec(family, sotype, proto int) (fd int, err error) {
// See go/src/net/sys_cloexec.go for background.
syscall.ForkLock.RLock()
defer syscall.ForkLock.RUnlock()
fd, err = unix.Socket(family, sotype, proto)
if err == nil {
unix.CloseOnExec(fd)
}
return
}

View File

@@ -0,0 +1,435 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
import (
"errors"
"fmt"
"io"
"net"
"os"
"sync"
"syscall"
"unsafe"
"golang.org/x/sys/unix"
)
const (
_TUNSIFHEAD = 0x80047460
_TUNSIFMODE = 0x8004745e
_TUNGIFNAME = 0x4020745d
_TUNSIFPID = 0x2000745f
_SIOCGIFINFO_IN6 = 0xc048696c
_SIOCSIFINFO_IN6 = 0xc048696d
_ND6_IFF_AUTO_LINKLOCAL = 0x20
_ND6_IFF_NO_DAD = 0x100
)
// Iface requests with just the name
type ifreqName struct {
Name [unix.IFNAMSIZ]byte
_ [16]byte
}
// Iface requests with a pointer
type ifreqPtr struct {
Name [unix.IFNAMSIZ]byte
Data uintptr
_ [16 - unsafe.Sizeof(uintptr(0))]byte
}
// Iface requests with MTU
type ifreqMtu struct {
Name [unix.IFNAMSIZ]byte
MTU uint32
_ [12]byte
}
// ND6 flag manipulation
type nd6Req struct {
Name [unix.IFNAMSIZ]byte
Linkmtu uint32
Maxmtu uint32
Basereachable uint32
Reachable uint32
Retrans uint32
Flags uint32
Recalctm int
Chlim uint8
Initialized uint8
Randomseed0 [8]byte
Randomseed1 [8]byte
Randomid [8]byte
}
type NativeTun struct {
name string
tunFile *os.File
events chan Event
errors chan error
routeSocket int
closeOnce sync.Once
}
func (tun *NativeTun) routineRouteListener(tunIfindex int) {
var (
statusUp bool
statusMTU int
)
defer close(tun.events)
data := make([]byte, os.Getpagesize())
for {
retry:
n, err := unix.Read(tun.routeSocket, data)
if err != nil {
if errors.Is(err, syscall.EINTR) {
goto retry
}
tun.errors <- err
return
}
if n < 14 {
continue
}
if data[3 /* type */] != unix.RTM_IFINFO {
continue
}
ifindex := int(*(*uint16)(unsafe.Pointer(&data[12 /* ifindex */])))
if ifindex != tunIfindex {
continue
}
iface, err := net.InterfaceByIndex(ifindex)
if err != nil {
tun.errors <- err
return
}
// Up / Down event
up := (iface.Flags & net.FlagUp) != 0
if up != statusUp && up {
tun.events <- EventUp
}
if up != statusUp && !up {
tun.events <- EventDown
}
statusUp = up
// MTU changes
if iface.MTU != statusMTU {
tun.events <- EventMTUUpdate
}
statusMTU = iface.MTU
}
}
func tunName(fd uintptr) (string, error) {
var ifreq ifreqName
_, _, err := unix.Syscall(unix.SYS_IOCTL, fd, _TUNGIFNAME, uintptr(unsafe.Pointer(&ifreq)))
if err != 0 {
return "", err
}
return unix.ByteSliceToString(ifreq.Name[:]), nil
}
// Destroy a named system interface
func tunDestroy(name string) error {
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
if err != nil {
return err
}
defer unix.Close(fd)
var ifr [32]byte
copy(ifr[:], name)
_, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCIFDESTROY), uintptr(unsafe.Pointer(&ifr[0])))
if errno != 0 {
return fmt.Errorf("failed to destroy interface %s: %w", name, errno)
}
return nil
}
func CreateTUN(name string, mtu int) (Device, error) {
if len(name) > unix.IFNAMSIZ-1 {
return nil, errors.New("interface name too long")
}
// See if interface already exists
iface, _ := net.InterfaceByName(name)
if iface != nil {
return nil, fmt.Errorf("interface %s already exists", name)
}
tunFile, err := os.OpenFile("/dev/tun", unix.O_RDWR|unix.O_CLOEXEC, 0)
if err != nil {
return nil, err
}
tun := NativeTun{tunFile: tunFile}
var assignedName string
tun.operateOnFd(func(fd uintptr) {
assignedName, err = tunName(fd)
})
if err != nil {
tunFile.Close()
return nil, err
}
// Enable ifhead mode, otherwise tun will complain if it gets a non-AF_INET packet
ifheadmode := 1
var errno syscall.Errno
tun.operateOnFd(func(fd uintptr) {
_, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, _TUNSIFHEAD, uintptr(unsafe.Pointer(&ifheadmode)))
})
if errno != 0 {
tunFile.Close()
tunDestroy(assignedName)
return nil, fmt.Errorf("unable to put into IFHEAD mode: %w", errno)
}
// Get out of PTP mode.
ifflags := syscall.IFF_BROADCAST | syscall.IFF_MULTICAST
tun.operateOnFd(func(fd uintptr) {
_, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, uintptr(_TUNSIFMODE), uintptr(unsafe.Pointer(&ifflags)))
})
if errno != 0 {
tunFile.Close()
tunDestroy(assignedName)
return nil, fmt.Errorf("unable to put into IFF_BROADCAST mode: %w", errno)
}
// Disable link-local v6, not just because WireGuard doesn't do that anyway, but
// also because there are serious races with attaching and detaching LLv6 addresses
// in relation to interface lifetime within the FreeBSD kernel.
confd6, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
if err != nil {
tunFile.Close()
tunDestroy(assignedName)
return nil, err
}
defer unix.Close(confd6)
var ndireq nd6Req
copy(ndireq.Name[:], assignedName)
_, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd6), uintptr(_SIOCGIFINFO_IN6), uintptr(unsafe.Pointer(&ndireq)))
if errno != 0 {
tunFile.Close()
tunDestroy(assignedName)
return nil, fmt.Errorf("unable to get nd6 flags for %s: %w", assignedName, errno)
}
ndireq.Flags = ndireq.Flags &^ _ND6_IFF_AUTO_LINKLOCAL
ndireq.Flags = ndireq.Flags | _ND6_IFF_NO_DAD
_, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd6), uintptr(_SIOCSIFINFO_IN6), uintptr(unsafe.Pointer(&ndireq)))
if errno != 0 {
tunFile.Close()
tunDestroy(assignedName)
return nil, fmt.Errorf("unable to set nd6 flags for %s: %w", assignedName, errno)
}
if name != "" {
confd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
if err != nil {
tunFile.Close()
tunDestroy(assignedName)
return nil, err
}
defer unix.Close(confd)
var newnp [unix.IFNAMSIZ]byte
copy(newnp[:], name)
var ifr ifreqPtr
copy(ifr.Name[:], assignedName)
ifr.Data = uintptr(unsafe.Pointer(&newnp[0]))
_, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd), uintptr(unix.SIOCSIFNAME), uintptr(unsafe.Pointer(&ifr)))
if errno != 0 {
tunFile.Close()
tunDestroy(assignedName)
return nil, fmt.Errorf("Failed to rename %s to %s: %w", assignedName, name, errno)
}
}
return CreateTUNFromFile(tunFile, mtu)
}
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
tun := &NativeTun{
tunFile: file,
events: make(chan Event, 10),
errors: make(chan error, 1),
}
var errno syscall.Errno
tun.operateOnFd(func(fd uintptr) {
_, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, _TUNSIFPID, uintptr(0))
})
if errno != 0 {
tun.tunFile.Close()
return nil, fmt.Errorf("unable to become controlling TUN process: %w", errno)
}
name, err := tun.Name()
if err != nil {
tun.tunFile.Close()
return nil, err
}
tunIfindex, err := func() (int, error) {
iface, err := net.InterfaceByName(name)
if err != nil {
return -1, err
}
return iface.Index, nil
}()
if err != nil {
tun.tunFile.Close()
return nil, err
}
tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.AF_UNSPEC)
if err != nil {
tun.tunFile.Close()
return nil, err
}
go tun.routineRouteListener(tunIfindex)
err = tun.setMTU(mtu)
if err != nil {
tun.Close()
return nil, err
}
return tun, nil
}
func (tun *NativeTun) Name() (string, error) {
var name string
var err error
tun.operateOnFd(func(fd uintptr) {
name, err = tunName(fd)
})
if err != nil {
return "", err
}
tun.name = name
return name, nil
}
func (tun *NativeTun) File() *os.File {
return tun.tunFile
}
func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
select {
case err := <-tun.errors:
return 0, err
default:
buf := bufs[0][offset-4:]
n, err := tun.tunFile.Read(buf[:])
if n < 4 {
return 0, err
}
sizes[0] = n - 4
return 1, err
}
}
func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
if offset < 4 {
return 0, io.ErrShortBuffer
}
for i, buf := range bufs {
buf = buf[offset-4:]
if len(buf) < 5 {
return i, io.ErrShortBuffer
}
buf[0] = 0x00
buf[1] = 0x00
buf[2] = 0x00
switch buf[4] >> 4 {
case 4:
buf[3] = unix.AF_INET
case 6:
buf[3] = unix.AF_INET6
default:
return i, unix.EAFNOSUPPORT
}
if _, err := tun.tunFile.Write(buf); err != nil {
return i, err
}
}
return len(bufs), nil
}
func (tun *NativeTun) Close() error {
var err1, err2, err3 error
tun.closeOnce.Do(func() {
err1 = tun.tunFile.Close()
err2 = tunDestroy(tun.name)
if tun.routeSocket != -1 {
unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
err3 = unix.Close(tun.routeSocket)
tun.routeSocket = -1
} else if tun.events != nil {
close(tun.events)
}
})
if err1 != nil {
return err1
}
if err2 != nil {
return err2
}
return err3
}
func (tun *NativeTun) setMTU(n int) error {
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
if err != nil {
return err
}
defer unix.Close(fd)
var ifr ifreqMtu
copy(ifr.Name[:], tun.name)
ifr.MTU = uint32(n)
_, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCSIFMTU), uintptr(unsafe.Pointer(&ifr)))
if errno != 0 {
return fmt.Errorf("failed to set MTU on %s: %w", tun.name, errno)
}
return nil
}
func (tun *NativeTun) MTU() (int, error) {
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
if err != nil {
return 0, err
}
defer unix.Close(fd)
var ifr ifreqMtu
copy(ifr.Name[:], tun.name)
_, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCGIFMTU), uintptr(unsafe.Pointer(&ifr)))
if errno != 0 {
return 0, fmt.Errorf("failed to get MTU on %s: %w", tun.name, errno)
}
return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil
}
func (tun *NativeTun) BatchSize() int {
return 1
}

View File

@@ -0,0 +1,660 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
/* Implementation of the TUN device interface for linux
*/
import (
"errors"
"fmt"
"os"
"sync"
"syscall"
"time"
"unsafe"
"github.com/tailscale/wireguard-go/conn"
"github.com/tailscale/wireguard-go/rwcancel"
"golang.org/x/sys/unix"
)
const (
cloneDevicePath = "/dev/net/tun"
ifReqSize = unix.IFNAMSIZ + 64
)
type NativeTun struct {
tunFile *os.File
index int32 // if index
errors chan error // async error handling
events chan Event // device related events
netlinkSock int
netlinkCancel *rwcancel.RWCancel
hackListenerClosed sync.Mutex
statusListenersShutdown chan struct{}
batchSize int
vnetHdr bool
closeOnce sync.Once
nameOnce sync.Once // guards calling initNameCache, which sets following fields
nameCache string // name of interface
nameErr error
readOpMu sync.Mutex // readOpMu guards readBuff
readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr
writeOpMu sync.Mutex // writeOpMu guards the following fields
toWrite []int
tcpGROTable *tcpGROTable
udpGROTable *udpGROTable
gro groDisablementFlags
}
type groDisablementFlags int
const (
tcpGRODisabled groDisablementFlags = 1 << iota
udpGRODisabled
)
func (g *groDisablementFlags) disableTCPGRO() {
*g |= tcpGRODisabled
}
func (g *groDisablementFlags) canTCPGRO() bool {
return (*g)&tcpGRODisabled == 0
}
func (g *groDisablementFlags) disableUDPGRO() {
*g |= udpGRODisabled
}
func (g *groDisablementFlags) canUDPGRO() bool {
return (*g)&udpGRODisabled == 0
}
func (tun *NativeTun) File() *os.File {
return tun.tunFile
}
func (tun *NativeTun) routineHackListener() {
defer tun.hackListenerClosed.Unlock()
/* This is needed for the detection to work across network namespaces
* If you are reading this and know a better method, please get in touch.
*/
last := 0
const (
up = 1
down = 2
)
for {
sysconn, err := tun.tunFile.SyscallConn()
if err != nil {
return
}
err2 := sysconn.Control(func(fd uintptr) {
_, err = unix.Write(int(fd), nil)
})
if err2 != nil {
return
}
switch err {
case unix.EINVAL:
if last != up {
// If the tunnel is up, it reports that write() is
// allowed but we provided invalid data.
tun.events <- EventUp
last = up
}
case unix.EIO:
if last != down {
// If the tunnel is down, it reports that no I/O
// is possible, without checking our provided data.
tun.events <- EventDown
last = down
}
default:
return
}
select {
case <-time.After(time.Second):
// nothing
case <-tun.statusListenersShutdown:
return
}
}
}
func createNetlinkSocket() (int, error) {
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
if err != nil {
return -1, err
}
saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
Groups: unix.RTMGRP_LINK | unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR,
}
err = unix.Bind(sock, saddr)
if err != nil {
return -1, err
}
return sock, nil
}
func (tun *NativeTun) routineNetlinkListener() {
defer func() {
unix.Close(tun.netlinkSock)
tun.hackListenerClosed.Lock()
close(tun.events)
tun.netlinkCancel.Close()
}()
for msg := make([]byte, 1<<16); ; {
var err error
var msgn int
for {
msgn, _, _, _, err = unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0)
if err == nil || !rwcancel.RetryAfterError(err) {
break
}
if !tun.netlinkCancel.ReadyRead() {
tun.errors <- fmt.Errorf("netlink socket closed: %w", err)
return
}
}
if err != nil {
tun.errors <- fmt.Errorf("failed to receive netlink message: %w", err)
return
}
select {
case <-tun.statusListenersShutdown:
return
default:
}
wasEverUp := false
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
if int(hdr.Len) > len(remain) {
break
}
switch hdr.Type {
case unix.NLMSG_DONE:
remain = []byte{}
case unix.RTM_NEWLINK:
info := *(*unix.IfInfomsg)(unsafe.Pointer(&remain[unix.SizeofNlMsghdr]))
remain = remain[hdr.Len:]
if info.Index != tun.index {
// not our interface
continue
}
if info.Flags&unix.IFF_RUNNING != 0 {
tun.events <- EventUp
wasEverUp = true
}
if info.Flags&unix.IFF_RUNNING == 0 {
// Don't emit EventDown before we've ever emitted EventUp.
// This avoids a startup race with HackListener, which
// might detect Up before we have finished reporting Down.
if wasEverUp {
tun.events <- EventDown
}
}
tun.events <- EventMTUUpdate
default:
remain = remain[hdr.Len:]
}
}
}
}
func getIFIndex(name string) (int32, error) {
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
if err != nil {
return 0, err
}
defer unix.Close(fd)
var ifr [ifReqSize]byte
copy(ifr[:], name)
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(unix.SIOCGIFINDEX),
uintptr(unsafe.Pointer(&ifr[0])),
)
if errno != 0 {
return 0, errno
}
return *(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])), nil
}
func (tun *NativeTun) setMTU(n int) error {
name, err := tun.Name()
if err != nil {
return err
}
// open datagram socket
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
if err != nil {
return err
}
defer unix.Close(fd)
// do ioctl call
var ifr [ifReqSize]byte
copy(ifr[:], name)
*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n)
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(unix.SIOCSIFMTU),
uintptr(unsafe.Pointer(&ifr[0])),
)
if errno != 0 {
return fmt.Errorf("failed to set MTU of TUN device: %w", errno)
}
return nil
}
func (tun *NativeTun) MTU() (int, error) {
name, err := tun.Name()
if err != nil {
return 0, err
}
// open datagram socket
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
if err != nil {
return 0, err
}
defer unix.Close(fd)
// do ioctl call
var ifr [ifReqSize]byte
copy(ifr[:], name)
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(unix.SIOCGIFMTU),
uintptr(unsafe.Pointer(&ifr[0])),
)
if errno != 0 {
return 0, fmt.Errorf("failed to get MTU of TUN device: %w", errno)
}
return int(*(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ]))), nil
}
func (tun *NativeTun) Name() (string, error) {
tun.nameOnce.Do(tun.initNameCache)
return tun.nameCache, tun.nameErr
}
func (tun *NativeTun) initNameCache() {
tun.nameCache, tun.nameErr = tun.nameSlow()
}
func (tun *NativeTun) nameSlow() (string, error) {
sysconn, err := tun.tunFile.SyscallConn()
if err != nil {
return "", err
}
var ifr [ifReqSize]byte
var errno syscall.Errno
err = sysconn.Control(func(fd uintptr) {
_, _, errno = unix.Syscall(
unix.SYS_IOCTL,
fd,
uintptr(unix.TUNGETIFF),
uintptr(unsafe.Pointer(&ifr[0])),
)
})
if err != nil {
return "", fmt.Errorf("failed to get name of TUN device: %w", err)
}
if errno != 0 {
return "", fmt.Errorf("failed to get name of TUN device: %w", errno)
}
return unix.ByteSliceToString(ifr[:]), nil
}
func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
tun.writeOpMu.Lock()
defer func() {
tun.tcpGROTable.reset()
tun.udpGROTable.reset()
tun.writeOpMu.Unlock()
}()
var (
errs error
total int
)
tun.toWrite = tun.toWrite[:0]
if tun.vnetHdr {
err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.gro, &tun.toWrite)
if err != nil {
return 0, err
}
offset -= virtioNetHdrLen
} else {
for i := range bufs {
tun.toWrite = append(tun.toWrite, i)
}
}
for _, bufsI := range tun.toWrite {
n, err := tun.tunFile.Write(bufs[bufsI][offset:])
if errors.Is(err, syscall.EBADFD) {
return total, os.ErrClosed
}
if err != nil {
errs = errors.Join(errs, err)
} else {
total += n
}
}
return total, errs
}
// handleVirtioRead splits in into bufs, leaving offset bytes at the front of
// each buffer. It mutates sizes to reflect the size of each element of bufs,
// and returns the number of packets read.
func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) {
var hdr virtioNetHdr
err := hdr.decode(in)
if err != nil {
return 0, err
}
in = in[virtioNetHdrLen:]
options, err := hdr.toGSOOptions()
if err != nil {
return 0, err
}
// Don't trust HdrLen from the kernel as it can be equal to the length
// of the entire first packet when the kernel is handling it as part of a
// FORWARD path. Instead, parse the transport header length and add it onto
// CsumStart, which is synonymous for IP header length.
if options.GSOType == GSOUDPL4 {
options.HdrLen = options.CsumStart + 8
} else if options.GSOType != GSONone {
if len(in) <= int(options.CsumStart+12) {
return 0, errors.New("packet is too short")
}
tcpHLen := uint16(in[options.CsumStart+12] >> 4 * 4)
if tcpHLen < 20 || tcpHLen > 60 {
// A TCP header must be between 20 and 60 bytes in length.
return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
}
options.HdrLen = options.CsumStart + tcpHLen
}
return GSOSplit(in, options, bufs, sizes, offset)
}
func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
tun.readOpMu.Lock()
defer tun.readOpMu.Unlock()
select {
case err := <-tun.errors:
return 0, err
default:
readInto := bufs[0][offset:]
if tun.vnetHdr {
readInto = tun.readBuff[:]
}
n, err := tun.tunFile.Read(readInto)
if errors.Is(err, syscall.EBADFD) {
err = os.ErrClosed
}
if err != nil {
return 0, err
}
if tun.vnetHdr {
return handleVirtioRead(readInto[:n], bufs, sizes, offset)
} else {
sizes[0] = n
return 1, nil
}
}
}
func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
func (tun *NativeTun) Close() error {
var err1, err2 error
tun.closeOnce.Do(func() {
if tun.statusListenersShutdown != nil {
close(tun.statusListenersShutdown)
if tun.netlinkCancel != nil {
err1 = tun.netlinkCancel.Cancel()
}
} else if tun.events != nil {
close(tun.events)
}
err2 = tun.tunFile.Close()
})
if err1 != nil {
return err1
}
return err2
}
func (tun *NativeTun) BatchSize() int {
return tun.batchSize
}
// DisableUDPGRO disables UDP GRO if it is enabled. See the GRODevice interface
// for cases where it should be called.
func (tun *NativeTun) DisableUDPGRO() {
tun.writeOpMu.Lock()
tun.gro.disableUDPGRO()
tun.writeOpMu.Unlock()
}
// DisableTCPGRO disables TCP GRO if it is enabled. See the GRODevice interface
// for cases where it should be called.
func (tun *NativeTun) DisableTCPGRO() {
tun.writeOpMu.Lock()
tun.gro.disableTCPGRO()
tun.writeOpMu.Unlock()
}
const (
// TODO: support TSO with ECN bits
tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6
)
func (tun *NativeTun) initFromFlags(name string) error {
sc, err := tun.tunFile.SyscallConn()
if err != nil {
return err
}
if e := sc.Control(func(fd uintptr) {
var (
ifr *unix.Ifreq
)
ifr, err = unix.NewIfreq(name)
if err != nil {
return
}
err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr)
if err != nil {
return
}
got := ifr.Uint16()
if got&unix.IFF_VNET_HDR != 0 {
// tunTCPOffloads were added in Linux v2.6. We require their support
// if IFF_VNET_HDR is set.
err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads)
if err != nil {
return
}
tun.vnetHdr = true
tun.batchSize = conn.IdealBatchSize
// tunUDPOffloads were added in Linux v6.2. We do not return an
// error if they are unsupported at runtime.
if unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads) != nil {
tun.gro.disableUDPGRO()
}
} else {
tun.batchSize = 1
}
}); e != nil {
return e
}
return err
}
// CreateTUN creates a Device with the provided name and MTU.
func CreateTUN(name string, mtu int) (Device, error) {
nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0)
if err != nil {
if os.IsNotExist(err) {
return nil, fmt.Errorf("CreateTUN(%q) failed; %s does not exist", name, cloneDevicePath)
}
return nil, err
}
ifr, err := unix.NewIfreq(name)
if err != nil {
return nil, err
}
// IFF_VNET_HDR enables the "tun status hack" via routineHackListener()
// where a null write will return EINVAL indicating the TUN is up.
ifr.SetUint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR)
err = unix.IoctlIfreq(nfd, unix.TUNSETIFF, ifr)
if err != nil {
return nil, err
}
err = unix.SetNonblock(nfd, true)
if err != nil {
unix.Close(nfd)
return nil, err
}
// Note that the above -- open,ioctl,nonblock -- must happen prior to handing it to netpoll as below this line.
fd := os.NewFile(uintptr(nfd), cloneDevicePath)
return CreateTUNFromFile(fd, mtu)
}
// CreateTUNFromFile creates a Device from an os.File with the provided MTU.
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
tun := &NativeTun{
tunFile: file,
events: make(chan Event, 5),
errors: make(chan error, 5),
statusListenersShutdown: make(chan struct{}),
tcpGROTable: newTCPGROTable(),
udpGROTable: newUDPGROTable(),
toWrite: make([]int, 0, conn.IdealBatchSize),
}
name, err := tun.Name()
if err != nil {
return nil, err
}
err = tun.initFromFlags(name)
if err != nil {
return nil, err
}
// start event listener
tun.index, err = getIFIndex(name)
if err != nil {
return nil, err
}
tun.netlinkSock, err = createNetlinkSocket()
if err != nil {
return nil, err
}
tun.netlinkCancel, err = rwcancel.NewRWCancel(tun.netlinkSock)
if err != nil {
unix.Close(tun.netlinkSock)
return nil, err
}
tun.hackListenerClosed.Lock()
go tun.routineNetlinkListener()
go tun.routineHackListener() // cross namespace
err = tun.setMTU(mtu)
if err != nil {
unix.Close(tun.netlinkSock)
return nil, err
}
return tun, nil
}
// CreateUnmonitoredTUNFromFD creates a Device from the provided file
// descriptor.
func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
err := unix.SetNonblock(fd, true)
if err != nil {
return nil, "", err
}
file := os.NewFile(uintptr(fd), "/dev/tun")
tun := &NativeTun{
tunFile: file,
events: make(chan Event, 5),
errors: make(chan error, 5),
tcpGROTable: newTCPGROTable(),
udpGROTable: newUDPGROTable(),
toWrite: make([]int, 0, conn.IdealBatchSize),
}
name, err := tun.Name()
if err != nil {
return nil, "", err
}
err = tun.initFromFlags(name)
if err != nil {
return nil, "", err
}
return tun, name, err
}

View File

@@ -0,0 +1,333 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
import (
"errors"
"fmt"
"io"
"net"
"os"
"sync"
"syscall"
"unsafe"
"golang.org/x/sys/unix"
)
// Structure for iface mtu get/set ioctls
type ifreq_mtu struct {
Name [unix.IFNAMSIZ]byte
MTU uint32
Pad0 [12]byte
}
const _TUNSIFMODE = 0x8004745d
type NativeTun struct {
name string
tunFile *os.File
events chan Event
errors chan error
routeSocket int
closeOnce sync.Once
}
func (tun *NativeTun) routineRouteListener(tunIfindex int) {
var (
statusUp bool
statusMTU int
)
defer close(tun.events)
check := func() bool {
iface, err := net.InterfaceByIndex(tunIfindex)
if err != nil {
tun.errors <- err
return true
}
// Up / Down event
up := (iface.Flags & net.FlagUp) != 0
if up != statusUp && up {
tun.events <- EventUp
}
if up != statusUp && !up {
tun.events <- EventDown
}
statusUp = up
// MTU changes
if iface.MTU != statusMTU {
tun.events <- EventMTUUpdate
}
statusMTU = iface.MTU
return false
}
if check() {
return
}
data := make([]byte, os.Getpagesize())
for {
n, err := unix.Read(tun.routeSocket, data)
if err != nil {
if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR {
continue
}
tun.errors <- err
return
}
if n < 8 {
continue
}
if data[3 /* type */] != unix.RTM_IFINFO {
continue
}
ifindex := int(*(*uint16)(unsafe.Pointer(&data[6 /* ifindex */])))
if ifindex != tunIfindex {
continue
}
if check() {
return
}
}
}
func CreateTUN(name string, mtu int) (Device, error) {
ifIndex := -1
if name != "tun" {
_, err := fmt.Sscanf(name, "tun%d", &ifIndex)
if err != nil || ifIndex < 0 {
return nil, fmt.Errorf("Interface name must be tun[0-9]*")
}
}
var tunfile *os.File
var err error
if ifIndex != -1 {
tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR|unix.O_CLOEXEC, 0)
} else {
for ifIndex = 0; ifIndex < 256; ifIndex++ {
tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR|unix.O_CLOEXEC, 0)
if err == nil || !errors.Is(err, syscall.EBUSY) {
break
}
}
}
if err != nil {
return nil, err
}
tun, err := CreateTUNFromFile(tunfile, mtu)
if err == nil && name == "tun" {
fname := os.Getenv("WG_TUN_NAME_FILE")
if fname != "" {
os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0o400)
}
}
return tun, err
}
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
tun := &NativeTun{
tunFile: file,
events: make(chan Event, 10),
errors: make(chan error, 1),
}
name, err := tun.Name()
if err != nil {
tun.tunFile.Close()
return nil, err
}
tunIfindex, err := func() (int, error) {
iface, err := net.InterfaceByName(name)
if err != nil {
return -1, err
}
return iface.Index, nil
}()
if err != nil {
tun.tunFile.Close()
return nil, err
}
tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.AF_UNSPEC)
if err != nil {
tun.tunFile.Close()
return nil, err
}
go tun.routineRouteListener(tunIfindex)
currentMTU, err := tun.MTU()
if err != nil || currentMTU != mtu {
err = tun.setMTU(mtu)
if err != nil {
tun.Close()
return nil, err
}
}
return tun, nil
}
func (tun *NativeTun) Name() (string, error) {
gostat, err := tun.tunFile.Stat()
if err != nil {
tun.name = ""
return "", err
}
stat := gostat.Sys().(*syscall.Stat_t)
tun.name = fmt.Sprintf("tun%d", stat.Rdev%256)
return tun.name, nil
}
func (tun *NativeTun) File() *os.File {
return tun.tunFile
}
func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
select {
case err := <-tun.errors:
return 0, err
default:
buf := bufs[0][offset-4:]
n, err := tun.tunFile.Read(buf[:])
if n < 4 {
return 0, err
}
sizes[0] = n - 4
return 1, err
}
}
func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
if offset < 4 {
return 0, io.ErrShortBuffer
}
for i, buf := range bufs {
buf = buf[offset-4:]
buf[0] = 0x00
buf[1] = 0x00
buf[2] = 0x00
switch buf[4] >> 4 {
case 4:
buf[3] = unix.AF_INET
case 6:
buf[3] = unix.AF_INET6
default:
return i, unix.EAFNOSUPPORT
}
if _, err := tun.tunFile.Write(buf); err != nil {
return i, err
}
}
return len(bufs), nil
}
func (tun *NativeTun) Close() error {
var err1, err2 error
tun.closeOnce.Do(func() {
err1 = tun.tunFile.Close()
if tun.routeSocket != -1 {
unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
err2 = unix.Close(tun.routeSocket)
tun.routeSocket = -1
} else if tun.events != nil {
close(tun.events)
}
})
if err1 != nil {
return err1
}
return err2
}
func (tun *NativeTun) setMTU(n int) error {
// open datagram socket
var fd int
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
if err != nil {
return err
}
defer unix.Close(fd)
// do ioctl call
var ifr ifreq_mtu
copy(ifr.Name[:], tun.name)
ifr.MTU = uint32(n)
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(unix.SIOCSIFMTU),
uintptr(unsafe.Pointer(&ifr)),
)
if errno != 0 {
return fmt.Errorf("failed to set MTU on %s", tun.name)
}
return nil
}
func (tun *NativeTun) MTU() (int, error) {
// open datagram socket
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
if err != nil {
return 0, err
}
defer unix.Close(fd)
// do ioctl call
var ifr ifreq_mtu
copy(ifr.Name[:], tun.name)
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(unix.SIOCGIFMTU),
uintptr(unsafe.Pointer(&ifr)),
)
if errno != 0 {
return 0, fmt.Errorf("failed to get MTU on %s", tun.name)
}
return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil
}
func (tun *NativeTun) BatchSize() int {
return 1
}

View File

@@ -0,0 +1,242 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
import (
"errors"
"fmt"
"os"
"sync"
"sync/atomic"
"time"
_ "unsafe"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wintun"
)
const (
rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond)
spinloopRateThreshold = 800000000 / 8 // 800mbps
spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s
)
type rateJuggler struct {
current atomic.Uint64
nextByteCount atomic.Uint64
nextStartTime atomic.Int64
changing atomic.Bool
}
type NativeTun struct {
wt *wintun.Adapter
name string
handle windows.Handle
rate rateJuggler
session wintun.Session
readWait windows.Handle
events chan Event
running sync.WaitGroup
closeOnce sync.Once
close atomic.Bool
forcedMTU int
outSizes []int
}
var (
WintunTunnelType = "WireGuard"
WintunStaticRequestedGUID *windows.GUID
)
//go:linkname procyield runtime.procyield
func procyield(cycles uint32)
//go:linkname nanotime runtime.nanotime
func nanotime() int64
// CreateTUN creates a Wintun interface with the given name. Should a Wintun
// interface with the same name exist, it is reused.
func CreateTUN(ifname string, mtu int) (Device, error) {
return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu)
}
// CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
// a requested GUID. Should a Wintun interface with the same name exist, it is reused.
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID)
if err != nil {
return nil, fmt.Errorf("Error creating interface: %w", err)
}
forcedMTU := 1420
if mtu > 0 {
forcedMTU = mtu
}
tun := &NativeTun{
wt: wt,
name: ifname,
handle: windows.InvalidHandle,
events: make(chan Event, 10),
forcedMTU: forcedMTU,
}
tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB
if err != nil {
tun.wt.Close()
close(tun.events)
return nil, fmt.Errorf("Error starting session: %w", err)
}
tun.readWait = tun.session.ReadWaitEvent()
return tun, nil
}
func (tun *NativeTun) Name() (string, error) {
return tun.name, nil
}
func (tun *NativeTun) File() *os.File {
return nil
}
func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
func (tun *NativeTun) Close() error {
var err error
tun.closeOnce.Do(func() {
tun.close.Store(true)
windows.SetEvent(tun.readWait)
tun.running.Wait()
tun.session.End()
if tun.wt != nil {
tun.wt.Close()
}
close(tun.events)
})
return err
}
func (tun *NativeTun) MTU() (int, error) {
return tun.forcedMTU, nil
}
// TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes.
func (tun *NativeTun) ForceMTU(mtu int) {
if tun.close.Load() {
return
}
update := tun.forcedMTU != mtu
tun.forcedMTU = mtu
if update {
tun.events <- EventMTUUpdate
}
}
func (tun *NativeTun) BatchSize() int {
// TODO: implement batching with wintun
return 1
}
// Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
tun.running.Add(1)
defer tun.running.Done()
retry:
if tun.close.Load() {
return 0, os.ErrClosed
}
start := nanotime()
shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2
for {
if tun.close.Load() {
return 0, os.ErrClosed
}
packet, err := tun.session.ReceivePacket()
switch err {
case nil:
packetSize := len(packet)
copy(bufs[0][offset:], packet)
sizes[0] = packetSize
tun.session.ReleaseReceivePacket(packet)
tun.rate.update(uint64(packetSize))
return 1, nil
case windows.ERROR_NO_MORE_ITEMS:
if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
windows.WaitForSingleObject(tun.readWait, windows.INFINITE)
goto retry
}
procyield(1)
continue
case windows.ERROR_HANDLE_EOF:
return 0, os.ErrClosed
case windows.ERROR_INVALID_DATA:
return 0, errors.New("Send ring corrupt")
}
return 0, fmt.Errorf("Read failed: %w", err)
}
}
func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
tun.running.Add(1)
defer tun.running.Done()
if tun.close.Load() {
return 0, os.ErrClosed
}
for i, buf := range bufs {
packetSize := len(buf) - offset
tun.rate.update(uint64(packetSize))
packet, err := tun.session.AllocateSendPacket(packetSize)
switch err {
case nil:
// TODO: Explore options to eliminate this copy.
copy(packet, buf[offset:])
tun.session.SendPacket(packet)
continue
case windows.ERROR_HANDLE_EOF:
return i, os.ErrClosed
case windows.ERROR_BUFFER_OVERFLOW:
continue // Dropping when ring is full.
default:
return i, fmt.Errorf("Write failed: %w", err)
}
}
return len(bufs), nil
}
// LUID returns Windows interface instance ID.
func (tun *NativeTun) LUID() uint64 {
tun.running.Add(1)
defer tun.running.Done()
if tun.close.Load() {
return 0
}
return tun.wt.LUID()
}
// RunningVersion returns the running version of the Wintun driver.
func (tun *NativeTun) RunningVersion() (version uint32, err error) {
return wintun.RunningVersion()
}
func (rate *rateJuggler) update(packetLen uint64) {
now := nanotime()
total := rate.nextByteCount.Add(packetLen)
period := uint64(now - rate.nextStartTime.Load())
if period >= rateMeasurementGranularity {
if !rate.changing.CompareAndSwap(false, true) {
return
}
rate.nextStartTime.Store(now)
rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period)
rate.nextByteCount.Store(0)
rate.changing.Store(false)
}
}