Update dependencies

This commit is contained in:
bluepython508
2024-11-01 17:33:34 +00:00
parent 033ac0b400
commit 5cdfab398d
3596 changed files with 1033483 additions and 259 deletions

View File

@@ -0,0 +1,76 @@
package state
// A Range represents a contiguous range of T.
//
// +stateify savable
type addrRange struct {
// Start is the inclusive start of the range.
Start uintptr
// End is the exclusive end of the range.
End uintptr
}
// WellFormed returns true if r.Start <= r.End. All other methods on a Range
// require that the Range is well-formed.
//
//go:nosplit
func (r addrRange) WellFormed() bool {
return r.Start <= r.End
}
// Length returns the length of the range.
//
//go:nosplit
func (r addrRange) Length() uintptr {
return r.End - r.Start
}
// Contains returns true if r contains x.
//
//go:nosplit
func (r addrRange) Contains(x uintptr) bool {
return r.Start <= x && x < r.End
}
// Overlaps returns true if r and r2 overlap.
//
//go:nosplit
func (r addrRange) Overlaps(r2 addrRange) bool {
return r.Start < r2.End && r2.Start < r.End
}
// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is
// contained within r.
//
//go:nosplit
func (r addrRange) IsSupersetOf(r2 addrRange) bool {
return r.Start <= r2.Start && r.End >= r2.End
}
// Intersect returns a range consisting of the intersection between r and r2.
// If r and r2 do not overlap, Intersect returns a range with unspecified
// bounds, but for which Length() == 0.
//
//go:nosplit
func (r addrRange) Intersect(r2 addrRange) addrRange {
if r.Start < r2.Start {
r.Start = r2.Start
}
if r.End > r2.End {
r.End = r2.End
}
if r.End < r.Start {
r.End = r.Start
}
return r
}
// CanSplitAt returns true if it is legal to split a segment spanning the range
// r at x; that is, splitting at x would produce two ranges, both of which have
// non-zero length.
//
//go:nosplit
func (r addrRange) CanSplitAt(x uintptr) bool {
return r.Contains(x) && r.Start < x
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,239 @@
package state
// ElementMapper provides an identity mapping by default.
//
// This can be replaced to provide a struct that maps elements to linker
// objects, if they are not the same. An ElementMapper is not typically
// required if: Linker is left as is, Element is left as is, or Linker and
// Element are the same type.
type completeElementMapper struct{}
// linkerFor maps an Element to a Linker.
//
// This default implementation should be inlined.
//
//go:nosplit
func (completeElementMapper) linkerFor(elem *objectDecodeState) *objectDecodeState { return elem }
// List is an intrusive list. Entries can be added to or removed from the list
// in O(1) time and with no additional memory allocations.
//
// The zero value for List is an empty list ready to use.
//
// To iterate over a list (where l is a List):
//
// for e := l.Front(); e != nil; e = e.Next() {
// // do something with e.
// }
//
// +stateify savable
type completeList struct {
head *objectDecodeState
tail *objectDecodeState
}
// Reset resets list l to the empty state.
func (l *completeList) Reset() {
l.head = nil
l.tail = nil
}
// Empty returns true iff the list is empty.
//
//go:nosplit
func (l *completeList) Empty() bool {
return l.head == nil
}
// Front returns the first element of list l or nil.
//
//go:nosplit
func (l *completeList) Front() *objectDecodeState {
return l.head
}
// Back returns the last element of list l or nil.
//
//go:nosplit
func (l *completeList) Back() *objectDecodeState {
return l.tail
}
// Len returns the number of elements in the list.
//
// NOTE: This is an O(n) operation.
//
//go:nosplit
func (l *completeList) Len() (count int) {
for e := l.Front(); e != nil; e = (completeElementMapper{}.linkerFor(e)).Next() {
count++
}
return count
}
// PushFront inserts the element e at the front of list l.
//
//go:nosplit
func (l *completeList) PushFront(e *objectDecodeState) {
linker := completeElementMapper{}.linkerFor(e)
linker.SetNext(l.head)
linker.SetPrev(nil)
if l.head != nil {
completeElementMapper{}.linkerFor(l.head).SetPrev(e)
} else {
l.tail = e
}
l.head = e
}
// PushFrontList inserts list m at the start of list l, emptying m.
//
//go:nosplit
func (l *completeList) PushFrontList(m *completeList) {
if l.head == nil {
l.head = m.head
l.tail = m.tail
} else if m.head != nil {
completeElementMapper{}.linkerFor(l.head).SetPrev(m.tail)
completeElementMapper{}.linkerFor(m.tail).SetNext(l.head)
l.head = m.head
}
m.head = nil
m.tail = nil
}
// PushBack inserts the element e at the back of list l.
//
//go:nosplit
func (l *completeList) PushBack(e *objectDecodeState) {
linker := completeElementMapper{}.linkerFor(e)
linker.SetNext(nil)
linker.SetPrev(l.tail)
if l.tail != nil {
completeElementMapper{}.linkerFor(l.tail).SetNext(e)
} else {
l.head = e
}
l.tail = e
}
// PushBackList inserts list m at the end of list l, emptying m.
//
//go:nosplit
func (l *completeList) PushBackList(m *completeList) {
if l.head == nil {
l.head = m.head
l.tail = m.tail
} else if m.head != nil {
completeElementMapper{}.linkerFor(l.tail).SetNext(m.head)
completeElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
l.tail = m.tail
}
m.head = nil
m.tail = nil
}
// InsertAfter inserts e after b.
//
//go:nosplit
func (l *completeList) InsertAfter(b, e *objectDecodeState) {
bLinker := completeElementMapper{}.linkerFor(b)
eLinker := completeElementMapper{}.linkerFor(e)
a := bLinker.Next()
eLinker.SetNext(a)
eLinker.SetPrev(b)
bLinker.SetNext(e)
if a != nil {
completeElementMapper{}.linkerFor(a).SetPrev(e)
} else {
l.tail = e
}
}
// InsertBefore inserts e before a.
//
//go:nosplit
func (l *completeList) InsertBefore(a, e *objectDecodeState) {
aLinker := completeElementMapper{}.linkerFor(a)
eLinker := completeElementMapper{}.linkerFor(e)
b := aLinker.Prev()
eLinker.SetNext(a)
eLinker.SetPrev(b)
aLinker.SetPrev(e)
if b != nil {
completeElementMapper{}.linkerFor(b).SetNext(e)
} else {
l.head = e
}
}
// Remove removes e from l.
//
//go:nosplit
func (l *completeList) Remove(e *objectDecodeState) {
linker := completeElementMapper{}.linkerFor(e)
prev := linker.Prev()
next := linker.Next()
if prev != nil {
completeElementMapper{}.linkerFor(prev).SetNext(next)
} else if l.head == e {
l.head = next
}
if next != nil {
completeElementMapper{}.linkerFor(next).SetPrev(prev)
} else if l.tail == e {
l.tail = prev
}
linker.SetNext(nil)
linker.SetPrev(nil)
}
// Entry is a default implementation of Linker. Users can add anonymous fields
// of this type to their structs to make them automatically implement the
// methods needed by List.
//
// +stateify savable
type completeEntry struct {
next *objectDecodeState
prev *objectDecodeState
}
// Next returns the entry that follows e in the list.
//
//go:nosplit
func (e *completeEntry) Next() *objectDecodeState {
return e.next
}
// Prev returns the entry that precedes e in the list.
//
//go:nosplit
func (e *completeEntry) Prev() *objectDecodeState {
return e.prev
}
// SetNext assigns 'entry' as the entry that follows e in the list.
//
//go:nosplit
func (e *completeEntry) SetNext(elem *objectDecodeState) {
e.next = elem
}
// SetPrev assigns 'entry' as the entry that precedes e in the list.
//
//go:nosplit
func (e *completeEntry) SetPrev(elem *objectDecodeState) {
e.prev = elem
}

View File

@@ -0,0 +1,737 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package state
import (
"bytes"
"context"
"fmt"
"io"
"math"
"reflect"
"gvisor.dev/gvisor/pkg/state/wire"
)
// internalCallback is a interface called on object completion.
//
// There are two implementations: objectDecodeState & userCallback.
type internalCallback interface {
// source returns the dependent object. May be nil.
source() *objectDecodeState
// callbackRun executes the callback.
callbackRun()
}
// userCallback is an implementation of internalCallback.
type userCallback func()
// source implements internalCallback.source.
func (userCallback) source() *objectDecodeState {
return nil
}
// callbackRun implements internalCallback.callbackRun.
func (uc userCallback) callbackRun() {
uc()
}
// objectDecodeState represents an object that may be in the process of being
// decoded. Specifically, it represents either a decoded object, or an an
// interest in a future object that will be decoded. When that interest is
// registered (via register), the storage for the object will be created, but
// it will not be decoded until the object is encountered in the stream.
type objectDecodeState struct {
// id is the id for this object.
id objectID
// typ is the id for this typeID. This may be zero if this is not a
// type-registered structure.
typ typeID
// obj is the object. This may or may not be valid yet, depending on
// whether complete returns true. However, regardless of whether the
// object is valid, obj contains a final storage location for the
// object. This is immutable.
//
// Note that this must be addressable (obj.Addr() must not panic).
//
// The obj passed to the decode methods below will equal this obj only
// in the case of decoding the top-level object. However, the passed
// obj may represent individual fields, elements of a slice, etc. that
// are effectively embedded within the reflect.Value below but with
// distinct types.
obj reflect.Value
// blockedBy is the number of dependencies this object has.
blockedBy int
// callbacksInline is inline storage for callbacks.
callbacksInline [2]internalCallback
// callbacks is a set of callbacks to execute on load.
callbacks []internalCallback
completeEntry
}
// addCallback adds a callback to the objectDecodeState.
func (ods *objectDecodeState) addCallback(ic internalCallback) {
if ods.callbacks == nil {
ods.callbacks = ods.callbacksInline[:0]
}
ods.callbacks = append(ods.callbacks, ic)
}
// findCycleFor returns when the given object is found in the blocking set.
func (ods *objectDecodeState) findCycleFor(target *objectDecodeState) []*objectDecodeState {
for _, ic := range ods.callbacks {
other := ic.source()
if other != nil && other == target {
return []*objectDecodeState{target}
} else if childList := other.findCycleFor(target); childList != nil {
return append(childList, other)
}
}
// This should not occur.
Failf("no deadlock found?")
panic("unreachable")
}
// findCycle finds a dependency cycle.
func (ods *objectDecodeState) findCycle() []*objectDecodeState {
return append(ods.findCycleFor(ods), ods)
}
// source implements internalCallback.source.
func (ods *objectDecodeState) source() *objectDecodeState {
return ods
}
// callbackRun implements internalCallback.callbackRun.
func (ods *objectDecodeState) callbackRun() {
ods.blockedBy--
}
// decodeState is a graph of objects in the process of being decoded.
//
// The decode process involves loading the breadth-first graph generated by
// encode. This graph is read in it's entirety, ensuring that all object
// storage is complete.
//
// As the graph is being serialized, a set of completion callbacks are
// executed. These completion callbacks should form a set of acyclic subgraphs
// over the original one. After decoding is complete, the objects are scanned
// to ensure that all callbacks are executed, otherwise the callback graph was
// not acyclic.
type decodeState struct {
// ctx is the decode context.
ctx context.Context
// r is the input stream.
r io.Reader
// types is the type database.
types typeDecodeDatabase
// objectByID is the set of objects in progress.
objectsByID []*objectDecodeState
// deferred are objects that have been read, by no interest has been
// registered yet. These will be decoded once interest in registered.
deferred map[objectID]wire.Object
// pending is the set of objects that are not yet complete.
pending completeList
// stats tracks time data.
stats Stats
}
// lookup looks up an object in decodeState or returns nil if no such object
// has been previously registered.
func (ds *decodeState) lookup(id objectID) *objectDecodeState {
if len(ds.objectsByID) < int(id) {
return nil
}
return ds.objectsByID[id-1]
}
// checkComplete checks for completion.
func (ds *decodeState) checkComplete(ods *objectDecodeState) bool {
// Still blocked?
if ods.blockedBy > 0 {
return false
}
// Track stats if relevant.
if ods.callbacks != nil && ods.typ != 0 {
ds.stats.start(ods.typ)
defer ds.stats.done()
}
// Fire all callbacks.
for _, ic := range ods.callbacks {
ic.callbackRun()
}
// Mark completed.
cbs := ods.callbacks
ods.callbacks = nil
ds.pending.Remove(ods)
// Recursively check others.
for _, ic := range cbs {
if other := ic.source(); other != nil && other.blockedBy == 0 {
ds.checkComplete(other)
}
}
return true // All set.
}
// wait registers a dependency on an object.
//
// As a special case, we always allow _useable_ references back to the first
// decoding object because it may have fields that are already decoded. We also
// allow trivial self reference, since they can be handled internally.
func (ds *decodeState) wait(waiter *objectDecodeState, id objectID, callback func()) {
switch id {
case waiter.id:
// Trivial self reference.
fallthrough
case 1:
// Root object; see above.
if callback != nil {
callback()
}
return
}
// Mark as blocked.
waiter.blockedBy++
// No nil can be returned here.
other := ds.lookup(id)
if callback != nil {
// Add the additional user callback.
other.addCallback(userCallback(callback))
}
// Mark waiter as unblocked.
other.addCallback(waiter)
}
// waitObject notes a blocking relationship.
func (ds *decodeState) waitObject(ods *objectDecodeState, encoded wire.Object, callback func()) {
if rv, ok := encoded.(*wire.Ref); ok && rv.Root != 0 {
// Refs can encode pointers and maps.
ds.wait(ods, objectID(rv.Root), callback)
} else if sv, ok := encoded.(*wire.Slice); ok && sv.Ref.Root != 0 {
// See decodeObject; we need to wait for the array (if non-nil).
ds.wait(ods, objectID(sv.Ref.Root), callback)
} else if iv, ok := encoded.(*wire.Interface); ok {
// It's an interface (wait recursively).
ds.waitObject(ods, iv.Value, callback)
} else if callback != nil {
// Nothing to wait for: execute the callback immediately.
callback()
}
}
// walkChild returns a child object from obj, given an accessor path. This is
// the decode-side equivalent to traverse in encode.go.
//
// For the purposes of this function, a child object is either a field within a
// struct or an array element, with one such indirection per element in
// path. The returned value may be an unexported field, so it may not be
// directly assignable. See decode_unsafe.go.
func walkChild(path []wire.Dot, obj reflect.Value) reflect.Value {
// See wire.Ref.Dots. The path here is specified in reverse order.
for i := len(path) - 1; i >= 0; i-- {
switch pc := path[i].(type) {
case *wire.FieldName: // Must be a pointer.
if obj.Kind() != reflect.Struct {
Failf("next component in child path is a field name, but the current object is not a struct. Path: %v, current obj: %#v", path, obj)
}
obj = obj.FieldByName(string(*pc))
case wire.Index: // Embedded.
if obj.Kind() != reflect.Array {
Failf("next component in child path is an array index, but the current object is not an array. Path: %v, current obj: %#v", path, obj)
}
obj = obj.Index(int(pc))
default:
panic("unreachable: switch should be exhaustive")
}
}
return obj
}
// register registers a decode with a type.
//
// This type is only used to instantiate a new object if it has not been
// registered previously. This depends on the type provided if none is
// available in the object itself.
func (ds *decodeState) register(r *wire.Ref, typ reflect.Type) reflect.Value {
// Grow the objectsByID slice.
id := objectID(r.Root)
if len(ds.objectsByID) < int(id) {
ds.objectsByID = append(ds.objectsByID, make([]*objectDecodeState, int(id)-len(ds.objectsByID))...)
}
// Does this object already exist?
ods := ds.objectsByID[id-1]
if ods != nil {
return walkChild(r.Dots, ods.obj)
}
// Create the object.
if len(r.Dots) != 0 {
typ = ds.findType(r.Type)
}
v := reflect.New(typ)
ods = &objectDecodeState{
id: id,
obj: v.Elem(),
}
ds.objectsByID[id-1] = ods
ds.pending.PushBack(ods)
// Process any deferred objects & callbacks.
if encoded, ok := ds.deferred[id]; ok {
delete(ds.deferred, id)
ds.decodeObject(ods, ods.obj, encoded)
}
return walkChild(r.Dots, ods.obj)
}
// objectDecoder is for decoding structs.
type objectDecoder struct {
// ds is decodeState.
ds *decodeState
// ods is current object being decoded.
ods *objectDecodeState
// reconciledTypeEntry is the reconciled type information.
rte *reconciledTypeEntry
// encoded is the encoded object state.
encoded *wire.Struct
}
// load is helper for the public methods on Source.
func (od *objectDecoder) load(slot int, objPtr reflect.Value, wait bool, fn func()) {
// Note that we have reconciled the type and may remap the fields here
// to match what's expected by the decoder. The "slot" parameter here
// is in terms of the local type, where the fields in the encoded
// object are in terms of the wire object's type, which might be in a
// different order (but will have the same fields).
v := *od.encoded.Field(od.rte.FieldOrder[slot])
od.ds.decodeObject(od.ods, objPtr.Elem(), v)
if wait {
// Mark this individual object a blocker.
od.ds.waitObject(od.ods, v, fn)
}
}
// aterLoad implements Source.AfterLoad.
func (od *objectDecoder) afterLoad(fn func()) {
// Queue the local callback; this will execute when all of the above
// data dependencies have been cleared.
od.ods.addCallback(userCallback(fn))
}
// decodeStruct decodes a struct value.
func (ds *decodeState) decodeStruct(ods *objectDecodeState, obj reflect.Value, encoded *wire.Struct) {
if encoded.TypeID == 0 {
// Allow anonymous empty structs, but only if the encoded
// object also has no fields.
if encoded.Fields() == 0 && obj.NumField() == 0 {
return
}
// Propagate an error.
Failf("empty struct on wire %#v has field mismatch with type %q", encoded, obj.Type().Name())
}
// Lookup the object type.
rte := ds.types.Lookup(typeID(encoded.TypeID), obj.Type())
ods.typ = typeID(encoded.TypeID)
// Invoke the loader.
od := objectDecoder{
ds: ds,
ods: ods,
rte: rte,
encoded: encoded,
}
ds.stats.start(ods.typ)
defer ds.stats.done()
if sl, ok := obj.Addr().Interface().(SaverLoader); ok {
// Note: may be a registered empty struct which does not
// implement the saver/loader interfaces.
sl.StateLoad(ds.ctx, Source{internal: od})
}
}
// decodeMap decodes a map value.
func (ds *decodeState) decodeMap(ods *objectDecodeState, obj reflect.Value, encoded *wire.Map) {
if obj.IsNil() {
// See pointerTo.
obj.Set(reflect.MakeMap(obj.Type()))
}
for i := 0; i < len(encoded.Keys); i++ {
// Decode the objects.
kv := reflect.New(obj.Type().Key()).Elem()
vv := reflect.New(obj.Type().Elem()).Elem()
ds.decodeObject(ods, kv, encoded.Keys[i])
ds.decodeObject(ods, vv, encoded.Values[i])
ds.waitObject(ods, encoded.Keys[i], nil)
ds.waitObject(ods, encoded.Values[i], nil)
// Set in the map.
obj.SetMapIndex(kv, vv)
}
}
// decodeArray decodes an array value.
func (ds *decodeState) decodeArray(ods *objectDecodeState, obj reflect.Value, encoded *wire.Array) {
if len(encoded.Contents) != obj.Len() {
Failf("mismatching array length expect=%d, actual=%d", obj.Len(), len(encoded.Contents))
}
// Decode the contents into the array.
for i := 0; i < len(encoded.Contents); i++ {
ds.decodeObject(ods, obj.Index(i), encoded.Contents[i])
ds.waitObject(ods, encoded.Contents[i], nil)
}
}
// findType finds the type for the given wire.TypeSpecs.
func (ds *decodeState) findType(t wire.TypeSpec) reflect.Type {
switch x := t.(type) {
case wire.TypeID:
typ := ds.types.LookupType(typeID(x))
rte := ds.types.Lookup(typeID(x), typ)
return rte.LocalType
case *wire.TypeSpecPointer:
return reflect.PtrTo(ds.findType(x.Type))
case *wire.TypeSpecArray:
return reflect.ArrayOf(int(x.Count), ds.findType(x.Type))
case *wire.TypeSpecSlice:
return reflect.SliceOf(ds.findType(x.Type))
case *wire.TypeSpecMap:
return reflect.MapOf(ds.findType(x.Key), ds.findType(x.Value))
default:
// Should not happen.
Failf("unknown type %#v", t)
}
panic("unreachable")
}
// decodeInterface decodes an interface value.
func (ds *decodeState) decodeInterface(ods *objectDecodeState, obj reflect.Value, encoded *wire.Interface) {
if _, ok := encoded.Type.(wire.TypeSpecNil); ok {
// Special case; the nil object. Just decode directly, which
// will read nil from the wire (if encoded correctly).
ds.decodeObject(ods, obj, encoded.Value)
return
}
// We now need to resolve the actual type.
typ := ds.findType(encoded.Type)
// We need to imbue type information here, then we can proceed to
// decode normally. In order to avoid issues with setting value-types,
// we create a new non-interface version of this object. We will then
// set the interface object to be equal to whatever we decode.
origObj := obj
obj = reflect.New(typ).Elem()
defer origObj.Set(obj)
// With the object now having sufficient type information to actually
// have Set called on it, we can proceed to decode the value.
ds.decodeObject(ods, obj, encoded.Value)
}
// isFloatEq determines if x and y represent the same value.
func isFloatEq(x float64, y float64) bool {
switch {
case math.IsNaN(x):
return math.IsNaN(y)
case math.IsInf(x, 1):
return math.IsInf(y, 1)
case math.IsInf(x, -1):
return math.IsInf(y, -1)
default:
return x == y
}
}
// isComplexEq determines if x and y represent the same value.
func isComplexEq(x complex128, y complex128) bool {
return isFloatEq(real(x), real(y)) && isFloatEq(imag(x), imag(y))
}
// decodeObject decodes a object value.
func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, encoded wire.Object) {
switch x := encoded.(type) {
case wire.Nil: // Fast path: first.
// We leave obj alone here. That's because if obj represents an
// interface, it may have been imbued with type information in
// decodeInterface, and we don't want to destroy that.
case *wire.Ref:
// Nil pointers may be encoded in a "forceValue" context. For
// those we just leave it alone as the value will already be
// correct (nil).
if id := objectID(x.Root); id == 0 {
return
}
// Note that if this is a map type, we go through a level of
// indirection to allow for map aliasing.
if obj.Kind() == reflect.Map {
v := ds.register(x, obj.Type())
if v.IsNil() {
// Note that we don't want to clobber the map
// if has already been decoded by decodeMap. We
// just make it so that we have a consistent
// reference when that eventually does happen.
v.Set(reflect.MakeMap(v.Type()))
}
obj.Set(v)
return
}
// Normal assignment: authoritative only if no dots.
v := ds.register(x, obj.Type().Elem())
obj.Set(reflectValueRWAddr(v))
case wire.Bool:
obj.SetBool(bool(x))
case wire.Int:
obj.SetInt(int64(x))
if obj.Int() != int64(x) {
Failf("signed integer truncated from %v to %v", int64(x), obj.Int())
}
case wire.Uint:
obj.SetUint(uint64(x))
if obj.Uint() != uint64(x) {
Failf("unsigned integer truncated from %v to %v", uint64(x), obj.Uint())
}
case wire.Float32:
obj.SetFloat(float64(x))
case wire.Float64:
obj.SetFloat(float64(x))
if !isFloatEq(obj.Float(), float64(x)) {
Failf("floating point number truncated from %v to %v", float64(x), obj.Float())
}
case *wire.Complex64:
obj.SetComplex(complex128(*x))
case *wire.Complex128:
obj.SetComplex(complex128(*x))
if !isComplexEq(obj.Complex(), complex128(*x)) {
Failf("complex number truncated from %v to %v", complex128(*x), obj.Complex())
}
case *wire.String:
obj.SetString(string(*x))
case *wire.Slice:
// See *wire.Ref above; same applies.
if id := objectID(x.Ref.Root); id == 0 {
return
}
// Note that it's fine to slice the array here and assume that
// contents will still be filled in later on.
typ := reflect.ArrayOf(int(x.Capacity), obj.Type().Elem()) // The object type.
v := ds.register(&x.Ref, typ)
obj.Set(reflectValueRWSlice3(v, 0, int(x.Length), int(x.Capacity)))
case *wire.Array:
ds.decodeArray(ods, obj, x)
case *wire.Struct:
ds.decodeStruct(ods, obj, x)
case *wire.Map:
ds.decodeMap(ods, obj, x)
case *wire.Interface:
ds.decodeInterface(ods, obj, x)
default:
// Should not happen, not propagated as an error.
Failf("unknown object %#v for %q", encoded, obj.Type().Name())
}
}
// Load deserializes the object graph rooted at obj.
//
// This function may panic and should be run in safely().
func (ds *decodeState) Load(obj reflect.Value) {
ds.stats.init()
defer ds.stats.fini(func(id typeID) string {
return ds.types.LookupName(id)
})
// Create the root object.
rootOds := &objectDecodeState{
id: 1,
obj: obj,
}
ds.objectsByID = append(ds.objectsByID, rootOds)
ds.pending.PushBack(rootOds)
// Read the number of objects.
numObjects, object, err := ReadHeader(ds.r)
if err != nil {
Failf("header error: %w", err)
}
if !object {
Failf("object missing")
}
// Decode all objects.
var (
encoded wire.Object
ods *objectDecodeState
id objectID
tid = typeID(1)
)
if err := safely(func() {
// Decode all objects in the stream.
//
// Note that the structure of this decoding loop should match the raw
// decoding loop in state/pretty/pretty.printer.printStream().
for i := uint64(0); i < numObjects; {
// Unmarshal either a type object or object ID.
encoded = wire.Load(ds.r)
switch we := encoded.(type) {
case *wire.Type:
ds.types.Register(we)
tid++
encoded = nil
continue
case wire.Uint:
id = objectID(we)
i++
// Unmarshal and resolve the actual object.
encoded = wire.Load(ds.r)
ods = ds.lookup(id)
if ods != nil {
// Decode the object.
ds.decodeObject(ods, ods.obj, encoded)
} else {
// If an object hasn't had interest registered
// previously or isn't yet valid, we deferred
// decoding until interest is registered.
ds.deferred[id] = encoded
}
// For error handling.
ods = nil
encoded = nil
default:
Failf("wanted type or object ID, got %T", encoded)
}
}
}); err != nil {
// Include as much information as we can, taking into account
// the possible state transitions above.
if ods != nil {
Failf("error decoding object ID %d (%T) from %#v: %w", id, ods.obj.Interface(), encoded, err)
} else if encoded != nil {
Failf("error decoding from %#v: %w", encoded, err)
} else {
Failf("general decoding error: %w", err)
}
}
// Check if we have any deferred objects.
numDeferred := 0
for id, encoded := range ds.deferred {
numDeferred++
if s, ok := encoded.(*wire.Struct); ok && s.TypeID != 0 {
typ := ds.types.LookupType(typeID(s.TypeID))
Failf("unused deferred object: ID %d, type %v", id, typ)
} else {
Failf("unused deferred object: ID %d, %#v", id, encoded)
}
}
if numDeferred != 0 {
Failf("still had %d deferred objects", numDeferred)
}
// Scan and fire all callbacks. We iterate over the list of incomplete
// objects until all have been finished. We stop iterating if no
// objects become complete (there is a dependency cycle).
//
// Note that we iterate backwards here, because there will be a strong
// tendendcy for blocking relationships to go from earlier objects to
// later (deeper) objects in the graph. This will reduce the number of
// iterations required to finish all objects.
if err := safely(func() {
for ds.pending.Back() != nil {
thisCycle := false
for ods = ds.pending.Back(); ods != nil; {
if ds.checkComplete(ods) {
thisCycle = true
break
}
ods = ods.Prev()
}
if !thisCycle {
break
}
}
}); err != nil {
Failf("error executing callbacks: %w\nfor object %#v", err, ods.obj.Interface())
}
// Check if we have any remaining dependency cycles. If there are any
// objects left in the pending list, then it must be due to a cycle.
if ods := ds.pending.Front(); ods != nil {
// This must be the result of a dependency cycle.
cycle := ods.findCycle()
var buf bytes.Buffer
buf.WriteString("dependency cycle: {")
for i, cycleOS := range cycle {
if i > 0 {
buf.WriteString(" => ")
}
fmt.Fprintf(&buf, "%q", cycleOS.obj.Type())
}
buf.WriteString("}")
Failf("incomplete graph: %s", string(buf.Bytes()))
}
}
// ReadHeader reads an object header.
//
// Each object written to the statefile is prefixed with a header. See
// WriteHeader for more information; these functions are exported to allow
// non-state writes to the file to play nice with debugging tools.
func ReadHeader(r io.Reader) (length uint64, object bool, err error) {
// Read the header.
err = safely(func() {
length = wire.LoadUint(r)
})
if err != nil {
// On the header, pass raw I/O errors.
if sErr, ok := err.(*ErrState); ok {
return 0, false, sErr.Unwrap()
}
}
// Decode whether the object is valid.
object = length&objectFlag != 0
length &^= objectFlag
return
}

