Update dependencies

This commit is contained in:
bluepython508
2025-04-09 01:00:12 +01:00
parent f0641ffd6e
commit 5a9cfc022c
882 changed files with 68930 additions and 24201 deletions

View File

@@ -318,8 +318,11 @@ func (b *Buffer) PullUp(offset, length int) (View, bool) {
if x := curr.Intersect(tgt); x.Len() == tgt.Len() {
// buf covers the whole requested target range.
sub := x.Offset(-curr.begin)
// Don't increment the reference count of the underlying chunk. Views
// returned by PullUp are explicitly unowned and read only
if v.sharesChunk() {
old := v.chunk
v.chunk = v.chunk.Clone()
old.DecRef()
}
new := View{
read: v.read + sub.begin,
write: v.read + sub.end,

View File

@@ -308,7 +308,7 @@ func (fs FeatureSet) HasFeature(feature Feature) bool {
// WriteCPUInfoTo is to generate a section of one cpu in /proc/cpuinfo. This is
// a minimal /proc/cpuinfo, it is missing some fields like "microcode" that are
// not always printed in Linux. The bogomips field is simply made up.
// not always printed in Linux. Several fields are simply made up.
func (fs FeatureSet) WriteCPUInfoTo(cpu, numCPU uint, w io.Writer) {
// Avoid many redundant calls here, since this can occasionally appear
// in the hot path. Read all basic information up front, see above.
@@ -322,6 +322,13 @@ func (fs FeatureSet) WriteCPUInfoTo(cpu, numCPU uint, w io.Writer) {
fmt.Fprintf(w, "model name\t: %s\n", "unknown") // Unknown for now.
fmt.Fprintf(w, "stepping\t: %s\n", "unknown") // Unknown for now.
fmt.Fprintf(w, "cpu MHz\t\t: %.3f\n", cpuFreqMHz)
// Pretend the CPU has 8192 KB of cache. Note that real /proc/cpuinfo exposes total L3 cache
// size on Intel and per-core L2 cache size on AMD (as of Linux 6.1.0), so the value of this
// field is not really important in practice. Any value that is chosen here will be wrong
// by an order of magnitude on a significant chunk of x86 machines.
// 8192 KB is selected because it is a reasonable size that will be effectively usable on
// lightly loaded machines - most machines have 1-4MB of L3 cache per core.
fmt.Fprintf(w, "cache size\t: 8192 KB\n")
fmt.Fprintf(w, "physical id\t: 0\n") // Pretend all CPUs are in the same socket.
fmt.Fprintf(w, "siblings\t: %d\n", numCPU)
fmt.Fprintf(w, "core id\t\t: %d\n", cpu)
@@ -473,3 +480,17 @@ func (fs FeatureSet) archCheckHostCompatible(hfs FeatureSet) error {
return nil
}
// AllowedHWCap1 returns the HWCAP1 bits that the guest is allowed to depend
// on.
func (fs FeatureSet) AllowedHWCap1() uint64 {
// HWCAPS are not supported on amd64.
return 0
}
// AllowedHWCap2 returns the HWCAP2 bits that the guest is allowed to depend
// on.
func (fs FeatureSet) AllowedHWCap2() uint64 {
// HWCAPS are not supported on amd64.
return 0
}

View File

@@ -1,7 +1,7 @@
// automatically generated by stateify.
//go:build amd64 && amd64 && amd64 && amd64
// +build amd64,amd64,amd64,amd64
//go:build amd64 && amd64 && amd64 && amd64 && amd64
// +build amd64,amd64,amd64,amd64,amd64
package cpuid

View File

@@ -108,3 +108,47 @@ func (fs FeatureSet) WriteCPUInfoTo(cpu, numCPU uint, w io.Writer) {
func (FeatureSet) archCheckHostCompatible(FeatureSet) error {
return nil
}
// AllowedHWCap1 returns the HWCAP1 bits that the guest is allowed to depend
// on.
func (fs FeatureSet) AllowedHWCap1() uint64 {
// Pick a set of safe HWCAPS to expose. These do not rely on cpu state
// that gvisor does not restore after a context switch.
allowed := HWCAP_AES |
HWCAP_ASIMD |
HWCAP_ASIMDDP |
HWCAP_ASIMDFHM |
HWCAP_ASIMDHP |
HWCAP_ASIMDRDM |
HWCAP_ATOMICS |
HWCAP_CRC32 |
HWCAP_DCPOP |
HWCAP_DIT |
HWCAP_EVTSTRM |
HWCAP_FCMA |
HWCAP_FLAGM |
HWCAP_FP |
HWCAP_FPHP |
HWCAP_ILRCPC |
HWCAP_JSCVT |
HWCAP_LRCPC |
HWCAP_PMULL |
HWCAP_SHA1 |
HWCAP_SHA2 |
HWCAP_SHA3 |
HWCAP_SHA512 |
HWCAP_SM3 |
HWCAP_SM4 |
HWCAP_USCAT
return fs.hwCap.hwCap1 & uint64(allowed)
}
// AllowedHWCap2 returns the HWCAP2 bits that the guest is allowed to depend
// on.
func (fs FeatureSet) AllowedHWCap2() uint64 {
// We don't expose anything here yet, but this could be expanded to
// include features do not rely on cpu state that is not restored after
// a context switch.
allowed := 0
return fs.hwCap.hwCap2 & uint64(allowed)
}

View File

@@ -1,7 +1,7 @@
// automatically generated by stateify.
//go:build arm64 && arm64 && arm64
// +build arm64,arm64,arm64
//go:build arm64 && arm64 && arm64 && arm64
// +build arm64,arm64,arm64,arm64
package cpuid

View File

@@ -0,0 +1,24 @@
// Copyright 2024 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build amd64
// +build amd64
package cpuid
// See arch/x86/include/uapi/asm/hwcap2.h
const (
HWCAP2_RING3MWAIT = 1 << 0
HWCAP2_FSGSBASE = 1 << 1
)

View File

@@ -0,0 +1,79 @@
// Copyright 2024 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build arm64
// +build arm64
package cpuid
// See arch/arm64/include/uapi/asm/hwcap.h
const (
// HWCAP flags for AT_HWCAP.
HWCAP_FP = 1 << 0
HWCAP_ASIMD = 1 << 1
HWCAP_EVTSTRM = 1 << 2
HWCAP_AES = 1 << 3
HWCAP_PMULL = 1 << 4
HWCAP_SHA1 = 1 << 5
HWCAP_SHA2 = 1 << 6
HWCAP_CRC32 = 1 << 7
HWCAP_ATOMICS = 1 << 8
HWCAP_FPHP = 1 << 9
HWCAP_ASIMDHP = 1 << 10
HWCAP_CPUID = 1 << 11
HWCAP_ASIMDRDM = 1 << 12
HWCAP_JSCVT = 1 << 13
HWCAP_FCMA = 1 << 14
HWCAP_LRCPC = 1 << 15
HWCAP_DCPOP = 1 << 16
HWCAP_SHA3 = 1 << 17
HWCAP_SM3 = 1 << 18
HWCAP_SM4 = 1 << 19
HWCAP_ASIMDDP = 1 << 20
HWCAP_SHA512 = 1 << 21
HWCAP_SVE = 1 << 22
HWCAP_ASIMDFHM = 1 << 23
HWCAP_DIT = 1 << 24
HWCAP_USCAT = 1 << 25
HWCAP_ILRCPC = 1 << 26
HWCAP_FLAGM = 1 << 27
HWCAP_SSBS = 1 << 28
HWCAP_SB = 1 << 29
HWCAP_PACA = 1 << 30
HWCAP_PACG = 1 << 31
// HWCAP2 flags for AT_HWCAP2.
HWCAP2_DCPODP = 1 << 0
HWCAP2_SVE2 = 1 << 1
HWCAP2_SVEAES = 1 << 2
HWCAP2_SVEPMULL = 1 << 3
HWCAP2_SVEBITPERM = 1 << 4
HWCAP2_SVESHA3 = 1 << 5
HWCAP2_SVESM4 = 1 << 6
HWCAP2_FLAGM2 = 1 << 7
HWCAP2_FRINT = 1 << 8
HWCAP2_SVEI8MM = 1 << 9
HWCAP2_SVEF32MM = 1 << 10
HWCAP2_SVEF64MM = 1 << 11
HWCAP2_SVEBF16 = 1 << 12
HWCAP2_I8MM = 1 << 13
HWCAP2_BF16 = 1 << 14
HWCAP2_DGH = 1 << 15
HWCAP2_RNG = 1 << 16
HWCAP2_BTI = 1 << 17
HWCAP2_MTE = 1 << 18
HWCAP2_ECV = 1 << 19
HWCAP2_AFP = 1 << 20
HWCAP2_RPRES = 1 << 21
)

View File

@@ -18,9 +18,10 @@
package cpuid
import (
"io/ioutil"
"bufio"
"bytes"
"os"
"strconv"
"strings"
"gvisor.dev/gvisor/pkg/log"
)
@@ -180,39 +181,44 @@ var (
// filter installation. This value is used to create the fake /proc/cpuinfo
// from a FeatureSet.
func readMaxCPUFreq() {
cpuinfob, err := ioutil.ReadFile("/proc/cpuinfo")
cpuinfoFile, err := os.Open("/proc/cpuinfo")
if err != nil {
// Leave it as 0... the VDSO bails out in the same way.
log.Warningf("Could not read /proc/cpuinfo: %v", err)
log.Warningf("Could not open /proc/cpuinfo: %v", err)
return
}
cpuinfo := string(cpuinfob)
defer cpuinfoFile.Close()
// We get the value straight from host /proc/cpuinfo. On machines with
// frequency scaling enabled, this will only get the current value
// which will likely be inaccurate. This is fine on machines with
// frequency scaling disabled.
for _, line := range strings.Split(cpuinfo, "\n") {
if strings.Contains(line, "cpu MHz") {
splitMHz := strings.Split(line, ":")
s := bufio.NewScanner(cpuinfoFile)
for s.Scan() {
line := s.Bytes()
if bytes.Contains(line, []byte("cpu MHz")) {
splitMHz := bytes.Split(line, []byte(":"))
if len(splitMHz) < 2 {
log.Warningf("Could not read /proc/cpuinfo: malformed cpu MHz line")
log.Warningf("Could not parse /proc/cpuinfo: malformed cpu MHz line: %q", line)
return
}
// If there was a problem, leave cpuFreqMHz as 0.
var err error
cpuFreqMHz, err = strconv.ParseFloat(strings.TrimSpace(splitMHz[1]), 64)
splitMHzStr := string(bytes.TrimSpace(splitMHz[1]))
f64MHz, err := strconv.ParseFloat(splitMHzStr, 64)
if err != nil {
log.Warningf("Could not parse cpu MHz value %v: %v", splitMHz[1], err)
cpuFreqMHz = 0
log.Warningf("Could not parse cpu MHz value %q: %v", splitMHzStr, err)
return
}
cpuFreqMHz = f64MHz
return
}
}
if err := s.Err(); err != nil {
log.Warningf("Could not read /proc/cpuinfo: %v", err)
return
}
log.Warningf("Could not parse /proc/cpuinfo, it is empty or does not contain cpu MHz")
}
// xgetbv reads an extended control register.

View File

@@ -18,7 +18,7 @@
package cpuid
import (
"io/ioutil"
"os"
"runtime"
"strconv"
"strings"
@@ -51,7 +51,7 @@ func initCPUInfo() {
// warn about them not existing.
return
}
cpuinfob, err := ioutil.ReadFile("/proc/cpuinfo")
cpuinfob, err := os.ReadFile("/proc/cpuinfo")
if err != nil {
// Leave everything at 0, nothing can be done.
log.Warningf("Could not read /proc/cpuinfo: %v", err)

View File

@@ -26,7 +26,9 @@ type Static map[In]Out
// Fixed converts the FeatureSet to a fixed set.
func (fs FeatureSet) Fixed() FeatureSet {
return fs.ToStatic().ToFeatureSet()
sfs := fs.ToStatic().ToFeatureSet()
sfs.hwCap = fs.hwCap
return sfs
}
// ToStatic converts a FeatureSet to a Static function.

View File

@@ -394,6 +394,7 @@ type Waker struct {
allWakersNext *Waker
}
// +stateify savable
type wakerState struct {
asserted bool
other *Sleeper

View File

@@ -74,7 +74,36 @@ func (w *Waker) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.LoadValue(0, new(wakerState), func(y any) { w.loadS(ctx, y.(wakerState)) })
}
func (w *wakerState) StateTypeName() string {
return "pkg/sleep.wakerState"
}
func (w *wakerState) StateFields() []string {
return []string{
"asserted",
"other",
}
}
func (w *wakerState) beforeSave() {}
// +checklocksignore
func (w *wakerState) StateSave(stateSinkObject state.Sink) {
w.beforeSave()
stateSinkObject.Save(0, &w.asserted)
stateSinkObject.Save(1, &w.other)
}
func (w *wakerState) afterLoad(context.Context) {}
// +checklocksignore
func (w *wakerState) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &w.asserted)
stateSourceObject.Load(1, &w.other)
}
func init() {
state.Register((*Sleeper)(nil))
state.Register((*Waker)(nil))
state.Register((*wakerState)(nil))
}

View File

@@ -530,21 +530,52 @@ func (s *addrSet) RemoveAll() {
// if the caller needs to do additional work before removing each segment,
// iterate segments and call Remove in a loop instead.
func (s *addrSet) RemoveRange(r addrRange) addrGapIterator {
seg, gap := s.Find(r.Start)
if seg.Ok() {
seg = s.Isolate(seg, r)
gap = s.Remove(seg)
}
for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() {
seg = s.SplitAfter(seg, r.End)
gap = s.Remove(seg)
}
return gap
return s.RemoveRangeWith(r, nil)
}
// RemoveFullRange is equivalent to RemoveRange, except that if any key in the
// given range does not correspond to a segment, RemoveFullRange panics.
func (s *addrSet) RemoveFullRange(r addrRange) addrGapIterator {
return s.RemoveFullRangeWith(r, nil)
}
// RemoveRangeWith removes all segments in the given range. An iterator to the
// newly formed gap is returned, and all existing iterators are invalidated.
//
// The function f is applied to each segment immediately before it is removed,
// in order of ascending keys. Segments that lie partially outside r are split
// before f is called, such that f only observes segments entirely within r.
// Non-empty gaps between segments are skipped.
//
// RemoveRangeWith searches the set to find segments to remove. If the caller
// already has an iterator to either end of the range of segments to remove, or
// if the caller needs to do additional work before removing each segment,
// iterate segments and call Remove in a loop instead.
//
// N.B. f must not invalidate iterators into s.
func (s *addrSet) RemoveRangeWith(r addrRange, f func(seg addrIterator)) addrGapIterator {
seg, gap := s.Find(r.Start)
if seg.Ok() {
seg = s.Isolate(seg, r)
if f != nil {
f(seg)
}
gap = s.Remove(seg)
}
for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() {
seg = s.SplitAfter(seg, r.End)
if f != nil {
f(seg)
}
gap = s.Remove(seg)
}
return gap
}
// RemoveFullRangeWith is equivalent to RemoveRangeWith, except that if any key
// in the given range does not correspond to a segment, RemoveFullRangeWith
// panics.
func (s *addrSet) RemoveFullRangeWith(r addrRange, f func(seg addrIterator)) addrGapIterator {
seg := s.FindSegment(r.Start)
if !seg.Ok() {
panic(fmt.Sprintf("missing segment at %v", r.Start))
@@ -552,6 +583,9 @@ func (s *addrSet) RemoveFullRange(r addrRange) addrGapIterator {
seg = s.SplitBefore(seg, r.Start)
for {
seg = s.SplitAfter(seg, r.End)
if f != nil {
f(seg)
}
end := seg.End()
gap := s.Remove(seg)
if r.End <= end {
@@ -817,11 +851,11 @@ func (s *addrSet) Isolate(seg addrIterator, r addrRange) addrIterator {
// LowerBoundSegmentSplitBefore provides an iterator to the first segment to be
// mutated, suitable as the initial value for a loop variable.
func (s *addrSet) LowerBoundSegmentSplitBefore(min uintptr) addrIterator {
seg := s.LowerBoundSegment(min)
seg, gap := s.Find(min)
if seg.Ok() {
seg = s.SplitBefore(seg, min)
return s.SplitBefore(seg, min)
}
return seg
return gap.NextSegment()
}
// UpperBoundSegmentSplitAfter combines UpperBoundSegment and SplitAfter.
@@ -831,11 +865,11 @@ func (s *addrSet) LowerBoundSegmentSplitBefore(min uintptr) addrIterator {
// UpperBoundSegmentSplitAfter provides an iterator to the first segment to be
// mutated, suitable as the initial value for a loop variable.
func (s *addrSet) UpperBoundSegmentSplitAfter(max uintptr) addrIterator {
seg := s.UpperBoundSegment(max)
seg, gap := s.Find(max)
if seg.Ok() {
seg = s.SplitAfter(seg, max)
return s.SplitAfter(seg, max)
}
return seg
return gap.PrevSegment()
}
// VisitRange applies the function f to all segments intersecting the range r,

View File

@@ -18,7 +18,6 @@ import (
"bytes"
"context"
"fmt"
"io"
"math"
"reflect"
@@ -143,7 +142,7 @@ type decodeState struct {
ctx context.Context
// r is the input stream.
r io.Reader
r wire.Reader
// types is the type database.
types typeDecodeDatabase
@@ -591,7 +590,7 @@ func (ds *decodeState) Load(obj reflect.Value) {
ds.pending.PushBack(rootOds)
// Read the number of objects.
numObjects, object, err := ReadHeader(ds.r)
numObjects, object, err := ReadHeader(&ds.r)
if err != nil {
Failf("header error: %w", err)
}
@@ -613,7 +612,7 @@ func (ds *decodeState) Load(obj reflect.Value) {
// decoding loop in state/pretty/pretty.printer.printStream().
for i := uint64(0); i < numObjects; {
// Unmarshal either a type object or object ID.
encoded = wire.Load(ds.r)
encoded = wire.Load(&ds.r)
switch we := encoded.(type) {
case *wire.Type:
ds.types.Register(we)
@@ -624,7 +623,7 @@ func (ds *decodeState) Load(obj reflect.Value) {
id = objectID(we)
i++
// Unmarshal and resolve the actual object.
encoded = wire.Load(ds.r)
encoded = wire.Load(&ds.r)
ods = ds.lookup(id)
if ods != nil {
// Decode the object.
@@ -718,7 +717,7 @@ func (ds *decodeState) Load(obj reflect.Value) {
// Each object written to the statefile is prefixed with a header. See
// WriteHeader for more information; these functions are exported to allow
// non-state writes to the file to play nice with debugging tools.
func ReadHeader(r io.Reader) (length uint64, object bool, err error) {
func ReadHeader(r *wire.Reader) (length uint64, object bool, err error) {
// Read the header.
err = safely(func() {
length = wire.LoadUint(r)

View File

@@ -16,7 +16,6 @@ package state
import (
"context"
"io"
"reflect"
"sort"
@@ -62,7 +61,7 @@ type encodeState struct {
ctx context.Context
// w is the output stream.
w io.Writer
w wire.Writer
// types is the type database.
types typeEncodeDatabase
@@ -781,7 +780,7 @@ func (es *encodeState) Save(obj reflect.Value) {
}
// Write the header with the number of objects.
if err := WriteHeader(es.w, uint64(len(es.pending)), true); err != nil {
if err := WriteHeader(&es.w, uint64(len(es.pending)), true); err != nil {
Failf("error writing header: %w", err)
}
@@ -791,7 +790,7 @@ func (es *encodeState) Save(obj reflect.Value) {
if err := safely(func() {
for _, wt := range es.pendingTypes {
// Encode the type.
wire.Save(es.w, &wt)
wire.Save(&es.w, &wt)
}
// Emit objects in ID order.
ids := make([]objectID, 0, len(es.pending))
@@ -803,10 +802,10 @@ func (es *encodeState) Save(obj reflect.Value) {
})
for _, id := range ids {
// Encode the id.
wire.Save(es.w, wire.Uint(id))
wire.Save(&es.w, wire.Uint(id))
// Marshal the object.
oes := es.pending[id]
wire.Save(es.w, oes.encoded)
wire.Save(&es.w, oes.encoded)
}
}); err != nil {
// Include the object and the error.
@@ -825,7 +824,7 @@ const objectFlag uint64 = 1 << 63
// order to generate statefiles that play nicely with debugging tools, raw
// writes should be prefixed with a header with object set to false and the
// appropriate length. This will allow tools to skip these regions.
func WriteHeader(w io.Writer, length uint64, object bool) error {
func WriteHeader(w *wire.Writer, length uint64, object bool) error {
// Sanity check the length.
if length&objectFlag != 0 {
Failf("impossibly huge length: %d", length)

View File

@@ -92,7 +92,7 @@ func Save(ctx context.Context, w io.Writer, rootPtr any) (Stats, error) {
// Create the encoding state.
es := encodeState{
ctx: ctx,
w: w,
w: wire.Writer{Writer: w},
types: makeTypeEncodeDatabase(),
zeroValues: make(map[reflect.Type]*objectEncodeState),
pending: make(map[objectID]*objectEncodeState),
@@ -111,7 +111,7 @@ func Load(ctx context.Context, r io.Reader, rootPtr any) (Stats, error) {
// Create the decoding state.
ds := decodeState{
ctx: ctx,
r: r,
r: wire.Reader{Reader: r},
types: makeTypeDecodeDatabase(),
deferred: make(map[objectID]wire.Object),
}

View File

@@ -204,7 +204,7 @@ var singleFieldOrder = []int{0}
//
// This method never returns nil.
func (tbd *typeDecodeDatabase) Lookup(id typeID, typ reflect.Type) *reconciledTypeEntry {
if len(tbd.byID) > int(id) && tbd.byID[id-1] != nil {
if len(tbd.byID) >= int(id) && tbd.byID[id-1] != nil {
// Already reconciled.
return tbd.byID[id-1]
}

View File

@@ -33,17 +33,39 @@ import (
"math"
"gvisor.dev/gvisor/pkg/gohacks"
"gvisor.dev/gvisor/pkg/sync"
)
var oneByteArrayPool = sync.Pool{
New: func() any { return &[1]byte{} },
// Reader bundles an io.Reader with a buffer used to implement readByte
// efficiently.
type Reader struct {
io.Reader
buf [1]byte
}
// readByte reads a single byte from r.Reader without allocation. It panics on
// error.
func (r *Reader) readByte() byte {
n, err := r.Read(r.buf[:])
if n != 1 {
panic(err)
}
return r.buf[0]
}
// Writer bundles an io.Writer with a buffer used to implement writeByte
// efficiently.
type Writer struct {
io.Writer
// buf is used by Uint as a scratch buffer.
buf [10]byte
}
// readFull is a utility. The equivalent is not needed for Write, but the API
// contract dictates that it must always complete all bytes given or return an
// error.
func readFull(r io.Reader, p []byte) {
func readFull(r *Reader, p []byte) {
for done := 0; done < len(p); {
n, err := r.Read(p[done:])
done += n
@@ -58,25 +80,25 @@ type Object interface {
// save saves the given object.
//
// Panic is used for error control flow.
save(io.Writer)
save(*Writer)
// load loads a new object of the given type.
//
// Panic is used for error control flow.
load(io.Reader) Object
load(*Reader) Object
}
// Bool is a boolean.
type Bool bool
// loadBool loads an object of type Bool.
func loadBool(r io.Reader) Bool {
func loadBool(r *Reader) Bool {
b := loadUint(r)
return Bool(b == 1)
}
// save implements Object.save.
func (b Bool) save(w io.Writer) {
func (b Bool) save(w *Writer) {
var v Uint
if b {
v = 1
@@ -87,7 +109,7 @@ func (b Bool) save(w io.Writer) {
}
// load implements Object.load.
func (Bool) load(r io.Reader) Object { return loadBool(r) }
func (Bool) load(r *Reader) Object { return loadBool(r) }
// Int is a signed integer.
//
@@ -95,7 +117,7 @@ func (Bool) load(r io.Reader) Object { return loadBool(r) }
type Int int64
// loadInt loads an object of type Int.
func loadInt(r io.Reader) Int {
func loadInt(r *Reader) Int {
u := loadUint(r)
x := Int(u >> 1)
if u&1 != 0 {
@@ -105,7 +127,7 @@ func loadInt(r io.Reader) Int {
}
// save implements Object.save.
func (i Int) save(w io.Writer) {
func (i Int) save(w *Writer) {
u := Uint(i) << 1
if i < 0 {
u = ^u
@@ -114,29 +136,19 @@ func (i Int) save(w io.Writer) {
}
// load implements Object.load.
func (Int) load(r io.Reader) Object { return loadInt(r) }
func (Int) load(r *Reader) Object { return loadInt(r) }
// Uint is an unsigned integer.
type Uint uint64
func readByte(r io.Reader) byte {
p := oneByteArrayPool.Get().(*[1]byte)
defer oneByteArrayPool.Put(p)
n, err := r.Read(p[:])
if n != 1 {
panic(err)
}
return p[0]
}
// loadUint loads an object of type Uint.
func loadUint(r io.Reader) Uint {
func loadUint(r *Reader) Uint {
var (
u Uint
s uint
)
for i := 0; i <= 9; i++ {
b := readByte(r)
b := r.readByte()
if b < 0x80 {
if i == 9 && b > 1 {
panic("overflow")
@@ -150,76 +162,71 @@ func loadUint(r io.Reader) Uint {
panic("unreachable")
}
func writeByte(w io.Writer, b byte) {
p := oneByteArrayPool.Get().(*[1]byte)
defer oneByteArrayPool.Put(p)
p[0] = b
n, err := w.Write(p[:])
if n != 1 {
// save implements Object.save.
func (u Uint) save(w *Writer) {
i := 0
for u >= 0x80 {
w.buf[i] = byte(u) | 0x80
i++
u >>= 7
}
w.buf[i] = byte(u)
if _, err := w.Write(w.buf[:i+1]); err != nil {
panic(err)
}
}
// save implements Object.save.
func (u Uint) save(w io.Writer) {
for u >= 0x80 {
writeByte(w, byte(u)|0x80)
u >>= 7
}
writeByte(w, byte(u))
}
// load implements Object.load.
func (Uint) load(r io.Reader) Object { return loadUint(r) }
func (Uint) load(r *Reader) Object { return loadUint(r) }
// Float32 is a 32-bit floating point number.
type Float32 float32
// loadFloat32 loads an object of type Float32.
func loadFloat32(r io.Reader) Float32 {
func loadFloat32(r *Reader) Float32 {
n := loadUint(r)
return Float32(math.Float32frombits(uint32(n)))
}
// save implements Object.save.
func (f Float32) save(w io.Writer) {
func (f Float32) save(w *Writer) {
n := Uint(math.Float32bits(float32(f)))
n.save(w)
}
// load implements Object.load.
func (Float32) load(r io.Reader) Object { return loadFloat32(r) }
func (Float32) load(r *Reader) Object { return loadFloat32(r) }
// Float64 is a 64-bit floating point number.
type Float64 float64
// loadFloat64 loads an object of type Float64.
func loadFloat64(r io.Reader) Float64 {
func loadFloat64(r *Reader) Float64 {
n := loadUint(r)
return Float64(math.Float64frombits(uint64(n)))
}
// save implements Object.save.
func (f Float64) save(w io.Writer) {
func (f Float64) save(w *Writer) {
n := Uint(math.Float64bits(float64(f)))
n.save(w)
}
// load implements Object.load.
func (Float64) load(r io.Reader) Object { return loadFloat64(r) }
func (Float64) load(r *Reader) Object { return loadFloat64(r) }
// Complex64 is a 64-bit complex number.
type Complex64 complex128
// loadComplex64 loads an object of type Complex64.
func loadComplex64(r io.Reader) Complex64 {
func loadComplex64(r *Reader) Complex64 {
re := loadFloat32(r)
im := loadFloat32(r)
return Complex64(complex(float32(re), float32(im)))
}
// save implements Object.save.
func (c *Complex64) save(w io.Writer) {
func (c *Complex64) save(w *Writer) {
re := Float32(real(*c))
im := Float32(imag(*c))
re.save(w)
@@ -227,7 +234,7 @@ func (c *Complex64) save(w io.Writer) {
}
// load implements Object.load.
func (*Complex64) load(r io.Reader) Object {
func (*Complex64) load(r *Reader) Object {
c := loadComplex64(r)
return &c
}
@@ -236,14 +243,14 @@ func (*Complex64) load(r io.Reader) Object {
type Complex128 complex128
// loadComplex128 loads an object of type Complex128.
func loadComplex128(r io.Reader) Complex128 {
func loadComplex128(r *Reader) Complex128 {
re := loadFloat64(r)
im := loadFloat64(r)
return Complex128(complex(float64(re), float64(im)))
}
// save implements Object.save.
func (c *Complex128) save(w io.Writer) {
func (c *Complex128) save(w *Writer) {
re := Float64(real(*c))
im := Float64(imag(*c))
re.save(w)
@@ -251,7 +258,7 @@ func (c *Complex128) save(w io.Writer) {
}
// load implements Object.load.
func (*Complex128) load(r io.Reader) Object {
func (*Complex128) load(r *Reader) Object {
c := loadComplex128(r)
return &c
}
@@ -260,7 +267,7 @@ func (*Complex128) load(r io.Reader) Object {
type String string
// loadString loads an object of type String.
func loadString(r io.Reader) String {
func loadString(r *Reader) String {
l := loadUint(r)
p := make([]byte, l)
readFull(r, p)
@@ -268,7 +275,7 @@ func loadString(r io.Reader) String {
}
// save implements Object.save.
func (s *String) save(w io.Writer) {
func (s *String) save(w *Writer) {
l := Uint(len(*s))
l.save(w)
p := gohacks.ImmutableBytesFromString(string(*s))
@@ -279,7 +286,7 @@ func (s *String) save(w io.Writer) {
}
// load implements Object.load.
func (*String) load(r io.Reader) Object {
func (*String) load(r *Reader) Object {
s := loadString(r)
return &s
}
@@ -315,7 +322,7 @@ type Ref struct {
}
// loadRef loads an object of type Ref (abstract).
func loadRef(r io.Reader) Ref {
func loadRef(r *Reader) Ref {
ref := Ref{
Root: loadUint(r),
}
@@ -343,7 +350,7 @@ func loadRef(r io.Reader) Ref {
}
// save implements Object.save.
func (r *Ref) save(w io.Writer) {
func (r *Ref) save(w *Writer) {
r.Root.save(w)
l := Uint(len(r.Dots))
l.save(w)
@@ -372,7 +379,7 @@ func (r *Ref) save(w io.Writer) {
}
// load implements Object.load.
func (*Ref) load(r io.Reader) Object {
func (*Ref) load(r *Reader) Object {
ref := loadRef(r)
return &ref
}
@@ -381,15 +388,15 @@ func (*Ref) load(r io.Reader) Object {
type Nil struct{}
// loadNil loads an object of type Nil.
func loadNil(r io.Reader) Nil {
func loadNil(r *Reader) Nil {
return Nil{}
}
// save implements Object.save.
func (Nil) save(w io.Writer) {}
func (Nil) save(w *Writer) {}
// load implements Object.load.
func (Nil) load(r io.Reader) Object { return loadNil(r) }
func (Nil) load(r *Reader) Object { return loadNil(r) }
// Slice is a slice value.
type Slice struct {
@@ -399,7 +406,7 @@ type Slice struct {
}
// loadSlice loads an object of type Slice.
func loadSlice(r io.Reader) Slice {
func loadSlice(r *Reader) Slice {
return Slice{
Length: loadUint(r),
Capacity: loadUint(r),
@@ -408,14 +415,14 @@ func loadSlice(r io.Reader) Slice {
}
// save implements Object.save.
func (s *Slice) save(w io.Writer) {
func (s *Slice) save(w *Writer) {
s.Length.save(w)
s.Capacity.save(w)
s.Ref.save(w)
}
// load implements Object.load.
func (*Slice) load(r io.Reader) Object {
func (*Slice) load(r *Reader) Object {
s := loadSlice(r)
return &s
}
@@ -426,7 +433,7 @@ type Array struct {
}
// loadArray loads an object of type Array.
func loadArray(r io.Reader) Array {
func loadArray(r *Reader) Array {
l := loadUint(r)
if l == 0 {
// Note that there isn't a single object available to encode
@@ -448,7 +455,7 @@ func loadArray(r io.Reader) Array {
}
// save implements Object.save.
func (a *Array) save(w io.Writer) {
func (a *Array) save(w *Writer) {
l := Uint(len(a.Contents))
l.save(w)
if l == 0 {
@@ -463,7 +470,7 @@ func (a *Array) save(w io.Writer) {
}
// load implements Object.load.
func (*Array) load(r io.Reader) Object {
func (*Array) load(r *Reader) Object {
a := loadArray(r)
return &a
}
@@ -475,7 +482,7 @@ type Map struct {
}
// loadMap loads an object of type Map.
func loadMap(r io.Reader) Map {
func loadMap(r *Reader) Map {
l := loadUint(r)
if l == 0 {
// See LoadArray.
@@ -499,7 +506,7 @@ func loadMap(r io.Reader) Map {
}
// save implements Object.save.
func (m *Map) save(w io.Writer) {
func (m *Map) save(w *Writer) {
l := Uint(len(m.Keys))
if int(l) != len(m.Values) {
panic(fmt.Sprintf("mismatched keys (%d) Aand values (%d)", len(m.Keys), len(m.Values)))
@@ -519,7 +526,7 @@ func (m *Map) save(w io.Writer) {
}
// load implements Object.load.
func (*Map) load(r io.Reader) Object {
func (*Map) load(r *Reader) Object {
m := loadMap(r)
return &m
}
@@ -584,7 +591,7 @@ const (
)
// loadTypeSpec loads TypeSpec values.
func loadTypeSpec(r io.Reader) TypeSpec {
func loadTypeSpec(r *Reader) TypeSpec {
switch hdr := loadUint(r); hdr {
case typeSpecTypeID:
return TypeID(loadUint(r))
@@ -615,7 +622,7 @@ func loadTypeSpec(r io.Reader) TypeSpec {
}
// saveTypeSpec saves TypeSpec values.
func saveTypeSpec(w io.Writer, t TypeSpec) {
func saveTypeSpec(w *Writer, t TypeSpec) {
switch x := t.(type) {
case TypeID:
typeSpecTypeID.save(w)
@@ -649,7 +656,7 @@ type Interface struct {
}
// loadInterface loads an object of type Interface.
func loadInterface(r io.Reader) Interface {
func loadInterface(r *Reader) Interface {
return Interface{
Type: loadTypeSpec(r),
Value: Load(r),
@@ -657,13 +664,13 @@ func loadInterface(r io.Reader) Interface {
}
// save implements Object.save.
func (i *Interface) save(w io.Writer) {
func (i *Interface) save(w *Writer) {
saveTypeSpec(w, i.Type)
Save(w, i.Value)
}
// load implements Object.load.
func (*Interface) load(r io.Reader) Object {
func (*Interface) load(r *Reader) Object {
i := loadInterface(r)
return &i
}
@@ -675,7 +682,7 @@ type Type struct {
}
// loadType loads an object of type Type.
func loadType(r io.Reader) Type {
func loadType(r *Reader) Type {
name := string(loadString(r))
l := loadUint(r)
fields := make([]string, l)
@@ -689,7 +696,7 @@ func loadType(r io.Reader) Type {
}
// save implements Object.save.
func (t *Type) save(w io.Writer) {
func (t *Type) save(w *Writer) {
s := String(t.Name)
s.save(w)
l := Uint(len(t.Fields))
@@ -701,7 +708,7 @@ func (t *Type) save(w io.Writer) {
}
// load implements Object.load.
func (*Type) load(r io.Reader) Object {
func (*Type) load(r *Reader) Object {
t := loadType(r)
return &t
}
@@ -710,7 +717,7 @@ func (*Type) load(r io.Reader) Object {
type multipleObjects []Object
// loadMultipleObjects loads a series of objects.
func loadMultipleObjects(r io.Reader) multipleObjects {
func loadMultipleObjects(r *Reader) multipleObjects {
l := loadUint(r)
m := make(multipleObjects, l)
for i := 0; i < int(l); i++ {
@@ -720,7 +727,7 @@ func loadMultipleObjects(r io.Reader) multipleObjects {
}
// save implements Object.save.
func (m *multipleObjects) save(w io.Writer) {
func (m *multipleObjects) save(w *Writer) {
l := Uint(len(*m))
l.save(w)
for i := 0; i < int(l); i++ {
@@ -729,7 +736,7 @@ func (m *multipleObjects) save(w io.Writer) {
}
// load implements Object.load.
func (*multipleObjects) load(r io.Reader) Object {
func (*multipleObjects) load(r *Reader) Object {
m := loadMultipleObjects(r)
return &m
}
@@ -738,13 +745,13 @@ func (*multipleObjects) load(r io.Reader) Object {
type noObjects struct{}
// loadNoObjects loads a sentinel.
func loadNoObjects(r io.Reader) noObjects { return noObjects{} }
func loadNoObjects(r *Reader) noObjects { return noObjects{} }
// save implements Object.save.
func (noObjects) save(w io.Writer) {}
func (noObjects) save(w *Writer) {}
// load implements Object.load.
func (noObjects) load(r io.Reader) Object { return loadNoObjects(r) }
func (noObjects) load(r *Reader) Object { return loadNoObjects(r) }
// Struct is a basic composite value.
type Struct struct {
@@ -799,7 +806,7 @@ func (s *Struct) Fields() int {
}
// loadStruct loads an object of type Struct.
func loadStruct(r io.Reader) Struct {
func loadStruct(r *Reader) Struct {
return Struct{
TypeID: TypeID(loadUint(r)),
fields: Load(r),
@@ -810,13 +817,13 @@ func loadStruct(r io.Reader) Struct {
//
// Precondition: Alloc must have been called, and the fields all filled in
// appropriately. See Alloc and Add for more details.
func (s *Struct) save(w io.Writer) {
func (s *Struct) save(w *Writer) {
Uint(s.TypeID).save(w)
Save(w, s.fields)
}
// load implements Object.load.
func (*Struct) load(r io.Reader) Object {
func (*Struct) load(r *Reader) Object {
s := loadStruct(r)
return &s
}
@@ -851,7 +858,7 @@ const (
// +checkescape all
//
// N.B. This function will panic on error.
func Save(w io.Writer, obj Object) {
func Save(w *Writer, obj Object) {
switch x := obj.(type) {
case Bool:
typeBool.save(w)
@@ -917,7 +924,7 @@ func Save(w io.Writer, obj Object) {
// +checkescape all
//
// N.B. This function will panic on error.
func Load(r io.Reader) Object {
func Load(r *Reader) Object {
switch hdr := loadUint(r); hdr {
case typeBool:
return loadBool(r)
@@ -964,13 +971,13 @@ func Load(r io.Reader) Object {
// LoadUint loads a single unsigned integer.
//
// N.B. This function will panic on error.
func LoadUint(r io.Reader) uint64 {
func LoadUint(r *Reader) uint64 {
return uint64(loadUint(r))
}
// SaveUint saves a single unsigned integer.
//
// N.B. This function will panic on error.
func SaveUint(w io.Writer, v uint64) {
func SaveUint(w *Writer, v uint64) {
Uint(v).save(w)
}

View File

@@ -34,3 +34,18 @@ type (
func NewCond(l Locker) *Cond {
return sync.NewCond(l)
}
// OnceFunc is a wrapper around sync.OnceFunc.
func OnceFunc(f func()) func() {
return sync.OnceFunc(f)
}
// OnceValue is a wrapper around sync.OnceValue.
func OnceValue[T any](f func() T) func() T {
return sync.OnceValue(f)
}
// OnceValues is a wrapper around sync.OnceValues.
func OnceValues[T1, T2 any](f func() (T1, T2)) func() (T1, T2) {
return sync.OnceValues(f)
}

View File

@@ -3,7 +3,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.21
//go:build go1.21 && !go1.24
package sync

View File

@@ -0,0 +1,14 @@
// Copyright 2024 The gVisor Authors.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.24
package sync
import "unsafe"
// Use checkoffset to assert that maptype.hasher (the only field we use) has
// the correct offset.
const maptypeHasherOffset = unsafe.Offsetof(maptype{}.Hasher) // +checkoffset internal/abi SwissMapType.Hasher

View File

@@ -620,4 +620,22 @@ func (*ErrMulticastInputCannotBeOutput) IgnoreStats() bool {
}
func (*ErrMulticastInputCannotBeOutput) String() string { return "output cannot contain input" }
// ErrEndpointBusy indicates that the operation cannot be completed because the
// endpoint is busy.
//
// +stateify savable
type ErrEndpointBusy struct{}
// isError implements Error.
func (*ErrEndpointBusy) isError() {}
// IgnoreStats implements Error.
func (*ErrEndpointBusy) IgnoreStats() bool {
return true
}
func (*ErrEndpointBusy) String() string {
return "operation cannot be completed because the endpoint is busy"
}
// LINT.ThenChange(../syserr/netstack.go)

View File

@@ -603,6 +603,9 @@ const (
// IPv4OptionTimestampType is the option type for the Timestamp option.
IPv4OptionTimestampType IPv4OptionType = 68
// IPv4OptionExperimentType is the option type for the Experiment option.
IPv4OptionExperimentType IPv4OptionType = 30
// ipv4OptionTypeOffset is the offset in an option of its type field.
ipv4OptionTypeOffset = 0
@@ -800,6 +803,17 @@ func (i *IPv4OptionIterator) Next() (IPv4Option, bool, *IPv4OptParameterProblem)
}
retval := IPv4OptionRouterAlert(optionBody)
return &retval, false, nil
case IPv4OptionExperimentType:
if optLen != IPv4OptionExperimentLength {
i.ErrCursor++
return nil, false, &IPv4OptParameterProblem{
Pointer: i.ErrCursor,
NeedICMP: true,
}
}
retval := IPv4OptionExperiment(optionBody)
return &retval, false, nil
}
retval := IPv4OptionGeneric(optionBody)
return &retval, false, nil
@@ -1074,6 +1088,35 @@ func (ra *IPv4OptionRouterAlert) Value() uint16 {
return binary.BigEndian.Uint16(ra.Contents()[IPv4OptionRouterAlertValueOffset:])
}
// Experiment option specific related constants.
const (
// IPv4OptionExperimentLength is the length of an Experiment option.
IPv4OptionExperimentLength = 4
// IPv4OptionExperimentValueOffset is the offset for the value of an
// Experiment option.
IPv4OptionExperimentValueOffset = 2
)
var _ IPv4Option = (*IPv4OptionExperiment)(nil)
// IPv4OptionExperiment is an IPv4 option defined by RFC 4727.
type IPv4OptionExperiment []byte
// Type implements IPv4Option.
func (*IPv4OptionExperiment) Type() IPv4OptionType { return IPv4OptionExperimentType }
// Size implements IPv4Option.
func (*IPv4OptionExperiment) Size() uint8 { return uint8(IPv4OptionExperimentLength) }
// Contents implements IPv4Option.
func (ex *IPv4OptionExperiment) Contents() []byte { return *ex }
// Value returns the value of the IPv4OptionRouterAlert.
func (ex *IPv4OptionExperiment) Value() uint16 {
return binary.BigEndian.Uint16(ex.Contents()[IPv4OptionExperimentValueOffset:])
}
// IPv4SerializableOption is an interface to represent serializable IPv4 option
// types.
type IPv4SerializableOption interface {
@@ -1179,6 +1222,28 @@ func (o *IPv4SerializableRouterAlertOption) serializeInto(buffer []byte) uint8 {
return o.length()
}
var _ IPv4SerializableOptionPayload = (*IPv4SerializableExperimentOption)(nil)
var _ IPv4SerializableOption = (*IPv4SerializableExperimentOption)(nil)
// IPv4SerializableExperimentOption provides serialization for the IPv4
// Experiment option.
type IPv4SerializableExperimentOption struct {
Tag uint16
}
func (*IPv4SerializableExperimentOption) optionType() IPv4OptionType {
return IPv4OptionExperimentType
}
func (*IPv4SerializableExperimentOption) length() uint8 {
return IPv4OptionExperimentLength - IPv4OptionExperimentValueOffset
}
func (o *IPv4SerializableExperimentOption) serializeInto(buffer []byte) uint8 {
binary.BigEndian.PutUint16(buffer, o.Tag)
return o.length()
}
var _ IPv4SerializableOption = (*IPv4SerializableNOPOption)(nil)
// IPv4SerializableNOPOption provides serialization for the IPv4 no-op option.

View File

@@ -49,6 +49,10 @@ const (
// of an IPv6 payload, as per RFC 8200 section 4.7.
IPv6NoNextHeaderIdentifier IPv6ExtensionHeaderIdentifier = 59
// IPv6ExperimentExtHdrIdentifier is the header identifier of an Experiment
// extension header, as per RFC 4727 section 3.3.
IPv6ExperimentExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 253
// IPv6UnknownExtHdrIdentifier is reserved by IANA.
// https://www.iana.org/assignments/ipv6-parameters/ipv6-parameters.xhtml#extension-header
// "254 Use for experimentation and testing [RFC3692][RFC4727]"
@@ -411,6 +415,17 @@ type IPv6DestinationOptionsExtHdr struct {
// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
func (IPv6DestinationOptionsExtHdr) isIPv6PayloadHeader() {}
// IPv6ExperimentExtHdr is a buffer holding the Experiment extension header.
type IPv6ExperimentExtHdr struct {
Value uint16
}
// Release implements IPv6PayloadHeader.Release.
func (IPv6ExperimentExtHdr) Release() {}
// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
func (IPv6ExperimentExtHdr) isIPv6PayloadHeader() {}
// IPv6RoutingExtHdr is a buffer holding the Routing extension header specific
// data as outlined in RFC 8200 section 4.4.
type IPv6RoutingExtHdr struct {
@@ -580,7 +595,7 @@ func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) {
// Is the header we are parsing a known extension header?
switch i.nextHdrIdentifier {
case IPv6HopByHopOptionsExtHdrIdentifier:
nextHdrIdentifier, view, err := i.nextHeaderData(false /* fragmentHdr */, nil)
nextHdrIdentifier, view, err := i.nextHeaderData(false /* ignoreLength */, nil)
if err != nil {
return nil, true, err
}
@@ -588,7 +603,7 @@ func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) {
i.nextHdrIdentifier = nextHdrIdentifier
return IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr{view}}, false, nil
case IPv6RoutingExtHdrIdentifier:
nextHdrIdentifier, view, err := i.nextHeaderData(false /* fragmentHdr */, nil)
nextHdrIdentifier, view, err := i.nextHeaderData(false /* ignoreLength */, nil)
if err != nil {
return nil, true, err
}
@@ -599,7 +614,7 @@ func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) {
var data [6]byte
// We ignore the returned bytes because we know the fragment extension
// header specific data will fit in data.
nextHdrIdentifier, _, err := i.nextHeaderData(true /* fragmentHdr */, data[:])
nextHdrIdentifier, _, err := i.nextHeaderData(true /* ignoreLength */, data[:])
if err != nil {
return nil, true, err
}
@@ -618,13 +633,24 @@ func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) {
i.nextHdrIdentifier = nextHdrIdentifier
return fragmentExtHdr, false, nil
case IPv6DestinationOptionsExtHdrIdentifier:
nextHdrIdentifier, view, err := i.nextHeaderData(false /* fragmentHdr */, nil)
nextHdrIdentifier, view, err := i.nextHeaderData(false /* ignoreLength */, nil)
if err != nil {
return nil, true, err
}
i.nextHdrIdentifier = nextHdrIdentifier
return IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr{view}}, false, nil
case IPv6ExperimentExtHdrIdentifier:
var data [IPv6ExperimentHdrLength - ipv6ExperimentHdrValueOffset]byte
nextHdrIdentifier, _, err := i.nextHeaderData(true /* ignoreLength */, data[:])
if err != nil {
return nil, true, err
}
i.nextHdrIdentifier = nextHdrIdentifier
hdr := IPv6ExperimentExtHdr{
Value: binary.BigEndian.Uint16(data[:ipv6ExperimentHdrTagLength]),
}
return hdr, false, nil
case IPv6NoNextHeaderIdentifier:
// This indicates the end of the IPv6 payload.
return nil, true, nil
@@ -644,14 +670,14 @@ func (i *IPv6PayloadIterator) NextHeaderIdentifier() IPv6ExtensionHeaderIdentifi
// nextHeaderData returns the extension header's Next Header field and raw data.
//
// fragmentHdr indicates that the extension header being parsed is the Fragment
// extension header so the Length field should be ignored as it is Reserved
// for the Fragment extension header.
// ignoreLength indicates that the extension header being parsed should ignore
// the Length field as it is reserved. This is for the Fragment and Experiment
// extension headers.
//
// If bytes is not nil, extension header specific data will be read into bytes
// if it has enough capacity. If bytes is provided but does not have enough
// capacity for the data, nextHeaderData will panic.
func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IPv6ExtensionHeaderIdentifier, *buffer.View, error) {
func (i *IPv6PayloadIterator) nextHeaderData(ignoreLength bool, bytes []byte) (IPv6ExtensionHeaderIdentifier, *buffer.View, error) {
// We ignore the number of bytes read because we know we will only ever read
// at max 1 bytes since rune has a length of 1. If we read 0 bytes, the Read
// would return io.EOF to indicate that io.Reader has reached the end of the
@@ -667,13 +693,13 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP
length, err = rdr.ReadByte()
if err != nil {
if fragmentHdr {
if ignoreLength {
return 0, nil, fmt.Errorf("error when reading the Length field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
}
return 0, nil, fmt.Errorf("error when reading the Reserved field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
}
if fragmentHdr {
if ignoreLength {
length = 0
}
@@ -689,7 +715,7 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP
i.nextOffset += uint32((length + 1) * ipv6ExtHdrLenBytesPerUnit)
bytesLen := int(length)*ipv6ExtHdrLenBytesPerUnit + ipv6ExtHdrLenBytesExcluded
if fragmentHdr {
if ignoreLength {
if n := len(bytes); n < bytesLen {
panic(fmt.Sprintf("bytes only has space for %d bytes but need space for %d bytes (length = %d) for extension header with id = %d", n, bytesLen, length, i.nextHdrIdentifier))
}
@@ -735,6 +761,36 @@ type IPv6SerializableExtHdr interface {
serializeInto(nextHeader uint8, b []byte) int
}
// ipv6RouterAlertPayloadLength is the length of the Router Alert payload
// as defined in RFC 4727 section 3.3.
const (
IPv6ExperimentHdrLength = 8
ipv6ExperimentNextHeaderOffset = 0
ipv6ExperimentLengthOffset = 1
ipv6ExperimentHdrValueOffset = 2
ipv6ExperimentHdrTagLength = 2
)
var _ IPv6SerializableExtHdr = (*IPv6ExperimentExtHdr)(nil)
// identifier implements IPv6SerializableExtHdr.
func (h IPv6ExperimentExtHdr) identifier() IPv6ExtensionHeaderIdentifier {
return IPv6ExperimentExtHdrIdentifier
}
// length implements IPv6SerializableExtHdr.
func (h IPv6ExperimentExtHdr) length() int {
return IPv6ExperimentHdrLength
}
// serializeInto implements IPv6SerializableExtHdr.
func (h IPv6ExperimentExtHdr) serializeInto(nextHeader uint8, b []byte) int {
b[ipv6ExperimentNextHeaderOffset] = nextHeader
b[ipv6ExperimentLengthOffset] = (IPv6ExperimentHdrLength / ipv6ExtHdrLenBytesPerUnit) - 1
binary.BigEndian.PutUint16(b[ipv6ExperimentHdrValueOffset:][:ipv6ExperimentHdrTagLength], uint16(h.Value))
return IPv6ExperimentHdrLength
}
var _ IPv6SerializableExtHdr = (*IPv6SerializableHopByHopExtHdr)(nil)
// IPv6SerializableHopByHopExtHdr implements serialization of the Hop by Hop

View File

@@ -135,7 +135,7 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s
last: last,
filled: true,
final: currentHole.final,
pkt: pkt.IncRef(),
pkt: pkt.Clone(),
}
r.filled++
// For IPv6, it is possible to have different Protocol values between
@@ -150,7 +150,7 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s
if r.pkt != nil {
r.pkt.DecRef()
}
r.pkt = pkt.IncRef()
r.pkt = pkt.Clone()
r.proto = proto
}
break

View File

@@ -294,6 +294,7 @@ func (m *MultiCounterIPForwardingStats) StateFields() []string {
"NoMulticastPendingQueueBufferSpace",
"OutgoingDeviceNoBufferSpace",
"Errors",
"OutgoingDeviceClosedForSend",
}
}
@@ -315,6 +316,7 @@ func (m *MultiCounterIPForwardingStats) StateSave(stateSinkObject state.Sink) {
stateSinkObject.Save(10, &m.NoMulticastPendingQueueBufferSpace)
stateSinkObject.Save(11, &m.OutgoingDeviceNoBufferSpace)
stateSinkObject.Save(12, &m.Errors)
stateSinkObject.Save(13, &m.OutgoingDeviceClosedForSend)
}
func (m *MultiCounterIPForwardingStats) afterLoad(context.Context) {}
@@ -334,6 +336,7 @@ func (m *MultiCounterIPForwardingStats) StateLoad(ctx context.Context, stateSour
stateSourceObject.Load(10, &m.NoMulticastPendingQueueBufferSpace)
stateSourceObject.Load(11, &m.OutgoingDeviceNoBufferSpace)
stateSourceObject.Load(12, &m.Errors)
stateSourceObject.Load(13, &m.OutgoingDeviceClosedForSend)
}
func (m *MultiCounterIPStats) StateTypeName() string {

View File

@@ -78,6 +78,10 @@ type MultiCounterIPForwardingStats struct {
// Errors is the number of IP packets received which could not be
// successfully forwarded.
Errors tcpip.MultiCounterStat
// OutgoingDeviceClosedForSend is the number of packets that were dropped due
// to the outgoing device being closed for send.
OutgoingDeviceClosedForSend tcpip.MultiCounterStat
}
// Init sets internal counters to track a and b counters.
@@ -95,9 +99,10 @@ func (m *MultiCounterIPForwardingStats) Init(a, b *tcpip.IPForwardingStats) {
m.UnknownOutputEndpoint.Init(a.UnknownOutputEndpoint, b.UnknownOutputEndpoint)
m.NoMulticastPendingQueueBufferSpace.Init(a.NoMulticastPendingQueueBufferSpace, b.NoMulticastPendingQueueBufferSpace)
m.OutgoingDeviceNoBufferSpace.Init(a.OutgoingDeviceNoBufferSpace, b.OutgoingDeviceNoBufferSpace)
m.OutgoingDeviceClosedForSend.Init(a.OutgoingDeviceClosedForSend, b.OutgoingDeviceClosedForSend)
}
// LINT.ThenChange(:MultiCounterIPForwardingStats, ../../../tcpip.go:IPForwardingStats)
// LINT.ThenChange(../../../tcpip.go:IPForwardingStats)
// LINT.IfChange(MultiCounterIPStats)
@@ -211,4 +216,4 @@ func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) {
m.Forwarding.Init(&a.Forwarding, &b.Forwarding)
}
// LINT.ThenChange(:MultiCounterIPStats, ../../../tcpip.go:IPStats)
// LINT.ThenChange(../../../tcpip.go:IPStats)

View File

@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/atomicbitops"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -69,6 +70,8 @@ const (
forwardingEnabled = 1
)
var martianPacketLogger = log.BasicRateLimitedLogger(time.Minute)
var ipv4BroadcastAddr = header.IPv4Broadcast.WithPrefix()
var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil)
@@ -446,6 +449,9 @@ func (e *endpoint) getID() uint16 {
}
func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) tcpip.Error {
if expVal := params.ExperimentOptionValue; expVal != 0 {
options = append(options, &header.IPv4SerializableExperimentOption{Tag: expVal})
}
hdrLen := header.IPv4MinimumSize
var optLen int
if options != nil {
@@ -839,11 +845,13 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
if !e.nic.IsLoopback() {
if !e.protocol.options.AllowExternalLoopbackTraffic {
if header.IsV4LoopbackAddress(h.SourceAddress()) {
martianPacketLogger.Infof("Martian packet dropped with loopback source address. If your traffic is unexpectedly dropped, you may want to allow martian packets.")
stats.InvalidSourceAddressesReceived.Increment()
return
}
if header.IsV4LoopbackAddress(h.DestinationAddress()) {
martianPacketLogger.Infof("Martian packet dropped with loopback destination address. If your traffic is unexpectedly dropped, you may want to allow martian packets.")
stats.InvalidDestinationAddressesReceived.Increment()
return
}
@@ -868,7 +876,9 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
}
// CheckPrerouting can modify the backing storage of the packet, so refresh
// the header.
h = header.IPv4(pkt.NetworkHeader().Slice())
e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */)
}
@@ -1198,6 +1208,13 @@ func (e *endpoint) handleForwardingError(err ip.ForwardingError) {
stats.Forwarding.UnknownOutputEndpoint.Increment()
case *ip.ErrOutgoingDeviceNoBufferSpace:
stats.Forwarding.OutgoingDeviceNoBufferSpace.Increment()
case *ip.ErrOther:
switch err := err.Err.(type) {
case *tcpip.ErrClosedForSend:
stats.Forwarding.OutgoingDeviceClosedForSend.Increment()
default:
panic(fmt.Sprintf("unrecognized tcpip forwarding error: %s", err))
}
default:
panic(fmt.Sprintf("unrecognized forwarding error: %s", err))
}

View File

@@ -385,6 +385,13 @@ func (e *endpoint) SetNDPConfigurations(c NDPConfigurations) {
e.mu.ndp.configs = c
}
// NDPConfigurations implements NDPEndpoint.
func (e *endpoint) NDPConfigurations() NDPConfigurations {
e.mu.RLock()
defer e.mu.RUnlock()
return e.mu.ndp.configs
}
// hasTentativeAddr returns true if addr is tentative on e.
func (e *endpoint) hasTentativeAddr(addr tcpip.Address) bool {
e.mu.RLock()
@@ -731,6 +738,9 @@ func (e *endpoint) MaxHeaderLength() uint16 {
}
func addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) tcpip.Error {
if params.ExperimentOptionValue != 0 {
extensionHeaders = append(extensionHeaders, &header.IPv6ExperimentExtHdr{Value: params.ExperimentOptionValue})
}
extHdrsLen := extensionHeaders.Length()
length := pkt.Size() + extensionHeaders.Length()
if length > math.MaxUint16 {
@@ -1133,6 +1143,9 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
}
}
// CheckPrerouting can modify the backing storage of the packet, so refresh
// the header.
h = header.IPv6(pkt.NetworkHeader().Slice())
e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */)
}
@@ -1300,6 +1313,13 @@ func (e *endpoint) handleForwardingError(err ip.ForwardingError) {
stats.Forwarding.UnknownOutputEndpoint.Increment()
case *ip.ErrOutgoingDeviceNoBufferSpace:
stats.Forwarding.OutgoingDeviceNoBufferSpace.Increment()
case *ip.ErrOther:
switch err := err.Err.(type) {
case *tcpip.ErrClosedForSend:
stats.Forwarding.OutgoingDeviceClosedForSend.Increment()
default:
panic(fmt.Sprintf("unrecognized tcpip forwarding error: %s", err))
}
default:
panic(fmt.Sprintf("unrecognized forwarding error: %s", err))
}
@@ -1446,6 +1466,7 @@ func (e *endpoint) processExtensionHeader(it *header.IPv6PayloadIterator, pkt **
if err := e.processIPv6RawPayloadHeader(&extHdr, it, *pkt, *routerAlert, previousHeaderStart, *hasFragmentHeader); err != nil {
return true, err
}
case header.IPv6ExperimentExtHdr:
default:
// Since the iterator returns IPv6RawPayloadHeader for unknown Extension
// Header IDs this should never happen unless we missed a supported type
@@ -1488,6 +1509,7 @@ func (e *endpoint) processExtensionHeaders(h header.IPv6, pkt *stack.PacketBuffe
routerAlert *header.IPv6RouterAlertOption
)
for {
h := header.IPv6(pkt.NetworkHeader().Slice())
if done, err := e.processExtensionHeader(&it, &pkt, h, &routerAlert, &hasFragmentHeader, forwarding); err != nil || done {
return err
}

View File

@@ -172,6 +172,9 @@ const (
type NDPEndpoint interface {
// SetNDPConfigurations sets the NDP configurations.
SetNDPConfigurations(NDPConfigurations)
// NDPConfigurations returns the NDP configurations.
NDPConfigurations() NDPConfigurations
}
// DHCPv6ConfigurationFromNDPRA is a configuration available via DHCPv6 that an

View File

@@ -263,6 +263,10 @@ type SocketOptions struct {
// rcvlowat specifies the minimum number of bytes which should be
// received to indicate the socket as readable.
rcvlowat atomicbitops.Int32
// experimentOptionValue is the value set for the IP option experiment header
// if it is not zero.
experimentOptionValue atomicbitops.Uint32
}
// InitHandler initializes the handler. This must be called before using the
@@ -539,6 +543,17 @@ func (so *SocketOptions) SetLinger(linger LingerOption) {
so.mu.Unlock()
}
// GetExperimentOptionValue gets value for the experiment IP option header.
func (so *SocketOptions) GetExperimentOptionValue() uint16 {
v := so.experimentOptionValue.Load()
return uint16(v)
}
// SetExperimentOptionValue sets the value for the experiment IP option header.
func (so *SocketOptions) SetExperimentOptionValue(v uint16) {
so.experimentOptionValue.Store(uint32(v))
}
// SockErrOrigin represents the constants for error origin.
type SockErrOrigin uint8

View File

@@ -41,10 +41,12 @@ type AddressableEndpointState struct {
// AddressableEndpointState.mu
// addressState.mu
mu addressableEndpointStateRWMutex `state:"nosave"`
// TODO(b/361075310): Enable s/r for the below fields.
//
// +checklocks:mu
endpoints map[tcpip.Address]*addressState
endpoints map[tcpip.Address]*addressState `state:"nosave"`
// +checklocks:mu
primary []*addressState
primary []*addressState `state:"nosave"`
}
// AddressableEndpointStateOptions contains options used to configure an
@@ -736,8 +738,6 @@ func (a *AddressableEndpointState) Cleanup() {
var _ AddressEndpoint = (*addressState)(nil)
// addressState holds state for an address.
//
// +stateify savable
type addressState struct {
addressableEndpointState *AddressableEndpointState
addr tcpip.AddressWithPrefix
@@ -748,7 +748,7 @@ type addressState struct {
//
// AddressableEndpointState.mu
// addressState.mu
mu addressStateRWMutex `state:"nosave"`
mu addressStateRWMutex
refs addressStateRefs
// checklocks:mu
kind AddressKind

View File

@@ -22,11 +22,28 @@ import (
var _ NetworkLinkEndpoint = (*BridgeEndpoint)(nil)
// +stateify savable
type bridgePort struct {
bridge *BridgeEndpoint
nic *nic
}
// BridgeFDBKey is the MAC address of a device which a bridge port is associated with.
type BridgeFDBKey tcpip.LinkAddress
// BridgeFDBEntry consists of all metadata for a FDB record.
type BridgeFDBEntry struct {
port *bridgePort
}
// PortLinkAddress returns the mac address of the device that is bound to the bridge port.
func (e BridgeFDBEntry) PortLinkAddress() tcpip.LinkAddress {
if e.port == nil {
return ""
}
return e.port.nic.LinkAddress()
}
// ParseHeader implements stack.LinkEndpoint.
func (p *bridgePort) ParseHeader(pkt *PacketBuffer) bool {
_, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
@@ -36,23 +53,49 @@ func (p *bridgePort) ParseHeader(pkt *PacketBuffer) bool {
// DeliverNetworkPacket implements stack.NetworkDispatcher.
func (p *bridgePort) DeliverNetworkPacket(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
bridge := p.bridge
eth := header.Ethernet(pkt.LinkHeader().Slice())
updateFDB := false
bridge.mu.RLock()
// Send the packet to all other ports.
for _, port := range bridge.ports {
if p == port {
continue
// Add an entry at the bridge FDB, it maps a MAC address
// to a bridge port where the traffic is received when
// the MAC address is not multicast.
// Network packets that are sent to the learned MAC address
// will be forwarded to the bridge port that is stored in
// the FDB table.
sourceAddress := eth.SourceAddress()
if _, hasSourceFDB := bridge.fdbTable[BridgeFDBKey(sourceAddress)]; !header.IsMulticastEthernetAddress(sourceAddress) && !hasSourceFDB {
updateFDB = true
}
if entry, exist := bridge.fdbTable[BridgeFDBKey(eth.DestinationAddress())]; !exist {
// When no FDB entry is found, send the packet to all ports.
for _, port := range bridge.ports {
if p == port {
continue
}
newPkt := NewPacketBuffer(PacketBufferOptions{
ReserveHeaderBytes: int(port.nic.MaxHeaderLength()),
Payload: pkt.ToBuffer(),
})
port.nic.writeRawPacket(newPkt)
newPkt.DecRef()
}
} else if entry.port != p {
destPort := entry.port
newPkt := NewPacketBuffer(PacketBufferOptions{
ReserveHeaderBytes: int(port.nic.MaxHeaderLength()),
ReserveHeaderBytes: int(destPort.nic.MaxHeaderLength()),
Payload: pkt.ToBuffer(),
})
port.nic.writeRawPacket(newPkt)
destPort.nic.writeRawPacket(newPkt)
newPkt.DecRef()
}
d := bridge.dispatcher
bridge.mu.RUnlock()
if updateFDB {
bridge.mu.Lock()
bridge.addFDBEntryLocked(eth.SourceAddress(), p, 0)
bridge.mu.Unlock()
}
if d != nil {
// The dispatcher may acquire Stack.mu in DeliverNetworkPacket(), which is
// ordered above bridge.mu. So call DeliverNetworkPacket() without holding
@@ -71,12 +114,15 @@ func NewBridgeEndpoint(mtu uint32) *BridgeEndpoint {
addr: tcpip.GetRandMacAddr(),
}
b.ports = make(map[tcpip.NICID]*bridgePort)
b.fdbTable = make(map[BridgeFDBKey]BridgeFDBEntry)
return b
}
// BridgeEndpoint is a bridge endpoint.
//
// +stateify savable
type BridgeEndpoint struct {
mu bridgeRWMutex
mu bridgeRWMutex `state:"nosave"`
// +checklocks:mu
ports map[tcpip.NICID]*bridgePort
// +checklocks:mu
@@ -86,7 +132,9 @@ type BridgeEndpoint struct {
// +checklocks:mu
attached bool
// +checklocks:mu
mtu uint32
mtu uint32
// +checklocks:mu
fdbTable map[BridgeFDBKey]BridgeFDBEntry
maxHeaderLength atomicbitops.Uint32
}
@@ -140,6 +188,12 @@ func (b *BridgeEndpoint) DelNIC(nic *nic) tcpip.Error {
b.mu.Lock()
defer b.mu.Unlock()
port := b.ports[nic.id]
for k, e := range b.fdbTable {
if e.port == port {
delete(b.fdbTable, k)
}
}
delete(b.ports, nic.id)
nic.NetworkLinkEndpoint.Attach(nic)
return nil
@@ -169,8 +223,8 @@ func (b *BridgeEndpoint) MaxHeaderLength() uint16 {
// LinkAddress implements stack.LinkEndpoint.LinkAddress.
func (b *BridgeEndpoint) LinkAddress() tcpip.LinkAddress {
b.mu.Lock()
defer b.mu.Unlock()
b.mu.RLock()
defer b.mu.RUnlock()
return b.addr
}
@@ -195,6 +249,7 @@ func (b *BridgeEndpoint) Attach(dispatcher NetworkDispatcher) {
}
b.dispatcher = dispatcher
b.ports = make(map[tcpip.NICID]*bridgePort)
b.fdbTable = make(map[BridgeFDBKey]BridgeFDBEntry)
}
// IsAttached implements stack.LinkEndpoint.IsAttached.
@@ -227,3 +282,25 @@ func (b *BridgeEndpoint) Close() {}
// SetOnCloseAction implements stack.LinkEndpoint.Close.
func (b *BridgeEndpoint) SetOnCloseAction(func()) {}
// Add a new FDBEntry by learning. The learning happens when a packet
// is received by a bridge port, the bridge will use the port for the future
// deliveries to the NIC device.
// The addr is the key when it looks for the entry.
//
// +checklocks:b.mu
func (b *BridgeEndpoint) addFDBEntryLocked(addr tcpip.LinkAddress, source *bridgePort, flags uint64) bool {
// TODO(b/376924093): limit bridge FDB size.
b.fdbTable[BridgeFDBKey(addr)] = BridgeFDBEntry{
port: source,
}
return true
}
// FindFDBEntry find the FDB entry for the given address. If it doesn't exist,
// it will return an empty entry.
func (b *BridgeEndpoint) FindFDBEntry(addr tcpip.LinkAddress) BridgeFDBEntry {
b.mu.RLock()
defer b.mu.RUnlock()
return b.fdbTable[BridgeFDBKey(addr)]
}

View File

@@ -24,17 +24,16 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
// TODO(b/256037250): Enable by default.
// TODO(b/256037250): We parse headers here. We should save those headers in
// PacketBuffers so they don't have to be re-parsed later.
// TODO(b/256037250): I still see the occasional SACK block in the zero-loss
// benchmark, which should not happen.
// TODO(b/256037250): Some dispatchers, e.g. XDP and RecvMmsg, can receive
// multiple packets at a time. Even if the GRO interval is 0, there is an
// opportunity for coalescing.
// TODO(b/256037250): We're doing some header parsing here, which presents the
// opportunity to skip it later.
// TODO(b/256037250): Can we pass a packet list up the stack too?
// There is room for improvement to the GRO engine:
// - We should save those headers in
// PacketBuffers so they don't have to be re-parsed later.
// - We still see the occasional SACK block in the zero-loss
// benchmark, which should not happen.
// - Some dispatchers, e.g. XDP and RecvMmsg, can receive
// multiple packets at a time. Even if the GRO interval is 0, there is an
// opportunity for coalescing.
// - We could pass a packet list up the stack to reduce traversals up the
// stack.
const (
// groNBuckets is the number of GRO buckets.
@@ -50,6 +49,8 @@ const (
)
// A groBucket holds packets that are undergoing GRO.
//
// +stateify savable
type groBucket struct {
// count is the number of packets in the bucket.
count int
@@ -265,6 +266,8 @@ func (gb *groBucket) found(gd *GRO, groPkt *groPacket, flushGROPkt bool, pkt *st
// A groPacket is packet undergoing GRO. It may be several packets coalesced
// together.
//
// +stateify savable
type groPacket struct {
// groPacketEntry is an intrusive list.
groPacketEntry
@@ -303,6 +306,8 @@ func (pk *groPacket) payloadSize() int {
}
// GRO coalesces incoming packets to increase throughput.
//
// +stateify savable
type GRO struct {
enabled bool
buckets [groNBuckets]groBucket
@@ -444,6 +449,7 @@ func (gd *GRO) dispatch6(pkt *stack.PacketBuffer) {
case header.IPv6HopByHopOptionsExtHdr:
case header.IPv6RoutingExtHdr:
case header.IPv6DestinationOptionsExtHdr:
case header.IPv6ExperimentExtHdr:
default:
// This is either a TCP header or something we can't handle.
ipHdrSize = int(it.HeaderOffset())
@@ -508,8 +514,7 @@ func (gd *GRO) dispatch6(pkt *stack.PacketBuffer) {
}
func (gd *GRO) bucketForPacket4(ipHdr header.IPv4, tcpHdr header.TCP) int {
// TODO(b/256037250): Use jenkins or checksum. Write a test to print
// distribution.
// It would be better to use jenkins or checksum.
var sum int
srcAddr := ipHdr.SourceAddress()
for _, val := range srcAddr.AsSlice() {
@@ -525,8 +530,7 @@ func (gd *GRO) bucketForPacket4(ipHdr header.IPv4, tcpHdr header.TCP) int {
}
func (gd *GRO) bucketForPacket6(ipHdr header.IPv6, tcpHdr header.TCP) int {
// TODO(b/256037250): Use jenkins or checksum. Write a test to print
// distribution.
// It would be better to use jenkins or checksum.
var sum int
srcAddr := ipHdr.SourceAddress()
for _, val := range srcAddr.AsSlice() {

View File

@@ -8,6 +8,111 @@ import (
"gvisor.dev/gvisor/pkg/state"
)
func (gb *groBucket) StateTypeName() string {
return "pkg/tcpip/stack/gro.groBucket"
}
func (gb *groBucket) StateFields() []string {
return []string{
"count",
"packets",
"packetsPrealloc",
"allocIdxs",
}
}
func (gb *groBucket) beforeSave() {}
// +checklocksignore
func (gb *groBucket) StateSave(stateSinkObject state.Sink) {
gb.beforeSave()
stateSinkObject.Save(0, &gb.count)
stateSinkObject.Save(1, &gb.packets)
stateSinkObject.Save(2, &gb.packetsPrealloc)
stateSinkObject.Save(3, &gb.allocIdxs)
}
func (gb *groBucket) afterLoad(context.Context) {}
// +checklocksignore
func (gb *groBucket) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &gb.count)
stateSourceObject.Load(1, &gb.packets)
stateSourceObject.Load(2, &gb.packetsPrealloc)
stateSourceObject.Load(3, &gb.allocIdxs)
}
func (pk *groPacket) StateTypeName() string {
return "pkg/tcpip/stack/gro.groPacket"
}
func (pk *groPacket) StateFields() []string {
return []string{
"groPacketEntry",
"pkt",
"ipHdr",
"tcpHdr",
"initialLength",
"idx",
}
}
func (pk *groPacket) beforeSave() {}
// +checklocksignore
func (pk *groPacket) StateSave(stateSinkObject state.Sink) {
pk.beforeSave()
stateSinkObject.Save(0, &pk.groPacketEntry)
stateSinkObject.Save(1, &pk.pkt)
stateSinkObject.Save(2, &pk.ipHdr)
stateSinkObject.Save(3, &pk.tcpHdr)
stateSinkObject.Save(4, &pk.initialLength)
stateSinkObject.Save(5, &pk.idx)
}
func (pk *groPacket) afterLoad(context.Context) {}
// +checklocksignore
func (pk *groPacket) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &pk.groPacketEntry)
stateSourceObject.Load(1, &pk.pkt)
stateSourceObject.Load(2, &pk.ipHdr)
stateSourceObject.Load(3, &pk.tcpHdr)
stateSourceObject.Load(4, &pk.initialLength)
stateSourceObject.Load(5, &pk.idx)
}
func (gd *GRO) StateTypeName() string {
return "pkg/tcpip/stack/gro.GRO"
}
func (gd *GRO) StateFields() []string {
return []string{
"enabled",
"buckets",
"Dispatcher",
}
}
func (gd *GRO) beforeSave() {}
// +checklocksignore
func (gd *GRO) StateSave(stateSinkObject state.Sink) {
gd.beforeSave()
stateSinkObject.Save(0, &gd.enabled)
stateSinkObject.Save(1, &gd.buckets)
stateSinkObject.Save(2, &gd.Dispatcher)
}
func (gd *GRO) afterLoad(context.Context) {}
// +checklocksignore
func (gd *GRO) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &gd.enabled)
stateSourceObject.Load(1, &gd.buckets)
stateSourceObject.Load(2, &gd.Dispatcher)
}
func (l *groPacketList) StateTypeName() string {
return "pkg/tcpip/stack/gro.groPacketList"
}
@@ -65,6 +170,9 @@ func (e *groPacketEntry) StateLoad(ctx context.Context, stateSourceObject state.
}
func init() {
state.Register((*groBucket)(nil))
state.Register((*groPacket)(nil))
state.Register((*GRO)(nil))
state.Register((*groPacketList)(nil))
state.Register((*groPacketEntry)(nil))
}

View File

@@ -335,9 +335,9 @@ func (it *IPTables) shouldSkipOrPopulateTables(tables []checkTable, pkt *PacketB
// This is called in the hot path even when iptables are disabled, so we ensure
// that it does not allocate. Note that called functions (e.g.
// getConnAndUpdate) can allocate.
// TODO(b/233951539): checkescape fails on arm sometimes. Fix and re-add.
// +checkescape
func (it *IPTables) CheckPrerouting(pkt *PacketBuffer, addressEP AddressableEndpoint, inNicName string) bool {
tables := [...]checkTable{
tables := [...]checkTable{ // escapes: on arm this causes an allocation.
{
fn: check,
tableID: MangleID,
@@ -373,9 +373,9 @@ func (it *IPTables) CheckPrerouting(pkt *PacketBuffer, addressEP AddressableEndp
// This is called in the hot path even when iptables are disabled, so we ensure
// that it does not allocate. Note that called functions (e.g.
// getConnAndUpdate) can allocate.
// TODO(b/233951539): checkescape fails on arm sometimes. Fix and re-add.
// +checkescape
func (it *IPTables) CheckInput(pkt *PacketBuffer, inNicName string) bool {
tables := [...]checkTable{
tables := [...]checkTable{ // escapes: on arm this causes an allocation.
{
fn: checkNAT,
tableID: NATID,
@@ -413,9 +413,9 @@ func (it *IPTables) CheckInput(pkt *PacketBuffer, inNicName string) bool {
// This is called in the hot path even when iptables are disabled, so we ensure
// that it does not allocate. Note that called functions (e.g.
// getConnAndUpdate) can allocate.
// TODO(b/233951539): checkescape fails on arm sometimes. Fix and re-add.
// +checkescape
func (it *IPTables) CheckForward(pkt *PacketBuffer, inNicName, outNicName string) bool {
tables := [...]checkTable{
tables := [...]checkTable{ // escapes: on arm this causes an allocation.
{
fn: check,
tableID: FilterID,
@@ -445,9 +445,9 @@ func (it *IPTables) CheckForward(pkt *PacketBuffer, inNicName, outNicName string
// This is called in the hot path even when iptables are disabled, so we ensure
// that it does not allocate. Note that called functions (e.g.
// getConnAndUpdate) can allocate.
// TODO(b/233951539): checkescape fails on arm sometimes. Fix and re-add.
// +checkescape
func (it *IPTables) CheckOutput(pkt *PacketBuffer, r *Route, outNicName string) bool {
tables := [...]checkTable{
tables := [...]checkTable{ // escapes: on arm this causes an allocation.
{
fn: check,
tableID: MangleID,
@@ -489,9 +489,9 @@ func (it *IPTables) CheckOutput(pkt *PacketBuffer, r *Route, outNicName string)
// This is called in the hot path even when iptables are disabled, so we ensure
// that it does not allocate. Note that called functions (e.g.
// getConnAndUpdate) can allocate.
// TODO(b/233951539): checkescape fails on arm sometimes. Fix and re-add.
// +checkescape
func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, outNicName string) bool {
tables := [...]checkTable{
tables := [...]checkTable{ // escapes: on arm this causes an allocation.
{
fn: check,
tableID: MangleID,

View File

@@ -29,6 +29,8 @@ const (
)
// NeighborEntry describes a neighboring device in the local network.
//
// +stateify savable
type NeighborEntry struct {
Addr tcpip.Address
LinkAddr tcpip.LinkAddress
@@ -76,17 +78,38 @@ const (
Unreachable
)
// +stateify savable
type timer struct {
// done indicates to the timer that the timer was stopped.
done *bool
timer tcpip.Timer
timer tcpip.Timer `state:"nosave"`
}
// +stateify savable
type neighborEntryMu struct {
neighborEntryRWMutex `state:"nosave"`
neigh NeighborEntry
// done is closed when address resolution is complete. It is nil iff s is
// incomplete and resolution is not yet in progress.
done chan struct{} `state:"nosave"`
// onResolve is called with the result of address resolution.
onResolve []func(LinkResolutionResult) `state:"nosave"`
isRouter bool
timer timer
}
// neighborEntry implements a neighbor entry's individual node behavior, as per
// RFC 4861 section 7.3.3. Neighbor Unreachability Detection operates in
// parallel with the sending of packets to a neighbor, necessitating the
// entry's lock to be acquired for all operations.
//
// +stateify savable
type neighborEntry struct {
neighborEntryEntry
@@ -95,22 +118,7 @@ type neighborEntry struct {
// nudState points to the Neighbor Unreachability Detection configuration.
nudState *NUDState
mu struct {
neighborEntryRWMutex
neigh NeighborEntry
// done is closed when address resolution is complete. It is nil iff s is
// incomplete and resolution is not yet in progress.
done chan struct{}
// onResolve is called with the result of address resolution.
onResolve []func(LinkResolutionResult)
isRouter bool
timer timer
}
mu neighborEntryMu
}
// newNeighborEntry creates a neighbor cache entry starting at the default

View File

@@ -90,6 +90,10 @@ type nic struct {
// Primary is the main controlling interface in a bonded setup.
Primary *nic
// experimentIPOptionEnabled indicates whether the NIC supports the
// experiment IP option.
experimentIPOptionEnabled bool
}
// makeNICStats initializes the NIC statistics and associates them to the global
@@ -103,7 +107,7 @@ func makeNICStats(global tcpip.NICStats) sharedStats {
// +stateify savable
type packetEndpointList struct {
mu packetEndpointListRWMutex
mu packetEndpointListRWMutex `state:"nosave"`
// eps is protected by mu, but the contained PacketEndpoint values are not.
//
@@ -188,6 +192,7 @@ func newNIC(stack *Stack, id tcpip.NICID, ep LinkEndpoint, opts NICOptions) *nic
duplicateAddressDetectors: make(map[tcpip.NetworkProtocolNumber]DuplicateAddressDetector),
qDisc: qDisc,
deliverLinkPackets: opts.DeliverLinkPackets,
experimentIPOptionEnabled: opts.EnableExperimentIPOption,
}
nic.linkResQueue.init(nic)
@@ -1095,6 +1100,12 @@ func (n *nic) multicastForwarding(protocol tcpip.NetworkProtocolNumber) (bool, t
return ep.MulticastForwarding(), nil
}
// GetExperimentIPOptionEnabled returns whether the NIC is responsible for
// passing the experiment IP option.
func (n *nic) GetExperimentIPOptionEnabled() bool {
return n.experimentIPOptionEnabled
}
// CoordinatorNIC represents NetworkLinkEndpoint that can join multiple network devices.
type CoordinatorNIC interface {
// AddNIC adds the specified NIC device.

View File

@@ -381,6 +381,7 @@ func (pk *PacketBuffer) Clone() *PacketBuffer {
newPk.Hash = pk.Hash
newPk.Owner = pk.Owner
newPk.GSOOptions = pk.GSOOptions
newPk.EgressRoute = pk.EgressRoute
newPk.NetworkProtocolNumber = pk.NetworkProtocolNumber
newPk.dnatDone = pk.dnatDone
newPk.snatDone = pk.snatDone

View File

@@ -33,9 +33,8 @@ type pendingPacket struct {
pkt *PacketBuffer
}
// +stateify savable
type packetsPendingLinkResolutionMu struct {
packetsPendingLinkResolutionMutex `state:"nosave"`
packetsPendingLinkResolutionMutex
// The packets to send once the resolver completes.
//
@@ -56,7 +55,7 @@ type packetsPendingLinkResolutionMu struct {
// +stateify savable
type packetsPendingLinkResolution struct {
nic *nic
mu packetsPendingLinkResolutionMu
mu packetsPendingLinkResolutionMu `state:"nosave"`
}
func (f *packetsPendingLinkResolution) incrementOutgoingPacketErrors(pkt *PacketBuffer) {
@@ -150,7 +149,7 @@ func (f *packetsPendingLinkResolution) enqueue(r *Route, pkt *PacketBuffer) tcpi
packets, ok := f.mu.packets[ch]
packets = append(packets, pendingPacket{
routeInfo: routeInfo,
pkt: pkt.IncRef(),
pkt: pkt.Clone(),
})
if len(packets) > maxPendingPacketsPerResolution {

View File

@@ -162,7 +162,7 @@ type PacketEndpoint interface {
// match the endpoint.
//
// Implementers should treat packet as immutable and should copy it
// before before modification.
// before modification.
//
// linkHeader may have a length of 0, in which case the PacketEndpoint
// should construct its own ethernet header for applications.
@@ -171,6 +171,67 @@ type PacketEndpoint interface {
HandlePacket(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// MappablePacketEndpoint is a packet endpoint that supports forwarding its
// packets to a PacketMMapEndpoint.
type MappablePacketEndpoint interface {
PacketEndpoint
// GetPacketMMapOpts returns the options for initializing a PacketMMapEndpoint
// for this endpoint.
GetPacketMMapOpts(req *tcpip.TpacketReq, isRx bool) PacketMMapOpts
// SetPacketMMapEndpoint sets the PacketMMapEndpoint for this endpoint. All
// packets received by this endpoint will be forwarded to the provided
// PacketMMapEndpoint.
SetPacketMMapEndpoint(ep PacketMMapEndpoint)
// GetPacketMMapEndpoint returns the PacketMMapEndpoint for this endpoint or
// nil if there is none.
GetPacketMMapEndpoint() PacketMMapEndpoint
// HandlePacketMMapCopy is a function that is called when a packet received is
// too large for the buffer size specified for the memory mapped endpoint. In
// this case, the packet is copied and passed to the original packet endpoint.
HandlePacketMMapCopy(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// PacketMMapOpts are the options for initializing a PacketMMapEndpoint.
//
// +stateify savable
type PacketMMapOpts struct {
Req *tcpip.TpacketReq
IsRx bool
Cooked bool
Stack *Stack
Stats *tcpip.TransportEndpointStats
Wq *waiter.Queue
NICID tcpip.NICID
NetProto tcpip.NetworkProtocolNumber
PacketEndpoint MappablePacketEndpoint
}
// PacketMMapEndpoint is the interface implemented by endpoints to handle memory
// mapped packets over the packet transport protocol (PACKET_MMAP).
type PacketMMapEndpoint interface {
// HandlePacket is called by the stack when new packets arrive that
// match the endpoint.
//
// Implementers should treat packet as immutable and should copy it
// before modification.
//
// linkHeader may have a length of 0, in which case the PacketEndpoint
// should construct its own ethernet header for applications.
//
// HandlePacket may modify pkt.
HandlePacket(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
// Close releases any resources associated with the endpoint.
Close()
// Readiness returns the events that the endpoint is ready for.
Readiness(mask waiter.EventMask) waiter.EventMask
}
// UnknownDestinationPacketDisposition enumerates the possible return values from
// HandleUnknownDestinationPacket().
type UnknownDestinationPacketDisposition int
@@ -244,6 +305,9 @@ type TransportProtocol interface {
// previously paused by Pause.
Resume()
// Restore starts any protocol level background workers during restore.
Restore()
// Parse sets pkt.TransportHeader and trims pkt.Data appropriately. It does
// neither and returns false if pkt.Data is too small, i.e. pkt.Data.Size() <
// MinimumPacketSize()
@@ -319,6 +383,10 @@ type NetworkHeaderParams struct {
// DF indicates whether the DF bit should be set.
DF bool
// ExperimentOptionValue is a 16 bit value that is set for the IP experiment
// option headers if it is not zero.
ExperimentOptionValue uint16
}
// GroupAddressableEndpoint is an endpoint that supports group addressing.
@@ -1142,7 +1210,7 @@ type NetworkLinkEndpoint interface {
// Close is called when the endpoint is removed from a stack.
Close()
// SetOnCloseAction sets the action that will be exected before closing the
// SetOnCloseAction sets the action that will be executed before closing the
// endpoint. It is used to destroy a network device when its endpoint
// is closed. Endpoints that are closed only after destroying their
// network devices can implement this method as no-op.

View File

@@ -0,0 +1,29 @@
// Copyright 2024 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"context"
"math/rand"
"time"
cryptorand "gvisor.dev/gvisor/pkg/rand"
)
// afterLoad is invoked by stateify.
func (s *Stack) afterLoad(context.Context) {
s.insecureRNG = rand.New(rand.NewSource(time.Now().UnixNano()))
s.secureRNG = cryptorand.RNGFrom(cryptorand.Reader)
}

View File

@@ -20,11 +20,11 @@
package stack
import (
"context"
"encoding/binary"
"fmt"
"io"
"math/rand"
"sync/atomic"
"time"
"golang.org/x/time/rate"
@@ -90,16 +90,16 @@ type Stack struct {
// routeTable is a list of routes sorted by prefix length, longest (most specific) first.
// +checklocks:routeMu
routeTable tcpip.RouteList
routeTable tcpip.RouteList `state:"nosave"`
mu stackRWMutex `state:"nosave"`
// +checklocks:mu
nics map[tcpip.NICID]*nic
nics map[tcpip.NICID]*nic `state:"nosave"`
// +checklocks:mu
defaultForwardingEnabled map[tcpip.NetworkProtocolNumber]struct{}
// nicIDGen is used to generate NIC IDs.
nicIDGen atomicbitops.Int32
nicIDGen atomicbitops.Int32 `state:"nosave"`
// cleanupEndpointsMu protects cleanupEndpoints.
cleanupEndpointsMu cleanupEndpointsMutex `state:"nosave"`
@@ -108,11 +108,6 @@ type Stack struct {
*ports.PortManager
// If not nil, then any new endpoints will have this probe function
// invoked everytime they receive a TCP segment.
// TODO(b/341946753): Restore them when netstack is savable.
tcpProbeFunc atomic.Value `state:"nosave"` // TCPProbeFunc
// clock is used to generate user-visible times.
clock tcpip.Clock
@@ -150,11 +145,9 @@ type Stack struct {
// randomGenerator is an injectable pseudo random generator that can be
// used when a random number is required. It must not be used in
// security-sensitive contexts.
// TODO(b/341946753): Restore them when netstack is savable.
insecureRNG *rand.Rand `state:"nosave"`
// secureRNG is a cryptographically secure random number generator.
// TODO(b/341946753): Restore them when netstack is savable.
secureRNG cryptorand.RNG `state:"nosave"`
// sendBufferSize holds the min/default/max send buffer sizes for
@@ -180,6 +173,9 @@ type Stack struct {
// tsOffsetSecret is the secret key for generating timestamp offsets
// initialized at stack startup.
tsOffsetSecret uint32
// saveRestoreEnabled indicates whether the stack is saved and restored.
saveRestoreEnabled bool
}
// NetworkProtocolFactory instantiates a network protocol.
@@ -779,23 +775,27 @@ func (s *Stack) addRouteLocked(route *tcpip.Route) {
s.routeTable.PushBack(route)
}
// RemoveRoutes removes matching routes from the route table.
func (s *Stack) RemoveRoutes(match func(tcpip.Route) bool) {
// RemoveRoutes removes matching routes from the route table, it
// returns the number of routes that are removed.
func (s *Stack) RemoveRoutes(match func(tcpip.Route) bool) int {
s.routeMu.Lock()
defer s.routeMu.Unlock()
s.removeRoutesLocked(match)
return s.removeRoutesLocked(match)
}
// +checklocks:s.routeMu
func (s *Stack) removeRoutesLocked(match func(tcpip.Route) bool) {
func (s *Stack) removeRoutesLocked(match func(tcpip.Route) bool) int {
count := 0
for route := s.routeTable.Front(); route != nil; {
next := route.Next()
if match(*route) {
s.routeTable.Remove(route)
count++
}
route = next
}
return count
}
// ReplaceRoute replaces the route in the routing table which matchse
@@ -878,6 +878,10 @@ type NICOptions struct {
// DeliverLinkPackets specifies whether the NIC is responsible for
// delivering raw packets to packet sockets.
DeliverLinkPackets bool
// EnableExperimentIPOption specifies whether the NIC is responsible for
// passing the experiment IP option.
EnableExperimentIPOption bool
}
// GetNICByID return a network device associated with the specified ID.
@@ -1049,7 +1053,10 @@ func (s *Stack) SetNICCoordinator(id tcpip.NICID, mid tcpip.NICID) tcpip.Error {
if !ok {
return &tcpip.ErrUnknownNICID{}
}
// Setting a coordinator for a coordinator NIC is not allowed.
if _, ok := nic.NetworkLinkEndpoint.(CoordinatorNIC); ok {
return &tcpip.ErrNoSuchFile{}
}
m, ok := s.nics[mid]
if !ok {
return &tcpip.ErrUnknownNICID{}
@@ -1959,6 +1966,36 @@ func (s *Stack) Pause() {
}
}
func (s *Stack) getNICs() map[tcpip.NICID]*nic {
s.mu.RLock()
defer s.mu.RUnlock()
nics := s.nics
return nics
}
// ReplaceConfig replaces config in the loaded stack.
func (s *Stack) ReplaceConfig(st *Stack) {
if st == nil {
panic("stack.Stack cannot be nil when netstack s/r is enabled")
}
// Update route table.
s.SetRouteTable(st.GetRouteTable())
// Update NICs.
nics := st.getNICs()
s.mu.Lock()
defer s.mu.Unlock()
s.nics = make(map[tcpip.NICID]*nic)
for id, nic := range nics {
nic.stack = s
s.nics[id] = nic
_ = s.NextNICID()
}
s.tables = st.tables
}
// Restore restarts the stack after a restore. This must be called after the
// entire system has been restored.
func (s *Stack) Restore() {
@@ -1967,13 +2004,18 @@ func (s *Stack) Restore() {
s.mu.Lock()
eps := s.restoredEndpoints
s.restoredEndpoints = nil
saveRestoreEnabled := s.saveRestoreEnabled
s.mu.Unlock()
for _, e := range eps {
e.Restore(s)
}
// Now resume any protocol level background workers.
for _, p := range s.transportProtocols {
p.proto.Resume()
if saveRestoreEnabled {
p.proto.Restore()
} else {
p.proto.Resume()
}
}
}
@@ -2102,41 +2144,6 @@ func (s *Stack) TransportProtocolInstance(num tcpip.TransportProtocolNumber) Tra
return nil
}
// AddTCPProbe installs a probe function that will be invoked on every segment
// received by a given TCP endpoint. The probe function is passed a copy of the
// TCP endpoint state before and after processing of the segment.
//
// NOTE: TCPProbe is added only to endpoints created after this call. Endpoints
// created prior to this call will not call the probe function.
//
// Further, installing two different probes back to back can result in some
// endpoints calling the first one and some the second one. There is no
// guarantee provided on which probe will be invoked. Ideally this should only
// be called once per stack.
func (s *Stack) AddTCPProbe(probe TCPProbeFunc) {
s.tcpProbeFunc.Store(probe)
}
// GetTCPProbe returns the TCPProbeFunc if installed with AddTCPProbe, nil
// otherwise.
func (s *Stack) GetTCPProbe() TCPProbeFunc {
p := s.tcpProbeFunc.Load()
if p == nil {
return nil
}
return p.(TCPProbeFunc)
}
// RemoveTCPProbe removes an installed TCP probe.
//
// NOTE: This only ensures that endpoints created after this call do not
// have a probe attached. Endpoints already created will continue to invoke
// TCP probe.
func (s *Stack) RemoveTCPProbe() {
// This must be TCPProbeFunc(nil) because atomic.Value.Store(nil) panics.
s.tcpProbeFunc.Store(TCPProbeFunc(nil))
}
// JoinGroup joins the given multicast group on the given NIC.
func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) tcpip.Error {
s.mu.RLock()
@@ -2399,3 +2406,32 @@ func (s *Stack) SetNICStack(id tcpip.NICID, peer *Stack) (tcpip.NICID, tcpip.Err
id = tcpip.NICID(peer.NextNICID())
return id, peer.CreateNICWithOptions(id, ne, NICOptions{Name: nic.Name()})
}
// EnableSaveRestore marks the saveRestoreEnabled to true.
func (s *Stack) EnableSaveRestore() {
s.mu.Lock()
defer s.mu.Unlock()
s.saveRestoreEnabled = true
}
// IsSaveRestoreEnabled returns true if save restore is enabled for the stack.
func (s *Stack) IsSaveRestoreEnabled() bool {
s.mu.Lock()
defer s.mu.Unlock()
return s.saveRestoreEnabled
}
// contextID is this package's type for context.Context.Value keys.
type contextID int
const (
// CtxRestoreStack is a Context.Value key for the stack to be used in restore.
CtxRestoreStack contextID = iota
)
// RestoreStackFromContext returns the stack to be used during restore.
func RestoreStackFromContext(ctx context.Context) *Stack {
return ctx.Value(CtxRestoreStack).(*Stack)
}

File diff suppressed because it is too large Load Diff

View File

@@ -35,7 +35,6 @@ import (
"io"
"math"
"math/bits"
"math/rand"
"net"
"reflect"
"strconv"
@@ -43,6 +42,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/atomicbitops"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -1185,6 +1185,19 @@ func (*ICMPv6Filter) isGettableSocketOption() {}
func (*ICMPv6Filter) isSettableSocketOption() {}
// TpacketReq is the tpacket_req structure as described in
// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
//
// +stateify savable
type TpacketReq struct {
TpBlockSize uint32
TpBlockNr uint32
TpFrameSize uint32
TpFrameNr uint32
}
func (*TpacketReq) isSettableSocketOption() {}
// EndpointState represents the state of an endpoint.
type EndpointState uint8
@@ -1981,6 +1994,10 @@ type IPForwardingStats struct {
// successfully forwarded.
Errors *StatCounter
// OutgoingDeviceClosedForSend is the number of packets that were dropped due
// to the outgoing device being closed for send.
OutgoingDeviceClosedForSend *StatCounter
// LINT.ThenChange(network/internal/ip/stats.go:MultiCounterIPForwardingStats)
}

View File

@@ -932,6 +932,27 @@ func (e *ErrMulticastInputCannotBeOutput) afterLoad(context.Context) {}
func (e *ErrMulticastInputCannotBeOutput) StateLoad(ctx context.Context, stateSourceObject state.Source) {
}
func (e *ErrEndpointBusy) StateTypeName() string {
return "pkg/tcpip.ErrEndpointBusy"
}
func (e *ErrEndpointBusy) StateFields() []string {
return []string{}
}
func (e *ErrEndpointBusy) beforeSave() {}
// +checklocksignore
func (e *ErrEndpointBusy) StateSave(stateSinkObject state.Sink) {
e.beforeSave()
}
func (e *ErrEndpointBusy) afterLoad(context.Context) {}
// +checklocksignore
func (e *ErrEndpointBusy) StateLoad(ctx context.Context, stateSourceObject state.Source) {
}
func (l *RouteList) StateTypeName() string {
return "pkg/tcpip.RouteList"
}
@@ -1078,6 +1099,7 @@ func (so *SocketOptions) StateFields() []string {
"receiveBufferSize",
"linger",
"rcvlowat",
"experimentOptionValue",
}
}
@@ -1114,6 +1136,7 @@ func (so *SocketOptions) StateSave(stateSinkObject state.Sink) {
stateSinkObject.Save(25, &so.receiveBufferSize)
stateSinkObject.Save(26, &so.linger)
stateSinkObject.Save(27, &so.rcvlowat)
stateSinkObject.Save(28, &so.experimentOptionValue)
}
func (so *SocketOptions) afterLoad(context.Context) {}
@@ -1148,6 +1171,7 @@ func (so *SocketOptions) StateLoad(ctx context.Context, stateSourceObject state.
stateSourceObject.Load(25, &so.receiveBufferSize)
stateSourceObject.Load(26, &so.linger)
stateSourceObject.Load(27, &so.rcvlowat)
stateSourceObject.Load(28, &so.experimentOptionValue)
}
func (l *LocalSockError) StateTypeName() string {
@@ -1644,6 +1668,40 @@ func (f *ICMPv6Filter) StateLoad(ctx context.Context, stateSourceObject state.So
stateSourceObject.Load(0, &f.DenyType)
}
func (t *TpacketReq) StateTypeName() string {
return "pkg/tcpip.TpacketReq"
}
func (t *TpacketReq) StateFields() []string {
return []string{
"TpBlockSize",
"TpBlockNr",
"TpFrameSize",
"TpFrameNr",
}
}
func (t *TpacketReq) beforeSave() {}
// +checklocksignore
func (t *TpacketReq) StateSave(stateSinkObject state.Sink) {
t.beforeSave()
stateSinkObject.Save(0, &t.TpBlockSize)
stateSinkObject.Save(1, &t.TpBlockNr)
stateSinkObject.Save(2, &t.TpFrameSize)
stateSinkObject.Save(3, &t.TpFrameNr)
}
func (t *TpacketReq) afterLoad(context.Context) {}
// +checklocksignore
func (t *TpacketReq) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &t.TpBlockSize)
stateSourceObject.Load(1, &t.TpBlockNr)
stateSourceObject.Load(2, &t.TpFrameSize)
stateSourceObject.Load(3, &t.TpFrameNr)
}
func (l *LingerOption) StateTypeName() string {
return "pkg/tcpip.LingerOption"
}
@@ -2362,6 +2420,7 @@ func (i *IPForwardingStats) StateFields() []string {
"NoMulticastPendingQueueBufferSpace",
"OutgoingDeviceNoBufferSpace",
"Errors",
"OutgoingDeviceClosedForSend",
}
}
@@ -2383,6 +2442,7 @@ func (i *IPForwardingStats) StateSave(stateSinkObject state.Sink) {
stateSinkObject.Save(10, &i.NoMulticastPendingQueueBufferSpace)
stateSinkObject.Save(11, &i.OutgoingDeviceNoBufferSpace)
stateSinkObject.Save(12, &i.Errors)
stateSinkObject.Save(13, &i.OutgoingDeviceClosedForSend)
}
func (i *IPForwardingStats) afterLoad(context.Context) {}
@@ -2402,6 +2462,7 @@ func (i *IPForwardingStats) StateLoad(ctx context.Context, stateSourceObject sta
stateSourceObject.Load(10, &i.NoMulticastPendingQueueBufferSpace)
stateSourceObject.Load(11, &i.OutgoingDeviceNoBufferSpace)
stateSourceObject.Load(12, &i.Errors)
stateSourceObject.Load(13, &i.OutgoingDeviceClosedForSend)
}
func (i *IPStats) StateTypeName() string {
@@ -3230,6 +3291,7 @@ func init() {
state.Register((*ErrWouldBlock)(nil))
state.Register((*ErrMissingRequiredFields)(nil))
state.Register((*ErrMulticastInputCannotBeOutput)(nil))
state.Register((*ErrEndpointBusy)(nil))
state.Register((*RouteList)(nil))
state.Register((*RouteEntry)(nil))
state.Register((*sockErrorList)(nil))
@@ -3250,6 +3312,7 @@ func init() {
state.Register((*TCPSendBufferSizeRangeOption)(nil))
state.Register((*TCPReceiveBufferSizeRangeOption)(nil))
state.Register((*ICMPv6Filter)(nil))
state.Register((*TpacketReq)(nil))
state.Register((*LingerOption)(nil))
state.Register((*IPPacketInfo)(nil))
state.Register((*IPv6PacketInfo)(nil))

View File

@@ -57,7 +57,7 @@ type endpoint struct {
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
stack *stack.Stack
transProto tcpip.TransportProtocolNumber
waiterQueue *waiter.Queue
net network.Endpoint

View File

@@ -36,7 +36,11 @@ func (p *icmpPacket) loadReceivedAt(_ context.Context, nsec int64) {
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad(ctx context.Context) {
stack.RestoreStackFromContext(ctx).RegisterRestoredEndpoint(e)
if e.stack.IsSaveRestoreEnabled() {
e.stack.RegisterRestoredEndpoint(e)
} else {
stack.RestoreStackFromContext(ctx).RegisterRestoredEndpoint(e)
}
}
// beforeSave is invoked by stateify.
@@ -50,6 +54,10 @@ func (e *endpoint) Restore(s *stack.Stack) {
e.thaw()
e.net.Resume(s)
if e.stack.IsSaveRestoreEnabled() {
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
return
}
e.stack = s
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)

View File

@@ -60,6 +60,7 @@ func (e *endpoint) StateTypeName() string {
func (e *endpoint) StateFields() []string {
return []string{
"DefaultSocketOptionsHandler",
"stack",
"transProto",
"waiterQueue",
"net",
@@ -78,33 +79,35 @@ func (e *endpoint) StateFields() []string {
func (e *endpoint) StateSave(stateSinkObject state.Sink) {
e.beforeSave()
stateSinkObject.Save(0, &e.DefaultSocketOptionsHandler)
stateSinkObject.Save(1, &e.transProto)
stateSinkObject.Save(2, &e.waiterQueue)
stateSinkObject.Save(3, &e.net)
stateSinkObject.Save(4, &e.stats)
stateSinkObject.Save(5, &e.ops)
stateSinkObject.Save(6, &e.rcvReady)
stateSinkObject.Save(7, &e.rcvList)
stateSinkObject.Save(8, &e.rcvBufSize)
stateSinkObject.Save(9, &e.rcvClosed)
stateSinkObject.Save(10, &e.frozen)
stateSinkObject.Save(11, &e.ident)
stateSinkObject.Save(1, &e.stack)
stateSinkObject.Save(2, &e.transProto)
stateSinkObject.Save(3, &e.waiterQueue)
stateSinkObject.Save(4, &e.net)
stateSinkObject.Save(5, &e.stats)
stateSinkObject.Save(6, &e.ops)
stateSinkObject.Save(7, &e.rcvReady)
stateSinkObject.Save(8, &e.rcvList)
stateSinkObject.Save(9, &e.rcvBufSize)
stateSinkObject.Save(10, &e.rcvClosed)
stateSinkObject.Save(11, &e.frozen)
stateSinkObject.Save(12, &e.ident)
}
// +checklocksignore
func (e *endpoint) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &e.DefaultSocketOptionsHandler)
stateSourceObject.Load(1, &e.transProto)
stateSourceObject.Load(2, &e.waiterQueue)
stateSourceObject.Load(3, &e.net)
stateSourceObject.Load(4, &e.stats)
stateSourceObject.Load(5, &e.ops)
stateSourceObject.Load(6, &e.rcvReady)
stateSourceObject.Load(7, &e.rcvList)
stateSourceObject.Load(8, &e.rcvBufSize)
stateSourceObject.Load(9, &e.rcvClosed)
stateSourceObject.Load(10, &e.frozen)
stateSourceObject.Load(11, &e.ident)
stateSourceObject.Load(1, &e.stack)
stateSourceObject.Load(2, &e.transProto)
stateSourceObject.Load(3, &e.waiterQueue)
stateSourceObject.Load(4, &e.net)
stateSourceObject.Load(5, &e.stats)
stateSourceObject.Load(6, &e.ops)
stateSourceObject.Load(7, &e.rcvReady)
stateSourceObject.Load(8, &e.rcvList)
stateSourceObject.Load(9, &e.rcvBufSize)
stateSourceObject.Load(10, &e.rcvClosed)
stateSourceObject.Load(11, &e.frozen)
stateSourceObject.Load(12, &e.ident)
stateSourceObject.AfterLoad(func() { e.afterLoad(ctx) })
}

View File

@@ -128,6 +128,9 @@ func (*protocol) Pause() {}
// Resume implements stack.TransportProtocol.Resume.
func (*protocol) Resume() {}
// Restore implements stack.TransportProtocol.Restore.
func (*protocol) Restore() {}
// Parse implements stack.TransportProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
// Right now, the Parse() method is tied to enabled protocols passed into

View File

@@ -35,7 +35,7 @@ import (
// +stateify savable
type Endpoint struct {
// The following fields must only be set once then never changed.
stack *stack.Stack `state:"manual"`
stack *stack.Stack
ops *tcpip.SocketOptions
netProto tcpip.NetworkProtocolNumber
transProto tcpip.TransportProtocolNumber
@@ -53,7 +53,7 @@ type Endpoint struct {
// +checklocks:mu
effectiveNetProto tcpip.NetworkProtocolNumber
// +checklocks:mu
connectedRoute *stack.Route `state:"manual"`
connectedRoute *stack.Route `state:"nosave"`
// +checklocks:mu
multicastMemberships map[multicastMembership]struct{}
// +checklocks:mu
@@ -310,6 +310,13 @@ func (c *WriteContext) newPacketBufferLocked(reserveHdrBytes int, data buffer.Bu
// This matches Linux behaviour:
// https://github.com/torvalds/linux/blob/38d741cb70b/include/net/sock.h#L2519
// https://github.com/torvalds/linux/blob/38d741cb70b/net/core/sock.c#L2588
var expOptVal uint16
if nic, err := c.e.stack.GetNICByID(c.route.OutgoingNIC()); err == nil && nic.GetExperimentIPOptionEnabled() {
expOptVal = c.e.ops.GetExperimentOptionValue()
}
if c.route.NetProto() == header.IPv6ProtocolNumber && expOptVal != 0 {
reserveHdrBytes += header.IPv6ExperimentHdrLength
}
pktSize := int64(reserveHdrBytes) + int64(data.Size())
e.sendBufferSizeInUse += pktSize
@@ -344,10 +351,16 @@ func (c *WriteContext) WritePacket(pkt *stack.PacketBuffer, headerIncluded bool)
return c.route.WriteHeaderIncludedPacket(pkt)
}
var expOptVal uint16
if nic, err := c.e.stack.GetNICByID(c.route.OutgoingNIC()); err == nil && nic.GetExperimentIPOptionEnabled() {
expOptVal = c.e.ops.GetExperimentOptionValue()
}
err := c.route.WritePacket(stack.NetworkHeaderParams{
Protocol: c.e.transProto,
TTL: c.ttl,
TOS: c.tos,
Protocol: c.e.transProto,
TTL: c.ttl,
TOS: c.tos,
ExperimentOptionValue: expOptVal,
}, pkt)
if _, ok := err.(*tcpip.ErrNoBufferSpace); ok {

View File

@@ -14,6 +14,7 @@ func (e *Endpoint) StateTypeName() string {
func (e *Endpoint) StateFields() []string {
return []string{
"stack",
"ops",
"netProto",
"transProto",
@@ -40,48 +41,50 @@ func (e *Endpoint) beforeSave() {}
// +checklocksignore
func (e *Endpoint) StateSave(stateSinkObject state.Sink) {
e.beforeSave()
stateSinkObject.Save(0, &e.ops)
stateSinkObject.Save(1, &e.netProto)
stateSinkObject.Save(2, &e.transProto)
stateSinkObject.Save(3, &e.waiterQueue)
stateSinkObject.Save(4, &e.wasBound)
stateSinkObject.Save(5, &e.owner)
stateSinkObject.Save(6, &e.writeShutdown)
stateSinkObject.Save(7, &e.effectiveNetProto)
stateSinkObject.Save(8, &e.multicastMemberships)
stateSinkObject.Save(9, &e.ipv4TTL)
stateSinkObject.Save(10, &e.ipv6HopLimit)
stateSinkObject.Save(11, &e.multicastTTL)
stateSinkObject.Save(12, &e.multicastAddr)
stateSinkObject.Save(13, &e.multicastNICID)
stateSinkObject.Save(14, &e.ipv4TOS)
stateSinkObject.Save(15, &e.ipv6TClass)
stateSinkObject.Save(16, &e.info)
stateSinkObject.Save(17, &e.state)
stateSinkObject.Save(0, &e.stack)
stateSinkObject.Save(1, &e.ops)
stateSinkObject.Save(2, &e.netProto)
stateSinkObject.Save(3, &e.transProto)
stateSinkObject.Save(4, &e.waiterQueue)
stateSinkObject.Save(5, &e.wasBound)
stateSinkObject.Save(6, &e.owner)
stateSinkObject.Save(7, &e.writeShutdown)
stateSinkObject.Save(8, &e.effectiveNetProto)
stateSinkObject.Save(9, &e.multicastMemberships)
stateSinkObject.Save(10, &e.ipv4TTL)
stateSinkObject.Save(11, &e.ipv6HopLimit)
stateSinkObject.Save(12, &e.multicastTTL)
stateSinkObject.Save(13, &e.multicastAddr)
stateSinkObject.Save(14, &e.multicastNICID)
stateSinkObject.Save(15, &e.ipv4TOS)
stateSinkObject.Save(16, &e.ipv6TClass)
stateSinkObject.Save(17, &e.info)
stateSinkObject.Save(18, &e.state)
}
func (e *Endpoint) afterLoad(context.Context) {}
// +checklocksignore
func (e *Endpoint) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &e.ops)
stateSourceObject.Load(1, &e.netProto)
stateSourceObject.Load(2, &e.transProto)
stateSourceObject.Load(3, &e.waiterQueue)
stateSourceObject.Load(4, &e.wasBound)
stateSourceObject.Load(5, &e.owner)
stateSourceObject.Load(6, &e.writeShutdown)
stateSourceObject.Load(7, &e.effectiveNetProto)
stateSourceObject.Load(8, &e.multicastMemberships)
stateSourceObject.Load(9, &e.ipv4TTL)
stateSourceObject.Load(10, &e.ipv6HopLimit)
stateSourceObject.Load(11, &e.multicastTTL)
stateSourceObject.Load(12, &e.multicastAddr)
stateSourceObject.Load(13, &e.multicastNICID)
stateSourceObject.Load(14, &e.ipv4TOS)
stateSourceObject.Load(15, &e.ipv6TClass)
stateSourceObject.Load(16, &e.info)
stateSourceObject.Load(17, &e.state)
stateSourceObject.Load(0, &e.stack)
stateSourceObject.Load(1, &e.ops)
stateSourceObject.Load(2, &e.netProto)
stateSourceObject.Load(3, &e.transProto)
stateSourceObject.Load(4, &e.waiterQueue)
stateSourceObject.Load(5, &e.wasBound)
stateSourceObject.Load(6, &e.owner)
stateSourceObject.Load(7, &e.writeShutdown)
stateSourceObject.Load(8, &e.effectiveNetProto)
stateSourceObject.Load(9, &e.multicastMemberships)
stateSourceObject.Load(10, &e.ipv4TTL)
stateSourceObject.Load(11, &e.ipv6HopLimit)
stateSourceObject.Load(12, &e.multicastTTL)
stateSourceObject.Load(13, &e.multicastAddr)
stateSourceObject.Load(14, &e.multicastNICID)
stateSourceObject.Load(15, &e.ipv4TOS)
stateSourceObject.Load(16, &e.ipv6TClass)
stateSourceObject.Load(17, &e.info)
stateSourceObject.Load(18, &e.state)
}
func (m *multicastMembership) StateTypeName() string {

View File

@@ -63,7 +63,7 @@ type endpoint struct {
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
stack *stack.Stack
waiterQueue *waiter.Queue
cooked bool
ops tcpip.SocketOptions

View File

@@ -43,12 +43,20 @@ func (ep *endpoint) beforeSave() {
// afterLoad is invoked by stateify.
func (ep *endpoint) afterLoad(ctx context.Context) {
if !ep.stack.IsSaveRestoreEnabled() {
ep.mu.Lock()
ep.stack = stack.RestoreStackFromContext(ctx)
ep.mu.Unlock()
}
ep.stack.RegisterRestoredEndpoint(ep)
}
// Restore implements tcpip.RestoredEndpoint.Restore.
func (ep *endpoint) Restore(_ *stack.Stack) {
ep.mu.Lock()
defer ep.mu.Unlock()
ep.stack = stack.RestoreStackFromContext(ctx)
ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
if err := ep.stack.RegisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep); err != nil {
panic(fmt.Sprintf("RegisterPacketEndpoint(%d, %d, _): %s", ep.boundNIC, ep.boundNetProto, err))
}

View File

@@ -54,6 +54,7 @@ func (ep *endpoint) StateTypeName() string {
func (ep *endpoint) StateFields() []string {
return []string{
"DefaultSocketOptionsHandler",
"stack",
"waiterQueue",
"cooked",
"ops",
@@ -73,35 +74,37 @@ func (ep *endpoint) StateFields() []string {
func (ep *endpoint) StateSave(stateSinkObject state.Sink) {
ep.beforeSave()
stateSinkObject.Save(0, &ep.DefaultSocketOptionsHandler)
stateSinkObject.Save(1, &ep.waiterQueue)
stateSinkObject.Save(2, &ep.cooked)
stateSinkObject.Save(3, &ep.ops)
stateSinkObject.Save(4, &ep.stats)
stateSinkObject.Save(5, &ep.rcvList)
stateSinkObject.Save(6, &ep.rcvBufSize)
stateSinkObject.Save(7, &ep.rcvClosed)
stateSinkObject.Save(8, &ep.rcvDisabled)
stateSinkObject.Save(9, &ep.closed)
stateSinkObject.Save(10, &ep.boundNetProto)
stateSinkObject.Save(11, &ep.boundNIC)
stateSinkObject.Save(12, &ep.lastError)
stateSinkObject.Save(1, &ep.stack)
stateSinkObject.Save(2, &ep.waiterQueue)
stateSinkObject.Save(3, &ep.cooked)
stateSinkObject.Save(4, &ep.ops)
stateSinkObject.Save(5, &ep.stats)
stateSinkObject.Save(6, &ep.rcvList)
stateSinkObject.Save(7, &ep.rcvBufSize)
stateSinkObject.Save(8, &ep.rcvClosed)
stateSinkObject.Save(9, &ep.rcvDisabled)
stateSinkObject.Save(10, &ep.closed)
stateSinkObject.Save(11, &ep.boundNetProto)
stateSinkObject.Save(12, &ep.boundNIC)
stateSinkObject.Save(13, &ep.lastError)
}
// +checklocksignore
func (ep *endpoint) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &ep.DefaultSocketOptionsHandler)
stateSourceObject.Load(1, &ep.waiterQueue)
stateSourceObject.Load(2, &ep.cooked)
stateSourceObject.Load(3, &ep.ops)
stateSourceObject.Load(4, &ep.stats)
stateSourceObject.Load(5, &ep.rcvList)
stateSourceObject.Load(6, &ep.rcvBufSize)
stateSourceObject.Load(7, &ep.rcvClosed)
stateSourceObject.Load(8, &ep.rcvDisabled)
stateSourceObject.Load(9, &ep.closed)
stateSourceObject.Load(10, &ep.boundNetProto)
stateSourceObject.Load(11, &ep.boundNIC)
stateSourceObject.Load(12, &ep.lastError)
stateSourceObject.Load(1, &ep.stack)
stateSourceObject.Load(2, &ep.waiterQueue)
stateSourceObject.Load(3, &ep.cooked)
stateSourceObject.Load(4, &ep.ops)
stateSourceObject.Load(5, &ep.stats)
stateSourceObject.Load(6, &ep.rcvList)
stateSourceObject.Load(7, &ep.rcvBufSize)
stateSourceObject.Load(8, &ep.rcvClosed)
stateSourceObject.Load(9, &ep.rcvDisabled)
stateSourceObject.Load(10, &ep.closed)
stateSourceObject.Load(11, &ep.boundNetProto)
stateSourceObject.Load(12, &ep.boundNIC)
stateSourceObject.Load(13, &ep.lastError)
stateSourceObject.AfterLoad(func() { ep.afterLoad(ctx) })
}

View File

@@ -73,7 +73,7 @@ type endpoint struct {
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
stack *stack.Stack
transProto tcpip.TransportProtocolNumber
waiterQueue *waiter.Queue
associated bool

View File

@@ -16,7 +16,6 @@ package raw
import (
"context"
"fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -35,7 +34,11 @@ func (p *rawPacket) loadReceivedAt(_ context.Context, nsec int64) {
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad(ctx context.Context) {
stack.RestoreStackFromContext(ctx).RegisterRestoredEndpoint(e)
if e.stack.IsSaveRestoreEnabled() {
e.stack.RegisterRestoredEndpoint(e)
} else {
stack.RestoreStackFromContext(ctx).RegisterRestoredEndpoint(e)
}
}
// beforeSave is invoked by stateify.
@@ -46,16 +49,20 @@ func (e *endpoint) beforeSave() {
// Restore implements tcpip.RestoredEndpoint.Restore.
func (e *endpoint) Restore(s *stack.Stack) {
e.net.Resume(s)
e.setReceiveDisabled(false)
e.net.Resume(s)
if e.stack.IsSaveRestoreEnabled() {
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
return
}
e.stack = s
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
if e.associated {
netProto := e.net.NetProto()
if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil {
panic(fmt.Sprintf("e.stack.RegisterRawTransportEndpoint(%d, %d, _): %s", netProto, e.transProto, err))
panic("RegisterRawTransportEndpoint failed during restore")
}
}
}

View File

@@ -60,6 +60,7 @@ func (e *endpoint) StateTypeName() string {
func (e *endpoint) StateFields() []string {
return []string{
"DefaultSocketOptionsHandler",
"stack",
"transProto",
"waiterQueue",
"associated",
@@ -79,35 +80,37 @@ func (e *endpoint) StateFields() []string {
func (e *endpoint) StateSave(stateSinkObject state.Sink) {
e.beforeSave()
stateSinkObject.Save(0, &e.DefaultSocketOptionsHandler)
stateSinkObject.Save(1, &e.transProto)
stateSinkObject.Save(2, &e.waiterQueue)
stateSinkObject.Save(3, &e.associated)
stateSinkObject.Save(4, &e.net)
stateSinkObject.Save(5, &e.stats)
stateSinkObject.Save(6, &e.ops)
stateSinkObject.Save(7, &e.rcvList)
stateSinkObject.Save(8, &e.rcvBufSize)
stateSinkObject.Save(9, &e.rcvClosed)
stateSinkObject.Save(10, &e.rcvDisabled)
stateSinkObject.Save(11, &e.ipv6ChecksumOffset)
stateSinkObject.Save(12, &e.icmpv6Filter)
stateSinkObject.Save(1, &e.stack)
stateSinkObject.Save(2, &e.transProto)
stateSinkObject.Save(3, &e.waiterQueue)
stateSinkObject.Save(4, &e.associated)
stateSinkObject.Save(5, &e.net)
stateSinkObject.Save(6, &e.stats)
stateSinkObject.Save(7, &e.ops)
stateSinkObject.Save(8, &e.rcvList)
stateSinkObject.Save(9, &e.rcvBufSize)
stateSinkObject.Save(10, &e.rcvClosed)
stateSinkObject.Save(11, &e.rcvDisabled)
stateSinkObject.Save(12, &e.ipv6ChecksumOffset)
stateSinkObject.Save(13, &e.icmpv6Filter)
}
// +checklocksignore
func (e *endpoint) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &e.DefaultSocketOptionsHandler)
stateSourceObject.Load(1, &e.transProto)
stateSourceObject.Load(2, &e.waiterQueue)
stateSourceObject.Load(3, &e.associated)
stateSourceObject.Load(4, &e.net)
stateSourceObject.Load(5, &e.stats)
stateSourceObject.Load(6, &e.ops)
stateSourceObject.Load(7, &e.rcvList)
stateSourceObject.Load(8, &e.rcvBufSize)
stateSourceObject.Load(9, &e.rcvClosed)
stateSourceObject.Load(10, &e.rcvDisabled)
stateSourceObject.Load(11, &e.ipv6ChecksumOffset)
stateSourceObject.Load(12, &e.icmpv6Filter)
stateSourceObject.Load(1, &e.stack)
stateSourceObject.Load(2, &e.transProto)
stateSourceObject.Load(3, &e.waiterQueue)
stateSourceObject.Load(4, &e.associated)
stateSourceObject.Load(5, &e.net)
stateSourceObject.Load(6, &e.stats)
stateSourceObject.Load(7, &e.ops)
stateSourceObject.Load(8, &e.rcvList)
stateSourceObject.Load(9, &e.rcvBufSize)
stateSourceObject.Load(10, &e.rcvClosed)
stateSourceObject.Load(11, &e.rcvDisabled)
stateSourceObject.Load(12, &e.ipv6ChecksumOffset)
stateSourceObject.Load(13, &e.icmpv6Filter)
stateSourceObject.AfterLoad(func() { e.afterLoad(ctx) })
}

View File

@@ -526,13 +526,14 @@ func (e *Endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
}
cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
fields := tcpFields{
id: s.id,
ttl: calculateTTL(route, e.ipv4TTL, e.ipv6HopLimit),
tos: e.sendTOS,
flags: header.TCPFlagSyn | header.TCPFlagAck,
seq: cookie,
ack: s.sequenceNumber + 1,
rcvWnd: ctx.rcvWnd,
id: s.id,
ttl: calculateTTL(route, e.ipv4TTL, e.ipv6HopLimit),
tos: e.sendTOS,
flags: header.TCPFlagSyn | header.TCPFlagAck,
seq: cookie,
ack: s.sequenceNumber + 1,
rcvWnd: ctx.rcvWnd,
expOptVal: e.getExperimentOptionValue(route),
}
if err := e.sendSynTCP(route, fields, synOpts); err != nil {
return err

View File

@@ -30,15 +30,22 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
// InitialRTO is the initial retransmission timeout.
// https://github.com/torvalds/linux/blob/7c636d4d20f/include/net/tcp.h#L142
const InitialRTO = time.Second
const (
// tcpMinTimeout is the minimum timeout for a SYN retransmit.
// This mirrors the TCP_TIMEOUT_MIN variable in Linux.
// See: https://github.com/torvalds/linux/blob/249aca0d3d631660aa3583c6a3559b75b6e971b4/include/net/tcp.h#L143
tcpMinTimeout = 2 * time.Microsecond
// maxSegmentsPerWake is the maximum number of segments to process in the main
// protocol goroutine per wake-up. Yielding [after this number of segments are
// processed] allows other events to be processed as well (e.g., timeouts,
// resets, etc.).
const maxSegmentsPerWake = 100
// InitialRTO is the initial retransmission timeout.
// https://github.com/torvalds/linux/blob/7c636d4d20f/include/net/tcp.h#L142
InitialRTO = time.Second
// maxSegmentsPerWake is the maximum number of segments to process per
// wake-up. Yielding [after this number of segments are processed]
// allows other events to be processed as well (e.g., timeouts, resets,
// etc.).
maxSegmentsPerWake = 100
)
type handshakeState int
@@ -297,6 +304,9 @@ func (h *handshake) synSentState(s *segment) tcpip.Error {
// <SEQ=SEG.ACK><CTL=RST>
// and send it.
h.ep.sendEmptyRaw(header.TCPFlagRst, s.ackNumber, 0, 0)
// Since this was a challenge ACK reschedule the retransmit timer to fire
// soon so that the SYN is retransmitted quickly.
h.retransmitTimer.reinit(tcpMinTimeout)
return nil
}
@@ -354,13 +364,14 @@ func (h *handshake) synSentState(s *segment) tcpip.Error {
ttl = h.ep.route.DefaultTTL()
}
h.ep.sendSynTCP(h.ep.route, tcpFields{
id: h.ep.TransportEndpointInfo.ID,
ttl: ttl,
tos: h.ep.sendTOS,
flags: h.flags,
seq: h.iss,
ack: h.ackNum,
rcvWnd: h.rcvWnd,
id: h.ep.TransportEndpointInfo.ID,
ttl: ttl,
tos: h.ep.sendTOS,
flags: h.flags,
seq: h.iss,
ack: h.ackNum,
rcvWnd: h.rcvWnd,
expOptVal: h.ep.getExperimentOptionValue(h.ep.route),
}, synOpts)
return nil
}
@@ -440,13 +451,14 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error {
MSS: h.ep.amss,
}
h.ep.sendSynTCP(h.ep.route, tcpFields{
id: h.ep.TransportEndpointInfo.ID,
ttl: calculateTTL(h.ep.route, h.ep.ipv4TTL, h.ep.ipv6HopLimit),
tos: h.ep.sendTOS,
flags: h.flags,
seq: h.iss,
ack: h.ackNum,
rcvWnd: h.rcvWnd,
id: h.ep.TransportEndpointInfo.ID,
ttl: calculateTTL(h.ep.route, h.ep.ipv4TTL, h.ep.ipv6HopLimit),
tos: h.ep.sendTOS,
flags: h.flags,
seq: h.iss,
ack: h.ackNum,
rcvWnd: h.rcvWnd,
expOptVal: h.ep.getExperimentOptionValue(h.ep.route),
}, synOpts)
return nil
}
@@ -533,7 +545,7 @@ func (h *handshake) processSegments() tcpip.Error {
// We stop processing packets once the handshake is completed,
// otherwise we may process packets meant to be processed by
// the main protocol goroutine.
// the TCP processor goroutine.
if h.state == handshakeCompleted {
break
}
@@ -577,13 +589,14 @@ func (h *handshake) start() {
h.sendSYNOpts = synOpts
h.ep.sendSynTCP(h.ep.route, tcpFields{
id: h.ep.TransportEndpointInfo.ID,
ttl: calculateTTL(h.ep.route, h.ep.ipv4TTL, h.ep.ipv6HopLimit),
tos: h.ep.sendTOS,
flags: h.flags,
seq: h.iss,
ack: h.ackNum,
rcvWnd: h.rcvWnd,
id: h.ep.TransportEndpointInfo.ID,
ttl: calculateTTL(h.ep.route, h.ep.ipv4TTL, h.ep.ipv6HopLimit),
tos: h.ep.sendTOS,
flags: h.flags,
seq: h.iss,
ack: h.ackNum,
rcvWnd: h.rcvWnd,
expOptVal: h.ep.getExperimentOptionValue(h.ep.route),
}, synOpts)
}
@@ -613,13 +626,14 @@ func (h *handshake) retransmitHandlerLocked() tcpip.Error {
// retransmitted on their own).
if h.active || !h.acked || h.deferAccept != 0 && e.stack.Clock().NowMonotonic().Sub(h.startTime) > h.deferAccept {
e.sendSynTCP(e.route, tcpFields{
id: e.TransportEndpointInfo.ID,
ttl: calculateTTL(e.route, e.ipv4TTL, e.ipv6HopLimit),
tos: e.sendTOS,
flags: h.flags,
seq: h.iss,
ack: h.ackNum,
rcvWnd: h.rcvWnd,
id: e.TransportEndpointInfo.ID,
ttl: calculateTTL(e.route, e.ipv4TTL, e.ipv6HopLimit),
tos: e.sendTOS,
flags: h.flags,
seq: h.iss,
ack: h.ackNum,
rcvWnd: h.rcvWnd,
expOptVal: e.getExperimentOptionValue(e.route),
}, h.sendSYNOpts)
// If we have ever retransmitted the SYN-ACK or
// SYN segment, we should only measure RTT if
@@ -633,6 +647,7 @@ func (h *handshake) retransmitHandlerLocked() tcpip.Error {
// to an established state given the last segment received from peer. It also
// initializes sender/receiver.
// +checklocks:h.ep.mu
// +checklocksalias:h.ep.snd.ep.mu=h.ep.mu
func (h *handshake) transitionToStateEstablishedLocked(s *segment) {
// Stop the SYN retransmissions now that handshake is complete.
if h.retransmitTimer != nil {
@@ -701,6 +716,11 @@ func (bt *backoffTimer) reset() tcpip.Error {
return nil
}
func (bt *backoffTimer) reinit(timeout time.Duration) {
bt.timeout = timeout
bt.t.Reset(bt.timeout)
}
func (bt *backoffTimer) stop() {
bt.t.Stop()
}
@@ -785,22 +805,27 @@ func makeSynOptions(opts header.TCPSynOptions) []byte {
// tcpFields is a struct to carry different parameters required by the
// send*TCP variant functions below.
type tcpFields struct {
id stack.TransportEndpointID
ttl uint8
tos uint8
flags header.TCPFlags
seq seqnum.Value
ack seqnum.Value
rcvWnd seqnum.Size
opts []byte
txHash uint32
df bool
id stack.TransportEndpointID
ttl uint8
tos uint8
flags header.TCPFlags
seq seqnum.Value
ack seqnum.Value
rcvWnd seqnum.Size
opts []byte
txHash uint32
df bool
expOptVal uint16
}
func (e *Endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOptions) tcpip.Error {
tf.opts = makeSynOptions(opts)
// We ignore SYN send errors and let the callers re-attempt send.
p := stack.NewPacketBuffer(stack.PacketBufferOptions{ReserveHeaderBytes: header.TCPMinimumSize + int(r.MaxHeaderLength()) + len(tf.opts)})
hdrSize := header.TCPMinimumSize + int(r.MaxHeaderLength()) + len(tf.opts)
if r.NetProto() == header.IPv6ProtocolNumber && tf.expOptVal != 0 {
hdrSize += header.IPv6ExperimentHdrLength
}
p := stack.NewPacketBuffer(stack.PacketBufferOptions{ReserveHeaderBytes: hdrSize})
defer p.DecRef()
if err := e.sendTCP(r, tf, p, stack.GSO{}); err != nil {
e.stats.SendErrors.SynSendToNetworkFailed.Increment()
@@ -872,6 +897,10 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso sta
// packet already has the truncated data.
shouldSplitPacket := i != n-1
if shouldSplitPacket {
if r.NetProto() == header.IPv6ProtocolNumber && tf.expOptVal != 0 {
// Reserve extra bytes for the experiment option.
hdrSize += header.IPv6ExperimentHdrLength
}
splitPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ReserveHeaderBytes: hdrSize})
splitPkt.Data().ReadFromPacketData(pkt.Data(), packetSize)
pkt = splitPkt
@@ -882,7 +911,13 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso sta
buildTCPHdr(r, tf, pkt, gso)
tf.seq = tf.seq.Add(seqnum.Size(packetSize))
pkt.GSOOptions = gso
if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos, DF: tf.df}, pkt); err != nil {
if err := r.WritePacket(stack.NetworkHeaderParams{
Protocol: ProtocolNumber,
TTL: tf.ttl,
TOS: tf.tos,
DF: tf.df,
ExperimentOptionValue: tf.expOptVal,
}, pkt); err != nil {
r.Stats().TCP.SegmentSendErrors.Increment()
if shouldSplitPacket {
pkt.DecRef()
@@ -914,7 +949,13 @@ func sendTCP(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso stack.GS
pkt.Owner = owner
buildTCPHdr(r, tf, pkt, gso)
if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos, DF: tf.df}, pkt); err != nil {
if err := r.WritePacket(stack.NetworkHeaderParams{
Protocol: ProtocolNumber,
TTL: tf.ttl,
TOS: tf.tos,
DF: tf.df,
ExperimentOptionValue: tf.expOptVal,
}, pkt); err != nil {
r.Stats().TCP.SegmentSendErrors.Increment()
return err
}
@@ -986,17 +1027,23 @@ func (e *Endpoint) sendRaw(pkt *stack.PacketBuffer, flags header.TCPFlags, seq,
}
options := e.makeOptions(sackBlocks)
defer putOptions(options)
pkt.ReserveHeaderBytes(header.TCPMinimumSize + int(e.route.MaxHeaderLength()) + len(options))
hdrSize := header.TCPMinimumSize + int(e.route.MaxHeaderLength()) + len(options)
expOptVal := e.getExperimentOptionValue(e.route)
if e.route.NetProto() == header.IPv6ProtocolNumber && expOptVal != 0 {
hdrSize += header.IPv6ExperimentHdrLength
}
pkt.ReserveHeaderBytes(hdrSize)
return e.sendTCP(e.route, tcpFields{
id: e.TransportEndpointInfo.ID,
ttl: calculateTTL(e.route, e.ipv4TTL, e.ipv6HopLimit),
tos: e.sendTOS,
flags: flags,
seq: seq,
ack: ack,
rcvWnd: rcvWnd,
opts: options,
df: e.pmtud == tcpip.PMTUDiscoveryWant || e.pmtud == tcpip.PMTUDiscoveryDo,
id: e.TransportEndpointInfo.ID,
ttl: calculateTTL(e.route, e.ipv4TTL, e.ipv6HopLimit),
tos: e.sendTOS,
flags: flags,
seq: seq,
ack: ack,
rcvWnd: rcvWnd,
opts: options,
df: e.pmtud == tcpip.PMTUDiscoveryWant || e.pmtud == tcpip.PMTUDiscoveryDo,
expOptVal: expOptVal,
}, pkt, e.gso)
}
@@ -1017,9 +1064,9 @@ func (e *Endpoint) sendData(next *segment) {
// resetConnectionLocked puts the endpoint in an error state with the given
// error code and sends a RST if and only if the error is not ErrConnectionReset
// indicating that the connection is being reset due to receiving a RST. This
// method must only be called from the protocol goroutine.
// indicating that the connection is being reset due to receiving a RST.
// +checklocks:e.mu
// +checklocksalias:e.snd.ep.mu=e.mu
func (e *Endpoint) resetConnectionLocked(err tcpip.Error) {
// Only send a reset if the connection is being aborted for a reason
// other than receiving a reset.
@@ -1152,10 +1199,6 @@ func (e *Endpoint) handleReset(s *segment) (ok bool, err tcpip.Error) {
// except SYN-SENT, all reset (RST) segments are
// validated by checking their SEQ-fields." So
// we only process it if it's acceptable.
// Notify protocol goroutine. This is required when
// handleSegment is invoked from the processor goroutine
// rather than the worker goroutine.
return false, &tcpip.ErrConnectionReset{}
}
}
@@ -1210,7 +1253,7 @@ func (e *Endpoint) handleSegmentsLocked() tcpip.Error {
// +checklocks:e.mu
func (e *Endpoint) probeSegmentLocked() {
if fn := e.probe; fn != nil {
var state stack.TCPEndpointState
var state TCPEndpointState
e.completeStateLocked(&state)
fn(&state)
}
@@ -1277,12 +1320,7 @@ func (e *Endpoint) handleSegmentLocked(s *segment) (cont bool, err tcpip.Error)
state := e.EndpointState()
if state == StateClose {
// When we get into StateClose while processing from the queue,
// return immediately and let the protocolMainloop handle it.
//
// We can reach StateClose only while processing a previous segment
// or a notification from the protocolMainLoop (caller goroutine).
// This means that with this return, the segment dequeue below can
// never occur on a closed endpoint.
// return immediately and let the TCP processors handle it.
return false, nil
}
@@ -1336,6 +1374,9 @@ func (e *Endpoint) keepaliveTimerExpired() tcpip.Error {
// resetKeepaliveTimer restarts or stops the keepalive timer, depending on
// whether it is enabled for this endpoint.
//
// +checklocks:e.mu
// +checklocksalias:e.snd.ep.mu=e.mu
func (e *Endpoint) resetKeepaliveTimer(receivedData bool) {
e.keepalive.Lock()
defer e.keepalive.Unlock()

View File

@@ -19,7 +19,6 @@ import (
"time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
// effectivelyInfinity is an initialization value used for round-trip times
@@ -58,7 +57,7 @@ const (
// See: https://tools.ietf.org/html/rfc8312.
// +stateify savable
type cubicState struct {
stack.TCPCubicState
TCPCubicState
// numCongestionEvents tracks the number of congestion events since last
// RTO.
@@ -69,10 +68,12 @@ type cubicState struct {
// newCubicCC returns a partially initialized cubic state with the constants
// beta and c set and t set to current time.
//
// +checklocks:s.ep.mu
func newCubicCC(s *sender) *cubicState {
now := s.ep.stack.Clock().NowMonotonic()
return &cubicState{
TCPCubicState: stack.TCPCubicState{
TCPCubicState: TCPCubicState{
T: now,
Beta: 0.7,
C: 0.4,
@@ -94,6 +95,8 @@ func newCubicCC(s *sender) *cubicState {
// previously lowered ssThresh without experiencing packet loss.
//
// Refer: https://tools.ietf.org/html/rfc8312#section-4.8
//
// +checklocks:c.s.ep.mu
func (c *cubicState) enterCongestionAvoidance() {
// See: https://tools.ietf.org/html/rfc8312#section-4.7 &
// https://tools.ietf.org/html/rfc8312#section-4.8
@@ -117,6 +120,8 @@ func (c *cubicState) enterCongestionAvoidance() {
// increase'). The RFC version includes only the latter algorithm and adds an
// intermediate phase called Conservative Slow Start, which is not implemented
// here.
//
// +checklocks:c.s.ep.mu
func (c *cubicState) updateHyStart(rtt time.Duration) {
if rtt < 0 {
// negative indicates unknown
@@ -152,6 +157,7 @@ func (c *cubicState) updateHyStart(rtt time.Duration) {
}
}
// +checklocks:c.s.ep.mu
func (c *cubicState) beginHyStartRound(now tcpip.MonotonicTime) {
c.EndSeq = c.s.SndNxt
c.SampleCount = 0
@@ -165,6 +171,8 @@ func (c *cubicState) beginHyStartRound(now tcpip.MonotonicTime) {
// algorithm used by NewReno. If after adjusting the congestion window we cross
// the ssThresh then it will return the number of packets that must be consumed
// in congestion avoidance mode.
//
// +checklocks:c.s.ep.mu
func (c *cubicState) updateSlowStart(packetsAcked int) int {
// Don't let the congestion window cross into the congestion
// avoidance range.
@@ -187,6 +195,8 @@ func (c *cubicState) updateSlowStart(packetsAcked int) int {
// Update updates cubic's internal state variables. It must be called on every
// ACK received.
// Refer: https://tools.ietf.org/html/rfc8312#section-4
//
// +checklocks:c.s.ep.mu
func (c *cubicState) Update(packetsAcked int, rtt time.Duration) {
if c.s.Ssthresh == InitialSsthresh && c.s.SndCwnd < c.s.Ssthresh {
c.updateHyStart(rtt)
@@ -247,6 +257,8 @@ func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int
}
// HandleLossDetected implements congestionControl.HandleLossDetected.
//
// +checklocks:c.s.ep.mu
func (c *cubicState) HandleLossDetected() {
// See: https://tools.ietf.org/html/rfc8312#section-4.5
c.numCongestionEvents++
@@ -259,6 +271,8 @@ func (c *cubicState) HandleLossDetected() {
}
// HandleRTOExpired implements congestionContrl.HandleRTOExpired.
//
// +checklocks:c.s.ep.mu
func (c *cubicState) HandleRTOExpired() {
// See: https://tools.ietf.org/html/rfc8312#section-4.6
c.T = c.s.ep.stack.Clock().NowMonotonic()
@@ -297,6 +311,8 @@ func (c *cubicState) PostRecovery() {
// reduceSlowStartThreshold returns new SsThresh as described in
// https://tools.ietf.org/html/rfc8312#section-4.7.
//
// +checklocks:c.s.ep.mu
func (c *cubicState) reduceSlowStartThreshold() {
c.s.Ssthresh = int(math.Max(float64(c.s.SndCwnd)*c.Beta, 2.0))
}

View File

@@ -78,9 +78,8 @@ func (q *epQueue) empty() bool {
//
// +stateify savable
type processor struct {
epQ epQueue
sleeper sleep.Sleeper
// TODO(b/341946753): Restore them when netstack is savable.
epQ epQueue
sleeper sleep.Sleeper `state:"nosave"`
newEndpointWaker sleep.Waker `state:"nosave"`
closeWaker sleep.Waker `state:"nosave"`
pauseWaker sleep.Waker `state:"nosave"`
@@ -381,9 +380,18 @@ func (d *dispatcher) init(rng *rand.Rand, nProcessors int) {
d.mu.Lock()
defer d.mu.Unlock()
d.closed = false
d.processors = make([]processor, nProcessors)
d.hasher = jenkinsHasher{seed: rng.Uint32()}
d.startLocked()
}
// +checklocks:d.mu
func (d *dispatcher) startLocked() {
if d.closed {
return
}
for i := range d.processors {
p := &d.processors[i]
p.sleeper.AddWaker(&p.newEndpointWaker)
@@ -399,6 +407,13 @@ func (d *dispatcher) init(rng *rand.Rand, nProcessors int) {
}
}
func (d *dispatcher) start() {
d.mu.Lock()
defer d.mu.Unlock()
d.startLocked()
}
// close closes a dispatcher and its processors.
func (d *dispatcher) close() {
d.mu.Lock()

View File

@@ -286,15 +286,11 @@ func (*Stats) IsEndpointStats() {}
// +stateify savable
type sndQueueInfo struct {
sndQueueMu sync.Mutex `state:"nosave"`
stack.TCPSndBufState
// sndWaker is used to signal the protocol goroutine when there may be
// segments that need to be sent.
sndWaker sleep.Waker `state:"manual"`
TCPSndBufState
}
// CloneState clones sq into other. It is not thread safe
func (sq *sndQueueInfo) CloneState(other *stack.TCPSndBufState) {
func (sq *sndQueueInfo) CloneState(other *TCPSndBufState) {
other.SndBufSize = sq.SndBufSize
other.SndBufUsed = sq.SndBufUsed
other.SndClosed = sq.SndClosed
@@ -342,9 +338,12 @@ func (sq *sndQueueInfo) CloneState(other *stack.TCPSndBufState) {
// For more details please see the detailed documentation on
// e.LockUser/e.UnlockUser methods.
//
// TODO(b/339664055): Checklocks should be used more extensively here. Coverage
// is currently sparse.
//
// +stateify savable
type Endpoint struct {
stack.TCPEndpointStateInner
TCPEndpointStateInner
stack.TransportEndpointInfo
tcpip.DefaultSocketOptionsHandler
@@ -364,8 +363,8 @@ type Endpoint struct {
// The following fields are initialized at creation time and do not
// change throughout the lifetime of the endpoint.
stack *stack.Stack `state:"manual"`
protocol *protocol `state:"manual"`
stack *stack.Stack
protocol *protocol
waiterQueue *waiter.Queue `state:"wait"`
// hardError is meaningful only when state is stateError. It stores the
@@ -381,7 +380,7 @@ type Endpoint struct {
rcvQueueMu sync.Mutex `state:"nosave"`
// +checklocks:rcvQueueMu
stack.TCPRcvBufState
TCPRcvBufState
// rcvMemUsed tracks the total amount of memory in use by received segments
// held in rcvQueue, pendingRcvdSegments and the segment queue. This is used to
@@ -416,10 +415,10 @@ type Endpoint struct {
// state.
origEndpointState uint32 `state:"nosave"`
isPortReserved bool `state:"manual"`
isRegistered bool `state:"manual"`
isPortReserved bool
isRegistered bool
boundNICID tcpip.NICID
route *stack.Route `state:"manual"`
route *stack.Route `state:"nosave"`
ipv4TTL uint8
ipv6HopLimit int16
isConnectNotified bool
@@ -524,10 +523,9 @@ type Endpoint struct {
// +checklocks:acceptMu
acceptQueue acceptQueue
// The following are only used from the protocol goroutine, and
// therefore don't need locks to protect them.
rcv *receiver `state:"wait"`
snd *sender `state:"wait"`
snd *sender `state:"wait"`
// The goroutine drain completion notification channel.
drainDone chan struct{} `state:"nosave"`
@@ -541,7 +539,7 @@ type Endpoint struct {
// probe if not nil is invoked on every received segment. It is passed
// a copy of the current state of the endpoint.
probe stack.TCPProbeFunc `state:"nosave"`
probe TCPProbeFunc `state:"nosave"`
// The following are only used to assist the restore run to re-connect.
connectingAddress tcpip.Address
@@ -638,6 +636,7 @@ func (e *Endpoint) isOwnedByUser() bool {
// should not be holding the lock for long and spinning reduces latency as we
// avoid an expensive sleep/wakeup of the syscall goroutine).
// +checklocksacquire:e.mu
// +checklocksacquire:e.snd.ep.mu
func (e *Endpoint) LockUser() {
const iterations = 5
for i := 0; i < iterations; i++ {
@@ -650,14 +649,14 @@ func (e *Endpoint) LockUser() {
if e.ownedByUser.Load() == 1 {
e.mu.Lock()
e.ownedByUser.Store(1)
return
return // +checklocksforce: this locks e.snd.ep.mu
}
// Spin but don't yield the processor since the lower half
// should yield the lock soon.
continue
}
e.ownedByUser.Store(1)
return
return // +checklocksforce: this locks e.snd.ep.mu
}
for i := 0; i < iterations; i++ {
@@ -670,7 +669,7 @@ func (e *Endpoint) LockUser() {
if e.ownedByUser.Load() == 1 {
e.mu.Lock()
e.ownedByUser.Store(1)
return
return // +checklocksforce: this locks e.snd.ep.mu
}
// Spin but yield the processor since the lower half
// should yield the lock soon.
@@ -678,7 +677,7 @@ func (e *Endpoint) LockUser() {
continue
}
e.ownedByUser.Store(1)
return
return // +checklocksforce: this locks e.snd.ep.mu
}
// Finally just give up and wait for the Lock.
@@ -849,7 +848,7 @@ func newEndpoint(s *stack.Stack, protocol *protocol, netProto tcpip.NetworkProto
TransProto: header.TCPProtocolNumber,
},
sndQueueInfo: sndQueueInfo{
TCPSndBufState: stack.TCPSndBufState{
TCPSndBufState: TCPSndBufState{
SndMTU: math.MaxInt32,
},
},
@@ -910,10 +909,7 @@ func newEndpoint(s *stack.Stack, protocol *protocol, netProto tcpip.NetworkProto
e.maxSynRetries = uint8(synRetries)
}
if p := s.GetTCPProbe(); p != nil {
e.probe = p
}
e.probe = protocol.probe
e.segmentQueue.ep = e
// TODO(https://gvisor.dev/issues/7493): Defer creating the timer until TCP connection becomes
@@ -1014,21 +1010,19 @@ func (e *Endpoint) purgeReadQueue() {
}
// +checklocks:e.mu
// +checklocksalias:e.snd.ep.mu=e.mu
func (e *Endpoint) purgeWriteQueue() {
if e.snd != nil {
e.sndQueueInfo.sndQueueMu.Lock()
defer e.sndQueueInfo.sndQueueMu.Unlock()
e.snd.updateWriteNext(nil)
for {
s := e.snd.writeList.Front()
if s == nil {
break
}
for s := e.snd.writeList.Front(); s != nil; s = e.snd.writeList.Front() {
e.snd.writeList.Remove(s)
s.DecRef()
}
e.sndQueueInfo.SndBufUsed = 0
e.sndQueueInfo.SndClosed = true
e.snd.SndNxt = e.snd.SndUna
}
}
@@ -1450,8 +1444,7 @@ func (e *Endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
if memDelta > 0 {
// If the window was small before this read and if the read freed up
// enough buffer space, to either fit an aMSS or half a receive buffer
// (whichever smaller), then notify the protocol goroutine to send a
// window update.
// (whichever smaller), then send a window update.
if crossed, above := e.windowCrossedACKThresholdLocked(memDelta, int(e.ops.GetReceiveBufferSize())); crossed && above {
sendNonZeroWindowUpdate = true
}
@@ -1595,6 +1588,7 @@ func (e *Endpoint) readFromPayloader(p tcpip.Payloader, opts tcpip.WriteOptions,
// queueSegment reads data from the payloader and returns a segment to be sent.
// +checklocks:e.mu
// +checklocksalias:e.snd.ep.mu=e.mu
func (e *Endpoint) queueSegment(p tcpip.Payloader, opts tcpip.WriteOptions) (*segment, int, tcpip.Error) {
e.sndQueueInfo.sndQueueMu.Lock()
defer e.sndQueueInfo.sndQueueMu.Unlock()
@@ -2385,6 +2379,7 @@ func (e *Endpoint) registerEndpoint(addr tcpip.FullAddress, netProto tcpip.Netwo
// connect connects the endpoint to its peer.
// +checklocks:e.mu
// +checklocksalias:e.snd.ep.mu=e.mu
func (e *Endpoint) connect(addr tcpip.FullAddress, handshake bool) tcpip.Error {
connectingAddr := addr.Addr
@@ -2472,14 +2467,12 @@ func (e *Endpoint) connect(addr tcpip.FullAddress, handshake bool) tcpip.Error {
// connection setting here.
if !handshake {
e.segmentQueue.mu.Lock()
for _, l := range []segmentList{e.segmentQueue.list, e.snd.writeList} {
for _, l := range []segmentList{e.segmentQueue.list, e.snd.writeList.writeList} {
for s := l.Front(); s != nil; s = s.Next() {
s.id = e.TransportEndpointInfo.ID
e.sndQueueInfo.sndWaker.Assert()
}
}
e.segmentQueue.mu.Unlock()
e.snd.ep.AssertLockHeld(e)
e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0)
e.setEndpointState(StateEstablished)
// Set the new auto tuned send buffer size after entering
@@ -2522,6 +2515,7 @@ func (e *Endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
}
// +checklocks:e.mu
// +checklocksalias:e.snd.ep.mu=e.mu
func (e *Endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error {
e.shutdownFlags |= flags
switch {
@@ -2962,8 +2956,11 @@ func (e *Endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB
}
}
// updateSndBufferUsage is called by the protocol goroutine when room opens up
// in the send buffer. The number of newly available bytes is v.
// updateSndBufferUsage is called by when room opens up in the send buffer. The
// number of newly available bytes is v.
//
// +checklocks:e.mu
// +checklocksalias:e.snd.ep.mu=e.mu
func (e *Endpoint) updateSndBufferUsage(v int) {
sendBufferSize := e.getSendBufferSize()
e.sndQueueInfo.sndQueueMu.Lock()
@@ -2987,9 +2984,8 @@ func (e *Endpoint) updateSndBufferUsage(v int) {
}
}
// readyToRead is called by the protocol goroutine when a new segment is ready
// to be read, or when the connection is closed for receiving (in which case
// s will be nil).
// readyToRead is called when a new segment is ready to be read, or when the
// connection is closed for receiving (in which case s will be nil).
//
// +checklocks:e.mu
func (e *Endpoint) readyToRead(s *segment) {
@@ -3146,9 +3142,10 @@ func (e *Endpoint) maxOptionSize() (size int) {
// used before invoking the probe.
//
// +checklocks:e.mu
func (e *Endpoint) completeStateLocked(s *stack.TCPEndpointState) {
// +checklocksalias:e.snd.ep.mu=e.mu
func (e *Endpoint) completeStateLocked(s *TCPEndpointState) {
s.TCPEndpointStateInner = e.TCPEndpointStateInner
s.ID = stack.TCPEndpointID(e.TransportEndpointInfo.ID)
s.ID = TCPEndpointID(e.TransportEndpointInfo.ID)
s.SegTime = e.stack.Clock().NowMonotonic()
s.Receiver = e.rcv.TCPReceiverState
s.Sender = e.snd.TCPSenderState
@@ -3295,6 +3292,9 @@ func GetTCPReceiveBufferLimits(s tcpip.StackHandler) tcpip.ReceiveBufferSizeOpti
// computeTCPSendBufferSize implements auto tuning of send buffer size and
// returns the new send buffer size.
//
// +checklocks:e.mu
// +checklocksalias:e.snd.ep.mu=e.mu
func (e *Endpoint) computeTCPSendBufferSize() int64 {
curSndBufSz := int64(e.getSendBufferSize())
@@ -3330,3 +3330,12 @@ func (e *Endpoint) computeTCPSendBufferSize() int64 {
func (e *Endpoint) GetAcceptConn() bool {
return EndpointState(e.State()) == StateListen
}
// getExperimentOptionValue returns the experiment option value set on the
// endpoint if experiment IP options are enabled on outgoing NIC of the route.
func (e *Endpoint) getExperimentOptionValue(route *stack.Route) uint16 {
if nic, err := e.stack.GetNICByID(route.OutgoingNIC()); err == nil && nic.GetExperimentIPOptionEnabled() {
return e.ops.GetExperimentOptionValue()
}
return 0
}

View File

@@ -19,13 +19,24 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/atomicbitops"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
// logDisconnectOnce ensures we don't spam logs when many connections are terminated.
var logDisconnectOnce sync.Once
func logDisconnect() {
logDisconnectOnce.Do(func() {
log.Infof("One or more TCP connections terminated during save")
})
}
// beforeSave is invoked by stateify.
func (e *Endpoint) beforeSave() {
// Stop incoming packets.
@@ -44,6 +55,7 @@ func (e *Endpoint) beforeSave() {
Err: fmt.Errorf("endpoint cannot be saved in connected state: local %s:%d, remote %s:%d", e.TransportEndpointInfo.ID.LocalAddress, e.TransportEndpointInfo.ID.LocalPort, e.TransportEndpointInfo.ID.RemoteAddress, e.TransportEndpointInfo.ID.RemotePort),
})
}
logDisconnect()
e.resetConnectionLocked(&tcpip.ErrConnectionAborted{})
e.mu.Unlock()
e.Close()
@@ -95,8 +107,9 @@ var connectingLoading sync.WaitGroup
func (e *Endpoint) loadState(_ context.Context, epState EndpointState) {
// This is to ensure that the loading wait groups include all applicable
// endpoints before any asynchronous calls to the Wait() methods.
// For restore purposes we treat TimeWait like a connected endpoint.
if epState.connected() || epState == StateTimeWait {
// For restore purposes we treat all endpoints with state after
// StateEstablished and before StateClosed like connected endpoint.
if epState.connected() {
connectedLoading.Add(1)
}
switch {
@@ -118,7 +131,11 @@ func (e *Endpoint) afterLoad(ctx context.Context) {
// Restore the endpoint to InitialState as it will be moved to
// its origEndpointState during Restore.
e.state = atomicbitops.FromUint32(uint32(StateInitial))
stack.RestoreStackFromContext(ctx).RegisterRestoredEndpoint(e)
if e.stack.IsSaveRestoreEnabled() {
e.stack.RegisterRestoredEndpoint(e)
} else {
stack.RestoreStackFromContext(ctx).RegisterRestoredEndpoint(e)
}
}
// Restore implements tcpip.RestoredEndpoint.Restore.
@@ -132,29 +149,34 @@ func (e *Endpoint) Restore(s *stack.Stack) {
snd.probeTimer.init(s.Clock(), timerHandler(e, e.snd.probeTimerExpired))
snd.corkTimer.init(s.Clock(), timerHandler(e, e.snd.corkTimerExpired))
}
e.stack = s
e.protocol = protocolFromStack(s)
saveRestoreEnabled := e.stack.IsSaveRestoreEnabled()
if !saveRestoreEnabled {
e.stack = s
e.protocol = protocolFromStack(s)
}
e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits)
e.segmentQueue.thaw()
bind := func() {
e.mu.Lock()
defer e.mu.Unlock()
addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.TransportEndpointInfo.ID.LocalPort}, true /* bind */)
if err != nil {
panic("unable to parse BindAddr: " + err.String())
}
portRes := ports.Reservation{
Networks: e.effectiveNetProtos,
Transport: ProtocolNumber,
Addr: addr.Addr,
Port: addr.Port,
Flags: e.boundPortFlags,
BindToDevice: e.boundBindToDevice,
Dest: e.boundDest,
}
if ok := e.stack.ReserveTuple(portRes); !ok {
panic(fmt.Sprintf("unable to re-reserve tuple (%v, %q, %d, %+v, %d, %v)", e.effectiveNetProtos, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest))
if !saveRestoreEnabled {
addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.TransportEndpointInfo.ID.LocalPort}, true /* bind */)
if err != nil {
panic("unable to parse BindAddr: " + err.String())
}
portRes := ports.Reservation{
Networks: e.effectiveNetProtos,
Transport: ProtocolNumber,
Addr: addr.Addr,
Port: addr.Port,
Flags: e.boundPortFlags,
BindToDevice: e.boundBindToDevice,
Dest: e.boundDest,
}
if ok := e.stack.ReserveTuple(portRes); !ok {
panic(fmt.Sprintf("unable to re-reserve tuple (%v, %q, %d, %+v, %d, %v)", e.effectiveNetProtos, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest))
}
}
e.isPortReserved = true
@@ -182,6 +204,10 @@ func (e *Endpoint) Restore(s *stack.Stack) {
// Reset the scoreboard to reinitialize the sack information as
// we do not restore SACK information.
e.scoreboard.Reset()
if saveRestoreEnabled {
// Unregister the endpoint before registering again during Connect.
e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, header.TCPProtocolNumber, e.TransportEndpointInfo.ID, e, e.boundPortFlags, e.boundBindToDevice)
}
e.mu.Lock()
err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort}, false /* handshake */)
if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
@@ -206,23 +232,39 @@ func (e *Endpoint) Restore(s *stack.Stack) {
connectedLoading.Done()
case epState == StateListen:
tcpip.AsyncLoading.Add(1)
go func() {
connectedLoading.Wait()
bind()
e.acceptMu.Lock()
backlog := e.acceptQueue.capacity
e.acceptMu.Unlock()
if err := e.Listen(backlog); err != nil {
panic("endpoint listening failed: " + err.String())
}
e.LockUser()
if e.shutdownFlags != 0 {
e.shutdownLocked(e.shutdownFlags)
}
e.UnlockUser()
listenLoading.Done()
tcpip.AsyncLoading.Done()
}()
if !saveRestoreEnabled {
go func() {
connectedLoading.Wait()
bind()
e.acceptMu.Lock()
backlog := e.acceptQueue.capacity
e.acceptMu.Unlock()
if err := e.Listen(backlog); err != nil {
panic("endpoint listening failed: " + err.String())
}
e.LockUser()
if e.shutdownFlags != 0 {
e.shutdownLocked(e.shutdownFlags)
}
e.UnlockUser()
listenLoading.Done()
tcpip.AsyncLoading.Done()
}()
} else {
go func() {
connectedLoading.Wait()
e.LockUser()
// All endpoints will be moved to initial state after
// restore. Set endpoint to its originial listen state.
e.setEndpointState(StateListen)
// Initialize the listening context.
rcvWnd := seqnum.Size(e.receiveBufferAvailable())
e.listenCtx = newListenContext(e.stack, e.protocol, e, rcvWnd, e.ops.GetV6Only(), e.NetProto)
e.UnlockUser()
listenLoading.Done()
tcpip.AsyncLoading.Done()
}()
}
case epState == StateConnecting:
// Initial SYN hasn't been sent yet so initiate a connect.
tcpip.AsyncLoading.Add(1)
@@ -238,26 +280,30 @@ func (e *Endpoint) Restore(s *stack.Stack) {
tcpip.AsyncLoading.Done()
}()
case epState == StateSynSent || epState == StateSynRecv:
connectedLoading.Wait()
listenLoading.Wait()
// Initial SYN has been sent/received so we should bind the
// ports start the retransmit timer for the SYNs and let it
// naturally complete the connection.
bind()
e.mu.Lock()
defer e.mu.Unlock()
e.setEndpointState(epState)
r, err := e.stack.FindRoute(e.boundNICID, e.TransportEndpointInfo.ID.LocalAddress, e.TransportEndpointInfo.ID.RemoteAddress, e.effectiveNetProtos[0], false /* multicastLoop */)
if err != nil {
panic(fmt.Sprintf("FindRoute failed when restoring endpoint w/ ID: %+v", e.ID))
}
e.route = r
timer, err := newBackoffTimer(e.stack.Clock(), InitialRTO, MaxRTO, timerHandler(e, e.h.retransmitHandlerLocked))
if err != nil {
panic(fmt.Sprintf("newBackOffTimer(_, %s, %s, _) failed: %s", InitialRTO, MaxRTO, err))
}
e.h.retransmitTimer = timer
connectingLoading.Done()
tcpip.AsyncLoading.Add(1)
go func() {
connectedLoading.Wait()
listenLoading.Wait()
// Initial SYN has been sent/received so we should bind the
// ports start the retransmit timer for the SYNs and let it
// naturally complete the connection.
bind()
e.mu.Lock()
defer e.mu.Unlock()
e.setEndpointState(epState)
r, err := e.stack.FindRoute(e.boundNICID, e.TransportEndpointInfo.ID.LocalAddress, e.TransportEndpointInfo.ID.RemoteAddress, e.effectiveNetProtos[0], false /* multicastLoop */)
if err != nil {
panic(fmt.Sprintf("FindRoute failed when restoring endpoint w/ ID: %+v", e.ID))
}
e.route = r
timer, err := newBackoffTimer(e.stack.Clock(), InitialRTO, MaxRTO, timerHandler(e, e.h.retransmitHandlerLocked))
if err != nil {
panic(fmt.Sprintf("newBackOffTimer(_, %s, %s, _) failed: %s", InitialRTO, MaxRTO, err))
}
e.h.retransmitTimer = timer
connectingLoading.Done()
tcpip.AsyncLoading.Done()
}()
case epState == StateBound:
tcpip.AsyncLoading.Add(1)
go func() {

View File

@@ -15,6 +15,9 @@
package tcp
import (
"fmt"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -170,3 +173,56 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
return ep, nil
}
// ForwardedPacketExperimentOption returns the experiment option value from the
// forwarded packet and a bool indicating whether an experiment option value was
// found.
func (r *ForwarderRequest) ForwardedPacketExperimentOption() (uint16, bool) {
r.mu.Lock()
defer r.mu.Unlock()
switch r.segment.pkt.NetworkProtocolNumber {
case header.IPv4ProtocolNumber:
h := header.IPv4(r.segment.pkt.NetworkHeader().Slice())
opts := h.Options()
iter := opts.MakeIterator()
for {
opt, done, err := iter.Next()
if err != nil {
return 0, false
}
if done {
return 0, false
}
if opt.Type() == header.IPv4OptionExperimentType {
return opt.(*header.IPv4OptionExperiment).Value(), true
}
}
case header.IPv6ProtocolNumber:
h := header.IPv6(r.segment.pkt.NetworkHeader().Slice())
v := r.segment.pkt.NetworkHeader().View()
if v != nil {
v.TrimFront(header.IPv6MinimumSize)
}
buf := buffer.MakeWithView(v)
buf.Append(r.segment.pkt.TransportHeader().View())
dataBuf := r.segment.pkt.Data().ToBuffer()
buf.Merge(&dataBuf)
it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), buf)
for {
hdr, done, err := it.Next()
if done || err != nil {
break
}
if h, ok := hdr.(header.IPv6ExperimentExtHdr); ok {
hdr.Release()
return h.Value, true
}
hdr.Release()
}
default:
panic(fmt.Sprintf("Unexpected network protocol number %d", r.segment.pkt.NetworkProtocolNumber))
}
return 0, false
}

View File

@@ -109,6 +109,12 @@ type protocol struct {
synRetries uint8
dispatcher dispatcher
// probe, if not nil, will be invoked any time an endpoint receives a
// TCP segment.
//
// This is immutable after creation.
probe TCPProbeFunc `state:"nosave"`
// The following secrets are initialized once and stay unchanged after.
seqnumSecret [16]byte
tsOffsetSecret [16]byte
@@ -227,16 +233,26 @@ func replyWithReset(st *stack.Stack, s *segment, tos, ipv4TTL uint8, ipv6HopLimi
ack = s.sequenceNumber.Add(s.logicalLen())
}
p := stack.NewPacketBuffer(stack.PacketBufferOptions{ReserveHeaderBytes: header.TCPMinimumSize + int(route.MaxHeaderLength())})
var expOptVal uint16
if s.ep != nil {
expOptVal = s.ep.getExperimentOptionValue(route)
}
hdrSize := header.TCPMinimumSize + int(route.MaxHeaderLength())
if route.NetProto() == header.IPv6ProtocolNumber && expOptVal != 0 {
hdrSize += header.IPv6ExperimentHdrLength
}
p := stack.NewPacketBuffer(stack.PacketBufferOptions{ReserveHeaderBytes: hdrSize})
defer p.DecRef()
return sendTCP(route, tcpFields{
id: s.id,
ttl: ttl,
tos: tos,
flags: flags,
seq: seq,
ack: ack,
rcvWnd: 0,
id: s.id,
ttl: ttl,
tos: tos,
flags: flags,
seq: seq,
ack: ack,
rcvWnd: 0,
expOptVal: expOptVal,
}, p, stack.GSO{}, nil /* PacketOwner */)
}
@@ -364,7 +380,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip
return nil
case *tcpip.TCPSynRetriesOption:
if *v < 1 || *v > 255 {
if *v < 1 {
return &tcpip.ErrInvalidOptionValue{}
}
p.mu.Lock()
@@ -508,6 +524,11 @@ func (p *protocol) Resume() {
p.dispatcher.resume()
}
// Restore implements stack.TransportProtocol.Restore.
func (p *protocol) Restore() {
p.dispatcher.start()
}
// Parse implements stack.TransportProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
return parse.TCP(pkt)
@@ -515,7 +536,19 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
// NewProtocol returns a TCP transport protocol with Reno congestion control.
func NewProtocol(s *stack.Stack) stack.TransportProtocol {
return newProtocol(s, ccReno)
return newProtocol(s, ccReno, nil)
}
// NewProtocolProbe returns a TCP transport protocol with Reno congestion
// control and the given probe.
//
// The probe will be invoked on every segment received by TCP endpoints. The
// probe function is passed a copy of the TCP endpoint state before and after
// processing of the segment.
func NewProtocolProbe(probe TCPProbeFunc) func(*stack.Stack) stack.TransportProtocol {
return func(s *stack.Stack) stack.TransportProtocol {
return newProtocol(s, ccReno, probe)
}
}
// NewProtocolCUBIC returns a TCP transport protocol with CUBIC congestion
@@ -523,10 +556,10 @@ func NewProtocol(s *stack.Stack) stack.TransportProtocol {
//
// TODO(b/345835636): Remove this and make CUBIC the default across the board.
func NewProtocolCUBIC(s *stack.Stack) stack.TransportProtocol {
return newProtocol(s, ccCubic)
return newProtocol(s, ccCubic, nil)
}
func newProtocol(s *stack.Stack, cc string) stack.TransportProtocol {
func newProtocol(s *stack.Stack, cc string, probe TCPProbeFunc) stack.TransportProtocol {
rng := s.SecureRNG()
var seqnumSecret [16]byte
var tsOffsetSecret [16]byte
@@ -562,6 +595,7 @@ func newProtocol(s *stack.Stack, cc string) stack.TransportProtocol {
recovery: tcpip.TCPRACKLossDetection,
seqnumSecret: seqnumSecret,
tsOffsetSecret: tsOffsetSecret,
probe: probe,
}
p.dispatcher.init(s.InsecureRNG(), runtime.GOMAXPROCS(0))
return &p

View File

@@ -19,7 +19,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
const (
@@ -47,7 +46,7 @@ const (
//
// +stateify savable
type rackControl struct {
stack.TCPRACKState
TCPRACKState
// exitedRecovery indicates if the connection is exiting loss recovery.
// This flag is set if the sender is leaving the recovery after
@@ -162,6 +161,8 @@ func (s *sender) shouldSchedulePTO() bool {
// schedulePTO schedules the probe timeout as defined in
// https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5.1.
//
// +checklocks:s.ep.mu
func (s *sender) schedulePTO() {
pto := time.Second
s.rtt.Lock()
@@ -237,6 +238,8 @@ func (s *sender) probeTimerExpired() tcpip.Error {
// detectTLPRecovery detects if recovery was accomplished by the loss probes
// and updates TLP state accordingly.
// See https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.6.3.
//
// +checklocks:s.ep.mu
func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) {
if !(s.ep.SACKPermitted && s.rc.tlpRxtOut) {
return
@@ -280,6 +283,8 @@ func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) {
// been observed RACK uses reo_wnd of zero during loss recovery, in order to
// retransmit quickly, or when the number of DUPACKs exceeds the classic
// DUPACKthreshold.
//
// +checklocks:rc.snd.ep.mu
func (rc *rackControl) updateRACKReorderWindow() {
dsackSeen := rc.DSACKSeen
snd := rc.snd
@@ -353,6 +358,8 @@ func (rc *rackControl) exitRecovery() {
// detectLoss marks the segment as lost if the reordering window has elapsed
// and the ACK is not received. It will also arm the reorder timer.
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 Step 5.
//
// +checklocks:rc.snd.ep.mu
func (rc *rackControl) detectLoss(rcvTime tcpip.MonotonicTime) int {
var timeout time.Duration
numLost := 0

View File

@@ -21,7 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
// receiver holds the state necessary to receive TCP segments and turn them
@@ -29,7 +28,7 @@ import (
//
// +stateify savable
type receiver struct {
stack.TCPReceiverState
TCPReceiverState
ep *Endpoint
// rcvWnd is the non-scaled receive window last advertised to the peer.
@@ -55,7 +54,7 @@ type receiver struct {
func newReceiver(ep *Endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver {
return &receiver{
ep: ep,
TCPReceiverState: stack.TCPReceiverState{
TCPReceiverState: TCPReceiverState{
RcvNxt: irs + 1,
RcvAcc: irs.Add(rcvWnd + 1),
RcvWndScale: rcvWndScale,
@@ -97,6 +96,7 @@ func (r *receiver) currentWindow() (curWnd seqnum.Size) {
// getSendParams returns the parameters needed by the sender when building
// segments to send.
// +checklocks:r.ep.mu
// +checklocksalias:r.ep.snd.ep.mu=r.ep.mu
func (r *receiver) getSendParams() (RcvNxt seqnum.Value, rcvWnd seqnum.Size) {
newWnd := r.ep.selectWindow()
curWnd := r.currentWindow()

View File

@@ -35,6 +35,8 @@ func newRenoCC(s *sender) *renoState {
// algorithm used by NewReno. If after adjusting the congestion window
// we cross the SSthreshold then it will return the number of packets that
// must be consumed in congestion avoidance mode.
//
// +checklocks:r.s.ep.mu
func (r *renoState) updateSlowStart(packetsAcked int) int {
// Don't let the congestion window cross into the congestion
// avoidance range.
@@ -51,6 +53,8 @@ func (r *renoState) updateSlowStart(packetsAcked int) int {
// updateCongestionAvoidance will update congestion window in congestion
// avoidance mode as described in RFC5681 section 3.1
//
// +checklocks:r.s.ep.mu
func (r *renoState) updateCongestionAvoidance(packetsAcked int) {
// Consume the packets in congestion avoidance mode.
r.s.SndCAAckCount += packetsAcked
@@ -62,6 +66,8 @@ func (r *renoState) updateCongestionAvoidance(packetsAcked int) {
// reduceSlowStartThreshold reduces the slow-start threshold per RFC 5681,
// page 6, eq. 4. It is called when we detect congestion in the network.
//
// +checklocks:r.s.ep.mu
func (r *renoState) reduceSlowStartThreshold() {
r.s.Ssthresh = r.s.Outstanding / 2
if r.s.Ssthresh < 2 {
@@ -73,6 +79,8 @@ func (r *renoState) reduceSlowStartThreshold() {
// Update updates the congestion state based on the number of packets that
// were acknowledged.
// Update implements congestionControl.Update.
//
// +checklocks:r.s.ep.mu
func (r *renoState) Update(packetsAcked int, _ time.Duration) {
if r.s.SndCwnd < r.s.Ssthresh {
packetsAcked = r.updateSlowStart(packetsAcked)
@@ -84,6 +92,8 @@ func (r *renoState) Update(packetsAcked int, _ time.Duration) {
}
// HandleLossDetected implements congestionControl.HandleLossDetected.
//
// +checklocks:r.s.ep.mu
func (r *renoState) HandleLossDetected() {
// A retransmit was triggered due to nDupAckThreshold or when RACK
// detected loss. Reduce our slow start threshold.
@@ -91,6 +101,8 @@ func (r *renoState) HandleLossDetected() {
}
// HandleRTOExpired implements congestionControl.HandleRTOExpired.
//
// +checklocks:r.s.ep.mu
func (r *renoState) HandleRTOExpired() {
// We lost a packet, so reduce ssthresh.
r.reduceSlowStartThreshold()

View File

@@ -130,7 +130,7 @@ func newIncomingSegment(id stack.TransportEndpointID, clock tcpip.Clock, pkt *st
s.window = seqnum.Size(hdr.WindowSize())
s.rcvdTime = clock.NowMonotonic()
s.dataMemSize = pkt.MemSize()
s.pkt = pkt.IncRef()
s.pkt = pkt.Clone()
s.csumValid = csumValid
if !s.pkt.RXChecksumValidated {

View File

@@ -99,7 +99,9 @@ type lossRecovery interface {
//
// +stateify savable
type sender struct {
stack.TCPSenderState
// +checklocks:ep.mu
TCPSenderState
ep *Endpoint
// lr is the loss recovery algorithm used by the sender.
@@ -124,7 +126,9 @@ type sender struct {
// writeList holds all writable data: both unsent data and
// sent-but-unacknowledged data. Alternatively: it holds all bytes
// starting from SND.UNA.
writeList segmentList
//
// +checklocks:ep.mu
writeList protectedWriteList
// resendTimer is used for RTOs.
resendTimer timer `state:"nosave"`
@@ -180,6 +184,54 @@ type sender struct {
corkTimer timer `state:"nosave"`
}
// protectedWriteList wraps the write list, checking for invalid state when
// segments are added or removed.
//
// TODO(b/339664055): Revert once bug is fixed.
//
// +stateify savable
type protectedWriteList struct {
writeList segmentList
set map[*segment]struct{}
}
// Front returns the front of the write list.
func (wl *protectedWriteList) Front() *segment {
return wl.writeList.Front()
}
// Back returns the back of the write list.
func (wl *protectedWriteList) Back() *segment {
return wl.writeList.Back()
}
// Remove removes seg from the write list.
func (wl *protectedWriteList) Remove(seg *segment) {
if _, ok := wl.set[seg]; !ok {
panic("segment not found write list")
}
wl.writeList.Remove(seg)
delete(wl.set, seg)
}
// PushBack pushes seg onto the back of the write list.
func (wl *protectedWriteList) PushBack(seg *segment) {
if _, ok := wl.set[seg]; ok {
panic("segment already in write list")
}
wl.writeList.PushBack(seg)
wl.set[seg] = struct{}{}
}
// InsertAfter inserts seg after before.
func (wl *protectedWriteList) InsertAfter(before, seg *segment) {
if _, ok := wl.set[seg]; ok {
panic("segment already in write list")
}
wl.writeList.InsertAfter(before, seg)
wl.set[seg] = struct{}{}
}
// rtt is a synchronization wrapper used to appease stateify. See the comment
// in sender, where it is used.
//
@@ -187,7 +239,7 @@ type sender struct {
type rtt struct {
sync.Mutex `state:"nosave"`
stack.TCPRTTState
TCPRTTState
}
// +checklocks:ep.mu
@@ -199,7 +251,7 @@ func newSender(ep *Endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
s := &sender{
ep: ep,
TCPSenderState: stack.TCPSenderState{
TCPSenderState: TCPSenderState{
SndWnd: sndWnd,
SndUna: iss + 1,
SndNxt: iss + 1,
@@ -207,7 +259,7 @@ func newSender(ep *Endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
LastSendTime: ep.stack.Clock().NowMonotonic(),
MaxPayloadSize: maxPayloadSize,
MaxSentAck: irs + 1,
FastRecovery: stack.TCPFastRecoveryState{
FastRecovery: TCPFastRecoveryState{
// See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 1.
Last: iss,
HighRxt: iss,
@@ -216,8 +268,18 @@ func newSender(ep *Endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
RTO: 1 * time.Second,
},
gso: ep.gso.Type != stack.GSONone,
writeList: protectedWriteList{
set: make(map[*segment]struct{}),
},
}
return newSenderHelper(ep, iss, irs, sndWnd, mss, sndWndScale, maxPayloadSize, s)
}
// newSenderHelper exists to sate checklocks.
//
// +checklocks:ep.mu
// +checklocksalias:s.ep.mu=ep.mu
func newSenderHelper(ep *Endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint16, sndWndScale int, maxPayloadSize int, s *sender) *sender {
if s.gso {
s.ep.gso.MSS = uint16(maxPayloadSize)
}
@@ -237,7 +299,6 @@ func newSender(ep *Endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
s.probeTimer.init(s.ep.stack.Clock(), timerHandler(s.ep, s.probeTimerExpired))
s.corkTimer.init(s.ep.stack.Clock(), timerHandler(s.ep, s.corkTimerExpired))
s.ep.AssertLockHeld(ep)
s.updateMaxPayloadSize(int(ep.route.MTU()), 0)
// Initialize SACK Scoreboard after updating max payload size as we use
// the maxPayloadSize as the smss when determining if a segment is lost
@@ -269,6 +330,8 @@ func newSender(ep *Endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
// initCongestionControl initializes the specified congestion control module and
// returns a handle to it. It also initializes the sndCwnd and sndSsThresh to
// their initial values.
//
// +checklocks:s.ep.mu
func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionControlOption) congestionControl {
s.SndCwnd = InitialCwnd
s.Ssthresh = InitialSsthresh
@@ -369,6 +432,8 @@ func (s *sender) sendAck() {
// updateRTO updates the retransmit timeout when a new roud-trip time is
// available. This is done in accordance with section 2 of RFC 6298.
//
// +checklocks:s.ep.mu
func (s *sender) updateRTO(rtt time.Duration) {
s.rtt.Lock()
if !s.rtt.TCPRTTState.SRTTInited {
@@ -418,6 +483,7 @@ func (s *sender) updateRTO(rtt time.Duration) {
}
s.RTO = s.rtt.TCPRTTState.SRTT + 4*s.rtt.TCPRTTState.RTTVar
s.RTTState = s.rtt.TCPRTTState
s.rtt.Unlock()
if s.RTO < s.minRTO {
s.RTO = s.minRTO
@@ -614,6 +680,8 @@ func (s *sender) pCount(seg *segment, maxPayloadSize int) int {
// splitSeg splits a given segment at the size specified and inserts the
// remainder as a new segment after the current one in the write list.
//
// +checklocks:s.ep.mu
func (s *sender) splitSeg(seg *segment, size int) {
if seg.payloadSize() <= size {
return
@@ -649,6 +717,8 @@ func (s *sender) splitSeg(seg *segment, size int) {
//
// rescueRtx will be true only if nextSeg is a rescue retransmission as
// described by Step 4) of the NextSeg algorithm.
//
// +checklocks:s.ep.mu
func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRtx bool) {
var s3 *segment
var s4 *segment
@@ -911,6 +981,13 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
segEnd = seg.sequenceNumber.Add(seqnum.Size(seg.payloadSize()))
}
// TODO(b/379932042): Below is the only place we update SND.NXT besides
// initialization. It's possible that we're increasing SND.NXT by
// trying to write a segment that isn't in the write list.
if _, ok := s.writeList.set[seg]; !ok {
panic("attempted to send segment not in write list")
}
s.sendSegment(seg)
// Update sndNxt if we actually sent new data (as opposed to
@@ -945,6 +1022,7 @@ func (s *sender) sendZeroWindowProbe() {
s.resendTimer.enable(s.RTO)
}
// +checklocks:s.ep.mu
func (s *sender) enableZeroWindowProbing() {
s.zeroWindowProbing = true
// We piggyback the probing on the retransmit timer with the
@@ -963,6 +1041,7 @@ func (s *sender) disableZeroWindowProbing() {
s.resendTimer.disable()
}
// +checklocks:s.ep.mu
func (s *sender) postXmit(dataSent bool, shouldScheduleProbe bool) {
if dataSent {
// We sent data, so we should stop the keepalive timer to ensure
@@ -1039,6 +1118,7 @@ func (s *sender) sendData() {
s.postXmit(dataSent, true /* shouldScheduleProbe */)
}
// +checklocks:s.ep.mu
func (s *sender) enterRecovery() {
// Initialize the variables used to detect spurious recovery after
// entering recovery.
@@ -1082,6 +1162,7 @@ func (s *sender) enterRecovery() {
s.ep.stack.Stats().TCP.FastRecovery.Increment()
}
// +checklocks:s.ep.mu
func (s *sender) leaveRecovery() {
s.FastRecovery.Active = false
s.FastRecovery.MaxCwnd = 0
@@ -1104,6 +1185,8 @@ func (s *sender) isAssignedSequenceNumber(seg *segment) bool {
// maintains the congestion window in number of packets and not bytes, so
// SetPipe() here measures number of outstanding packets rather than actual
// outstanding bytes in the network.
//
// +checklocks:s.ep.mu
func (s *sender) SetPipe() {
// If SACK isn't permitted or it is permitted but recovery is not active
// then ignore pipe calculations.
@@ -1157,6 +1240,8 @@ func (s *sender) SetPipe() {
// shouldEnterRecovery returns true if the sender should enter fast recovery
// based on dupAck count and sack scoreboard.
// See RFC 6675 section 5.
//
// +checklocks:s.ep.mu
func (s *sender) shouldEnterRecovery() bool {
return s.DupAckCount >= nDupAckThreshold ||
(s.ep.SACKPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection == 0 && s.ep.scoreboard.IsLost(s.SndUna))
@@ -1165,6 +1250,8 @@ func (s *sender) shouldEnterRecovery() bool {
// detectLoss is called when an ack is received and returns whether a loss is
// detected. It manages the state related to duplicate acks and determines if
// a retransmit is needed according to the rules in RFC 6582 (NewReno).
//
// +checklocks:s.ep.mu
func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) {
// We're not in fast recovery yet.
@@ -1215,6 +1302,8 @@ func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) {
// isDupAck determines if seg is a duplicate ack as defined in
// https://tools.ietf.org/html/rfc5681#section-2.
//
// +checklocks:s.ep.mu
func (s *sender) isDupAck(seg *segment) bool {
// A TCP that utilizes selective acknowledgments (SACKs) [RFC2018, RFC2883]
// can leverage the SACK information to determine when an incoming ACK is a
@@ -1245,6 +1334,8 @@ func (s *sender) isDupAck(seg *segment) bool {
//
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
// steps 2 and 3.
//
// +checklocks:s.ep.mu
func (s *sender) walkSACK(rcvdSeg *segment) bool {
s.rc.setDSACKSeen(false)
@@ -1362,6 +1453,7 @@ func (s *sender) recordRetransmitTS() {
s.retransmitTS = s.ep.tsValNow()
}
// +checklocks:s.ep.mu
func (s *sender) detectSpuriousRecovery(hasDSACK bool, tsEchoReply uint32) {
// Return if the sender has already detected spurious recovery.
if s.spuriousRecovery {
@@ -1567,7 +1659,8 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// have no data, but do consume a sequence number.
seg := s.writeList.Front()
if seg == nil {
panic(fmt.Sprintf("invalid state: there are %d unacknowledged bytes left, but the write list is empty:\n%+v", ackLeft, s.TCPSenderState))
panic(fmt.Sprintf("invalid state: there are %d unacknowledged bytes left, but the write list is empty:\n"+
"TCPSenderState: %+v\nsender: %+v\nendpoint: %+v", ackLeft, s.TCPSenderState, s, s.ep))
}
datalen := seg.logicalLen()
@@ -1730,6 +1823,7 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error {
// flags and sequence number.
// +checklocks:s.ep.mu
// +checklocksalias:s.ep.rcv.ep.mu=s.ep.mu
// +checklocksalias:s.ep.rcv.ep.snd.ep.mu=s.ep.mu
func (s *sender) sendSegmentFromPacketBuffer(pkt *stack.PacketBuffer, flags header.TCPFlags, seq seqnum.Value) tcpip.Error {
s.LastSendTime = s.ep.stack.Clock().NowMonotonic()
if seq == s.RTTMeasureSeqNum {
@@ -1751,7 +1845,9 @@ func (s *sender) sendSegmentFromPacketBuffer(pkt *stack.PacketBuffer, flags head
// sendEmptySegment sends a new empty segment, flags and sequence number.
// +checklocks:s.ep.mu
// +checklocksalias:s.ep.rcv.ep.snd.ep.mu=s.ep.mu
// +checklocksalias:s.ep.rcv.ep.mu=s.ep.mu
// +checklocksalias:s.ep.snd.ep.mu=s.ep.mu
func (s *sender) sendEmptySegment(flags header.TCPFlags, seq seqnum.Value) tcpip.Error {
s.LastSendTime = s.ep.stack.Clock().NowMonotonic()
if seq == s.RTTMeasureSeqNum {

View File

@@ -12,10 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
package tcp
import (
"context"
"time"
"gvisor.dev/gvisor/pkg/atomicbitops"
@@ -25,19 +24,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
)
// contextID is this package's type for context.Context.Value keys.
type contextID int
const (
// CtxRestoreStack is a Context.Value key for the stack to be used in restore.
CtxRestoreStack contextID = iota
)
// RestoreStackFromContext returns the stack to be used during restore.
func RestoreStackFromContext(ctx context.Context) *Stack {
return ctx.Value(CtxRestoreStack).(*Stack)
}
// TCPProbeFunc is the expected function type for a TCP probe function to be
// passed to stack.AddTCPProbe.
type TCPProbeFunc func(s *TCPEndpointState)

View File

@@ -177,7 +177,6 @@ func (p *processor) StateTypeName() string {
func (p *processor) StateFields() []string {
return []string{
"epQ",
"sleeper",
}
}
@@ -187,7 +186,6 @@ func (p *processor) beforeSave() {}
func (p *processor) StateSave(stateSinkObject state.Sink) {
p.beforeSave()
stateSinkObject.Save(0, &p.epQ)
stateSinkObject.Save(1, &p.sleeper)
}
func (p *processor) afterLoad(context.Context) {}
@@ -195,7 +193,6 @@ func (p *processor) afterLoad(context.Context) {}
// +checklocksignore
func (p *processor) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &p.epQ)
stateSourceObject.Load(1, &p.sleeper)
}
func (d *dispatcher) StateTypeName() string {
@@ -445,6 +442,8 @@ func (e *Endpoint) StateFields() []string {
"TCPEndpointStateInner",
"TransportEndpointInfo",
"DefaultSocketOptionsHandler",
"stack",
"protocol",
"waiterQueue",
"hardError",
"lastError",
@@ -454,6 +453,8 @@ func (e *Endpoint) StateFields() []string {
"rcvQueue",
"state",
"connectionDirectionState",
"isPortReserved",
"isRegistered",
"boundNICID",
"ipv4TTL",
"ipv6HopLimit",
@@ -502,58 +503,62 @@ func (e *Endpoint) StateSave(stateSinkObject state.Sink) {
e.beforeSave()
var stateValue EndpointState
stateValue = e.saveState()
stateSinkObject.SaveValue(10, stateValue)
stateSinkObject.SaveValue(12, stateValue)
stateSinkObject.Save(0, &e.TCPEndpointStateInner)
stateSinkObject.Save(1, &e.TransportEndpointInfo)
stateSinkObject.Save(2, &e.DefaultSocketOptionsHandler)
stateSinkObject.Save(3, &e.waiterQueue)
stateSinkObject.Save(4, &e.hardError)
stateSinkObject.Save(5, &e.lastError)
stateSinkObject.Save(6, &e.TCPRcvBufState)
stateSinkObject.Save(7, &e.rcvMemUsed)
stateSinkObject.Save(8, &e.ownedByUser)
stateSinkObject.Save(9, &e.rcvQueue)
stateSinkObject.Save(11, &e.connectionDirectionState)
stateSinkObject.Save(12, &e.boundNICID)
stateSinkObject.Save(13, &e.ipv4TTL)
stateSinkObject.Save(14, &e.ipv6HopLimit)
stateSinkObject.Save(15, &e.isConnectNotified)
stateSinkObject.Save(16, &e.h)
stateSinkObject.Save(17, &e.portFlags)
stateSinkObject.Save(18, &e.boundBindToDevice)
stateSinkObject.Save(19, &e.boundPortFlags)
stateSinkObject.Save(20, &e.boundDest)
stateSinkObject.Save(21, &e.effectiveNetProtos)
stateSinkObject.Save(22, &e.recentTSTime)
stateSinkObject.Save(23, &e.shutdownFlags)
stateSinkObject.Save(24, &e.tcpRecovery)
stateSinkObject.Save(25, &e.sack)
stateSinkObject.Save(26, &e.delay)
stateSinkObject.Save(27, &e.scoreboard)
stateSinkObject.Save(28, &e.segmentQueue)
stateSinkObject.Save(29, &e.userMSS)
stateSinkObject.Save(30, &e.maxSynRetries)
stateSinkObject.Save(31, &e.windowClamp)
stateSinkObject.Save(32, &e.sndQueueInfo)
stateSinkObject.Save(33, &e.cc)
stateSinkObject.Save(34, &e.keepalive)
stateSinkObject.Save(35, &e.userTimeout)
stateSinkObject.Save(36, &e.deferAccept)
stateSinkObject.Save(37, &e.acceptQueue)
stateSinkObject.Save(38, &e.rcv)
stateSinkObject.Save(39, &e.snd)
stateSinkObject.Save(40, &e.connectingAddress)
stateSinkObject.Save(41, &e.amss)
stateSinkObject.Save(42, &e.sendTOS)
stateSinkObject.Save(43, &e.gso)
stateSinkObject.Save(44, &e.stats)
stateSinkObject.Save(45, &e.tcpLingerTimeout)
stateSinkObject.Save(46, &e.closed)
stateSinkObject.Save(47, &e.txHash)
stateSinkObject.Save(48, &e.owner)
stateSinkObject.Save(49, &e.ops)
stateSinkObject.Save(50, &e.lastOutOfWindowAckTime)
stateSinkObject.Save(51, &e.pmtud)
stateSinkObject.Save(3, &e.stack)
stateSinkObject.Save(4, &e.protocol)
stateSinkObject.Save(5, &e.waiterQueue)
stateSinkObject.Save(6, &e.hardError)
stateSinkObject.Save(7, &e.lastError)
stateSinkObject.Save(8, &e.TCPRcvBufState)
stateSinkObject.Save(9, &e.rcvMemUsed)
stateSinkObject.Save(10, &e.ownedByUser)
stateSinkObject.Save(11, &e.rcvQueue)
stateSinkObject.Save(13, &e.connectionDirectionState)
stateSinkObject.Save(14, &e.isPortReserved)
stateSinkObject.Save(15, &e.isRegistered)
stateSinkObject.Save(16, &e.boundNICID)
stateSinkObject.Save(17, &e.ipv4TTL)
stateSinkObject.Save(18, &e.ipv6HopLimit)
stateSinkObject.Save(19, &e.isConnectNotified)
stateSinkObject.Save(20, &e.h)
stateSinkObject.Save(21, &e.portFlags)
stateSinkObject.Save(22, &e.boundBindToDevice)
stateSinkObject.Save(23, &e.boundPortFlags)
stateSinkObject.Save(24, &e.boundDest)
stateSinkObject.Save(25, &e.effectiveNetProtos)
stateSinkObject.Save(26, &e.recentTSTime)
stateSinkObject.Save(27, &e.shutdownFlags)
stateSinkObject.Save(28, &e.tcpRecovery)
stateSinkObject.Save(29, &e.sack)
stateSinkObject.Save(30, &e.delay)
stateSinkObject.Save(31, &e.scoreboard)
stateSinkObject.Save(32, &e.segmentQueue)
stateSinkObject.Save(33, &e.userMSS)
stateSinkObject.Save(34, &e.maxSynRetries)
stateSinkObject.Save(35, &e.windowClamp)
stateSinkObject.Save(36, &e.sndQueueInfo)
stateSinkObject.Save(37, &e.cc)
stateSinkObject.Save(38, &e.keepalive)
stateSinkObject.Save(39, &e.userTimeout)
stateSinkObject.Save(40, &e.deferAccept)
stateSinkObject.Save(41, &e.acceptQueue)
stateSinkObject.Save(42, &e.rcv)
stateSinkObject.Save(43, &e.snd)
stateSinkObject.Save(44, &e.connectingAddress)
stateSinkObject.Save(45, &e.amss)
stateSinkObject.Save(46, &e.sendTOS)
stateSinkObject.Save(47, &e.gso)
stateSinkObject.Save(48, &e.stats)
stateSinkObject.Save(49, &e.tcpLingerTimeout)
stateSinkObject.Save(50, &e.closed)
stateSinkObject.Save(51, &e.txHash)
stateSinkObject.Save(52, &e.owner)
stateSinkObject.Save(53, &e.ops)
stateSinkObject.Save(54, &e.lastOutOfWindowAckTime)
stateSinkObject.Save(55, &e.pmtud)
}
// +checklocksignore
@@ -561,55 +566,59 @@ func (e *Endpoint) StateLoad(ctx context.Context, stateSourceObject state.Source
stateSourceObject.Load(0, &e.TCPEndpointStateInner)
stateSourceObject.Load(1, &e.TransportEndpointInfo)
stateSourceObject.Load(2, &e.DefaultSocketOptionsHandler)
stateSourceObject.LoadWait(3, &e.waiterQueue)
stateSourceObject.Load(4, &e.hardError)
stateSourceObject.Load(5, &e.lastError)
stateSourceObject.Load(6, &e.TCPRcvBufState)
stateSourceObject.Load(7, &e.rcvMemUsed)
stateSourceObject.Load(8, &e.ownedByUser)
stateSourceObject.LoadWait(9, &e.rcvQueue)
stateSourceObject.Load(11, &e.connectionDirectionState)
stateSourceObject.Load(12, &e.boundNICID)
stateSourceObject.Load(13, &e.ipv4TTL)
stateSourceObject.Load(14, &e.ipv6HopLimit)
stateSourceObject.Load(15, &e.isConnectNotified)
stateSourceObject.Load(16, &e.h)
stateSourceObject.Load(17, &e.portFlags)
stateSourceObject.Load(18, &e.boundBindToDevice)
stateSourceObject.Load(19, &e.boundPortFlags)
stateSourceObject.Load(20, &e.boundDest)
stateSourceObject.Load(21, &e.effectiveNetProtos)
stateSourceObject.Load(22, &e.recentTSTime)
stateSourceObject.Load(23, &e.shutdownFlags)
stateSourceObject.Load(24, &e.tcpRecovery)
stateSourceObject.Load(25, &e.sack)
stateSourceObject.Load(26, &e.delay)
stateSourceObject.Load(27, &e.scoreboard)
stateSourceObject.LoadWait(28, &e.segmentQueue)
stateSourceObject.Load(29, &e.userMSS)
stateSourceObject.Load(30, &e.maxSynRetries)
stateSourceObject.Load(31, &e.windowClamp)
stateSourceObject.Load(32, &e.sndQueueInfo)
stateSourceObject.Load(33, &e.cc)
stateSourceObject.Load(34, &e.keepalive)
stateSourceObject.Load(35, &e.userTimeout)
stateSourceObject.Load(36, &e.deferAccept)
stateSourceObject.Load(37, &e.acceptQueue)
stateSourceObject.LoadWait(38, &e.rcv)
stateSourceObject.LoadWait(39, &e.snd)
stateSourceObject.Load(40, &e.connectingAddress)
stateSourceObject.Load(41, &e.amss)
stateSourceObject.Load(42, &e.sendTOS)
stateSourceObject.Load(43, &e.gso)
stateSourceObject.Load(44, &e.stats)
stateSourceObject.Load(45, &e.tcpLingerTimeout)
stateSourceObject.Load(46, &e.closed)
stateSourceObject.Load(47, &e.txHash)
stateSourceObject.Load(48, &e.owner)
stateSourceObject.Load(49, &e.ops)
stateSourceObject.Load(50, &e.lastOutOfWindowAckTime)
stateSourceObject.Load(51, &e.pmtud)
stateSourceObject.LoadValue(10, new(EndpointState), func(y any) { e.loadState(ctx, y.(EndpointState)) })
stateSourceObject.Load(3, &e.stack)
stateSourceObject.Load(4, &e.protocol)
stateSourceObject.LoadWait(5, &e.waiterQueue)
stateSourceObject.Load(6, &e.hardError)
stateSourceObject.Load(7, &e.lastError)
stateSourceObject.Load(8, &e.TCPRcvBufState)
stateSourceObject.Load(9, &e.rcvMemUsed)
stateSourceObject.Load(10, &e.ownedByUser)
stateSourceObject.LoadWait(11, &e.rcvQueue)
stateSourceObject.Load(13, &e.connectionDirectionState)
stateSourceObject.Load(14, &e.isPortReserved)
stateSourceObject.Load(15, &e.isRegistered)
stateSourceObject.Load(16, &e.boundNICID)
stateSourceObject.Load(17, &e.ipv4TTL)
stateSourceObject.Load(18, &e.ipv6HopLimit)
stateSourceObject.Load(19, &e.isConnectNotified)
stateSourceObject.Load(20, &e.h)
stateSourceObject.Load(21, &e.portFlags)
stateSourceObject.Load(22, &e.boundBindToDevice)
stateSourceObject.Load(23, &e.boundPortFlags)
stateSourceObject.Load(24, &e.boundDest)
stateSourceObject.Load(25, &e.effectiveNetProtos)
stateSourceObject.Load(26, &e.recentTSTime)
stateSourceObject.Load(27, &e.shutdownFlags)
stateSourceObject.Load(28, &e.tcpRecovery)
stateSourceObject.Load(29, &e.sack)
stateSourceObject.Load(30, &e.delay)
stateSourceObject.Load(31, &e.scoreboard)
stateSourceObject.LoadWait(32, &e.segmentQueue)
stateSourceObject.Load(33, &e.userMSS)
stateSourceObject.Load(34, &e.maxSynRetries)
stateSourceObject.Load(35, &e.windowClamp)
stateSourceObject.Load(36, &e.sndQueueInfo)
stateSourceObject.Load(37, &e.cc)
stateSourceObject.Load(38, &e.keepalive)
stateSourceObject.Load(39, &e.userTimeout)
stateSourceObject.Load(40, &e.deferAccept)
stateSourceObject.Load(41, &e.acceptQueue)
stateSourceObject.LoadWait(42, &e.rcv)
stateSourceObject.LoadWait(43, &e.snd)
stateSourceObject.Load(44, &e.connectingAddress)
stateSourceObject.Load(45, &e.amss)
stateSourceObject.Load(46, &e.sendTOS)
stateSourceObject.Load(47, &e.gso)
stateSourceObject.Load(48, &e.stats)
stateSourceObject.Load(49, &e.tcpLingerTimeout)
stateSourceObject.Load(50, &e.closed)
stateSourceObject.Load(51, &e.txHash)
stateSourceObject.Load(52, &e.owner)
stateSourceObject.Load(53, &e.ops)
stateSourceObject.Load(54, &e.lastOutOfWindowAckTime)
stateSourceObject.Load(55, &e.pmtud)
stateSourceObject.LoadValue(12, new(EndpointState), func(y any) { e.loadState(ctx, y.(EndpointState)) })
stateSourceObject.AfterLoad(func() { e.afterLoad(ctx) })
}
@@ -1106,6 +1115,34 @@ func (s *sender) StateLoad(ctx context.Context, stateSourceObject state.Source)
stateSourceObject.Load(16, &s.startCork)
}
func (wl *protectedWriteList) StateTypeName() string {
return "pkg/tcpip/transport/tcp.protectedWriteList"
}
func (wl *protectedWriteList) StateFields() []string {
return []string{
"writeList",
"set",
}
}
func (wl *protectedWriteList) beforeSave() {}
// +checklocksignore
func (wl *protectedWriteList) StateSave(stateSinkObject state.Sink) {
wl.beforeSave()
stateSinkObject.Save(0, &wl.writeList)
stateSinkObject.Save(1, &wl.set)
}
func (wl *protectedWriteList) afterLoad(context.Context) {}
// +checklocksignore
func (wl *protectedWriteList) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &wl.writeList)
stateSourceObject.Load(1, &wl.set)
}
func (r *rtt) StateTypeName() string {
return "pkg/tcpip/transport/tcp.rtt"
}
@@ -1131,6 +1168,586 @@ func (r *rtt) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &r.TCPRTTState)
}
func (t *TCPCubicState) StateTypeName() string {
return "pkg/tcpip/transport/tcp.TCPCubicState"
}
func (t *TCPCubicState) StateFields() []string {
return []string{
"WLastMax",
"WMax",
"T",
"TimeSinceLastCongestion",
"C",
"K",
"Beta",
"WC",
"WEst",
"EndSeq",
"CurrRTT",
"LastRTT",
"SampleCount",
"LastAck",
"RoundStart",
}
}
func (t *TCPCubicState) beforeSave() {}
// +checklocksignore
func (t *TCPCubicState) StateSave(stateSinkObject state.Sink) {
t.beforeSave()
stateSinkObject.Save(0, &t.WLastMax)
stateSinkObject.Save(1, &t.WMax)
stateSinkObject.Save(2, &t.T)
stateSinkObject.Save(3, &t.TimeSinceLastCongestion)
stateSinkObject.Save(4, &t.C)
stateSinkObject.Save(5, &t.K)
stateSinkObject.Save(6, &t.Beta)
stateSinkObject.Save(7, &t.WC)
stateSinkObject.Save(8, &t.WEst)
stateSinkObject.Save(9, &t.EndSeq)
stateSinkObject.Save(10, &t.CurrRTT)
stateSinkObject.Save(11, &t.LastRTT)
stateSinkObject.Save(12, &t.SampleCount)
stateSinkObject.Save(13, &t.LastAck)
stateSinkObject.Save(14, &t.RoundStart)
}
func (t *TCPCubicState) afterLoad(context.Context) {}
// +checklocksignore
func (t *TCPCubicState) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &t.WLastMax)
stateSourceObject.Load(1, &t.WMax)
stateSourceObject.Load(2, &t.T)
stateSourceObject.Load(3, &t.TimeSinceLastCongestion)
stateSourceObject.Load(4, &t.C)
stateSourceObject.Load(5, &t.K)
stateSourceObject.Load(6, &t.Beta)
stateSourceObject.Load(7, &t.WC)
stateSourceObject.Load(8, &t.WEst)
stateSourceObject.Load(9, &t.EndSeq)
stateSourceObject.Load(10, &t.CurrRTT)
stateSourceObject.Load(11, &t.LastRTT)
stateSourceObject.Load(12, &t.SampleCount)
stateSourceObject.Load(13, &t.LastAck)
stateSourceObject.Load(14, &t.RoundStart)
}
func (t *TCPRACKState) StateTypeName() string {
return "pkg/tcpip/transport/tcp.TCPRACKState"
}
func (t *TCPRACKState) StateFields() []string {
return []string{
"XmitTime",
"EndSequence",
"FACK",
"RTT",
"Reord",
"DSACKSeen",
"ReoWnd",
"ReoWndIncr",
"ReoWndPersist",
"RTTSeq",
}
}
func (t *TCPRACKState) beforeSave() {}
// +checklocksignore
func (t *TCPRACKState) StateSave(stateSinkObject state.Sink) {
t.beforeSave()
stateSinkObject.Save(0, &t.XmitTime)
stateSinkObject.Save(1, &t.EndSequence)
stateSinkObject.Save(2, &t.FACK)
stateSinkObject.Save(3, &t.RTT)
stateSinkObject.Save(4, &t.Reord)
stateSinkObject.Save(5, &t.DSACKSeen)
stateSinkObject.Save(6, &t.ReoWnd)
stateSinkObject.Save(7, &t.ReoWndIncr)
stateSinkObject.Save(8, &t.ReoWndPersist)
stateSinkObject.Save(9, &t.RTTSeq)
}
func (t *TCPRACKState) afterLoad(context.Context) {}
// +checklocksignore
func (t *TCPRACKState) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &t.XmitTime)
stateSourceObject.Load(1, &t.EndSequence)
stateSourceObject.Load(2, &t.FACK)
stateSourceObject.Load(3, &t.RTT)
stateSourceObject.Load(4, &t.Reord)
stateSourceObject.Load(5, &t.DSACKSeen)
stateSourceObject.Load(6, &t.ReoWnd)
stateSourceObject.Load(7, &t.ReoWndIncr)
stateSourceObject.Load(8, &t.ReoWndPersist)
stateSourceObject.Load(9, &t.RTTSeq)
}
func (t *TCPEndpointID) StateTypeName() string {
return "pkg/tcpip/transport/tcp.TCPEndpointID"
}
func (t *TCPEndpointID) StateFields() []string {
return []string{
"LocalPort",
"LocalAddress",
"RemotePort",
"RemoteAddress",
}
}
func (t *TCPEndpointID) beforeSave() {}
// +checklocksignore
func (t *TCPEndpointID) StateSave(stateSinkObject state.Sink) {
t.beforeSave()
stateSinkObject.Save(0, &t.LocalPort)
stateSinkObject.Save(1, &t.LocalAddress)
stateSinkObject.Save(2, &t.RemotePort)
stateSinkObject.Save(3, &t.RemoteAddress)
}
func (t *TCPEndpointID) afterLoad(context.Context) {}
// +checklocksignore
func (t *TCPEndpointID) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &t.LocalPort)
stateSourceObject.Load(1, &t.LocalAddress)
stateSourceObject.Load(2, &t.RemotePort)
stateSourceObject.Load(3, &t.RemoteAddress)
}
func (t *TCPFastRecoveryState) StateTypeName() string {
return "pkg/tcpip/transport/tcp.TCPFastRecoveryState"
}
func (t *TCPFastRecoveryState) StateFields() []string {
return []string{
"Active",
"First",
"Last",
"MaxCwnd",
"HighRxt",
"RescueRxt",
}
}
func (t *TCPFastRecoveryState) beforeSave() {}
// +checklocksignore
func (t *TCPFastRecoveryState) StateSave(stateSinkObject state.Sink) {
t.beforeSave()
stateSinkObject.Save(0, &t.Active)
stateSinkObject.Save(1, &t.First)
stateSinkObject.Save(2, &t.Last)
stateSinkObject.Save(3, &t.MaxCwnd)
stateSinkObject.Save(4, &t.HighRxt)
stateSinkObject.Save(5, &t.RescueRxt)
}
func (t *TCPFastRecoveryState) afterLoad(context.Context) {}
// +checklocksignore
func (t *TCPFastRecoveryState) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &t.Active)
stateSourceObject.Load(1, &t.First)
stateSourceObject.Load(2, &t.Last)
stateSourceObject.Load(3, &t.MaxCwnd)
stateSourceObject.Load(4, &t.HighRxt)
stateSourceObject.Load(5, &t.RescueRxt)
}
func (t *TCPReceiverState) StateTypeName() string {
return "pkg/tcpip/transport/tcp.TCPReceiverState"
}
func (t *TCPReceiverState) StateFields() []string {
return []string{
"RcvNxt",
"RcvAcc",
"RcvWndScale",
"PendingBufUsed",
}
}
func (t *TCPReceiverState) beforeSave() {}
// +checklocksignore
func (t *TCPReceiverState) StateSave(stateSinkObject state.Sink) {
t.beforeSave()
stateSinkObject.Save(0, &t.RcvNxt)
stateSinkObject.Save(1, &t.RcvAcc)
stateSinkObject.Save(2, &t.RcvWndScale)
stateSinkObject.Save(3, &t.PendingBufUsed)
}
func (t *TCPReceiverState) afterLoad(context.Context) {}
// +checklocksignore
func (t *TCPReceiverState) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &t.RcvNxt)
stateSourceObject.Load(1, &t.RcvAcc)
stateSourceObject.Load(2, &t.RcvWndScale)
stateSourceObject.Load(3, &t.PendingBufUsed)
}
func (t *TCPRTTState) StateTypeName() string {
return "pkg/tcpip/transport/tcp.TCPRTTState"
}
func (t *TCPRTTState) StateFields() []string {
return []string{
"SRTT",
"RTTVar",
"SRTTInited",
}
}
func (t *TCPRTTState) beforeSave() {}
// +checklocksignore
func (t *TCPRTTState) StateSave(stateSinkObject state.Sink) {
t.beforeSave()
stateSinkObject.Save(0, &t.SRTT)
stateSinkObject.Save(1, &t.RTTVar)
stateSinkObject.Save(2, &t.SRTTInited)
}
func (t *TCPRTTState) afterLoad(context.Context) {}
// +checklocksignore
func (t *TCPRTTState) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &t.SRTT)
stateSourceObject.Load(1, &t.RTTVar)
stateSourceObject.Load(2, &t.SRTTInited)
}
func (t *TCPSenderState) StateTypeName() string {
return "pkg/tcpip/transport/tcp.TCPSenderState"
}
func (t *TCPSenderState) StateFields() []string {
return []string{
"LastSendTime",
"DupAckCount",
"SndCwnd",
"Ssthresh",
"SndCAAckCount",
"Outstanding",
"SackedOut",
"SndWnd",
"SndUna",
"SndNxt",
"RTTMeasureSeqNum",
"RTTMeasureTime",
"Closed",
"RTO",
"RTTState",
"MaxPayloadSize",
"SndWndScale",
"MaxSentAck",
"FastRecovery",
"Cubic",
"RACKState",
"RetransmitTS",
"SpuriousRecovery",
}
}
func (t *TCPSenderState) beforeSave() {}
// +checklocksignore
func (t *TCPSenderState) StateSave(stateSinkObject state.Sink) {
t.beforeSave()
stateSinkObject.Save(0, &t.LastSendTime)
stateSinkObject.Save(1, &t.DupAckCount)
stateSinkObject.Save(2, &t.SndCwnd)
stateSinkObject.Save(3, &t.Ssthresh)
stateSinkObject.Save(4, &t.SndCAAckCount)
stateSinkObject.Save(5, &t.Outstanding)
stateSinkObject.Save(6, &t.SackedOut)
stateSinkObject.Save(7, &t.SndWnd)
stateSinkObject.Save(8, &t.SndUna)
stateSinkObject.Save(9, &t.SndNxt)
stateSinkObject.Save(10, &t.RTTMeasureSeqNum)
stateSinkObject.Save(11, &t.RTTMeasureTime)
stateSinkObject.Save(12, &t.Closed)
stateSinkObject.Save(13, &t.RTO)
stateSinkObject.Save(14, &t.RTTState)
stateSinkObject.Save(15, &t.MaxPayloadSize)
stateSinkObject.Save(16, &t.SndWndScale)
stateSinkObject.Save(17, &t.MaxSentAck)
stateSinkObject.Save(18, &t.FastRecovery)
stateSinkObject.Save(19, &t.Cubic)
stateSinkObject.Save(20, &t.RACKState)
stateSinkObject.Save(21, &t.RetransmitTS)
stateSinkObject.Save(22, &t.SpuriousRecovery)
}
func (t *TCPSenderState) afterLoad(context.Context) {}
// +checklocksignore
func (t *TCPSenderState) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &t.LastSendTime)
stateSourceObject.Load(1, &t.DupAckCount)
stateSourceObject.Load(2, &t.SndCwnd)
stateSourceObject.Load(3, &t.Ssthresh)
stateSourceObject.Load(4, &t.SndCAAckCount)
stateSourceObject.Load(5, &t.Outstanding)
stateSourceObject.Load(6, &t.SackedOut)
stateSourceObject.Load(7, &t.SndWnd)
stateSourceObject.Load(8, &t.SndUna)
stateSourceObject.Load(9, &t.SndNxt)
stateSourceObject.Load(10, &t.RTTMeasureSeqNum)
stateSourceObject.Load(11, &t.RTTMeasureTime)
stateSourceObject.Load(12, &t.Closed)
stateSourceObject.Load(13, &t.RTO)
stateSourceObject.Load(14, &t.RTTState)
stateSourceObject.Load(15, &t.MaxPayloadSize)
stateSourceObject.Load(16, &t.SndWndScale)
stateSourceObject.Load(17, &t.MaxSentAck)
stateSourceObject.Load(18, &t.FastRecovery)
stateSourceObject.Load(19, &t.Cubic)
stateSourceObject.Load(20, &t.RACKState)
stateSourceObject.Load(21, &t.RetransmitTS)
stateSourceObject.Load(22, &t.SpuriousRecovery)
}
func (t *TCPSACKInfo) StateTypeName() string {
return "pkg/tcpip/transport/tcp.TCPSACKInfo"
}
func (t *TCPSACKInfo) StateFields() []string {
return []string{
"Blocks",
"ReceivedBlocks",
"MaxSACKED",
}
}
func (t *TCPSACKInfo) beforeSave() {}
// +checklocksignore
func (t *TCPSACKInfo) StateSave(stateSinkObject state.Sink) {
t.beforeSave()
stateSinkObject.Save(0, &t.Blocks)
stateSinkObject.Save(1, &t.ReceivedBlocks)
stateSinkObject.Save(2, &t.MaxSACKED)
}
func (t *TCPSACKInfo) afterLoad(context.Context) {}
// +checklocksignore
func (t *TCPSACKInfo) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &t.Blocks)
stateSourceObject.Load(1, &t.ReceivedBlocks)
stateSourceObject.Load(2, &t.MaxSACKED)
}
func (r *RcvBufAutoTuneParams) StateTypeName() string {
return "pkg/tcpip/transport/tcp.RcvBufAutoTuneParams"
}
func (r *RcvBufAutoTuneParams) StateFields() []string {
return []string{
"MeasureTime",
"CopiedBytes",
"PrevCopiedBytes",
"RcvBufSize",
"RTT",
"RTTVar",
"RTTMeasureSeqNumber",
"RTTMeasureTime",
"Disabled",
}
}
func (r *RcvBufAutoTuneParams) beforeSave() {}
// +checklocksignore
func (r *RcvBufAutoTuneParams) StateSave(stateSinkObject state.Sink) {
r.beforeSave()
stateSinkObject.Save(0, &r.MeasureTime)
stateSinkObject.Save(1, &r.CopiedBytes)
stateSinkObject.Save(2, &r.PrevCopiedBytes)
stateSinkObject.Save(3, &r.RcvBufSize)
stateSinkObject.Save(4, &r.RTT)
stateSinkObject.Save(5, &r.RTTVar)
stateSinkObject.Save(6, &r.RTTMeasureSeqNumber)
stateSinkObject.Save(7, &r.RTTMeasureTime)
stateSinkObject.Save(8, &r.Disabled)
}
func (r *RcvBufAutoTuneParams) afterLoad(context.Context) {}
// +checklocksignore
func (r *RcvBufAutoTuneParams) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &r.MeasureTime)
stateSourceObject.Load(1, &r.CopiedBytes)
stateSourceObject.Load(2, &r.PrevCopiedBytes)
stateSourceObject.Load(3, &r.RcvBufSize)
stateSourceObject.Load(4, &r.RTT)
stateSourceObject.Load(5, &r.RTTVar)
stateSourceObject.Load(6, &r.RTTMeasureSeqNumber)
stateSourceObject.Load(7, &r.RTTMeasureTime)
stateSourceObject.Load(8, &r.Disabled)
}
func (t *TCPRcvBufState) StateTypeName() string {
return "pkg/tcpip/transport/tcp.TCPRcvBufState"
}
func (t *TCPRcvBufState) StateFields() []string {
return []string{
"RcvBufUsed",
"RcvAutoParams",
"RcvClosed",
}
}
func (t *TCPRcvBufState) beforeSave() {}
// +checklocksignore
func (t *TCPRcvBufState) StateSave(stateSinkObject state.Sink) {
t.beforeSave()
stateSinkObject.Save(0, &t.RcvBufUsed)
stateSinkObject.Save(1, &t.RcvAutoParams)
stateSinkObject.Save(2, &t.RcvClosed)
}
func (t *TCPRcvBufState) afterLoad(context.Context) {}
// +checklocksignore
func (t *TCPRcvBufState) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &t.RcvBufUsed)
stateSourceObject.Load(1, &t.RcvAutoParams)
stateSourceObject.Load(2, &t.RcvClosed)
}
func (t *TCPSndBufState) StateTypeName() string {
return "pkg/tcpip/transport/tcp.TCPSndBufState"
}
func (t *TCPSndBufState) StateFields() []string {
return []string{
"SndBufSize",
"SndBufUsed",
"SndClosed",
"PacketTooBigCount",
"SndMTU",
"AutoTuneSndBufDisabled",
}
}
func (t *TCPSndBufState) beforeSave() {}
// +checklocksignore
func (t *TCPSndBufState) StateSave(stateSinkObject state.Sink) {
t.beforeSave()
stateSinkObject.Save(0, &t.SndBufSize)
stateSinkObject.Save(1, &t.SndBufUsed)
stateSinkObject.Save(2, &t.SndClosed)
stateSinkObject.Save(3, &t.PacketTooBigCount)
stateSinkObject.Save(4, &t.SndMTU)
stateSinkObject.Save(5, &t.AutoTuneSndBufDisabled)
}
func (t *TCPSndBufState) afterLoad(context.Context) {}
// +checklocksignore
func (t *TCPSndBufState) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &t.SndBufSize)
stateSourceObject.Load(1, &t.SndBufUsed)
stateSourceObject.Load(2, &t.SndClosed)
stateSourceObject.Load(3, &t.PacketTooBigCount)
stateSourceObject.Load(4, &t.SndMTU)
stateSourceObject.Load(5, &t.AutoTuneSndBufDisabled)
}
func (t *TCPEndpointStateInner) StateTypeName() string {
return "pkg/tcpip/transport/tcp.TCPEndpointStateInner"
}
func (t *TCPEndpointStateInner) StateFields() []string {
return []string{
"TSOffset",
"SACKPermitted",
"SendTSOk",
"RecentTS",
}
}
func (t *TCPEndpointStateInner) beforeSave() {}
// +checklocksignore
func (t *TCPEndpointStateInner) StateSave(stateSinkObject state.Sink) {
t.beforeSave()
stateSinkObject.Save(0, &t.TSOffset)
stateSinkObject.Save(1, &t.SACKPermitted)
stateSinkObject.Save(2, &t.SendTSOk)
stateSinkObject.Save(3, &t.RecentTS)
}
func (t *TCPEndpointStateInner) afterLoad(context.Context) {}
// +checklocksignore
func (t *TCPEndpointStateInner) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &t.TSOffset)
stateSourceObject.Load(1, &t.SACKPermitted)
stateSourceObject.Load(2, &t.SendTSOk)
stateSourceObject.Load(3, &t.RecentTS)
}
func (t *TCPEndpointState) StateTypeName() string {
return "pkg/tcpip/transport/tcp.TCPEndpointState"
}
func (t *TCPEndpointState) StateFields() []string {
return []string{
"TCPEndpointStateInner",
"ID",
"SegTime",
"RcvBufState",
"SndBufState",
"SACK",
"Receiver",
"Sender",
}
}
func (t *TCPEndpointState) beforeSave() {}
// +checklocksignore
func (t *TCPEndpointState) StateSave(stateSinkObject state.Sink) {
t.beforeSave()
stateSinkObject.Save(0, &t.TCPEndpointStateInner)
stateSinkObject.Save(1, &t.ID)
stateSinkObject.Save(2, &t.SegTime)
stateSinkObject.Save(3, &t.RcvBufState)
stateSinkObject.Save(4, &t.SndBufState)
stateSinkObject.Save(5, &t.SACK)
stateSinkObject.Save(6, &t.Receiver)
stateSinkObject.Save(7, &t.Sender)
}
func (t *TCPEndpointState) afterLoad(context.Context) {}
// +checklocksignore
func (t *TCPEndpointState) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &t.TCPEndpointStateInner)
stateSourceObject.Load(1, &t.ID)
stateSourceObject.Load(2, &t.SegTime)
stateSourceObject.Load(3, &t.RcvBufState)
stateSourceObject.Load(4, &t.SndBufState)
stateSourceObject.Load(5, &t.SACK)
stateSourceObject.Load(6, &t.Receiver)
stateSourceObject.Load(7, &t.Sender)
}
func (l *endpointList) StateTypeName() string {
return "pkg/tcpip/transport/tcp.endpointList"
}
@@ -1292,7 +1909,21 @@ func init() {
state.Register((*segment)(nil))
state.Register((*segmentQueue)(nil))
state.Register((*sender)(nil))
state.Register((*protectedWriteList)(nil))
state.Register((*rtt)(nil))
state.Register((*TCPCubicState)(nil))
state.Register((*TCPRACKState)(nil))
state.Register((*TCPEndpointID)(nil))
state.Register((*TCPFastRecoveryState)(nil))
state.Register((*TCPReceiverState)(nil))
state.Register((*TCPRTTState)(nil))
state.Register((*TCPSenderState)(nil))
state.Register((*TCPSACKInfo)(nil))
state.Register((*RcvBufAutoTuneParams)(nil))
state.Register((*TCPRcvBufState)(nil))
state.Register((*TCPSndBufState)(nil))
state.Register((*TCPEndpointStateInner)(nil))
state.Register((*TCPEndpointState)(nil))
state.Register((*endpointList)(nil))
state.Register((*endpointEntry)(nil))
state.Register((*segmentList)(nil))

View File

@@ -61,7 +61,7 @@ type endpoint struct {
// The following fields are initialized at creation time and do not
// change throughout the lifetime of the endpoint.
stack *stack.Stack `state:"manual"`
stack *stack.Stack
waiterQueue *waiter.Queue
net network.Endpoint
stats tcpip.TransportEndpointStats
@@ -160,11 +160,16 @@ func (e *endpoint) Abort() {
// associated with it.
func (e *endpoint) Close() {
e.mu.Lock()
defer e.mu.Unlock()
e.closeLocked()
}
// Preconditions: e.mu is locked.
// +checklocks:e.mu
func (e *endpoint) closeLocked() {
switch state := e.net.State(); state {
case transport.DatagramEndpointStateInitial:
case transport.DatagramEndpointStateClosed:
e.mu.Unlock()
return
case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
id := e.net.Info().ID
@@ -201,7 +206,6 @@ func (e *endpoint) Close() {
e.net.Shutdown()
e.net.Close()
e.readShutdown = true
e.mu.Unlock()
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
}
@@ -952,7 +956,9 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
Addr: id.LocalAddress,
Port: hdr.DestinationPort(),
},
pkt: pkt.IncRef(),
// We need to clone the packet because ReadTo modifies the write index of
// the underlying buffer. Clone does not copy the data, just the metadata.
pkt: pkt.Clone(),
}
e.rcvList.PushBack(packet)
e.rcvBufSize += pkt.Data().Size()

View File

@@ -16,7 +16,6 @@ package udp
import (
"context"
"fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -36,7 +35,11 @@ func (p *udpPacket) loadReceivedAt(_ context.Context, nsec int64) {
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad(ctx context.Context) {
stack.RestoreStackFromContext(ctx).RegisterRestoredEndpoint(e)
if e.stack.IsSaveRestoreEnabled() {
e.stack.RegisterRestoredEndpoint(e)
} else {
stack.RestoreStackFromContext(ctx).RegisterRestoredEndpoint(e)
}
}
// beforeSave is invoked by stateify.
@@ -53,7 +56,10 @@ func (e *endpoint) Restore(s *stack.Stack) {
defer e.mu.Unlock()
e.net.Resume(s)
if e.stack.IsSaveRestoreEnabled() {
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
return
}
e.stack = s
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
@@ -69,12 +75,12 @@ func (e *endpoint) Restore(s *stack.Stack) {
id.RemotePort = e.remotePort
id, e.boundBindToDevice, err = e.registerWithStack(e.effectiveNetProtos, id)
if err != nil {
panic(err)
panic("registering udp endpoint with the stack failed during restore")
}
e.localPort = id.LocalPort
e.remotePort = id.RemotePort
default:
panic(fmt.Sprintf("unhandled state = %s", state))
panic("unhandled state")
}
}

View File

@@ -47,7 +47,7 @@ func (f *Forwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Packet
f.handler(&ForwarderRequest{
stack: f.stack,
id: id,
pkt: pkt.IncRef(),
pkt: pkt.Clone(),
})
return true
@@ -76,15 +76,17 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
netHdr := r.pkt.Network()
if err := ep.net.Bind(tcpip.FullAddress{NIC: r.pkt.NICID, Addr: netHdr.DestinationAddress(), Port: r.id.LocalPort}); err != nil {
ep.closeLocked()
return nil, err
}
if err := ep.net.Connect(tcpip.FullAddress{NIC: r.pkt.NICID, Addr: netHdr.SourceAddress(), Port: r.id.RemotePort}); err != nil {
ep.closeLocked()
return nil, err
}
if err := r.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil {
ep.Close()
ep.closeLocked()
return nil, err
}

View File

@@ -124,6 +124,9 @@ func (*protocol) Pause() {}
// Resume implements stack.TransportProtocol.Resume.
func (*protocol) Resume() {}
// Restore implements stack.TransportProtocol.Restore.
func (*protocol) Restore() {}
// Parse implements stack.TransportProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
return parse.UDP(pkt)

View File

@@ -66,6 +66,7 @@ func (e *endpoint) StateTypeName() string {
func (e *endpoint) StateFields() []string {
return []string{
"DefaultSocketOptionsHandler",
"stack",
"waiterQueue",
"net",
"stats",
@@ -90,45 +91,47 @@ func (e *endpoint) StateFields() []string {
func (e *endpoint) StateSave(stateSinkObject state.Sink) {
e.beforeSave()
stateSinkObject.Save(0, &e.DefaultSocketOptionsHandler)
stateSinkObject.Save(1, &e.waiterQueue)
stateSinkObject.Save(2, &e.net)
stateSinkObject.Save(3, &e.stats)
stateSinkObject.Save(4, &e.ops)
stateSinkObject.Save(5, &e.rcvReady)
stateSinkObject.Save(6, &e.rcvList)
stateSinkObject.Save(7, &e.rcvBufSize)
stateSinkObject.Save(8, &e.rcvClosed)
stateSinkObject.Save(9, &e.lastError)
stateSinkObject.Save(10, &e.portFlags)
stateSinkObject.Save(11, &e.boundBindToDevice)
stateSinkObject.Save(12, &e.boundPortFlags)
stateSinkObject.Save(13, &e.readShutdown)
stateSinkObject.Save(14, &e.effectiveNetProtos)
stateSinkObject.Save(15, &e.frozen)
stateSinkObject.Save(16, &e.localPort)
stateSinkObject.Save(17, &e.remotePort)
stateSinkObject.Save(1, &e.stack)
stateSinkObject.Save(2, &e.waiterQueue)
stateSinkObject.Save(3, &e.net)
stateSinkObject.Save(4, &e.stats)
stateSinkObject.Save(5, &e.ops)
stateSinkObject.Save(6, &e.rcvReady)
stateSinkObject.Save(7, &e.rcvList)
stateSinkObject.Save(8, &e.rcvBufSize)
stateSinkObject.Save(9, &e.rcvClosed)
stateSinkObject.Save(10, &e.lastError)
stateSinkObject.Save(11, &e.portFlags)
stateSinkObject.Save(12, &e.boundBindToDevice)
stateSinkObject.Save(13, &e.boundPortFlags)
stateSinkObject.Save(14, &e.readShutdown)
stateSinkObject.Save(15, &e.effectiveNetProtos)
stateSinkObject.Save(16, &e.frozen)
stateSinkObject.Save(17, &e.localPort)
stateSinkObject.Save(18, &e.remotePort)
}
// +checklocksignore
func (e *endpoint) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &e.DefaultSocketOptionsHandler)
stateSourceObject.Load(1, &e.waiterQueue)
stateSourceObject.Load(2, &e.net)
stateSourceObject.Load(3, &e.stats)
stateSourceObject.Load(4, &e.ops)
stateSourceObject.Load(5, &e.rcvReady)
stateSourceObject.Load(6, &e.rcvList)
stateSourceObject.Load(7, &e.rcvBufSize)
stateSourceObject.Load(8, &e.rcvClosed)
stateSourceObject.Load(9, &e.lastError)
stateSourceObject.Load(10, &e.portFlags)
stateSourceObject.Load(11, &e.boundBindToDevice)
stateSourceObject.Load(12, &e.boundPortFlags)
stateSourceObject.Load(13, &e.readShutdown)
stateSourceObject.Load(14, &e.effectiveNetProtos)
stateSourceObject.Load(15, &e.frozen)
stateSourceObject.Load(16, &e.localPort)
stateSourceObject.Load(17, &e.remotePort)
stateSourceObject.Load(1, &e.stack)
stateSourceObject.Load(2, &e.waiterQueue)
stateSourceObject.Load(3, &e.net)
stateSourceObject.Load(4, &e.stats)
stateSourceObject.Load(5, &e.ops)
stateSourceObject.Load(6, &e.rcvReady)
stateSourceObject.Load(7, &e.rcvList)
stateSourceObject.Load(8, &e.rcvBufSize)
stateSourceObject.Load(9, &e.rcvClosed)
stateSourceObject.Load(10, &e.lastError)
stateSourceObject.Load(11, &e.portFlags)
stateSourceObject.Load(12, &e.boundBindToDevice)
stateSourceObject.Load(13, &e.boundPortFlags)
stateSourceObject.Load(14, &e.readShutdown)
stateSourceObject.Load(15, &e.effectiveNetProtos)
stateSourceObject.Load(16, &e.frozen)
stateSourceObject.Load(17, &e.localPort)
stateSourceObject.Load(18, &e.remotePort)
stateSourceObject.AfterLoad(func() { e.afterLoad(ctx) })
}

View File

@@ -77,7 +77,7 @@ const (
EventInternal EventMask = 0x1000
EventRdHUp EventMask = 0x2000 // POLLRDHUP
allEvents EventMask = 0x1f | EventRdNorm | EventWrNorm | EventRdHUp
AllEvents EventMask = 0x1f | EventRdNorm | EventWrNorm | EventRdHUp
ReadableEvents EventMask = EventIn | EventRdNorm
WritableEvents EventMask = EventOut | EventWrNorm
)
@@ -86,7 +86,7 @@ const (
// from the Linux events e, which is in the format used by poll(2).
func EventMaskFromLinux(e uint32) EventMask {
// Our flag definitions are currently identical to Linux.
return EventMask(e) & allEvents
return EventMask(e) & AllEvents
}
// ToLinux returns e in the format used by Linux poll(2).
@@ -259,45 +259,23 @@ func (q *Queue) IsEmpty() bool {
return q.list.Front() == nil
}
// AlwaysReady implements the Waitable interface but is always ready. Embedding
// this struct into another struct makes it implement the boilerplate empty
// functions automatically.
type AlwaysReady struct {
}
// Readiness always returns the input mask because this object is always ready.
func (*AlwaysReady) Readiness(mask EventMask) EventMask {
return mask
}
// EventRegister doesn't do anything because this object doesn't need to issue
// notifications because its readiness never changes.
func (*AlwaysReady) EventRegister(*Entry) error {
return nil
}
// EventUnregister doesn't do anything because this object doesn't need to issue
// notifications because its readiness never changes.
func (*AlwaysReady) EventUnregister(e *Entry) {
}
// NeverReady implements the Waitable interface but is never ready. Otherwise,
// this is exactly the same as AlwaysReady.
type NeverReady struct {
}
// Readiness always returns the input mask because this object is always ready.
func (*NeverReady) Readiness(mask EventMask) EventMask {
return mask
// Readiness always returns 0 because this object is never ready.
func (*NeverReady) Readiness(EventMask) EventMask {
return 0
}
// EventRegister doesn't do anything because this object doesn't need to issue
// notifications because its readiness never changes.
func (*NeverReady) EventRegister(e *Entry) error {
func (*NeverReady) EventRegister(*Entry) error {
return nil
}
// EventUnregister doesn't do anything because this object doesn't need to issue
// notifications because its readiness never changes.
func (*NeverReady) EventUnregister(e *Entry) {
func (*NeverReady) EventUnregister(*Entry) {
}