This commit is contained in:
2026-02-19 10:07:43 +00:00
parent 007438e372
commit 6e637ecf77
1763 changed files with 60820 additions and 279516 deletions

View File

@@ -18,6 +18,7 @@ import (
"golang.org/x/crypto/blake2s"
chp "golang.org/x/crypto/chacha20poly1305"
"tailscale.com/syncs"
"tailscale.com/types/key"
)
@@ -48,7 +49,7 @@ type Conn struct {
// rxState is all the Conn state that Read uses.
type rxState struct {
sync.Mutex
syncs.Mutex
cipher cipher.AEAD
nonce nonce
buf *maxMsgBuffer // or nil when reads exhausted

View File

@@ -12,7 +12,6 @@ import (
"sync/atomic"
"time"
"tailscale.com/logtail/backoff"
"tailscale.com/net/sockstats"
"tailscale.com/tailcfg"
"tailscale.com/tstime"
@@ -21,8 +20,10 @@ import (
"tailscale.com/types/netmap"
"tailscale.com/types/persist"
"tailscale.com/types/structs"
"tailscale.com/util/backoff"
"tailscale.com/util/clientmetric"
"tailscale.com/util/execqueue"
"tailscale.com/util/testenv"
)
type LoginGoal struct {
@@ -117,14 +118,13 @@ type Auto struct {
logf logger.Logf
closed bool
updateCh chan struct{} // readable when we should inform the server of a change
observer Observer // called to update Client status; always non-nil
observer Observer // if non-nil, called to update Client status
observerQueue execqueue.ExecQueue
shutdownFn func() // to be called prior to shutdown or nil
unregisterHealthWatch func()
mu sync.Mutex // mutex guards the following fields
started bool // whether [Auto.Start] has been called
wantLoggedIn bool // whether the user wants to be logged in per last method call
urlToVisit string // the last url we were told to visit
expiry time.Time
@@ -140,7 +140,6 @@ type Auto struct {
loggedIn bool // true if currently logged in
loginGoal *LoginGoal // non-nil if some login activity is desired
inMapPoll bool // true once we get the first MapResponse in a stream; false when HTTP response ends
state State // TODO(bradfitz): delete this, make it computed by method from other state
authCtx context.Context // context used for auth requests
mapCtx context.Context // context used for netmap and update requests
@@ -153,15 +152,21 @@ type Auto struct {
// New creates and starts a new Auto.
func New(opts Options) (*Auto, error) {
c, err := NewNoStart(opts)
if c != nil {
c.Start()
c, err := newNoStart(opts)
if err != nil {
return nil, err
}
if opts.StartPaused {
c.SetPaused(true)
}
if !opts.SkipStartForTests {
c.start()
}
return c, err
}
// NewNoStart creates a new Auto, but without calling Start on it.
func NewNoStart(opts Options) (_ *Auto, err error) {
// newNoStart creates a new Auto, but without calling Start on it.
func newNoStart(opts Options) (_ *Auto, err error) {
direct, err := NewDirect(opts)
if err != nil {
return nil, err
@@ -172,9 +177,6 @@ func NewNoStart(opts Options) (_ *Auto, err error) {
}
}()
if opts.Observer == nil {
return nil, errors.New("missing required Options.Observer")
}
if opts.Logf == nil {
opts.Logf = func(fmt string, args ...any) {}
}
@@ -192,15 +194,14 @@ func NewNoStart(opts Options) (_ *Auto, err error) {
observer: opts.Observer,
shutdownFn: opts.Shutdown,
}
c.authCtx, c.authCancel = context.WithCancel(context.Background())
c.authCtx = sockstats.WithSockStats(c.authCtx, sockstats.LabelControlClientAuto, opts.Logf)
c.mapCtx, c.mapCancel = context.WithCancel(context.Background())
c.mapCtx = sockstats.WithSockStats(c.mapCtx, sockstats.LabelControlClientAuto, opts.Logf)
c.unregisterHealthWatch = opts.HealthTracker.RegisterWatcher(direct.ReportHealthChange)
return c, nil
}
// SetPaused controls whether HTTP activity should be paused.
@@ -225,10 +226,21 @@ func (c *Auto) SetPaused(paused bool) {
c.unpauseWaiters = nil
}
// Start starts the client's goroutines.
// StartForTest starts the client's goroutines.
//
// It should only be called for clients created by NewNoStart.
func (c *Auto) Start() {
// It should only be called for clients created with [Options.SkipStartForTests].
func (c *Auto) StartForTest() {
testenv.AssertInTest()
c.start()
}
func (c *Auto) start() {
c.mu.Lock()
defer c.mu.Unlock()
if c.started {
return
}
c.started = true
go c.authRoutine()
go c.mapRoutine()
go c.updateRoutine()
@@ -302,10 +314,11 @@ func (c *Auto) authRoutine() {
c.mu.Lock()
goal := c.loginGoal
ctx := c.authCtx
loggedIn := c.loggedIn
if goal != nil {
c.logf("[v1] authRoutine: %s; wantLoggedIn=%v", c.state, true)
c.logf("[v1] authRoutine: loggedIn=%v; wantLoggedIn=%v", loggedIn, true)
} else {
c.logf("[v1] authRoutine: %s; goal=nil paused=%v", c.state, c.paused)
c.logf("[v1] authRoutine: loggedIn=%v; goal=nil paused=%v", loggedIn, c.paused)
}
c.mu.Unlock()
@@ -328,11 +341,6 @@ func (c *Auto) authRoutine() {
c.mu.Lock()
c.urlToVisit = goal.url
if goal.url != "" {
c.state = StateURLVisitRequired
} else {
c.state = StateAuthenticating
}
c.mu.Unlock()
var url string
@@ -366,7 +374,6 @@ func (c *Auto) authRoutine() {
flags: LoginDefault,
url: url,
}
c.state = StateURLVisitRequired
c.mu.Unlock()
c.sendStatus("authRoutine-url", err, url, nil)
@@ -386,7 +393,6 @@ func (c *Auto) authRoutine() {
c.urlToVisit = ""
c.loggedIn = true
c.loginGoal = nil
c.state = StateAuthenticated
c.mu.Unlock()
c.sendStatus("authRoutine-success", nil, "", nil)
@@ -419,6 +425,11 @@ func (c *Auto) unpausedChanLocked() <-chan bool {
return unpaused
}
// ClientID returns the ClientID of the direct controlClient
func (c *Auto) ClientID() int64 {
return c.direct.ClientID()
}
// mapRoutineState is the state of Auto.mapRoutine while it's running.
type mapRoutineState struct {
c *Auto
@@ -431,21 +442,17 @@ func (mrs mapRoutineState) UpdateFullNetmap(nm *netmap.NetworkMap) {
c := mrs.c
c.mu.Lock()
ctx := c.mapCtx
c.inMapPoll = true
if c.loggedIn {
c.state = StateSynchronized
}
c.expiry = nm.Expiry
c.expiry = nm.SelfKeyExpiry()
stillAuthed := c.loggedIn
c.logf("[v1] mapRoutine: netmap received: %s", c.state)
c.logf("[v1] mapRoutine: netmap received: loggedIn=%v inMapPoll=true", stillAuthed)
c.mu.Unlock()
if stillAuthed {
c.sendStatus("mapRoutine-got-netmap", nil, "", nm)
}
// Reset the backoff timer if we got a netmap.
mrs.bo.BackOff(ctx, nil)
mrs.bo.Reset()
}
func (mrs mapRoutineState) UpdateNetmapDelta(muts []netmap.NodeMutation) bool {
@@ -486,8 +493,8 @@ func (c *Auto) mapRoutine() {
}
c.mu.Lock()
c.logf("[v1] mapRoutine: %s", c.state)
loggedIn := c.loggedIn
c.logf("[v1] mapRoutine: loggedIn=%v", loggedIn)
ctx := c.mapCtx
c.mu.Unlock()
@@ -518,9 +525,6 @@ func (c *Auto) mapRoutine() {
c.direct.health.SetOutOfPollNetMap()
c.mu.Lock()
c.inMapPoll = false
if c.state == StateSynchronized {
c.state = StateAuthenticated
}
paused := c.paused
c.mu.Unlock()
@@ -586,12 +590,12 @@ func (c *Auto) sendStatus(who string, err error, url string, nm *netmap.NetworkM
c.mu.Unlock()
return
}
state := c.state
loggedIn := c.loggedIn
inMapPoll := c.inMapPoll
loginGoal := c.loginGoal
c.mu.Unlock()
c.logf("[v1] sendStatus: %s: %v", who, state)
c.logf("[v1] sendStatus: %s: loggedIn=%v inMapPoll=%v", who, loggedIn, inMapPoll)
var p persist.PersistView
if nm != nil && loggedIn && inMapPoll {
@@ -602,18 +606,31 @@ func (c *Auto) sendStatus(who string, err error, url string, nm *netmap.NetworkM
nm = nil
}
newSt := &Status{
URL: url,
Persist: p,
NetMap: nm,
Err: err,
state: state,
URL: url,
Persist: p,
NetMap: nm,
Err: err,
LoggedIn: loggedIn && loginGoal == nil,
InMapPoll: inMapPoll,
}
if c.observer == nil {
return
}
c.lastStatus.Store(newSt)
// Launch a new goroutine to avoid blocking the caller while the observer
// does its thing, which may result in a call back into the client.
metricQueued.Add(1)
c.observerQueue.Add(func() {
c.mu.Lock()
closed := c.closed
c.mu.Unlock()
if closed {
return
}
if canSkipStatus(newSt, c.lastStatus.Load()) {
metricSkippable.Add(1)
if !c.direct.controlKnobs.DisableSkipStatusQueue.Load() {
@@ -657,14 +674,15 @@ func canSkipStatus(s1, s2 *Status) bool {
// we can't skip it.
return false
}
if s1.Err != nil || s1.URL != "" {
// If s1 has an error or a URL, we shouldn't skip it, lest the error go
// away in s2 or in-between. We want to make sure all the subsystems see
// it. Plus there aren't many of these, so not worth skipping.
if s1.Err != nil || s1.URL != "" || s1.LoggedIn {
// If s1 has an error, a URL, or LoginFinished set, we shouldn't skip it,
// lest the error go away in s2 or in-between. We want to make sure all
// the subsystems see it. Plus there aren't many of these, so not worth
// skipping.
return false
}
if !s1.Persist.Equals(s2.Persist) || s1.state != s2.state {
// If s1 has a different Persist or state than s2,
if !s1.Persist.Equals(s2.Persist) || s1.LoggedIn != s2.LoggedIn || s1.InMapPoll != s2.InMapPoll || s1.URL != s2.URL {
// If s1 has a different Persist, LoginFinished, Synced, or URL than s2,
// don't skip it. We only care about skipping the typical
// entries where the only difference is the NetMap.
return false
@@ -726,7 +744,6 @@ func (c *Auto) Logout(ctx context.Context) error {
}
c.mu.Lock()
c.loggedIn = false
c.state = StateNotAuthenticated
c.cancelAuthCtxLocked()
c.cancelMapCtxLocked()
c.mu.Unlock()
@@ -750,6 +767,13 @@ func (c *Auto) UpdateEndpoints(endpoints []tailcfg.Endpoint) {
}
}
// SetDiscoPublicKey sets the client's Disco public to key and sends the change
// to the control server.
func (c *Auto) SetDiscoPublicKey(key key.DiscoPublic) {
c.direct.SetDiscoPublicKey(key)
c.updateControl()
}
func (c *Auto) Shutdown() {
c.mu.Lock()
if c.closed {
@@ -774,7 +798,6 @@ func (c *Auto) Shutdown() {
shutdownFn()
}
c.unregisterHealthWatch()
<-c.authDone
<-c.mapDone
<-c.updateDone
@@ -813,13 +836,3 @@ func (c *Auto) SetDNS(ctx context.Context, req *tailcfg.SetDNSRequest) error {
func (c *Auto) DoNoiseRequest(req *http.Request) (*http.Response, error) {
return c.direct.DoNoiseRequest(req)
}
// GetSingleUseNoiseRoundTripper returns a RoundTripper that can be only be used
// once (and must be used once) to make a single HTTP request over the noise
// channel to the coordination server.
//
// In addition to the RoundTripper, it returns the HTTP/2 channel's early noise
// payload, if any.
func (c *Auto) GetSingleUseNoiseRoundTripper(ctx context.Context) (http.RoundTripper, *tailcfg.EarlyNoise, error) {
return c.direct.GetSingleUseNoiseRoundTripper(ctx)
}

View File

@@ -12,6 +12,7 @@ import (
"context"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
// LoginFlags is a bitmask of options to change the behavior of Client.Login
@@ -80,7 +81,15 @@ type Client interface {
// TODO: a server-side change would let us simply upload this
// in a separate http request. It has nothing to do with the rest of
// the state machine.
// Note: the auto client uploads the new endpoints to control immediately.
UpdateEndpoints(endpoints []tailcfg.Endpoint)
// SetDiscoPublicKey updates the disco public key that will be sent in
// future map requests. This should be called after rotating the discovery key.
// Note: the auto client uploads the new key to control immediately.
SetDiscoPublicKey(key.DiscoPublic)
// ClientID returns the ClientID of a client. This ID is meant to
// distinguish one client from another.
ClientID() int64
}
// UserVisibleError is an error that should be shown to users.

View File

@@ -4,9 +4,11 @@
package controlclient
import (
"bufio"
"bytes"
"cmp"
"context"
"crypto"
"crypto/sha256"
"encoding/binary"
"encoding/json"
"errors"
@@ -16,19 +18,20 @@ import (
"net"
"net/http"
"net/netip"
"net/url"
"os"
"reflect"
"runtime"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"go4.org/mem"
"tailscale.com/control/controlknobs"
"tailscale.com/control/ts2021"
"tailscale.com/envknob"
"tailscale.com/feature"
"tailscale.com/feature/buildfeatures"
"tailscale.com/health"
"tailscale.com/hostinfo"
"tailscale.com/ipn/ipnstate"
@@ -37,11 +40,11 @@ import (
"tailscale.com/net/dnsfallback"
"tailscale.com/net/netmon"
"tailscale.com/net/netutil"
"tailscale.com/net/netx"
"tailscale.com/net/tlsdial"
"tailscale.com/net/tsdial"
"tailscale.com/net/tshttpproxy"
"tailscale.com/syncs"
"tailscale.com/tailcfg"
"tailscale.com/tempfork/httprec"
"tailscale.com/tka"
"tailscale.com/tstime"
"tailscale.com/types/key"
@@ -51,62 +54,70 @@ import (
"tailscale.com/types/ptr"
"tailscale.com/types/tkatype"
"tailscale.com/util/clientmetric"
"tailscale.com/util/multierr"
"tailscale.com/util/eventbus"
"tailscale.com/util/singleflight"
"tailscale.com/util/syspolicy"
"tailscale.com/util/systemd"
"tailscale.com/util/syspolicy/pkey"
"tailscale.com/util/syspolicy/policyclient"
"tailscale.com/util/testenv"
"tailscale.com/util/zstdframe"
)
// Direct is the client that connects to a tailcontrol server for a node.
type Direct struct {
httpc *http.Client // HTTP client used to talk to tailcontrol
interceptedDial *atomic.Bool // if non-nil, pointer to bool whether ScreenTime intercepted our dial
dialer *tsdial.Dialer
dnsCache *dnscache.Resolver
controlKnobs *controlknobs.Knobs // always non-nil
serverURL string // URL of the tailcontrol server
clock tstime.Clock
logf logger.Logf
netMon *netmon.Monitor // non-nil
health *health.Tracker
discoPubKey key.DiscoPublic
getMachinePrivKey func() (key.MachinePrivate, error)
debugFlags []string
skipIPForwardingCheck bool
pinger Pinger
popBrowser func(url string) // or nil
c2nHandler http.Handler // or nil
onClientVersion func(*tailcfg.ClientVersion) // or nil
onControlTime func(time.Time) // or nil
onTailnetDefaultAutoUpdate func(bool) // or nil
panicOnUse bool // if true, panic if client is used (for testing)
closedCtx context.Context // alive until Direct.Close is called
closeCtx context.CancelFunc // cancels closedCtx
httpc *http.Client // HTTP client used to do TLS requests to control (just https://controlplane.tailscale.com/key?v=123)
interceptedDial *atomic.Bool // if non-nil, pointer to bool whether ScreenTime intercepted our dial
dialer *tsdial.Dialer
dnsCache *dnscache.Resolver
controlKnobs *controlknobs.Knobs // always non-nil
serverURL string // URL of the tailcontrol server
clock tstime.Clock
logf logger.Logf
netMon *netmon.Monitor // non-nil
health *health.Tracker
busClient *eventbus.Client
clientVersionPub *eventbus.Publisher[tailcfg.ClientVersion]
autoUpdatePub *eventbus.Publisher[AutoUpdate]
controlTimePub *eventbus.Publisher[ControlTime]
getMachinePrivKey func() (key.MachinePrivate, error)
debugFlags []string
skipIPForwardingCheck bool
pinger Pinger
popBrowser func(url string) // or nil
polc policyclient.Client // always non-nil
c2nHandler http.Handler // or nil
panicOnUse bool // if true, panic if client is used (for testing)
closedCtx context.Context // alive until Direct.Close is called
closeCtx context.CancelFunc // cancels closedCtx
dialPlan ControlDialPlanner // can be nil
mu sync.Mutex // mutex guards the following fields
mu syncs.Mutex // mutex guards the following fields
serverLegacyKey key.MachinePublic // original ("legacy") nacl crypto_box-based public key; only used for signRegisterRequest on Windows now
serverNoiseKey key.MachinePublic
discoPubKey key.DiscoPublic // protected by mu; can be updated via [SetDiscoPublicKey]
sfGroup singleflight.Group[struct{}, *NoiseClient] // protects noiseClient creation.
noiseClient *NoiseClient
sfGroup singleflight.Group[struct{}, *ts2021.Client] // protects noiseClient creation.
noiseClient *ts2021.Client // also protected by mu
persist persist.PersistView
authKey string
tryingNewKey key.NodePrivate
expiry time.Time // or zero value if none/unknown
hostinfo *tailcfg.Hostinfo // always non-nil
netinfo *tailcfg.NetInfo
endpoints []tailcfg.Endpoint
tkaHead string
lastPingURL string // last PingRequest.URL received, for dup suppression
persist persist.PersistView
authKey string
tryingNewKey key.NodePrivate
expiry time.Time // or zero value if none/unknown
hostinfo *tailcfg.Hostinfo // always non-nil
netinfo *tailcfg.NetInfo
endpoints []tailcfg.Endpoint
tkaHead string
lastPingURL string // last PingRequest.URL received, for dup suppression
connectionHandleForTest string // sent in MapRequest.ConnectionHandleForTest
controlClientID int64 // Random ID used to differentiate clients for consumers of messages.
}
// Observer is implemented by users of the control client (such as LocalBackend)
// to get notified of changes in the control client's status.
//
// If an implementation of Observer also implements [NetmapDeltaUpdater], they get
// delta updates as well as full netmap updates.
type Observer interface {
// SetControlClientStatus is called when the client has a new status to
// report. The Client is provided to allow the Observer to track which
@@ -116,28 +127,36 @@ type Observer interface {
}
type Options struct {
Persist persist.Persist // initial persistent data
GetMachinePrivateKey func() (key.MachinePrivate, error) // returns the machine key to use
ServerURL string // URL of the tailcontrol server
AuthKey string // optional node auth key for auto registration
Clock tstime.Clock
Hostinfo *tailcfg.Hostinfo // non-nil passes ownership, nil means to use default using os.Hostname, etc
DiscoPublicKey key.DiscoPublic
Logf logger.Logf
HTTPTestClient *http.Client // optional HTTP client to use (for tests only)
NoiseTestClient *http.Client // optional HTTP client to use for noise RPCs (tests only)
DebugFlags []string // debug settings to send to control
HealthTracker *health.Tracker
PopBrowserURL func(url string) // optional func to open browser
OnClientVersion func(*tailcfg.ClientVersion) // optional func to inform GUI of client version status
OnControlTime func(time.Time) // optional func to notify callers of new time from control
OnTailnetDefaultAutoUpdate func(bool) // optional func to inform GUI of default auto-update setting for the tailnet
Dialer *tsdial.Dialer // non-nil
C2NHandler http.Handler // or nil
ControlKnobs *controlknobs.Knobs // or nil to ignore
Persist persist.Persist // initial persistent data
GetMachinePrivateKey func() (key.MachinePrivate, error) // returns the machine key to use
ServerURL string // URL of the tailcontrol server
AuthKey string // optional node auth key for auto registration
Clock tstime.Clock
Hostinfo *tailcfg.Hostinfo // non-nil passes ownership, nil means to use default using os.Hostname, etc
DiscoPublicKey key.DiscoPublic
PolicyClient policyclient.Client // or nil for none
Logf logger.Logf
HTTPTestClient *http.Client // optional HTTP client to use (for tests only)
NoiseTestClient *http.Client // optional HTTP client to use for noise RPCs (tests only)
DebugFlags []string // debug settings to send to control
HealthTracker *health.Tracker
PopBrowserURL func(url string) // optional func to open browser
Dialer *tsdial.Dialer // non-nil
C2NHandler http.Handler // or nil
ControlKnobs *controlknobs.Knobs // or nil to ignore
Bus *eventbus.Bus // non-nil, for setting up publishers
SkipStartForTests bool // if true, don't call [Auto.Start] to avoid any background goroutines (for tests only)
// StartPaused indicates whether the client should start in a paused state
// where it doesn't do network requests. This primarily exists for testing
// but not necessarily "go test" tests, so it isn't restricted to only
// being used in tests.
StartPaused bool
// Observer is called when there's a change in status to report
// from the control client.
// If nil, no status updates are reported.
Observer Observer
// SkipIPForwardingCheck declares that the host's IP
@@ -213,6 +232,8 @@ type NetmapDeltaUpdater interface {
UpdateNetmapDelta([]netmap.NodeMutation) (ok bool)
}
var nextControlClientID atomic.Int64
// NewDirect returns a new Direct client.
func NewDirect(opts Options) (*Direct, error) {
if opts.ServerURL == "" {
@@ -238,10 +259,6 @@ func NewDirect(opts Options) (*Direct, error) {
opts.ControlKnobs = &controlknobs.Knobs{}
}
opts.ServerURL = strings.TrimRight(opts.ServerURL, "/")
serverURL, err := url.Parse(opts.ServerURL)
if err != nil {
return nil, err
}
if opts.Clock == nil {
opts.Clock = tstime.StdClock{}
}
@@ -269,10 +286,14 @@ func NewDirect(opts Options) (*Direct, error) {
var interceptedDial *atomic.Bool
if httpc == nil {
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.Proxy = tshttpproxy.ProxyFromEnvironment
tshttpproxy.SetTransportGetProxyConnectHeader(tr)
tr.TLSClientConfig = tlsdial.Config(serverURL.Hostname(), opts.HealthTracker, tr.TLSClientConfig)
var dialFunc dialFunc
if buildfeatures.HasUseProxy {
tr.Proxy = feature.HookProxyFromEnvironment.GetOrNil()
if f, ok := feature.HookProxySetTransportGetProxyConnectHeader.GetOk(); ok {
f(tr)
}
}
tr.TLSClientConfig = tlsdial.Config(opts.HealthTracker, tr.TLSClientConfig)
var dialFunc netx.DialFunc
dialFunc, interceptedDial = makeScreenTimeDetectingDialFunc(opts.Dialer.SystemDial)
tr.DialContext = dnscache.Dialer(dialFunc, dnsCache)
tr.DialTLSContext = dnscache.TLSDialer(dialFunc, dnsCache, tr.TLSClientConfig)
@@ -286,32 +307,32 @@ func NewDirect(opts Options) (*Direct, error) {
}
c := &Direct{
httpc: httpc,
interceptedDial: interceptedDial,
controlKnobs: opts.ControlKnobs,
getMachinePrivKey: opts.GetMachinePrivateKey,
serverURL: opts.ServerURL,
clock: opts.Clock,
logf: opts.Logf,
persist: opts.Persist.View(),
authKey: opts.AuthKey,
discoPubKey: opts.DiscoPublicKey,
debugFlags: opts.DebugFlags,
netMon: netMon,
health: opts.HealthTracker,
skipIPForwardingCheck: opts.SkipIPForwardingCheck,
pinger: opts.Pinger,
popBrowser: opts.PopBrowserURL,
onClientVersion: opts.OnClientVersion,
onTailnetDefaultAutoUpdate: opts.OnTailnetDefaultAutoUpdate,
onControlTime: opts.OnControlTime,
c2nHandler: opts.C2NHandler,
dialer: opts.Dialer,
dnsCache: dnsCache,
dialPlan: opts.DialPlan,
httpc: httpc,
interceptedDial: interceptedDial,
controlKnobs: opts.ControlKnobs,
getMachinePrivKey: opts.GetMachinePrivateKey,
serverURL: opts.ServerURL,
clock: opts.Clock,
logf: opts.Logf,
persist: opts.Persist.View(),
authKey: opts.AuthKey,
debugFlags: opts.DebugFlags,
netMon: netMon,
health: opts.HealthTracker,
skipIPForwardingCheck: opts.SkipIPForwardingCheck,
pinger: opts.Pinger,
polc: cmp.Or(opts.PolicyClient, policyclient.Client(policyclient.NoPolicyClient{})),
popBrowser: opts.PopBrowserURL,
c2nHandler: opts.C2NHandler,
dialer: opts.Dialer,
dnsCache: dnsCache,
dialPlan: opts.DialPlan,
}
c.discoPubKey = opts.DiscoPublicKey
c.closedCtx, c.closeCtx = context.WithCancel(context.Background())
c.controlClientID = nextControlClientID.Add(1)
if opts.Hostinfo == nil {
c.SetHostinfo(hostinfo.New())
} else {
@@ -321,7 +342,7 @@ func NewDirect(opts Options) (*Direct, error) {
}
}
if opts.NoiseTestClient != nil {
c.noiseClient = &NoiseClient{
c.noiseClient = &ts2021.Client{
Client: opts.NoiseTestClient,
}
c.serverNoiseKey = key.NewMachine().Public() // prevent early error before hitting test client
@@ -329,6 +350,12 @@ func NewDirect(opts Options) (*Direct, error) {
if strings.Contains(opts.ServerURL, "controlplane.tailscale.com") && envknob.Bool("TS_PANIC_IF_HIT_MAIN_CONTROL") {
c.panicOnUse = true
}
c.busClient = opts.Bus.Client("controlClient.direct")
c.clientVersionPub = eventbus.Publish[tailcfg.ClientVersion](c.busClient)
c.autoUpdatePub = eventbus.Publish[AutoUpdate](c.busClient)
c.controlTimePub = eventbus.Publish[ControlTime](c.busClient)
return c, nil
}
@@ -338,15 +365,14 @@ func (c *Direct) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
c.busClient.Close()
if c.noiseClient != nil {
if err := c.noiseClient.Close(); err != nil {
return err
}
}
c.noiseClient = nil
if tr, ok := c.httpc.Transport.(*http.Transport); ok {
tr.CloseIdleConnections()
}
c.httpc.CloseIdleConnections()
return nil
}
@@ -387,7 +413,7 @@ func (c *Direct) SetNetInfo(ni *tailcfg.NetInfo) bool {
return true
}
// SetNetInfo stores a new TKA head value for next update.
// SetTKAHead stores a new TKA head value for next update.
// It reports whether the TKA head changed.
func (c *Direct) SetTKAHead(tkaHead string) bool {
c.mu.Lock()
@@ -402,6 +428,14 @@ func (c *Direct) SetTKAHead(tkaHead string) bool {
return true
}
// SetConnectionHandleForTest stores a new MapRequest.ConnectionHandleForTest
// value for the next update.
func (c *Direct) SetConnectionHandleForTest(handle string) {
c.mu.Lock()
defer c.mu.Unlock()
c.connectionHandleForTest = handle
}
func (c *Direct) GetPersist() persist.PersistView {
c.mu.Lock()
defer c.mu.Unlock()
@@ -523,7 +557,9 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new
} else {
if expired {
c.logf("Old key expired -> regen=true")
systemd.Status("key expired; run 'tailscale up' to authenticate")
if f, ok := feature.HookSystemdStatus.GetOk(); ok {
f("key expired; run 'tailscale up' to authenticate")
}
regen = true
}
if (opt.Flags & LoginInteractive) != 0 {
@@ -582,6 +618,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new
if persist.NetworkLockKey.IsZero() {
persist.NetworkLockKey = key.NewNLPrivate()
}
nlPub := persist.NetworkLockKey.Public()
if tryingNewKey.IsZero() {
@@ -611,7 +648,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new
return regen, opt.URL, nil, err
}
tailnet, err := syspolicy.GetString(syspolicy.Tailnet, "")
tailnet, err := c.polc.GetString(pkey.Tailnet, "")
if err != nil {
c.logf("unable to provide Tailnet field in register request. err: %v", err)
}
@@ -641,7 +678,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new
AuthKey: authKey,
}
}
err = signRegisterRequest(&request, c.serverURL, c.serverLegacyKey, machinePrivKey.Public())
err = signRegisterRequest(c.polc, &request, c.serverURL, c.serverLegacyKey, machinePrivKey.Public())
if err != nil {
// If signing failed, clear all related fields
request.SignatureType = tailcfg.SignatureNone
@@ -678,8 +715,8 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new
if err != nil {
return regen, opt.URL, nil, err
}
addLBHeader(req, request.OldNodeKey)
addLBHeader(req, request.NodeKey)
ts2021.AddLBHeader(req, request.OldNodeKey)
ts2021.AddLBHeader(req, request.NodeKey)
res, err := httpc.Do(req)
if err != nil {
@@ -816,6 +853,31 @@ func (c *Direct) SendUpdate(ctx context.Context) error {
return c.sendMapRequest(ctx, false, nil)
}
// SetDiscoPublicKey updates the disco public key in local state.
// It does not implicitly trigger [SendUpdate]; callers should arrange for that.
func (c *Direct) SetDiscoPublicKey(key key.DiscoPublic) {
c.mu.Lock()
defer c.mu.Unlock()
c.discoPubKey = key
}
// ClientID returns the controlClientID of the controlClient.
func (c *Direct) ClientID() int64 {
return c.controlClientID
}
// AutoUpdate is an eventbus value, reporting the value of tailcfg.MapResponse.DefaultAutoUpdate.
type AutoUpdate struct {
ClientID int64 // The ID field is used for consumers to differentiate instances of Direct.
Value bool // The Value represents DefaultAutoUpdate from [tailcfg.MapResponse].
}
// ControlTime is an eventbus value, reporting the value of tailcfg.MapResponse.ControlTime.
type ControlTime struct {
ClientID int64 // The ID field is used for consumers to differentiate instances of Direct.
Value time.Time // The Value represents ControlTime from [tailcfg.MapResponse].
}
// If we go more than watchdogTimeout without hearing from the server,
// end the long poll. We should be receiving a keep alive ping
// every minute.
@@ -848,8 +910,11 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap
persist := c.persist
serverURL := c.serverURL
serverNoiseKey := c.serverNoiseKey
discoKey := c.discoPubKey
hi := c.hostInfoLocked()
backendLogID := hi.BackendLogID
connectionHandleForTest := c.connectionHandleForTest
tkaHead := c.tkaHead
var epStrs []string
var eps []netip.AddrPort
var epTypes []tailcfg.EndpointType
@@ -889,21 +954,44 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap
}
nodeKey := persist.PublicNodeKey()
request := &tailcfg.MapRequest{
Version: tailcfg.CurrentCapabilityVersion,
KeepAlive: true,
NodeKey: nodeKey,
DiscoKey: c.discoPubKey,
Endpoints: eps,
EndpointTypes: epTypes,
Stream: isStreaming,
Hostinfo: hi,
DebugFlags: c.debugFlags,
OmitPeers: nu == nil,
TKAHead: c.tkaHead,
Version: tailcfg.CurrentCapabilityVersion,
KeepAlive: true,
NodeKey: nodeKey,
DiscoKey: discoKey,
Endpoints: eps,
EndpointTypes: epTypes,
Stream: isStreaming,
Hostinfo: hi,
DebugFlags: c.debugFlags,
OmitPeers: nu == nil,
TKAHead: tkaHead,
ConnectionHandleForTest: connectionHandleForTest,
}
// If we have a hardware attestation key, sign the node key with it and send
// the key & signature in the map request.
if buildfeatures.HasTPM {
if k := persist.AsStruct().AttestationKey; k != nil && !k.IsZero() {
hwPub := key.HardwareAttestationPublicFromPlatformKey(k)
request.HardwareAttestationKey = hwPub
t := c.clock.Now()
msg := fmt.Sprintf("%d|%s", t.Unix(), nodeKey.String())
digest := sha256.Sum256([]byte(msg))
sig, err := k.Sign(nil, digest[:], crypto.SHA256)
if err != nil {
c.logf("failed to sign node key with hardware attestation key: %v", err)
} else {
request.HardwareAttestationKeySignature = sig
request.HardwareAttestationKeySignatureTimestamp = t
}
}
}
var extraDebugFlags []string
if hi != nil && c.netMon != nil && !c.skipIPForwardingCheck &&
if buildfeatures.HasAdvertiseRoutes && hi != nil && c.netMon != nil && !c.skipIPForwardingCheck &&
ipForwardingBroken(hi.RoutableIPs, c.netMon.InterfaceState()) {
extraDebugFlags = append(extraDebugFlags, "warn-ip-forwarding-off")
}
@@ -967,7 +1055,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap
if err != nil {
return err
}
addLBHeader(req, nodeKey)
ts2021.AddLBHeader(req, nodeKey)
res, err := httpc.Do(req)
if err != nil {
@@ -1015,7 +1103,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap
c.persist = newPersist.View()
persist = c.persist
}
c.expiry = nm.Expiry
c.expiry = nm.SelfKeyExpiry()
}
// gotNonKeepAliveMessage is whether we've yet received a MapResponse message without
@@ -1047,7 +1135,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap
vlogf("netmap: read body after %v", time.Since(t0).Round(time.Millisecond))
var resp tailcfg.MapResponse
if err := c.decodeMsg(msg, &resp); err != nil {
if err := sess.decodeMsg(msg, &resp); err != nil {
vlogf("netmap: decode error: %v", err)
return err
}
@@ -1072,21 +1160,19 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap
c.logf("netmap: control says to open URL %v; no popBrowser func", u)
}
}
if resp.ClientVersion != nil && c.onClientVersion != nil {
c.onClientVersion(resp.ClientVersion)
if resp.ClientVersion != nil {
c.clientVersionPub.Publish(*resp.ClientVersion)
}
if resp.ControlTime != nil && !resp.ControlTime.IsZero() {
c.logf.JSON(1, "controltime", resp.ControlTime.UTC())
if c.onControlTime != nil {
c.onControlTime(*resp.ControlTime)
}
c.controlTimePub.Publish(ControlTime{c.controlClientID, *resp.ControlTime})
}
if resp.KeepAlive {
vlogf("netmap: got keep-alive")
} else {
vlogf("netmap: got new map")
}
if resp.ControlDialPlan != nil {
if resp.ControlDialPlan != nil && !ignoreDialPlan() {
if c.dialPlan != nil {
c.logf("netmap: got new dial plan from control")
c.dialPlan.Store(resp.ControlDialPlan)
@@ -1098,11 +1184,21 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap
metricMapResponseKeepAlives.Add(1)
continue
}
if au, ok := resp.DefaultAutoUpdate.Get(); ok {
if c.onTailnetDefaultAutoUpdate != nil {
c.onTailnetDefaultAutoUpdate(au)
// DefaultAutoUpdate in its CapMap and deprecated top-level field forms.
if self := resp.Node; self != nil {
for _, v := range self.CapMap[tailcfg.NodeAttrDefaultAutoUpdate] {
switch v {
case "true", "false":
c.autoUpdatePub.Publish(AutoUpdate{c.controlClientID, v == "true"})
default:
c.logf("netmap: [unexpected] unknown %s in CapMap: %q", tailcfg.NodeAttrDefaultAutoUpdate, v)
}
}
}
if au, ok := resp.DeprecatedDefaultAutoUpdate.Get(); ok {
c.autoUpdatePub.Publish(AutoUpdate{c.controlClientID, au})
}
metricMapResponseMap.Add(1)
if gotNonKeepAliveMessage {
@@ -1125,12 +1221,33 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap
return nil
}
// NetmapFromMapResponseForDebug returns a NetworkMap from the given MapResponse.
// It is intended for debugging only.
func NetmapFromMapResponseForDebug(ctx context.Context, pr persist.PersistView, resp *tailcfg.MapResponse) (*netmap.NetworkMap, error) {
if resp == nil {
return nil, errors.New("nil MapResponse")
}
if resp.Node == nil {
return nil, errors.New("MapResponse lacks Node")
}
nu := &rememberLastNetmapUpdater{}
sess := newMapSession(pr.PrivateNodeKey(), nu, nil)
defer sess.Close()
if err := sess.HandleNonKeepAliveMapResponse(ctx, resp); err != nil {
return nil, fmt.Errorf("HandleNonKeepAliveMapResponse: %w", err)
}
return sess.netmap(), nil
}
func (c *Direct) handleDebugMessage(ctx context.Context, debug *tailcfg.Debug) error {
if code := debug.Exit; code != nil {
c.logf("exiting process with status %v per controlplane", *code)
os.Exit(*code)
}
if debug.DisableLogTail {
if buildfeatures.HasLogTail && debug.DisableLogTail {
logtail.Disable()
envknob.SetNoLogsNoSupport()
}
@@ -1179,12 +1296,23 @@ func decode(res *http.Response, v any) error {
var jsonEscapedZero = []byte(`\u0000`)
const justKeepAliveStr = `{"KeepAlive":true}`
// decodeMsg is responsible for uncompressing msg and unmarshaling into v.
func (c *Direct) decodeMsg(compressedMsg []byte, v any) error {
func (sess *mapSession) decodeMsg(compressedMsg []byte, v *tailcfg.MapResponse) error {
// Fast path for common case of keep-alive message.
// See tailscale/tailscale#17343.
if sess.keepAliveZ != nil && bytes.Equal(compressedMsg, sess.keepAliveZ) {
v.KeepAlive = true
return nil
}
b, err := zstdframe.AppendDecode(nil, compressedMsg)
if err != nil {
return err
}
sess.ztdDecodesForTest++
if DevKnob.DumpNetMaps() {
var buf bytes.Buffer
json.Indent(&buf, b, "", " ")
@@ -1197,6 +1325,9 @@ func (c *Direct) decodeMsg(compressedMsg []byte, v any) error {
if err := json.Unmarshal(b, v); err != nil {
return fmt.Errorf("response: %v", err)
}
if v.KeepAlive && string(b) == justKeepAliveStr {
sess.keepAliveZ = compressedMsg
}
return nil
}
@@ -1244,7 +1375,7 @@ func loadServerPubKeys(ctx context.Context, httpc *http.Client, serverURL string
out = tailcfg.OverTLSPublicKeyResponse{}
k, err := key.ParseMachinePublicUntyped(mem.B(b))
if err != nil {
return nil, multierr.New(jsonErr, err)
return nil, errors.Join(jsonErr, err)
}
out.LegacyPublicKey = k
return &out, nil
@@ -1314,6 +1445,10 @@ func (c *Direct) isUniquePingRequest(pr *tailcfg.PingRequest) bool {
return true
}
// HookAnswerC2NPing is where feature/c2n conditionally registers support
// for handling C2N (control-to-node) HTTP requests.
var HookAnswerC2NPing feature.Hook[func(logger.Logf, http.Handler, *http.Client, *tailcfg.PingRequest)]
func (c *Direct) answerPing(pr *tailcfg.PingRequest) {
httpc := c.httpc
useNoise := pr.URLIsNoise || pr.Types == "c2n"
@@ -1334,11 +1469,16 @@ func (c *Direct) answerPing(pr *tailcfg.PingRequest) {
answerHeadPing(c.logf, httpc, pr)
return
case "c2n":
if !buildfeatures.HasC2N {
return
}
if !useNoise && !envknob.Bool("TS_DEBUG_PERMIT_HTTP_C2N") {
c.logf("refusing to answer c2n ping without noise")
return
}
answerC2NPing(c.logf, c.c2nHandler, httpc, pr)
if f, ok := HookAnswerC2NPing.GetOk(); ok {
f(c.logf, c.c2nHandler, httpc, pr)
}
return
}
for _, t := range strings.Split(pr.Types, ",") {
@@ -1373,54 +1513,6 @@ func answerHeadPing(logf logger.Logf, c *http.Client, pr *tailcfg.PingRequest) {
}
}
func answerC2NPing(logf logger.Logf, c2nHandler http.Handler, c *http.Client, pr *tailcfg.PingRequest) {
if c2nHandler == nil {
logf("answerC2NPing: c2nHandler not defined")
return
}
hreq, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(pr.Payload)))
if err != nil {
logf("answerC2NPing: ReadRequest: %v", err)
return
}
if pr.Log {
logf("answerC2NPing: got c2n request for %v ...", hreq.RequestURI)
}
handlerTimeout := time.Minute
if v := hreq.Header.Get("C2n-Handler-Timeout"); v != "" {
handlerTimeout, _ = time.ParseDuration(v)
}
handlerCtx, cancel := context.WithTimeout(context.Background(), handlerTimeout)
defer cancel()
hreq = hreq.WithContext(handlerCtx)
rec := httprec.NewRecorder()
c2nHandler.ServeHTTP(rec, hreq)
cancel()
c2nResBuf := new(bytes.Buffer)
rec.Result().Write(c2nResBuf)
replyCtx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
req, err := http.NewRequestWithContext(replyCtx, "POST", pr.URL, c2nResBuf)
if err != nil {
logf("answerC2NPing: NewRequestWithContext: %v", err)
return
}
if pr.Log {
logf("answerC2NPing: sending POST ping to %v ...", pr.URL)
}
t0 := clock.Now()
_, err = c.Do(req)
d := time.Since(t0).Round(time.Millisecond)
if err != nil {
logf("answerC2NPing error: %v to %v (after %v)", err, pr.URL, d)
} else if pr.Log {
logf("answerC2NPing complete to %v (after %v)", pr.URL, d)
}
}
// sleepAsRequest implements the sleep for a tailcfg.Debug message requesting
// that the client sleep. The complication is that while we're sleeping (if for
// a long time), we need to periodically reset the watchdog timer before it
@@ -1445,7 +1537,7 @@ func sleepAsRequested(ctx context.Context, logf logger.Logf, d time.Duration, cl
}
// getNoiseClient returns the noise client, creating one if one doesn't exist.
func (c *Direct) getNoiseClient() (*NoiseClient, error) {
func (c *Direct) getNoiseClient() (*ts2021.Client, error) {
c.mu.Lock()
serverNoiseKey := c.serverNoiseKey
nc := c.noiseClient
@@ -1460,13 +1552,13 @@ func (c *Direct) getNoiseClient() (*NoiseClient, error) {
if c.dialPlan != nil {
dp = c.dialPlan.Load
}
nc, err, _ := c.sfGroup.Do(struct{}{}, func() (*NoiseClient, error) {
nc, err, _ := c.sfGroup.Do(struct{}{}, func() (*ts2021.Client, error) {
k, err := c.getMachinePrivKey()
if err != nil {
return nil, err
}
c.logf("[v1] creating new noise client")
nc, err := NewNoiseClient(NoiseOpts{
nc, err := ts2021.NewClient(ts2021.ClientOpts{
PrivKey: k,
ServerPubKey: serverNoiseKey,
ServerURL: c.serverURL,
@@ -1500,7 +1592,7 @@ func (c *Direct) setDNSNoise(ctx context.Context, req *tailcfg.SetDNSRequest) er
if err != nil {
return err
}
res, err := nc.post(ctx, "/machine/set-dns", newReq.NodeKey, &newReq)
res, err := nc.Post(ctx, "/machine/set-dns", newReq.NodeKey, &newReq)
if err != nil {
return err
}
@@ -1521,6 +1613,9 @@ func (c *Direct) setDNSNoise(ctx context.Context, req *tailcfg.SetDNSRequest) er
// SetDNS sends the SetDNSRequest request to the control plane server,
// requesting a DNS record be created or updated.
func (c *Direct) SetDNS(ctx context.Context, req *tailcfg.SetDNSRequest) (err error) {
if !buildfeatures.HasACME {
return feature.ErrUnavailable
}
metricSetDNS.Add(1)
defer func() {
if err != nil {
@@ -1541,20 +1636,6 @@ func (c *Direct) DoNoiseRequest(req *http.Request) (*http.Response, error) {
return nc.Do(req)
}
// GetSingleUseNoiseRoundTripper returns a RoundTripper that can be only be used
// once (and must be used once) to make a single HTTP request over the noise
// channel to the coordination server.
//
// In addition to the RoundTripper, it returns the HTTP/2 channel's early noise
// payload, if any.
func (c *Direct) GetSingleUseNoiseRoundTripper(ctx context.Context) (http.RoundTripper, *tailcfg.EarlyNoise, error) {
nc, err := c.getNoiseClient()
if err != nil {
return nil, nil, err
}
return nc.GetSingleUseRoundTripper(ctx)
}
// doPingerPing sends a Ping to pr.IP using pinger, and sends an http request back to
// pr.URL with ping response data.
func doPingerPing(logf logger.Logf, c *http.Client, pr *tailcfg.PingRequest, pinger Pinger, pingType tailcfg.PingType) {
@@ -1611,47 +1692,6 @@ func postPingResult(start time.Time, logf logger.Logf, c *http.Client, pr *tailc
return nil
}
// ReportHealthChange reports to the control plane a change to this node's
// health. w must be non-nil. us can be nil to indicate a healthy state for w.
func (c *Direct) ReportHealthChange(w *health.Warnable, us *health.UnhealthyState) {
if w == health.NetworkStatusWarnable || w == health.IPNStateWarnable || w == health.LoginStateWarnable {
// We don't report these. These include things like the network is down
// (in which case we can't report anyway) or the user wanted things
// stopped, as opposed to the more unexpected failure types in the other
// subsystems.
return
}
np, err := c.getNoiseClient()
if err != nil {
// Don't report errors to control if the server doesn't support noise.
return
}
nodeKey, ok := c.GetPersist().PublicNodeKeyOK()
if !ok {
return
}
if c.panicOnUse {
panic("tainted client")
}
// TODO(angott): at some point, update `Subsys` in the request to be `Warnable`
req := &tailcfg.HealthChangeRequest{
Subsys: string(w.Code),
NodeKey: nodeKey,
}
if us != nil {
req.Error = us.Text
}
// Best effort, no logging:
ctx, cancel := context.WithTimeout(c.closedCtx, 5*time.Second)
defer cancel()
res, err := np.post(ctx, "/machine/update-health", nodeKey, req)
if err != nil {
return
}
res.Body.Close()
}
// SetDeviceAttrs does a synchronous call to the control plane to update
// the node's attributes.
//
@@ -1690,7 +1730,7 @@ func (c *Direct) SetDeviceAttrs(ctx context.Context, attrs tailcfg.AttrUpdate) e
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
res, err := nc.doWithBody(ctx, "PATCH", "/machine/set-device-attr", nodeKey, req)
res, err := nc.DoWithBody(ctx, "PATCH", "/machine/set-device-attr", nodeKey, req)
if err != nil {
return err
}
@@ -1731,7 +1771,7 @@ func (c *Direct) sendAuditLog(ctx context.Context, auditLog tailcfg.AuditLogRequ
panic("tainted client")
}
res, err := nc.post(ctx, "/machine/audit-log", nodeKey, req)
res, err := nc.Post(ctx, "/machine/audit-log", nodeKey, req)
if err != nil {
return fmt.Errorf("%w: %w", errHTTPPostFailure, err)
}
@@ -1743,20 +1783,12 @@ func (c *Direct) sendAuditLog(ctx context.Context, auditLog tailcfg.AuditLogRequ
return nil
}
func addLBHeader(req *http.Request, nodeKey key.NodePublic) {
if !nodeKey.IsZero() {
req.Header.Add(tailcfg.LBHeader, nodeKey.String())
}
}
type dialFunc = func(ctx context.Context, network, addr string) (net.Conn, error)
// makeScreenTimeDetectingDialFunc returns dialFunc, optionally wrapped (on
// Apple systems) with a func that sets the returned atomic.Bool for whether
// Screen Time seemed to intercept the connection.
//
// The returned *atomic.Bool is nil on non-Apple systems.
func makeScreenTimeDetectingDialFunc(dial dialFunc) (dialFunc, *atomic.Bool) {
func makeScreenTimeDetectingDialFunc(dial netx.DialFunc) (netx.DialFunc, *atomic.Bool) {
switch runtime.GOOS {
case "darwin", "ios":
// Continue below.
@@ -1774,6 +1806,13 @@ func makeScreenTimeDetectingDialFunc(dial dialFunc) (dialFunc, *atomic.Bool) {
}, ab
}
func ignoreDialPlan() bool {
// If we're running in v86 (a JavaScript-based emulation of a 32-bit x86)
// our networking is very limited. Let's ignore the dial plan since it's too
// complicated to race that many IPs anyway.
return hostinfo.IsInVM86()
}
func isTCPLoopback(a net.Addr) bool {
if ta, ok := a.(*net.TCPAddr); ok {
return ta.IP.IsLoopback()

View File

@@ -6,7 +6,10 @@ package controlclient
import (
"cmp"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"io"
"maps"
"net"
"reflect"
@@ -19,6 +22,7 @@ import (
"tailscale.com/control/controlknobs"
"tailscale.com/envknob"
"tailscale.com/hostinfo"
"tailscale.com/tailcfg"
"tailscale.com/tstime"
"tailscale.com/types/key"
@@ -53,6 +57,9 @@ type mapSession struct {
altClock tstime.Clock // if nil, regular time is used
cancel context.CancelFunc // always non-nil, shuts down caller's base long poll context
keepAliveZ []byte // if non-nil, the learned zstd encoding of the just-KeepAlive message for this session
ztdDecodesForTest int // for testing
// sessionAliveCtx is a Background-based context that's alive for the
// duration of the mapSession that we own the lifetime of. It's closed by
// sessionAliveCtxClose.
@@ -86,6 +93,7 @@ type mapSession struct {
lastDomain string
lastDomainAuditLogID string
lastHealth []string
lastDisplayMessages map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage
lastPopBrowserURL string
lastTKAInfo *tailcfg.TKAInfo
lastNetmapSummary string // from NetworkMap.VeryConcise
@@ -308,6 +316,31 @@ func (ms *mapSession) updateStateFromResponse(resp *tailcfg.MapResponse) {
}
}
// In the copy/v86 wasm environment with limited networking, if the
// control plane didn't pick our DERP home for us, do it ourselves and
// mark all but the lowest region as NoMeasureNoHome. For prod, this
// will be Region 1, NYC, a compromise between the US and Europe. But
// really the control plane should pick this. This is only a fallback.
if hostinfo.IsInVM86() {
numCanMeasure := 0
lowest := 0
for rid, r := range dm.Regions {
if !r.NoMeasureNoHome {
numCanMeasure++
if lowest == 0 || rid < lowest {
lowest = rid
}
}
}
if numCanMeasure > 1 {
for rid, r := range dm.Regions {
if rid != lowest {
r.NoMeasureNoHome = true
}
}
}
}
// Zero-valued fields in a DERPMap mean that we're not changing
// anything and are using the previous value(s).
if ldm := ms.lastDERPMap; ldm != nil {
@@ -383,6 +416,21 @@ func (ms *mapSession) updateStateFromResponse(resp *tailcfg.MapResponse) {
if resp.Health != nil {
ms.lastHealth = resp.Health
}
if resp.DisplayMessages != nil {
if v, ok := resp.DisplayMessages["*"]; ok && v == nil {
ms.lastDisplayMessages = nil
}
for k, v := range resp.DisplayMessages {
if k == "*" {
continue
}
if v != nil {
mak.Set(&ms.lastDisplayMessages, k, *v)
} else {
delete(ms.lastDisplayMessages, k)
}
}
}
if resp.TKAInfo != nil {
ms.lastTKAInfo = resp.TKAInfo
}
@@ -802,9 +850,23 @@ func (ms *mapSession) sortedPeers() []tailcfg.NodeView {
func (ms *mapSession) netmap() *netmap.NetworkMap {
peerViews := ms.sortedPeers()
var msgs map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage
if len(ms.lastDisplayMessages) != 0 {
msgs = ms.lastDisplayMessages
} else if len(ms.lastHealth) > 0 {
// Convert all ms.lastHealth to the new [netmap.NetworkMap.DisplayMessages]
for _, h := range ms.lastHealth {
id := "health-" + strhash(h) // Unique ID in case there is more than one health message
mak.Set(&msgs, tailcfg.DisplayMessageID(id), tailcfg.DisplayMessage{
Title: "Coordination server reports an issue",
Severity: tailcfg.SeverityMedium,
Text: "The coordination server is reporting a health issue: " + h,
})
}
}
nm := &netmap.NetworkMap{
NodeKey: ms.publicNodeKey,
PrivateKey: ms.privateNodeKey,
MachineKey: ms.machinePubKey,
Peers: peerViews,
UserProfiles: make(map[tailcfg.UserID]tailcfg.UserProfileView),
@@ -816,7 +878,7 @@ func (ms *mapSession) netmap() *netmap.NetworkMap {
SSHPolicy: ms.lastSSHPolicy,
CollectServices: ms.collectServices,
DERPMap: ms.lastDERPMap,
ControlHealth: ms.lastHealth,
DisplayMessages: msgs,
TKAEnabled: ms.lastTKAInfo != nil && !ms.lastTKAInfo.Disabled,
}
@@ -829,8 +891,6 @@ func (ms *mapSession) netmap() *netmap.NetworkMap {
if node := ms.lastNode; node.Valid() {
nm.SelfNode = node
nm.Expiry = node.KeyExpiry()
nm.Name = node.Name()
nm.AllCaps = ms.lastCapSet
}
@@ -842,5 +902,12 @@ func (ms *mapSession) netmap() *netmap.NetworkMap {
if DevKnob.ForceProxyDNS() {
nm.DNS.Proxied = true
}
return nm
}
func strhash(h string) string {
s := sha256.New()
io.WriteString(s, h)
return hex.EncodeToString(s.Sum(nil))
}

View File

@@ -1,418 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package controlclient
import (
"bytes"
"cmp"
"context"
"encoding/json"
"errors"
"math"
"net/http"
"net/netip"
"net/url"
"sync"
"time"
"golang.org/x/net/http2"
"tailscale.com/control/controlhttp"
"tailscale.com/health"
"tailscale.com/internal/noiseconn"
"tailscale.com/net/dnscache"
"tailscale.com/net/netmon"
"tailscale.com/net/tsdial"
"tailscale.com/tailcfg"
"tailscale.com/tstime"
"tailscale.com/types/key"
"tailscale.com/types/logger"
"tailscale.com/util/mak"
"tailscale.com/util/multierr"
"tailscale.com/util/singleflight"
)
// NoiseClient provides a http.Client to connect to tailcontrol over
// the ts2021 protocol.
type NoiseClient struct {
// Client is an HTTP client to talk to the coordination server.
// It automatically makes a new Noise connection as needed.
// It does not support node key proofs. To do that, call
// noiseClient.getConn instead to make a connection.
*http.Client
// h2t is the HTTP/2 transport we use a bit to create new
// *http2.ClientConns. We don't use its connection pool and we don't use its
// dialing. We use it for exactly one reason: its idle timeout that can only
// be configured via the HTTP/1 config. And then we call NewClientConn (with
// an existing Noise connection) on the http2.Transport which sets up an
// http2.ClientConn using that idle timeout from an http1.Transport.
h2t *http2.Transport
// sfDial ensures that two concurrent requests for a noise connection only
// produce one shared one between the two callers.
sfDial singleflight.Group[struct{}, *noiseconn.Conn]
dialer *tsdial.Dialer
dnsCache *dnscache.Resolver
privKey key.MachinePrivate
serverPubKey key.MachinePublic
host string // the host part of serverURL
httpPort string // the default port to dial
httpsPort string // the fallback Noise-over-https port or empty if none
// dialPlan optionally returns a ControlDialPlan previously received
// from the control server; either the function or the return value can
// be nil.
dialPlan func() *tailcfg.ControlDialPlan
logf logger.Logf
netMon *netmon.Monitor
health *health.Tracker
// mu only protects the following variables.
mu sync.Mutex
closed bool
last *noiseconn.Conn // or nil
nextID int
connPool map[int]*noiseconn.Conn // active connections not yet closed; see noiseconn.Conn.Close
}
// NoiseOpts contains options for the NewNoiseClient function. All fields are
// required unless otherwise specified.
type NoiseOpts struct {
// PrivKey is this node's private key.
PrivKey key.MachinePrivate
// ServerPubKey is the public key of the server.
ServerPubKey key.MachinePublic
// ServerURL is the URL of the server to connect to.
ServerURL string
// Dialer's SystemDial function is used to connect to the server.
Dialer *tsdial.Dialer
// DNSCache is the caching Resolver to use to connect to the server.
//
// This field can be nil.
DNSCache *dnscache.Resolver
// Logf is the log function to use. This field can be nil.
Logf logger.Logf
// NetMon is the network monitor that, if set, will be used to get the
// network interface state. This field can be nil; if so, the current
// state will be looked up dynamically.
NetMon *netmon.Monitor
// HealthTracker, if non-nil, is the health tracker to use.
HealthTracker *health.Tracker
// DialPlan, if set, is a function that should return an explicit plan
// on how to connect to the server.
DialPlan func() *tailcfg.ControlDialPlan
}
// NewNoiseClient returns a new noiseClient for the provided server and machine key.
// serverURL is of the form https://<host>:<port> (no trailing slash).
//
// netMon may be nil, if non-nil it's used to do faster interface lookups.
// dialPlan may be nil
func NewNoiseClient(opts NoiseOpts) (*NoiseClient, error) {
logf := opts.Logf
u, err := url.Parse(opts.ServerURL)
if err != nil {
return nil, err
}
if u.Scheme != "http" && u.Scheme != "https" {
return nil, errors.New("invalid ServerURL scheme, must be http or https")
}
var httpPort string
var httpsPort string
addr, _ := netip.ParseAddr(u.Hostname())
isPrivateHost := addr.IsPrivate() || addr.IsLoopback() || u.Hostname() == "localhost"
if port := u.Port(); port != "" {
// If there is an explicit port specified, entirely rely on the scheme,
// unless it's http with a private host in which case we never try using HTTPS.
if u.Scheme == "https" {
httpPort = ""
httpsPort = port
} else if u.Scheme == "http" {
httpPort = port
httpsPort = "443"
if isPrivateHost {
logf("setting empty HTTPS port with http scheme and private host %s", u.Hostname())
httpsPort = ""
}
}
} else if u.Scheme == "http" && isPrivateHost {
// Whenever the scheme is http and the hostname is an IP address, do not set the HTTPS port,
// as there cannot be a TLS certificate issued for an IP, unless it's a public IP.
httpPort = "80"
httpsPort = ""
} else {
// Otherwise, use the standard ports
httpPort = "80"
httpsPort = "443"
}
np := &NoiseClient{
serverPubKey: opts.ServerPubKey,
privKey: opts.PrivKey,
host: u.Hostname(),
httpPort: httpPort,
httpsPort: httpsPort,
dialer: opts.Dialer,
dnsCache: opts.DNSCache,
dialPlan: opts.DialPlan,
logf: opts.Logf,
netMon: opts.NetMon,
health: opts.HealthTracker,
}
// Create the HTTP/2 Transport using a net/http.Transport
// (which only does HTTP/1) because it's the only way to
// configure certain properties on the http2.Transport. But we
// never actually use the net/http.Transport for any HTTP/1
// requests.
h2Transport, err := http2.ConfigureTransports(&http.Transport{
IdleConnTimeout: time.Minute,
})
if err != nil {
return nil, err
}
np.h2t = h2Transport
np.Client = &http.Client{Transport: np}
return np, nil
}
// GetSingleUseRoundTripper returns a RoundTripper that can be only be used once
// (and must be used once) to make a single HTTP request over the noise channel
// to the coordination server.
//
// In addition to the RoundTripper, it returns the HTTP/2 channel's early noise
// payload, if any.
func (nc *NoiseClient) GetSingleUseRoundTripper(ctx context.Context) (http.RoundTripper, *tailcfg.EarlyNoise, error) {
for tries := 0; tries < 3; tries++ {
conn, err := nc.getConn(ctx)
if err != nil {
return nil, nil, err
}
ok, earlyPayloadMaybeNil, err := conn.ReserveNewRequest(ctx)
if err != nil {
return nil, nil, err
}
if ok {
return conn, earlyPayloadMaybeNil, nil
}
}
return nil, nil, errors.New("[unexpected] failed to reserve a request on a connection")
}
// contextErr is an error that wraps another error and is used to indicate that
// the error was because a context expired.
type contextErr struct {
err error
}
func (e contextErr) Error() string {
return e.err.Error()
}
func (e contextErr) Unwrap() error {
return e.err
}
// getConn returns a noiseconn.Conn that can be used to make requests to the
// coordination server. It may return a cached connection or create a new one.
// Dials are singleflighted, so concurrent calls to getConn may only dial once.
// As such, context values may not be respected as there are no guarantees that
// the context passed to getConn is the same as the context passed to dial.
func (nc *NoiseClient) getConn(ctx context.Context) (*noiseconn.Conn, error) {
nc.mu.Lock()
if last := nc.last; last != nil && last.CanTakeNewRequest() {
nc.mu.Unlock()
return last, nil
}
nc.mu.Unlock()
for {
// We singeflight the dial to avoid making multiple connections, however
// that means that we can't simply cancel the dial if the context is
// canceled. Instead, we have to additionally check that the context
// which was canceled is our context and retry if our context is still
// valid.
conn, err, _ := nc.sfDial.Do(struct{}{}, func() (*noiseconn.Conn, error) {
c, err := nc.dial(ctx)
if err != nil {
if ctx.Err() != nil {
return nil, contextErr{ctx.Err()}
}
return nil, err
}
return c, nil
})
var ce contextErr
if err == nil || !errors.As(err, &ce) {
return conn, err
}
if ctx.Err() == nil {
// The dial failed because of a context error, but our context
// is still valid. Retry.
continue
}
// The dial failed because our context was canceled. Return the
// underlying error.
return nil, ce.Unwrap()
}
}
func (nc *NoiseClient) RoundTrip(req *http.Request) (*http.Response, error) {
ctx := req.Context()
conn, err := nc.getConn(ctx)
if err != nil {
return nil, err
}
return conn.RoundTrip(req)
}
// connClosed removes the connection with the provided ID from the pool
// of active connections.
func (nc *NoiseClient) connClosed(id int) {
nc.mu.Lock()
defer nc.mu.Unlock()
conn := nc.connPool[id]
if conn != nil {
delete(nc.connPool, id)
if nc.last == conn {
nc.last = nil
}
}
}
// Close closes all the underlying noise connections.
// It is a no-op and returns nil if the connection is already closed.
func (nc *NoiseClient) Close() error {
nc.mu.Lock()
nc.closed = true
conns := nc.connPool
nc.connPool = nil
nc.mu.Unlock()
var errors []error
for _, c := range conns {
if err := c.Close(); err != nil {
errors = append(errors, err)
}
}
return multierr.New(errors...)
}
// dial opens a new connection to tailcontrol, fetching the server noise key
// if not cached.
func (nc *NoiseClient) dial(ctx context.Context) (*noiseconn.Conn, error) {
nc.mu.Lock()
connID := nc.nextID
nc.nextID++
nc.mu.Unlock()
if tailcfg.CurrentCapabilityVersion > math.MaxUint16 {
// Panic, because a test should have started failing several
// thousand version numbers before getting to this point.
panic("capability version is too high to fit in the wire protocol")
}
var dialPlan *tailcfg.ControlDialPlan
if nc.dialPlan != nil {
dialPlan = nc.dialPlan()
}
// If we have a dial plan, then set our timeout as slightly longer than
// the maximum amount of time contained therein; we assume that
// explicit instructions on timeouts are more useful than a single
// hard-coded timeout.
//
// The default value of 5 is chosen so that, when there's no dial plan,
// we retain the previous behaviour of 10 seconds end-to-end timeout.
timeoutSec := 5.0
if dialPlan != nil {
for _, c := range dialPlan.Candidates {
if v := c.DialStartDelaySec + c.DialTimeoutSec; v > timeoutSec {
timeoutSec = v
}
}
}
// After we establish a connection, we need some time to actually
// upgrade it into a Noise connection. With a ballpark worst-case RTT
// of 1000ms, give ourselves an extra 5 seconds to complete the
// handshake.
timeoutSec += 5
// Be extremely defensive and ensure that the timeout is in the range
// [5, 60] seconds (e.g. if we accidentally get a negative number).
if timeoutSec > 60 {
timeoutSec = 60
} else if timeoutSec < 5 {
timeoutSec = 5
}
timeout := time.Duration(timeoutSec * float64(time.Second))
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
clientConn, err := (&controlhttp.Dialer{
Hostname: nc.host,
HTTPPort: nc.httpPort,
HTTPSPort: cmp.Or(nc.httpsPort, controlhttp.NoPort),
MachineKey: nc.privKey,
ControlKey: nc.serverPubKey,
ProtocolVersion: uint16(tailcfg.CurrentCapabilityVersion),
Dialer: nc.dialer.SystemDial,
DNSCache: nc.dnsCache,
DialPlan: dialPlan,
Logf: nc.logf,
NetMon: nc.netMon,
HealthTracker: nc.health,
Clock: tstime.StdClock{},
}).Dial(ctx)
if err != nil {
return nil, err
}
ncc, err := noiseconn.New(clientConn.Conn, nc.h2t, connID, nc.connClosed)
if err != nil {
return nil, err
}
nc.mu.Lock()
if nc.closed {
nc.mu.Unlock()
ncc.Close() // Needs to be called without holding the lock.
return nil, errors.New("noise client closed")
}
defer nc.mu.Unlock()
mak.Set(&nc.connPool, connID, ncc)
nc.last = ncc
return ncc, nil
}
// post does a POST to the control server at the given path, JSON-encoding body.
// The provided nodeKey is an optional load balancing hint.
func (nc *NoiseClient) post(ctx context.Context, path string, nodeKey key.NodePublic, body any) (*http.Response, error) {
return nc.doWithBody(ctx, "POST", path, nodeKey, body)
}
func (nc *NoiseClient) doWithBody(ctx context.Context, method, path string, nodeKey key.NodePublic, body any) (*http.Response, error) {
jbody, err := json.Marshal(body)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, method, "https://"+nc.host+path, bytes.NewReader(jbody))
if err != nil {
return nil, err
}
addLBHeader(req, nodeKey)
req.Header.Set("Content-Type", "application/json")
conn, err := nc.getConn(ctx)
if err != nil {
return nil, err
}
return conn.RoundTrip(req)
}

View File

@@ -18,7 +18,8 @@ import (
"github.com/tailscale/certstore"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/util/syspolicy"
"tailscale.com/util/syspolicy/pkey"
"tailscale.com/util/syspolicy/policyclient"
)
// getMachineCertificateSubject returns the exact name of a Subject that needs
@@ -30,8 +31,8 @@ import (
// each RegisterRequest will be unsigned.
//
// Example: "CN=Tailscale Inc Test Root CA,OU=Tailscale Inc Test Certificate Authority,O=Tailscale Inc,ST=ON,C=CA"
func getMachineCertificateSubject() string {
machineCertSubject, _ := syspolicy.GetString(syspolicy.MachineCertificateSubject, "")
func getMachineCertificateSubject(polc policyclient.Client) string {
machineCertSubject, _ := polc.GetString(pkey.MachineCertificateSubject, "")
return machineCertSubject
}
@@ -136,7 +137,7 @@ func findIdentity(subject string, st certstore.Store) (certstore.Identity, []*x5
// using that identity's public key. In addition to the signature, the full
// certificate chain is included so that the control server can validate the
// certificate from a copy of the root CA's certificate.
func signRegisterRequest(req *tailcfg.RegisterRequest, serverURL string, serverPubKey, machinePubKey key.MachinePublic) (err error) {
func signRegisterRequest(polc policyclient.Client, req *tailcfg.RegisterRequest, serverURL string, serverPubKey, machinePubKey key.MachinePublic) (err error) {
defer func() {
if err != nil {
err = fmt.Errorf("signRegisterRequest: %w", err)
@@ -147,7 +148,7 @@ func signRegisterRequest(req *tailcfg.RegisterRequest, serverURL string, serverP
return errBadRequest
}
machineCertificateSubject := getMachineCertificateSubject()
machineCertificateSubject := getMachineCertificateSubject(polc)
if machineCertificateSubject == "" {
return errCertificateNotConfigured
}

View File

@@ -8,9 +8,10 @@ package controlclient
import (
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/util/syspolicy/policyclient"
)
// signRegisterRequest on non-supported platforms always returns errNoCertStore.
func signRegisterRequest(req *tailcfg.RegisterRequest, serverURL string, serverPubKey, machinePubKey key.MachinePublic) error {
func signRegisterRequest(polc policyclient.Client, req *tailcfg.RegisterRequest, serverURL string, serverPubKey, machinePubKey key.MachinePublic) error {
return errNoCertStore
}

View File

@@ -4,8 +4,6 @@
package controlclient
import (
"encoding/json"
"fmt"
"reflect"
"tailscale.com/types/netmap"
@@ -13,57 +11,6 @@ import (
"tailscale.com/types/structs"
)
// State is the high-level state of the client. It is used only in
// unit tests for proper sequencing, don't depend on it anywhere else.
//
// TODO(apenwarr): eliminate the state, as it's now obsolete.
//
// apenwarr: Historical note: controlclient.Auto was originally
// intended to be the state machine for the whole tailscale client, but that
// turned out to not be the right abstraction layer, and it moved to
// ipn.Backend. Since ipn.Backend now has a state machine, it would be
// much better if controlclient could be a simple stateless API. But the
// current server-side API (two interlocking polling https calls) makes that
// very hard to implement. A server side API change could untangle this and
// remove all the statefulness.
type State int
const (
StateNew = State(iota)
StateNotAuthenticated
StateAuthenticating
StateURLVisitRequired
StateAuthenticated
StateSynchronized // connected and received map update
)
func (s State) AppendText(b []byte) ([]byte, error) {
return append(b, s.String()...), nil
}
func (s State) MarshalText() ([]byte, error) {
return []byte(s.String()), nil
}
func (s State) String() string {
switch s {
case StateNew:
return "state:new"
case StateNotAuthenticated:
return "state:not-authenticated"
case StateAuthenticating:
return "state:authenticating"
case StateURLVisitRequired:
return "state:url-visit-required"
case StateAuthenticated:
return "state:authenticated"
case StateSynchronized:
return "state:synchronized"
default:
return fmt.Sprintf("state:unknown:%d", int(s))
}
}
type Status struct {
_ structs.Incomparable
@@ -76,6 +23,14 @@ type Status struct {
// URL, if non-empty, is the interactive URL to visit to finish logging in.
URL string
// LoggedIn, if true, indicates that serveRegister has completed and no
// other login change is in progress.
LoggedIn bool
// InMapPoll, if true, indicates that we've received at least one netmap
// and are connected to receive updates.
InMapPoll bool
// NetMap is the latest server-pushed state of the tailnet network.
NetMap *netmap.NetworkMap
@@ -83,26 +38,8 @@ type Status struct {
//
// TODO(bradfitz,maisem): clarify this.
Persist persist.PersistView
// state is the internal state. It should not be exposed outside this
// package, but we have some automated tests elsewhere that need to
// use it via the StateForTest accessor.
// TODO(apenwarr): Unexport or remove these.
state State
}
// LoginFinished reports whether the controlclient is in its "StateAuthenticated"
// state where it's in a happy register state but not yet in a map poll.
//
// TODO(bradfitz): delete this and everything around Status.state.
func (s *Status) LoginFinished() bool { return s.state == StateAuthenticated }
// StateForTest returns the internal state of s for tests only.
func (s *Status) StateForTest() State { return s.state }
// SetStateForTest sets the internal state of s for tests only.
func (s *Status) SetStateForTest(state State) { s.state = state }
// Equal reports whether s and s2 are equal.
func (s *Status) Equal(s2 *Status) bool {
if s == nil && s2 == nil {
@@ -111,15 +48,8 @@ func (s *Status) Equal(s2 *Status) bool {
return s != nil && s2 != nil &&
s.Err == s2.Err &&
s.URL == s2.URL &&
s.state == s2.state &&
s.LoggedIn == s2.LoggedIn &&
s.InMapPoll == s2.InMapPoll &&
reflect.DeepEqual(s.Persist, s2.Persist) &&
reflect.DeepEqual(s.NetMap, s2.NetMap)
}
func (s Status) String() string {
b, err := json.MarshalIndent(s, "", "\t")
if err != nil {
panic(err)
}
return s.state.String() + " " + string(b)
}

View File

@@ -20,37 +20,37 @@
package controlhttp
import (
"cmp"
"context"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"io"
"math"
"net"
"net/http"
"net/http/httptrace"
"net/netip"
"net/url"
"runtime"
"sort"
"sync/atomic"
"time"
"tailscale.com/control/controlbase"
"tailscale.com/control/controlhttp/controlhttpcommon"
"tailscale.com/envknob"
"tailscale.com/feature"
"tailscale.com/feature/buildfeatures"
"tailscale.com/health"
"tailscale.com/net/dnscache"
"tailscale.com/net/dnsfallback"
"tailscale.com/net/netutil"
"tailscale.com/net/netx"
"tailscale.com/net/sockstats"
"tailscale.com/net/tlsdial"
"tailscale.com/net/tshttpproxy"
"tailscale.com/syncs"
"tailscale.com/tailcfg"
"tailscale.com/tstime"
"tailscale.com/util/multierr"
)
var stdDialer net.Dialer
@@ -81,7 +81,7 @@ func (a *Dialer) getProxyFunc() func(*http.Request) (*url.URL, error) {
if a.proxyFunc != nil {
return a.proxyFunc
}
return tshttpproxy.ProxyFromEnvironment
return feature.HookProxyFromEnvironment.GetOrNil()
}
// httpsFallbackDelay is how long we'll wait for a.HTTPPort to work before
@@ -103,157 +103,71 @@ func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) {
// host we know about.
useDialPlan := envknob.BoolDefaultTrue("TS_USE_CONTROL_DIAL_PLAN")
if !useDialPlan || a.DialPlan == nil || len(a.DialPlan.Candidates) == 0 {
return a.dialHost(ctx, netip.Addr{})
return a.dialHost(ctx)
}
candidates := a.DialPlan.Candidates
// Otherwise, we try dialing per the plan. Store the highest priority
// in the list, so that if we get a connection to one of those
// candidates we can return quickly.
var highestPriority int = math.MinInt
for _, c := range candidates {
if c.Priority > highestPriority {
highestPriority = c.Priority
}
}
// This context allows us to cancel in-flight connections if we get a
// highest-priority connection before we're all done.
// Create a context to be canceled as we return, so once we get a good connection,
// we can drop all the other ones.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// Now, for each candidate, kick off a dial in parallel.
type dialResult struct {
conn *ClientConn
err error
addr netip.Addr
priority int
}
resultsCh := make(chan dialResult, len(candidates))
var pending atomic.Int32
pending.Store(int32(len(candidates)))
for _, c := range candidates {
go func(ctx context.Context, c tailcfg.ControlIPCandidate) {
var (
conn *ClientConn
err error
)
// Always send results back to our channel.
defer func() {
resultsCh <- dialResult{conn, err, c.IP, c.Priority}
if pending.Add(-1) == 0 {
close(resultsCh)
}
}()
// If non-zero, wait the configured start timeout
// before we do anything.
if c.DialStartDelaySec > 0 {
a.logf("[v2] controlhttp: waiting %.2f seconds before dialing %q @ %v", c.DialStartDelaySec, a.Hostname, c.IP)
tmr, tmrChannel := a.clock().NewTimer(time.Duration(c.DialStartDelaySec * float64(time.Second)))
defer tmr.Stop()
select {
case <-ctx.Done():
err = ctx.Err()
return
case <-tmrChannel:
}
}
// Now, create a sub-context with the given timeout and
// try dialing the provided host.
ctx, cancel := context.WithTimeout(ctx, time.Duration(c.DialTimeoutSec*float64(time.Second)))
defer cancel()
// This will dial, and the defer above sends it back to our parent.
a.logf("[v2] controlhttp: trying to dial %q @ %v", a.Hostname, c.IP)
conn, err = a.dialHost(ctx, c.IP)
}(ctx, c)
}
var results []dialResult
for res := range resultsCh {
// If we get a response that has the highest priority, we don't
// need to wait for any of the other connections to finish; we
// can just return this connection.
//
// TODO(andrew): we could make this better by keeping track of
// the highest remaining priority dynamically, instead of just
// checking for the highest total
if res.priority == highestPriority && res.conn != nil {
a.logf("[v1] controlhttp: high-priority success dialing %q @ %v from dial plan", a.Hostname, res.addr)
// Drain the channel and any existing connections in
// the background.
go func() {
for _, res := range results {
if res.conn != nil {
res.conn.Close()
}
}
for res := range resultsCh {
if res.conn != nil {
res.conn.Close()
}
}
if a.drainFinished != nil {
close(a.drainFinished)
}
}()
return res.conn, nil
}
// This isn't a highest-priority result, so just store it until
// we're done.
results = append(results, res)
}
// After we finish this function, close any remaining open connections.
defer func() {
for _, result := range results {
// Note: below, we nil out the returned connection (if
// any) in the slice so we don't close it.
if result.conn != nil {
result.conn.Close()
}
}
// We don't drain asynchronously after this point, so notify our
// channel when we return.
if a.drainFinished != nil {
close(a.drainFinished)
}
}()
// Sort by priority, then take the first non-error response.
sort.Slice(results, func(i, j int) bool {
// NOTE: intentionally inverted so that the highest priority
// item comes first
return results[i].priority > results[j].priority
})
var (
conn *ClientConn
errs []error
)
for i, result := range results {
if result.err != nil {
errs = append(errs, result.err)
continue
err error
}
resultsCh := make(chan dialResult) // unbuffered, never closed
dialCand := func(cand tailcfg.ControlIPCandidate) (*ClientConn, error) {
if cand.ACEHost != "" {
a.logf("[v2] controlhttp: waited %.2f seconds, dialing %q via ACE %s (%s)", cand.DialStartDelaySec, a.Hostname, cand.ACEHost, cmp.Or(cand.IP.String(), "dns"))
} else {
a.logf("[v2] controlhttp: waited %.2f seconds, dialing %q @ %s", cand.DialStartDelaySec, a.Hostname, cand.IP.String())
}
a.logf("[v1] controlhttp: succeeded dialing %q @ %v from dial plan", a.Hostname, result.addr)
conn = result.conn
results[i].conn = nil // so we don't close it in the defer
return conn, nil
ctx, cancel := context.WithTimeout(ctx, time.Duration(cand.DialTimeoutSec*float64(time.Second)))
defer cancel()
return a.dialHostOpt(ctx, cand.IP, cand.ACEHost)
}
merr := multierr.New(errs...)
// If we get here, then we didn't get anywhere with our dial plan; fall back to just using DNS.
a.logf("controlhttp: failed dialing using DialPlan, falling back to DNS; errs=%s", merr.Error())
return a.dialHost(ctx, netip.Addr{})
for _, cand := range candidates {
timer := time.AfterFunc(time.Duration(cand.DialStartDelaySec*float64(time.Second)), func() {
go func() {
conn, err := dialCand(cand)
select {
case resultsCh <- dialResult{conn, err}:
if err == nil {
a.logf("[v1] controlhttp: succeeded dialing %q @ %v from dial plan", a.Hostname, cmp.Or(cand.ACEHost, cand.IP.String()))
}
case <-ctx.Done():
if conn != nil {
conn.Close()
}
}
}()
})
defer timer.Stop()
}
var errs []error
for {
select {
case res := <-resultsCh:
if res.err == nil {
return res.conn, nil
}
errs = append(errs, res.err)
if len(errs) == len(candidates) {
// If we get here, then we didn't get anywhere with our dial plan; fall back to just using DNS.
a.logf("controlhttp: failed dialing using DialPlan, falling back to DNS; errs=%s", errors.Join(errs...))
return a.dialHost(ctx)
}
case <-ctx.Done():
a.logf("controlhttp: context aborted dialing")
return nil, ctx.Err()
}
}
}
// The TS_FORCE_NOISE_443 envknob forces the controlclient noise dialer to
@@ -270,6 +184,15 @@ var forceNoise443 = envknob.RegisterBool("TS_FORCE_NOISE_443")
// use HTTPS connections as its underlay connection (double crypto). This can
// be necessary when networks or middle boxes are messing with port 80.
func (d *Dialer) forceNoise443() bool {
if runtime.GOOS == "plan9" {
// For running demos of Plan 9 in a browser with network relays,
// we want to minimize the number of connections we're making.
// The main reason to use port 80 is to avoid double crypto
// costs server-side but the costs are tiny and number of Plan 9
// users doesn't make it worth it. Just disable this and always use
// HTTPS for Plan 9. That also reduces some log spam.
return true
}
if forceNoise443() {
return true
}
@@ -301,10 +224,19 @@ var debugNoiseDial = envknob.RegisterBool("TS_DEBUG_NOISE_DIAL")
// dialHost connects to the configured Dialer.Hostname and upgrades the
// connection into a controlbase.Conn.
func (a *Dialer) dialHost(ctx context.Context) (*ClientConn, error) {
return a.dialHostOpt(ctx,
netip.Addr{}, // no pre-resolved IP
"", // don't use ACE
)
}
// dialHostOpt connects to the configured Dialer.Hostname and upgrades the
// connection into a controlbase.Conn.
//
// If optAddr is valid, then no DNS is used and the connection will be made to the
// provided address.
func (a *Dialer) dialHost(ctx context.Context, optAddr netip.Addr) (*ClientConn, error) {
func (a *Dialer) dialHostOpt(ctx context.Context, optAddr netip.Addr, optACEHost string) (*ClientConn, error) {
// Create one shared context used by both port 80 and port 443 dials.
// If port 80 is still in flight when 443 returns, this deferred cancel
// will stop the port 80 dial.
@@ -326,7 +258,7 @@ func (a *Dialer) dialHost(ctx context.Context, optAddr netip.Addr) (*ClientConn,
Host: net.JoinHostPort(a.Hostname, strDef(a.HTTPSPort, "443")),
Path: serverUpgradePath,
}
if a.HTTPSPort == NoPort {
if a.HTTPSPort == NoPort || optACEHost != "" {
u443 = nil
}
@@ -338,11 +270,11 @@ func (a *Dialer) dialHost(ctx context.Context, optAddr netip.Addr) (*ClientConn,
ch := make(chan tryURLRes) // must be unbuffered
try := func(u *url.URL) {
if debugNoiseDial() {
a.logf("trying noise dial (%v, %v) ...", u, optAddr)
a.logf("trying noise dial (%v, %v) ...", u, cmp.Or(optACEHost, optAddr.String()))
}
cbConn, err := a.dialURL(ctx, u, optAddr)
cbConn, err := a.dialURL(ctx, u, optAddr, optACEHost)
if debugNoiseDial() {
a.logf("noise dial (%v, %v) = (%v, %v)", u, optAddr, cbConn, err)
a.logf("noise dial (%v, %v) = (%v, %v)", u, cmp.Or(optACEHost, optAddr.String()), cbConn, err)
}
select {
case ch <- tryURLRes{u, cbConn, err}:
@@ -373,6 +305,9 @@ func (a *Dialer) dialHost(ctx context.Context, optAddr netip.Addr) (*ClientConn,
}
var err80, err443 error
if forceTLS {
err80 = errors.New("TLS forced: no port 80 dialed")
}
for {
select {
case <-ctx.Done():
@@ -408,12 +343,12 @@ func (a *Dialer) dialHost(ctx context.Context, optAddr netip.Addr) (*ClientConn,
//
// If optAddr is valid, then no DNS is used and the connection will be made to the
// provided address.
func (a *Dialer) dialURL(ctx context.Context, u *url.URL, optAddr netip.Addr) (*ClientConn, error) {
func (a *Dialer) dialURL(ctx context.Context, u *url.URL, optAddr netip.Addr, optACEHost string) (*ClientConn, error) {
init, cont, err := controlbase.ClientDeferred(a.MachineKey, a.ControlKey, a.ProtocolVersion)
if err != nil {
return nil, err
}
netConn, err := a.tryURLUpgrade(ctx, u, optAddr, init)
netConn, err := a.tryURLUpgrade(ctx, u, optAddr, optACEHost, init)
if err != nil {
return nil, err
}
@@ -459,13 +394,15 @@ var macOSScreenTime = health.Register(&health.Warnable{
ImpactsConnectivity: true,
})
var HookMakeACEDialer feature.Hook[func(dialer netx.DialFunc, aceHost string, optIP netip.Addr) netx.DialFunc]
// tryURLUpgrade connects to u, and tries to upgrade it to a net.Conn.
//
// If optAddr is valid, then no DNS is used and the connection will be made to
// the provided address.
//
// Only the provided ctx is used, not a.ctx.
func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, optAddr netip.Addr, init []byte) (_ net.Conn, retErr error) {
func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, optAddr netip.Addr, optACEHost string, init []byte) (_ net.Conn, retErr error) {
var dns *dnscache.Resolver
// If we were provided an address to dial, then create a resolver that just
@@ -480,13 +417,24 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, optAddr netip.Ad
dns = a.resolver()
}
var dialer dnscache.DialContextFunc
var dialer netx.DialFunc
if a.Dialer != nil {
dialer = a.Dialer
} else {
dialer = stdDialer.DialContext
}
if optACEHost != "" {
if !buildfeatures.HasACE {
return nil, feature.ErrUnavailable
}
f, ok := HookMakeACEDialer.GetOk()
if !ok {
return nil, feature.ErrUnavailable
}
dialer = f(dialer, optACEHost, optAddr)
}
// On macOS, see if Screen Time is blocking things.
if runtime.GOOS == "darwin" {
var proxydIntercepted atomic.Bool // intercepted by macOS webfilterproxyd
@@ -513,13 +461,25 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, optAddr netip.Ad
tr := http.DefaultTransport.(*http.Transport).Clone()
defer tr.CloseIdleConnections()
tr.Proxy = a.getProxyFunc()
tshttpproxy.SetTransportGetProxyConnectHeader(tr)
tr.DialContext = dnscache.Dialer(dialer, dns)
if optACEHost != "" {
// If using ACE, we don't want to use any HTTP proxy.
// ACE is already a tunnel+proxy.
// TODO(tailscale/corp#32483): use system proxy too?
tr.Proxy = nil
tr.DialContext = dialer
} else {
if buildfeatures.HasUseProxy {
tr.Proxy = a.getProxyFunc()
if set, ok := feature.HookProxySetTransportGetProxyConnectHeader.GetOk(); ok {
set(tr)
}
}
tr.DialContext = dnscache.Dialer(dialer, dns)
}
// Disable HTTP2, since h2 can't do protocol switching.
tr.TLSClientConfig.NextProtos = []string{}
tr.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{}
tr.TLSClientConfig = tlsdial.Config(a.Hostname, a.HealthTracker, tr.TLSClientConfig)
tr.TLSClientConfig = tlsdial.Config(a.HealthTracker, tr.TLSClientConfig)
if !tr.TLSClientConfig.InsecureSkipVerify {
panic("unexpected") // should be set by tlsdial.Config
}

View File

@@ -12,6 +12,7 @@ import (
"tailscale.com/health"
"tailscale.com/net/dnscache"
"tailscale.com/net/netmon"
"tailscale.com/net/netx"
"tailscale.com/tailcfg"
"tailscale.com/tstime"
"tailscale.com/types/key"
@@ -66,7 +67,7 @@ type Dialer struct {
// Dialer is the dialer used to make outbound connections.
//
// If not specified, this defaults to net.Dialer.DialContext.
Dialer dnscache.DialContextFunc
Dialer netx.DialFunc
// DNSCache is the caching Resolver used by this Dialer.
//
@@ -77,8 +78,8 @@ type Dialer struct {
// dropped.
Logf logger.Logf
// NetMon is the [netmon.Monitor] to use for this Dialer. It must be
// non-nil.
// NetMon is the [netmon.Monitor] to use for this Dialer.
// It is optional.
NetMon *netmon.Monitor
// HealthTracker, if non-nil, is the health tracker to use.
@@ -97,7 +98,6 @@ type Dialer struct {
logPort80Failure atomic.Bool
// For tests only
drainFinished chan struct{}
omitCertErrorLogging bool
testFallbackDelay time.Duration

View File

@@ -62,8 +62,9 @@ type Knobs struct {
// netfiltering, unless overridden by the user.
LinuxForceNfTables atomic.Bool
// SeamlessKeyRenewal is whether to enable the alpha functionality of
// renewing node keys without breaking connections.
// SeamlessKeyRenewal is whether to renew node keys without breaking connections.
// This is enabled by default in 1.90 and later, but we but we can remotely disable
// it from the control plane if there's a problem.
// http://go/seamless-key-renewal
SeamlessKeyRenewal atomic.Bool
@@ -98,10 +99,6 @@ type Knobs struct {
// allows us to disable the new behavior remotely if needed.
DisableLocalDNSOverrideViaNRPT atomic.Bool
// DisableCryptorouting indicates that the node should not use the
// magicsock crypto routing feature.
DisableCryptorouting atomic.Bool
// DisableCaptivePortalDetection is whether the node should not perform captive portal detection
// automatically when the network state changes.
DisableCaptivePortalDetection atomic.Bool
@@ -132,12 +129,12 @@ func (k *Knobs) UpdateFromNodeAttributes(capMap tailcfg.NodeCapMap) {
forceIPTables = has(tailcfg.NodeAttrLinuxMustUseIPTables)
forceNfTables = has(tailcfg.NodeAttrLinuxMustUseNfTables)
seamlessKeyRenewal = has(tailcfg.NodeAttrSeamlessKeyRenewal)
disableSeamlessKeyRenewal = has(tailcfg.NodeAttrDisableSeamlessKeyRenewal)
probeUDPLifetime = has(tailcfg.NodeAttrProbeUDPLifetime)
appCStoreRoutes = has(tailcfg.NodeAttrStoreAppCRoutes)
userDialUseRoutes = has(tailcfg.NodeAttrUserDialUseRoutes)
disableSplitDNSWhenNoCustomResolvers = has(tailcfg.NodeAttrDisableSplitDNSWhenNoCustomResolvers)
disableLocalDNSOverrideViaNRPT = has(tailcfg.NodeAttrDisableLocalDNSOverrideViaNRPT)
disableCryptorouting = has(tailcfg.NodeAttrDisableMagicSockCryptoRouting)
disableCaptivePortalDetection = has(tailcfg.NodeAttrDisableCaptivePortalDetection)
disableSkipStatusQueue = has(tailcfg.NodeAttrDisableSkipStatusQueue)
)
@@ -159,15 +156,28 @@ func (k *Knobs) UpdateFromNodeAttributes(capMap tailcfg.NodeCapMap) {
k.SilentDisco.Store(silentDisco)
k.LinuxForceIPTables.Store(forceIPTables)
k.LinuxForceNfTables.Store(forceNfTables)
k.SeamlessKeyRenewal.Store(seamlessKeyRenewal)
k.ProbeUDPLifetime.Store(probeUDPLifetime)
k.AppCStoreRoutes.Store(appCStoreRoutes)
k.UserDialUseRoutes.Store(userDialUseRoutes)
k.DisableSplitDNSWhenNoCustomResolvers.Store(disableSplitDNSWhenNoCustomResolvers)
k.DisableLocalDNSOverrideViaNRPT.Store(disableLocalDNSOverrideViaNRPT)
k.DisableCryptorouting.Store(disableCryptorouting)
k.DisableCaptivePortalDetection.Store(disableCaptivePortalDetection)
k.DisableSkipStatusQueue.Store(disableSkipStatusQueue)
// If both attributes are present, then "enable" should win. This reflects
// the history of seamless key renewal.
//
// Before 1.90, seamless was a private alpha, opt-in feature. Devices would
// only seamless do if customers opted in using the seamless renewal attr.
//
// In 1.90 and later, seamless is the default behaviour, and devices will use
// seamless unless explicitly told not to by control (e.g. if we discover
// a bug and want clients to use the prior behaviour).
//
// If a customer has opted in to the pre-1.90 seamless implementation, we
// don't want to switch it off for them -- we only want to switch it off for
// devices that haven't opted in.
k.SeamlessKeyRenewal.Store(seamlessKeyRenewal || !disableSeamlessKeyRenewal)
}
// AsDebugJSON returns k as something that can be marshalled with json.Marshal

312
vendor/tailscale.com/control/ts2021/client.go generated vendored Normal file
View File

@@ -0,0 +1,312 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package ts2021
import (
"bytes"
"cmp"
"context"
"encoding/json"
"errors"
"fmt"
"log"
"math"
"net"
"net/http"
"net/netip"
"net/url"
"sync"
"time"
"tailscale.com/control/controlhttp"
"tailscale.com/health"
"tailscale.com/net/dnscache"
"tailscale.com/net/netmon"
"tailscale.com/net/tsdial"
"tailscale.com/tailcfg"
"tailscale.com/tstime"
"tailscale.com/types/key"
"tailscale.com/types/logger"
"tailscale.com/util/mak"
"tailscale.com/util/set"
)
// Client provides a http.Client to connect to tailcontrol over
// the ts2021 protocol.
type Client struct {
// Client is an HTTP client to talk to the coordination server.
// It automatically makes a new Noise connection as needed.
*http.Client
logf logger.Logf // non-nil
opts ClientOpts
host string // the host part of serverURL
httpPort string // the default port to dial
httpsPort string // the fallback Noise-over-https port or empty if none
// mu protects the following
mu sync.Mutex
closed bool
connPool set.HandleSet[*Conn] // all live connections
}
// ClientOpts contains options for the [NewClient] function. All fields are
// required unless otherwise specified.
type ClientOpts struct {
// ServerURL is the URL of the server to connect to.
ServerURL string
// PrivKey is this node's private key.
PrivKey key.MachinePrivate
// ServerPubKey is the public key of the server.
// It is of the form https://<host>:<port> (no trailing slash).
ServerPubKey key.MachinePublic
// Dialer's SystemDial function is used to connect to the server.
Dialer *tsdial.Dialer
// Optional fields follow
// Logf is the log function to use.
// If nil, log.Printf is used.
Logf logger.Logf
// NetMon is the network monitor that will be used to get the
// network interface state. This field can be nil; if so, the current
// state will be looked up dynamically.
NetMon *netmon.Monitor
// DNSCache is the caching Resolver to use to connect to the server.
//
// This field can be nil.
DNSCache *dnscache.Resolver
// HealthTracker, if non-nil, is the health tracker to use.
HealthTracker *health.Tracker
// DialPlan, if set, is a function that should return an explicit plan
// on how to connect to the server.
DialPlan func() *tailcfg.ControlDialPlan
// ProtocolVersion, if non-zero, specifies an alternate
// protocol version to use instead of the default,
// of [tailcfg.CurrentCapabilityVersion].
ProtocolVersion uint16
}
// NewClient returns a new noiseClient for the provided server and machine key.
//
// netMon may be nil, if non-nil it's used to do faster interface lookups.
// dialPlan may be nil
func NewClient(opts ClientOpts) (*Client, error) {
logf := opts.Logf
if logf == nil {
logf = log.Printf
}
if opts.ServerURL == "" {
return nil, errors.New("ServerURL is required")
}
if opts.PrivKey.IsZero() {
return nil, errors.New("PrivKey is required")
}
if opts.ServerPubKey.IsZero() {
return nil, errors.New("ServerPubKey is required")
}
if opts.Dialer == nil {
return nil, errors.New("Dialer is required")
}
u, err := url.Parse(opts.ServerURL)
if err != nil {
return nil, fmt.Errorf("invalid ClientOpts.ServerURL: %w", err)
}
if u.Scheme != "http" && u.Scheme != "https" {
return nil, errors.New("invalid ServerURL scheme, must be http or https")
}
httpPort, httpsPort := "80", "443"
addr, _ := netip.ParseAddr(u.Hostname())
isPrivateHost := addr.IsPrivate() || addr.IsLoopback() || u.Hostname() == "localhost"
if port := u.Port(); port != "" {
// If there is an explicit port specified, entirely rely on the scheme,
// unless it's http with a private host in which case we never try using HTTPS.
if u.Scheme == "https" {
httpPort = ""
httpsPort = port
} else if u.Scheme == "http" {
httpPort = port
httpsPort = "443"
if isPrivateHost {
logf("setting empty HTTPS port with http scheme and private host %s", u.Hostname())
httpsPort = ""
}
}
} else if u.Scheme == "http" && isPrivateHost {
// Whenever the scheme is http and the hostname is an IP address, do not set the HTTPS port,
// as there cannot be a TLS certificate issued for an IP, unless it's a public IP.
httpPort = "80"
httpsPort = ""
}
np := &Client{
opts: opts,
host: u.Hostname(),
httpPort: httpPort,
httpsPort: httpsPort,
logf: logf,
}
tr := &http.Transport{
Protocols: new(http.Protocols),
MaxConnsPerHost: 1,
}
// We force only HTTP/2 for this transport, which is what the control server
// speaks inside the ts2021 Noise encryption. But Go doesn't know about that,
// so we use "SetUnencryptedHTTP2" even though it's actually encrypted.
tr.Protocols.SetUnencryptedHTTP2(true)
tr.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return np.dial(ctx)
}
np.Client = &http.Client{Transport: tr}
return np, nil
}
// Close closes all the underlying noise connections.
// It is a no-op and returns nil if the connection is already closed.
func (nc *Client) Close() error {
nc.mu.Lock()
live := nc.connPool
nc.closed = true
nc.connPool = nil // stop noteConnClosed from mutating it as we loop over it (in live) below
nc.mu.Unlock()
for _, c := range live {
c.Close()
}
nc.Client.CloseIdleConnections()
return nil
}
// dial opens a new connection to tailcontrol, fetching the server noise key
// if not cached.
func (nc *Client) dial(ctx context.Context) (*Conn, error) {
if tailcfg.CurrentCapabilityVersion > math.MaxUint16 {
// Panic, because a test should have started failing several
// thousand version numbers before getting to this point.
panic("capability version is too high to fit in the wire protocol")
}
var dialPlan *tailcfg.ControlDialPlan
if nc.opts.DialPlan != nil {
dialPlan = nc.opts.DialPlan()
}
// If we have a dial plan, then set our timeout as slightly longer than
// the maximum amount of time contained therein; we assume that
// explicit instructions on timeouts are more useful than a single
// hard-coded timeout.
//
// The default value of 5 is chosen so that, when there's no dial plan,
// we retain the previous behaviour of 10 seconds end-to-end timeout.
timeoutSec := 5.0
if dialPlan != nil {
for _, c := range dialPlan.Candidates {
if v := c.DialStartDelaySec + c.DialTimeoutSec; v > timeoutSec {
timeoutSec = v
}
}
}
// After we establish a connection, we need some time to actually
// upgrade it into a Noise connection. With a ballpark worst-case RTT
// of 1000ms, give ourselves an extra 5 seconds to complete the
// handshake.
timeoutSec += 5
// Be extremely defensive and ensure that the timeout is in the range
// [5, 60] seconds (e.g. if we accidentally get a negative number).
if timeoutSec > 60 {
timeoutSec = 60
} else if timeoutSec < 5 {
timeoutSec = 5
}
timeout := time.Duration(timeoutSec * float64(time.Second))
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
chd := &controlhttp.Dialer{
Hostname: nc.host,
HTTPPort: nc.httpPort,
HTTPSPort: cmp.Or(nc.httpsPort, controlhttp.NoPort),
MachineKey: nc.opts.PrivKey,
ControlKey: nc.opts.ServerPubKey,
ProtocolVersion: cmp.Or(nc.opts.ProtocolVersion, uint16(tailcfg.CurrentCapabilityVersion)),
Dialer: nc.opts.Dialer.SystemDial,
DNSCache: nc.opts.DNSCache,
DialPlan: dialPlan,
Logf: nc.logf,
NetMon: nc.opts.NetMon,
HealthTracker: nc.opts.HealthTracker,
Clock: tstime.StdClock{},
}
clientConn, err := chd.Dial(ctx)
if err != nil {
return nil, err
}
nc.mu.Lock()
handle := set.NewHandle()
ncc := NewConn(clientConn.Conn, func() { nc.noteConnClosed(handle) })
mak.Set(&nc.connPool, handle, ncc)
if nc.closed {
nc.mu.Unlock()
ncc.Close() // Needs to be called without holding the lock.
return nil, errors.New("noise client closed")
}
defer nc.mu.Unlock()
return ncc, nil
}
// noteConnClosed notes that the *Conn with the given handle has closed and
// should be removed from the live connPool (which is usually of size 0 or 1,
// except perhaps briefly 2 during a network failure and reconnect).
func (nc *Client) noteConnClosed(handle set.Handle) {
nc.mu.Lock()
defer nc.mu.Unlock()
nc.connPool.Delete(handle)
}
// post does a POST to the control server at the given path, JSON-encoding body.
// The provided nodeKey is an optional load balancing hint.
func (nc *Client) Post(ctx context.Context, path string, nodeKey key.NodePublic, body any) (*http.Response, error) {
return nc.DoWithBody(ctx, "POST", path, nodeKey, body)
}
func (nc *Client) DoWithBody(ctx context.Context, method, path string, nodeKey key.NodePublic, body any) (*http.Response, error) {
jbody, err := json.Marshal(body)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, method, "https://"+nc.host+path, bytes.NewReader(jbody))
if err != nil {
return nil, err
}
AddLBHeader(req, nodeKey)
req.Header.Set("Content-Type", "application/json")
return nc.Do(req)
}
// AddLBHeader adds the load balancer header to req if nodeKey is non-zero.
func AddLBHeader(req *http.Request, nodeKey key.NodePublic) {
if !nodeKey.IsZero() {
req.Header.Add(tailcfg.LBHeader, nodeKey.String())
}
}

158
vendor/tailscale.com/control/ts2021/conn.go generated vendored Normal file
View File

@@ -0,0 +1,158 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package ts2021 handles the details of the Tailscale 2021 control protocol
// that are after (above) the Noise layer. In particular, the
// "tailcfg.EarlyNoise" message and the subsequent HTTP/2 connection.
package ts2021
import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"errors"
"io"
"sync"
"tailscale.com/control/controlbase"
"tailscale.com/tailcfg"
)
// Conn is a wrapper around controlbase.Conn.
//
// It allows attaching an ID to a connection to allow cleaning up references in
// the pool when the connection is closed, properly handles an optional "early
// payload" that's sent prior to beginning the HTTP/2 session, and provides a
// way to return a connection to a pool when the connection is closed.
//
// Use [NewConn] to build a new Conn if you want [Conn.GetEarlyPayload] to work.
// Otherwise making a Conn directly, only setting Conn, is fine.
type Conn struct {
*controlbase.Conn
onClose func() // or nil
readHeaderOnce sync.Once // guards init of reader field
reader io.Reader // (effectively Conn.Reader after header)
earlyPayloadReady chan struct{} // closed after earlyPayload is set (including set to nil)
earlyPayload *tailcfg.EarlyNoise
earlyPayloadErr error
}
// NewConn creates a new Conn that wraps the given controlbase.Conn.
//
// h2t is the HTTP/2 transport to use for the connection; a new
// http2.ClientConn will be created that reads from the returned Conn.
//
// connID should be a unique ID for this connection. When the Conn is closed,
// the onClose function will be called if it is non-nil.
func NewConn(conn *controlbase.Conn, onClose func()) *Conn {
return &Conn{
Conn: conn,
earlyPayloadReady: make(chan struct{}),
onClose: sync.OnceFunc(onClose),
}
}
// GetEarlyPayload waits for the early Noise payload to arrive.
// It may return (nil, nil) if the server begins HTTP/2 without one.
//
// It is safe to call this multiple times; all callers will block until the
// early Noise payload is ready (if any) and will return the same result for
// the lifetime of the Conn.
func (c *Conn) GetEarlyPayload(ctx context.Context) (*tailcfg.EarlyNoise, error) {
if c.earlyPayloadReady == nil {
return nil, errors.New("Conn was not created with NewConn; early payload not supported")
}
select {
case <-c.earlyPayloadReady:
return c.earlyPayload, c.earlyPayloadErr
default:
go c.readHeaderOnce.Do(c.readHeader)
}
select {
case <-c.earlyPayloadReady:
return c.earlyPayload, c.earlyPayloadErr
case <-ctx.Done():
return nil, ctx.Err()
}
}
// The first 9 bytes from the server to client over Noise are either an HTTP/2
// settings frame (a normal HTTP/2 setup) or, as we added later, an "early payload"
// header that's also 9 bytes long: 5 bytes (EarlyPayloadMagic) followed by 4 bytes
// of length. Then that many bytes of JSON-encoded tailcfg.EarlyNoise.
// The early payload is optional. Some servers may not send it.
const (
hdrLen = 9 // http2 frame header size; also size of our early payload size header
)
// EarlyPayloadMagic is the 5-byte magic prefix that indicates an early payload.
const EarlyPayloadMagic = "\xff\xff\xffTS"
// returnErrReader is an io.Reader that always returns an error.
type returnErrReader struct {
err error // the error to return
}
func (r returnErrReader) Read([]byte) (int, error) { return 0, r.err }
// Read is basically the same as controlbase.Conn.Read, but it first reads the
// "early payload" header from the server which may or may not be present,
// depending on the server.
func (c *Conn) Read(p []byte) (n int, err error) {
c.readHeaderOnce.Do(c.readHeader)
return c.reader.Read(p)
}
// Close closes the connection.
func (c *Conn) Close() error {
if c.onClose != nil {
defer c.onClose()
}
return c.Conn.Close()
}
// readHeader reads the optional "early payload" from the server that arrives
// after the Noise handshake but before the HTTP/2 session begins.
//
// readHeader is responsible for reading the header (if present), initializing
// c.earlyPayload, closing c.earlyPayloadReady, and initializing c.reader for
// future reads.
func (c *Conn) readHeader() {
if c.earlyPayloadReady != nil {
defer close(c.earlyPayloadReady)
}
setErr := func(err error) {
c.reader = returnErrReader{err}
c.earlyPayloadErr = err
}
var hdr [hdrLen]byte
if _, err := io.ReadFull(c.Conn, hdr[:]); err != nil {
setErr(err)
return
}
if string(hdr[:len(EarlyPayloadMagic)]) != EarlyPayloadMagic {
// No early payload. We have to return the 9 bytes read we already
// consumed.
c.reader = io.MultiReader(bytes.NewReader(hdr[:]), c.Conn)
return
}
epLen := binary.BigEndian.Uint32(hdr[len(EarlyPayloadMagic):])
if epLen > 10<<20 {
setErr(errors.New("invalid early payload length"))
return
}
payBuf := make([]byte, epLen)
if _, err := io.ReadFull(c.Conn, payBuf); err != nil {
setErr(err)
return
}
if err := json.Unmarshal(payBuf, &c.earlyPayload); err != nil {
setErr(err)
return
}
c.reader = c.Conn
}