View File

@@ -0,0 +1,76 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package state
import (
"fmt"
"reflect"
"runtime"
"unsafe"
)
// reflectValueRWAddr is equivalent to obj.Addr(), except that the returned
// reflect.Value is usable in assignments even if obj was obtained by the use
// of unexported struct fields.
//
// Preconditions: obj.CanAddr().
func reflectValueRWAddr(obj reflect.Value) reflect.Value {
return reflect.NewAt(obj.Type(), unsafe.Pointer(obj.UnsafeAddr()))
}
// reflectValueRWSlice3 is equivalent to arr.Slice3(i, j, k), except that the
// returned reflect.Value is usable in assignments even if obj was obtained by
// the use of unexported struct fields.
//
// Preconditions:
// - arr.Kind() == reflect.Array.
// - i, j, k >= 0.
// - i <= j <= k <= arr.Len().
func reflectValueRWSlice3(arr reflect.Value, i, j, k int) reflect.Value {
if arr.Kind() != reflect.Array {
panic(fmt.Sprintf("arr has kind %v, wanted %v", arr.Kind(), reflect.Array))
}
if i < 0 || j < 0 || k < 0 {
panic(fmt.Sprintf("negative subscripts (%d, %d, %d)", i, j, k))
}
if i > j {
panic(fmt.Sprintf("subscript i (%d) > j (%d)", i, j))
}
if j > k {
panic(fmt.Sprintf("subscript j (%d) > k (%d)", j, k))
}
if k > arr.Len() {
panic(fmt.Sprintf("subscript k (%d) > array length (%d)", k, arr.Len()))
}
sliceTyp := reflect.SliceOf(arr.Type().Elem())
if i == arr.Len() {
// By precondition, i == j == k == arr.Len().
return reflect.MakeSlice(sliceTyp, 0, 0)
}
slh := reflect.SliceHeader{
// reflect.Value.CanAddr() == false for arrays, so we need to get the
// address from the first element of the array.
Data: arr.Index(i).UnsafeAddr(),
Len: j - i,
Cap: k - i,
}
slobj := reflect.NewAt(sliceTyp, unsafe.Pointer(&slh)).Elem()
// Before slobj is constructed, arr holds the only pointer-typed pointer to
// the array since reflect.SliceHeader.Data is a uintptr, so arr must be
// kept alive.
runtime.KeepAlive(arr)
return slobj
}

