Update dependencies

This commit is contained in:
bluepython508
2024-11-01 17:33:34 +00:00
parent 033ac0b400
commit 5cdfab398d
3596 changed files with 1033483 additions and 259 deletions

View File

@@ -0,0 +1,96 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type addressStateRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var addressStatelockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type addressStatelockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *addressStateRWMutex) Lock() {
locking.AddGLock(addressStateprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *addressStateRWMutex) NestedLock(i addressStatelockNameIndex) {
locking.AddGLock(addressStateprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *addressStateRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(addressStateprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *addressStateRWMutex) NestedUnlock(i addressStatelockNameIndex) {
m.mu.Unlock()
locking.DelGLock(addressStateprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *addressStateRWMutex) RLock() {
locking.AddGLock(addressStateprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *addressStateRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(addressStateprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *addressStateRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *addressStateRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *addressStateRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var addressStateprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func addressStateinitLockNames() {}
func init() {
addressStateinitLockNames()
addressStateprefixIndex = locking.NewMutexClass(reflect.TypeOf(addressStateRWMutex{}), addressStatelockNames)
}

View File

@@ -0,0 +1,142 @@
package stack
import (
"context"
"fmt"
"gvisor.dev/gvisor/pkg/atomicbitops"
"gvisor.dev/gvisor/pkg/refs"
)
// enableLogging indicates whether reference-related events should be logged (with
// stack traces). This is false by default and should only be set to true for
// debugging purposes, as it can generate an extremely large amount of output
// and drastically degrade performance.
const addressStateenableLogging = false
// obj is used to customize logging. Note that we use a pointer to T so that
// we do not copy the entire object when passed as a format parameter.
var addressStateobj *addressState
// Refs implements refs.RefCounter. It keeps a reference count using atomic
// operations and calls the destructor when the count reaches zero.
//
// NOTE: Do not introduce additional fields to the Refs struct. It is used by
// many filesystem objects, and we want to keep it as small as possible (i.e.,
// the same size as using an int64 directly) to avoid taking up extra cache
// space. In general, this template should not be extended at the cost of
// performance. If it does not offer enough flexibility for a particular object
// (example: b/187877947), we should implement the RefCounter/CheckedObject
// interfaces manually.
//
// +stateify savable
type addressStateRefs struct {
// refCount is composed of two fields:
//
// [32-bit speculative references]:[32-bit real references]
//
// Speculative references are used for TryIncRef, to avoid a CompareAndSwap
// loop. See IncRef, DecRef and TryIncRef for details of how these fields are
// used.
refCount atomicbitops.Int64
}
// InitRefs initializes r with one reference and, if enabled, activates leak
// checking.
func (r *addressStateRefs) InitRefs() {
r.refCount.RacyStore(1)
refs.Register(r)
}
// RefType implements refs.CheckedObject.RefType.
func (r *addressStateRefs) RefType() string {
return fmt.Sprintf("%T", addressStateobj)[1:]
}
// LeakMessage implements refs.CheckedObject.LeakMessage.
func (r *addressStateRefs) LeakMessage() string {
return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs())
}
// LogRefs implements refs.CheckedObject.LogRefs.
func (r *addressStateRefs) LogRefs() bool {
return addressStateenableLogging
}
// ReadRefs returns the current number of references. The returned count is
// inherently racy and is unsafe to use without external synchronization.
func (r *addressStateRefs) ReadRefs() int64 {
return r.refCount.Load()
}
// IncRef implements refs.RefCounter.IncRef.
//
//go:nosplit
func (r *addressStateRefs) IncRef() {
v := r.refCount.Add(1)
if addressStateenableLogging {
refs.LogIncRef(r, v)
}
if v <= 1 {
panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType()))
}
}
// TryIncRef implements refs.TryRefCounter.TryIncRef.
//
// To do this safely without a loop, a speculative reference is first acquired
// on the object. This allows multiple concurrent TryIncRef calls to distinguish
// other TryIncRef calls from genuine references held.
//
//go:nosplit
func (r *addressStateRefs) TryIncRef() bool {
const speculativeRef = 1 << 32
if v := r.refCount.Add(speculativeRef); int32(v) == 0 {
r.refCount.Add(-speculativeRef)
return false
}
v := r.refCount.Add(-speculativeRef + 1)
if addressStateenableLogging {
refs.LogTryIncRef(r, v)
}
return true
}
// DecRef implements refs.RefCounter.DecRef.
//
// Note that speculative references are counted here. Since they were added
// prior to real references reaching zero, they will successfully convert to
// real references. In other words, we see speculative references only in the
// following case:
//
// A: TryIncRef [speculative increase => sees non-negative references]
// B: DecRef [real decrease]
// A: TryIncRef [transform speculative to real]
//
//go:nosplit
func (r *addressStateRefs) DecRef(destroy func()) {
v := r.refCount.Add(-1)
if addressStateenableLogging {
refs.LogDecRef(r, v)
}
switch {
case v < 0:
panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType()))
case v == 0:
refs.Unregister(r)
if destroy != nil {
destroy()
}
}
}
func (r *addressStateRefs) afterLoad(context.Context) {
if r.ReadRefs() > 0 {
refs.Register(r)
}
}

View File

@@ -0,0 +1,950 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
func (lifetimes *AddressLifetimes) sanitize() {
if lifetimes.Deprecated {
lifetimes.PreferredUntil = tcpip.MonotonicTime{}
}
}
var _ AddressableEndpoint = (*AddressableEndpointState)(nil)
// AddressableEndpointState is an implementation of an AddressableEndpoint.
//
// +stateify savable
type AddressableEndpointState struct {
networkEndpoint NetworkEndpoint
options AddressableEndpointStateOptions
// Lock ordering (from outer to inner lock ordering):
//
// AddressableEndpointState.mu
// addressState.mu
mu addressableEndpointStateRWMutex `state:"nosave"`
// +checklocks:mu
endpoints map[tcpip.Address]*addressState
// +checklocks:mu
primary []*addressState
}
// AddressableEndpointStateOptions contains options used to configure an
// AddressableEndpointState.
//
// +stateify savable
type AddressableEndpointStateOptions struct {
// HiddenWhileDisabled determines whether addresses should be returned to
// callers while the NetworkEndpoint this AddressableEndpointState belongs
// to is disabled.
HiddenWhileDisabled bool
}
// Init initializes the AddressableEndpointState with networkEndpoint.
//
// Must be called before calling any other function on m.
func (a *AddressableEndpointState) Init(networkEndpoint NetworkEndpoint, options AddressableEndpointStateOptions) {
a.networkEndpoint = networkEndpoint
a.options = options
a.mu.Lock()
defer a.mu.Unlock()
a.endpoints = make(map[tcpip.Address]*addressState)
}
// OnNetworkEndpointEnabledChanged must be called every time the
// NetworkEndpoint this AddressableEndpointState belongs to is enabled or
// disabled so that any AddressDispatchers can be notified of the NIC enabled
// change.
func (a *AddressableEndpointState) OnNetworkEndpointEnabledChanged() {
a.mu.RLock()
defer a.mu.RUnlock()
for _, ep := range a.endpoints {
ep.mu.Lock()
ep.notifyChangedLocked()
ep.mu.Unlock()
}
}
// GetAddress returns the AddressEndpoint for the passed address.
//
// GetAddress does not increment the address's reference count or check if the
// address is considered bound to the endpoint.
//
// Returns nil if the passed address is not associated with the endpoint.
func (a *AddressableEndpointState) GetAddress(addr tcpip.Address) AddressEndpoint {
a.mu.RLock()
defer a.mu.RUnlock()
ep, ok := a.endpoints[addr]
if !ok {
return nil
}
return ep
}
// ForEachEndpoint calls f for each address.
//
// Once f returns false, f will no longer be called.
func (a *AddressableEndpointState) ForEachEndpoint(f func(AddressEndpoint) bool) {
a.mu.RLock()
defer a.mu.RUnlock()
for _, ep := range a.endpoints {
if !f(ep) {
return
}
}
}
// ForEachPrimaryEndpoint calls f for each primary address.
//
// Once f returns false, f will no longer be called.
func (a *AddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint) bool) {
a.mu.RLock()
defer a.mu.RUnlock()
for _, ep := range a.primary {
if !f(ep) {
return
}
}
}
func (a *AddressableEndpointState) releaseAddressState(addrState *addressState) {
a.mu.Lock()
defer a.mu.Unlock()
a.releaseAddressStateLocked(addrState)
}
// releaseAddressStateLocked removes addrState from a's address state
// (primary and endpoints list).
//
// +checklocks:a.mu
func (a *AddressableEndpointState) releaseAddressStateLocked(addrState *addressState) {
oldPrimary := a.primary
for i, s := range a.primary {
if s == addrState {
a.primary = append(a.primary[:i], a.primary[i+1:]...)
oldPrimary[len(oldPrimary)-1] = nil
break
}
}
delete(a.endpoints, addrState.addr.Address)
}
// AddAndAcquirePermanentAddress implements AddressableEndpoint.
func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties AddressProperties) (AddressEndpoint, tcpip.Error) {
return a.AddAndAcquireAddress(addr, properties, Permanent)
}
// AddAndAcquireTemporaryAddress adds a temporary address.
//
// Returns *tcpip.ErrDuplicateAddress if the address exists.
//
// The temporary address's endpoint is acquired and returned.
func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, tcpip.Error) {
return a.AddAndAcquireAddress(addr, AddressProperties{PEB: peb}, Temporary)
}
// AddAndAcquireAddress adds an address with the specified kind.
//
// Returns *tcpip.ErrDuplicateAddress if the address exists.
func (a *AddressableEndpointState) AddAndAcquireAddress(addr tcpip.AddressWithPrefix, properties AddressProperties, kind AddressKind) (AddressEndpoint, tcpip.Error) {
a.mu.Lock()
defer a.mu.Unlock()
ep, err := a.addAndAcquireAddressLocked(addr, properties, kind)
// From https://golang.org/doc/faq#nil_error:
//
// Under the covers, interfaces are implemented as two elements, a type T and
// a value V.
//
// An interface value is nil only if the V and T are both unset, (T=nil, V is
// not set), In particular, a nil interface will always hold a nil type. If we
// store a nil pointer of type *int inside an interface value, the inner type
// will be *int regardless of the value of the pointer: (T=*int, V=nil). Such
// an interface value will therefore be non-nil even when the pointer value V
// inside is nil.
//
// Since addAndAcquireAddressLocked returns a nil value with a non-nil type,
// we need to explicitly return nil below if ep is (a typed) nil.
if ep == nil {
return nil, err
}
return ep, err
}
// addAndAcquireAddressLocked adds, acquires and returns a permanent or
// temporary address.
//
// If the addressable endpoint already has the address in a non-permanent state,
// and addAndAcquireAddressLocked is adding a permanent address, that address is
// promoted in place and its properties set to the properties provided. If the
// address already exists in any other state, then *tcpip.ErrDuplicateAddress is
// returned, regardless the kind of address that is being added.
//
// +checklocks:a.mu
func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, properties AddressProperties, kind AddressKind) (*addressState, tcpip.Error) {
var permanent bool
switch kind {
case PermanentExpired:
panic(fmt.Sprintf("cannot add address %s in PermanentExpired state", addr))
case Permanent, PermanentTentative:
permanent = true
case Temporary:
default:
panic(fmt.Sprintf("unknown address kind: %d", kind))
}
// attemptAddToPrimary is false when the address is already in the primary
// address list.
attemptAddToPrimary := true
addrState, ok := a.endpoints[addr.Address]
if ok {
if !permanent {
// We are adding a non-permanent address but the address exists. No need
// to go any further since we can only promote existing temporary/expired
// addresses to permanent.
return nil, &tcpip.ErrDuplicateAddress{}
}
addrState.mu.RLock()
if addrState.refs.ReadRefs() == 0 {
panic(fmt.Sprintf("found an address that should have been released (ref count == 0); address = %s", addrState.addr))
}
isPermanent := addrState.kind.IsPermanent()
addrState.mu.RUnlock()
if isPermanent {
// We are adding a permanent address but a permanent address already
// exists.
return nil, &tcpip.ErrDuplicateAddress{}
}
// We now promote the address.
for i, s := range a.primary {
if s == addrState {
switch properties.PEB {
case CanBePrimaryEndpoint:
// The address is already in the primary address list.
attemptAddToPrimary = false
case FirstPrimaryEndpoint:
if i == 0 {
// The address is already first in the primary address list.
attemptAddToPrimary = false
} else {
a.primary = append(a.primary[:i], a.primary[i+1:]...)
}
case NeverPrimaryEndpoint:
a.primary = append(a.primary[:i], a.primary[i+1:]...)
default:
panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", properties.PEB))
}
break
}
}
addrState.refs.IncRef()
} else {
addrState = &addressState{
addressableEndpointState: a,
addr: addr,
temporary: properties.Temporary,
// Cache the subnet in addrState to avoid calls to addr.Subnet() as that
// results in allocations on every call.
subnet: addr.Subnet(),
}
addrState.refs.InitRefs()
a.endpoints[addr.Address] = addrState
// We never promote an address to temporary - it can only be added as such.
// If we are actually adding a permanent address, it is promoted below.
addrState.kind = Temporary
}
// At this point we have an address we are either promoting from an expired or
// temporary address to permanent, promoting an expired address to temporary,
// or we are adding a new temporary or permanent address.
//
// The address MUST be write locked at this point.
addrState.mu.Lock()
defer addrState.mu.Unlock()
if permanent {
if addrState.kind.IsPermanent() {
panic(fmt.Sprintf("only non-permanent addresses should be promoted to permanent; address = %s", addrState.addr))
}
// Primary addresses are biased by 1.
addrState.refs.IncRef()
addrState.kind = kind
}
addrState.configType = properties.ConfigType
lifetimes := properties.Lifetimes
lifetimes.sanitize()
addrState.lifetimes = lifetimes
addrState.disp = properties.Disp
if attemptAddToPrimary {
switch properties.PEB {
case NeverPrimaryEndpoint:
case CanBePrimaryEndpoint:
a.primary = append(a.primary, addrState)
case FirstPrimaryEndpoint:
if cap(a.primary) == len(a.primary) {
a.primary = append([]*addressState{addrState}, a.primary...)
} else {
// Shift all the endpoints by 1 to make room for the new address at the
// front. We could have just created a new slice but this saves
// allocations when the slice has capacity for the new address.
primaryCount := len(a.primary)
a.primary = append(a.primary, nil)
if n := copy(a.primary[1:], a.primary); n != primaryCount {
panic(fmt.Sprintf("copied %d elements; expected = %d elements", n, primaryCount))
}
a.primary[0] = addrState
}
default:
panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", properties.PEB))
}
}
addrState.notifyChangedLocked()
return addrState, nil
}
// RemovePermanentAddress implements AddressableEndpoint.
func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) tcpip.Error {
a.mu.Lock()
defer a.mu.Unlock()
return a.removePermanentAddressLocked(addr)
}
// removePermanentAddressLocked is like RemovePermanentAddress but with locking
// requirements.
//
// +checklocks:a.mu
func (a *AddressableEndpointState) removePermanentAddressLocked(addr tcpip.Address) tcpip.Error {
addrState, ok := a.endpoints[addr]
if !ok {
return &tcpip.ErrBadLocalAddress{}
}
return a.removePermanentEndpointLocked(addrState, AddressRemovalManualAction)
}
// RemovePermanentEndpoint removes the passed endpoint if it is associated with
// a and permanent.
func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint, reason AddressRemovalReason) tcpip.Error {
addrState, ok := ep.(*addressState)
if !ok || addrState.addressableEndpointState != a {
return &tcpip.ErrInvalidEndpointState{}
}
a.mu.Lock()
defer a.mu.Unlock()
return a.removePermanentEndpointLocked(addrState, reason)
}
// removePermanentAddressLocked is like RemovePermanentAddress but with locking
// requirements.
//
// +checklocks:a.mu
func (a *AddressableEndpointState) removePermanentEndpointLocked(addrState *addressState, reason AddressRemovalReason) tcpip.Error {
if !addrState.GetKind().IsPermanent() {
return &tcpip.ErrBadLocalAddress{}
}
addrState.remove(reason)
a.decAddressRefLocked(addrState)
return nil
}
// decAddressRef decrements the address's reference count and releases it once
// the reference count hits 0.
func (a *AddressableEndpointState) decAddressRef(addrState *addressState) {
a.mu.Lock()
defer a.mu.Unlock()
a.decAddressRefLocked(addrState)
}
// decAddressRefLocked is like decAddressRef but with locking requirements.
//
// +checklocks:a.mu
func (a *AddressableEndpointState) decAddressRefLocked(addrState *addressState) {
destroy := false
addrState.refs.DecRef(func() {
destroy = true
})
if !destroy {
return
}
addrState.mu.Lock()
defer addrState.mu.Unlock()
// A non-expired permanent address must not have its reference count dropped
// to 0.
if addrState.kind.IsPermanent() {
panic(fmt.Sprintf("permanent addresses should be removed through the AddressableEndpoint: addr = %s, kind = %d", addrState.addr, addrState.kind))
}
a.releaseAddressStateLocked(addrState)
}
// SetDeprecated implements stack.AddressableEndpoint.
func (a *AddressableEndpointState) SetDeprecated(addr tcpip.Address, deprecated bool) tcpip.Error {
a.mu.RLock()
defer a.mu.RUnlock()
addrState, ok := a.endpoints[addr]
if !ok {
return &tcpip.ErrBadLocalAddress{}
}
addrState.SetDeprecated(deprecated)
return nil
}
// SetLifetimes implements stack.AddressableEndpoint.
func (a *AddressableEndpointState) SetLifetimes(addr tcpip.Address, lifetimes AddressLifetimes) tcpip.Error {
a.mu.RLock()
defer a.mu.RUnlock()
addrState, ok := a.endpoints[addr]
if !ok {
return &tcpip.ErrBadLocalAddress{}
}
addrState.SetLifetimes(lifetimes)
return nil
}
// MainAddress implements AddressableEndpoint.
func (a *AddressableEndpointState) MainAddress() tcpip.AddressWithPrefix {
a.mu.RLock()
defer a.mu.RUnlock()
ep := a.acquirePrimaryAddressRLocked(tcpip.Address{}, tcpip.Address{} /* srcHint */, func(ep *addressState) bool {
switch kind := ep.GetKind(); kind {
case Permanent:
return a.networkEndpoint.Enabled() || !a.options.HiddenWhileDisabled
case PermanentTentative, PermanentExpired, Temporary:
return false
default:
panic(fmt.Sprintf("unknown address kind: %d", kind))
}
})
if ep == nil {
return tcpip.AddressWithPrefix{}
}
addr := ep.AddressWithPrefix()
// Note that when ep must have a ref count >=2, because its ref count
// must be >=1 in order to be found and the ref count was incremented
// when a reference was acquired. The only way for the ref count to
// drop below 2 is for the endpoint to be removed, which requires a
// write lock; so we're guaranteed to be able to decrement the ref
// count and not need to remove the endpoint from a.primary.
ep.decRefMustNotFree()
return addr
}
// acquirePrimaryAddressRLocked returns an acquired primary address that is
// valid according to isValid.
//
// +checklocksread:a.mu
func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(remoteAddr, srcHint tcpip.Address, isValid func(*addressState) bool) *addressState {
// TODO: Move this out into IPv4-specific code.
// IPv6 handles source IP selection elsewhere. We have to do source
// selection only for IPv4, in which case ep is never deprecated. Thus
// we don't have to worry about refcounts.
if remoteAddr.Len() == header.IPv4AddressSize && remoteAddr != (tcpip.Address{}) {
var best *addressState
var bestLen uint8
for _, state := range a.primary {
if !isValid(state) {
continue
}
// Source hint takes precedent over prefix matching.
if state.addr.Address == srcHint && srcHint != (tcpip.Address{}) {
best = state
break
}
stateLen := state.addr.Address.MatchingPrefix(remoteAddr)
if best == nil || bestLen < stateLen {
best = state
bestLen = stateLen
}
}
if best != nil && best.TryIncRef() {
return best
}
}
var deprecatedEndpoint *addressState
for _, ep := range a.primary {
if !isValid(ep) {
continue
}
if !ep.Deprecated() {
if ep.TryIncRef() {
// ep is not deprecated, so return it immediately.
//
// If we kept track of a deprecated endpoint, decrement its reference
// count since it was incremented when we decided to keep track of it.
if deprecatedEndpoint != nil {
// Note that when deprecatedEndpoint was found, its ref count
// must have necessarily been >=1, and after incrementing it
// must be >=2. The only way for the ref count to drop below 2 is
// for the endpoint to be removed, which requires a write lock;
// so we're guaranteed to be able to decrement the ref count
// and not need to remove the endpoint from a.primary.
deprecatedEndpoint.decRefMustNotFree()
}
return ep
}
} else if deprecatedEndpoint == nil && ep.TryIncRef() {
// We prefer an endpoint that is not deprecated, but we keep track of
// ep in case a doesn't have any non-deprecated endpoints.
//
// If we end up finding a more preferred endpoint, ep's reference count
// will be decremented.
deprecatedEndpoint = ep
}
}
return deprecatedEndpoint
}
// AcquireAssignedAddressOrMatching returns an address endpoint that is
// considered assigned to the addressable endpoint.
//
// If the address is an exact match with an existing address, that address is
// returned. Otherwise, if f is provided, f is called with each address and
// the address that f returns true for is returned.
//
// If there is no matching address, a temporary address will be returned if
// allowTemp is true.
//
// If readOnly is true, the address will be returned without an extra reference.
// In this case it is not safe to modify the endpoint, only read attributes like
// subnet.
//
// Regardless how the address was obtained, it will be acquired before it is
// returned.
func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tcpip.Address, f func(AddressEndpoint) bool, allowTemp bool, tempPEB PrimaryEndpointBehavior, readOnly bool) AddressEndpoint {
lookup := func() *addressState {
if addrState, ok := a.endpoints[localAddr]; ok {
if !addrState.IsAssigned(allowTemp) {
return nil
}
if !readOnly && !addrState.TryIncRef() {
panic(fmt.Sprintf("failed to increase the reference count for address = %s", addrState.addr))
}
return addrState
}
if f != nil {
for _, addrState := range a.endpoints {
if addrState.IsAssigned(allowTemp) && f(addrState) {
if !readOnly && !addrState.TryIncRef() {
continue
}
return addrState
}
}
}
return nil
}
// Avoid exclusive lock on mu unless we need to add a new address.
a.mu.RLock()
ep := lookup()
a.mu.RUnlock()
if ep != nil {
return ep
}
if !allowTemp {
return nil
}
// Acquire state lock in exclusive mode as we need to add a new temporary
// endpoint.
a.mu.Lock()
defer a.mu.Unlock()
// Do the lookup again in case another goroutine added the address in the time
// we released and acquired the lock.
ep = lookup()
if ep != nil {
return ep
}
// Proceed to add a new temporary endpoint.
addr := localAddr.WithPrefix()
ep, err := a.addAndAcquireAddressLocked(addr, AddressProperties{PEB: tempPEB}, Temporary)
if err != nil {
// addAndAcquireAddressLocked only returns an error if the address is
// already assigned but we just checked above if the address exists so we
// expect no error.
panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, AddressProperties{PEB: %s}, false): %s", addr, tempPEB, err))
}
// From https://golang.org/doc/faq#nil_error:
//
// Under the covers, interfaces are implemented as two elements, a type T and
// a value V.
//
// An interface value is nil only if the V and T are both unset, (T=nil, V is
// not set), In particular, a nil interface will always hold a nil type. If we
// store a nil pointer of type *int inside an interface value, the inner type
// will be *int regardless of the value of the pointer: (T=*int, V=nil). Such
// an interface value will therefore be non-nil even when the pointer value V
// inside is nil.
//
// Since addAndAcquireAddressLocked returns a nil value with a non-nil type,
// we need to explicitly return nil below if ep is (a typed) nil.
if ep == nil {
return nil
}
if readOnly {
if ep.addressableEndpointState == a {
// Checklocks doesn't understand that we are logically guaranteed to have
// ep.mu locked already. We need to use checklocksignore to appease the
// analyzer.
ep.addressableEndpointState.decAddressRefLocked(ep) // +checklocksignore
} else {
ep.DecRef()
}
}
return ep
}
// AcquireAssignedAddress implements AddressableEndpoint.
func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior, readOnly bool) AddressEndpoint {
return a.AcquireAssignedAddressOrMatching(localAddr, nil, allowTemp, tempPEB, readOnly)
}
// AcquireOutgoingPrimaryAddress implements AddressableEndpoint.
func (a *AddressableEndpointState) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, srcHint tcpip.Address, allowExpired bool) AddressEndpoint {
a.mu.Lock()
defer a.mu.Unlock()
ep := a.acquirePrimaryAddressRLocked(remoteAddr, srcHint, func(ep *addressState) bool {
return ep.IsAssigned(allowExpired)
})
// From https://golang.org/doc/faq#nil_error:
//
// Under the covers, interfaces are implemented as two elements, a type T and
// a value V.
//
// An interface value is nil only if the V and T are both unset, (T=nil, V is
// not set), In particular, a nil interface will always hold a nil type. If we
// store a nil pointer of type *int inside an interface value, the inner type
// will be *int regardless of the value of the pointer: (T=*int, V=nil). Such
// an interface value will therefore be non-nil even when the pointer value V
// inside is nil.
//
// Since acquirePrimaryAddressLocked returns a nil value with a non-nil type,
// we need to explicitly return nil below if ep is (a typed) nil.
if ep == nil {
return nil
}
return ep
}
// PrimaryAddresses implements AddressableEndpoint.
func (a *AddressableEndpointState) PrimaryAddresses() []tcpip.AddressWithPrefix {
a.mu.RLock()
defer a.mu.RUnlock()
var addrs []tcpip.AddressWithPrefix
if a.options.HiddenWhileDisabled && !a.networkEndpoint.Enabled() {
return addrs
}
for _, ep := range a.primary {
switch kind := ep.GetKind(); kind {
// Don't include tentative, expired or temporary endpoints
// to avoid confusion and prevent the caller from using
// those.
case PermanentTentative, PermanentExpired, Temporary:
continue
case Permanent:
default:
panic(fmt.Sprintf("address %s has unknown kind %d", ep.AddressWithPrefix(), kind))
}
addrs = append(addrs, ep.AddressWithPrefix())
}
return addrs
}
// PermanentAddresses implements AddressableEndpoint.
func (a *AddressableEndpointState) PermanentAddresses() []tcpip.AddressWithPrefix {
a.mu.RLock()
defer a.mu.RUnlock()
var addrs []tcpip.AddressWithPrefix
for _, ep := range a.endpoints {
if !ep.GetKind().IsPermanent() {
continue
}
addrs = append(addrs, ep.AddressWithPrefix())
}
return addrs
}
// Cleanup forcefully leaves all groups and removes all permanent addresses.
func (a *AddressableEndpointState) Cleanup() {
a.mu.Lock()
defer a.mu.Unlock()
for _, ep := range a.endpoints {
// removePermanentEndpointLocked returns *tcpip.ErrBadLocalAddress if ep is
// not a permanent address.
switch err := a.removePermanentEndpointLocked(ep, AddressRemovalInterfaceRemoved); err.(type) {
case nil, *tcpip.ErrBadLocalAddress:
default:
panic(fmt.Sprintf("unexpected error from removePermanentEndpointLocked(%s): %s", ep.addr, err))
}
}
}
var _ AddressEndpoint = (*addressState)(nil)
// addressState holds state for an address.
//
// +stateify savable
type addressState struct {
addressableEndpointState *AddressableEndpointState
addr tcpip.AddressWithPrefix
subnet tcpip.Subnet
temporary bool
// Lock ordering (from outer to inner lock ordering):
//
// AddressableEndpointState.mu
// addressState.mu
mu addressStateRWMutex `state:"nosave"`
refs addressStateRefs
// checklocks:mu
kind AddressKind
// checklocks:mu
configType AddressConfigType
// lifetimes holds this address' lifetimes.
//
// Invariant: if lifetimes.deprecated is true, then lifetimes.PreferredUntil
// must be the zero value. Note that the converse does not need to be
// upheld!
//
// checklocks:mu
lifetimes AddressLifetimes
// The enclosing mutex must be write-locked before calling methods on the
// dispatcher.
//
// checklocks:mu
disp AddressDispatcher
}
// AddressWithPrefix implements AddressEndpoint.
func (a *addressState) AddressWithPrefix() tcpip.AddressWithPrefix {
return a.addr
}
// Subnet implements AddressEndpoint.
func (a *addressState) Subnet() tcpip.Subnet {
return a.subnet
}
// GetKind implements AddressEndpoint.
func (a *addressState) GetKind() AddressKind {
a.mu.RLock()
defer a.mu.RUnlock()
return a.kind
}
// SetKind implements AddressEndpoint.
func (a *addressState) SetKind(kind AddressKind) {
a.mu.Lock()
defer a.mu.Unlock()
prevKind := a.kind
a.kind = kind
if kind == PermanentExpired {
a.notifyRemovedLocked(AddressRemovalManualAction)
} else if prevKind != kind && a.addressableEndpointState.networkEndpoint.Enabled() {
a.notifyChangedLocked()
}
}
// notifyRemovedLocked notifies integrators of address removal.
//
// +checklocks:a.mu
func (a *addressState) notifyRemovedLocked(reason AddressRemovalReason) {
if disp := a.disp; disp != nil {
a.disp.OnRemoved(reason)
a.disp = nil
}
}
func (a *addressState) remove(reason AddressRemovalReason) {
a.mu.Lock()
defer a.mu.Unlock()
a.kind = PermanentExpired
a.notifyRemovedLocked(reason)
}
// IsAssigned implements AddressEndpoint.
func (a *addressState) IsAssigned(allowExpired bool) bool {
switch kind := a.GetKind(); kind {
case PermanentTentative:
return false
case PermanentExpired:
return allowExpired
case Permanent, Temporary:
return true
default:
panic(fmt.Sprintf("address %s has unknown kind %d", a.AddressWithPrefix(), kind))
}
}
// IncRef implements AddressEndpoint.
func (a *addressState) TryIncRef() bool {
return a.refs.TryIncRef()
}
// DecRef implements AddressEndpoint.
func (a *addressState) DecRef() {
a.addressableEndpointState.decAddressRef(a)
}
// decRefMustNotFree decreases the reference count with the guarantee that the
// reference count will be greater than 0 after the decrement.
//
// Panics if the ref count is less than 2 after acquiring the lock in this
// function.
func (a *addressState) decRefMustNotFree() {
a.refs.DecRef(func() {
panic(fmt.Sprintf("cannot decrease addressState %s without freeing the endpoint", a.addr))
})
}
// ConfigType implements AddressEndpoint.
func (a *addressState) ConfigType() AddressConfigType {
a.mu.RLock()
defer a.mu.RUnlock()
return a.configType
}
// notifyChangedLocked notifies integrators of address property changes.
//
// +checklocks:a.mu
func (a *addressState) notifyChangedLocked() {
if a.disp == nil {
return
}
state := AddressDisabled
if a.addressableEndpointState.networkEndpoint.Enabled() {
switch a.kind {
case Permanent:
state = AddressAssigned
case PermanentTentative:
state = AddressTentative
case Temporary, PermanentExpired:
return
default:
panic(fmt.Sprintf("unrecognized address kind = %d", a.kind))
}
}
a.disp.OnChanged(a.lifetimes, state)
}
// SetDeprecated implements AddressEndpoint.
func (a *addressState) SetDeprecated(d bool) {
a.mu.Lock()
defer a.mu.Unlock()
var changed bool
if a.lifetimes.Deprecated != d {
a.lifetimes.Deprecated = d
changed = true
}
if d {
a.lifetimes.PreferredUntil = tcpip.MonotonicTime{}
}
if changed {
a.notifyChangedLocked()
}
}
// Deprecated implements AddressEndpoint.
func (a *addressState) Deprecated() bool {
a.mu.RLock()
defer a.mu.RUnlock()
return a.lifetimes.Deprecated
}
// SetLifetimes implements AddressEndpoint.
func (a *addressState) SetLifetimes(lifetimes AddressLifetimes) {
a.mu.Lock()
defer a.mu.Unlock()
lifetimes.sanitize()
var changed bool
if a.lifetimes != lifetimes {
changed = true
}
a.lifetimes = lifetimes
if changed {
a.notifyChangedLocked()
}
}
// Lifetimes implements AddressEndpoint.
func (a *addressState) Lifetimes() AddressLifetimes {
a.mu.RLock()
defer a.mu.RUnlock()
return a.lifetimes
}
// Temporary implements AddressEndpoint.
func (a *addressState) Temporary() bool {
return a.temporary
}
// RegisterDispatcher implements AddressEndpoint.
func (a *addressState) RegisterDispatcher(disp AddressDispatcher) {
a.mu.Lock()
defer a.mu.Unlock()
if disp != nil {
a.disp = disp
a.notifyChangedLocked()
}
}

View File

