This commit is contained in:
2026-02-19 10:07:43 +00:00
parent 007438e372
commit 6e637ecf77
1763 changed files with 60820 additions and 279516 deletions

View File

@@ -7,6 +7,7 @@ package bakedroots
import (
"crypto/x509"
"fmt"
"sync"
"tailscale.com/util/testenv"
@@ -14,7 +15,7 @@ import (
// Get returns the baked-in roots.
//
// As of 2025-01-21, this includes only the LetsEncrypt ISRG Root X1 root.
// As of 2025-01-21, this includes only the LetsEncrypt ISRG Root X1 & X2 roots.
func Get() *x509.CertPool {
roots.once.Do(func() {
roots.parsePEM(append(
@@ -25,16 +26,9 @@ func Get() *x509.CertPool {
return roots.p
}
// testingTB is a subset of testing.TB needed
// to verify the caller isn't in a parallel test.
type testingTB interface {
// Setenv panics if it's in a parallel test.
Setenv(k, v string)
}
// ResetForTest resets the cached roots for testing,
// optionally setting them to caPEM if non-nil.
func ResetForTest(tb testingTB, caPEM []byte) {
func ResetForTest(tb testenv.TB, caPEM []byte) {
if !testenv.InTest() {
panic("not in test")
}
@@ -43,6 +37,10 @@ func ResetForTest(tb testingTB, caPEM []byte) {
roots = rootsOnce{}
if caPEM != nil {
roots.once.Do(func() { roots.parsePEM(caPEM) })
tb.Cleanup(func() {
// Reset the roots to real roots for any following test.
roots = rootsOnce{}
})
}
}
@@ -56,7 +54,7 @@ type rootsOnce struct {
func (r *rootsOnce) parsePEM(caPEM []byte) {
p := x509.NewCertPool()
if !p.AppendCertsFromPEM(caPEM) {
panic("bogus PEM")
panic(fmt.Sprintf("bogus PEM: %q", caPEM))
}
r.p = p
}

47
vendor/tailscale.com/net/batching/conn.go generated vendored Normal file
View File

@@ -0,0 +1,47 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package batching implements a socket optimized for increased throughput.
package batching
import (
"net/netip"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"tailscale.com/net/packet"
"tailscale.com/types/nettype"
)
var (
// This acts as a compile-time check for our usage of ipv6.Message in
// [Conn] for both IPv6 and IPv4 operations.
_ ipv6.Message = ipv4.Message{}
)
// Conn is a nettype.PacketConn that provides batched i/o using
// platform-specific optimizations, e.g. {recv,send}mmsg & UDP GSO/GRO.
//
// Conn originated from (and is still used by) magicsock where its API was
// strongly influenced by [wireguard-go/conn.Bind] constraints, namely
// wireguard-go's ownership of packet memory.
type Conn interface {
nettype.PacketConn
// ReadBatch reads messages from [Conn] into msgs. It returns the number of
// messages the caller should evaluate for nonzero len, as a zero len
// message may fall on either side of a nonzero.
//
// Each [ipv6.Message.OOB] must be sized to at least MinControlMessageSize().
ReadBatch(msgs []ipv6.Message, flags int) (n int, err error)
// WriteBatchTo writes buffs to addr.
//
// If geneve.VNI.IsSet(), then geneve is encoded into the space preceding
// offset, and offset must equal [packet.GeneveFixedHeaderLength]. If
// !geneve.VNI.IsSet() then the space preceding offset is ignored.
//
// len(buffs) must be <= batchSize supplied in TryUpgradeToConn().
//
// WriteBatchTo may return a [neterror.ErrUDPGSODisabled] error if UDP GSO
// was disabled as a result of a send error.
WriteBatchTo(buffs [][]byte, addr netip.AddrPort, geneve packet.GeneveHeader, offset int) error
}

23
vendor/tailscale.com/net/batching/conn_default.go generated vendored Normal file
View File

@@ -0,0 +1,23 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build !linux
package batching
import (
"tailscale.com/types/nettype"
)
// TryUpgradeToConn is no-op on all platforms except linux.
func TryUpgradeToConn(pconn nettype.PacketConn, _ string, _ int) nettype.PacketConn {
return pconn
}
var controlMessageSize = 0
func MinControlMessageSize() int {
return controlMessageSize
}
const IdealBatchSize = 1

460
vendor/tailscale.com/net/batching/conn_linux.go generated vendored Normal file
View File

@@ -0,0 +1,460 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package batching
import (
"encoding/binary"
"errors"
"fmt"
"net"
"net/netip"
"runtime"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
"unsafe"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
"tailscale.com/hostinfo"
"tailscale.com/net/neterror"
"tailscale.com/net/packet"
"tailscale.com/types/nettype"
)
// xnetBatchReaderWriter defines the batching i/o methods of
// golang.org/x/net/ipv4.PacketConn (and ipv6.PacketConn).
// TODO(jwhited): This should eventually be replaced with the standard library
// implementation of https://github.com/golang/go/issues/45886
type xnetBatchReaderWriter interface {
xnetBatchReader
xnetBatchWriter
}
type xnetBatchReader interface {
ReadBatch([]ipv6.Message, int) (int, error)
}
type xnetBatchWriter interface {
WriteBatch([]ipv6.Message, int) (int, error)
}
var (
// [linuxBatchingConn] implements [Conn].
_ Conn = (*linuxBatchingConn)(nil)
)
// linuxBatchingConn is a UDP socket that provides batched i/o. It implements
// [Conn].
type linuxBatchingConn struct {
pc *net.UDPConn
xpc xnetBatchReaderWriter
rxOffload bool // supports UDP GRO or similar
txOffload atomic.Bool // supports UDP GSO or similar
setGSOSizeInControl func(control *[]byte, gsoSize uint16) // typically setGSOSizeInControl(); swappable for testing
getGSOSizeFromControl func(control []byte) (int, error) // typically getGSOSizeFromControl(); swappable for testing
sendBatchPool sync.Pool
}
func (c *linuxBatchingConn) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) {
if c.rxOffload {
// UDP_GRO is opt-in on Linux via setsockopt(). Once enabled you may
// receive a "monster datagram" from any read call. The ReadFrom() API
// does not support passing the GSO size and is unsafe to use in such a
// case. Other platforms may vary in behavior, but we go with the most
// conservative approach to prevent this from becoming a footgun in the
// future.
return 0, netip.AddrPort{}, errors.New("rx UDP offload is enabled on this socket, single packet reads are unavailable")
}
return c.pc.ReadFromUDPAddrPort(p)
}
func (c *linuxBatchingConn) SetDeadline(t time.Time) error {
return c.pc.SetDeadline(t)
}
func (c *linuxBatchingConn) SetReadDeadline(t time.Time) error {
return c.pc.SetReadDeadline(t)
}
func (c *linuxBatchingConn) SetWriteDeadline(t time.Time) error {
return c.pc.SetWriteDeadline(t)
}
const (
// This was initially established for Linux, but may split out to
// GOOS-specific values later. It originates as UDP_MAX_SEGMENTS in the
// kernel's TX path, and UDP_GRO_CNT_MAX for RX.
udpSegmentMaxDatagrams = 64
)
const (
// Exceeding these values results in EMSGSIZE.
maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
maxIPv6PayloadLen = 1<<16 - 1 - 8
)
// coalesceMessages iterates 'buffs', setting and coalescing them in 'msgs'
// where possible while maintaining datagram order.
//
// All msgs have their Addr field set to addr.
//
// All msgs[i].Buffers[0] are preceded by a Geneve header (geneve) if geneve.VNI.IsSet().
func (c *linuxBatchingConn) coalesceMessages(addr *net.UDPAddr, geneve packet.GeneveHeader, buffs [][]byte, msgs []ipv6.Message, offset int) int {
var (
base = -1 // index of msg we are currently coalescing into
gsoSize int // segmentation size of msgs[base]
dgramCnt int // number of dgrams coalesced into msgs[base]
endBatch bool // tracking flag to start a new batch on next iteration of buffs
)
maxPayloadLen := maxIPv4PayloadLen
if addr.IP.To4() == nil {
maxPayloadLen = maxIPv6PayloadLen
}
vniIsSet := geneve.VNI.IsSet()
for i, buff := range buffs {
if vniIsSet {
geneve.Encode(buff)
} else {
buff = buff[offset:]
}
if i > 0 {
msgLen := len(buff)
baseLenBefore := len(msgs[base].Buffers[0])
freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
if msgLen+baseLenBefore <= maxPayloadLen &&
msgLen <= gsoSize &&
msgLen <= freeBaseCap &&
dgramCnt < udpSegmentMaxDatagrams &&
!endBatch {
msgs[base].Buffers[0] = append(msgs[base].Buffers[0], make([]byte, msgLen)...)
copy(msgs[base].Buffers[0][baseLenBefore:], buff)
if i == len(buffs)-1 {
c.setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize))
}
dgramCnt++
if msgLen < gsoSize {
// A smaller than gsoSize packet on the tail is legal, but
// it must end the batch.
endBatch = true
}
continue
}
}
if dgramCnt > 1 {
c.setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize))
}
// Reset prior to incrementing base since we are preparing to start a
// new potential batch.
endBatch = false
base++
gsoSize = len(buff)
msgs[base].OOB = msgs[base].OOB[:0]
msgs[base].Buffers[0] = buff
msgs[base].Addr = addr
dgramCnt = 1
}
return base + 1
}
type sendBatch struct {
msgs []ipv6.Message
ua *net.UDPAddr
}
func (c *linuxBatchingConn) getSendBatch() *sendBatch {
batch := c.sendBatchPool.Get().(*sendBatch)
return batch
}
func (c *linuxBatchingConn) putSendBatch(batch *sendBatch) {
for i := range batch.msgs {
batch.msgs[i] = ipv6.Message{Buffers: batch.msgs[i].Buffers, OOB: batch.msgs[i].OOB}
}
c.sendBatchPool.Put(batch)
}
func (c *linuxBatchingConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort, geneve packet.GeneveHeader, offset int) error {
batch := c.getSendBatch()
defer c.putSendBatch(batch)
if addr.Addr().Is6() {
as16 := addr.Addr().As16()
copy(batch.ua.IP, as16[:])
batch.ua.IP = batch.ua.IP[:16]
} else {
as4 := addr.Addr().As4()
copy(batch.ua.IP, as4[:])
batch.ua.IP = batch.ua.IP[:4]
}
batch.ua.Port = int(addr.Port())
var (
n int
retried bool
)
retry:
if c.txOffload.Load() {
n = c.coalesceMessages(batch.ua, geneve, buffs, batch.msgs, offset)
} else {
vniIsSet := geneve.VNI.IsSet()
if vniIsSet {
offset -= packet.GeneveFixedHeaderLength
}
for i := range buffs {
if vniIsSet {
geneve.Encode(buffs[i])
}
batch.msgs[i].Buffers[0] = buffs[i][offset:]
batch.msgs[i].Addr = batch.ua
batch.msgs[i].OOB = batch.msgs[i].OOB[:0]
}
n = len(buffs)
}
err := c.writeBatch(batch.msgs[:n])
if err != nil && c.txOffload.Load() && neterror.ShouldDisableUDPGSO(err) {
c.txOffload.Store(false)
retried = true
goto retry
}
if retried {
return neterror.ErrUDPGSODisabled{OnLaddr: c.pc.LocalAddr().String(), RetryErr: err}
}
return err
}
func (c *linuxBatchingConn) SyscallConn() (syscall.RawConn, error) {
return c.pc.SyscallConn()
}
func (c *linuxBatchingConn) writeBatch(msgs []ipv6.Message) error {
var head int
for {
n, err := c.xpc.WriteBatch(msgs[head:], 0)
if err != nil || n == len(msgs[head:]) {
// Returning the number of packets written would require
// unraveling individual msg len and gso size during a coalesced
// write. The top of the call stack disregards partial success,
// so keep this simple for now.
return err
}
head += n
}
}
// splitCoalescedMessages splits coalesced messages from the tail of dst
// beginning at index 'firstMsgAt' into the head of the same slice. It reports
// the number of elements to evaluate in msgs for nonzero len (msgs[i].N). An
// error is returned if a socket control message cannot be parsed or a split
// operation would overflow msgs.
func (c *linuxBatchingConn) splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int) (n int, err error) {
for i := firstMsgAt; i < len(msgs); i++ {
msg := &msgs[i]
if msg.N == 0 {
return n, err
}
var (
gsoSize int
start int
end = msg.N
numToSplit = 1
)
gsoSize, err = c.getGSOSizeFromControl(msg.OOB[:msg.NN])
if err != nil {
return n, err
}
if gsoSize > 0 {
numToSplit = (msg.N + gsoSize - 1) / gsoSize
end = gsoSize
}
for j := 0; j < numToSplit; j++ {
if n > i {
return n, errors.New("splitting coalesced packet resulted in overflow")
}
copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
msgs[n].N = copied
msgs[n].Addr = msg.Addr
start = end
end += gsoSize
if end > msg.N {
end = msg.N
}
n++
}
if i != n-1 {
// It is legal for bytes to move within msg.Buffers[0] as a result
// of splitting, so we only zero the source msg len when it is not
// the destination of the last split operation above.
msg.N = 0
}
}
return n, nil
}
func (c *linuxBatchingConn) ReadBatch(msgs []ipv6.Message, flags int) (n int, err error) {
if !c.rxOffload || len(msgs) < 2 {
return c.xpc.ReadBatch(msgs, flags)
}
// Read into the tail of msgs, split into the head.
readAt := len(msgs) - 2
numRead, err := c.xpc.ReadBatch(msgs[readAt:], 0)
if err != nil || numRead == 0 {
return 0, err
}
return c.splitCoalescedMessages(msgs, readAt)
}
func (c *linuxBatchingConn) LocalAddr() net.Addr {
return c.pc.LocalAddr().(*net.UDPAddr)
}
func (c *linuxBatchingConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
return c.pc.WriteToUDPAddrPort(b, addr)
}
func (c *linuxBatchingConn) Close() error {
return c.pc.Close()
}
// tryEnableUDPOffload attempts to enable the UDP_GRO socket option on pconn,
// and returns two booleans indicating TX and RX UDP offload support.
func tryEnableUDPOffload(pconn nettype.PacketConn) (hasTX bool, hasRX bool) {
if c, ok := pconn.(*net.UDPConn); ok {
rc, err := c.SyscallConn()
if err != nil {
return
}
err = rc.Control(func(fd uintptr) {
_, errSyscall := syscall.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
hasTX = errSyscall == nil
errSyscall = syscall.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
hasRX = errSyscall == nil
})
if err != nil {
return false, false
}
}
return hasTX, hasRX
}
// getGSOSizeFromControl returns the GSO size found in control. If no GSO size
// is found or the len(control) < unix.SizeofCmsghdr, this function returns 0.
// A non-nil error will be returned if len(control) > unix.SizeofCmsghdr but
// its contents cannot be parsed as a socket control message.
func getGSOSizeFromControl(control []byte) (int, error) {
var (
hdr unix.Cmsghdr
data []byte
rem = control
err error
)
for len(rem) > unix.SizeofCmsghdr {
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
if err != nil {
return 0, fmt.Errorf("error parsing socket control message: %w", err)
}
if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= 2 {
return int(binary.NativeEndian.Uint16(data[:2])), nil
}
}
return 0, nil
}
// setGSOSizeInControl sets a socket control message in control containing
// gsoSize. If len(control) < controlMessageSize control's len will be set to 0.
func setGSOSizeInControl(control *[]byte, gsoSize uint16) {
*control = (*control)[:0]
if cap(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) {
return
}
if cap(*control) < controlMessageSize {
return
}
*control = (*control)[:cap(*control)]
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0]))
hdr.Level = unix.SOL_UDP
hdr.Type = unix.UDP_SEGMENT
hdr.SetLen(unix.CmsgLen(2))
binary.NativeEndian.PutUint16((*control)[unix.SizeofCmsghdr:], gsoSize)
*control = (*control)[:unix.CmsgSpace(2)]
}
// TryUpgradeToConn probes the capabilities of the OS and pconn, and upgrades
// pconn to a [Conn] if appropriate. A batch size of [IdealBatchSize] is
// suggested for the best performance.
func TryUpgradeToConn(pconn nettype.PacketConn, network string, batchSize int) nettype.PacketConn {
if runtime.GOOS != "linux" {
// Exclude Android.
return pconn
}
if network != "udp4" && network != "udp6" {
return pconn
}
if strings.HasPrefix(hostinfo.GetOSVersion(), "2.") {
// recvmmsg/sendmmsg were added in 2.6.33, but we support down to
// 2.6.32 for old NAS devices. See https://github.com/tailscale/tailscale/issues/6807.
// As a cheap heuristic: if the Linux kernel starts with "2", just
// consider it too old for mmsg. Nobody who cares about performance runs
// such ancient kernels. UDP offload was added much later, so no
// upgrades are available.
return pconn
}
uc, ok := pconn.(*net.UDPConn)
if !ok {
return pconn
}
b := &linuxBatchingConn{
pc: uc,
getGSOSizeFromControl: getGSOSizeFromControl,
setGSOSizeInControl: setGSOSizeInControl,
sendBatchPool: sync.Pool{
New: func() any {
ua := &net.UDPAddr{
IP: make([]byte, 16),
}
msgs := make([]ipv6.Message, batchSize)
for i := range msgs {
msgs[i].Buffers = make([][]byte, 1)
msgs[i].Addr = ua
msgs[i].OOB = make([]byte, controlMessageSize)
}
return &sendBatch{
ua: ua,
msgs: msgs,
}
},
},
}
switch network {
case "udp4":
b.xpc = ipv4.NewPacketConn(uc)
case "udp6":
b.xpc = ipv6.NewPacketConn(uc)
default:
panic("bogus network")
}
var txOffload bool
txOffload, b.rxOffload = tryEnableUDPOffload(uc)
b.txOffload.Store(txOffload)
return b
}
var controlMessageSize = -1 // bomb if used for allocation before init
func init() {
// controlMessageSize is set to hold a UDP_GRO or UDP_SEGMENT control
// message. These contain a single uint16 of data.
controlMessageSize = unix.CmsgSpace(2)
}
// MinControlMessageSize returns the minimum control message size required to
// support read batching via [Conn.ReadBatch].
func MinControlMessageSize() int {
return controlMessageSize
}
const IdealBatchSize = 128

View File

@@ -18,6 +18,7 @@ import (
"time"
"tailscale.com/net/netmon"
"tailscale.com/syncs"
"tailscale.com/tailcfg"
"tailscale.com/types/logger"
)
@@ -32,7 +33,7 @@ type Detector struct {
// currIfIndex is the index of the interface that is currently being used by the httpClient.
currIfIndex int
// mu guards currIfIndex.
mu sync.Mutex
mu syncs.Mutex
// logf is the logger used for logging messages. If it is nil, log.Printf is used.
logf logger.Logf
}

View File

@@ -1,222 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package connstats maintains statistics about connections
// flowing through a TUN device (which operate at the IP layer).
package connstats
import (
"context"
"net/netip"
"sync"
"time"
"golang.org/x/sync/errgroup"
"tailscale.com/net/packet"
"tailscale.com/net/tsaddr"
"tailscale.com/types/netlogtype"
)
// Statistics maintains counters for every connection.
// All methods are safe for concurrent use.
// The zero value is ready for use.
type Statistics struct {
maxConns int // immutable once set
mu sync.Mutex
connCnts
connCntsCh chan connCnts
shutdownCtx context.Context
shutdown context.CancelFunc
group errgroup.Group
}
type connCnts struct {
start time.Time
end time.Time
virtual map[netlogtype.Connection]netlogtype.Counts
physical map[netlogtype.Connection]netlogtype.Counts
}
// NewStatistics creates a data structure for tracking connection statistics
// that periodically dumps the virtual and physical connection counts
// depending on whether the maxPeriod or maxConns is exceeded.
// The dump function is called from a single goroutine.
// Shutdown must be called to cleanup resources.
func NewStatistics(maxPeriod time.Duration, maxConns int, dump func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts)) *Statistics {
s := &Statistics{maxConns: maxConns}
s.connCntsCh = make(chan connCnts, 256)
s.shutdownCtx, s.shutdown = context.WithCancel(context.Background())
s.group.Go(func() error {
// TODO(joetsai): Using a ticker is problematic on mobile platforms
// where waking up a process every maxPeriod when there is no activity
// is a drain on battery life. Switch this instead to instead use
// a time.Timer that is triggered upon network activity.
ticker := new(time.Ticker)
if maxPeriod > 0 {
ticker = time.NewTicker(maxPeriod)
defer ticker.Stop()
}
for {
var cc connCnts
select {
case cc = <-s.connCntsCh:
case <-ticker.C:
cc = s.extract()
case <-s.shutdownCtx.Done():
cc = s.extract()
}
if len(cc.virtual)+len(cc.physical) > 0 && dump != nil {
dump(cc.start, cc.end, cc.virtual, cc.physical)
}
if s.shutdownCtx.Err() != nil {
return nil
}
}
})
return s
}
// UpdateTxVirtual updates the counters for a transmitted IP packet
// The source and destination of the packet directly correspond with
// the source and destination in netlogtype.Connection.
func (s *Statistics) UpdateTxVirtual(b []byte) {
s.updateVirtual(b, false)
}
// UpdateRxVirtual updates the counters for a received IP packet.
// The source and destination of the packet are inverted with respect to
// the source and destination in netlogtype.Connection.
func (s *Statistics) UpdateRxVirtual(b []byte) {
s.updateVirtual(b, true)
}
var (
tailscaleServiceIPv4 = tsaddr.TailscaleServiceIP()
tailscaleServiceIPv6 = tsaddr.TailscaleServiceIPv6()
)
func (s *Statistics) updateVirtual(b []byte, receive bool) {
var p packet.Parsed
p.Decode(b)
conn := netlogtype.Connection{Proto: p.IPProto, Src: p.Src, Dst: p.Dst}
if receive {
conn.Src, conn.Dst = conn.Dst, conn.Src
}
// Network logging is defined as traffic between two Tailscale nodes.
// Traffic with the internal Tailscale service is not with another node
// and should not be logged. It also happens to be a high volume
// amount of discrete traffic flows (e.g., DNS lookups).
switch conn.Dst.Addr() {
case tailscaleServiceIPv4, tailscaleServiceIPv6:
return
}
s.mu.Lock()
defer s.mu.Unlock()
cnts, found := s.virtual[conn]
if !found && !s.preInsertConn() {
return
}
if receive {
cnts.RxPackets++
cnts.RxBytes += uint64(len(b))
} else {
cnts.TxPackets++
cnts.TxBytes += uint64(len(b))
}
s.virtual[conn] = cnts
}
// UpdateTxPhysical updates the counters for zero or more transmitted wireguard packets.
// The src is always a Tailscale IP address, representing some remote peer.
// The dst is a remote IP address and port that corresponds
// with some physical peer backing the Tailscale IP address.
func (s *Statistics) UpdateTxPhysical(src netip.Addr, dst netip.AddrPort, packets, bytes int) {
s.updatePhysical(src, dst, packets, bytes, false)
}
// UpdateRxPhysical updates the counters for zero or more received wireguard packets.
// The src is always a Tailscale IP address, representing some remote peer.
// The dst is a remote IP address and port that corresponds
// with some physical peer backing the Tailscale IP address.
func (s *Statistics) UpdateRxPhysical(src netip.Addr, dst netip.AddrPort, packets, bytes int) {
s.updatePhysical(src, dst, packets, bytes, true)
}
func (s *Statistics) updatePhysical(src netip.Addr, dst netip.AddrPort, packets, bytes int, receive bool) {
conn := netlogtype.Connection{Src: netip.AddrPortFrom(src, 0), Dst: dst}
s.mu.Lock()
defer s.mu.Unlock()
cnts, found := s.physical[conn]
if !found && !s.preInsertConn() {
return
}
if receive {
cnts.RxPackets += uint64(packets)
cnts.RxBytes += uint64(bytes)
} else {
cnts.TxPackets += uint64(packets)
cnts.TxBytes += uint64(bytes)
}
s.physical[conn] = cnts
}
// preInsertConn updates the maps to handle insertion of a new connection.
// It reports false if insertion is not allowed (i.e., after shutdown).
func (s *Statistics) preInsertConn() bool {
// Check whether insertion of a new connection will exceed maxConns.
if len(s.virtual)+len(s.physical) == s.maxConns && s.maxConns > 0 {
// Extract the current statistics and send it to the serializer.
// Avoid blocking the network packet handling path.
select {
case s.connCntsCh <- s.extractLocked():
default:
// TODO(joetsai): Log that we are dropping an entire connCounts.
}
}
// Initialize the maps if nil.
if s.virtual == nil && s.physical == nil {
s.start = time.Now().UTC()
s.virtual = make(map[netlogtype.Connection]netlogtype.Counts)
s.physical = make(map[netlogtype.Connection]netlogtype.Counts)
}
return s.shutdownCtx.Err() == nil
}
func (s *Statistics) extract() connCnts {
s.mu.Lock()
defer s.mu.Unlock()
return s.extractLocked()
}
func (s *Statistics) extractLocked() connCnts {
if len(s.virtual)+len(s.physical) == 0 {
return connCnts{}
}
s.end = time.Now().UTC()
cc := s.connCnts
s.connCnts = connCnts{}
return cc
}
// TestExtract synchronously extracts the current network statistics map
// and resets the counters. This should only be used for testing purposes.
func (s *Statistics) TestExtract() (virtual, physical map[netlogtype.Connection]netlogtype.Counts) {
cc := s.extract()
return cc.virtual, cc.physical
}
// Shutdown performs a final flush of statistics.
// Statistics for any subsequent calls to Update will be dropped.
// It is safe to call Shutdown concurrently and repeatedly.
func (s *Statistics) Shutdown(context.Context) error {
s.shutdown()
return s.group.Wait()
}

View File