View File

@@ -0,0 +1,239 @@
package state
// ElementMapper provides an identity mapping by default.
//
// This can be replaced to provide a struct that maps elements to linker
// objects, if they are not the same. An ElementMapper is not typically
// required if: Linker is left as is, Element is left as is, or Linker and
// Element are the same type.
type deferredElementMapper struct{}
// linkerFor maps an Element to a Linker.
//
// This default implementation should be inlined.
//
//go:nosplit
func (deferredElementMapper) linkerFor(elem *objectEncodeState) *objectEncodeState { return elem }
// List is an intrusive list. Entries can be added to or removed from the list
// in O(1) time and with no additional memory allocations.
//
// The zero value for List is an empty list ready to use.
//
// To iterate over a list (where l is a List):
//
// for e := l.Front(); e != nil; e = e.Next() {
// // do something with e.
// }
//
// +stateify savable
type deferredList struct {
head *objectEncodeState
tail *objectEncodeState
}
// Reset resets list l to the empty state.
func (l *deferredList) Reset() {
l.head = nil
l.tail = nil
}
// Empty returns true iff the list is empty.
//
//go:nosplit
func (l *deferredList) Empty() bool {
return l.head == nil
}
// Front returns the first element of list l or nil.
//
//go:nosplit
func (l *deferredList) Front() *objectEncodeState {
return l.head
}
// Back returns the last element of list l or nil.
//
//go:nosplit
func (l *deferredList) Back() *objectEncodeState {
return l.tail
}
// Len returns the number of elements in the list.
//
// NOTE: This is an O(n) operation.
//
//go:nosplit
func (l *deferredList) Len() (count int) {
for e := l.Front(); e != nil; e = (deferredElementMapper{}.linkerFor(e)).Next() {
count++
}
return count
}
// PushFront inserts the element e at the front of list l.
//
//go:nosplit
func (l *deferredList) PushFront(e *objectEncodeState) {
linker := deferredElementMapper{}.linkerFor(e)
linker.SetNext(l.head)
linker.SetPrev(nil)
if l.head != nil {
deferredElementMapper{}.linkerFor(l.head).SetPrev(e)
} else {
l.tail = e
}
l.head = e
}
// PushFrontList inserts list m at the start of list l, emptying m.
//
//go:nosplit
func (l *deferredList) PushFrontList(m *deferredList) {
if l.head == nil {
l.head = m.head
l.tail = m.tail
} else if m.head != nil {
deferredElementMapper{}.linkerFor(l.head).SetPrev(m.tail)
deferredElementMapper{}.linkerFor(m.tail).SetNext(l.head)
l.head = m.head
}
m.head = nil
m.tail = nil
}
// PushBack inserts the element e at the back of list l.
//
//go:nosplit
func (l *deferredList) PushBack(e *objectEncodeState) {
linker := deferredElementMapper{}.linkerFor(e)
linker.SetNext(nil)
linker.SetPrev(l.tail)
if l.tail != nil {
deferredElementMapper{}.linkerFor(l.tail).SetNext(e)
} else {
l.head = e
}
l.tail = e
}
// PushBackList inserts list m at the end of list l, emptying m.
//
//go:nosplit
func (l *deferredList) PushBackList(m *deferredList) {
if l.head == nil {
l.head = m.head
l.tail = m.tail
} else if m.head != nil {
deferredElementMapper{}.linkerFor(l.tail).SetNext(m.head)
deferredElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
l.tail = m.tail
}
m.head = nil
m.tail = nil
}
// InsertAfter inserts e after b.
//
//go:nosplit
func (l *deferredList) InsertAfter(b, e *objectEncodeState) {
bLinker := deferredElementMapper{}.linkerFor(b)
eLinker := deferredElementMapper{}.linkerFor(e)
a := bLinker.Next()
eLinker.SetNext(a)
eLinker.SetPrev(b)
bLinker.SetNext(e)
if a != nil {
deferredElementMapper{}.linkerFor(a).SetPrev(e)
} else {
l.tail = e
}
}
// InsertBefore inserts e before a.
//
//go:nosplit
func (l *deferredList) InsertBefore(a, e *objectEncodeState) {
aLinker := deferredElementMapper{}.linkerFor(a)
eLinker := deferredElementMapper{}.linkerFor(e)
b := aLinker.Prev()
eLinker.SetNext(a)
eLinker.SetPrev(b)
aLinker.SetPrev(e)
if b != nil {
deferredElementMapper{}.linkerFor(b).SetNext(e)
} else {
l.head = e
}
}
// Remove removes e from l.
//
//go:nosplit
func (l *deferredList) Remove(e *objectEncodeState) {
linker := deferredElementMapper{}.linkerFor(e)
prev := linker.Prev()
next := linker.Next()
if prev != nil {
deferredElementMapper{}.linkerFor(prev).SetNext(next)
} else if l.head == e {
l.head = next
}
if next != nil {
deferredElementMapper{}.linkerFor(next).SetPrev(prev)
} else if l.tail == e {
l.tail = prev
}
linker.SetNext(nil)
linker.SetPrev(nil)
}
// Entry is a default implementation of Linker. Users can add anonymous fields
// of this type to their structs to make them automatically implement the
// methods needed by List.
//
// +stateify savable
type deferredEntry struct {
next *objectEncodeState
prev *objectEncodeState
}
// Next returns the entry that follows e in the list.
//
//go:nosplit
func (e *deferredEntry) Next() *objectEncodeState {
return e.next
}
// Prev returns the entry that precedes e in the list.
//
//go:nosplit
func (e *deferredEntry) Prev() *objectEncodeState {
return e.prev
}
// SetNext assigns 'entry' as the entry that follows e in the list.
//
//go:nosplit
func (e *deferredEntry) SetNext(elem *objectEncodeState) {
e.next = elem
}
// SetPrev assigns 'entry' as the entry that precedes e in the list.
//
//go:nosplit
func (e *deferredEntry) SetPrev(elem *objectEncodeState) {
e.prev = elem
}

View File