@@ -0,0 +1,96 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type addressableEndpointStateRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var addressableEndpointStatelockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type addressableEndpointStatelockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *addressableEndpointStateRWMutex) Lock() {
locking.AddGLock(addressableEndpointStateprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *addressableEndpointStateRWMutex) NestedLock(i addressableEndpointStatelockNameIndex) {
locking.AddGLock(addressableEndpointStateprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *addressableEndpointStateRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(addressableEndpointStateprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *addressableEndpointStateRWMutex) NestedUnlock(i addressableEndpointStatelockNameIndex) {
m.mu.Unlock()
locking.DelGLock(addressableEndpointStateprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *addressableEndpointStateRWMutex) RLock() {
locking.AddGLock(addressableEndpointStateprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *addressableEndpointStateRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(addressableEndpointStateprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *addressableEndpointStateRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *addressableEndpointStateRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *addressableEndpointStateRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var addressableEndpointStateprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func addressableEndpointStateinitLockNames() {}
func init() {
addressableEndpointStateinitLockNames()
addressableEndpointStateprefixIndex = locking.NewMutexClass(reflect.TypeOf(addressableEndpointStateRWMutex{}), addressableEndpointStatelockNames)
}

View File

@@ -0,0 +1,229 @@
// Copyright 2024 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"gvisor.dev/gvisor/pkg/atomicbitops"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
var _ NetworkLinkEndpoint = (*BridgeEndpoint)(nil)
type bridgePort struct {
bridge *BridgeEndpoint
nic *nic
}
// ParseHeader implements stack.LinkEndpoint.
func (p *bridgePort) ParseHeader(pkt *PacketBuffer) bool {
_, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
return ok
}
// DeliverNetworkPacket implements stack.NetworkDispatcher.
func (p *bridgePort) DeliverNetworkPacket(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
bridge := p.bridge
bridge.mu.RLock()
// Send the packet to all other ports.
for _, port := range bridge.ports {
if p == port {
continue
}
newPkt := NewPacketBuffer(PacketBufferOptions{
ReserveHeaderBytes: int(port.nic.MaxHeaderLength()),
Payload: pkt.ToBuffer(),
})
port.nic.writeRawPacket(newPkt)
newPkt.DecRef()
}
d := bridge.dispatcher
bridge.mu.RUnlock()
if d != nil {
// The dispatcher may acquire Stack.mu in DeliverNetworkPacket(), which is
// ordered above bridge.mu. So call DeliverNetworkPacket() without holding
// bridge.mu to avoid circular locking.
d.DeliverNetworkPacket(protocol, pkt)
}
}
func (p *bridgePort) DeliverLinkPacket(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
}
// NewBridgeEndpoint creates a new bridge endpoint.
func NewBridgeEndpoint(mtu uint32) *BridgeEndpoint {
b := &BridgeEndpoint{
mtu: mtu,
addr: tcpip.GetRandMacAddr(),
}
b.ports = make(map[tcpip.NICID]*bridgePort)
return b
}
// BridgeEndpoint is a bridge endpoint.
type BridgeEndpoint struct {
mu bridgeRWMutex
// +checklocks:mu
ports map[tcpip.NICID]*bridgePort
// +checklocks:mu
dispatcher NetworkDispatcher
// +checklocks:mu
addr tcpip.LinkAddress
// +checklocks:mu
attached bool
// +checklocks:mu
mtu uint32
maxHeaderLength atomicbitops.Uint32
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
func (b *BridgeEndpoint) WritePackets(pkts PacketBufferList) (int, tcpip.Error) {
b.mu.RLock()
defer b.mu.RUnlock()
pktsSlice := pkts.AsSlice()
n := len(pktsSlice)
for _, p := range b.ports {
for _, pkt := range pktsSlice {
// In order to properly loop back to the inbound side we must create a
// fresh packet that only contains the underlying payload with no headers
// or struct fields set.
newPkt := NewPacketBuffer(PacketBufferOptions{
Payload: pkt.ToBuffer(),
ReserveHeaderBytes: int(p.nic.MaxHeaderLength()),
})
newPkt.EgressRoute = pkt.EgressRoute
newPkt.NetworkProtocolNumber = pkt.NetworkProtocolNumber
p.nic.writePacket(newPkt)
newPkt.DecRef()
}
}
return n, nil
}
// AddNIC adds the specified NIC to the bridge.
func (b *BridgeEndpoint) AddNIC(n *nic) tcpip.Error {
b.mu.Lock()
defer b.mu.Unlock()
port := &bridgePort{
nic: n,
bridge: b,
}
n.NetworkLinkEndpoint.Attach(port)
b.ports[n.id] = port
if b.maxHeaderLength.Load() < uint32(n.MaxHeaderLength()) {
b.maxHeaderLength.Store(uint32(n.MaxHeaderLength()))
}
return nil
}
// DelNIC remove the specified NIC from the bridge.
func (b *BridgeEndpoint) DelNIC(nic *nic) tcpip.Error {
b.mu.Lock()
defer b.mu.Unlock()
delete(b.ports, nic.id)
nic.NetworkLinkEndpoint.Attach(nic)
return nil
}
// MTU implements stack.LinkEndpoint.MTU.
func (b *BridgeEndpoint) MTU() uint32 {
b.mu.RLock()
defer b.mu.RUnlock()
if b.mtu > header.EthernetMinimumSize {
return b.mtu - header.EthernetMinimumSize
}
return 0
}
// SetMTU implements stack.LinkEndpoint.SetMTU.
func (b *BridgeEndpoint) SetMTU(mtu uint32) {
b.mu.Lock()
defer b.mu.Unlock()
b.mtu = mtu
}
// MaxHeaderLength implements stack.LinkEndpoint.
func (b *BridgeEndpoint) MaxHeaderLength() uint16 {
return uint16(b.maxHeaderLength.Load())
}
// LinkAddress implements stack.LinkEndpoint.LinkAddress.
func (b *BridgeEndpoint) LinkAddress() tcpip.LinkAddress {
b.mu.Lock()
defer b.mu.Unlock()
return b.addr
}
// SetLinkAddress implements stack.LinkEndpoint.SetLinkAddress.
func (b *BridgeEndpoint) SetLinkAddress(addr tcpip.LinkAddress) {
b.mu.Lock()
defer b.mu.Unlock()
b.addr = addr
}
// Capabilities implements stack.LinkEndpoint.Capabilities.
func (b *BridgeEndpoint) Capabilities() LinkEndpointCapabilities {
return CapabilityRXChecksumOffload | CapabilitySaveRestore | CapabilityResolutionRequired
}
// Attach implements stack.LinkEndpoint.Attach.
func (b *BridgeEndpoint) Attach(dispatcher NetworkDispatcher) {
b.mu.Lock()
defer b.mu.Unlock()
for _, p := range b.ports {
p.nic.Primary = nil
}
b.dispatcher = dispatcher
b.ports = make(map[tcpip.NICID]*bridgePort)
}
// IsAttached implements stack.LinkEndpoint.IsAttached.
func (b *BridgeEndpoint) IsAttached() bool {
b.mu.RLock()
defer b.mu.RUnlock()
return b.dispatcher != nil
}
// Wait implements stack.LinkEndpoint.Wait.
func (b *BridgeEndpoint) Wait() {
}
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
func (b *BridgeEndpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareEther
}
// AddHeader implements stack.LinkEndpoint.AddHeader.
func (b *BridgeEndpoint) AddHeader(pkt *PacketBuffer) {
}
// ParseHeader implements stack.LinkEndpoint.ParseHeader.
func (b *BridgeEndpoint) ParseHeader(*PacketBuffer) bool {
return true
}
// Close implements stack.LinkEndpoint.Close.
func (b *BridgeEndpoint) Close() {}
// SetOnCloseAction implements stack.LinkEndpoint.Close.
func (b *BridgeEndpoint) SetOnCloseAction(func()) {}

View File

@@ -0,0 +1,96 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type bridgeRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var bridgelockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type bridgelockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *bridgeRWMutex) Lock() {
locking.AddGLock(bridgeprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *bridgeRWMutex) NestedLock(i bridgelockNameIndex) {
locking.AddGLock(bridgeprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *bridgeRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(bridgeprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *bridgeRWMutex) NestedUnlock(i bridgelockNameIndex) {
m.mu.Unlock()
locking.DelGLock(bridgeprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *bridgeRWMutex) RLock() {
locking.AddGLock(bridgeprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *bridgeRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(bridgeprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *bridgeRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *bridgeRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *bridgeRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var bridgeprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func bridgeinitLockNames() {}
func init() {
bridgeinitLockNames()
bridgeprefixIndex = locking.NewMutexClass(reflect.TypeOf(bridgeRWMutex{}), bridgelockNames)
}

View File

@@ -0,0 +1,98 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type bucketRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var bucketlockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type bucketlockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
const (
bucketLockOthertuple = bucketlockNameIndex(0)
)
const ()
// Lock locks m.
// +checklocksignore
func (m *bucketRWMutex) Lock() {
locking.AddGLock(bucketprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *bucketRWMutex) NestedLock(i bucketlockNameIndex) {
locking.AddGLock(bucketprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *bucketRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(bucketprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *bucketRWMutex) NestedUnlock(i bucketlockNameIndex) {
m.mu.Unlock()
locking.DelGLock(bucketprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *bucketRWMutex) RLock() {
locking.AddGLock(bucketprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *bucketRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(bucketprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *bucketRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *bucketRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *bucketRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var bucketprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func bucketinitLockNames() { bucketlockNames = []string{"otherTuple"} }
func init() {
bucketinitLockNames()
bucketprefixIndex = locking.NewMutexClass(reflect.TypeOf(bucketRWMutex{}), bucketlockNames)
}

View File

@@ -0,0 +1,64 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// Mutex is sync.Mutex with the correctness validator.
type cleanupEndpointsMutex struct {
mu sync.Mutex
}
var cleanupEndpointsprefixIndex *locking.MutexClass
// lockNames is a list of user-friendly lock names.
// Populated in init.
var cleanupEndpointslockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type cleanupEndpointslockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *cleanupEndpointsMutex) Lock() {
locking.AddGLock(cleanupEndpointsprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *cleanupEndpointsMutex) NestedLock(i cleanupEndpointslockNameIndex) {
locking.AddGLock(cleanupEndpointsprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *cleanupEndpointsMutex) Unlock() {
locking.DelGLock(cleanupEndpointsprefixIndex, -1)
m.mu.Unlock()
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *cleanupEndpointsMutex) NestedUnlock(i cleanupEndpointslockNameIndex) {
locking.DelGLock(cleanupEndpointsprefixIndex, int(i))
m.mu.Unlock()
}
// DO NOT REMOVE: The following function is automatically replaced.
func cleanupEndpointsinitLockNames() {}
func init() {
cleanupEndpointsinitLockNames()
cleanupEndpointsprefixIndex = locking.NewMutexClass(reflect.TypeOf(cleanupEndpointsMutex{}), cleanupEndpointslockNames)
}

View File

@@ -0,0 +1,96 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type connRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var connlockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type connlockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *connRWMutex) Lock() {
locking.AddGLock(connprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *connRWMutex) NestedLock(i connlockNameIndex) {
locking.AddGLock(connprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *connRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(connprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *connRWMutex) NestedUnlock(i connlockNameIndex) {
m.mu.Unlock()
locking.DelGLock(connprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *connRWMutex) RLock() {
locking.AddGLock(connprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *connRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(connprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *connRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *connRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *connRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var connprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func conninitLockNames() {}
func init() {
conninitLockNames()
connprefixIndex = locking.NewMutexClass(reflect.TypeOf(connRWMutex{}), connlockNames)
}

View File

@@ -0,0 +1,96 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type connTrackRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var connTracklockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type connTracklockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *connTrackRWMutex) Lock() {
locking.AddGLock(connTrackprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *connTrackRWMutex) NestedLock(i connTracklockNameIndex) {
locking.AddGLock(connTrackprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *connTrackRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(connTrackprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *connTrackRWMutex) NestedUnlock(i connTracklockNameIndex) {
m.mu.Unlock()
locking.DelGLock(connTrackprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *connTrackRWMutex) RLock() {
locking.AddGLock(connTrackprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *connTrackRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(connTrackprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *connTrackRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *connTrackRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *connTrackRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var connTrackprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func connTrackinitLockNames() {}
func init() {
connTrackinitLockNames()
connTrackprefixIndex = locking.NewMutexClass(reflect.TypeOf(connTrackRWMutex{}), connTracklockNames)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,96 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type endpointsByNICRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var endpointsByNIClockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type endpointsByNIClockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *endpointsByNICRWMutex) Lock() {
locking.AddGLock(endpointsByNICprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *endpointsByNICRWMutex) NestedLock(i endpointsByNIClockNameIndex) {
locking.AddGLock(endpointsByNICprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *endpointsByNICRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(endpointsByNICprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *endpointsByNICRWMutex) NestedUnlock(i endpointsByNIClockNameIndex) {
m.mu.Unlock()
locking.DelGLock(endpointsByNICprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *endpointsByNICRWMutex) RLock() {
locking.AddGLock(endpointsByNICprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *endpointsByNICRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(endpointsByNICprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *endpointsByNICRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *endpointsByNICRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *endpointsByNICRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var endpointsByNICprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func endpointsByNICinitLockNames() {}
func init() {
endpointsByNICinitLockNames()
endpointsByNICprefixIndex = locking.NewMutexClass(reflect.TypeOf(endpointsByNICRWMutex{}), endpointsByNIClockNames)
}

View File

@@ -0,0 +1,599 @@
// Copyright 2022 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package gro implements generic receive offload.
package gro
import (
"bytes"
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
// TODO(b/256037250): Enable by default.
// TODO(b/256037250): We parse headers here. We should save those headers in
// PacketBuffers so they don't have to be re-parsed later.
// TODO(b/256037250): I still see the occasional SACK block in the zero-loss
// benchmark, which should not happen.
// TODO(b/256037250): Some dispatchers, e.g. XDP and RecvMmsg, can receive
// multiple packets at a time. Even if the GRO interval is 0, there is an
// opportunity for coalescing.
// TODO(b/256037250): We're doing some header parsing here, which presents the
// opportunity to skip it later.
// TODO(b/256037250): Can we pass a packet list up the stack too?
const (
// groNBuckets is the number of GRO buckets.
groNBuckets = 8
groNBucketsMask = groNBuckets - 1
// groBucketSize is the size of each GRO bucket.
groBucketSize = 8
// groMaxPacketSize is the maximum size of a GRO'd packet.
groMaxPacketSize = 1 << 16 // 65KB.
)
// A groBucket holds packets that are undergoing GRO.
type groBucket struct {
// count is the number of packets in the bucket.
count int
// packets is the linked list of packets.
packets groPacketList
// packetsPrealloc and allocIdxs are used to preallocate and reuse
// groPacket structs and avoid allocation.
packetsPrealloc [groBucketSize]groPacket
allocIdxs [groBucketSize]int
}
func (gb *groBucket) full() bool {
return gb.count == groBucketSize
}
// insert inserts pkt into the bucket.
func (gb *groBucket) insert(pkt *stack.PacketBuffer, ipHdr []byte, tcpHdr header.TCP) {
groPkt := &gb.packetsPrealloc[gb.allocIdxs[gb.count]]
*groPkt = groPacket{
pkt: pkt,
ipHdr: ipHdr,
tcpHdr: tcpHdr,
initialLength: pkt.Data().Size(), // pkt.Data() contains network header.
idx: groPkt.idx,
}
gb.count++
gb.packets.PushBack(groPkt)
}
// removeOldest removes the oldest packet from gb and returns the contained
// PacketBuffer. gb must not be empty.
func (gb *groBucket) removeOldest() *stack.PacketBuffer {
pkt := gb.packets.Front()
gb.packets.Remove(pkt)
gb.count--
gb.allocIdxs[gb.count] = pkt.idx
ret := pkt.pkt
pkt.reset()
return ret
}
// removeOne removes a packet from gb. It also resets pkt to its zero value.
func (gb *groBucket) removeOne(pkt *groPacket) {
gb.packets.Remove(pkt)
gb.count--
gb.allocIdxs[gb.count] = pkt.idx
pkt.reset()
}
// findGROPacket4 returns the groPkt that matches ipHdr and tcpHdr, or nil if
// none exists. It also returns whether the groPkt should be flushed based on
// differences between the two headers.
func (gb *groBucket) findGROPacket4(pkt *stack.PacketBuffer, ipHdr header.IPv4, tcpHdr header.TCP) (*groPacket, bool) {
for groPkt := gb.packets.Front(); groPkt != nil; groPkt = groPkt.Next() {
// Do the addresses match?
groIPHdr := header.IPv4(groPkt.ipHdr)
if ipHdr.SourceAddress() != groIPHdr.SourceAddress() || ipHdr.DestinationAddress() != groIPHdr.DestinationAddress() {
continue
}
// Do the ports match?
if tcpHdr.SourcePort() != groPkt.tcpHdr.SourcePort() || tcpHdr.DestinationPort() != groPkt.tcpHdr.DestinationPort() {
continue
}
// We've found a packet of the same flow.
// IP checks.
TOS, _ := ipHdr.TOS()
groTOS, _ := groIPHdr.TOS()
if ipHdr.TTL() != groIPHdr.TTL() || TOS != groTOS {
return groPkt, true
}
// TCP checks.
if shouldFlushTCP(groPkt, tcpHdr) {
return groPkt, true
}
// There's an upper limit on coalesced packet size.
if pkt.Data().Size()-header.IPv4MinimumSize-int(tcpHdr.DataOffset())+groPkt.pkt.Data().Size() >= groMaxPacketSize {
return groPkt, true
}
return groPkt, false
}
return nil, false
}
// findGROPacket6 returns the groPkt that matches ipHdr and tcpHdr, or nil if
// none exists. It also returns whether the groPkt should be flushed based on
// differences between the two headers.
func (gb *groBucket) findGROPacket6(pkt *stack.PacketBuffer, ipHdr header.IPv6, tcpHdr header.TCP) (*groPacket, bool) {
for groPkt := gb.packets.Front(); groPkt != nil; groPkt = groPkt.Next() {
// Do the addresses match?
groIPHdr := header.IPv6(groPkt.ipHdr)
if ipHdr.SourceAddress() != groIPHdr.SourceAddress() || ipHdr.DestinationAddress() != groIPHdr.DestinationAddress() {
continue
}
// Need to check that headers are the same except:
// - Traffic class, a difference of which causes a flush.
// - Hop limit, a difference of which causes a flush.
// - Length, which is checked later.
// - Version, which is checked by an earlier call to IsValid().
trafficClass, flowLabel := ipHdr.TOS()
groTrafficClass, groFlowLabel := groIPHdr.TOS()
if flowLabel != groFlowLabel || ipHdr.NextHeader() != groIPHdr.NextHeader() {
continue
}
// Unlike IPv4, IPv6 packets with extension headers can be coalesced.
if !bytes.Equal(ipHdr[header.IPv6MinimumSize:], groIPHdr[header.IPv6MinimumSize:]) {
continue
}
// Do the ports match?
if tcpHdr.SourcePort() != groPkt.tcpHdr.SourcePort() || tcpHdr.DestinationPort() != groPkt.tcpHdr.DestinationPort() {
continue
}
// We've found a packet of the same flow.
// TCP checks.
if shouldFlushTCP(groPkt, tcpHdr) {
return groPkt, true
}
// Do the traffic class and hop limit match?
if trafficClass != groTrafficClass || ipHdr.HopLimit() != groIPHdr.HopLimit() {
return groPkt, true
}
// This limit is artificial for IPv6 -- we could allow even
// larger packets via jumbograms.
if pkt.Data().Size()-len(ipHdr)-int(tcpHdr.DataOffset())+groPkt.pkt.Data().Size() >= groMaxPacketSize {
return groPkt, true
}
return groPkt, false
}
return nil, false
}
func (gb *groBucket) found(gd *GRO, groPkt *groPacket, flushGROPkt bool, pkt *stack.PacketBuffer, ipHdr []byte, tcpHdr header.TCP, updateIPHdr func([]byte, int)) {
// Flush groPkt or merge the packets.
pktSize := pkt.Data().Size()
flags := tcpHdr.Flags()
dataOff := tcpHdr.DataOffset()
tcpPayloadSize := pkt.Data().Size() - len(ipHdr) - int(dataOff)
if flushGROPkt {
// Flush the existing GRO packet.
pkt := groPkt.pkt
gb.removeOne(groPkt)
gd.handlePacket(pkt)
pkt.DecRef()
groPkt = nil
} else if groPkt != nil {
// Merge pkt in to GRO packet.
pkt.Data().TrimFront(len(ipHdr) + int(dataOff))
groPkt.pkt.Data().Merge(pkt.Data())
// Update the IP total length.
updateIPHdr(groPkt.ipHdr, tcpPayloadSize)
// Add flags from the packet to the GRO packet.
groPkt.tcpHdr.SetFlags(uint8(groPkt.tcpHdr.Flags() | (flags & (header.TCPFlagFin | header.TCPFlagPsh))))
pkt = nil
}
// Flush if the packet isn't the same size as the previous packets or
// if certain flags are set. The reason for checking size equality is:
// - If the packet is smaller than the others, this is likely the end
// of some message. Peers will send MSS-sized packets until they have
// insufficient data to do so.
// - If the packet is larger than the others, this packet is either
// malformed, a local GSO packet, or has already been handled by host
// GRO.
flush := header.TCPFlags(flags)&(header.TCPFlagUrg|header.TCPFlagPsh|header.TCPFlagRst|header.TCPFlagSyn|header.TCPFlagFin) != 0
flush = flush || tcpPayloadSize == 0
if groPkt != nil {
flush = flush || pktSize != groPkt.initialLength
}
switch {
case flush && groPkt != nil:
// A merge occurred and we need to flush groPkt.
pkt := groPkt.pkt
gb.removeOne(groPkt)
gd.handlePacket(pkt)
pkt.DecRef()
case flush && groPkt == nil:
// No merge occurred and the incoming packet needs to be flushed.
gd.handlePacket(pkt)
case !flush && groPkt == nil:
// New flow and we don't need to flush. Insert pkt into GRO.
if gb.full() {
// Head is always the oldest packet
toFlush := gb.removeOldest()
gb.insert(pkt.IncRef(), ipHdr, tcpHdr)
gd.handlePacket(toFlush)
toFlush.DecRef()
} else {
gb.insert(pkt.IncRef(), ipHdr, tcpHdr)
}
default:
// A merge occurred and we don't need to flush anything.
}
}
// A groPacket is packet undergoing GRO. It may be several packets coalesced
// together.
type groPacket struct {
// groPacketEntry is an intrusive list.
groPacketEntry
// pkt is the coalesced packet.
pkt *stack.PacketBuffer
// ipHdr is the IP (v4 or v6) header for the coalesced packet.
ipHdr []byte
// tcpHdr is the TCP header for the coalesced packet.
tcpHdr header.TCP
// initialLength is the length of the first packet in the flow. It is
// used as a best-effort guess at MSS: senders will send MSS-sized
// packets until they run out of data, so we coalesce as long as
// packets are the same size.
initialLength int
// idx is the groPacket's index in its bucket packetsPrealloc. It is
// immutable.
idx int
}
// reset resets all mutable fields of the groPacket.
func (pk *groPacket) reset() {
*pk = groPacket{
idx: pk.idx,
}
}
// payloadSize is the payload size of the coalesced packet, which does not
// include the network or transport headers.
func (pk *groPacket) payloadSize() int {
return pk.pkt.Data().Size() - len(pk.ipHdr) - int(pk.tcpHdr.DataOffset())
}
// GRO coalesces incoming packets to increase throughput.
type GRO struct {
enabled bool
buckets [groNBuckets]groBucket
Dispatcher stack.NetworkDispatcher
}
// Init initializes GRO.
func (gd *GRO) Init(enabled bool) {
gd.enabled = enabled
for i := range gd.buckets {
bucket := &gd.buckets[i]
for j := range bucket.packetsPrealloc {
bucket.allocIdxs[j] = j
bucket.packetsPrealloc[j].idx = j
}
}
}
// Enqueue the packet in GRO. This does not flush packets; Flush() must be
// called explicitly for that.
//
// pkt.NetworkProtocolNumber and pkt.RXChecksumValidated must be set.
func (gd *GRO) Enqueue(pkt *stack.PacketBuffer) {
if !gd.enabled {
gd.handlePacket(pkt)
return
}
switch pkt.NetworkProtocolNumber {
case header.IPv4ProtocolNumber:
gd.dispatch4(pkt)
case header.IPv6ProtocolNumber:
gd.dispatch6(pkt)
default:
gd.handlePacket(pkt)
}
}
func (gd *GRO) dispatch4(pkt *stack.PacketBuffer) {
// Immediately get the IPv4 and TCP headers. We need a way to hash the
// packet into its bucket, which requires addresses and ports. Linux
// simply gets a hash passed by hardware, but we're not so lucky.
// We only GRO TCP packets. The check for the transport protocol number
// is done below so that we can PullUp both the IP and TCP headers
// together.
hdrBytes, ok := pkt.Data().PullUp(header.IPv4MinimumSize + header.TCPMinimumSize)
if !ok {
gd.handlePacket(pkt)
return
}
ipHdr := header.IPv4(hdrBytes)
// We don't handle fragments. That should be the vast majority of
// traffic, and simplifies handling.
if ipHdr.FragmentOffset() != 0 || ipHdr.Flags()&header.IPv4FlagMoreFragments != 0 {
gd.handlePacket(pkt)
return
}
// We only handle TCP packets without IP options.
if ipHdr.HeaderLength() != header.IPv4MinimumSize || tcpip.TransportProtocolNumber(ipHdr.Protocol()) != header.TCPProtocolNumber {
gd.handlePacket(pkt)
return
}
tcpHdr := header.TCP(hdrBytes[header.IPv4MinimumSize:])
ipHdr = ipHdr[:header.IPv4MinimumSize]
dataOff := tcpHdr.DataOffset()
if dataOff < header.TCPMinimumSize {
// Malformed packet: will be handled further up the stack.
gd.handlePacket(pkt)
return
}
hdrBytes, ok = pkt.Data().PullUp(header.IPv4MinimumSize + int(dataOff))
if !ok {
// Malformed packet: will be handled further up the stack.
gd.handlePacket(pkt)
return
}
tcpHdr = header.TCP(hdrBytes[header.IPv4MinimumSize:])
// If either checksum is bad, flush the packet. Since we don't know
// what bits were flipped, we can't identify this packet with a flow.
if !pkt.RXChecksumValidated {
if !ipHdr.IsValid(pkt.Data().Size()) || !ipHdr.IsChecksumValid() {
gd.handlePacket(pkt)
return
}
payloadChecksum := pkt.Data().ChecksumAtOffset(header.IPv4MinimumSize + int(dataOff))
tcpPayloadSize := pkt.Data().Size() - header.IPv4MinimumSize - int(dataOff)
if !tcpHdr.IsChecksumValid(ipHdr.SourceAddress(), ipHdr.DestinationAddress(), payloadChecksum, uint16(tcpPayloadSize)) {
gd.handlePacket(pkt)
return
}
// We've validated the checksum, no reason for others to do it
// again.
pkt.RXChecksumValidated = true
}
// Now we can get the bucket for the packet.
bucket := &gd.buckets[gd.bucketForPacket4(ipHdr, tcpHdr)&groNBucketsMask]
groPkt, flushGROPkt := bucket.findGROPacket4(pkt, ipHdr, tcpHdr)
bucket.found(gd, groPkt, flushGROPkt, pkt, ipHdr, tcpHdr, updateIPv4Hdr)
}
func (gd *GRO) dispatch6(pkt *stack.PacketBuffer) {
// Immediately get the IPv6 and TCP headers. We need a way to hash the
// packet into its bucket, which requires addresses and ports. Linux
// simply gets a hash passed by hardware, but we're not so lucky.
hdrBytes, ok := pkt.Data().PullUp(header.IPv6MinimumSize)
if !ok {
gd.handlePacket(pkt)
return
}
ipHdr := header.IPv6(hdrBytes)
// Getting the IP header (+ extension headers) size is a bit of a pain
// on IPv6.
transProto := tcpip.TransportProtocolNumber(ipHdr.NextHeader())
buf := pkt.Data().ToBuffer()
buf.TrimFront(header.IPv6MinimumSize)
it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(transProto), buf)
ipHdrSize := int(header.IPv6MinimumSize)
for {
transProto = tcpip.TransportProtocolNumber(it.NextHeaderIdentifier())
extHdr, done, err := it.Next()
if err != nil {
gd.handlePacket(pkt)
return
}
if done {
break
}
switch extHdr.(type) {
// We can GRO these, so just skip over them.
case header.IPv6HopByHopOptionsExtHdr:
case header.IPv6RoutingExtHdr:
case header.IPv6DestinationOptionsExtHdr:
default:
// This is either a TCP header or something we can't handle.
ipHdrSize = int(it.HeaderOffset())
done = true
}
extHdr.Release()
if done {
break
}
}
hdrBytes, ok = pkt.Data().PullUp(ipHdrSize + header.TCPMinimumSize)
if !ok {
gd.handlePacket(pkt)
return
}
ipHdr = header.IPv6(hdrBytes[:ipHdrSize])
// We only handle TCP packets.
if transProto != header.TCPProtocolNumber {
gd.handlePacket(pkt)
return
}
tcpHdr := header.TCP(hdrBytes[ipHdrSize:])
dataOff := tcpHdr.DataOffset()
if dataOff < header.TCPMinimumSize {
// Malformed packet: will be handled further up the stack.
gd.handlePacket(pkt)
return
}
hdrBytes, ok = pkt.Data().PullUp(ipHdrSize + int(dataOff))
if !ok {
// Malformed packet: will be handled further up the stack.
gd.handlePacket(pkt)
return
}
tcpHdr = header.TCP(hdrBytes[ipHdrSize:])
// If either checksum is bad, flush the packet. Since we don't know
// what bits were flipped, we can't identify this packet with a flow.
if !pkt.RXChecksumValidated {
if !ipHdr.IsValid(pkt.Data().Size()) {
gd.handlePacket(pkt)
return
}
payloadChecksum := pkt.Data().ChecksumAtOffset(ipHdrSize + int(dataOff))
tcpPayloadSize := pkt.Data().Size() - ipHdrSize - int(dataOff)
if !tcpHdr.IsChecksumValid(ipHdr.SourceAddress(), ipHdr.DestinationAddress(), payloadChecksum, uint16(tcpPayloadSize)) {
gd.handlePacket(pkt)
return
}
// We've validated the checksum, no reason for others to do it
// again.
pkt.RXChecksumValidated = true
}
// Now we can get the bucket for the packet.
bucket := &gd.buckets[gd.bucketForPacket6(ipHdr, tcpHdr)&groNBucketsMask]
groPkt, flushGROPkt := bucket.findGROPacket6(pkt, ipHdr, tcpHdr)
bucket.found(gd, groPkt, flushGROPkt, pkt, ipHdr, tcpHdr, updateIPv6Hdr)
}
func (gd *GRO) bucketForPacket4(ipHdr header.IPv4, tcpHdr header.TCP) int {
// TODO(b/256037250): Use jenkins or checksum. Write a test to print
// distribution.
var sum int
srcAddr := ipHdr.SourceAddress()
for _, val := range srcAddr.AsSlice() {
sum += int(val)
}
dstAddr := ipHdr.DestinationAddress()
for _, val := range dstAddr.AsSlice() {
sum += int(val)
}
sum += int(tcpHdr.SourcePort())
sum += int(tcpHdr.DestinationPort())
return sum
}
func (gd *GRO) bucketForPacket6(ipHdr header.IPv6, tcpHdr header.TCP) int {
// TODO(b/256037250): Use jenkins or checksum. Write a test to print
// distribution.
var sum int
srcAddr := ipHdr.SourceAddress()
for _, val := range srcAddr.AsSlice() {
sum += int(val)
}
dstAddr := ipHdr.DestinationAddress()
for _, val := range dstAddr.AsSlice() {
sum += int(val)
}
sum += int(tcpHdr.SourcePort())
sum += int(tcpHdr.DestinationPort())
return sum
}
// Flush sends all packets up the stack.
func (gd *GRO) Flush() {
for i := range gd.buckets {
for groPkt := gd.buckets[i].packets.Front(); groPkt != nil; groPkt = groPkt.Next() {
pkt := groPkt.pkt
gd.buckets[i].removeOne(groPkt)
gd.handlePacket(pkt)
pkt.DecRef()
}
}
}
func (gd *GRO) handlePacket(pkt *stack.PacketBuffer) {
gd.Dispatcher.DeliverNetworkPacket(pkt.NetworkProtocolNumber, pkt)
}
// String implements fmt.Stringer.
func (gd *GRO) String() string {
ret := "GRO state: \n"
for i := range gd.buckets {
bucket := &gd.buckets[i]
ret += fmt.Sprintf("bucket %d: %d packets: ", i, bucket.count)
for groPkt := bucket.packets.Front(); groPkt != nil; groPkt = groPkt.Next() {
ret += fmt.Sprintf("%d, ", groPkt.pkt.Data().Size())
}
ret += "\n"
}
return ret
}
// shouldFlushTCP returns whether the TCP headers indicate that groPkt should
// be flushed
func shouldFlushTCP(groPkt *groPacket, tcpHdr header.TCP) bool {
flags := tcpHdr.Flags()
groPktFlags := groPkt.tcpHdr.Flags()
dataOff := tcpHdr.DataOffset()
if flags&header.TCPFlagCwr != 0 || // Is congestion control occurring?
(flags^groPktFlags)&^(header.TCPFlagCwr|header.TCPFlagFin|header.TCPFlagPsh) != 0 || // Do the flags differ besides CRW, FIN, and PSH?
tcpHdr.AckNumber() != groPkt.tcpHdr.AckNumber() || // Do the ACKs match?
dataOff != groPkt.tcpHdr.DataOffset() || // Are the TCP headers the same length?
groPkt.tcpHdr.SequenceNumber()+uint32(groPkt.payloadSize()) != tcpHdr.SequenceNumber() { // Does the incoming packet match the expected sequence number?
return true
}
// The options, including timestamps, must be identical.
return !bytes.Equal(tcpHdr[header.TCPMinimumSize:], groPkt.tcpHdr[header.TCPMinimumSize:])
}
func updateIPv4Hdr(ipHdrBytes []byte, newBytes int) {
ipHdr := header.IPv4(ipHdrBytes)
ipHdr.SetTotalLength(ipHdr.TotalLength() + uint16(newBytes))
}
func updateIPv6Hdr(ipHdrBytes []byte, newBytes int) {
ipHdr := header.IPv6(ipHdrBytes)
ipHdr.SetPayloadLength(ipHdr.PayloadLength() + uint16(newBytes))
}

View File

@@ -0,0 +1,239 @@
package gro
// ElementMapper provides an identity mapping by default.
//
// This can be replaced to provide a struct that maps elements to linker
// objects, if they are not the same. An ElementMapper is not typically
// required if: Linker is left as is, Element is left as is, or Linker and
// Element are the same type.
type groPacketElementMapper struct{}
// linkerFor maps an Element to a Linker.
//
// This default implementation should be inlined.
//
//go:nosplit
func (groPacketElementMapper) linkerFor(elem *groPacket) *groPacket { return elem }
// List is an intrusive list. Entries can be added to or removed from the list
// in O(1) time and with no additional memory allocations.
//
// The zero value for List is an empty list ready to use.
//
// To iterate over a list (where l is a List):
//
// for e := l.Front(); e != nil; e = e.Next() {
// // do something with e.
// }
//
// +stateify savable
type groPacketList struct {
head *groPacket
tail *groPacket
}
// Reset resets list l to the empty state.
func (l *groPacketList) Reset() {
l.head = nil
l.tail = nil
}
// Empty returns true iff the list is empty.
//
//go:nosplit
func (l *groPacketList) Empty() bool {
return l.head == nil
}
// Front returns the first element of list l or nil.
//
//go:nosplit
func (l *groPacketList) Front() *groPacket {
return l.head
}
// Back returns the last element of list l or nil.
//
//go:nosplit
func (l *groPacketList) Back() *groPacket {
return l.tail
}
// Len returns the number of elements in the list.
//
// NOTE: This is an O(n) operation.
//
//go:nosplit
func (l *groPacketList) Len() (count int) {
for e := l.Front(); e != nil; e = (groPacketElementMapper{}.linkerFor(e)).Next() {
count++
}
return count
}
// PushFront inserts the element e at the front of list l.
//
//go:nosplit
func (l *groPacketList) PushFront(e *groPacket) {
linker := groPacketElementMapper{}.linkerFor(e)
linker.SetNext(l.head)
linker.SetPrev(nil)
if l.head != nil {
groPacketElementMapper{}.linkerFor(l.head).SetPrev(e)
} else {
l.tail = e
}
l.head = e
}
// PushFrontList inserts list m at the start of list l, emptying m.
//
//go:nosplit
func (l *groPacketList) PushFrontList(m *groPacketList) {
if l.head == nil {
l.head = m.head
l.tail = m.tail
} else if m.head != nil {
groPacketElementMapper{}.linkerFor(l.head).SetPrev(m.tail)
groPacketElementMapper{}.linkerFor(m.tail).SetNext(l.head)
l.head = m.head
}
m.head = nil
m.tail = nil
}
// PushBack inserts the element e at the back of list l.
//
//go:nosplit
func (l *groPacketList) PushBack(e *groPacket) {
linker := groPacketElementMapper{}.linkerFor(e)
linker.SetNext(nil)
linker.SetPrev(l.tail)
if l.tail != nil {
groPacketElementMapper{}.linkerFor(l.tail).SetNext(e)
} else {
l.head = e
}
l.tail = e
}
// PushBackList inserts list m at the end of list l, emptying m.
//
//go:nosplit
func (l *groPacketList) PushBackList(m *groPacketList) {
if l.head == nil {
l.head = m.head
l.tail = m.tail
} else if m.head != nil {
groPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head)
groPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
l.tail = m.tail
}
m.head = nil
m.tail = nil
}
// InsertAfter inserts e after b.
//
//go:nosplit
func (l *groPacketList) InsertAfter(b, e *groPacket) {
bLinker := groPacketElementMapper{}.linkerFor(b)
eLinker := groPacketElementMapper{}.linkerFor(e)
a := bLinker.Next()
eLinker.SetNext(a)
eLinker.SetPrev(b)
bLinker.SetNext(e)
if a != nil {
groPacketElementMapper{}.linkerFor(a).SetPrev(e)
} else {
l.tail = e
}
}
// InsertBefore inserts e before a.
//
//go:nosplit
func (l *groPacketList) InsertBefore(a, e *groPacket) {
aLinker := groPacketElementMapper{}.linkerFor(a)
eLinker := groPacketElementMapper{}.linkerFor(e)
b := aLinker.Prev()
eLinker.SetNext(a)
eLinker.SetPrev(b)
aLinker.SetPrev(e)
if b != nil {
groPacketElementMapper{}.linkerFor(b).SetNext(e)
} else {
l.head = e
}
}
// Remove removes e from l.
//
//go:nosplit
func (l *groPacketList) Remove(e *groPacket) {
linker := groPacketElementMapper{}.linkerFor(e)
prev := linker.Prev()
next := linker.Next()
if prev != nil {
groPacketElementMapper{}.linkerFor(prev).SetNext(next)
} else if l.head == e {
l.head = next
}
if next != nil {
groPacketElementMapper{}.linkerFor(next).SetPrev(prev)
} else if l.tail == e {
l.tail = prev
}
linker.SetNext(nil)
linker.SetPrev(nil)
}
// Entry is a default implementation of Linker. Users can add anonymous fields
// of this type to their structs to make them automatically implement the
// methods needed by List.
//
// +stateify savable
type groPacketEntry struct {
next *groPacket
prev *groPacket
}
// Next returns the entry that follows e in the list.
//
//go:nosplit
func (e *groPacketEntry) Next() *groPacket {
return e.next
}
// Prev returns the entry that precedes e in the list.
//
//go:nosplit
func (e *groPacketEntry) Prev() *groPacket {
return e.prev
}
// SetNext assigns 'entry' as the entry that follows e in the list.
//
//go:nosplit
func (e *groPacketEntry) SetNext(elem *groPacket) {
e.next = elem
}
// SetPrev assigns 'entry' as the entry that precedes e in the list.
//
//go:nosplit
func (e *groPacketEntry) SetPrev(elem *groPacket) {
e.prev = elem
}

View File

@@ -0,0 +1,70 @@
// automatically generated by stateify.
package gro
import (
"context"
"gvisor.dev/gvisor/pkg/state"
)
func (l *groPacketList) StateTypeName() string {
return "pkg/tcpip/stack/gro.groPacketList"
}
func (l *groPacketList) StateFields() []string {
return []string{
"head",
"tail",
}
}
func (l *groPacketList) beforeSave() {}
// +checklocksignore
func (l *groPacketList) StateSave(stateSinkObject state.Sink) {
l.beforeSave()
stateSinkObject.Save(0, &l.head)
stateSinkObject.Save(1, &l.tail)
}
func (l *groPacketList) afterLoad(context.Context) {}
// +checklocksignore
func (l *groPacketList) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &l.head)
stateSourceObject.Load(1, &l.tail)
}
func (e *groPacketEntry) StateTypeName() string {
return "pkg/tcpip/stack/gro.groPacketEntry"
}
func (e *groPacketEntry) StateFields() []string {
return []string{
"next",
"prev",
}
}
func (e *groPacketEntry) beforeSave() {}
// +checklocksignore
func (e *groPacketEntry) StateSave(stateSinkObject state.Sink) {
e.beforeSave()
stateSinkObject.Save(0, &e.next)
stateSinkObject.Save(1, &e.prev)
}
func (e *groPacketEntry) afterLoad(context.Context) {}
// +checklocksignore
func (e *groPacketEntry) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &e.next)
stateSourceObject.Load(1, &e.prev)
}
func init() {
state.Register((*groPacketList)(nil))
state.Register((*groPacketEntry)(nil))
}

View File

@@ -0,0 +1,40 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at //
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Code generated by "stringer -type headerType ."; DO NOT EDIT.
package stack
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[virtioNetHeader-0]
_ = x[linkHeader-1]
_ = x[networkHeader-2]
_ = x[transportHeader-3]
_ = x[numHeaderType-4]
}
const _headerType_name = "virtioNetHeaderlinkHeadernetworkHeadertransportHeadernumHeaderType"
var _headerType_index = [...]uint8{0, 10, 23, 38, 51}
func (i headerType) String() string {
if i < 0 || i >= headerType(len(_headerType_index)-1) {
return "headerType(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _headerType_name[_headerType_index[i]:_headerType_index[i+1]]
}

View File

@@ -0,0 +1,41 @@
// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at //
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Code generated by "stringer -type Hook ."; DO NOT EDIT.
package stack
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[Prerouting-0]
_ = x[Input-1]
_ = x[Forward-2]
_ = x[Output-3]
_ = x[Postrouting-4]
_ = x[NumHooks-5]
}
const _Hook_name = "PreroutingInputForwardOutputPostroutingNumHooks"
var _Hook_index = [...]uint8{0, 10, 15, 22, 28, 39, 47}
func (i Hook) String() string {
if i >= Hook(len(_Hook_index)-1) {
return "Hook(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _Hook_name[_Hook_index[i]:_Hook_index[i+1]]
}

View File

@@ -0,0 +1,75 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"golang.org/x/time/rate"
"gvisor.dev/gvisor/pkg/tcpip"
)
const (
// icmpLimit is the default maximum number of ICMP messages permitted by this
// rate limiter.
icmpLimit = 1000
// icmpBurst is the default number of ICMP messages that can be sent in a single
// burst.
icmpBurst = 50
)
// ICMPRateLimiter is a global rate limiter that controls the generation of
// ICMP messages generated by the stack.
//
// +stateify savable
type ICMPRateLimiter struct {
// TODO(b/341946753): Restore when netstack is savable.
limiter *rate.Limiter `state:"nosave"`
clock tcpip.Clock
}
// NewICMPRateLimiter returns a global rate limiter for controlling the rate
// at which ICMP messages are generated by the stack. The returned limiter
// does not apply limits to any ICMP types by default.
func NewICMPRateLimiter(clock tcpip.Clock) *ICMPRateLimiter {
return &ICMPRateLimiter{
clock: clock,
limiter: rate.NewLimiter(icmpLimit, icmpBurst),
}
}
// SetLimit sets a new Limit for the limiter.
func (l *ICMPRateLimiter) SetLimit(limit rate.Limit) {
l.limiter.SetLimitAt(l.clock.Now(), limit)
}
// Limit returns the maximum overall event rate.
func (l *ICMPRateLimiter) Limit() rate.Limit {
return l.limiter.Limit()
}
// SetBurst sets a new burst size for the limiter.
func (l *ICMPRateLimiter) SetBurst(burst int) {
l.limiter.SetBurstAt(l.clock.Now(), burst)
}
// Burst returns the maximum burst size.
func (l *ICMPRateLimiter) Burst() int {
return l.limiter.Burst()
}
// Allow reports whether one ICMP message may be sent now.
func (l *ICMPRateLimiter) Allow() bool {
return l.limiter.AllowN(l.clock.Now(), 1)
}

View File

@@ -0,0 +1,717 @@
// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"context"
"fmt"
"math/rand"
"reflect"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
// TableID identifies a specific table.
type TableID int
// Each value identifies a specific table.
const (
NATID TableID = iota
MangleID
FilterID
NumTables
)
// HookUnset indicates that there is no hook set for an entrypoint or
// underflow.
const HookUnset = -1
// reaperDelay is how long to wait before starting to reap connections.
const reaperDelay = 5 * time.Second
// DefaultTables returns a default set of tables. Each chain is set to accept
// all packets.
func DefaultTables(clock tcpip.Clock, rand *rand.Rand) *IPTables {
return &IPTables{
v4Tables: [NumTables]Table{
NATID: {
Rules: []Rule{
{Filter: EmptyFilter4(), Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
{Filter: EmptyFilter4(), Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
{Filter: EmptyFilter4(), Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
{Filter: EmptyFilter4(), Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
{Filter: EmptyFilter4(), Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: 0,
Input: 1,
Forward: HookUnset,
Output: 2,
Postrouting: 3,
},
Underflows: [NumHooks]int{
Prerouting: 0,
Input: 1,
Forward: HookUnset,
Output: 2,
Postrouting: 3,
},
},
MangleID: {
Rules: []Rule{
{Filter: EmptyFilter4(), Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
{Filter: EmptyFilter4(), Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
{Filter: EmptyFilter4(), Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: 0,
Output: 1,
},
Underflows: [NumHooks]int{
Prerouting: 0,
Input: HookUnset,
Forward: HookUnset,
Output: 1,
Postrouting: HookUnset,
},
},
FilterID: {
Rules: []Rule{
{Filter: EmptyFilter4(), Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
{Filter: EmptyFilter4(), Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
{Filter: EmptyFilter4(), Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
{Filter: EmptyFilter4(), Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: HookUnset,
Input: 0,
Forward: 1,
Output: 2,
Postrouting: HookUnset,
},
Underflows: [NumHooks]int{
Prerouting: HookUnset,
Input: 0,
Forward: 1,
Output: 2,
Postrouting: HookUnset,
},
},
},
v6Tables: [NumTables]Table{
NATID: {
Rules: []Rule{
{Filter: EmptyFilter6(), Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
{Filter: EmptyFilter6(), Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
{Filter: EmptyFilter6(), Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
{Filter: EmptyFilter6(), Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
{Filter: EmptyFilter6(), Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: 0,
Input: 1,
Forward: HookUnset,
Output: 2,
Postrouting: 3,
},
Underflows: [NumHooks]int{
Prerouting: 0,
Input: 1,
Forward: HookUnset,
Output: 2,
Postrouting: 3,
},
},
MangleID: {
Rules: []Rule{
{Filter: EmptyFilter6(), Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
{Filter: EmptyFilter6(), Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
{Filter: EmptyFilter6(), Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: 0,
Output: 1,
},
Underflows: [NumHooks]int{
Prerouting: 0,
Input: HookUnset,
Forward: HookUnset,
Output: 1,
Postrouting: HookUnset,
},
},
FilterID: {
Rules: []Rule{
{Filter: EmptyFilter6(), Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
{Filter: EmptyFilter6(), Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
{Filter: EmptyFilter6(), Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
{Filter: EmptyFilter6(), Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: HookUnset,
Input: 0,
Forward: 1,
Output: 2,
Postrouting: HookUnset,
},
Underflows: [NumHooks]int{
Prerouting: HookUnset,
Input: 0,
Forward: 1,
Output: 2,
Postrouting: HookUnset,
},
},
},
connections: ConnTrack{
seed: rand.Uint32(),
clock: clock,
rand: rand,
},
}
}
// EmptyFilterTable returns a Table with no rules and the filter table chains
// mapped to HookUnset.
func EmptyFilterTable() Table {
return Table{
Rules: []Rule{},
BuiltinChains: [NumHooks]int{
Prerouting: HookUnset,
Postrouting: HookUnset,
},
Underflows: [NumHooks]int{
Prerouting: HookUnset,
Postrouting: HookUnset,
},
}
}
// EmptyNATTable returns a Table with no rules and the filter table chains
// mapped to HookUnset.
func EmptyNATTable() Table {
return Table{
Rules: []Rule{},
BuiltinChains: [NumHooks]int{
Forward: HookUnset,
},
Underflows: [NumHooks]int{
Forward: HookUnset,
},
}
}
// GetTable returns a table with the given id and IP version. It panics when an
// invalid id is provided.
func (it *IPTables) GetTable(id TableID, ipv6 bool) Table {
it.mu.RLock()
defer it.mu.RUnlock()
return it.getTableRLocked(id, ipv6)
}
// +checklocksread:it.mu
func (it *IPTables) getTableRLocked(id TableID, ipv6 bool) Table {
if ipv6 {
return it.v6Tables[id]
}
return it.v4Tables[id]
}
// ReplaceTable replaces or inserts table by name. It panics when an invalid id
// is provided.
func (it *IPTables) ReplaceTable(id TableID, table Table, ipv6 bool) {
it.replaceTable(id, table, ipv6, false /* force */)
}
// ForceReplaceTable replaces or inserts table by name. It panics when an invalid id
// is provided. It enables iptables even when the inserted table is all
// conditionless ACCEPT, skipping our optimization that disables iptables until
// they're modified.
func (it *IPTables) ForceReplaceTable(id TableID, table Table, ipv6 bool) {
it.replaceTable(id, table, ipv6, true /* force */)
}
func (it *IPTables) replaceTable(id TableID, table Table, ipv6, force bool) {
it.mu.Lock()
defer it.mu.Unlock()
// If iptables is being enabled, initialize the conntrack table and
// reaper.
if !it.modified {
// Don't do anything if the table is identical.
if ((ipv6 && reflect.DeepEqual(table, it.v6Tables[id])) || (!ipv6 && reflect.DeepEqual(table, it.v4Tables[id]))) && !force {
return
}
it.connections.init()
it.startReaper(reaperDelay)
}
it.modified = true
if ipv6 {
it.v6Tables[id] = table
} else {
it.v4Tables[id] = table
}
}
// A chainVerdict is what a table decides should be done with a packet.
type chainVerdict int
const (
// chainAccept indicates the packet should continue through netstack.
chainAccept chainVerdict = iota
// chainDrop indicates the packet should be dropped.
chainDrop
// chainReturn indicates the packet should return to the calling chain
// or the underflow rule of a builtin chain.
chainReturn
)
type checkTable struct {
fn checkTableFn
tableID TableID
table Table
}
// shouldSkipOrPopulateTables returns true iff IPTables should be skipped.
//
// If IPTables should not be skipped, tables will be updated with the
// specified table.
//
// This is called in the hot path even when iptables are disabled, so we ensure
// it does not allocate. We check recursively for heap allocations, but not for:
// - Stack splitting, which can allocate.
// - Calls to interfaces, which can allocate.
// - Calls to dynamic functions, which can allocate.
//
// +checkescape:hard
func (it *IPTables) shouldSkipOrPopulateTables(tables []checkTable, pkt *PacketBuffer) bool {
switch pkt.NetworkProtocolNumber {
case header.IPv4ProtocolNumber, header.IPv6ProtocolNumber:
default:
// IPTables only supports IPv4/IPv6.
return true
}
it.mu.RLock()
defer it.mu.RUnlock()
if !it.modified {
// Many users never configure iptables. Spare them the cost of rule
// traversal if rules have never been set.
return true
}
for i := range tables {
table := &tables[i]
table.table = it.getTableRLocked(table.tableID, pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber)
}
return false
}
// CheckPrerouting performs the prerouting hook on the packet.
//
// Returns true iff the packet may continue traversing the stack; the packet
// must be dropped if false is returned.
//
// Precondition: The packet's network and transport header must be set.
//
// This is called in the hot path even when iptables are disabled, so we ensure
// that it does not allocate. Note that called functions (e.g.
// getConnAndUpdate) can allocate.
// TODO(b/233951539): checkescape fails on arm sometimes. Fix and re-add.
func (it *IPTables) CheckPrerouting(pkt *PacketBuffer, addressEP AddressableEndpoint, inNicName string) bool {
tables := [...]checkTable{
{
fn: check,
tableID: MangleID,
},
{
fn: checkNAT,
tableID: NATID,
},
}
if it.shouldSkipOrPopulateTables(tables[:], pkt) {
return true
}
pkt.tuple = it.connections.getConnAndUpdate(pkt, false /* skipChecksumValidation */)
for _, table := range tables {
if !table.fn(it, table.table, Prerouting, pkt, nil /* route */, addressEP, inNicName, "" /* outNicName */) {
return false
}
}
return true
}
// CheckInput performs the input hook on the packet.
//
// Returns true iff the packet may continue traversing the stack; the packet
// must be dropped if false is returned.
//
// Precondition: The packet's network and transport header must be set.
//
// This is called in the hot path even when iptables are disabled, so we ensure
// that it does not allocate. Note that called functions (e.g.
// getConnAndUpdate) can allocate.
// TODO(b/233951539): checkescape fails on arm sometimes. Fix and re-add.
func (it *IPTables) CheckInput(pkt *PacketBuffer, inNicName string) bool {
tables := [...]checkTable{
{
fn: checkNAT,
tableID: NATID,
},
{
fn: check,
tableID: FilterID,
},
}
if it.shouldSkipOrPopulateTables(tables[:], pkt) {
return true
}
for _, table := range tables {
if !table.fn(it, table.table, Input, pkt, nil /* route */, nil /* addressEP */, inNicName, "" /* outNicName */) {
return false
}
}
if t := pkt.tuple; t != nil {
pkt.tuple = nil
return t.conn.finalize()
}
return true
}
// CheckForward performs the forward hook on the packet.
//
// Returns true iff the packet may continue traversing the stack; the packet
// must be dropped if false is returned.
//
// Precondition: The packet's network and transport header must be set.
//
// This is called in the hot path even when iptables are disabled, so we ensure
// that it does not allocate. Note that called functions (e.g.
// getConnAndUpdate) can allocate.
// TODO(b/233951539): checkescape fails on arm sometimes. Fix and re-add.
func (it *IPTables) CheckForward(pkt *PacketBuffer, inNicName, outNicName string) bool {
tables := [...]checkTable{
{
fn: check,
tableID: FilterID,
},
}
if it.shouldSkipOrPopulateTables(tables[:], pkt) {
return true
}
for _, table := range tables {
if !table.fn(it, table.table, Forward, pkt, nil /* route */, nil /* addressEP */, inNicName, outNicName) {
return false
}
}
return true
}
// CheckOutput performs the output hook on the packet.
//
// Returns true iff the packet may continue traversing the stack; the packet
// must be dropped if false is returned.
//
// Precondition: The packet's network and transport header must be set.
//
// This is called in the hot path even when iptables are disabled, so we ensure
// that it does not allocate. Note that called functions (e.g.
// getConnAndUpdate) can allocate.
// TODO(b/233951539): checkescape fails on arm sometimes. Fix and re-add.
func (it *IPTables) CheckOutput(pkt *PacketBuffer, r *Route, outNicName string) bool {
tables := [...]checkTable{
{
fn: check,
tableID: MangleID,
},
{
fn: checkNAT,
tableID: NATID,
},
{
fn: check,
tableID: FilterID,
},
}
if it.shouldSkipOrPopulateTables(tables[:], pkt) {
return true
}
// We don't need to validate the checksum in the Output path: we can assume
// we calculate it correctly, plus checksumming may be deferred due to GSO.
pkt.tuple = it.connections.getConnAndUpdate(pkt, true /* skipChecksumValidation */)
for _, table := range tables {
if !table.fn(it, table.table, Output, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName) {
return false
}
}
return true
}
// CheckPostrouting performs the postrouting hook on the packet.
//
// Returns true iff the packet may continue traversing the stack; the packet
// must be dropped if false is returned.
//
// Precondition: The packet's network and transport header must be set.
//
// This is called in the hot path even when iptables are disabled, so we ensure
// that it does not allocate. Note that called functions (e.g.
// getConnAndUpdate) can allocate.
// TODO(b/233951539): checkescape fails on arm sometimes. Fix and re-add.
func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, outNicName string) bool {
tables := [...]checkTable{
{
fn: check,
tableID: MangleID,
},
{
fn: checkNAT,
tableID: NATID,
},
}
if it.shouldSkipOrPopulateTables(tables[:], pkt) {
return true
}
for _, table := range tables {
if !table.fn(it, table.table, Postrouting, pkt, r, addressEP, "" /* inNicName */, outNicName) {
return false
}
}
if t := pkt.tuple; t != nil {
pkt.tuple = nil
return t.conn.finalize()
}
return true
}
// Note: this used to omit the *IPTables parameter, but doing so caused
// unnecessary allocations.
type checkTableFn func(it *IPTables, table Table, hook Hook, pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) bool
func checkNAT(it *IPTables, table Table, hook Hook, pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) bool {
return it.checkNAT(table, hook, pkt, r, addressEP, inNicName, outNicName)
}
// checkNAT runs the packet through the NAT table.
//
// See check.
func (it *IPTables) checkNAT(table Table, hook Hook, pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) bool {
t := pkt.tuple
if t != nil && t.conn.handlePacket(pkt, hook, r) {
return true
}
if !it.check(table, hook, pkt, r, addressEP, inNicName, outNicName) {
return false
}
if t == nil {
return true
}
dnat, natDone := func() (bool, bool) {
switch hook {
case Prerouting, Output:
return true, pkt.dnatDone
case Input, Postrouting:
return false, pkt.snatDone
case Forward:
panic("should not attempt NAT in forwarding")
default:
panic(fmt.Sprintf("unhandled hook = %d", hook))
}
}()
// Make sure the connection is NATed.
//
// If the packet was already NATed, the connection must be NATed.
if !natDone {
t.conn.maybePerformNoopNAT(pkt, hook, r, dnat)
}
return true
}
func check(it *IPTables, table Table, hook Hook, pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) bool {
return it.check(table, hook, pkt, r, addressEP, inNicName, outNicName)
}
// check runs the packet through the rules in the specified table for the
// hook. It returns true if the packet should continue to traverse through the
// network stack or tables, or false when it must be dropped.
//
// Precondition: The packet's network and transport header must be set.
func (it *IPTables) check(table Table, hook Hook, pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) bool {
ruleIdx := table.BuiltinChains[hook]
switch verdict := it.checkChain(hook, pkt, table, ruleIdx, r, addressEP, inNicName, outNicName); verdict {
// If the table returns Accept, move on to the next table.
case chainAccept:
return true
// The Drop verdict is final.
case chainDrop:
return false
case chainReturn:
// Any Return from a built-in chain means we have to
// call the underflow.
underflow := table.Rules[table.Underflows[hook]]
switch v, _ := underflow.Target.Action(pkt, hook, r, addressEP); v {
case RuleAccept:
return true
case RuleDrop:
return false
case RuleJump, RuleReturn:
panic("Underflows should only return RuleAccept or RuleDrop.")
default:
panic(fmt.Sprintf("Unknown verdict: %d", v))
}
default:
panic(fmt.Sprintf("Unknown verdict %v.", verdict))
}
}
// beforeSave is invoked by stateify.
func (it *IPTables) beforeSave() {
// Ensure the reaper exits cleanly.
it.reaper.Stop()
// Prevent others from modifying the connection table.
it.connections.mu.Lock()
}
// afterLoad is invoked by stateify.
func (it *IPTables) afterLoad(context.Context) {
it.startReaper(reaperDelay)
}
// startReaper periodically reaps timed out connections.
func (it *IPTables) startReaper(interval time.Duration) {
bucket := 0
it.reaper = it.connections.clock.AfterFunc(interval, func() {
bucket, interval = it.connections.reapUnused(bucket, interval)
it.reaper.Reset(interval)
})
}
// Preconditions:
// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// - pkt.NetworkHeader is not nil.
func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
for ruleIdx < len(table.Rules) {
switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, r, addressEP, inNicName, outNicName); verdict {
case RuleAccept:
return chainAccept
case RuleDrop:
return chainDrop
case RuleReturn:
return chainReturn
case RuleJump:
// "Jumping" to the next rule just means we're
// continuing on down the list.
if jumpTo == ruleIdx+1 {
ruleIdx++
continue
}
switch verdict := it.checkChain(hook, pkt, table, jumpTo, r, addressEP, inNicName, outNicName); verdict {
case chainAccept:
return chainAccept
case chainDrop:
return chainDrop
case chainReturn:
ruleIdx++
continue
default:
panic(fmt.Sprintf("Unknown verdict: %d", verdict))
}
default:
panic(fmt.Sprintf("Unknown verdict: %d", verdict))
}
}
// We got through the entire table without a decision. Default to DROP
// for safety.
return chainDrop
}
// Preconditions:
// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// - pkt.NetworkHeader is not nil.
//
// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// * pkt.NetworkHeader is not nil.
func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
// Check whether the packet matches the IP header filter.
if !rule.Filter.match(pkt, hook, inNicName, outNicName) {
// Continue on to the next rule.
return RuleJump, ruleIdx + 1
}
// Go through each rule matcher. If they all match, run
// the rule target.
for _, matcher := range rule.Matchers {
matches, hotdrop := matcher.Match(hook, pkt, inNicName, outNicName)
if hotdrop {
return RuleDrop, 0
}
if !matches {
// Continue on to the next rule.
return RuleJump, ruleIdx + 1
}
}
// All the matchers matched, so run the target.
return rule.Target.Action(pkt, hook, r, addressEP)
}
// OriginalDst returns the original destination of redirected connections. It
// returns an error if the connection doesn't exist or isn't redirected.
func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
it.mu.RLock()
defer it.mu.RUnlock()
if !it.modified {
return tcpip.Address{}, 0, &tcpip.ErrNotConnected{}
}
return it.connections.originalDst(epID, netProto, transProto)
}

View File

@@ -0,0 +1,96 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type ipTablesRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var ipTableslockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type ipTableslockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *ipTablesRWMutex) Lock() {
locking.AddGLock(ipTablesprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *ipTablesRWMutex) NestedLock(i ipTableslockNameIndex) {
locking.AddGLock(ipTablesprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *ipTablesRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(ipTablesprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *ipTablesRWMutex) NestedUnlock(i ipTableslockNameIndex) {
m.mu.Unlock()
locking.DelGLock(ipTablesprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *ipTablesRWMutex) RLock() {
locking.AddGLock(ipTablesprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *ipTablesRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(ipTablesprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *ipTablesRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *ipTablesRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *ipTablesRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var ipTablesprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func ipTablesinitLockNames() {}
func init() {
ipTablesinitLockNames()
ipTablesprefixIndex = locking.NewMutexClass(reflect.TypeOf(ipTablesRWMutex{}), ipTableslockNames)
}

View File

@@ -0,0 +1,493 @@
// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"fmt"
"math"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
// AcceptTarget accepts packets.
//
// +stateify savable
type AcceptTarget struct {
// NetworkProtocol is the network protocol the target is used with.
NetworkProtocol tcpip.NetworkProtocolNumber
}
// Action implements Target.Action.
func (*AcceptTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleAccept, 0
}
// DropTarget drops packets.
//
// +stateify savable
type DropTarget struct {
// NetworkProtocol is the network protocol the target is used with.
NetworkProtocol tcpip.NetworkProtocolNumber
}
// Action implements Target.Action.
func (*DropTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleDrop, 0
}
// RejectIPv4WithHandler handles rejecting a packet.
type RejectIPv4WithHandler interface {
// SendRejectionError sends an error packet in response to the packet.
SendRejectionError(pkt *PacketBuffer, rejectWith RejectIPv4WithICMPType, inputHook bool) tcpip.Error
}
// RejectIPv4WithICMPType indicates the type of ICMP error that should be sent.
type RejectIPv4WithICMPType int
// The types of errors that may be returned when rejecting IPv4 packets.
const (
_ RejectIPv4WithICMPType = iota
RejectIPv4WithICMPNetUnreachable
RejectIPv4WithICMPHostUnreachable
RejectIPv4WithICMPPortUnreachable
RejectIPv4WithICMPNetProhibited
RejectIPv4WithICMPHostProhibited
RejectIPv4WithICMPAdminProhibited
)
// RejectIPv4Target drops packets and sends back an error packet in response to the
// matched packet.
//
// +stateify savable
type RejectIPv4Target struct {
Handler RejectIPv4WithHandler
RejectWith RejectIPv4WithICMPType
}
// Action implements Target.Action.
func (rt *RejectIPv4Target) Action(pkt *PacketBuffer, hook Hook, _ *Route, _ AddressableEndpoint) (RuleVerdict, int) {
switch hook {
case Input, Forward, Output:
// There is nothing reasonable for us to do in response to an error here;
// we already drop the packet.
_ = rt.Handler.SendRejectionError(pkt, rt.RejectWith, hook == Input)
return RuleDrop, 0
case Prerouting, Postrouting:
panic(fmt.Sprintf("%s hook not supported for REDIRECT", hook))
default:
panic(fmt.Sprintf("unhandled hook = %s", hook))
}
}
// RejectIPv6WithHandler handles rejecting a packet.
type RejectIPv6WithHandler interface {
// SendRejectionError sends an error packet in response to the packet.
SendRejectionError(pkt *PacketBuffer, rejectWith RejectIPv6WithICMPType, forwardingHook bool) tcpip.Error
}
// RejectIPv6WithICMPType indicates the type of ICMP error that should be sent.
type RejectIPv6WithICMPType int
// The types of errors that may be returned when rejecting IPv6 packets.
const (
_ RejectIPv6WithICMPType = iota
RejectIPv6WithICMPNoRoute
RejectIPv6WithICMPAddrUnreachable
RejectIPv6WithICMPPortUnreachable
RejectIPv6WithICMPAdminProhibited
)
// RejectIPv6Target drops packets and sends back an error packet in response to the
// matched packet.
//
// +stateify savable
type RejectIPv6Target struct {
Handler RejectIPv6WithHandler
RejectWith RejectIPv6WithICMPType
}
// Action implements Target.Action.
func (rt *RejectIPv6Target) Action(pkt *PacketBuffer, hook Hook, _ *Route, _ AddressableEndpoint) (RuleVerdict, int) {
switch hook {
case Input, Forward, Output:
// There is nothing reasonable for us to do in response to an error here;
// we already drop the packet.
_ = rt.Handler.SendRejectionError(pkt, rt.RejectWith, hook == Input)
return RuleDrop, 0
case Prerouting, Postrouting:
panic(fmt.Sprintf("%s hook not supported for REDIRECT", hook))
default:
panic(fmt.Sprintf("unhandled hook = %s", hook))
}
}
// ErrorTarget logs an error and drops the packet. It represents a target that
// should be unreachable.
//
// +stateify savable
type ErrorTarget struct {
// NetworkProtocol is the network protocol the target is used with.
NetworkProtocol tcpip.NetworkProtocolNumber
}
// Action implements Target.Action.
func (*ErrorTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
log.Debugf("ErrorTarget triggered.")
return RuleDrop, 0
}
// UserChainTarget marks a rule as the beginning of a user chain.
//
// +stateify savable
type UserChainTarget struct {
// Name is the chain name.
Name string
// NetworkProtocol is the network protocol the target is used with.
NetworkProtocol tcpip.NetworkProtocolNumber
}
// Action implements Target.Action.
func (*UserChainTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
panic("UserChainTarget should never be called.")
}
// ReturnTarget returns from the current chain. If the chain is a built-in, the
// hook's underflow should be called.
//
// +stateify savable
type ReturnTarget struct {
// NetworkProtocol is the network protocol the target is used with.
NetworkProtocol tcpip.NetworkProtocolNumber
}
// Action implements Target.Action.
func (*ReturnTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleReturn, 0
}
// DNATTarget modifies the destination port/IP of packets.
//
// +stateify savable
type DNATTarget struct {
// The new destination address for packets.
//
// Immutable.
Addr tcpip.Address
// The new destination port for packets.
//
// Immutable.
Port uint16
// NetworkProtocol is the network protocol the target is used with.
//
// Immutable.
NetworkProtocol tcpip.NetworkProtocolNumber
// ChangeAddress indicates whether we should check addresses.
//
// Immutable.
ChangeAddress bool
// ChangePort indicates whether we should check ports.
//
// Immutable.
ChangePort bool
}
// Action implements Target.Action.
func (rt *DNATTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
// Sanity check.
if rt.NetworkProtocol != pkt.NetworkProtocolNumber {
panic(fmt.Sprintf(
"DNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
rt.NetworkProtocol, pkt.NetworkProtocolNumber))
}
switch hook {
case Prerouting, Output:
case Input, Forward, Postrouting:
panic(fmt.Sprintf("%s not supported for DNAT", hook))
default:
panic(fmt.Sprintf("%s unrecognized", hook))
}
return dnatAction(pkt, hook, r, rt.Port, rt.Addr, rt.ChangePort, rt.ChangeAddress)
}
// RedirectTarget redirects the packet to this machine by modifying the
// destination port/IP. Outgoing packets are redirected to the loopback device,
// and incoming packets are redirected to the incoming interface (rather than
// forwarded).
//
// +stateify savable
type RedirectTarget struct {
// Port indicates port used to redirect. It is immutable.
Port uint16
// NetworkProtocol is the network protocol the target is used with. It
// is immutable.
NetworkProtocol tcpip.NetworkProtocolNumber
}
// Action implements Target.Action.
func (rt *RedirectTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
// Sanity check.
if rt.NetworkProtocol != pkt.NetworkProtocolNumber {
panic(fmt.Sprintf(
"RedirectTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
rt.NetworkProtocol, pkt.NetworkProtocolNumber))
}
// Change the address to loopback (127.0.0.1 or ::1) in Output and to
// the primary address of the incoming interface in Prerouting.
var address tcpip.Address
switch hook {
case Output:
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
address = tcpip.AddrFrom4([4]byte{127, 0, 0, 1})
} else {
address = header.IPv6Loopback
}
case Prerouting:
// addressEP is expected to be set for the prerouting hook.
address = addressEP.MainAddress().Address
default:
panic("redirect target is supported only on output and prerouting hooks")
}
return dnatAction(pkt, hook, r, rt.Port, address, true /* changePort */, true /* changeAddress */)
}
// SNATTarget modifies the source port/IP in the outgoing packets.
//
// +stateify savable
type SNATTarget struct {
Addr tcpip.Address
Port uint16
// NetworkProtocol is the network protocol the target is used with. It
// is immutable.
NetworkProtocol tcpip.NetworkProtocolNumber
// ChangeAddress indicates whether we should check addresses.
//
// Immutable.
ChangeAddress bool
// ChangePort indicates whether we should check ports.
//
// Immutable.
ChangePort bool
}
func dnatAction(pkt *PacketBuffer, hook Hook, r *Route, port uint16, address tcpip.Address, changePort, changeAddress bool) (RuleVerdict, int) {
return natAction(pkt, hook, r, portOrIdentRange{start: port, size: 1}, address, true /* dnat */, changePort, changeAddress)
}
func targetPortRangeForTCPAndUDP(originalSrcPort uint16) portOrIdentRange {
// As per iptables(8),
//
// If no port range is specified, then source ports below 512 will be
// mapped to other ports below 512: those between 512 and 1023 inclusive
// will be mapped to ports below 1024, and other ports will be mapped to
// 1024 or above.
switch {
case originalSrcPort < 512:
return portOrIdentRange{start: 1, size: 511}
case originalSrcPort < 1024:
return portOrIdentRange{start: 1, size: 1023}
default:
return portOrIdentRange{start: 1024, size: math.MaxUint16 - 1023}
}
}
func snatAction(pkt *PacketBuffer, hook Hook, r *Route, port uint16, address tcpip.Address, changePort, changeAddress bool) (RuleVerdict, int) {
portsOrIdents := portOrIdentRange{start: port, size: 1}
switch pkt.TransportProtocolNumber {
case header.UDPProtocolNumber:
if port == 0 {
portsOrIdents = targetPortRangeForTCPAndUDP(header.UDP(pkt.TransportHeader().Slice()).SourcePort())
}
case header.TCPProtocolNumber:
if port == 0 {
portsOrIdents = targetPortRangeForTCPAndUDP(header.TCP(pkt.TransportHeader().Slice()).SourcePort())
}
case header.ICMPv4ProtocolNumber, header.ICMPv6ProtocolNumber:
// Allow NAT-ing to any 16-bit value for ICMP's Ident field to match Linux
// behaviour.
//
// https://github.com/torvalds/linux/blob/58e1100fdc5990b0cc0d4beaf2562a92e621ac7d/net/netfilter/nf_nat_core.c#L391
portsOrIdents = portOrIdentRange{start: 0, size: math.MaxUint16 + 1}
}
return natAction(pkt, hook, r, portsOrIdents, address, false /* dnat */, changePort, changeAddress)
}
func natAction(pkt *PacketBuffer, hook Hook, r *Route, portsOrIdents portOrIdentRange, address tcpip.Address, dnat, changePort, changeAddress bool) (RuleVerdict, int) {
// Drop the packet if network and transport header are not set.
if len(pkt.NetworkHeader().Slice()) == 0 || len(pkt.TransportHeader().Slice()) == 0 {
return RuleDrop, 0
}
if t := pkt.tuple; t != nil {
t.conn.performNAT(pkt, hook, r, portsOrIdents, address, dnat, changePort, changeAddress)
return RuleAccept, 0
}
return RuleDrop, 0
}
// Action implements Target.Action.
func (st *SNATTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, _ AddressableEndpoint) (RuleVerdict, int) {
// Sanity check.
if st.NetworkProtocol != pkt.NetworkProtocolNumber {
panic(fmt.Sprintf(
"SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
st.NetworkProtocol, pkt.NetworkProtocolNumber))
}
switch hook {
case Postrouting, Input:
case Prerouting, Output, Forward:
panic(fmt.Sprintf("%s not supported", hook))
default:
panic(fmt.Sprintf("%s unrecognized", hook))
}
return snatAction(pkt, hook, r, st.Port, st.Addr, st.ChangePort, st.ChangeAddress)
}
// MasqueradeTarget modifies the source port/IP in the outgoing packets.
//
// +stateify savable
type MasqueradeTarget struct {
// NetworkProtocol is the network protocol the target is used with. It
// is immutable.
NetworkProtocol tcpip.NetworkProtocolNumber
}
// Action implements Target.Action.
func (mt *MasqueradeTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
// Sanity check.
if mt.NetworkProtocol != pkt.NetworkProtocolNumber {
panic(fmt.Sprintf(
"MasqueradeTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
mt.NetworkProtocol, pkt.NetworkProtocolNumber))
}
switch hook {
case Postrouting:
case Prerouting, Input, Forward, Output:
panic(fmt.Sprintf("masquerade target is supported only on postrouting hook; hook = %d", hook))
default:
panic(fmt.Sprintf("%s unrecognized", hook))
}
// addressEP is expected to be set for the postrouting hook.
ep := addressEP.AcquireOutgoingPrimaryAddress(pkt.Network().DestinationAddress(), tcpip.Address{} /* srcHint */, false /* allowExpired */)
if ep == nil {
// No address exists that we can use as a source address.
return RuleDrop, 0
}
address := ep.AddressWithPrefix().Address
ep.DecRef()
return snatAction(pkt, hook, r, 0 /* port */, address, true /* changePort */, true /* changeAddress */)
}
func rewritePacket(n header.Network, t header.Transport, updateSRCFields, fullChecksum, updatePseudoHeader bool, newPortOrIdent uint16, newAddr tcpip.Address) {
switch t := t.(type) {
case header.ChecksummableTransport:
if updateSRCFields {
if fullChecksum {
t.SetSourcePortWithChecksumUpdate(newPortOrIdent)
} else {
t.SetSourcePort(newPortOrIdent)
}
} else {
if fullChecksum {
t.SetDestinationPortWithChecksumUpdate(newPortOrIdent)
} else {
t.SetDestinationPort(newPortOrIdent)
}
}
if updatePseudoHeader {
var oldAddr tcpip.Address
if updateSRCFields {
oldAddr = n.SourceAddress()
} else {
oldAddr = n.DestinationAddress()
}
t.UpdateChecksumPseudoHeaderAddress(oldAddr, newAddr, fullChecksum)
}
case header.ICMPv4:
switch icmpType := t.Type(); icmpType {
case header.ICMPv4Echo:
if updateSRCFields {
t.SetIdentWithChecksumUpdate(newPortOrIdent)
}
case header.ICMPv4EchoReply:
if !updateSRCFields {
t.SetIdentWithChecksumUpdate(newPortOrIdent)
}
default:
panic(fmt.Sprintf("unexpected ICMPv4 type = %d", icmpType))
}
case header.ICMPv6:
switch icmpType := t.Type(); icmpType {
case header.ICMPv6EchoRequest:
if updateSRCFields {
t.SetIdentWithChecksumUpdate(newPortOrIdent)
}
case header.ICMPv6EchoReply:
if !updateSRCFields {
t.SetIdentWithChecksumUpdate(newPortOrIdent)
}
default:
panic(fmt.Sprintf("unexpected ICMPv4 type = %d", icmpType))
}
var oldAddr tcpip.Address
if updateSRCFields {
oldAddr = n.SourceAddress()
} else {
oldAddr = n.DestinationAddress()
}
t.UpdateChecksumPseudoHeaderAddress(oldAddr, newAddr)
default:
panic(fmt.Sprintf("unhandled transport = %#v", t))
}
if checksummableNetHeader, ok := n.(header.ChecksummableNetwork); ok {
if updateSRCFields {
checksummableNetHeader.SetSourceAddressWithChecksumUpdate(newAddr)
} else {
checksummableNetHeader.SetDestinationAddressWithChecksumUpdate(newAddr)
}
} else if updateSRCFields {
n.SetSourceAddress(newAddr)
} else {
n.SetDestinationAddress(newAddr)
}
}

View File

@@ -0,0 +1,387 @@
// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"fmt"
"strings"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
// A Hook specifies one of the hooks built into the network stack.
//
// Userspace app Userspace app
// ^ |
// | v
// [Input] [Output]
// ^ |
// | v
// | routing
// | |
// | v
// ----->[Prerouting]----->routing----->[Forward]---------[Postrouting]----->
type Hook uint
const (
// Prerouting happens before a packet is routed to applications or to
// be forwarded.
Prerouting Hook = iota
// Input happens before a packet reaches an application.
Input
// Forward happens once it's decided that a packet should be forwarded
// to another host.
Forward
// Output happens after a packet is written by an application to be
// sent out.
Output
// Postrouting happens just before a packet goes out on the wire.
Postrouting
// NumHooks is the total number of hooks.
NumHooks
)
// A RuleVerdict is what a rule decides should be done with a packet.
type RuleVerdict int
const (
// RuleAccept indicates the packet should continue through netstack.
RuleAccept RuleVerdict = iota
// RuleDrop indicates the packet should be dropped.
RuleDrop
// RuleJump indicates the packet should jump to another chain.
RuleJump
// RuleReturn indicates the packet should return to the previous chain.
RuleReturn
)
// IPTables holds all the tables for a netstack.
//
// +stateify savable
type IPTables struct {
connections ConnTrack
reaper tcpip.Timer
mu ipTablesRWMutex `state:"nosave"`
// v4Tables and v6tables map tableIDs to tables. They hold builtin
// tables only, not user tables.
//
// mu protects the array of tables, but not the tables themselves.
// +checklocks:mu
v4Tables [NumTables]Table
//
// mu protects the array of tables, but not the tables themselves.
// +checklocks:mu
v6Tables [NumTables]Table
// modified is whether tables have been modified at least once. It is
// used to elide the iptables performance overhead for workloads that
// don't utilize iptables.
//
// +checklocks:mu
modified bool
}
// Modified returns whether iptables has been modified. It is inherently racy
// and intended for use only in tests.
func (it *IPTables) Modified() bool {
it.mu.Lock()
defer it.mu.Unlock()
return it.modified
}
// VisitTargets traverses all the targets of all tables and replaces each with
// transform(target).
func (it *IPTables) VisitTargets(transform func(Target) Target) {
it.mu.Lock()
defer it.mu.Unlock()
for tid := range it.v4Tables {
for i, rule := range it.v4Tables[tid].Rules {
it.v4Tables[tid].Rules[i].Target = transform(rule.Target)
}
}
for tid := range it.v6Tables {
for i, rule := range it.v6Tables[tid].Rules {
it.v6Tables[tid].Rules[i].Target = transform(rule.Target)
}
}
}
// A Table defines a set of chains and hooks into the network stack.
//
// It is a list of Rules, entry points (BuiltinChains), and error handlers
// (Underflows). As packets traverse netstack, they hit hooks. When a packet
// hits a hook, iptables compares it to Rules starting from that hook's entry
// point. So if a packet hits the Input hook, we look up the corresponding
// entry point in BuiltinChains and jump to that point.
//
// If the Rule doesn't match the packet, iptables continues to the next Rule.
// If a Rule does match, it can issue a verdict on the packet (e.g. RuleAccept
// or RuleDrop) that causes the packet to stop traversing iptables. It can also
// jump to other rules or perform custom actions based on Rule.Target.
//
// Underflow Rules are invoked when a chain returns without reaching a verdict.
//
// +stateify savable
type Table struct {
// Rules holds the rules that make up the table.
Rules []Rule
// BuiltinChains maps builtin chains to their entrypoint rule in Rules.
BuiltinChains [NumHooks]int
// Underflows maps builtin chains to their underflow rule in Rules
// (i.e. the rule to execute if the chain returns without a verdict).
Underflows [NumHooks]int
}
// ValidHooks returns a bitmap of the builtin hooks for the given table.
func (table *Table) ValidHooks() uint32 {
hooks := uint32(0)
for hook, ruleIdx := range table.BuiltinChains {
if ruleIdx != HookUnset {
hooks |= 1 << hook
}
}
return hooks
}
// A Rule is a packet processing rule. It consists of two pieces. First it
// contains zero or more matchers, each of which is a specification of which
// packets this rule applies to. If there are no matchers in the rule, it
// applies to any packet.
//
// +stateify savable
type Rule struct {
// Filter holds basic IP filtering fields common to every rule.
Filter IPHeaderFilter
// Matchers is the list of matchers for this rule.
Matchers []Matcher
// Target is the action to invoke if all the matchers match the packet.
Target Target
}
// IPHeaderFilter performs basic IP header matching common to every rule.
//
// +stateify savable
type IPHeaderFilter struct {
// Protocol matches the transport protocol.
Protocol tcpip.TransportProtocolNumber
// CheckProtocol determines whether the Protocol field should be
// checked during matching.
CheckProtocol bool
// Dst matches the destination IP address.
Dst tcpip.Address
// DstMask masks bits of the destination IP address when comparing with
// Dst.
DstMask tcpip.Address
// DstInvert inverts the meaning of the destination IP check, i.e. when
// true the filter will match packets that fail the destination
// comparison.
DstInvert bool
// Src matches the source IP address.
Src tcpip.Address
// SrcMask masks bits of the source IP address when comparing with Src.
SrcMask tcpip.Address
// SrcInvert inverts the meaning of the source IP check, i.e. when true the
// filter will match packets that fail the source comparison.
SrcInvert bool
// InputInterface matches the name of the incoming interface for the packet.
InputInterface string
// InputInterfaceMask masks the characters of the interface name when
// comparing with InputInterface.
InputInterfaceMask string
// InputInterfaceInvert inverts the meaning of incoming interface check,
// i.e. when true the filter will match packets that fail the incoming
// interface comparison.
InputInterfaceInvert bool
// OutputInterface matches the name of the outgoing interface for the packet.
OutputInterface string
// OutputInterfaceMask masks the characters of the interface name when
// comparing with OutputInterface.
OutputInterfaceMask string
// OutputInterfaceInvert inverts the meaning of outgoing interface check,
// i.e. when true the filter will match packets that fail the outgoing
// interface comparison.
OutputInterfaceInvert bool
}
// EmptyFilter4 returns an initialized IPv4 header filter.
func EmptyFilter4() IPHeaderFilter {
return IPHeaderFilter{
Dst: tcpip.AddrFrom4([4]byte{}),
DstMask: tcpip.AddrFrom4([4]byte{}),
Src: tcpip.AddrFrom4([4]byte{}),
SrcMask: tcpip.AddrFrom4([4]byte{}),
}
}
// EmptyFilter6 returns an initialized IPv6 header filter.
func EmptyFilter6() IPHeaderFilter {
return IPHeaderFilter{
Dst: tcpip.AddrFrom16([16]byte{}),
DstMask: tcpip.AddrFrom16([16]byte{}),
Src: tcpip.AddrFrom16([16]byte{}),
SrcMask: tcpip.AddrFrom16([16]byte{}),
}
}
// match returns whether pkt matches the filter.
//
// Preconditions: pkt.NetworkHeader is set and is at least of the minimal IPv4
// or IPv6 header length.
func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicName string) bool {
// Extract header fields.
var (
transProto tcpip.TransportProtocolNumber
dstAddr tcpip.Address
srcAddr tcpip.Address
)
switch proto := pkt.NetworkProtocolNumber; proto {
case header.IPv4ProtocolNumber:
hdr := header.IPv4(pkt.NetworkHeader().Slice())
transProto = hdr.TransportProtocol()
dstAddr = hdr.DestinationAddress()
srcAddr = hdr.SourceAddress()
case header.IPv6ProtocolNumber:
hdr := header.IPv6(pkt.NetworkHeader().Slice())
transProto = hdr.TransportProtocol()
dstAddr = hdr.DestinationAddress()
srcAddr = hdr.SourceAddress()
default:
panic(fmt.Sprintf("unknown network protocol with EtherType: %d", proto))
}
// Check the transport protocol.
if fl.CheckProtocol && fl.Protocol != transProto {
return false
}
// Check the addresses.
if !filterAddress(dstAddr, fl.DstMask, fl.Dst, fl.DstInvert) ||
!filterAddress(srcAddr, fl.SrcMask, fl.Src, fl.SrcInvert) {
return false
}
switch hook {
case Prerouting, Input:
return matchIfName(inNicName, fl.InputInterface, fl.InputInterfaceInvert)
case Output:
return matchIfName(outNicName, fl.OutputInterface, fl.OutputInterfaceInvert)
case Forward:
if !matchIfName(inNicName, fl.InputInterface, fl.InputInterfaceInvert) {
return false
}
if !matchIfName(outNicName, fl.OutputInterface, fl.OutputInterfaceInvert) {
return false
}
return true
case Postrouting:
return true
default:
panic(fmt.Sprintf("unknown hook: %d", hook))
}
}
func matchIfName(nicName string, ifName string, invert bool) bool {
n := len(ifName)
if n == 0 {
// If the interface name is omitted in the filter, any interface will match.
return true
}
// If the interface name ends with '+', any interface which begins with the
// name should be matched.
var matches bool
if strings.HasSuffix(ifName, "+") {
matches = strings.HasPrefix(nicName, ifName[:n-1])
} else {
matches = nicName == ifName
}
return matches != invert
}
// NetworkProtocol returns the protocol (IPv4 or IPv6) on to which the header
// applies.
func (fl IPHeaderFilter) NetworkProtocol() tcpip.NetworkProtocolNumber {
switch fl.Src.BitLen() {
case header.IPv4AddressSizeBits:
return header.IPv4ProtocolNumber
case header.IPv6AddressSizeBits:
return header.IPv6ProtocolNumber
}
panic(fmt.Sprintf("invalid address in IPHeaderFilter: %s", fl.Src))
}
// filterAddress returns whether addr matches the filter.
func filterAddress(addr, mask, filterAddr tcpip.Address, invert bool) bool {
matches := true
addrBytes := addr.AsSlice()
maskBytes := mask.AsSlice()
filterBytes := filterAddr.AsSlice()
for i := range filterAddr.AsSlice() {
if addrBytes[i]&maskBytes[i] != filterBytes[i] {
matches = false
break
}
}
return matches != invert
}
// A Matcher is the interface for matching packets.
type Matcher interface {
// Match returns whether the packet matches and whether the packet
// should be "hotdropped", i.e. dropped immediately. This is usually
// used for suspicious packets.
//
// Precondition: packet.NetworkHeader is set.
Match(hook Hook, packet *PacketBuffer, inputInterfaceName, outputInterfaceName string) (matches bool, hotdrop bool)
}
// A Target is the interface for taking an action for a packet.
type Target interface {
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
// Jump, it also returns the index of the rule to jump to.
Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int)
}

View File

@@ -0,0 +1,96 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type multiPortEndpointRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var multiPortEndpointlockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type multiPortEndpointlockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *multiPortEndpointRWMutex) Lock() {
locking.AddGLock(multiPortEndpointprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *multiPortEndpointRWMutex) NestedLock(i multiPortEndpointlockNameIndex) {
locking.AddGLock(multiPortEndpointprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *multiPortEndpointRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(multiPortEndpointprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *multiPortEndpointRWMutex) NestedUnlock(i multiPortEndpointlockNameIndex) {
m.mu.Unlock()
locking.DelGLock(multiPortEndpointprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *multiPortEndpointRWMutex) RLock() {
locking.AddGLock(multiPortEndpointprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *multiPortEndpointRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(multiPortEndpointprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *multiPortEndpointRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *multiPortEndpointRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *multiPortEndpointRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var multiPortEndpointprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func multiPortEndpointinitLockNames() {}
func init() {
multiPortEndpointinitLockNames()
multiPortEndpointprefixIndex = locking.NewMutexClass(reflect.TypeOf(multiPortEndpointRWMutex{}), multiPortEndpointlockNames)
}

View File

@@ -0,0 +1,314 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
)
// NeighborCacheSize is the size of the neighborCache. Exceeding this size will
// result in the least recently used entry being evicted.
const NeighborCacheSize = 512 // max entries per interface
// NeighborStats holds metrics for the neighbor table.
type NeighborStats struct {
// UnreachableEntryLookups counts the number of lookups performed on an
// entry in Unreachable state.
UnreachableEntryLookups *tcpip.StatCounter
}
// +stateify savable
type dynamicCacheEntry struct {
lru neighborEntryList
// count tracks the amount of dynamic entries in the cache. This is
// needed since static entries do not count towards the LRU cache
// eviction strategy.
count uint16
}
// +stateify savable
type neighborCacheMu struct {
neighborCacheRWMutex `state:"nosave"`
cache map[tcpip.Address]*neighborEntry
dynamic dynamicCacheEntry
}
// neighborCache maps IP addresses to link addresses. It uses the Least
// Recently Used (LRU) eviction strategy to implement a bounded cache for
// dynamically acquired entries. It contains the state machine and configuration
// for running Neighbor Unreachability Detection (NUD).
//
// There are two types of entries in the neighbor cache:
// 1. Dynamic entries are discovered automatically by neighbor discovery
// protocols (e.g. ARP, NDP). These protocols will attempt to reconfirm
// reachability with the device once the entry's state becomes Stale.
// 2. Static entries are explicitly added by a user and have no expiration.
// Their state is always Static. The amount of static entries stored in the
// cache is unbounded.
//
// +stateify savable
type neighborCache struct {
nic *nic
state *NUDState
linkRes LinkAddressResolver
mu neighborCacheMu
}
// getOrCreateEntry retrieves a cache entry associated with addr. The
// returned entry is always refreshed in the cache (it is reachable via the
// map, and its place is bumped in LRU).
//
// If a matching entry exists in the cache, it is returned. If no matching
// entry exists and the cache is full, an existing entry is evicted via LRU,
// reset to state incomplete, and returned. If no matching entry exists and the
// cache is not full, a new entry with state incomplete is allocated and
// returned.
func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address) *neighborEntry {
n.mu.Lock()
defer n.mu.Unlock()
if entry, ok := n.mu.cache[remoteAddr]; ok {
entry.mu.RLock()
if entry.mu.neigh.State != Static {
n.mu.dynamic.lru.Remove(entry)
n.mu.dynamic.lru.PushFront(entry)
}
entry.mu.RUnlock()
return entry
}
// The entry that needs to be created must be dynamic since all static
// entries are directly added to the cache via addStaticEntry.
entry := newNeighborEntry(n, remoteAddr, n.state)
if n.mu.dynamic.count == NeighborCacheSize {
e := n.mu.dynamic.lru.Back()
e.mu.Lock()
delete(n.mu.cache, e.mu.neigh.Addr)
n.mu.dynamic.lru.Remove(e)
n.mu.dynamic.count--
e.removeLocked()
e.mu.Unlock()
}
n.mu.cache[remoteAddr] = entry
n.mu.dynamic.lru.PushFront(entry)
n.mu.dynamic.count++
return entry
}
// entry looks up neighbor information matching the remote address, and returns
// it if readily available.
//
// Returns ErrWouldBlock if the link address is not readily available, along
// with a notification channel for the caller to block on. Triggers address
// resolution asynchronously.
//
// If onResolve is provided, it will be called either immediately, if resolution
// is not required, or when address resolution is complete, with the resolved
// link address and whether resolution succeeded. After any callbacks have been
// called, the returned notification channel is closed.
//
// NB: if a callback is provided, it should not call into the neighbor cache.
//
// If specified, the local address must be an address local to the interface the
// neighbor cache belongs to. The local address is the source address of a
// packet prompting NUD/link address resolution.
func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (*neighborEntry, <-chan struct{}, tcpip.Error) {
entry := n.getOrCreateEntry(remoteAddr)
entry.mu.Lock()
defer entry.mu.Unlock()
switch s := entry.mu.neigh.State; s {
case Stale:
entry.handlePacketQueuedLocked(localAddr)
fallthrough
case Reachable, Static, Delay, Probe:
// As per RFC 4861 section 7.3.3:
// "Neighbor Unreachability Detection operates in parallel with the sending
// of packets to a neighbor. While reasserting a neighbor's reachability,
// a node continues sending packets to that neighbor using the cached
// link-layer address."
if onResolve != nil {
onResolve(LinkResolutionResult{LinkAddress: entry.mu.neigh.LinkAddr, Err: nil})
}
return entry, nil, nil
case Unknown, Incomplete, Unreachable:
if onResolve != nil {
entry.mu.onResolve = append(entry.mu.onResolve, onResolve)
}
if entry.mu.done == nil {
// Address resolution needs to be initiated.
entry.mu.done = make(chan struct{})
}
entry.handlePacketQueuedLocked(localAddr)
return entry, entry.mu.done, &tcpip.ErrWouldBlock{}
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", s))
}
}
// entries returns all entries in the neighbor cache.
func (n *neighborCache) entries() []NeighborEntry {
n.mu.RLock()
defer n.mu.RUnlock()
entries := make([]NeighborEntry, 0, len(n.mu.cache))
for _, entry := range n.mu.cache {
entry.mu.RLock()
entries = append(entries, entry.mu.neigh)
entry.mu.RUnlock()
}
return entries
}
// addStaticEntry adds a static entry to the neighbor cache, mapping an IP
// address to a link address. If a dynamic entry exists in the neighbor cache
// with the same address, it will be replaced with this static entry. If a
// static entry exists with the same address but different link address, it
// will be updated with the new link address. If a static entry exists with the
// same address and link address, nothing will happen.
func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAddress) {
n.mu.Lock()
defer n.mu.Unlock()
if entry, ok := n.mu.cache[addr]; ok {
entry.mu.Lock()
if entry.mu.neigh.State != Static {
// Dynamic entry found with the same address.
n.mu.dynamic.lru.Remove(entry)
n.mu.dynamic.count--
} else if entry.mu.neigh.LinkAddr == linkAddr {
// Static entry found with the same address and link address.
entry.mu.Unlock()
return
} else {
// Static entry found with the same address but different link address.
entry.mu.neigh.LinkAddr = linkAddr
entry.dispatchChangeEventLocked()
entry.mu.Unlock()
return
}
entry.removeLocked()
entry.mu.Unlock()
}
entry := newStaticNeighborEntry(n, addr, linkAddr, n.state)
n.mu.cache[addr] = entry
entry.mu.Lock()
defer entry.mu.Unlock()
entry.dispatchAddEventLocked()
}
// removeEntry removes a dynamic or static entry by address from the neighbor
// cache. Returns true if the entry was found and deleted.
func (n *neighborCache) removeEntry(addr tcpip.Address) bool {
n.mu.Lock()
defer n.mu.Unlock()
entry, ok := n.mu.cache[addr]
if !ok {
return false
}
entry.mu.Lock()
defer entry.mu.Unlock()
if entry.mu.neigh.State != Static {
n.mu.dynamic.lru.Remove(entry)
n.mu.dynamic.count--
}
entry.removeLocked()
delete(n.mu.cache, entry.mu.neigh.Addr)
return true
}
// clear removes all dynamic and static entries from the neighbor cache.
func (n *neighborCache) clear() {
n.mu.Lock()
defer n.mu.Unlock()
for _, entry := range n.mu.cache {
entry.mu.Lock()
entry.removeLocked()
entry.mu.Unlock()
}
n.mu.dynamic.lru = neighborEntryList{}
clear(n.mu.cache)
n.mu.dynamic.count = 0
}
// config returns the NUD configuration.
func (n *neighborCache) config() NUDConfigurations {
return n.state.Config()
}
// setConfig changes the NUD configuration.
//
// If config contains invalid NUD configuration values, it will be fixed to
// use default values for the erroneous values.
func (n *neighborCache) setConfig(config NUDConfigurations) {
config.resetInvalidFields()
n.state.SetConfig(config)
}
// handleProbe handles a neighbor probe as defined by RFC 4861 section 7.2.3.
//
// Validation of the probe is expected to be handled by the caller.
func (n *neighborCache) handleProbe(remoteAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) {
entry := n.getOrCreateEntry(remoteAddr)
entry.mu.Lock()
entry.handleProbeLocked(remoteLinkAddr)
entry.mu.Unlock()
}
// handleConfirmation handles a neighbor confirmation as defined by
// RFC 4861 section 7.2.5.
//
// Validation of the confirmation is expected to be handled by the caller.
func (n *neighborCache) handleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) {
n.mu.RLock()
entry, ok := n.mu.cache[addr]
n.mu.RUnlock()
if ok {
entry.mu.Lock()
entry.handleConfirmationLocked(linkAddr, flags)
entry.mu.Unlock()
} else {
// The confirmation SHOULD be silently discarded if the recipient did not
// initiate any communication with the target. This is indicated if there is
// no matching entry for the remote address.
n.nic.stats.neighbor.droppedConfirmationForNoninitiatedNeighbor.Increment()
}
}
func (n *neighborCache) init(nic *nic, r LinkAddressResolver) {
*n = neighborCache{
nic: nic,
state: NewNUDState(nic.stack.nudConfigs, nic.stack.clock, nic.stack.insecureRNG),
linkRes: r,
}
n.mu.Lock()
n.mu.cache = make(map[tcpip.Address]*neighborEntry, NeighborCacheSize)
n.mu.Unlock()
}

View File

@@ -0,0 +1,96 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type neighborCacheRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var neighborCachelockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type neighborCachelockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *neighborCacheRWMutex) Lock() {
locking.AddGLock(neighborCacheprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *neighborCacheRWMutex) NestedLock(i neighborCachelockNameIndex) {
locking.AddGLock(neighborCacheprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *neighborCacheRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(neighborCacheprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *neighborCacheRWMutex) NestedUnlock(i neighborCachelockNameIndex) {
m.mu.Unlock()
locking.DelGLock(neighborCacheprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *neighborCacheRWMutex) RLock() {
locking.AddGLock(neighborCacheprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *neighborCacheRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(neighborCacheprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *neighborCacheRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *neighborCacheRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *neighborCacheRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var neighborCacheprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func neighborCacheinitLockNames() {}
func init() {
neighborCacheinitLockNames()
neighborCacheprefixIndex = locking.NewMutexClass(reflect.TypeOf(neighborCacheRWMutex{}), neighborCachelockNames)
}

View File

@@ -0,0 +1,646 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
const (
// immediateDuration is a duration of zero for scheduling work that needs to
// be done immediately but asynchronously to avoid deadlock.
immediateDuration time.Duration = 0
)
// NeighborEntry describes a neighboring device in the local network.
type NeighborEntry struct {
Addr tcpip.Address
LinkAddr tcpip.LinkAddress
State NeighborState
UpdatedAt tcpip.MonotonicTime
}
// NeighborState defines the state of a NeighborEntry within the Neighbor
// Unreachability Detection state machine, as per RFC 4861 section 7.3.2 and
// RFC 7048.
type NeighborState uint8
const (
// Unknown means reachability has not been verified yet. This is the initial
// state of entries that have been created automatically by the Neighbor
// Unreachability Detection state machine.
Unknown NeighborState = iota
// Incomplete means that there is an outstanding request to resolve the
// address.
Incomplete
// Reachable means the path to the neighbor is functioning properly for both
// receive and transmit paths.
Reachable
// Stale means reachability to the neighbor is unknown, but packets are still
// able to be transmitted to the possibly stale link address.
Stale
// Delay means reachability to the neighbor is unknown and pending
// confirmation from an upper-level protocol like TCP, but packets are still
// able to be transmitted to the possibly stale link address.
Delay
// Probe means a reachability confirmation is actively being sought by
// periodically retransmitting reachability probes until a reachability
// confirmation is received, or until the maximum number of probes has been
// sent.
Probe
// Static describes entries that have been explicitly added by the user. They
// do not expire and are not deleted until explicitly removed.
Static
// Unreachable means reachability confirmation failed; the maximum number of
// reachability probes has been sent and no replies have been received.
//
// TODO(gvisor.dev/issue/5472): Add the following sentence when we implement
// RFC 7048: "Packets continue to be sent to the neighbor while
// re-attempting to resolve the address."
Unreachable
)
type timer struct {
// done indicates to the timer that the timer was stopped.
done *bool
timer tcpip.Timer
}
// neighborEntry implements a neighbor entry's individual node behavior, as per
// RFC 4861 section 7.3.3. Neighbor Unreachability Detection operates in
// parallel with the sending of packets to a neighbor, necessitating the
// entry's lock to be acquired for all operations.
type neighborEntry struct {
neighborEntryEntry
cache *neighborCache
// nudState points to the Neighbor Unreachability Detection configuration.
nudState *NUDState
mu struct {
neighborEntryRWMutex
neigh NeighborEntry
// done is closed when address resolution is complete. It is nil iff s is
// incomplete and resolution is not yet in progress.
done chan struct{}
// onResolve is called with the result of address resolution.
onResolve []func(LinkResolutionResult)
isRouter bool
timer timer
}
}
// newNeighborEntry creates a neighbor cache entry starting at the default
// state, Unknown. Transition out of Unknown by calling either
// `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created
// neighborEntry.
func newNeighborEntry(cache *neighborCache, remoteAddr tcpip.Address, nudState *NUDState) *neighborEntry {
n := &neighborEntry{
cache: cache,
nudState: nudState,
}
n.mu.Lock()
n.mu.neigh = NeighborEntry{
Addr: remoteAddr,
State: Unknown,
}
n.mu.Unlock()
return n
}
// newStaticNeighborEntry creates a neighbor cache entry starting at the
// Static state. The entry can only transition out of Static by directly
// calling `setStateLocked`.
func newStaticNeighborEntry(cache *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry {
entry := NeighborEntry{
Addr: addr,
LinkAddr: linkAddr,
State: Static,
UpdatedAt: cache.nic.stack.clock.NowMonotonic(),
}
n := &neighborEntry{
cache: cache,
nudState: state,
}
n.mu.Lock()
n.mu.neigh = entry
n.mu.Unlock()
return n
}
// notifyCompletionLocked notifies those waiting for address resolution, with
// the link address if resolution completed successfully.
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) notifyCompletionLocked(err tcpip.Error) {
res := LinkResolutionResult{LinkAddress: e.mu.neigh.LinkAddr, Err: err}
for _, callback := range e.mu.onResolve {
callback(res)
}
e.mu.onResolve = nil
if ch := e.mu.done; ch != nil {
close(ch)
e.mu.done = nil
// Dequeue the pending packets asynchronously to not hold up the current
// goroutine as writing packets may be a costly operation.
//
// At the time of writing, when writing packets, a neighbor's link address
// is resolved (which ends up obtaining the entry's lock) while holding the
// link resolution queue's lock. Dequeuing packets asynchronously avoids a
// lock ordering violation.
//
// NB: this is equivalent to spawning a goroutine directly using the go
// keyword but allows tests that use manual clocks to deterministically
// wait for this work to complete.
e.cache.nic.stack.clock.AfterFunc(0, func() {
e.cache.nic.linkResQueue.dequeue(ch, e.mu.neigh.LinkAddr, err)
})
}
}
// dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has
// been added.
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) dispatchAddEventLocked() {
if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil {
nudDisp.OnNeighborAdded(e.cache.nic.id, e.mu.neigh)
}
}
// dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry
// has changed state or link-layer address.
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) dispatchChangeEventLocked() {
if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil {
nudDisp.OnNeighborChanged(e.cache.nic.id, e.mu.neigh)
}
}
// dispatchRemoveEventLocked signals to stack's NUD Dispatcher that the entry
// has been removed.
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) dispatchRemoveEventLocked() {
if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil {
nudDisp.OnNeighborRemoved(e.cache.nic.id, e.mu.neigh)
}
}
// cancelTimerLocked cancels the currently scheduled action, if there is one.
// Entries in Unknown, Stale, or Static state do not have a scheduled action.
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) cancelTimerLocked() {
if e.mu.timer.timer != nil {
e.mu.timer.timer.Stop()
*e.mu.timer.done = true
e.mu.timer = timer{}
}
}
// removeLocked prepares the entry for removal.
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) removeLocked() {
e.mu.neigh.UpdatedAt = e.cache.nic.stack.clock.NowMonotonic()
e.dispatchRemoveEventLocked()
// Set state to unknown to invalidate this entry if it's cached in a Route.
e.setStateLocked(Unknown)
e.cancelTimerLocked()
// TODO(https://gvisor.dev/issues/5583): test the case where this function is
// called during resolution; that can happen in at least these scenarios:
//
// - manual address removal during resolution
//
// - neighbor cache eviction during resolution
e.notifyCompletionLocked(&tcpip.ErrAborted{})
}
// setStateLocked transitions the entry to the specified state immediately.
//
// Follows the logic defined in RFC 4861 section 7.3.3.
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) setStateLocked(next NeighborState) {
e.cancelTimerLocked()
prev := e.mu.neigh.State
e.mu.neigh.State = next
e.mu.neigh.UpdatedAt = e.cache.nic.stack.clock.NowMonotonic()
config := e.nudState.Config()
switch next {
case Incomplete:
panic(fmt.Sprintf("should never transition to Incomplete with setStateLocked; neigh = %#v, prev state = %s", e.mu.neigh, prev))
case Reachable:
// Protected by e.mu.
done := false
e.mu.timer = timer{
done: &done,
timer: e.cache.nic.stack.Clock().AfterFunc(e.nudState.ReachableTime(), func() {
e.mu.Lock()
defer e.mu.Unlock()
if done {
// The timer was stopped because the entry changed state.
return
}
e.setStateLocked(Stale)
e.dispatchChangeEventLocked()
}),
}
case Delay:
// Protected by e.mu.
done := false
e.mu.timer = timer{
done: &done,
timer: e.cache.nic.stack.Clock().AfterFunc(config.DelayFirstProbeTime, func() {
e.mu.Lock()
defer e.mu.Unlock()
if done {
// The timer was stopped because the entry changed state.
return
}
e.setStateLocked(Probe)
e.dispatchChangeEventLocked()
}),
}
case Probe:
// Protected by e.mu.
done := false
remaining := config.MaxUnicastProbes
addr := e.mu.neigh.Addr
linkAddr := e.mu.neigh.LinkAddr
// Send a probe in another gorountine to free this thread of execution
// for finishing the state transition. This is necessary to escape the
// currently held lock so we can send the probe message without holding
// a shared lock.
e.mu.timer = timer{
done: &done,
timer: e.cache.nic.stack.Clock().AfterFunc(immediateDuration, func() {
var err tcpip.Error = &tcpip.ErrTimeout{}
if remaining != 0 {
err = e.cache.linkRes.LinkAddressRequest(addr, tcpip.Address{} /* localAddr */, linkAddr)
}
e.mu.Lock()
defer e.mu.Unlock()
if done {
// The timer was stopped because the entry changed state.
return
}
if err != nil {
e.setStateLocked(Unreachable)
e.notifyCompletionLocked(err)
e.dispatchChangeEventLocked()
return
}
remaining--
e.mu.timer.timer.Reset(config.RetransmitTimer)
}),
}
case Unreachable:
case Unknown, Stale, Static:
// Do nothing
default:
panic(fmt.Sprintf("Invalid state transition from %q to %q", prev, next))
}
}
// handlePacketQueuedLocked advances the state machine according to a packet
// being queued for outgoing transmission.
//
// Follows the logic defined in RFC 4861 section 7.3.3.
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
switch e.mu.neigh.State {
case Unknown, Unreachable:
prev := e.mu.neigh.State
e.mu.neigh.State = Incomplete
e.mu.neigh.UpdatedAt = e.cache.nic.stack.clock.NowMonotonic()
switch prev {
case Unknown:
e.dispatchAddEventLocked()
case Unreachable:
e.dispatchChangeEventLocked()
e.cache.nic.stats.neighbor.unreachableEntryLookups.Increment()
}
config := e.nudState.Config()
// Protected by e.mu.
done := false
remaining := config.MaxMulticastProbes
addr := e.mu.neigh.Addr
// Send a probe in another gorountine to free this thread of execution
// for finishing the state transition. This is necessary to escape the
// currently held lock so we can send the probe message without holding
// a shared lock.
e.mu.timer = timer{
done: &done,
timer: e.cache.nic.stack.Clock().AfterFunc(immediateDuration, func() {
var err tcpip.Error = &tcpip.ErrTimeout{}
if remaining != 0 {
// As per RFC 4861 section 7.2.2:
//
// If the source address of the packet prompting the solicitation is
// the same as one of the addresses assigned to the outgoing interface,
// that address SHOULD be placed in the IP Source Address of the
// outgoing solicitation.
//
err = e.cache.linkRes.LinkAddressRequest(addr, localAddr, "" /* linkAddr */)
}
e.mu.Lock()
defer e.mu.Unlock()
if done {
// The timer was stopped because the entry changed state.
return
}
if err != nil {
e.setStateLocked(Unreachable)
e.notifyCompletionLocked(err)
e.dispatchChangeEventLocked()
return
}
remaining--
e.mu.timer.timer.Reset(config.RetransmitTimer)
}),
}
case Stale:
e.setStateLocked(Delay)
e.dispatchChangeEventLocked()
case Incomplete, Reachable, Delay, Probe, Static:
// Do nothing
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", e.mu.neigh.State))
}
}
// handleProbeLocked processes an incoming neighbor probe (e.g. ARP request or
// Neighbor Solicitation for ARP or NDP, respectively).
//
// Follows the logic defined in RFC 4861 section 7.2.3.
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) {
// Probes MUST be silently discarded if the target address is tentative, does
// not exist, or not bound to the NIC as per RFC 4861 section 7.2.3. These
// checks MUST be done by the NetworkEndpoint.
switch e.mu.neigh.State {
case Unknown:
e.mu.neigh.LinkAddr = remoteLinkAddr
e.setStateLocked(Stale)
e.dispatchAddEventLocked()
case Incomplete:
// "If an entry already exists, and the cached link-layer address
// differs from the one in the received Source Link-Layer option, the
// cached address should be replaced by the received address, and the
// entry's reachability state MUST be set to STALE."
// - RFC 4861 section 7.2.3
e.mu.neigh.LinkAddr = remoteLinkAddr
e.setStateLocked(Stale)
e.notifyCompletionLocked(nil)
e.dispatchChangeEventLocked()
case Reachable, Delay, Probe:
if e.mu.neigh.LinkAddr != remoteLinkAddr {
e.mu.neigh.LinkAddr = remoteLinkAddr
e.setStateLocked(Stale)
e.dispatchChangeEventLocked()
}
case Stale:
if e.mu.neigh.LinkAddr != remoteLinkAddr {
e.mu.neigh.LinkAddr = remoteLinkAddr
e.dispatchChangeEventLocked()
}
case Unreachable:
// TODO(gvisor.dev/issue/5472): Do not change the entry if the link
// address is the same, as per RFC 7048.
e.mu.neigh.LinkAddr = remoteLinkAddr
e.setStateLocked(Stale)
e.dispatchChangeEventLocked()
case Static:
// Do nothing
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", e.mu.neigh.State))
}
}
// handleConfirmationLocked processes an incoming neighbor confirmation
// (e.g. ARP reply or Neighbor Advertisement for ARP or NDP, respectively).
//
// Follows the state machine defined by RFC 4861 section 7.2.5.
//
// TODO(gvisor.dev/issue/2277): To protect against ARP poisoning and other
// attacks against NDP functions, Secure Neighbor Discovery (SEND) Protocol
// should be deployed where preventing access to the broadcast segment might
// not be possible. SEND uses RSA key pairs to produce Cryptographically
// Generated Addresses (CGA), as defined in RFC 3972. This ensures that the
// claimed source of an NDP message is the owner of the claimed address.
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) {
switch e.mu.neigh.State {
case Incomplete:
if len(linkAddr) == 0 {
// "If the link layer has addresses and no Target Link-Layer Address
// option is included, the receiving node SHOULD silently discard the
// received advertisement." - RFC 4861 section 7.2.5
e.cache.nic.stats.neighbor.droppedInvalidLinkAddressConfirmations.Increment()
break
}
e.mu.neigh.LinkAddr = linkAddr
if flags.Solicited {
e.setStateLocked(Reachable)
} else {
e.setStateLocked(Stale)
}
e.dispatchChangeEventLocked()
e.mu.isRouter = flags.IsRouter
e.notifyCompletionLocked(nil)
// "Note that the Override flag is ignored if the entry is in the
// INCOMPLETE state." - RFC 4861 section 7.2.5
case Reachable, Stale, Delay, Probe:
isLinkAddrDifferent := len(linkAddr) != 0 && e.mu.neigh.LinkAddr != linkAddr
if isLinkAddrDifferent {
if !flags.Override {
if e.mu.neigh.State == Reachable {
e.setStateLocked(Stale)
e.dispatchChangeEventLocked()
}
break
}
e.mu.neigh.LinkAddr = linkAddr
if !flags.Solicited {
if e.mu.neigh.State != Stale {
e.setStateLocked(Stale)
e.dispatchChangeEventLocked()
} else {
// Notify the LinkAddr change, even though NUD state hasn't changed.
e.dispatchChangeEventLocked()
}
break
}
}
if flags.Solicited && (flags.Override || !isLinkAddrDifferent) {
wasReachable := e.mu.neigh.State == Reachable
// Set state to Reachable again to refresh timers.
e.setStateLocked(Reachable)
e.notifyCompletionLocked(nil)
if !wasReachable {
e.dispatchChangeEventLocked()
}
}
if e.mu.isRouter && !flags.IsRouter && header.IsV6UnicastAddress(e.mu.neigh.Addr) {
// "In those cases where the IsRouter flag changes from TRUE to FALSE as
// a result of this update, the node MUST remove that router from the
// Default Router List and update the Destination Cache entries for all
// destinations using that neighbor as a router as specified in Section
// 7.3.3. This is needed to detect when a node that is used as a router
// stops forwarding packets due to being configured as a host."
// - RFC 4861 section 7.2.5
//
// TODO(gvisor.dev/issue/4085): Remove the special casing we do for IPv6
// here.
ep := e.cache.nic.getNetworkEndpoint(header.IPv6ProtocolNumber)
if ep == nil {
panic(fmt.Sprintf("have a neighbor entry for an IPv6 router but no IPv6 network endpoint"))
}
if ndpEP, ok := ep.(NDPEndpoint); ok {
ndpEP.InvalidateDefaultRouter(e.mu.neigh.Addr)
}
}
e.mu.isRouter = flags.IsRouter
case Unknown, Unreachable, Static:
// Do nothing
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", e.mu.neigh.State))
}
}
// handleUpperLevelConfirmation processes an incoming upper-level protocol
// (e.g. TCP acknowledgements) reachability confirmation.
func (e *neighborEntry) handleUpperLevelConfirmation() {
tryHandleConfirmation := func() bool {
switch e.mu.neigh.State {
case Stale, Delay, Probe:
return true
case Reachable:
// Avoid setStateLocked; Timer.Reset is cheaper.
//
// Note that setting the timer does not need to be protected by the
// entry's write lock since we do not modify the timer pointer, but the
// time the timer should fire. The timer should have internal locks to
// synchronize timer resets changes with the clock.
e.mu.timer.timer.Reset(e.nudState.ReachableTime())
return false
case Unknown, Incomplete, Unreachable, Static:
// Do nothing
return false
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", e.mu.neigh.State))
}
}
e.mu.RLock()
needsTransition := tryHandleConfirmation()
e.mu.RUnlock()
if !needsTransition {
return
}
// We need to transition the neighbor to Reachable so take the write lock and
// perform the transition, but only if we still need the transition since the
// state could have changed since we dropped the read lock above.
e.mu.Lock()
defer e.mu.Unlock()
if needsTransition := tryHandleConfirmation(); needsTransition {
e.setStateLocked(Reachable)
e.dispatchChangeEventLocked()
}
}
// getRemoteLinkAddress returns the entry's link address and whether that link
// address is valid.
func (e *neighborEntry) getRemoteLinkAddress() (tcpip.LinkAddress, bool) {
e.mu.RLock()
defer e.mu.RUnlock()
switch e.mu.neigh.State {
case Reachable, Static, Delay, Probe:
return e.mu.neigh.LinkAddr, true
case Unknown, Incomplete, Unreachable, Stale:
return "", false
default:
panic(fmt.Sprintf("invalid state for neighbor entry %v: %v", e.mu.neigh, e.mu.neigh.State))
}
}

View File

@@ -0,0 +1,239 @@
package stack
// ElementMapper provides an identity mapping by default.
//
// This can be replaced to provide a struct that maps elements to linker
// objects, if they are not the same. An ElementMapper is not typically
// required if: Linker is left as is, Element is left as is, or Linker and
// Element are the same type.
type neighborEntryElementMapper struct{}
// linkerFor maps an Element to a Linker.
//
// This default implementation should be inlined.
//
//go:nosplit
func (neighborEntryElementMapper) linkerFor(elem *neighborEntry) *neighborEntry { return elem }
// List is an intrusive list. Entries can be added to or removed from the list
// in O(1) time and with no additional memory allocations.
//
// The zero value for List is an empty list ready to use.
//
// To iterate over a list (where l is a List):
//
// for e := l.Front(); e != nil; e = e.Next() {
// // do something with e.
// }
//
// +stateify savable
type neighborEntryList struct {
head *neighborEntry
tail *neighborEntry
}
// Reset resets list l to the empty state.
func (l *neighborEntryList) Reset() {
l.head = nil
l.tail = nil
}
// Empty returns true iff the list is empty.
//
//go:nosplit
func (l *neighborEntryList) Empty() bool {
return l.head == nil
}
// Front returns the first element of list l or nil.
//
//go:nosplit
func (l *neighborEntryList) Front() *neighborEntry {
return l.head
}
// Back returns the last element of list l or nil.
//
//go:nosplit
func (l *neighborEntryList) Back() *neighborEntry {
return l.tail
}
// Len returns the number of elements in the list.
//
// NOTE: This is an O(n) operation.
//
//go:nosplit
func (l *neighborEntryList) Len() (count int) {
for e := l.Front(); e != nil; e = (neighborEntryElementMapper{}.linkerFor(e)).Next() {
count++
}
return count
}
// PushFront inserts the element e at the front of list l.
//
//go:nosplit
func (l *neighborEntryList) PushFront(e *neighborEntry) {
linker := neighborEntryElementMapper{}.linkerFor(e)
linker.SetNext(l.head)
linker.SetPrev(nil)
if l.head != nil {
neighborEntryElementMapper{}.linkerFor(l.head).SetPrev(e)
} else {
l.tail = e
}
l.head = e
}
// PushFrontList inserts list m at the start of list l, emptying m.
//
//go:nosplit
func (l *neighborEntryList) PushFrontList(m *neighborEntryList) {
if l.head == nil {
l.head = m.head
l.tail = m.tail
} else if m.head != nil {
neighborEntryElementMapper{}.linkerFor(l.head).SetPrev(m.tail)
neighborEntryElementMapper{}.linkerFor(m.tail).SetNext(l.head)
l.head = m.head
}
m.head = nil
m.tail = nil
}
// PushBack inserts the element e at the back of list l.
//
//go:nosplit
func (l *neighborEntryList) PushBack(e *neighborEntry) {
linker := neighborEntryElementMapper{}.linkerFor(e)
linker.SetNext(nil)
linker.SetPrev(l.tail)
if l.tail != nil {
neighborEntryElementMapper{}.linkerFor(l.tail).SetNext(e)
} else {
l.head = e
}
l.tail = e
}
// PushBackList inserts list m at the end of list l, emptying m.
//
//go:nosplit
func (l *neighborEntryList) PushBackList(m *neighborEntryList) {
if l.head == nil {
l.head = m.head
l.tail = m.tail
} else if m.head != nil {
neighborEntryElementMapper{}.linkerFor(l.tail).SetNext(m.head)
neighborEntryElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
l.tail = m.tail
}
m.head = nil
m.tail = nil
}
// InsertAfter inserts e after b.
//
//go:nosplit
func (l *neighborEntryList) InsertAfter(b, e *neighborEntry) {
bLinker := neighborEntryElementMapper{}.linkerFor(b)
eLinker := neighborEntryElementMapper{}.linkerFor(e)
a := bLinker.Next()
eLinker.SetNext(a)
eLinker.SetPrev(b)
bLinker.SetNext(e)
if a != nil {
neighborEntryElementMapper{}.linkerFor(a).SetPrev(e)
} else {
l.tail = e
}
}
// InsertBefore inserts e before a.
//
//go:nosplit
func (l *neighborEntryList) InsertBefore(a, e *neighborEntry) {
aLinker := neighborEntryElementMapper{}.linkerFor(a)
eLinker := neighborEntryElementMapper{}.linkerFor(e)
b := aLinker.Prev()
eLinker.SetNext(a)
eLinker.SetPrev(b)
aLinker.SetPrev(e)
if b != nil {
neighborEntryElementMapper{}.linkerFor(b).SetNext(e)
} else {
l.head = e
}
}
// Remove removes e from l.
//
//go:nosplit
func (l *neighborEntryList) Remove(e *neighborEntry) {
linker := neighborEntryElementMapper{}.linkerFor(e)
prev := linker.Prev()
next := linker.Next()
if prev != nil {
neighborEntryElementMapper{}.linkerFor(prev).SetNext(next)
} else if l.head == e {
l.head = next
}
if next != nil {
neighborEntryElementMapper{}.linkerFor(next).SetPrev(prev)
} else if l.tail == e {
l.tail = prev
}
linker.SetNext(nil)
linker.SetPrev(nil)
}
// Entry is a default implementation of Linker. Users can add anonymous fields
// of this type to their structs to make them automatically implement the
// methods needed by List.
//
// +stateify savable
type neighborEntryEntry struct {
next *neighborEntry
prev *neighborEntry
}
// Next returns the entry that follows e in the list.
//
//go:nosplit
func (e *neighborEntryEntry) Next() *neighborEntry {
return e.next
}
// Prev returns the entry that precedes e in the list.
//
//go:nosplit
func (e *neighborEntryEntry) Prev() *neighborEntry {
return e.prev
}
// SetNext assigns 'entry' as the entry that follows e in the list.
//
//go:nosplit
func (e *neighborEntryEntry) SetNext(elem *neighborEntry) {
e.next = elem
}
// SetPrev assigns 'entry' as the entry that precedes e in the list.
//
//go:nosplit
func (e *neighborEntryEntry) SetPrev(elem *neighborEntry) {
e.prev = elem
}

View File

@@ -0,0 +1,96 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type neighborEntryRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var neighborEntrylockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type neighborEntrylockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *neighborEntryRWMutex) Lock() {
locking.AddGLock(neighborEntryprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *neighborEntryRWMutex) NestedLock(i neighborEntrylockNameIndex) {
locking.AddGLock(neighborEntryprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *neighborEntryRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(neighborEntryprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *neighborEntryRWMutex) NestedUnlock(i neighborEntrylockNameIndex) {
m.mu.Unlock()
locking.DelGLock(neighborEntryprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *neighborEntryRWMutex) RLock() {
locking.AddGLock(neighborEntryprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *neighborEntryRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(neighborEntryprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *neighborEntryRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *neighborEntryRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *neighborEntryRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var neighborEntryprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func neighborEntryinitLockNames() {}
func init() {
neighborEntryinitLockNames()
neighborEntryprefixIndex = locking.NewMutexClass(reflect.TypeOf(neighborEntryRWMutex{}), neighborEntrylockNames)
}

View File

@@ -0,0 +1,44 @@
// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Code generated by "stringer -type NeighborState"; DO NOT EDIT.
package stack
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[Unknown-0]
_ = x[Incomplete-1]
_ = x[Reachable-2]
_ = x[Stale-3]
_ = x[Delay-4]
_ = x[Probe-5]
_ = x[Static-6]
_ = x[Unreachable-7]
}
const _NeighborState_name = "UnknownIncompleteReachableStaleDelayProbeStaticUnreachable"
var _NeighborState_index = [...]uint8{0, 7, 17, 26, 31, 36, 41, 47, 58}
func (i NeighborState) String() string {
if i >= NeighborState(len(_NeighborState_index)-1) {
return "NeighborState(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _NeighborState_name[_NeighborState_index[i]:_NeighborState_index[i+1]]
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,96 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type nicRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var niclockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type niclockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *nicRWMutex) Lock() {
locking.AddGLock(nicprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *nicRWMutex) NestedLock(i niclockNameIndex) {
locking.AddGLock(nicprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *nicRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(nicprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *nicRWMutex) NestedUnlock(i niclockNameIndex) {
m.mu.Unlock()
locking.DelGLock(nicprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *nicRWMutex) RLock() {
locking.AddGLock(nicprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *nicRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(nicprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *nicRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *nicRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *nicRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var nicprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func nicinitLockNames() {}
func init() {
nicinitLockNames()
nicprefixIndex = locking.NewMutexClass(reflect.TypeOf(nicRWMutex{}), niclockNames)
}

View File

@@ -0,0 +1,84 @@
// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"gvisor.dev/gvisor/pkg/tcpip"
)
// +stateify savable
type sharedStats struct {
local tcpip.NICStats
multiCounterNICStats
}
// LINT.IfChange(multiCounterNICPacketStats)
// +stateify savable
type multiCounterNICPacketStats struct {
packets tcpip.MultiCounterStat
bytes tcpip.MultiCounterStat
}
func (m *multiCounterNICPacketStats) init(a, b *tcpip.NICPacketStats) {
m.packets.Init(a.Packets, b.Packets)
m.bytes.Init(a.Bytes, b.Bytes)
}
// LINT.ThenChange(../tcpip.go:NICPacketStats)
// LINT.IfChange(multiCounterNICNeighborStats)
// +stateify savable
type multiCounterNICNeighborStats struct {
unreachableEntryLookups tcpip.MultiCounterStat
droppedConfirmationForNoninitiatedNeighbor tcpip.MultiCounterStat
droppedInvalidLinkAddressConfirmations tcpip.MultiCounterStat
}
func (m *multiCounterNICNeighborStats) init(a, b *tcpip.NICNeighborStats) {
m.unreachableEntryLookups.Init(a.UnreachableEntryLookups, b.UnreachableEntryLookups)
m.droppedConfirmationForNoninitiatedNeighbor.Init(a.DroppedConfirmationForNoninitiatedNeighbor, b.DroppedConfirmationForNoninitiatedNeighbor)
m.droppedInvalidLinkAddressConfirmations.Init(a.DroppedInvalidLinkAddressConfirmations, b.DroppedInvalidLinkAddressConfirmations)
}
// LINT.ThenChange(../tcpip.go:NICNeighborStats)
// LINT.IfChange(multiCounterNICStats)
// +stateify savable
type multiCounterNICStats struct {
unknownL3ProtocolRcvdPacketCounts tcpip.MultiIntegralStatCounterMap
unknownL4ProtocolRcvdPacketCounts tcpip.MultiIntegralStatCounterMap
malformedL4RcvdPackets tcpip.MultiCounterStat
tx multiCounterNICPacketStats
txPacketsDroppedNoBufferSpace tcpip.MultiCounterStat
rx multiCounterNICPacketStats
disabledRx multiCounterNICPacketStats
neighbor multiCounterNICNeighborStats
}
func (m *multiCounterNICStats) init(a, b *tcpip.NICStats) {
m.unknownL3ProtocolRcvdPacketCounts.Init(a.UnknownL3ProtocolRcvdPacketCounts, b.UnknownL3ProtocolRcvdPacketCounts)
m.unknownL4ProtocolRcvdPacketCounts.Init(a.UnknownL4ProtocolRcvdPacketCounts, b.UnknownL4ProtocolRcvdPacketCounts)
m.malformedL4RcvdPackets.Init(a.MalformedL4RcvdPackets, b.MalformedL4RcvdPackets)
m.tx.init(&a.Tx, &b.Tx)
m.txPacketsDroppedNoBufferSpace.Init(a.TxPacketsDroppedNoBufferSpace, b.TxPacketsDroppedNoBufferSpace)
m.rx.init(&a.Rx, &b.Rx)
m.disabledRx.init(&a.DisabledRx, &b.DisabledRx)
m.neighbor.init(&a.Neighbor, &b.Neighbor)
}
// LINT.ThenChange(../tcpip.go:NICStats)

View File

@@ -0,0 +1,429 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"math"
"math/rand"
"sync"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
)
const (
// defaultBaseReachableTime is the default base duration for computing the
// random reachable time.
//
// Reachable time is the duration for which a neighbor is considered
// reachable after a positive reachability confirmation is received. It is a
// function of a uniformly distributed random value between the minimum and
// maximum random factors, multiplied by the base reachable time. Using a
// random component eliminates the possibility that Neighbor Unreachability
// Detection messages will synchronize with each other.
//
// Default taken from REACHABLE_TIME of RFC 4861 section 10.
defaultBaseReachableTime = 30 * time.Second
// minimumBaseReachableTime is the minimum base duration for computing the
// random reachable time.
//
// Minimum = 1ms
minimumBaseReachableTime = time.Millisecond
// defaultMinRandomFactor is the default minimum value of the random factor
// used for computing reachable time.
//
// Default taken from MIN_RANDOM_FACTOR of RFC 4861 section 10.
defaultMinRandomFactor = 0.5
// defaultMaxRandomFactor is the default maximum value of the random factor
// used for computing reachable time.
//
// The default value depends on the value of MinRandomFactor.
// If MinRandomFactor is less than MAX_RANDOM_FACTOR of RFC 4861 section 10,
// the value from the RFC will be used; otherwise, the default is
// MinRandomFactor multiplied by three.
defaultMaxRandomFactor = 1.5
// defaultRetransmitTimer is the default amount of time to wait between
// sending reachability probes.
//
// Default taken from RETRANS_TIMER of RFC 4861 section 10.
defaultRetransmitTimer = time.Second
// minimumRetransmitTimer is the minimum amount of time to wait between
// sending reachability probes.
//
// Note, RFC 4861 does not impose a minimum Retransmit Timer, but we do here
// to make sure the messages are not sent all at once. We also come to this
// value because in the RetransmitTimer field of a Router Advertisement, a
// value of 0 means unspecified, so the smallest valid value is 1. Note, the
// unit of the RetransmitTimer field in the Router Advertisement is
// milliseconds.
minimumRetransmitTimer = time.Millisecond
// defaultDelayFirstProbeTime is the default duration to wait for a
// non-Neighbor-Discovery related protocol to reconfirm reachability after
// entering the DELAY state. After this time, a reachability probe will be
// sent and the entry will transition to the PROBE state.
//
// Default taken from DELAY_FIRST_PROBE_TIME of RFC 4861 section 10.
defaultDelayFirstProbeTime = 5 * time.Second
// defaultMaxMulticastProbes is the default number of reachabililty probes
// to send before concluding negative reachability and deleting the neighbor
// entry from the INCOMPLETE state.
//
// Default taken from MAX_MULTICAST_SOLICIT of RFC 4861 section 10.
defaultMaxMulticastProbes = 3
// defaultMaxUnicastProbes is the default number of reachability probes to
// send before concluding retransmission from within the PROBE state should
// cease and the entry SHOULD be deleted.
//
// Default taken from MAX_UNICASE_SOLICIT of RFC 4861 section 10.
defaultMaxUnicastProbes = 3
// defaultMaxAnycastDelayTime is the default time in which the stack SHOULD
// delay sending a response for a random time between 0 and this time, if the
// target address is an anycast address.
//
// Default taken from MAX_ANYCAST_DELAY_TIME of RFC 4861 section 10.
defaultMaxAnycastDelayTime = time.Second
// defaultMaxReachbilityConfirmations is the default amount of unsolicited
// reachability confirmation messages a node MAY send to all-node multicast
// address when it determines its link-layer address has changed.
//
// Default taken from MAX_NEIGHBOR_ADVERTISEMENT of RFC 4861 section 10.
defaultMaxReachbilityConfirmations = 3
)
// NUDDispatcher is the interface integrators of netstack must implement to
// receive and handle NUD related events.
type NUDDispatcher interface {
// OnNeighborAdded will be called when a new entry is added to a NIC's (with
// ID nicID) neighbor table.
//
// This function is permitted to block indefinitely without interfering with
// the stack's operation.
//
// May be called concurrently.
OnNeighborAdded(tcpip.NICID, NeighborEntry)
// OnNeighborChanged will be called when an entry in a NIC's (with ID nicID)
// neighbor table changes state and/or link address.
//
// This function is permitted to block indefinitely without interfering with
// the stack's operation.
//
// May be called concurrently.
OnNeighborChanged(tcpip.NICID, NeighborEntry)
// OnNeighborRemoved will be called when an entry is removed from a NIC's
// (with ID nicID) neighbor table.
//
// This function is permitted to block indefinitely without interfering with
// the stack's operation.
//
// May be called concurrently.
OnNeighborRemoved(tcpip.NICID, NeighborEntry)
}
// ReachabilityConfirmationFlags describes the flags used within a reachability
// confirmation (e.g. ARP reply or Neighbor Advertisement for ARP or NDP,
// respectively).
type ReachabilityConfirmationFlags struct {
// Solicited indicates that the advertisement was sent in response to a
// reachability probe.
Solicited bool
// Override indicates that the reachability confirmation should override an
// existing neighbor cache entry and update the cached link-layer address.
// When Override is not set the confirmation will not update a cached
// link-layer address, but will update an existing neighbor cache entry for
// which no link-layer address is known.
Override bool
// IsRouter indicates that the sender is a router.
IsRouter bool
}
// NUDConfigurations is the NUD configurations for the netstack. This is used
// by the neighbor cache to operate the NUD state machine on each device in the
// local network.
//
// +stateify savable
type NUDConfigurations struct {
// BaseReachableTime is the base duration for computing the random reachable
// time.
//
// Reachable time is the duration for which a neighbor is considered
// reachable after a positive reachability confirmation is received. It is a
// function of uniformly distributed random value between minRandomFactor and
// maxRandomFactor multiplied by baseReachableTime. Using a random component
// eliminates the possibility that Neighbor Unreachability Detection messages
// will synchronize with each other.
//
// After this time, a neighbor entry will transition from REACHABLE to STALE
// state.
//
// Must be greater than 0.
BaseReachableTime time.Duration
// LearnBaseReachableTime enables learning BaseReachableTime during runtime
// from the neighbor discovery protocol, if supported.
//
// TODO(gvisor.dev/issue/2240): Implement this NUD configuration option.
LearnBaseReachableTime bool
// MinRandomFactor is the minimum value of the random factor used for
// computing reachable time.
//
// See BaseReachbleTime for more information on computing the reachable time.
//
// Must be greater than 0.
MinRandomFactor float32
// MaxRandomFactor is the maximum value of the random factor used for
// computing reachabile time.
//
// See BaseReachbleTime for more information on computing the reachable time.
//
// Must be great than or equal to MinRandomFactor.
MaxRandomFactor float32
// RetransmitTimer is the duration between retransmission of reachability
// probes in the PROBE state.
RetransmitTimer time.Duration
// LearnRetransmitTimer enables learning RetransmitTimer during runtime from
// the neighbor discovery protocol, if supported.
//
// TODO(gvisor.dev/issue/2241): Implement this NUD configuration option.
LearnRetransmitTimer bool
// DelayFirstProbeTime is the duration to wait for a non-Neighbor-Discovery
// related protocol to reconfirm reachability after entering the DELAY state.
// After this time, a reachability probe will be sent and the entry will
// transition to the PROBE state.
//
// Must be greater than 0.
DelayFirstProbeTime time.Duration
// MaxMulticastProbes is the number of reachability probes to send before
// concluding negative reachability and deleting the neighbor entry from the
// INCOMPLETE state.
//
// Must be greater than 0.
MaxMulticastProbes uint32
// MaxUnicastProbes is the number of reachability probes to send before
// concluding retransmission from within the PROBE state should cease and
// entry SHOULD be deleted.
//
// Must be greater than 0.
MaxUnicastProbes uint32
// MaxAnycastDelayTime is the time in which the stack SHOULD delay sending a
// response for a random time between 0 and this time, if the target address
// is an anycast address.
//
// TODO(gvisor.dev/issue/2242): Use this option when sending solicited
// neighbor confirmations to anycast addresses and proxying neighbor
// confirmations.
MaxAnycastDelayTime time.Duration
// MaxReachabilityConfirmations is the number of unsolicited reachability
// confirmation messages a node MAY send to all-node multicast address when
// it determines its link-layer address has changed.
//
// TODO(gvisor.dev/issue/2246): Discuss if implementation of this NUD
// configuration option is necessary.
MaxReachabilityConfirmations uint32
}
// DefaultNUDConfigurations returns a NUDConfigurations populated with default
// values defined by RFC 4861 section 10.
func DefaultNUDConfigurations() NUDConfigurations {
return NUDConfigurations{
BaseReachableTime: defaultBaseReachableTime,
LearnBaseReachableTime: true,
MinRandomFactor: defaultMinRandomFactor,
MaxRandomFactor: defaultMaxRandomFactor,
RetransmitTimer: defaultRetransmitTimer,
LearnRetransmitTimer: true,
DelayFirstProbeTime: defaultDelayFirstProbeTime,
MaxMulticastProbes: defaultMaxMulticastProbes,
MaxUnicastProbes: defaultMaxUnicastProbes,
MaxAnycastDelayTime: defaultMaxAnycastDelayTime,
MaxReachabilityConfirmations: defaultMaxReachbilityConfirmations,
}
}
// resetInvalidFields modifies an invalid NDPConfigurations with valid values.
// If invalid values are present in c, the corresponding default values will be
// used instead. This is needed to check, and conditionally fix, user-specified
// NUDConfigurations.
func (c *NUDConfigurations) resetInvalidFields() {
if c.BaseReachableTime < minimumBaseReachableTime {
c.BaseReachableTime = defaultBaseReachableTime
}
if c.MinRandomFactor <= 0 {
c.MinRandomFactor = defaultMinRandomFactor
}
if c.MaxRandomFactor < c.MinRandomFactor {
c.MaxRandomFactor = calcMaxRandomFactor(c.MinRandomFactor)
}
if c.RetransmitTimer < minimumRetransmitTimer {
c.RetransmitTimer = defaultRetransmitTimer
}
if c.DelayFirstProbeTime == 0 {
c.DelayFirstProbeTime = defaultDelayFirstProbeTime
}
if c.MaxMulticastProbes == 0 {
c.MaxMulticastProbes = defaultMaxMulticastProbes
}
if c.MaxUnicastProbes == 0 {
c.MaxUnicastProbes = defaultMaxUnicastProbes
}
}
// calcMaxRandomFactor calculates the maximum value of the random factor used
// for computing reachable time. This function is necessary for when the
// default specified in RFC 4861 section 10 is less than the current
// MinRandomFactor.
//
// Assumes minRandomFactor is positive since validation of the minimum value
// should come before the validation of the maximum.
func calcMaxRandomFactor(minRandomFactor float32) float32 {
if minRandomFactor > defaultMaxRandomFactor {
return minRandomFactor * 3
}
return defaultMaxRandomFactor
}
// +stateify savable
type nudStateMu struct {
sync.RWMutex `state:"nosave"`
config NUDConfigurations
// reachableTime is the duration to wait for a REACHABLE entry to
// transition into STALE after inactivity. This value is calculated with
// the algorithm defined in RFC 4861 section 6.3.2.
reachableTime time.Duration
expiration tcpip.MonotonicTime
prevBaseReachableTime time.Duration
prevMinRandomFactor float32
prevMaxRandomFactor float32
}
// NUDState stores states needed for calculating reachable time.
//
// +stateify savable
type NUDState struct {
clock tcpip.Clock
// TODO(b/341946753): Restore when netstack is savable.
rng *rand.Rand `state:"nosave"`
mu nudStateMu
}
// NewNUDState returns new NUDState using c as configuration and the specified
// random number generator for use in recomputing ReachableTime.
func NewNUDState(c NUDConfigurations, clock tcpip.Clock, rng *rand.Rand) *NUDState {
s := &NUDState{
clock: clock,
rng: rng,
}
s.mu.config = c
return s
}
// Config returns the NUD configuration.
func (s *NUDState) Config() NUDConfigurations {
s.mu.RLock()
defer s.mu.RUnlock()
return s.mu.config
}
// SetConfig replaces the existing NUD configurations with c.
func (s *NUDState) SetConfig(c NUDConfigurations) {
s.mu.Lock()
defer s.mu.Unlock()
s.mu.config = c
}
// ReachableTime returns the duration to wait for a REACHABLE entry to
// transition into STALE after inactivity. This value is recalculated for new
// values of BaseReachableTime, MinRandomFactor, and MaxRandomFactor using the
// algorithm defined in RFC 4861 section 6.3.2.
func (s *NUDState) ReachableTime() time.Duration {
s.mu.Lock()
defer s.mu.Unlock()
if s.clock.NowMonotonic().After(s.mu.expiration) ||
s.mu.config.BaseReachableTime != s.mu.prevBaseReachableTime ||
s.mu.config.MinRandomFactor != s.mu.prevMinRandomFactor ||
s.mu.config.MaxRandomFactor != s.mu.prevMaxRandomFactor {
s.recomputeReachableTimeLocked()
}
return s.mu.reachableTime
}
// recomputeReachableTimeLocked forces a recalculation of ReachableTime using
// the algorithm defined in RFC 4861 section 6.3.2.
//
// This SHOULD automatically be invoked during certain situations, as per
// RFC 4861 section 6.3.4:
//
// If the received Reachable Time value is non-zero, the host SHOULD set its
// BaseReachableTime variable to the received value. If the new value
// differs from the previous value, the host SHOULD re-compute a new random
// ReachableTime value. ReachableTime is computed as a uniformly
// distributed random value between MIN_RANDOM_FACTOR and MAX_RANDOM_FACTOR
// times the BaseReachableTime. Using a random component eliminates the
// possibility that Neighbor Unreachability Detection messages will
// synchronize with each other.
//
// In most cases, the advertised Reachable Time value will be the same in
// consecutive Router Advertisements, and a host's BaseReachableTime rarely
// changes. In such cases, an implementation SHOULD ensure that a new
// random value gets re-computed at least once every few hours.
//
// s.mu MUST be locked for writing.
func (s *NUDState) recomputeReachableTimeLocked() {
s.mu.prevBaseReachableTime = s.mu.config.BaseReachableTime
s.mu.prevMinRandomFactor = s.mu.config.MinRandomFactor
s.mu.prevMaxRandomFactor = s.mu.config.MaxRandomFactor
randomFactor := s.mu.config.MinRandomFactor + s.rng.Float32()*(s.mu.config.MaxRandomFactor-s.mu.config.MinRandomFactor)
// Check for overflow, given that minRandomFactor and maxRandomFactor are
// guaranteed to be positive numbers.
if math.MaxInt64/randomFactor < float32(s.mu.config.BaseReachableTime) {
s.mu.reachableTime = time.Duration(math.MaxInt64)
} else if randomFactor == 1 {
// Avoid loss of precision when a large base reachable time is used.
s.mu.reachableTime = s.mu.config.BaseReachableTime
} else {
reachableTime := int64(float32(s.mu.config.BaseReachableTime) * randomFactor)
s.mu.reachableTime = time.Duration(reachableTime)
}
s.mu.expiration = s.clock.NowMonotonic().Add(2 * time.Hour)
}

View File

@@ -0,0 +1,769 @@
// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at //
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"fmt"
"io"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
type headerType int
const (
virtioNetHeader headerType = iota
linkHeader
networkHeader
transportHeader
numHeaderType
)
var pkPool = sync.Pool{
New: func() any {
return &PacketBuffer{}
},
}
// PacketBufferOptions specifies options for PacketBuffer creation.
type PacketBufferOptions struct {
// ReserveHeaderBytes is the number of bytes to reserve for headers. Total
// number of bytes pushed onto the headers must not exceed this value.
ReserveHeaderBytes int
// Payload is the initial unparsed data for the new packet. If set, it will
// be owned by the new packet.
Payload buffer.Buffer
// IsForwardedPacket identifies that the PacketBuffer being created is for a
// forwarded packet.
IsForwardedPacket bool
// OnRelease is a function to be run when the packet buffer is no longer
// referenced (released back to the pool).
OnRelease func()
}
// A PacketBuffer contains all the data of a network packet.
//
// As a PacketBuffer traverses up the stack, it may be necessary to pass it to
// multiple endpoints.
//
// The whole packet is expected to be a series of bytes in the following order:
// LinkHeader, NetworkHeader, TransportHeader, and Data. Any of them can be
// empty. Use of PacketBuffer in any other order is unsupported.
//
// PacketBuffer must be created with NewPacketBuffer, which sets the initial
// reference count to 1. Owners should call `DecRef()` when they are finished
// with the buffer to return it to the pool.
//
// Internal structure: A PacketBuffer holds a pointer to buffer.Buffer, which
// exposes a logically-contiguous byte storage. The underlying storage structure
// is abstracted out, and should not be a concern here for most of the time.
//
// |- reserved ->|
// |--->| consumed (incoming)
// 0 V V
// +--------+----+----+--------------------+
// | | | | current data ... | (buf)
// +--------+----+----+--------------------+
// ^ |
// |<---| pushed (outgoing)
//
// When a PacketBuffer is created, a `reserved` header region can be specified,
// which stack pushes headers in this region for an outgoing packet. There could
// be no such region for an incoming packet, and `reserved` is 0. The value of
// `reserved` never changes in the entire lifetime of the packet.
//
// Outgoing Packet: When a header is pushed, `pushed` gets incremented by the
// pushed length, and the current value is stored for each header. PacketBuffer
// subtracts this value from `reserved` to compute the starting offset of each
// header in `buf`.
//
// Incoming Packet: When a header is consumed (a.k.a. parsed), the current
// `consumed` value is stored for each header, and it gets incremented by the
// consumed length. PacketBuffer adds this value to `reserved` to compute the
// starting offset of each header in `buf`.
//
// +stateify savable
type PacketBuffer struct {
_ sync.NoCopy
packetBufferRefs
// buf is the underlying buffer for the packet. See struct level docs for
// details.
buf buffer.Buffer
reserved int
pushed int
consumed int
// headers stores metadata about each header.
headers [numHeaderType]headerInfo
// NetworkProtocolNumber is only valid when NetworkHeader().View().IsEmpty()
// returns false.
// TODO(gvisor.dev/issue/3574): Remove the separately passed protocol
// numbers in registration APIs that take a PacketBuffer.
NetworkProtocolNumber tcpip.NetworkProtocolNumber
// TransportProtocol is only valid if it is non zero.
// TODO(gvisor.dev/issue/3810): This and the network protocol number should
// be moved into the headerinfo. This should resolve the validity issue.
TransportProtocolNumber tcpip.TransportProtocolNumber
// Hash is the transport layer hash of this packet. A value of zero
// indicates no valid hash has been set.
Hash uint32
// Owner is implemented by task to get the uid and gid.
// Only set for locally generated packets.
Owner tcpip.PacketOwner
// The following fields are only set by the qdisc layer when the packet
// is added to a queue.
EgressRoute RouteInfo
GSOOptions GSO
// snatDone indicates if the packet's source has been manipulated as per
// iptables NAT table.
snatDone bool
// dnatDone indicates if the packet's destination has been manipulated as per
// iptables NAT table.
dnatDone bool
// PktType indicates the SockAddrLink.PacketType of the packet as defined in
// https://www.man7.org/linux/man-pages/man7/packet.7.html.
PktType tcpip.PacketType
// NICID is the ID of the last interface the network packet was handled at.
NICID tcpip.NICID
// RXChecksumValidated indicates that checksum verification may be
// safely skipped.
RXChecksumValidated bool
// NetworkPacketInfo holds an incoming packet's network-layer information.
NetworkPacketInfo NetworkPacketInfo
tuple *tuple
// onRelease is a function to be run when the packet buffer is no longer
// referenced (released back to the pool).
onRelease func() `state:"nosave"`
}
// NewPacketBuffer creates a new PacketBuffer with opts.
func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer {
pk := pkPool.Get().(*PacketBuffer)
pk.reset()
if opts.ReserveHeaderBytes != 0 {
v := buffer.NewViewSize(opts.ReserveHeaderBytes)
pk.buf.Append(v)
pk.reserved = opts.ReserveHeaderBytes
}
if opts.Payload.Size() > 0 {
pk.buf.Merge(&opts.Payload)
}
pk.NetworkPacketInfo.IsForwardedPacket = opts.IsForwardedPacket
pk.onRelease = opts.OnRelease
pk.InitRefs()
return pk
}
// IncRef increments the PacketBuffer's refcount.
func (pk *PacketBuffer) IncRef() *PacketBuffer {
pk.packetBufferRefs.IncRef()
return pk
}
// DecRef decrements the PacketBuffer's refcount. If the refcount is
// decremented to zero, the PacketBuffer is returned to the PacketBuffer
// pool.
func (pk *PacketBuffer) DecRef() {
pk.packetBufferRefs.DecRef(func() {
if pk.onRelease != nil {
pk.onRelease()
}
pk.buf.Release()
pkPool.Put(pk)
})
}
func (pk *PacketBuffer) reset() {
*pk = PacketBuffer{}
}
// ReservedHeaderBytes returns the number of bytes initially reserved for
// headers.
func (pk *PacketBuffer) ReservedHeaderBytes() int {
return pk.reserved
}
// AvailableHeaderBytes returns the number of bytes currently available for
// headers. This is relevant to PacketHeader.Push method only.
func (pk *PacketBuffer) AvailableHeaderBytes() int {
return pk.reserved - pk.pushed
}
// VirtioNetHeader returns the handle to virtio-layer header.
func (pk *PacketBuffer) VirtioNetHeader() PacketHeader {
return PacketHeader{
pk: pk,
typ: virtioNetHeader,
}
}
// LinkHeader returns the handle to link-layer header.
func (pk *PacketBuffer) LinkHeader() PacketHeader {
return PacketHeader{
pk: pk,
typ: linkHeader,
}
}
// NetworkHeader returns the handle to network-layer header.
func (pk *PacketBuffer) NetworkHeader() PacketHeader {
return PacketHeader{
pk: pk,
typ: networkHeader,
}
}
// TransportHeader returns the handle to transport-layer header.
func (pk *PacketBuffer) TransportHeader() PacketHeader {
return PacketHeader{
pk: pk,
typ: transportHeader,
}
}
// HeaderSize returns the total size of all headers in bytes.
func (pk *PacketBuffer) HeaderSize() int {
return pk.pushed + pk.consumed
}
// Size returns the size of packet in bytes.
func (pk *PacketBuffer) Size() int {
return int(pk.buf.Size()) - pk.headerOffset()
}
// MemSize returns the estimation size of the pk in memory, including backing
// buffer data.
func (pk *PacketBuffer) MemSize() int {
return int(pk.buf.Size()) + PacketBufferStructSize
}
// Data returns the handle to data portion of pk.
func (pk *PacketBuffer) Data() PacketData {
return PacketData{pk: pk}
}
// AsSlices returns the underlying storage of the whole packet.
//
// Note that AsSlices can allocate a lot. In hot paths it may be preferable to
// iterate over a PacketBuffer's data via AsViewList.
func (pk *PacketBuffer) AsSlices() [][]byte {
vl := pk.buf.AsViewList()
views := make([][]byte, 0, vl.Len())
offset := pk.headerOffset()
pk.buf.SubApply(offset, int(pk.buf.Size())-offset, func(v *buffer.View) {
views = append(views, v.AsSlice())
})
return views
}
// AsViewList returns the list of Views backing the PacketBuffer along with the
// header offset into them. Users may not save or modify the ViewList returned.
func (pk *PacketBuffer) AsViewList() (buffer.ViewList, int) {
return pk.buf.AsViewList(), pk.headerOffset()
}
// ToBuffer returns a caller-owned copy of the underlying storage of the whole
// packet.
func (pk *PacketBuffer) ToBuffer() buffer.Buffer {
b := pk.buf.Clone()
b.TrimFront(int64(pk.headerOffset()))
return b
}
// ToView returns a caller-owned copy of the underlying storage of the whole
// packet as a view.
func (pk *PacketBuffer) ToView() *buffer.View {
p := buffer.NewView(int(pk.buf.Size()))
offset := pk.headerOffset()
pk.buf.SubApply(offset, int(pk.buf.Size())-offset, func(v *buffer.View) {
p.Write(v.AsSlice())
})
return p
}
func (pk *PacketBuffer) headerOffset() int {
return pk.reserved - pk.pushed
}
func (pk *PacketBuffer) headerOffsetOf(typ headerType) int {
return pk.reserved + pk.headers[typ].offset
}
func (pk *PacketBuffer) dataOffset() int {
return pk.reserved + pk.consumed
}
func (pk *PacketBuffer) push(typ headerType, size int) []byte {
h := &pk.headers[typ]
if h.length > 0 {
panic(fmt.Sprintf("push(%s, %d) called after previous push", typ, size))
}
if pk.pushed+size > pk.reserved {
panic(fmt.Sprintf("push(%s, %d) overflows; pushed=%d reserved=%d", typ, size, pk.pushed, pk.reserved))
}
pk.pushed += size
h.offset = -pk.pushed
h.length = size
view := pk.headerView(typ)
return view.AsSlice()
}
func (pk *PacketBuffer) consume(typ headerType, size int) (v []byte, consumed bool) {
h := &pk.headers[typ]
if h.length > 0 {
panic(fmt.Sprintf("consume must not be called twice: type %s", typ))
}
if pk.reserved+pk.consumed+size > int(pk.buf.Size()) {
return nil, false
}
h.offset = pk.consumed
h.length = size
pk.consumed += size
view := pk.headerView(typ)
return view.AsSlice(), true
}
func (pk *PacketBuffer) headerView(typ headerType) buffer.View {
h := &pk.headers[typ]
if h.length == 0 {
return buffer.View{}
}
v, ok := pk.buf.PullUp(pk.headerOffsetOf(typ), h.length)
if !ok {
panic("PullUp failed")
}
return v
}
// Clone makes a semi-deep copy of pk. The underlying packet payload is
// shared. Hence, no modifications is done to underlying packet payload.
func (pk *PacketBuffer) Clone() *PacketBuffer {
newPk := pkPool.Get().(*PacketBuffer)
newPk.reset()
newPk.buf = pk.buf.Clone()
newPk.reserved = pk.reserved
newPk.pushed = pk.pushed
newPk.consumed = pk.consumed
newPk.headers = pk.headers
newPk.Hash = pk.Hash
newPk.Owner = pk.Owner
newPk.GSOOptions = pk.GSOOptions
newPk.NetworkProtocolNumber = pk.NetworkProtocolNumber
newPk.dnatDone = pk.dnatDone
newPk.snatDone = pk.snatDone
newPk.TransportProtocolNumber = pk.TransportProtocolNumber
newPk.PktType = pk.PktType
newPk.NICID = pk.NICID
newPk.RXChecksumValidated = pk.RXChecksumValidated
newPk.NetworkPacketInfo = pk.NetworkPacketInfo
newPk.tuple = pk.tuple
newPk.InitRefs()
return newPk
}
// ReserveHeaderBytes prepends reserved space for headers at the front
// of the underlying buf. Can only be called once per packet.
func (pk *PacketBuffer) ReserveHeaderBytes(reserved int) {
if pk.reserved != 0 {
panic(fmt.Sprintf("ReserveHeaderBytes(...) called on packet with reserved=%d, want reserved=0", pk.reserved))
}
pk.reserved = reserved
pk.buf.Prepend(buffer.NewViewSize(reserved))
}
// Network returns the network header as a header.Network.
//
// Network should only be called when NetworkHeader has been set.
func (pk *PacketBuffer) Network() header.Network {
switch netProto := pk.NetworkProtocolNumber; netProto {
case header.IPv4ProtocolNumber:
return header.IPv4(pk.NetworkHeader().Slice())
case header.IPv6ProtocolNumber:
return header.IPv6(pk.NetworkHeader().Slice())
default:
panic(fmt.Sprintf("unknown network protocol number %d", netProto))
}
}
// CloneToInbound makes a semi-deep copy of the packet buffer (similar to
// Clone) to be used as an inbound packet.
//
// See PacketBuffer.Data for details about how a packet buffer holds an inbound
// packet.
func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
newPk := pkPool.Get().(*PacketBuffer)
newPk.reset()
newPk.buf = pk.buf.Clone()
newPk.InitRefs()
// Treat unfilled header portion as reserved.
newPk.reserved = pk.AvailableHeaderBytes()
newPk.tuple = pk.tuple
return newPk
}
// DeepCopyForForwarding creates a deep copy of the packet buffer for
// forwarding.
//
// The returned packet buffer will have the network and transport headers
// set if the original packet buffer did.
func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBuffer {
payload := BufferSince(pk.NetworkHeader())
defer payload.Release()
newPk := NewPacketBuffer(PacketBufferOptions{
ReserveHeaderBytes: reservedHeaderBytes,
Payload: payload.DeepClone(),
IsForwardedPacket: true,
})
{
consumeBytes := len(pk.NetworkHeader().Slice())
if _, consumed := newPk.NetworkHeader().Consume(consumeBytes); !consumed {
panic(fmt.Sprintf("expected to consume network header %d bytes from new packet", consumeBytes))
}
newPk.NetworkProtocolNumber = pk.NetworkProtocolNumber
}
{
consumeBytes := len(pk.TransportHeader().Slice())
if _, consumed := newPk.TransportHeader().Consume(consumeBytes); !consumed {
panic(fmt.Sprintf("expected to consume transport header %d bytes from new packet", consumeBytes))
}
newPk.TransportProtocolNumber = pk.TransportProtocolNumber
}
newPk.tuple = pk.tuple
return newPk
}
// headerInfo stores metadata about a header in a packet.
//
// +stateify savable
type headerInfo struct {
// offset is the offset of the header in pk.buf relative to
// pk.buf[pk.reserved]. See the PacketBuffer struct for details.
offset int
// length is the length of this header.
length int
}
// PacketHeader is a handle object to a header in the underlying packet.
type PacketHeader struct {
pk *PacketBuffer
typ headerType
}
// View returns an caller-owned copy of the underlying storage of h as a
// *buffer.View.
func (h PacketHeader) View() *buffer.View {
view := h.pk.headerView(h.typ)
if view.Size() == 0 {
return nil
}
return view.Clone()
}
// Slice returns the underlying storage of h as a []byte. The returned slice
// should not be modified if the underlying packet could be shared, cloned, or
// borrowed.
func (h PacketHeader) Slice() []byte {
view := h.pk.headerView(h.typ)
return view.AsSlice()
}
// Push pushes size bytes in the front of its residing packet, and returns the
// backing storage. Callers may only call one of Push or Consume once on each
// header in the lifetime of the underlying packet.
func (h PacketHeader) Push(size int) []byte {
return h.pk.push(h.typ, size)
}
// Consume moves the first size bytes of the unparsed data portion in the packet
// to h, and returns the backing storage. In the case of data is shorter than
// size, consumed will be false, and the state of h will not be affected.
// Callers may only call one of Push or Consume once on each header in the
// lifetime of the underlying packet.
func (h PacketHeader) Consume(size int) (v []byte, consumed bool) {
return h.pk.consume(h.typ, size)
}
// PacketData represents the data portion of a PacketBuffer.
//
// +stateify savable
type PacketData struct {
pk *PacketBuffer
}
// PullUp returns a contiguous slice of size bytes from the beginning of d.
// Callers should not keep the view for later use. Callers can write to the
// returned slice if they have singular ownership over the underlying
// Buffer.
func (d PacketData) PullUp(size int) (b []byte, ok bool) {
view, ok := d.pk.buf.PullUp(d.pk.dataOffset(), size)
return view.AsSlice(), ok
}
// Consume is the same as PullUp except that is additionally consumes the
// returned bytes. Subsequent PullUp or Consume will not return these bytes.
func (d PacketData) Consume(size int) ([]byte, bool) {
v, ok := d.PullUp(size)
if ok {
d.pk.consumed += size
}
return v, ok
}
// ReadTo reads bytes from d to dst. It also removes these bytes from d
// unless peek is true.
func (d PacketData) ReadTo(dst io.Writer, peek bool) (int, error) {
var (
err error
done int
)
offset := d.pk.dataOffset()
d.pk.buf.SubApply(offset, int(d.pk.buf.Size())-offset, func(v *buffer.View) {
if err != nil {
return
}
var n int
n, err = dst.Write(v.AsSlice())
done += n
if err != nil {
return
}
if n != v.Size() {
panic(fmt.Sprintf("io.Writer.Write succeeded with incomplete write: %d != %d", n, v.Size()))
}
})
if !peek {
d.pk.buf.TrimFront(int64(done))
}
return done, err
}
// CapLength reduces d to at most length bytes.
func (d PacketData) CapLength(length int) {
if length < 0 {
panic("length < 0")
}
d.pk.buf.Truncate(int64(length + d.pk.dataOffset()))
}
// ToBuffer returns the underlying storage of d in a buffer.Buffer.
func (d PacketData) ToBuffer() buffer.Buffer {
buf := d.pk.buf.Clone()
offset := d.pk.dataOffset()
buf.TrimFront(int64(offset))
return buf
}
// AppendView appends v into d, taking the ownership of v.
func (d PacketData) AppendView(v *buffer.View) {
d.pk.buf.Append(v)
}
// MergeBuffer merges b into d and clears b.
func (d PacketData) MergeBuffer(b *buffer.Buffer) {
d.pk.buf.Merge(b)
}
// MergeFragment appends the data portion of frag to dst. It modifies
// frag and frag should not be used again.
func MergeFragment(dst, frag *PacketBuffer) {
frag.buf.TrimFront(int64(frag.dataOffset()))
dst.buf.Merge(&frag.buf)
}
// ReadFrom moves at most count bytes from the beginning of src to the end
// of d and returns the number of bytes moved.
func (d PacketData) ReadFrom(src *buffer.Buffer, count int) int {
toRead := int64(count)
if toRead > src.Size() {
toRead = src.Size()
}
clone := src.Clone()
clone.Truncate(toRead)
d.pk.buf.Merge(&clone)
src.TrimFront(toRead)
return int(toRead)
}
// ReadFromPacketData moves count bytes from the beginning of oth to the end of
// d.
func (d PacketData) ReadFromPacketData(oth PacketData, count int) {
buf := oth.ToBuffer()
buf.Truncate(int64(count))
d.MergeBuffer(&buf)
oth.TrimFront(count)
buf.Release()
}
// Merge clears headers in oth and merges its data with d.
func (d PacketData) Merge(oth PacketData) {
oth.pk.buf.TrimFront(int64(oth.pk.dataOffset()))
d.pk.buf.Merge(&oth.pk.buf)
}
// TrimFront removes up to count bytes from the front of d's payload.
func (d PacketData) TrimFront(count int) {
if count > d.Size() {
count = d.Size()
}
buf := d.pk.Data().ToBuffer()
buf.TrimFront(int64(count))
d.pk.buf.Truncate(int64(d.pk.dataOffset()))
d.pk.buf.Merge(&buf)
}
// Size returns the number of bytes in the data payload of the packet.
func (d PacketData) Size() int {
return int(d.pk.buf.Size()) - d.pk.dataOffset()
}
// AsRange returns a Range representing the current data payload of the packet.
func (d PacketData) AsRange() Range {
return Range{
pk: d.pk,
offset: d.pk.dataOffset(),
length: d.Size(),
}
}
// Checksum returns a checksum over the data payload of the packet.
func (d PacketData) Checksum() uint16 {
return d.pk.buf.Checksum(d.pk.dataOffset())
}
// ChecksumAtOffset returns a checksum over the data payload of the packet
// starting from offset.
func (d PacketData) ChecksumAtOffset(offset int) uint16 {
return d.pk.buf.Checksum(offset)
}
// Range represents a contiguous subportion of a PacketBuffer.
type Range struct {
pk *PacketBuffer
offset int
length int
}
// Size returns the number of bytes in r.
func (r Range) Size() int {
return r.length
}
// SubRange returns a new Range starting at off bytes of r. It returns an empty
// range if off is out-of-bounds.
func (r Range) SubRange(off int) Range {
if off > r.length {
return Range{pk: r.pk}
}
return Range{
pk: r.pk,
offset: r.offset + off,
length: r.length - off,
}
}
// Capped returns a new Range with the same starting point of r and length
// capped at max.
func (r Range) Capped(max int) Range {
if r.length <= max {
return r
}
return Range{
pk: r.pk,
offset: r.offset,
length: max,
}
}
// ToSlice returns a caller-owned copy of data in r.
func (r Range) ToSlice() []byte {
if r.length == 0 {
return nil
}
all := make([]byte, 0, r.length)
r.iterate(func(v *buffer.View) {
all = append(all, v.AsSlice()...)
})
return all
}
// ToView returns a caller-owned copy of data in r.
func (r Range) ToView() *buffer.View {
if r.length == 0 {
return nil
}
newV := buffer.NewView(r.length)
r.iterate(func(v *buffer.View) {
newV.Write(v.AsSlice())
})
return newV
}
// iterate calls fn for each piece in r. fn is always called with a non-empty
// slice.
func (r Range) iterate(fn func(*buffer.View)) {
r.pk.buf.SubApply(r.offset, r.length, fn)
}
// PayloadSince returns a caller-owned view containing the payload starting from
// and including a particular header.
func PayloadSince(h PacketHeader) *buffer.View {
offset := h.pk.headerOffset()
for i := headerType(0); i < h.typ; i++ {
offset += h.pk.headers[i].length
}
return Range{
pk: h.pk,
offset: offset,
length: int(h.pk.buf.Size()) - offset,
}.ToView()
}
// BufferSince returns a caller-owned view containing the packet payload
// starting from and including a particular header.
func BufferSince(h PacketHeader) buffer.Buffer {
offset := h.pk.headerOffset()
for i := headerType(0); i < h.typ; i++ {
offset += h.pk.headers[i].length
}
clone := h.pk.buf.Clone()
clone.TrimFront(int64(offset))
return clone
}

View File

@@ -0,0 +1,87 @@
// Copyright 2022 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at //
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
// PacketBufferList is a slice-backed list. All operations are O(1) unless
// otherwise noted.
//
// Note: this is intentionally backed by a slice, not an intrusive list. We've
// switched PacketBufferList back-and-forth between intrusive list and
// slice-backed implementations, and the latter has proven to be preferable:
//
// - Intrusive lists are a refcounting nightmare, as modifying the list
// sometimes-but-not-always modifies the list for others.
// - The slice-backed implementation has been benchmarked and is slightly more
// performant.
//
// +stateify savable
type PacketBufferList struct {
pbs []*PacketBuffer
}
// AsSlice returns a slice containing the packets in the list.
//
//go:nosplit
func (pl *PacketBufferList) AsSlice() []*PacketBuffer {
return pl.pbs
}
// Reset decrements all elements and resets the list to the empty state.
//
//go:nosplit
func (pl *PacketBufferList) Reset() {
for i, pb := range pl.pbs {
pb.DecRef()
pl.pbs[i] = nil
}
pl.pbs = pl.pbs[:0]
}
// Len returns the number of elements in the list.
//
//go:nosplit
func (pl *PacketBufferList) Len() int {
return len(pl.pbs)
}
// PushBack inserts the PacketBuffer at the back of the list.
//
//go:nosplit
func (pl *PacketBufferList) PushBack(pb *PacketBuffer) {
pl.pbs = append(pl.pbs, pb)
}
// PopFront removes the first element in the list if it exists and returns it.
//
//go:nosplit
func (pl *PacketBufferList) PopFront() *PacketBuffer {
if len(pl.pbs) == 0 {
return nil
}
pkt := pl.pbs[0]
pl.pbs = pl.pbs[1:]
return pkt
}
// DecRef decreases the reference count on each PacketBuffer
// stored in the list.
//
// NOTE: runs in O(n) time.
//
//go:nosplit
func (pl PacketBufferList) DecRef() {
for _, pb := range pl.pbs {
pb.DecRef()
}
}

View File

@@ -0,0 +1,142 @@
package stack
import (
"context"
"fmt"
"gvisor.dev/gvisor/pkg/atomicbitops"
"gvisor.dev/gvisor/pkg/refs"
)
// enableLogging indicates whether reference-related events should be logged (with
// stack traces). This is false by default and should only be set to true for
// debugging purposes, as it can generate an extremely large amount of output
// and drastically degrade performance.
const packetBufferenableLogging = false
// obj is used to customize logging. Note that we use a pointer to T so that
// we do not copy the entire object when passed as a format parameter.
var packetBufferobj *PacketBuffer
// Refs implements refs.RefCounter. It keeps a reference count using atomic
// operations and calls the destructor when the count reaches zero.
//
// NOTE: Do not introduce additional fields to the Refs struct. It is used by
// many filesystem objects, and we want to keep it as small as possible (i.e.,
// the same size as using an int64 directly) to avoid taking up extra cache
// space. In general, this template should not be extended at the cost of
// performance. If it does not offer enough flexibility for a particular object
// (example: b/187877947), we should implement the RefCounter/CheckedObject
// interfaces manually.
//
// +stateify savable
type packetBufferRefs struct {
// refCount is composed of two fields:
//
// [32-bit speculative references]:[32-bit real references]
//
// Speculative references are used for TryIncRef, to avoid a CompareAndSwap
// loop. See IncRef, DecRef and TryIncRef for details of how these fields are
// used.
refCount atomicbitops.Int64
}
// InitRefs initializes r with one reference and, if enabled, activates leak
// checking.
func (r *packetBufferRefs) InitRefs() {
r.refCount.RacyStore(1)
refs.Register(r)
}
// RefType implements refs.CheckedObject.RefType.
func (r *packetBufferRefs) RefType() string {
return fmt.Sprintf("%T", packetBufferobj)[1:]
}
// LeakMessage implements refs.CheckedObject.LeakMessage.
func (r *packetBufferRefs) LeakMessage() string {
return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs())
}
// LogRefs implements refs.CheckedObject.LogRefs.
func (r *packetBufferRefs) LogRefs() bool {
return packetBufferenableLogging
}
// ReadRefs returns the current number of references. The returned count is
// inherently racy and is unsafe to use without external synchronization.
func (r *packetBufferRefs) ReadRefs() int64 {
return r.refCount.Load()
}
// IncRef implements refs.RefCounter.IncRef.
//
//go:nosplit
func (r *packetBufferRefs) IncRef() {
v := r.refCount.Add(1)
if packetBufferenableLogging {
refs.LogIncRef(r, v)
}
if v <= 1 {
panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType()))
}
}
// TryIncRef implements refs.TryRefCounter.TryIncRef.
//
// To do this safely without a loop, a speculative reference is first acquired
// on the object. This allows multiple concurrent TryIncRef calls to distinguish
// other TryIncRef calls from genuine references held.
//
//go:nosplit
func (r *packetBufferRefs) TryIncRef() bool {
const speculativeRef = 1 << 32
if v := r.refCount.Add(speculativeRef); int32(v) == 0 {
r.refCount.Add(-speculativeRef)
return false
}
v := r.refCount.Add(-speculativeRef + 1)
if packetBufferenableLogging {
refs.LogTryIncRef(r, v)
}
return true
}
// DecRef implements refs.RefCounter.DecRef.
//
// Note that speculative references are counted here. Since they were added
// prior to real references reaching zero, they will successfully convert to
// real references. In other words, we see speculative references only in the
// following case:
//
// A: TryIncRef [speculative increase => sees non-negative references]
// B: DecRef [real decrease]
// A: TryIncRef [transform speculative to real]
//
//go:nosplit
func (r *packetBufferRefs) DecRef(destroy func()) {
v := r.refCount.Add(-1)
if packetBufferenableLogging {
refs.LogDecRef(r, v)
}
switch {
case v < 0:
panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType()))
case v == 0:
refs.Unregister(r)
if destroy != nil {
destroy()
}
}
}
func (r *packetBufferRefs) afterLoad(context.Context) {
if r.ReadRefs() > 0 {
refs.Register(r)
}
}

View File

@@ -0,0 +1,28 @@
// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import "unsafe"
// PacketBufferStructSize is the minimal size of the packet buffer overhead.
const PacketBufferStructSize = int(unsafe.Sizeof(PacketBuffer{}))
// ID returns a unique ID for the underlying storage of the packet.
//
// Two *PacketBuffers have the same IDs if and only if they point to the same
// location in memory.
func (pk *PacketBuffer) ID() uintptr {
return uintptr(unsafe.Pointer(pk))
}

View File

@@ -0,0 +1,96 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type packetEndpointListRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var packetEndpointListlockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type packetEndpointListlockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *packetEndpointListRWMutex) Lock() {
locking.AddGLock(packetEndpointListprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *packetEndpointListRWMutex) NestedLock(i packetEndpointListlockNameIndex) {
locking.AddGLock(packetEndpointListprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *packetEndpointListRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(packetEndpointListprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *packetEndpointListRWMutex) NestedUnlock(i packetEndpointListlockNameIndex) {
m.mu.Unlock()
locking.DelGLock(packetEndpointListprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *packetEndpointListRWMutex) RLock() {
locking.AddGLock(packetEndpointListprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *packetEndpointListRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(packetEndpointListprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *packetEndpointListRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *packetEndpointListRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *packetEndpointListRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var packetEndpointListprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func packetEndpointListinitLockNames() {}
func init() {
packetEndpointListinitLockNames()
packetEndpointListprefixIndex = locking.NewMutexClass(reflect.TypeOf(packetEndpointListRWMutex{}), packetEndpointListlockNames)
}

View File

@@ -0,0 +1,96 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type packetEPsRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var packetEPslockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type packetEPslockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *packetEPsRWMutex) Lock() {
locking.AddGLock(packetEPsprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *packetEPsRWMutex) NestedLock(i packetEPslockNameIndex) {
locking.AddGLock(packetEPsprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *packetEPsRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(packetEPsprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *packetEPsRWMutex) NestedUnlock(i packetEPslockNameIndex) {
m.mu.Unlock()
locking.DelGLock(packetEPsprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *packetEPsRWMutex) RLock() {
locking.AddGLock(packetEPsprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *packetEPsRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(packetEPsprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *packetEPsRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *packetEPsRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *packetEPsRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var packetEPsprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func packetEPsinitLockNames() {}
func init() {
packetEPsinitLockNames()
packetEPsprefixIndex = locking.NewMutexClass(reflect.TypeOf(packetEPsRWMutex{}), packetEPslockNames)
}

View File

@@ -0,0 +1,64 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// Mutex is sync.Mutex with the correctness validator.
type packetsPendingLinkResolutionMutex struct {
mu sync.Mutex
}
var packetsPendingLinkResolutionprefixIndex *locking.MutexClass
// lockNames is a list of user-friendly lock names.
// Populated in init.
var packetsPendingLinkResolutionlockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type packetsPendingLinkResolutionlockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *packetsPendingLinkResolutionMutex) Lock() {
locking.AddGLock(packetsPendingLinkResolutionprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *packetsPendingLinkResolutionMutex) NestedLock(i packetsPendingLinkResolutionlockNameIndex) {
locking.AddGLock(packetsPendingLinkResolutionprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *packetsPendingLinkResolutionMutex) Unlock() {
locking.DelGLock(packetsPendingLinkResolutionprefixIndex, -1)
m.mu.Unlock()
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *packetsPendingLinkResolutionMutex) NestedUnlock(i packetsPendingLinkResolutionlockNameIndex) {
locking.DelGLock(packetsPendingLinkResolutionprefixIndex, int(i))
m.mu.Unlock()
}
// DO NOT REMOVE: The following function is automatically replaced.
func packetsPendingLinkResolutioninitLockNames() {}
func init() {
packetsPendingLinkResolutioninitLockNames()
packetsPendingLinkResolutionprefixIndex = locking.NewMutexClass(reflect.TypeOf(packetsPendingLinkResolutionMutex{}), packetsPendingLinkResolutionlockNames)
}

View File

@@ -0,0 +1,224 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
)
const (
// maxPendingResolutions is the maximum number of pending link-address
// resolutions.
maxPendingResolutions = 64
maxPendingPacketsPerResolution = 256
)
// +stateify savable
type pendingPacket struct {
routeInfo RouteInfo
pkt *PacketBuffer
}
// +stateify savable
type packetsPendingLinkResolutionMu struct {
packetsPendingLinkResolutionMutex `state:"nosave"`
// The packets to send once the resolver completes.
//
// The link resolution channel is used as the key for this map.
packets map[<-chan struct{}][]pendingPacket
// FIFO of channels used to cancel the oldest goroutine waiting for
// link-address resolution.
//
// cancelChans holds the same channels that are used as keys to packets.
cancelChans []<-chan struct{}
}
// packetsPendingLinkResolution is a queue of packets pending link resolution.
//
// Once link resolution completes successfully, the packets will be written.
//
// +stateify savable
type packetsPendingLinkResolution struct {
nic *nic
mu packetsPendingLinkResolutionMu
}
func (f *packetsPendingLinkResolution) incrementOutgoingPacketErrors(pkt *PacketBuffer) {
f.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
if ipEndpointStats, ok := f.nic.getNetworkEndpoint(pkt.NetworkProtocolNumber).Stats().(IPNetworkEndpointStats); ok {
ipEndpointStats.IPStats().OutgoingPacketErrors.Increment()
}
}
func (f *packetsPendingLinkResolution) init(nic *nic) {
f.mu.Lock()
defer f.mu.Unlock()
f.nic = nic
f.mu.packets = make(map[<-chan struct{}][]pendingPacket)
}
// cancel drains all pending packet queues and release all packet
// references.
func (f *packetsPendingLinkResolution) cancel() {
f.mu.Lock()
defer f.mu.Unlock()
for ch, pendingPackets := range f.mu.packets {
for _, p := range pendingPackets {
p.pkt.DecRef()
}
delete(f.mu.packets, ch)
}
f.mu.cancelChans = nil
}
// dequeue any pending packets associated with ch.
//
// If err is nil, packets will be written and sent to the given remote link
// address.
func (f *packetsPendingLinkResolution) dequeue(ch <-chan struct{}, linkAddr tcpip.LinkAddress, err tcpip.Error) {
f.mu.Lock()
packets, ok := f.mu.packets[ch]
delete(f.mu.packets, ch)
if ok {
for i, cancelChan := range f.mu.cancelChans {
if cancelChan == ch {
f.mu.cancelChans = append(f.mu.cancelChans[:i], f.mu.cancelChans[i+1:]...)
break
}
}
}
f.mu.Unlock()
if ok {
f.dequeuePackets(packets, linkAddr, err)
}
}
// enqueue a packet to be sent once link resolution completes.
//
// If the maximum number of pending resolutions is reached, the packets
// associated with the oldest link resolution will be dequeued as if they failed
// link resolution.
func (f *packetsPendingLinkResolution) enqueue(r *Route, pkt *PacketBuffer) tcpip.Error {
f.mu.Lock()
// Make sure we attempt resolution while holding f's lock so that we avoid
// a race where link resolution completes before we enqueue the packets.
//
// A @ T1: Call ResolvedFields (get link resolution channel)
// B @ T2: Complete link resolution, dequeue pending packets
// C @ T1: Enqueue packet that already completed link resolution (which will
// never dequeue)
//
// To make sure B does not interleave with A and C, we make sure A and C are
// done while holding the lock.
routeInfo, ch, err := r.resolvedFields(nil)
switch err.(type) {
case nil:
// The route resolved immediately, so we don't need to wait for link
// resolution to send the packet.
f.mu.Unlock()
pkt.EgressRoute = routeInfo
return f.nic.writePacket(pkt)
case *tcpip.ErrWouldBlock:
// We need to wait for link resolution to complete.
default:
f.mu.Unlock()
return err
}
defer f.mu.Unlock()
packets, ok := f.mu.packets[ch]
packets = append(packets, pendingPacket{
routeInfo: routeInfo,
pkt: pkt.IncRef(),
})
if len(packets) > maxPendingPacketsPerResolution {
f.incrementOutgoingPacketErrors(packets[0].pkt)
packets[0].pkt.DecRef()
packets[0] = pendingPacket{}
packets = packets[1:]
if numPackets := len(packets); numPackets != maxPendingPacketsPerResolution {
panic(fmt.Sprintf("holding more queued packets than expected; got = %d, want <= %d", numPackets, maxPendingPacketsPerResolution))
}
}
f.mu.packets[ch] = packets
if ok {
return nil
}
cancelledPackets := f.newCancelChannelLocked(ch)
if len(cancelledPackets) != 0 {
// Dequeue the pending packets in a new goroutine to not hold up the current
// goroutine as handing link resolution failures may be a costly operation.
go f.dequeuePackets(cancelledPackets, "" /* linkAddr */, &tcpip.ErrAborted{})
}
return nil
}
// newCancelChannelLocked appends the link resolution channel to a FIFO. If the
// maximum number of pending resolutions is reached, the oldest channel will be
// removed and its associated pending packets will be returned.
func (f *packetsPendingLinkResolution) newCancelChannelLocked(newCH <-chan struct{}) []pendingPacket {
f.mu.cancelChans = append(f.mu.cancelChans, newCH)
if len(f.mu.cancelChans) <= maxPendingResolutions {
return nil
}
ch := f.mu.cancelChans[0]
f.mu.cancelChans[0] = nil
f.mu.cancelChans = f.mu.cancelChans[1:]
if l := len(f.mu.cancelChans); l > maxPendingResolutions {
panic(fmt.Sprintf("max pending resolutions reached; got %d active resolutions, max = %d", l, maxPendingResolutions))
}
packets, ok := f.mu.packets[ch]
if !ok {
panic("must have a packet queue for an uncancelled channel")
}
delete(f.mu.packets, ch)
return packets
}
func (f *packetsPendingLinkResolution) dequeuePackets(packets []pendingPacket, linkAddr tcpip.LinkAddress, err tcpip.Error) {
for _, p := range packets {
if err == nil {
p.routeInfo.RemoteLinkAddress = linkAddr
p.pkt.EgressRoute = p.routeInfo
_ = f.nic.writePacket(p.pkt)
} else {
f.incrementOutgoingPacketErrors(p.pkt)
if linkResolvableEP, ok := f.nic.getNetworkEndpoint(p.pkt.NetworkProtocolNumber).(LinkResolvableNetworkEndpoint); ok {
linkResolvableEP.HandleLinkResolutionFailure(p.pkt)
}
}
p.pkt.DecRef()
}
}

View File

@@ -0,0 +1,40 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"math/rand"
"gvisor.dev/gvisor/pkg/sync"
)
// lockedRandomSource provides a threadsafe rand.Source.
type lockedRandomSource struct {
mu sync.Mutex
src rand.Source
}
func (r *lockedRandomSource) Int63() (n int64) {
r.mu.Lock()
n = r.src.Int63()
r.mu.Unlock()
return n
}
func (r *lockedRandomSource) Seed(seed int64) {
r.mu.Lock()
r.src.Seed(seed)
r.mu.Unlock()
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,598 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
// Route represents a route through the networking stack to a given destination.
//
// It is safe to call Route's methods from multiple goroutines.
type Route struct {
routeInfo routeInfo
// localAddressNIC is the interface the address is associated with.
// TODO(gvisor.dev/issue/4548): Remove this field once we can query the
// address's assigned status without the NIC.
localAddressNIC *nic
// mu protects annotated fields below.
mu routeRWMutex
// localAddressEndpoint is the local address this route is associated with.
// +checklocks:mu
localAddressEndpoint AssignableAddressEndpoint
// remoteLinkAddress is the link-layer (MAC) address of the next hop.
// +checklocks:mu
remoteLinkAddress tcpip.LinkAddress
// outgoingNIC is the interface this route uses to write packets.
outgoingNIC *nic
// linkRes is set if link address resolution is enabled for this protocol on
// the route's NIC.
linkRes *linkResolver
// neighborEntry is the cached result of fetching a neighbor entry from the
// neighbor cache.
// +checklocks:mu
neighborEntry *neighborEntry
// mtu is the maximum transmission unit to use for this route.
// If mtu is 0, this field is ignored and the MTU of the outgoing NIC
// is used for egress packets.
mtu uint32
}
// +stateify savable
type routeInfo struct {
RemoteAddress tcpip.Address
LocalAddress tcpip.Address
LocalLinkAddress tcpip.LinkAddress
NextHop tcpip.Address
NetProto tcpip.NetworkProtocolNumber
Loop PacketLooping
}
// RemoteAddress returns the route's destination.
func (r *Route) RemoteAddress() tcpip.Address {
return r.routeInfo.RemoteAddress
}
// LocalAddress returns the route's local address.
func (r *Route) LocalAddress() tcpip.Address {
return r.routeInfo.LocalAddress
}
// LocalLinkAddress returns the route's local link-layer address.
func (r *Route) LocalLinkAddress() tcpip.LinkAddress {
return r.routeInfo.LocalLinkAddress
}
// NextHop returns the next node in the route's path to the destination.
func (r *Route) NextHop() tcpip.Address {
return r.routeInfo.NextHop
}
// NetProto returns the route's network-layer protocol number.
func (r *Route) NetProto() tcpip.NetworkProtocolNumber {
return r.routeInfo.NetProto
}
// Loop returns the route's required packet looping.
func (r *Route) Loop() PacketLooping {
return r.routeInfo.Loop
}
// OutgoingNIC returns the route's outgoing NIC.
func (r *Route) OutgoingNIC() tcpip.NICID {
return r.outgoingNIC.id
}
// RouteInfo contains all of Route's exported fields.
//
// +stateify savable
type RouteInfo struct {
routeInfo
// RemoteLinkAddress is the link-layer (MAC) address of the next hop in the
// route.
RemoteLinkAddress tcpip.LinkAddress
}
// Fields returns a RouteInfo with all of the known values for the route's
// fields.
//
// If any fields are unknown (e.g. remote link address when it is waiting for
// link address resolution), they will be unset.
func (r *Route) Fields() RouteInfo {
r.mu.RLock()
defer r.mu.RUnlock()
return r.fieldsLocked()
}
// +checklocksread:r.mu
func (r *Route) fieldsLocked() RouteInfo {
return RouteInfo{
routeInfo: r.routeInfo,
RemoteLinkAddress: r.remoteLinkAddress,
}
}
// constructAndValidateRoute validates and initializes a route. It takes
// ownership of the provided local address.
//
// Returns an empty route if validation fails.
func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *nic, gateway, localAddr, remoteAddr tcpip.Address, handleLocal, multicastLoop bool, mtu uint32) *Route {
if localAddr.BitLen() == 0 {
localAddr = addressEndpoint.AddressWithPrefix().Address
}
if localAddressNIC != outgoingNIC && header.IsV6LinkLocalUnicastAddress(localAddr) {
addressEndpoint.DecRef()
return nil
}
// If no remote address is provided, use the local address.
if remoteAddr.BitLen() == 0 {
remoteAddr = localAddr
}
r := makeRoute(
netProto,
gateway,
localAddr,
remoteAddr,
outgoingNIC,
localAddressNIC,
addressEndpoint,
handleLocal,
multicastLoop,
mtu,
)
return r
}
// makeRoute initializes a new route. It takes ownership of the provided
// AssignableAddressEndpoint.
func makeRoute(netProto tcpip.NetworkProtocolNumber, gateway, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *nic, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool, mtu uint32) *Route {
if localAddressNIC.stack != outgoingNIC.stack {
panic(fmt.Sprintf("cannot create a route with NICs from different stacks"))
}
if localAddr.BitLen() == 0 {
localAddr = localAddressEndpoint.AddressWithPrefix().Address
}
loop := PacketOut
// Loopback interface loops back packets at the link endpoint level. We
// could remove this check if loopback interfaces looped back packets
// at the network layer.
if !outgoingNIC.IsLoopback() {
if handleLocal && localAddr != (tcpip.Address{}) && remoteAddr == localAddr {
loop = PacketLoop
} else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) {
loop |= PacketLoop
} else if remoteAddr == header.IPv4Broadcast {
loop |= PacketLoop
} else if subnet := localAddressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) {
loop |= PacketLoop
}
}
r := makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop, mtu)
if r.Loop()&PacketOut == 0 {
// Packet will not leave the stack, no need for a gateway or a remote link
// address.
return r
}
if r.outgoingNIC.NetworkLinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 {
if linkRes, ok := r.outgoingNIC.linkAddrResolvers[r.NetProto()]; ok {
r.linkRes = linkRes
}
}
if gateway.BitLen() > 0 {
r.routeInfo.NextHop = gateway
return r
}
if r.linkRes == nil {
return r
}
if linkAddr, ok := r.linkRes.resolver.ResolveStaticAddress(r.RemoteAddress()); ok {
r.ResolveWith(linkAddr)
return r
}
if subnet := localAddressEndpoint.Subnet(); subnet.IsBroadcast(remoteAddr) {
r.ResolveWith(header.EthernetBroadcastAddress)
return r
}
if r.RemoteAddress() == r.LocalAddress() {
// Local link address is already known.
r.ResolveWith(r.LocalLinkAddress())
}
return r
}
func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *nic, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping, mtu uint32) *Route {
r := &Route{
routeInfo: routeInfo{
NetProto: netProto,
LocalAddress: localAddr,
LocalLinkAddress: outgoingNIC.NetworkLinkEndpoint.LinkAddress(),
RemoteAddress: remoteAddr,
Loop: loop,
},
localAddressNIC: localAddressNIC,
outgoingNIC: outgoingNIC,
mtu: mtu,
}
r.mu.Lock()
r.localAddressEndpoint = localAddressEndpoint
r.mu.Unlock()
return r
}
// makeLocalRoute initializes a new local route. It takes ownership of the
// provided AssignableAddressEndpoint.
//
// A local route is a route to a destination that is local to the stack.
func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *nic, localAddressEndpoint AssignableAddressEndpoint) *Route {
loop := PacketLoop
// Loopback interface loops back packets at the link endpoint level. We
// could remove this check if loopback interfaces looped back packets
// at the network layer.
if outgoingNIC.IsLoopback() {
loop = PacketOut
}
return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop, 0 /* mtu */)
}
// RemoteLinkAddress returns the link-layer (MAC) address of the next hop in
// the route.
func (r *Route) RemoteLinkAddress() tcpip.LinkAddress {
r.mu.RLock()
defer r.mu.RUnlock()
return r.remoteLinkAddress
}
// NICID returns the id of the NIC from which this route originates.
func (r *Route) NICID() tcpip.NICID {
return r.outgoingNIC.ID()
}
// MaxHeaderLength forwards the call to the network endpoint's implementation.
func (r *Route) MaxHeaderLength() uint16 {
return r.outgoingNIC.getNetworkEndpoint(r.NetProto()).MaxHeaderLength()
}
// Stats returns a mutable copy of current stats.
func (r *Route) Stats() tcpip.Stats {
return r.outgoingNIC.stack.Stats()
}
// PseudoHeaderChecksum forwards the call to the network endpoint's
// implementation.
func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, totalLen uint16) uint16 {
return header.PseudoHeaderChecksum(protocol, r.LocalAddress(), r.RemoteAddress(), totalLen)
}
// RequiresTXTransportChecksum returns false if the route does not require
// transport checksums to be populated.
func (r *Route) RequiresTXTransportChecksum() bool {
if r.local() {
return false
}
return r.outgoingNIC.NetworkLinkEndpoint.Capabilities()&CapabilityTXChecksumOffload == 0
}
// HasGVisorGSOCapability returns true if the route supports gVisor GSO.
func (r *Route) HasGVisorGSOCapability() bool {
if gso, ok := r.outgoingNIC.NetworkLinkEndpoint.(GSOEndpoint); ok {
return gso.SupportedGSO() == GVisorGSOSupported
}
return false
}
// HasHostGSOCapability returns true if the route supports host GSO.
func (r *Route) HasHostGSOCapability() bool {
if gso, ok := r.outgoingNIC.NetworkLinkEndpoint.(GSOEndpoint); ok {
return gso.SupportedGSO() == HostGSOSupported
}
return false
}
// HasSaveRestoreCapability returns true if the route supports save/restore.
func (r *Route) HasSaveRestoreCapability() bool {
return r.outgoingNIC.NetworkLinkEndpoint.Capabilities()&CapabilitySaveRestore != 0
}
// HasDisconnectOkCapability returns true if the route supports disconnecting.
func (r *Route) HasDisconnectOkCapability() bool {
return r.outgoingNIC.NetworkLinkEndpoint.Capabilities()&CapabilityDisconnectOk != 0
}
// GSOMaxSize returns the maximum GSO packet size.
func (r *Route) GSOMaxSize() uint32 {
if gso, ok := r.outgoingNIC.NetworkLinkEndpoint.(GSOEndpoint); ok {
return gso.GSOMaxSize()
}
return 0
}
// ResolveWith immediately resolves a route with the specified remote link
// address.
func (r *Route) ResolveWith(addr tcpip.LinkAddress) {
r.mu.Lock()
defer r.mu.Unlock()
r.remoteLinkAddress = addr
}
// ResolvedFieldsResult is the result of a route resolution attempt.
type ResolvedFieldsResult struct {
RouteInfo RouteInfo
Err tcpip.Error
}
// ResolvedFields attempts to resolve the remote link address if it is not
// known.
//
// If a callback is provided, it will be called before ResolvedFields returns
// when address resolution is not required. If address resolution is required,
// the callback will be called once address resolution is complete, regardless
// of success or failure.
//
// Note, the route will not cache the remote link address when address
// resolution completes.
func (r *Route) ResolvedFields(afterResolve func(ResolvedFieldsResult)) tcpip.Error {
_, _, err := r.resolvedFields(afterResolve)
return err
}
// resolvedFields is like ResolvedFields but also returns a notification channel
// when address resolution is required. This channel will become readable once
// address resolution is complete.
//
// The route's fields will also be returned, regardless of whether address
// resolution is required or not.
func (r *Route) resolvedFields(afterResolve func(ResolvedFieldsResult)) (RouteInfo, <-chan struct{}, tcpip.Error) {
r.mu.RLock()
fields := r.fieldsLocked()
resolutionRequired := r.isResolutionRequiredRLocked()
r.mu.RUnlock()
if !resolutionRequired {
if afterResolve != nil {
afterResolve(ResolvedFieldsResult{RouteInfo: fields, Err: nil})
}
return fields, nil, nil
}
// If specified, the local address used for link address resolution must be an
// address on the outgoing interface.
var linkAddressResolutionRequestLocalAddr tcpip.Address
if r.localAddressNIC == r.outgoingNIC {
linkAddressResolutionRequestLocalAddr = r.LocalAddress()
}
nEntry := r.getCachedNeighborEntry()
if nEntry != nil {
if addr, ok := nEntry.getRemoteLinkAddress(); ok {
fields.RemoteLinkAddress = addr
if afterResolve != nil {
afterResolve(ResolvedFieldsResult{RouteInfo: fields, Err: nil})
}
return fields, nil, nil
}
}
afterResolveFields := fields
entry, ch, err := r.linkRes.neigh.entry(r.nextHop(), linkAddressResolutionRequestLocalAddr, func(lrr LinkResolutionResult) {
if afterResolve != nil {
if lrr.Err == nil {
afterResolveFields.RemoteLinkAddress = lrr.LinkAddress
}
afterResolve(ResolvedFieldsResult{RouteInfo: afterResolveFields, Err: lrr.Err})
}
})
if err == nil {
fields.RemoteLinkAddress, _ = entry.getRemoteLinkAddress()
}
r.setCachedNeighborEntry(entry)
return fields, ch, err
}
func (r *Route) getCachedNeighborEntry() *neighborEntry {
r.mu.RLock()
defer r.mu.RUnlock()
return r.neighborEntry
}
func (r *Route) setCachedNeighborEntry(entry *neighborEntry) {
r.mu.Lock()
defer r.mu.Unlock()
r.neighborEntry = entry
}
func (r *Route) nextHop() tcpip.Address {
if r.NextHop().BitLen() == 0 {
return r.RemoteAddress()
}
return r.NextHop()
}
// local returns true if the route is a local route.
func (r *Route) local() bool {
return r.Loop() == PacketLoop || r.outgoingNIC.IsLoopback()
}
// IsResolutionRequired returns true if Resolve() must be called to resolve
// the link address before the route can be written to.
//
// The NICs the route is associated with must not be locked.
func (r *Route) IsResolutionRequired() bool {
r.mu.RLock()
defer r.mu.RUnlock()
return r.isResolutionRequiredRLocked()
}
// +checklocksread:r.mu
func (r *Route) isResolutionRequiredRLocked() bool {
return len(r.remoteLinkAddress) == 0 && r.linkRes != nil && r.isValidForOutgoingRLocked() && !r.local()
}
func (r *Route) isValidForOutgoing() bool {
r.mu.RLock()
defer r.mu.RUnlock()
return r.isValidForOutgoingRLocked()
}
// +checklocksread:r.mu
func (r *Route) isValidForOutgoingRLocked() bool {
if !r.outgoingNIC.Enabled() {
return false
}
localAddressEndpoint := r.localAddressEndpoint
if localAddressEndpoint == nil || !r.localAddressNIC.isValidForOutgoing(localAddressEndpoint) {
return false
}
// If the source NIC and outgoing NIC are different, make sure the stack has
// forwarding enabled, or the packet will be handled locally.
if r.outgoingNIC != r.localAddressNIC && !isNICForwarding(r.localAddressNIC, r.NetProto()) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto(), r.RemoteAddress())) {
return false
}
return true
}
// WritePacket writes the packet through the given route.
func (r *Route) WritePacket(params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error {
if !r.isValidForOutgoing() {
return &tcpip.ErrInvalidEndpointState{}
}
return r.outgoingNIC.getNetworkEndpoint(r.NetProto()).WritePacket(r, params, pkt)
}
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) tcpip.Error {
if !r.isValidForOutgoing() {
return &tcpip.ErrInvalidEndpointState{}
}
return r.outgoingNIC.getNetworkEndpoint(r.NetProto()).WriteHeaderIncludedPacket(r, pkt)
}
// DefaultTTL returns the default TTL of the underlying network endpoint.
func (r *Route) DefaultTTL() uint8 {
return r.outgoingNIC.getNetworkEndpoint(r.NetProto()).DefaultTTL()
}
// MTU returns the MTU of the route if present, otherwise the MTU of the underlying network endpoint.
func (r *Route) MTU() uint32 {
if r.mtu > 0 {
return r.mtu
}
return r.outgoingNIC.getNetworkEndpoint(r.NetProto()).MTU()
}
// Release decrements the reference counter of the resources associated with the
// route.
func (r *Route) Release() {
r.mu.Lock()
defer r.mu.Unlock()
if ep := r.localAddressEndpoint; ep != nil {
ep.DecRef()
}
}
// Acquire increments the reference counter of the resources associated with the
// route.
func (r *Route) Acquire() {
r.mu.RLock()
defer r.mu.RUnlock()
r.acquireLocked()
}
// +checklocksread:r.mu
func (r *Route) acquireLocked() {
if ep := r.localAddressEndpoint; ep != nil {
if !ep.TryIncRef() {
panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", r.LocalAddress()))
}
}
}
// Stack returns the instance of the Stack that owns this route.
func (r *Route) Stack() *Stack {
return r.outgoingNIC.stack
}
func (r *Route) isV4Broadcast(addr tcpip.Address) bool {
if addr == header.IPv4Broadcast {
return true
}
r.mu.RLock()
localAddressEndpoint := r.localAddressEndpoint
r.mu.RUnlock()
if localAddressEndpoint == nil {
return false
}
subnet := localAddressEndpoint.Subnet()
return subnet.IsBroadcast(addr)
}
// IsOutboundBroadcast returns true if the route is for an outbound broadcast
// packet.
func (r *Route) IsOutboundBroadcast() bool {
// Only IPv4 has a notion of broadcast.
return r.isV4Broadcast(r.RemoteAddress())
}
// ConfirmReachable informs the network/link layer that the neighbour used for
// the route is reachable.
//
// "Reachable" is defined as having full-duplex communication between the
// local and remote ends of the route.
func (r *Route) ConfirmReachable() {
if entry := r.getCachedNeighborEntry(); entry != nil {
entry.handleUpperLevelConfirmation()
}
}

View File

@@ -0,0 +1,96 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type routeRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var routelockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type routelockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *routeRWMutex) Lock() {
locking.AddGLock(routeprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *routeRWMutex) NestedLock(i routelockNameIndex) {
locking.AddGLock(routeprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *routeRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(routeprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *routeRWMutex) NestedUnlock(i routelockNameIndex) {
m.mu.Unlock()
locking.DelGLock(routeprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *routeRWMutex) RLock() {
locking.AddGLock(routeprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *routeRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(routeprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *routeRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *routeRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *routeRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var routeprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func routeinitLockNames() {}
func init() {
routeinitLockNames()
routeprefixIndex = locking.NewMutexClass(reflect.TypeOf(routeRWMutex{}), routelockNames)
}

View File

@@ -0,0 +1,96 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type routeStackRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var routeStacklockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type routeStacklockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *routeStackRWMutex) Lock() {
locking.AddGLock(routeStackprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *routeStackRWMutex) NestedLock(i routeStacklockNameIndex) {
locking.AddGLock(routeStackprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *routeStackRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(routeStackprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *routeStackRWMutex) NestedUnlock(i routeStacklockNameIndex) {
m.mu.Unlock()
locking.DelGLock(routeStackprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *routeStackRWMutex) RLock() {
locking.AddGLock(routeStackprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *routeStackRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(routeStackprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *routeStackRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *routeStackRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *routeStackRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var routeStackprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func routeStackinitLockNames() {}
func init() {
routeStackinitLockNames()
routeStackprefixIndex = locking.NewMutexClass(reflect.TypeOf(routeStackRWMutex{}), routeStacklockNames)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,96 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type stackRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var stacklockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type stacklockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *stackRWMutex) Lock() {
locking.AddGLock(stackprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *stackRWMutex) NestedLock(i stacklockNameIndex) {
locking.AddGLock(stackprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *stackRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(stackprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *stackRWMutex) NestedUnlock(i stacklockNameIndex) {
m.mu.Unlock()
locking.DelGLock(stackprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *stackRWMutex) RLock() {
locking.AddGLock(stackprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *stackRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(stackprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *stackRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *stackRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *stackRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var stackprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func stackinitLockNames() {}
func init() {
stackinitLockNames()
stackprefixIndex = locking.NewMutexClass(reflect.TypeOf(stackRWMutex{}), stacklockNames)
}

View File

@@ -0,0 +1,125 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"time"
"gvisor.dev/gvisor/pkg/tcpip"
)
const (
// MinBufferSize is the smallest size of a receive or send buffer.
MinBufferSize = 4 << 10 // 4 KiB
// DefaultBufferSize is the default size of the send/recv buffer for a
// transport endpoint.
DefaultBufferSize = 212 << 10 // 212 KiB
// DefaultMaxBufferSize is the default maximum permitted size of a
// send/receive buffer.
DefaultMaxBufferSize = 4 << 20 // 4 MiB
// defaultTCPInvalidRateLimit is the default value for
// stack.TCPInvalidRateLimit.
defaultTCPInvalidRateLimit = 500 * time.Millisecond
)
// ReceiveBufferSizeOption is used by stack.(Stack*).Option/SetOption to
// get/set the default, min and max receive buffer sizes.
type ReceiveBufferSizeOption struct {
Min int
Default int
Max int
}
// TCPInvalidRateLimitOption is used by stack.(Stack*).Option/SetOption to get/set
// stack.tcpInvalidRateLimit.
type TCPInvalidRateLimitOption time.Duration
// SetOption allows setting stack wide options.
func (s *Stack) SetOption(option any) tcpip.Error {
switch v := option.(type) {
case tcpip.SendBufferSizeOption:
// Make sure we don't allow lowering the buffer below minimum
// required for stack to work.
if v.Min < MinBufferSize {
return &tcpip.ErrInvalidOptionValue{}
}
if v.Default < v.Min || v.Default > v.Max {
return &tcpip.ErrInvalidOptionValue{}
}
s.mu.Lock()
s.sendBufferSize = v
s.mu.Unlock()
return nil
case tcpip.ReceiveBufferSizeOption:
// Make sure we don't allow lowering the buffer below minimum
// required for stack to work.
if v.Min < MinBufferSize {
return &tcpip.ErrInvalidOptionValue{}
}
if v.Default < v.Min || v.Default > v.Max {
return &tcpip.ErrInvalidOptionValue{}
}
s.mu.Lock()
s.receiveBufferSize = v
s.mu.Unlock()
return nil
case TCPInvalidRateLimitOption:
if v < 0 {
return &tcpip.ErrInvalidOptionValue{}
}
s.mu.Lock()
s.tcpInvalidRateLimit = time.Duration(v)
s.mu.Unlock()
return nil
default:
return &tcpip.ErrUnknownProtocolOption{}
}
}
// Option allows retrieving stack wide options.
func (s *Stack) Option(option any) tcpip.Error {
switch v := option.(type) {
case *tcpip.SendBufferSizeOption:
s.mu.RLock()
*v = s.sendBufferSize
s.mu.RUnlock()
return nil
case *tcpip.ReceiveBufferSizeOption:
s.mu.RLock()
*v = s.receiveBufferSize
s.mu.RUnlock()
return nil
case *TCPInvalidRateLimitOption:
s.mu.RLock()
*v = TCPInvalidRateLimitOption(s.tcpInvalidRateLimit)
s.mu.RUnlock()
return nil
default:
return &tcpip.ErrUnknownProtocolOption{}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,3 @@
// automatically generated by stateify.
package stack

View File

@@ -0,0 +1,96 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type stateConnRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var stateConnlockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type stateConnlockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *stateConnRWMutex) Lock() {
locking.AddGLock(stateConnprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *stateConnRWMutex) NestedLock(i stateConnlockNameIndex) {
locking.AddGLock(stateConnprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *stateConnRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(stateConnprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *stateConnRWMutex) NestedUnlock(i stateConnlockNameIndex) {
m.mu.Unlock()
locking.DelGLock(stateConnprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *stateConnRWMutex) RLock() {
locking.AddGLock(stateConnprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *stateConnRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(stateConnprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *stateConnRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *stateConnRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *stateConnRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var stateConnprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func stateConninitLockNames() {}
func init() {
stateConninitLockNames()
stateConnprefixIndex = locking.NewMutexClass(reflect.TypeOf(stateConnRWMutex{}), stateConnlockNames)
}

View File

@@ -0,0 +1,494 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"context"
"time"
"gvisor.dev/gvisor/pkg/atomicbitops"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/internal/tcp"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
)
// contextID is this package's type for context.Context.Value keys.
type contextID int
const (
// CtxRestoreStack is a Context.Value key for the stack to be used in restore.
CtxRestoreStack contextID = iota
)
// RestoreStackFromContext returns the stack to be used during restore.
func RestoreStackFromContext(ctx context.Context) *Stack {
return ctx.Value(CtxRestoreStack).(*Stack)
}
// TCPProbeFunc is the expected function type for a TCP probe function to be
// passed to stack.AddTCPProbe.
type TCPProbeFunc func(s *TCPEndpointState)
// TCPCubicState is used to hold a copy of the internal cubic state when the
// TCPProbeFunc is invoked.
//
// +stateify savable
type TCPCubicState struct {
// WLastMax is the previous wMax value.
WLastMax float64
// WMax is the value of the congestion window at the time of the last
// congestion event.
WMax float64
// T is the time when the current congestion avoidance was entered.
T tcpip.MonotonicTime
// TimeSinceLastCongestion denotes the time since the current
// congestion avoidance was entered.
TimeSinceLastCongestion time.Duration
// C is the cubic constant as specified in RFC8312, page 11.
C float64
// K is the time period (in seconds) that the above function takes to
// increase the current window size to WMax if there are no further
// congestion events and is calculated using the following equation:
//
// K = cubic_root(WMax*(1-beta_cubic)/C) (Eq. 2, page 5)
K float64
// Beta is the CUBIC multiplication decrease factor. That is, when a
// congestion event is detected, CUBIC reduces its cwnd to
// WC(0)=WMax*beta_cubic.
Beta float64
// WC is window computed by CUBIC at time TimeSinceLastCongestion. It's
// calculated using the formula:
//
// WC(TimeSinceLastCongestion) = C*(t-K)^3 + WMax (Eq. 1)
WC float64
// WEst is the window computed by CUBIC at time
// TimeSinceLastCongestion+RTT i.e WC(TimeSinceLastCongestion+RTT).
WEst float64
// EndSeq is the sequence number that, when cumulatively ACK'd, ends the
// HyStart round.
EndSeq seqnum.Value
// CurrRTT is the minimum round-trip time from the current round.
CurrRTT time.Duration
// LastRTT is the minimum round-trip time from the previous round.
LastRTT time.Duration
// SampleCount is the number of samples from the current round.
SampleCount uint
// LastAck is the time we received the most recent ACK (or start of round if
// more recent).
LastAck tcpip.MonotonicTime
// RoundStart is the time we started the most recent HyStart round.
RoundStart tcpip.MonotonicTime
}
// TCPRACKState is used to hold a copy of the internal RACK state when the
// TCPProbeFunc is invoked.
//
// +stateify savable
type TCPRACKState struct {
// XmitTime is the transmission timestamp of the most recent
// acknowledged segment.
XmitTime tcpip.MonotonicTime
// EndSequence is the ending TCP sequence number of the most recent
// acknowledged segment.
EndSequence seqnum.Value
// FACK is the highest selectively or cumulatively acknowledged
// sequence.
FACK seqnum.Value
// RTT is the round trip time of the most recently delivered packet on
// the connection (either cumulatively acknowledged or selectively
// acknowledged) that was not marked invalid as a possible spurious
// retransmission.
RTT time.Duration
// Reord is true iff reordering has been detected on this connection.
Reord bool
// DSACKSeen is true iff the connection has seen a DSACK.
DSACKSeen bool
// ReoWnd is the reordering window time used for recording packet
// transmission times. It is used to defer the moment at which RACK
// marks a packet lost.
ReoWnd time.Duration
// ReoWndIncr is the multiplier applied to adjust reorder window.
ReoWndIncr uint8
// ReoWndPersist is the number of loss recoveries before resetting
// reorder window.
ReoWndPersist int8
// RTTSeq is the SND.NXT when RTT is updated.
RTTSeq seqnum.Value
}
// TCPEndpointID is the unique 4 tuple that identifies a given endpoint.
//
// +stateify savable
type TCPEndpointID struct {
// LocalPort is the local port associated with the endpoint.
LocalPort uint16
// LocalAddress is the local [network layer] address associated with
// the endpoint.
LocalAddress tcpip.Address
// RemotePort is the remote port associated with the endpoint.
RemotePort uint16
// RemoteAddress it the remote [network layer] address associated with
// the endpoint.
RemoteAddress tcpip.Address
}
// TCPFastRecoveryState holds a copy of the internal fast recovery state of a
// TCP endpoint.
//
// +stateify savable
type TCPFastRecoveryState struct {
// Active if true indicates the endpoint is in fast recovery. The
// following fields are only meaningful when Active is true.
Active bool
// First is the first unacknowledged sequence number being recovered.
First seqnum.Value
// Last is the 'recover' sequence number that indicates the point at
// which we should exit recovery barring any timeouts etc.
Last seqnum.Value
// MaxCwnd is the maximum value we are permitted to grow the congestion
// window during recovery. This is set at the time we enter recovery.
// It exists to avoid attacks where the receiver intentionally sends
// duplicate acks to artificially inflate the sender's cwnd.
MaxCwnd int
// HighRxt is the highest sequence number which has been retransmitted
// during the current loss recovery phase. See: RFC 6675 Section 2 for
// details.
HighRxt seqnum.Value
// RescueRxt is the highest sequence number which has been
// optimistically retransmitted to prevent stalling of the ACK clock
// when there is loss at the end of the window and no new data is
// available for transmission. See: RFC 6675 Section 2 for details.
RescueRxt seqnum.Value
}
// TCPReceiverState holds a copy of the internal state of the receiver for a
// given TCP endpoint.
//
// +stateify savable
type TCPReceiverState struct {
// RcvNxt is the TCP variable RCV.NXT.
RcvNxt seqnum.Value
// RcvAcc is one beyond the last acceptable sequence number. That is,
// the "largest" sequence value that the receiver has announced to its
// peer that it's willing to accept. This may be different than RcvNxt
// + (last advertised receive window) if the receive window is reduced;
// in that case we have to reduce the window as we receive more data
// instead of shrinking it.
RcvAcc seqnum.Value
// RcvWndScale is the window scaling to use for inbound segments.
RcvWndScale uint8
// PendingBufUsed is the number of bytes pending in the receive queue.
PendingBufUsed int
}
// TCPRTTState holds a copy of information about the endpoint's round trip
// time.
//
// +stateify savable
type TCPRTTState struct {
// SRTT is the smoothed round trip time defined in section 2 of RFC
// 6298.
SRTT time.Duration
// RTTVar is the round-trip time variation as defined in section 2 of
// RFC 6298.
RTTVar time.Duration
// SRTTInited if true indicates that a valid RTT measurement has been
// completed.
SRTTInited bool
}
// TCPSenderState holds a copy of the internal state of the sender for a given
// TCP Endpoint.
//
// +stateify savable
type TCPSenderState struct {
// LastSendTime is the timestamp at which we sent the last segment.
LastSendTime tcpip.MonotonicTime
// DupAckCount is the number of Duplicate ACKs received. It is used for
// fast retransmit.
DupAckCount int
// SndCwnd is the size of the sending congestion window in packets.
SndCwnd int
// Ssthresh is the threshold between slow start and congestion
// avoidance.
Ssthresh int
// SndCAAckCount is the number of packets acknowledged during
// congestion avoidance. When enough packets have been ack'd (typically
// cwnd packets), the congestion window is incremented by one.
SndCAAckCount int
// Outstanding is the number of packets that have been sent but not yet
// acknowledged.
Outstanding int
// SackedOut is the number of packets which have been selectively
// acked.
SackedOut int
// SndWnd is the send window size in bytes.
SndWnd seqnum.Size
// SndUna is the next unacknowledged sequence number.
SndUna seqnum.Value
// SndNxt is the sequence number of the next segment to be sent.
SndNxt seqnum.Value
// RTTMeasureSeqNum is the sequence number being used for the latest
// RTT measurement.
RTTMeasureSeqNum seqnum.Value
// RTTMeasureTime is the time when the RTTMeasureSeqNum was sent.
RTTMeasureTime tcpip.MonotonicTime
// Closed indicates that the caller has closed the endpoint for
// sending.
Closed bool
// RTO is the retransmit timeout as defined in section of 2 of RFC
// 6298.
RTO time.Duration
// RTTState holds information about the endpoint's round trip time.
RTTState TCPRTTState
// MaxPayloadSize is the maximum size of the payload of a given
// segment. It is initialized on demand.
MaxPayloadSize int
// SndWndScale is the number of bits to shift left when reading the
// send window size from a segment.
SndWndScale uint8
// MaxSentAck is the highest acknowledgement number sent till now.
MaxSentAck seqnum.Value
// FastRecovery holds the fast recovery state for the endpoint.
FastRecovery TCPFastRecoveryState
// Cubic holds the state related to CUBIC congestion control.
Cubic TCPCubicState
// RACKState holds the state related to RACK loss detection algorithm.
RACKState TCPRACKState
// RetransmitTS records the timestamp used to detect spurious recovery.
RetransmitTS uint32
// SpuriousRecovery indicates if the sender entered recovery spuriously.
SpuriousRecovery bool
}
// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint.
//
// +stateify savable
type TCPSACKInfo struct {
// Blocks is the list of SACK Blocks that identify the out of order
// segments held by a given TCP endpoint.
Blocks []header.SACKBlock
// ReceivedBlocks are the SACK blocks received by this endpoint from
// the peer endpoint.
ReceivedBlocks []header.SACKBlock
// MaxSACKED is the highest sequence number that has been SACKED by the
// peer.
MaxSACKED seqnum.Value
}
// RcvBufAutoTuneParams holds state related to TCP receive buffer auto-tuning.
//
// +stateify savable
type RcvBufAutoTuneParams struct {
// MeasureTime is the time at which the current measurement was
// started.
MeasureTime tcpip.MonotonicTime
// CopiedBytes is the number of bytes copied to user space since this
// measure began.
CopiedBytes int
// PrevCopiedBytes is the number of bytes copied to userspace in the
// previous RTT period.
PrevCopiedBytes int
// RcvBufSize is the auto tuned receive buffer size.
RcvBufSize int
// RTT is the smoothed RTT as measured by observing the time between
// when a byte is first acknowledged and the receipt of data that is at
// least one window beyond the sequence number that was acknowledged.
RTT time.Duration
// RTTVar is the "round-trip time variation" as defined in section 2 of
// RFC6298.
RTTVar time.Duration
// RTTMeasureSeqNumber is the highest acceptable sequence number at the
// time this RTT measurement period began.
RTTMeasureSeqNumber seqnum.Value
// RTTMeasureTime is the absolute time at which the current RTT
// measurement period began.
RTTMeasureTime tcpip.MonotonicTime
// Disabled is true if an explicit receive buffer is set for the
// endpoint.
Disabled bool
}
// TCPRcvBufState contains information about the state of an endpoint's receive
// socket buffer.
//
// +stateify savable
type TCPRcvBufState struct {
// RcvBufUsed is the amount of bytes actually held in the receive
// socket buffer for the endpoint.
RcvBufUsed int
// RcvBufAutoTuneParams is used to hold state variables to compute the
// auto tuned receive buffer size.
RcvAutoParams RcvBufAutoTuneParams
// RcvClosed if true, indicates the endpoint has been closed for
// reading.
RcvClosed bool
}
// TCPSndBufState contains information about the state of an endpoint's send
// socket buffer.
//
// +stateify savable
type TCPSndBufState struct {
// SndBufSize is the size of the socket send buffer.
SndBufSize int
// SndBufUsed is the number of bytes held in the socket send buffer.
SndBufUsed int
// SndClosed indicates that the endpoint has been closed for sends.
SndClosed bool
// PacketTooBigCount is used to notify the main protocol routine how
// many times a "packet too big" control packet is received.
PacketTooBigCount int
// SndMTU is the smallest MTU seen in the control packets received.
SndMTU int
// AutoTuneSndBufDisabled indicates that the auto tuning of send buffer
// is disabled.
AutoTuneSndBufDisabled atomicbitops.Uint32
}
// TCPEndpointStateInner contains the members of TCPEndpointState used directly
// (that is, not within another containing struct) within the endpoint's
// internal implementation.
//
// +stateify savable
type TCPEndpointStateInner struct {
// TSOffset is a randomized offset added to the value of the TSVal
// field in the timestamp option.
TSOffset tcp.TSOffset
// SACKPermitted is set to true if the peer sends the TCPSACKPermitted
// option in the SYN/SYN-ACK.
SACKPermitted bool
// SendTSOk is used to indicate when the TS Option has been negotiated.
// When sendTSOk is true every non-RST segment should carry a TS as per
// RFC7323#section-1.1.
SendTSOk bool
// RecentTS is the timestamp that should be sent in the TSEcr field of
// the timestamp for future segments sent by the endpoint. This field
// is updated if required when a new segment is received by this
// endpoint.
RecentTS uint32
}
// TCPEndpointState is a copy of the internal state of a TCP endpoint.
//
// +stateify savable
type TCPEndpointState struct {
// TCPEndpointStateInner contains the members of TCPEndpointState used
// by the endpoint's internal implementation.
TCPEndpointStateInner
// ID is a copy of the TransportEndpointID for the endpoint.
ID TCPEndpointID
// SegTime denotes the absolute time when this segment was received.
SegTime tcpip.MonotonicTime
// RcvBufState contains information about the state of the endpoint's
// receive socket buffer.
RcvBufState TCPRcvBufState
// SndBufState contains information about the state of the endpoint's
// send socket buffer.
SndBufState TCPSndBufState
// SACK holds TCP SACK related information for this endpoint.
SACK TCPSACKInfo
// Receiver holds variables related to the TCP receiver for the
// endpoint.
Receiver TCPReceiverState
// Sender holds state related to the TCP Sender for the endpoint.
Sender TCPSenderState
}

View File

@@ -0,0 +1,733 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/ports"
)
// +stateify savable
type protocolIDs struct {
network tcpip.NetworkProtocolNumber
transport tcpip.TransportProtocolNumber
}
// transportEndpoints manages all endpoints of a given protocol. It has its own
// mutex so as to reduce interference between protocols.
//
// +stateify savable
type transportEndpoints struct {
mu transportEndpointsRWMutex `state:"nosave"`
// +checklocks:mu
endpoints map[TransportEndpointID]*endpointsByNIC
// rawEndpoints contains endpoints for raw sockets, which receive all
// traffic of a given protocol regardless of port.
//
// +checklocks:mu
rawEndpoints []RawTransportEndpoint
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
eps.mu.Lock()
defer eps.mu.Unlock()
epsByNIC, ok := eps.endpoints[id]
if !ok {
return
}
if !epsByNIC.unregisterEndpoint(bindToDevice, ep, flags) {
return
}
delete(eps.endpoints, id)
}
func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint {
eps.mu.RLock()
defer eps.mu.RUnlock()
es := make([]TransportEndpoint, 0, len(eps.endpoints))
for _, e := range eps.endpoints {
es = append(es, e.transportEndpoints()...)
}
return es
}
// iterEndpointsLocked yields all endpointsByNIC in eps that match id, in
// descending order of match quality. If a call to yield returns false,
// iterEndpointsLocked stops iteration and returns immediately.
//
// +checklocksread:eps.mu
func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield func(*endpointsByNIC) bool) {
// Try to find a match with the id as provided.
if ep, ok := eps.endpoints[id]; ok {
if !yield(ep) {
return
}
}
// Try to find a match with the id minus the local address.
nid := id
nid.LocalAddress = tcpip.Address{}
if ep, ok := eps.endpoints[nid]; ok {
if !yield(ep) {
return
}
}
// Try to find a match with the id minus the remote part.
nid.LocalAddress = id.LocalAddress
nid.RemoteAddress = tcpip.Address{}
nid.RemotePort = 0
if ep, ok := eps.endpoints[nid]; ok {
if !yield(ep) {
return
}
}
// Try to find a match with only the local port.
nid.LocalAddress = tcpip.Address{}
if ep, ok := eps.endpoints[nid]; ok {
if !yield(ep) {
return
}
}
}
// findAllEndpointsLocked returns all endpointsByNIC in eps that match id, in
// descending order of match quality.
//
// +checklocksread:eps.mu
func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []*endpointsByNIC {
var matchedEPs []*endpointsByNIC
eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
matchedEPs = append(matchedEPs, ep)
return true
})
return matchedEPs
}
// findEndpointLocked returns the endpoint that most closely matches the given id.
//
// +checklocksread:eps.mu
func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpointsByNIC {
var matchedEP *endpointsByNIC
eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
matchedEP = ep
return false
})
return matchedEP
}
// +stateify savable
type endpointsByNIC struct {
// seed is a random secret for a jenkins hash.
seed uint32
mu endpointsByNICRWMutex `state:"nosave"`
// +checklocks:mu
endpoints map[tcpip.NICID]*multiPortEndpoint
}
func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
epsByNIC.mu.RLock()
defer epsByNIC.mu.RUnlock()
var eps []TransportEndpoint
for _, ep := range epsByNIC.endpoints {
eps = append(eps, ep.transportEndpoints()...)
}
return eps
}
// handlePacket is called by the stack when new packets arrive to this transport
// endpoint. It returns false if the packet could not be matched to any
// transport endpoint, true otherwise.
func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) bool {
epsByNIC.mu.RLock()
mpep, ok := epsByNIC.endpoints[pkt.NICID]
if !ok {
if mpep, ok = epsByNIC.endpoints[0]; !ok {
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
return false
}
}
// If this is a broadcast or multicast datagram, deliver the datagram to all
// endpoints bound to the right device.
if isInboundMulticastOrBroadcast(pkt, id.LocalAddress) {
mpep.handlePacketAll(id, pkt)
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
return true
}
// multiPortEndpoints are guaranteed to have at least one element.
transEP := mpep.selectEndpoint(id, epsByNIC.seed)
if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue {
queuedProtocol.QueuePacket(transEP, id, pkt)
epsByNIC.mu.RUnlock()
return true
}
epsByNIC.mu.RUnlock()
transEP.HandlePacket(id, pkt)
return true
}
// handleError delivers an error to the transport endpoint identified by id.
func (epsByNIC *endpointsByNIC) handleError(n *nic, id TransportEndpointID, transErr TransportError, pkt *PacketBuffer) {
epsByNIC.mu.RLock()
mpep, ok := epsByNIC.endpoints[n.ID()]
if !ok {
mpep, ok = epsByNIC.endpoints[0]
}
if !ok {
epsByNIC.mu.RUnlock()
return
}
// TODO(eyalsoha): Why don't we look at id to see if this packet needs to
// broadcast like we are doing with handlePacket above?
// multiPortEndpoints are guaranteed to have at least one element.
transEP := mpep.selectEndpoint(id, epsByNIC.seed)
epsByNIC.mu.RUnlock()
transEP.HandleError(transErr, pkt)
}
// registerEndpoint returns true if it succeeds. It fails and returns
// false if ep already has an element with the same key.
func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
epsByNIC.mu.Lock()
defer epsByNIC.mu.Unlock()
multiPortEp, ok := epsByNIC.endpoints[bindToDevice]
if !ok {
multiPortEp = &multiPortEndpoint{
demux: d,
netProto: netProto,
transProto: transProto,
}
}
if err := multiPortEp.singleRegisterEndpoint(t, flags); err != nil {
return err
}
// Only add this newly created multiportEndpoint if the singleRegisterEndpoint
// succeeded.
if !ok {
epsByNIC.endpoints[bindToDevice] = multiPortEp
}
return nil
}
func (epsByNIC *endpointsByNIC) checkEndpoint(flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
epsByNIC.mu.RLock()
defer epsByNIC.mu.RUnlock()
multiPortEp, ok := epsByNIC.endpoints[bindToDevice]
if !ok {
return nil
}
return multiPortEp.singleCheckEndpoint(flags)
}
// unregisterEndpoint returns true if endpointsByNIC has to be unregistered.
func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint, flags ports.Flags) bool {
epsByNIC.mu.Lock()
defer epsByNIC.mu.Unlock()
multiPortEp, ok := epsByNIC.endpoints[bindToDevice]
if !ok {
return false
}
if multiPortEp.unregisterEndpoint(t, flags) {
delete(epsByNIC.endpoints, bindToDevice)
}
return len(epsByNIC.endpoints) == 0
}
// transportDemuxer demultiplexes packets targeted at a transport endpoint
// (i.e., after they've been parsed by the network layer). It does two levels
// of demultiplexing: first based on the network and transport protocols, then
// based on endpoints IDs. It should only be instantiated via
// newTransportDemuxer.
//
// +stateify savable
type transportDemuxer struct {
stack *Stack
// protocol is immutable.
protocol map[protocolIDs]*transportEndpoints
queuedProtocols map[protocolIDs]queuedTransportProtocol
}
// queuedTransportProtocol if supported by a protocol implementation will cause
// the dispatcher to delivery packets to the QueuePacket method instead of
// calling HandlePacket directly on the endpoint.
type queuedTransportProtocol interface {
QueuePacket(ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer)
}
func newTransportDemuxer(stack *Stack) *transportDemuxer {
d := &transportDemuxer{
stack: stack,
protocol: make(map[protocolIDs]*transportEndpoints),
queuedProtocols: make(map[protocolIDs]queuedTransportProtocol),
}
// Add each network and transport pair to the demuxer.
for netProto := range stack.networkProtocols {
for proto := range stack.transportProtocols {
protoIDs := protocolIDs{netProto, proto}
d.protocol[protoIDs] = &transportEndpoints{
endpoints: make(map[TransportEndpointID]*endpointsByNIC),
}
qTransProto, isQueued := (stack.transportProtocols[proto].proto).(queuedTransportProtocol)
if isQueued {
d.queuedProtocols[protoIDs] = qTransProto
}
}
}
return d
}
// registerEndpoint registers the given endpoint with the dispatcher such that
// packets that match the endpoint ID are delivered to it.
func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
for i, n := range netProtos {
if err := d.singleRegisterEndpoint(n, protocol, id, ep, flags, bindToDevice); err != nil {
d.unregisterEndpoint(netProtos[:i], protocol, id, ep, flags, bindToDevice)
return err
}
}
return nil
}
// checkEndpoint checks if an endpoint can be registered with the dispatcher.
func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
for _, n := range netProtos {
if err := d.singleCheckEndpoint(n, protocol, id, flags, bindToDevice); err != nil {
return err
}
}
return nil
}
// multiPortEndpoint is a container for TransportEndpoints which are bound to
// the same pair of address and port. endpointsArr always has at least one
// element.
//
// FIXME(gvisor.dev/issue/873): Restore this properly. Currently, we just save
// this to ensure that the underlying endpoints get saved/restored, but not not
// use the restored copy.
//
// +stateify savable
type multiPortEndpoint struct {
demux *transportDemuxer
netProto tcpip.NetworkProtocolNumber
transProto tcpip.TransportProtocolNumber
flags ports.FlagCounter
mu multiPortEndpointRWMutex `state:"nosave"`
// endpoints stores the transport endpoints in the order in which they
// were bound. This is required for UDP SO_REUSEADDR.
//
// +checklocks:mu
endpoints []TransportEndpoint
}
func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint {
ep.mu.RLock()
eps := append([]TransportEndpoint(nil), ep.endpoints...)
ep.mu.RUnlock()
return eps
}
// reciprocalScale scales a value into range [0, n).
//
// This is similar to val % n, but faster.
// See http://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
func reciprocalScale(val, n uint32) uint32 {
return uint32((uint64(val) * uint64(n)) >> 32)
}
// selectEndpoint calculates a hash of destination and source addresses and
// ports then uses it to select a socket. In this case, all packets from one
// address will be sent to same endpoint.
func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID, seed uint32) TransportEndpoint {
ep.mu.RLock()
defer ep.mu.RUnlock()
if len(ep.endpoints) == 1 {
return ep.endpoints[0]
}
if ep.flags.SharedFlags().ToFlags().Effective().MostRecent {
return ep.endpoints[len(ep.endpoints)-1]
}
payload := []byte{
byte(id.LocalPort),
byte(id.LocalPort >> 8),
byte(id.RemotePort),
byte(id.RemotePort >> 8),
}
h := jenkins.Sum32(seed)
h.Write(payload)
h.Write(id.LocalAddress.AsSlice())
h.Write(id.RemoteAddress.AsSlice())
hash := h.Sum32()
idx := reciprocalScale(hash, uint32(len(ep.endpoints)))
return ep.endpoints[idx]
}
func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) {
ep.mu.RLock()
queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}]
// HandlePacket may modify pkt, so each endpoint needs
// its own copy except for the final one.
for _, endpoint := range ep.endpoints[:len(ep.endpoints)-1] {
clone := pkt.Clone()
if mustQueue {
queuedProtocol.QueuePacket(endpoint, id, clone)
} else {
endpoint.HandlePacket(id, clone)
}
clone.DecRef()
}
if endpoint := ep.endpoints[len(ep.endpoints)-1]; mustQueue {
queuedProtocol.QueuePacket(endpoint, id, pkt)
} else {
endpoint.HandlePacket(id, pkt)
}
ep.mu.RUnlock() // Don't use defer for performance reasons.
}
// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint
// list. The list might be empty already.
func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
bits := flags.Bits() & ports.MultiBindFlagMask
if len(ep.endpoints) != 0 {
// If it was previously bound, we need to check if we can bind again.
if ep.flags.TotalRefs() > 0 && bits&ep.flags.SharedFlags() == 0 {
return &tcpip.ErrPortInUse{}
}
}
ep.endpoints = append(ep.endpoints, t)
ep.flags.AddRef(bits)
return nil
}
func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) tcpip.Error {
ep.mu.RLock()
defer ep.mu.RUnlock()
bits := flags.Bits() & ports.MultiBindFlagMask
if len(ep.endpoints) != 0 {
// If it was previously bound, we need to check if we can bind again.
if ep.flags.TotalRefs() > 0 && bits&ep.flags.SharedFlags() == 0 {
return &tcpip.ErrPortInUse{}
}
}
return nil
}
// unregisterEndpoint returns true if multiPortEndpoint has to be unregistered.
func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports.Flags) bool {
ep.mu.Lock()
defer ep.mu.Unlock()
for i, endpoint := range ep.endpoints {
if endpoint == t {
copy(ep.endpoints[i:], ep.endpoints[i+1:])
ep.endpoints[len(ep.endpoints)-1] = nil
ep.endpoints = ep.endpoints[:len(ep.endpoints)-1]
ep.flags.DropRef(flags.Bits() & ports.MultiBindFlagMask)
break
}
}
return len(ep.endpoints) == 0
}
func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
if id.RemotePort != 0 {
// SO_REUSEPORT only applies to bound/listening endpoints.
flags.LoadBalanced = false
}
eps, ok := d.protocol[protocolIDs{netProto, protocol}]
if !ok {
return &tcpip.ErrUnknownProtocol{}
}
eps.mu.Lock()
defer eps.mu.Unlock()
epsByNIC, ok := eps.endpoints[id]
if !ok {
epsByNIC = &endpointsByNIC{
endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
seed: d.stack.seed,
}
}
if err := epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice); err != nil {
return err
}
// Only add this newly created epsByNIC if registerEndpoint succeeded.
if !ok {
eps.endpoints[id] = epsByNIC
}
return nil
}
func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
if id.RemotePort != 0 {
// SO_REUSEPORT only applies to bound/listening endpoints.
flags.LoadBalanced = false
}
eps, ok := d.protocol[protocolIDs{netProto, protocol}]
if !ok {
return &tcpip.ErrUnknownProtocol{}
}
eps.mu.RLock()
defer eps.mu.RUnlock()
epsByNIC, ok := eps.endpoints[id]
if !ok {
return nil
}
return epsByNIC.checkEndpoint(flags, bindToDevice)
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
if id.RemotePort != 0 {
// SO_REUSEPORT only applies to bound/listening endpoints.
flags.LoadBalanced = false
}
for _, n := range netProtos {
if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
eps.unregisterEndpoint(id, ep, flags, bindToDevice)
}
}
}
// deliverPacket attempts to find one or more matching transport endpoints, and
// then, if matches are found, delivers the packet to them. Returns true if
// the packet no longer needs to be handled.
func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}]
if !ok {
return false
}
// If the packet is a UDP broadcast or multicast, then find all matching
// transport endpoints.
if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(pkt, id.LocalAddress) {
eps.mu.RLock()
destEPs := eps.findAllEndpointsLocked(id)
eps.mu.RUnlock()
// Fail if we didn't find at least one matching transport endpoint.
if len(destEPs) == 0 {
d.stack.stats.UDP.UnknownPortErrors.Increment()
return false
}
// handlePacket takes may modify pkt, so each endpoint needs its own
// copy except for the final one.
for _, ep := range destEPs[:len(destEPs)-1] {
clone := pkt.Clone()
ep.handlePacket(id, clone)
clone.DecRef()
}
destEPs[len(destEPs)-1].handlePacket(id, pkt)
return true
}
// If the packet is a TCP packet with a unspecified source or non-unicast
// destination address, then do nothing further and instruct the caller to do
// the same. The network layer handles address validation for specified source
// addresses.
if protocol == header.TCPProtocolNumber && (!isSpecified(id.LocalAddress) || !isSpecified(id.RemoteAddress) || isInboundMulticastOrBroadcast(pkt, id.LocalAddress)) {
// TCP can only be used to communicate between a single source and a
// single destination; the addresses must be unicast.e
d.stack.stats.TCP.InvalidSegmentsReceived.Increment()
return true
}
eps.mu.RLock()
ep := eps.findEndpointLocked(id)
eps.mu.RUnlock()
if ep == nil {
if protocol == header.UDPProtocolNumber {
d.stack.stats.UDP.UnknownPortErrors.Increment()
}
return false
}
return ep.handlePacket(id, pkt)
}
// deliverRawPacket attempts to deliver the given packet and returns whether it
// was delivered successfully.
func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool {
eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}]
if !ok {
return false
}
// As in net/ipv4/ip_input.c:ip_local_deliver, attempt to deliver via
// raw endpoint first. If there are multiple raw endpoints, they all
// receive the packet.
eps.mu.RLock()
// Copy the list of raw endpoints to avoid packet handling under lock.
var rawEPs []RawTransportEndpoint
if n := len(eps.rawEndpoints); n != 0 {
rawEPs = make([]RawTransportEndpoint, n)
if m := copy(rawEPs, eps.rawEndpoints); m != n {
panic(fmt.Sprintf("unexpected copy = %d, want %d", m, n))
}
}
eps.mu.RUnlock()
for _, rawEP := range rawEPs {
// Each endpoint gets its own copy of the packet for the sake
// of save/restore.
clone := pkt.Clone()
rawEP.HandlePacket(clone)
clone.DecRef()
}
return len(rawEPs) != 0
}
// deliverError attempts to deliver the given error to the appropriate transport
// endpoint.
//
// Returns true if the error was delivered.
func (d *transportDemuxer) deliverError(n *nic, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr TransportError, pkt *PacketBuffer, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{net, trans}]
if !ok {
return false
}
eps.mu.RLock()
ep := eps.findEndpointLocked(id)
eps.mu.RUnlock()
if ep == nil {
return false
}
ep.handleError(n, id, transErr, pkt)
return true
}
// findTransportEndpoint find a single endpoint that most closely matches the provided id.
func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint {
eps, ok := d.protocol[protocolIDs{netProto, transProto}]
if !ok {
return nil
}
eps.mu.RLock()
epsByNIC := eps.findEndpointLocked(id)
if epsByNIC == nil {
eps.mu.RUnlock()
return nil
}
epsByNIC.mu.RLock()
eps.mu.RUnlock()
mpep, ok := epsByNIC.endpoints[nicID]
if !ok {
if mpep, ok = epsByNIC.endpoints[0]; !ok {
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
return nil
}
}
ep := mpep.selectEndpoint(id, epsByNIC.seed)
epsByNIC.mu.RUnlock()
return ep
}
// registerRawEndpoint registers the given endpoint with the dispatcher such
// that packets of the appropriate protocol are delivered to it. A single
// packet can be sent to one or more raw endpoints along with a non-raw
// endpoint.
func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) tcpip.Error {
eps, ok := d.protocol[protocolIDs{netProto, transProto}]
if !ok {
return &tcpip.ErrNotSupported{}
}
eps.mu.Lock()
eps.rawEndpoints = append(eps.rawEndpoints, ep)
eps.mu.Unlock()
return nil
}
// unregisterRawEndpoint unregisters the raw endpoint for the given transport
// protocol such that it won't receive any more packets.
func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) {
eps, ok := d.protocol[protocolIDs{netProto, transProto}]
if !ok {
panic(fmt.Errorf("tried to unregister endpoint with unsupported network and transport protocol pair: %d, %d", netProto, transProto))
}
eps.mu.Lock()
for i, rawEP := range eps.rawEndpoints {
if rawEP == ep {
lastIdx := len(eps.rawEndpoints) - 1
eps.rawEndpoints[i] = eps.rawEndpoints[lastIdx]
eps.rawEndpoints[lastIdx] = nil
eps.rawEndpoints = eps.rawEndpoints[:lastIdx]
break
}
}
eps.mu.Unlock()
}
func isInboundMulticastOrBroadcast(pkt *PacketBuffer, localAddr tcpip.Address) bool {
return pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(localAddr) || header.IsV6MulticastAddress(localAddr)
}
func isSpecified(addr tcpip.Address) bool {
return addr != header.IPv4Any && addr != header.IPv6Any
}

View File

@@ -0,0 +1,96 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type transportEndpointsRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var transportEndpointslockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type transportEndpointslockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *transportEndpointsRWMutex) Lock() {
locking.AddGLock(transportEndpointsprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *transportEndpointsRWMutex) NestedLock(i transportEndpointslockNameIndex) {
locking.AddGLock(transportEndpointsprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *transportEndpointsRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(transportEndpointsprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *transportEndpointsRWMutex) NestedUnlock(i transportEndpointslockNameIndex) {
m.mu.Unlock()
locking.DelGLock(transportEndpointsprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *transportEndpointsRWMutex) RLock() {
locking.AddGLock(transportEndpointsprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *transportEndpointsRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(transportEndpointsprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *transportEndpointsRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *transportEndpointsRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *transportEndpointsRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var transportEndpointsprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func transportEndpointsinitLockNames() {}
func init() {
transportEndpointsinitLockNames()
transportEndpointsprefixIndex = locking.NewMutexClass(reflect.TypeOf(transportEndpointsRWMutex{}), transportEndpointslockNames)
}

View File

@@ -0,0 +1,239 @@
package stack
// ElementMapper provides an identity mapping by default.
//
// This can be replaced to provide a struct that maps elements to linker
// objects, if they are not the same. An ElementMapper is not typically
// required if: Linker is left as is, Element is left as is, or Linker and
// Element are the same type.
type tupleElementMapper struct{}
// linkerFor maps an Element to a Linker.
//
// This default implementation should be inlined.
//
//go:nosplit
func (tupleElementMapper) linkerFor(elem *tuple) *tuple { return elem }
// List is an intrusive list. Entries can be added to or removed from the list
// in O(1) time and with no additional memory allocations.
//
// The zero value for List is an empty list ready to use.
//
// To iterate over a list (where l is a List):
//
// for e := l.Front(); e != nil; e = e.Next() {
// // do something with e.
// }
//
// +stateify savable
type tupleList struct {
head *tuple
tail *tuple
}
// Reset resets list l to the empty state.
func (l *tupleList) Reset() {
l.head = nil
l.tail = nil
}
// Empty returns true iff the list is empty.
//
//go:nosplit
func (l *tupleList) Empty() bool {
return l.head == nil
}
// Front returns the first element of list l or nil.
//
//go:nosplit
func (l *tupleList) Front() *tuple {
return l.head
}
// Back returns the last element of list l or nil.
//
//go:nosplit
func (l *tupleList) Back() *tuple {
return l.tail
}
// Len returns the number of elements in the list.
//
// NOTE: This is an O(n) operation.
//
//go:nosplit
func (l *tupleList) Len() (count int) {
for e := l.Front(); e != nil; e = (tupleElementMapper{}.linkerFor(e)).Next() {
count++
}
return count
}
// PushFront inserts the element e at the front of list l.
//
//go:nosplit
func (l *tupleList) PushFront(e *tuple) {
linker := tupleElementMapper{}.linkerFor(e)
linker.SetNext(l.head)
linker.SetPrev(nil)
if l.head != nil {
tupleElementMapper{}.linkerFor(l.head).SetPrev(e)
} else {
l.tail = e
}
l.head = e
}
// PushFrontList inserts list m at the start of list l, emptying m.
//
//go:nosplit
func (l *tupleList) PushFrontList(m *tupleList) {
if l.head == nil {
l.head = m.head
l.tail = m.tail
} else if m.head != nil {
tupleElementMapper{}.linkerFor(l.head).SetPrev(m.tail)
tupleElementMapper{}.linkerFor(m.tail).SetNext(l.head)
l.head = m.head
}
m.head = nil
m.tail = nil
}
// PushBack inserts the element e at the back of list l.
//
//go:nosplit
func (l *tupleList) PushBack(e *tuple) {
linker := tupleElementMapper{}.linkerFor(e)
linker.SetNext(nil)
linker.SetPrev(l.tail)
if l.tail != nil {
tupleElementMapper{}.linkerFor(l.tail).SetNext(e)
} else {
l.head = e
}
l.tail = e
}
// PushBackList inserts list m at the end of list l, emptying m.
//
//go:nosplit
func (l *tupleList) PushBackList(m *tupleList) {
if l.head == nil {
l.head = m.head
l.tail = m.tail
} else if m.head != nil {
tupleElementMapper{}.linkerFor(l.tail).SetNext(m.head)
tupleElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
l.tail = m.tail
}
m.head = nil
m.tail = nil
}
// InsertAfter inserts e after b.
//
//go:nosplit
func (l *tupleList) InsertAfter(b, e *tuple) {
bLinker := tupleElementMapper{}.linkerFor(b)
eLinker := tupleElementMapper{}.linkerFor(e)
a := bLinker.Next()
eLinker.SetNext(a)
eLinker.SetPrev(b)
bLinker.SetNext(e)
if a != nil {
tupleElementMapper{}.linkerFor(a).SetPrev(e)
} else {
l.tail = e
}
}
// InsertBefore inserts e before a.
//
//go:nosplit
func (l *tupleList) InsertBefore(a, e *tuple) {
aLinker := tupleElementMapper{}.linkerFor(a)
eLinker := tupleElementMapper{}.linkerFor(e)
b := aLinker.Prev()
eLinker.SetNext(a)
eLinker.SetPrev(b)
aLinker.SetPrev(e)
if b != nil {
tupleElementMapper{}.linkerFor(b).SetNext(e)
} else {
l.head = e
}
}
// Remove removes e from l.
//
//go:nosplit
func (l *tupleList) Remove(e *tuple) {
linker := tupleElementMapper{}.linkerFor(e)
prev := linker.Prev()
next := linker.Next()
if prev != nil {
tupleElementMapper{}.linkerFor(prev).SetNext(next)
} else if l.head == e {
l.head = next
}
if next != nil {
tupleElementMapper{}.linkerFor(next).SetPrev(prev)
} else if l.tail == e {
l.tail = prev
}
linker.SetNext(nil)
linker.SetPrev(nil)
}
// Entry is a default implementation of Linker. Users can add anonymous fields
// of this type to their structs to make them automatically implement the
// methods needed by List.
//
// +stateify savable
type tupleEntry struct {
next *tuple
prev *tuple
}
// Next returns the entry that follows e in the list.
//
//go:nosplit
func (e *tupleEntry) Next() *tuple {
return e.next
}
// Prev returns the entry that precedes e in the list.
//
//go:nosplit
func (e *tupleEntry) Prev() *tuple {
return e.prev
}
// SetNext assigns 'entry' as the entry that follows e in the list.
//
//go:nosplit
func (e *tupleEntry) SetNext(elem *tuple) {
e.next = elem
}
// SetPrev assigns 'entry' as the entry that precedes e in the list.
//
//go:nosplit
func (e *tupleEntry) SetPrev(elem *tuple) {
e.prev = elem
}