@@ -1,6 +1,8 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:generate go run tailscale.com/cmd/viewer --type=Config --clonefunc
// Package dns contains code to configure and manage DNS settings.
package dns
@@ -8,8 +10,12 @@ import (
"bufio"
"fmt"
"net/netip"
"reflect"
"slices"
"sort"
"tailscale.com/control/controlknobs"
"tailscale.com/envknob"
"tailscale.com/net/dns/publicdns"
"tailscale.com/net/dns/resolver"
"tailscale.com/net/tsaddr"
@@ -47,11 +53,28 @@ type Config struct {
OnlyIPv6 bool
}
func (c *Config) serviceIP() netip.Addr {
var magicDNSDualStack = envknob.RegisterBool("TS_DEBUG_MAGIC_DNS_DUAL_STACK")
// serviceIPs returns the list of service IPs where MagicDNS is reachable.
//
// The provided knobs may be nil.
func (c *Config) serviceIPs(knobs *controlknobs.Knobs) []netip.Addr {
if c.OnlyIPv6 {
return tsaddr.TailscaleServiceIPv6()
return []netip.Addr{tsaddr.TailscaleServiceIPv6()}
}
return tsaddr.TailscaleServiceIP()
// TODO(bradfitz,mikeodr,raggi): include IPv6 here too; tailscale/tailscale#15404
// And add a controlknobs knob to disable dual stack.
//
// For now, opt-in for testing.
if magicDNSDualStack() {
return []netip.Addr{
tsaddr.TailscaleServiceIP(),
tsaddr.TailscaleServiceIPv6(),
}
}
return []netip.Addr{tsaddr.TailscaleServiceIP()}
}
// WriteToBufioWriter write a debug version of c for logs to w, omitting
@@ -162,21 +185,16 @@ func sameResolverNames(a, b []*dnstype.Resolver) bool {
if a[i].Addr != b[i].Addr {
return false
}
if !sameIPs(a[i].BootstrapResolution, b[i].BootstrapResolution) {
if !slices.Equal(a[i].BootstrapResolution, b[i].BootstrapResolution) {
return false
}
}
return true
}
func sameIPs(a, b []netip.Addr) bool {
if len(a) != len(b) {
return false
func (c *Config) Equal(o *Config) bool {
if c == nil || o == nil {
return c == o
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
return reflect.DeepEqual(c, o)
}

59
vendor/tailscale.com/net/dns/dbus.go generated vendored Normal file
View File

@@ -0,0 +1,59 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build linux && !android && !ts_omit_dbus
package dns
import (
"context"
"time"
"github.com/godbus/dbus/v5"
)
func init() {
optDBusPing.Set(dbusPing)
optDBusReadString.Set(dbusReadString)
}
func dbusPing(name, objectPath string) error {
conn, err := dbus.SystemBus()
if err != nil {
// DBus probably not running.
return err
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
obj := conn.Object(name, dbus.ObjectPath(objectPath))
call := obj.CallWithContext(ctx, "org.freedesktop.DBus.Peer.Ping", 0)
return call.Err
}
// dbusReadString reads a string property from the provided name and object
// path. property must be in "interface.member" notation.
func dbusReadString(name, objectPath, iface, member string) (string, error) {
conn, err := dbus.SystemBus()
if err != nil {
// DBus probably not running.
return "", err
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
obj := conn.Object(name, dbus.ObjectPath(objectPath))
var result dbus.Variant
err = obj.CallWithContext(ctx, "org.freedesktop.DBus.Properties.Get", 0, iface, member).Store(&result)
if err != nil {
return "", err
}
if s, ok := result.Value().(string); ok {
return s, nil
}
return result.String(), nil
}

View File

@@ -1,7 +1,7 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build linux || freebsd || openbsd
//go:build (linux && !android) || freebsd || openbsd
package dns

View File

@@ -1,6 +1,8 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build !android && !ios
package dns
import (
@@ -19,13 +21,16 @@ import (
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"tailscale.com/feature"
"tailscale.com/health"
"tailscale.com/net/dns/resolvconffile"
"tailscale.com/net/tsaddr"
"tailscale.com/types/logger"
"tailscale.com/util/dnsname"
"tailscale.com/util/eventbus"
"tailscale.com/version/distro"
)
@@ -132,6 +137,11 @@ type directManager struct {
// but is better than having non-functioning DNS.
renameBroken bool
trampleCount atomic.Int64
trampleTimer *time.Timer
eventClient *eventbus.Client
trampleDNSPub *eventbus.Publisher[TrampleDNS]
ctx context.Context // valid until Close
ctxClose context.CancelFunc // closes ctx
@@ -142,11 +152,13 @@ type directManager struct {
}
//lint:ignore U1000 used in manager_{freebsd,openbsd}.go
func newDirectManager(logf logger.Logf, health *health.Tracker) *directManager {
return newDirectManagerOnFS(logf, health, directFS{})
func newDirectManager(logf logger.Logf, health *health.Tracker, bus *eventbus.Bus) *directManager {
return newDirectManagerOnFS(logf, health, bus, directFS{})
}
func newDirectManagerOnFS(logf logger.Logf, health *health.Tracker, fs wholeFileFS) *directManager {
var trampleWatchDuration = 5 * time.Second
func newDirectManagerOnFS(logf logger.Logf, health *health.Tracker, bus *eventbus.Bus, fs wholeFileFS) *directManager {
ctx, cancel := context.WithCancel(context.Background())
m := &directManager{
logf: logf,
@@ -155,6 +167,13 @@ func newDirectManagerOnFS(logf logger.Logf, health *health.Tracker, fs wholeFile
ctx: ctx,
ctxClose: cancel,
}
if bus != nil {
m.eventClient = bus.Client("dns.directManager")
m.trampleDNSPub = eventbus.Publish[TrampleDNS](m.eventClient)
}
m.trampleTimer = time.AfterFunc(trampleWatchDuration, func() {
m.trampleCount.Store(0)
})
go m.runFileWatcher()
return m
}
@@ -413,8 +432,91 @@ func (m *directManager) GetBaseConfig() (OSConfig, error) {
return oscfg, nil
}
// HookWatchFile is a hook for watching file changes, for platforms that support it.
// The function is called with a directory and filename to watch, and a callback
// to call when the file changes. It returns an error if the watch could not be set up.
var HookWatchFile feature.Hook[func(ctx context.Context, dir, filename string, cb func()) error]
func (m *directManager) runFileWatcher() {
watchFile, ok := HookWatchFile.GetOk()
if !ok {
return
}
if err := watchFile(m.ctx, "/etc/", resolvConf, m.checkForFileTrample); err != nil {
// This is all best effort for now, so surface warnings to users.
m.logf("dns: inotify: %s", err)
}
}
var resolvTrampleWarnable = health.Register(&health.Warnable{
Code: "resolv-conf-overwritten",
Severity: health.SeverityMedium,
Title: "DNS configuration issue",
Text: health.StaticMessage("System DNS config not ideal. /etc/resolv.conf overwritten. See https://tailscale.com/s/dns-fight"),
})
// checkForFileTrample checks whether /etc/resolv.conf has been trampled
// by another program on the system. (e.g. a DHCP client)
func (m *directManager) checkForFileTrample() {
m.mu.Lock()
want := m.wantResolvConf
lastWarn := m.lastWarnContents
m.mu.Unlock()
if want == nil {
return
}
cur, err := m.fs.ReadFile(resolvConf)
if err != nil {
m.logf("trample: read error: %v", err)
return
}
if bytes.Equal(cur, want) {
m.health.SetHealthy(resolvTrampleWarnable)
if lastWarn != nil {
m.mu.Lock()
m.lastWarnContents = nil
m.mu.Unlock()
m.logf("trample: resolv.conf again matches expected content")
}
return
}
if bytes.Equal(cur, lastWarn) {
// We already logged about this, so not worth doing it again.
return
}
m.mu.Lock()
m.lastWarnContents = cur
m.mu.Unlock()
show := cur
if len(show) > 1024 {
show = show[:1024]
}
m.logf("trample: resolv.conf changed from what we expected. did some other program interfere? current contents: %q", show)
m.health.SetUnhealthy(resolvTrampleWarnable, nil)
if m.trampleDNSPub != nil {
n := m.trampleCount.Add(1)
if n < 10 {
m.trampleDNSPub.Publish(TrampleDNS{
LastTrample: time.Now(),
TramplesInTimeout: n,
})
m.trampleTimer.Reset(trampleWatchDuration)
} else {
m.logf("trample: resolv.conf overwritten %d times, no longer attempting to replace it.", n)
}
}
}
func (m *directManager) Close() error {
m.ctxClose()
if m.eventClient != nil {
m.eventClient.Close()
}
// We used to keep a file for the tailscale config and symlinked
// to it, but then we stopped because /etc/resolv.conf being a

View File

@@ -1,102 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package dns
import (
"bytes"
"context"
"fmt"
"github.com/illarion/gonotify/v3"
"tailscale.com/health"
)
func (m *directManager) runFileWatcher() {
if err := watchFile(m.ctx, "/etc/", resolvConf, m.checkForFileTrample); err != nil {
// This is all best effort for now, so surface warnings to users.
m.logf("dns: inotify: %s", err)
}
}
// watchFile sets up an inotify watch for a given directory and
// calls the callback function every time a particular file is changed.
// The filename should be located in the provided directory.
func watchFile(ctx context.Context, dir, filename string, cb func()) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
const events = gonotify.IN_ATTRIB |
gonotify.IN_CLOSE_WRITE |
gonotify.IN_CREATE |
gonotify.IN_DELETE |
gonotify.IN_MODIFY |
gonotify.IN_MOVE
watcher, err := gonotify.NewDirWatcher(ctx, events, dir)
if err != nil {
return fmt.Errorf("NewDirWatcher: %w", err)
}
for {
select {
case event := <-watcher.C:
if event.Name == filename {
cb()
}
case <-ctx.Done():
return ctx.Err()
}
}
}
var resolvTrampleWarnable = health.Register(&health.Warnable{
Code: "resolv-conf-overwritten",
Severity: health.SeverityMedium,
Title: "Linux DNS configuration issue",
Text: health.StaticMessage("Linux DNS config not ideal. /etc/resolv.conf overwritten. See https://tailscale.com/s/dns-fight"),
})
// checkForFileTrample checks whether /etc/resolv.conf has been trampled
// by another program on the system. (e.g. a DHCP client)
func (m *directManager) checkForFileTrample() {
m.mu.Lock()
want := m.wantResolvConf
lastWarn := m.lastWarnContents
m.mu.Unlock()
if want == nil {
return
}
cur, err := m.fs.ReadFile(resolvConf)
if err != nil {
m.logf("trample: read error: %v", err)
return
}
if bytes.Equal(cur, want) {
m.health.SetHealthy(resolvTrampleWarnable)
if lastWarn != nil {
m.mu.Lock()
m.lastWarnContents = nil
m.mu.Unlock()
m.logf("trample: resolv.conf again matches expected content")
}
return
}
if bytes.Equal(cur, lastWarn) {
// We already logged about this, so not worth doing it again.
return
}
m.mu.Lock()
m.lastWarnContents = cur
m.mu.Unlock()
show := cur
if len(show) > 1024 {
show = show[:1024]
}
m.logf("trample: resolv.conf changed from what we expected. did some other program interfere? current contents: %q", show)
m.health.SetUnhealthy(resolvTrampleWarnable, nil)
}

View File

@@ -1,10 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build !linux
package dns
func (m *directManager) runFileWatcher() {
// Not implemented on other platforms. Maybe it could resort to polling.
}

74
vendor/tailscale.com/net/dns/dns_clone.go generated vendored Normal file
View File

@@ -0,0 +1,74 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Code generated by tailscale.com/cmd/cloner; DO NOT EDIT.
package dns
import (
"net/netip"
"tailscale.com/types/dnstype"
"tailscale.com/util/dnsname"
)
// Clone makes a deep copy of Config.
// The result aliases no memory with the original.
func (src *Config) Clone() *Config {
if src == nil {
return nil
}
dst := new(Config)
*dst = *src
if src.DefaultResolvers != nil {
dst.DefaultResolvers = make([]*dnstype.Resolver, len(src.DefaultResolvers))
for i := range dst.DefaultResolvers {
if src.DefaultResolvers[i] == nil {
dst.DefaultResolvers[i] = nil
} else {
dst.DefaultResolvers[i] = src.DefaultResolvers[i].Clone()
}
}
}
if dst.Routes != nil {
dst.Routes = map[dnsname.FQDN][]*dnstype.Resolver{}
for k := range src.Routes {
dst.Routes[k] = append([]*dnstype.Resolver{}, src.Routes[k]...)
}
}
dst.SearchDomains = append(src.SearchDomains[:0:0], src.SearchDomains...)
if dst.Hosts != nil {
dst.Hosts = map[dnsname.FQDN][]netip.Addr{}
for k := range src.Hosts {
dst.Hosts[k] = append([]netip.Addr{}, src.Hosts[k]...)
}
}
return dst
}
// A compilation failure here means this code must be regenerated, with the command at the top of this file.
var _ConfigCloneNeedsRegeneration = Config(struct {
DefaultResolvers []*dnstype.Resolver
Routes map[dnsname.FQDN][]*dnstype.Resolver
SearchDomains []dnsname.FQDN
Hosts map[dnsname.FQDN][]netip.Addr
OnlyIPv6 bool
}{})
// Clone duplicates src into dst and reports whether it succeeded.
// To succeed, <src, dst> must be of types <*T, *T> or <*T, **T>,
// where T is one of Config.
func Clone(dst, src any) bool {
switch src := src.(type) {
case *Config:
switch dst := dst.(type) {
case *Config:
*dst = *src.Clone()
return true
case **Config:
*dst = src.Clone()
return true
}
}
return false
}

138
vendor/tailscale.com/net/dns/dns_view.go generated vendored Normal file
View File

@@ -0,0 +1,138 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Code generated by tailscale/cmd/viewer; DO NOT EDIT.
package dns
import (
jsonv1 "encoding/json"
"errors"
"net/netip"
jsonv2 "github.com/go-json-experiment/json"
"github.com/go-json-experiment/json/jsontext"
"tailscale.com/types/dnstype"
"tailscale.com/types/views"
"tailscale.com/util/dnsname"
)
//go:generate go run tailscale.com/cmd/cloner -clonefunc=true -type=Config
// View returns a read-only view of Config.
func (p *Config) View() ConfigView {
return ConfigView{ж: p}
}
// ConfigView provides a read-only view over Config.
//
// Its methods should only be called if `Valid()` returns true.
type ConfigView struct {
// ж is the underlying mutable value, named with a hard-to-type
// character that looks pointy like a pointer.
// It is named distinctively to make you think of how dangerous it is to escape
// to callers. You must not let callers be able to mutate it.
ж *Config
}
// Valid reports whether v's underlying value is non-nil.
func (v ConfigView) Valid() bool { return v.ж != nil }
// AsStruct returns a clone of the underlying value which aliases no memory with
// the original.
func (v ConfigView) AsStruct() *Config {
if v.ж == nil {
return nil
}
return v.ж.Clone()
}
// MarshalJSON implements [jsonv1.Marshaler].
func (v ConfigView) MarshalJSON() ([]byte, error) {
return jsonv1.Marshal(v.ж)
}
// MarshalJSONTo implements [jsonv2.MarshalerTo].
func (v ConfigView) MarshalJSONTo(enc *jsontext.Encoder) error {
return jsonv2.MarshalEncode(enc, v.ж)
}
// UnmarshalJSON implements [jsonv1.Unmarshaler].
func (v *ConfigView) UnmarshalJSON(b []byte) error {
if v.ж != nil {
return errors.New("already initialized")
}
if len(b) == 0 {
return nil
}
var x Config
if err := jsonv1.Unmarshal(b, &x); err != nil {
return err
}
v.ж = &x
return nil
}
// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom].
func (v *ConfigView) UnmarshalJSONFrom(dec *jsontext.Decoder) error {
if v.ж != nil {
return errors.New("already initialized")
}
var x Config
if err := jsonv2.UnmarshalDecode(dec, &x); err != nil {
return err
}
v.ж = &x
return nil
}
// DefaultResolvers are the DNS resolvers to use for DNS names
// which aren't covered by more specific per-domain routes below.
// If empty, the OS's default resolvers (the ones that predate
// Tailscale altering the configuration) are used.
func (v ConfigView) DefaultResolvers() views.SliceView[*dnstype.Resolver, dnstype.ResolverView] {
return views.SliceOfViews[*dnstype.Resolver, dnstype.ResolverView](v.ж.DefaultResolvers)
}
// Routes maps a DNS suffix to the resolvers that should be used
// for queries that fall within that suffix.
// If a query doesn't match any entry in Routes, the
// DefaultResolvers are used.
// A Routes entry with no resolvers means the route should be
// authoritatively answered using the contents of Hosts.
func (v ConfigView) Routes() views.MapFn[dnsname.FQDN, []*dnstype.Resolver, views.SliceView[*dnstype.Resolver, dnstype.ResolverView]] {
return views.MapFnOf(v.ж.Routes, func(t []*dnstype.Resolver) views.SliceView[*dnstype.Resolver, dnstype.ResolverView] {
return views.SliceOfViews[*dnstype.Resolver, dnstype.ResolverView](t)
})
}
// SearchDomains are DNS suffixes to try when expanding
// single-label queries.
func (v ConfigView) SearchDomains() views.Slice[dnsname.FQDN] {
return views.SliceOf(v.ж.SearchDomains)
}
// Hosts maps DNS FQDNs to their IPs, which can be a mix of IPv4
// and IPv6.
// Queries matching entries in Hosts are resolved locally by
// 100.100.100.100 without leaving the machine.
// Adding an entry to Hosts merely creates the record. If you want
// it to resolve, you also need to add appropriate routes to
// Routes.
func (v ConfigView) Hosts() views.MapSlice[dnsname.FQDN, netip.Addr] {
return views.MapSliceOf(v.ж.Hosts)
}
// OnlyIPv6, if true, uses the IPv6 service IP (for MagicDNS)
// instead of the IPv4 version (100.100.100.100).
func (v ConfigView) OnlyIPv6() bool { return v.ж.OnlyIPv6 }
func (v ConfigView) Equal(v2 ConfigView) bool { return v.ж.Equal(v2.ж) }
// A compilation failure here means this code must be regenerated, with the command at the top of this file.
var _ConfigViewNeedsRegeneration = Config(struct {
DefaultResolvers []*dnstype.Resolver
Routes map[dnsname.FQDN][]*dnstype.Resolver
SearchDomains []dnsname.FQDN
Hosts map[dnsname.FQDN][]netip.Addr
OnlyIPv6 bool
}{})

View File

@@ -20,17 +20,19 @@ import (
"time"
"tailscale.com/control/controlknobs"
"tailscale.com/feature/buildfeatures"
"tailscale.com/health"
"tailscale.com/net/dns/resolver"
"tailscale.com/net/netmon"
"tailscale.com/net/tsdial"
"tailscale.com/syncs"
"tailscale.com/tstime/rate"
"tailscale.com/types/dnstype"
"tailscale.com/types/logger"
"tailscale.com/util/clientmetric"
"tailscale.com/util/dnsname"
"tailscale.com/util/eventbus"
"tailscale.com/util/slicesx"
"tailscale.com/util/syspolicy/policyclient"
)
var (
@@ -53,6 +55,8 @@ type Manager struct {
logf logger.Logf
health *health.Tracker
eventClient *eventbus.Client
activeQueriesAtomic int32
ctx context.Context // good until Down
@@ -63,16 +67,17 @@ type Manager struct {
knobs *controlknobs.Knobs // or nil
goos string // if empty, gets set to runtime.GOOS
mu sync.Mutex // guards following
// config is the last configuration we successfully compiled or nil if there
// was any failure applying the last configuration.
config *Config
mu sync.Mutex // guards following
config *Config // Tracks the last viable DNS configuration set by Set. nil on failures other than compilation failures or if set has never been called.
}
// NewManagers created a new manager from the given config.
// NewManager created a new manager from the given config.
//
// knobs may be nil.
func NewManager(logf logger.Logf, oscfg OSConfigurator, health *health.Tracker, dialer *tsdial.Dialer, linkSel resolver.ForwardLinkSelector, knobs *controlknobs.Knobs, goos string) *Manager {
func NewManager(logf logger.Logf, oscfg OSConfigurator, health *health.Tracker, dialer *tsdial.Dialer, linkSel resolver.ForwardLinkSelector, knobs *controlknobs.Knobs, goos string, bus *eventbus.Bus) *Manager {
if !buildfeatures.HasDNS {
return nil
}
if dialer == nil {
panic("nil Dialer")
}
@@ -93,19 +98,17 @@ func NewManager(logf logger.Logf, oscfg OSConfigurator, health *health.Tracker,
goos: goos,
}
// Rate limit our attempts to correct our DNS configuration.
// This is done on incoming queries, we don't want to spam it.
limiter := rate.NewLimiter(1.0/5.0, 1)
// This will recompile the DNS config, which in turn will requery the system
// DNS settings. The recovery func should triggered only when we are missing
// upstream nameservers and require them to forward a query.
m.resolver.SetMissingUpstreamRecovery(func() {
if limiter.Allow() {
m.logf("resolution failed due to missing upstream nameservers. Recompiling DNS configuration.")
if err := m.RecompileDNSConfig(); err != nil {
m.logf("config recompilation failed: %v", err)
}
m.eventClient = bus.Client("dns.Manager")
eventbus.SubscribeFunc(m.eventClient, func(trample TrampleDNS) {
m.mu.Lock()
defer m.mu.Unlock()
if m.config == nil {
m.logf("resolve.conf was trampled, but there is no DNS config")
return
}
m.logf("resolve.conf was trampled, setting existing config again")
if err := m.setLocked(*m.config); err != nil {
m.logf("error setting DNS config: %s", err)
}
})
@@ -115,9 +118,14 @@ func NewManager(logf logger.Logf, oscfg OSConfigurator, health *health.Tracker,
}
// Resolver returns the Manager's DNS Resolver.
func (m *Manager) Resolver() *resolver.Resolver { return m.resolver }
func (m *Manager) Resolver() *resolver.Resolver {
if !buildfeatures.HasDNS {
return nil
}
return m.resolver
}
// RecompileDNSConfig sets the DNS config to the current value, which has
// RecompileDNSConfig recompiles the last attempted DNS configuration, which has
// the side effect of re-querying the OS's interface nameservers. This should be used
// on platforms where the interface nameservers can change. Darwin, for example,
// where the nameservers aren't always available when we process a major interface
@@ -127,17 +135,23 @@ func (m *Manager) Resolver() *resolver.Resolver { return m.resolver }
// give a better or different result than when [Manager.Set] was last called. The
// logic for making that determination is up to the caller.
//
// It returns [ErrNoDNSConfig] if the [Manager] has no existing DNS configuration.
// It returns [ErrNoDNSConfig] if [Manager.Set] has never been called.
func (m *Manager) RecompileDNSConfig() error {
if !buildfeatures.HasDNS {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
if m.config == nil {
return ErrNoDNSConfig
if m.config != nil {
return m.setLocked(*m.config)
}
return m.setLocked(*m.config)
return ErrNoDNSConfig
}
func (m *Manager) Set(cfg Config) error {
if !buildfeatures.HasDNS {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
return m.setLocked(cfg)
@@ -145,6 +159,9 @@ func (m *Manager) Set(cfg Config) error {
// GetBaseConfig returns the current base OS DNS configuration as provided by the OSConfigurator.
func (m *Manager) GetBaseConfig() (OSConfig, error) {
if !buildfeatures.HasDNS {
panic("unreachable")
}
return m.os.GetBaseConfig()
}
@@ -154,15 +171,15 @@ func (m *Manager) GetBaseConfig() (OSConfig, error) {
func (m *Manager) setLocked(cfg Config) error {
syncs.AssertLocked(&m.mu)
// On errors, the 'set' config is cleared.
m.config = nil
m.logf("Set: %v", logger.ArgWriter(func(w *bufio.Writer) {
cfg.WriteToBufioWriter(w)
}))
rcfg, ocfg, err := m.compileConfig(cfg)
if err != nil {
// On a compilation failure, set m.config set for later reuse by
// [Manager.RecompileDNSConfig] and return the error.
m.config = &cfg
return err
}
@@ -174,10 +191,10 @@ func (m *Manager) setLocked(cfg Config) error {
}))
if err := m.resolver.SetConfig(rcfg); err != nil {
m.config = nil
return err
}
if err := m.os.SetDNS(ocfg); err != nil {
m.health.SetUnhealthy(osConfigurationSetWarnable, health.Args{health.ArgError: err.Error()})
if err := m.setDNSLocked(ocfg); err != nil {
return err
}
@@ -187,6 +204,15 @@ func (m *Manager) setLocked(cfg Config) error {
return nil
}
func (m *Manager) setDNSLocked(ocfg OSConfig) error {
if err := m.os.SetDNS(ocfg); err != nil {
m.config = nil
m.health.SetUnhealthy(osConfigurationSetWarnable, health.Args{health.ArgError: err.Error()})
return err
}
return nil
}
// compileHostEntries creates a list of single-label resolutions possible
// from the configured hosts and search domains.
// The entries are compiled in the order of the search domains, then the hosts.
@@ -284,7 +310,7 @@ func (m *Manager) compileConfig(cfg Config) (rcfg resolver.Config, ocfg OSConfig
// Deal with trivial configs first.
switch {
case !cfg.needsOSResolver():
case !cfg.needsOSResolver() || runtime.GOOS == "plan9":
// Set search domains, but nothing else. This also covers the
// case where cfg is entirely zero, in which case these
// configs clear all Tailscale DNS settings.
@@ -307,7 +333,7 @@ func (m *Manager) compileConfig(cfg Config) (rcfg resolver.Config, ocfg OSConfig
// through quad-100.
rcfg.Routes = routes
rcfg.Routes["."] = cfg.DefaultResolvers
ocfg.Nameservers = []netip.Addr{cfg.serviceIP()}
ocfg.Nameservers = cfg.serviceIPs(m.knobs)
return rcfg, ocfg, nil
}
@@ -345,7 +371,7 @@ func (m *Manager) compileConfig(cfg Config) (rcfg resolver.Config, ocfg OSConfig
// or routes + MagicDNS, or just MagicDNS, or on an OS that cannot
// split-DNS. Install a split config pointing at quad-100.
rcfg.Routes = routes
ocfg.Nameservers = []netip.Addr{cfg.serviceIP()}
ocfg.Nameservers = cfg.serviceIPs(m.knobs)
var baseCfg *OSConfig // base config; non-nil if/when known
@@ -355,7 +381,10 @@ func (m *Manager) compileConfig(cfg Config) (rcfg resolver.Config, ocfg OSConfig
// that as the forwarder for all DNS traffic that quad-100 doesn't handle.
if isApple || !m.os.SupportsSplitDNS() {
// If the OS can't do native split-dns, read out the underlying
// resolver config and blend it into our config.
// resolver config and blend it into our config. On apple platforms, [OSConfigurator.GetBaseConfig]
// has a tendency to temporarily fail if called immediately following
// an interface change. These failures should be retried if/when the OS
// indicates that the DNS configuration has changed via [RecompileDNSConfig].
cfg, err := m.os.GetBaseConfig()
if err == nil {
baseCfg = &cfg
@@ -451,6 +480,13 @@ const (
maxReqSizeTCP = 4096
)
// TrampleDNS is an an event indicating we detected that DNS config was
// overwritten by another process.
type TrampleDNS struct {
LastTrample time.Time
TramplesInTimeout int64
}
// dnsTCPSession services DNS requests sent over TCP.
type dnsTCPSession struct {
m *Manager
@@ -572,15 +608,22 @@ func (m *Manager) HandleTCPConn(conn net.Conn, srcAddr netip.AddrPort) {
}
func (m *Manager) Down() error {
if !buildfeatures.HasDNS {
return nil
}
m.ctxCancel()
if err := m.os.Close(); err != nil {
return err
}
m.eventClient.Close()
m.resolver.Close()
return nil
}
func (m *Manager) FlushCaches() error {
if !buildfeatures.HasDNS {
return nil
}
return flushCaches()
}
@@ -589,20 +632,22 @@ func (m *Manager) FlushCaches() error {
// No other state needs to be instantiated before this runs.
//
// health must not be nil
func CleanUp(logf logger.Logf, netMon *netmon.Monitor, health *health.Tracker, interfaceName string) {
oscfg, err := NewOSConfigurator(logf, nil, nil, interfaceName)
func CleanUp(logf logger.Logf, netMon *netmon.Monitor, bus *eventbus.Bus, health *health.Tracker, interfaceName string) {
if !buildfeatures.HasDNS {
return
}
oscfg, err := NewOSConfigurator(logf, health, bus, policyclient.Get(), nil, interfaceName)
if err != nil {
logf("creating dns cleanup: %v", err)
return
}
d := &tsdial.Dialer{Logf: logf}
d.SetNetMon(netMon)
dns := NewManager(logf, oscfg, health, d, nil, nil, runtime.GOOS)
d.SetBus(bus)
dns := NewManager(logf, oscfg, health, d, nil, nil, runtime.GOOS, bus)
if err := dns.Down(); err != nil {
logf("dns down: %v", err)
}
}
var (
metricDNSQueryErrorQueue = clientmetric.NewCounter("dns_query_local_error_queue")
)
var metricDNSQueryErrorQueue = clientmetric.NewCounter("dns_query_local_error_queue")

View File

@@ -13,13 +13,15 @@ import (
"tailscale.com/net/dns/resolvconffile"
"tailscale.com/net/tsaddr"
"tailscale.com/types/logger"
"tailscale.com/util/eventbus"
"tailscale.com/util/mak"
"tailscale.com/util/syspolicy/policyclient"
)
// NewOSConfigurator creates a new OS configurator.
//
// The health tracker and the knobs may be nil and are ignored on this platform.
func NewOSConfigurator(logf logger.Logf, _ *health.Tracker, _ *controlknobs.Knobs, ifName string) (OSConfigurator, error) {
// The health tracker, bus and the knobs may be nil and are ignored on this platform.
func NewOSConfigurator(logf logger.Logf, _ *health.Tracker, _ *eventbus.Bus, _ policyclient.Client, _ *controlknobs.Knobs, ifName string) (OSConfigurator, error) {
return &darwinConfigurator{logf: logf, ifName: ifName}, nil
}

View File

@@ -1,7 +1,7 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build !linux && !freebsd && !openbsd && !windows && !darwin && !illumos && !solaris
//go:build (!linux || android) && !freebsd && !openbsd && !windows && !darwin && !illumos && !solaris && !plan9
package dns
@@ -9,11 +9,13 @@ import (
"tailscale.com/control/controlknobs"
"tailscale.com/health"
"tailscale.com/types/logger"
"tailscale.com/util/eventbus"
"tailscale.com/util/syspolicy/policyclient"
)
// NewOSConfigurator creates a new OS configurator.
//
// The health tracker and the knobs may be nil and are ignored on this platform.
func NewOSConfigurator(logger.Logf, *health.Tracker, *controlknobs.Knobs, string) (OSConfigurator, error) {
func NewOSConfigurator(logger.Logf, *health.Tracker, *eventbus.Bus, policyclient.Client, *controlknobs.Knobs, string) (OSConfigurator, error) {
return NewNoopManager()
}

View File

@@ -10,15 +10,17 @@ import (
"tailscale.com/control/controlknobs"
"tailscale.com/health"
"tailscale.com/types/logger"
"tailscale.com/util/eventbus"
"tailscale.com/util/syspolicy/policyclient"
)
// NewOSConfigurator creates a new OS configurator.
//
// The health tracker may be nil; the knobs may be nil and are ignored on this platform.
func NewOSConfigurator(logf logger.Logf, health *health.Tracker, _ *controlknobs.Knobs, _ string) (OSConfigurator, error) {
func NewOSConfigurator(logf logger.Logf, health *health.Tracker, bus *eventbus.Bus, _ policyclient.Client, _ *controlknobs.Knobs, _ string) (OSConfigurator, error) {
bs, err := os.ReadFile("/etc/resolv.conf")
if os.IsNotExist(err) {
return newDirectManager(logf, health), nil
return newDirectManager(logf, health, bus), nil
}
if err != nil {
return nil, fmt.Errorf("reading /etc/resolv.conf: %w", err)
@@ -28,16 +30,16 @@ func NewOSConfigurator(logf logger.Logf, health *health.Tracker, _ *controlknobs
case "resolvconf":
switch resolvconfStyle() {
case "":
return newDirectManager(logf, health), nil
return newDirectManager(logf, health, bus), nil
case "debian":
return newDebianResolvconfManager(logf)
case "openresolv":
return newOpenresolvManager(logf)
default:
logf("[unexpected] got unknown flavor of resolvconf %q, falling back to direct manager", resolvconfStyle())
return newDirectManager(logf, health), nil
return newDirectManager(logf, health, bus), nil
}
default:
return newDirectManager(logf, health), nil
return newDirectManager(logf, health, bus), nil
}
}

View File

@@ -1,11 +1,12 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build linux && !android
package dns
import (
"bytes"
"context"
"errors"
"fmt"
"os"
@@ -13,13 +14,16 @@ import (
"sync"
"time"
"github.com/godbus/dbus/v5"
"tailscale.com/control/controlknobs"
"tailscale.com/feature"
"tailscale.com/feature/buildfeatures"
"tailscale.com/health"
"tailscale.com/net/netaddr"
"tailscale.com/types/logger"
"tailscale.com/util/clientmetric"
"tailscale.com/util/cmpver"
"tailscale.com/util/eventbus"
"tailscale.com/util/syspolicy/policyclient"
"tailscale.com/version/distro"
)
type kv struct {
@@ -32,18 +36,59 @@ func (kv kv) String() string {
var publishOnce sync.Once
// reconfigTimeout is the time interval within which Manager.{Up,Down} should complete.
//
// This is particularly useful because certain conditions can cause indefinite hangs
// (such as improper dbus auth followed by contextless dbus.Object.Call).
// Such operations should be wrapped in a timeout context.
const reconfigTimeout = time.Second
// Set unless ts_omit_networkmanager
var (
optNewNMManager feature.Hook[func(ifName string) (OSConfigurator, error)]
optNMIsUsingResolved feature.Hook[func() error]
optNMVersionBetween feature.Hook[func(v1, v2 string) (bool, error)]
)
// Set unless ts_omit_resolved
var (
optNewResolvedManager feature.Hook[func(logf logger.Logf, health *health.Tracker, interfaceName string) (OSConfigurator, error)]
)
// Set unless ts_omit_dbus
var (
optDBusPing feature.Hook[func(name, objectPath string) error]
optDBusReadString feature.Hook[func(name, objectPath, iface, member string) (string, error)]
)
// NewOSConfigurator created a new OS configurator.
//
// The health tracker may be nil; the knobs may be nil and are ignored on this platform.
func NewOSConfigurator(logf logger.Logf, health *health.Tracker, _ *controlknobs.Knobs, interfaceName string) (ret OSConfigurator, err error) {
env := newOSConfigEnv{
fs: directFS{},
dbusPing: dbusPing,
dbusReadString: dbusReadString,
nmIsUsingResolved: nmIsUsingResolved,
nmVersionBetween: nmVersionBetween,
resolvconfStyle: resolvconfStyle,
func NewOSConfigurator(logf logger.Logf, health *health.Tracker, bus *eventbus.Bus, _ policyclient.Client, _ *controlknobs.Knobs, interfaceName string) (ret OSConfigurator, err error) {
if !buildfeatures.HasDNS || distro.Get() == distro.JetKVM {
return NewNoopManager()
}
env := newOSConfigEnv{
fs: directFS{},
resolvconfStyle: resolvconfStyle,
}
if f, ok := optDBusPing.GetOk(); ok {
env.dbusPing = f
} else {
env.dbusPing = func(_, _ string) error { return errors.ErrUnsupported }
}
if f, ok := optDBusReadString.GetOk(); ok {
env.dbusReadString = f
} else {
env.dbusReadString = func(_, _, _, _ string) (string, error) { return "", errors.ErrUnsupported }
}
if f, ok := optNMIsUsingResolved.GetOk(); ok {
env.nmIsUsingResolved = f
} else {
env.nmIsUsingResolved = func() error { return errors.ErrUnsupported }
}
env.nmVersionBetween, _ = optNMVersionBetween.GetOk() // GetOk to not panic if nil; unused if optNMIsUsingResolved returns an error
mode, err := dnsMode(logf, health, env)
if err != nil {
return nil, err
@@ -56,19 +101,26 @@ func NewOSConfigurator(logf logger.Logf, health *health.Tracker, _ *controlknobs
logf("dns: using %q mode", mode)
switch mode {
case "direct":
return newDirectManagerOnFS(logf, health, env.fs), nil
return newDirectManagerOnFS(logf, health, bus, env.fs), nil
case "systemd-resolved":
return newResolvedManager(logf, health, interfaceName)
if f, ok := optNewResolvedManager.GetOk(); ok {
return f(logf, health, interfaceName)
}
return nil, fmt.Errorf("tailscaled was built without DNS %q support", mode)
case "network-manager":
return newNMManager(interfaceName)
if f, ok := optNewNMManager.GetOk(); ok {
return f(interfaceName)
}
return nil, fmt.Errorf("tailscaled was built without DNS %q support", mode)
case "debian-resolvconf":
return newDebianResolvconfManager(logf)
case "openresolv":
return newOpenresolvManager(logf)
default:
logf("[unexpected] detected unknown DNS mode %q, using direct manager as last resort", mode)
return newDirectManagerOnFS(logf, health, env.fs), nil
}
return newDirectManagerOnFS(logf, health, bus, env.fs), nil
}
// newOSConfigEnv are the funcs newOSConfigurator needs, pulled out for testing.
@@ -284,50 +336,6 @@ func dnsMode(logf logger.Logf, health *health.Tracker, env newOSConfigEnv) (ret
}
}
func nmVersionBetween(first, last string) (bool, error) {
conn, err := dbus.SystemBus()
if err != nil {
// DBus probably not running.
return false, err
}
nm := conn.Object("org.freedesktop.NetworkManager", dbus.ObjectPath("/org/freedesktop/NetworkManager"))
v, err := nm.GetProperty("org.freedesktop.NetworkManager.Version")
if err != nil {
return false, err
}
version, ok := v.Value().(string)
if !ok {
return false, fmt.Errorf("unexpected type %T for NM version", v.Value())
}
outside := cmpver.Compare(version, first) < 0 || cmpver.Compare(version, last) > 0
return !outside, nil
}
func nmIsUsingResolved() error {
conn, err := dbus.SystemBus()
if err != nil {
// DBus probably not running.
return err
}
nm := conn.Object("org.freedesktop.NetworkManager", dbus.ObjectPath("/org/freedesktop/NetworkManager/DnsManager"))
v, err := nm.GetProperty("org.freedesktop.NetworkManager.DnsManager.Mode")
if err != nil {
return fmt.Errorf("getting NM mode: %w", err)
}
mode, ok := v.Value().(string)
if !ok {
return fmt.Errorf("unexpected type %T for NM DNS mode", v.Value())
}
if mode != "systemd-resolved" {
return errors.New("NetworkManager is not using systemd-resolved for DNS")
}
return nil
}
// resolvedIsActuallyResolver reports whether the system is using
// systemd-resolved as the resolver. There are two different ways to
// use systemd-resolved:
@@ -388,44 +396,3 @@ func isLibnssResolveUsed(env newOSConfigEnv) error {
}
return fmt.Errorf("libnss_resolve not used")
}
func dbusPing(name, objectPath string) error {
conn, err := dbus.SystemBus()
if err != nil {
// DBus probably not running.
return err
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
obj := conn.Object(name, dbus.ObjectPath(objectPath))
call := obj.CallWithContext(ctx, "org.freedesktop.DBus.Peer.Ping", 0)
return call.Err
}
// dbusReadString reads a string property from the provided name and object
// path. property must be in "interface.member" notation.
func dbusReadString(name, objectPath, iface, member string) (string, error) {
conn, err := dbus.SystemBus()
if err != nil {
// DBus probably not running.
return "", err
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
obj := conn.Object(name, dbus.ObjectPath(objectPath))
var result dbus.Variant
err = obj.CallWithContext(ctx, "org.freedesktop.DBus.Properties.Get", 0, iface, member).Store(&result)
if err != nil {
return "", err
}
if s, ok := result.Value().(string); ok {
return s, nil
}
return result.String(), nil
}

View File

@@ -11,6 +11,8 @@ import (
"tailscale.com/control/controlknobs"
"tailscale.com/health"
"tailscale.com/types/logger"
"tailscale.com/util/eventbus"
"tailscale.com/util/syspolicy/policyclient"
)
type kv struct {
@@ -24,8 +26,8 @@ func (kv kv) String() string {
// NewOSConfigurator created a new OS configurator.
//
// The health tracker may be nil; the knobs may be nil and are ignored on this platform.
func NewOSConfigurator(logf logger.Logf, health *health.Tracker, _ *controlknobs.Knobs, interfaceName string) (OSConfigurator, error) {
return newOSConfigurator(logf, health, interfaceName,
func NewOSConfigurator(logf logger.Logf, health *health.Tracker, bus *eventbus.Bus, _ policyclient.Client, _ *controlknobs.Knobs, interfaceName string) (OSConfigurator, error) {
return newOSConfigurator(logf, health, bus, interfaceName,
newOSConfigEnv{
rcIsResolvd: rcIsResolvd,
fs: directFS{},
@@ -38,7 +40,7 @@ type newOSConfigEnv struct {
rcIsResolvd func(resolvConfContents []byte) bool
}
func newOSConfigurator(logf logger.Logf, health *health.Tracker, interfaceName string, env newOSConfigEnv) (ret OSConfigurator, err error) {
func newOSConfigurator(logf logger.Logf, health *health.Tracker, bus *eventbus.Bus, interfaceName string, env newOSConfigEnv) (ret OSConfigurator, err error) {
var debug []kv
dbg := func(k, v string) {
debug = append(debug, kv{k, v})
@@ -53,7 +55,7 @@ func newOSConfigurator(logf logger.Logf, health *health.Tracker, interfaceName s
bs, err := env.fs.ReadFile(resolvConf)
if os.IsNotExist(err) {
dbg("rc", "missing")
return newDirectManager(logf, health), nil
return newDirectManager(logf, health, bus), nil
}
if err != nil {
return nil, fmt.Errorf("reading /etc/resolv.conf: %w", err)
@@ -65,7 +67,7 @@ func newOSConfigurator(logf logger.Logf, health *health.Tracker, interfaceName s
}
dbg("resolvd", "missing")
return newDirectManager(logf, health), nil
return newDirectManager(logf, health, bus), nil
}
func rcIsResolvd(resolvConfContents []byte) bool {

183
vendor/tailscale.com/net/dns/manager_plan9.go generated vendored Normal file
View File

@@ -0,0 +1,183 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// TODO: man 6 ndb | grep -e 'suffix.*same line'
// to detect Russ's https://9fans.topicbox.com/groups/9fans/T9c9d81b5801a0820/ndb-suffix-specific-dns-changes
package dns
import (
"bufio"
"bytes"
"fmt"
"io"
"net/netip"
"os"
"regexp"
"strings"
"unicode"
"tailscale.com/control/controlknobs"
"tailscale.com/health"
"tailscale.com/types/logger"
"tailscale.com/util/eventbus"
"tailscale.com/util/set"
"tailscale.com/util/syspolicy/policyclient"
)
func NewOSConfigurator(logf logger.Logf, ht *health.Tracker, _ *eventbus.Bus, _ policyclient.Client, knobs *controlknobs.Knobs, interfaceName string) (OSConfigurator, error) {
return &plan9DNSManager{
logf: logf,
ht: ht,
knobs: knobs,
}, nil
}
type plan9DNSManager struct {
logf logger.Logf
ht *health.Tracker
knobs *controlknobs.Knobs
}
// netNDBBytesWithoutTailscale returns raw (the contents of /net/ndb) with any
// Tailscale bits removed.
func netNDBBytesWithoutTailscale(raw []byte) ([]byte, error) {
var ret bytes.Buffer
bs := bufio.NewScanner(bytes.NewReader(raw))
removeLine := set.Set[string]{}
for bs.Scan() {
t := bs.Text()
if rest, ok := strings.CutPrefix(t, "#tailscaled-added-line:"); ok {
removeLine.Add(strings.TrimSpace(rest))
continue
}
trimmed := strings.TrimSpace(t)
if removeLine.Contains(trimmed) {
removeLine.Delete(trimmed)
continue
}
// Also remove any DNS line referencing *.ts.net. This is
// Tailscale-specific (and won't work with, say, Headscale), but
// the Headscale case will be covered by the #tailscaled-added-line
// logic above, assuming the user didn't delete those comments.
if (strings.HasPrefix(trimmed, "dns=") || strings.Contains(trimmed, "dnsdomain=")) &&
strings.HasSuffix(trimmed, ".ts.net") {
continue
}
ret.WriteString(t)
ret.WriteByte('\n')
}
return ret.Bytes(), bs.Err()
}
// setNDBSuffix adds lines to tsFree (the contents of /net/ndb already cleaned
// of Tailscale-added lines) to add the optional DNS search domain (e.g.
// "foo.ts.net") and DNS server to it.
func setNDBSuffix(tsFree []byte, suffix string) []byte {
suffix = strings.TrimSuffix(suffix, ".")
if suffix == "" {
return tsFree
}
var buf bytes.Buffer
bs := bufio.NewScanner(bytes.NewReader(tsFree))
var added []string
addLine := func(s string) {
added = append(added, strings.TrimSpace(s))
buf.WriteString(s)
}
for bs.Scan() {
buf.Write(bs.Bytes())
buf.WriteByte('\n')
t := bs.Text()
if suffix != "" && len(added) == 0 && strings.HasPrefix(t, "\tdns=") {
addLine(fmt.Sprintf("\tdns=100.100.100.100 suffix=%s\n", suffix))
addLine(fmt.Sprintf("\tdnsdomain=%s\n", suffix))
}
}
bufTrim := bytes.TrimLeftFunc(buf.Bytes(), unicode.IsSpace)
if len(added) == 0 {
return bufTrim
}
var ret bytes.Buffer
for _, s := range added {
ret.WriteString("#tailscaled-added-line: ")
ret.WriteString(s)
ret.WriteString("\n")
}
ret.WriteString("\n")
ret.Write(bufTrim)
return ret.Bytes()
}
func (m *plan9DNSManager) SetDNS(c OSConfig) error {
ndbOnDisk, err := os.ReadFile("/net/ndb")
if err != nil {
return err
}
tsFree, err := netNDBBytesWithoutTailscale(ndbOnDisk)
if err != nil {
return err
}
var suffix string
if len(c.SearchDomains) > 0 {
suffix = string(c.SearchDomains[0])
}
newBuf := setNDBSuffix(tsFree, suffix)
if !bytes.Equal(newBuf, ndbOnDisk) {
if err := os.WriteFile("/net/ndb", newBuf, 0644); err != nil {
return fmt.Errorf("writing /net/ndb: %w", err)
}
if f, err := os.OpenFile("/net/dns", os.O_RDWR, 0); err == nil {
if _, err := io.WriteString(f, "refresh\n"); err != nil {
f.Close()
return fmt.Errorf("/net/dns refresh write: %w", err)
}
if err := f.Close(); err != nil {
return fmt.Errorf("/net/dns refresh close: %w", err)
}
}
}
return nil
}
func (m *plan9DNSManager) SupportsSplitDNS() bool { return false }
func (m *plan9DNSManager) Close() error {
// TODO(bradfitz): remove the Tailscale bits from /net/ndb ideally
return nil
}
var dnsRegex = regexp.MustCompile(`\bdns=(\d+\.\d+\.\d+\.\d+)\b`)
func (m *plan9DNSManager) GetBaseConfig() (OSConfig, error) {
var oc OSConfig
f, err := os.Open("/net/ndb")
if err != nil {
return oc, err
}
defer f.Close()
bs := bufio.NewScanner(f)
for bs.Scan() {
m := dnsRegex.FindSubmatch(bs.Bytes())
if m == nil {
continue
}
addr, err := netip.ParseAddr(string(m[1]))
if err != nil {
continue
}
oc.Nameservers = append(oc.Nameservers, addr)
}
if err := bs.Err(); err != nil {
return oc, err
}
return oc, nil
}

View File

@@ -7,8 +7,10 @@ import (
"tailscale.com/control/controlknobs"
"tailscale.com/health"
"tailscale.com/types/logger"
"tailscale.com/util/eventbus"
"tailscale.com/util/syspolicy/policyclient"
)
func NewOSConfigurator(logf logger.Logf, health *health.Tracker, _ *controlknobs.Knobs, iface string) (OSConfigurator, error) {
return newDirectManager(logf, health), nil
func NewOSConfigurator(logf logger.Logf, health *health.Tracker, bus *eventbus.Bus, _ policyclient.Client, _ *controlknobs.Knobs, iface string) (OSConfigurator, error) {
return newDirectManager(logf, health, bus), nil
}

View File

@@ -16,7 +16,6 @@ import (
"slices"
"sort"
"strings"
"sync"
"syscall"
"time"
@@ -27,8 +26,13 @@ import (
"tailscale.com/control/controlknobs"
"tailscale.com/envknob"
"tailscale.com/health"
"tailscale.com/syncs"
"tailscale.com/types/logger"
"tailscale.com/util/dnsname"
"tailscale.com/util/eventbus"
"tailscale.com/util/syspolicy/pkey"
"tailscale.com/util/syspolicy/policyclient"
"tailscale.com/util/syspolicy/ptype"
"tailscale.com/util/winutil"
)
@@ -44,19 +48,26 @@ type windowsManager struct {
knobs *controlknobs.Knobs // or nil
nrptDB *nrptRuleDatabase
wslManager *wslManager
polc policyclient.Client
mu sync.Mutex
unregisterPolicyChangeCb func() // called when the manager is closing
mu syncs.Mutex
closing bool
}
// NewOSConfigurator created a new OS configurator.
//
// The health tracker and the knobs may be nil.
func NewOSConfigurator(logf logger.Logf, health *health.Tracker, knobs *controlknobs.Knobs, interfaceName string) (OSConfigurator, error) {
// The health tracker, eventbus and the knobs may be nil.
func NewOSConfigurator(logf logger.Logf, health *health.Tracker, bus *eventbus.Bus, polc policyclient.Client, knobs *controlknobs.Knobs, interfaceName string) (OSConfigurator, error) {
if polc == nil {
panic("nil policyclient.Client")
}
ret := &windowsManager{
logf: logf,
guid: interfaceName,
knobs: knobs,
polc: polc,
wslManager: newWSLManager(logf, health),
}
@@ -64,6 +75,11 @@ func NewOSConfigurator(logf logger.Logf, health *health.Tracker, knobs *controlk
ret.nrptDB = newNRPTRuleDatabase(logf)
}
var err error
if ret.unregisterPolicyChangeCb, err = polc.RegisterChangeCallback(ret.sysPolicyChanged); err != nil {
logf("error registering policy change callback: %v", err) // non-fatal
}
go func() {
// Log WSL status once at startup.
if distros, err := wslDistros(); err != nil {
@@ -148,7 +164,7 @@ func setTailscaleHosts(logf logger.Logf, prevHostsFile []byte, hosts []*HostEntr
header = "# TailscaleHostsSectionStart"
footer = "# TailscaleHostsSectionEnd"
)
var comments = []string{
comments := []string{
"# This section contains MagicDNS entries for Tailscale.",
"# Do not edit this section manually.",
}
@@ -362,11 +378,9 @@ func (m *windowsManager) SetDNS(cfg OSConfig) error {
// configuration only, routing one set of things to the "split"
// resolver and the rest to the primary.
// Unconditionally disable dynamic DNS updates and NetBIOS on our
// interfaces.
if err := m.disableDynamicUpdates(); err != nil {
m.logf("disableDynamicUpdates error: %v\n", err)
}
// Reconfigure DNS registration according to the [syspolicy.DNSRegistration]
// policy setting, and unconditionally disable NetBIOS on our interfaces.
m.reconfigureDNSRegistration()
if err := m.disableNetBIOS(); err != nil {
m.logf("disableNetBIOS error: %v\n", err)
}
@@ -485,6 +499,10 @@ func (m *windowsManager) Close() error {
m.closing = true
m.mu.Unlock()
if m.unregisterPolicyChangeCb != nil {
m.unregisterPolicyChangeCb()
}
err := m.SetDNS(OSConfig{})
if m.nrptDB != nil {
m.nrptDB.Close()
@@ -493,15 +511,62 @@ func (m *windowsManager) Close() error {
return err
}
// disableDynamicUpdates sets the appropriate registry values to prevent the
// Windows DHCP client from sending dynamic DNS updates for our interface to
// AD domain controllers.
func (m *windowsManager) disableDynamicUpdates() error {
// sysPolicyChanged is a callback triggered by [syspolicy] when it detects
// a change in one or more syspolicy settings.
func (m *windowsManager) sysPolicyChanged(policy policyclient.PolicyChange) {
if policy.HasChanged(pkey.EnableDNSRegistration) {
m.reconfigureDNSRegistration()
}
}
// reconfigureDNSRegistration configures the DNS registration settings
// using the [syspolicy.DNSRegistration] policy setting, if it is set.
// If the policy is not configured, it disables DNS registration.
func (m *windowsManager) reconfigureDNSRegistration() {
// Disable DNS registration by default (if the policy setting is not configured).
// This is primarily for historical reasons and to avoid breaking existing
// setups that rely on this behavior.
enableDNSRegistration, err := m.polc.GetPreferenceOption(pkey.EnableDNSRegistration, ptype.NeverByPolicy)
if err != nil {
m.logf("error getting DNSRegistration policy setting: %v", err) // non-fatal; we'll use the default
}
if enableDNSRegistration.Show() {
// "Show" reports whether the policy setting is configured as "user-decides".
// The name is a bit unfortunate in this context, as we don't actually "show" anything.
// Still, if the admin configured the policy as "user-decides", we shouldn't modify
// the adapter's settings and should leave them up to the user (admin rights required)
// or the system defaults.
return
}
// Otherwise, if the policy setting is configured as "always" or "never",
// we should configure the adapter accordingly.
if err := m.configureDNSRegistration(enableDNSRegistration.IsAlways()); err != nil {
m.logf("error configuring DNS registration: %v", err)
}
}
// configureDNSRegistration sets the appropriate registry values to allow or prevent
// the Windows DHCP client from registering Tailscale IP addresses with DNS
// and sending dynamic updates for our interface to AD domain controllers.
func (m *windowsManager) configureDNSRegistration(enabled bool) error {
prefixen := []winutil.RegistryPathPrefix{
winutil.IPv4TCPIPInterfacePrefix,
winutil.IPv6TCPIPInterfacePrefix,
}
var (
registrationEnabled = uint32(0)
disableDynamicUpdate = uint32(1)
maxNumberOfAddressesToRegister = uint32(0)
)
if enabled {
registrationEnabled = 1
disableDynamicUpdate = 0
maxNumberOfAddressesToRegister = 1
}
for _, prefix := range prefixen {
k, err := m.openInterfaceKey(prefix)
if err != nil {
@@ -509,13 +574,13 @@ func (m *windowsManager) disableDynamicUpdates() error {
}
defer k.Close()
if err := k.SetDWordValue("RegistrationEnabled", 0); err != nil {
if err := k.SetDWordValue("RegistrationEnabled", registrationEnabled); err != nil {
return err
}
if err := k.SetDWordValue("DisableDynamicUpdate", 1); err != nil {
if err := k.SetDWordValue("DisableDynamicUpdate", disableDynamicUpdate); err != nil {
return err
}
if err := k.SetDWordValue("MaxNumberOfAddressesToRegister", 0); err != nil {
if err := k.SetDWordValue("MaxNumberOfAddressesToRegister", maxNumberOfAddressesToRegister); err != nil {
return err
}
}

63
vendor/tailscale.com/net/dns/nm.go generated vendored
View File

@@ -1,13 +1,14 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build linux
//go:build linux && !android && !ts_omit_networkmanager
package dns
import (
"context"
"encoding/binary"
"errors"
"fmt"
"net"
"net/netip"
@@ -16,6 +17,7 @@ import (
"github.com/godbus/dbus/v5"
"tailscale.com/net/tsaddr"
"tailscale.com/util/cmpver"
"tailscale.com/util/dnsname"
)
@@ -25,13 +27,6 @@ const (
lowerPriority = int32(200) // lower than all builtin auto priorities
)
// reconfigTimeout is the time interval within which Manager.{Up,Down} should complete.
//
// This is particularly useful because certain conditions can cause indefinite hangs
// (such as improper dbus auth followed by contextless dbus.Object.Call).
// Such operations should be wrapped in a timeout context.
const reconfigTimeout = time.Second
// nmManager uses the NetworkManager DBus API.
type nmManager struct {
interfaceName string
@@ -39,7 +34,13 @@ type nmManager struct {
dnsManager dbus.BusObject
}
func newNMManager(interfaceName string) (*nmManager, error) {
func init() {
optNewNMManager.Set(newNMManager)
optNMIsUsingResolved.Set(nmIsUsingResolved)
optNMVersionBetween.Set(nmVersionBetween)
}
func newNMManager(interfaceName string) (OSConfigurator, error) {
conn, err := dbus.SystemBus()
if err != nil {
return nil, err
@@ -389,3 +390,47 @@ func (m *nmManager) Close() error {
// settings when the tailscale interface goes away.
return nil
}
func nmVersionBetween(first, last string) (bool, error) {
conn, err := dbus.SystemBus()
if err != nil {
// DBus probably not running.
return false, err
}
nm := conn.Object("org.freedesktop.NetworkManager", dbus.ObjectPath("/org/freedesktop/NetworkManager"))
v, err := nm.GetProperty("org.freedesktop.NetworkManager.Version")
if err != nil {
return false, err
}
version, ok := v.Value().(string)
if !ok {
return false, fmt.Errorf("unexpected type %T for NM version", v.Value())
}
outside := cmpver.Compare(version, first) < 0 || cmpver.Compare(version, last) > 0
return !outside, nil
}
func nmIsUsingResolved() error {
conn, err := dbus.SystemBus()
if err != nil {
// DBus probably not running.
return err
}
nm := conn.Object("org.freedesktop.NetworkManager", dbus.ObjectPath("/org/freedesktop/NetworkManager/DnsManager"))
v, err := nm.GetProperty("org.freedesktop.NetworkManager.DnsManager.Mode")
if err != nil {
return fmt.Errorf("getting NM mode: %w", err)
}
mode, ok := v.Value().(string)
if !ok {
return fmt.Errorf("unexpected type %T for NM DNS mode", v.Value())
}
if mode != "systemd-resolved" {
return errors.New("NetworkManager is not using systemd-resolved for DNS")
}
return nil
}

View File

@@ -1,7 +1,7 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build linux || freebsd || openbsd
//go:build (linux && !android) || freebsd || openbsd
package dns

View File

@@ -11,6 +11,7 @@ import (
"slices"
"strings"
"tailscale.com/feature/buildfeatures"
"tailscale.com/types/logger"
"tailscale.com/util/dnsname"
)
@@ -158,6 +159,10 @@ func (a OSConfig) Equal(b OSConfig) bool {
// Fixes https://github.com/tailscale/tailscale/issues/5669
func (a OSConfig) Format(f fmt.State, verb rune) {
logger.ArgWriter(func(w *bufio.Writer) {
if !buildfeatures.HasDNS {
w.WriteString(`{DNS-unlinked}`)
return
}
w.WriteString(`{Nameservers:[`)
for i, ns := range a.Nameservers {
if i != 0 {

View File

@@ -17,6 +17,8 @@ import (
"strconv"
"strings"
"sync"
"tailscale.com/feature/buildfeatures"
)
// dohOfIP maps from public DNS IPs to their DoH base URL.
@@ -163,6 +165,9 @@ const (
// populate is called once to initialize the knownDoH and dohIPsOfBase maps.
func populate() {
if !buildfeatures.HasDNS {
return
}
// Cloudflare
// https://developers.cloudflare.com/1.1.1.1/ip-addresses/
addDoH("1.1.1.1", "https://cloudflare-dns.com/dns-query")

View File

@@ -1,621 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package recursive implements a simple recursive DNS resolver.
package recursive
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"slices"
"strings"
"time"
"github.com/miekg/dns"
"tailscale.com/envknob"
"tailscale.com/net/netns"
"tailscale.com/types/logger"
"tailscale.com/util/dnsname"
"tailscale.com/util/mak"
"tailscale.com/util/multierr"
"tailscale.com/util/slicesx"
)
const (
// maxDepth is how deep from the root nameservers we'll recurse when
// resolving; passing this limit will instead return an error.
//
// maxDepth must be at least 20 to resolve "console.aws.amazon.com",
// which is a domain with a moderately complicated DNS setup. The
// current value of 30 was chosen semi-arbitrarily to ensure that we
// have about 50% headroom.
maxDepth = 30
// numStartingServers is the number of root nameservers that we use as
// initial candidates for our recursion.
numStartingServers = 3
// udpQueryTimeout is the amount of time we wait for a UDP response
// from a nameserver before falling back to a TCP connection.
udpQueryTimeout = 5 * time.Second
// These constants aren't typed in the DNS package, so we create typed
// versions here to avoid having to do repeated type casts.
qtypeA dns.Type = dns.Type(dns.TypeA)
qtypeAAAA dns.Type = dns.Type(dns.TypeAAAA)
)
var (
// ErrMaxDepth is returned when recursive resolving exceeds the maximum
// depth limit for this package.
ErrMaxDepth = fmt.Errorf("exceeded max depth %d when resolving", maxDepth)
// ErrAuthoritativeNoResponses is the error returned when an
// authoritative nameserver indicates that there are no responses to
// the given query.
ErrAuthoritativeNoResponses = errors.New("authoritative server returned no responses")
// ErrNoResponses is returned when our resolution process completes
// with no valid responses from any nameserver, but no authoritative
// server explicitly returned NXDOMAIN.
ErrNoResponses = errors.New("no responses to query")
)
var rootServersV4 = []netip.Addr{
netip.MustParseAddr("198.41.0.4"), // a.root-servers.net
netip.MustParseAddr("170.247.170.2"), // b.root-servers.net
netip.MustParseAddr("192.33.4.12"), // c.root-servers.net
netip.MustParseAddr("199.7.91.13"), // d.root-servers.net
netip.MustParseAddr("192.203.230.10"), // e.root-servers.net
netip.MustParseAddr("192.5.5.241"), // f.root-servers.net
netip.MustParseAddr("192.112.36.4"), // g.root-servers.net
netip.MustParseAddr("198.97.190.53"), // h.root-servers.net
netip.MustParseAddr("192.36.148.17"), // i.root-servers.net
netip.MustParseAddr("192.58.128.30"), // j.root-servers.net
netip.MustParseAddr("193.0.14.129"), // k.root-servers.net
netip.MustParseAddr("199.7.83.42"), // l.root-servers.net
netip.MustParseAddr("202.12.27.33"), // m.root-servers.net
}
var rootServersV6 = []netip.Addr{
netip.MustParseAddr("2001:503:ba3e::2:30"), // a.root-servers.net
netip.MustParseAddr("2801:1b8:10::b"), // b.root-servers.net
netip.MustParseAddr("2001:500:2::c"), // c.root-servers.net
netip.MustParseAddr("2001:500:2d::d"), // d.root-servers.net
netip.MustParseAddr("2001:500:a8::e"), // e.root-servers.net
netip.MustParseAddr("2001:500:2f::f"), // f.root-servers.net
netip.MustParseAddr("2001:500:12::d0d"), // g.root-servers.net
netip.MustParseAddr("2001:500:1::53"), // h.root-servers.net
netip.MustParseAddr("2001:7fe::53"), // i.root-servers.net
netip.MustParseAddr("2001:503:c27::2:30"), // j.root-servers.net
netip.MustParseAddr("2001:7fd::1"), // k.root-servers.net
netip.MustParseAddr("2001:500:9f::42"), // l.root-servers.net
netip.MustParseAddr("2001:dc3::35"), // m.root-servers.net
}
var debug = envknob.RegisterBool("TS_DEBUG_RECURSIVE_DNS")
// Resolver is a recursive DNS resolver that is designed for looking up A and AAAA records.
type Resolver struct {
// Dialer is used to create outbound connections. If nil, a zero
// net.Dialer will be used instead.
Dialer netns.Dialer
// Logf is the logging function to use; if none is specified, then logs
// will be dropped.
Logf logger.Logf
// NoIPv6, if set, will prevent this package from querying for AAAA
// records and will avoid contacting nameservers over IPv6.
NoIPv6 bool
// Test mocks
testQueryHook func(name dnsname.FQDN, nameserver netip.Addr, protocol string, qtype dns.Type) (*dns.Msg, error)
testExchangeHook func(nameserver netip.Addr, network string, msg *dns.Msg) (*dns.Msg, error)
rootServers []netip.Addr
timeNow func() time.Time
// Caching
// NOTE(andrew): if we make resolution parallel, this needs a mutex
queryCache map[dnsQuery]dnsMsgWithExpiry
// Possible future additions:
// - Additional nameservers? From the system maybe?
// - NoIPv4 for IPv4
// - DNS-over-HTTPS or DNS-over-TLS support
}
// queryState stores all state during the course of a single query
type queryState struct {
// rootServers are the root nameservers to start from
rootServers []netip.Addr
// TODO: metrics?
}
type dnsQuery struct {
nameserver netip.Addr
name dnsname.FQDN
qtype dns.Type
}
func (q dnsQuery) String() string {
return fmt.Sprintf("dnsQuery{nameserver:%q,name:%q,qtype:%v}", q.nameserver.String(), q.name, q.qtype)
}
type dnsMsgWithExpiry struct {
*dns.Msg
expiresAt time.Time
}
func (r *Resolver) now() time.Time {
if r.timeNow != nil {
return r.timeNow()
}
return time.Now()
}
func (r *Resolver) logf(format string, args ...any) {
if r.Logf == nil {
return
}
r.Logf(format, args...)
}
func (r *Resolver) depthlogf(depth int, format string, args ...any) {
if r.Logf == nil || !debug() {
return
}
prefix := fmt.Sprintf("[%d] %s", depth, strings.Repeat(" ", depth))
r.Logf(prefix+format, args...)
}
var defaultDialer net.Dialer
func (r *Resolver) dialer() netns.Dialer {
if r.Dialer != nil {
return r.Dialer
}
return &defaultDialer
}
func (r *Resolver) newState() *queryState {
var rootServers []netip.Addr
if len(r.rootServers) > 0 {
rootServers = r.rootServers
} else {
// Select a random subset of root nameservers to start from, since if
// we don't get responses from those, something else has probably gone
// horribly wrong.
roots4 := slices.Clone(rootServersV4)
slicesx.Shuffle(roots4)
roots4 = roots4[:numStartingServers]
var roots6 []netip.Addr
if !r.NoIPv6 {
roots6 = slices.Clone(rootServersV6)
slicesx.Shuffle(roots6)
roots6 = roots6[:numStartingServers]
}
// Interleave the root servers so that we try to contact them over
// IPv4, then IPv6, IPv4, IPv6, etc.
rootServers = slicesx.Interleave(roots4, roots6)
}
return &queryState{
rootServers: rootServers,
}
}
// Resolve will perform a recursive DNS resolution for the provided name,
// starting at a randomly-chosen root DNS server, and return the A and AAAA
// responses as a slice of netip.Addrs along with the minimum TTL for the
// returned records.
func (r *Resolver) Resolve(ctx context.Context, name string) (addrs []netip.Addr, minTTL time.Duration, err error) {
dnsName, err := dnsname.ToFQDN(name)
if err != nil {
return nil, 0, err
}
qstate := r.newState()
r.logf("querying IPv4 addresses for: %q", name)
addrs4, minTTL4, err4 := r.resolveRecursiveFromRoot(ctx, qstate, 0, dnsName, qtypeA)
var (
addrs6 []netip.Addr
minTTL6 time.Duration
err6 error
)
if !r.NoIPv6 {
r.logf("querying IPv6 addresses for: %q", name)
addrs6, minTTL6, err6 = r.resolveRecursiveFromRoot(ctx, qstate, 0, dnsName, qtypeAAAA)
}
if err4 != nil && err6 != nil {
if err4 == err6 {
return nil, 0, err4
}
return nil, 0, multierr.New(err4, err6)
}
if err4 != nil {
return addrs6, minTTL6, nil
} else if err6 != nil {
return addrs4, minTTL4, nil
}
minTTL = minTTL4
if minTTL6 < minTTL {
minTTL = minTTL6
}
addrs = append(addrs4, addrs6...)
if len(addrs) == 0 {
return nil, 0, ErrNoResponses
}
slicesx.Shuffle(addrs)
return addrs, minTTL, nil
}
func (r *Resolver) resolveRecursiveFromRoot(
ctx context.Context,
qstate *queryState,
depth int,
name dnsname.FQDN, // what we're querying
qtype dns.Type,
) ([]netip.Addr, time.Duration, error) {
r.depthlogf(depth, "resolving %q from root (type: %v)", name, qtype)
var depthError bool
for _, server := range qstate.rootServers {
addrs, minTTL, err := r.resolveRecursive(ctx, qstate, depth, name, server, qtype)
if err == nil {
return addrs, minTTL, err
} else if errors.Is(err, ErrAuthoritativeNoResponses) {
return nil, 0, ErrAuthoritativeNoResponses
} else if errors.Is(err, ErrMaxDepth) {
depthError = true
}
}
if depthError {
return nil, 0, ErrMaxDepth
}
return nil, 0, ErrNoResponses
}
func (r *Resolver) resolveRecursive(
ctx context.Context,
qstate *queryState,
depth int,
name dnsname.FQDN, // what we're querying
nameserver netip.Addr,
qtype dns.Type,
) ([]netip.Addr, time.Duration, error) {
if depth == maxDepth {
r.depthlogf(depth, "not recursing past maximum depth")
return nil, 0, ErrMaxDepth
}
// Ask this nameserver for an answer.
resp, err := r.queryNameserver(ctx, depth, name, nameserver, qtype)
if err != nil {
return nil, 0, err
}
// If we get an actual answer from the nameserver, then return it.
var (
answers []netip.Addr
cnames []dnsname.FQDN
minTTL = 24 * 60 * 60 // 24 hours in seconds
)
for _, answer := range resp.Answer {
if crec, ok := answer.(*dns.CNAME); ok {
cnameFQDN, err := dnsname.ToFQDN(crec.Target)
if err != nil {
r.logf("bad CNAME %q returned: %v", crec.Target, err)
continue
}
cnames = append(cnames, cnameFQDN)
continue
}
addr := addrFromRecord(answer)
if !addr.IsValid() {
r.logf("[unexpected] invalid record in %T answer", answer)
} else if addr.Is4() && qtype != qtypeA {
r.logf("[unexpected] got IPv4 answer but qtype=%v", qtype)
} else if addr.Is6() && qtype != qtypeAAAA {
r.logf("[unexpected] got IPv6 answer but qtype=%v", qtype)
} else {
answers = append(answers, addr)
minTTL = min(minTTL, int(answer.Header().Ttl))
}
}
if len(answers) > 0 {
r.depthlogf(depth, "got answers for %q: %v", name, answers)
return answers, time.Duration(minTTL) * time.Second, nil
}
r.depthlogf(depth, "no answers for %q", name)
// If we have a non-zero number of CNAMEs, then try resolving those
// (from the root again) and return the first one that succeeds.
//
// TODO: return the union of all responses?
// TODO: parallelism?
if len(cnames) > 0 {
r.depthlogf(depth, "got CNAME responses for %q: %v", name, cnames)
}
var cnameDepthError bool
for _, cname := range cnames {
answers, minTTL, err := r.resolveRecursiveFromRoot(ctx, qstate, depth+1, cname, qtype)
if err == nil {
return answers, minTTL, nil
} else if errors.Is(err, ErrAuthoritativeNoResponses) {
return nil, 0, ErrAuthoritativeNoResponses
} else if errors.Is(err, ErrMaxDepth) {
cnameDepthError = true
}
}
// If this is an authoritative response, then we know that continuing
// to look further is not going to result in any answers and we should
// bail out.
if resp.MsgHdr.Authoritative {
// If we failed to recurse into a CNAME due to a depth limit,
// propagate that here.
if cnameDepthError {
return nil, 0, ErrMaxDepth
}
r.depthlogf(depth, "got authoritative response with no answers; stopping")
return nil, 0, ErrAuthoritativeNoResponses
}
r.depthlogf(depth, "got %d NS responses and %d ADDITIONAL responses for %q", len(resp.Ns), len(resp.Extra), name)
// No CNAMEs and no answers; see if we got any AUTHORITY responses,
// which indicate which nameservers to query next.
var authorities []dnsname.FQDN
for _, rr := range resp.Ns {
ns, ok := rr.(*dns.NS)
if !ok {
continue
}
nsName, err := dnsname.ToFQDN(ns.Ns)
if err != nil {
r.logf("unexpected bad NS name %q: %v", ns.Ns, err)
continue
}
authorities = append(authorities, nsName)
}
// Also check for "glue" records, which are IP addresses provided by
// the DNS server for authority responses; these are required when the
// authority server is a subdomain of what's being resolved.
glueRecords := make(map[dnsname.FQDN][]netip.Addr)
for _, rr := range resp.Extra {
name, err := dnsname.ToFQDN(rr.Header().Name)
if err != nil {
r.logf("unexpected bad Name %q in Extra addr: %v", rr.Header().Name, err)
continue
}
if addr := addrFromRecord(rr); addr.IsValid() {
glueRecords[name] = append(glueRecords[name], addr)
} else {
r.logf("unexpected bad Extra %T addr", rr)
}
}
// Try authorities with glue records first, to minimize the number of
// additional DNS queries that we need to make.
authoritiesGlue, authoritiesNoGlue := slicesx.Partition(authorities, func(aa dnsname.FQDN) bool {
return len(glueRecords[aa]) > 0
})
authorityDepthError := false
r.depthlogf(depth, "authorities with glue records for recursion: %v", authoritiesGlue)
for _, authority := range authoritiesGlue {
for _, nameserver := range glueRecords[authority] {
answers, minTTL, err := r.resolveRecursive(ctx, qstate, depth+1, name, nameserver, qtype)
if err == nil {
return answers, minTTL, nil
} else if errors.Is(err, ErrAuthoritativeNoResponses) {
return nil, 0, ErrAuthoritativeNoResponses
} else if errors.Is(err, ErrMaxDepth) {
authorityDepthError = true
}
}
}
r.depthlogf(depth, "authorities with no glue records for recursion: %v", authoritiesNoGlue)
for _, authority := range authoritiesNoGlue {
// First, resolve the IP for the authority server from the
// root, querying for both IPv4 and IPv6 addresses regardless
// of what the current question type is.
//
// TODO: check for infinite recursion; it'll get caught by our
// recursion depth, but we want to bail early.
for _, authorityQtype := range []dns.Type{qtypeAAAA, qtypeA} {
answers, _, err := r.resolveRecursiveFromRoot(ctx, qstate, depth+1, authority, authorityQtype)
if err != nil {
r.depthlogf(depth, "error querying authority %q: %v", authority, err)
continue
}
r.depthlogf(depth, "resolved authority %q (type %v) to: %v", authority, authorityQtype, answers)
// Now, query this authority for the final address.
for _, nameserver := range answers {
answers, minTTL, err := r.resolveRecursive(ctx, qstate, depth+1, name, nameserver, qtype)
if err == nil {
return answers, minTTL, nil
} else if errors.Is(err, ErrAuthoritativeNoResponses) {
return nil, 0, ErrAuthoritativeNoResponses
} else if errors.Is(err, ErrMaxDepth) {
authorityDepthError = true
}
}
}
}
if authorityDepthError {
return nil, 0, ErrMaxDepth
}
return nil, 0, ErrNoResponses
}
// queryNameserver sends a query for "name" to the nameserver "nameserver" for
// records of type "qtype", trying both UDP and TCP connections as
// appropriate.
func (r *Resolver) queryNameserver(
ctx context.Context,
depth int,
name dnsname.FQDN, // what we're querying
nameserver netip.Addr, // destination of query
qtype dns.Type,
) (*dns.Msg, error) {
// TODO(andrew): we should QNAME minimisation here to avoid sending the
// full name to intermediate/root nameservers. See:
// https://www.rfc-editor.org/rfc/rfc7816
// Handle the case where UDP is blocked by adding an explicit timeout
// for the UDP portion of this query.
udpCtx, udpCtxCancel := context.WithTimeout(ctx, udpQueryTimeout)
defer udpCtxCancel()
msg, err := r.queryNameserverProto(udpCtx, depth, name, nameserver, "udp", qtype)
if err == nil {
return msg, nil
}
msg, err2 := r.queryNameserverProto(ctx, depth, name, nameserver, "tcp", qtype)
if err2 == nil {
return msg, nil
}
return nil, multierr.New(err, err2)
}
// queryNameserverProto sends a query for "name" to the nameserver "nameserver"
// for records of type "qtype" over the provided protocol (either "udp"
// or "tcp"), and returns the DNS response or an error.
func (r *Resolver) queryNameserverProto(
ctx context.Context,
depth int,
name dnsname.FQDN, // what we're querying
nameserver netip.Addr, // destination of query
protocol string,
qtype dns.Type,
) (resp *dns.Msg, err error) {
if r.testQueryHook != nil {
return r.testQueryHook(name, nameserver, protocol, qtype)
}
now := r.now()
nameserverStr := nameserver.String()
cacheKey := dnsQuery{
nameserver: nameserver,
name: name,
qtype: qtype,
}
cacheEntry, ok := r.queryCache[cacheKey]
if ok && cacheEntry.expiresAt.Before(now) {
r.depthlogf(depth, "using cached response from %s about %q (type: %v)", nameserverStr, name, qtype)
return cacheEntry.Msg, nil
}
var network string
if nameserver.Is4() {
network = protocol + "4"
} else {
network = protocol + "6"
}
// Prepare a message asking for an appropriately-typed record
// for the name we're querying.
m := new(dns.Msg)
m.SetQuestion(name.WithTrailingDot(), uint16(qtype))
// Allow mocking out the network components with our exchange hook.
if r.testExchangeHook != nil {
resp, err = r.testExchangeHook(nameserver, network, m)
} else {
// Dial the current nameserver using our dialer.
var nconn net.Conn
nconn, err = r.dialer().DialContext(ctx, network, net.JoinHostPort(nameserverStr, "53"))
if err != nil {
return nil, err
}
var c dns.Client // TODO: share?
conn := &dns.Conn{
Conn: nconn,
UDPSize: c.UDPSize,
}
// Send the DNS request to the current nameserver.
r.depthlogf(depth, "asking %s over %s about %q (type: %v)", nameserverStr, protocol, name, qtype)
resp, _, err = c.ExchangeWithConnContext(ctx, m, conn)
}
if err != nil {
return nil, err
}
// If the message was truncated and we're using UDP, re-run with TCP.
if resp.MsgHdr.Truncated && protocol == "udp" {
r.depthlogf(depth, "response message truncated; re-running query with TCP")
resp, err = r.queryNameserverProto(ctx, depth, name, nameserver, "tcp", qtype)
if err != nil {
return nil, err
}
}
// Find minimum expiry for all records in this message.
var minTTL int
for _, rr := range resp.Answer {
minTTL = min(minTTL, int(rr.Header().Ttl))
}
for _, rr := range resp.Ns {
minTTL = min(minTTL, int(rr.Header().Ttl))
}
for _, rr := range resp.Extra {
minTTL = min(minTTL, int(rr.Header().Ttl))
}
mak.Set(&r.queryCache, cacheKey, dnsMsgWithExpiry{
Msg: resp,
expiresAt: now.Add(time.Duration(minTTL) * time.Second),
})
return resp, nil
}
func addrFromRecord(rr dns.RR) netip.Addr {
switch v := rr.(type) {
case *dns.A:
ip, ok := netip.AddrFromSlice(v.A)
if !ok || !ip.Is4() {
return netip.Addr{}
}
return ip
case *dns.AAAA:
ip, ok := netip.AddrFromSlice(v.AAAA)
if !ok || !ip.Is6() {
return netip.Addr{}
}
return ip
}
return netip.Addr{}
}

View File

@@ -1,7 +1,7 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build linux
//go:build linux && !android && !ts_omit_resolved
package dns
@@ -15,8 +15,8 @@ import (
"github.com/godbus/dbus/v5"
"golang.org/x/sys/unix"
"tailscale.com/health"
"tailscale.com/logtail/backoff"
"tailscale.com/types/logger"
"tailscale.com/util/backoff"
"tailscale.com/util/dnsname"
)
@@ -70,7 +70,11 @@ type resolvedManager struct {
configCR chan changeRequest // tracks OSConfigs changes and error responses
}
func newResolvedManager(logf logger.Logf, health *health.Tracker, interfaceName string) (*resolvedManager, error) {
func init() {
optNewResolvedManager.Set(newResolvedManager)
}
func newResolvedManager(logf logger.Logf, health *health.Tracker, interfaceName string) (OSConfigurator, error) {
iface, err := net.InterfaceByName(interfaceName)
if err != nil {
return nil, err

View File

@@ -8,14 +8,18 @@ import (
"html"
"net/http"
"strconv"
"sync"
"sync/atomic"
"time"
"tailscale.com/feature/buildfeatures"
"tailscale.com/health"
"tailscale.com/syncs"
)
func init() {
if !buildfeatures.HasDNS {
return
}
health.RegisterDebugHandler("dnsfwd", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n, _ := strconv.Atoi(r.FormValue("n"))
if n <= 0 {
@@ -35,7 +39,7 @@ func init() {
var fwdLogAtomic atomic.Pointer[fwdLog]
type fwdLog struct {
mu sync.Mutex
mu syncs.Mutex
pos int // ent[pos] is next entry
ent []fwdLogEntry
}

View File

@@ -17,6 +17,7 @@ import (
"net/http"
"net/netip"
"net/url"
"runtime"
"sort"
"strings"
"sync"
@@ -26,13 +27,17 @@ import (
dns "golang.org/x/net/dns/dnsmessage"
"tailscale.com/control/controlknobs"
"tailscale.com/envknob"
"tailscale.com/feature"
"tailscale.com/feature/buildfeatures"
"tailscale.com/health"
"tailscale.com/net/dns/publicdns"
"tailscale.com/net/dnscache"
"tailscale.com/net/neterror"
"tailscale.com/net/netmon"
"tailscale.com/net/netx"
"tailscale.com/net/sockstats"
"tailscale.com/net/tsdial"
"tailscale.com/syncs"
"tailscale.com/types/dnstype"
"tailscale.com/types/logger"
"tailscale.com/types/nettype"
@@ -215,18 +220,19 @@ type resolverAndDelay struct {
// forwarder forwards DNS packets to a number of upstream nameservers.
type forwarder struct {
logf logger.Logf
netMon *netmon.Monitor // always non-nil
linkSel ForwardLinkSelector // TODO(bradfitz): remove this when tsdial.Dialer absorbs it
dialer *tsdial.Dialer
health *health.Tracker // always non-nil
logf logger.Logf
netMon *netmon.Monitor // always non-nil
linkSel ForwardLinkSelector // TODO(bradfitz): remove this when tsdial.Dialer absorbs it
dialer *tsdial.Dialer
health *health.Tracker // always non-nil
verboseFwd bool // if true, log all DNS forwarding
controlKnobs *controlknobs.Knobs // or nil
ctx context.Context // good until Close
ctxCancel context.CancelFunc // closes ctx
mu sync.Mutex // guards following
mu syncs.Mutex // guards following
dohClient map[string]*http.Client // urlBase -> client
@@ -243,26 +249,23 @@ type forwarder struct {
// /etc/resolv.conf is missing/corrupt, and the peerapi ExitDNS stub
// resolver lookup.
cloudHostFallback []resolverAndDelay
// missingUpstreamRecovery, if non-nil, is set called when a SERVFAIL is
// returned due to missing upstream resolvers.
//
// This should attempt to properly (re)set the upstream resolvers.
missingUpstreamRecovery func()
}
func newForwarder(logf logger.Logf, netMon *netmon.Monitor, linkSel ForwardLinkSelector, dialer *tsdial.Dialer, health *health.Tracker, knobs *controlknobs.Knobs) *forwarder {
if !buildfeatures.HasDNS {
return nil
}
if netMon == nil {
panic("nil netMon")
}
f := &forwarder{
logf: logger.WithPrefix(logf, "forward: "),
netMon: netMon,
linkSel: linkSel,
dialer: dialer,
health: health,
controlKnobs: knobs,
missingUpstreamRecovery: func() {},
logf: logger.WithPrefix(logf, "forward: "),
netMon: netMon,
linkSel: linkSel,
dialer: dialer,
health: health,
controlKnobs: knobs,
verboseFwd: verboseDNSForward(),
}
f.ctx, f.ctxCancel = context.WithCancel(context.Background())
return f
@@ -520,15 +523,18 @@ var (
//
// send expects the reply to have the same txid as txidOut.
func (f *forwarder) send(ctx context.Context, fq *forwardQuery, rr resolverAndDelay) (ret []byte, err error) {
if verboseDNSForward() {
if f.verboseFwd {
id := forwarderCount.Add(1)
domain, typ, _ := nameFromQuery(fq.packet)
f.logf("forwarder.send(%q, %d, %v, %d) [%d] ...", rr.name.Addr, fq.txid, typ, len(domain), id)
f.logf("forwarder.send(%q, %d, %v, %d) from %v [%d] ...", rr.name.Addr, fq.txid, typ, len(domain), fq.src, id)
defer func() {
f.logf("forwarder.send(%q, %d, %v, %d) [%d] = %v, %v", rr.name.Addr, fq.txid, typ, len(domain), id, len(ret), err)
f.logf("forwarder.send(%q, %d, %v, %d) from %v [%d] = %v, %v", rr.name.Addr, fq.txid, typ, len(domain), fq.src, id, len(ret), err)
}()
}
if strings.HasPrefix(rr.name.Addr, "http://") {
if !buildfeatures.HasPeerAPIClient {
return nil, feature.ErrUnavailable
}
return f.sendDoH(ctx, rr.name.Addr, f.dialer.PeerAPIHTTPClient(), fq.packet)
}
if strings.HasPrefix(rr.name.Addr, "https://") {
@@ -739,18 +745,38 @@ func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAn
return out, nil
}
func (f *forwarder) getDialerType() dnscache.DialContextFunc {
if f.controlKnobs != nil && f.controlKnobs.UserDialUseRoutes.Load() {
// It is safe to use UserDial as it dials external servers without going through Tailscale
// and closes connections on interface change in the same way as SystemDial does,
// thus preventing DNS resolution issues when switching between WiFi and cellular,
// but can also dial an internal DNS server on the Tailnet or via a subnet router.
//
// TODO(nickkhyl): Update tsdial.Dialer to reuse the bart.Table we create in net/tstun.Wrapper
// to avoid having two bart tables in memory, especially on iOS. Once that's done,
// we can get rid of the nodeAttr/control knob and always use UserDial for DNS.
//
// See https://github.com/tailscale/tailscale/issues/12027.
var optDNSForwardUseRoutes = envknob.RegisterOptBool("TS_DEBUG_DNS_FORWARD_USE_ROUTES")
// ShouldUseRoutes reports whether the DNS resolver should consider routes when dialing
// upstream nameservers via TCP.
//
// If true, routes should be considered ([tsdial.Dialer.UserDial]), otherwise defer
// to the system routes ([tsdial.Dialer.SystemDial]).
//
// TODO(nickkhyl): Update [tsdial.Dialer] to reuse the bart.Table we create in net/tstun.Wrapper
// to avoid having two bart tables in memory, especially on iOS. Once that's done,
// we can get rid of the nodeAttr/control knob and always use UserDial for DNS.
//
// See tailscale/tailscale#12027.
func ShouldUseRoutes(knobs *controlknobs.Knobs) bool {
if !buildfeatures.HasDNS {
return false
}
switch runtime.GOOS {
case "android", "ios":
// On mobile platforms with lower memory limits (e.g., 50MB on iOS),
// this behavior is still gated by the "user-dial-routes" nodeAttr.
return knobs != nil && knobs.UserDialUseRoutes.Load()
default:
// On all other platforms, it is the default behavior,
// but it can be overridden with the "TS_DEBUG_DNS_FORWARD_USE_ROUTES" env var.
doNotUseRoutes := optDNSForwardUseRoutes().EqualBool(false)
return !doNotUseRoutes
}
}
func (f *forwarder) getDialerType() netx.DialFunc {
if ShouldUseRoutes(f.controlKnobs) {
return f.dialer.UserDial
}
return f.dialer.SystemDial
@@ -878,6 +904,7 @@ type forwardQuery struct {
txid txid
packet []byte
family string // "tcp" or "udp"
src netip.AddrPort
// closeOnCtxDone lets send register values to Close if the
// caller's ctx expires. This avoids send from allocating its
@@ -943,13 +970,6 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo
f.health.SetUnhealthy(dnsForwarderFailing, health.Args{health.ArgDNSServers: ""})
f.logf("no upstream resolvers set, returning SERVFAIL")
// Attempt to recompile the DNS configuration
// If we are being asked to forward queries and we have no
// nameservers, the network is in a bad state.
if f.missingUpstreamRecovery != nil {
f.missingUpstreamRecovery()
}
res, err := servfailResponse(query)
if err != nil {
return err
@@ -969,11 +989,12 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo
txid: getTxID(query.bs),
packet: query.bs,
family: query.family,
src: query.addr,
closeOnCtxDone: new(closePool),
}
defer fq.closeOnCtxDone.Close()
if verboseDNSForward() {
if f.verboseFwd {
domainSha256 := sha256.Sum256([]byte(domain))
domainSig := base64.RawStdEncoding.EncodeToString(domainSha256[:3])
f.logf("request(%d, %v, %d, %s) %d...", fq.txid, typ, len(domain), domainSig, len(fq.packet))
@@ -1018,7 +1039,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo
metricDNSFwdErrorContext.Add(1)
return fmt.Errorf("waiting to send response: %w", ctx.Err())
case responseChan <- packet{v, query.family, query.addr}:
if verboseDNSForward() {
if f.verboseFwd {
f.logf("response(%d, %v, %d) = %d, nil", fq.txid, typ, len(domain), len(v))
}
metricDNSFwdSuccess.Add(1)
@@ -1048,7 +1069,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo
}
f.health.SetUnhealthy(dnsForwarderFailing, health.Args{health.ArgDNSServers: strings.Join(resolverAddrs, ",")})
case responseChan <- res:
if verboseDNSForward() {
if f.verboseFwd {
f.logf("forwarder response(%d, %v, %d) = %d, %v", fq.txid, typ, len(domain), len(res.bs), firstErr)
}
return nil

View File

@@ -25,6 +25,8 @@ import (
dns "golang.org/x/net/dns/dnsmessage"
"tailscale.com/control/controlknobs"
"tailscale.com/envknob"
"tailscale.com/feature"
"tailscale.com/feature/buildfeatures"
"tailscale.com/health"
"tailscale.com/net/dns/resolvconffile"
"tailscale.com/net/netaddr"
@@ -212,7 +214,7 @@ type Resolver struct {
closed chan struct{}
// mu guards the following fields from being updated while used.
mu sync.Mutex
mu syncs.Mutex
localDomains []dnsname.FQDN
hostToIP map[dnsname.FQDN][]netip.Addr
ipToHost map[netip.Addr]dnsname.FQDN
@@ -251,18 +253,12 @@ func New(logf logger.Logf, linkSel ForwardLinkSelector, dialer *tsdial.Dialer, h
return r
}
// SetMissingUpstreamRecovery sets a callback to be called upon encountering
// a SERVFAIL due to missing upstream resolvers.
//
// This call should only happen before the resolver is used. It is not safe
// for concurrent use.
func (r *Resolver) SetMissingUpstreamRecovery(f func()) {
r.forwarder.missingUpstreamRecovery = f
}
func (r *Resolver) TestOnlySetHook(hook func(Config)) { r.saveConfigForTests = hook }
func (r *Resolver) SetConfig(cfg Config) error {
if !buildfeatures.HasDNS {
return nil
}
if r.saveConfigForTests != nil {
r.saveConfigForTests(cfg)
}
@@ -288,6 +284,9 @@ func (r *Resolver) SetConfig(cfg Config) error {
// Close shuts down the resolver and ensures poll goroutines have exited.
// The Resolver cannot be used again after Close is called.
func (r *Resolver) Close() {
if !buildfeatures.HasDNS {
return
}
select {
case <-r.closed:
return
@@ -305,6 +304,9 @@ func (r *Resolver) Close() {
const dnsQueryTimeout = 10 * time.Second
func (r *Resolver) Query(ctx context.Context, bs []byte, family string, from netip.AddrPort) ([]byte, error) {
if !buildfeatures.HasDNS {
return nil, feature.ErrUnavailable
}
metricDNSQueryLocal.Add(1)
select {
case <-r.closed:
@@ -332,6 +334,9 @@ func (r *Resolver) Query(ctx context.Context, bs []byte, family string, from net
// GetUpstreamResolvers returns the resolvers that would be used to resolve
// the given FQDN.
func (r *Resolver) GetUpstreamResolvers(name dnsname.FQDN) []*dnstype.Resolver {
if !buildfeatures.HasDNS {
return nil
}
return r.forwarder.GetUpstreamResolvers(name)
}
@@ -360,6 +365,9 @@ func parseExitNodeQuery(q []byte) *response {
// and a nil error.
// TODO: figure out if we even need an error result.
func (r *Resolver) HandlePeerDNSQuery(ctx context.Context, q []byte, from netip.AddrPort, allowName func(name string) bool) (res []byte, err error) {
if !buildfeatures.HasDNS {
return nil, feature.ErrUnavailable
}
metricDNSExitProxyQuery.Add(1)
ch := make(chan packet, 1)
@@ -436,6 +444,9 @@ var debugExitNodeDNSNetPkg = envknob.RegisterBool("TS_DEBUG_EXIT_NODE_DNS_NET_PK
// response contains the pre-serialized response, which notably
// includes the original question and its header.
func handleExitNodeDNSQueryWithNetPkg(ctx context.Context, logf logger.Logf, resolver *net.Resolver, resp *response) (res []byte, err error) {
if !buildfeatures.HasDNS {
return nil, feature.ErrUnavailable
}
logf = logger.WithPrefix(logf, "exitNodeDNSQueryWithNetPkg: ")
if resp.Question.Class != dns.ClassINET {
return nil, errors.New("unsupported class")
@@ -1256,6 +1267,9 @@ func (r *Resolver) respondReverse(query []byte, name dnsname.FQDN, resp *respons
// respond returns a DNS response to query if it can be resolved locally.
// Otherwise, it returns errNotOurName.
func (r *Resolver) respond(query []byte) ([]byte, error) {
if !buildfeatures.HasDNS {
return nil, feature.ErrUnavailable
}
parser := dnsParserPool.Get().(*dnsParser)
defer dnsParserPool.Put(parser)

View File

@@ -76,7 +76,7 @@ func (wm *wslManager) SetDNS(cfg OSConfig) error {
}
managers := make(map[string]*directManager)
for _, distro := range distros {
managers[distro] = newDirectManagerOnFS(wm.logf, wm.health, wslFS{
managers[distro] = newDirectManagerOnFS(wm.logf, wm.health, nil, wslFS{
user: "root",
distro: distro,
})

View File

@@ -19,10 +19,13 @@ import (
"time"
"tailscale.com/envknob"
"tailscale.com/net/netx"
"tailscale.com/syncs"
"tailscale.com/types/logger"
"tailscale.com/util/cloudenv"
"tailscale.com/util/singleflight"
"tailscale.com/util/slicesx"
"tailscale.com/util/testenv"
)
var zaddr netip.Addr
@@ -62,6 +65,10 @@ type Resolver struct {
// If nil, net.DefaultResolver is used.
Forward *net.Resolver
// LookupIPForTest, if non-nil and in tests, handles requests instead
// of the usual mechanisms.
LookupIPForTest func(ctx context.Context, host string) ([]netip.Addr, error)
// LookupIPFallback optionally provides a backup DNS mechanism
// to use if Forward returns an error or no results.
LookupIPFallback func(ctx context.Context, host string) ([]netip.Addr, error)
@@ -91,7 +98,7 @@ type Resolver struct {
sf singleflight.Group[string, ipRes]
mu sync.Mutex
mu syncs.Mutex
ipCache map[string]ipCacheEntry
}
@@ -199,6 +206,9 @@ func (r *Resolver) LookupIP(ctx context.Context, host string) (ip, v6 netip.Addr
}
allIPs = append(allIPs, naIP)
}
if !ip.IsValid() && v6.IsValid() {
ip = v6
}
r.dlogf("returning %d static results", len(allIPs))
return
}
@@ -283,7 +293,13 @@ func (r *Resolver) lookupIP(ctx context.Context, host string) (ip, ip6 netip.Add
lookupCtx, lookupCancel := context.WithTimeout(ctx, r.lookupTimeoutForHost(host))
defer lookupCancel()
ips, err := r.fwd().LookupNetIP(lookupCtx, "ip", host)
var ips []netip.Addr
if r.LookupIPForTest != nil && testenv.InTest() {
ips, err = r.LookupIPForTest(ctx, host)
} else {
ips, err = r.fwd().LookupNetIP(lookupCtx, "ip", host)
}
if err != nil || len(ips) == 0 {
if resolver, ok := r.cloudHostResolver(); ok {
r.dlogf("resolving %q via cloud resolver", host)
@@ -355,10 +371,8 @@ func (r *Resolver) addIPCache(host string, ip, ip6 netip.Addr, allIPs []netip.Ad
}
}
type DialContextFunc func(ctx context.Context, network, address string) (net.Conn, error)
// Dialer returns a wrapped DialContext func that uses the provided dnsCache.
func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc {
func Dialer(fwd netx.DialFunc, dnsCache *Resolver) netx.DialFunc {
d := &dialer{
fwd: fwd,
dnsCache: dnsCache,
@@ -369,7 +383,7 @@ func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc {
// dialer is the config and accumulated state for a dial func returned by Dialer.
type dialer struct {
fwd DialContextFunc
fwd netx.DialFunc
dnsCache *Resolver
mu sync.Mutex
@@ -461,7 +475,7 @@ type dialCall struct {
d *dialer
network, address, host, port string
mu sync.Mutex // lock ordering: dialer.mu, then dialCall.mu
mu syncs.Mutex // lock ordering: dialer.mu, then dialCall.mu
fails map[netip.Addr]error // set of IPs that failed to dial thus far
}
@@ -653,7 +667,7 @@ func v6addrs(aa []netip.Addr) (ret []netip.Addr) {
// TLSDialer is like Dialer but returns a func suitable for using with net/http.Transport.DialTLSContext.
// It returns a *tls.Conn type on success.
// On TLS cert validation failure, it can invoke a backup DNS resolution strategy.
func TLSDialer(fwd DialContextFunc, dnsCache *Resolver, tlsConfigBase *tls.Config) DialContextFunc {
func TLSDialer(fwd netx.DialFunc, dnsCache *Resolver, tlsConfigBase *tls.Config) netx.DialFunc {
tcpDialer := Dialer(fwd, dnsCache)
return func(ctx context.Context, network, address string) (net.Conn, error) {
host, _, err := net.SplitHostPort(address)

View File

@@ -22,35 +22,20 @@ import (
"net/url"
"os"
"reflect"
"slices"
"sync/atomic"
"time"
"tailscale.com/atomicfile"
"tailscale.com/envknob"
"tailscale.com/feature"
"tailscale.com/health"
"tailscale.com/net/dns/recursive"
"tailscale.com/net/netmon"
"tailscale.com/net/netns"
"tailscale.com/net/tlsdial"
"tailscale.com/net/tshttpproxy"
"tailscale.com/tailcfg"
"tailscale.com/types/logger"
"tailscale.com/util/clientmetric"
"tailscale.com/util/singleflight"
"tailscale.com/util/slicesx"
)
var (
optRecursiveResolver = envknob.RegisterOptBool("TS_DNSFALLBACK_RECURSIVE_RESOLVER")
disableRecursiveResolver = envknob.RegisterBool("TS_DNSFALLBACK_DISABLE_RECURSIVE_RESOLVER") // legacy pre-1.52 env knob name
)
type resolveResult struct {
addrs []netip.Addr
minTTL time.Duration
}
// MakeLookupFunc creates a function that can be used to resolve hostnames
// (e.g. as a LookupIPFallback from dnscache.Resolver).
// The netMon parameter is optional; if non-nil it's used to do faster interface lookups.
@@ -68,145 +53,13 @@ type fallbackResolver struct {
logf logger.Logf
netMon *netmon.Monitor // or nil
healthTracker *health.Tracker // or nil
sf singleflight.Group[string, resolveResult]
// for tests
waitForCompare bool
}
func (fr *fallbackResolver) Lookup(ctx context.Context, host string) ([]netip.Addr, error) {
// If they've explicitly disabled the recursive resolver with the legacy
// TS_DNSFALLBACK_DISABLE_RECURSIVE_RESOLVER envknob or not set the
// newer TS_DNSFALLBACK_RECURSIVE_RESOLVER to true, then don't use the
// recursive resolver. (tailscale/corp#15261) In the future, we might
// change the default (the opt.Bool being unset) to mean enabled.
if disableRecursiveResolver() || !optRecursiveResolver().EqualBool(true) {
return lookup(ctx, host, fr.logf, fr.healthTracker, fr.netMon)
}
addrsCh := make(chan []netip.Addr, 1)
// Run the recursive resolver in the background so we can
// compare the results. For tests, we also allow waiting for the
// comparison to complete; normally, we do this entirely asynchronously
// so as not to block the caller.
var done chan struct{}
if fr.waitForCompare {
done = make(chan struct{})
go func() {
defer close(done)
fr.compareWithRecursive(ctx, addrsCh, host)
}()
} else {
go fr.compareWithRecursive(ctx, addrsCh, host)
}
addrs, err := lookup(ctx, host, fr.logf, fr.healthTracker, fr.netMon)
if err != nil {
addrsCh <- nil
return nil, err
}
addrsCh <- slices.Clone(addrs)
if fr.waitForCompare {
select {
case <-done:
case <-ctx.Done():
}
}
return addrs, nil
}
// compareWithRecursive is responsible for comparing the DNS resolution
// performed via the "normal" path (bootstrap DNS requests to the DERP servers)
// with DNS resolution performed with our in-process recursive DNS resolver.
//
// It will select on addrsCh to read exactly one set of addrs (returned by the
// "normal" path) and compare against the results returned by the recursive
// resolver. If ctx is canceled, then it will abort.
func (fr *fallbackResolver) compareWithRecursive(
ctx context.Context,
addrsCh <-chan []netip.Addr,
host string,
) {
logf := logger.WithPrefix(fr.logf, "recursive: ")
// Ensure that we catch panics while we're testing this
// code path; this should never panic, but we don't
// want to take down the process by having the panic
// propagate to the top of the goroutine's stack and
// then terminate.
defer func() {
if r := recover(); r != nil {
logf("bootstrap DNS: recovered panic: %v", r)
metricRecursiveErrors.Add(1)
}
}()
// Don't resolve the same host multiple times
// concurrently; if we end up in a tight loop, this can
// take up a lot of CPU.
var didRun bool
result, err, _ := fr.sf.Do(host, func() (resolveResult, error) {
didRun = true
resolver := &recursive.Resolver{
Dialer: netns.NewDialer(logf, fr.netMon),
Logf: logf,
}
addrs, minTTL, err := resolver.Resolve(ctx, host)
if err != nil {
logf("error using recursive resolver: %v", err)
metricRecursiveErrors.Add(1)
return resolveResult{}, err
}
return resolveResult{addrs, minTTL}, nil
})
// The singleflight function handled errors; return if
// there was one. Additionally, don't bother doing the
// comparison if we waited on another singleflight
// caller; the results are likely to be the same, so
// rather than spam the logs we can just exit and let
// the singleflight call that did execute do the
// comparison.
//
// Returning here is safe because the addrsCh channel
// is buffered, so the main function won't block even
// if we never read from it.
if err != nil || !didRun {
return
}
addrs, minTTL := result.addrs, result.minTTL
compareAddr := func(a, b netip.Addr) int { return a.Compare(b) }
slices.SortFunc(addrs, compareAddr)
// Wait for a response from the main function; try this once before we
// check whether the context is canceled since selects are
// nondeterministic.
var oldAddrs []netip.Addr
select {
case oldAddrs = <-addrsCh:
// All good; continue
default:
// Now block.
select {
case oldAddrs = <-addrsCh:
case <-ctx.Done():
return
}
}
slices.SortFunc(oldAddrs, compareAddr)
matches := slices.Equal(addrs, oldAddrs)
logf("bootstrap DNS comparison: matches=%v oldAddrs=%v addrs=%v minTTL=%v", matches, oldAddrs, addrs, minTTL)
if matches {
metricRecursiveMatches.Add(1)
} else {
metricRecursiveMismatches.Add(1)
}
return lookup(ctx, host, fr.logf, fr.healthTracker, fr.netMon)
}
func lookup(ctx context.Context, host string, logf logger.Logf, ht *health.Tracker, netMon *netmon.Monitor) ([]netip.Addr, error) {
@@ -282,11 +135,11 @@ func bootstrapDNSMap(ctx context.Context, serverName string, serverIP netip.Addr
dialer := netns.NewDialer(logf, netMon)
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.DisableKeepAlives = true // This transport is meant to be used once.
tr.Proxy = tshttpproxy.ProxyFromEnvironment
tr.Proxy = feature.HookProxyFromEnvironment.GetOrNil()
tr.DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, "tcp", net.JoinHostPort(serverIP.String(), "443"))
}
tr.TLSClientConfig = tlsdial.Config(serverName, ht, tr.TLSClientConfig)
tr.TLSClientConfig = tlsdial.Config(ht, tr.TLSClientConfig)
c := &http.Client{Transport: tr}
req, err := http.NewRequestWithContext(ctx, "GET", "https://"+serverName+"/bootstrap-dns?q="+url.QueryEscape(queryName), nil)
if err != nil {
@@ -428,9 +281,3 @@ func SetCachePath(path string, logf logger.Logf) {
cachedDERPMap.Store(dm)
logf("[v2] dnsfallback: SetCachePath loaded cached DERP map")
}
var (
metricRecursiveMatches = clientmetric.NewCounter("dnsfallback_recursive_matches")
metricRecursiveMismatches = clientmetric.NewCounter("dnsfallback_recursive_mismatches")
metricRecursiveErrors = clientmetric.NewCounter("dnsfallback_recursive_errors")
)

View File

@@ -22,6 +22,7 @@ type Listener struct {
ch chan Conn
closeOnce sync.Once
closed chan struct{}
onClose func() // or nil
// NewConn, if non-nil, is called to create a new pair of connections
// when dialing. If nil, NewConn is used.
@@ -38,24 +39,29 @@ func Listen(addr string) *Listener {
}
// Addr implements net.Listener.Addr.
func (l *Listener) Addr() net.Addr {
return l.addr
func (ln *Listener) Addr() net.Addr {
return ln.addr
}
// Close closes the pipe listener.
func (l *Listener) Close() error {
l.closeOnce.Do(func() {
close(l.closed)
func (ln *Listener) Close() error {
var cleanup func()
ln.closeOnce.Do(func() {
cleanup = ln.onClose
close(ln.closed)
})
if cleanup != nil {
cleanup()
}
return nil
}
// Accept blocks until a new connection is available or the listener is closed.
func (l *Listener) Accept() (net.Conn, error) {
func (ln *Listener) Accept() (net.Conn, error) {
select {
case c := <-l.ch:
case c := <-ln.ch:
return c, nil
case <-l.closed:
case <-ln.closed:
return nil, net.ErrClosed
}
}
@@ -64,18 +70,18 @@ func (l *Listener) Accept() (net.Conn, error) {
// The provided Context must be non-nil. If the context expires before the
// connection is complete, an error is returned. Once successfully connected
// any expiration of the context will not affect the connection.
func (l *Listener) Dial(ctx context.Context, network, addr string) (_ net.Conn, err error) {
func (ln *Listener) Dial(ctx context.Context, network, addr string) (_ net.Conn, err error) {
if !strings.HasSuffix(network, "tcp") {
return nil, net.UnknownNetworkError(network)
}
if connAddr(addr) != l.addr {
if connAddr(addr) != ln.addr {
return nil, &net.AddrError{
Err: "invalid address",
Addr: addr,
}
}
newConn := l.NewConn
newConn := ln.NewConn
if newConn == nil {
newConn = func(network, addr string, maxBuf int) (Conn, Conn) {
return NewConn(addr, maxBuf)
@@ -92,9 +98,9 @@ func (l *Listener) Dial(ctx context.Context, network, addr string) (_ net.Conn,
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-l.closed:
case <-ln.closed:
return nil, net.ErrClosed
case l.ch <- s:
case ln.ch <- s:
return c, nil
}
}

View File

@@ -6,3 +6,87 @@
// in tests and other situations where you don't want to use the
// network.
package memnet
import (
"context"
"fmt"
"net"
"net/netip"
"tailscale.com/net/netx"
"tailscale.com/syncs"
)
var _ netx.Network = (*Network)(nil)
// Network implements [Network] using an in-memory network, usually
// used for testing.
//
// As of 2025-04-08, it only supports TCP.
//
// Its zero value is a valid [netx.Network] implementation.
type Network struct {
mu syncs.Mutex
lns map[string]*Listener // address -> listener
}
func (m *Network) Listen(network, address string) (net.Listener, error) {
if network != "tcp" && network != "tcp4" && network != "tcp6" {
return nil, fmt.Errorf("memNetwork: Listen called with unsupported network %q", network)
}
ap, err := netip.ParseAddrPort(address)
if err != nil {
return nil, fmt.Errorf("memNetwork: Listen called with invalid address %q: %w", address, err)
}
m.mu.Lock()
defer m.mu.Unlock()
if m.lns == nil {
m.lns = make(map[string]*Listener)
}
port := ap.Port()
for {
if port == 0 {
port = 33000
}
key := net.JoinHostPort(ap.Addr().String(), fmt.Sprint(port))
_, ok := m.lns[key]
if ok {
if ap.Port() != 0 {
return nil, fmt.Errorf("memNetwork: Listen called with duplicate address %q", address)
}
port++
continue
}
ln := Listen(key)
m.lns[key] = ln
ln.onClose = func() {
m.mu.Lock()
delete(m.lns, key)
m.mu.Unlock()
}
return ln, nil
}
}
func (m *Network) NewLocalTCPListener() net.Listener {
ln, err := m.Listen("tcp", "127.0.0.1:0")
if err != nil {
panic(fmt.Sprintf("memNetwork: failed to create local TCP listener: %v", err))
}
return ln
}
func (m *Network) Dial(ctx context.Context, network, address string) (net.Conn, error) {
if network != "tcp" && network != "tcp4" && network != "tcp6" {
return nil, fmt.Errorf("memNetwork: Dial called with unsupported network %q", network)
}
m.mu.Lock()
ln, ok := m.lns[address]
m.mu.Unlock()
if !ok {
return nil, fmt.Errorf("memNetwork: Dial called on unknown address %q", address)
}
return ln.Dial(ctx, network, address)
}

View File

@@ -34,7 +34,7 @@ func FromStdIPNet(std *net.IPNet) (prefix netip.Prefix, ok bool) {
}
ip = ip.Unmap()
if l := len(std.Mask); l != net.IPv4len && l != net.IPv6len {
if ln := len(std.Mask); ln != net.IPv4len && ln != net.IPv6len {
// Invalid mask.
return netip.Prefix{}, false
}

55
vendor/tailscale.com/net/netcheck/captiveportal.go generated vendored Normal file
View File

@@ -0,0 +1,55 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build !ts_omit_captiveportal
package netcheck
import (
"context"
"time"
"tailscale.com/net/captivedetection"
"tailscale.com/tailcfg"
)
func init() {
hookStartCaptivePortalDetection.Set(startCaptivePortalDetection)
}
func startCaptivePortalDetection(ctx context.Context, rs *reportState, dm *tailcfg.DERPMap, preferredDERP int) (done <-chan struct{}, stop func()) {
c := rs.c
// NOTE(andrew): we can't simply add this goroutine to the
// `NewWaitGroupChan` below, since we don't wait for that
// waitgroup to finish when exiting this function and thus get
// a data race.
ch := make(chan struct{})
tmr := time.AfterFunc(c.captivePortalDelay(), func() {
defer close(ch)
d := captivedetection.NewDetector(c.logf)
found := d.Detect(ctx, c.NetMon, dm, preferredDERP)
rs.report.CaptivePortal.Set(found)
})
stop = func() {
// Don't cancel our captive portal check if we're
// explicitly doing a verbose netcheck.
if c.Verbose {
return
}
if tmr.Stop() {
// Stopped successfully; need to close the
// signal channel ourselves.
close(ch)
return
}
// Did not stop; do nothing and it'll finish by itself
// and close the signal channel.
}
return ch, stop
}

View File

@@ -23,15 +23,18 @@ import (
"syscall"
"time"
"tailscale.com/derp"
"tailscale.com/derp/derphttp"
"tailscale.com/envknob"
"tailscale.com/net/captivedetection"
"tailscale.com/feature"
"tailscale.com/feature/buildfeatures"
"tailscale.com/hostinfo"
"tailscale.com/net/dnscache"
"tailscale.com/net/neterror"
"tailscale.com/net/netmon"
"tailscale.com/net/netns"
"tailscale.com/net/ping"
"tailscale.com/net/portmapper"
"tailscale.com/net/portmapper/portmappertype"
"tailscale.com/net/sockstats"
"tailscale.com/net/stun"
"tailscale.com/syncs"
@@ -213,7 +216,7 @@ type Client struct {
// PortMapper, if non-nil, is used for portmap queries.
// If nil, portmap discovery is not done.
PortMapper *portmapper.Client // lazily initialized on first use
PortMapper portmappertype.Client
// UseDNSCache controls whether this client should use a
// *dnscache.Resolver to resolve DERP hostnames, when no IP address is
@@ -232,7 +235,7 @@ type Client struct {
testEnoughRegions int
testCaptivePortalDelay time.Duration
mu sync.Mutex // guards following
mu syncs.Mutex // guards following
nextFull bool // do a full region scan, even if last != nil
prev map[time.Time]*Report // some previous reports
last *Report // most recent report
@@ -448,7 +451,7 @@ func makeProbePlan(dm *tailcfg.DERPMap, ifState *netmon.State, last *Report, pre
// restoration back to the home DERP on the next full netcheck ~5 minutes later
// - which is highly disruptive when it causes shifts in geo routed subnet
// routers. By always including the home DERP in the incremental netcheck, we
// ensure that the home DERP is always probed, even if it observed a recenet
// ensure that the home DERP is always probed, even if it observed a recent
// poor latency sample. This inclusion enables the latency history checks in
// home DERP selection to still take effect.
// planContainsHome indicates whether the home DERP has been added to the probePlan,
@@ -594,7 +597,7 @@ type reportState struct {
stopProbeCh chan struct{}
waitPortMap sync.WaitGroup
mu sync.Mutex
mu syncs.Mutex
report *Report // to be returned by GetReport
inFlight map[stun.TxID]func(netip.AddrPort) // called without c.mu held
gotEP4 netip.AddrPort
@@ -728,7 +731,7 @@ func (rs *reportState) probePortMapServices() {
res, err := rs.c.PortMapper.Probe(context.Background())
if err != nil {
if !errors.Is(err, portmapper.ErrGatewayRange) {
if !errors.Is(err, portmappertype.ErrGatewayRange) {
// "skipping portmap; gateway range likely lacks support"
// is not very useful, and too spammy on cloud systems.
// If there are other errors, we want to log those.
@@ -752,6 +755,7 @@ func newReport() *Report {
// GetReportOpts contains options that can be passed to GetReport. Unless
// specified, all fields are optional and can be left as their zero value.
// At most one of OnlyTCP443 or OnlySTUN may be set.
type GetReportOpts struct {
// GetLastDERPActivity is a callback that, if provided, should return
// the absolute time that the calling code last communicated with a
@@ -764,6 +768,8 @@ type GetReportOpts struct {
// OnlyTCP443 constrains netcheck reporting to measurements over TCP port
// 443.
OnlyTCP443 bool
// OnlySTUN constrains netcheck reporting to STUN measurements over UDP.
OnlySTUN bool
}
// getLastDERPActivity calls o.GetLastDERPActivity if both o and
@@ -781,6 +787,8 @@ func (c *Client) SetForcePreferredDERP(region int) {
c.ForcePreferredDERP = region
}
var hookStartCaptivePortalDetection feature.Hook[func(ctx context.Context, rs *reportState, dm *tailcfg.DERPMap, preferredDERP int) (<-chan struct{}, func())]
// GetReport gets a report. The 'opts' argument is optional and can be nil.
// Callers are discouraged from passing a ctx with an arbitrary deadline as this
// may cause GetReport to return prematurely before all reporting methods have
@@ -789,6 +797,13 @@ func (c *Client) SetForcePreferredDERP(region int) {
//
// It may not be called concurrently with itself.
func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetReportOpts) (_ *Report, reterr error) {
onlySTUN := false
if opts != nil && opts.OnlySTUN {
if opts.OnlyTCP443 {
return nil, errors.New("netcheck: only one of OnlySTUN or OnlyTCP443 may be set in opts")
}
onlySTUN = true
}
defer func() {
if reterr != nil {
metricNumGetReportError.Add(1)
@@ -863,7 +878,10 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe
c.curState = nil
}()
if runtime.GOOS == "js" || runtime.GOOS == "tamago" {
if runtime.GOOS == "js" || runtime.GOOS == "tamago" || (runtime.GOOS == "plan9" && hostinfo.IsInVM86()) {
if onlySTUN {
return nil, errors.New("platform is restricted to HTTP, but OnlySTUN is set in opts")
}
if err := c.runHTTPOnlyChecks(ctx, last, rs, dm); err != nil {
return nil, err
}
@@ -895,38 +913,9 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe
// it's unnecessary.
captivePortalDone := syncs.ClosedChan()
captivePortalStop := func() {}
if !rs.incremental {
// NOTE(andrew): we can't simply add this goroutine to the
// `NewWaitGroupChan` below, since we don't wait for that
// waitgroup to finish when exiting this function and thus get
// a data race.
ch := make(chan struct{})
captivePortalDone = ch
tmr := time.AfterFunc(c.captivePortalDelay(), func() {
defer close(ch)
d := captivedetection.NewDetector(c.logf)
found := d.Detect(ctx, c.NetMon, dm, preferredDERP)
rs.report.CaptivePortal.Set(found)
})
captivePortalStop = func() {
// Don't cancel our captive portal check if we're
// explicitly doing a verbose netcheck.
if c.Verbose {
return
}
if tmr.Stop() {
// Stopped successfully; need to close the
// signal channel ourselves.
close(ch)
return
}
// Did not stop; do nothing and it'll finish by itself
// and close the signal channel.
}
if buildfeatures.HasCaptivePortal && !rs.incremental && !onlySTUN {
start := hookStartCaptivePortalDetection.Get()
captivePortalDone, captivePortalStop = start(ctx, rs, dm, preferredDERP)
}
wg := syncs.NewWaitGroupChan()
@@ -969,13 +958,13 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe
rs.stopTimers()
// Try HTTPS and ICMP latency check if all STUN probes failed due to
// UDP presumably being blocked.
// UDP presumably being blocked, and we are not constrained to only STUN.
// TODO: this should be moved into the probePlan, using probeProto probeHTTPS.
if !rs.anyUDP() && ctx.Err() == nil {
if !rs.anyUDP() && ctx.Err() == nil && !onlySTUN {
var wg sync.WaitGroup
var need []*tailcfg.DERPRegion
for rid, reg := range dm.Regions {
if !rs.haveRegionLatency(rid) && regionHasDERPNode(reg) {
if !rs.haveRegionLatency(rid) && regionHasDERPNode(reg) && !reg.Avoid && !reg.NoMeasureNoHome {
need = append(need, reg)
}
}
@@ -1004,9 +993,9 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe
c.logf("[v1] netcheck: measuring HTTPS latency of %v (%d): %v", reg.RegionCode, reg.RegionID, err)
} else {
rs.mu.Lock()
if l, ok := rs.report.RegionLatency[reg.RegionID]; !ok {
if latency, ok := rs.report.RegionLatency[reg.RegionID]; !ok {
mak.Set(&rs.report.RegionLatency, reg.RegionID, d)
} else if l >= d {
} else if latency >= d {
rs.report.RegionLatency[reg.RegionID] = d
}
// We set these IPv4 and IPv6 but they're not really used
@@ -1045,7 +1034,7 @@ func (c *Client) finishAndStoreReport(rs *reportState, dm *tailcfg.DERPMap) *Rep
}
// runHTTPOnlyChecks is the netcheck done by environments that can
// only do HTTP requests, such as ws/wasm.
// only do HTTP requests, such as js/wasm.
func (c *Client) runHTTPOnlyChecks(ctx context.Context, last *Report, rs *reportState, dm *tailcfg.DERPMap) error {
var regions []*tailcfg.DERPRegion
if rs.incremental && last != nil {
@@ -1057,9 +1046,25 @@ func (c *Client) runHTTPOnlyChecks(ctx context.Context, last *Report, rs *report
}
if len(regions) == 0 {
for _, dr := range dm.Regions {
if dr.NoMeasureNoHome {
continue
}
regions = append(regions, dr)
}
}
if len(regions) == 1 && hostinfo.IsInVM86() {
// If we only have 1 region that's probably and we're in a
// network-limited v86 environment, don't actually probe it. Just fake
// some results.
rg := regions[0]
if len(rg.Nodes) > 0 {
node := rg.Nodes[0]
rs.addNodeLatency(node, netip.AddrPort{}, 999*time.Millisecond)
return nil
}
}
c.logf("running HTTP-only netcheck against %v regions", len(regions))
var wg sync.WaitGroup
@@ -1068,7 +1073,6 @@ func (c *Client) runHTTPOnlyChecks(ctx context.Context, last *Report, rs *report
continue
}
wg.Add(1)
rg := rg
go func() {
defer wg.Done()
node := rg.Nodes[0]
@@ -1188,6 +1192,10 @@ func (c *Client) measureAllICMPLatency(ctx context.Context, rs *reportState, nee
if len(need) == 0 {
return nil
}
if runtime.GOOS == "plan9" {
// ICMP isn't implemented.
return nil
}
ctx, done := context.WithTimeout(ctx, icmpProbeTimeout)
defer done()
@@ -1206,9 +1214,9 @@ func (c *Client) measureAllICMPLatency(ctx context.Context, rs *reportState, nee
} else if ok {
c.logf("[v1] ICMP latency of %v (%d): %v", reg.RegionCode, reg.RegionID, d)
rs.mu.Lock()
if l, ok := rs.report.RegionLatency[reg.RegionID]; !ok {
if latency, ok := rs.report.RegionLatency[reg.RegionID]; !ok {
mak.Set(&rs.report.RegionLatency, reg.RegionID, d)
} else if l >= d {
} else if latency >= d {
rs.report.RegionLatency[reg.RegionID] = d
}
@@ -1337,6 +1345,15 @@ const (
// even without receiving a STUN response.
// Note: must remain higher than the derp package frameReceiveRecordRate
PreferredDERPFrameTime = 8 * time.Second
// PreferredDERPKeepAliveTimeout is 2x the DERP Keep Alive timeout. If there
// is no latency data to make judgements from, but we have heard from our
// current DERP region inside of 2x the KeepAlive window, don't switch DERP
// regions yet, keep the current region. This prevents region flapping /
// home DERP removal during short periods of packet loss where the DERP TCP
// connection may itself naturally recover.
// TODO(raggi): expose shared time bounds from the DERP package rather than
// duplicating them here.
PreferredDERPKeepAliveTimeout = 2 * derp.KeepAlive
)
// addReportHistoryAndSetPreferredDERP adds r to the set of recent Reports
@@ -1421,13 +1438,10 @@ func (c *Client) addReportHistoryAndSetPreferredDERP(rs *reportState, r *Report,
// the STUN probe) since we started the netcheck, or in the past 2s, as
// another signal for "this region is still working".
heardFromOldRegionRecently := false
prevRegionLastHeard := rs.opts.getLastDERPActivity(prevDERP)
if changingPreferred {
if lastHeard := rs.opts.getLastDERPActivity(prevDERP); !lastHeard.IsZero() {
now := c.timeNow()
heardFromOldRegionRecently = lastHeard.After(rs.start)
heardFromOldRegionRecently = heardFromOldRegionRecently || lastHeard.After(now.Add(-PreferredDERPFrameTime))
}
heardFromOldRegionRecently = prevRegionLastHeard.After(rs.start)
heardFromOldRegionRecently = heardFromOldRegionRecently || prevRegionLastHeard.After(now.Add(-PreferredDERPFrameTime))
}
// The old region is accessible if we've heard from it via a non-STUN
@@ -1454,17 +1468,20 @@ func (c *Client) addReportHistoryAndSetPreferredDERP(rs *reportState, r *Report,
// If the forced DERP region probed successfully, or has recent traffic,
// use it.
_, haveLatencySample := r.RegionLatency[c.ForcePreferredDERP]
var recentActivity bool
if lastHeard := rs.opts.getLastDERPActivity(c.ForcePreferredDERP); !lastHeard.IsZero() {
now := c.timeNow()
recentActivity = lastHeard.After(rs.start)
recentActivity = recentActivity || lastHeard.After(now.Add(-PreferredDERPFrameTime))
}
lastHeard := rs.opts.getLastDERPActivity(c.ForcePreferredDERP)
recentActivity := lastHeard.After(rs.start)
recentActivity = recentActivity || lastHeard.After(now.Add(-PreferredDERPFrameTime))
if haveLatencySample || recentActivity {
r.PreferredDERP = c.ForcePreferredDERP
}
}
// If there was no latency data to make judgements on, but there is an
// active DERP connection that has at least been doing KeepAlive recently,
// keep it, rather than dropping it.
if r.PreferredDERP == 0 && prevRegionLastHeard.After(now.Add(-PreferredDERPKeepAliveTimeout)) {
r.PreferredDERP = prevDERP
}
}
func updateLatency(m map[int]time.Duration, regionID int, d time.Duration) {

View File

@@ -13,7 +13,6 @@ import (
"tailscale.com/net/stun"
"tailscale.com/types/logger"
"tailscale.com/types/nettype"
"tailscale.com/util/multierr"
)
// Standalone creates the necessary UDP sockets on the given bindAddr and starts
@@ -62,7 +61,7 @@ func (c *Client) Standalone(ctx context.Context, bindAddr string) error {
// If both v4 and v6 failed, report an error, otherwise let one succeed.
if len(errs) == 2 {
return multierr.New(errs...)
return errors.Join(errs...)
}
return nil
}

View File

@@ -1,7 +1,7 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build !linux
//go:build !linux || android
package netkernelconf

View File

@@ -1,6 +1,8 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build linux && !android
package netkernelconf
import (

View File

@@ -6,6 +6,8 @@
package netmon
import (
"errors"
"fmt"
"log"
"net"
@@ -16,14 +18,26 @@ var (
lastKnownDefaultRouteIfName syncs.AtomicValue[string]
)
// UpdateLastKnownDefaultRouteInterface is called by ipn-go-bridge in the iOS app when
// UpdateLastKnownDefaultRouteInterface is called by ipn-go-bridge from apple network extensions when
// our NWPathMonitor instance detects a network path transition.
func UpdateLastKnownDefaultRouteInterface(ifName string) {
if ifName == "" {
return
}
if old := lastKnownDefaultRouteIfName.Swap(ifName); old != ifName {
log.Printf("defaultroute_darwin: update from Swift, ifName = %s (was %s)", ifName, old)
interfaces, err := netInterfaces()
if err != nil {
log.Printf("defaultroute_darwin: UpdateLastKnownDefaultRouteInterface could not get interfaces: %v", err)
return
}
netif, err := getInterfaceByName(ifName, interfaces)
if err != nil {
log.Printf("defaultroute_darwin: UpdateLastKnownDefaultRouteInterface could not find interface index for %s: %v", ifName, err)
return
}
log.Printf("defaultroute_darwin: updated last known default if from OS, ifName = %s index: %d (was %s)", ifName, netif.Index, old)
}
}
@@ -40,45 +54,12 @@ func defaultRoute() (d DefaultRouteDetails, err error) {
//
// If for any reason the Swift machinery didn't work and we don't get any updates, we will
// fallback to the BSD logic.
// Start by getting all available interfaces.
interfaces, err := netInterfaces()
if err != nil {
log.Printf("defaultroute_darwin: could not get interfaces: %v", err)
return d, ErrNoGatewayIndexFound
}
getInterfaceByName := func(name string) *Interface {
for _, ifc := range interfaces {
if ifc.Name != name {
continue
}
if !ifc.IsUp() {
log.Printf("defaultroute_darwin: %s is down", name)
return nil
}
addrs, _ := ifc.Addrs()
if len(addrs) == 0 {
log.Printf("defaultroute_darwin: %s has no addresses", name)
return nil
}
return &ifc
}
return nil
}
// Did Swift set lastKnownDefaultRouteInterface? If so, we should use it and don't bother
// with anything else. However, for sanity, do check whether Swift gave us with an interface
// that exists, is up, and has an address.
if swiftIfName := lastKnownDefaultRouteIfName.Load(); swiftIfName != "" {
ifc := getInterfaceByName(swiftIfName)
if ifc != nil {
d.InterfaceName = ifc.Name
d.InterfaceIndex = ifc.Index
return d, nil
}
osRoute, osRouteErr := OSDefaultRoute()
if osRouteErr == nil {
// If we got a valid interface from the OS, use it.
d.InterfaceName = osRoute.InterfaceName
d.InterfaceIndex = osRoute.InterfaceIndex
return d, nil
}
// Fallback to the BSD logic
@@ -94,3 +75,48 @@ func defaultRoute() (d DefaultRouteDetails, err error) {
d.InterfaceIndex = idx
return d, nil
}
// OSDefaultRoute returns the DefaultRouteDetails for the default interface as provided by the OS
// via UpdateLastKnownDefaultRouteInterface. If UpdateLastKnownDefaultRouteInterface has not been called,
// the interface name is not valid, or we cannot find its index, an error is returned.
func OSDefaultRoute() (d DefaultRouteDetails, err error) {
// Did Swift set lastKnownDefaultRouteInterface? If so, we should use it and don't bother
// with anything else. However, for sanity, do check whether Swift gave us with an interface
// that exists, is up, and has an address and is not the tunnel itself.
if swiftIfName := lastKnownDefaultRouteIfName.Load(); swiftIfName != "" {
// Start by getting all available interfaces.
interfaces, err := netInterfaces()
if err != nil {
log.Printf("defaultroute_darwin: could not get interfaces: %v", err)
return d, err
}
if ifc, err := getInterfaceByName(swiftIfName, interfaces); err == nil {
d.InterfaceName = ifc.Name
d.InterfaceIndex = ifc.Index
return d, nil
}
}
err = errors.New("no os provided default route interface found")
return d, err
}
func getInterfaceByName(name string, interfaces []Interface) (*Interface, error) {
for _, ifc := range interfaces {
if ifc.Name != name {
continue
}
if !ifc.IsUp() {
return nil, fmt.Errorf("defaultroute_darwin: %s is down", name)
}
addrs, _ := ifc.Addrs()
if len(addrs) == 0 {
return nil, fmt.Errorf("defaultroute_darwin: %s has no addresses", name)
}
return &ifc, nil
}
return nil, errors.New("no interfaces found")
}

103
vendor/tailscale.com/net/netmon/interfaces.go generated vendored Normal file
View File

@@ -0,0 +1,103 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package netmon
import (
"errors"
"net"
"tailscale.com/syncs"
)
type ifProps struct {
mu syncs.Mutex
name string // interface name, if known/set
index int // interface index, if known/set
}
// tsIfProps tracks the properties (name and index) of the tailscale interface.
// There is only one tailscale interface per tailscaled instance.
var tsIfProps ifProps
func (p *ifProps) tsIfName() string {
p.mu.Lock()
defer p.mu.Unlock()
return p.name
}
func (p *ifProps) tsIfIndex() int {
p.mu.Lock()
defer p.mu.Unlock()
return p.index
}
func (p *ifProps) set(ifName string, ifIndex int) {
p.mu.Lock()
defer p.mu.Unlock()
p.name = ifName
p.index = ifIndex
}
// TODO (barnstar): This doesn't need the Monitor receiver anymore but we're
// keeping it for API compatibility to avoid a breaking change.  This can be
// removed when the various clients have switched to SetTailscaleInterfaceProps
func (m *Monitor) SetTailscaleInterfaceName(ifName string) {
SetTailscaleInterfaceProps(ifName, 0)
}
// SetTailscaleInterfaceProps sets the name of the Tailscale interface and
// its index for use by various listeners/dialers. If the index is zero,
// an attempt will be made to look it up by name. This makes no attempt
// to validate that the interface exists at the time of calling.
//
// If this method is called, it is the responsibility of the caller to
// update the interface name and index if they change.
//
// This should be called as early as possible during tailscaled startup.
func SetTailscaleInterfaceProps(ifName string, ifIndex int) {
if ifIndex != 0 {
tsIfProps.set(ifName, ifIndex)
return
}
ifaces, err := net.Interfaces()
if err != nil {
return
}
for _, iface := range ifaces {
if iface.Name == ifName {
ifIndex = iface.Index
break
}
}
tsIfProps.set(ifName, ifIndex)
}
// TailscaleInterfaceName returns the name of the Tailscale interface.
// For example, "tailscale0", "tun0", "utun3", etc or an error if unset.
//
// Callers must handle errors, as the Tailscale interface
// name may not be set in some environments.
func TailscaleInterfaceName() (string, error) {
name := tsIfProps.tsIfName()
if name == "" {
return "", errors.New("Tailscale interface name not set")
}
return name, nil
}
// TailscaleInterfaceIndex returns the index of the Tailscale interface or
// an error if unset.
//
// Callers must handle errors, as the Tailscale interface
// index may not be set in some environments.
func TailscaleInterfaceIndex() (int, error) {
index := tsIfProps.tsIfIndex()
if index == 0 {
return 0, errors.New("Tailscale interface index not set")
}
return index, nil
}

View File

@@ -7,12 +7,12 @@ import (
"fmt"
"net"
"strings"
"sync"
"syscall"
"unsafe"
"golang.org/x/net/route"
"golang.org/x/sys/unix"
"tailscale.com/syncs"
"tailscale.com/util/mak"
)
@@ -26,7 +26,7 @@ func parseRoutingTable(rib []byte) ([]route.Message, error) {
}
var ifNames struct {
sync.Mutex
syncs.Mutex
m map[int]string // ifindex => name
}

View File

@@ -22,6 +22,7 @@ import (
"github.com/mdlayher/netlink"
"go4.org/mem"
"golang.org/x/sys/unix"
"tailscale.com/feature/buildfeatures"
"tailscale.com/net/netaddr"
"tailscale.com/util/lineiter"
)
@@ -41,6 +42,9 @@ ens18 00000000 0100000A 0003 0 0 0 00000000
ens18 0000000A 00000000 0001 0 0 0 0000FFFF 0 0 0
*/
func likelyHomeRouterIPLinux() (ret netip.Addr, myIP netip.Addr, ok bool) {
if !buildfeatures.HasPortMapper {
return
}
if procNetRouteErr.Load() {
// If we failed to read /proc/net/route previously, don't keep trying.
return ret, myIP, false

View File

@@ -13,6 +13,7 @@ import (
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"tailscale.com/feature/buildfeatures"
"tailscale.com/tsconst"
)
@@ -22,7 +23,9 @@ const (
func init() {
likelyHomeRouterIP = likelyHomeRouterIPWindows
getPAC = getPACWindows
if buildfeatures.HasUseProxy {
getPAC = getPACWindows
}
}
func likelyHomeRouterIPWindows() (ret netip.Addr, _ netip.Addr, ok bool) {
@@ -244,6 +247,9 @@ const (
)
func getPACWindows() string {
if !buildfeatures.HasUseProxy {
return ""
}
var res *uint16
r, _, e := detectAutoProxyConfigURL.Call(
winHTTP_AUTO_DETECT_TYPE_DHCP|winHTTP_AUTO_DETECT_TYPE_DNS_A,

View File

@@ -4,39 +4,46 @@
package netmon
import (
"context"
"sync"
"time"
"tailscale.com/types/logger"
"tailscale.com/util/eventbus"
)
const cooldownSeconds = 300
// LinkChangeLogLimiter returns a new [logger.Logf] that logs each unique
// format string to the underlying logger only once per major LinkChange event.
// format string to the underlying logger only once per major LinkChange event
// with a cooldownSeconds second cooldown.
//
// The returned function should be called when the logger is no longer needed,
// to release resources from the Monitor.
func LinkChangeLogLimiter(logf logger.Logf, nm *Monitor) (_ logger.Logf, unregister func()) {
var formatSeen sync.Map // map[string]bool
unregister = nm.RegisterChangeCallback(func(cd *ChangeDelta) {
// If we're in a major change or a time jump, clear the seen map.
if cd.Major || cd.TimeJumped {
formatSeen.Clear()
// The logger stops tracking seen format strings when the provided context is
// done.
func LinkChangeLogLimiter(ctx context.Context, logf logger.Logf, nm *Monitor) logger.Logf {
var formatLastSeen sync.Map // map[string]int64
sub := eventbus.SubscribeFunc(nm.b, func(cd *ChangeDelta) {
// Any link changes that are flagged as likely require a rebind are
// interesting enough that we should log them.
if cd.RebindLikelyRequired {
formatLastSeen.Clear()
}
})
context.AfterFunc(ctx, sub.Close)
return func(format string, args ...any) {
// We only store 'true' in the map, so if it's present then it
// means we've already logged this format string.
_, loaded := formatSeen.LoadOrStore(format, true)
if loaded {
// TODO(andrew-d): we may still want to log this
// message every N minutes (1x/hour?) even if it's been
// seen, so that debugging doesn't require searching
// back in the logs for an unbounded amount of time.
//
// See: https://github.com/tailscale/tailscale/issues/13145
return
// get the current timestamp
now := time.Now().Unix()
lastSeen, ok := formatLastSeen.Load(format)
if ok {
// if we've seen this format string within the last cooldownSeconds, skip logging
if now-lastSeen.(int64) < cooldownSeconds {
return
}
}
// update the last seen timestamp for this format string
formatLastSeen.Store(format, now)
logf(format, args...)
}, unregister
}
}

View File

@@ -7,15 +7,20 @@
package netmon
import (
"encoding/json"
"errors"
"fmt"
"log"
"net/netip"
"runtime"
"slices"
"sync"
"time"
"tailscale.com/feature/buildfeatures"
"tailscale.com/syncs"
"tailscale.com/types/logger"
"tailscale.com/util/clientmetric"
"tailscale.com/util/eventbus"
"tailscale.com/util/set"
)
@@ -42,27 +47,28 @@ type osMon interface {
// until the osMon is closed. After a Close, the returned
// error is ignored.
Receive() (message, error)
// IsInterestingInterface reports whether the provided interface should
// be considered for network change events.
IsInterestingInterface(iface string) bool
}
// IsInterestingInterface is the function used to determine whether
// a given interface name is interesting enough to pay attention to
// for network change monitoring purposes.
//
// If nil, all interfaces are considered interesting.
var IsInterestingInterface func(Interface, []netip.Prefix) bool
// Monitor represents a monitoring instance.
type Monitor struct {
logf logger.Logf
logf logger.Logf
b *eventbus.Client
changed *eventbus.Publisher[ChangeDelta]
om osMon // nil means not supported on this platform
change chan bool // send false to wake poller, true to also force ChangeDeltas be sent
stop chan struct{} // closed on Stop
static bool // static Monitor that doesn't actually monitor
// Things that must be set early, before use,
// and not change at runtime.
tsIfName string // tailscale interface name, if known/set ("tailscale0", "utun3", ...)
mu sync.Mutex // guards all following fields
mu syncs.Mutex // guards all following fields
cbs set.HandleSet[ChangeFunc]
ruleDelCB set.HandleSet[RuleDeleteCallback]
ifState *State
gwValid bool // whether gw and gwSelfIP are valid
gw netip.Addr // our gateway's IP
@@ -80,55 +86,246 @@ type Monitor struct {
type ChangeFunc func(*ChangeDelta)
// ChangeDelta describes the difference between two network states.
//
// Use NewChangeDelta to construct a delta and compute the cached fields.
type ChangeDelta struct {
// Monitor is the network monitor that sent this delta.
Monitor *Monitor
// Old is the old interface state, if known.
// old is the old interface state, if known.
// It's nil if the old state is unknown.
// Do not mutate it.
Old *State
old *State
// New is the new network state.
// It is always non-nil.
// Do not mutate it.
New *State
// Major is our legacy boolean of whether the network changed in some major
// way.
//
// Deprecated: do not remove. As of 2023-08-23 we're in a renewed effort to
// remove it and ask specific qustions of ChangeDelta instead. Look at Old
// and New (or add methods to ChangeDelta) instead of using Major.
Major bool
new *State
// TimeJumped is whether there was a big jump in wall time since the last
// time we checked. This is a hint that a mobile sleeping device might have
// time we checked. This is a hint that a sleeping device might have
// come out of sleep.
TimeJumped bool
// TODO(bradfitz): add some lazy cached fields here as needed with methods
// on *ChangeDelta to let callers ask specific questions
DefaultRouteInterface string
// Computed Fields
DefaultInterfaceChanged bool // whether default route interface changed
IsLessExpensive bool // whether new state's default interface is less expensive than old.
HasPACOrProxyConfigChanged bool // whether PAC/HTTP proxy config changed
InterfaceIPsChanged bool // whether any interface IPs changed in a meaningful way
AvailableProtocolsChanged bool // whether we have seen a change in available IPv4/IPv6
DefaultInterfaceMaybeViable bool // whether the default interface is potentially viable (has usable IPs, is up and is not the tunnel itself)
IsInitialState bool // whether this is the initial state (old == nil, new != nil)
// RebindLikelyRequired combines the various fields above to report whether this change likely requires us
// to rebind sockets. This is a very conservative estimate and covers a number ofcases where a rebind
// may not be strictly necessary. Consumers of the ChangeDelta should consider checking the individual fields
// above or the state of their sockets.
RebindLikelyRequired bool
}
// CurrentState returns the current (new) state after the change.
func (cd *ChangeDelta) CurrentState() *State {
return cd.new
}
// NewChangeDelta builds a ChangeDelta and eagerly computes the cached fields.
// forceViability, if true, forces DefaultInterfaceMaybeViable to be true regardless of the
// actual state of the default interface. This is useful in testing.
func NewChangeDelta(old, new *State, timeJumped bool, forceViability bool) (*ChangeDelta, error) {
cd := ChangeDelta{
old: old,
new: new,
TimeJumped: timeJumped,
}
if cd.new == nil {
log.Printf("[unexpected] NewChangeDelta called with nil new state")
return nil, errors.New("new state cannot be nil")
} else if cd.old == nil && cd.new != nil {
cd.DefaultInterfaceChanged = cd.new.DefaultRouteInterface != ""
cd.IsLessExpensive = false
cd.HasPACOrProxyConfigChanged = true
cd.InterfaceIPsChanged = true
cd.IsInitialState = true
} else {
cd.AvailableProtocolsChanged = (cd.old.HaveV4 != cd.new.HaveV4) || (cd.old.HaveV6 != cd.new.HaveV6)
cd.DefaultInterfaceChanged = cd.old.DefaultRouteInterface != cd.new.DefaultRouteInterface
cd.IsLessExpensive = cd.old.IsExpensive && !cd.new.IsExpensive
cd.HasPACOrProxyConfigChanged = (cd.old.PAC != cd.new.PAC) || (cd.old.HTTPProxy != cd.new.HTTPProxy)
cd.InterfaceIPsChanged = cd.isInterestingInterfaceChange()
}
cd.DefaultRouteInterface = new.DefaultRouteInterface
defIf := new.Interface[cd.DefaultRouteInterface]
tsIfName, err := TailscaleInterfaceName()
// The default interface is not viable if it is down or it is the Tailscale interface itself.
if !forceViability && (!defIf.IsUp() || (err == nil && cd.DefaultRouteInterface == tsIfName)) {
cd.DefaultInterfaceMaybeViable = false
} else {
cd.DefaultInterfaceMaybeViable = true
}
// Compute rebind requirement. The default interface needs to be viable and
// one of the other conditions needs to be true.
cd.RebindLikelyRequired = (cd.old == nil ||
cd.TimeJumped ||
cd.DefaultInterfaceChanged ||
cd.InterfaceIPsChanged ||
cd.IsLessExpensive ||
cd.HasPACOrProxyConfigChanged ||
cd.AvailableProtocolsChanged) &&
cd.DefaultInterfaceMaybeViable
return &cd, nil
}
// StateDesc returns a description of the old and new states for logging.
func (cd *ChangeDelta) StateDesc() string {
return fmt.Sprintf("old: %v new: %v", cd.old, cd.new)
}
// InterfaceIPDisappeared reports whether the given IP address exists on any interface
// in the old state, but not in the new state.
func (cd *ChangeDelta) InterfaceIPDisappeared(ip netip.Addr) bool {
if cd.old == nil {
return false
}
if cd.new == nil && cd.old.HasIP(ip) {
return true
}
return cd.new.HasIP(ip) && !cd.old.HasIP(ip)
}
// AnyInterfaceUp reports whether any interfaces are up in the new state.
func (cd *ChangeDelta) AnyInterfaceUp() bool {
if cd.new == nil {
return false
}
for _, ifi := range cd.new.Interface {
if ifi.IsUp() {
return true
}
}
return false
}
// isInterestingInterfaceChange reports whether any interfaces have changed in a meaningful way.
// This excludes interfaces that are not interesting per IsInterestingInterface and
// filters out changes to interface IPs that that are uninteresting (e.g. link-local addresses).
func (cd *ChangeDelta) isInterestingInterfaceChange() bool {
// If there is no old state, everything is considered changed.
if cd.old == nil {
return true
}
// Compare interfaces in both directions. Old to new and new to old.
tsIfName, ifNameErr := TailscaleInterfaceName()
for iname, oldInterface := range cd.old.Interface {
if ifNameErr == nil && iname == tsIfName {
// Ignore changes in the Tailscale interface itself
continue
}
oldIps := filterRoutableIPs(cd.old.InterfaceIPs[iname])
if IsInterestingInterface != nil && !IsInterestingInterface(oldInterface, oldIps) {
continue
}
// Old interfaces with no routable addresses are not interesting
if len(oldIps) == 0 {
continue
}
// The old interface doesn't exist in the new interface set and it has
// a global unicast IP. That's considered a change from the perspective
// of anything that may have been bound to it. If it didn't have a global
// unicast IP, it's not interesting.
newInterface, ok := cd.new.Interface[iname]
if !ok {
return true
}
newIps, ok := cd.new.InterfaceIPs[iname]
if !ok {
return true
}
newIps = filterRoutableIPs(newIps)
if !oldInterface.Equal(newInterface) || !prefixesEqual(oldIps, newIps) {
return true
}
}
for iname, newInterface := range cd.new.Interface {
if ifNameErr == nil && iname == tsIfName {
// Ignore changes in the Tailscale interface itself
continue
}
newIps := filterRoutableIPs(cd.new.InterfaceIPs[iname])
if IsInterestingInterface != nil && !IsInterestingInterface(newInterface, newIps) {
continue
}
// New interfaces with no routable addresses are not interesting
if len(newIps) == 0 {
continue
}
oldInterface, ok := cd.old.Interface[iname]
if !ok {
return true
}
oldIps, ok := cd.old.InterfaceIPs[iname]
if !ok {
// Redundant but we can't dig up the "old" IPs for this interface.
return true
}
oldIps = filterRoutableIPs(oldIps)
// The interface's IPs, Name, MTU, etc have changed. This is definitely interesting.
if !newInterface.Equal(oldInterface) || !prefixesEqual(oldIps, newIps) {
return true
}
}
return false
}
func filterRoutableIPs(addrs []netip.Prefix) []netip.Prefix {
var filtered []netip.Prefix
for _, pfx := range addrs {
a := pfx.Addr()
// Skip link-local multicast addresses.
if a.IsLinkLocalMulticast() {
continue
}
if isUsableV4(a) || isUsableV6(a) {
filtered = append(filtered, pfx)
}
}
return filtered
}
// New instantiates and starts a monitoring instance.
// The returned monitor is inactive until it's started by the Start method.
// Use RegisterChangeCallback to get notified of network changes.
func New(logf logger.Logf) (*Monitor, error) {
func New(bus *eventbus.Bus, logf logger.Logf) (*Monitor, error) {
logf = logger.WithPrefix(logf, "monitor: ")
m := &Monitor{
logf: logf,
b: bus.Client("netmon"),
change: make(chan bool, 1),
stop: make(chan struct{}),
lastWall: wallTime(),
}
m.changed = eventbus.Publish[ChangeDelta](m.b)
st, err := m.interfaceStateUncached()
if err != nil {
return nil, err
}
m.ifState = st
m.om, err = newOSMon(logf, m)
m.om, err = newOSMon(bus, logf, m)
if err != nil {
return nil, err
}
@@ -161,16 +358,7 @@ func (m *Monitor) InterfaceState() *State {
}
func (m *Monitor) interfaceStateUncached() (*State, error) {
return getState(m.tsIfName)
}
// SetTailscaleInterfaceName sets the name of the Tailscale interface. For
// example, "tailscale0", "tun0", "utun3", etc.
//
// This must be called only early in tailscaled startup before the monitor is
// used.
func (m *Monitor) SetTailscaleInterfaceName(ifName string) {
m.tsIfName = ifName
return getState(tsIfProps.tsIfName())
}
// GatewayAndSelfIP returns the current network's default gateway, and
@@ -179,6 +367,9 @@ func (m *Monitor) SetTailscaleInterfaceName(ifName string) {
// It's the same as interfaces.LikelyHomeRouterIP, but it caches the
// result until the monitor detects a network change.
func (m *Monitor) GatewayAndSelfIP() (gw, myIP netip.Addr, ok bool) {
if !buildfeatures.HasPortMapper {
return
}
if m.static {
return
}
@@ -218,29 +409,6 @@ func (m *Monitor) RegisterChangeCallback(callback ChangeFunc) (unregister func()
}
}
// RuleDeleteCallback is a callback when a Linux IP policy routing
// rule is deleted. The table is the table number (52, 253, 354) and
// priority is the priority order number (for Tailscale rules
// currently: 5210, 5230, 5250, 5270)
type RuleDeleteCallback func(table uint8, priority uint32)
// RegisterRuleDeleteCallback adds callback to the set of parties to be
// notified (in their own goroutine) when a Linux ip rule is deleted.
// To remove this callback, call unregister (or close the monitor).
func (m *Monitor) RegisterRuleDeleteCallback(callback RuleDeleteCallback) (unregister func()) {
if m.static {
return func() {}
}
m.mu.Lock()
defer m.mu.Unlock()
handle := m.ruleDelCB.Add(callback)
return func() {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.ruleDelCB, handle)
}
}
// Start starts the monitor.
// A monitor can only be started & closed once.
func (m *Monitor) Start() {
@@ -353,10 +521,6 @@ func (m *Monitor) pump() {
time.Sleep(time.Second)
continue
}
if rdm, ok := msg.(ipRuleDeletedMessage); ok {
m.notifyRuleDeleted(rdm)
continue
}
if msg.ignore() {
continue
}
@@ -364,25 +528,6 @@ func (m *Monitor) pump() {
}
}
func (m *Monitor) notifyRuleDeleted(rdm ipRuleDeletedMessage) {
m.mu.Lock()
defer m.mu.Unlock()
for _, cb := range m.ruleDelCB {
go cb(rdm.table, rdm.priority)
}
}
// isInterestingInterface reports whether the provided interface should be
// considered when checking for network state changes.
// The ips parameter should be the IPs of the provided interface.
func (m *Monitor) isInterestingInterface(i Interface, ips []netip.Prefix) bool {
if !m.om.IsInterestingInterface(i.Name) {
return false
}
return true
}
// debounce calls the callback function with a delay between events
// and exits when a stop is issued.
func (m *Monitor) debounce() {
@@ -404,7 +549,10 @@ func (m *Monitor) debounce() {
select {
case <-m.stop:
return
case <-time.After(250 * time.Millisecond):
// 1s is reasonable debounce time for network changes. Events such as undocking a laptop
// or roaming onto wifi will often generate multiple events in quick succession as interfaces
// flap. We want to avoid spamming consumers of these events.
case <-time.After(1000 * time.Millisecond):
}
}
}
@@ -431,146 +579,51 @@ func (m *Monitor) handlePotentialChange(newState *State, forceCallbacks bool) {
return
}
delta := &ChangeDelta{
Monitor: m,
Old: oldState,
New: newState,
TimeJumped: timeJumped,
delta, err := NewChangeDelta(oldState, newState, timeJumped, false)
if err != nil {
m.logf("[unexpected] error creating ChangeDelta: %v", err)
return
}
delta.Major = m.IsMajorChangeFrom(oldState, newState)
if delta.Major {
if delta.RebindLikelyRequired {
m.gwValid = false
m.ifState = newState
if s1, s2 := oldState.String(), delta.New.String(); s1 == s2 {
m.logf("[unexpected] network state changed, but stringification didn't: %v", s1)
m.logf("[unexpected] old: %s", jsonSummary(oldState))
m.logf("[unexpected] new: %s", jsonSummary(newState))
}
}
m.ifState = newState
// See if we have a queued or new time jump signal.
if timeJumped {
m.resetTimeJumpedLocked()
if !delta.Major {
// Only log if it wasn't an interesting change.
m.logf("time jumped (probably wake from sleep); synthesizing major change event")
delta.Major = true
}
}
metricChange.Add(1)
if delta.Major {
if delta.RebindLikelyRequired {
metricChangeMajor.Add(1)
}
if delta.TimeJumped {
metricChangeTimeJump.Add(1)
}
m.changed.Publish(*delta)
for _, cb := range m.cbs {
go cb(delta)
}
}
// IsMajorChangeFrom reports whether the transition from s1 to s2 is
// a "major" change, where major roughly means it's worth tearing down
// a bunch of connections and rebinding.
//
// TODO(bradiftz): tigten this definition.
func (m *Monitor) IsMajorChangeFrom(s1, s2 *State) bool {
if s1 == nil && s2 == nil {
// reports whether a and b contain the same set of prefixes regardless of order.
func prefixesEqual(a, b []netip.Prefix) bool {
if len(a) != len(b) {
return false
}
if s1 == nil || s2 == nil {
return true
}
if s1.HaveV6 != s2.HaveV6 ||
s1.HaveV4 != s2.HaveV4 ||
s1.IsExpensive != s2.IsExpensive ||
s1.DefaultRouteInterface != s2.DefaultRouteInterface ||
s1.HTTPProxy != s2.HTTPProxy ||
s1.PAC != s2.PAC {
return true
}
for iname, i := range s1.Interface {
if iname == m.tsIfName {
// Ignore changes in the Tailscale interface itself.
continue
}
ips := s1.InterfaceIPs[iname]
if !m.isInterestingInterface(i, ips) {
continue
}
i2, ok := s2.Interface[iname]
if !ok {
return true
}
ips2, ok := s2.InterfaceIPs[iname]
if !ok {
return true
}
if !i.Equal(i2) || !prefixesMajorEqual(ips, ips2) {
return true
}
}
// Iterate over s2 in case there is a field in s2 that doesn't exist in s1
for iname, i := range s2.Interface {
if iname == m.tsIfName {
// Ignore changes in the Tailscale interface itself.
continue
}
ips := s2.InterfaceIPs[iname]
if !m.isInterestingInterface(i, ips) {
continue
}
i1, ok := s1.Interface[iname]
if !ok {
return true
}
ips1, ok := s1.InterfaceIPs[iname]
if !ok {
return true
}
if !i.Equal(i1) || !prefixesMajorEqual(ips, ips1) {
return true
}
}
return false
}
// prefixesMajorEqual reports whether a and b are equal after ignoring
// boring things like link-local, loopback, and multicast addresses.
func prefixesMajorEqual(a, b []netip.Prefix) bool {
// trim returns a subslice of p with link local unicast,
// loopback, and multicast prefixes removed from the front.
trim := func(p []netip.Prefix) []netip.Prefix {
for len(p) > 0 {
a := p[0].Addr()
if a.IsLinkLocalUnicast() || a.IsLoopback() || a.IsMulticast() {
p = p[1:]
continue
}
break
}
return p
}
for {
a = trim(a)
b = trim(b)
if len(a) == 0 || len(b) == 0 {
return len(a) == 0 && len(b) == 0
}
if a[0] != b[0] {
return false
}
a, b = a[1:], b[1:]
}
}
aa := make([]netip.Prefix, len(a))
bb := make([]netip.Prefix, len(b))
copy(aa, a)
copy(bb, b)
func jsonSummary(x any) any {
j, err := json.Marshal(x)
if err != nil {
return err
less := func(x, y netip.Prefix) int {
return x.Addr().Compare(y.Addr())
}
return j
slices.SortFunc(aa, less)
slices.SortFunc(bb, less)
return slices.Equal(aa, bb)
}
func wallTime() time.Time {
@@ -596,7 +649,7 @@ func (m *Monitor) pollWallTime() {
//
// We don't do this on mobile platforms for battery reasons, and because these
// platforms don't really sleep in the same way.
const shouldMonitorTimeJump = runtime.GOOS != "android" && runtime.GOOS != "ios"
const shouldMonitorTimeJump = runtime.GOOS != "android" && runtime.GOOS != "ios" && runtime.GOOS != "plan9"
// checkWallTimeAdvanceLocked reports whether wall time jumped more than 150% of
// pollWallTimeInterval, indicating we probably just came out of sleep. Once a
@@ -617,10 +670,3 @@ func (m *Monitor) checkWallTimeAdvanceLocked() bool {
func (m *Monitor) resetTimeJumpedLocked() {
m.timeJumped = false
}
type ipRuleDeletedMessage struct {
table uint8
priority uint32
}
func (ipRuleDeletedMessage) ignore() bool { return true }

View File

@@ -13,8 +13,15 @@ import (
"golang.org/x/sys/unix"
"tailscale.com/net/netaddr"
"tailscale.com/types/logger"
"tailscale.com/util/eventbus"
)
func init() {
IsInterestingInterface = func(iface Interface, prefixes []netip.Prefix) bool {
return isInterestingInterface(iface.Name)
}
}
const debugRouteMessages = false
// unspecifiedMessage is a minimal message implementation that should not
@@ -24,7 +31,7 @@ type unspecifiedMessage struct{}
func (unspecifiedMessage) ignore() bool { return false }
func newOSMon(logf logger.Logf, _ *Monitor) (osMon, error) {
func newOSMon(_ *eventbus.Bus, logf logger.Logf, _ *Monitor) (osMon, error) {
fd, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, 0)
if err != nil {
return nil, err
@@ -124,11 +131,10 @@ func addrType(addrs []route.Addr, rtaxType int) route.Addr {
return nil
}
func (m *darwinRouteMon) IsInterestingInterface(iface string) bool {
func isInterestingInterface(iface string) bool {
baseName := strings.TrimRight(iface, "0123456789")
switch baseName {
// TODO(maisem): figure out what this list should actually be.
case "llw", "awdl", "ipsec":
case "llw", "awdl", "ipsec", "gif", "XHC", "anpi", "lo", "utun":
return false
}
return true
@@ -136,7 +142,7 @@ func (m *darwinRouteMon) IsInterestingInterface(iface string) bool {
func (m *darwinRouteMon) skipInterfaceAddrMessage(msg *route.InterfaceAddrMessage) bool {
if la, ok := addrType(msg.Addrs, unix.RTAX_IFP).(*route.LinkAddr); ok {
if !m.IsInterestingInterface(la.Name) {
if !isInterestingInterface(la.Name) {
return true
}
}
@@ -149,6 +155,14 @@ func (m *darwinRouteMon) skipRouteMessage(msg *route.RouteMessage) bool {
// dst = fe80::b476:66ff:fe30:c8f6%15
return true
}
// We can skip route messages from uninteresting interfaces. We do this upstream
// against the InterfaceMonitor, but skipping them here avoids unnecessary work.
if la, ok := addrType(msg.Addrs, unix.RTAX_IFP).(*route.LinkAddr); ok {
if !isInterestingInterface(la.Name) {
return true
}
}
return false
}

View File

@@ -10,6 +10,7 @@ import (
"strings"
"tailscale.com/types/logger"
"tailscale.com/util/eventbus"
)
// unspecifiedMessage is a minimal message implementation that should not
@@ -24,7 +25,7 @@ type devdConn struct {
conn net.Conn
}
func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) {
func newOSMon(_ *eventbus.Bus, logf logger.Logf, m *Monitor) (osMon, error) {
conn, err := net.Dial("unixpacket", "/var/run/devd.seqpacket.pipe")
if err != nil {
logf("devd dial error: %v, falling back to polling method", err)
@@ -33,8 +34,6 @@ func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) {
return &devdConn{conn}, nil
}
func (c *devdConn) IsInterestingInterface(iface string) bool { return true }
func (c *devdConn) Close() error {
return c.conn.Close()
}

View File

@@ -16,6 +16,7 @@ import (
"tailscale.com/envknob"
"tailscale.com/net/tsaddr"
"tailscale.com/types/logger"
"tailscale.com/util/eventbus"
)
var debugNetlinkMessages = envknob.RegisterBool("TS_DEBUG_NETLINK")
@@ -27,15 +28,26 @@ type unspecifiedMessage struct{}
func (unspecifiedMessage) ignore() bool { return false }
// RuleDeleted reports that one of Tailscale's policy routing rules
// was deleted.
type RuleDeleted struct {
// Table is the table number that the deleted rule referenced.
Table uint8
// Priority is the lookup priority of the deleted rule.
Priority uint32
}
// nlConn wraps a *netlink.Conn and returns a monitor.Message
// instead of a netlink.Message. Currently, messages are discarded,
// but down the line, when messages trigger different logic depending
// on the type of event, this provides the capability of handling
// each architecture-specific message in a generic fashion.
type nlConn struct {
logf logger.Logf
conn *netlink.Conn
buffered []netlink.Message
busClient *eventbus.Client
rulesDeleted *eventbus.Publisher[RuleDeleted]
logf logger.Logf
conn *netlink.Conn
buffered []netlink.Message
// addrCache maps interface indices to a set of addresses, and is
// used to suppress duplicate RTM_NEWADDR messages. It is populated
@@ -44,7 +56,7 @@ type nlConn struct {
addrCache map[uint32]map[netip.Addr]bool
}
func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) {
func newOSMon(bus *eventbus.Bus, logf logger.Logf, m *Monitor) (osMon, error) {
conn, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{
// Routes get us most of the events of interest, but we need
// address as well to cover things like DHCP deciding to give
@@ -59,12 +71,20 @@ func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) {
logf("monitor_linux: AF_NETLINK RTMGRP failed, falling back to polling")
return newPollingMon(logf, m)
}
return &nlConn{logf: logf, conn: conn, addrCache: make(map[uint32]map[netip.Addr]bool)}, nil
client := bus.Client("netmon-iprules")
return &nlConn{
busClient: client,
rulesDeleted: eventbus.Publish[RuleDeleted](client),
logf: logf,
conn: conn,
addrCache: make(map[uint32]map[netip.Addr]bool),
}, nil
}
func (c *nlConn) IsInterestingInterface(iface string) bool { return true }
func (c *nlConn) Close() error { return c.conn.Close() }
func (c *nlConn) Close() error {
c.busClient.Close()
return c.conn.Close()
}
func (c *nlConn) Receive() (message, error) {
if len(c.buffered) == 0 {
@@ -219,14 +239,15 @@ func (c *nlConn) Receive() (message, error) {
// On `ip -4 rule del pref 5210 table main`, logs:
// monitor: ip rule deleted: {Family:2 DstLength:0 SrcLength:0 Tos:0 Table:254 Protocol:0 Scope:0 Type:1 Flags:0 Attributes:{Dst:<nil> Src:<nil> Gateway:<nil> OutIface:0 Priority:5210 Table:254 Mark:4294967295 Expires:<nil> Metrics:<nil> Multipath:[]}}
}
rdm := ipRuleDeletedMessage{
table: rmsg.Table,
priority: rmsg.Attributes.Priority,
rd := RuleDeleted{
Table: rmsg.Table,
Priority: rmsg.Attributes.Priority,
}
c.rulesDeleted.Publish(rd)
if debugNetlinkMessages() {
c.logf("%+v", rdm)
c.logf("%+v", rd)
}
return rdm, nil
return ignoreMessage{}, nil
case unix.RTM_NEWLINK, unix.RTM_DELLINK:
// This is an unhandled message, but don't print an error.
// See https://github.com/tailscale/tailscale/issues/6806

View File

@@ -7,9 +7,10 @@ package netmon
import (
"tailscale.com/types/logger"
"tailscale.com/util/eventbus"
)
func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) {
func newOSMon(_ *eventbus.Bus, logf logger.Logf, m *Monitor) (osMon, error) {
return newPollingMon(logf, m)
}

View File

@@ -13,6 +13,7 @@ import (
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"tailscale.com/net/tsaddr"
"tailscale.com/types/logger"
"tailscale.com/util/eventbus"
)
var (
@@ -45,7 +46,7 @@ type winMon struct {
noDeadlockTicker *time.Ticker
}
func newOSMon(logf logger.Logf, pm *Monitor) (osMon, error) {
func newOSMon(_ *eventbus.Bus, logf logger.Logf, pm *Monitor) (osMon, error) {
m := &winMon{
logf: logf,
isActive: pm.isActive,
@@ -73,8 +74,6 @@ func newOSMon(logf logger.Logf, pm *Monitor) (osMon, error) {
return m, nil
}
func (m *winMon) IsInterestingInterface(iface string) bool { return true }
func (m *winMon) Close() (ret error) {
m.cancel()
m.noDeadlockTicker.Stop()

View File

@@ -35,10 +35,6 @@ type pollingMon struct {
stop chan struct{}
}
func (pm *pollingMon) IsInterestingInterface(iface string) bool {
return true
}
func (pm *pollingMon) Close() error {
pm.closeOnce.Do(func() {
close(pm.stop)

View File

@@ -15,10 +15,11 @@ import (
"strings"
"tailscale.com/envknob"
"tailscale.com/feature"
"tailscale.com/feature/buildfeatures"
"tailscale.com/hostinfo"
"tailscale.com/net/netaddr"
"tailscale.com/net/tsaddr"
"tailscale.com/net/tshttpproxy"
"tailscale.com/util/mak"
)
@@ -148,12 +149,28 @@ type Interface struct {
Desc string // extra description (used on Windows)
}
func (i Interface) IsLoopback() bool { return isLoopback(i.Interface) }
func (i Interface) IsUp() bool { return isUp(i.Interface) }
func (i Interface) IsLoopback() bool {
if i.Interface == nil {
return false
}
return isLoopback(i.Interface)
}
func (i Interface) IsUp() bool {
if i.Interface == nil {
return false
}
return isUp(i.Interface)
}
func (i Interface) Addrs() ([]net.Addr, error) {
if i.AltAddrs != nil {
return i.AltAddrs, nil
}
if i.Interface == nil {
return nil, nil
}
return i.Interface.Addrs()
}
@@ -182,6 +199,10 @@ func (ifaces InterfaceList) ForeachInterfaceAddress(fn func(Interface, netip.Pre
if pfx, ok := netaddr.FromStdIPNet(v); ok {
fn(iface, pfx)
}
case *net.IPAddr:
if ip, ok := netip.AddrFromSlice(v.IP); ok {
fn(iface, netip.PrefixFrom(ip, ip.BitLen()))
}
}
}
}
@@ -214,6 +235,10 @@ func (ifaces InterfaceList) ForeachInterface(fn func(Interface, []netip.Prefix))
if pfx, ok := netaddr.FromStdIPNet(v); ok {
pfxs = append(pfxs, pfx)
}
case *net.IPAddr:
if ip, ok := netip.AddrFromSlice(v.IP); ok {
pfxs = append(pfxs, netip.PrefixFrom(ip, ip.BitLen()))
}
}
}
sort.Slice(pfxs, func(i, j int) bool {
@@ -445,15 +470,22 @@ func hasTailscaleIP(pfxs []netip.Prefix) bool {
}
func isTailscaleInterface(name string, ips []netip.Prefix) bool {
// Sandboxed macOS and Plan9 (and anything else that explicitly calls SetTailscaleInterfaceProps).
tsIfName, err := TailscaleInterfaceName()
if err == nil {
// If we've been told the Tailscale interface name, use that.
return name == tsIfName
}
// The sandboxed app should (as of 1.92) set the tun interface name via SetTailscaleInterfaceProps
// early in the startup process. The non-sandboxed app does not.
// TODO (barnstar): If Wireguard created the tun device on darwin, it should know the name and it should
// be explicitly set instead checking addresses here.
if runtime.GOOS == "darwin" && strings.HasPrefix(name, "utun") && hasTailscaleIP(ips) {
// On macOS in the sandboxed app (at least as of
// 2021-02-25), we often see two utun devices
// (e.g. utun4 and utun7) with the same IPv4 and IPv6
// addresses. Just remove all utun devices with
// Tailscale IPs until we know what's happening with
// macOS NetworkExtensions and utun devices.
return true
}
// Windows, Linux...
return name == "Tailscale" || // as it is on Windows
strings.HasPrefix(name, "tailscale") // TODO: use --tun flag value, etc; see TODO in method doc
}
@@ -476,9 +508,16 @@ func getState(optTSInterfaceName string) (*State, error) {
ifUp := ni.IsUp()
s.Interface[ni.Name] = ni
s.InterfaceIPs[ni.Name] = append(s.InterfaceIPs[ni.Name], pfxs...)
// Skip uninteresting interfaces
if IsInterestingInterface != nil && !IsInterestingInterface(ni, pfxs) {
return
}
if !ifUp || isTSInterfaceName || isTailscaleInterface(ni.Name, pfxs) {
return
}
for _, pfx := range pfxs {
if pfx.Addr().IsLoopback() {
continue
@@ -501,13 +540,15 @@ func getState(optTSInterfaceName string) (*State, error) {
}
}
if s.AnyInterfaceUp() {
if buildfeatures.HasUseProxy && s.AnyInterfaceUp() {
req, err := http.NewRequest("GET", LoginEndpointForProxyDetermination, nil)
if err != nil {
return nil, err
}
if u, err := tshttpproxy.ProxyFromEnvironment(req); err == nil && u != nil {
s.HTTPProxy = u.String()
if proxyFromEnv, ok := feature.HookProxyFromEnvironment.GetOk(); ok {
if u, err := proxyFromEnv(req); err == nil && u != nil {
s.HTTPProxy = u.String()
}
}
if getPAC != nil {
s.PAC = getPAC()
@@ -570,6 +611,9 @@ var disableLikelyHomeRouterIPSelf = envknob.RegisterBool("TS_DEBUG_DISABLE_LIKEL
// the LAN using that gateway.
// This is used as the destination for UPnP, NAT-PMP, PCP, etc queries.
func LikelyHomeRouterIP() (gateway, myIP netip.Addr, ok bool) {
if !buildfeatures.HasPortMapper {
return
}
// If we don't have a way to get the home router IP, then we can't do
// anything; just return.
if likelyHomeRouterIP == nil {
@@ -760,8 +804,7 @@ func (m *Monitor) HasCGNATInterface() (bool, error) {
hasCGNATInterface := false
cgnatRange := tsaddr.CGNATRange()
err := ForeachInterface(func(i Interface, pfxs []netip.Prefix) {
isTSInterfaceName := m.tsIfName != "" && i.Name == m.tsIfName
if hasCGNATInterface || !i.IsUp() || isTSInterfaceName || isTailscaleInterface(i.Name, pfxs) {
if hasCGNATInterface || !i.IsUp() || isTailscaleInterface(i.Name, pfxs) {
return
}
for _, pfx := range pfxs {

View File

@@ -17,6 +17,7 @@ import (
"context"
"net"
"net/netip"
"runtime"
"sync/atomic"
"tailscale.com/net/netknob"
@@ -39,18 +40,36 @@ var bindToInterfaceByRoute atomic.Bool
// setting the TS_BIND_TO_INTERFACE_BY_ROUTE.
//
// Currently, this only changes the behaviour on macOS and Windows.
func SetBindToInterfaceByRoute(v bool) {
bindToInterfaceByRoute.Store(v)
func SetBindToInterfaceByRoute(logf logger.Logf, v bool) {
if bindToInterfaceByRoute.Swap(v) != v {
logf("netns: bindToInterfaceByRoute changed to %v", v)
}
}
var disableBindConnToInterface atomic.Bool
// SetDisableBindConnToInterface disables the (normal) behavior of binding
// connections to the default network interface.
// connections to the default network interface on Darwin nodes.
//
// Currently, this only has an effect on Darwin.
func SetDisableBindConnToInterface(v bool) {
disableBindConnToInterface.Store(v)
// Unless you intended to disable this for tailscaled on macos (which is likely
// to break things), you probably wanted to set
// SetDisableBindConnToInterfaceAppleExt which will disable explicit interface
// binding only when tailscaled is running inside a network extension process.
func SetDisableBindConnToInterface(logf logger.Logf, v bool) {
if disableBindConnToInterface.Swap(v) != v {
logf("netns: disableBindConnToInterface changed to %v", v)
}
}
var disableBindConnToInterfaceAppleExt atomic.Bool
// SetDisableBindConnToInterfaceAppleExt disables the (normal) behavior of binding
// connections to the default network interface but only on Apple clients where
// tailscaled is running inside a network extension.
func SetDisableBindConnToInterfaceAppleExt(logf logger.Logf, v bool) {
if runtime.GOOS == "darwin" && disableBindConnToInterfaceAppleExt.Swap(v) != v {
logf("netns: disableBindConnToInterfaceAppleExt changed to %v", v)
}
}
// Listener returns a new net.Listener with its Control hook func

View File

@@ -21,6 +21,7 @@ import (
"tailscale.com/net/netmon"
"tailscale.com/net/tsaddr"
"tailscale.com/types/logger"
"tailscale.com/version"
)
func control(logf logger.Logf, netMon *netmon.Monitor) func(network, address string, c syscall.RawConn) error {
@@ -33,18 +34,14 @@ var bindToInterfaceByRouteEnv = envknob.RegisterBool("TS_BIND_TO_INTERFACE_BY_RO
var errInterfaceStateInvalid = errors.New("interface state invalid")
// controlLogf marks c as necessary to dial in a separate network namespace.
//
// It's intentionally the same signature as net.Dialer.Control
// and net.ListenConfig.Control.
// controlLogf binds c to a particular interface as necessary to dial the
// provided (network, address).
func controlLogf(logf logger.Logf, netMon *netmon.Monitor, network, address string, c syscall.RawConn) error {
if isLocalhost(address) {
// Don't bind to an interface for localhost connections.
if disableBindConnToInterface.Load() || (version.IsMacGUIVariant() && disableBindConnToInterfaceAppleExt.Load()) {
return nil
}
if disableBindConnToInterface.Load() {
logf("netns_darwin: binding connection to interfaces disabled")
if isLocalhost(address) {
return nil
}
@@ -78,10 +75,38 @@ func getInterfaceIndex(logf logger.Logf, netMon *netmon.Monitor, address string)
return -1, errInterfaceStateInvalid
}
if iface, ok := state.Interface[state.DefaultRouteInterface]; ok {
return iface.Index, nil
// Netmon's cached view of the default inteface
cachedIdx, ok := state.Interface[state.DefaultRouteInterface]
// OSes view (if available) of the default interface
osIf, osIferr := netmon.OSDefaultRoute()
idx := -1
errOut := errInterfaceStateInvalid
// Preferentially choose the OS's view of the default if index. Due to the way darwin sets the delegated
// interface on tunnel creation only, it is possible for netmon to have a stale view of the default and
// netmon's view is often temporarily wrong during network transitions, or for us to not have the
// the the oses view of the defaultIf yet.
if osIferr == nil {
idx = osIf.InterfaceIndex
errOut = nil
} else if ok {
idx = cachedIdx.Index
errOut = nil
}
return -1, errInterfaceStateInvalid
if osIferr == nil && ok && (osIf.InterfaceIndex != cachedIdx.Index) {
logf("netns: [unexpected] os default if %q (%d) != netmon cached if %q (%d)", osIf.InterfaceName, osIf.InterfaceIndex, cachedIdx.Name, cachedIdx.Index)
}
// Sanity check to make sure we didn't pick the tailscale interface
if tsif, err2 := tailscaleInterface(); tsif != nil && err2 == nil && errOut == nil {
if tsif.Index == idx {
idx = -1
errOut = errInterfaceStateInvalid
}
}
return idx, errOut
}
useRoute := bindToInterfaceByRoute.Load() || bindToInterfaceByRouteEnv()
@@ -100,7 +125,7 @@ func getInterfaceIndex(logf logger.Logf, netMon *netmon.Monitor, address string)
idx, err := interfaceIndexFor(addr, true /* canRecurse */)
if err != nil {
logf("netns: error in interfaceIndexFor: %v", err)
logf("netns: error getting interface index for %q: %v", address, err)
return defaultIdx()
}
@@ -108,10 +133,13 @@ func getInterfaceIndex(logf logger.Logf, netMon *netmon.Monitor, address string)
// if so, we fall back to binding from the default.
tsif, err2 := tailscaleInterface()
if err2 == nil && tsif != nil && tsif.Index == idx {
logf("[unexpected] netns: interfaceIndexFor returned Tailscale interface")
// note: with an exit node enabled, this is almost always true. defaultIdx() is the
// right thing to do here.
return defaultIdx()
}
logf("netns: completed success interfaceIndexFor(%s) = %d", address, idx)
return idx, err
}

View File

@@ -20,3 +20,7 @@ func control(logger.Logf, *netmon.Monitor) func(network, address string, c sysca
func controlC(network, address string, c syscall.RawConn) error {
return nil
}
func UseSocketMark() bool {
return false
}

View File

@@ -25,3 +25,7 @@ func parseAddress(address string) (addr netip.Addr, err error) {
return netip.ParseAddr(host)
}
func UseSocketMark() bool {
return false
}

View File

@@ -15,8 +15,8 @@ import (
"golang.org/x/sys/unix"
"tailscale.com/envknob"
"tailscale.com/net/netmon"
"tailscale.com/tsconst"
"tailscale.com/types/logger"
"tailscale.com/util/linuxfw"
)
// socketMarkWorksOnce is the sync.Once & cached value for useSocketMark.
@@ -111,7 +111,7 @@ func controlC(network, address string, c syscall.RawConn) error {
}
func setBypassMark(fd uintptr) error {
if err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, linuxfw.TailscaleBypassMarkNum); err != nil {
if err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, tsconst.LinuxBypassMarkNum); err != nil {
return fmt.Errorf("setting SO_MARK bypass: %w", err)
}
return nil

View File

@@ -1,7 +1,7 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build !ios && !js
//go:build !ios && !js && !android && !ts_omit_useproxy
package netns

View File

@@ -45,7 +45,7 @@ var (
)
func getBestInterfaceEx(sockaddr *winipcfg.RawSockaddrInet, bestIfaceIndex *uint32) (ret error) {
r0, _, _ := syscall.Syscall(procGetBestInterfaceEx.Addr(), 2, uintptr(unsafe.Pointer(sockaddr)), uintptr(unsafe.Pointer(bestIfaceIndex)), 0)
r0, _, _ := syscall.SyscallN(procGetBestInterfaceEx.Addr(), uintptr(unsafe.Pointer(sockaddr)), uintptr(unsafe.Pointer(bestIfaceIndex)))
if r0 != 0 {
ret = syscall.Errno(r0)
}

View File

@@ -1,35 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package netstat returns the local machine's network connection table.
package netstat
import (
"errors"
"net/netip"
"runtime"
)
var ErrNotImplemented = errors.New("not implemented for GOOS=" + runtime.GOOS)
type Entry struct {
Local, Remote netip.AddrPort
Pid int
State string // TODO: type?
OSMetadata OSMetadata
}
// Table contains local machine's TCP connection entries.
//
// Currently only TCP (IPv4 and IPv6) are included.
type Table struct {
Entries []Entry
}
// Get returns the connection table.
//
// It returns ErrNotImplemented if the table is not available for the
// current operating system.
func Get() (*Table, error) {
return get()
}

View File

@@ -1,14 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build !windows
package netstat
// OSMetadata includes any additional OS-specific information that may be
// obtained during the retrieval of a given Entry.
type OSMetadata struct{}
func get() (*Table, error) {
return nil, ErrNotImplemented
}

View File

@@ -1,287 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package netstat
import (
"errors"
"fmt"
"math/bits"
"net/netip"
"unsafe"
"golang.org/x/sys/cpu"
"golang.org/x/sys/windows"
"tailscale.com/net/netaddr"
)
// OSMetadata includes any additional OS-specific information that may be
// obtained during the retrieval of a given Entry.
type OSMetadata interface {
// GetModule returns the entry's module name.
//
// It returns ("", nil) if no entry is found. As of 2023-01-27, any returned
// error is silently discarded by its sole caller in portlist_windows.go and
// treated equivalently as returning ("", nil), but this may change in the
// future. An error should only be returned in casees that are worthy of
// being logged at least.
GetModule() (string, error)
}
// See https://docs.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedtcptable
// TCP_TABLE_OWNER_MODULE_ALL means to include the PID and module. The table type
// we get back from Windows depends on AF_INET vs AF_INET6:
// MIB_TCPTABLE_OWNER_MODULE for v4 or MIB_TCP6TABLE_OWNER_MODULE for v6.
const tcpTableOwnerModuleAll = 8
// TCPIP_OWNER_MODULE_BASIC_INFO means to request "basic information" about the
// owner module.
const tcpipOwnerModuleBasicInfo = 0
var (
iphlpapi = windows.NewLazySystemDLL("iphlpapi.dll")
getTCPTable = iphlpapi.NewProc("GetExtendedTcpTable")
getOwnerModuleFromTcpEntry = iphlpapi.NewProc("GetOwnerModuleFromTcpEntry")
getOwnerModuleFromTcp6Entry = iphlpapi.NewProc("GetOwnerModuleFromTcp6Entry")
// TODO: GetExtendedUdpTable also? if/when needed.
)
// See https://web.archive.org/web/20221219211913/https://learn.microsoft.com/en-us/windows/win32/api/tcpmib/ns-tcpmib-mib_tcprow_owner_module
type _MIB_TCPROW_OWNER_MODULE struct {
state uint32
localAddr uint32
localPort uint32
remoteAddr uint32
remotePort uint32
pid uint32
createTimestamp int64
owningModuleInfo [16]uint64
}
func (row *_MIB_TCPROW_OWNER_MODULE) asEntry() Entry {
return Entry{
Local: ipport4(row.localAddr, port(&row.localPort)),
Remote: ipport4(row.remoteAddr, port(&row.remotePort)),
Pid: int(row.pid),
State: state(row.state),
OSMetadata: row,
}
}
type _MIB_TCPTABLE_OWNER_MODULE struct {
numEntries uint32
table _MIB_TCPROW_OWNER_MODULE
}
func (m *_MIB_TCPTABLE_OWNER_MODULE) getRows() []_MIB_TCPROW_OWNER_MODULE {
return unsafe.Slice(&m.table, m.numEntries)
}
// See https://web.archive.org/web/20221219212442/https://learn.microsoft.com/en-us/windows/win32/api/tcpmib/ns-tcpmib-mib_tcp6row_owner_module
type _MIB_TCP6ROW_OWNER_MODULE struct {
localAddr [16]byte
localScope uint32
localPort uint32
remoteAddr [16]byte
remoteScope uint32
remotePort uint32
state uint32
pid uint32
createTimestamp int64
owningModuleInfo [16]uint64
}
func (row *_MIB_TCP6ROW_OWNER_MODULE) asEntry() Entry {
return Entry{
Local: ipport6(row.localAddr, row.localScope, port(&row.localPort)),
Remote: ipport6(row.remoteAddr, row.remoteScope, port(&row.remotePort)),
Pid: int(row.pid),
State: state(row.state),
OSMetadata: row,
}
}
type _MIB_TCP6TABLE_OWNER_MODULE struct {
numEntries uint32
table _MIB_TCP6ROW_OWNER_MODULE
}
func (m *_MIB_TCP6TABLE_OWNER_MODULE) getRows() []_MIB_TCP6ROW_OWNER_MODULE {
return unsafe.Slice(&m.table, m.numEntries)
}
// See https://web.archive.org/web/20221219213143/https://learn.microsoft.com/en-us/windows/win32/api/iprtrmib/ns-iprtrmib-tcpip_owner_module_basic_info
type _TCPIP_OWNER_MODULE_BASIC_INFO struct {
moduleName *uint16
modulePath *uint16
}
func get() (*Table, error) {
t := new(Table)
if err := t.addEntries(windows.AF_INET); err != nil {
return nil, fmt.Errorf("failed to get IPv4 entries: %w", err)
}
if err := t.addEntries(windows.AF_INET6); err != nil {
return nil, fmt.Errorf("failed to get IPv6 entries: %w", err)
}
return t, nil
}
func (t *Table) addEntries(fam int) error {
var size uint32
var addr unsafe.Pointer
var buf []byte
for {
err, _, _ := getTCPTable.Call(
uintptr(addr),
uintptr(unsafe.Pointer(&size)),
1, // sorted
uintptr(fam),
tcpTableOwnerModuleAll,
0, // reserved; "must be zero"
)
if err == 0 {
break
}
if err == uintptr(windows.ERROR_INSUFFICIENT_BUFFER) {
const maxSize = 10 << 20
if size > maxSize || size < 4 {
return fmt.Errorf("unreasonable kernel-reported size %d", size)
}
buf = make([]byte, size)
addr = unsafe.Pointer(&buf[0])
continue
}
return windows.Errno(err)
}
if len(buf) < int(size) {
return errors.New("unexpected size growth from system call")
}
buf = buf[:size]
switch fam {
case windows.AF_INET:
info := (*_MIB_TCPTABLE_OWNER_MODULE)(unsafe.Pointer(&buf[0]))
rows := info.getRows()
for _, row := range rows {
t.Entries = append(t.Entries, row.asEntry())
}
case windows.AF_INET6:
info := (*_MIB_TCP6TABLE_OWNER_MODULE)(unsafe.Pointer(&buf[0]))
rows := info.getRows()
for _, row := range rows {
t.Entries = append(t.Entries, row.asEntry())
}
}
return nil
}
var states = []string{
"",
"CLOSED",
"LISTEN",
"SYN-SENT",
"SYN-RECEIVED",
"ESTABLISHED",
"FIN-WAIT-1",
"FIN-WAIT-2",
"CLOSE-WAIT",
"CLOSING",
"LAST-ACK",
"DELETE-TCB",
}
func state(v uint32) string {
if v < uint32(len(states)) {
return states[v]
}
return fmt.Sprintf("unknown-state-%d", v)
}
func ipport4(addr uint32, port uint16) netip.AddrPort {
if !cpu.IsBigEndian {
addr = bits.ReverseBytes32(addr)
}
return netip.AddrPortFrom(
netaddr.IPv4(byte(addr>>24), byte(addr>>16), byte(addr>>8), byte(addr)),
port)
}
func ipport6(addr [16]byte, scope uint32, port uint16) netip.AddrPort {
ip := netip.AddrFrom16(addr).Unmap()
if scope != 0 {
// TODO: something better here?
ip = ip.WithZone(fmt.Sprint(scope))
}
return netip.AddrPortFrom(ip, port)
}
func port(v *uint32) uint16 {
if !cpu.IsBigEndian {
return uint16(bits.ReverseBytes32(*v) >> 16)
}
return uint16(*v >> 16)
}
type moduleInfoConstraint interface {
_MIB_TCPROW_OWNER_MODULE | _MIB_TCP6ROW_OWNER_MODULE
}
// moduleInfo implements OSMetadata.GetModule. It calls
// getOwnerModuleFromTcpEntry or getOwnerModuleFromTcp6Entry.
//
// See
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getownermodulefromtcpentry
//
// It may return "", nil indicating a successful call but with empty data.
func moduleInfo[entryType moduleInfoConstraint](entry *entryType, proc *windows.LazyProc) (string, error) {
var buf []byte
var desiredLen uint32
var addr unsafe.Pointer
for {
e, _, _ := proc.Call(
uintptr(unsafe.Pointer(entry)),
uintptr(tcpipOwnerModuleBasicInfo),
uintptr(addr),
uintptr(unsafe.Pointer(&desiredLen)),
)
err := windows.Errno(e)
if err == windows.ERROR_SUCCESS {
break
}
if err == windows.ERROR_NOT_FOUND {
return "", nil
}
if err != windows.ERROR_INSUFFICIENT_BUFFER {
return "", err
}
if desiredLen > 1<<20 {
// Sanity check before allocating too much.
return "", nil
}
buf = make([]byte, desiredLen)
addr = unsafe.Pointer(&buf[0])
}
if addr == nil {
// GetOwnerModuleFromTcp*Entry can apparently return ERROR_SUCCESS
// (NO_ERROR) on the first call without the usual first
// ERROR_INSUFFICIENT_BUFFER result. Windows said success, so interpret
// that was sucessfully not having data.
return "", nil
}
basicInfo := (*_TCPIP_OWNER_MODULE_BASIC_INFO)(addr)
return windows.UTF16PtrToString(basicInfo.moduleName), nil
}
// GetModule implements OSMetadata.
func (m *_MIB_TCPROW_OWNER_MODULE) GetModule() (string, error) {
return moduleInfo(m, getOwnerModuleFromTcpEntry)
}
// GetModule implements OSMetadata.
func (m *_MIB_TCP6ROW_OWNER_MODULE) GetModule() (string, error) {
return moduleInfo(m, getOwnerModuleFromTcp6Entry)
}

View File

@@ -8,7 +8,8 @@ import (
"bufio"
"io"
"net"
"sync"
"tailscale.com/syncs"
)
// NewOneConnListener returns a net.Listener that returns c on its
@@ -29,7 +30,7 @@ func NewOneConnListener(c net.Conn, addr net.Addr) net.Listener {
type oneConnListener struct {
addr net.Addr
mu sync.Mutex
mu syncs.Mutex
conn net.Conn
}

53
vendor/tailscale.com/net/netx/netx.go generated vendored Normal file
View File

@@ -0,0 +1,53 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package netx contains types to describe and abstract over how dialing and
// listening are performed.
package netx
import (
"context"
"fmt"
"net"
)
// DialFunc is a function that dials a network address.
//
// It's the type implemented by net.Dialer.DialContext or required
// by net/http.Transport.DialContext, etc.
type DialFunc func(ctx context.Context, network, address string) (net.Conn, error)
// Network describes a network that can listen and dial. The two common
// implementations are [RealNetwork], using the net package to use the real
// network, or [memnet.Network], using an in-memory network (typically for testing)
type Network interface {
NewLocalTCPListener() net.Listener
Listen(network, address string) (net.Listener, error)
Dial(ctx context.Context, network, address string) (net.Conn, error)
}
// RealNetwork returns a Network implementation that uses the real
// net package.
func RealNetwork() Network { return realNetwork{} }
// realNetwork implements [Network] using the real net package.
type realNetwork struct{}
func (realNetwork) Listen(network, address string) (net.Listener, error) {
return net.Listen(network, address)
}
func (realNetwork) Dial(ctx context.Context, network, address string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, network, address)
}
func (realNetwork) NewLocalTCPListener() net.Listener {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
if ln, err = net.Listen("tcp6", "[::1]:0"); err != nil {
panic(fmt.Sprintf("failed to listen on either IPv4 or IPv6 localhost port: %v", err))
}
}
return ln
}

View File

@@ -8,8 +8,6 @@ import (
"encoding/binary"
"net/netip"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"tailscale.com/net/packet"
"tailscale.com/types/ipproto"
)
@@ -88,13 +86,13 @@ func updateV4PacketChecksums(p *packet.Parsed, old, new netip.Addr) {
tr := p.Transport()
switch p.IPProto {
case ipproto.UDP, ipproto.DCCP:
if len(tr) < header.UDPMinimumSize {
if len(tr) < minUDPSize {
// Not enough space for a UDP header.
return
}
updateV4Checksum(tr[6:8], o4[:], n4[:])
case ipproto.TCP:
if len(tr) < header.TCPMinimumSize {
if len(tr) < minTCPSize {
// Not enough space for a TCP header.
return
}
@@ -112,34 +110,60 @@ func updateV4PacketChecksums(p *packet.Parsed, old, new netip.Addr) {
}
}
const (
minUDPSize = 8
minTCPSize = 20
minICMPv6Size = 8
minIPv6Header = 40
offsetICMPv6Checksum = 2
offsetUDPChecksum = 6
offsetTCPChecksum = 16
)
// updateV6PacketChecksums updates the checksums in the packet buffer.
// p is modified in place.
// If p.IPProto is unknown, no checksums are updated.
func updateV6PacketChecksums(p *packet.Parsed, old, new netip.Addr) {
if len(p.Buffer()) < 40 {
if len(p.Buffer()) < minIPv6Header {
// Not enough space for an IPv6 header.
return
}
o6, n6 := tcpip.AddrFrom16Slice(old.AsSlice()), tcpip.AddrFrom16Slice(new.AsSlice())
o6, n6 := old.As16(), new.As16()
// Now update the transport layer checksums, where applicable.
tr := p.Transport()
switch p.IPProto {
case ipproto.ICMPv6:
if len(tr) < header.ICMPv6MinimumSize {
if len(tr) < minICMPv6Size {
return
}
header.ICMPv6(tr).UpdateChecksumPseudoHeaderAddress(o6, n6)
ss := tr[offsetICMPv6Checksum:]
xsum := binary.BigEndian.Uint16(ss)
binary.BigEndian.PutUint16(ss,
^checksumUpdate2ByteAlignedAddress(^xsum, o6, n6))
case ipproto.UDP, ipproto.DCCP:
if len(tr) < header.UDPMinimumSize {
if len(tr) < minUDPSize {
return
}
header.UDP(tr).UpdateChecksumPseudoHeaderAddress(o6, n6, true)
ss := tr[offsetUDPChecksum:]
xsum := binary.BigEndian.Uint16(ss)
xsum = ^xsum
xsum = checksumUpdate2ByteAlignedAddress(xsum, o6, n6)
xsum = ^xsum
binary.BigEndian.PutUint16(ss, xsum)
case ipproto.TCP:
if len(tr) < header.TCPMinimumSize {
if len(tr) < minTCPSize {
return
}
header.TCP(tr).UpdateChecksumPseudoHeaderAddress(o6, n6, true)
ss := tr[offsetTCPChecksum:]
xsum := binary.BigEndian.Uint16(ss)
xsum = ^xsum
xsum = checksumUpdate2ByteAlignedAddress(xsum, o6, n6)
xsum = ^xsum
binary.BigEndian.PutUint16(ss, xsum)
case ipproto.SCTP:
// No transport layer update required.
}
@@ -195,3 +219,77 @@ func updateV4Checksum(oldSum, old, new []byte) {
hcPrime := ^uint16(cPrime)
binary.BigEndian.PutUint16(oldSum, hcPrime)
}
// checksumUpdate2ByteAlignedAddress updates an address in a calculated
// checksum.
//
// The addresses must have the same length and must contain an even number
// of bytes. The address MUST begin at a 2-byte boundary in the original buffer.
//
// This implementation is copied from gVisor, but updated to use [16]byte.
func checksumUpdate2ByteAlignedAddress(xsum uint16, old, new [16]byte) uint16 {
const uint16Bytes = 2
oldAddr := old[:]
newAddr := new[:]
// As per RFC 1071 page 4,
// (4) Incremental Update
//
// ...
//
// To update the checksum, simply add the differences of the
// sixteen bit integers that have been changed. To see why this
// works, observe that every 16-bit integer has an additive inverse
// and that addition is associative. From this it follows that
// given the original value m, the new value m', and the old
// checksum C, the new checksum C' is:
//
// C' = C + (-m) + m' = C + (m' - m)
for len(oldAddr) != 0 {
// Convert the 2 byte sequences to uint16 values then apply the increment
// update.
xsum = checksumUpdate2ByteAlignedUint16(xsum, (uint16(oldAddr[0])<<8)+uint16(oldAddr[1]), (uint16(newAddr[0])<<8)+uint16(newAddr[1]))
oldAddr = oldAddr[uint16Bytes:]
newAddr = newAddr[uint16Bytes:]
}
return xsum
}
// checksumUpdate2ByteAlignedUint16 updates a uint16 value in a calculated
// checksum.
//
// The value MUST begin at a 2-byte boundary in the original buffer.
//
// This implementation is copied from gVisor.
func checksumUpdate2ByteAlignedUint16(xsum, old, new uint16) uint16 {
// As per RFC 1071 page 4,
// (4) Incremental Update
//
// ...
//
// To update the checksum, simply add the differences of the
// sixteen bit integers that have been changed. To see why this
// works, observe that every 16-bit integer has an additive inverse
// and that addition is associative. From this it follows that
// given the original value m, the new value m', and the old
// checksum C, the new checksum C' is:
//
// C' = C + (-m) + m' = C + (m' - m)
if old == new {
return xsum
}
return checksumCombine(xsum, checksumCombine(new, ^old))
}
// checksumCombine combines the two uint16 to form their checksum. This is done
// by adding them and the carry.
//
// Note that checksum a must have been computed on an even number of bytes.
//
// This implementation is copied from gVisor.
func checksumCombine(a, b uint16) uint16 {
v := uint32(a) + uint32(b)
return uint16(v + v>>16)
}

View File

@@ -24,6 +24,33 @@ const (
GeneveProtocolWireGuard uint16 = 0x7A12
)
// VirtualNetworkID is a Geneve header (RFC8926) 3-byte virtual network
// identifier. Its methods are NOT thread-safe.
type VirtualNetworkID struct {
_vni uint32
}
const (
vniSetMask uint32 = 0xFF000000
vniGetMask uint32 = ^vniSetMask
)
// IsSet returns true if Set() had been called previously, otherwise false.
func (v *VirtualNetworkID) IsSet() bool {
return v._vni&vniSetMask != 0
}
// Set sets the provided VNI. If VNI exceeds the 3-byte storage it will be
// clamped.
func (v *VirtualNetworkID) Set(vni uint32) {
v._vni = vni | vniSetMask
}
// Get returns the VNI value.
func (v *VirtualNetworkID) Get() uint32 {
return v._vni & vniGetMask
}
// GeneveHeader represents the fixed size Geneve header from RFC8926.
// TLVs/options are not implemented/supported.
//
@@ -51,7 +78,7 @@ type GeneveHeader struct {
// decisions or MAY be used as a mechanism to distinguish between
// overlapping address spaces contained in the encapsulated packet when load
// balancing across CPUs.
VNI uint32
VNI VirtualNetworkID
// O (1 bit): Control packet. This packet contains a control message.
// Control messages are sent between tunnel endpoints. Tunnel endpoints MUST
@@ -65,12 +92,18 @@ type GeneveHeader struct {
Control bool
}
// Encode encodes GeneveHeader into b. If len(b) < GeneveFixedHeaderLength an
// io.ErrShortBuffer error is returned.
var ErrGeneveVNIUnset = errors.New("VNI is unset")
// Encode encodes GeneveHeader into b. If len(b) < [GeneveFixedHeaderLength] an
// [io.ErrShortBuffer] error is returned. If !h.VNI.IsSet() then an
// [ErrGeneveVNIUnset] error is returned.
func (h *GeneveHeader) Encode(b []byte) error {
if len(b) < GeneveFixedHeaderLength {
return io.ErrShortBuffer
}
if !h.VNI.IsSet() {
return ErrGeneveVNIUnset
}
if h.Version > 3 {
return errors.New("version must be <= 3")
}
@@ -81,15 +114,12 @@ func (h *GeneveHeader) Encode(b []byte) error {
b[1] |= 0x80
}
binary.BigEndian.PutUint16(b[2:], h.Protocol)
if h.VNI > 1<<24-1 {
return errors.New("VNI must be <= 2^24-1")
}
binary.BigEndian.PutUint32(b[4:], h.VNI<<8)
binary.BigEndian.PutUint32(b[4:], h.VNI.Get()<<8)
return nil
}
// Decode decodes GeneveHeader from b. If len(b) < GeneveFixedHeaderLength an
// io.ErrShortBuffer error is returned.
// Decode decodes GeneveHeader from b. If len(b) < [GeneveFixedHeaderLength] an
// [io.ErrShortBuffer] error is returned.
func (h *GeneveHeader) Decode(b []byte) error {
if len(b) < GeneveFixedHeaderLength {
return io.ErrShortBuffer
@@ -99,6 +129,6 @@ func (h *GeneveHeader) Decode(b []byte) error {
h.Control = true
}
h.Protocol = binary.BigEndian.Uint16(b[2:])
h.VNI = binary.BigEndian.Uint32(b[4:]) >> 8
h.VNI.Set(binary.BigEndian.Uint32(b[4:]) >> 8)
return nil
}

View File

@@ -8,6 +8,7 @@ import (
"math"
)
const igmpHeaderLength = 8
const tcpHeaderLength = 20
const sctpHeaderLength = 12

View File

@@ -51,10 +51,11 @@ type Parsed struct {
IPVersion uint8
// IPProto is the IP subprotocol (UDP, TCP, etc.). Valid iff IPVersion != 0.
IPProto ipproto.Proto
// SrcIP4 is the source address. Family matches IPVersion. Port is
// valid iff IPProto == TCP || IPProto == UDP.
// Src is the source address. Family matches IPVersion. Port is
// valid iff IPProto == TCP || IPProto == UDP || IPProto == SCTP.
Src netip.AddrPort
// DstIP4 is the destination address. Family matches IPVersion.
// Dst is the destination address. Family matches IPVersion. Port is
// valid iff IPProto == TCP || IPProto == UDP || IPProto == SCTP.
Dst netip.AddrPort
// TCPFlags is the packet's TCP flag bits. Valid iff IPProto == TCP.
TCPFlags TCPFlag
@@ -160,14 +161,8 @@ func (q *Parsed) decode4(b []byte) {
if fragOfs == 0 {
// This is the first fragment
if moreFrags && len(sub) < minFragBlks {
// Suspiciously short first fragment, dump it.
q.IPProto = unknown
return
}
// otherwise, this is either non-fragmented (the usual case)
// or a big enough initial fragment that we can read the
// whole subprotocol header.
// Every protocol below MUST check that it has at least one entire
// transport header in order to protect against fragment confusion.
switch q.IPProto {
case ipproto.ICMPv4:
if len(sub) < icmp4HeaderLength {
@@ -179,6 +174,10 @@ func (q *Parsed) decode4(b []byte) {
q.dataofs = q.subofs + icmp4HeaderLength
return
case ipproto.IGMP:
if len(sub) < igmpHeaderLength {
q.IPProto = unknown
return
}
// Keep IPProto, but don't parse anything else
// out.
return
@@ -211,6 +210,15 @@ func (q *Parsed) decode4(b []byte) {
q.Dst = withPort(q.Dst, binary.BigEndian.Uint16(sub[2:4]))
return
case ipproto.TSMP:
// Strictly disallow fragmented TSMP
if moreFrags {
q.IPProto = unknown
return
}
if len(sub) < minTSMPSize {
q.IPProto = unknown
return
}
// Inter-tailscale messages.
q.dataofs = q.subofs
return
@@ -223,8 +231,11 @@ func (q *Parsed) decode4(b []byte) {
} else {
// This is a fragment other than the first one.
if fragOfs < minFragBlks {
// First frag was suspiciously short, so we can't
// trust the followup either.
// disallow fragment offsets that are potentially inside of a
// transport header. This is notably asymmetric with the
// first-packet limit, that may allow a first-packet that requires a
// shorter offset than this limit, but without state to tie this
// to the first fragment we can not allow shorter packets.
q.IPProto = unknown
return
}
@@ -314,6 +325,10 @@ func (q *Parsed) decode6(b []byte) {
q.Dst = withPort(q.Dst, binary.BigEndian.Uint16(sub[2:4]))
return
case ipproto.TSMP:
if len(sub) < minTSMPSize {
q.IPProto = unknown
return
}
// Inter-tailscale messages.
q.dataofs = q.subofs
return

View File

@@ -15,10 +15,13 @@ import (
"fmt"
"net/netip"
"tailscale.com/net/flowtrack"
"go4.org/mem"
"tailscale.com/types/ipproto"
"tailscale.com/types/key"
)
const minTSMPSize = 7 // the rejected body is 7 bytes
// TailscaleRejectedHeader is a TSMP message that says that one
// Tailscale node has rejected the connection from another. Unlike a
// TCP RST, this includes a reason.
@@ -56,10 +59,6 @@ type TailscaleRejectedHeader struct {
const rejectFlagBitMaybeBroken = 0x1
func (rh TailscaleRejectedHeader) Flow() flowtrack.Tuple {
return flowtrack.MakeTuple(rh.Proto, rh.Src, rh.Dst)
}
func (rh TailscaleRejectedHeader) String() string {
return fmt.Sprintf("TSMP-reject-flow{%s %s > %s}: %s", rh.Proto, rh.Src, rh.Dst, rh.Reason)
}
@@ -75,6 +74,9 @@ const (
// TSMPTypePong is the type byte for a TailscalePongResponse.
TSMPTypePong TSMPType = 'o'
// TSPMTypeDiscoAdvertisement is the type byte for sending disco keys
TSMPTypeDiscoAdvertisement TSMPType = 'a'
)
type TailscaleRejectReason byte
@@ -262,3 +264,54 @@ func (h TSMPPongReply) Marshal(buf []byte) error {
binary.BigEndian.PutUint16(buf[9:11], h.PeerAPIPort)
return nil
}
// TSMPDiscoKeyAdvertisement is a TSMP message that's used for distributing Disco Keys.
//
// On the wire, after the IP header, it's currently 33 bytes:
// - 'a' (TSMPTypeDiscoAdvertisement)
// - 32 disco key bytes
type TSMPDiscoKeyAdvertisement struct {
Src, Dst netip.Addr // Src and Dst are set from the parent IP Header when parsing.
Key key.DiscoPublic
}
func (ka *TSMPDiscoKeyAdvertisement) Marshal() ([]byte, error) {
var iph Header
if ka.Src.Is4() {
iph = IP4Header{
IPProto: ipproto.TSMP,
Src: ka.Src,
Dst: ka.Dst,
}
} else {
iph = IP6Header{
IPProto: ipproto.TSMP,
Src: ka.Src,
Dst: ka.Dst,
}
}
payload := make([]byte, 0, 33)
payload = append(payload, byte(TSMPTypeDiscoAdvertisement))
payload = ka.Key.AppendTo(payload)
if len(payload) != 33 {
// Mostly to safeguard against ourselves changing this in the future.
return []byte{}, fmt.Errorf("expected payload length 33, got %d", len(payload))
}
return Generate(iph, payload[:]), nil
}
func (pp *Parsed) AsTSMPDiscoAdvertisement() (tka TSMPDiscoKeyAdvertisement, ok bool) {
if pp.IPProto != ipproto.TSMP {
return
}
p := pp.Payload()
if len(p) < 33 || p[0] != byte(TSMPTypeDiscoAdvertisement) {
return
}
tka.Src = pp.Src.Addr()
tka.Dst = pp.Dst.Addr()
tka.Key = key.DiscoPublicFromRaw32(mem.B(p[1:33]))
return tka, true
}

View File

@@ -10,6 +10,7 @@ import (
"context"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"io"
"log"
@@ -22,9 +23,9 @@ import (
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"tailscale.com/syncs"
"tailscale.com/types/logger"
"tailscale.com/util/mak"
"tailscale.com/util/multierr"
)
const (
@@ -64,7 +65,7 @@ type Pinger struct {
wg sync.WaitGroup
// Following fields protected by mu
mu sync.Mutex
mu syncs.Mutex
// conns is a map of "type" to net.PacketConn, type is either
// "ip4:icmp" or "ip6:icmp"
conns map[string]net.PacketConn
@@ -157,17 +158,17 @@ func (p *Pinger) Close() error {
p.conns = nil
p.mu.Unlock()
var errors []error
var errs []error
for _, c := range conns {
if err := c.Close(); err != nil {
errors = append(errors, err)
errs = append(errs, err)
}
}
p.wg.Wait()
p.cleanupOutstanding()
return multierr.New(errors...)
return errors.Join(errs...)
}
func (p *Pinger) run(ctx context.Context, conn net.PacketConn, typ string) {

View File

@@ -10,8 +10,8 @@ package portmapper
import (
"context"
"github.com/tailscale/goupnp"
"github.com/tailscale/goupnp/soap"
"github.com/huin/goupnp"
"github.com/huin/goupnp/soap"
)
const (
@@ -32,8 +32,8 @@ type legacyWANPPPConnection1 struct {
goupnp.ServiceClient
}
// AddPortMapping implements upnpClient
func (client *legacyWANPPPConnection1) AddPortMapping(
// AddPortMappingCtx implements upnpClient
func (client *legacyWANPPPConnection1) AddPortMappingCtx(
ctx context.Context,
NewRemoteHost string,
NewExternalPort uint16,
@@ -85,11 +85,11 @@ func (client *legacyWANPPPConnection1) AddPortMapping(
response := any(nil)
// Perform the SOAP call.
return client.SOAPClient.PerformAction(ctx, urn_LegacyWANPPPConnection_1, "AddPortMapping", request, response)
return client.SOAPClient.PerformActionCtx(ctx, urn_LegacyWANPPPConnection_1, "AddPortMapping", request, response)
}
// DeletePortMapping implements upnpClient
func (client *legacyWANPPPConnection1) DeletePortMapping(ctx context.Context, NewRemoteHost string, NewExternalPort uint16, NewProtocol string) (err error) {
// DeletePortMappingCtx implements upnpClient
func (client *legacyWANPPPConnection1) DeletePortMappingCtx(ctx context.Context, NewRemoteHost string, NewExternalPort uint16, NewProtocol string) (err error) {
// Request structure.
request := &struct {
NewRemoteHost string
@@ -110,11 +110,11 @@ func (client *legacyWANPPPConnection1) DeletePortMapping(ctx context.Context, Ne
response := any(nil)
// Perform the SOAP call.
return client.SOAPClient.PerformAction(ctx, urn_LegacyWANPPPConnection_1, "DeletePortMapping", request, response)
return client.SOAPClient.PerformActionCtx(ctx, urn_LegacyWANPPPConnection_1, "DeletePortMapping", request, response)
}
// GetExternalIPAddress implements upnpClient
func (client *legacyWANPPPConnection1) GetExternalIPAddress(ctx context.Context) (NewExternalIPAddress string, err error) {
// GetExternalIPAddressCtx implements upnpClient
func (client *legacyWANPPPConnection1) GetExternalIPAddressCtx(ctx context.Context) (NewExternalIPAddress string, err error) {
// Request structure.
request := any(nil)
@@ -124,7 +124,7 @@ func (client *legacyWANPPPConnection1) GetExternalIPAddress(ctx context.Context)
}{}
// Perform the SOAP call.
if err = client.SOAPClient.PerformAction(ctx, urn_LegacyWANPPPConnection_1, "GetExternalIPAddress", request, response); err != nil {
if err = client.SOAPClient.PerformActionCtx(ctx, urn_LegacyWANPPPConnection_1, "GetExternalIPAddress", request, response); err != nil {
return
}
@@ -134,8 +134,8 @@ func (client *legacyWANPPPConnection1) GetExternalIPAddress(ctx context.Context)
return
}
// GetStatusInfo implements upnpClient
func (client *legacyWANPPPConnection1) GetStatusInfo(ctx context.Context) (NewConnectionStatus string, NewLastConnectionError string, NewUptime uint32, err error) {
// GetStatusInfoCtx implements upnpClient
func (client *legacyWANPPPConnection1) GetStatusInfoCtx(ctx context.Context) (NewConnectionStatus string, NewLastConnectionError string, NewUptime uint32, err error) {
// Request structure.
request := any(nil)
@@ -147,7 +147,7 @@ func (client *legacyWANPPPConnection1) GetStatusInfo(ctx context.Context) (NewCo
}{}
// Perform the SOAP call.
if err = client.SOAPClient.PerformAction(ctx, urn_LegacyWANPPPConnection_1, "GetStatusInfo", request, response); err != nil {
if err = client.SOAPClient.PerformActionCtx(ctx, urn_LegacyWANPPPConnection_1, "GetStatusInfo", request, response); err != nil {
return
}
@@ -171,8 +171,8 @@ type legacyWANIPConnection1 struct {
goupnp.ServiceClient
}
// AddPortMapping implements upnpClient
func (client *legacyWANIPConnection1) AddPortMapping(
// AddPortMappingCtx implements upnpClient
func (client *legacyWANIPConnection1) AddPortMappingCtx(
ctx context.Context,
NewRemoteHost string,
NewExternalPort uint16,
@@ -224,11 +224,11 @@ func (client *legacyWANIPConnection1) AddPortMapping(
response := any(nil)
// Perform the SOAP call.
return client.SOAPClient.PerformAction(ctx, urn_LegacyWANIPConnection_1, "AddPortMapping", request, response)
return client.SOAPClient.PerformActionCtx(ctx, urn_LegacyWANIPConnection_1, "AddPortMapping", request, response)
}
// DeletePortMapping implements upnpClient
func (client *legacyWANIPConnection1) DeletePortMapping(ctx context.Context, NewRemoteHost string, NewExternalPort uint16, NewProtocol string) (err error) {
// DeletePortMappingCtx implements upnpClient
func (client *legacyWANIPConnection1) DeletePortMappingCtx(ctx context.Context, NewRemoteHost string, NewExternalPort uint16, NewProtocol string) (err error) {
// Request structure.
request := &struct {
NewRemoteHost string
@@ -249,11 +249,11 @@ func (client *legacyWANIPConnection1) DeletePortMapping(ctx context.Context, New
response := any(nil)
// Perform the SOAP call.
return client.SOAPClient.PerformAction(ctx, urn_LegacyWANIPConnection_1, "DeletePortMapping", request, response)
return client.SOAPClient.PerformActionCtx(ctx, urn_LegacyWANIPConnection_1, "DeletePortMapping", request, response)
}
// GetExternalIPAddress implements upnpClient
func (client *legacyWANIPConnection1) GetExternalIPAddress(ctx context.Context) (NewExternalIPAddress string, err error) {
// GetExternalIPAddressCtx implements upnpClient
func (client *legacyWANIPConnection1) GetExternalIPAddressCtx(ctx context.Context) (NewExternalIPAddress string, err error) {
// Request structure.
request := any(nil)
@@ -263,7 +263,7 @@ func (client *legacyWANIPConnection1) GetExternalIPAddress(ctx context.Context)
}{}
// Perform the SOAP call.
if err = client.SOAPClient.PerformAction(ctx, urn_LegacyWANIPConnection_1, "GetExternalIPAddress", request, response); err != nil {
if err = client.SOAPClient.PerformActionCtx(ctx, urn_LegacyWANIPConnection_1, "GetExternalIPAddress", request, response); err != nil {
return
}
@@ -273,8 +273,8 @@ func (client *legacyWANIPConnection1) GetExternalIPAddress(ctx context.Context)
return
}
// GetStatusInfo implements upnpClient
func (client *legacyWANIPConnection1) GetStatusInfo(ctx context.Context) (NewConnectionStatus string, NewLastConnectionError string, NewUptime uint32, err error) {
// GetStatusInfoCtx implements upnpClient
func (client *legacyWANIPConnection1) GetStatusInfoCtx(ctx context.Context) (NewConnectionStatus string, NewLastConnectionError string, NewUptime uint32, err error) {
// Request structure.
request := any(nil)
@@ -286,7 +286,7 @@ func (client *legacyWANIPConnection1) GetStatusInfo(ctx context.Context) (NewCon
}{}
// Perform the SOAP call.
if err = client.SOAPClient.PerformAction(ctx, urn_LegacyWANIPConnection_1, "GetStatusInfo", request, response); err != nil {
if err = client.SOAPClient.PerformActionCtx(ctx, urn_LegacyWANIPConnection_1, "GetStatusInfo", request, response); err != nil {
return
}

View File

@@ -24,8 +24,9 @@ const _pmpResultCode_name = "OKUnsupportedVersionNotAuthorizedNetworkFailureOutO
var _pmpResultCode_index = [...]uint8{0, 2, 20, 33, 47, 61, 78}
func (i pmpResultCode) String() string {
if i >= pmpResultCode(len(_pmpResultCode_index)-1) {
idx := int(i) - 0
if i < 0 || idx >= len(_pmpResultCode_index)-1 {
return "pmpResultCode(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _pmpResultCode_name[_pmpResultCode_index[i]:_pmpResultCode_index[i+1]]
return _pmpResultCode_name[_pmpResultCode_index[idx]:_pmpResultCode_index[idx+1]]
}

View File

@@ -8,29 +8,36 @@ package portmapper
import (
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"slices"
"sync"
"sync/atomic"
"time"
"go4.org/mem"
"tailscale.com/control/controlknobs"
"tailscale.com/envknob"
"tailscale.com/feature/buildfeatures"
"tailscale.com/net/netaddr"
"tailscale.com/net/neterror"
"tailscale.com/net/netmon"
"tailscale.com/net/netns"
"tailscale.com/net/portmapper/portmappertype"
"tailscale.com/net/sockstats"
"tailscale.com/syncs"
"tailscale.com/types/logger"
"tailscale.com/types/nettype"
"tailscale.com/util/clientmetric"
"tailscale.com/util/eventbus"
)
var (
ErrNoPortMappingServices = portmappertype.ErrNoPortMappingServices
ErrGatewayRange = portmappertype.ErrGatewayRange
ErrGatewayIPv6 = portmappertype.ErrGatewayIPv6
ErrPortMappingDisabled = portmappertype.ErrPortMappingDisabled
)
var disablePortMapperEnv = envknob.RegisterBool("TS_DISABLE_PORTMAPPER")
@@ -48,15 +55,33 @@ type DebugKnobs struct {
LogHTTP bool
// Disable* disables a specific service from mapping.
DisableUPnP bool
DisablePMP bool
DisablePCP bool
// If the funcs are nil or return false, the service is not disabled.
// Use the corresponding accessor methods without the "Func" suffix
// to check whether a service is disabled.
DisableUPnPFunc func() bool
DisablePMPFunc func() bool
DisablePCPFunc func() bool
// DisableAll, if non-nil, is a func that reports whether all port
// mapping attempts should be disabled.
DisableAll func() bool
}
// DisableUPnP reports whether UPnP is disabled.
func (k *DebugKnobs) DisableUPnP() bool {
return k != nil && k.DisableUPnPFunc != nil && k.DisableUPnPFunc()
}
// DisablePMP reports whether NAT-PMP is disabled.
func (k *DebugKnobs) DisablePMP() bool {
return k != nil && k.DisablePMPFunc != nil && k.DisablePMPFunc()
}
// DisablePCP reports whether PCP is disabled.
func (k *DebugKnobs) DisablePCP() bool {
return k != nil && k.DisablePCPFunc != nil && k.DisablePCPFunc()
}
func (k *DebugKnobs) disableAll() bool {
if disablePortMapperEnv() {
return true
@@ -84,16 +109,20 @@ const trustServiceStillAvailableDuration = 10 * time.Minute
// Client is a port mapping client.
type Client struct {
// The following two fields must both be non-nil.
// Both are immutable after construction.
pubClient *eventbus.Client
updates *eventbus.Publisher[portmappertype.Mapping]
logf logger.Logf
netMon *netmon.Monitor // optional; nil means interfaces will be looked up on-demand
controlKnobs *controlknobs.Knobs
ipAndGateway func() (gw, ip netip.Addr, ok bool)
onChange func() // or nil
debug DebugKnobs
testPxPPort uint16 // if non-zero, pxpPort to use for tests
testUPnPPort uint16 // if non-zero, uPnPPort to use for tests
mu sync.Mutex // guards following, and all fields thereof
mu syncs.Mutex // guards following, and all fields thereof
// runningCreate is whether we're currently working on creating
// a port mapping (whether GetCachedMappingOrStartCreatingOne kicked
@@ -124,6 +153,8 @@ type Client struct {
mapping mapping // non-nil if we have a mapping
}
var _ portmappertype.Client = (*Client)(nil)
func (c *Client) vlogf(format string, args ...any) {
if c.debug.VerboseLogs {
c.logf(format, args...)
@@ -153,7 +184,6 @@ type mapping interface {
MappingDebug() string
}
// HaveMapping reports whether we have a current valid mapping.
func (c *Client) HaveMapping() bool {
c.mu.Lock()
defer c.mu.Unlock()
@@ -201,32 +231,52 @@ func (m *pmpMapping) Release(ctx context.Context) {
uc.WriteToUDPAddrPort(pkt, m.gw)
}
// NewClient returns a new portmapping client.
//
// The netMon parameter is required.
//
// The debug argument allows configuring the behaviour of the portmapper for
// debugging; if nil, a sensible set of defaults will be used.
//
// The controlKnobs, if non-nil, specifies the control knobs from the control
// plane that might disable portmapping.
//
// The optional onChange argument specifies a func to run in a new goroutine
// whenever the port mapping status has changed. If nil, it doesn't make a
// callback.
func NewClient(logf logger.Logf, netMon *netmon.Monitor, debug *DebugKnobs, controlKnobs *controlknobs.Knobs, onChange func()) *Client {
if netMon == nil {
panic("nil netMon")
// Config carries the settings for a [Client].
type Config struct {
// EventBus, which must be non-nil, is used for event publication and
// subscription by portmapper clients created from this config.
EventBus *eventbus.Bus
// Logf is called to generate text logs for the client. If nil, logger.Discard is used.
Logf logger.Logf
// NetMon is the network monitor used by the client. It must be non-nil.
NetMon *netmon.Monitor
// DebugKnobs, if non-nil, configure the behaviour of the portmapper for
// debugging. If nil, a sensible set of defaults will be used.
DebugKnobs *DebugKnobs
// OnChange is called to run in a new goroutine whenever the port mapping
// status has changed. If nil, no callback is issued.
OnChange func()
}
// NewClient constructs a new portmapping [Client] from c. It will panic if any
// required parameters are omitted.
func NewClient(c Config) *Client {
switch {
case c.NetMon == nil:
panic("nil NetMon")
case c.EventBus == nil:
panic("nil EventBus")
}
ret := &Client{
logf: logf,
netMon: netMon,
ipAndGateway: netmon.LikelyHomeRouterIP, // TODO(bradfitz): move this to method on netMon
onChange: onChange,
controlKnobs: controlKnobs,
logf: c.Logf,
netMon: c.NetMon,
onChange: c.OnChange,
}
if debug != nil {
ret.debug = *debug
if buildfeatures.HasPortMapper {
// TODO(bradfitz): move this to method on netMon
ret.ipAndGateway = netmon.LikelyHomeRouterIP
}
ret.pubClient = c.EventBus.Client("portmapper")
ret.updates = eventbus.Publish[portmappertype.Mapping](ret.pubClient)
if ret.logf == nil {
ret.logf = logger.Discard
}
if c.DebugKnobs != nil {
ret.debug = *c.DebugKnobs
}
return ret
}
@@ -256,6 +306,9 @@ func (c *Client) Close() error {
}
c.closed = true
c.invalidateMappingsLocked(true)
c.updates.Close()
c.pubClient.Close()
// TODO: close some future ever-listening UDP socket(s),
// waiting for multicast announcements from router.
return nil
@@ -417,13 +470,6 @@ func IsNoMappingError(err error) bool {
return ok
}
var (
ErrNoPortMappingServices = errors.New("no port mapping services were found")
ErrGatewayRange = errors.New("skipping portmap; gateway range likely lacks support")
ErrGatewayIPv6 = errors.New("skipping portmap; no IPv6 support for portmapping")
ErrPortMappingDisabled = errors.New("port mapping is disabled")
)
// GetCachedMappingOrStartCreatingOne quickly returns with our current cached portmapping, if any.
// If there's not one, it starts up a background goroutine to create one.
// If the background goroutine ends up creating one, the onChange hook registered with the
@@ -467,10 +513,29 @@ func (c *Client) createMapping() {
c.runningCreate = false
}()
if _, err := c.createOrGetMapping(ctx); err == nil && c.onChange != nil {
mapping, _, err := c.createOrGetMapping(ctx)
if err != nil {
if !IsNoMappingError(err) {
c.logf("createOrGetMapping: %v", err)
}
return
} else if mapping == nil {
return
// TODO(creachadair): This was already logged in createOrGetMapping.
// It really should not happen at all, but we will need to untangle
// the control flow to eliminate that possibility. Meanwhile, this
// mitigates a panic downstream, cf. #16662.
}
c.updates.Publish(portmappertype.Mapping{
External: mapping.External(),
Type: mapping.MappingType(),
GoodUntil: mapping.GoodUntil(),
})
// TODO(creachadair): Remove this entirely once there are no longer any
// places where the callback is set.
if c.onChange != nil {
go c.onChange()
} else if err != nil && !IsNoMappingError(err) {
c.logf("createOrGetMapping: %v", err)
}
}
@@ -482,19 +547,19 @@ var wildcardIP = netip.MustParseAddr("0.0.0.0")
//
// If no mapping is available, the error will be of type
// NoMappingError; see IsNoMappingError.
func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPort, err error) {
func (c *Client) createOrGetMapping(ctx context.Context) (mapping mapping, external netip.AddrPort, err error) {
if c.debug.disableAll() {
return netip.AddrPort{}, NoMappingError{ErrPortMappingDisabled}
return nil, netip.AddrPort{}, NoMappingError{ErrPortMappingDisabled}
}
if c.debug.DisableUPnP && c.debug.DisablePCP && c.debug.DisablePMP {
return netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices}
if c.debug.DisableUPnP() && c.debug.DisablePCP() && c.debug.DisablePMP() {
return nil, netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices}
}
gw, myIP, ok := c.gatewayAndSelfIP()
if !ok {
return netip.AddrPort{}, NoMappingError{ErrGatewayRange}
return nil, netip.AddrPort{}, NoMappingError{ErrGatewayRange}
}
if gw.Is6() {
return netip.AddrPort{}, NoMappingError{ErrGatewayIPv6}
return nil, netip.AddrPort{}, NoMappingError{ErrGatewayIPv6}
}
now := time.Now()
@@ -523,6 +588,17 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor
return
}
// TODO(creachadair): This is more subtle than it should be. Ideally we
// would just return the mapping directly, but there are many different
// paths through the function with carefully-balanced locks, and not all
// the paths have a mapping to return. As a workaround, while we're here
// doing cleanup under the lock, grab the final mapping value and return
// it, so the caller does not need to grab the lock again and potentially
// race with a later update. The mapping itself is concurrency-safe.
//
// We should restructure this code so the locks are properly scoped.
mapping = c.mapping
// Print the internal details of each mapping if we're being verbose.
if c.debug.VerboseLogs {
c.logf("successfully obtained mapping: now=%d external=%v type=%s mapping=%s",
@@ -548,19 +624,19 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor
if now.Before(m.RenewAfter()) {
defer c.mu.Unlock()
reusedExisting = true
return m.External(), nil
return nil, m.External(), nil
}
// The mapping might still be valid, so just try to renew it.
prevPort = m.External().Port()
}
if c.debug.DisablePCP && c.debug.DisablePMP {
if c.debug.DisablePCP() && c.debug.DisablePMP() {
c.mu.Unlock()
if external, ok := c.getUPnPPortMapping(ctx, gw, internalAddr, prevPort); ok {
return external, nil
return nil, external, nil
}
c.vlogf("fallback to UPnP due to PCP and PMP being disabled failed")
return netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices}
return nil, netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices}
}
// If we just did a Probe (e.g. via netchecker) but didn't
@@ -587,16 +663,16 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor
c.mu.Unlock()
// fallback to UPnP portmapping
if external, ok := c.getUPnPPortMapping(ctx, gw, internalAddr, prevPort); ok {
return external, nil
return nil, external, nil
}
c.vlogf("fallback to UPnP due to no PCP and PMP failed")
return netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices}
return nil, netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices}
}
c.mu.Unlock()
uc, err := c.listenPacket(ctx, "udp4", ":0")
if err != nil {
return netip.AddrPort{}, err
return nil, netip.AddrPort{}, err
}
defer uc.Close()
@@ -605,7 +681,7 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor
pxpAddr := netip.AddrPortFrom(gw, c.pxpPort())
preferPCP := !c.debug.DisablePCP && (c.debug.DisablePMP || (!haveRecentPMP && haveRecentPCP))
preferPCP := !c.debug.DisablePCP() && (c.debug.DisablePMP() || (!haveRecentPMP && haveRecentPCP))
// Create a mapping, defaulting to PMP unless only PCP was seen recently.
if preferPCP {
@@ -616,7 +692,7 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor
if neterror.TreatAsLostUDP(err) {
err = NoMappingError{ErrNoPortMappingServices}
}
return netip.AddrPort{}, err
return nil, netip.AddrPort{}, err
}
} else {
// Ask for our external address if needed.
@@ -625,7 +701,7 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor
if neterror.TreatAsLostUDP(err) {
err = NoMappingError{ErrNoPortMappingServices}
}
return netip.AddrPort{}, err
return nil, netip.AddrPort{}, err
}
}
@@ -634,7 +710,7 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor
if neterror.TreatAsLostUDP(err) {
err = NoMappingError{ErrNoPortMappingServices}
}
return netip.AddrPort{}, err
return nil, netip.AddrPort{}, err
}
}
@@ -643,13 +719,13 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor
n, src, err := uc.ReadFromUDPAddrPort(res)
if err != nil {
if ctx.Err() == context.Canceled {
return netip.AddrPort{}, err
return nil, netip.AddrPort{}, err
}
// fallback to UPnP portmapping
if mapping, ok := c.getUPnPPortMapping(ctx, gw, internalAddr, prevPort); ok {
return mapping, nil
return nil, mapping, nil
}
return netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices}
return nil, netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices}
}
src = netaddr.Unmap(src)
if !src.IsValid() {
@@ -665,7 +741,7 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor
continue
}
if pres.ResultCode != 0 {
return netip.AddrPort{}, NoMappingError{fmt.Errorf("PMP response Op=0x%x,Res=0x%x", pres.OpCode, pres.ResultCode)}
return nil, netip.AddrPort{}, NoMappingError{fmt.Errorf("PMP response Op=0x%x,Res=0x%x", pres.OpCode, pres.ResultCode)}
}
if pres.OpCode == pmpOpReply|pmpOpMapPublicAddr {
m.external = netip.AddrPortFrom(pres.PublicAddr, m.external.Port())
@@ -683,7 +759,7 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor
if err != nil {
c.logf("failed to get PCP mapping: %v", err)
// PCP should only have a single packet response
return netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices}
return nil, netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices}
}
pcpMapping.c = c
pcpMapping.internal = m.internal
@@ -691,10 +767,10 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor
c.mu.Lock()
defer c.mu.Unlock()
c.mapping = pcpMapping
return pcpMapping.external, nil
return pcpMapping, pcpMapping.external, nil
default:
c.logf("unknown PMP/PCP version number: %d %v", version, res[:n])
return netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices}
return nil, netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices}
}
}
@@ -702,7 +778,7 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor
c.mu.Lock()
defer c.mu.Unlock()
c.mapping = m
return m.external, nil
return nil, m.external, nil
}
}
}
@@ -790,19 +866,13 @@ func parsePMPResponse(pkt []byte) (res pmpResponse, ok bool) {
return res, true
}
type ProbeResult struct {
PCP bool
PMP bool
UPnP bool
}
// Probe returns a summary of which port mapping services are
// available on the network.
//
// If a probe has run recently and there haven't been any network changes since,
// the returned result might be server from the Client's cache, without
// sending any network traffic.
func (c *Client) Probe(ctx context.Context) (res ProbeResult, err error) {
func (c *Client) Probe(ctx context.Context) (res portmappertype.ProbeResult, err error) {
if c.debug.disableAll() {
return res, ErrPortMappingDisabled
}
@@ -837,19 +907,19 @@ func (c *Client) Probe(ctx context.Context) (res ProbeResult, err error) {
// https://github.com/tailscale/tailscale/issues/1001
if c.sawPMPRecently() {
res.PMP = true
} else if !c.debug.DisablePMP {
} else if !c.debug.DisablePMP() {
metricPMPSent.Add(1)
uc.WriteToUDPAddrPort(pmpReqExternalAddrPacket, pxpAddr)
}
if c.sawPCPRecently() {
res.PCP = true
} else if !c.debug.DisablePCP {
} else if !c.debug.DisablePCP() {
metricPCPSent.Add(1)
uc.WriteToUDPAddrPort(pcpAnnounceRequest(myIP), pxpAddr)
}
if c.sawUPnPRecently() {
res.UPnP = true
} else if !c.debug.DisableUPnP {
} else if !c.debug.DisableUPnP() {
// Strictly speaking, you discover UPnP services by sending an
// SSDP query (which uPnPPacket is) to udp/1900 on the SSDP
// multicast address, and then get a flood of responses back

View File

@@ -0,0 +1,88 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package portmappertype defines the net/portmapper interface, which may or may not be
// linked into the binary.
package portmappertype
import (
"context"
"errors"
"net/netip"
"time"
"tailscale.com/feature"
"tailscale.com/net/netmon"
"tailscale.com/types/logger"
"tailscale.com/util/eventbus"
)
// HookNewPortMapper is a hook to install the portmapper creation function.
// It must be set by an init function when buildfeatures.HasPortmapper is true.
var HookNewPortMapper feature.Hook[func(logf logger.Logf,
bus *eventbus.Bus,
netMon *netmon.Monitor,
disableUPnPOrNil,
onlyTCP443OrNil func() bool) Client]
var (
ErrNoPortMappingServices = errors.New("no port mapping services were found")
ErrGatewayRange = errors.New("skipping portmap; gateway range likely lacks support")
ErrGatewayIPv6 = errors.New("skipping portmap; no IPv6 support for portmapping")
ErrPortMappingDisabled = errors.New("port mapping is disabled")
)
// ProbeResult is the result of a portmapper probe, saying
// which port mapping protocols were discovered.
type ProbeResult struct {
PCP bool
PMP bool
UPnP bool
}
// Client is the interface implemented by a portmapper client.
type Client interface {
// Probe returns a summary of which port mapping services are available on
// the network.
//
// If a probe has run recently and there haven't been any network changes
// since, the returned result might be server from the Client's cache,
// without sending any network traffic.
Probe(context.Context) (ProbeResult, error)
// HaveMapping reports whether we have a current valid mapping.
HaveMapping() bool
// SetGatewayLookupFunc set the func that returns the machine's default
// gateway IP, and the primary IP address for that gateway. It must be
// called before the client is used. If not called,
// interfaces.LikelyHomeRouterIP is used.
SetGatewayLookupFunc(f func() (gw, myIP netip.Addr, ok bool))
// NoteNetworkDown should be called when the network has transitioned to a down state.
// It's too late to release port mappings at this point (the user might've just turned off
// their wifi), but we can make sure we invalidate mappings for later when the network
// comes back.
NoteNetworkDown()
// GetCachedMappingOrStartCreatingOne quickly returns with our current cached portmapping, if any.
// If there's not one, it starts up a background goroutine to create one.
// If the background goroutine ends up creating one, the onChange hook registered with the
// NewClient constructor (if any) will fire.
GetCachedMappingOrStartCreatingOne() (external netip.AddrPort, ok bool)
// SetLocalPort updates the local port number to which we want to port
// map UDP traffic
SetLocalPort(localPort uint16)
Close() error
}
// Mapping is an event recording the allocation of a port mapping.
type Mapping struct {
External netip.AddrPort
Type string
GoodUntil time.Time
// TODO(creachadair): Record whether we reused an existing mapping?
}

View File

@@ -25,15 +25,47 @@ import (
"sync/atomic"
"time"
"github.com/tailscale/goupnp"
"github.com/tailscale/goupnp/dcps/internetgateway2"
"github.com/tailscale/goupnp/soap"
"github.com/huin/goupnp"
"github.com/huin/goupnp/dcps/internetgateway2"
"github.com/huin/goupnp/soap"
"tailscale.com/envknob"
"tailscale.com/net/netns"
"tailscale.com/types/logger"
"tailscale.com/util/ctxkey"
"tailscale.com/util/mak"
)
// upnpHTTPClientKey is a context key for storing an HTTP client to use
// for UPnP requests. This allows us to use a custom HTTP client (with custom
// dialer, timeouts, etc.) while using the upstream goupnp library which only
// supports a global HTTPClientDefault.
var upnpHTTPClientKey = ctxkey.New[*http.Client]("portmapper.upnpHTTPClient", nil)
// delegatingRoundTripper implements http.RoundTripper by delegating to
// the HTTP client stored in the request's context. This allows us to use
// per-request HTTP client configuration with the upstream goupnp library.
type delegatingRoundTripper struct {
inner *http.Client
}
func (d delegatingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if c := upnpHTTPClientKey.Value(req.Context()); c != nil {
return c.Transport.RoundTrip(req)
}
return d.inner.Do(req)
}
func init() {
// The upstream goupnp library uses a global HTTP client for all
// requests, while we want to be able to use a per-Client
// [http.Client]. We replace its global HTTP client with one that
// delegates to the HTTP client stored in the request's context.
old := goupnp.HTTPClientDefault
goupnp.HTTPClientDefault = &http.Client{
Transport: delegatingRoundTripper{old},
}
}
// References:
//
// WANIP Connection v2: http://upnp.org/specs/gw/UPnP-gw-WANIPConnection-v2-Service.pdf
@@ -79,14 +111,17 @@ func (u *upnpMapping) MappingDebug() string {
u.loc)
}
func (u *upnpMapping) Release(ctx context.Context) {
u.client.DeletePortMapping(ctx, "", u.external.Port(), upnpProtocolUDP)
u.client.DeletePortMappingCtx(ctx, "", u.external.Port(), upnpProtocolUDP)
}
// upnpClient is an interface over the multiple different clients exported by goupnp,
// exposing the functions we need for portmapping. Those clients are auto-generated from XML-specs,
// which is why they're not very idiomatic.
//
// The method names use the *Ctx suffix to match the upstream goupnp library's convention
// for context-aware methods.
type upnpClient interface {
AddPortMapping(
AddPortMappingCtx(
ctx context.Context,
// remoteHost is the remote device sending packets to this device, in the format of x.x.x.x.
@@ -119,9 +154,9 @@ type upnpClient interface {
leaseDurationSec uint32,
) error
DeletePortMapping(ctx context.Context, remoteHost string, externalPort uint16, protocol string) error
GetExternalIPAddress(ctx context.Context) (externalIPAddress string, err error)
GetStatusInfo(ctx context.Context) (status string, lastConnError string, uptime uint32, err error)
DeletePortMappingCtx(ctx context.Context, remoteHost string, externalPort uint16, protocol string) error
GetExternalIPAddressCtx(ctx context.Context) (externalIPAddress string, err error)
GetStatusInfoCtx(ctx context.Context) (status string, lastConnError string, uptime uint32, err error)
}
// tsPortMappingDesc gets sent to UPnP clients as a human-readable label for the portmapping.
@@ -171,7 +206,7 @@ func addAnyPortMapping(
// First off, try using AddAnyPortMapping; if there's a conflict, the
// router will pick another port and return it.
if upnp, ok := upnp.(*internetgateway2.WANIPConnection2); ok {
return upnp.AddAnyPortMapping(
return upnp.AddAnyPortMappingCtx(
ctx,
"",
externalPort,
@@ -186,7 +221,7 @@ func addAnyPortMapping(
// Fall back to using AddPortMapping, which requests a mapping to/from
// a specific external port.
err = upnp.AddPortMapping(
err = upnp.AddPortMappingCtx(
ctx,
"",
externalPort,
@@ -209,7 +244,7 @@ func addAnyPortMapping(
// The meta is the most recently parsed UDP discovery packet response
// from the Internet Gateway Device.
func getUPnPRootDevice(ctx context.Context, logf logger.Logf, debug DebugKnobs, gw netip.Addr, meta uPnPDiscoResponse) (rootDev *goupnp.RootDevice, loc *url.URL, err error) {
if debug.DisableUPnP {
if debug.DisableUPnP() {
return nil, nil, nil
}
@@ -244,7 +279,7 @@ func getUPnPRootDevice(ctx context.Context, logf logger.Logf, debug DebugKnobs,
defer cancel()
// This part does a network fetch.
root, err := goupnp.DeviceByURL(ctx, u)
root, err := goupnp.DeviceByURLCtx(ctx, u)
if err != nil {
return nil, nil, err
}
@@ -257,8 +292,7 @@ func getUPnPRootDevice(ctx context.Context, logf logger.Logf, debug DebugKnobs,
//
// loc is the parsed location that was used to fetch the given RootDevice.
//
// The provided ctx is not retained in the returned upnpClient, but
// its associated HTTP client is (if set via goupnp.WithHTTPClient).
// The provided ctx is not retained in the returned upnpClient.
func selectBestService(ctx context.Context, logf logger.Logf, root *goupnp.RootDevice, loc *url.URL) (client upnpClient, err error) {
method := "none"
defer func() {
@@ -274,9 +308,9 @@ func selectBestService(ctx context.Context, logf logger.Logf, root *goupnp.RootD
// First, get all available clients from the device, and append to our
// list of possible clients. Order matters here; we want to prefer
// WANIPConnection2 over WANIPConnection1 or WANPPPConnection.
wanIP2, _ := internetgateway2.NewWANIPConnection2ClientsFromRootDevice(ctx, root, loc)
wanIP1, _ := internetgateway2.NewWANIPConnection1ClientsFromRootDevice(ctx, root, loc)
wanPPP, _ := internetgateway2.NewWANPPPConnection1ClientsFromRootDevice(ctx, root, loc)
wanIP2, _ := internetgateway2.NewWANIPConnection2ClientsFromRootDevice(root, loc)
wanIP1, _ := internetgateway2.NewWANIPConnection1ClientsFromRootDevice(root, loc)
wanPPP, _ := internetgateway2.NewWANPPPConnection1ClientsFromRootDevice(root, loc)
var clients []upnpClient
for _, v := range wanIP2 {
@@ -291,12 +325,12 @@ func selectBestService(ctx context.Context, logf logger.Logf, root *goupnp.RootD
// These are legacy services that were deprecated in 2015, but are
// still in use by older devices; try them just in case.
legacyClients, _ := goupnp.NewServiceClientsFromRootDevice(ctx, root, loc, urn_LegacyWANPPPConnection_1)
legacyClients, _ := goupnp.NewServiceClientsFromRootDevice(root, loc, urn_LegacyWANPPPConnection_1)
metricUPnPSelectLegacy.Add(int64(len(legacyClients)))
for _, client := range legacyClients {
clients = append(clients, &legacyWANPPPConnection1{client})
}
legacyClients, _ = goupnp.NewServiceClientsFromRootDevice(ctx, root, loc, urn_LegacyWANIPConnection_1)
legacyClients, _ = goupnp.NewServiceClientsFromRootDevice(root, loc, urn_LegacyWANIPConnection_1)
metricUPnPSelectLegacy.Add(int64(len(legacyClients)))
for _, client := range legacyClients {
clients = append(clients, &legacyWANIPConnection1{client})
@@ -346,7 +380,7 @@ func selectBestService(ctx context.Context, logf logger.Logf, root *goupnp.RootD
}
// Check if the device has an external IP address.
extIP, err := svc.GetExternalIPAddress(ctx)
extIP, err := svc.GetExternalIPAddressCtx(ctx)
if err != nil {
continue
}
@@ -399,7 +433,7 @@ func selectBestService(ctx context.Context, logf logger.Logf, root *goupnp.RootD
// serviceIsConnected returns whether a given UPnP service is connected, based
// on the NewConnectionStatus field returned from GetStatusInfo.
func serviceIsConnected(ctx context.Context, logf logger.Logf, svc upnpClient) bool {
status, _ /* NewLastConnectionError */, _ /* NewUptime */, err := svc.GetStatusInfo(ctx)
status, _ /* NewLastConnectionError */, _ /* NewUptime */, err := svc.GetStatusInfoCtx(ctx)
if err != nil {
return false
}
@@ -434,7 +468,7 @@ func (c *Client) getUPnPPortMapping(
internal netip.AddrPort,
prevPort uint16,
) (external netip.AddrPort, ok bool) {
if disableUPnpEnv() || c.debug.DisableUPnP || (c.controlKnobs != nil && c.controlKnobs.DisableUPnP.Load()) {
if disableUPnpEnv() || c.debug.DisableUPnP() {
return netip.AddrPort{}, false
}
@@ -454,7 +488,7 @@ func (c *Client) getUPnPPortMapping(
c.mu.Lock()
oldMapping, ok := c.mapping.(*upnpMapping)
metas := c.uPnPMetas
ctx = goupnp.WithHTTPClient(ctx, c.upnpHTTPClientLocked())
ctx = upnpHTTPClientKey.WithValue(ctx, c.upnpHTTPClientLocked())
c.mu.Unlock()
// Wrapper for a uPnPDiscoResponse with an optional existing root
@@ -629,7 +663,7 @@ func (c *Client) tryUPnPPortmapWithDevice(
}
// TODO cache this ip somewhere?
extIP, err := client.GetExternalIPAddress(ctx)
extIP, err := client.GetExternalIPAddressCtx(ctx)
c.vlogf("client.GetExternalIPAddress: %v, %v", extIP, err)
if err != nil {
return netip.AddrPort{}, nil, err

View File

@@ -1,152 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package routetable provides functions that operate on the system's route
// table.
package routetable
import (
"bufio"
"fmt"
"net/netip"
"strconv"
"tailscale.com/types/logger"
)
var (
//lint:ignore U1000 used in routetable_linux_test.go and routetable_bsd_test.go
defaultRouteIPv4 = RouteDestination{Prefix: netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
//lint:ignore U1000 used in routetable_bsd_test.go
defaultRouteIPv6 = RouteDestination{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
)
// RouteEntry contains common cross-platform fields describing an entry in the
// system route table.
type RouteEntry struct {
// Family is the IP family of the route; it will be either 4 or 6.
Family int
// Type is the type of this route.
Type RouteType
// Dst is the destination of the route.
Dst RouteDestination
// Gatewayis the gateway address specified for this route.
// This value will be invalid (where !r.Gateway.IsValid()) in cases
// where there is no gateway address for this route.
Gateway netip.Addr
// Interface is the name of the network interface to use when sending
// packets that match this route. This field can be empty.
Interface string
// Sys contains platform-specific information about this route.
Sys any
}
// Format implements the fmt.Formatter interface.
func (r RouteEntry) Format(f fmt.State, verb rune) {
logger.ArgWriter(func(w *bufio.Writer) {
switch r.Family {
case 4:
fmt.Fprintf(w, "{Family: IPv4")
case 6:
fmt.Fprintf(w, "{Family: IPv6")
default:
fmt.Fprintf(w, "{Family: unknown(%d)", r.Family)
}
// Match 'ip route' and other tools by not printing the route
// type if it's a unicast route.
if r.Type != RouteTypeUnicast {
fmt.Fprintf(w, ", Type: %s", r.Type)
}
if r.Dst.IsValid() {
fmt.Fprintf(w, ", Dst: %s", r.Dst)
} else {
w.WriteString(", Dst: invalid")
}
if r.Gateway.IsValid() {
fmt.Fprintf(w, ", Gateway: %s", r.Gateway)
}
if r.Interface != "" {
fmt.Fprintf(w, ", Interface: %s", r.Interface)
}
if r.Sys != nil {
var formatVerb string
switch {
case f.Flag('#'):
formatVerb = "%#v"
case f.Flag('+'):
formatVerb = "%+v"
default:
formatVerb = "%v"
}
fmt.Fprintf(w, ", Sys: "+formatVerb, r.Sys)
}
w.WriteString("}")
}).Format(f, verb)
}
// RouteDestination is the destination of a route.
//
// This is similar to net/netip.Prefix, but also contains an optional IPv6
// zone.
type RouteDestination struct {
netip.Prefix
Zone string
}
func (r RouteDestination) String() string {
ip := r.Prefix.Addr()
if r.Zone != "" {
ip = ip.WithZone(r.Zone)
}
return ip.String() + "/" + strconv.Itoa(r.Prefix.Bits())
}
// RouteType describes the type of a route.
type RouteType int
const (
// RouteTypeUnspecified is the unspecified route type.
RouteTypeUnspecified RouteType = iota
// RouteTypeLocal indicates that the destination of this route is an
// address that belongs to this system.
RouteTypeLocal
// RouteTypeUnicast indicates that the destination of this route is a
// "regular" address--one that neither belongs to this host, nor is a
// broadcast/multicast/etc. address.
RouteTypeUnicast
// RouteTypeBroadcast indicates that the destination of this route is a
// broadcast address.
RouteTypeBroadcast
// RouteTypeMulticast indicates that the destination of this route is a
// multicast address.
RouteTypeMulticast
// RouteTypeOther indicates that the route is of some other valid type;
// see the Sys field for the OS-provided route information to determine
// the exact type.
RouteTypeOther
)
func (r RouteType) String() string {
switch r {
case RouteTypeUnspecified:
return "unspecified"
case RouteTypeLocal:
return "local"
case RouteTypeUnicast:
return "unicast"
case RouteTypeBroadcast:
return "broadcast"
case RouteTypeMulticast:
return "multicast"
case RouteTypeOther:
return "other"
default:
return "invalid"
}
}

View File

@@ -1,293 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build darwin || freebsd
package routetable
import (
"bufio"
"fmt"
"net"
"net/netip"
"runtime"
"sort"
"strings"
"syscall"
"golang.org/x/net/route"
"golang.org/x/sys/unix"
"tailscale.com/net/netmon"
"tailscale.com/types/logger"
)
type RouteEntryBSD struct {
// GatewayInterface is the name of the interface specified as a gateway
// for this route, if any.
GatewayInterface string
// GatewayIdx is the index of the interface specified as a gateway for
// this route, if any.
GatewayIdx int
// GatewayAddr is the link-layer address of the gateway for this route,
// if any.
GatewayAddr string
// Flags contains a string representation of common flags for this
// route.
Flags []string
// RawFlags contains the raw flags that were returned by the operating
// system for this route.
RawFlags int
}
// Format implements the fmt.Formatter interface.
func (r RouteEntryBSD) Format(f fmt.State, verb rune) {
logger.ArgWriter(func(w *bufio.Writer) {
var pstart bool
pr := func(format string, args ...any) {
if pstart {
fmt.Fprintf(w, ", "+format, args...)
} else {
fmt.Fprintf(w, format, args...)
pstart = true
}
}
w.WriteString("{")
if r.GatewayInterface != "" {
pr("GatewayInterface: %s", r.GatewayInterface)
}
if r.GatewayIdx > 0 {
pr("GatewayIdx: %d", r.GatewayIdx)
}
if r.GatewayAddr != "" {
pr("GatewayAddr: %s", r.GatewayAddr)
}
pr("Flags: %v", r.Flags)
unknownFlags := r.RawFlags
for fv := range flags {
if r.RawFlags&fv == fv {
unknownFlags &= ^fv
}
}
if unknownFlags != 0 {
pr("UnknownFlags: %x ", unknownFlags)
}
w.WriteString("}")
}).Format(f, verb)
}
// ipFromRMAddr returns a netip.Addr converted from one of the
// route.Inet{4,6}Addr types.
func ipFromRMAddr(ifs map[int]netmon.Interface, addr any) netip.Addr {
switch v := addr.(type) {
case *route.Inet4Addr:
return netip.AddrFrom4(v.IP)
case *route.Inet6Addr:
ip := netip.AddrFrom16(v.IP)
if v.ZoneID != 0 {
if iif, ok := ifs[v.ZoneID]; ok {
ip = ip.WithZone(iif.Name)
} else {
ip = ip.WithZone(fmt.Sprint(v.ZoneID))
}
}
return ip
}
return netip.Addr{}
}
// populateGateway populates gateway fields on a RouteEntry/RouteEntryBSD.
func populateGateway(re *RouteEntry, reSys *RouteEntryBSD, ifs map[int]netmon.Interface, addr any) {
// If the address type has a valid IP, use that.
if ip := ipFromRMAddr(ifs, addr); ip.IsValid() {
re.Gateway = ip
return
}
switch v := addr.(type) {
case *route.LinkAddr:
reSys.GatewayIdx = v.Index
if iif, ok := ifs[v.Index]; ok {
reSys.GatewayInterface = iif.Name
}
var sb strings.Builder
for i, x := range v.Addr {
if i != 0 {
sb.WriteByte(':')
}
fmt.Fprintf(&sb, "%02x", x)
}
reSys.GatewayAddr = sb.String()
}
}
// populateDestination populates the 'Dst' field on a RouteEntry based on the
// RouteMessage's destination and netmask fields.
func populateDestination(re *RouteEntry, ifs map[int]netmon.Interface, rm *route.RouteMessage) {
dst := rm.Addrs[unix.RTAX_DST]
if dst == nil {
return
}
ip := ipFromRMAddr(ifs, dst)
if !ip.IsValid() {
return
}
if ip.Is4() {
re.Family = 4
} else {
re.Family = 6
}
re.Dst = RouteDestination{
Prefix: netip.PrefixFrom(ip, 32), // default if nothing more specific
}
// If the RTF_HOST flag is set, then this is a host route and there's
// no netmask in this RouteMessage.
if rm.Flags&unix.RTF_HOST != 0 {
return
}
// As above if there's no netmask in the list of addrs
if len(rm.Addrs) < unix.RTAX_NETMASK || rm.Addrs[unix.RTAX_NETMASK] == nil {
return
}
nm := ipFromRMAddr(ifs, rm.Addrs[unix.RTAX_NETMASK])
if !ip.IsValid() {
return
}
// Count the number of bits in the netmask IP and use that to make our prefix.
ones, _ /* bits */ := net.IPMask(nm.AsSlice()).Size()
// Print this ourselves instead of using netip.Prefix so that we don't
// lose the zone (since netip.Prefix strips that).
//
// NOTE(andrew): this doesn't print the same values as the 'netstat' tool
// for some addresses on macOS, and I have no idea why. Specifically,
// 'netstat -rn' will show something like:
// ff00::/8 ::1 UmCI lo0
//
// But we will get:
// destination=ff00::/40 [...]
//
// The netmask that we get back from FetchRIB has 32 more bits in it
// than netstat prints, but only for multicast routes.
//
// For consistency's sake, we're going to do the same here so that we
// get the same values as netstat returns.
if runtime.GOOS == "darwin" && ip.Is6() && ip.IsMulticast() && ones > 32 {
ones -= 32
}
re.Dst = RouteDestination{
Prefix: netip.PrefixFrom(ip, ones),
Zone: ip.Zone(),
}
}
// routeEntryFromMsg returns a RouteEntry from a single route.Message
// returned by the operating system.
func routeEntryFromMsg(ifsByIdx map[int]netmon.Interface, msg route.Message) (RouteEntry, bool) {
rm, ok := msg.(*route.RouteMessage)
if !ok {
return RouteEntry{}, false
}
// Ignore things that we don't understand
if rm.Version < 3 || rm.Version > 5 {
return RouteEntry{}, false
}
if rm.Type != rmExpectedType {
return RouteEntry{}, false
}
if len(rm.Addrs) < unix.RTAX_GATEWAY {
return RouteEntry{}, false
}
if rm.Flags&skipFlags != 0 {
return RouteEntry{}, false
}
reSys := RouteEntryBSD{
RawFlags: rm.Flags,
}
for fv, fs := range flags {
if rm.Flags&fv == fv {
reSys.Flags = append(reSys.Flags, fs)
}
}
sort.Strings(reSys.Flags)
re := RouteEntry{}
hasFlag := func(f int) bool { return rm.Flags&f != 0 }
switch {
case hasFlag(unix.RTF_LOCAL):
re.Type = RouteTypeLocal
case hasFlag(unix.RTF_BROADCAST):
re.Type = RouteTypeBroadcast
case hasFlag(unix.RTF_MULTICAST):
re.Type = RouteTypeMulticast
// From the manpage: "host entry (net otherwise)"
case !hasFlag(unix.RTF_HOST):
re.Type = RouteTypeUnicast
default:
re.Type = RouteTypeOther
}
populateDestination(&re, ifsByIdx, rm)
if unix.RTAX_GATEWAY < len(rm.Addrs) {
populateGateway(&re, &reSys, ifsByIdx, rm.Addrs[unix.RTAX_GATEWAY])
}
if outif, ok := ifsByIdx[rm.Index]; ok {
re.Interface = outif.Name
}
re.Sys = reSys
return re, true
}
// Get returns route entries from the system route table, limited to at most
// 'max' results.
func Get(max int) ([]RouteEntry, error) {
// Fetching the list of interfaces can race with fetching our route
// table, but we do it anyway since it's helpful for debugging.
ifs, err := netmon.GetInterfaceList()
if err != nil {
return nil, err
}
ifsByIdx := make(map[int]netmon.Interface)
for _, iif := range ifs {
ifsByIdx[iif.Index] = iif
}
rib, err := route.FetchRIB(syscall.AF_UNSPEC, ribType, 0)
if err != nil {
return nil, err
}
msgs, err := route.ParseRIB(parseType, rib)
if err != nil {
return nil, err
}
var ret []RouteEntry
for _, m := range msgs {
re, ok := routeEntryFromMsg(ifsByIdx, m)
if ok {
ret = append(ret, re)
if len(ret) == max {
break
}
}
}
return ret, nil
}

View File

@@ -1,36 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build darwin
package routetable
import "golang.org/x/sys/unix"
const (
ribType = unix.NET_RT_DUMP2
parseType = unix.NET_RT_IFLIST2
rmExpectedType = unix.RTM_GET2
// Skip routes that were cloned from a parent
skipFlags = unix.RTF_WASCLONED
)
var flags = map[int]string{
unix.RTF_BLACKHOLE: "blackhole",
unix.RTF_BROADCAST: "broadcast",
unix.RTF_GATEWAY: "gateway",
unix.RTF_GLOBAL: "global",
unix.RTF_HOST: "host",
unix.RTF_IFSCOPE: "ifscope",
unix.RTF_LOCAL: "local",
unix.RTF_MULTICAST: "multicast",
unix.RTF_REJECT: "reject",
unix.RTF_ROUTER: "router",
unix.RTF_STATIC: "static",
unix.RTF_UP: "up",
// More obscure flags, just to have full coverage.
unix.RTF_LLINFO: "{RTF_LLINFO}",
unix.RTF_PRCLONING: "{RTF_PRCLONING}",
unix.RTF_CLONING: "{RTF_CLONING}",
}

View File

@@ -1,28 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build freebsd
package routetable
import "golang.org/x/sys/unix"
const (
ribType = unix.NET_RT_DUMP
parseType = unix.NET_RT_IFLIST
rmExpectedType = unix.RTM_GET
// Nothing to skip
skipFlags = 0
)
var flags = map[int]string{
unix.RTF_BLACKHOLE: "blackhole",
unix.RTF_BROADCAST: "broadcast",
unix.RTF_GATEWAY: "gateway",
unix.RTF_HOST: "host",
unix.RTF_MULTICAST: "multicast",
unix.RTF_REJECT: "reject",
unix.RTF_STATIC: "static",
unix.RTF_UP: "up",
}

View File

@@ -1,229 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build linux
package routetable
import (
"bufio"
"fmt"
"net/netip"
"strconv"
"github.com/tailscale/netlink"
"golang.org/x/sys/unix"
"tailscale.com/net/netaddr"
"tailscale.com/net/netmon"
"tailscale.com/types/logger"
)
// RouteEntryLinux is the structure that makes up the Sys field of the
// RouteEntry structure.
type RouteEntryLinux struct {
// Type is the raw type of the route.
Type int
// Table is the routing table index of this route.
Table int
// Src is the source of the route (if any).
Src netip.Addr
// Proto describes the source of the route--i.e. what caused this route
// to be added to the route table.
Proto netlink.RouteProtocol
// Priority is the route's priority.
Priority int
// Scope is the route's scope.
Scope int
// InputInterfaceIdx is the input interface index.
InputInterfaceIdx int
// InputInterfaceName is the input interface name (if available).
InputInterfaceName string
}
// Format implements the fmt.Formatter interface.
func (r RouteEntryLinux) Format(f fmt.State, verb rune) {
logger.ArgWriter(func(w *bufio.Writer) {
// TODO(andrew): should we skip printing anything if type is unicast?
fmt.Fprintf(w, "{Type: %s", r.TypeName())
// Match 'ip route' behaviour when printing these fields
if r.Table != unix.RT_TABLE_MAIN {
fmt.Fprintf(w, ", Table: %s", r.TableName())
}
if r.Proto != unix.RTPROT_BOOT {
fmt.Fprintf(w, ", Proto: %s", r.Proto)
}
if r.Src.IsValid() {
fmt.Fprintf(w, ", Src: %s", r.Src)
}
if r.Priority != 0 {
fmt.Fprintf(w, ", Priority: %d", r.Priority)
}
if r.Scope != unix.RT_SCOPE_UNIVERSE {
fmt.Fprintf(w, ", Scope: %s", r.ScopeName())
}
if r.InputInterfaceName != "" {
fmt.Fprintf(w, ", InputInterfaceName: %s", r.InputInterfaceName)
} else if r.InputInterfaceIdx != 0 {
fmt.Fprintf(w, ", InputInterfaceIdx: %d", r.InputInterfaceIdx)
}
w.WriteString("}")
}).Format(f, verb)
}
// TypeName returns the string representation of this route's Type.
func (r RouteEntryLinux) TypeName() string {
switch r.Type {
case unix.RTN_UNSPEC:
return "none"
case unix.RTN_UNICAST:
return "unicast"
case unix.RTN_LOCAL:
return "local"
case unix.RTN_BROADCAST:
return "broadcast"
case unix.RTN_ANYCAST:
return "anycast"
case unix.RTN_MULTICAST:
return "multicast"
case unix.RTN_BLACKHOLE:
return "blackhole"
case unix.RTN_UNREACHABLE:
return "unreachable"
case unix.RTN_PROHIBIT:
return "prohibit"
case unix.RTN_THROW:
return "throw"
case unix.RTN_NAT:
return "nat"
case unix.RTN_XRESOLVE:
return "xresolve"
default:
return strconv.Itoa(r.Type)
}
}
// TableName returns the string representation of this route's Table.
func (r RouteEntryLinux) TableName() string {
switch r.Table {
case unix.RT_TABLE_DEFAULT:
return "default"
case unix.RT_TABLE_MAIN:
return "main"
case unix.RT_TABLE_LOCAL:
return "local"
default:
return strconv.Itoa(r.Table)
}
}
// ScopeName returns the string representation of this route's Scope.
func (r RouteEntryLinux) ScopeName() string {
switch r.Scope {
case unix.RT_SCOPE_UNIVERSE:
return "global"
case unix.RT_SCOPE_NOWHERE:
return "nowhere"
case unix.RT_SCOPE_HOST:
return "host"
case unix.RT_SCOPE_LINK:
return "link"
case unix.RT_SCOPE_SITE:
return "site"
default:
return strconv.Itoa(r.Scope)
}
}
// Get returns route entries from the system route table, limited to at most
// max results.
func Get(max int) ([]RouteEntry, error) {
// Fetching the list of interfaces can race with fetching our route
// table, but we do it anyway since it's helpful for debugging.
ifs, err := netmon.GetInterfaceList()
if err != nil {
return nil, err
}
ifsByIdx := make(map[int]netmon.Interface)
for _, iif := range ifs {
ifsByIdx[iif.Index] = iif
}
filter := &netlink.Route{}
routes, err := netlink.RouteListFiltered(netlink.FAMILY_ALL, filter, netlink.RT_FILTER_TABLE)
if err != nil {
return nil, err
}
var ret []RouteEntry
for _, route := range routes {
if route.Family != netlink.FAMILY_V4 && route.Family != netlink.FAMILY_V6 {
continue
}
re := RouteEntry{}
if route.Family == netlink.FAMILY_V4 {
re.Family = 4
} else {
re.Family = 6
}
switch route.Type {
case unix.RTN_UNSPEC:
re.Type = RouteTypeUnspecified
case unix.RTN_UNICAST:
re.Type = RouteTypeUnicast
case unix.RTN_LOCAL:
re.Type = RouteTypeLocal
case unix.RTN_BROADCAST:
re.Type = RouteTypeBroadcast
case unix.RTN_MULTICAST:
re.Type = RouteTypeMulticast
default:
re.Type = RouteTypeOther
}
if route.Dst != nil {
if d, ok := netaddr.FromStdIPNet(route.Dst); ok {
re.Dst = RouteDestination{Prefix: d}
}
} else if route.Family == netlink.FAMILY_V4 {
re.Dst = RouteDestination{Prefix: netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
} else {
re.Dst = RouteDestination{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
}
if gw := route.Gw; gw != nil {
if gwa, ok := netip.AddrFromSlice(gw); ok {
re.Gateway = gwa
}
}
if outif, ok := ifsByIdx[route.LinkIndex]; ok {
re.Interface = outif.Name
} else if route.LinkIndex > 0 {
re.Interface = fmt.Sprintf("link#%d", route.LinkIndex)
}
reSys := RouteEntryLinux{
Type: route.Type,
Table: route.Table,
Proto: route.Protocol,
Priority: route.Priority,
Scope: int(route.Scope),
InputInterfaceIdx: route.ILinkIndex,
}
if src, ok := netip.AddrFromSlice(route.Src); ok {
reSys.Src = src
}
if iif, ok := ifsByIdx[route.ILinkIndex]; ok {
reSys.InputInterfaceName = iif.Name
}
re.Sys = reSys
ret = append(ret, re)
// Stop after we've reached the maximum number of routes
if len(ret) == max {
break
}
}
return ret, nil
}

View File

@@ -1,17 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build !linux && !darwin && !freebsd
package routetable
import (
"errors"
"runtime"
)
var errUnsupported = errors.New("cannot get route table on platform " + runtime.GOOS)
func Get(max int) ([]RouteEntry, error) {
return nil, errUnsupported
}

37
vendor/tailscale.com/net/sockopts/sockopts.go generated vendored Normal file
View File

@@ -0,0 +1,37 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package sockopts contains logic for applying socket options.
package sockopts
import (
"net"
"runtime"
"tailscale.com/types/nettype"
)
// BufferDirection represents either the read/receive or write/send direction
// of a socket buffer.
type BufferDirection string
const (
ReadDirection BufferDirection = "read"
WriteDirection BufferDirection = "write"
)
func portableSetBufferSize(pconn nettype.PacketConn, direction BufferDirection, size int) error {
if runtime.GOOS == "plan9" {
// Not supported. Don't try. Avoid logspam.
return nil
}
var err error
if c, ok := pconn.(*net.UDPConn); ok {
if direction == WriteDirection {
err = c.SetWriteBuffer(size)
} else {
err = c.SetReadBuffer(size)
}
}
return err
}

21
vendor/tailscale.com/net/sockopts/sockopts_default.go generated vendored Normal file
View File

@@ -0,0 +1,21 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build !linux
package sockopts
import (
"tailscale.com/types/nettype"
)
// SetBufferSize sets pconn's buffer to size for direction. size may be silently
// capped depending on platform.
//
// errForce is only relevant for Linux, and will always be nil otherwise,
// but we maintain a consistent cross-platform API.
//
// If pconn is not a [*net.UDPConn], then SetBufferSize is no-op.
func SetBufferSize(pconn nettype.PacketConn, direction BufferDirection, size int) (errForce error, errPortable error) {
return nil, portableSetBufferSize(pconn, direction, size)
}

40
vendor/tailscale.com/net/sockopts/sockopts_linux.go generated vendored Normal file
View File

@@ -0,0 +1,40 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build linux
package sockopts
import (
"net"
"syscall"
"tailscale.com/types/nettype"
)
// SetBufferSize sets pconn's buffer to size for direction. It attempts
// (errForce) to set SO_SNDBUFFORCE or SO_RECVBUFFORCE which can overcome the
// limit of net.core.{r,w}mem_max, but require CAP_NET_ADMIN. It falls back to
// the portable implementation (errPortable) if that fails, which may be
// silently capped to net.core.{r,w}mem_max.
//
// If pconn is not a [*net.UDPConn], then SetBufferSize is no-op.
func SetBufferSize(pconn nettype.PacketConn, direction BufferDirection, size int) (errForce error, errPortable error) {
opt := syscall.SO_RCVBUFFORCE
if direction == WriteDirection {
opt = syscall.SO_SNDBUFFORCE
}
if c, ok := pconn.(*net.UDPConn); ok {
var rc syscall.RawConn
rc, errForce = c.SyscallConn()
if errForce == nil {
rc.Control(func(fd uintptr) {
errForce = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, opt, size)
})
}
if errForce != nil {
errPortable = portableSetBufferSize(pconn, direction, size)
}
}
return errForce, errPortable
}

View File

@@ -0,0 +1,15 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build !windows
package sockopts
import (
"tailscale.com/types/nettype"
)
// SetICMPErrImmunity is no-op on non-Windows.
func SetICMPErrImmunity(pconn nettype.PacketConn) error {
return nil
}

62
vendor/tailscale.com/net/sockopts/sockopts_windows.go generated vendored Normal file
View File

@@ -0,0 +1,62 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build windows
package sockopts
import (
"fmt"
"net"
"unsafe"
"golang.org/x/sys/windows"
"tailscale.com/types/nettype"
)
// SetICMPErrImmunity sets socket options on pconn to prevent ICMP reception,
// e.g. ICMP Port Unreachable, from surfacing as a syscall error.
//
// If pconn is not a [*net.UDPConn], then SetICMPErrImmunity is no-op.
func SetICMPErrImmunity(pconn nettype.PacketConn) error {
c, ok := pconn.(*net.UDPConn)
if !ok {
// not a UDP connection; nothing to do
return nil
}
sysConn, err := c.SyscallConn()
if err != nil {
return fmt.Errorf("SetICMPErrImmunity: getting SyscallConn failed: %v", err)
}
// Similar to https://github.com/golang/go/issues/5834 (which involved
// WSAECONNRESET), Windows can return a WSAENETRESET error, even on UDP
// reads. Disable this.
const SIO_UDP_NETRESET = windows.IOC_IN | windows.IOC_VENDOR | 15
var ioctlErr error
err = sysConn.Control(func(fd uintptr) {
ret := uint32(0)
flag := uint32(0)
size := uint32(unsafe.Sizeof(flag))
ioctlErr = windows.WSAIoctl(
windows.Handle(fd),
SIO_UDP_NETRESET, // iocc
(*byte)(unsafe.Pointer(&flag)), // inbuf
size, // cbif
nil, // outbuf
0, // cbob
&ret, // cbbr
nil, // overlapped
0, // completionRoutine
)
})
if ioctlErr != nil {
return fmt.Errorf("SetICMPErrImmunity: could not set SIO_UDP_NETRESET: %v", ioctlErr)
}
if err != nil {
return fmt.Errorf("SetICMPErrImmunity: SyscallConn.Control failed: %v", err)
}
return nil
}

View File

@@ -120,10 +120,10 @@ func (s *Server) logf(format string, args ...any) {
}
// Serve accepts and handles incoming connections on the given listener.
func (s *Server) Serve(l net.Listener) error {
defer l.Close()
func (s *Server) Serve(ln net.Listener) error {
defer ln.Close()
for {
c, err := l.Accept()
c, err := ln.Accept()
if err != nil {
return err
}

View File

@@ -28,8 +28,9 @@ const _Label_name = "ControlClientAutoControlClientDialerDERPHTTPClientLogtailLo
var _Label_index = [...]uint8{0, 17, 36, 50, 63, 78, 93, 107, 123, 140, 157, 169, 186, 201}
func (i Label) String() string {
if i >= Label(len(_Label_index)-1) {
idx := int(i) - 0
if i < 0 || idx >= len(_Label_index)-1 {
return "Label(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _Label_name[_Label_index[i]:_Label_index[i+1]]
return _Label_name[_Label_index[idx]:_Label_index[idx+1]]
}

View File

@@ -10,12 +10,12 @@ import (
"fmt"
"net"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
"tailscale.com/net/netmon"
"tailscale.com/syncs"
"tailscale.com/types/logger"
"tailscale.com/util/clientmetric"
"tailscale.com/version"
@@ -40,7 +40,7 @@ var sockStats = struct {
// mu protects fields in this group (but not the fields within
// sockStatCounters). It should not be held in the per-read/write
// callbacks.
mu sync.Mutex
mu syncs.Mutex
countersByLabel map[Label]*sockStatCounters
knownInterfaces map[int]string // interface index -> name
usedInterfaces map[int]int // set of interface indexes
@@ -271,10 +271,10 @@ func setNetMon(netMon *netmon.Monitor) {
}
netMon.RegisterChangeCallback(func(delta *netmon.ChangeDelta) {
if !delta.Major {
if !delta.RebindLikelyRequired {
return
}
state := delta.New
state := delta.CurrentState()
ifName := state.DefaultRouteInterface
if ifName == "" {
return

View File

@@ -1,51 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package tcpinfo provides platform-agnostic accessors to information about a
// TCP connection (e.g. RTT, MSS, etc.).
package tcpinfo
import (
"errors"
"net"
"time"
)
var (
ErrNotTCP = errors.New("tcpinfo: not a TCP conn")
ErrUnimplemented = errors.New("tcpinfo: unimplemented")
)
// RTT returns the RTT for the given net.Conn.
//
// If the net.Conn is not a *net.TCPConn and cannot be unwrapped into one, then
// ErrNotTCP will be returned. If retrieving the RTT is not supported on the
// current platform, ErrUnimplemented will be returned.
func RTT(conn net.Conn) (time.Duration, error) {
tcpConn, err := unwrap(conn)
if err != nil {
return 0, err
}
return rttImpl(tcpConn)
}
// netConner is implemented by crypto/tls.Conn to unwrap into an underlying
// net.Conn.
type netConner interface {
NetConn() net.Conn
}
// unwrap attempts to unwrap a net.Conn into an underlying *net.TCPConn
func unwrap(nc net.Conn) (*net.TCPConn, error) {
for {
switch v := nc.(type) {
case *net.TCPConn:
return v, nil
case netConner:
nc = v.NetConn()
default:
return nil, ErrNotTCP
}
}
}

View File

@@ -1,33 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package tcpinfo
import (
"net"
"time"
"golang.org/x/sys/unix"
)
func rttImpl(conn *net.TCPConn) (time.Duration, error) {
rawConn, err := conn.SyscallConn()
if err != nil {
return 0, err
}
var (
tcpInfo *unix.TCPConnectionInfo
sysErr error
)
err = rawConn.Control(func(fd uintptr) {
tcpInfo, sysErr = unix.GetsockoptTCPConnectionInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_CONNECTION_INFO)
})
if err != nil {
return 0, err
} else if sysErr != nil {
return 0, sysErr
}
return time.Duration(tcpInfo.Rttcur) * time.Millisecond, nil
}

View File

@@ -1,33 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package tcpinfo
import (
"net"
"time"
"golang.org/x/sys/unix"
)
func rttImpl(conn *net.TCPConn) (time.Duration, error) {
rawConn, err := conn.SyscallConn()
if err != nil {
return 0, err
}
var (
tcpInfo *unix.TCPInfo
sysErr error
)
err = rawConn.Control(func(fd uintptr) {
tcpInfo, sysErr = unix.GetsockoptTCPInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_INFO)
})
if err != nil {
return 0, err
} else if sysErr != nil {
return 0, sysErr
}
return time.Duration(tcpInfo.Rtt) * time.Microsecond, nil
}

View File

@@ -1,15 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build !linux && !darwin
package tcpinfo
import (
"net"
"time"
)
func rttImpl(conn *net.TCPConn) (time.Duration, error) {
return 0, ErrUnimplemented
}

View File

@@ -9,13 +9,19 @@ package blockblame
import (
"crypto/x509"
"strings"
"sync"
"tailscale.com/feature/buildfeatures"
)
// VerifyCertificate checks if the given certificate c is issued by a firewall manufacturer
// that is known to block Tailscale connections. It returns true and the Manufacturer of
// the equipment if it is, or false and nil if it is not.
func VerifyCertificate(c *x509.Certificate) (m *Manufacturer, ok bool) {
for _, m := range Manufacturers {
if !buildfeatures.HasDebug {
return nil, false
}
for _, m := range manufacturers() {
if m.match != nil && m.match(c) {
return m, true
}
@@ -33,46 +39,56 @@ type Manufacturer struct {
match matchFunc
}
var Manufacturers = []*Manufacturer{
{
Name: "Aruba Networks",
match: issuerContains("Aruba"),
},
{
Name: "Cisco",
match: issuerContains("Cisco"),
},
{
Name: "Fortinet",
match: matchAny(
issuerContains("Fortinet"),
certEmail("support@fortinet.com"),
),
},
{
Name: "Huawei",
match: certEmail("mobile@huawei.com"),
},
{
Name: "Palo Alto Networks",
match: matchAny(
issuerContains("Palo Alto Networks"),
issuerContains("PAN-FW"),
),
},
{
Name: "Sophos",
match: issuerContains("Sophos"),
},
{
Name: "Ubiquiti",
match: matchAny(
issuerContains("UniFi"),
issuerContains("Ubiquiti"),
),
},
func manufacturers() []*Manufacturer {
manufacturersOnce.Do(func() {
manufacturersList = []*Manufacturer{
{
Name: "Aruba Networks",
match: issuerContains("Aruba"),
},
{
Name: "Cisco",
match: issuerContains("Cisco"),
},
{
Name: "Fortinet",
match: matchAny(
issuerContains("Fortinet"),
certEmail("support@fortinet.com"),
),
},
{
Name: "Huawei",
match: certEmail("mobile@huawei.com"),
},
{
Name: "Palo Alto Networks",
match: matchAny(
issuerContains("Palo Alto Networks"),
issuerContains("PAN-FW"),
),
},
{
Name: "Sophos",
match: issuerContains("Sophos"),
},
{
Name: "Ubiquiti",
match: matchAny(
issuerContains("UniFi"),
issuerContains("Ubiquiti"),
),
},
}
})
return manufacturersList
}
var (
manufacturersOnce sync.Once
manufacturersList []*Manufacturer
)
type matchFunc func(*x509.Certificate) bool
func issuerContains(s string) matchFunc {

View File

@@ -21,11 +21,14 @@ import (
"net"
"net/http"
"os"
"strings"
"sync"
"sync/atomic"
"time"
"tailscale.com/derp/derpconst"
"tailscale.com/envknob"
"tailscale.com/feature/buildfeatures"
"tailscale.com/health"
"tailscale.com/hostinfo"
"tailscale.com/net/bakedroots"
@@ -34,12 +37,6 @@ import (
var counterFallbackOK int32 // atomic
// If SSLKEYLOGFILE is set, it's a file to which we write our TLS private keys
// in a way that WireShark can read.
//
// See https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format
var sslKeyLogFile = os.Getenv("SSLKEYLOGFILE")
var debug = envknob.RegisterBool("TS_DEBUG_TLS_DIAL")
// tlsdialWarningPrinted tracks whether we've printed a warning about a given
@@ -57,26 +54,40 @@ var mitmBlockWarnable = health.Register(&health.Warnable{
ImpactsConnectivity: true,
})
// Config returns a tls.Config for connecting to a server.
// Config returns a tls.Config for connecting to a server that
// uses system roots for validation but, if those fail, also tries
// the baked-in LetsEncrypt roots as a fallback validation method.
//
// If base is non-nil, it's cloned as the base config before
// being configured and returned.
// If ht is non-nil, it's used to report health errors.
func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config {
func Config(ht *health.Tracker, base *tls.Config) *tls.Config {
var conf *tls.Config
if base == nil {
conf = new(tls.Config)
} else {
conf = base.Clone()
}
conf.ServerName = host
if n := sslKeyLogFile; n != "" {
f, err := os.OpenFile(n, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
if err != nil {
log.Fatal(err)
// Note: we do NOT set conf.ServerName here (as we accidentally did
// previously), as this path is also used when dialing an HTTPS proxy server
// (through which we'll send a CONNECT request to get a TCP connection to do
// the real TCP connection) because host is the ultimate hostname, but this
// tls.Config is used for both the proxy and the ultimate target.
if buildfeatures.HasDebug {
// If SSLKEYLOGFILE is set, it's a file to which we write our TLS private keys
// in a way that WireShark can read.
//
// See https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format
if n := os.Getenv("SSLKEYLOGFILE"); n != "" {
f, err := os.OpenFile(n, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
if err != nil {
log.Fatal(err)
}
log.Printf("WARNING: writing to SSLKEYLOGFILE %v", n)
conf.KeyLogWriter = f
}
log.Printf("WARNING: writing to SSLKEYLOGFILE %v", n)
conf.KeyLogWriter = f
}
if conf.InsecureSkipVerify {
@@ -91,7 +102,9 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config {
// (with the baked-in fallback root) in the VerifyConnection hook.
conf.InsecureSkipVerify = true
conf.VerifyConnection = func(cs tls.ConnectionState) (retErr error) {
if host == "log.tailscale.com" && hostinfo.IsNATLabGuestVM() {
dialedHost := cs.ServerName
if dialedHost == "log.tailscale.com" && hostinfo.IsNATLabGuestVM() {
// Allow log.tailscale.com TLS MITM for integration tests when
// the client's running within a NATLab VM.
return nil
@@ -114,7 +127,7 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config {
// Show a dedicated warning.
m, ok := blockblame.VerifyCertificate(cert)
if ok {
log.Printf("tlsdial: server cert for %q looks like %q equipment (could be blocking Tailscale)", host, m.Name)
log.Printf("tlsdial: server cert seen while dialing %q looks like %q equipment (could be blocking Tailscale)", dialedHost, m.Name)
ht.SetUnhealthy(mitmBlockWarnable, health.Args{"manufacturer": m.Name})
} else {
ht.SetHealthy(mitmBlockWarnable)
@@ -133,7 +146,7 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config {
ht.SetTLSConnectionError(cs.ServerName, nil)
if selfSignedIssuer != "" {
// Log the self-signed issuer, but don't treat it as an error.
log.Printf("tlsdial: warning: server cert for %q passed x509 validation but is self-signed by %q", host, selfSignedIssuer)
log.Printf("tlsdial: warning: server cert for %q passed x509 validation but is self-signed by %q", dialedHost, selfSignedIssuer)
}
}
}()
@@ -142,7 +155,7 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config {
// First try doing x509 verification with the system's
// root CA pool.
opts := x509.VerifyOptions{
DNSName: cs.ServerName,
DNSName: dialedHost,
Intermediates: x509.NewCertPool(),
}
for _, cert := range cs.PeerCertificates[1:] {
@@ -150,22 +163,22 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config {
}
_, errSys := cs.PeerCertificates[0].Verify(opts)
if debug() {
log.Printf("tlsdial(sys %q): %v", host, errSys)
log.Printf("tlsdial(sys %q): %v", dialedHost, errSys)
}
if !buildfeatures.HasBakedRoots || (errSys == nil && !debug()) {
return errSys
}
// Always verify with our baked-in Let's Encrypt certificate,
// so we can log an informational message. This is useful for
// detecting SSL MiTM.
// If we have baked-in LetsEncrypt roots and we either failed above, or
// debug logging is enabled, also verify with LetsEncrypt.
opts.Roots = bakedroots.Get()
_, bakedErr := cs.PeerCertificates[0].Verify(opts)
if debug() {
log.Printf("tlsdial(bake %q): %v", host, bakedErr)
log.Printf("tlsdial(bake %q): %v", dialedHost, bakedErr)
} else if bakedErr != nil {
if _, loaded := tlsdialWarningPrinted.LoadOrStore(host, true); !loaded {
if errSys == nil {
log.Printf("tlsdial: warning: server cert for %q is not a Let's Encrypt cert", host)
} else {
log.Printf("tlsdial: error: server cert for %q failed to verify and is not a Let's Encrypt cert", host)
if _, loaded := tlsdialWarningPrinted.LoadOrStore(dialedHost, true); !loaded {
if errSys != nil {
log.Printf("tlsdial: error: server cert for %q failed both system roots & Let's Encrypt root validation", dialedHost)
}
}
}
@@ -200,9 +213,6 @@ func SetConfigExpectedCert(c *tls.Config, certDNSName string) {
c.ServerName = certDNSName
return
}
if c.VerifyPeerCertificate != nil {
panic("refusing to override tls.Config.VerifyPeerCertificate")
}
// Set InsecureSkipVerify to prevent crypto/tls from doing its
// own cert verification, but do the same work that it'd do
// (but using certDNSName) in the VerifyPeerCertificate hook.
@@ -232,8 +242,8 @@ func SetConfigExpectedCert(c *tls.Config, certDNSName string) {
if debug() {
log.Printf("tlsdial(sys %q/%q): %v", c.ServerName, certDNSName, errSys)
}
if errSys == nil {
return nil
if !buildfeatures.HasBakedRoots || errSys == nil {
return errSys
}
opts.Roots = bakedroots.Get()
_, err := certs[0].Verify(opts)
@@ -247,41 +257,50 @@ func SetConfigExpectedCert(c *tls.Config, certDNSName string) {
}
}
// SetConfigExpectedCertHash configures c's VerifyPeerCertificate function
// to require that exactly 1 cert is presented, and that the hex of its SHA256 hash
// is equal to wantFullCertSHA256Hex and that it's a valid cert for c.ServerName.
// SetConfigExpectedCertHash configures c's VerifyPeerCertificate function to
// require that exactly 1 cert is presented (not counting any present MetaCert),
// and that the hex of its SHA256 hash is equal to wantFullCertSHA256Hex and
// that it's a valid cert for c.ServerName.
func SetConfigExpectedCertHash(c *tls.Config, wantFullCertSHA256Hex string) {
if c.VerifyPeerCertificate != nil {
panic("refusing to override tls.Config.VerifyPeerCertificate")
}
// Set InsecureSkipVerify to prevent crypto/tls from doing its
// own cert verification, but do the same work that it'd do
// (but using certDNSName) in the VerifyPeerCertificate hook.
// (but using certDNSName) in the VerifyConnection hook.
c.InsecureSkipVerify = true
c.VerifyConnection = nil
c.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
if len(rawCerts) == 0 {
return errors.New("no certs presented")
c.VerifyConnection = func(cs tls.ConnectionState) error {
dialedHost := cs.ServerName
var sawGoodCert bool
for _, cert := range cs.PeerCertificates {
if strings.HasPrefix(cert.Subject.CommonName, derpconst.MetaCertCommonNamePrefix) {
continue
}
if sawGoodCert {
return errors.New("unexpected multiple certs presented")
}
if fmt.Sprintf("%02x", sha256.Sum256(cert.Raw)) != wantFullCertSHA256Hex {
return fmt.Errorf("cert hash does not match expected cert hash")
}
if dialedHost != "" { // it's empty when dialing a derper by IP with no hostname
if err := cert.VerifyHostname(dialedHost); err != nil {
return fmt.Errorf("cert does not match server name %q: %w", dialedHost, err)
}
}
now := time.Now()
if now.After(cert.NotAfter) {
return fmt.Errorf("cert expired %v", cert.NotAfter)
}
if now.Before(cert.NotBefore) {
return fmt.Errorf("cert not yet valid until %v; is your clock correct?", cert.NotBefore)
}
sawGoodCert = true
}
if len(rawCerts) > 1 {
return errors.New("unexpected multiple certs presented")
}
if fmt.Sprintf("%02x", sha256.Sum256(rawCerts[0])) != wantFullCertSHA256Hex {
return fmt.Errorf("cert hash does not match expected cert hash")
}
cert, err := x509.ParseCertificate(rawCerts[0])
if err != nil {
return fmt.Errorf("ParseCertificate: %w", err)
}
if err := cert.VerifyHostname(c.ServerName); err != nil {
return fmt.Errorf("cert does not match server name %q: %w", c.ServerName, err)
}
now := time.Now()
if now.After(cert.NotAfter) {
return fmt.Errorf("cert expired %v", cert.NotAfter)
}
if now.Before(cert.NotBefore) {
return fmt.Errorf("cert not yet valid until %v; is your clock correct?", cert.NotBefore)
if !sawGoodCert {
return errors.New("expected cert not presented")
}
return nil
}
@@ -292,12 +311,8 @@ func SetConfigExpectedCertHash(c *tls.Config, wantFullCertSHA256Hex string) {
func NewTransport() *http.Transport {
return &http.Transport{
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
var d tls.Dialer
d.Config = Config(host, nil, nil)
d.Config = Config(nil, nil)
return d.DialContext(ctx, network, addr)
},
}

Some files were not shown because too many files have changed in this diff Show More