@@ -0,0 +1,874 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package state
import (
"context"
"io"
"reflect"
"sort"
"gvisor.dev/gvisor/pkg/state/wire"
)
// objectEncodeState the type and identity of an object occupying a memory
// address range. This is the value type for addrSet, and the intrusive entry
// for the deferred list.
type objectEncodeState struct {
// id is the assigned ID for this object.
id objectID
// obj is the object value. Note that this may be replaced if we
// encounter an object that contains this object. When this happens (in
// resolve), we will update existing references appropriately, below,
// and defer a re-encoding of the object.
obj reflect.Value
// encoded is the encoded value of this object. Note that this may not
// be up to date if this object is still in the deferred list.
encoded wire.Object
// how indicates whether this object should be encoded as a value. This
// is used only for deferred encoding.
how encodeStrategy
// refs are the list of reference objects used by other objects
// referring to this object. When the object is updated, these
// references may be updated directly and automatically.
refs []*wire.Ref
deferredEntry
}
// encodeState is state used for encoding.
//
// The encoding process constructs a representation of the in-memory graph of
// objects before a single object is serialized. This is done to ensure that
// all references can be fully disambiguated. See resolve for more details.
type encodeState struct {
// ctx is the encode context.
ctx context.Context
// w is the output stream.
w io.Writer
// types is the type database.
types typeEncodeDatabase
// lastID is the last allocated object ID.
lastID objectID
// values tracks the address ranges occupied by objects, along with the
// types of these objects. This is used to locate pointer targets,
// including pointers to fields within another type.
//
// Multiple objects may overlap in memory iff the larger object fully
// contains the smaller one, and the type of the smaller object matches
// a field or array element's type at the appropriate offset. An
// arbitrary number of objects may be nested in this manner.
//
// Note that this does not track zero-sized objects, those are tracked
// by zeroValues below.
values addrSet
// zeroValues tracks zero-sized objects.
zeroValues map[reflect.Type]*objectEncodeState
// deferred is the list of objects to be encoded.
deferred deferredList
// pendingTypes is the list of types to be serialized. Serialization
// will occur when all objects have been encoded, but before pending is
// serialized.
pendingTypes []wire.Type
// pending maps object IDs to objects to be serialized. Serialization does
// not actually occur until the full object graph is computed.
pending map[objectID]*objectEncodeState
// encodedStructs maps reflect.Values representing structs to previous
// encodings of those structs. This is necessary to avoid duplicate calls
// to SaverLoader.StateSave() that may result in multiple calls to
// Sink.SaveValue() for a given field, resulting in object duplication.
encodedStructs map[reflect.Value]*wire.Struct
// stats tracks time data.
stats Stats
}
// isSameSizeParent returns true if child is a field value or element within
// parent. Only a struct or array can have a child value.
//
// isSameSizeParent deals with objects like this:
//
// struct child {
// // fields..
// }
//
// struct parent {
// c child
// }
//
// var p parent
// record(&p.c)
//
// Here, &p and &p.c occupy the exact same address range.
//
// Or like this:
//
// struct child {
// // fields
// }
//
// var arr [1]parent
// record(&arr[0])
//
// Similarly, &arr[0] and &arr[0].c have the exact same address range.
//
// Precondition: parent and child must occupy the same memory.
func isSameSizeParent(parent reflect.Value, childType reflect.Type) bool {
switch parent.Kind() {
case reflect.Struct:
for i := 0; i < parent.NumField(); i++ {
field := parent.Field(i)
if field.Type() == childType {
return true
}
// Recurse through any intermediate types.
if isSameSizeParent(field, childType) {
return true
}
// Does it make sense to keep going if the first field
// doesn't match? Yes, because there might be an
// arbitrary number of zero-sized fields before we get
// a match, and childType itself can be zero-sized.
}
return false
case reflect.Array:
// The only case where an array with more than one elements can
// return true is if childType is zero-sized. In such cases,
// it's ambiguous which element contains the match since a
// zero-sized child object fully fits in any of the zero-sized
// elements in an array... However since all elements are of
// the same type, we only need to check one element.
//
// For non-zero-sized childTypes, parent.Len() must be 1, but a
// combination of the precondition and an implicit comparison
// between the array element size and childType ensures this.
return parent.Len() > 0 && isSameSizeParent(parent.Index(0), childType)
default:
return false
}
}
// nextID returns the next valid ID.
func (es *encodeState) nextID() objectID {
es.lastID++
return objectID(es.lastID)
}
// dummyAddr points to the dummy zero-sized address.
var dummyAddr = reflect.ValueOf(new(struct{})).Pointer()
// resolve records the address range occupied by an object.
func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) {
addr := obj.Pointer()
// Is this a map pointer? Just record the single address. It is not
// possible to take any pointers into the map internals.
if obj.Kind() == reflect.Map {
if addr == 0 {
// Just leave the nil reference alone. This is fine, we
// may need to encode as a reference in this way. We
// return nil for our objectEncodeState so that anyone
// depending on this value knows there's nothing there.
return
}
seg, gap := es.values.Find(addr)
if seg.Ok() {
// Ensure the map types match.
existing := seg.Value()
if existing.obj.Type() != obj.Type() {
Failf("overlapping map objects at 0x%x: [new object] %#v [existing object type] %s", addr, obj, existing.obj)
}
// No sense recording refs, maps may not be replaced by
// covering objects, they are maximal.
ref.Root = wire.Uint(existing.id)
return
}
// Record the map.
r := addrRange{addr, addr + 1}
oes := &objectEncodeState{
id: es.nextID(),
obj: obj,
how: encodeMapAsValue,
}
// Use Insert instead of InsertWithoutMergingUnchecked when race
// detection is enabled to get additional sanity-checking from Merge.
if !raceEnabled {
es.values.InsertWithoutMergingUnchecked(gap, r, oes)
} else {
es.values.Insert(gap, r, oes)
}
es.pending[oes.id] = oes
es.deferred.PushBack(oes)
// See above: no ref recording.
ref.Root = wire.Uint(oes.id)
return
}
// If not a map, then the object must be a pointer.
if obj.Kind() != reflect.Ptr {
Failf("attempt to record non-map and non-pointer object %#v", obj)
}
obj = obj.Elem() // Value from here.
// Is this a zero-sized type?
typ := obj.Type()
size := typ.Size()
if size == 0 {
if addr == dummyAddr {
// Zero-sized objects point to a dummy byte within the
// runtime. There's no sense recording this in the
// address map. We add this to the dedicated
// zeroValues.
//
// Note that zero-sized objects must be *true*
// zero-sized objects. They cannot be part of some
// larger object. In that case, they are assigned a
// 1-byte address at the end of the object.
oes, ok := es.zeroValues[typ]
if !ok {
oes = &objectEncodeState{
id: es.nextID(),
obj: obj,
}
es.zeroValues[typ] = oes
es.pending[oes.id] = oes
es.deferred.PushBack(oes)
}
// There's also no sense tracking back references. We
// know that this is a true zero-sized object, and not
// part of a larger container, so it will not change.
ref.Root = wire.Uint(oes.id)
return
}
size = 1 // See above.
}
end := addr + size
r := addrRange{addr, end}
seg := es.values.LowerBoundSegment(addr)
var (
oes *objectEncodeState
gap addrGapIterator
)
// Does at least one previously-registered object overlap this one?
if seg.Ok() && seg.Start() < end {
existing := seg.Value()
if seg.Range() == r && typ == existing.obj.Type() {
// This exact object is already registered. Avoid the traversal and
// just return directly. We don't need to encode the type
// information or any dots here.
ref.Root = wire.Uint(existing.id)
existing.refs = append(existing.refs, ref)
return
}
if seg.Range().IsSupersetOf(r) && (seg.Range() != r || isSameSizeParent(existing.obj, typ)) {
// This object is contained within a previously-registered object.
// Perform traversal from the container to the new object.
ref.Root = wire.Uint(existing.id)
ref.Dots = traverse(existing.obj.Type(), typ, seg.Start(), addr)
ref.Type = es.findType(existing.obj.Type())
existing.refs = append(existing.refs, ref)
return
}
// This object contains one or more previously-registered objects.
// Remove them and update existing references to use the new one.
oes := &objectEncodeState{
// Reuse the root ID of the first contained element.
id: existing.id,
obj: obj,
}
type elementEncodeState struct {
addr uintptr
typ reflect.Type
refs []*wire.Ref
}
var (
elems []elementEncodeState
gap addrGapIterator
)
for {
// Each contained object should be completely contained within
// this one.
if raceEnabled && !r.IsSupersetOf(seg.Range()) {
Failf("containing object %#v does not contain existing object %#v", obj, existing.obj)
}
elems = append(elems, elementEncodeState{
addr: seg.Start(),
typ: existing.obj.Type(),
refs: existing.refs,
})
delete(es.pending, existing.id)
es.deferred.Remove(existing)
gap = es.values.Remove(seg)
seg = gap.NextSegment()
if !seg.Ok() || seg.Start() >= end {
break
}
existing = seg.Value()
}
wt := es.findType(typ)
for _, elem := range elems {
dots := traverse(typ, elem.typ, addr, elem.addr)
for _, ref := range elem.refs {
ref.Root = wire.Uint(oes.id)
ref.Dots = append(ref.Dots, dots...)
ref.Type = wt
}
oes.refs = append(oes.refs, elem.refs...)
}
// Finally register the new containing object.
if !raceEnabled {
es.values.InsertWithoutMergingUnchecked(gap, r, oes)
} else {
es.values.Insert(gap, r, oes)
}
es.pending[oes.id] = oes
es.deferred.PushBack(oes)
ref.Root = wire.Uint(oes.id)
oes.refs = append(oes.refs, ref)
return
}
// No existing object overlaps this one. Register a new object.
oes = &objectEncodeState{
id: es.nextID(),
obj: obj,
}
if seg.Ok() {
gap = seg.PrevGap()
} else {
gap = es.values.LastGap()
}
if !raceEnabled {
es.values.InsertWithoutMergingUnchecked(gap, r, oes)
} else {
es.values.Insert(gap, r, oes)
}
es.pending[oes.id] = oes
es.deferred.PushBack(oes)
ref.Root = wire.Uint(oes.id)
oes.refs = append(oes.refs, ref)
}
// traverse searches for a target object within a root object, where the target
// object is a struct field or array element within root, with potentially
// multiple intervening types. traverse returns the set of field or element
// traversals required to reach the target.
//
// Note that for efficiency, traverse returns the dots in the reverse order.
// That is, the first traversal required will be the last element of the list.
//
// Precondition: The target object must lie completely within the range defined
// by [rootAddr, rootAddr + sizeof(rootType)].
func traverse(rootType, targetType reflect.Type, rootAddr, targetAddr uintptr) []wire.Dot {
// Recursion base case: the types actually match.
if targetType == rootType && targetAddr == rootAddr {
return nil
}
switch rootType.Kind() {
case reflect.Struct:
offset := targetAddr - rootAddr
for i := rootType.NumField(); i > 0; i-- {
field := rootType.Field(i - 1)
// The first field from the end with an offset that is
// smaller than or equal to our address offset is where
// the target is located. Traverse from there.
if field.Offset <= offset {
dots := traverse(field.Type, targetType, rootAddr+field.Offset, targetAddr)
fieldName := wire.FieldName(field.Name)
return append(dots, &fieldName)
}
}
// Should never happen; the target should be reachable.
Failf("no field in root type %v contains target type %v", rootType, targetType)
case reflect.Array:
// Since arrays have homogeneous types, all elements have the
// same size and we can compute where the target lives. This
// does not matter for the purpose of typing, but matters for
// the purpose of computing the address of the given index.
elemSize := int(rootType.Elem().Size())
n := int(targetAddr-rootAddr) / elemSize // Relies on integer division rounding down.
if rootType.Len() < n {
Failf("traversal target of type %v @%x is beyond the end of the array type %v @%x with %v elements",
targetType, targetAddr, rootType, rootAddr, rootType.Len())
}
dots := traverse(rootType.Elem(), targetType, rootAddr+uintptr(n*elemSize), targetAddr)
return append(dots, wire.Index(n))
default:
// For any other type, there's no possibility of aliasing so if
// the types didn't match earlier then we have an address
// collision which shouldn't be possible at this point.
Failf("traverse failed for root type %v and target type %v", rootType, targetType)
}
panic("unreachable")
}
// encodeMap encodes a map.
func (es *encodeState) encodeMap(obj reflect.Value, dest *wire.Object) {
if obj.IsNil() {
// Because there is a difference between a nil map and an empty
// map, we need to not decode in the case of a truly nil map.
*dest = wire.Nil{}
return
}
l := obj.Len()
m := &wire.Map{
Keys: make([]wire.Object, l),
Values: make([]wire.Object, l),
}
*dest = m
for i, k := range obj.MapKeys() {
v := obj.MapIndex(k)
// Map keys must be encoded using the full value because the
// type will be omitted after the first key.
es.encodeObject(k, encodeAsValue, &m.Keys[i])
es.encodeObject(v, encodeAsValue, &m.Values[i])
}
}
// objectEncoder is for encoding structs.
type objectEncoder struct {
// es is encodeState.
es *encodeState
// encoded is the encoded struct.
encoded *wire.Struct
}
// save is called by the public methods on Sink.
func (oe *objectEncoder) save(slot int, obj reflect.Value) {
fieldValue := oe.encoded.Field(slot)
oe.es.encodeObject(obj, encodeDefault, fieldValue)
}
// encodeStruct encodes a composite object.
func (es *encodeState) encodeStruct(obj reflect.Value, dest *wire.Object) {
if s, ok := es.encodedStructs[obj]; ok {
*dest = s
return
}
s := &wire.Struct{}
*dest = s
es.encodedStructs[obj] = s
// Ensure that the obj is addressable. There are two cases when it is
// not. First, is when this is dispatched via SaveValue. Second, when
// this is a map key as a struct. Either way, we need to make a copy to
// obtain an addressable value.
if !obj.CanAddr() {
localObj := reflect.New(obj.Type())
localObj.Elem().Set(obj)
obj = localObj.Elem()
}
// Look the type up in the database.
te, ok := es.types.Lookup(obj.Type())
if te == nil {
if obj.NumField() == 0 {
// Allow unregistered anonymous, empty structs. This
// will just return success without ever invoking the
// passed function. This uses the immutable EmptyStruct
// variable to prevent an allocation in this case.
//
// Note that this mechanism does *not* work for
// interfaces in general. So you can't dispatch
// non-registered empty structs via interfaces because
// then they can't be restored.
s.Alloc(0)
return
}
// We need a SaverLoader for struct types.
Failf("struct %T does not implement SaverLoader", obj.Interface())
}
if !ok {
// Queue the type to be serialized.
es.pendingTypes = append(es.pendingTypes, te.Type)
}
// Invoke the provided saver.
s.TypeID = wire.TypeID(te.ID)
s.Alloc(len(te.Fields))
oe := objectEncoder{
es: es,
encoded: s,
}
es.stats.start(te.ID)
defer es.stats.done()
if sl, ok := obj.Addr().Interface().(SaverLoader); ok {
// Note: may be a registered empty struct which does not
// implement the saver/loader interfaces.
sl.StateSave(Sink{internal: oe})
}
}
// encodeArray encodes an array.
func (es *encodeState) encodeArray(obj reflect.Value, dest *wire.Object) {
l := obj.Len()
a := &wire.Array{
Contents: make([]wire.Object, l),
}
*dest = a
for i := 0; i < l; i++ {
// We need to encode the full value because arrays are encoded
// using the type information from only the first element.
es.encodeObject(obj.Index(i), encodeAsValue, &a.Contents[i])
}
}
// findType recursively finds type information.
func (es *encodeState) findType(typ reflect.Type) wire.TypeSpec {
// First: check if this is a proper type. It's possible for pointers,
// slices, arrays, maps, etc to all have some different type.
te, ok := es.types.Lookup(typ)
if te != nil {
if !ok {
// See encodeStruct.
es.pendingTypes = append(es.pendingTypes, te.Type)
}
return wire.TypeID(te.ID)
}
switch typ.Kind() {
case reflect.Ptr:
return &wire.TypeSpecPointer{
Type: es.findType(typ.Elem()),
}
case reflect.Slice:
return &wire.TypeSpecSlice{
Type: es.findType(typ.Elem()),
}
case reflect.Array:
return &wire.TypeSpecArray{
Count: wire.Uint(typ.Len()),
Type: es.findType(typ.Elem()),
}
case reflect.Map:
return &wire.TypeSpecMap{
Key: es.findType(typ.Key()),
Value: es.findType(typ.Elem()),
}
default:
// After potentially chasing many pointers, the
// ultimate type of the object is not known.
Failf("type %q is not known", typ)
}
panic("unreachable")
}
// encodeInterface encodes an interface.
func (es *encodeState) encodeInterface(obj reflect.Value, dest *wire.Object) {
// Dereference the object.
obj = obj.Elem()
if !obj.IsValid() {
// Special case: the nil object.
*dest = &wire.Interface{
Type: wire.TypeSpecNil{},
Value: wire.Nil{},
}
return
}
// Encode underlying object.
i := &wire.Interface{
Type: es.findType(obj.Type()),
}
*dest = i
es.encodeObject(obj, encodeAsValue, &i.Value)
}
// isPrimitive returns true if this is a primitive object, or a composite
// object composed entirely of primitives.
func isPrimitiveZero(typ reflect.Type) bool {
switch typ.Kind() {
case reflect.Ptr:
// Pointers are always treated as primitive types because we
// won't encode directly from here. Returning true here won't
// prevent the object from being encoded correctly.
return true
case reflect.Bool:
return true
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return true
case reflect.Float32, reflect.Float64:
return true
case reflect.Complex64, reflect.Complex128:
return true
case reflect.String:
return true
case reflect.Slice:
// The slice itself a primitive, but not necessarily the array
// that points to. This is similar to a pointer.
return true
case reflect.Array:
// We cannot treat an array as a primitive, because it may be
// composed of structures or other things with side-effects.
return isPrimitiveZero(typ.Elem())
case reflect.Interface:
// Since we now that this type is the zero type, the interface
// value must be zero. Therefore this is primitive.
return true
case reflect.Struct:
return false
case reflect.Map:
// The isPrimitiveZero function is called only on zero-types to
// see if it's safe to serialize. Since a zero map has no
// elements, it is safe to treat as a primitive.
return true
default:
Failf("unknown type %q", typ.Name())
}
panic("unreachable")
}
// encodeStrategy is the strategy used for encodeObject.
type encodeStrategy int
const (
// encodeDefault means types are encoded normally as references.
encodeDefault encodeStrategy = iota
// encodeAsValue means that types will never take short-circuited and
// will always be encoded as a normal value.
encodeAsValue
// encodeMapAsValue means that even maps will be fully encoded.
encodeMapAsValue
)
// encodeObject encodes an object.
func (es *encodeState) encodeObject(obj reflect.Value, how encodeStrategy, dest *wire.Object) {
if how == encodeDefault && isPrimitiveZero(obj.Type()) && obj.IsZero() {
*dest = wire.Nil{}
return
}
switch obj.Kind() {
case reflect.Ptr: // Fast path: first.
r := new(wire.Ref)
*dest = r
if obj.IsNil() {
// May be in an array or elsewhere such that a value is
// required. So we encode as a reference to the zero
// object, which does not exist. Note that this has to
// be handled correctly in the decode path as well.
return
}
es.resolve(obj, r)
case reflect.Bool:
*dest = wire.Bool(obj.Bool())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
*dest = wire.Int(obj.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
*dest = wire.Uint(obj.Uint())
case reflect.Float32:
*dest = wire.Float32(obj.Float())
case reflect.Float64:
*dest = wire.Float64(obj.Float())
case reflect.Complex64:
c := wire.Complex64(obj.Complex())
*dest = &c // Needs alloc.
case reflect.Complex128:
c := wire.Complex128(obj.Complex())
*dest = &c // Needs alloc.
case reflect.String:
s := wire.String(obj.String())
*dest = &s // Needs alloc.
case reflect.Array:
es.encodeArray(obj, dest)
case reflect.Slice:
s := &wire.Slice{
Capacity: wire.Uint(obj.Cap()),
Length: wire.Uint(obj.Len()),
}
*dest = s
// Note that we do need to provide a wire.Slice type here as
// how is not encodeDefault. If this were the case, then it
// would have been caught by the IsZero check above and we
// would have just used wire.Nil{}.
if obj.IsNil() {
return
}
// Slices need pointer resolution.
es.resolve(arrayFromSlice(obj), &s.Ref)
case reflect.Interface:
es.encodeInterface(obj, dest)
case reflect.Struct:
es.encodeStruct(obj, dest)
case reflect.Map:
if how == encodeMapAsValue {
es.encodeMap(obj, dest)
return
}
r := new(wire.Ref)
*dest = r
es.resolve(obj, r)
default:
Failf("unknown object %#v", obj.Interface())
panic("unreachable")
}
}
// Save serializes the object graph rooted at obj.
func (es *encodeState) Save(obj reflect.Value) {
es.stats.init()
defer es.stats.fini(func(id typeID) string {
return es.pendingTypes[id-1].Name
})
// Resolve the first object, which should queue a pile of additional
// objects on the pending list. All queued objects should be fully
// resolved, and we should be able to serialize after this call.
var root wire.Ref
es.resolve(obj.Addr(), &root)
// Encode the graph.
var oes *objectEncodeState
if err := safely(func() {
for oes = es.deferred.Front(); oes != nil; oes = es.deferred.Front() {
// Remove and encode the object. Note that as a result
// of this encoding, the object may be enqueued on the
// deferred list yet again. That's expected, and why it
// is removed first.
es.deferred.Remove(oes)
es.encodeObject(oes.obj, oes.how, &oes.encoded)
}
}); err != nil {
// Include the object in the error message.
Failf("encoding error: %w\nfor object %#v", err, oes.obj.Interface())
}
// Check that we have objects to serialize.
if len(es.pending) == 0 {
Failf("pending is empty?")
}
// Write the header with the number of objects.
if err := WriteHeader(es.w, uint64(len(es.pending)), true); err != nil {
Failf("error writing header: %w", err)
}
// Serialize all pending types and pending objects. Note that we don't
// bother removing from this list as we walk it because that just
// wastes time. It will not change after this point.
if err := safely(func() {
for _, wt := range es.pendingTypes {
// Encode the type.
wire.Save(es.w, &wt)
}
// Emit objects in ID order.
ids := make([]objectID, 0, len(es.pending))
for id := range es.pending {
ids = append(ids, id)
}
sort.Slice(ids, func(i, j int) bool {
return ids[i] < ids[j]
})
for _, id := range ids {
// Encode the id.
wire.Save(es.w, wire.Uint(id))
// Marshal the object.
oes := es.pending[id]
wire.Save(es.w, oes.encoded)
}
}); err != nil {
// Include the object and the error.
Failf("error serializing object %#v: %w", oes.encoded, err)
}
}
// objectFlag indicates that the length is a # of objects, rather than a raw
// byte length. When this is set on a length header in the stream, it may be
// decoded appropriately.
const objectFlag uint64 = 1 << 63
// WriteHeader writes a header.
//
// Each object written to the statefile should be prefixed with a header. In
// order to generate statefiles that play nicely with debugging tools, raw
// writes should be prefixed with a header with object set to false and the
// appropriate length. This will allow tools to skip these regions.
func WriteHeader(w io.Writer, length uint64, object bool) error {
// Sanity check the length.
if length&objectFlag != 0 {
Failf("impossibly huge length: %d", length)
}
if object {
length |= objectFlag
}
// Write a header.
return safely(func() {
wire.SaveUint(w, length)
})
}
// addrSetFunctions is used by addrSet.
type addrSetFunctions struct{}
func (addrSetFunctions) MinKey() uintptr {
return 0
}
func (addrSetFunctions) MaxKey() uintptr {
return ^uintptr(0)
}
func (addrSetFunctions) ClearValue(val **objectEncodeState) {
*val = nil
}
func (addrSetFunctions) Merge(r1 addrRange, val1 *objectEncodeState, r2 addrRange, val2 *objectEncodeState) (*objectEncodeState, bool) {
if val1.obj == val2.obj {
// This, should never happen. It would indicate that the same
// object exists in two non-contiguous address ranges. Note
// that this assertion can only be triggered if the race
// detector is enabled.
Failf("unexpected merge in addrSet @ %v and %v: %#v and %#v", r1, r2, val1.obj, val2.obj)
}
// Reject the merge.
return val1, false
}
func (addrSetFunctions) Split(r addrRange, val *objectEncodeState, _ uintptr) (*objectEncodeState, *objectEncodeState) {
// A split should never happen: we don't remove ranges.
Failf("unexpected split in addrSet @ %v: %#v", r, val.obj)
panic("unreachable")
}

View File

@@ -0,0 +1,32 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package state
import (
"reflect"
"unsafe"
)
// arrayFromSlice constructs a new pointer to the slice data.
//
// It would be similar to the following:
//
// x := make([]Foo, l, c)
// a := ([l]Foo*)(unsafe.Pointer(x[0]))
func arrayFromSlice(obj reflect.Value) reflect.Value {
return reflect.NewAt(
reflect.ArrayOf(obj.Cap(), obj.Type().Elem()),
unsafe.Pointer(obj.Pointer()))
}

View File

@@ -0,0 +1,324 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package state provides functionality related to saving and loading object
// graphs. For most types, it provides a set of default saving / loading logic
// that will be invoked automatically if custom logic is not defined.
//
// Kind Support
// ---- -------
// Bool default
// Int default
// Int8 default
// Int16 default
// Int32 default
// Int64 default
// Uint default
// Uint8 default
// Uint16 default
// Uint32 default
// Uint64 default
// Float32 default
// Float64 default
// Complex64 default
// Complex128 default
// Array default
// Chan custom
// Func custom
// Interface default
// Map default
// Ptr default
// Slice default
// String default
// Struct custom (*) Unless zero-sized.
// UnsafePointer custom
//
// See README.md for an overview of how encoding and decoding works.
package state
import (
"context"
"fmt"
"io"
"reflect"
"runtime"
"gvisor.dev/gvisor/pkg/state/wire"
)
// objectID is a unique identifier assigned to each object to be serialized.
// Each instance of an object is considered separately, i.e. if there are two
// objects of the same type in the object graph being serialized, they'll be
// assigned unique objectIDs.
type objectID uint32
// typeID is the identifier for a type. Types are serialized and tracked
// alongside objects in order to avoid the overhead of encoding field names in
// all objects.
type typeID uint32
// ErrState is returned when an error is encountered during encode/decode.
type ErrState struct {
// err is the underlying error.
err error
// trace is the stack trace.
trace string
}
// Error returns a sensible description of the state error.
func (e *ErrState) Error() string {
return fmt.Sprintf("%v:\n%s", e.err, e.trace)
}
// Unwrap implements standard unwrapping.
func (e *ErrState) Unwrap() error {
return e.err
}
// Save saves the given object state.
func Save(ctx context.Context, w io.Writer, rootPtr any) (Stats, error) {
// Create the encoding state.
es := encodeState{
ctx: ctx,
w: w,
types: makeTypeEncodeDatabase(),
zeroValues: make(map[reflect.Type]*objectEncodeState),
pending: make(map[objectID]*objectEncodeState),
encodedStructs: make(map[reflect.Value]*wire.Struct),
}
// Perform the encoding.
err := safely(func() {
es.Save(reflect.ValueOf(rootPtr).Elem())
})
return es.stats, err
}
// Load loads a checkpoint.
func Load(ctx context.Context, r io.Reader, rootPtr any) (Stats, error) {
// Create the decoding state.
ds := decodeState{
ctx: ctx,
r: r,
types: makeTypeDecodeDatabase(),
deferred: make(map[objectID]wire.Object),
}
// Attempt our decode.
err := safely(func() {
ds.Load(reflect.ValueOf(rootPtr).Elem())
})
return ds.stats, err
}
// Sink is used for Type.StateSave.
type Sink struct {
internal objectEncoder
}
// Save adds the given object to the map.
//
// You should pass always pointers to the object you are saving. For example:
//
// type X struct {
// A int
// B *int
// }
//
// func (x *X) StateTypeInfo(m Sink) state.TypeInfo {
// return state.TypeInfo{
// Name: "pkg.X",
// Fields: []string{
// "A",
// "B",
// },
// }
// }
//
// func (x *X) StateSave(m Sink) {
// m.Save(0, &x.A) // Field is A.
// m.Save(1, &x.B) // Field is B.
// }
//
// func (x *X) StateLoad(m Source) {
// m.Load(0, &x.A) // Field is A.
// m.Load(1, &x.B) // Field is B.
// }
func (s Sink) Save(slot int, objPtr any) {
s.internal.save(slot, reflect.ValueOf(objPtr).Elem())
}
// SaveValue adds the given object value to the map.
//
// This should be used for values where pointers are not available, or casts
// are required during Save/Load.
//
// For example, if we want to cast external package type P.Foo to int64:
//
// func (x *X) StateSave(m Sink) {
// m.SaveValue(0, "A", int64(x.A))
// }
//
// func (x *X) StateLoad(m Source) {
// m.LoadValue(0, new(int64), func(x any) {
// x.A = P.Foo(x.(int64))
// })
// }
func (s Sink) SaveValue(slot int, obj any) {
s.internal.save(slot, reflect.ValueOf(obj))
}
// Context returns the context object provided at save time.
func (s Sink) Context() context.Context {
return s.internal.es.ctx
}
// Type is an interface that must be implemented by Struct objects. This allows
// these objects to be serialized while minimizing runtime reflection required.
//
// All these methods can be automatically generated by the go_statify tool.
type Type interface {
// StateTypeName returns the type's name.
//
// This is used for matching type information during encoding and
// decoding, as well as dynamic interface dispatch. This should be
// globally unique.
StateTypeName() string
// StateFields returns information about the type.
//
// Fields is the set of fields for the object. Calls to Sink.Save and
// Source.Load must be made in-order with respect to these fields.
//
// This will be called at most once per serialization.
StateFields() []string
}
// SaverLoader must be implemented by struct types.
type SaverLoader interface {
// StateSave saves the state of the object to the given Map.
StateSave(Sink)
// StateLoad loads the state of the object.
StateLoad(context.Context, Source)
}
// Source is used for Type.StateLoad.
type Source struct {
internal objectDecoder
}
// Load loads the given object passed as a pointer..
//
// See Sink.Save for an example.
func (s Source) Load(slot int, objPtr any) {
s.internal.load(slot, reflect.ValueOf(objPtr), false, nil)
}
// LoadWait loads the given objects from the map, and marks it as requiring all
// AfterLoad executions to complete prior to running this object's AfterLoad.
//
// See Sink.Save for an example.
func (s Source) LoadWait(slot int, objPtr any) {
s.internal.load(slot, reflect.ValueOf(objPtr), true, nil)
}
// LoadValue loads the given object value from the map.
//
// See Sink.SaveValue for an example.
func (s Source) LoadValue(slot int, objPtr any, fn func(any)) {
o := reflect.ValueOf(objPtr)
s.internal.load(slot, o, true, func() { fn(o.Elem().Interface()) })
}
// AfterLoad schedules a function execution when all objects have been
// allocated and their automated loading and customized load logic have been
// executed. fn will not be executed until all of current object's
// dependencies' AfterLoad() logic, if exist, have been executed.
func (s Source) AfterLoad(fn func()) {
s.internal.afterLoad(fn)
}
// Context returns the context object provided at load time.
func (s Source) Context() context.Context {
return s.internal.ds.ctx
}
// IsZeroValue checks if the given value is the zero value.
//
// This function is used by the stateify tool.
func IsZeroValue(val any) bool {
return val == nil || reflect.ValueOf(val).Elem().IsZero()
}
// Failf is a wrapper around panic that should be used to generate errors that
// can be caught during saving and loading.
func Failf(fmtStr string, v ...any) {
panic(fmt.Errorf(fmtStr, v...))
}
// safely executes the given function, catching a panic and unpacking as an
// error.
//
// The error flow through the state package uses panic and recover. There are
// two important reasons for this:
//
// 1) Many of the reflection methods will already panic with invalid data or
// violated assumptions. We would want to recover anyways here.
//
// 2) It allows us to eliminate boilerplate within Save() and Load() functions.
// In nearly all cases, when the low-level serialization functions fail, you
// will want the checkpoint to fail anyways. Plumbing errors through every
// method doesn't add a lot of value. If there are specific error conditions
// that you'd like to handle, you should add appropriate functionality to
// objects themselves prior to calling Save() and Load().
func safely(fn func()) (err error) {
defer func() {
if r := recover(); r != nil {
if es, ok := r.(*ErrState); ok {
err = es // Propagate.
return
}
// Build a new state error.
es := new(ErrState)
if e, ok := r.(error); ok {
es.err = e
} else {
es.err = fmt.Errorf("%v", r)
}
// Make a stack. We don't know how big it will be ahead
// of time, but want to make sure we get the whole
// thing. So we just do a stupid brute force approach.
var stack []byte
for sz := 1024; ; sz *= 2 {
stack = make([]byte, sz)
n := runtime.Stack(stack, false)
if n < sz {
es.trace = string(stack[:n])
break
}
}
// Set the error.
err = es
}
}()
// Execute the function.
fn()
return nil
}

View File

@@ -0,0 +1,20 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build !race
// +build !race
package state
var raceEnabled = false

View File

@@ -0,0 +1,20 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build race
// +build race
package state
var raceEnabled = true

View File

@@ -0,0 +1,145 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package state
import (
"bytes"
"fmt"
"sort"
"time"
)
type statEntry struct {
count uint
total time.Duration
}
// Stats tracks encode / decode timing.
//
// This currently provides a meaningful String function and no other way to
// extract stats about individual types.
//
// All exported receivers accept nil.
type Stats struct {
// byType contains a breakdown of time spent by type.
//
// This is indexed *directly* by typeID, including zero.
byType []statEntry
// stack contains objects in progress.
stack []typeID
// names contains type names.
//
// This is also indexed *directly* by typeID, including zero, which we
// hard-code as "state.default". This is only resolved by calling fini
// on the stats object.
names []string
// last is the last start time.
last time.Time
}
// init initializes statistics.
func (s *Stats) init() {
s.last = time.Now()
s.stack = append(s.stack, 0)
}
// fini finalizes statistics.
func (s *Stats) fini(resolve func(id typeID) string) {
s.done()
// Resolve all type names.
s.names = make([]string, len(s.byType))
s.names[0] = "state.default" // See above.
for id := typeID(1); int(id) < len(s.names); id++ {
s.names[id] = resolve(id)
}
}
// sample adds the samples to the given object.
func (s *Stats) sample(id typeID) {
now := time.Now()
if len(s.byType) <= int(id) {
// Allocate all the missing entries in one fell swoop.
s.byType = append(s.byType, make([]statEntry, 1+int(id)-len(s.byType))...)
}
s.byType[id].total += now.Sub(s.last)
s.last = now
}
// start starts a sample.
func (s *Stats) start(id typeID) {
last := s.stack[len(s.stack)-1]
s.sample(last)
s.stack = append(s.stack, id)
}
// done finishes the current sample.
func (s *Stats) done() {
last := s.stack[len(s.stack)-1]
s.sample(last)
s.byType[last].count++
s.stack = s.stack[:len(s.stack)-1]
}
type sliceEntry struct {
name string
entry *statEntry
}
// String returns a table representation of the stats.
func (s *Stats) String() string {
// Build a list of stat entries.
ss := make([]sliceEntry, 0, len(s.byType))
for id := 0; id < len(s.names); id++ {
ss = append(ss, sliceEntry{
name: s.names[id],
entry: &s.byType[id],
})
}
// Sort by total time (descending).
sort.Slice(ss, func(i, j int) bool {
return ss[i].entry.total > ss[j].entry.total
})
// Print the stat results.
var (
buf bytes.Buffer
count uint
total time.Duration
)
buf.WriteString("\n")
buf.WriteString(fmt.Sprintf("% 16s | % 8s | % 16s | %s\n", "total", "count", "per", "type"))
buf.WriteString("-----------------+----------+------------------+----------------\n")
for _, se := range ss {
if se.entry.count == 0 {
// Since we store all types linearly, we are not
// guaranteed that any entry actually has time.
continue
}
count += se.entry.count
total += se.entry.total
per := se.entry.total / time.Duration(se.entry.count)
buf.WriteString(fmt.Sprintf("% 16s | %8d | % 16s | %s\n",
se.entry.total, se.entry.count, per, se.name))
}
buf.WriteString("-----------------+----------+------------------+----------------\n")
buf.WriteString(fmt.Sprintf("% 16s | % 8d | % 16s | [all]",
total, count, total/time.Duration(count)))
return string(buf.Bytes())
}

View File

@@ -0,0 +1,384 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package state
import (
"reflect"
"sort"
"gvisor.dev/gvisor/pkg/state/wire"
)
// assertValidType asserts that the type is valid.
func assertValidType(name string, fields []string) {
if name == "" {
Failf("type has empty name")
}
fieldsCopy := make([]string, len(fields))
for i := 0; i < len(fields); i++ {
if fields[i] == "" {
Failf("field has empty name for type %q", name)
}
fieldsCopy[i] = fields[i]
}
sort.Slice(fieldsCopy, func(i, j int) bool {
return fieldsCopy[i] < fieldsCopy[j]
})
for i := range fieldsCopy {
if i > 0 && fieldsCopy[i-1] == fieldsCopy[i] {
Failf("duplicate field %q for type %s", fieldsCopy[i], name)
}
}
}
// typeEntry is an entry in the typeDatabase.
type typeEntry struct {
ID typeID
wire.Type
}
// reconciledTypeEntry is a reconciled entry in the typeDatabase.
type reconciledTypeEntry struct {
wire.Type
LocalType reflect.Type
FieldOrder []int
}
// typeEncodeDatabase is an internal TypeInfo database for encoding.
type typeEncodeDatabase struct {
// byType maps by type to the typeEntry.
byType map[reflect.Type]*typeEntry
// lastID is the last used ID.
lastID typeID
}
// makeTypeEncodeDatabase makes a typeDatabase.
func makeTypeEncodeDatabase() typeEncodeDatabase {
return typeEncodeDatabase{
byType: make(map[reflect.Type]*typeEntry),
}
}
// typeDecodeDatabase is an internal TypeInfo database for decoding.
type typeDecodeDatabase struct {
// byID maps by ID to type.
byID []*reconciledTypeEntry
// pending are entries that are pending validation by Lookup. These
// will be reconciled with actual objects. Note that these will also be
// used to lookup types by name, since they may not be reconciled and
// there's little value to deleting from this map.
pending []*wire.Type
}
// makeTypeDecodeDatabase makes a typeDatabase.
func makeTypeDecodeDatabase() typeDecodeDatabase {
return typeDecodeDatabase{}
}
// lookupNameFields extracts the name and fields from an object.
func lookupNameFields(typ reflect.Type) (string, []string, bool) {
v := reflect.Zero(reflect.PtrTo(typ)).Interface()
t, ok := v.(Type)
if !ok {
// Is this a primitive?
if typ.Kind() == reflect.Interface {
return interfaceType, nil, true
}
name := typ.Name()
if _, ok := primitiveTypeDatabase[name]; !ok {
// This is not a known type, and not a primitive. The
// encoder may proceed for anonymous empty structs, or
// it may deference the type pointer and try again.
return "", nil, false
}
return name, nil, true
}
// Sanity check the type.
if raceEnabled {
if _, ok := reverseTypeDatabase[typ]; !ok {
// The type was not registered? Must be an embedded
// structure or something else.
return "", nil, false
}
}
// Extract the name from the object.
name := t.StateTypeName()
fields := t.StateFields()
assertValidType(name, fields)
return name, fields, true
}
// Lookup looks up or registers the given object.
//
// The bool indicates whether this is an existing entry: false means the entry
// did not exist, and true means the entry did exist. If this bool is false and
// the returned typeEntry are nil, then the obj did not implement the Type
// interface.
func (tdb *typeEncodeDatabase) Lookup(typ reflect.Type) (*typeEntry, bool) {
te, ok := tdb.byType[typ]
if !ok {
// Lookup the type information.
name, fields, ok := lookupNameFields(typ)
if !ok {
// Empty structs may still be encoded, so let the
// caller decide what to do from here.
return nil, false
}
// Register the new type.
tdb.lastID++
te = &typeEntry{
ID: tdb.lastID,
Type: wire.Type{
Name: name,
Fields: fields,
},
}
// All done.
tdb.byType[typ] = te
return te, false
}
return te, true
}
// Register adds a typeID entry.
func (tbd *typeDecodeDatabase) Register(typ *wire.Type) {
assertValidType(typ.Name, typ.Fields)
tbd.pending = append(tbd.pending, typ)
}
// LookupName looks up the type name by ID.
func (tbd *typeDecodeDatabase) LookupName(id typeID) string {
if len(tbd.pending) < int(id) {
// This is likely an encoder error?
Failf("type ID %d not available", id)
}
return tbd.pending[id-1].Name
}
// LookupType looks up the type by ID.
func (tbd *typeDecodeDatabase) LookupType(id typeID) reflect.Type {
name := tbd.LookupName(id)
typ, ok := globalTypeDatabase[name]
if !ok {
// If not available, see if it's primitive.
typ, ok = primitiveTypeDatabase[name]
if !ok && name == interfaceType {
// Matches the built-in interface type.
var i any
return reflect.TypeOf(&i).Elem()
}
if !ok {
// The type is perhaps not registered?
Failf("type name %q is not available", name)
}
return typ // Primitive type.
}
return typ // Registered type.
}
// singleFieldOrder defines the field order for a single field.
var singleFieldOrder = []int{0}
// Lookup looks up or registers the given object.
//
// First, the typeID is searched to see if this has already been appropriately
// reconciled. If no, then a reconciliation will take place that may result in a
// field ordering. If a nil reconciledTypeEntry is returned from this method,
// then the object does not support the Type interface.
//
// This method never returns nil.
func (tbd *typeDecodeDatabase) Lookup(id typeID, typ reflect.Type) *reconciledTypeEntry {
if len(tbd.byID) > int(id) && tbd.byID[id-1] != nil {
// Already reconciled.
return tbd.byID[id-1]
}
// The ID has not been reconciled yet. That's fine. We need to make
// sure it aligns with the current provided object.
if len(tbd.pending) < int(id) {
// This id was never registered. Probably an encoder error?
Failf("typeDatabase does not contain id %d", id)
}
// Extract the pending info.
pending := tbd.pending[id-1]
// Grow the byID list.
if len(tbd.byID) < int(id) {
tbd.byID = append(tbd.byID, make([]*reconciledTypeEntry, int(id)-len(tbd.byID))...)
}
// Reconcile the type.
name, fields, ok := lookupNameFields(typ)
if !ok {
// Empty structs are decoded only when the type is nil. Since
// this isn't the case, we fail here.
Failf("unsupported type %q during decode; can't reconcile", pending.Name)
}
if name != pending.Name {
// Are these the same type? Print a helpful message as this may
// actually happen in practice if types change.
Failf("typeDatabase contains conflicting definitions for id %d: %s->%v (current) and %s->%v (existing)",
id, name, fields, pending.Name, pending.Fields)
}
rte := &reconciledTypeEntry{
Type: wire.Type{
Name: name,
Fields: fields,
},
LocalType: typ,
}
// If there are zero or one fields, then we skip allocating the field
// slice. There is special handling for decoding in this case. If the
// field name does not match, it will be caught in the general purpose
// code below.
if len(fields) != len(pending.Fields) {
Failf("type %q contains different fields: %v (decode) and %v (encode)",
name, fields, pending.Fields)
}
if len(fields) == 0 {
tbd.byID[id-1] = rte // Save.
return rte
}
if len(fields) == 1 && fields[0] == pending.Fields[0] {
tbd.byID[id-1] = rte // Save.
rte.FieldOrder = singleFieldOrder
return rte
}
// For each field in the current object's information, match it to a
// field in the destination object. We know from the assertion above
// and the insertion on insertion to pending that neither field
// contains any duplicates.
fieldOrder := make([]int, len(fields))
for i, name := range fields {
fieldOrder[i] = -1 // Sentinel.
// Is it an exact match?
if pending.Fields[i] == name {
fieldOrder[i] = i
continue
}
// Find the matching field.
for j, otherName := range pending.Fields {
if name == otherName {
fieldOrder[i] = j
break
}
}
if fieldOrder[i] == -1 {
// The type name matches but we are lacking some common fields.
Failf("type %q has mismatched fields: %v (decode) and %v (encode)",
name, fields, pending.Fields)
}
}
// The type has been reeconciled.
rte.FieldOrder = fieldOrder
tbd.byID[id-1] = rte
return rte
}
// interfaceType defines all interfaces.
const interfaceType = "interface"
// primitiveTypeDatabase is a set of fixed types.
var primitiveTypeDatabase = func() map[string]reflect.Type {
r := make(map[string]reflect.Type)
for _, t := range []reflect.Type{
reflect.TypeOf(false),
reflect.TypeOf(int(0)),
reflect.TypeOf(int8(0)),
reflect.TypeOf(int16(0)),
reflect.TypeOf(int32(0)),
reflect.TypeOf(int64(0)),
reflect.TypeOf(uint(0)),
reflect.TypeOf(uintptr(0)),
reflect.TypeOf(uint8(0)),
reflect.TypeOf(uint16(0)),
reflect.TypeOf(uint32(0)),
reflect.TypeOf(uint64(0)),
reflect.TypeOf(""),
reflect.TypeOf(float32(0.0)),
reflect.TypeOf(float64(0.0)),
reflect.TypeOf(complex64(0.0)),
reflect.TypeOf(complex128(0.0)),
} {
r[t.Name()] = t
}
return r
}()
// globalTypeDatabase is used for dispatching interfaces on decode.
var globalTypeDatabase = map[string]reflect.Type{}
// reverseTypeDatabase is a reverse mapping.
var reverseTypeDatabase = map[reflect.Type]string{}
// Release releases references to global type databases.
// Must only be called in contexts where they will definitely never be used,
// in order to save memory.
func Release() {
globalTypeDatabase = nil
reverseTypeDatabase = nil
}
// Register registers a type.
//
// This must be called on init and only done once.
func Register(t Type) {
name := t.StateTypeName()
typ := reflect.TypeOf(t)
if raceEnabled {
assertValidType(name, t.StateFields())
// Register must always be called on pointers.
if typ.Kind() != reflect.Ptr {
Failf("Register must be called on pointers")
}
}
typ = typ.Elem()
if raceEnabled {
if typ.Kind() == reflect.Struct {
// All registered structs must implement SaverLoader. We allow
// the registration is non-struct types with just the Type
// interface, but we need to call StateSave/StateLoad methods
// on aggregate types.
if _, ok := t.(SaverLoader); !ok {
Failf("struct %T does not implement SaverLoader", t)
}
} else {
// Non-structs must not have any fields. We don't support
// calling StateSave/StateLoad methods on any non-struct types.
// If custom behavior is required, these types should be
// wrapped in a structure of some kind.
if fields := t.StateFields(); len(fields) != 0 {
Failf("non-struct %T has non-zero fields %v", t, fields)
}
// We don't allow non-structs to implement StateSave/StateLoad
// methods, because they won't be called and it's confusing.
if _, ok := t.(SaverLoader); ok {
Failf("non-struct %T implements SaverLoader", t)
}
}
if _, ok := primitiveTypeDatabase[name]; ok {
Failf("conflicting primitiveTypeDatabase entry for %T: used by primitive", t)
}
if _, ok := globalTypeDatabase[name]; ok {
Failf("conflicting globalTypeDatabase entries for %T: name conflict", t)
}
if name == interfaceType {
Failf("conflicting name for %T: matches interfaceType", t)
}
reverseTypeDatabase[typ] = name
}
globalTypeDatabase[name] = typ
}

View File

@@ -0,0 +1,976 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package wire contains a few basic types that can be composed to serialize
// graph information for the state package. This package defines the wire
// protocol.
//
// Note that these types are careful about how they implement the relevant
// interfaces (either value receiver or pointer receiver), so that native-sized
// types, such as integers and simple pointers, can fit inside the interface
// object.
//
// This package also uses panic as control flow, so called should be careful to
// wrap calls in appropriate handlers.
//
// Testing for this package is driven by the state test package.
package wire
import (
"fmt"
"io"
"math"
"gvisor.dev/gvisor/pkg/gohacks"
"gvisor.dev/gvisor/pkg/sync"
)
var oneByteArrayPool = sync.Pool{
New: func() any { return &[1]byte{} },
}
// readFull is a utility. The equivalent is not needed for Write, but the API
// contract dictates that it must always complete all bytes given or return an
// error.
func readFull(r io.Reader, p []byte) {
for done := 0; done < len(p); {
n, err := r.Read(p[done:])
done += n
if n == 0 && err != nil {
panic(err)
}
}
}
// Object is a generic object.
type Object interface {
// save saves the given object.
//
// Panic is used for error control flow.
save(io.Writer)
// load loads a new object of the given type.
//
// Panic is used for error control flow.
load(io.Reader) Object
}
// Bool is a boolean.
type Bool bool
// loadBool loads an object of type Bool.
func loadBool(r io.Reader) Bool {
b := loadUint(r)
return Bool(b == 1)
}
// save implements Object.save.
func (b Bool) save(w io.Writer) {
var v Uint
if b {
v = 1
} else {
v = 0
}
v.save(w)
}
// load implements Object.load.
func (Bool) load(r io.Reader) Object { return loadBool(r) }
// Int is a signed integer.
//
// This uses varint encoding.
type Int int64
// loadInt loads an object of type Int.
func loadInt(r io.Reader) Int {
u := loadUint(r)
x := Int(u >> 1)
if u&1 != 0 {
x = ^x
}
return x
}
// save implements Object.save.
func (i Int) save(w io.Writer) {
u := Uint(i) << 1
if i < 0 {
u = ^u
}
u.save(w)
}
// load implements Object.load.
func (Int) load(r io.Reader) Object { return loadInt(r) }
// Uint is an unsigned integer.
type Uint uint64
func readByte(r io.Reader) byte {
p := oneByteArrayPool.Get().(*[1]byte)
defer oneByteArrayPool.Put(p)
n, err := r.Read(p[:])
if n != 1 {
panic(err)
}
return p[0]
}
// loadUint loads an object of type Uint.
func loadUint(r io.Reader) Uint {
var (
u Uint
s uint
)
for i := 0; i <= 9; i++ {
b := readByte(r)
if b < 0x80 {
if i == 9 && b > 1 {
panic("overflow")
}
u |= Uint(b) << s
return u
}
u |= Uint(b&0x7f) << s
s += 7
}
panic("unreachable")
}
func writeByte(w io.Writer, b byte) {
p := oneByteArrayPool.Get().(*[1]byte)
defer oneByteArrayPool.Put(p)
p[0] = b
n, err := w.Write(p[:])
if n != 1 {
panic(err)
}
}
// save implements Object.save.
func (u Uint) save(w io.Writer) {
for u >= 0x80 {
writeByte(w, byte(u)|0x80)
u >>= 7
}
writeByte(w, byte(u))
}
// load implements Object.load.
func (Uint) load(r io.Reader) Object { return loadUint(r) }
// Float32 is a 32-bit floating point number.
type Float32 float32
// loadFloat32 loads an object of type Float32.
func loadFloat32(r io.Reader) Float32 {
n := loadUint(r)
return Float32(math.Float32frombits(uint32(n)))
}
// save implements Object.save.
func (f Float32) save(w io.Writer) {
n := Uint(math.Float32bits(float32(f)))
n.save(w)
}
// load implements Object.load.
func (Float32) load(r io.Reader) Object { return loadFloat32(r) }
// Float64 is a 64-bit floating point number.
type Float64 float64
// loadFloat64 loads an object of type Float64.
func loadFloat64(r io.Reader) Float64 {
n := loadUint(r)
return Float64(math.Float64frombits(uint64(n)))
}
// save implements Object.save.
func (f Float64) save(w io.Writer) {
n := Uint(math.Float64bits(float64(f)))
n.save(w)
}
// load implements Object.load.
func (Float64) load(r io.Reader) Object { return loadFloat64(r) }
// Complex64 is a 64-bit complex number.
type Complex64 complex128
// loadComplex64 loads an object of type Complex64.
func loadComplex64(r io.Reader) Complex64 {
re := loadFloat32(r)
im := loadFloat32(r)
return Complex64(complex(float32(re), float32(im)))
}
// save implements Object.save.
func (c *Complex64) save(w io.Writer) {
re := Float32(real(*c))
im := Float32(imag(*c))
re.save(w)
im.save(w)
}
// load implements Object.load.
func (*Complex64) load(r io.Reader) Object {
c := loadComplex64(r)
return &c
}
// Complex128 is a 128-bit complex number.
type Complex128 complex128
// loadComplex128 loads an object of type Complex128.
func loadComplex128(r io.Reader) Complex128 {
re := loadFloat64(r)
im := loadFloat64(r)
return Complex128(complex(float64(re), float64(im)))
}
// save implements Object.save.
func (c *Complex128) save(w io.Writer) {
re := Float64(real(*c))
im := Float64(imag(*c))
re.save(w)
im.save(w)
}
// load implements Object.load.
func (*Complex128) load(r io.Reader) Object {
c := loadComplex128(r)
return &c
}
// String is a string.
type String string
// loadString loads an object of type String.
func loadString(r io.Reader) String {
l := loadUint(r)
p := make([]byte, l)
readFull(r, p)
return String(gohacks.StringFromImmutableBytes(p))
}
// save implements Object.save.
func (s *String) save(w io.Writer) {
l := Uint(len(*s))
l.save(w)
p := gohacks.ImmutableBytesFromString(string(*s))
_, err := w.Write(p) // Must write all bytes.
if err != nil {
panic(err)
}
}
// load implements Object.load.
func (*String) load(r io.Reader) Object {
s := loadString(r)
return &s
}
// Dot is a kind of reference: one of Index and FieldName.
type Dot interface {
isDot()
}
// Index is a reference resolution.
type Index uint32
func (Index) isDot() {}
// FieldName is a reference resolution.
type FieldName string
func (*FieldName) isDot() {}
// Ref is a reference to an object.
type Ref struct {
// Root is the root object.
Root Uint
// Dots is the set of traversals required from the Root object above.
// Note that this will be stored in reverse order for efficiency.
Dots []Dot
// Type is the base type for the root object. This is non-nil iff Dots
// is non-zero length (that is, this is a complex reference). This is
// not *strictly* necessary, but can be used to simplify decoding.
Type TypeSpec
}
// loadRef loads an object of type Ref (abstract).
func loadRef(r io.Reader) Ref {
ref := Ref{
Root: loadUint(r),
}
l := loadUint(r)
ref.Dots = make([]Dot, l)
for i := 0; i < int(l); i++ {
// Disambiguate between an Index (non-negative) and a field
// name (negative). This does some space and avoids a dedicate
// loadDot function. See Ref.save for the other side.
d := loadInt(r)
if d >= 0 {
ref.Dots[i] = Index(d)
continue
}
p := make([]byte, -d)
readFull(r, p)
fieldName := FieldName(gohacks.StringFromImmutableBytes(p))
ref.Dots[i] = &fieldName
}
if l != 0 {
// Only if dots is non-zero.
ref.Type = loadTypeSpec(r)
}
return ref
}
// save implements Object.save.
func (r *Ref) save(w io.Writer) {
r.Root.save(w)
l := Uint(len(r.Dots))
l.save(w)
for _, d := range r.Dots {
// See LoadRef. We use non-negative numbers to encode Index
// objects and negative numbers to encode field lengths.
switch x := d.(type) {
case Index:
i := Int(x)
i.save(w)
case *FieldName:
d := Int(-len(*x))
d.save(w)
p := gohacks.ImmutableBytesFromString(string(*x))
if _, err := w.Write(p); err != nil {
panic(err)
}
default:
panic("unknown dot implementation")
}
}
if l != 0 {
// See above.
saveTypeSpec(w, r.Type)
}
}
// load implements Object.load.
func (*Ref) load(r io.Reader) Object {
ref := loadRef(r)
return &ref
}
// Nil is a primitive zero value of any type.
type Nil struct{}
// loadNil loads an object of type Nil.
func loadNil(r io.Reader) Nil {
return Nil{}
}
// save implements Object.save.
func (Nil) save(w io.Writer) {}
// load implements Object.load.
func (Nil) load(r io.Reader) Object { return loadNil(r) }
// Slice is a slice value.
type Slice struct {
Length Uint
Capacity Uint
Ref Ref
}
// loadSlice loads an object of type Slice.
func loadSlice(r io.Reader) Slice {
return Slice{
Length: loadUint(r),
Capacity: loadUint(r),
Ref: loadRef(r),
}
}
// save implements Object.save.
func (s *Slice) save(w io.Writer) {
s.Length.save(w)
s.Capacity.save(w)
s.Ref.save(w)
}
// load implements Object.load.
func (*Slice) load(r io.Reader) Object {
s := loadSlice(r)
return &s
}
// Array is an array value.
type Array struct {
Contents []Object
}
// loadArray loads an object of type Array.
func loadArray(r io.Reader) Array {
l := loadUint(r)
if l == 0 {
// Note that there isn't a single object available to encode
// the type of, so we need this additional branch.
return Array{}
}
// All the objects here have the same type, so use dynamic dispatch
// only once. All other objects will automatically take the same type
// as the first object.
contents := make([]Object, l)
v := Load(r)
contents[0] = v
for i := 1; i < int(l); i++ {
contents[i] = v.load(r)
}
return Array{
Contents: contents,
}
}
// save implements Object.save.
func (a *Array) save(w io.Writer) {
l := Uint(len(a.Contents))
l.save(w)
if l == 0 {
// See LoadArray.
return
}
// See above.
Save(w, a.Contents[0])
for i := 1; i < int(l); i++ {
a.Contents[i].save(w)
}
}
// load implements Object.load.
func (*Array) load(r io.Reader) Object {
a := loadArray(r)
return &a
}
// Map is a map value.
type Map struct {
Keys []Object
Values []Object
}
// loadMap loads an object of type Map.
func loadMap(r io.Reader) Map {
l := loadUint(r)
if l == 0 {
// See LoadArray.
return Map{}
}
// See type dispatch notes in Array.
keys := make([]Object, l)
values := make([]Object, l)
k := Load(r)
v := Load(r)
keys[0] = k
values[0] = v
for i := 1; i < int(l); i++ {
keys[i] = k.load(r)
values[i] = v.load(r)
}
return Map{
Keys: keys,
Values: values,
}
}
// save implements Object.save.
func (m *Map) save(w io.Writer) {
l := Uint(len(m.Keys))
if int(l) != len(m.Values) {
panic(fmt.Sprintf("mismatched keys (%d) Aand values (%d)", len(m.Keys), len(m.Values)))
}
l.save(w)
if l == 0 {
// See LoadArray.
return
}
// See above.
Save(w, m.Keys[0])
Save(w, m.Values[0])
for i := 1; i < int(l); i++ {
m.Keys[i].save(w)
m.Values[i].save(w)
}
}
// load implements Object.load.
func (*Map) load(r io.Reader) Object {
m := loadMap(r)
return &m
}
// TypeSpec is a type dereference.
type TypeSpec interface {
isTypeSpec()
}
// TypeID is a concrete type ID.
type TypeID Uint
func (TypeID) isTypeSpec() {}
// TypeSpecPointer is a pointer type.
type TypeSpecPointer struct {
Type TypeSpec
}
func (*TypeSpecPointer) isTypeSpec() {}
// TypeSpecArray is an array type.
type TypeSpecArray struct {
Count Uint
Type TypeSpec
}
func (*TypeSpecArray) isTypeSpec() {}
// TypeSpecSlice is a slice type.
type TypeSpecSlice struct {
Type TypeSpec
}
func (*TypeSpecSlice) isTypeSpec() {}
// TypeSpecMap is a map type.
type TypeSpecMap struct {
Key TypeSpec
Value TypeSpec
}
func (*TypeSpecMap) isTypeSpec() {}
// TypeSpecNil is an empty type.
type TypeSpecNil struct{}
func (TypeSpecNil) isTypeSpec() {}
// TypeSpec types.
//
// These use a distinct encoding on the wire, as they are used only in the
// interface object. They are decoded through the dedicated loadTypeSpec and
// saveTypeSpec functions.
const (
typeSpecTypeID Uint = iota
typeSpecPointer
typeSpecArray
typeSpecSlice
typeSpecMap
typeSpecNil
)
// loadTypeSpec loads TypeSpec values.
func loadTypeSpec(r io.Reader) TypeSpec {
switch hdr := loadUint(r); hdr {
case typeSpecTypeID:
return TypeID(loadUint(r))
case typeSpecPointer:
return &TypeSpecPointer{
Type: loadTypeSpec(r),
}
case typeSpecArray:
return &TypeSpecArray{
Count: loadUint(r),
Type: loadTypeSpec(r),
}
case typeSpecSlice:
return &TypeSpecSlice{
Type: loadTypeSpec(r),
}
case typeSpecMap:
return &TypeSpecMap{
Key: loadTypeSpec(r),
Value: loadTypeSpec(r),
}
case typeSpecNil:
return TypeSpecNil{}
default:
// This is not a valid stream?
panic(fmt.Errorf("unknown header: %d", hdr))
}
}
// saveTypeSpec saves TypeSpec values.
func saveTypeSpec(w io.Writer, t TypeSpec) {
switch x := t.(type) {
case TypeID:
typeSpecTypeID.save(w)
Uint(x).save(w)
case *TypeSpecPointer:
typeSpecPointer.save(w)
saveTypeSpec(w, x.Type)
case *TypeSpecArray:
typeSpecArray.save(w)
x.Count.save(w)
saveTypeSpec(w, x.Type)
case *TypeSpecSlice:
typeSpecSlice.save(w)
saveTypeSpec(w, x.Type)
case *TypeSpecMap:
typeSpecMap.save(w)
saveTypeSpec(w, x.Key)
saveTypeSpec(w, x.Value)
case TypeSpecNil:
typeSpecNil.save(w)
default:
// This should not happen?
panic(fmt.Errorf("unknown type %T", t))
}
}
// Interface is an interface value.
type Interface struct {
Type TypeSpec
Value Object
}
// loadInterface loads an object of type Interface.
func loadInterface(r io.Reader) Interface {
return Interface{
Type: loadTypeSpec(r),
Value: Load(r),
}
}
// save implements Object.save.
func (i *Interface) save(w io.Writer) {
saveTypeSpec(w, i.Type)
Save(w, i.Value)
}
// load implements Object.load.
func (*Interface) load(r io.Reader) Object {
i := loadInterface(r)
return &i
}
// Type is type information.
type Type struct {
Name string
Fields []string
}
// loadType loads an object of type Type.
func loadType(r io.Reader) Type {
name := string(loadString(r))
l := loadUint(r)
fields := make([]string, l)
for i := 0; i < int(l); i++ {
fields[i] = string(loadString(r))
}
return Type{
Name: name,
Fields: fields,
}
}
// save implements Object.save.
func (t *Type) save(w io.Writer) {
s := String(t.Name)
s.save(w)
l := Uint(len(t.Fields))
l.save(w)
for i := 0; i < int(l); i++ {
s := String(t.Fields[i])
s.save(w)
}
}
// load implements Object.load.
func (*Type) load(r io.Reader) Object {
t := loadType(r)
return &t
}
// multipleObjects is a special type for serializing multiple objects.
type multipleObjects []Object
// loadMultipleObjects loads a series of objects.
func loadMultipleObjects(r io.Reader) multipleObjects {
l := loadUint(r)
m := make(multipleObjects, l)
for i := 0; i < int(l); i++ {
m[i] = Load(r)
}
return m
}
// save implements Object.save.
func (m *multipleObjects) save(w io.Writer) {
l := Uint(len(*m))
l.save(w)
for i := 0; i < int(l); i++ {
Save(w, (*m)[i])
}
}
// load implements Object.load.
func (*multipleObjects) load(r io.Reader) Object {
m := loadMultipleObjects(r)
return &m
}
// noObjects represents no objects.
type noObjects struct{}
// loadNoObjects loads a sentinel.
func loadNoObjects(r io.Reader) noObjects { return noObjects{} }
// save implements Object.save.
func (noObjects) save(w io.Writer) {}
// load implements Object.load.
func (noObjects) load(r io.Reader) Object { return loadNoObjects(r) }
// Struct is a basic composite value.
type Struct struct {
TypeID TypeID
fields Object // Optionally noObjects or *multipleObjects.
}
// Field returns a pointer to the given field slot.
//
// This must be called after Alloc.
func (s *Struct) Field(i int) *Object {
if fields, ok := s.fields.(*multipleObjects); ok {
return &((*fields)[i])
}
if _, ok := s.fields.(noObjects); ok {
// Alloc may be optionally called; can't call twice.
panic("Field called inappropriately, wrong Alloc?")
}
return &s.fields
}
// Alloc allocates the given number of fields.
//
// This must be called before Add and Save.
//
// Precondition: slots must be positive.
func (s *Struct) Alloc(slots int) {
switch {
case slots == 0:
s.fields = noObjects{}
case slots == 1:
// Leave it alone.
case slots > 1:
fields := make(multipleObjects, slots)
s.fields = &fields
default:
// Violates precondition.
panic(fmt.Sprintf("Alloc called with negative slots %d?", slots))
}
}
// Fields returns the number of fields.
func (s *Struct) Fields() int {
switch x := s.fields.(type) {
case *multipleObjects:
return len(*x)
case noObjects:
return 0
default:
return 1
}
}
// loadStruct loads an object of type Struct.
func loadStruct(r io.Reader) Struct {
return Struct{
TypeID: TypeID(loadUint(r)),
fields: Load(r),
}
}
// save implements Object.save.
//
// Precondition: Alloc must have been called, and the fields all filled in
// appropriately. See Alloc and Add for more details.
func (s *Struct) save(w io.Writer) {
Uint(s.TypeID).save(w)
Save(w, s.fields)
}
// load implements Object.load.
func (*Struct) load(r io.Reader) Object {
s := loadStruct(r)
return &s
}
// Object types.
//
// N.B. Be careful about changing the order or introducing new elements in the
// middle here. This is part of the wire format and shouldn't change.
const (
typeBool Uint = iota
typeInt
typeUint
typeFloat32
typeFloat64
typeNil
typeRef
typeString
typeSlice
typeArray
typeMap
typeStruct
typeNoObjects
typeMultipleObjects
typeInterface
typeComplex64
typeComplex128
typeType
)
// Save saves the given object.
//
// +checkescape all
//
// N.B. This function will panic on error.
func Save(w io.Writer, obj Object) {
switch x := obj.(type) {
case Bool:
typeBool.save(w)
x.save(w)
case Int:
typeInt.save(w)
x.save(w)
case Uint:
typeUint.save(w)
x.save(w)
case Float32:
typeFloat32.save(w)
x.save(w)
case Float64:
typeFloat64.save(w)
x.save(w)
case Nil:
typeNil.save(w)
x.save(w)
case *Ref:
typeRef.save(w)
x.save(w)
case *String:
typeString.save(w)
x.save(w)
case *Slice:
typeSlice.save(w)
x.save(w)
case *Array:
typeArray.save(w)
x.save(w)
case *Map:
typeMap.save(w)
x.save(w)
case *Struct:
typeStruct.save(w)
x.save(w)
case noObjects:
typeNoObjects.save(w)
x.save(w)
case *multipleObjects:
typeMultipleObjects.save(w)
x.save(w)
case *Interface:
typeInterface.save(w)
x.save(w)
case *Type:
typeType.save(w)
x.save(w)
case *Complex64:
typeComplex64.save(w)
x.save(w)
case *Complex128:
typeComplex128.save(w)
x.save(w)
default:
panic(fmt.Errorf("unknown type: %#v", obj))
}
}
// Load loads a new object.
//
// +checkescape all
//
// N.B. This function will panic on error.
func Load(r io.Reader) Object {
switch hdr := loadUint(r); hdr {
case typeBool:
return loadBool(r)
case typeInt:
return loadInt(r)
case typeUint:
return loadUint(r)
case typeFloat32:
return loadFloat32(r)
case typeFloat64:
return loadFloat64(r)
case typeNil:
return loadNil(r)
case typeRef:
return ((*Ref)(nil)).load(r) // Escapes.
case typeString:
return ((*String)(nil)).load(r) // Escapes.
case typeSlice:
return ((*Slice)(nil)).load(r) // Escapes.
case typeArray:
return ((*Array)(nil)).load(r) // Escapes.
case typeMap:
return ((*Map)(nil)).load(r) // Escapes.
case typeStruct:
return ((*Struct)(nil)).load(r) // Escapes.
case typeNoObjects: // Special for struct.
return loadNoObjects(r)
case typeMultipleObjects: // Special for struct.
return ((*multipleObjects)(nil)).load(r) // Escapes.
case typeInterface:
return ((*Interface)(nil)).load(r) // Escapes.
case typeComplex64:
return ((*Complex64)(nil)).load(r) // Escapes.
case typeComplex128:
return ((*Complex128)(nil)).load(r) // Escapes.
case typeType:
return ((*Type)(nil)).load(r) // Escapes.
default:
// This is not a valid stream?
panic(fmt.Errorf("unknown header: %d", hdr))
}
}
// LoadUint loads a single unsigned integer.
//
// N.B. This function will panic on error.
func LoadUint(r io.Reader) uint64 {
return uint64(loadUint(r))
}
// SaveUint saves a single unsigned integer.
//
// N.B. This function will panic on error.
func SaveUint(w io.Writer, v uint64) {
Uint(v).save(w)
}