Update dependencies

This commit is contained in:
bluepython508
2024-11-01 17:33:34 +00:00
parent 033ac0b400
commit 5cdfab398d
3596 changed files with 1033483 additions and 259 deletions

View File

@@ -0,0 +1,713 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package gonet provides a Go net package compatible wrapper for a tcpip stack.
package gonet
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"time"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
var (
errCanceled = errors.New("operation canceled")
errWouldBlock = errors.New("operation would block")
)
// timeoutError is how the net package reports timeouts.
type timeoutError struct{}
func (e *timeoutError) Error() string { return "i/o timeout" }
func (e *timeoutError) Timeout() bool { return true }
func (e *timeoutError) Temporary() bool { return true }
// A TCPListener is a wrapper around a TCP tcpip.Endpoint that implements
// net.Listener.
type TCPListener struct {
stack *stack.Stack
ep tcpip.Endpoint
wq *waiter.Queue
cancelOnce sync.Once
cancel chan struct{}
}
// NewTCPListener creates a new TCPListener from a listening tcpip.Endpoint.
func NewTCPListener(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *TCPListener {
return &TCPListener{
stack: s,
ep: ep,
wq: wq,
cancel: make(chan struct{}),
}
}
// maxListenBacklog is set to be reasonably high for most uses of gonet. Go net
// package uses the value in /proc/sys/net/core/somaxconn file in Linux as the
// default listen backlog. The value below matches the default in common linux
// distros.
//
// See: https://cs.opensource.google/go/go/+/refs/tags/go1.18.1:src/net/sock_linux.go;drc=refs%2Ftags%2Fgo1.18.1;l=66
const maxListenBacklog = 4096
// ListenTCP creates a new TCPListener.
func ListenTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPListener, error) {
// Create a TCP endpoint, bind it, then start listening.
var wq waiter.Queue
ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
if err != nil {
return nil, errors.New(err.String())
}
if err := ep.Bind(addr); err != nil {
ep.Close()
return nil, &net.OpError{
Op: "bind",
Net: "tcp",
Addr: fullToTCPAddr(addr),
Err: errors.New(err.String()),
}
}
if err := ep.Listen(maxListenBacklog); err != nil {
ep.Close()
return nil, &net.OpError{
Op: "listen",
Net: "tcp",
Addr: fullToTCPAddr(addr),
Err: errors.New(err.String()),
}
}
return NewTCPListener(s, &wq, ep), nil
}
// Close implements net.Listener.Close.
func (l *TCPListener) Close() error {
l.ep.Close()
return nil
}
// Shutdown stops the HTTP server.
func (l *TCPListener) Shutdown() {
l.ep.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
l.cancelOnce.Do(func() {
close(l.cancel) // broadcast cancellation
})
}
// Addr implements net.Listener.Addr.
func (l *TCPListener) Addr() net.Addr {
a, err := l.ep.GetLocalAddress()
if err != nil {
return nil
}
return fullToTCPAddr(a)
}
type deadlineTimer struct {
// mu protects the fields below.
mu sync.Mutex
readTimer *time.Timer
readCancelCh chan struct{}
writeTimer *time.Timer
writeCancelCh chan struct{}
}
func (d *deadlineTimer) init() {
d.readCancelCh = make(chan struct{})
d.writeCancelCh = make(chan struct{})
}
func (d *deadlineTimer) readCancel() <-chan struct{} {
d.mu.Lock()
c := d.readCancelCh
d.mu.Unlock()
return c
}
func (d *deadlineTimer) writeCancel() <-chan struct{} {
d.mu.Lock()
c := d.writeCancelCh
d.mu.Unlock()
return c
}
// setDeadline contains the shared logic for setting a deadline.
//
// cancelCh and timer must be pointers to deadlineTimer.readCancelCh and
// deadlineTimer.readTimer or deadlineTimer.writeCancelCh and
// deadlineTimer.writeTimer.
//
// setDeadline must only be called while holding d.mu.
func (d *deadlineTimer) setDeadline(cancelCh *chan struct{}, timer **time.Timer, t time.Time) {
if *timer != nil && !(*timer).Stop() {
*cancelCh = make(chan struct{})
}
// Create a new channel if we already closed it due to setting an already
// expired time. We won't race with the timer because we already handled
// that above.
select {
case <-*cancelCh:
*cancelCh = make(chan struct{})
default:
}
// "A zero value for t means I/O operations will not time out."
// - net.Conn.SetDeadline
if t.IsZero() {
*timer = nil
return
}
timeout := t.Sub(time.Now())
if timeout <= 0 {
close(*cancelCh)
return
}
// Timer.Stop returns whether or not the AfterFunc has started, but
// does not indicate whether or not it has completed. Make a copy of
// the cancel channel to prevent this code from racing with the next
// call of setDeadline replacing *cancelCh.
ch := *cancelCh
*timer = time.AfterFunc(timeout, func() {
close(ch)
})
}
// SetReadDeadline implements net.Conn.SetReadDeadline and
// net.PacketConn.SetReadDeadline.
func (d *deadlineTimer) SetReadDeadline(t time.Time) error {
d.mu.Lock()
d.setDeadline(&d.readCancelCh, &d.readTimer, t)
d.mu.Unlock()
return nil
}
// SetWriteDeadline implements net.Conn.SetWriteDeadline and
// net.PacketConn.SetWriteDeadline.
func (d *deadlineTimer) SetWriteDeadline(t time.Time) error {
d.mu.Lock()
d.setDeadline(&d.writeCancelCh, &d.writeTimer, t)
d.mu.Unlock()
return nil
}
// SetDeadline implements net.Conn.SetDeadline and net.PacketConn.SetDeadline.
func (d *deadlineTimer) SetDeadline(t time.Time) error {
d.mu.Lock()
d.setDeadline(&d.readCancelCh, &d.readTimer, t)
d.setDeadline(&d.writeCancelCh, &d.writeTimer, t)
d.mu.Unlock()
return nil
}
// A TCPConn is a wrapper around a TCP tcpip.Endpoint that implements the net.Conn
// interface.
type TCPConn struct {
deadlineTimer
wq *waiter.Queue
ep tcpip.Endpoint
// readMu serializes reads and implicitly protects read.
//
// Lock ordering:
// If both readMu and deadlineTimer.mu are to be used in a single
// request, readMu must be acquired before deadlineTimer.mu.
readMu sync.Mutex
// read contains bytes that have been read from the endpoint,
// but haven't yet been returned.
read []byte
}
// NewTCPConn creates a new TCPConn.
func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn {
c := &TCPConn{
wq: wq,
ep: ep,
}
c.deadlineTimer.init()
return c
}
// Accept implements net.Conn.Accept.
func (l *TCPListener) Accept() (net.Conn, error) {
n, wq, err := l.ep.Accept(nil)
if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Create wait queue entry that notifies a channel.
waitEntry, notifyCh := waiter.NewChannelEntry(waiter.ReadableEvents)
l.wq.EventRegister(&waitEntry)
defer l.wq.EventUnregister(&waitEntry)
for {
n, wq, err = l.ep.Accept(nil)
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
break
}
select {
case <-l.cancel:
return nil, errCanceled
case <-notifyCh:
}
}
}
if err != nil {
return nil, &net.OpError{
Op: "accept",
Net: "tcp",
Addr: l.Addr(),
Err: errors.New(err.String()),
}
}
return NewTCPConn(wq, n), nil
}
type opErrorer interface {
newOpError(op string, err error) *net.OpError
}
// commonRead implements the common logic between net.Conn.Read and
// net.PacketConn.ReadFrom.
func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer) (int, error) {
select {
case <-deadline:
return 0, errorer.newOpError("read", &timeoutError{})
default:
}
w := tcpip.SliceWriter(b)
opts := tcpip.ReadOptions{NeedRemoteAddr: addr != nil}
res, err := ep.Read(&w, opts)
if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Create wait queue entry that notifies a channel.
waitEntry, notifyCh := waiter.NewChannelEntry(waiter.ReadableEvents)
wq.EventRegister(&waitEntry)
defer wq.EventUnregister(&waitEntry)
for {
res, err = ep.Read(&w, opts)
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
break
}
select {
case <-deadline:
return 0, errorer.newOpError("read", &timeoutError{})
case <-notifyCh:
}
}
}
if _, ok := err.(*tcpip.ErrClosedForReceive); ok {
return 0, io.EOF
}
if err != nil {
return 0, errorer.newOpError("read", errors.New(err.String()))
}
if addr != nil {
*addr = res.RemoteAddr
}
return res.Count, nil
}
// Read implements net.Conn.Read.
func (c *TCPConn) Read(b []byte) (int, error) {
c.readMu.Lock()
defer c.readMu.Unlock()
deadline := c.readCancel()
n, err := commonRead(b, c.ep, c.wq, deadline, nil, c)
if n != 0 {
c.ep.ModerateRecvBuf(n)
}
return n, err
}
// Write implements net.Conn.Write.
func (c *TCPConn) Write(b []byte) (int, error) {
deadline := c.writeCancel()
// Check if deadlineTimer has already expired.
select {
case <-deadline:
return 0, c.newOpError("write", &timeoutError{})
default:
}
// We must handle two soft failure conditions simultaneously:
// 1. Write may write nothing and return *tcpip.ErrWouldBlock.
// If this happens, we need to register for notifications if we have
// not already and wait to try again.
// 2. Write may write fewer than the full number of bytes and return
// without error. In this case we need to try writing the remaining
// bytes again. I do not need to register for notifications.
//
// What is more, these two soft failure conditions can be interspersed.
// There is no guarantee that all of the condition #1s will occur before
// all of the condition #2s or visa-versa.
var (
r bytes.Reader
nbytes int
entry waiter.Entry
ch <-chan struct{}
)
for nbytes != len(b) {
r.Reset(b[nbytes:])
n, err := c.ep.Write(&r, tcpip.WriteOptions{})
nbytes += int(n)
switch err.(type) {
case nil:
case *tcpip.ErrWouldBlock:
if ch == nil {
entry, ch = waiter.NewChannelEntry(waiter.WritableEvents)
c.wq.EventRegister(&entry)
defer c.wq.EventUnregister(&entry)
} else {
// Don't wait immediately after registration in case more data
// became available between when we last checked and when we setup
// the notification.
select {
case <-deadline:
return nbytes, c.newOpError("write", &timeoutError{})
case <-ch:
continue
}
}
default:
return nbytes, c.newOpError("write", errors.New(err.String()))
}
}
return nbytes, nil
}
// Close implements net.Conn.Close.
func (c *TCPConn) Close() error {
c.ep.Close()
return nil
}
// CloseRead shuts down the reading side of the TCP connection. Most callers
// should just use Close.
//
// A TCP Half-Close is performed the same as CloseRead for *net.TCPConn.
func (c *TCPConn) CloseRead() error {
if terr := c.ep.Shutdown(tcpip.ShutdownRead); terr != nil {
return c.newOpError("close", errors.New(terr.String()))
}
return nil
}
// CloseWrite shuts down the writing side of the TCP connection. Most callers
// should just use Close.
//
// A TCP Half-Close is performed the same as CloseWrite for *net.TCPConn.
func (c *TCPConn) CloseWrite() error {
if terr := c.ep.Shutdown(tcpip.ShutdownWrite); terr != nil {
return c.newOpError("close", errors.New(terr.String()))
}
return nil
}
// LocalAddr implements net.Conn.LocalAddr.
func (c *TCPConn) LocalAddr() net.Addr {
a, err := c.ep.GetLocalAddress()
if err != nil {
return nil
}
return fullToTCPAddr(a)
}
// RemoteAddr implements net.Conn.RemoteAddr.
func (c *TCPConn) RemoteAddr() net.Addr {
a, err := c.ep.GetRemoteAddress()
if err != nil {
return nil
}
return fullToTCPAddr(a)
}
func (c *TCPConn) newOpError(op string, err error) *net.OpError {
return &net.OpError{
Op: op,
Net: "tcp",
Source: c.LocalAddr(),
Addr: c.RemoteAddr(),
Err: err,
}
}
func fullToTCPAddr(addr tcpip.FullAddress) *net.TCPAddr {
return &net.TCPAddr{IP: net.IP(addr.Addr.AsSlice()), Port: int(addr.Port)}
}
func fullToUDPAddr(addr tcpip.FullAddress) *net.UDPAddr {
return &net.UDPAddr{IP: net.IP(addr.Addr.AsSlice()), Port: int(addr.Port)}
}
// DialTCP creates a new TCPConn connected to the specified address.
func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) {
return DialContextTCP(context.Background(), s, addr, network)
}
// DialTCPWithBind creates a new TCPConn connected to the specified
// remoteAddress with its local address bound to localAddr.
func DialTCPWithBind(ctx context.Context, s *stack.Stack, localAddr, remoteAddr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) {
// Create TCP endpoint, then connect.
var wq waiter.Queue
ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
if err != nil {
return nil, errors.New(err.String())
}
// Create wait queue entry that notifies a channel.
//
// We do this unconditionally as Connect will always return an error.
waitEntry, notifyCh := waiter.NewChannelEntry(waiter.WritableEvents)
wq.EventRegister(&waitEntry)
defer wq.EventUnregister(&waitEntry)
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
// Bind before connect if requested.
if localAddr != (tcpip.FullAddress{}) {
if err = ep.Bind(localAddr); err != nil {
return nil, fmt.Errorf("ep.Bind(%+v) = %s", localAddr, err)
}
}
err = ep.Connect(remoteAddr)
if _, ok := err.(*tcpip.ErrConnectStarted); ok {
select {
case <-ctx.Done():
ep.Close()
return nil, ctx.Err()
case <-notifyCh:
}
err = ep.LastError()
}
if err != nil {
ep.Close()
return nil, &net.OpError{
Op: "connect",
Net: "tcp",
Addr: fullToTCPAddr(remoteAddr),
Err: errors.New(err.String()),
}
}
return NewTCPConn(&wq, ep), nil
}
// DialContextTCP creates a new TCPConn connected to the specified address
// with the option of adding cancellation and timeouts.
func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) {
return DialTCPWithBind(ctx, s, tcpip.FullAddress{} /* localAddr */, addr /* remoteAddr */, network)
}
// A UDPConn is a wrapper around a UDP tcpip.Endpoint that implements
// net.Conn and net.PacketConn.
type UDPConn struct {
deadlineTimer
ep tcpip.Endpoint
wq *waiter.Queue
}
// NewUDPConn creates a new UDPConn.
func NewUDPConn(wq *waiter.Queue, ep tcpip.Endpoint) *UDPConn {
c := &UDPConn{
ep: ep,
wq: wq,
}
c.deadlineTimer.init()
return c
}
// DialUDP creates a new UDPConn.
//
// If laddr is nil, a local address is automatically chosen.
//
// If raddr is nil, the UDPConn is left unconnected.
func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*UDPConn, error) {
var wq waiter.Queue
ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq)
if err != nil {
return nil, errors.New(err.String())
}
if laddr != nil {
if err := ep.Bind(*laddr); err != nil {
ep.Close()
return nil, &net.OpError{
Op: "bind",
Net: "udp",
Addr: fullToUDPAddr(*laddr),
Err: errors.New(err.String()),
}
}
}
c := NewUDPConn(&wq, ep)
if raddr != nil {
if err := c.ep.Connect(*raddr); err != nil {
c.ep.Close()
return nil, &net.OpError{
Op: "connect",
Net: "udp",
Addr: fullToUDPAddr(*raddr),
Err: errors.New(err.String()),
}
}
}
return c, nil
}
func (c *UDPConn) newOpError(op string, err error) *net.OpError {
return c.newRemoteOpError(op, nil, err)
}
func (c *UDPConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError {
return &net.OpError{
Op: op,
Net: "udp",
Source: c.LocalAddr(),
Addr: remote,
Err: err,
}
}
// RemoteAddr implements net.Conn.RemoteAddr.
func (c *UDPConn) RemoteAddr() net.Addr {
a, err := c.ep.GetRemoteAddress()
if err != nil {
return nil
}
return fullToUDPAddr(a)
}
// Read implements net.Conn.Read
func (c *UDPConn) Read(b []byte) (int, error) {
bytesRead, _, err := c.ReadFrom(b)
return bytesRead, err
}
// ReadFrom implements net.PacketConn.ReadFrom.
func (c *UDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
deadline := c.readCancel()
var addr tcpip.FullAddress
n, err := commonRead(b, c.ep, c.wq, deadline, &addr, c)
if err != nil {
return 0, nil, err
}
return n, fullToUDPAddr(addr), nil
}
func (c *UDPConn) Write(b []byte) (int, error) {
return c.WriteTo(b, nil)
}
// WriteTo implements net.PacketConn.WriteTo.
func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
deadline := c.writeCancel()
// Check if deadline has already expired.
select {
case <-deadline:
return 0, c.newRemoteOpError("write", addr, &timeoutError{})
default:
}
// If we're being called by Write, there is no addr
writeOptions := tcpip.WriteOptions{}
if addr != nil {
ua := addr.(*net.UDPAddr)
writeOptions.To = &tcpip.FullAddress{
Addr: tcpip.AddrFromSlice(ua.IP),
Port: uint16(ua.Port),
}
}
var r bytes.Reader
r.Reset(b)
n, err := c.ep.Write(&r, writeOptions)
if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Create wait queue entry that notifies a channel.
waitEntry, notifyCh := waiter.NewChannelEntry(waiter.WritableEvents)
c.wq.EventRegister(&waitEntry)
defer c.wq.EventUnregister(&waitEntry)
for {
select {
case <-deadline:
return int(n), c.newRemoteOpError("write", addr, &timeoutError{})
case <-notifyCh:
}
n, err = c.ep.Write(&r, writeOptions)
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
break
}
}
}
if err == nil {
return int(n), nil
}
return int(n), c.newRemoteOpError("write", addr, errors.New(err.String()))
}
// Close implements net.PacketConn.Close.
func (c *UDPConn) Close() error {
c.ep.Close()
return nil
}
// LocalAddr implements net.PacketConn.LocalAddr.
func (c *UDPConn) LocalAddr() net.Addr {
a, err := c.ep.GetLocalAddress()
if err != nil {
return nil
}
return fullToUDPAddr(a)
}

View File

@@ -0,0 +1,3 @@
// automatically generated by stateify.
package gonet

View File

@@ -0,0 +1,68 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package checksum provides the implementation of the encoding and decoding of
// network protocol headers.
package checksum
import (
"encoding/binary"
)
// Size is the size of a checksum.
//
// The checksum is held in a uint16 which is 2 bytes.
const Size = 2
// Put puts the checksum in the provided byte slice.
func Put(b []byte, xsum uint16) {
binary.BigEndian.PutUint16(b, xsum)
}
// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
// given byte array. This function uses an optimized version of the checksum
// algorithm.
//
// The initial checksum must have been computed on an even number of bytes.
func Checksum(buf []byte, initial uint16) uint16 {
s, _ := calculateChecksum(buf, false, initial)
return s
}
// Checksumer calculates checksum defined in RFC 1071.
type Checksumer struct {
sum uint16
odd bool
}
// Add adds b to checksum.
func (c *Checksumer) Add(b []byte) {
if len(b) > 0 {
c.sum, c.odd = calculateChecksum(b, c.odd, c.sum)
}
}
// Checksum returns the latest checksum value.
func (c *Checksumer) Checksum() uint16 {
return c.sum
}
// Combine combines the two uint16 to form their checksum. This is done
// by adding them and the carry.
//
// Note that checksum a must have been computed on an even number of bytes.
func Combine(a, b uint16) uint16 {
v := uint32(a) + uint32(b)
return uint16(v + v>>16)
}

View File

@@ -0,0 +1,3 @@
// automatically generated by stateify.
package checksum

View File

@@ -0,0 +1,182 @@
// Copyright 2023 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package checksum
import (
"encoding/binary"
"math/bits"
"unsafe"
)
// Note: odd indicates whether initial is a partial checksum over an odd number
// of bytes.
func calculateChecksum(buf []byte, odd bool, initial uint16) (uint16, bool) {
// Use a larger-than-uint16 accumulator to benefit from parallel summation
// as described in RFC 1071 1.2.C.
acc := uint64(initial)
// Handle an odd number of previously-summed bytes, and get the return
// value for odd.
if odd {
acc += uint64(buf[0])
buf = buf[1:]
}
odd = len(buf)&1 != 0
// Aligning &buf[0] below is much simpler if len(buf) >= 8; special-case
// smaller bufs.
if len(buf) < 8 {
if len(buf) >= 4 {
acc += (uint64(buf[0]) << 8) + uint64(buf[1])
acc += (uint64(buf[2]) << 8) + uint64(buf[3])
buf = buf[4:]
}
if len(buf) >= 2 {
acc += (uint64(buf[0]) << 8) + uint64(buf[1])
buf = buf[2:]
}
if len(buf) >= 1 {
acc += uint64(buf[0]) << 8
// buf = buf[1:] is skipped because it's unused and nogo will
// complain.
}
return reduce(acc), odd
}
// On little-endian architectures, multi-byte loads from buf will load
// bytes in the wrong order. Rather than byte-swap after each load (slow),
// we byte-swap the accumulator before summing any bytes and byte-swap it
// back before returning, which still produces the correct result as
// described in RFC 1071 1.2.B "Byte Order Independence".
//
// acc is at most a uint16 + a uint8, so its upper 32 bits must be 0s. We
// preserve this property by byte-swapping only the lower 32 bits of acc,
// so that additions to acc performed during alignment can't overflow.
acc = uint64(bswapIfLittleEndian32(uint32(acc)))
// Align &buf[0] to an 8-byte boundary.
bswapped := false
if sliceAddr(buf)&1 != 0 {
// Compute the rest of the partial checksum with bytes swapped, and
// swap back before returning; see the last paragraph of
// RFC 1071 1.2.B.
acc = uint64(bits.ReverseBytes32(uint32(acc)))
bswapped = true
// No `<< 8` here due to the byte swap we just did.
acc += uint64(bswapIfLittleEndian16(uint16(buf[0])))
buf = buf[1:]
}
if sliceAddr(buf)&2 != 0 {
acc += uint64(*(*uint16)(unsafe.Pointer(&buf[0])))
buf = buf[2:]
}
if sliceAddr(buf)&4 != 0 {
acc += uint64(*(*uint32)(unsafe.Pointer(&buf[0])))
buf = buf[4:]
}
// Sum 64 bytes at a time. Beyond this point, additions to acc may
// overflow, so we have to handle carrying.
for len(buf) >= 64 {
var carry uint64
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0)
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[8])), carry)
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[16])), carry)
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[24])), carry)
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[32])), carry)
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[40])), carry)
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[48])), carry)
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[56])), carry)
acc, _ = bits.Add64(acc, 0, carry)
buf = buf[64:]
}
// Sum the remaining 0-63 bytes.
if len(buf) >= 32 {
var carry uint64
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0)
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[8])), carry)
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[16])), carry)
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[24])), carry)
acc, _ = bits.Add64(acc, 0, carry)
buf = buf[32:]
}
if len(buf) >= 16 {
var carry uint64
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0)
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[8])), carry)
acc, _ = bits.Add64(acc, 0, carry)
buf = buf[16:]
}
if len(buf) >= 8 {
var carry uint64
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0)
acc, _ = bits.Add64(acc, 0, carry)
buf = buf[8:]
}
if len(buf) >= 4 {
var carry uint64
acc, carry = bits.Add64(acc, uint64(*(*uint32)(unsafe.Pointer(&buf[0]))), 0)
acc, _ = bits.Add64(acc, 0, carry)
buf = buf[4:]
}
if len(buf) >= 2 {
var carry uint64
acc, carry = bits.Add64(acc, uint64(*(*uint16)(unsafe.Pointer(&buf[0]))), 0)
acc, _ = bits.Add64(acc, 0, carry)
buf = buf[2:]
}
if len(buf) >= 1 {
// bswapIfBigEndian16(buf[0]) == bswapIfLittleEndian16(buf[0]<<8).
var carry uint64
acc, carry = bits.Add64(acc, uint64(bswapIfBigEndian16(uint16(buf[0]))), 0)
acc, _ = bits.Add64(acc, 0, carry)
// buf = buf[1:] is skipped because it's unused and nogo will complain.
}
// Reduce the checksum to 16 bits and undo byte swaps before returning.
acc16 := bswapIfLittleEndian16(reduce(acc))
if bswapped {
acc16 = bits.ReverseBytes16(acc16)
}
return acc16, odd
}
func reduce(acc uint64) uint16 {
// Ideally we would do:
// return uint16(acc>>48) +' uint16(acc>>32) +' uint16(acc>>16) +' uint16(acc)
// for more instruction-level parallelism; however, there is no
// bits.Add16().
acc = (acc >> 32) + (acc & 0xffff_ffff) // at most 0x1_ffff_fffe
acc32 := uint32(acc>>32 + acc) // at most 0xffff_ffff
acc32 = (acc32 >> 16) + (acc32 & 0xffff) // at most 0x1_fffe
return uint16(acc32>>16 + acc32) // at most 0xffff
}
func bswapIfLittleEndian32(val uint32) uint32 {
return binary.BigEndian.Uint32((*[4]byte)(unsafe.Pointer(&val))[:])
}
func bswapIfLittleEndian16(val uint16) uint16 {
return binary.BigEndian.Uint16((*[2]byte)(unsafe.Pointer(&val))[:])
}
func bswapIfBigEndian16(val uint16) uint16 {
return binary.LittleEndian.Uint16((*[2]byte)(unsafe.Pointer(&val))[:])
}
func sliceAddr(buf []byte) uintptr {
return uintptr(unsafe.Pointer(unsafe.SliceData(buf)))
}

View File

@@ -0,0 +1,3 @@
// automatically generated by stateify.
package checksum

View File

@@ -0,0 +1,623 @@
// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tcpip
import (
"fmt"
)
// Error represents an error in the netstack error space.
//
// The error interface is intentionally omitted to avoid loss of type
// information that would occur if these errors were passed as error.
type Error interface {
isError()
// IgnoreStats indicates whether this error should be included in failure
// counts in tcpip.Stats structs.
IgnoreStats() bool
fmt.Stringer
}
const maxErrno = 134
// LINT.IfChange
// ErrAborted indicates the operation was aborted.
//
// +stateify savable
type ErrAborted struct{}
func (*ErrAborted) isError() {}
// IgnoreStats implements Error.
func (*ErrAborted) IgnoreStats() bool {
return false
}
func (*ErrAborted) String() string {
return "operation aborted"
}
// ErrAddressFamilyNotSupported indicates the operation does not support the
// given address family.
//
// +stateify savable
type ErrAddressFamilyNotSupported struct{}
func (*ErrAddressFamilyNotSupported) isError() {}
// IgnoreStats implements Error.
func (*ErrAddressFamilyNotSupported) IgnoreStats() bool {
return false
}
func (*ErrAddressFamilyNotSupported) String() string {
return "address family not supported by protocol"
}
// ErrAlreadyBound indicates the endpoint is already bound.
//
// +stateify savable
type ErrAlreadyBound struct{}
func (*ErrAlreadyBound) isError() {}
// IgnoreStats implements Error.
func (*ErrAlreadyBound) IgnoreStats() bool {
return true
}
func (*ErrAlreadyBound) String() string { return "endpoint already bound" }
// ErrAlreadyConnected indicates the endpoint is already connected.
//
// +stateify savable
type ErrAlreadyConnected struct{}
func (*ErrAlreadyConnected) isError() {}
// IgnoreStats implements Error.
func (*ErrAlreadyConnected) IgnoreStats() bool {
return true
}
func (*ErrAlreadyConnected) String() string { return "endpoint is already connected" }
// ErrAlreadyConnecting indicates the endpoint is already connecting.
//
// +stateify savable
type ErrAlreadyConnecting struct{}
func (*ErrAlreadyConnecting) isError() {}
// IgnoreStats implements Error.
func (*ErrAlreadyConnecting) IgnoreStats() bool {
return true
}
func (*ErrAlreadyConnecting) String() string { return "endpoint is already connecting" }
// ErrBadAddress indicates a bad address was provided.
//
// +stateify savable
type ErrBadAddress struct{}
func (*ErrBadAddress) isError() {}
// IgnoreStats implements Error.
func (*ErrBadAddress) IgnoreStats() bool {
return false
}
func (*ErrBadAddress) String() string { return "bad address" }
// ErrBadBuffer indicates a bad buffer was provided.
//
// +stateify savable
type ErrBadBuffer struct{}
func (*ErrBadBuffer) isError() {}
// IgnoreStats implements Error.
func (*ErrBadBuffer) IgnoreStats() bool {
return false
}
func (*ErrBadBuffer) String() string { return "bad buffer" }
// ErrBadLocalAddress indicates a bad local address was provided.
//
// +stateify savable
type ErrBadLocalAddress struct{}
func (*ErrBadLocalAddress) isError() {}
// IgnoreStats implements Error.
func (*ErrBadLocalAddress) IgnoreStats() bool {
return false
}
func (*ErrBadLocalAddress) String() string { return "bad local address" }
// ErrBroadcastDisabled indicates broadcast is not enabled on the endpoint.
//
// +stateify savable
type ErrBroadcastDisabled struct{}
func (*ErrBroadcastDisabled) isError() {}
// IgnoreStats implements Error.
func (*ErrBroadcastDisabled) IgnoreStats() bool {
return false
}
func (*ErrBroadcastDisabled) String() string { return "broadcast socket option disabled" }
// ErrClosedForReceive indicates the endpoint is closed for incoming data.
//
// +stateify savable
type ErrClosedForReceive struct{}
func (*ErrClosedForReceive) isError() {}
// IgnoreStats implements Error.
func (*ErrClosedForReceive) IgnoreStats() bool {
return false
}
func (*ErrClosedForReceive) String() string { return "endpoint is closed for receive" }
// ErrClosedForSend indicates the endpoint is closed for outgoing data.
//
// +stateify savable
type ErrClosedForSend struct{}
func (*ErrClosedForSend) isError() {}
// IgnoreStats implements Error.
func (*ErrClosedForSend) IgnoreStats() bool {
return false
}
func (*ErrClosedForSend) String() string { return "endpoint is closed for send" }
// ErrConnectStarted indicates the endpoint is connecting asynchronously.
//
// +stateify savable
type ErrConnectStarted struct{}
func (*ErrConnectStarted) isError() {}
// IgnoreStats implements Error.
func (*ErrConnectStarted) IgnoreStats() bool {
return true
}
func (*ErrConnectStarted) String() string { return "connection attempt started" }
// ErrConnectionAborted indicates the connection was aborted.
//
// +stateify savable
type ErrConnectionAborted struct{}
func (*ErrConnectionAborted) isError() {}
// IgnoreStats implements Error.
func (*ErrConnectionAborted) IgnoreStats() bool {
return false
}
func (*ErrConnectionAborted) String() string { return "connection aborted" }
// ErrConnectionRefused indicates the connection was refused.
//
// +stateify savable
type ErrConnectionRefused struct{}
func (*ErrConnectionRefused) isError() {}
// IgnoreStats implements Error.
func (*ErrConnectionRefused) IgnoreStats() bool {
return false
}
func (*ErrConnectionRefused) String() string { return "connection was refused" }
// ErrConnectionReset indicates the connection was reset.
//
// +stateify savable
type ErrConnectionReset struct{}
func (*ErrConnectionReset) isError() {}
// IgnoreStats implements Error.
func (*ErrConnectionReset) IgnoreStats() bool {
return false
}
func (*ErrConnectionReset) String() string { return "connection reset by peer" }
// ErrDestinationRequired indicates the operation requires a destination
// address, and one was not provided.
//
// +stateify savable
type ErrDestinationRequired struct{}
func (*ErrDestinationRequired) isError() {}
// IgnoreStats implements Error.
func (*ErrDestinationRequired) IgnoreStats() bool {
return false
}
func (*ErrDestinationRequired) String() string { return "destination address is required" }
// ErrDuplicateAddress indicates the operation encountered a duplicate address.
//
// +stateify savable
type ErrDuplicateAddress struct{}
func (*ErrDuplicateAddress) isError() {}
// IgnoreStats implements Error.
func (*ErrDuplicateAddress) IgnoreStats() bool {
return false
}
func (*ErrDuplicateAddress) String() string { return "duplicate address" }
// ErrDuplicateNICID indicates the operation encountered a duplicate NIC ID.
//
// +stateify savable
type ErrDuplicateNICID struct{}
func (*ErrDuplicateNICID) isError() {}
// IgnoreStats implements Error.
func (*ErrDuplicateNICID) IgnoreStats() bool {
return false
}
func (*ErrDuplicateNICID) String() string { return "duplicate nic id" }
// ErrInvalidNICID indicates the operation used an invalid NIC ID.
//
// +stateify savable
type ErrInvalidNICID struct{}
func (*ErrInvalidNICID) isError() {}
// IgnoreStats implements Error.
func (*ErrInvalidNICID) IgnoreStats() bool {
return false
}
func (*ErrInvalidNICID) String() string { return "invalid nic id" }
// ErrInvalidEndpointState indicates the endpoint is in an invalid state.
//
// +stateify savable
type ErrInvalidEndpointState struct{}
func (*ErrInvalidEndpointState) isError() {}
// IgnoreStats implements Error.
func (*ErrInvalidEndpointState) IgnoreStats() bool {
return false
}
func (*ErrInvalidEndpointState) String() string { return "endpoint is in invalid state" }
// ErrInvalidOptionValue indicates an invalid option value was provided.
//
// +stateify savable
type ErrInvalidOptionValue struct{}
func (*ErrInvalidOptionValue) isError() {}
// IgnoreStats implements Error.
func (*ErrInvalidOptionValue) IgnoreStats() bool {
return false
}
func (*ErrInvalidOptionValue) String() string { return "invalid option value specified" }
// ErrInvalidPortRange indicates an attempt to set an invalid port range.
//
// +stateify savable
type ErrInvalidPortRange struct{}
func (*ErrInvalidPortRange) isError() {}
// IgnoreStats implements Error.
func (*ErrInvalidPortRange) IgnoreStats() bool {
return true
}
func (*ErrInvalidPortRange) String() string { return "invalid port range" }
// ErrMalformedHeader indicates the operation encountered a malformed header.
//
// +stateify savable
type ErrMalformedHeader struct{}
func (*ErrMalformedHeader) isError() {}
// IgnoreStats implements Error.
func (*ErrMalformedHeader) IgnoreStats() bool {
return false
}
func (*ErrMalformedHeader) String() string { return "header is malformed" }
// ErrMessageTooLong indicates the operation encountered a message whose length
// exceeds the maximum permitted.
//
// +stateify savable
type ErrMessageTooLong struct{}
func (*ErrMessageTooLong) isError() {}
// IgnoreStats implements Error.
func (*ErrMessageTooLong) IgnoreStats() bool {
return false
}
func (*ErrMessageTooLong) String() string { return "message too long" }
// ErrNetworkUnreachable indicates the operation is not able to reach the
// destination network.
//
// +stateify savable
type ErrNetworkUnreachable struct{}
func (*ErrNetworkUnreachable) isError() {}
// IgnoreStats implements Error.
func (*ErrNetworkUnreachable) IgnoreStats() bool {
return false
}
func (*ErrNetworkUnreachable) String() string { return "network is unreachable" }
// ErrNoBufferSpace indicates no buffer space is available.
//
// +stateify savable
type ErrNoBufferSpace struct{}
func (*ErrNoBufferSpace) isError() {}
// IgnoreStats implements Error.
func (*ErrNoBufferSpace) IgnoreStats() bool {
return false
}
func (*ErrNoBufferSpace) String() string { return "no buffer space available" }
// ErrNoPortAvailable indicates no port could be allocated for the operation.
//
// +stateify savable
type ErrNoPortAvailable struct{}
func (*ErrNoPortAvailable) isError() {}
// IgnoreStats implements Error.
func (*ErrNoPortAvailable) IgnoreStats() bool {
return false
}
func (*ErrNoPortAvailable) String() string { return "no ports are available" }
// ErrHostUnreachable indicates that a destination host could not be
// reached.
//
// +stateify savable
type ErrHostUnreachable struct{}
func (*ErrHostUnreachable) isError() {}
// IgnoreStats implements Error.
func (*ErrHostUnreachable) IgnoreStats() bool {
return false
}
func (*ErrHostUnreachable) String() string { return "no route to host" }
// ErrHostDown indicates that a destination host is down.
//
// +stateify savable
type ErrHostDown struct{}
func (*ErrHostDown) isError() {}
// IgnoreStats implements Error.
func (*ErrHostDown) IgnoreStats() bool {
return false
}
func (*ErrHostDown) String() string { return "host is down" }
// ErrNoNet indicates that the host is not on the network.
//
// +stateify savable
type ErrNoNet struct{}
func (*ErrNoNet) isError() {}
// IgnoreStats implements Error.
func (*ErrNoNet) IgnoreStats() bool {
return false
}
func (*ErrNoNet) String() string { return "machine is not on the network" }
// ErrNoSuchFile is used to indicate that ENOENT should be returned the to
// calling application.
//
// +stateify savable
type ErrNoSuchFile struct{}
func (*ErrNoSuchFile) isError() {}
// IgnoreStats implements Error.
func (*ErrNoSuchFile) IgnoreStats() bool {
return false
}
func (*ErrNoSuchFile) String() string { return "no such file" }
// ErrNotConnected indicates the endpoint is not connected.
//
// +stateify savable
type ErrNotConnected struct{}
func (*ErrNotConnected) isError() {}
// IgnoreStats implements Error.
func (*ErrNotConnected) IgnoreStats() bool {
return false
}
func (*ErrNotConnected) String() string { return "endpoint not connected" }
// ErrNotPermitted indicates the operation is not permitted.
//
// +stateify savable
type ErrNotPermitted struct{}
func (*ErrNotPermitted) isError() {}
// IgnoreStats implements Error.
func (*ErrNotPermitted) IgnoreStats() bool {
return false
}
func (*ErrNotPermitted) String() string { return "operation not permitted" }
// ErrNotSupported indicates the operation is not supported.
//
// +stateify savable
type ErrNotSupported struct{}
func (*ErrNotSupported) isError() {}
// IgnoreStats implements Error.
func (*ErrNotSupported) IgnoreStats() bool {
return false
}
func (*ErrNotSupported) String() string { return "operation not supported" }
// ErrPortInUse indicates the provided port is in use.
//
// +stateify savable
type ErrPortInUse struct{}
func (*ErrPortInUse) isError() {}
// IgnoreStats implements Error.
func (*ErrPortInUse) IgnoreStats() bool {
return false
}
func (*ErrPortInUse) String() string { return "port is in use" }
// ErrQueueSizeNotSupported indicates the endpoint does not allow queue size
// operation.
//
// +stateify savable
type ErrQueueSizeNotSupported struct{}
func (*ErrQueueSizeNotSupported) isError() {}
// IgnoreStats implements Error.
func (*ErrQueueSizeNotSupported) IgnoreStats() bool {
return false
}
func (*ErrQueueSizeNotSupported) String() string { return "queue size querying not supported" }
// ErrTimeout indicates the operation timed out.
//
// +stateify savable
type ErrTimeout struct{}
func (*ErrTimeout) isError() {}
// IgnoreStats implements Error.
func (*ErrTimeout) IgnoreStats() bool {
return false
}
func (*ErrTimeout) String() string { return "operation timed out" }
// ErrUnknownDevice indicates an unknown device identifier was provided.
//
// +stateify savable
type ErrUnknownDevice struct{}
func (*ErrUnknownDevice) isError() {}
// IgnoreStats implements Error.
func (*ErrUnknownDevice) IgnoreStats() bool {
return false
}
func (*ErrUnknownDevice) String() string { return "unknown device" }
// ErrUnknownNICID indicates an unknown NIC ID was provided.
//
// +stateify savable
type ErrUnknownNICID struct{}
func (*ErrUnknownNICID) isError() {}
// IgnoreStats implements Error.
func (*ErrUnknownNICID) IgnoreStats() bool {
return false
}
func (*ErrUnknownNICID) String() string { return "unknown nic id" }
// ErrUnknownProtocol indicates an unknown protocol was requested.
//
// +stateify savable
type ErrUnknownProtocol struct{}
func (*ErrUnknownProtocol) isError() {}
// IgnoreStats implements Error.
func (*ErrUnknownProtocol) IgnoreStats() bool {
return false
}
func (*ErrUnknownProtocol) String() string { return "unknown protocol" }
// ErrUnknownProtocolOption indicates an unknown protocol option was provided.
//
// +stateify savable
type ErrUnknownProtocolOption struct{}
func (*ErrUnknownProtocolOption) isError() {}
// IgnoreStats implements Error.
func (*ErrUnknownProtocolOption) IgnoreStats() bool {
return false
}
func (*ErrUnknownProtocolOption) String() string { return "unknown option for protocol" }
// ErrWouldBlock indicates the operation would block.
//
// +stateify savable
type ErrWouldBlock struct{}
func (*ErrWouldBlock) isError() {}
// IgnoreStats implements Error.
func (*ErrWouldBlock) IgnoreStats() bool {
return true
}
func (*ErrWouldBlock) String() string { return "operation would block" }
// ErrMissingRequiredFields indicates that a required field is missing.
//
// +stateify savable
type ErrMissingRequiredFields struct{}
func (*ErrMissingRequiredFields) isError() {}
// IgnoreStats implements Error.
func (*ErrMissingRequiredFields) IgnoreStats() bool {
return true
}
func (*ErrMissingRequiredFields) String() string { return "missing required fields" }
// ErrMulticastInputCannotBeOutput indicates that an input interface matches an
// output interface in the same multicast route.
//
// +stateify savable
type ErrMulticastInputCannotBeOutput struct{}
func (*ErrMulticastInputCannotBeOutput) isError() {}
// IgnoreStats implements Error.
func (*ErrMulticastInputCannotBeOutput) IgnoreStats() bool {
return true
}
func (*ErrMulticastInputCannotBeOutput) String() string { return "output cannot contain input" }
// LINT.ThenChange(../syserr/netstack.go)

View File

@@ -0,0 +1,74 @@
// Copyright 2024 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build linux
// +build linux
package tcpip
import (
"golang.org/x/sys/unix"
)
// TranslateErrno translate an errno from the syscall package into a
// tcpip Error.
//
// Valid, but unrecognized errnos will be translated to
// *ErrInvalidEndpointState (EINVAL). This includes the "zero" value.
func TranslateErrno(e unix.Errno) Error {
switch e {
case unix.EEXIST:
return &ErrDuplicateAddress{}
case unix.ENETUNREACH:
return &ErrHostUnreachable{}
case unix.EINVAL:
return &ErrInvalidEndpointState{}
case unix.EALREADY:
return &ErrAlreadyConnecting{}
case unix.EISCONN:
return &ErrAlreadyConnected{}
case unix.EADDRINUSE:
return &ErrPortInUse{}
case unix.EADDRNOTAVAIL:
return &ErrBadLocalAddress{}
case unix.EPIPE:
return &ErrClosedForSend{}
case unix.EWOULDBLOCK:
return &ErrWouldBlock{}
case unix.ECONNREFUSED:
return &ErrConnectionRefused{}
case unix.ETIMEDOUT:
return &ErrTimeout{}
case unix.EINPROGRESS:
return &ErrConnectStarted{}
case unix.EDESTADDRREQ:
return &ErrDestinationRequired{}
case unix.ENOTSUP:
return &ErrNotSupported{}
case unix.ENOTTY:
return &ErrQueueSizeNotSupported{}
case unix.ENOTCONN:
return &ErrNotConnected{}
case unix.ECONNRESET:
return &ErrConnectionReset{}
case unix.ECONNABORTED:
return &ErrConnectionAborted{}
case unix.EMSGSIZE:
return &ErrMessageTooLong{}
case unix.ENOBUFS:
return &ErrNoBufferSpace{}
default:
return &ErrInvalidEndpointState{}
}
}

View File

@@ -0,0 +1,79 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package jenkins implements Jenkins's one_at_a_time, non-cryptographic hash
// functions created by by Bob Jenkins.
//
// See https://en.wikipedia.org/wiki/Jenkins_hash_function#cite_note-dobbsx-1
package jenkins
import (
"hash"
)
// Sum32 represents Jenkins's one_at_a_time hash.
//
// Use the Sum32 type directly (as opposed to New32 below)
// to avoid allocations.
type Sum32 uint32
// New32 returns a new 32-bit Jenkins's one_at_a_time hash.Hash.
//
// Its Sum method will lay the value out in big-endian byte order.
func New32() hash.Hash32 {
var s Sum32
return &s
}
// Reset resets the hash to its initial state.
func (s *Sum32) Reset() { *s = 0 }
// Sum32 returns the hash value
func (s *Sum32) Sum32() uint32 {
sCopy := *s
sCopy += sCopy << 3
sCopy ^= sCopy >> 11
sCopy += sCopy << 15
return uint32(sCopy)
}
// Write adds more data to the running hash.
//
// It never returns an error.
func (s *Sum32) Write(data []byte) (int, error) {
sCopy := *s
for _, b := range data {
sCopy += Sum32(b)
sCopy += sCopy << 10
sCopy ^= sCopy >> 6
}
*s = sCopy
return len(data), nil
}
// Size returns the number of bytes Sum will return.
func (s *Sum32) Size() int { return 4 }
// BlockSize returns the hash's underlying block size.
func (s *Sum32) BlockSize() int { return 1 }
// Sum appends the current hash to in and returns the resulting slice.
//
// It does not change the underlying hash state.
func (s *Sum32) Sum(in []byte) []byte {
v := s.Sum32()
return append(in, byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
}

View File

@@ -0,0 +1,3 @@
// automatically generated by stateify.
package jenkins

View File

@@ -0,0 +1,127 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"encoding/binary"
"gvisor.dev/gvisor/pkg/tcpip"
)
const (
// ARPProtocolNumber is the ARP network protocol number.
ARPProtocolNumber tcpip.NetworkProtocolNumber = 0x0806
// ARPSize is the size of an IPv4-over-Ethernet ARP packet.
ARPSize = 28
)
// ARPHardwareType is the hardware type for LinkEndpoint in an ARP header.
type ARPHardwareType uint16
// Typical ARP HardwareType values. Some of the constants have to be specific
// values as they are egressed on the wire in the HTYPE field of an ARP header.
const (
ARPHardwareNone ARPHardwareType = 0
// ARPHardwareEther specifically is the HTYPE for Ethernet as specified
// in the IANA list here:
//
// https://www.iana.org/assignments/arp-parameters/arp-parameters.xhtml#arp-parameters-2
ARPHardwareEther ARPHardwareType = 1
ARPHardwareLoopback ARPHardwareType = 2
)
// ARPOp is an ARP opcode.
type ARPOp uint16
// Typical ARP opcodes defined in RFC 826.
const (
ARPRequest ARPOp = 1
ARPReply ARPOp = 2
)
// ARP is an ARP packet stored in a byte array as described in RFC 826.
type ARP []byte
const (
hTypeOffset = 0
protocolOffset = 2
haAddressSizeOffset = 4
protoAddressSizeOffset = 5
opCodeOffset = 6
senderHAAddressOffset = 8
senderProtocolAddressOffset = senderHAAddressOffset + EthernetAddressSize
targetHAAddressOffset = senderProtocolAddressOffset + IPv4AddressSize
targetProtocolAddressOffset = targetHAAddressOffset + EthernetAddressSize
)
func (a ARP) hardwareAddressType() ARPHardwareType {
return ARPHardwareType(binary.BigEndian.Uint16(a[hTypeOffset:]))
}
func (a ARP) protocolAddressSpace() uint16 { return binary.BigEndian.Uint16(a[protocolOffset:]) }
func (a ARP) hardwareAddressSize() int { return int(a[haAddressSizeOffset]) }
func (a ARP) protocolAddressSize() int { return int(a[protoAddressSizeOffset]) }
// Op is the ARP opcode.
func (a ARP) Op() ARPOp { return ARPOp(binary.BigEndian.Uint16(a[opCodeOffset:])) }
// SetOp sets the ARP opcode.
func (a ARP) SetOp(op ARPOp) {
binary.BigEndian.PutUint16(a[opCodeOffset:], uint16(op))
}
// SetIPv4OverEthernet configures the ARP packet for IPv4-over-Ethernet.
func (a ARP) SetIPv4OverEthernet() {
binary.BigEndian.PutUint16(a[hTypeOffset:], uint16(ARPHardwareEther))
binary.BigEndian.PutUint16(a[protocolOffset:], uint16(IPv4ProtocolNumber))
a[haAddressSizeOffset] = EthernetAddressSize
a[protoAddressSizeOffset] = uint8(IPv4AddressSize)
}
// HardwareAddressSender is the link address of the sender.
// It is a view on to the ARP packet so it can be used to set the value.
func (a ARP) HardwareAddressSender() []byte {
return a[senderHAAddressOffset : senderHAAddressOffset+EthernetAddressSize]
}
// ProtocolAddressSender is the protocol address of the sender.
// It is a view on to the ARP packet so it can be used to set the value.
func (a ARP) ProtocolAddressSender() []byte {
return a[senderProtocolAddressOffset : senderProtocolAddressOffset+IPv4AddressSize]
}
// HardwareAddressTarget is the link address of the target.
// It is a view on to the ARP packet so it can be used to set the value.
func (a ARP) HardwareAddressTarget() []byte {
return a[targetHAAddressOffset : targetHAAddressOffset+EthernetAddressSize]
}
// ProtocolAddressTarget is the protocol address of the target.
// It is a view on to the ARP packet so it can be used to set the value.
func (a ARP) ProtocolAddressTarget() []byte {
return a[targetProtocolAddressOffset : targetProtocolAddressOffset+IPv4AddressSize]
}
// IsValid reports whether this is an ARP packet for IPv4 over Ethernet.
func (a ARP) IsValid() bool {
if len(a) < ARPSize {
return false
}
return a.hardwareAddressType() == ARPHardwareEther &&
a.protocolAddressSpace() == uint16(IPv4ProtocolNumber) &&
a.hardwareAddressSize() == EthernetAddressSize &&
a.protocolAddressSize() == IPv4AddressSize
}

View File

@@ -0,0 +1,107 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package header provides the implementation of the encoding and decoding of
// network protocol headers.
package header
import (
"encoding/binary"
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/checksum"
)
// PseudoHeaderChecksum calculates the pseudo-header checksum for the given
// destination protocol and network address. Pseudo-headers are needed by
// transport layers when calculating their own checksum.
func PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, srcAddr tcpip.Address, dstAddr tcpip.Address, totalLen uint16) uint16 {
xsum := checksum.Checksum(srcAddr.AsSlice(), 0)
xsum = checksum.Checksum(dstAddr.AsSlice(), xsum)
// Add the length portion of the checksum to the pseudo-checksum.
var tmp [2]byte
binary.BigEndian.PutUint16(tmp[:], totalLen)
xsum = checksum.Checksum(tmp[:], xsum)
return checksum.Checksum([]byte{0, uint8(protocol)}, xsum)
}
// checksumUpdate2ByteAlignedUint16 updates a uint16 value in a calculated
// checksum.
//
// The value MUST begin at a 2-byte boundary in the original buffer.
func checksumUpdate2ByteAlignedUint16(xsum, old, new uint16) uint16 {
// As per RFC 1071 page 4,
// (4) Incremental Update
//
// ...
//
// To update the checksum, simply add the differences of the
// sixteen bit integers that have been changed. To see why this
// works, observe that every 16-bit integer has an additive inverse
// and that addition is associative. From this it follows that
// given the original value m, the new value m', and the old
// checksum C, the new checksum C' is:
//
// C' = C + (-m) + m' = C + (m' - m)
if old == new {
return xsum
}
return checksum.Combine(xsum, checksum.Combine(new, ^old))
}
// checksumUpdate2ByteAlignedAddress updates an address in a calculated
// checksum.
//
// The addresses must have the same length and must contain an even number
// of bytes. The address MUST begin at a 2-byte boundary in the original buffer.
func checksumUpdate2ByteAlignedAddress(xsum uint16, old, new tcpip.Address) uint16 {
const uint16Bytes = 2
if old.BitLen() != new.BitLen() {
panic(fmt.Sprintf("buffer lengths are different; old = %d, new = %d", old.BitLen()/8, new.BitLen()/8))
}
if oldBytes := old.BitLen() % 16; oldBytes != 0 {
panic(fmt.Sprintf("buffer has an odd number of bytes; got = %d", oldBytes))
}
oldAddr := old.AsSlice()
newAddr := new.AsSlice()
// As per RFC 1071 page 4,
// (4) Incremental Update
//
// ...
//
// To update the checksum, simply add the differences of the
// sixteen bit integers that have been changed. To see why this
// works, observe that every 16-bit integer has an additive inverse
// and that addition is associative. From this it follows that
// given the original value m, the new value m', and the old
// checksum C, the new checksum C' is:
//
// C' = C + (-m) + m' = C + (m' - m)
for len(oldAddr) != 0 {
// Convert the 2 byte sequences to uint16 values then apply the increment
// update.
xsum = checksumUpdate2ByteAlignedUint16(xsum, (uint16(oldAddr[0])<<8)+uint16(oldAddr[1]), (uint16(newAddr[0])<<8)+uint16(newAddr[1]))
oldAddr = oldAddr[uint16Bytes:]
newAddr = newAddr[uint16Bytes:]
}
return xsum
}

View File

@@ -0,0 +1,18 @@
// Copyright 2022 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
// DatagramMaximumSize is the maximum supported size of a single datagram.
const DatagramMaximumSize = 0xffff // 65KB.

View File

@@ -0,0 +1,192 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"encoding/binary"
"gvisor.dev/gvisor/pkg/tcpip"
)
const (
dstMAC = 0
srcMAC = 6
ethType = 12
)
// EthernetFields contains the fields of an ethernet frame header. It is used to
// describe the fields of a frame that needs to be encoded.
type EthernetFields struct {
// SrcAddr is the "MAC source" field of an ethernet frame header.
SrcAddr tcpip.LinkAddress
// DstAddr is the "MAC destination" field of an ethernet frame header.
DstAddr tcpip.LinkAddress
// Type is the "ethertype" field of an ethernet frame header.
Type tcpip.NetworkProtocolNumber
}
// Ethernet represents an ethernet frame header stored in a byte array.
type Ethernet []byte
const (
// EthernetMinimumSize is the minimum size of a valid ethernet frame.
EthernetMinimumSize = 14
// EthernetMaximumSize is the maximum size of a valid ethernet frame.
EthernetMaximumSize = 18
// EthernetAddressSize is the size, in bytes, of an ethernet address.
EthernetAddressSize = 6
// UnspecifiedEthernetAddress is the unspecified ethernet address
// (all bits set to 0).
UnspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")
// EthernetBroadcastAddress is an ethernet address that addresses every node
// on a local link.
EthernetBroadcastAddress = tcpip.LinkAddress("\xff\xff\xff\xff\xff\xff")
// unicastMulticastFlagMask is the mask of the least significant bit in
// the first octet (in network byte order) of an ethernet address that
// determines whether the ethernet address is a unicast or multicast. If
// the masked bit is a 1, then the address is a multicast, unicast
// otherwise.
//
// See the IEEE Std 802-2001 document for more details. Specifically,
// section 9.2.1 of http://ieee802.org/secmail/pdfocSP2xXA6d.pdf:
// "A 48-bit universal address consists of two parts. The first 24 bits
// correspond to the OUI as assigned by the IEEE, expect that the
// assignee may set the LSB of the first octet to 1 for group addresses
// or set it to 0 for individual addresses."
unicastMulticastFlagMask = 1
// unicastMulticastFlagByteIdx is the byte that holds the
// unicast/multicast flag. See unicastMulticastFlagMask.
unicastMulticastFlagByteIdx = 0
)
const (
// EthernetProtocolAll is a catch-all for all protocols carried inside
// an ethernet frame. It is mainly used to create packet sockets that
// capture all traffic.
EthernetProtocolAll tcpip.NetworkProtocolNumber = 0x0003
// EthernetProtocolPUP is the PARC Universal Packet protocol ethertype.
EthernetProtocolPUP tcpip.NetworkProtocolNumber = 0x0200
)
// Ethertypes holds the protocol numbers describing the payload of an ethernet
// frame. These types aren't necessarily supported by netstack, but can be used
// to catch all traffic of a type via packet endpoints.
var Ethertypes = []tcpip.NetworkProtocolNumber{
EthernetProtocolAll,
EthernetProtocolPUP,
}
// SourceAddress returns the "MAC source" field of the ethernet frame header.
func (b Ethernet) SourceAddress() tcpip.LinkAddress {
return tcpip.LinkAddress(b[srcMAC:][:EthernetAddressSize])
}
// DestinationAddress returns the "MAC destination" field of the ethernet frame
// header.
func (b Ethernet) DestinationAddress() tcpip.LinkAddress {
return tcpip.LinkAddress(b[dstMAC:][:EthernetAddressSize])
}
// Type returns the "ethertype" field of the ethernet frame header.
func (b Ethernet) Type() tcpip.NetworkProtocolNumber {
return tcpip.NetworkProtocolNumber(binary.BigEndian.Uint16(b[ethType:]))
}
// Encode encodes all the fields of the ethernet frame header.
func (b Ethernet) Encode(e *EthernetFields) {
binary.BigEndian.PutUint16(b[ethType:], uint16(e.Type))
copy(b[srcMAC:][:EthernetAddressSize], e.SrcAddr)
copy(b[dstMAC:][:EthernetAddressSize], e.DstAddr)
}
// IsMulticastEthernetAddress returns true if the address is a multicast
// ethernet address.
func IsMulticastEthernetAddress(addr tcpip.LinkAddress) bool {
if len(addr) != EthernetAddressSize {
return false
}
return addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0
}
// IsValidUnicastEthernetAddress returns true if the address is a unicast
// ethernet address.
func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool {
if len(addr) != EthernetAddressSize {
return false
}
if addr == UnspecifiedEthernetAddress {
return false
}
if addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0 {
return false
}
return true
}
// EthernetAddressFromMulticastIPv4Address returns a multicast Ethernet address
// for a multicast IPv4 address.
//
// addr MUST be a multicast IPv4 address.
func EthernetAddressFromMulticastIPv4Address(addr tcpip.Address) tcpip.LinkAddress {
var linkAddrBytes [EthernetAddressSize]byte
// RFC 1112 Host Extensions for IP Multicasting
//
// 6.4. Extensions to an Ethernet Local Network Module:
//
// An IP host group address is mapped to an Ethernet multicast
// address by placing the low-order 23-bits of the IP address
// into the low-order 23 bits of the Ethernet multicast address
// 01-00-5E-00-00-00 (hex).
addrBytes := addr.As4()
linkAddrBytes[0] = 0x1
linkAddrBytes[2] = 0x5e
linkAddrBytes[3] = addrBytes[1] & 0x7F
copy(linkAddrBytes[4:], addrBytes[IPv4AddressSize-2:])
return tcpip.LinkAddress(linkAddrBytes[:])
}
// EthernetAddressFromMulticastIPv6Address returns a multicast Ethernet address
// for a multicast IPv6 address.
//
// addr MUST be a multicast IPv6 address.
func EthernetAddressFromMulticastIPv6Address(addr tcpip.Address) tcpip.LinkAddress {
// RFC 2464 Transmission of IPv6 Packets over Ethernet Networks
//
// 7. Address Mapping -- Multicast
//
// An IPv6 packet with a multicast destination address DST,
// consisting of the sixteen octets DST[1] through DST[16], is
// transmitted to the Ethernet multicast address whose first
// two octets are the value 3333 hexadecimal and whose last
// four octets are the last four octets of DST.
addrBytes := addr.As16()
linkAddrBytes := []byte(addrBytes[IPv6AddressSize-EthernetAddressSize:])
linkAddrBytes[0] = 0x33
linkAddrBytes[1] = 0x33
return tcpip.LinkAddress(linkAddrBytes[:])
}

View File

@@ -0,0 +1,73 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
const (
typeHLen = 0
encapProto = 1
)
// GUEFields contains the fields of a GUE packet. It is used to describe the
// fields of a packet that needs to be encoded.
type GUEFields struct {
// Type is the "type" field of the GUE header.
Type uint8
// Control is the "control" field of the GUE header.
Control bool
// HeaderLength is the "header length" field of the GUE header. It must
// be at least 4 octets, and a multiple of 4 as well.
HeaderLength uint8
// Protocol is the "protocol" field of the GUE header. This is one of
// the IPPROTO_* values.
Protocol uint8
}
// GUE represents a Generic UDP Encapsulation header stored in a byte array, the
// fields are described in https://tools.ietf.org/html/draft-ietf-nvo3-gue-01.
type GUE []byte
const (
// GUEMinimumSize is the minimum size of a valid GUE packet.
GUEMinimumSize = 4
)
// TypeAndControl returns the GUE packet type (top 3 bits of the first byte,
// which includes the control bit).
func (b GUE) TypeAndControl() uint8 {
return b[typeHLen] >> 5
}
// HeaderLength returns the total length of the GUE header.
func (b GUE) HeaderLength() uint8 {
return 4 + 4*(b[typeHLen]&0x1f)
}
// Protocol returns the protocol field of the GUE header.
func (b GUE) Protocol() uint8 {
return b[encapProto]
}
// Encode encodes all the fields of the GUE header.
func (b GUE) Encode(i *GUEFields) {
ctl := uint8(0)
if i.Control {
ctl = 1 << 5
}
b[typeHLen] = ctl | i.Type<<6 | (i.HeaderLength-4)/4
b[encapProto] = i.Protocol
}

View File

@@ -0,0 +1,120 @@
// automatically generated by stateify.
package header
import (
"context"
"gvisor.dev/gvisor/pkg/state"
)
func (t *TCPSynOptions) StateTypeName() string {
return "pkg/tcpip/header.TCPSynOptions"
}
func (t *TCPSynOptions) StateFields() []string {
return []string{
"MSS",
"WS",
"TS",
"TSVal",
"TSEcr",
"SACKPermitted",
"Flags",
}
}
func (t *TCPSynOptions) beforeSave() {}
// +checklocksignore
func (t *TCPSynOptions) StateSave(stateSinkObject state.Sink) {
t.beforeSave()
stateSinkObject.Save(0, &t.MSS)
stateSinkObject.Save(1, &t.WS)
stateSinkObject.Save(2, &t.TS)
stateSinkObject.Save(3, &t.TSVal)
stateSinkObject.Save(4, &t.TSEcr)
stateSinkObject.Save(5, &t.SACKPermitted)
stateSinkObject.Save(6, &t.Flags)
}
func (t *TCPSynOptions) afterLoad(context.Context) {}
// +checklocksignore
func (t *TCPSynOptions) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &t.MSS)
stateSourceObject.Load(1, &t.WS)
stateSourceObject.Load(2, &t.TS)
stateSourceObject.Load(3, &t.TSVal)
stateSourceObject.Load(4, &t.TSEcr)
stateSourceObject.Load(5, &t.SACKPermitted)
stateSourceObject.Load(6, &t.Flags)
}
func (r *SACKBlock) StateTypeName() string {
return "pkg/tcpip/header.SACKBlock"
}
func (r *SACKBlock) StateFields() []string {
return []string{
"Start",
"End",
}
}
func (r *SACKBlock) beforeSave() {}
// +checklocksignore
func (r *SACKBlock) StateSave(stateSinkObject state.Sink) {
r.beforeSave()
stateSinkObject.Save(0, &r.Start)
stateSinkObject.Save(1, &r.End)
}
func (r *SACKBlock) afterLoad(context.Context) {}
// +checklocksignore
func (r *SACKBlock) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &r.Start)
stateSourceObject.Load(1, &r.End)
}
func (t *TCPOptions) StateTypeName() string {
return "pkg/tcpip/header.TCPOptions"
}
func (t *TCPOptions) StateFields() []string {
return []string{
"TS",
"TSVal",
"TSEcr",
"SACKBlocks",
}
}
func (t *TCPOptions) beforeSave() {}
// +checklocksignore
func (t *TCPOptions) StateSave(stateSinkObject state.Sink) {
t.beforeSave()
stateSinkObject.Save(0, &t.TS)
stateSinkObject.Save(1, &t.TSVal)
stateSinkObject.Save(2, &t.TSEcr)
stateSinkObject.Save(3, &t.SACKBlocks)
}
func (t *TCPOptions) afterLoad(context.Context) {}
// +checklocksignore
func (t *TCPOptions) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &t.TS)
stateSourceObject.Load(1, &t.TSVal)
stateSourceObject.Load(2, &t.TSEcr)
stateSourceObject.Load(3, &t.SACKBlocks)
}
func init() {
state.Register((*TCPSynOptions)(nil))
state.Register((*SACKBlock)(nil))
state.Register((*TCPOptions)(nil))
}

View File

@@ -0,0 +1,228 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"encoding/binary"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/checksum"
)
// ICMPv4 represents an ICMPv4 header stored in a byte array.
type ICMPv4 []byte
const (
// ICMPv4PayloadOffset defines the start of ICMP payload.
ICMPv4PayloadOffset = 8
// ICMPv4MinimumSize is the minimum size of a valid ICMP packet.
ICMPv4MinimumSize = 8
// ICMPv4MinimumErrorPayloadSize Is the smallest number of bytes of an
// errant packet's transport layer that an ICMP error type packet should
// attempt to send as per RFC 792 (see each type) and RFC 1122
// section 3.2.2 which states:
// Every ICMP error message includes the Internet header and at
// least the first 8 data octets of the datagram that triggered
// the error; more than 8 octets MAY be sent; this header and data
// MUST be unchanged from the received datagram.
//
// RFC 792 shows:
// 0 1 2 3
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Type | Code | Checksum |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | unused |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Internet Header + 64 bits of Original Data Datagram |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
ICMPv4MinimumErrorPayloadSize = 8
// ICMPv4ProtocolNumber is the ICMP transport protocol number.
ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1
// icmpv4ChecksumOffset is the offset of the checksum field
// in an ICMPv4 message.
icmpv4ChecksumOffset = 2
// icmpv4MTUOffset is the offset of the MTU field
// in an ICMPv4FragmentationNeeded message.
icmpv4MTUOffset = 6
// icmpv4IdentOffset is the offset of the ident field
// in an ICMPv4EchoRequest/Reply message.
icmpv4IdentOffset = 4
// icmpv4PointerOffset is the offset of the pointer field
// in an ICMPv4ParamProblem message.
icmpv4PointerOffset = 4
// icmpv4SequenceOffset is the offset of the sequence field
// in an ICMPv4EchoRequest/Reply message.
icmpv4SequenceOffset = 6
)
// ICMPv4Type is the ICMP type field described in RFC 792.
type ICMPv4Type byte
// ICMPv4Code is the ICMP code field described in RFC 792.
type ICMPv4Code byte
// Typical values of ICMPv4Type defined in RFC 792.
const (
ICMPv4EchoReply ICMPv4Type = 0
ICMPv4DstUnreachable ICMPv4Type = 3
ICMPv4SrcQuench ICMPv4Type = 4
ICMPv4Redirect ICMPv4Type = 5
ICMPv4Echo ICMPv4Type = 8
ICMPv4TimeExceeded ICMPv4Type = 11
ICMPv4ParamProblem ICMPv4Type = 12
ICMPv4Timestamp ICMPv4Type = 13
ICMPv4TimestampReply ICMPv4Type = 14
ICMPv4InfoRequest ICMPv4Type = 15
ICMPv4InfoReply ICMPv4Type = 16
)
// ICMP codes for ICMPv4 Time Exceeded messages as defined in RFC 792.
const (
ICMPv4TTLExceeded ICMPv4Code = 0
ICMPv4ReassemblyTimeout ICMPv4Code = 1
)
// ICMP codes for ICMPv4 Destination Unreachable messages as defined in RFC 792,
// RFC 1122 section 3.2.2.1 and RFC 1812 section 5.2.7.1.
const (
ICMPv4NetUnreachable ICMPv4Code = 0
ICMPv4HostUnreachable ICMPv4Code = 1
ICMPv4ProtoUnreachable ICMPv4Code = 2
ICMPv4PortUnreachable ICMPv4Code = 3
ICMPv4FragmentationNeeded ICMPv4Code = 4
ICMPv4SourceRouteFailed ICMPv4Code = 5
ICMPv4DestinationNetworkUnknown ICMPv4Code = 6
ICMPv4DestinationHostUnknown ICMPv4Code = 7
ICMPv4SourceHostIsolated ICMPv4Code = 8
ICMPv4NetProhibited ICMPv4Code = 9
ICMPv4HostProhibited ICMPv4Code = 10
ICMPv4NetUnreachableForTos ICMPv4Code = 11
ICMPv4HostUnreachableForTos ICMPv4Code = 12
ICMPv4AdminProhibited ICMPv4Code = 13
ICMPv4HostPrecedenceViolation ICMPv4Code = 14
ICMPv4PrecedenceCutInEffect ICMPv4Code = 15
)
// ICMPv4UnusedCode is a code to use in ICMP messages where no code is needed.
const ICMPv4UnusedCode ICMPv4Code = 0
// Type is the ICMP type field.
func (b ICMPv4) Type() ICMPv4Type { return ICMPv4Type(b[0]) }
// SetType sets the ICMP type field.
func (b ICMPv4) SetType(t ICMPv4Type) { b[0] = byte(t) }
// Code is the ICMP code field. Its meaning depends on the value of Type.
func (b ICMPv4) Code() ICMPv4Code { return ICMPv4Code(b[1]) }
// SetCode sets the ICMP code field.
func (b ICMPv4) SetCode(c ICMPv4Code) { b[1] = byte(c) }
// Pointer returns the pointer field in a Parameter Problem packet.
func (b ICMPv4) Pointer() byte { return b[icmpv4PointerOffset] }
// SetPointer sets the pointer field in a Parameter Problem packet.
func (b ICMPv4) SetPointer(c byte) { b[icmpv4PointerOffset] = c }
// Checksum is the ICMP checksum field.
func (b ICMPv4) Checksum() uint16 {
return binary.BigEndian.Uint16(b[icmpv4ChecksumOffset:])
}
// SetChecksum sets the ICMP checksum field.
func (b ICMPv4) SetChecksum(cs uint16) {
checksum.Put(b[icmpv4ChecksumOffset:], cs)
}
// SourcePort implements Transport.SourcePort.
func (ICMPv4) SourcePort() uint16 {
return 0
}
// DestinationPort implements Transport.DestinationPort.
func (ICMPv4) DestinationPort() uint16 {
return 0
}
// SetSourcePort implements Transport.SetSourcePort.
func (ICMPv4) SetSourcePort(uint16) {
}
// SetDestinationPort implements Transport.SetDestinationPort.
func (ICMPv4) SetDestinationPort(uint16) {
}
// Payload implements Transport.Payload.
func (b ICMPv4) Payload() []byte {
return b[ICMPv4PayloadOffset:]
}
// MTU retrieves the MTU field from an ICMPv4 message.
func (b ICMPv4) MTU() uint16 {
return binary.BigEndian.Uint16(b[icmpv4MTUOffset:])
}
// SetMTU sets the MTU field from an ICMPv4 message.
func (b ICMPv4) SetMTU(mtu uint16) {
binary.BigEndian.PutUint16(b[icmpv4MTUOffset:], mtu)
}
// Ident retrieves the Ident field from an ICMPv4 message.
func (b ICMPv4) Ident() uint16 {
return binary.BigEndian.Uint16(b[icmpv4IdentOffset:])
}
// SetIdent sets the Ident field from an ICMPv4 message.
func (b ICMPv4) SetIdent(ident uint16) {
binary.BigEndian.PutUint16(b[icmpv4IdentOffset:], ident)
}
// SetIdentWithChecksumUpdate sets the Ident field and updates the checksum.
func (b ICMPv4) SetIdentWithChecksumUpdate(new uint16) {
old := b.Ident()
b.SetIdent(new)
b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
}
// Sequence retrieves the Sequence field from an ICMPv4 message.
func (b ICMPv4) Sequence() uint16 {
return binary.BigEndian.Uint16(b[icmpv4SequenceOffset:])
}
// SetSequence sets the Sequence field from an ICMPv4 message.
func (b ICMPv4) SetSequence(sequence uint16) {
binary.BigEndian.PutUint16(b[icmpv4SequenceOffset:], sequence)
}
// ICMPv4Checksum calculates the ICMP checksum over the provided ICMP header,
// and payload.
func ICMPv4Checksum(h ICMPv4, payloadCsum uint16) uint16 {
xsum := payloadCsum
// h[2:4] is the checksum itself, skip it to avoid checksumming the checksum.
xsum = checksum.Checksum(h[:2], xsum)
xsum = checksum.Checksum(h[4:], xsum)
return ^xsum
}

View File

@@ -0,0 +1,304 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"encoding/binary"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/checksum"
)
// ICMPv6 represents an ICMPv6 header stored in a byte array.
type ICMPv6 []byte
const (
// ICMPv6HeaderSize is the size of the ICMPv6 header. That is, the
// sum of the size of the ICMPv6 Type, Code and Checksum fields, as
// per RFC 4443 section 2.1. After the ICMPv6 header, the ICMPv6
// message body begins.
ICMPv6HeaderSize = 4
// ICMPv6MinimumSize is the minimum size of a valid ICMP packet.
ICMPv6MinimumSize = 8
// ICMPv6PayloadOffset is the offset of the payload in an
// ICMP packet.
ICMPv6PayloadOffset = 8
// ICMPv6ProtocolNumber is the ICMP transport protocol number.
ICMPv6ProtocolNumber tcpip.TransportProtocolNumber = 58
// ICMPv6NeighborSolicitMinimumSize is the minimum size of a
// neighbor solicitation packet.
ICMPv6NeighborSolicitMinimumSize = ICMPv6HeaderSize + NDPNSMinimumSize
// ICMPv6NeighborAdvertMinimumSize is the minimum size of a
// neighbor advertisement packet.
ICMPv6NeighborAdvertMinimumSize = ICMPv6HeaderSize + NDPNAMinimumSize
// ICMPv6EchoMinimumSize is the minimum size of a valid echo packet.
ICMPv6EchoMinimumSize = 8
// ICMPv6ErrorHeaderSize is the size of an ICMP error packet header,
// as per RFC 4443, Appendix A, item 4 and the errata.
// ... all ICMP error messages shall have exactly
// 32 bits of type-specific data, so that receivers can reliably find
// the embedded invoking packet even when they don't recognize the
// ICMP message Type.
ICMPv6ErrorHeaderSize = 8
// ICMPv6DstUnreachableMinimumSize is the minimum size of a valid ICMP
// destination unreachable packet.
ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize
// ICMPv6PacketTooBigMinimumSize is the minimum size of a valid ICMP
// packet-too-big packet.
ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize
// ICMPv6ChecksumOffset is the offset of the checksum field
// in an ICMPv6 message.
ICMPv6ChecksumOffset = 2
// icmpv6PointerOffset is the offset of the pointer
// in an ICMPv6 Parameter problem message.
icmpv6PointerOffset = 4
// icmpv6MTUOffset is the offset of the MTU field in an ICMPv6
// PacketTooBig message.
icmpv6MTUOffset = 4
// icmpv6IdentOffset is the offset of the ident field
// in a ICMPv6 Echo Request/Reply message.
icmpv6IdentOffset = 4
// icmpv6SequenceOffset is the offset of the sequence field
// in a ICMPv6 Echo Request/Reply message.
icmpv6SequenceOffset = 6
// NDPHopLimit is the expected IP hop limit value of 255 for received
// NDP packets, as per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1,
// 7.1.2 and 8.1. If the hop limit value is not 255, nodes MUST silently
// drop the NDP packet. All outgoing NDP packets must use this value for
// its IP hop limit field.
NDPHopLimit = 255
)
// ICMPv6Type is the ICMP type field described in RFC 4443.
type ICMPv6Type byte
// Values for use in the Type field of ICMPv6 packet from RFC 4433.
const (
ICMPv6DstUnreachable ICMPv6Type = 1
ICMPv6PacketTooBig ICMPv6Type = 2
ICMPv6TimeExceeded ICMPv6Type = 3
ICMPv6ParamProblem ICMPv6Type = 4
ICMPv6EchoRequest ICMPv6Type = 128
ICMPv6EchoReply ICMPv6Type = 129
// Neighbor Discovery Protocol (NDP) messages, see RFC 4861.
ICMPv6RouterSolicit ICMPv6Type = 133
ICMPv6RouterAdvert ICMPv6Type = 134
ICMPv6NeighborSolicit ICMPv6Type = 135
ICMPv6NeighborAdvert ICMPv6Type = 136
ICMPv6RedirectMsg ICMPv6Type = 137
// Multicast Listener Discovery (MLD) messages, see RFC 2710.
ICMPv6MulticastListenerQuery ICMPv6Type = 130
ICMPv6MulticastListenerReport ICMPv6Type = 131
ICMPv6MulticastListenerDone ICMPv6Type = 132
// Multicast Listener Discovert Version 2 (MLDv2) messages, see RFC 3810.
ICMPv6MulticastListenerV2Report ICMPv6Type = 143
)
// IsErrorType returns true if the receiver is an ICMP error type.
func (typ ICMPv6Type) IsErrorType() bool {
// Per RFC 4443 section 2.1:
// ICMPv6 messages are grouped into two classes: error messages and
// informational messages. Error messages are identified as such by a
// zero in the high-order bit of their message Type field values. Thus,
// error messages have message types from 0 to 127; informational
// messages have message types from 128 to 255.
return typ&0x80 == 0
}
// ICMPv6Code is the ICMP Code field described in RFC 4443.
type ICMPv6Code byte
// ICMP codes used with Destination Unreachable (Type 1). As per RFC 4443
// section 3.1.
const (
ICMPv6NetworkUnreachable ICMPv6Code = 0
ICMPv6Prohibited ICMPv6Code = 1
ICMPv6BeyondScope ICMPv6Code = 2
ICMPv6AddressUnreachable ICMPv6Code = 3
ICMPv6PortUnreachable ICMPv6Code = 4
ICMPv6Policy ICMPv6Code = 5
ICMPv6RejectRoute ICMPv6Code = 6
)
// ICMP codes used with Time Exceeded (Type 3). As per RFC 4443 section 3.3.
const (
ICMPv6HopLimitExceeded ICMPv6Code = 0
ICMPv6ReassemblyTimeout ICMPv6Code = 1
)
// ICMP codes used with Parameter Problem (Type 4). As per RFC 4443 section 3.4.
const (
// ICMPv6ErroneousHeader indicates an erroneous header field was encountered.
ICMPv6ErroneousHeader ICMPv6Code = 0
// ICMPv6UnknownHeader indicates an unrecognized Next Header type encountered.
ICMPv6UnknownHeader ICMPv6Code = 1
// ICMPv6UnknownOption indicates an unrecognized IPv6 option was encountered.
ICMPv6UnknownOption ICMPv6Code = 2
)
// ICMPv6UnusedCode is the code value used with ICMPv6 messages which don't use
// the code field. (Types not mentioned above.)
const ICMPv6UnusedCode ICMPv6Code = 0
// Type is the ICMP type field.
func (b ICMPv6) Type() ICMPv6Type { return ICMPv6Type(b[0]) }
// SetType sets the ICMP type field.
func (b ICMPv6) SetType(t ICMPv6Type) { b[0] = byte(t) }
// Code is the ICMP code field. Its meaning depends on the value of Type.
func (b ICMPv6) Code() ICMPv6Code { return ICMPv6Code(b[1]) }
// SetCode sets the ICMP code field.
func (b ICMPv6) SetCode(c ICMPv6Code) { b[1] = byte(c) }
// TypeSpecific returns the type specific data field.
func (b ICMPv6) TypeSpecific() uint32 {
return binary.BigEndian.Uint32(b[icmpv6PointerOffset:])
}
// SetTypeSpecific sets the type specific data field.
func (b ICMPv6) SetTypeSpecific(val uint32) {
binary.BigEndian.PutUint32(b[icmpv6PointerOffset:], val)
}
// Checksum is the ICMP checksum field.
func (b ICMPv6) Checksum() uint16 {
return binary.BigEndian.Uint16(b[ICMPv6ChecksumOffset:])
}
// SetChecksum sets the ICMP checksum field.
func (b ICMPv6) SetChecksum(cs uint16) {
checksum.Put(b[ICMPv6ChecksumOffset:], cs)
}
// SourcePort implements Transport.SourcePort.
func (ICMPv6) SourcePort() uint16 {
return 0
}
// DestinationPort implements Transport.DestinationPort.
func (ICMPv6) DestinationPort() uint16 {
return 0
}
// SetSourcePort implements Transport.SetSourcePort.
func (ICMPv6) SetSourcePort(uint16) {
}
// SetDestinationPort implements Transport.SetDestinationPort.
func (ICMPv6) SetDestinationPort(uint16) {
}
// MTU retrieves the MTU field from an ICMPv6 message.
func (b ICMPv6) MTU() uint32 {
return binary.BigEndian.Uint32(b[icmpv6MTUOffset:])
}
// SetMTU sets the MTU field from an ICMPv6 message.
func (b ICMPv6) SetMTU(mtu uint32) {
binary.BigEndian.PutUint32(b[icmpv6MTUOffset:], mtu)
}
// Ident retrieves the Ident field from an ICMPv6 message.
func (b ICMPv6) Ident() uint16 {
return binary.BigEndian.Uint16(b[icmpv6IdentOffset:])
}
// SetIdent sets the Ident field from an ICMPv6 message.
func (b ICMPv6) SetIdent(ident uint16) {
binary.BigEndian.PutUint16(b[icmpv6IdentOffset:], ident)
}
// SetIdentWithChecksumUpdate sets the Ident field and updates the checksum.
func (b ICMPv6) SetIdentWithChecksumUpdate(new uint16) {
old := b.Ident()
b.SetIdent(new)
b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
}
// Sequence retrieves the Sequence field from an ICMPv6 message.
func (b ICMPv6) Sequence() uint16 {
return binary.BigEndian.Uint16(b[icmpv6SequenceOffset:])
}
// SetSequence sets the Sequence field from an ICMPv6 message.
func (b ICMPv6) SetSequence(sequence uint16) {
binary.BigEndian.PutUint16(b[icmpv6SequenceOffset:], sequence)
}
// MessageBody returns the message body as defined by RFC 4443 section 2.1; the
// portion of the ICMPv6 buffer after the first ICMPv6HeaderSize bytes.
func (b ICMPv6) MessageBody() []byte {
return b[ICMPv6HeaderSize:]
}
// Payload implements Transport.Payload.
func (b ICMPv6) Payload() []byte {
return b[ICMPv6PayloadOffset:]
}
// ICMPv6ChecksumParams contains parameters to calculate ICMPv6 checksum.
type ICMPv6ChecksumParams struct {
Header ICMPv6
Src tcpip.Address
Dst tcpip.Address
PayloadCsum uint16
PayloadLen int
}
// ICMPv6Checksum calculates the ICMP checksum over the provided ICMPv6 header,
// IPv6 src/dst addresses and the payload.
func ICMPv6Checksum(params ICMPv6ChecksumParams) uint16 {
h := params.Header
xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, params.Src, params.Dst, uint16(len(h)+params.PayloadLen))
xsum = checksum.Combine(xsum, params.PayloadCsum)
// h[2:4] is the checksum itself, skip it to avoid checksumming the checksum.
xsum = checksum.Checksum(h[:2], xsum)
xsum = checksum.Checksum(h[4:], xsum)
return ^xsum
}
// UpdateChecksumPseudoHeaderAddress updates the checksum to reflect an
// updated address in the pseudo header.
func (b ICMPv6) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address) {
b.SetChecksum(^checksumUpdate2ByteAlignedAddress(^b.Checksum(), old, new))
}

View File

@@ -0,0 +1,185 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"encoding/binary"
"fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/checksum"
)
// IGMP represents an IGMP header stored in a byte array.
type IGMP []byte
// IGMP implements `Transport`.
var _ Transport = (*IGMP)(nil)
const (
// IGMPMinimumSize is the minimum size of a valid IGMP packet in bytes,
// as per RFC 2236, Section 2, Page 2.
IGMPMinimumSize = 8
// IGMPQueryMinimumSize is the minimum size of a valid Membership Query
// Message in bytes, as per RFC 2236, Section 2, Page 2.
IGMPQueryMinimumSize = 8
// IGMPReportMinimumSize is the minimum size of a valid Report Message in
// bytes, as per RFC 2236, Section 2, Page 2.
IGMPReportMinimumSize = 8
// IGMPLeaveMessageMinimumSize is the minimum size of a valid Leave Message
// in bytes, as per RFC 2236, Section 2, Page 2.
IGMPLeaveMessageMinimumSize = 8
// IGMPTTL is the TTL for all IGMP messages, as per RFC 2236, Section 3, Page
// 3.
IGMPTTL = 1
// igmpTypeOffset defines the offset of the type field in an IGMP message.
igmpTypeOffset = 0
// igmpMaxRespTimeOffset defines the offset of the MaxRespTime field in an
// IGMP message.
igmpMaxRespTimeOffset = 1
// igmpChecksumOffset defines the offset of the checksum field in an IGMP
// message.
igmpChecksumOffset = 2
// igmpGroupAddressOffset defines the offset of the Group Address field in an
// IGMP message.
igmpGroupAddressOffset = 4
// IGMPProtocolNumber is IGMP's transport protocol number.
IGMPProtocolNumber tcpip.TransportProtocolNumber = 2
)
// IGMPType is the IGMP type field as per RFC 2236.
type IGMPType byte
// Values for the IGMP Type described in RFC 2236 Section 2.1, Page 2.
// Descriptions below come from there.
const (
// IGMPMembershipQuery indicates that the message type is Membership Query.
// "There are two sub-types of Membership Query messages:
// - General Query, used to learn which groups have members on an
// attached network.
// - Group-Specific Query, used to learn if a particular group
// has any members on an attached network.
// These two messages are differentiated by the Group Address, as
// described in section 1.4 ."
IGMPMembershipQuery IGMPType = 0x11
// IGMPv1MembershipReport indicates that the message is a Membership Report
// generated by a host using the IGMPv1 protocol: "an additional type of
// message, for backwards-compatibility with IGMPv1"
IGMPv1MembershipReport IGMPType = 0x12
// IGMPv2MembershipReport indicates that the Message type is a Membership
// Report generated by a host using the IGMPv2 protocol.
IGMPv2MembershipReport IGMPType = 0x16
// IGMPLeaveGroup indicates that the message type is a Leave Group
// notification message.
IGMPLeaveGroup IGMPType = 0x17
// IGMPv3MembershipReport indicates that the message type is a IGMPv3 report.
IGMPv3MembershipReport IGMPType = 0x22
)
// Type is the IGMP type field.
func (b IGMP) Type() IGMPType { return IGMPType(b[igmpTypeOffset]) }
// SetType sets the IGMP type field.
func (b IGMP) SetType(t IGMPType) { b[igmpTypeOffset] = byte(t) }
// MaxRespTime gets the MaxRespTimeField. This is meaningful only in Membership
// Query messages, in other cases it is set to 0 by the sender and ignored by
// the receiver.
func (b IGMP) MaxRespTime() time.Duration {
// As per RFC 2236 section 2.2,
//
// The Max Response Time field is meaningful only in Membership Query
// messages, and specifies the maximum allowed time before sending a
// responding report in units of 1/10 second. In all other messages, it
// is set to zero by the sender and ignored by receivers.
return DecisecondToDuration(uint16(b[igmpMaxRespTimeOffset]))
}
// SetMaxRespTime sets the MaxRespTimeField.
func (b IGMP) SetMaxRespTime(m byte) { b[igmpMaxRespTimeOffset] = m }
// Checksum is the IGMP checksum field.
func (b IGMP) Checksum() uint16 {
return binary.BigEndian.Uint16(b[igmpChecksumOffset:])
}
// SetChecksum sets the IGMP checksum field.
func (b IGMP) SetChecksum(checksum uint16) {
binary.BigEndian.PutUint16(b[igmpChecksumOffset:], checksum)
}
// GroupAddress gets the Group Address field.
func (b IGMP) GroupAddress() tcpip.Address {
return tcpip.AddrFrom4([4]byte(b[igmpGroupAddressOffset:][:IPv4AddressSize]))
}
// SetGroupAddress sets the Group Address field.
func (b IGMP) SetGroupAddress(address tcpip.Address) {
addrBytes := address.As4()
if n := copy(b[igmpGroupAddressOffset:], addrBytes[:]); n != IPv4AddressSize {
panic(fmt.Sprintf("copied %d bytes, expected %d", n, IPv4AddressSize))
}
}
// SourcePort implements Transport.SourcePort.
func (IGMP) SourcePort() uint16 {
return 0
}
// DestinationPort implements Transport.DestinationPort.
func (IGMP) DestinationPort() uint16 {
return 0
}
// SetSourcePort implements Transport.SetSourcePort.
func (IGMP) SetSourcePort(uint16) {
}
// SetDestinationPort implements Transport.SetDestinationPort.
func (IGMP) SetDestinationPort(uint16) {
}
// Payload implements Transport.Payload.
func (IGMP) Payload() []byte {
return nil
}
// IGMPCalculateChecksum calculates the IGMP checksum over the provided IGMP
// header.
func IGMPCalculateChecksum(h IGMP) uint16 {
// The header contains a checksum itself, set it aside to avoid checksumming
// the checksum and replace it afterwards.
existingXsum := h.Checksum()
h.SetChecksum(0)
xsum := ^checksum.Checksum(h, 0)
h.SetChecksum(existingXsum)
return xsum
}
// DecisecondToDuration converts a value representing deci-seconds to a
// time.Duration.
func DecisecondToDuration(ds uint16) time.Duration {
return time.Duration(ds) * time.Second / 10
}

View File

@@ -0,0 +1,502 @@
// Copyright 2022 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"bytes"
"encoding/binary"
"fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
)
var (
// IGMPv3RoutersAddress is the address to send IGMPv3 reports to.
//
// As per RFC 3376 section 4.2.14,
//
// Version 3 Reports are sent with an IP destination address of
// 224.0.0.22, to which all IGMPv3-capable multicast routers listen.
IGMPv3RoutersAddress = tcpip.AddrFrom4([4]byte{0xe0, 0x00, 0x00, 0x16})
)
const (
// IGMPv3QueryMinimumSize is the mimum size of a valid IGMPv3 query,
// as per RFC 3376 section 4.1.
IGMPv3QueryMinimumSize = 12
igmpv3QueryMaxRespCodeOffset = 1
igmpv3QueryGroupAddressOffset = 4
igmpv3QueryResvSQRVOffset = 8
igmpv3QueryQRVMask = 0b111
igmpv3QueryQQICOffset = 9
igmpv3QueryNumberOfSourcesOffset = 10
igmpv3QuerySourcesOffset = 12
)
// IGMPv3Query is an IGMPv3 query message.
//
// As per RFC 3376 section 4.1,
//
// 0 1 2 3
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Type = 0x11 | Max Resp Code | Checksum |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Group Address |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Resv |S| QRV | QQIC | Number of Sources (N) |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Source Address [1] |
// +- -+
// | Source Address [2] |
// +- . -+
// . . .
// . . .
// +- -+
// | Source Address [N] |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
type IGMPv3Query IGMP
// MaximumResponseCode returns the Maximum Response Code.
func (i IGMPv3Query) MaximumResponseCode() uint8 {
return i[igmpv3QueryMaxRespCodeOffset]
}
// IGMPv3MaximumResponseDelay returns the Maximum Response Delay in an IGMPv3
// Maximum Response Code.
//
// As per RFC 3376 section 4.1.1,
//
// The Max Resp Code field specifies the maximum time allowed before
// sending a responding report. The actual time allowed, called the Max
// Resp Time, is represented in units of 1/10 second and is derived from
// the Max Resp Code as follows:
//
// If Max Resp Code < 128, Max Resp Time = Max Resp Code
//
// If Max Resp Code >= 128, Max Resp Code represents a floating-point
// value as follows:
//
// 0 1 2 3 4 5 6 7
// +-+-+-+-+-+-+-+-+
// |1| exp | mant |
// +-+-+-+-+-+-+-+-+
//
// Max Resp Time = (mant | 0x10) << (exp + 3)
//
// Small values of Max Resp Time allow IGMPv3 routers to tune the "leave
// latency" (the time between the moment the last host leaves a group
// and the moment the routing protocol is notified that there are no
// more members). Larger values, especially in the exponential range,
// allow tuning of the burstiness of IGMP traffic on a network.
func IGMPv3MaximumResponseDelay(codeRaw uint8) time.Duration {
code := uint16(codeRaw)
if code < 128 {
return DecisecondToDuration(code)
}
const mantBits = 4
const expMask = 0b111
exp := (code >> mantBits) & expMask
mant := code & ((1 << mantBits) - 1)
return DecisecondToDuration((mant | 0x10) << (exp + 3))
}
// GroupAddress returns the group address.
func (i IGMPv3Query) GroupAddress() tcpip.Address {
return tcpip.AddrFrom4([4]byte(i[igmpv3QueryGroupAddressOffset:][:IPv4AddressSize]))
}
// QuerierRobustnessVariable returns the querier's robustness variable.
func (i IGMPv3Query) QuerierRobustnessVariable() uint8 {
return i[igmpv3QueryResvSQRVOffset] & igmpv3QueryQRVMask
}
// QuerierQueryInterval returns the querier's query interval.
func (i IGMPv3Query) QuerierQueryInterval() time.Duration {
return mldv2AndIGMPv3QuerierQueryCodeToInterval(i[igmpv3QueryQQICOffset])
}
// Sources returns an iterator over source addresses in the query.
//
// Returns false if the message cannot hold the expected number of sources.
func (i IGMPv3Query) Sources() (AddressIterator, bool) {
return makeAddressIterator(
i[igmpv3QuerySourcesOffset:],
binary.BigEndian.Uint16(i[igmpv3QueryNumberOfSourcesOffset:]),
IPv4AddressSize,
)
}
// IGMPv3ReportRecordType is the type of an IGMPv3 multicast address record
// found in an IGMPv3 report, as per RFC 3810 section 5.2.12.
type IGMPv3ReportRecordType int
// IGMPv3 multicast address record types, as per RFC 3810 section 5.2.12.
const (
IGMPv3ReportRecordModeIsInclude IGMPv3ReportRecordType = 1
IGMPv3ReportRecordModeIsExclude IGMPv3ReportRecordType = 2
IGMPv3ReportRecordChangeToIncludeMode IGMPv3ReportRecordType = 3
IGMPv3ReportRecordChangeToExcludeMode IGMPv3ReportRecordType = 4
IGMPv3ReportRecordAllowNewSources IGMPv3ReportRecordType = 5
IGMPv3ReportRecordBlockOldSources IGMPv3ReportRecordType = 6
)
const (
igmpv3ReportGroupAddressRecordMinimumSize = 8
igmpv3ReportGroupAddressRecordTypeOffset = 0
igmpv3ReportGroupAddressRecordAuxDataLenOffset = 1
igmpv3ReportGroupAddressRecordAuxDataLenUnits = 4
igmpv3ReportGroupAddressRecordNumberOfSourcesOffset = 2
igmpv3ReportGroupAddressRecordGroupAddressOffset = 4
igmpv3ReportGroupAddressRecordSourcesOffset = 8
)
// IGMPv3ReportGroupAddressRecordSerializer is an IGMPv3 Multicast Address
// Record serializer.
//
// As per RFC 3810 section 5.2, a Multicast Address Record has the following
// internal format:
//
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Record Type | Aux Data Len | Number of Sources (N) |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | |
// * *
// | |
// * Multicast Address *
// | |
// * *
// | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | |
// * *
// | |
// * Source Address [1] *
// | |
// * *
// | |
// +- -+
// | |
// * *
// | |
// * Source Address [2] *
// | |
// * *
// | |
// +- -+
// . . .
// . . .
// . . .
// +- -+
// | |
// * *
// | |
// * Source Address [N] *
// | |
// * *
// | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | |
// . .
// . Auxiliary Data .
// . .
// | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
type IGMPv3ReportGroupAddressRecordSerializer struct {
RecordType IGMPv3ReportRecordType
GroupAddress tcpip.Address
Sources []tcpip.Address
}
// Length returns the number of bytes this serializer would occupy.
func (s *IGMPv3ReportGroupAddressRecordSerializer) Length() int {
return igmpv3ReportGroupAddressRecordSourcesOffset + len(s.Sources)*IPv4AddressSize
}
func copyIPv4Address(dst []byte, src tcpip.Address) {
srcBytes := src.As4()
if n := copy(dst, srcBytes[:]); n != IPv4AddressSize {
panic(fmt.Sprintf("got copy(...) = %d, want = %d", n, IPv4AddressSize))
}
}
// SerializeInto serializes the record into the buffer.
//
// Panics if the buffer does not have enough space to fit the record.
func (s *IGMPv3ReportGroupAddressRecordSerializer) SerializeInto(b []byte) {
b[igmpv3ReportGroupAddressRecordTypeOffset] = byte(s.RecordType)
b[igmpv3ReportGroupAddressRecordAuxDataLenOffset] = 0
binary.BigEndian.PutUint16(b[igmpv3ReportGroupAddressRecordNumberOfSourcesOffset:], uint16(len(s.Sources)))
copyIPv4Address(b[igmpv3ReportGroupAddressRecordGroupAddressOffset:], s.GroupAddress)
b = b[igmpv3ReportGroupAddressRecordSourcesOffset:]
for _, source := range s.Sources {
copyIPv4Address(b, source)
b = b[IPv4AddressSize:]
}
}
const (
igmpv3ReportTypeOffset = 0
igmpv3ReportReserved1Offset = 1
igmpv3ReportReserved2Offset = 4
igmpv3ReportNumberOfGroupAddressRecordsOffset = 6
igmpv3ReportGroupAddressRecordsOffset = 8
)
// IGMPv3ReportSerializer is an MLD Version 2 Report serializer.
//
// As per RFC 3810 section 5.2,
//
// 0 1 2 3
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Type = 143 | Reserved | Checksum |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Reserved |Nr of Mcast Address Records (M)|
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | |
// . .
// . Multicast Address Record [1] .
// . .
// | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | |
// . .
// . Multicast Address Record [2] .
// . .
// | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | . |
// . . .
// | . |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | |
// . .
// . Multicast Address Record [M] .
// . .
// | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
type IGMPv3ReportSerializer struct {
Records []IGMPv3ReportGroupAddressRecordSerializer
}
// Length returns the number of bytes this serializer would occupy.
func (s *IGMPv3ReportSerializer) Length() int {
ret := igmpv3ReportGroupAddressRecordsOffset
for _, record := range s.Records {
ret += record.Length()
}
return ret
}
// SerializeInto serializes the report into the buffer.
//
// Panics if the buffer does not have enough space to fit the report.
func (s *IGMPv3ReportSerializer) SerializeInto(b []byte) {
b[igmpv3ReportTypeOffset] = byte(IGMPv3MembershipReport)
b[igmpv3ReportReserved1Offset] = 0
binary.BigEndian.PutUint16(b[igmpv3ReportReserved2Offset:], 0)
binary.BigEndian.PutUint16(b[igmpv3ReportNumberOfGroupAddressRecordsOffset:], uint16(len(s.Records)))
recordsBytes := b[igmpv3ReportGroupAddressRecordsOffset:]
for _, record := range s.Records {
len := record.Length()
record.SerializeInto(recordsBytes[:len])
recordsBytes = recordsBytes[len:]
}
binary.BigEndian.PutUint16(b[igmpChecksumOffset:], IGMPCalculateChecksum(b))
}
// IGMPv3ReportGroupAddressRecord is an IGMPv3 record.
//
// As per RFC 3810 section 5.2, a Multicast Address Record has the following
// internal format:
//
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Record Type | Aux Data Len | Number of Sources (N) |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | |
// * *
// | |
// * Multicast Address *
// | |
// * *
// | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | |
// * *
// | |
// * Source Address [1] *
// | |
// * *
// | |
// +- -+
// | |
// * *
// | |
// * Source Address [2] *
// | |
// * *
// | |
// +- -+
// . . .
// . . .
// . . .
// +- -+
// | |
// * *
// | |
// * Source Address [N] *
// | |
// * *
// | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | |
// . .
// . Auxiliary Data .
// . .
// | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
type IGMPv3ReportGroupAddressRecord []byte
// RecordType returns the type of this record.
func (r IGMPv3ReportGroupAddressRecord) RecordType() IGMPv3ReportRecordType {
return IGMPv3ReportRecordType(r[igmpv3ReportGroupAddressRecordTypeOffset])
}
// AuxDataLen returns the length of the auxiliary data in this record.
func (r IGMPv3ReportGroupAddressRecord) AuxDataLen() int {
return int(r[igmpv3ReportGroupAddressRecordAuxDataLenOffset]) * igmpv3ReportGroupAddressRecordAuxDataLenUnits
}
// numberOfSources returns the number of sources in this record.
func (r IGMPv3ReportGroupAddressRecord) numberOfSources() uint16 {
return binary.BigEndian.Uint16(r[igmpv3ReportGroupAddressRecordNumberOfSourcesOffset:])
}
// GroupAddress returns the multicast address this record targets.
func (r IGMPv3ReportGroupAddressRecord) GroupAddress() tcpip.Address {
return tcpip.AddrFrom4([4]byte(r[igmpv3ReportGroupAddressRecordGroupAddressOffset:][:IPv4AddressSize]))
}
// Sources returns an iterator over source addresses in the query.
//
// Returns false if the message cannot hold the expected number of sources.
func (r IGMPv3ReportGroupAddressRecord) Sources() (AddressIterator, bool) {
expectedLen := int(r.numberOfSources()) * IPv4AddressSize
b := r[igmpv3ReportGroupAddressRecordSourcesOffset:]
if len(b) < expectedLen {
return AddressIterator{}, false
}
return AddressIterator{addressSize: IPv4AddressSize, buf: bytes.NewBuffer(b[:expectedLen])}, true
}
// IGMPv3Report is an IGMPv3 Report.
//
// As per RFC 3810 section 5.2,
//
// 0 1 2 3
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Type = 143 | Reserved | Checksum |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Reserved |Nr of Mcast Address Records (M)|
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | |
// . .
// . Multicast Address Record [1] .
// . .
// | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | |
// . .
// . Multicast Address Record [2] .
// . .
// | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | . |
// . . .
// | . |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | |
// . .
// . Multicast Address Record [M] .
// . .
// | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
type IGMPv3Report []byte
// Checksum returns the checksum.
func (i IGMPv3Report) Checksum() uint16 {
return binary.BigEndian.Uint16(i[igmpChecksumOffset:])
}
// IGMPv3ReportGroupAddressRecordIterator is an iterator over IGMPv3 Multicast
// Address Records.
type IGMPv3ReportGroupAddressRecordIterator struct {
recordsLeft uint16
buf *bytes.Buffer
}
// IGMPv3ReportGroupAddressRecordIteratorNextDisposition is the possible
// return values from IGMPv3ReportGroupAddressRecordIterator.Next.
type IGMPv3ReportGroupAddressRecordIteratorNextDisposition int
const (
// IGMPv3ReportGroupAddressRecordIteratorNextOk indicates that a multicast
// address record was yielded.
IGMPv3ReportGroupAddressRecordIteratorNextOk IGMPv3ReportGroupAddressRecordIteratorNextDisposition = iota
// IGMPv3ReportGroupAddressRecordIteratorNextDone indicates that the iterator
// has been exhausted.
IGMPv3ReportGroupAddressRecordIteratorNextDone
// IGMPv3ReportGroupAddressRecordIteratorNextErrBufferTooShort indicates
// that the iterator expected another record, but the buffer ended
// prematurely.
IGMPv3ReportGroupAddressRecordIteratorNextErrBufferTooShort
)
// Next returns the next IGMPv3 Multicast Address Record.
func (it *IGMPv3ReportGroupAddressRecordIterator) Next() (IGMPv3ReportGroupAddressRecord, IGMPv3ReportGroupAddressRecordIteratorNextDisposition) {
if it.recordsLeft == 0 {
return IGMPv3ReportGroupAddressRecord{}, IGMPv3ReportGroupAddressRecordIteratorNextDone
}
if it.buf.Len() < igmpv3ReportGroupAddressRecordMinimumSize {
return IGMPv3ReportGroupAddressRecord{}, IGMPv3ReportGroupAddressRecordIteratorNextErrBufferTooShort
}
hdr := IGMPv3ReportGroupAddressRecord(it.buf.Bytes())
expectedLen := igmpv3ReportGroupAddressRecordMinimumSize +
int(hdr.AuxDataLen()) + int(hdr.numberOfSources())*IPv4AddressSize
bytes := it.buf.Next(expectedLen)
if len(bytes) < expectedLen {
return IGMPv3ReportGroupAddressRecord{}, IGMPv3ReportGroupAddressRecordIteratorNextErrBufferTooShort
}
it.recordsLeft--
return IGMPv3ReportGroupAddressRecord(bytes), IGMPv3ReportGroupAddressRecordIteratorNextOk
}
// GroupAddressRecords returns an iterator of IGMPv3 Multicast Address
// Records.
func (i IGMPv3Report) GroupAddressRecords() IGMPv3ReportGroupAddressRecordIterator {
return IGMPv3ReportGroupAddressRecordIterator{
recordsLeft: binary.BigEndian.Uint16(i[igmpv3ReportNumberOfGroupAddressRecordsOffset:]),
buf: bytes.NewBuffer(i[igmpv3ReportGroupAddressRecordsOffset:]),
}
}

View File

@@ -0,0 +1,130 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"gvisor.dev/gvisor/pkg/tcpip"
)
const (
// MaxIPPacketSize is the maximum supported IP packet size, excluding
// jumbograms. The maximum IPv4 packet size is 64k-1 (total size must fit
// in 16 bits). For IPv6, the payload max size (excluding jumbograms) is
// 64k-1 (also needs to fit in 16 bits). So we use 64k - 1 + 2 * m, where
// m is the minimum IPv6 header size; we leave room for some potential
// IP options.
MaxIPPacketSize = 0xffff + 2*IPv6MinimumSize
)
// Transport offers generic methods to query and/or update the fields of the
// header of a transport protocol buffer.
type Transport interface {
// SourcePort returns the value of the "source port" field.
SourcePort() uint16
// Destination returns the value of the "destination port" field.
DestinationPort() uint16
// Checksum returns the value of the "checksum" field.
Checksum() uint16
// SetSourcePort sets the value of the "source port" field.
SetSourcePort(uint16)
// SetDestinationPort sets the value of the "destination port" field.
SetDestinationPort(uint16)
// SetChecksum sets the value of the "checksum" field.
SetChecksum(uint16)
// Payload returns the data carried in the transport buffer.
Payload() []byte
}
// ChecksummableTransport is a Transport that supports checksumming.
type ChecksummableTransport interface {
Transport
// SetSourcePortWithChecksumUpdate sets the source port and updates
// the checksum.
//
// The receiver's checksum must be a fully calculated checksum.
SetSourcePortWithChecksumUpdate(port uint16)
// SetDestinationPortWithChecksumUpdate sets the destination port and updates
// the checksum.
//
// The receiver's checksum must be a fully calculated checksum.
SetDestinationPortWithChecksumUpdate(port uint16)
// UpdateChecksumPseudoHeaderAddress updates the checksum to reflect an
// updated address in the pseudo header.
//
// If fullChecksum is true, the receiver's checksum field is assumed to hold a
// fully calculated checksum. Otherwise, it is assumed to hold a partially
// calculated checksum which only reflects the pseudo header.
UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool)
}
// Network offers generic methods to query and/or update the fields of the
// header of a network protocol buffer.
type Network interface {
// SourceAddress returns the value of the "source address" field.
SourceAddress() tcpip.Address
// DestinationAddress returns the value of the "destination address"
// field.
DestinationAddress() tcpip.Address
// Checksum returns the value of the "checksum" field.
Checksum() uint16
// SetSourceAddress sets the value of the "source address" field.
SetSourceAddress(tcpip.Address)
// SetDestinationAddress sets the value of the "destination address"
// field.
SetDestinationAddress(tcpip.Address)
// SetChecksum sets the value of the "checksum" field.
SetChecksum(uint16)
// TransportProtocol returns the number of the transport protocol
// stored in the payload.
TransportProtocol() tcpip.TransportProtocolNumber
// Payload returns a byte slice containing the payload of the network
// packet.
Payload() []byte
// TOS returns the values of the "type of service" and "flow label" fields.
TOS() (uint8, uint32)
// SetTOS sets the values of the "type of service" and "flow label" fields.
SetTOS(t uint8, l uint32)
}
// ChecksummableNetwork is a Network that supports checksumming.
type ChecksummableNetwork interface {
Network
// SetSourceAddressAndChecksum sets the source address and updates the
// checksum to reflect the new address.
SetSourceAddressWithChecksumUpdate(tcpip.Address)
// SetDestinationAddressAndChecksum sets the destination address and
// updates the checksum to reflect the new address.
SetDestinationAddressWithChecksumUpdate(tcpip.Address)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,597 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"crypto/sha256"
"encoding/binary"
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
)
const (
versTCFL = 0
// IPv6PayloadLenOffset is the offset of the PayloadLength field in
// IPv6 header.
IPv6PayloadLenOffset = 4
// IPv6NextHeaderOffset is the offset of the NextHeader field in
// IPv6 header.
IPv6NextHeaderOffset = 6
hopLimit = 7
v6SrcAddr = 8
v6DstAddr = v6SrcAddr + IPv6AddressSize
// IPv6FixedHeaderSize is the size of the fixed header.
IPv6FixedHeaderSize = v6DstAddr + IPv6AddressSize
)
// IPv6Fields contains the fields of an IPv6 packet. It is used to describe the
// fields of a packet that needs to be encoded.
type IPv6Fields struct {
// TrafficClass is the "traffic class" field of an IPv6 packet.
TrafficClass uint8
// FlowLabel is the "flow label" field of an IPv6 packet.
FlowLabel uint32
// PayloadLength is the "payload length" field of an IPv6 packet, including
// the length of all extension headers.
PayloadLength uint16
// TransportProtocol is the transport layer protocol number. Serialized in the
// last "next header" field of the IPv6 header + extension headers.
TransportProtocol tcpip.TransportProtocolNumber
// HopLimit is the "Hop Limit" field of an IPv6 packet.
HopLimit uint8
// SrcAddr is the "source ip address" of an IPv6 packet.
SrcAddr tcpip.Address
// DstAddr is the "destination ip address" of an IPv6 packet.
DstAddr tcpip.Address
// ExtensionHeaders are the extension headers following the IPv6 header.
ExtensionHeaders IPv6ExtHdrSerializer
}
// IPv6 represents an ipv6 header stored in a byte array.
// Most of the methods of IPv6 access to the underlying slice without
// checking the boundaries and could panic because of 'index out of range'.
// Always call IsValid() to validate an instance of IPv6 before using other methods.
type IPv6 []byte
const (
// IPv6MinimumSize is the minimum size of a valid IPv6 packet.
IPv6MinimumSize = IPv6FixedHeaderSize
// IPv6AddressSize is the size, in bytes, of an IPv6 address.
IPv6AddressSize = 16
// IPv6AddressSizeBits is the size, in bits, of an IPv6 address.
IPv6AddressSizeBits = 128
// IPv6MaximumPayloadSize is the maximum size of a valid IPv6 payload per
// RFC 8200 Section 4.5.
IPv6MaximumPayloadSize = 65535
// IPv6ProtocolNumber is IPv6's network protocol number.
IPv6ProtocolNumber tcpip.NetworkProtocolNumber = 0x86dd
// IPv6Version is the version of the ipv6 protocol.
IPv6Version = 6
// IIDSize is the size of an interface identifier (IID), in bytes, as
// defined by RFC 4291 section 2.5.1.
IIDSize = 8
// IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 8200,
// section 5:
// IPv6 requires that every link in the Internet have an MTU of 1280 octets
// or greater. This is known as the IPv6 minimum link MTU.
IPv6MinimumMTU = 1280
// IIDOffsetInIPv6Address is the offset, in bytes, from the start
// of an IPv6 address to the beginning of the interface identifier
// (IID) for auto-generated addresses. That is, all bytes before
// the IIDOffsetInIPv6Address-th byte are the prefix bytes, and all
// bytes including and after the IIDOffsetInIPv6Address-th byte are
// for the IID.
IIDOffsetInIPv6Address = 8
// OpaqueIIDSecretKeyMinBytes is the recommended minimum number of bytes
// for the secret key used to generate an opaque interface identifier as
// outlined by RFC 7217.
OpaqueIIDSecretKeyMinBytes = 16
// ipv6MulticastAddressScopeByteIdx is the byte where the scope (scop) field
// is located within a multicast IPv6 address, as per RFC 4291 section 2.7.
ipv6MulticastAddressScopeByteIdx = 1
// ipv6MulticastAddressScopeMask is the mask for the scope (scop) field,
// within the byte holding the field, as per RFC 4291 section 2.7.
ipv6MulticastAddressScopeMask = 0xF
)
var (
// IPv6AllNodesMulticastAddress is a link-local multicast group that
// all IPv6 nodes MUST join, as per RFC 4291, section 2.8. Packets
// destined to this address will reach all nodes on a link.
//
// The address is ff02::1.
IPv6AllNodesMulticastAddress = tcpip.AddrFrom16([16]byte{0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01})
// IPv6AllRoutersInterfaceLocalMulticastAddress is an interface-local
// multicast group that all IPv6 routers MUST join, as per RFC 4291, section
// 2.8. Packets destined to this address will reach the router on an
// interface.
//
// The address is ff01::2.
IPv6AllRoutersInterfaceLocalMulticastAddress = tcpip.AddrFrom16([16]byte{0xff, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02})
// IPv6AllRoutersLinkLocalMulticastAddress is a link-local multicast group
// that all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets
// destined to this address will reach all routers on a link.
//
// The address is ff02::2.
IPv6AllRoutersLinkLocalMulticastAddress = tcpip.AddrFrom16([16]byte{0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02})
// IPv6AllRoutersSiteLocalMulticastAddress is a site-local multicast group
// that all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets
// destined to this address will reach all routers in a site.
//
// The address is ff05::2.
IPv6AllRoutersSiteLocalMulticastAddress = tcpip.AddrFrom16([16]byte{0xff, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02})
// IPv6Loopback is the IPv6 Loopback address.
IPv6Loopback = tcpip.AddrFrom16([16]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01})
// IPv6Any is the non-routable IPv6 "any" meta address. It is also
// known as the unspecified address.
IPv6Any = tcpip.AddrFrom16([16]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
)
// IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the
// catch-all or wildcard subnet. That is, all IPv6 addresses are considered to
// be contained within this subnet.
var IPv6EmptySubnet = tcpip.AddressWithPrefix{
Address: IPv6Any,
PrefixLen: 0,
}.Subnet()
// IPv4MappedIPv6Subnet is the prefix for an IPv4 mapped IPv6 address as defined
// by RFC 4291 section 2.5.5.
var IPv4MappedIPv6Subnet = tcpip.AddressWithPrefix{
Address: tcpip.AddrFrom16([16]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00}),
PrefixLen: 96,
}.Subnet()
// IPv6LinkLocalPrefix is the prefix for IPv6 link-local addresses, as defined
// by RFC 4291 section 2.5.6.
//
// The prefix is fe80::/64
var IPv6LinkLocalPrefix = tcpip.AddressWithPrefix{
Address: tcpip.AddrFrom16([16]byte{0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
PrefixLen: 64,
}
// PayloadLength returns the value of the "payload length" field of the ipv6
// header.
func (b IPv6) PayloadLength() uint16 {
return binary.BigEndian.Uint16(b[IPv6PayloadLenOffset:])
}
// HopLimit returns the value of the "Hop Limit" field of the ipv6 header.
func (b IPv6) HopLimit() uint8 {
return b[hopLimit]
}
// NextHeader returns the value of the "next header" field of the ipv6 header.
func (b IPv6) NextHeader() uint8 {
return b[IPv6NextHeaderOffset]
}
// TransportProtocol implements Network.TransportProtocol.
func (b IPv6) TransportProtocol() tcpip.TransportProtocolNumber {
return tcpip.TransportProtocolNumber(b.NextHeader())
}
// Payload implements Network.Payload.
func (b IPv6) Payload() []byte {
return b[IPv6MinimumSize:][:b.PayloadLength()]
}
// SourceAddress returns the "source address" field of the ipv6 header.
func (b IPv6) SourceAddress() tcpip.Address {
return tcpip.AddrFrom16([16]byte(b[v6SrcAddr:][:IPv6AddressSize]))
}
// DestinationAddress returns the "destination address" field of the ipv6
// header.
func (b IPv6) DestinationAddress() tcpip.Address {
return tcpip.AddrFrom16([16]byte(b[v6DstAddr:][:IPv6AddressSize]))
}
// SourceAddressSlice returns the "source address" field of the ipv6 header as a
// byte slice.
func (b IPv6) SourceAddressSlice() []byte {
return []byte(b[v6SrcAddr:][:IPv6AddressSize])
}
// DestinationAddressSlice returns the "destination address" field of the ipv6
// header as a byte slice.
func (b IPv6) DestinationAddressSlice() []byte {
return []byte(b[v6DstAddr:][:IPv6AddressSize])
}
// Checksum implements Network.Checksum. Given that IPv6 doesn't have a
// checksum, it just returns 0.
func (IPv6) Checksum() uint16 {
return 0
}
// TOS returns the "traffic class" and "flow label" fields of the ipv6 header.
func (b IPv6) TOS() (uint8, uint32) {
v := binary.BigEndian.Uint32(b[versTCFL:])
return uint8(v >> 20), v & 0xfffff
}
// SetTOS sets the "traffic class" and "flow label" fields of the ipv6 header.
func (b IPv6) SetTOS(t uint8, l uint32) {
vtf := (6 << 28) | (uint32(t) << 20) | (l & 0xfffff)
binary.BigEndian.PutUint32(b[versTCFL:], vtf)
}
// SetPayloadLength sets the "payload length" field of the ipv6 header.
func (b IPv6) SetPayloadLength(payloadLength uint16) {
binary.BigEndian.PutUint16(b[IPv6PayloadLenOffset:], payloadLength)
}
// SetSourceAddress sets the "source address" field of the ipv6 header.
func (b IPv6) SetSourceAddress(addr tcpip.Address) {
copy(b[v6SrcAddr:][:IPv6AddressSize], addr.AsSlice())
}
// SetDestinationAddress sets the "destination address" field of the ipv6
// header.
func (b IPv6) SetDestinationAddress(addr tcpip.Address) {
copy(b[v6DstAddr:][:IPv6AddressSize], addr.AsSlice())
}
// SetHopLimit sets the value of the "Hop Limit" field.
func (b IPv6) SetHopLimit(v uint8) {
b[hopLimit] = v
}
// SetNextHeader sets the value of the "next header" field of the ipv6 header.
func (b IPv6) SetNextHeader(v uint8) {
b[IPv6NextHeaderOffset] = v
}
// SetChecksum implements Network.SetChecksum. Given that IPv6 doesn't have a
// checksum, it is empty.
func (IPv6) SetChecksum(uint16) {
}
// Encode encodes all the fields of the ipv6 header.
func (b IPv6) Encode(i *IPv6Fields) {
extHdr := b[IPv6MinimumSize:]
b.SetTOS(i.TrafficClass, i.FlowLabel)
b.SetPayloadLength(i.PayloadLength)
b[hopLimit] = i.HopLimit
b.SetSourceAddress(i.SrcAddr)
b.SetDestinationAddress(i.DstAddr)
nextHeader, _ := i.ExtensionHeaders.Serialize(i.TransportProtocol, extHdr)
b[IPv6NextHeaderOffset] = nextHeader
}
// IsValid performs basic validation on the packet.
func (b IPv6) IsValid(pktSize int) bool {
if len(b) < IPv6MinimumSize {
return false
}
dlen := int(b.PayloadLength())
if dlen > pktSize-IPv6MinimumSize {
return false
}
if IPVersion(b) != IPv6Version {
return false
}
return true
}
// IsV4MappedAddress determines if the provided address is an IPv4 mapped
// address by checking if its prefix is 0:0:0:0:0:ffff::/96.
func IsV4MappedAddress(addr tcpip.Address) bool {
if addr.BitLen() != IPv6AddressSizeBits {
return false
}
return IPv4MappedIPv6Subnet.Contains(addr)
}
// IsV6MulticastAddress determines if the provided address is an IPv6
// multicast address (anything starting with FF).
func IsV6MulticastAddress(addr tcpip.Address) bool {
if addr.BitLen() != IPv6AddressSizeBits {
return false
}
return addr.As16()[0] == 0xff
}
// IsV6UnicastAddress determines if the provided address is a valid IPv6
// unicast (and specified) address. That is, IsV6UnicastAddress returns
// true if addr contains IPv6AddressSize bytes, is not the unspecified
// address and is not a multicast address.
func IsV6UnicastAddress(addr tcpip.Address) bool {
if addr.BitLen() != IPv6AddressSizeBits {
return false
}
// Must not be unspecified
if addr == IPv6Any {
return false
}
// Return if not a multicast.
return addr.As16()[0] != 0xff
}
var solicitedNodeMulticastPrefix = [13]byte{0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0xff}
// SolicitedNodeAddr computes the solicited-node multicast address. This is
// used for NDP. Described in RFC 4291. The argument must be a full-length IPv6
// address.
func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address {
addrBytes := addr.As16()
return tcpip.AddrFrom16([16]byte(append(solicitedNodeMulticastPrefix[:], addrBytes[len(addrBytes)-3:]...)))
}
// IsSolicitedNodeAddr determines whether the address is a solicited-node
// multicast address.
func IsSolicitedNodeAddr(addr tcpip.Address) bool {
addrBytes := addr.As16()
return solicitedNodeMulticastPrefix == [13]byte(addrBytes[:len(addrBytes)-3])
}
// EthernetAdddressToModifiedEUI64IntoBuf populates buf with a modified EUI-64
// from a 48-bit Ethernet/MAC address, as per RFC 4291 section 2.5.1.
//
// buf MUST be at least 8 bytes.
func EthernetAdddressToModifiedEUI64IntoBuf(linkAddr tcpip.LinkAddress, buf []byte) {
buf[0] = linkAddr[0] ^ 2
buf[1] = linkAddr[1]
buf[2] = linkAddr[2]
buf[3] = 0xFF
buf[4] = 0xFE
buf[5] = linkAddr[3]
buf[6] = linkAddr[4]
buf[7] = linkAddr[5]
}
// EthernetAddressToModifiedEUI64 computes a modified EUI-64 from a 48-bit
// Ethernet/MAC address, as per RFC 4291 section 2.5.1.
func EthernetAddressToModifiedEUI64(linkAddr tcpip.LinkAddress) [IIDSize]byte {
var buf [IIDSize]byte
EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, buf[:])
return buf
}
// LinkLocalAddr computes the default IPv6 link-local address from a link-layer
// (MAC) address.
func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address {
// Convert a 48-bit MAC to a modified EUI-64 and then prepend the
// link-local header, FE80::.
//
// The conversion is very nearly:
// aa:bb:cc:dd:ee:ff => FE80::Aabb:ccFF:FEdd:eeff
// Note the capital A. The conversion aa->Aa involves a bit flip.
lladdrb := [IPv6AddressSize]byte{
0: 0xFE,
1: 0x80,
}
EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, lladdrb[IIDOffsetInIPv6Address:])
return tcpip.AddrFrom16(lladdrb)
}
// IsV6LinkLocalUnicastAddress returns true iff the provided address is an IPv6
// link-local unicast address, as defined by RFC 4291 section 2.5.6.
func IsV6LinkLocalUnicastAddress(addr tcpip.Address) bool {
if addr.BitLen() != IPv6AddressSizeBits {
return false
}
addrBytes := addr.As16()
return addrBytes[0] == 0xfe && (addrBytes[1]&0xc0) == 0x80
}
// IsV6LoopbackAddress returns true iff the provided address is an IPv6 loopback
// address, as defined by RFC 4291 section 2.5.3.
func IsV6LoopbackAddress(addr tcpip.Address) bool {
return addr == IPv6Loopback
}
// IsV6LinkLocalMulticastAddress returns true iff the provided address is an
// IPv6 link-local multicast address, as defined by RFC 4291 section 2.7.
func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool {
return IsV6MulticastAddress(addr) && V6MulticastScope(addr) == IPv6LinkLocalMulticastScope
}
// AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier
// (IID) to buf as outlined by RFC 7217 and returns the extended buffer.
//
// The opaque IID is generated from the cryptographic hash of the concatenation
// of the prefix, NIC's name, DAD counter (DAD retry counter) and the secret
// key. The secret key SHOULD be at least OpaqueIIDSecretKeyMinBytes bytes and
// MUST be generated to a pseudo-random number. See RFC 4086 for randomness
// requirements for security.
//
// If buf has enough capacity for the IID (IIDSize bytes), a new underlying
// array for the buffer will not be allocated.
func AppendOpaqueInterfaceIdentifier(buf []byte, prefix tcpip.Subnet, nicName string, dadCounter uint8, secretKey []byte) []byte {
// As per RFC 7217 section 5, the opaque identifier can be generated as a
// cryptographic hash of the concatenation of each of the function parameters.
// Note, we omit the optional Network_ID field.
h := sha256.New()
// h.Write never returns an error.
prefixID := prefix.ID()
h.Write([]byte(prefixID.AsSlice()[:IIDOffsetInIPv6Address]))
h.Write([]byte(nicName))
h.Write([]byte{dadCounter})
h.Write(secretKey)
var sumBuf [sha256.Size]byte
sum := h.Sum(sumBuf[:0])
return append(buf, sum[:IIDSize]...)
}
// LinkLocalAddrWithOpaqueIID computes the default IPv6 link-local address with
// an opaque IID.
func LinkLocalAddrWithOpaqueIID(nicName string, dadCounter uint8, secretKey []byte) tcpip.Address {
lladdrb := [IPv6AddressSize]byte{
0: 0xFE,
1: 0x80,
}
return tcpip.AddrFrom16([16]byte(AppendOpaqueInterfaceIdentifier(lladdrb[:IIDOffsetInIPv6Address], IPv6LinkLocalPrefix.Subnet(), nicName, dadCounter, secretKey)))
}
// IPv6AddressScope is the scope of an IPv6 address.
type IPv6AddressScope int
const (
// LinkLocalScope indicates a link-local address.
LinkLocalScope IPv6AddressScope = iota
// GlobalScope indicates a global address.
GlobalScope
)
// ScopeForIPv6Address returns the scope for an IPv6 address.
func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, tcpip.Error) {
if addr.BitLen() != IPv6AddressSizeBits {
return GlobalScope, &tcpip.ErrBadAddress{}
}
switch {
case IsV6LinkLocalMulticastAddress(addr):
return LinkLocalScope, nil
case IsV6LinkLocalUnicastAddress(addr):
return LinkLocalScope, nil
default:
return GlobalScope, nil
}
}
// InitialTempIID generates the initial temporary IID history value to generate
// temporary SLAAC addresses with.
//
// Panics if initialTempIIDHistory is not at least IIDSize bytes.
func InitialTempIID(initialTempIIDHistory []byte, seed []byte, nicID tcpip.NICID) {
h := sha256.New()
// h.Write never returns an error.
h.Write(seed)
var nicIDBuf [4]byte
binary.BigEndian.PutUint32(nicIDBuf[:], uint32(nicID))
h.Write(nicIDBuf[:])
var sumBuf [sha256.Size]byte
sum := h.Sum(sumBuf[:0])
if n := copy(initialTempIIDHistory, sum[sha256.Size-IIDSize:]); n != IIDSize {
panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IIDSize))
}
}
// GenerateTempIPv6SLAACAddr generates a temporary SLAAC IPv6 address for an
// associated stable/permanent SLAAC address.
//
// GenerateTempIPv6SLAACAddr will update the temporary IID history value to be
// used when generating a new temporary IID.
//
// Panics if tempIIDHistory is not at least IIDSize bytes.
func GenerateTempIPv6SLAACAddr(tempIIDHistory []byte, stableAddr tcpip.Address) tcpip.AddressWithPrefix {
addrBytes := stableAddr.As16()
h := sha256.New()
h.Write(tempIIDHistory)
h.Write(addrBytes[IIDOffsetInIPv6Address:])
var sumBuf [sha256.Size]byte
sum := h.Sum(sumBuf[:0])
// The rightmost 64 bits of sum are saved for the next iteration.
if n := copy(tempIIDHistory, sum[sha256.Size-IIDSize:]); n != IIDSize {
panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IIDSize))
}
// The leftmost 64 bits of sum is used as the IID.
if n := copy(addrBytes[IIDOffsetInIPv6Address:], sum); n != IIDSize {
panic(fmt.Sprintf("copied %d IID bytes, expected %d bytes", n, IIDSize))
}
return tcpip.AddressWithPrefix{
Address: tcpip.AddrFrom16(addrBytes),
PrefixLen: IIDOffsetInIPv6Address * 8,
}
}
// IPv6MulticastScope is the scope of a multicast IPv6 address, as defined by
// RFC 7346 section 2.
type IPv6MulticastScope uint8
// The various values for IPv6 multicast scopes, as per RFC 7346 section 2:
//
// +------+--------------------------+-------------------------+
// | scop | NAME | REFERENCE |
// +------+--------------------------+-------------------------+
// | 0 | Reserved | [RFC4291], RFC 7346 |
// | 1 | Interface-Local scope | [RFC4291], RFC 7346 |
// | 2 | Link-Local scope | [RFC4291], RFC 7346 |
// | 3 | Realm-Local scope | [RFC4291], RFC 7346 |
// | 4 | Admin-Local scope | [RFC4291], RFC 7346 |
// | 5 | Site-Local scope | [RFC4291], RFC 7346 |
// | 6 | Unassigned | |
// | 7 | Unassigned | |
// | 8 | Organization-Local scope | [RFC4291], RFC 7346 |
// | 9 | Unassigned | |
// | A | Unassigned | |
// | B | Unassigned | |
// | C | Unassigned | |
// | D | Unassigned | |
// | E | Global scope | [RFC4291], RFC 7346 |
// | F | Reserved | [RFC4291], RFC 7346 |
// +------+--------------------------+-------------------------+
const (
IPv6Reserved0MulticastScope = IPv6MulticastScope(0x0)
IPv6InterfaceLocalMulticastScope = IPv6MulticastScope(0x1)
IPv6LinkLocalMulticastScope = IPv6MulticastScope(0x2)
IPv6RealmLocalMulticastScope = IPv6MulticastScope(0x3)
IPv6AdminLocalMulticastScope = IPv6MulticastScope(0x4)
IPv6SiteLocalMulticastScope = IPv6MulticastScope(0x5)
IPv6OrganizationLocalMulticastScope = IPv6MulticastScope(0x8)
IPv6GlobalMulticastScope = IPv6MulticastScope(0xE)
IPv6ReservedFMulticastScope = IPv6MulticastScope(0xF)
)
// V6MulticastScope returns the scope of a multicast address.
func V6MulticastScope(addr tcpip.Address) IPv6MulticastScope {
addrBytes := addr.As16()
return IPv6MulticastScope(addrBytes[ipv6MulticastAddressScopeByteIdx] & ipv6MulticastAddressScopeMask)
}

View File

@@ -0,0 +1,955 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"encoding/binary"
"errors"
"fmt"
"io"
"math"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
)
// IPv6ExtensionHeaderIdentifier is an IPv6 extension header identifier.
type IPv6ExtensionHeaderIdentifier uint8
const (
// IPv6HopByHopOptionsExtHdrIdentifier is the header identifier of a Hop by
// Hop Options extension header, as per RFC 8200 section 4.3.
IPv6HopByHopOptionsExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 0
// IPv6RoutingExtHdrIdentifier is the header identifier of a Routing extension
// header, as per RFC 8200 section 4.4.
IPv6RoutingExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 43
// IPv6FragmentExtHdrIdentifier is the header identifier of a Fragment
// extension header, as per RFC 8200 section 4.5.
IPv6FragmentExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 44
// IPv6DestinationOptionsExtHdrIdentifier is the header identifier of a
// Destination Options extension header, as per RFC 8200 section 4.6.
IPv6DestinationOptionsExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 60
// IPv6NoNextHeaderIdentifier is the header identifier used to signify the end
// of an IPv6 payload, as per RFC 8200 section 4.7.
IPv6NoNextHeaderIdentifier IPv6ExtensionHeaderIdentifier = 59
// IPv6UnknownExtHdrIdentifier is reserved by IANA.
// https://www.iana.org/assignments/ipv6-parameters/ipv6-parameters.xhtml#extension-header
// "254 Use for experimentation and testing [RFC3692][RFC4727]"
IPv6UnknownExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 254
)
const (
// ipv6UnknownExtHdrOptionActionMask is the mask of the action to take when
// a node encounters an unrecognized option.
ipv6UnknownExtHdrOptionActionMask = 192
// ipv6UnknownExtHdrOptionActionShift is the least significant bits to discard
// from the action value for an unrecognized option identifier.
ipv6UnknownExtHdrOptionActionShift = 6
// ipv6RoutingExtHdrSegmentsLeftIdx is the index to the Segments Left field
// within an IPv6RoutingExtHdr.
ipv6RoutingExtHdrSegmentsLeftIdx = 1
// IPv6FragmentExtHdrLength is the length of an IPv6 extension header, in
// bytes.
IPv6FragmentExtHdrLength = 8
// ipv6FragmentExtHdrFragmentOffsetOffset is the offset to the start of the
// Fragment Offset field within an IPv6FragmentExtHdr.
ipv6FragmentExtHdrFragmentOffsetOffset = 0
// ipv6FragmentExtHdrFragmentOffsetShift is the bit offset of the Fragment
// Offset field within an IPv6FragmentExtHdr.
ipv6FragmentExtHdrFragmentOffsetShift = 3
// ipv6FragmentExtHdrFlagsIdx is the index to the flags field within an
// IPv6FragmentExtHdr.
ipv6FragmentExtHdrFlagsIdx = 1
// ipv6FragmentExtHdrMFlagMask is the mask of the More (M) flag within the
// flags field of an IPv6FragmentExtHdr.
ipv6FragmentExtHdrMFlagMask = 1
// ipv6FragmentExtHdrIdentificationOffset is the offset to the Identification
// field within an IPv6FragmentExtHdr.
ipv6FragmentExtHdrIdentificationOffset = 2
// ipv6ExtHdrLenBytesPerUnit is the unit size of an extension header's length
// field. That is, given a Length field of 2, the extension header expects
// 16 bytes following the first 8 bytes (see ipv6ExtHdrLenBytesExcluded for
// details about the first 8 bytes' exclusion from the Length field).
ipv6ExtHdrLenBytesPerUnit = 8
// ipv6ExtHdrLenBytesExcluded is the number of bytes excluded from an
// extension header's Length field following the Length field.
//
// The Length field excludes the first 8 bytes, but the Next Header and Length
// field take up the first 2 of the 8 bytes so we expect (at minimum) 6 bytes
// after the Length field.
//
// This ensures that every extension header is at least 8 bytes.
ipv6ExtHdrLenBytesExcluded = 6
// IPv6FragmentExtHdrFragmentOffsetBytesPerUnit is the unit size of a Fragment
// extension header's Fragment Offset field. That is, given a Fragment Offset
// of 2, the extension header is indicating that the fragment's payload
// starts at the 16th byte in the reassembled packet.
IPv6FragmentExtHdrFragmentOffsetBytesPerUnit = 8
)
// padIPv6OptionsLength returns the total length for IPv6 options of length l
// considering the 8-octet alignment as stated in RFC 8200 Section 4.2.
func padIPv6OptionsLength(length int) int {
return (length + ipv6ExtHdrLenBytesPerUnit - 1) & ^(ipv6ExtHdrLenBytesPerUnit - 1)
}
// padIPv6Option fills b with the appropriate padding options depending on its
// length.
func padIPv6Option(b []byte) {
switch len(b) {
case 0: // No padding needed.
case 1: // Pad with Pad1.
b[ipv6ExtHdrOptionTypeOffset] = uint8(ipv6Pad1ExtHdrOptionIdentifier)
default: // Pad with PadN.
s := b[ipv6ExtHdrOptionPayloadOffset:]
clear(s)
b[ipv6ExtHdrOptionTypeOffset] = uint8(ipv6PadNExtHdrOptionIdentifier)
b[ipv6ExtHdrOptionLengthOffset] = uint8(len(s))
}
}
// ipv6OptionsAlignmentPadding returns the number of padding bytes needed to
// serialize an option at headerOffset with alignment requirements
// [align]n + alignOffset.
func ipv6OptionsAlignmentPadding(headerOffset int, align int, alignOffset int) int {
padLen := headerOffset - alignOffset
return ((padLen + align - 1) & ^(align - 1)) - padLen
}
// IPv6PayloadHeader is implemented by the various headers that can be found
// in an IPv6 payload.
//
// These headers include IPv6 extension headers or upper layer data.
type IPv6PayloadHeader interface {
isIPv6PayloadHeader()
// Release frees all resources held by the header.
Release()
}
// IPv6RawPayloadHeader the remainder of an IPv6 payload after an iterator
// encounters a Next Header field it does not recognize as an IPv6 extension
// header. The caller is responsible for releasing the underlying buffer after
// it's no longer needed.
type IPv6RawPayloadHeader struct {
Identifier IPv6ExtensionHeaderIdentifier
Buf buffer.Buffer
}
// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
func (IPv6RawPayloadHeader) isIPv6PayloadHeader() {}
// Release implements IPv6PayloadHeader.Release.
func (i IPv6RawPayloadHeader) Release() {
i.Buf.Release()
}
// ipv6OptionsExtHdr is an IPv6 extension header that holds options.
type ipv6OptionsExtHdr struct {
buf *buffer.View
}
// Release implements IPv6PayloadHeader.Release.
func (i ipv6OptionsExtHdr) Release() {
if i.buf != nil {
i.buf.Release()
}
}
// Iter returns an iterator over the IPv6 extension header options held in b.
func (i ipv6OptionsExtHdr) Iter() IPv6OptionsExtHdrOptionsIterator {
it := IPv6OptionsExtHdrOptionsIterator{}
it.reader = i.buf
return it
}
// IPv6OptionsExtHdrOptionsIterator is an iterator over IPv6 extension header
// options.
//
// Note, between when an IPv6OptionsExtHdrOptionsIterator is obtained and last
// used, no changes to the underlying buffer may happen. Doing so may cause
// undefined and unexpected behaviour. It is fine to obtain an
// IPv6OptionsExtHdrOptionsIterator, iterate over the first few options then
// modify the backing payload so long as the IPv6OptionsExtHdrOptionsIterator
// obtained before modification is no longer used.
type IPv6OptionsExtHdrOptionsIterator struct {
reader *buffer.View
// optionOffset is the number of bytes from the first byte of the
// options field to the beginning of the current option.
optionOffset uint32
// nextOptionOffset is the offset of the next option.
nextOptionOffset uint32
}
// OptionOffset returns the number of bytes parsed while processing the
// option field of the current Extension Header.
func (i *IPv6OptionsExtHdrOptionsIterator) OptionOffset() uint32 {
return i.optionOffset
}
// IPv6OptionUnknownAction is the action that must be taken if the processing
// IPv6 node does not recognize the option, as outlined in RFC 8200 section 4.2.
type IPv6OptionUnknownAction int
const (
// IPv6OptionUnknownActionSkip indicates that the unrecognized option must
// be skipped and the node should continue processing the header.
IPv6OptionUnknownActionSkip IPv6OptionUnknownAction = 0
// IPv6OptionUnknownActionDiscard indicates that the packet must be silently
// discarded.
IPv6OptionUnknownActionDiscard IPv6OptionUnknownAction = 1
// IPv6OptionUnknownActionDiscardSendICMP indicates that the packet must be
// discarded and the node must send an ICMP Parameter Problem, Code 2, message
// to the packet's source, regardless of whether or not the packet's
// Destination was a multicast address.
IPv6OptionUnknownActionDiscardSendICMP IPv6OptionUnknownAction = 2
// IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest indicates that the
// packet must be discarded and the node must send an ICMP Parameter Problem,
// Code 2, message to the packet's source only if the packet's Destination was
// not a multicast address.
IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest IPv6OptionUnknownAction = 3
)
// IPv6ExtHdrOption is implemented by the various IPv6 extension header options.
type IPv6ExtHdrOption interface {
// UnknownAction returns the action to take in response to an unrecognized
// option.
UnknownAction() IPv6OptionUnknownAction
// isIPv6ExtHdrOption is used to "lock" this interface so it is not
// implemented by other packages.
isIPv6ExtHdrOption()
}
// IPv6ExtHdrOptionIdentifier is an IPv6 extension header option identifier.
type IPv6ExtHdrOptionIdentifier uint8
const (
// ipv6Pad1ExtHdrOptionIdentifier is the identifier for a padding option that
// provides 1 byte padding, as outlined in RFC 8200 section 4.2.
ipv6Pad1ExtHdrOptionIdentifier IPv6ExtHdrOptionIdentifier = 0
// ipv6PadNExtHdrOptionIdentifier is the identifier for a padding option that
// provides variable length byte padding, as outlined in RFC 8200 section 4.2.
ipv6PadNExtHdrOptionIdentifier IPv6ExtHdrOptionIdentifier = 1
// ipv6RouterAlertHopByHopOptionIdentifier is the identifier for the Router
// Alert Hop by Hop option as defined in RFC 2711 section 2.1.
ipv6RouterAlertHopByHopOptionIdentifier IPv6ExtHdrOptionIdentifier = 5
// ipv6ExtHdrOptionTypeOffset is the option type offset in an extension header
// option as defined in RFC 8200 section 4.2.
ipv6ExtHdrOptionTypeOffset = 0
// ipv6ExtHdrOptionLengthOffset is the option length offset in an extension
// header option as defined in RFC 8200 section 4.2.
ipv6ExtHdrOptionLengthOffset = 1
// ipv6ExtHdrOptionPayloadOffset is the option payload offset in an extension
// header option as defined in RFC 8200 section 4.2.
ipv6ExtHdrOptionPayloadOffset = 2
)
// ipv6UnknownActionFromIdentifier maps an extension header option's
// identifier's high bits to the action to take when the identifier is unknown.
func ipv6UnknownActionFromIdentifier(id IPv6ExtHdrOptionIdentifier) IPv6OptionUnknownAction {
return IPv6OptionUnknownAction((id & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift)
}
// ErrMalformedIPv6ExtHdrOption indicates that an IPv6 extension header option
// is malformed.
var ErrMalformedIPv6ExtHdrOption = errors.New("malformed IPv6 extension header option")
// IPv6UnknownExtHdrOption holds the identifier and data for an IPv6 extension
// header option that is unknown by the parsing utilities.
type IPv6UnknownExtHdrOption struct {
Identifier IPv6ExtHdrOptionIdentifier
Data *buffer.View
}
// UnknownAction implements IPv6OptionUnknownAction.UnknownAction.
func (o *IPv6UnknownExtHdrOption) UnknownAction() IPv6OptionUnknownAction {
return ipv6UnknownActionFromIdentifier(o.Identifier)
}
// isIPv6ExtHdrOption implements IPv6ExtHdrOption.isIPv6ExtHdrOption.
func (*IPv6UnknownExtHdrOption) isIPv6ExtHdrOption() {}
// Next returns the next option in the options data.
//
// If the next item is not a known extension header option,
// IPv6UnknownExtHdrOption will be returned with the option identifier and data.
//
// The return is of the format (option, done, error). done will be true when
// Next is unable to return anything because the iterator has reached the end of
// the options data, or an error occurred.
func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error) {
for {
i.optionOffset = i.nextOptionOffset
temp, err := i.reader.ReadByte()
if err != nil {
// If we can't read the first byte of a new option, then we know the
// options buffer has been exhausted and we are done iterating.
return nil, true, nil
}
id := IPv6ExtHdrOptionIdentifier(temp)
// If the option identifier indicates the option is a Pad1 option, then we
// know the option does not have Length and Data fields. End processing of
// the Pad1 option and continue processing the buffer as a new option.
if id == ipv6Pad1ExtHdrOptionIdentifier {
i.nextOptionOffset = i.optionOffset + 1
continue
}
length, err := i.reader.ReadByte()
if err != nil {
if err != io.EOF {
// ReadByte should only ever return nil or io.EOF.
panic(fmt.Sprintf("unexpected error when reading the option's Length field for option with id = %d: %s", id, err))
}
// We use io.ErrUnexpectedEOF as exhausting the buffer is unexpected once
// we start parsing an option; we expect the reader to contain enough
// bytes for the whole option.
return nil, true, fmt.Errorf("error when reading the option's Length field for option with id = %d: %w", id, io.ErrUnexpectedEOF)
}
// Do we have enough bytes in the reader for the next option?
if n := i.reader.Size(); n < int(length) {
// Consume the remaining buffer.
i.reader.TrimFront(i.reader.Size())
// We return the same error as if we failed to read a non-padding option
// so consumers of this iterator don't need to differentiate between
// padding and non-padding options.
return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF)
}
i.nextOptionOffset = i.optionOffset + uint32(length) + 1 /* option ID */ + 1 /* length byte */
switch id {
case ipv6PadNExtHdrOptionIdentifier:
// Special-case the variable length padding option to avoid a copy.
i.reader.TrimFront(int(length))
continue
case ipv6RouterAlertHopByHopOptionIdentifier:
var routerAlertValue [ipv6RouterAlertPayloadLength]byte
if n, err := io.ReadFull(i.reader, routerAlertValue[:]); err != nil {
switch err {
case io.EOF, io.ErrUnexpectedEOF:
return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption)
default:
return nil, true, fmt.Errorf("read %d out of %d option data bytes for router alert option: %w", n, ipv6RouterAlertPayloadLength, err)
}
} else if n != int(length) {
return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption)
}
return &IPv6RouterAlertOption{Value: IPv6RouterAlertValue(binary.BigEndian.Uint16(routerAlertValue[:]))}, false, nil
default:
bytes := buffer.NewView(int(length))
if n, err := io.CopyN(bytes, i.reader, int64(length)); err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err)
}
return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil
}
}
}
// IPv6HopByHopOptionsExtHdr is a buffer holding the Hop By Hop Options
// extension header.
type IPv6HopByHopOptionsExtHdr struct {
ipv6OptionsExtHdr
}
// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
func (IPv6HopByHopOptionsExtHdr) isIPv6PayloadHeader() {}
// IPv6DestinationOptionsExtHdr is a buffer holding the Destination Options
// extension header.
type IPv6DestinationOptionsExtHdr struct {
ipv6OptionsExtHdr
}
// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
func (IPv6DestinationOptionsExtHdr) isIPv6PayloadHeader() {}
// IPv6RoutingExtHdr is a buffer holding the Routing extension header specific
// data as outlined in RFC 8200 section 4.4.
type IPv6RoutingExtHdr struct {
Buf *buffer.View
}
// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
func (IPv6RoutingExtHdr) isIPv6PayloadHeader() {}
// Release implements IPv6PayloadHeader.Release.
func (b IPv6RoutingExtHdr) Release() {
b.Buf.Release()
}
// SegmentsLeft returns the Segments Left field.
func (b IPv6RoutingExtHdr) SegmentsLeft() uint8 {
return b.Buf.AsSlice()[ipv6RoutingExtHdrSegmentsLeftIdx]
}
// IPv6FragmentExtHdr is a buffer holding the Fragment extension header specific
// data as outlined in RFC 8200 section 4.5.
//
// Note, the buffer does not include the Next Header and Reserved fields.
type IPv6FragmentExtHdr [6]byte
// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
func (IPv6FragmentExtHdr) isIPv6PayloadHeader() {}
// Release implements IPv6PayloadHeader.Release.
func (IPv6FragmentExtHdr) Release() {}
// FragmentOffset returns the Fragment Offset field.
//
// This value indicates where the buffer following the Fragment extension header
// starts in the target (reassembled) packet.
func (b IPv6FragmentExtHdr) FragmentOffset() uint16 {
return binary.BigEndian.Uint16(b[ipv6FragmentExtHdrFragmentOffsetOffset:]) >> ipv6FragmentExtHdrFragmentOffsetShift
}
// More returns the More (M) flag.
//
// This indicates whether any fragments are expected to succeed b.
func (b IPv6FragmentExtHdr) More() bool {
return b[ipv6FragmentExtHdrFlagsIdx]&ipv6FragmentExtHdrMFlagMask != 0
}
// ID returns the Identification field.
//
// This value is used to uniquely identify the packet, between a
// source and destination.
func (b IPv6FragmentExtHdr) ID() uint32 {
return binary.BigEndian.Uint32(b[ipv6FragmentExtHdrIdentificationOffset:])
}
// IsAtomic returns whether the fragment header indicates an atomic fragment. An
// atomic fragment is a fragment that contains all the data required to
// reassemble a full packet.
func (b IPv6FragmentExtHdr) IsAtomic() bool {
return !b.More() && b.FragmentOffset() == 0
}
// IPv6PayloadIterator is an iterator over the contents of an IPv6 payload.
//
// The IPv6 payload may contain IPv6 extension headers before any upper layer
// data.
//
// Note, between when an IPv6PayloadIterator is obtained and last used, no
// changes to the payload may happen. Doing so may cause undefined and
// unexpected behaviour. It is fine to obtain an IPv6PayloadIterator, iterate
// over the first few headers then modify the backing payload so long as the
// IPv6PayloadIterator obtained before modification is no longer used.
type IPv6PayloadIterator struct {
// The identifier of the next header to parse.
nextHdrIdentifier IPv6ExtensionHeaderIdentifier
payload buffer.Buffer
// Indicates to the iterator that it should return the remaining payload as a
// raw payload on the next call to Next.
forceRaw bool
// headerOffset is the offset of the beginning of the current extension
// header starting from the beginning of the fixed header.
headerOffset uint32
// parseOffset is the byte offset into the current extension header of the
// field we are currently examining. It can be added to the header offset
// if the absolute offset within the packet is required.
parseOffset uint32
// nextOffset is the offset of the next header.
nextOffset uint32
}
// HeaderOffset returns the offset to the start of the extension
// header most recently processed.
func (i IPv6PayloadIterator) HeaderOffset() uint32 {
return i.headerOffset
}
// ParseOffset returns the number of bytes successfully parsed.
func (i IPv6PayloadIterator) ParseOffset() uint32 {
return i.headerOffset + i.parseOffset
}
// MakeIPv6PayloadIterator returns an iterator over the IPv6 payload containing
// extension headers, or a raw payload if the payload cannot be parsed. The
// iterator takes ownership of the payload.
func MakeIPv6PayloadIterator(nextHdrIdentifier IPv6ExtensionHeaderIdentifier, payload buffer.Buffer) IPv6PayloadIterator {
return IPv6PayloadIterator{
nextHdrIdentifier: nextHdrIdentifier,
payload: payload,
nextOffset: IPv6FixedHeaderSize,
}
}
// Release frees the resources owned by the iterator.
func (i *IPv6PayloadIterator) Release() {
i.payload.Release()
}
// AsRawHeader returns the remaining payload of i as a raw header and
// optionally consumes the iterator.
//
// If consume is true, calls to Next after calling AsRawHeader on i will
// indicate that the iterator is done. The returned header takes ownership of
// its payload.
func (i *IPv6PayloadIterator) AsRawHeader(consume bool) IPv6RawPayloadHeader {
identifier := i.nextHdrIdentifier
var buf buffer.Buffer
if consume {
// Since we consume the iterator, we return the payload as is.
buf = i.payload
// Mark i as done, but keep track of where we were for error reporting.
*i = IPv6PayloadIterator{
nextHdrIdentifier: IPv6NoNextHeaderIdentifier,
headerOffset: i.headerOffset,
nextOffset: i.nextOffset,
}
} else {
buf = i.payload.Clone()
}
return IPv6RawPayloadHeader{Identifier: identifier, Buf: buf}
}
// Next returns the next item in the payload.
//
// If the next item is not a known IPv6 extension header, IPv6RawPayloadHeader
// will be returned with the remaining bytes and next header identifier.
//
// The return is of the format (header, done, error). done will be true when
// Next is unable to return anything because the iterator has reached the end of
// the payload, or an error occurred.
func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) {
i.headerOffset = i.nextOffset
i.parseOffset = 0
// We could be forced to return i as a raw header when the previous header was
// a fragment extension header as the data following the fragment extension
// header may not be complete.
if i.forceRaw {
return i.AsRawHeader(true /* consume */), false, nil
}
// Is the header we are parsing a known extension header?
switch i.nextHdrIdentifier {
case IPv6HopByHopOptionsExtHdrIdentifier:
nextHdrIdentifier, view, err := i.nextHeaderData(false /* fragmentHdr */, nil)
if err != nil {
return nil, true, err
}
i.nextHdrIdentifier = nextHdrIdentifier
return IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr{view}}, false, nil
case IPv6RoutingExtHdrIdentifier:
nextHdrIdentifier, view, err := i.nextHeaderData(false /* fragmentHdr */, nil)
if err != nil {
return nil, true, err
}
i.nextHdrIdentifier = nextHdrIdentifier
return IPv6RoutingExtHdr{view}, false, nil
case IPv6FragmentExtHdrIdentifier:
var data [6]byte
// We ignore the returned bytes because we know the fragment extension
// header specific data will fit in data.
nextHdrIdentifier, _, err := i.nextHeaderData(true /* fragmentHdr */, data[:])
if err != nil {
return nil, true, err
}
fragmentExtHdr := IPv6FragmentExtHdr(data)
// If the packet is not the first fragment, do not attempt to parse anything
// after the fragment extension header as the payload following the fragment
// extension header should not contain any headers; the first fragment must
// hold all the headers up to and including any upper layer headers, as per
// RFC 8200 section 4.5.
if fragmentExtHdr.FragmentOffset() != 0 {
i.forceRaw = true
}
i.nextHdrIdentifier = nextHdrIdentifier
return fragmentExtHdr, false, nil
case IPv6DestinationOptionsExtHdrIdentifier:
nextHdrIdentifier, view, err := i.nextHeaderData(false /* fragmentHdr */, nil)
if err != nil {
return nil, true, err
}
i.nextHdrIdentifier = nextHdrIdentifier
return IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr{view}}, false, nil
case IPv6NoNextHeaderIdentifier:
// This indicates the end of the IPv6 payload.
return nil, true, nil
default:
// The header we are parsing is not a known extension header. Return the
// raw payload.
return i.AsRawHeader(true /* consume */), false, nil
}
}
// NextHeaderIdentifier returns the identifier of the header next returned by
// it.Next().
func (i *IPv6PayloadIterator) NextHeaderIdentifier() IPv6ExtensionHeaderIdentifier {
return i.nextHdrIdentifier
}
// nextHeaderData returns the extension header's Next Header field and raw data.
//
// fragmentHdr indicates that the extension header being parsed is the Fragment
// extension header so the Length field should be ignored as it is Reserved
// for the Fragment extension header.
//
// If bytes is not nil, extension header specific data will be read into bytes
// if it has enough capacity. If bytes is provided but does not have enough
// capacity for the data, nextHeaderData will panic.
func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IPv6ExtensionHeaderIdentifier, *buffer.View, error) {
// We ignore the number of bytes read because we know we will only ever read
// at max 1 bytes since rune has a length of 1. If we read 0 bytes, the Read
// would return io.EOF to indicate that io.Reader has reached the end of the
// payload.
rdr := i.payload.AsBufferReader()
nextHdrIdentifier, err := rdr.ReadByte()
if err != nil {
return 0, nil, fmt.Errorf("error when reading the Next Header field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
}
i.parseOffset++
var length uint8
length, err = rdr.ReadByte()
if err != nil {
if fragmentHdr {
return 0, nil, fmt.Errorf("error when reading the Length field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
}
return 0, nil, fmt.Errorf("error when reading the Reserved field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
}
if fragmentHdr {
length = 0
}
// Make parseOffset point to the first byte of the Extension Header
// specific data.
i.parseOffset++
// length is in 8 byte chunks but doesn't include the first one.
// See RFC 8200 for each header type, sections 4.3-4.6 and the requirement
// in section 4.8 for new extension headers at the top of page 24.
// [ Hdr Ext Len ] ... Length of the Destination Options header in 8-octet
// units, not including the first 8 octets.
i.nextOffset += uint32((length + 1) * ipv6ExtHdrLenBytesPerUnit)
bytesLen := int(length)*ipv6ExtHdrLenBytesPerUnit + ipv6ExtHdrLenBytesExcluded
if fragmentHdr {
if n := len(bytes); n < bytesLen {
panic(fmt.Sprintf("bytes only has space for %d bytes but need space for %d bytes (length = %d) for extension header with id = %d", n, bytesLen, length, i.nextHdrIdentifier))
}
if n, err := io.ReadFull(&rdr, bytes); err != nil {
return 0, nil, fmt.Errorf("read %d out of %d extension header data bytes (length = %d) for header with id = %d: %w", n, bytesLen, length, i.nextHdrIdentifier, err)
}
return IPv6ExtensionHeaderIdentifier(nextHdrIdentifier), nil, nil
}
v := buffer.NewView(bytesLen)
if n, err := io.CopyN(v, &rdr, int64(bytesLen)); err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
v.Release()
return 0, nil, fmt.Errorf("read %d out of %d extension header data bytes (length = %d) for header with id = %d: %w", n, bytesLen, length, i.nextHdrIdentifier, err)
}
return IPv6ExtensionHeaderIdentifier(nextHdrIdentifier), v, nil
}
// IPv6SerializableExtHdr provides serialization for IPv6 extension
// headers.
type IPv6SerializableExtHdr interface {
// identifier returns the assigned IPv6 header identifier for this extension
// header.
identifier() IPv6ExtensionHeaderIdentifier
// length returns the total serialized length in bytes of this extension
// header, including the common next header and length fields.
length() int
// serializeInto serializes the receiver into the provided byte
// buffer and with the provided nextHeader value.
//
// Note, the caller MUST provide a byte buffer with size of at least
// length. Implementers of this function may assume that the byte buffer
// is of sufficient size. serializeInto MAY panic if the provided byte
// buffer is not of sufficient size.
//
// serializeInto returns the number of bytes that was used to serialize the
// receiver. Implementers must only use the number of bytes required to
// serialize the receiver. Callers MAY provide a larger buffer than required
// to serialize into.
serializeInto(nextHeader uint8, b []byte) int
}
var _ IPv6SerializableExtHdr = (*IPv6SerializableHopByHopExtHdr)(nil)
// IPv6SerializableHopByHopExtHdr implements serialization of the Hop by Hop
// options extension header.
type IPv6SerializableHopByHopExtHdr []IPv6SerializableHopByHopOption
const (
// ipv6HopByHopExtHdrNextHeaderOffset is the offset of the next header field
// in a hop by hop extension header as defined in RFC 8200 section 4.3.
ipv6HopByHopExtHdrNextHeaderOffset = 0
// ipv6HopByHopExtHdrLengthOffset is the offset of the length field in a hop
// by hop extension header as defined in RFC 8200 section 4.3.
ipv6HopByHopExtHdrLengthOffset = 1
// ipv6HopByHopExtHdrPayloadOffset is the offset of the options in a hop by
// hop extension header as defined in RFC 8200 section 4.3.
ipv6HopByHopExtHdrOptionsOffset = 2
// ipv6HopByHopExtHdrUnaccountedLenWords is the implicit number of 8-octet
// words in a hop by hop extension header's length field, as stated in RFC
// 8200 section 4.3:
// Length of the Hop-by-Hop Options header in 8-octet units,
// not including the first 8 octets.
ipv6HopByHopExtHdrUnaccountedLenWords = 1
)
// identifier implements IPv6SerializableExtHdr.
func (IPv6SerializableHopByHopExtHdr) identifier() IPv6ExtensionHeaderIdentifier {
return IPv6HopByHopOptionsExtHdrIdentifier
}
// length implements IPv6SerializableExtHdr.
func (h IPv6SerializableHopByHopExtHdr) length() int {
var total int
for _, opt := range h {
align, alignOffset := opt.alignment()
total += ipv6OptionsAlignmentPadding(total, align, alignOffset)
total += ipv6ExtHdrOptionPayloadOffset + int(opt.length())
}
// Account for next header and total length fields and add padding.
return padIPv6OptionsLength(ipv6HopByHopExtHdrOptionsOffset + total)
}
// serializeInto implements IPv6SerializableExtHdr.
func (h IPv6SerializableHopByHopExtHdr) serializeInto(nextHeader uint8, b []byte) int {
optBuffer := b[ipv6HopByHopExtHdrOptionsOffset:]
totalLength := ipv6HopByHopExtHdrOptionsOffset
for _, opt := range h {
// Calculate alignment requirements and pad buffer if necessary.
align, alignOffset := opt.alignment()
padLen := ipv6OptionsAlignmentPadding(totalLength, align, alignOffset)
if padLen != 0 {
padIPv6Option(optBuffer[:padLen])
totalLength += padLen
optBuffer = optBuffer[padLen:]
}
l := opt.serializeInto(optBuffer[ipv6ExtHdrOptionPayloadOffset:])
optBuffer[ipv6ExtHdrOptionTypeOffset] = uint8(opt.identifier())
optBuffer[ipv6ExtHdrOptionLengthOffset] = l
l += ipv6ExtHdrOptionPayloadOffset
totalLength += int(l)
optBuffer = optBuffer[l:]
}
padded := padIPv6OptionsLength(totalLength)
if padded != totalLength {
padIPv6Option(optBuffer[:padded-totalLength])
totalLength = padded
}
wordsLen := totalLength/ipv6ExtHdrLenBytesPerUnit - ipv6HopByHopExtHdrUnaccountedLenWords
if wordsLen > math.MaxUint8 {
panic(fmt.Sprintf("IPv6 hop by hop options too large: %d+1 64-bit words", wordsLen))
}
b[ipv6HopByHopExtHdrNextHeaderOffset] = nextHeader
b[ipv6HopByHopExtHdrLengthOffset] = uint8(wordsLen)
return totalLength
}
// IPv6SerializableHopByHopOption provides serialization for hop by hop options.
type IPv6SerializableHopByHopOption interface {
// identifier returns the option identifier of this Hop by Hop option.
identifier() IPv6ExtHdrOptionIdentifier
// length returns the *payload* size of the option (not considering the type
// and length fields).
length() uint8
// alignment returns the alignment requirements from this option.
//
// Alignment requirements take the form [align]n + offset as specified in
// RFC 8200 section 4.2. The alignment requirement is on the offset between
// the option type byte and the start of the hop by hop header.
//
// align must be a power of 2.
alignment() (align int, offset int)
// serializeInto serializes the receiver into the provided byte
// buffer.
//
// Note, the caller MUST provide a byte buffer with size of at least
// length. Implementers of this function may assume that the byte buffer
// is of sufficient size. serializeInto MAY panic if the provided byte
// buffer is not of sufficient size.
//
// serializeInto will return the number of bytes that was used to
// serialize the receiver. Implementers must only use the number of
// bytes required to serialize the receiver. Callers MAY provide a
// larger buffer than required to serialize into.
serializeInto([]byte) uint8
}
var _ IPv6SerializableHopByHopOption = (*IPv6RouterAlertOption)(nil)
// IPv6RouterAlertOption is the IPv6 Router alert Hop by Hop option defined in
// RFC 2711 section 2.1.
type IPv6RouterAlertOption struct {
Value IPv6RouterAlertValue
}
// IPv6RouterAlertValue is the payload of an IPv6 Router Alert option.
type IPv6RouterAlertValue uint16
const (
// IPv6RouterAlertMLD indicates a datagram containing a Multicast Listener
// Discovery message as defined in RFC 2711 section 2.1.
IPv6RouterAlertMLD IPv6RouterAlertValue = 0
// IPv6RouterAlertRSVP indicates a datagram containing an RSVP message as
// defined in RFC 2711 section 2.1.
IPv6RouterAlertRSVP IPv6RouterAlertValue = 1
// IPv6RouterAlertActiveNetworks indicates a datagram containing an Active
// Networks message as defined in RFC 2711 section 2.1.
IPv6RouterAlertActiveNetworks IPv6RouterAlertValue = 2
// ipv6RouterAlertPayloadLength is the length of the Router Alert payload
// as defined in RFC 2711.
ipv6RouterAlertPayloadLength = 2
// ipv6RouterAlertAlignmentRequirement is the alignment requirement for the
// Router Alert option defined as 2n+0 in RFC 2711.
ipv6RouterAlertAlignmentRequirement = 2
// ipv6RouterAlertAlignmentOffsetRequirement is the alignment offset
// requirement for the Router Alert option defined as 2n+0 in RFC 2711 section
// 2.1.
ipv6RouterAlertAlignmentOffsetRequirement = 0
)
// UnknownAction implements IPv6ExtHdrOption.
func (*IPv6RouterAlertOption) UnknownAction() IPv6OptionUnknownAction {
return ipv6UnknownActionFromIdentifier(ipv6RouterAlertHopByHopOptionIdentifier)
}
// isIPv6ExtHdrOption implements IPv6ExtHdrOption.
func (*IPv6RouterAlertOption) isIPv6ExtHdrOption() {}
// identifier implements IPv6SerializableHopByHopOption.
func (*IPv6RouterAlertOption) identifier() IPv6ExtHdrOptionIdentifier {
return ipv6RouterAlertHopByHopOptionIdentifier
}
// length implements IPv6SerializableHopByHopOption.
func (*IPv6RouterAlertOption) length() uint8 {
return ipv6RouterAlertPayloadLength
}
// alignment implements IPv6SerializableHopByHopOption.
func (*IPv6RouterAlertOption) alignment() (int, int) {
// From RFC 2711 section 2.1:
// Alignment requirement: 2n+0.
return ipv6RouterAlertAlignmentRequirement, ipv6RouterAlertAlignmentOffsetRequirement
}
// serializeInto implements IPv6SerializableHopByHopOption.
func (o *IPv6RouterAlertOption) serializeInto(b []byte) uint8 {
binary.BigEndian.PutUint16(b, uint16(o.Value))
return ipv6RouterAlertPayloadLength
}
// IPv6ExtHdrSerializer provides serialization of IPv6 extension headers.
type IPv6ExtHdrSerializer []IPv6SerializableExtHdr
// Serialize serializes the provided list of IPv6 extension headers into b.
//
// Note, b must be of sufficient size to hold all the headers in s. See
// IPv6ExtHdrSerializer.Length for details on the getting the total size of a
// serialized IPv6ExtHdrSerializer.
//
// Serialize may panic if b is not of sufficient size to hold all the options
// in s.
//
// Serialize takes the transportProtocol value to be used as the last extension
// header's Next Header value and returns the header identifier of the first
// serialized extension header and the total serialized length.
func (s IPv6ExtHdrSerializer) Serialize(transportProtocol tcpip.TransportProtocolNumber, b []byte) (uint8, int) {
nextHeader := uint8(transportProtocol)
if len(s) == 0 {
return nextHeader, 0
}
var totalLength int
for i, h := range s[:len(s)-1] {
length := h.serializeInto(uint8(s[i+1].identifier()), b)
b = b[length:]
totalLength += length
}
totalLength += s[len(s)-1].serializeInto(nextHeader, b)
return uint8(s[0].identifier()), totalLength
}
// Length returns the total number of bytes required to serialize the extension
// headers.
func (s IPv6ExtHdrSerializer) Length() int {
var totalLength int
for _, h := range s {
totalLength += h.length()
}
return totalLength
}

View File

@@ -0,0 +1,158 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"encoding/binary"
"gvisor.dev/gvisor/pkg/tcpip"
)
const (
nextHdrFrag = 0
fragOff = 2
more = 3
idV6 = 4
)
var _ IPv6SerializableExtHdr = (*IPv6SerializableFragmentExtHdr)(nil)
// IPv6SerializableFragmentExtHdr is used to serialize an IPv6 fragment
// extension header as defined in RFC 8200 section 4.5.
type IPv6SerializableFragmentExtHdr struct {
// FragmentOffset is the "fragment offset" field of an IPv6 fragment.
FragmentOffset uint16
// M is the "more" field of an IPv6 fragment.
M bool
// Identification is the "identification" field of an IPv6 fragment.
Identification uint32
}
// identifier implements IPv6SerializableFragmentExtHdr.
func (h *IPv6SerializableFragmentExtHdr) identifier() IPv6ExtensionHeaderIdentifier {
return IPv6FragmentHeader
}
// length implements IPv6SerializableFragmentExtHdr.
func (h *IPv6SerializableFragmentExtHdr) length() int {
return IPv6FragmentHeaderSize
}
// serializeInto implements IPv6SerializableFragmentExtHdr.
func (h *IPv6SerializableFragmentExtHdr) serializeInto(nextHeader uint8, b []byte) int {
// Prevent too many bounds checks.
_ = b[IPv6FragmentHeaderSize:]
binary.BigEndian.PutUint32(b[idV6:], h.Identification)
binary.BigEndian.PutUint16(b[fragOff:], h.FragmentOffset<<ipv6FragmentExtHdrFragmentOffsetShift)
b[nextHdrFrag] = nextHeader
if h.M {
b[more] |= ipv6FragmentExtHdrMFlagMask
}
return IPv6FragmentHeaderSize
}
// IPv6Fragment represents an ipv6 fragment header stored in a byte array.
// Most of the methods of IPv6Fragment access to the underlying slice without
// checking the boundaries and could panic because of 'index out of range'.
// Always call IsValid() to validate an instance of IPv6Fragment before using other methods.
type IPv6Fragment []byte
const (
// IPv6FragmentHeader header is the number used to specify that the next
// header is a fragment header, per RFC 2460.
IPv6FragmentHeader = 44
// IPv6FragmentHeaderSize is the size of the fragment header.
IPv6FragmentHeaderSize = 8
)
// IsValid performs basic validation on the fragment header.
func (b IPv6Fragment) IsValid() bool {
return len(b) >= IPv6FragmentHeaderSize
}
// NextHeader returns the value of the "next header" field of the ipv6 fragment.
func (b IPv6Fragment) NextHeader() uint8 {
return b[nextHdrFrag]
}
// FragmentOffset returns the "fragment offset" field of the ipv6 fragment.
func (b IPv6Fragment) FragmentOffset() uint16 {
return binary.BigEndian.Uint16(b[fragOff:]) >> 3
}
// More returns the "more" field of the ipv6 fragment.
func (b IPv6Fragment) More() bool {
return b[more]&1 > 0
}
// Payload implements Network.Payload.
func (b IPv6Fragment) Payload() []byte {
return b[IPv6FragmentHeaderSize:]
}
// ID returns the value of the identifier field of the ipv6 fragment.
func (b IPv6Fragment) ID() uint32 {
return binary.BigEndian.Uint32(b[idV6:])
}
// TransportProtocol implements Network.TransportProtocol.
func (b IPv6Fragment) TransportProtocol() tcpip.TransportProtocolNumber {
return tcpip.TransportProtocolNumber(b.NextHeader())
}
// The functions below have been added only to satisfy the Network interface.
// Checksum is not supported by IPv6Fragment.
func (b IPv6Fragment) Checksum() uint16 {
panic("not supported")
}
// SourceAddress is not supported by IPv6Fragment.
func (b IPv6Fragment) SourceAddress() tcpip.Address {
panic("not supported")
}
// DestinationAddress is not supported by IPv6Fragment.
func (b IPv6Fragment) DestinationAddress() tcpip.Address {
panic("not supported")
}
// SetSourceAddress is not supported by IPv6Fragment.
func (b IPv6Fragment) SetSourceAddress(tcpip.Address) {
panic("not supported")
}
// SetDestinationAddress is not supported by IPv6Fragment.
func (b IPv6Fragment) SetDestinationAddress(tcpip.Address) {
panic("not supported")
}
// SetChecksum is not supported by IPv6Fragment.
func (b IPv6Fragment) SetChecksum(uint16) {
panic("not supported")
}
// TOS is not supported by IPv6Fragment.
func (b IPv6Fragment) TOS() (uint8, uint32) {
panic("not supported")
}
// SetTOS is not supported by IPv6Fragment.
func (b IPv6Fragment) SetTOS(t uint8, l uint32) {
panic("not supported")
}

View File

@@ -0,0 +1,103 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"encoding/binary"
"fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
)
const (
// MLDMinimumSize is the minimum size for an MLD message.
MLDMinimumSize = 20
// MLDHopLimit is the Hop Limit for all IPv6 packets with an MLD message, as
// per RFC 2710 section 3.
MLDHopLimit = 1
// mldMaximumResponseDelayOffset is the offset to the Maximum Response Delay
// field within MLD.
mldMaximumResponseDelayOffset = 0
// mldMulticastAddressOffset is the offset to the Multicast Address field
// within MLD.
mldMulticastAddressOffset = 4
)
// MLD is a Multicast Listener Discovery message in an ICMPv6 packet.
//
// MLD will only contain the body of an ICMPv6 packet.
//
// As per RFC 2710 section 3, MLD messages have the following format (MLD only
// holds the bytes after the first four bytes in the diagram below):
//
// 0 1 2 3
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Type | Code | Checksum |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Maximum Response Delay | Reserved |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | |
// + +
// | |
// + Multicast Address +
// | |
// + +
// | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
type MLD []byte
// MaximumResponseDelay returns the Maximum Response Delay.
func (m MLD) MaximumResponseDelay() time.Duration {
// As per RFC 2710 section 3.4:
//
// The Maximum Response Delay field is meaningful only in Query
// messages, and specifies the maximum allowed delay before sending a
// responding Report, in units of milliseconds. In all other messages,
// it is set to zero by the sender and ignored by receivers.
return time.Duration(binary.BigEndian.Uint16(m[mldMaximumResponseDelayOffset:])) * time.Millisecond
}
// SetMaximumResponseDelay sets the Maximum Response Delay field.
//
// maxRespDelayMS is the value in milliseconds.
func (m MLD) SetMaximumResponseDelay(maxRespDelayMS uint16) {
binary.BigEndian.PutUint16(m[mldMaximumResponseDelayOffset:], maxRespDelayMS)
}
// MulticastAddress returns the Multicast Address.
func (m MLD) MulticastAddress() tcpip.Address {
// As per RFC 2710 section 3.5:
//
// In a Query message, the Multicast Address field is set to zero when
// sending a General Query, and set to a specific IPv6 multicast address
// when sending a Multicast-Address-Specific Query.
//
// In a Report or Done message, the Multicast Address field holds a
// specific IPv6 multicast address to which the message sender is
// listening or is ceasing to listen, respectively.
return tcpip.AddrFrom16([16]byte(m[mldMulticastAddressOffset:][:IPv6AddressSize]))
}
// SetMulticastAddress sets the Multicast Address field.
func (m MLD) SetMulticastAddress(multicastAddress tcpip.Address) {
if n := copy(m[mldMulticastAddressOffset:], multicastAddress.AsSlice()); n != IPv6AddressSize {
panic(fmt.Sprintf("copied %d bytes, expected to copy %d bytes", n, IPv6AddressSize))
}
}

View File

@@ -0,0 +1,541 @@
// Copyright 2022 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"bytes"
"encoding/binary"
"fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
)
const (
// MLDv2QueryMinimumSize is the minimum size for an MLDv2 message.
MLDv2QueryMinimumSize = 24
mldv2QueryMaximumResponseCodeOffset = 0
mldv2QueryResvSQRVOffset = 20
mldv2QueryQRVMask = 0b111
mldv2QueryQQICOffset = 21
// mldv2QueryNumberOfSourcesOffset is the offset to the Number of Sources
// field within MLDv2Query.
mldv2QueryNumberOfSourcesOffset = 22
// MLDv2ReportMinimumSize is the minimum size of an MLDv2 report.
MLDv2ReportMinimumSize = 24
// mldv2QuerySourcesOffset is the offset to the Sources field within
// MLDv2Query.
mldv2QuerySourcesOffset = 24
)
var (
// MLDv2RoutersAddress is the address to send MLDv2 reports to.
//
// As per RFC 3810 section 5.2.14,
//
// Version 2 Multicast Listener Reports are sent with an IP destination
// address of FF02:0:0:0:0:0:0:16, to which all MLDv2-capable multicast
// routers listen (see section 11 for IANA considerations related to
// this special destination address).
MLDv2RoutersAddress = tcpip.AddrFrom16([16]byte{0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x16})
)
// MLDv2Query is a Multicast Listener Discovery Version 2 Query message in an
// ICMPv6 packet.
//
// MLDv2Query will only contain the body of an ICMPv6 packet.
//
// As per RFC 3810 section 5.1, MLDv2 Query messages have the following format
// (MLDv2Query only holds the bytes after the first four bytes in the diagram
// below):
//
// 0 1 2 3
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Type = 130 | Code | Checksum |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Maximum Response Code | Reserved |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | |
// * *
// | |
// * Multicast Address *
// | |
// * *
// | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Resv |S| QRV | QQIC | Number of Sources (N) |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | |
// * *
// | |
// * Source Address [1] *
// | |
// * *
// | |
// +- -+
// | |
// * *
// | |
// * Source Address [2] *
// | |
// * *
// | |
// +- . -+
// . . .
// . . .
// +- -+
// | |
// * *
// | |
// * Source Address [N] *
// | |
// * *
// | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
type MLDv2Query MLD
// MaximumResponseCode returns the Maximum Response Code
func (m MLDv2Query) MaximumResponseCode() uint16 {
return binary.BigEndian.Uint16(m[mldv2QueryMaximumResponseCodeOffset:])
}
// MLDv2MaximumResponseDelay returns the Maximum Response Delay in an MLDv2
// Maximum Response Code.
//
// As per RFC 3810 section 5.1.3,
//
// The Maximum Response Code field specifies the maximum time allowed
// before sending a responding Report. The actual time allowed, called
// the Maximum Response Delay, is represented in units of milliseconds,
// and is derived from the Maximum Response Code as follows:
//
// If Maximum Response Code < 32768,
// Maximum Response Delay = Maximum Response Code
//
// If Maximum Response Code >=32768, Maximum Response Code represents a
// floating-point value as follows:
//
// 0 1 2 3 4 5 6 7 8 9 A B C D E F
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// |1| exp | mant |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
//
// Maximum Response Delay = (mant | 0x1000) << (exp+3)
//
// Small values of Maximum Response Delay allow MLDv2 routers to tune
// the "leave latency" (the time between the moment the last node on a
// link ceases to listen to a specific multicast address and the moment
// the routing protocol is notified that there are no more listeners for
// that address). Larger values, especially in the exponential range,
// allow the tuning of the burstiness of MLD traffic on a link.
func MLDv2MaximumResponseDelay(codeRaw uint16) time.Duration {
code := time.Duration(codeRaw)
if code < 32768 {
return code * time.Millisecond
}
const mantBits = 12
const expMask = 0b111
exp := (code >> mantBits) & expMask
mant := code & ((1 << mantBits) - 1)
return (mant | 0x1000) << (exp + 3) * time.Millisecond
}
// MulticastAddress returns the Multicast Address.
func (m MLDv2Query) MulticastAddress() tcpip.Address {
// As per RFC 2710 section 3.5:
//
// In a Query message, the Multicast Address field is set to zero when
// sending a General Query, and set to a specific IPv6 multicast address
// when sending a Multicast-Address-Specific Query.
//
// In a Report or Done message, the Multicast Address field holds a
// specific IPv6 multicast address to which the message sender is
// listening or is ceasing to listen, respectively.
return tcpip.AddrFrom16([16]byte(m[mldMulticastAddressOffset:][:IPv6AddressSize]))
}
// QuerierRobustnessVariable returns the querier's robustness variable.
func (m MLDv2Query) QuerierRobustnessVariable() uint8 {
return m[mldv2QueryResvSQRVOffset] & mldv2QueryQRVMask
}
// QuerierQueryInterval returns the querier's query interval.
func (m MLDv2Query) QuerierQueryInterval() time.Duration {
return mldv2AndIGMPv3QuerierQueryCodeToInterval(m[mldv2QueryQQICOffset])
}
// Sources returns an iterator over source addresses in the query.
//
// Returns false if the message cannot hold the expected number of sources.
func (m MLDv2Query) Sources() (AddressIterator, bool) {
return makeAddressIterator(
m[mldv2QuerySourcesOffset:],
binary.BigEndian.Uint16(m[mldv2QueryNumberOfSourcesOffset:]),
IPv6AddressSize,
)
}
// MLDv2ReportRecordType is the type of an MLDv2 multicast address record
// found in an MLDv2 report, as per RFC 3810 section 5.2.12.
type MLDv2ReportRecordType int
// MLDv2 multicast address record types, as per RFC 3810 section 5.2.12.
const (
MLDv2ReportRecordModeIsInclude MLDv2ReportRecordType = 1
MLDv2ReportRecordModeIsExclude MLDv2ReportRecordType = 2
MLDv2ReportRecordChangeToIncludeMode MLDv2ReportRecordType = 3
MLDv2ReportRecordChangeToExcludeMode MLDv2ReportRecordType = 4
MLDv2ReportRecordAllowNewSources MLDv2ReportRecordType = 5
MLDv2ReportRecordBlockOldSources MLDv2ReportRecordType = 6
)
const (
mldv2ReportMulticastAddressRecordMinimumSize = 20
mldv2ReportMulticastAddressRecordTypeOffset = 0
mldv2ReportMulticastAddressRecordAuxDataLenOffset = 1
mldv2ReportMulticastAddressRecordAuxDataLenUnits = 4
mldv2ReportMulticastAddressRecordNumberOfSourcesOffset = 2
mldv2ReportMulticastAddressRecordMulticastAddressOffset = 4
mldv2ReportMulticastAddressRecordSourcesOffset = 20
)
// MLDv2ReportMulticastAddressRecordSerializer is an MLDv2 Multicast Address
// Record serializer.
//
// As per RFC 3810 section 5.2, a Multicast Address Record has the following
// internal format:
//
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Record Type | Aux Data Len | Number of Sources (N) |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | |
// * *
// | |
// * Multicast Address *
// | |
// * *
// | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | |
// * *
// | |
// * Source Address [1] *
// | |
// * *
// | |
// +- -+
// | |
// * *
// | |
// * Source Address [2] *
// | |
// * *
// | |
// +- -+
// . . .
// . . .
// . . .
// +- -+
// | |
// * *
// | |
// * Source Address [N] *
// | |
// * *
// | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | |
// . .
// . Auxiliary Data .
// . .
// | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
type MLDv2ReportMulticastAddressRecordSerializer struct {
RecordType MLDv2ReportRecordType
MulticastAddress tcpip.Address
Sources []tcpip.Address
}
// Length returns the number of bytes this serializer would occupy.
func (s *MLDv2ReportMulticastAddressRecordSerializer) Length() int {
return mldv2ReportMulticastAddressRecordSourcesOffset + len(s.Sources)*IPv6AddressSize
}
func copyIPv6Address(dst []byte, src tcpip.Address) {
if n := copy(dst, src.AsSlice()); n != IPv6AddressSize {
panic(fmt.Sprintf("got copy(...) = %d, want = %d", n, IPv6AddressSize))
}
}
// SerializeInto serializes the record into the buffer.
//
// Panics if the buffer does not have enough space to fit the record.
func (s *MLDv2ReportMulticastAddressRecordSerializer) SerializeInto(b []byte) {
b[mldv2ReportMulticastAddressRecordTypeOffset] = byte(s.RecordType)
b[mldv2ReportMulticastAddressRecordAuxDataLenOffset] = 0
binary.BigEndian.PutUint16(b[mldv2ReportMulticastAddressRecordNumberOfSourcesOffset:], uint16(len(s.Sources)))
copyIPv6Address(b[mldv2ReportMulticastAddressRecordMulticastAddressOffset:], s.MulticastAddress)
b = b[mldv2ReportMulticastAddressRecordSourcesOffset:]
for _, source := range s.Sources {
copyIPv6Address(b, source)
b = b[IPv6AddressSize:]
}
}
const (
mldv2ReportReservedOffset = 0
mldv2ReportNumberOfMulticastAddressRecordsOffset = 2
mldv2ReportMulticastAddressRecordsOffset = 4
)
// MLDv2ReportSerializer is an MLD Version 2 Report serializer.
//
// As per RFC 3810 section 5.2,
//
// 0 1 2 3
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Type = 143 | Reserved | Checksum |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Reserved |Nr of Mcast Address Records (M)|
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | |
// . .
// . Multicast Address Record [1] .
// . .
// | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | |
// . .
// . Multicast Address Record [2] .
// . .
// | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | . |
// . . .
// | . |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | |
// . .
// . Multicast Address Record [M] .
// . .
// | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
type MLDv2ReportSerializer struct {
Records []MLDv2ReportMulticastAddressRecordSerializer
}
// Length returns the number of bytes this serializer would occupy.
func (s *MLDv2ReportSerializer) Length() int {
ret := mldv2ReportMulticastAddressRecordsOffset
for _, record := range s.Records {
ret += record.Length()
}
return ret
}
// SerializeInto serializes the report into the buffer.
//
// Panics if the buffer does not have enough space to fit the report.
func (s *MLDv2ReportSerializer) SerializeInto(b []byte) {
binary.BigEndian.PutUint16(b[mldv2ReportReservedOffset:], 0)
binary.BigEndian.PutUint16(b[mldv2ReportNumberOfMulticastAddressRecordsOffset:], uint16(len(s.Records)))
b = b[mldv2ReportMulticastAddressRecordsOffset:]
for _, record := range s.Records {
len := record.Length()
record.SerializeInto(b[:len])
b = b[len:]
}
}
// MLDv2ReportMulticastAddressRecord is an MLDv2 record.
//
// As per RFC 3810 section 5.2, a Multicast Address Record has the following
// internal format:
//
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Record Type | Aux Data Len | Number of Sources (N) |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | |
// * *
// | |
// * Multicast Address *
// | |
// * *
// | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | |
// * *
// | |
// * Source Address [1] *
// | |
// * *
// | |
// +- -+
// | |
// * *
// | |
// * Source Address [2] *
// | |
// * *
// | |
// +- -+
// . . .
// . . .
// . . .
// +- -+
// | |
// * *
// | |
// * Source Address [N] *
// | |
// * *
// | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | |
// . .
// . Auxiliary Data .
// . .
// | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
type MLDv2ReportMulticastAddressRecord []byte
// RecordType returns the type of this record.
func (r MLDv2ReportMulticastAddressRecord) RecordType() MLDv2ReportRecordType {
return MLDv2ReportRecordType(r[mldv2ReportMulticastAddressRecordTypeOffset])
}
// AuxDataLen returns the length of the auxiliary data in this record.
func (r MLDv2ReportMulticastAddressRecord) AuxDataLen() int {
return int(r[mldv2ReportMulticastAddressRecordAuxDataLenOffset]) * mldv2ReportMulticastAddressRecordAuxDataLenUnits
}
// numberOfSources returns the number of sources in this record.
func (r MLDv2ReportMulticastAddressRecord) numberOfSources() uint16 {
return binary.BigEndian.Uint16(r[mldv2ReportMulticastAddressRecordNumberOfSourcesOffset:])
}
// MulticastAddress returns the multicast address this record targets.
func (r MLDv2ReportMulticastAddressRecord) MulticastAddress() tcpip.Address {
return tcpip.AddrFrom16([16]byte(r[mldv2ReportMulticastAddressRecordMulticastAddressOffset:][:IPv6AddressSize]))
}
// Sources returns an iterator over source addresses in the query.
//
// Returns false if the message cannot hold the expected number of sources.
func (r MLDv2ReportMulticastAddressRecord) Sources() (AddressIterator, bool) {
expectedLen := int(r.numberOfSources()) * IPv6AddressSize
b := r[mldv2ReportMulticastAddressRecordSourcesOffset:]
if len(b) < expectedLen {
return AddressIterator{}, false
}
return AddressIterator{addressSize: IPv6AddressSize, buf: bytes.NewBuffer(b[:expectedLen])}, true
}
// MLDv2Report is an MLDv2 Report.
//
// As per RFC 3810 section 5.2,
//
// 0 1 2 3
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Type = 143 | Reserved | Checksum |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Reserved |Nr of Mcast Address Records (M)|
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | |
// . .
// . Multicast Address Record [1] .
// . .
// | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | |
// . .
// . Multicast Address Record [2] .
// . .
// | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | . |
// . . .
// | . |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | |
// . .
// . Multicast Address Record [M] .
// . .
// | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
type MLDv2Report []byte
// MLDv2ReportMulticastAddressRecordIterator is an iterator over MLDv2 Multicast
// Address Records.
type MLDv2ReportMulticastAddressRecordIterator struct {
recordsLeft uint16
buf *bytes.Buffer
}
// MLDv2ReportMulticastAddressRecordIteratorNextDisposition is the possible
// return values from MLDv2ReportMulticastAddressRecordIterator.Next.
type MLDv2ReportMulticastAddressRecordIteratorNextDisposition int
const (
// MLDv2ReportMulticastAddressRecordIteratorNextOk indicates that a multicast
// address record was yielded.
MLDv2ReportMulticastAddressRecordIteratorNextOk MLDv2ReportMulticastAddressRecordIteratorNextDisposition = iota
// MLDv2ReportMulticastAddressRecordIteratorNextDone indicates that the iterator
// has been exhausted.
MLDv2ReportMulticastAddressRecordIteratorNextDone
// MLDv2ReportMulticastAddressRecordIteratorNextErrBufferTooShort indicates
// that the iterator expected another record, but the buffer ended
// prematurely.
MLDv2ReportMulticastAddressRecordIteratorNextErrBufferTooShort
)
// Next returns the next MLDv2 Multicast Address Record.
func (it *MLDv2ReportMulticastAddressRecordIterator) Next() (MLDv2ReportMulticastAddressRecord, MLDv2ReportMulticastAddressRecordIteratorNextDisposition) {
if it.recordsLeft == 0 {
return MLDv2ReportMulticastAddressRecord{}, MLDv2ReportMulticastAddressRecordIteratorNextDone
}
if it.buf.Len() < mldv2ReportMulticastAddressRecordMinimumSize {
return MLDv2ReportMulticastAddressRecord{}, MLDv2ReportMulticastAddressRecordIteratorNextErrBufferTooShort
}
hdr := MLDv2ReportMulticastAddressRecord(it.buf.Bytes())
expectedLen := mldv2ReportMulticastAddressRecordMinimumSize +
int(hdr.AuxDataLen()) + int(hdr.numberOfSources())*IPv6AddressSize
bytes := it.buf.Next(expectedLen)
if len(bytes) < expectedLen {
return MLDv2ReportMulticastAddressRecord{}, MLDv2ReportMulticastAddressRecordIteratorNextErrBufferTooShort
}
it.recordsLeft--
return MLDv2ReportMulticastAddressRecord(bytes), MLDv2ReportMulticastAddressRecordIteratorNextOk
}
// MulticastAddressRecords returns an iterator of MLDv2 Multicast Address
// Records.
func (m MLDv2Report) MulticastAddressRecords() MLDv2ReportMulticastAddressRecordIterator {
return MLDv2ReportMulticastAddressRecordIterator{
recordsLeft: binary.BigEndian.Uint16(m[mldv2ReportNumberOfMulticastAddressRecordsOffset:]),
buf: bytes.NewBuffer(m[mldv2ReportMulticastAddressRecordsOffset:]),
}
}

View File

@@ -0,0 +1,124 @@
// Copyright 2022 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"bytes"
"fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
)
func mldv2AndIGMPv3QuerierQueryCodeToInterval(code uint8) time.Duration {
// MLDv2: As per RFC 3810 section 5.1.19,
//
// The Querier's Query Interval Code field specifies the [Query
// Interval] used by the Querier. The actual interval, called the
// Querier's Query Interval (QQI), is represented in units of seconds,
// and is derived from the Querier's Query Interval Code as follows:
//
// If QQIC < 128, QQI = QQIC
//
// If QQIC >= 128, QQIC represents a floating-point value as follows:
//
// 0 1 2 3 4 5 6 7
// +-+-+-+-+-+-+-+-+
// |1| exp | mant |
// +-+-+-+-+-+-+-+-+
//
// QQI = (mant | 0x10) << (exp + 3)
//
// Multicast routers that are not the current Querier adopt the QQI
// value from the most recently received Query as their own [Query
// Interval] value, unless that most recently received QQI was zero, in
// which case the receiving routers use the default [Query Interval]
// value specified in section 9.2.
//
// IGMPv3: As per RFC 3376 section 4.1.7,
//
// The Querier's Query Interval Code field specifies the [Query
// Interval] used by the querier. The actual interval, called the
// Querier's Query Interval (QQI), is represented in units of seconds
// and is derived from the Querier's Query Interval Code as follows:
//
// If QQIC < 128, QQI = QQIC
//
// If QQIC >= 128, QQIC represents a floating-point value as follows:
//
// 0 1 2 3 4 5 6 7
// +-+-+-+-+-+-+-+-+
// |1| exp | mant |
// +-+-+-+-+-+-+-+-+
//
// QQI = (mant | 0x10) << (exp + 3)
//
// Multicast routers that are not the current querier adopt the QQI
// value from the most recently received Query as their own [Query
// Interval] value, unless that most recently received QQI was zero, in
// which case the receiving routers use the default [Query Interval]
// value specified in section 8.2.
interval := time.Duration(code)
if interval < 128 {
return interval * time.Second
}
const expMask = 0b111
const mantBits = 4
mant := interval & ((1 << mantBits) - 1)
exp := (interval >> mantBits) & expMask
return (mant | 0x10) << (exp + 3) * time.Second
}
// MakeAddressIterator returns an AddressIterator.
func MakeAddressIterator(addressSize int, buf *bytes.Buffer) AddressIterator {
return AddressIterator{addressSize: addressSize, buf: buf}
}
// AddressIterator is an iterator over IPv6 addresses.
type AddressIterator struct {
addressSize int
buf *bytes.Buffer
}
// Done indicates that the iterator has been exhausted/has no more elements.
func (it *AddressIterator) Done() bool {
return it.buf.Len() == 0
}
// Next returns the next address in the iterator.
//
// Returns false if the iterator has been exhausted.
func (it *AddressIterator) Next() (tcpip.Address, bool) {
if it.Done() {
var emptyAddress tcpip.Address
return emptyAddress, false
}
b := it.buf.Next(it.addressSize)
if len(b) != it.addressSize {
panic(fmt.Sprintf("got len(buf.Next(%d)) = %d, want = %d", it.addressSize, len(b), it.addressSize))
}
return tcpip.AddrFromSlice(b), true
}
func makeAddressIterator(b []byte, expectedAddresses uint16, addressSize int) (AddressIterator, bool) {
expectedLen := int(expectedAddresses) * addressSize
if len(b) < expectedLen {
return AddressIterator{}, false
}
return MakeAddressIterator(addressSize, bytes.NewBuffer(b[:expectedLen])), true
}

View File

@@ -0,0 +1,110 @@
// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import "gvisor.dev/gvisor/pkg/tcpip"
// NDPNeighborAdvert is an NDP Neighbor Advertisement message. It will
// only contain the body of an ICMPv6 packet.
//
// See RFC 4861 section 4.4 for more details.
type NDPNeighborAdvert []byte
const (
// NDPNAMinimumSize is the minimum size of a valid NDP Neighbor
// Advertisement message (body of an ICMPv6 packet).
NDPNAMinimumSize = 20
// ndpNATargetAddressOffset is the start of the Target Address
// field within an NDPNeighborAdvert.
ndpNATargetAddressOffset = 4
// ndpNAOptionsOffset is the start of the NDP options in an
// NDPNeighborAdvert.
ndpNAOptionsOffset = ndpNATargetAddressOffset + IPv6AddressSize
// ndpNAFlagsOffset is the offset of the flags within an
// NDPNeighborAdvert
ndpNAFlagsOffset = 0
// ndpNARouterFlagMask is the mask of the Router Flag field in
// the flags byte within in an NDPNeighborAdvert.
ndpNARouterFlagMask = (1 << 7)
// ndpNASolicitedFlagMask is the mask of the Solicited Flag field in
// the flags byte within in an NDPNeighborAdvert.
ndpNASolicitedFlagMask = (1 << 6)
// ndpNAOverrideFlagMask is the mask of the Override Flag field in
// the flags byte within in an NDPNeighborAdvert.
ndpNAOverrideFlagMask = (1 << 5)
)
// TargetAddress returns the value within the Target Address field.
func (b NDPNeighborAdvert) TargetAddress() tcpip.Address {
return tcpip.AddrFrom16Slice(b[ndpNATargetAddressOffset:][:IPv6AddressSize])
}
// SetTargetAddress sets the value within the Target Address field.
func (b NDPNeighborAdvert) SetTargetAddress(addr tcpip.Address) {
copy(b[ndpNATargetAddressOffset:][:IPv6AddressSize], addr.AsSlice())
}
// RouterFlag returns the value of the Router Flag field.
func (b NDPNeighborAdvert) RouterFlag() bool {
return b[ndpNAFlagsOffset]&ndpNARouterFlagMask != 0
}
// SetRouterFlag sets the value in the Router Flag field.
func (b NDPNeighborAdvert) SetRouterFlag(f bool) {
if f {
b[ndpNAFlagsOffset] |= ndpNARouterFlagMask
} else {
b[ndpNAFlagsOffset] &^= ndpNARouterFlagMask
}
}
// SolicitedFlag returns the value of the Solicited Flag field.
func (b NDPNeighborAdvert) SolicitedFlag() bool {
return b[ndpNAFlagsOffset]&ndpNASolicitedFlagMask != 0
}
// SetSolicitedFlag sets the value in the Solicited Flag field.
func (b NDPNeighborAdvert) SetSolicitedFlag(f bool) {
if f {
b[ndpNAFlagsOffset] |= ndpNASolicitedFlagMask
} else {
b[ndpNAFlagsOffset] &^= ndpNASolicitedFlagMask
}
}
// OverrideFlag returns the value of the Override Flag field.
func (b NDPNeighborAdvert) OverrideFlag() bool {
return b[ndpNAFlagsOffset]&ndpNAOverrideFlagMask != 0
}
// SetOverrideFlag sets the value in the Override Flag field.
func (b NDPNeighborAdvert) SetOverrideFlag(f bool) {
if f {
b[ndpNAFlagsOffset] |= ndpNAOverrideFlagMask
} else {
b[ndpNAFlagsOffset] &^= ndpNAOverrideFlagMask
}
}
// Options returns an NDPOptions of the options body.
func (b NDPNeighborAdvert) Options() NDPOptions {
return NDPOptions(b[ndpNAOptionsOffset:])
}

View File

@@ -0,0 +1,52 @@
// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import "gvisor.dev/gvisor/pkg/tcpip"
// NDPNeighborSolicit is an NDP Neighbor Solicitation message. It will only
// contain the body of an ICMPv6 packet.
//
// See RFC 4861 section 4.3 for more details.
type NDPNeighborSolicit []byte
const (
// NDPNSMinimumSize is the minimum size of a valid NDP Neighbor
// Solicitation message (body of an ICMPv6 packet).
NDPNSMinimumSize = 20
// ndpNSTargetAddessOffset is the start of the Target Address
// field within an NDPNeighborSolicit.
ndpNSTargetAddessOffset = 4
// ndpNSOptionsOffset is the start of the NDP options in an
// NDPNeighborSolicit.
ndpNSOptionsOffset = ndpNSTargetAddessOffset + IPv6AddressSize
)
// TargetAddress returns the value within the Target Address field.
func (b NDPNeighborSolicit) TargetAddress() tcpip.Address {
return tcpip.AddrFrom16Slice(b[ndpNSTargetAddessOffset:][:IPv6AddressSize])
}
// SetTargetAddress sets the value within the Target Address field.
func (b NDPNeighborSolicit) SetTargetAddress(addr tcpip.Address) {
copy(b[ndpNSTargetAddessOffset:][:IPv6AddressSize], addr.AsSlice())
}
// Options returns an NDPOptions of the options body.
func (b NDPNeighborSolicit) Options() NDPOptions {
return NDPOptions(b[ndpNSOptionsOffset:])
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,204 @@
// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"encoding/binary"
"fmt"
"time"
)
var _ fmt.Stringer = NDPRoutePreference(0)
// NDPRoutePreference is the preference values for default routers or
// more-specific routes.
//
// As per RFC 4191 section 2.1,
//
// Default router preferences and preferences for more-specific routes
// are encoded the same way.
//
// Preference values are encoded as a two-bit signed integer, as
// follows:
//
// 01 High
// 00 Medium (default)
// 11 Low
// 10 Reserved - MUST NOT be sent
//
// Note that implementations can treat the value as a two-bit signed
// integer.
//
// Having just three values reinforces that they are not metrics and
// more values do not appear to be necessary for reasonable scenarios.
type NDPRoutePreference uint8
const (
// HighRoutePreference indicates a high preference, as per
// RFC 4191 section 2.1.
HighRoutePreference NDPRoutePreference = 0b01
// MediumRoutePreference indicates a medium preference, as per
// RFC 4191 section 2.1.
//
// This is the default preference value.
MediumRoutePreference = 0b00
// LowRoutePreference indicates a low preference, as per
// RFC 4191 section 2.1.
LowRoutePreference = 0b11
// ReservedRoutePreference is a reserved preference value, as per
// RFC 4191 section 2.1.
//
// It MUST NOT be sent.
ReservedRoutePreference = 0b10
)
// String implements fmt.Stringer.
func (p NDPRoutePreference) String() string {
switch p {
case HighRoutePreference:
return "HighRoutePreference"
case MediumRoutePreference:
return "MediumRoutePreference"
case LowRoutePreference:
return "LowRoutePreference"
case ReservedRoutePreference:
return "ReservedRoutePreference"
default:
return fmt.Sprintf("NDPRoutePreference(%d)", p)
}
}
// NDPRouterAdvert is an NDP Router Advertisement message. It will only contain
// the body of an ICMPv6 packet.
//
// See RFC 4861 section 4.2 and RFC 4191 section 2.2 for more details.
type NDPRouterAdvert []byte
// As per RFC 4191 section 2.2,
//
// 0 1 2 3
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Type | Code | Checksum |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Cur Hop Limit |M|O|H|Prf|Resvd| Router Lifetime |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Reachable Time |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Retrans Timer |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Options ...
// +-+-+-+-+-+-+-+-+-+-+-+-
const (
// NDPRAMinimumSize is the minimum size of a valid NDP Router
// Advertisement message (body of an ICMPv6 packet).
NDPRAMinimumSize = 12
// ndpRACurrHopLimitOffset is the byte of the Curr Hop Limit field
// within an NDPRouterAdvert.
ndpRACurrHopLimitOffset = 0
// ndpRAFlagsOffset is the byte with the NDP RA bit-fields/flags
// within an NDPRouterAdvert.
ndpRAFlagsOffset = 1
// ndpRAManagedAddrConfFlagMask is the mask of the Managed Address
// Configuration flag within the bit-field/flags byte of an
// NDPRouterAdvert.
ndpRAManagedAddrConfFlagMask = (1 << 7)
// ndpRAOtherConfFlagMask is the mask of the Other Configuration flag
// within the bit-field/flags byte of an NDPRouterAdvert.
ndpRAOtherConfFlagMask = (1 << 6)
// ndpDefaultRouterPreferenceShift is the shift of the Prf (Default Router
// Preference) field within the flags byte of an NDPRouterAdvert.
ndpDefaultRouterPreferenceShift = 3
// ndpDefaultRouterPreferenceMask is the mask of the Prf (Default Router
// Preference) field within the flags byte of an NDPRouterAdvert.
ndpDefaultRouterPreferenceMask = (0b11 << ndpDefaultRouterPreferenceShift)
// ndpRARouterLifetimeOffset is the start of the 2-byte Router Lifetime
// field within an NDPRouterAdvert.
ndpRARouterLifetimeOffset = 2
// ndpRAReachableTimeOffset is the start of the 4-byte Reachable Time
// field within an NDPRouterAdvert.
ndpRAReachableTimeOffset = 4
// ndpRARetransTimerOffset is the start of the 4-byte Retrans Timer
// field within an NDPRouterAdvert.
ndpRARetransTimerOffset = 8
// ndpRAOptionsOffset is the start of the NDP options in an
// NDPRouterAdvert.
ndpRAOptionsOffset = 12
)
// CurrHopLimit returns the value of the Curr Hop Limit field.
func (b NDPRouterAdvert) CurrHopLimit() uint8 {
return b[ndpRACurrHopLimitOffset]
}
// ManagedAddrConfFlag returns the value of the Managed Address Configuration
// flag.
func (b NDPRouterAdvert) ManagedAddrConfFlag() bool {
return b[ndpRAFlagsOffset]&ndpRAManagedAddrConfFlagMask != 0
}
// OtherConfFlag returns the value of the Other Configuration flag.
func (b NDPRouterAdvert) OtherConfFlag() bool {
return b[ndpRAFlagsOffset]&ndpRAOtherConfFlagMask != 0
}
// DefaultRouterPreference returns the Default Router Preference field.
func (b NDPRouterAdvert) DefaultRouterPreference() NDPRoutePreference {
return NDPRoutePreference((b[ndpRAFlagsOffset] & ndpDefaultRouterPreferenceMask) >> ndpDefaultRouterPreferenceShift)
}
// RouterLifetime returns the lifetime associated with the default router. A
// value of 0 means the source of the Router Advertisement is not a default
// router and SHOULD NOT appear on the default router list. Note, a value of 0
// only means that the router should not be used as a default router, it does
// not apply to other information contained in the Router Advertisement.
func (b NDPRouterAdvert) RouterLifetime() time.Duration {
// The field is the time in seconds, as per RFC 4861 section 4.2.
return time.Second * time.Duration(binary.BigEndian.Uint16(b[ndpRARouterLifetimeOffset:]))
}
// ReachableTime returns the time that a node assumes a neighbor is reachable
// after having received a reachability confirmation. A value of 0 means
// that it is unspecified by the source of the Router Advertisement message.
func (b NDPRouterAdvert) ReachableTime() time.Duration {
// The field is the time in milliseconds, as per RFC 4861 section 4.2.
return time.Millisecond * time.Duration(binary.BigEndian.Uint32(b[ndpRAReachableTimeOffset:]))
}
// RetransTimer returns the time between retransmitted Neighbor Solicitation
// messages. A value of 0 means that it is unspecified by the source of the
// Router Advertisement message.
func (b NDPRouterAdvert) RetransTimer() time.Duration {
// The field is the time in milliseconds, as per RFC 4861 section 4.2.
return time.Millisecond * time.Duration(binary.BigEndian.Uint32(b[ndpRARetransTimerOffset:]))
}
// Options returns an NDPOptions of the options body.
func (b NDPRouterAdvert) Options() NDPOptions {
return NDPOptions(b[ndpRAOptionsOffset:])
}

View File

@@ -0,0 +1,36 @@
// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
// NDPRouterSolicit is an NDP Router Solicitation message. It will only contain
// the body of an ICMPv6 packet.
//
// See RFC 4861 section 4.1 for more details.
type NDPRouterSolicit []byte
const (
// NDPRSMinimumSize is the minimum size of a valid NDP Router
// Solicitation message (body of an ICMPv6 packet).
NDPRSMinimumSize = 4
// ndpRSOptionsOffset is the start of the NDP options in an
// NDPRouterSolicit.
ndpRSOptionsOffset = 4
)
// Options returns an NDPOptions of the options body.
func (b NDPRouterSolicit) Options() NDPOptions {
return NDPOptions(b[ndpRSOptionsOffset:])
}

View File

@@ -0,0 +1,58 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Code generated by "stringer -type ndpOptionIdentifier"; DO NOT EDIT.
package header
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[ndpSourceLinkLayerAddressOptionType-1]
_ = x[ndpTargetLinkLayerAddressOptionType-2]
_ = x[ndpPrefixInformationType-3]
_ = x[ndpNonceOptionType-14]
_ = x[ndpRecursiveDNSServerOptionType-25]
_ = x[ndpDNSSearchListOptionType-31]
}
const (
_ndpOptionIdentifier_name_0 = "ndpSourceLinkLayerAddressOptionTypendpTargetLinkLayerAddressOptionTypendpPrefixInformationType"
_ndpOptionIdentifier_name_1 = "ndpNonceOptionType"
_ndpOptionIdentifier_name_2 = "ndpRecursiveDNSServerOptionType"
_ndpOptionIdentifier_name_3 = "ndpDNSSearchListOptionType"
)
var (
_ndpOptionIdentifier_index_0 = [...]uint8{0, 35, 70, 94}
)
func (i ndpOptionIdentifier) String() string {
switch {
case 1 <= i && i <= 3:
i -= 1
return _ndpOptionIdentifier_name_0[_ndpOptionIdentifier_index_0[i]:_ndpOptionIdentifier_index_0[i+1]]
case i == 14:
return _ndpOptionIdentifier_name_1
case i == 25:
return _ndpOptionIdentifier_name_2
case i == 31:
return _ndpOptionIdentifier_name_3
default:
return "ndpOptionIdentifier(" + strconv.FormatInt(int64(i), 10) + ")"
}
}

View File

@@ -0,0 +1,243 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package parse provides utilities to parse packets.
package parse
import (
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
// ARP populates pkt's network header with an ARP header found in
// pkt.Data.
//
// Returns true if the header was successfully parsed.
func ARP(pkt *stack.PacketBuffer) bool {
_, ok := pkt.NetworkHeader().Consume(header.ARPSize)
if ok {
pkt.NetworkProtocolNumber = header.ARPProtocolNumber
}
return ok
}
// IPv4 parses an IPv4 packet found in pkt.Data and populates pkt's network
// header with the IPv4 header.
//
// Returns true if the header was successfully parsed.
func IPv4(pkt *stack.PacketBuffer) bool {
hdr, ok := pkt.Data().PullUp(header.IPv4MinimumSize)
if !ok {
return false
}
ipHdr := header.IPv4(hdr)
// Header may have options, determine the true header length.
headerLen := int(ipHdr.HeaderLength())
if headerLen < header.IPv4MinimumSize {
// TODO(gvisor.dev/issue/2404): Per RFC 791, IHL needs to be at least 5 in
// order for the packet to be valid. Figure out if we want to reject this
// case.
headerLen = header.IPv4MinimumSize
}
hdr, ok = pkt.NetworkHeader().Consume(headerLen)
if !ok {
return false
}
ipHdr = header.IPv4(hdr)
length := int(ipHdr.TotalLength()) - len(hdr)
if length < 0 {
return false
}
pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
pkt.Data().CapLength(length)
return true
}
// IPv6 parses an IPv6 packet found in pkt.Data and populates pkt's network
// header with the IPv6 header.
func IPv6(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, fragID uint32, fragOffset uint16, fragMore bool, ok bool) {
hdr, ok := pkt.Data().PullUp(header.IPv6MinimumSize)
if !ok {
return 0, 0, 0, false, false
}
ipHdr := header.IPv6(hdr)
// Create a VV to parse the packet. We don't plan to modify anything here.
// dataVV consists of:
// - Any IPv6 header bytes after the first 40 (i.e. extensions).
// - The transport header, if present.
// - Any other payload data.
dataBuf := pkt.Data().ToBuffer()
dataBuf.TrimFront(header.IPv6MinimumSize)
it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataBuf)
defer it.Release()
// Iterate over the IPv6 extensions to find their length.
var nextHdr tcpip.TransportProtocolNumber
var extensionsSize int64
traverseExtensions:
for {
extHdr, done, err := it.Next()
if err != nil {
break
}
// If we exhaust the extension list, the entire packet is the IPv6 header
// and (possibly) extensions.
if done {
extensionsSize = dataBuf.Size()
break
}
switch extHdr := extHdr.(type) {
case header.IPv6FragmentExtHdr:
if extHdr.IsAtomic() {
// This fragment extension header indicates that this packet is an
// atomic fragment. An atomic fragment is a fragment that contains
// all the data required to reassemble a full packet. As per RFC 6946,
// atomic fragments must not interfere with "normal" fragmented traffic
// so we skip processing the fragment instead of feeding it through the
// reassembly process below.
continue
}
if fragID == 0 && fragOffset == 0 && !fragMore {
fragID = extHdr.ID()
fragOffset = extHdr.FragmentOffset()
fragMore = extHdr.More()
}
rawPayload := it.AsRawHeader(true /* consume */)
extensionsSize = dataBuf.Size() - rawPayload.Buf.Size()
rawPayload.Release()
extHdr.Release()
break traverseExtensions
case header.IPv6RawPayloadHeader:
// We've found the payload after any extensions.
extensionsSize = dataBuf.Size() - extHdr.Buf.Size()
nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier)
extHdr.Release()
break traverseExtensions
default:
extHdr.Release()
// Any other extension is a no-op, keep looping until we find the payload.
}
}
// Put the IPv6 header with extensions in pkt.NetworkHeader().
hdr, ok = pkt.NetworkHeader().Consume(header.IPv6MinimumSize + int(extensionsSize))
if !ok {
panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data().Size()))
}
ipHdr = header.IPv6(hdr)
pkt.Data().CapLength(int(ipHdr.PayloadLength()))
pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber
return nextHdr, fragID, fragOffset, fragMore, true
}
// UDP parses a UDP packet found in pkt.Data and populates pkt's transport
// header with the UDP header.
//
// Returns true if the header was successfully parsed.
func UDP(pkt *stack.PacketBuffer) bool {
_, ok := pkt.TransportHeader().Consume(header.UDPMinimumSize)
pkt.TransportProtocolNumber = header.UDPProtocolNumber
return ok
}
// TCP parses a TCP packet found in pkt.Data and populates pkt's transport
// header with the TCP header.
//
// Returns true if the header was successfully parsed.
func TCP(pkt *stack.PacketBuffer) bool {
// TCP header is variable length, peek at it first.
hdrLen := header.TCPMinimumSize
hdr, ok := pkt.Data().PullUp(hdrLen)
if !ok {
return false
}
// If the header has options, pull those up as well.
if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data().Size() {
// TODO(gvisor.dev/issue/2404): Figure out whether to reject this kind of
// packets.
hdrLen = offset
}
_, ok = pkt.TransportHeader().Consume(hdrLen)
pkt.TransportProtocolNumber = header.TCPProtocolNumber
return ok
}
// ICMPv4 populates the packet buffer's transport header with an ICMPv4 header,
// if present.
//
// Returns true if an ICMPv4 header was successfully parsed.
func ICMPv4(pkt *stack.PacketBuffer) bool {
if _, ok := pkt.TransportHeader().Consume(header.ICMPv4MinimumSize); ok {
pkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
return true
}
return false
}
// ICMPv6 populates the packet buffer's transport header with an ICMPv4 header,
// if present.
//
// Returns true if an ICMPv6 header was successfully parsed.
func ICMPv6(pkt *stack.PacketBuffer) bool {
hdr, ok := pkt.Data().PullUp(header.ICMPv6MinimumSize)
if !ok {
return false
}
h := header.ICMPv6(hdr)
switch h.Type() {
case header.ICMPv6RouterSolicit,
header.ICMPv6RouterAdvert,
header.ICMPv6NeighborSolicit,
header.ICMPv6NeighborAdvert,
header.ICMPv6RedirectMsg,
header.ICMPv6MulticastListenerQuery,
header.ICMPv6MulticastListenerReport,
header.ICMPv6MulticastListenerV2Report,
header.ICMPv6MulticastListenerDone:
size := pkt.Data().Size()
if _, ok := pkt.TransportHeader().Consume(size); !ok {
panic(fmt.Sprintf("expected to consume the full data of size = %d bytes into transport header", size))
}
case header.ICMPv6DstUnreachable,
header.ICMPv6PacketTooBig,
header.ICMPv6TimeExceeded,
header.ICMPv6ParamProblem,
header.ICMPv6EchoRequest,
header.ICMPv6EchoReply:
fallthrough
default:
if _, ok := pkt.TransportHeader().Consume(header.ICMPv6MinimumSize); !ok {
// Checked above if the packet buffer holds at least the minimum size for
// an ICMPv6 packet.
panic(fmt.Sprintf("expected to consume %d bytes", header.ICMPv6MinimumSize))
}
}
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
return true
}

View File

@@ -0,0 +1,3 @@
// automatically generated by stateify.
package parse

View File

@@ -0,0 +1,726 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"encoding/binary"
"github.com/google/btree"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/checksum"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
)
// These constants are the offsets of the respective fields in the TCP header.
const (
TCPSrcPortOffset = 0
TCPDstPortOffset = 2
TCPSeqNumOffset = 4
TCPAckNumOffset = 8
TCPDataOffset = 12
TCPFlagsOffset = 13
TCPWinSizeOffset = 14
TCPChecksumOffset = 16
TCPUrgentPtrOffset = 18
)
const (
// MaxWndScale is maximum allowed window scaling, as described in
// RFC 1323, section 2.3, page 11.
MaxWndScale = 14
// TCPMaxSACKBlocks is the maximum number of SACK blocks that can
// be encoded in a TCP option field.
TCPMaxSACKBlocks = 4
)
// TCPFlags is the dedicated type for TCP flags.
type TCPFlags uint8
// Intersects returns true iff there are flags common to both f and o.
func (f TCPFlags) Intersects(o TCPFlags) bool {
return f&o != 0
}
// Contains returns true iff all the flags in o are contained within f.
func (f TCPFlags) Contains(o TCPFlags) bool {
return f&o == o
}
// String implements Stringer.String.
func (f TCPFlags) String() string {
flagsStr := []byte("FSRPAUEC")
for i := range flagsStr {
if f&(1<<uint(i)) == 0 {
flagsStr[i] = ' '
}
}
return string(flagsStr)
}
// Flags that may be set in a TCP segment.
const (
TCPFlagFin TCPFlags = 1 << iota
TCPFlagSyn
TCPFlagRst
TCPFlagPsh
TCPFlagAck
TCPFlagUrg
TCPFlagEce
TCPFlagCwr
)
// Options that may be present in a TCP segment.
const (
TCPOptionEOL = 0
TCPOptionNOP = 1
TCPOptionMSS = 2
TCPOptionWS = 3
TCPOptionTS = 8
TCPOptionSACKPermitted = 4
TCPOptionSACK = 5
)
// Option Lengths.
const (
TCPOptionMSSLength = 4
TCPOptionTSLength = 10
TCPOptionWSLength = 3
TCPOptionSackPermittedLength = 2
)
// TCPFields contains the fields of a TCP packet. It is used to describe the
// fields of a packet that needs to be encoded.
type TCPFields struct {
// SrcPort is the "source port" field of a TCP packet.
SrcPort uint16
// DstPort is the "destination port" field of a TCP packet.
DstPort uint16
// SeqNum is the "sequence number" field of a TCP packet.
SeqNum uint32
// AckNum is the "acknowledgement number" field of a TCP packet.
AckNum uint32
// DataOffset is the "data offset" field of a TCP packet. It is the length of
// the TCP header in bytes.
DataOffset uint8
// Flags is the "flags" field of a TCP packet.
Flags TCPFlags
// WindowSize is the "window size" field of a TCP packet.
WindowSize uint16
// Checksum is the "checksum" field of a TCP packet.
Checksum uint16
// UrgentPointer is the "urgent pointer" field of a TCP packet.
UrgentPointer uint16
}
// TCPSynOptions is used to return the parsed TCP Options in a syn
// segment.
//
// +stateify savable
type TCPSynOptions struct {
// MSS is the maximum segment size provided by the peer in the SYN.
MSS uint16
// WS is the window scale option provided by the peer in the SYN.
//
// Set to -1 if no window scale option was provided.
WS int
// TS is true if the timestamp option was provided in the syn/syn-ack.
TS bool
// TSVal is the value of the TSVal field in the timestamp option.
TSVal uint32
// TSEcr is the value of the TSEcr field in the timestamp option.
TSEcr uint32
// SACKPermitted is true if the SACK option was provided in the SYN/SYN-ACK.
SACKPermitted bool
// Flags if specified are set on the outgoing SYN. The SYN flag is
// always set.
Flags TCPFlags
}
// SACKBlock represents a single contiguous SACK block.
//
// +stateify savable
type SACKBlock struct {
// Start indicates the lowest sequence number in the block.
Start seqnum.Value
// End indicates the sequence number immediately following the last
// sequence number of this block.
End seqnum.Value
}
// Less returns true if r.Start < b.Start.
func (r SACKBlock) Less(b btree.Item) bool {
return r.Start.LessThan(b.(SACKBlock).Start)
}
// Contains returns true if b is completely contained in r.
func (r SACKBlock) Contains(b SACKBlock) bool {
return r.Start.LessThanEq(b.Start) && b.End.LessThanEq(r.End)
}
// TCPOptions are used to parse and cache the TCP segment options for a non
// syn/syn-ack segment.
//
// +stateify savable
type TCPOptions struct {
// TS is true if the TimeStamp option is enabled.
TS bool
// TSVal is the value in the TSVal field of the segment.
TSVal uint32
// TSEcr is the value in the TSEcr field of the segment.
TSEcr uint32
// SACKBlocks are the SACK blocks specified in the segment.
SACKBlocks []SACKBlock
}
// TCP represents a TCP header stored in a byte array.
type TCP []byte
const (
// TCPMinimumSize is the minimum size of a valid TCP packet.
TCPMinimumSize = 20
// TCPOptionsMaximumSize is the maximum size of TCP options.
TCPOptionsMaximumSize = 40
// TCPHeaderMaximumSize is the maximum header size of a TCP packet.
TCPHeaderMaximumSize = TCPMinimumSize + TCPOptionsMaximumSize
// TCPTotalHeaderMaximumSize is the maximum size of headers from all layers in
// a TCP packet. It analogous to MAX_TCP_HEADER in Linux.
//
// TODO(b/319936470): Investigate why this needs to be at least 140 bytes. In
// Linux this value is at least 160, but in theory we should be able to use
// 138. In practice anything less than 140 starts to break GSO on gVNIC
// hardware.
TCPTotalHeaderMaximumSize = 160
// TCPProtocolNumber is TCP's transport protocol number.
TCPProtocolNumber tcpip.TransportProtocolNumber = 6
// TCPMinimumMSS is the minimum acceptable value for MSS. This is the
// same as the value TCP_MIN_MSS defined net/tcp.h.
TCPMinimumMSS = IPv4MaximumHeaderSize + TCPHeaderMaximumSize + MinIPFragmentPayloadSize - IPv4MinimumSize - TCPMinimumSize
// TCPMinimumSendMSS is the minimum value for MSS in a sender. This is the
// same as the value TCP_MIN_SND_MSS in net/tcp.h.
TCPMinimumSendMSS = TCPOptionsMaximumSize + MinIPFragmentPayloadSize
// TCPMaximumMSS is the maximum acceptable value for MSS.
TCPMaximumMSS = 0xffff
// TCPDefaultMSS is the MSS value that should be used if an MSS option
// is not received from the peer. It's also the value returned by
// TCP_MAXSEG option for a socket in an unconnected state.
//
// Per RFC 1122, page 85: "If an MSS option is not received at
// connection setup, TCP MUST assume a default send MSS of 536."
TCPDefaultMSS = 536
)
// SourcePort returns the "source port" field of the TCP header.
func (b TCP) SourcePort() uint16 {
return binary.BigEndian.Uint16(b[TCPSrcPortOffset:])
}
// DestinationPort returns the "destination port" field of the TCP header.
func (b TCP) DestinationPort() uint16 {
return binary.BigEndian.Uint16(b[TCPDstPortOffset:])
}
// SequenceNumber returns the "sequence number" field of the TCP header.
func (b TCP) SequenceNumber() uint32 {
return binary.BigEndian.Uint32(b[TCPSeqNumOffset:])
}
// AckNumber returns the "ack number" field of the TCP header.
func (b TCP) AckNumber() uint32 {
return binary.BigEndian.Uint32(b[TCPAckNumOffset:])
}
// DataOffset returns the "data offset" field of the TCP header. The return
// value is the length of the TCP header in bytes.
func (b TCP) DataOffset() uint8 {
return (b[TCPDataOffset] >> 4) * 4
}
// Payload returns the data in the TCP packet.
func (b TCP) Payload() []byte {
return b[b.DataOffset():]
}
// Flags returns the flags field of the TCP header.
func (b TCP) Flags() TCPFlags {
return TCPFlags(b[TCPFlagsOffset])
}
// WindowSize returns the "window size" field of the TCP header.
func (b TCP) WindowSize() uint16 {
return binary.BigEndian.Uint16(b[TCPWinSizeOffset:])
}
// Checksum returns the "checksum" field of the TCP header.
func (b TCP) Checksum() uint16 {
return binary.BigEndian.Uint16(b[TCPChecksumOffset:])
}
// UrgentPointer returns the "urgent pointer" field of the TCP header.
func (b TCP) UrgentPointer() uint16 {
return binary.BigEndian.Uint16(b[TCPUrgentPtrOffset:])
}
// SetSourcePort sets the "source port" field of the TCP header.
func (b TCP) SetSourcePort(port uint16) {
binary.BigEndian.PutUint16(b[TCPSrcPortOffset:], port)
}
// SetDestinationPort sets the "destination port" field of the TCP header.
func (b TCP) SetDestinationPort(port uint16) {
binary.BigEndian.PutUint16(b[TCPDstPortOffset:], port)
}
// SetChecksum sets the checksum field of the TCP header.
func (b TCP) SetChecksum(xsum uint16) {
checksum.Put(b[TCPChecksumOffset:], xsum)
}
// SetDataOffset sets the data offset field of the TCP header. headerLen should
// be the length of the TCP header in bytes.
func (b TCP) SetDataOffset(headerLen uint8) {
b[TCPDataOffset] = (headerLen / 4) << 4
}
// SetSequenceNumber sets the sequence number field of the TCP header.
func (b TCP) SetSequenceNumber(seqNum uint32) {
binary.BigEndian.PutUint32(b[TCPSeqNumOffset:], seqNum)
}
// SetAckNumber sets the ack number field of the TCP header.
func (b TCP) SetAckNumber(ackNum uint32) {
binary.BigEndian.PutUint32(b[TCPAckNumOffset:], ackNum)
}
// SetFlags sets the flags field of the TCP header.
func (b TCP) SetFlags(flags uint8) {
b[TCPFlagsOffset] = flags
}
// SetWindowSize sets the window size field of the TCP header.
func (b TCP) SetWindowSize(rcvwnd uint16) {
binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd)
}
// SetUrgentPointer sets the window size field of the TCP header.
func (b TCP) SetUrgentPointer(urgentPointer uint16) {
binary.BigEndian.PutUint16(b[TCPUrgentPtrOffset:], urgentPointer)
}
// CalculateChecksum calculates the checksum of the TCP segment.
// partialChecksum is the checksum of the network-layer pseudo-header
// and the checksum of the segment data.
func (b TCP) CalculateChecksum(partialChecksum uint16) uint16 {
// Calculate the rest of the checksum.
return checksum.Checksum(b[:b.DataOffset()], partialChecksum)
}
// IsChecksumValid returns true iff the TCP header's checksum is valid.
func (b TCP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum, payloadLength uint16) bool {
xsum := PseudoHeaderChecksum(TCPProtocolNumber, src, dst, uint16(b.DataOffset())+payloadLength)
xsum = checksum.Combine(xsum, payloadChecksum)
return b.CalculateChecksum(xsum) == 0xffff
}
// Options returns a slice that holds the unparsed TCP options in the segment.
func (b TCP) Options() []byte {
return b[TCPMinimumSize:b.DataOffset()]
}
// ParsedOptions returns a TCPOptions structure which parses and caches the TCP
// option values in the TCP segment. NOTE: Invoking this function repeatedly is
// expensive as it reparses the options on each invocation.
func (b TCP) ParsedOptions() TCPOptions {
return ParseTCPOptions(b.Options())
}
func (b TCP) encodeSubset(seq, ack uint32, flags TCPFlags, rcvwnd uint16) {
binary.BigEndian.PutUint32(b[TCPSeqNumOffset:], seq)
binary.BigEndian.PutUint32(b[TCPAckNumOffset:], ack)
b[TCPFlagsOffset] = uint8(flags)
binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd)
}
// Encode encodes all the fields of the TCP header.
func (b TCP) Encode(t *TCPFields) {
b.encodeSubset(t.SeqNum, t.AckNum, t.Flags, t.WindowSize)
b.SetSourcePort(t.SrcPort)
b.SetDestinationPort(t.DstPort)
b.SetDataOffset(t.DataOffset)
b.SetChecksum(t.Checksum)
b.SetUrgentPointer(t.UrgentPointer)
}
// EncodePartial updates a subset of the fields of the TCP header. It is useful
// in cases when similar segments are produced.
func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32, flags TCPFlags, rcvwnd uint16) {
// Add the total length and "flags" field contributions to the checksum.
// We don't use the flags field directly from the header because it's a
// one-byte field with an odd offset, so it would be accounted for
// incorrectly by the Checksum routine.
tmp := make([]byte, 4)
binary.BigEndian.PutUint16(tmp, length)
binary.BigEndian.PutUint16(tmp[2:], uint16(flags))
xsum := checksum.Checksum(tmp, partialChecksum)
// Encode the passed-in fields.
b.encodeSubset(seqnum, acknum, flags, rcvwnd)
// Add the contributions of the passed-in fields to the checksum.
xsum = checksum.Checksum(b[TCPSeqNumOffset:TCPSeqNumOffset+8], xsum)
xsum = checksum.Checksum(b[TCPWinSizeOffset:TCPWinSizeOffset+2], xsum)
// Encode the checksum.
b.SetChecksum(^xsum)
}
// SetSourcePortWithChecksumUpdate implements ChecksummableTransport.
func (b TCP) SetSourcePortWithChecksumUpdate(new uint16) {
old := b.SourcePort()
b.SetSourcePort(new)
b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
}
// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport.
func (b TCP) SetDestinationPortWithChecksumUpdate(new uint16) {
old := b.DestinationPort()
b.SetDestinationPort(new)
b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
}
// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport.
func (b TCP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) {
xsum := b.Checksum()
if fullChecksum {
xsum = ^xsum
}
xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new)
if fullChecksum {
xsum = ^xsum
}
b.SetChecksum(xsum)
}
// ParseSynOptions parses the options received in a SYN segment and returns the
// relevant ones. opts should point to the option part of the TCP header.
func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions {
limit := len(opts)
synOpts := TCPSynOptions{
// Per RFC 1122, page 85: "If an MSS option is not received at
// connection setup, TCP MUST assume a default send MSS of 536."
MSS: TCPDefaultMSS,
// If no window scale option is specified, WS in options is
// returned as -1; this is because the absence of the option
// indicates that the we cannot use window scaling on the
// receive end either.
WS: -1,
}
for i := 0; i < limit; {
switch opts[i] {
case TCPOptionEOL:
i = limit
case TCPOptionNOP:
i++
case TCPOptionMSS:
if i+4 > limit || opts[i+1] != 4 {
return synOpts
}
mss := uint16(opts[i+2])<<8 | uint16(opts[i+3])
if mss == 0 {
return synOpts
}
synOpts.MSS = mss
if mss < TCPMinimumSendMSS {
synOpts.MSS = TCPMinimumSendMSS
}
i += 4
case TCPOptionWS:
if i+3 > limit || opts[i+1] != 3 {
return synOpts
}
ws := int(opts[i+2])
if ws > MaxWndScale {
ws = MaxWndScale
}
synOpts.WS = ws
i += 3
case TCPOptionTS:
if i+10 > limit || opts[i+1] != 10 {
return synOpts
}
synOpts.TSVal = binary.BigEndian.Uint32(opts[i+2:])
if isAck {
// If the segment is a SYN-ACK then store the Timestamp Echo Reply
// in the segment.
synOpts.TSEcr = binary.BigEndian.Uint32(opts[i+6:])
}
synOpts.TS = true
i += 10
case TCPOptionSACKPermitted:
if i+2 > limit || opts[i+1] != 2 {
return synOpts
}
synOpts.SACKPermitted = true
i += 2
default:
// We don't recognize this option, just skip over it.
if i+2 > limit {
return synOpts
}
l := int(opts[i+1])
// If the length is incorrect or if l+i overflows the
// total options length then return false.
if l < 2 || i+l > limit {
return synOpts
}
i += l
}
}
return synOpts
}
// ParseTCPOptions extracts and stores all known options in the provided byte
// slice in a TCPOptions structure.
func ParseTCPOptions(b []byte) TCPOptions {
opts := TCPOptions{}
limit := len(b)
for i := 0; i < limit; {
switch b[i] {
case TCPOptionEOL:
i = limit
case TCPOptionNOP:
i++
case TCPOptionTS:
if i+10 > limit || (b[i+1] != 10) {
return opts
}
opts.TS = true
opts.TSVal = binary.BigEndian.Uint32(b[i+2:])
opts.TSEcr = binary.BigEndian.Uint32(b[i+6:])
i += 10
case TCPOptionSACK:
if i+2 > limit {
// Malformed SACK block, just return and stop parsing.
return opts
}
sackOptionLen := int(b[i+1])
if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 {
// Malformed SACK block, just return and stop parsing.
return opts
}
numBlocks := (sackOptionLen - 2) / 8
opts.SACKBlocks = []SACKBlock{}
for j := 0; j < numBlocks; j++ {
start := binary.BigEndian.Uint32(b[i+2+j*8:])
end := binary.BigEndian.Uint32(b[i+2+j*8+4:])
opts.SACKBlocks = append(opts.SACKBlocks, SACKBlock{
Start: seqnum.Value(start),
End: seqnum.Value(end),
})
}
i += sackOptionLen
default:
// We don't recognize this option, just skip over it.
if i+2 > limit {
return opts
}
l := int(b[i+1])
// If the length is incorrect or if l+i overflows the
// total options length then return false.
if l < 2 || i+l > limit {
return opts
}
i += l
}
}
return opts
}
// EncodeMSSOption encodes the MSS TCP option with the provided MSS values in
// the supplied buffer. If the provided buffer is not large enough then it just
// returns without encoding anything. It returns the number of bytes written to
// the provided buffer.
func EncodeMSSOption(mss uint32, b []byte) int {
if len(b) < TCPOptionMSSLength {
return 0
}
b[0], b[1], b[2], b[3] = TCPOptionMSS, TCPOptionMSSLength, byte(mss>>8), byte(mss)
return TCPOptionMSSLength
}
// EncodeWSOption encodes the WS TCP option with the WS value in the
// provided buffer. If the provided buffer is not large enough then it just
// returns without encoding anything. It returns the number of bytes written to
// the provided buffer.
func EncodeWSOption(ws int, b []byte) int {
if len(b) < TCPOptionWSLength {
return 0
}
b[0], b[1], b[2] = TCPOptionWS, TCPOptionWSLength, uint8(ws)
return int(b[1])
}
// EncodeTSOption encodes the provided tsVal and tsEcr values as a TCP timestamp
// option into the provided buffer. If the buffer is smaller than expected it
// just returns without encoding anything. It returns the number of bytes
// written to the provided buffer.
func EncodeTSOption(tsVal, tsEcr uint32, b []byte) int {
if len(b) < TCPOptionTSLength {
return 0
}
b[0], b[1] = TCPOptionTS, TCPOptionTSLength
binary.BigEndian.PutUint32(b[2:], tsVal)
binary.BigEndian.PutUint32(b[6:], tsEcr)
return int(b[1])
}
// EncodeSACKPermittedOption encodes a SACKPermitted option into the provided
// buffer. If the buffer is smaller than required it just returns without
// encoding anything. It returns the number of bytes written to the provided
// buffer.
func EncodeSACKPermittedOption(b []byte) int {
if len(b) < TCPOptionSackPermittedLength {
return 0
}
b[0], b[1] = TCPOptionSACKPermitted, TCPOptionSackPermittedLength
return int(b[1])
}
// EncodeSACKBlocks encodes the provided SACK blocks as a TCP SACK option block
// in the provided slice. It tries to fit in as many blocks as possible based on
// number of bytes available in the provided buffer. It returns the number of
// bytes written to the provided buffer.
func EncodeSACKBlocks(sackBlocks []SACKBlock, b []byte) int {
if len(sackBlocks) == 0 {
return 0
}
l := len(sackBlocks)
if l > TCPMaxSACKBlocks {
l = TCPMaxSACKBlocks
}
if ll := (len(b) - 2) / 8; ll < l {
l = ll
}
if l == 0 {
// There is not enough space in the provided buffer to add
// any SACK blocks.
return 0
}
b[0] = TCPOptionSACK
b[1] = byte(l*8 + 2)
for i := 0; i < l; i++ {
binary.BigEndian.PutUint32(b[i*8+2:], uint32(sackBlocks[i].Start))
binary.BigEndian.PutUint32(b[i*8+6:], uint32(sackBlocks[i].End))
}
return int(b[1])
}
// EncodeNOP adds an explicit NOP to the option list.
func EncodeNOP(b []byte) int {
if len(b) == 0 {
return 0
}
b[0] = TCPOptionNOP
return 1
}
// AddTCPOptionPadding adds the required number of TCPOptionNOP to quad align
// the option buffer. It adds padding bytes after the offset specified and
// returns the number of padding bytes added. The passed in options slice
// must have space for the padding bytes.
func AddTCPOptionPadding(options []byte, offset int) int {
paddingToAdd := -offset & 3
// Now add any padding bytes that might be required to quad align the
// options.
for i := offset; i < offset+paddingToAdd; i++ {
options[i] = TCPOptionNOP
}
return paddingToAdd
}
// Acceptable checks if a segment that starts at segSeq and has length segLen is
// "acceptable" for arriving in a receive window that starts at rcvNxt and ends
// before rcvAcc, according to the table on page 26 and 69 of RFC 793.
func Acceptable(segSeq seqnum.Value, segLen seqnum.Size, rcvNxt, rcvAcc seqnum.Value) bool {
if rcvNxt == rcvAcc {
return segLen == 0 && segSeq == rcvNxt
}
if segLen == 0 {
// rcvWnd is incremented by 1 because that is Linux's behavior despite the
// RFC.
return segSeq.InRange(rcvNxt, rcvAcc.Add(1))
}
// Page 70 of RFC 793 allows packets that can be made "acceptable" by trimming
// the payload, so we'll accept any payload that overlaps the receive window.
// segSeq < rcvAcc is more correct according to RFC, however, Linux does it
// differently, it uses segSeq <= rcvAcc, we'd want to keep the same behavior
// as Linux.
return rcvNxt.LessThan(segSeq.Add(segLen)) && segSeq.LessThanEq(rcvAcc)
}
// TCPValid returns true if the pkt has a valid TCP header. It checks whether:
// - The data offset is too small.
// - The data offset is too large.
// - The checksum is invalid.
//
// TCPValid corresponds to net/netfilter/nf_conntrack_proto_tcp.c:tcp_error.
func TCPValid(hdr TCP, payloadChecksum func() uint16, payloadSize uint16, srcAddr, dstAddr tcpip.Address, skipChecksumValidation bool) (csum uint16, csumValid, ok bool) {
if offset := int(hdr.DataOffset()); offset < TCPMinimumSize || offset > len(hdr) {
return
}
if skipChecksumValidation {
csumValid = true
} else {
csum = hdr.Checksum()
csumValid = hdr.IsChecksumValid(srcAddr, dstAddr, payloadChecksum(), payloadSize)
}
return csum, csumValid, true
}

View File

@@ -0,0 +1,195 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"encoding/binary"
"math"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/checksum"
)
const (
udpSrcPort = 0
udpDstPort = 2
udpLength = 4
udpChecksum = 6
)
const (
// UDPMaximumPacketSize is the largest possible UDP packet.
UDPMaximumPacketSize = 0xffff
)
// UDPFields contains the fields of a UDP packet. It is used to describe the
// fields of a packet that needs to be encoded.
type UDPFields struct {
// SrcPort is the "source port" field of a UDP packet.
SrcPort uint16
// DstPort is the "destination port" field of a UDP packet.
DstPort uint16
// Length is the "length" field of a UDP packet.
Length uint16
// Checksum is the "checksum" field of a UDP packet.
Checksum uint16
}
// UDP represents a UDP header stored in a byte array.
type UDP []byte
const (
// UDPMinimumSize is the minimum size of a valid UDP packet.
UDPMinimumSize = 8
// UDPMaximumSize is the maximum size of a valid UDP packet. The length field
// in the UDP header is 16 bits as per RFC 768.
UDPMaximumSize = math.MaxUint16
// UDPProtocolNumber is UDP's transport protocol number.
UDPProtocolNumber tcpip.TransportProtocolNumber = 17
)
// SourcePort returns the "source port" field of the UDP header.
func (b UDP) SourcePort() uint16 {
return binary.BigEndian.Uint16(b[udpSrcPort:])
}
// DestinationPort returns the "destination port" field of the UDP header.
func (b UDP) DestinationPort() uint16 {
return binary.BigEndian.Uint16(b[udpDstPort:])
}
// Length returns the "length" field of the UDP header.
func (b UDP) Length() uint16 {
return binary.BigEndian.Uint16(b[udpLength:])
}
// Payload returns the data contained in the UDP datagram.
func (b UDP) Payload() []byte {
return b[UDPMinimumSize:]
}
// Checksum returns the "checksum" field of the UDP header.
func (b UDP) Checksum() uint16 {
return binary.BigEndian.Uint16(b[udpChecksum:])
}
// SetSourcePort sets the "source port" field of the UDP header.
func (b UDP) SetSourcePort(port uint16) {
binary.BigEndian.PutUint16(b[udpSrcPort:], port)
}
// SetDestinationPort sets the "destination port" field of the UDP header.
func (b UDP) SetDestinationPort(port uint16) {
binary.BigEndian.PutUint16(b[udpDstPort:], port)
}
// SetChecksum sets the "checksum" field of the UDP header.
func (b UDP) SetChecksum(xsum uint16) {
checksum.Put(b[udpChecksum:], xsum)
}
// SetLength sets the "length" field of the UDP header.
func (b UDP) SetLength(length uint16) {
binary.BigEndian.PutUint16(b[udpLength:], length)
}
// CalculateChecksum calculates the checksum of the UDP packet, given the
// checksum of the network-layer pseudo-header and the checksum of the payload.
func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 {
// Calculate the rest of the checksum.
return checksum.Checksum(b[:UDPMinimumSize], partialChecksum)
}
// IsChecksumValid returns true iff the UDP header's checksum is valid.
func (b UDP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum uint16) bool {
xsum := PseudoHeaderChecksum(UDPProtocolNumber, dst, src, b.Length())
xsum = checksum.Combine(xsum, payloadChecksum)
return b.CalculateChecksum(xsum) == 0xffff
}
// Encode encodes all the fields of the UDP header.
func (b UDP) Encode(u *UDPFields) {
b.SetSourcePort(u.SrcPort)
b.SetDestinationPort(u.DstPort)
b.SetLength(u.Length)
b.SetChecksum(u.Checksum)
}
// SetSourcePortWithChecksumUpdate implements ChecksummableTransport.
func (b UDP) SetSourcePortWithChecksumUpdate(new uint16) {
old := b.SourcePort()
b.SetSourcePort(new)
b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
}
// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport.
func (b UDP) SetDestinationPortWithChecksumUpdate(new uint16) {
old := b.DestinationPort()
b.SetDestinationPort(new)
b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
}
// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport.
func (b UDP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) {
xsum := b.Checksum()
if fullChecksum {
xsum = ^xsum
}
xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new)
if fullChecksum {
xsum = ^xsum
}
b.SetChecksum(xsum)
}
// UDPValid returns true if the pkt has a valid UDP header. It checks whether:
// - The length field is too small.
// - The length field is too large.
// - The checksum is invalid.
//
// UDPValid corresponds to net/netfilter/nf_conntrack_proto_udp.c:udp_error.
func UDPValid(hdr UDP, payloadChecksum func() uint16, payloadSize uint16, netProto tcpip.NetworkProtocolNumber, srcAddr, dstAddr tcpip.Address, skipChecksumValidation bool) (lengthValid, csumValid bool) {
if length := hdr.Length(); length > payloadSize+UDPMinimumSize || length < UDPMinimumSize {
return false, false
}
if skipChecksumValidation {
return true, true
}
// On IPv4, UDP checksum is optional, and a zero value means the transmitter
// omitted the checksum generation, as per RFC 768:
//
// An all zero transmitted checksum value means that the transmitter
// generated no checksum (for debugging or for higher level protocols that
// don't care).
//
// On IPv6, UDP checksum is not optional, as per RFC 2460 Section 8.1:
//
// Unlike IPv4, when UDP packets are originated by an IPv6 node, the UDP
// checksum is not optional.
if netProto == IPv4ProtocolNumber && hdr.Checksum() == 0 {
return true, true
}
return true, hdr.IsChecksumValid(srcAddr, dstAddr, payloadChecksum())
}

View File

@@ -0,0 +1,94 @@
// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import "encoding/binary"
// These constants are declared in linux/virtio_net.h.
const (
_VIRTIO_NET_HDR_F_NEEDS_CSUM = 1
_VIRTIO_NET_HDR_GSO_NONE = 0
_VIRTIO_NET_HDR_GSO_TCPV4 = 1
_VIRTIO_NET_HDR_GSO_TCPV6 = 4
)
const (
// VirtioNetHeaderSize is the size of VirtioNetHeader in bytes.
VirtioNetHeaderSize = 10
)
// Offsets for fields in the virtio net header.
const (
flags = 0
gsoType = 1
hdrLen = 2
gsoSize = 4
csumStart = 6
csumOffset = 8
)
// VirtioNetHeaderFields is the Go equivalent of the struct declared in
// linux/virtio_net.h.
type VirtioNetHeaderFields struct {
Flags uint8
GSOType uint8
HdrLen uint16
GSOSize uint16
CSumStart uint16
CSumOffset uint16
}
// VirtioNetHeader represents a virtio net header stored in a byte array.
type VirtioNetHeader []byte
// Flags returns the "flags" field of the virtio net header.
func (v VirtioNetHeader) Flags() uint8 {
return uint8(v[flags])
}
// GSOType returns the "gsoType" field of the virtio net header.
func (v VirtioNetHeader) GSOType() uint8 {
return uint8(v[gsoType])
}
// HdrLen returns the "hdrLen" field of the virtio net header.
func (v VirtioNetHeader) HdrLen() uint16 {
return binary.BigEndian.Uint16(v[hdrLen:])
}
// GSOSize returns the "gsoSize" field of the virtio net header.
func (v VirtioNetHeader) GSOSize() uint16 {
return binary.BigEndian.Uint16(v[gsoSize:])
}
// CSumStart returns the "csumStart" field of the virtio net header.
func (v VirtioNetHeader) CSumStart() uint16 {
return binary.BigEndian.Uint16(v[csumStart:])
}
// CSumOffset returns the "csumOffset" field of the virtio net header.
func (v VirtioNetHeader) CSumOffset() uint16 {
return binary.BigEndian.Uint16(v[csumOffset:])
}
// Encode encodes all the fields of the virtio net header.
func (v VirtioNetHeader) Encode(f *VirtioNetHeaderFields) {
v[flags] = uint8(f.Flags)
v[gsoType] = uint8(f.GSOType)
binary.BigEndian.PutUint16(v[hdrLen:], f.HdrLen)
binary.BigEndian.PutUint16(v[gsoSize:], f.GSOSize)
binary.BigEndian.PutUint16(v[csumStart:], f.CSumStart)
binary.BigEndian.PutUint16(v[csumOffset:], f.CSumOffset)
}

View File

@@ -0,0 +1,48 @@
// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package tcp contains internal type definitions that are not expected to be
// used by anyone else outside pkg/tcpip.
package tcp
import (
"time"
"gvisor.dev/gvisor/pkg/tcpip"
)
// TSOffset is an offset applied to the value of the TSVal field in the TCP
// Timestamp option.
//
// +stateify savable
type TSOffset struct {
milliseconds uint32
}
// NewTSOffset creates a new TSOffset from milliseconds.
func NewTSOffset(milliseconds uint32) TSOffset {
return TSOffset{
milliseconds: milliseconds,
}
}
// TSVal applies the offset to now and returns the timestamp in milliseconds.
func (offset TSOffset) TSVal(now tcpip.MonotonicTime) uint32 {
return uint32(now.Sub(tcpip.MonotonicTime{}).Milliseconds()) + offset.milliseconds
}
// Elapsed calculates the elapsed time given now and the echoed back timestamp.
func (offset TSOffset) Elapsed(now tcpip.MonotonicTime, tsEcr uint32) time.Duration {
return time.Duration(offset.TSVal(now)-tsEcr) * time.Millisecond
}

View File

@@ -0,0 +1,38 @@
// automatically generated by stateify.
package tcp
import (
"context"
"gvisor.dev/gvisor/pkg/state"
)
func (offset *TSOffset) StateTypeName() string {
return "pkg/tcpip/internal/tcp.TSOffset"
}
func (offset *TSOffset) StateFields() []string {
return []string{
"milliseconds",
}
}
func (offset *TSOffset) beforeSave() {}
// +checklocksignore
func (offset *TSOffset) StateSave(stateSinkObject state.Sink) {
offset.beforeSave()
stateSinkObject.Save(0, &offset.milliseconds)
}
func (offset *TSOffset) afterLoad(context.Context) {}
// +checklocksignore
func (offset *TSOffset) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &offset.milliseconds)
}
func init() {
state.Register((*TSOffset)(nil))
}

View File

@@ -0,0 +1,93 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package hash contains utility functions for hashing.
package hash
import (
"encoding/binary"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
var hashIV = RandN32(1)[0]
// RandN32 generates a slice of n cryptographic random 32-bit numbers.
func RandN32(n int) []uint32 {
b := make([]byte, 4*n)
if _, err := rand.Read(b); err != nil {
panic("unable to get random numbers: " + err.Error())
}
r := make([]uint32, n)
for i := range r {
r[i] = binary.LittleEndian.Uint32(b[4*i : (4*i + 4)])
}
return r
}
// Hash3Words calculates the Jenkins hash of 3 32-bit words. This is adapted
// from linux.
func Hash3Words(a, b, c, initval uint32) uint32 {
const iv = 0xdeadbeef + (3 << 2)
initval += iv
a += initval
b += initval
c += initval
c ^= b
c -= rol32(b, 14)
a ^= c
a -= rol32(c, 11)
b ^= a
b -= rol32(a, 25)
c ^= b
c -= rol32(b, 16)
a ^= c
a -= rol32(c, 4)
b ^= a
b -= rol32(a, 14)
c ^= b
c -= rol32(b, 24)
return c
}
// IPv4FragmentHash computes the hash of the IPv4 fragment as suggested in RFC 791.
func IPv4FragmentHash(h header.IPv4) uint32 {
x := uint32(h.ID())<<16 | uint32(h.Protocol())
t := h.SourceAddress().As4()
y := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
t = h.DestinationAddress().As4()
z := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
return Hash3Words(x, y, z, hashIV)
}
// IPv6FragmentHash computes the hash of the ipv6 fragment.
// Unlike IPv4, the protocol is not used to compute the hash.
// RFC 2640 (sec 4.5) is not very sharp on this aspect.
// As a reference, also Linux ignores the protocol to compute
// the hash (inet6_hash_frag).
func IPv6FragmentHash(h header.IPv6, id uint32) uint32 {
t := h.SourceAddress().As16()
y := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
t = h.DestinationAddress().As16()
z := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
return Hash3Words(id, y, z, hashIV)
}
func rol32(v, shift uint32) uint32 {
return (v << shift) | (v >> ((-shift) & 31))
}

View File

@@ -0,0 +1,3 @@
// automatically generated by stateify.
package hash

View File

@@ -0,0 +1,374 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package fragmentation contains the implementation of IP fragmentation.
// It is based on RFC 791, RFC 815 and RFC 8200.
package fragmentation
import (
"errors"
"fmt"
"time"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
const (
// HighFragThreshold is the threshold at which we start trimming old
// fragmented packets. Linux uses a default value of 4 MB. See
// net.ipv4.ipfrag_high_thresh for more information.
HighFragThreshold = 4 << 20 // 4MB
// LowFragThreshold is the threshold we reach to when we start dropping
// older fragmented packets. It's important that we keep enough room for newer
// packets to be re-assembled. Hence, this needs to be lower than
// HighFragThreshold enough. Linux uses a default value of 3 MB. See
// net.ipv4.ipfrag_low_thresh for more information.
LowFragThreshold = 3 << 20 // 3MB
// minBlockSize is the minimum block size for fragments.
minBlockSize = 1
)
var (
// ErrInvalidArgs indicates to the caller that an invalid argument was
// provided.
ErrInvalidArgs = errors.New("invalid args")
// ErrFragmentOverlap indicates that, during reassembly, a fragment overlaps
// with another one.
ErrFragmentOverlap = errors.New("overlapping fragments")
// ErrFragmentConflict indicates that, during reassembly, some fragments are
// in conflict with one another.
ErrFragmentConflict = errors.New("conflicting fragments")
)
// FragmentID is the identifier for a fragment.
//
// +stateify savable
type FragmentID struct {
// Source is the source address of the fragment.
Source tcpip.Address
// Destination is the destination address of the fragment.
Destination tcpip.Address
// ID is the identification value of the fragment.
//
// This is a uint32 because IPv6 uses a 32-bit identification value.
ID uint32
// The protocol for the packet.
Protocol uint8
}
// Fragmentation is the main structure that other modules
// of the stack should use to implement IP Fragmentation.
//
// +stateify savable
type Fragmentation struct {
mu sync.Mutex `state:"nosave"`
highLimit int
lowLimit int
reassemblers map[FragmentID]*reassembler
rList reassemblerList
memSize int
timeout time.Duration
blockSize uint16
clock tcpip.Clock
releaseJob *tcpip.Job
timeoutHandler TimeoutHandler
}
// TimeoutHandler is consulted if a packet reassembly has timed out.
type TimeoutHandler interface {
// OnReassemblyTimeout will be called with the first fragment (or nil, if the
// first fragment has not been received) of a packet whose reassembly has
// timed out.
OnReassemblyTimeout(pkt *stack.PacketBuffer)
}
// NewFragmentation creates a new Fragmentation.
//
// blockSize specifies the fragment block size, in bytes.
//
// highMemoryLimit specifies the limit on the memory consumed
// by the fragments stored by Fragmentation (overhead of internal data-structures
// is not accounted). Fragments are dropped when the limit is reached.
//
// lowMemoryLimit specifies the limit on which we will reach by dropping
// fragments after reaching highMemoryLimit.
//
// reassemblingTimeout specifies the maximum time allowed to reassemble a packet.
// Fragments are lazily evicted only when a new a packet with an
// already existing fragmentation-id arrives after the timeout.
func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration, clock tcpip.Clock, timeoutHandler TimeoutHandler) *Fragmentation {
if lowMemoryLimit >= highMemoryLimit {
lowMemoryLimit = highMemoryLimit
}
if lowMemoryLimit < 0 {
lowMemoryLimit = 0
}
if blockSize < minBlockSize {
blockSize = minBlockSize
}
f := &Fragmentation{
reassemblers: make(map[FragmentID]*reassembler),
highLimit: highMemoryLimit,
lowLimit: lowMemoryLimit,
timeout: reassemblingTimeout,
blockSize: blockSize,
clock: clock,
timeoutHandler: timeoutHandler,
}
f.releaseJob = tcpip.NewJob(f.clock, &f.mu, f.releaseReassemblersLocked)
return f
}
// Process processes an incoming fragment belonging to an ID and returns a
// complete packet and its protocol number when all the packets belonging to
// that ID have been received.
//
// [first, last] is the range of the fragment bytes.
//
// first must be a multiple of the block size f is configured with. The size
// of the fragment data must be a multiple of the block size, unless there are
// no fragments following this fragment (more set to false).
//
// proto is the protocol number marked in the fragment being processed. It has
// to be given here outside of the FragmentID struct because IPv6 should not use
// the protocol to identify a fragment.
func (f *Fragmentation) Process(
id FragmentID, first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (
*stack.PacketBuffer, uint8, bool, error) {
if first > last {
return nil, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs)
}
if first%f.blockSize != 0 {
return nil, 0, false, fmt.Errorf("first=%d is not a multiple of block size=%d: %w", first, f.blockSize, ErrInvalidArgs)
}
fragmentSize := last - first + 1
if more && fragmentSize%f.blockSize != 0 {
return nil, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs)
}
if l := pkt.Data().Size(); l != int(fragmentSize) {
return nil, 0, false, fmt.Errorf("got fragment size=%d bytes not equal to the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs)
}
f.mu.Lock()
if f.reassemblers == nil {
return nil, 0, false, fmt.Errorf("Release() called before fragmentation processing could finish")
}
r, ok := f.reassemblers[id]
if !ok {
r = newReassembler(id, f.clock)
f.reassemblers[id] = r
wasEmpty := f.rList.Empty()
f.rList.PushFront(r)
if wasEmpty {
// If we have just pushed a first reassembler into an empty list, we
// should kickstart the release job. The release job will keep
// rescheduling itself until the list becomes empty.
f.releaseReassemblersLocked()
}
}
f.mu.Unlock()
resPkt, firstFragmentProto, done, memConsumed, err := r.process(first, last, more, proto, pkt)
if err != nil {
// We probably got an invalid sequence of fragments. Just
// discard the reassembler and move on.
f.mu.Lock()
f.release(r, false /* timedOut */)
f.mu.Unlock()
return nil, 0, false, fmt.Errorf("fragmentation processing error: %w", err)
}
f.mu.Lock()
f.memSize += memConsumed
if done {
f.release(r, false /* timedOut */)
}
// Evict reassemblers if we are consuming more memory than highLimit until
// we reach lowLimit.
if f.memSize > f.highLimit {
for f.memSize > f.lowLimit {
tail := f.rList.Back()
if tail == nil {
break
}
f.release(tail, false /* timedOut */)
}
}
f.mu.Unlock()
return resPkt, firstFragmentProto, done, nil
}
// Release releases all underlying resources.
func (f *Fragmentation) Release() {
f.mu.Lock()
defer f.mu.Unlock()
for _, r := range f.reassemblers {
f.release(r, false /* timedOut */)
}
f.reassemblers = nil
}
func (f *Fragmentation) release(r *reassembler, timedOut bool) {
// Before releasing a fragment we need to check if r is already marked as done.
// Otherwise, we would delete it twice.
if r.checkDoneOrMark() {
return
}
delete(f.reassemblers, r.id)
f.rList.Remove(r)
f.memSize -= r.memSize
if f.memSize < 0 {
log.Warningf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.memSize)
f.memSize = 0
}
if h := f.timeoutHandler; timedOut && h != nil {
h.OnReassemblyTimeout(r.pkt)
}
if r.pkt != nil {
r.pkt.DecRef()
r.pkt = nil
}
for _, h := range r.holes {
if h.pkt != nil {
h.pkt.DecRef()
h.pkt = nil
}
}
r.holes = nil
}
// releaseReassemblersLocked releases already-expired reassemblers, then
// schedules the job to call back itself for the remaining reassemblers if
// any. This function must be called with f.mu locked.
func (f *Fragmentation) releaseReassemblersLocked() {
now := f.clock.NowMonotonic()
for {
// The reassembler at the end of the list is the oldest.
r := f.rList.Back()
if r == nil {
// The list is empty.
break
}
elapsed := now.Sub(r.createdAt)
if f.timeout > elapsed {
// If the oldest reassembler has not expired, schedule the release
// job so that this function is called back when it has expired.
f.releaseJob.Schedule(f.timeout - elapsed)
break
}
// If the oldest reassembler has already expired, release it.
f.release(r, true /* timedOut*/)
}
}
// PacketFragmenter is the book-keeping struct for packet fragmentation.
type PacketFragmenter struct {
transportHeader []byte
data buffer.Buffer
reserve int
fragmentPayloadLen int
fragmentCount int
currentFragment int
fragmentOffset int
}
// MakePacketFragmenter prepares the struct needed for packet fragmentation.
//
// pkt is the packet to be fragmented.
//
// fragmentPayloadLen is the maximum number of bytes of fragmentable data a fragment can
// have.
//
// reserve is the number of bytes that should be reserved for the headers in
// each generated fragment.
func MakePacketFragmenter(pkt *stack.PacketBuffer, fragmentPayloadLen uint32, reserve int) PacketFragmenter {
// As per RFC 8200 Section 4.5, some IPv6 extension headers should not be
// repeated in each fragment. However we do not currently support any header
// of that kind yet, so the following computation is valid for both IPv4 and
// IPv6.
// TODO(gvisor.dev/issue/3912): Once Authentication or ESP Headers are
// supported for outbound packets, the fragmentable data should not include
// these headers.
var fragmentableData buffer.Buffer
fragmentableData.Append(pkt.TransportHeader().View())
pktBuf := pkt.Data().ToBuffer()
fragmentableData.Merge(&pktBuf)
fragmentCount := (uint32(fragmentableData.Size()) + fragmentPayloadLen - 1) / fragmentPayloadLen
return PacketFragmenter{
data: fragmentableData,
reserve: reserve,
fragmentPayloadLen: int(fragmentPayloadLen),
fragmentCount: int(fragmentCount),
}
}
// BuildNextFragment returns a packet with the payload of the next fragment,
// along with the fragment's offset, the number of bytes copied and a boolean
// indicating if there are more fragments left or not. If this function is
// called again after it indicated that no more fragments were left, it will
// panic.
//
// Note that the returned packet will not have its network and link headers
// populated, but space for them will be reserved. The transport header will be
// stored in the packet's data.
func (pf *PacketFragmenter) BuildNextFragment() (*stack.PacketBuffer, int, int, bool) {
if pf.currentFragment >= pf.fragmentCount {
panic("BuildNextFragment should not be called again after the last fragment was returned")
}
fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: pf.reserve,
})
// Copy data for the fragment.
copied := fragPkt.Data().ReadFrom(&pf.data, pf.fragmentPayloadLen)
offset := pf.fragmentOffset
pf.fragmentOffset += copied
pf.currentFragment++
more := pf.currentFragment != pf.fragmentCount
return fragPkt, offset, copied, more
}
// RemainingFragmentCount returns the number of fragments left to be built.
func (pf *PacketFragmenter) RemainingFragmentCount() int {
return pf.fragmentCount - pf.currentFragment
}
// Release frees resources owned by the packet fragmenter.
func (pf *PacketFragmenter) Release() {
pf.data.Release()
}

View File

@@ -0,0 +1,246 @@
// automatically generated by stateify.
package fragmentation
import (
"context"
"gvisor.dev/gvisor/pkg/state"
)
func (f *FragmentID) StateTypeName() string {
return "pkg/tcpip/network/internal/fragmentation.FragmentID"
}
func (f *FragmentID) StateFields() []string {
return []string{
"Source",
"Destination",
"ID",
"Protocol",
}
}
func (f *FragmentID) beforeSave() {}
// +checklocksignore
func (f *FragmentID) StateSave(stateSinkObject state.Sink) {
f.beforeSave()
stateSinkObject.Save(0, &f.Source)
stateSinkObject.Save(1, &f.Destination)
stateSinkObject.Save(2, &f.ID)
stateSinkObject.Save(3, &f.Protocol)
}
func (f *FragmentID) afterLoad(context.Context) {}
// +checklocksignore
func (f *FragmentID) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &f.Source)
stateSourceObject.Load(1, &f.Destination)
stateSourceObject.Load(2, &f.ID)
stateSourceObject.Load(3, &f.Protocol)
}
func (f *Fragmentation) StateTypeName() string {
return "pkg/tcpip/network/internal/fragmentation.Fragmentation"
}
func (f *Fragmentation) StateFields() []string {
return []string{
"highLimit",
"lowLimit",
"reassemblers",
"rList",
"memSize",
"timeout",
"blockSize",
"clock",
"releaseJob",
"timeoutHandler",
}
}
func (f *Fragmentation) beforeSave() {}
// +checklocksignore
func (f *Fragmentation) StateSave(stateSinkObject state.Sink) {
f.beforeSave()
stateSinkObject.Save(0, &f.highLimit)
stateSinkObject.Save(1, &f.lowLimit)
stateSinkObject.Save(2, &f.reassemblers)
stateSinkObject.Save(3, &f.rList)
stateSinkObject.Save(4, &f.memSize)
stateSinkObject.Save(5, &f.timeout)
stateSinkObject.Save(6, &f.blockSize)
stateSinkObject.Save(7, &f.clock)
stateSinkObject.Save(8, &f.releaseJob)
stateSinkObject.Save(9, &f.timeoutHandler)
}
func (f *Fragmentation) afterLoad(context.Context) {}
// +checklocksignore
func (f *Fragmentation) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &f.highLimit)
stateSourceObject.Load(1, &f.lowLimit)
stateSourceObject.Load(2, &f.reassemblers)
stateSourceObject.Load(3, &f.rList)
stateSourceObject.Load(4, &f.memSize)
stateSourceObject.Load(5, &f.timeout)
stateSourceObject.Load(6, &f.blockSize)
stateSourceObject.Load(7, &f.clock)
stateSourceObject.Load(8, &f.releaseJob)
stateSourceObject.Load(9, &f.timeoutHandler)
}
func (h *hole) StateTypeName() string {
return "pkg/tcpip/network/internal/fragmentation.hole"
}
func (h *hole) StateFields() []string {
return []string{
"first",
"last",
"filled",
"final",
"pkt",
}
}
func (h *hole) beforeSave() {}
// +checklocksignore
func (h *hole) StateSave(stateSinkObject state.Sink) {
h.beforeSave()
stateSinkObject.Save(0, &h.first)
stateSinkObject.Save(1, &h.last)
stateSinkObject.Save(2, &h.filled)
stateSinkObject.Save(3, &h.final)
stateSinkObject.Save(4, &h.pkt)
}
func (h *hole) afterLoad(context.Context) {}
// +checklocksignore
func (h *hole) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &h.first)
stateSourceObject.Load(1, &h.last)
stateSourceObject.Load(2, &h.filled)
stateSourceObject.Load(3, &h.final)
stateSourceObject.Load(4, &h.pkt)
}
func (r *reassembler) StateTypeName() string {
return "pkg/tcpip/network/internal/fragmentation.reassembler"
}
func (r *reassembler) StateFields() []string {
return []string{
"reassemblerEntry",
"id",
"memSize",
"proto",
"holes",
"filled",
"done",
"createdAt",
"pkt",
}
}
func (r *reassembler) beforeSave() {}
// +checklocksignore
func (r *reassembler) StateSave(stateSinkObject state.Sink) {
r.beforeSave()
stateSinkObject.Save(0, &r.reassemblerEntry)
stateSinkObject.Save(1, &r.id)
stateSinkObject.Save(2, &r.memSize)
stateSinkObject.Save(3, &r.proto)
stateSinkObject.Save(4, &r.holes)
stateSinkObject.Save(5, &r.filled)
stateSinkObject.Save(6, &r.done)
stateSinkObject.Save(7, &r.createdAt)
stateSinkObject.Save(8, &r.pkt)
}
func (r *reassembler) afterLoad(context.Context) {}
// +checklocksignore
func (r *reassembler) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &r.reassemblerEntry)
stateSourceObject.Load(1, &r.id)
stateSourceObject.Load(2, &r.memSize)
stateSourceObject.Load(3, &r.proto)
stateSourceObject.Load(4, &r.holes)
stateSourceObject.Load(5, &r.filled)
stateSourceObject.Load(6, &r.done)
stateSourceObject.Load(7, &r.createdAt)
stateSourceObject.Load(8, &r.pkt)
}
func (l *reassemblerList) StateTypeName() string {
return "pkg/tcpip/network/internal/fragmentation.reassemblerList"
}
func (l *reassemblerList) StateFields() []string {
return []string{
"head",
"tail",
}
}
func (l *reassemblerList) beforeSave() {}
// +checklocksignore
func (l *reassemblerList) StateSave(stateSinkObject state.Sink) {
l.beforeSave()
stateSinkObject.Save(0, &l.head)
stateSinkObject.Save(1, &l.tail)
}
func (l *reassemblerList) afterLoad(context.Context) {}
// +checklocksignore
func (l *reassemblerList) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &l.head)
stateSourceObject.Load(1, &l.tail)
}
func (e *reassemblerEntry) StateTypeName() string {
return "pkg/tcpip/network/internal/fragmentation.reassemblerEntry"
}
func (e *reassemblerEntry) StateFields() []string {
return []string{
"next",
"prev",
}
}
func (e *reassemblerEntry) beforeSave() {}
// +checklocksignore
func (e *reassemblerEntry) StateSave(stateSinkObject state.Sink) {
e.beforeSave()
stateSinkObject.Save(0, &e.next)
stateSinkObject.Save(1, &e.prev)
}
func (e *reassemblerEntry) afterLoad(context.Context) {}
// +checklocksignore
func (e *reassemblerEntry) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &e.next)
stateSourceObject.Load(1, &e.prev)
}
func init() {
state.Register((*FragmentID)(nil))
state.Register((*Fragmentation)(nil))
state.Register((*hole)(nil))
state.Register((*reassembler)(nil))
state.Register((*reassemblerList)(nil))
state.Register((*reassemblerEntry)(nil))
}

View File

@@ -0,0 +1,185 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package fragmentation
import (
"math"
"sort"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
// +stateify savable
type hole struct {
first uint16
last uint16
filled bool
final bool
// pkt is the fragment packet if hole is filled. We keep the whole pkt rather
// than the fragmented payload to prevent binding to specific buffer types.
pkt *stack.PacketBuffer
}
// +stateify savable
type reassembler struct {
reassemblerEntry
id FragmentID
memSize int
proto uint8
mu sync.Mutex `state:"nosave"`
holes []hole
filled int
done bool
createdAt tcpip.MonotonicTime
pkt *stack.PacketBuffer
}
func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler {
r := &reassembler{
id: id,
createdAt: clock.NowMonotonic(),
}
r.holes = append(r.holes, hole{
first: 0,
last: math.MaxUint16,
filled: false,
final: true,
})
return r
}
func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (*stack.PacketBuffer, uint8, bool, int, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.done {
// A concurrent goroutine might have already reassembled
// the packet and emptied the heap while this goroutine
// was waiting on the mutex. We don't have to do anything in this case.
return nil, 0, false, 0, nil
}
var holeFound bool
var memConsumed int
for i := range r.holes {
currentHole := &r.holes[i]
if last < currentHole.first || currentHole.last < first {
continue
}
// For IPv6, overlaps with an existing fragment are explicitly forbidden by
// RFC 8200 section 4.5:
// If any of the fragments being reassembled overlap with any other
// fragments being reassembled for the same packet, reassembly of that
// packet must be abandoned and all the fragments that have been received
// for that packet must be discarded, and no ICMP error messages should be
// sent.
//
// It is not explicitly forbidden for IPv4, but to keep parity with Linux we
// disallow it as well:
// https://github.com/torvalds/linux/blob/38525c6/net/ipv4/inet_fragment.c#L349
if first < currentHole.first || currentHole.last < last {
// Incoming fragment only partially fits in the free hole.
return nil, 0, false, 0, ErrFragmentOverlap
}
if !more {
if !currentHole.final || currentHole.filled && currentHole.last != last {
// We have another final fragment, which does not perfectly overlap.
return nil, 0, false, 0, ErrFragmentConflict
}
}
holeFound = true
if currentHole.filled {
// Incoming fragment is a duplicate.
continue
}
// We are populating the current hole with the payload and creating a new
// hole for any unfilled ranges on either end.
if first > currentHole.first {
r.holes = append(r.holes, hole{
first: currentHole.first,
last: first - 1,
filled: false,
final: false,
})
}
if last < currentHole.last && more {
r.holes = append(r.holes, hole{
first: last + 1,
last: currentHole.last,
filled: false,
final: currentHole.final,
})
currentHole.final = false
}
memConsumed = pkt.MemSize()
r.memSize += memConsumed
// Update the current hole to precisely match the incoming fragment.
r.holes[i] = hole{
first: first,
last: last,
filled: true,
final: currentHole.final,
pkt: pkt.IncRef(),
}
r.filled++
// For IPv6, it is possible to have different Protocol values between
// fragments of a packet (because, unlike IPv4, the Protocol is not used to
// identify a fragment). In this case, only the Protocol of the first
// fragment must be used as per RFC 8200 Section 4.5.
//
// TODO(gvisor.dev/issue/3648): During reassembly of an IPv6 packet, IP
// options received in the first fragment should be used - and they should
// override options from following fragments.
if first == 0 {
if r.pkt != nil {
r.pkt.DecRef()
}
r.pkt = pkt.IncRef()
r.proto = proto
}
break
}
if !holeFound {
// Incoming fragment is beyond end.
return nil, 0, false, 0, ErrFragmentConflict
}
// Check if all the holes have been filled and we are ready to reassemble.
if r.filled < len(r.holes) {
return nil, 0, false, memConsumed, nil
}
sort.Slice(r.holes, func(i, j int) bool {
return r.holes[i].first < r.holes[j].first
})
resPkt := r.holes[0].pkt.Clone()
for i := 1; i < len(r.holes); i++ {
stack.MergeFragment(resPkt, r.holes[i].pkt)
}
return resPkt, r.proto, true /* done */, memConsumed, nil
}
func (r *reassembler) checkDoneOrMark() bool {
r.mu.Lock()
prev := r.done
r.done = true
r.mu.Unlock()
return prev
}

View File

@@ -0,0 +1,239 @@
package fragmentation
// ElementMapper provides an identity mapping by default.
//
// This can be replaced to provide a struct that maps elements to linker
// objects, if they are not the same. An ElementMapper is not typically
// required if: Linker is left as is, Element is left as is, or Linker and
// Element are the same type.
type reassemblerElementMapper struct{}
// linkerFor maps an Element to a Linker.
//
// This default implementation should be inlined.
//
//go:nosplit
func (reassemblerElementMapper) linkerFor(elem *reassembler) *reassembler { return elem }
// List is an intrusive list. Entries can be added to or removed from the list
// in O(1) time and with no additional memory allocations.
//
// The zero value for List is an empty list ready to use.
//
// To iterate over a list (where l is a List):
//
// for e := l.Front(); e != nil; e = e.Next() {
// // do something with e.
// }
//
// +stateify savable
type reassemblerList struct {
head *reassembler
tail *reassembler
}
// Reset resets list l to the empty state.
func (l *reassemblerList) Reset() {
l.head = nil
l.tail = nil
}
// Empty returns true iff the list is empty.
//
//go:nosplit
func (l *reassemblerList) Empty() bool {
return l.head == nil
}
// Front returns the first element of list l or nil.
//
//go:nosplit
func (l *reassemblerList) Front() *reassembler {
return l.head
}
// Back returns the last element of list l or nil.
//
//go:nosplit
func (l *reassemblerList) Back() *reassembler {
return l.tail
}
// Len returns the number of elements in the list.
//
// NOTE: This is an O(n) operation.
//
//go:nosplit
func (l *reassemblerList) Len() (count int) {
for e := l.Front(); e != nil; e = (reassemblerElementMapper{}.linkerFor(e)).Next() {
count++
}
return count
}
// PushFront inserts the element e at the front of list l.
//
//go:nosplit
func (l *reassemblerList) PushFront(e *reassembler) {
linker := reassemblerElementMapper{}.linkerFor(e)
linker.SetNext(l.head)
linker.SetPrev(nil)
if l.head != nil {
reassemblerElementMapper{}.linkerFor(l.head).SetPrev(e)
} else {
l.tail = e
}
l.head = e
}
// PushFrontList inserts list m at the start of list l, emptying m.
//
//go:nosplit
func (l *reassemblerList) PushFrontList(m *reassemblerList) {
if l.head == nil {
l.head = m.head
l.tail = m.tail
} else if m.head != nil {
reassemblerElementMapper{}.linkerFor(l.head).SetPrev(m.tail)
reassemblerElementMapper{}.linkerFor(m.tail).SetNext(l.head)
l.head = m.head
}
m.head = nil
m.tail = nil
}
// PushBack inserts the element e at the back of list l.
//
//go:nosplit
func (l *reassemblerList) PushBack(e *reassembler) {
linker := reassemblerElementMapper{}.linkerFor(e)
linker.SetNext(nil)
linker.SetPrev(l.tail)
if l.tail != nil {
reassemblerElementMapper{}.linkerFor(l.tail).SetNext(e)
} else {
l.head = e
}
l.tail = e
}
// PushBackList inserts list m at the end of list l, emptying m.
//
//go:nosplit
func (l *reassemblerList) PushBackList(m *reassemblerList) {
if l.head == nil {
l.head = m.head
l.tail = m.tail
} else if m.head != nil {
reassemblerElementMapper{}.linkerFor(l.tail).SetNext(m.head)
reassemblerElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
l.tail = m.tail
}
m.head = nil
m.tail = nil
}
// InsertAfter inserts e after b.
//
//go:nosplit
func (l *reassemblerList) InsertAfter(b, e *reassembler) {
bLinker := reassemblerElementMapper{}.linkerFor(b)
eLinker := reassemblerElementMapper{}.linkerFor(e)
a := bLinker.Next()
eLinker.SetNext(a)
eLinker.SetPrev(b)
bLinker.SetNext(e)
if a != nil {
reassemblerElementMapper{}.linkerFor(a).SetPrev(e)
} else {
l.tail = e
}
}
// InsertBefore inserts e before a.
//
//go:nosplit
func (l *reassemblerList) InsertBefore(a, e *reassembler) {
aLinker := reassemblerElementMapper{}.linkerFor(a)
eLinker := reassemblerElementMapper{}.linkerFor(e)
b := aLinker.Prev()
eLinker.SetNext(a)
eLinker.SetPrev(b)
aLinker.SetPrev(e)
if b != nil {
reassemblerElementMapper{}.linkerFor(b).SetNext(e)
} else {
l.head = e
}
}
// Remove removes e from l.
//
//go:nosplit
func (l *reassemblerList) Remove(e *reassembler) {
linker := reassemblerElementMapper{}.linkerFor(e)
prev := linker.Prev()
next := linker.Next()
if prev != nil {
reassemblerElementMapper{}.linkerFor(prev).SetNext(next)
} else if l.head == e {
l.head = next
}
if next != nil {
reassemblerElementMapper{}.linkerFor(next).SetPrev(prev)
} else if l.tail == e {
l.tail = prev
}
linker.SetNext(nil)
linker.SetPrev(nil)
}
// Entry is a default implementation of Linker. Users can add anonymous fields
// of this type to their structs to make them automatically implement the
// methods needed by List.
//
// +stateify savable
type reassemblerEntry struct {
next *reassembler
prev *reassembler
}
// Next returns the entry that follows e in the list.
//
//go:nosplit
func (e *reassemblerEntry) Next() *reassembler {
return e.next
}
// Prev returns the entry that precedes e in the list.
//
//go:nosplit
func (e *reassemblerEntry) Prev() *reassembler {
return e.prev
}
// SetNext assigns 'entry' as the entry that follows e in the list.
//
//go:nosplit
func (e *reassemblerEntry) SetNext(elem *reassembler) {
e.next = elem
}
// SetPrev assigns 'entry' as the entry that precedes e in the list.
//
//go:nosplit
func (e *reassemblerEntry) SetPrev(elem *reassembler) {
e.prev = elem
}

View File

@@ -0,0 +1,304 @@
// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package ip holds IPv4/IPv6 common utilities.
package ip
import (
"bytes"
"fmt"
"io"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
type extendRequest int
const (
notRequested extendRequest = iota
requested
extended
)
// +stateify savable
type dadState struct {
nonce []byte
extendRequest extendRequest
done *bool
timer tcpip.Timer
completionHandlers []stack.DADCompletionHandler
}
// DADProtocol is a protocol whose core state machine can be represented by DAD.
type DADProtocol interface {
// SendDADMessage attempts to send a DAD probe message.
SendDADMessage(tcpip.Address, []byte) tcpip.Error
}
// DADOptions holds options for DAD.
//
// +stateify savable
type DADOptions struct {
Clock tcpip.Clock
// TODO(b/341946753): Restore when netstack is savable.
SecureRNG io.Reader `state:"nosave"`
NonceSize uint8
ExtendDADTransmits uint8
Protocol DADProtocol
NICID tcpip.NICID
}
// DAD performs duplicate address detection for addresses.
//
// +stateify savable
type DAD struct {
opts DADOptions
configs stack.DADConfigurations
protocolMU sync.Locker `state:"nosave"`
addresses map[tcpip.Address]dadState
}
// Init initializes the DAD state.
//
// Must only be called once for the lifetime of d; Init will panic if it is
// called twice.
//
// The lock will only be taken when timers fire.
func (d *DAD) Init(protocolMU sync.Locker, configs stack.DADConfigurations, opts DADOptions) {
if d.addresses != nil {
panic("attempted to initialize DAD state twice")
}
if opts.NonceSize != 0 && opts.ExtendDADTransmits == 0 {
panic(fmt.Sprintf("given a non-zero value for NonceSize (%d) but zero for ExtendDADTransmits", opts.NonceSize))
}
configs.Validate()
*d = DAD{
opts: opts,
configs: configs,
protocolMU: protocolMU,
addresses: make(map[tcpip.Address]dadState),
}
}
// CheckDuplicateAddressLocked performs DAD for an address, calling the
// completion handler once DAD resolves.
//
// If DAD is already performing for the provided address, h will be called when
// the currently running process completes.
//
// Precondition: d.protocolMU must be locked.
func (d *DAD) CheckDuplicateAddressLocked(addr tcpip.Address, h stack.DADCompletionHandler) stack.DADCheckAddressDisposition {
if d.configs.DupAddrDetectTransmits == 0 {
return stack.DADDisabled
}
ret := stack.DADAlreadyRunning
s, ok := d.addresses[addr]
if !ok {
ret = stack.DADStarting
remaining := d.configs.DupAddrDetectTransmits
// Protected by d.protocolMU.
done := false
s = dadState{
done: &done,
timer: d.opts.Clock.AfterFunc(0, func() {
dadDone := remaining == 0
nonce, earlyReturn := func() ([]byte, bool) {
d.protocolMU.Lock()
defer d.protocolMU.Unlock()
if done {
return nil, true
}
s, ok := d.addresses[addr]
if !ok {
panic(fmt.Sprintf("dad: timer fired but missing state for %s on NIC(%d)", addr, d.opts.NICID))
}
// As per RFC 7527 section 4
//
// If any probe is looped back within RetransTimer milliseconds
// after having sent DupAddrDetectTransmits NS(DAD) messages, the
// interface continues with another MAX_MULTICAST_SOLICIT number of
// NS(DAD) messages transmitted RetransTimer milliseconds apart.
if dadDone && s.extendRequest == requested {
dadDone = false
remaining = d.opts.ExtendDADTransmits
s.extendRequest = extended
}
if !dadDone && d.opts.NonceSize != 0 {
if s.nonce == nil {
s.nonce = make([]byte, d.opts.NonceSize)
}
if n, err := io.ReadFull(d.opts.SecureRNG, s.nonce); err != nil {
panic(fmt.Sprintf("SecureRNG.Read(...): %s", err))
} else if n != len(s.nonce) {
panic(fmt.Sprintf("expected to read %d bytes from secure RNG, only read %d bytes", len(s.nonce), n))
}
}
d.addresses[addr] = s
return s.nonce, false
}()
if earlyReturn {
return
}
var err tcpip.Error
if !dadDone {
err = d.opts.Protocol.SendDADMessage(addr, nonce)
}
d.protocolMU.Lock()
defer d.protocolMU.Unlock()
if done {
return
}
s, ok := d.addresses[addr]
if !ok {
panic(fmt.Sprintf("dad: timer fired but missing state for %s on NIC(%d)", addr, d.opts.NICID))
}
if !dadDone && err == nil {
remaining--
s.timer.Reset(d.configs.RetransmitTimer)
return
}
// At this point we know that either DAD has resolved or we hit an error
// sending the last DAD message. Either way, clear the DAD state.
done = false
s.timer.Stop()
delete(d.addresses, addr)
var res stack.DADResult = &stack.DADSucceeded{}
if err != nil {
res = &stack.DADError{Err: err}
}
for _, h := range s.completionHandlers {
h(res)
}
}),
}
}
s.completionHandlers = append(s.completionHandlers, h)
d.addresses[addr] = s
return ret
}
// ExtendIfNonceEqualLockedDisposition enumerates the possible results from
// ExtendIfNonceEqualLocked.
type ExtendIfNonceEqualLockedDisposition int
const (
// Extended indicates that the DAD process was extended.
Extended ExtendIfNonceEqualLockedDisposition = iota
// AlreadyExtended indicates that the DAD process was already extended.
AlreadyExtended
// NoDADStateFound indicates that DAD state was not found for the address.
NoDADStateFound
// NonceDisabled indicates that nonce values are not sent with DAD messages.
NonceDisabled
// NonceNotEqual indicates that the nonce value passed and the nonce in the
// last send DAD message are not equal.
NonceNotEqual
)
// ExtendIfNonceEqualLocked extends the DAD process if the provided nonce is the
// same as the nonce sent in the last DAD message.
//
// Precondition: d.protocolMU must be locked.
func (d *DAD) ExtendIfNonceEqualLocked(addr tcpip.Address, nonce []byte) ExtendIfNonceEqualLockedDisposition {
s, ok := d.addresses[addr]
if !ok {
return NoDADStateFound
}
if d.opts.NonceSize == 0 {
return NonceDisabled
}
if s.extendRequest != notRequested {
return AlreadyExtended
}
// As per RFC 7527 section 4
//
// If any probe is looped back within RetransTimer milliseconds after having
// sent DupAddrDetectTransmits NS(DAD) messages, the interface continues
// with another MAX_MULTICAST_SOLICIT number of NS(DAD) messages transmitted
// RetransTimer milliseconds apart.
//
// If a DAD message has already been sent and the nonce value we observed is
// the same as the nonce value we last sent, then we assume our probe was
// looped back and request an extension to the DAD process.
//
// Note, the first DAD message is sent asynchronously so we need to make sure
// that we sent a DAD message by checking if we have a nonce value set.
if s.nonce != nil && bytes.Equal(s.nonce, nonce) {
s.extendRequest = requested
d.addresses[addr] = s
return Extended
}
return NonceNotEqual
}
// StopLocked stops a currently running DAD process.
//
// Precondition: d.protocolMU must be locked.
func (d *DAD) StopLocked(addr tcpip.Address, reason stack.DADResult) {
s, ok := d.addresses[addr]
if !ok {
return
}
*s.done = true
s.timer.Stop()
delete(d.addresses, addr)
for _, h := range s.completionHandlers {
h(reason)
}
}
// SetConfigsLocked sets the DAD configurations.
//
// Precondition: d.protocolMU must be locked.
func (d *DAD) SetConfigsLocked(c stack.DADConfigurations) {
c.Validate()
d.configs = c
}

View File

@@ -0,0 +1,129 @@
// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ip
import (
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
)
// ForwardingError represents an error that occurred while trying to forward
// a packet.
type ForwardingError interface {
isForwardingError()
fmt.Stringer
}
// ErrTTLExceeded indicates that the received packet's TTL has been exceeded.
type ErrTTLExceeded struct{}
func (*ErrTTLExceeded) isForwardingError() {}
func (*ErrTTLExceeded) String() string { return "ttl exceeded" }
// ErrOutgoingDeviceNoBufferSpace indicates that the outgoing device does not
// have enough space to hold a buffer.
type ErrOutgoingDeviceNoBufferSpace struct{}
func (*ErrOutgoingDeviceNoBufferSpace) isForwardingError() {}
func (*ErrOutgoingDeviceNoBufferSpace) String() string { return "no device buffer space" }
// ErrParameterProblem indicates the received packet had a problem with an IP
// parameter.
type ErrParameterProblem struct{}
func (*ErrParameterProblem) isForwardingError() {}
func (*ErrParameterProblem) String() string { return "parameter problem" }
// ErrInitializingSourceAddress indicates the received packet had a source
// address that may only be used on the local network as part of initialization
// work.
type ErrInitializingSourceAddress struct{}
func (*ErrInitializingSourceAddress) isForwardingError() {}
func (*ErrInitializingSourceAddress) String() string { return "initializing source address" }
// ErrLinkLocalSourceAddress indicates the received packet had a link-local
// source address.
type ErrLinkLocalSourceAddress struct{}
func (*ErrLinkLocalSourceAddress) isForwardingError() {}
func (*ErrLinkLocalSourceAddress) String() string { return "link local source address" }
// ErrLinkLocalDestinationAddress indicates the received packet had a link-local
// destination address.
type ErrLinkLocalDestinationAddress struct{}
func (*ErrLinkLocalDestinationAddress) isForwardingError() {}
func (*ErrLinkLocalDestinationAddress) String() string { return "link local destination address" }
// ErrHostUnreachable indicates that the destination host could not be reached.
type ErrHostUnreachable struct{}
func (*ErrHostUnreachable) isForwardingError() {}
func (*ErrHostUnreachable) String() string { return "no route to host" }
// ErrMessageTooLong indicates the packet was too big for the outgoing MTU.
//
// +stateify savable
type ErrMessageTooLong struct{}
func (*ErrMessageTooLong) isForwardingError() {}
func (*ErrMessageTooLong) String() string { return "message too long" }
// ErrNoMulticastPendingQueueBufferSpace indicates that a multicast packet
// could not be added to the pending packet queue due to insufficient buffer
// space.
//
// +stateify savable
type ErrNoMulticastPendingQueueBufferSpace struct{}
func (*ErrNoMulticastPendingQueueBufferSpace) isForwardingError() {}
func (*ErrNoMulticastPendingQueueBufferSpace) String() string { return "no buffer space" }
// ErrUnexpectedMulticastInputInterface indicates that the interface that the
// packet arrived on did not match the routes expected input interface.
type ErrUnexpectedMulticastInputInterface struct{}
func (*ErrUnexpectedMulticastInputInterface) isForwardingError() {}
func (*ErrUnexpectedMulticastInputInterface) String() string { return "unexpected input interface" }
// ErrUnknownOutputEndpoint indicates that the output endpoint associated with
// a route could not be found.
type ErrUnknownOutputEndpoint struct{}
func (*ErrUnknownOutputEndpoint) isForwardingError() {}
func (*ErrUnknownOutputEndpoint) String() string { return "unknown endpoint" }
// ErrOther indicates the packet coould not be forwarded for a reason
// captured by the contained error.
type ErrOther struct {
Err tcpip.Error
}
func (*ErrOther) isForwardingError() {}
func (e *ErrOther) String() string { return fmt.Sprintf("other tcpip error: %s", e.Err) }

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,432 @@
// automatically generated by stateify.
package ip
import (
"context"
"gvisor.dev/gvisor/pkg/state"
)
func (d *dadState) StateTypeName() string {
return "pkg/tcpip/network/internal/ip.dadState"
}
func (d *dadState) StateFields() []string {
return []string{
"nonce",
"extendRequest",
"done",
"timer",
"completionHandlers",
}
}
func (d *dadState) beforeSave() {}
// +checklocksignore
func (d *dadState) StateSave(stateSinkObject state.Sink) {
d.beforeSave()
stateSinkObject.Save(0, &d.nonce)
stateSinkObject.Save(1, &d.extendRequest)
stateSinkObject.Save(2, &d.done)
stateSinkObject.Save(3, &d.timer)
stateSinkObject.Save(4, &d.completionHandlers)
}
func (d *dadState) afterLoad(context.Context) {}
// +checklocksignore
func (d *dadState) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &d.nonce)
stateSourceObject.Load(1, &d.extendRequest)
stateSourceObject.Load(2, &d.done)
stateSourceObject.Load(3, &d.timer)
stateSourceObject.Load(4, &d.completionHandlers)
}
func (d *DADOptions) StateTypeName() string {
return "pkg/tcpip/network/internal/ip.DADOptions"
}
func (d *DADOptions) StateFields() []string {
return []string{
"Clock",
"NonceSize",
"ExtendDADTransmits",
"Protocol",
"NICID",
}
}
func (d *DADOptions) beforeSave() {}
// +checklocksignore
func (d *DADOptions) StateSave(stateSinkObject state.Sink) {
d.beforeSave()
stateSinkObject.Save(0, &d.Clock)
stateSinkObject.Save(1, &d.NonceSize)
stateSinkObject.Save(2, &d.ExtendDADTransmits)
stateSinkObject.Save(3, &d.Protocol)
stateSinkObject.Save(4, &d.NICID)
}
func (d *DADOptions) afterLoad(context.Context) {}
// +checklocksignore
func (d *DADOptions) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &d.Clock)
stateSourceObject.Load(1, &d.NonceSize)
stateSourceObject.Load(2, &d.ExtendDADTransmits)
stateSourceObject.Load(3, &d.Protocol)
stateSourceObject.Load(4, &d.NICID)
}
func (d *DAD) StateTypeName() string {
return "pkg/tcpip/network/internal/ip.DAD"
}
func (d *DAD) StateFields() []string {
return []string{
"opts",
"configs",
"addresses",
}
}
func (d *DAD) beforeSave() {}
// +checklocksignore
func (d *DAD) StateSave(stateSinkObject state.Sink) {
d.beforeSave()
stateSinkObject.Save(0, &d.opts)
stateSinkObject.Save(1, &d.configs)
stateSinkObject.Save(2, &d.addresses)
}
func (d *DAD) afterLoad(context.Context) {}
// +checklocksignore
func (d *DAD) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &d.opts)
stateSourceObject.Load(1, &d.configs)
stateSourceObject.Load(2, &d.addresses)
}
func (e *ErrMessageTooLong) StateTypeName() string {
return "pkg/tcpip/network/internal/ip.ErrMessageTooLong"
}
func (e *ErrMessageTooLong) StateFields() []string {
return []string{}
}
func (e *ErrMessageTooLong) beforeSave() {}
// +checklocksignore
func (e *ErrMessageTooLong) StateSave(stateSinkObject state.Sink) {
e.beforeSave()
}
func (e *ErrMessageTooLong) afterLoad(context.Context) {}
// +checklocksignore
func (e *ErrMessageTooLong) StateLoad(ctx context.Context, stateSourceObject state.Source) {
}
func (e *ErrNoMulticastPendingQueueBufferSpace) StateTypeName() string {
return "pkg/tcpip/network/internal/ip.ErrNoMulticastPendingQueueBufferSpace"
}
func (e *ErrNoMulticastPendingQueueBufferSpace) StateFields() []string {
return []string{}
}
func (e *ErrNoMulticastPendingQueueBufferSpace) beforeSave() {}
// +checklocksignore
func (e *ErrNoMulticastPendingQueueBufferSpace) StateSave(stateSinkObject state.Sink) {
e.beforeSave()
}
func (e *ErrNoMulticastPendingQueueBufferSpace) afterLoad(context.Context) {}
// +checklocksignore
func (e *ErrNoMulticastPendingQueueBufferSpace) StateLoad(ctx context.Context, stateSourceObject state.Source) {
}
func (m *multicastGroupState) StateTypeName() string {
return "pkg/tcpip/network/internal/ip.multicastGroupState"
}
func (m *multicastGroupState) StateFields() []string {
return []string{
"joins",
"transmissionLeft",
"lastToSendReport",
"delayedReportJob",
"queriedIncludeSources",
"deleteScheduled",
}
}
func (m *multicastGroupState) beforeSave() {}
// +checklocksignore
func (m *multicastGroupState) StateSave(stateSinkObject state.Sink) {
m.beforeSave()
stateSinkObject.Save(0, &m.joins)
stateSinkObject.Save(1, &m.transmissionLeft)
stateSinkObject.Save(2, &m.lastToSendReport)
stateSinkObject.Save(3, &m.delayedReportJob)
stateSinkObject.Save(4, &m.queriedIncludeSources)
stateSinkObject.Save(5, &m.deleteScheduled)
}
func (m *multicastGroupState) afterLoad(context.Context) {}
// +checklocksignore
func (m *multicastGroupState) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &m.joins)
stateSourceObject.Load(1, &m.transmissionLeft)
stateSourceObject.Load(2, &m.lastToSendReport)
stateSourceObject.Load(3, &m.delayedReportJob)
stateSourceObject.Load(4, &m.queriedIncludeSources)
stateSourceObject.Load(5, &m.deleteScheduled)
}
func (g *GenericMulticastProtocolOptions) StateTypeName() string {
return "pkg/tcpip/network/internal/ip.GenericMulticastProtocolOptions"
}
func (g *GenericMulticastProtocolOptions) StateFields() []string {
return []string{
"Clock",
"Protocol",
"MaxUnsolicitedReportDelay",
}
}
func (g *GenericMulticastProtocolOptions) beforeSave() {}
// +checklocksignore
func (g *GenericMulticastProtocolOptions) StateSave(stateSinkObject state.Sink) {
g.beforeSave()
stateSinkObject.Save(0, &g.Clock)
stateSinkObject.Save(1, &g.Protocol)
stateSinkObject.Save(2, &g.MaxUnsolicitedReportDelay)
}
func (g *GenericMulticastProtocolOptions) afterLoad(context.Context) {}
// +checklocksignore
func (g *GenericMulticastProtocolOptions) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &g.Clock)
stateSourceObject.Load(1, &g.Protocol)
stateSourceObject.Load(2, &g.MaxUnsolicitedReportDelay)
}
func (g *GenericMulticastProtocolState) StateTypeName() string {
return "pkg/tcpip/network/internal/ip.GenericMulticastProtocolState"
}
func (g *GenericMulticastProtocolState) StateFields() []string {
return []string{
"opts",
"memberships",
"robustnessVariable",
"queryInterval",
"mode",
"modeTimer",
"generalQueryV2Timer",
"stateChangedReportV2Timer",
"stateChangedReportV2TimerSet",
}
}
func (g *GenericMulticastProtocolState) beforeSave() {}
// +checklocksignore
func (g *GenericMulticastProtocolState) StateSave(stateSinkObject state.Sink) {
g.beforeSave()
stateSinkObject.Save(0, &g.opts)
stateSinkObject.Save(1, &g.memberships)
stateSinkObject.Save(2, &g.robustnessVariable)
stateSinkObject.Save(3, &g.queryInterval)
stateSinkObject.Save(4, &g.mode)
stateSinkObject.Save(5, &g.modeTimer)
stateSinkObject.Save(6, &g.generalQueryV2Timer)
stateSinkObject.Save(7, &g.stateChangedReportV2Timer)
stateSinkObject.Save(8, &g.stateChangedReportV2TimerSet)
}
func (g *GenericMulticastProtocolState) afterLoad(context.Context) {}
// +checklocksignore
func (g *GenericMulticastProtocolState) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &g.opts)
stateSourceObject.Load(1, &g.memberships)
stateSourceObject.Load(2, &g.robustnessVariable)
stateSourceObject.Load(3, &g.queryInterval)
stateSourceObject.Load(4, &g.mode)
stateSourceObject.Load(5, &g.modeTimer)
stateSourceObject.Load(6, &g.generalQueryV2Timer)
stateSourceObject.Load(7, &g.stateChangedReportV2Timer)
stateSourceObject.Load(8, &g.stateChangedReportV2TimerSet)
}
func (m *MultiCounterIPForwardingStats) StateTypeName() string {
return "pkg/tcpip/network/internal/ip.MultiCounterIPForwardingStats"
}
func (m *MultiCounterIPForwardingStats) StateFields() []string {
return []string{
"Unrouteable",
"ExhaustedTTL",
"InitializingSource",
"LinkLocalSource",
"LinkLocalDestination",
"PacketTooBig",
"HostUnreachable",
"ExtensionHeaderProblem",
"UnexpectedMulticastInputInterface",
"UnknownOutputEndpoint",
"NoMulticastPendingQueueBufferSpace",
"OutgoingDeviceNoBufferSpace",
"Errors",
}
}
func (m *MultiCounterIPForwardingStats) beforeSave() {}
// +checklocksignore
func (m *MultiCounterIPForwardingStats) StateSave(stateSinkObject state.Sink) {
m.beforeSave()
stateSinkObject.Save(0, &m.Unrouteable)
stateSinkObject.Save(1, &m.ExhaustedTTL)
stateSinkObject.Save(2, &m.InitializingSource)
stateSinkObject.Save(3, &m.LinkLocalSource)
stateSinkObject.Save(4, &m.LinkLocalDestination)
stateSinkObject.Save(5, &m.PacketTooBig)
stateSinkObject.Save(6, &m.HostUnreachable)
stateSinkObject.Save(7, &m.ExtensionHeaderProblem)
stateSinkObject.Save(8, &m.UnexpectedMulticastInputInterface)
stateSinkObject.Save(9, &m.UnknownOutputEndpoint)
stateSinkObject.Save(10, &m.NoMulticastPendingQueueBufferSpace)
stateSinkObject.Save(11, &m.OutgoingDeviceNoBufferSpace)
stateSinkObject.Save(12, &m.Errors)
}
func (m *MultiCounterIPForwardingStats) afterLoad(context.Context) {}
// +checklocksignore
func (m *MultiCounterIPForwardingStats) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &m.Unrouteable)
stateSourceObject.Load(1, &m.ExhaustedTTL)
stateSourceObject.Load(2, &m.InitializingSource)
stateSourceObject.Load(3, &m.LinkLocalSource)
stateSourceObject.Load(4, &m.LinkLocalDestination)
stateSourceObject.Load(5, &m.PacketTooBig)
stateSourceObject.Load(6, &m.HostUnreachable)
stateSourceObject.Load(7, &m.ExtensionHeaderProblem)
stateSourceObject.Load(8, &m.UnexpectedMulticastInputInterface)
stateSourceObject.Load(9, &m.UnknownOutputEndpoint)
stateSourceObject.Load(10, &m.NoMulticastPendingQueueBufferSpace)
stateSourceObject.Load(11, &m.OutgoingDeviceNoBufferSpace)
stateSourceObject.Load(12, &m.Errors)
}
func (m *MultiCounterIPStats) StateTypeName() string {
return "pkg/tcpip/network/internal/ip.MultiCounterIPStats"
}
func (m *MultiCounterIPStats) StateFields() []string {
return []string{
"PacketsReceived",
"ValidPacketsReceived",
"DisabledPacketsReceived",
"InvalidDestinationAddressesReceived",
"InvalidSourceAddressesReceived",
"PacketsDelivered",
"PacketsSent",
"OutgoingPacketErrors",
"MalformedPacketsReceived",
"MalformedFragmentsReceived",
"IPTablesPreroutingDropped",
"IPTablesInputDropped",
"IPTablesForwardDropped",
"IPTablesOutputDropped",
"IPTablesPostroutingDropped",
"OptionTimestampReceived",
"OptionRecordRouteReceived",
"OptionRouterAlertReceived",
"OptionUnknownReceived",
"Forwarding",
}
}
func (m *MultiCounterIPStats) beforeSave() {}
// +checklocksignore
func (m *MultiCounterIPStats) StateSave(stateSinkObject state.Sink) {
m.beforeSave()
stateSinkObject.Save(0, &m.PacketsReceived)
stateSinkObject.Save(1, &m.ValidPacketsReceived)
stateSinkObject.Save(2, &m.DisabledPacketsReceived)
stateSinkObject.Save(3, &m.InvalidDestinationAddressesReceived)
stateSinkObject.Save(4, &m.InvalidSourceAddressesReceived)
stateSinkObject.Save(5, &m.PacketsDelivered)
stateSinkObject.Save(6, &m.PacketsSent)
stateSinkObject.Save(7, &m.OutgoingPacketErrors)
stateSinkObject.Save(8, &m.MalformedPacketsReceived)
stateSinkObject.Save(9, &m.MalformedFragmentsReceived)
stateSinkObject.Save(10, &m.IPTablesPreroutingDropped)
stateSinkObject.Save(11, &m.IPTablesInputDropped)
stateSinkObject.Save(12, &m.IPTablesForwardDropped)
stateSinkObject.Save(13, &m.IPTablesOutputDropped)
stateSinkObject.Save(14, &m.IPTablesPostroutingDropped)
stateSinkObject.Save(15, &m.OptionTimestampReceived)
stateSinkObject.Save(16, &m.OptionRecordRouteReceived)
stateSinkObject.Save(17, &m.OptionRouterAlertReceived)
stateSinkObject.Save(18, &m.OptionUnknownReceived)
stateSinkObject.Save(19, &m.Forwarding)
}
func (m *MultiCounterIPStats) afterLoad(context.Context) {}
// +checklocksignore
func (m *MultiCounterIPStats) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &m.PacketsReceived)
stateSourceObject.Load(1, &m.ValidPacketsReceived)
stateSourceObject.Load(2, &m.DisabledPacketsReceived)
stateSourceObject.Load(3, &m.InvalidDestinationAddressesReceived)
stateSourceObject.Load(4, &m.InvalidSourceAddressesReceived)
stateSourceObject.Load(5, &m.PacketsDelivered)
stateSourceObject.Load(6, &m.PacketsSent)
stateSourceObject.Load(7, &m.OutgoingPacketErrors)
stateSourceObject.Load(8, &m.MalformedPacketsReceived)
stateSourceObject.Load(9, &m.MalformedFragmentsReceived)
stateSourceObject.Load(10, &m.IPTablesPreroutingDropped)
stateSourceObject.Load(11, &m.IPTablesInputDropped)
stateSourceObject.Load(12, &m.IPTablesForwardDropped)
stateSourceObject.Load(13, &m.IPTablesOutputDropped)
stateSourceObject.Load(14, &m.IPTablesPostroutingDropped)
stateSourceObject.Load(15, &m.OptionTimestampReceived)
stateSourceObject.Load(16, &m.OptionRecordRouteReceived)
stateSourceObject.Load(17, &m.OptionRouterAlertReceived)
stateSourceObject.Load(18, &m.OptionUnknownReceived)
stateSourceObject.Load(19, &m.Forwarding)
}
func init() {
state.Register((*dadState)(nil))
state.Register((*DADOptions)(nil))
state.Register((*DAD)(nil))
state.Register((*ErrMessageTooLong)(nil))
state.Register((*ErrNoMulticastPendingQueueBufferSpace)(nil))
state.Register((*multicastGroupState)(nil))
state.Register((*GenericMulticastProtocolOptions)(nil))
state.Register((*GenericMulticastProtocolState)(nil))
state.Register((*MultiCounterIPForwardingStats)(nil))
state.Register((*MultiCounterIPStats)(nil))
}

View File

@@ -0,0 +1,214 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ip
import "gvisor.dev/gvisor/pkg/tcpip"
// LINT.IfChange(MultiCounterIPForwardingStats)
// MultiCounterIPForwardingStats holds IP forwarding statistics. Each counter
// may have several versions.
//
// +stateify savable
type MultiCounterIPForwardingStats struct {
// Unrouteable is the number of IP packets received which were dropped
// because the netstack could not construct a route to their
// destination.
Unrouteable tcpip.MultiCounterStat
// ExhaustedTTL is the number of IP packets received which were dropped
// because their TTL was exhausted.
ExhaustedTTL tcpip.MultiCounterStat
// InitializingSource is the number of IP packets which were dropped
// because they contained a source address that may only be used on the local
// network as part of initialization work.
InitializingSource tcpip.MultiCounterStat
// LinkLocalSource is the number of IP packets which were dropped
// because they contained a link-local source address.
LinkLocalSource tcpip.MultiCounterStat
// LinkLocalDestination is the number of IP packets which were dropped
// because they contained a link-local destination address.
LinkLocalDestination tcpip.MultiCounterStat
// PacketTooBig is the number of IP packets which were dropped because they
// were too big for the outgoing MTU.
PacketTooBig tcpip.MultiCounterStat
// HostUnreachable is the number of IP packets received which could not be
// successfully forwarded due to an unresolvable next hop.
HostUnreachable tcpip.MultiCounterStat
// ExtensionHeaderProblem is the number of IP packets which were dropped
// because of a problem encountered when processing an IPv6 extension
// header.
ExtensionHeaderProblem tcpip.MultiCounterStat
// UnexpectedMulticastInputInterface is the number of multicast packets that
// were received on an interface that did not match the corresponding route's
// expected input interface.
UnexpectedMulticastInputInterface tcpip.MultiCounterStat
// UnknownOutputEndpoint is the number of packets that could not be forwarded
// because the output endpoint could not be found.
UnknownOutputEndpoint tcpip.MultiCounterStat
// NoMulticastPendingQueueBufferSpace is the number of multicast packets that
// were dropped due to insufficient buffer space in the pending packet queue.
NoMulticastPendingQueueBufferSpace tcpip.MultiCounterStat
// OutgoingDeviceNoBufferSpace is the number of packets that were dropped due
// to insufficient space in the outgoing device.
OutgoingDeviceNoBufferSpace tcpip.MultiCounterStat
// Errors is the number of IP packets received which could not be
// successfully forwarded.
Errors tcpip.MultiCounterStat
}
// Init sets internal counters to track a and b counters.
func (m *MultiCounterIPForwardingStats) Init(a, b *tcpip.IPForwardingStats) {
m.Unrouteable.Init(a.Unrouteable, b.Unrouteable)
m.Errors.Init(a.Errors, b.Errors)
m.InitializingSource.Init(a.InitializingSource, b.InitializingSource)
m.LinkLocalSource.Init(a.LinkLocalSource, b.LinkLocalSource)
m.LinkLocalDestination.Init(a.LinkLocalDestination, b.LinkLocalDestination)
m.ExtensionHeaderProblem.Init(a.ExtensionHeaderProblem, b.ExtensionHeaderProblem)
m.PacketTooBig.Init(a.PacketTooBig, b.PacketTooBig)
m.ExhaustedTTL.Init(a.ExhaustedTTL, b.ExhaustedTTL)
m.HostUnreachable.Init(a.HostUnreachable, b.HostUnreachable)
m.UnexpectedMulticastInputInterface.Init(a.UnexpectedMulticastInputInterface, b.UnexpectedMulticastInputInterface)
m.UnknownOutputEndpoint.Init(a.UnknownOutputEndpoint, b.UnknownOutputEndpoint)
m.NoMulticastPendingQueueBufferSpace.Init(a.NoMulticastPendingQueueBufferSpace, b.NoMulticastPendingQueueBufferSpace)
m.OutgoingDeviceNoBufferSpace.Init(a.OutgoingDeviceNoBufferSpace, b.OutgoingDeviceNoBufferSpace)
}
// LINT.ThenChange(:MultiCounterIPForwardingStats, ../../../tcpip.go:IPForwardingStats)
// LINT.IfChange(MultiCounterIPStats)
// MultiCounterIPStats holds IP statistics, each counter may have several
// versions.
//
// +stateify savable
type MultiCounterIPStats struct {
// PacketsReceived is the number of IP packets received from the link
// layer.
PacketsReceived tcpip.MultiCounterStat
// ValidPacketsReceived is the number of valid IP packets that reached the IP
// layer.
ValidPacketsReceived tcpip.MultiCounterStat
// DisabledPacketsReceived is the number of IP packets received from
// the link layer when the IP layer is disabled.
DisabledPacketsReceived tcpip.MultiCounterStat
// InvalidDestinationAddressesReceived is the number of IP packets
// received with an unknown or invalid destination address.
InvalidDestinationAddressesReceived tcpip.MultiCounterStat
// InvalidSourceAddressesReceived is the number of IP packets received
// with a source address that should never have been received on the
// wire.
InvalidSourceAddressesReceived tcpip.MultiCounterStat
// PacketsDelivered is the number of incoming IP packets successfully
// delivered to the transport layer.
PacketsDelivered tcpip.MultiCounterStat
// PacketsSent is the number of IP packets sent via WritePacket.
PacketsSent tcpip.MultiCounterStat
// OutgoingPacketErrors is the number of IP packets which failed to
// write to a link-layer endpoint.
OutgoingPacketErrors tcpip.MultiCounterStat
// MalformedPacketsReceived is the number of IP Packets that were
// dropped due to the IP packet header failing validation checks.
MalformedPacketsReceived tcpip.MultiCounterStat
// MalformedFragmentsReceived is the number of IP Fragments that were
// dropped due to the fragment failing validation checks.
MalformedFragmentsReceived tcpip.MultiCounterStat
// IPTablesPreroutingDropped is the number of IP packets dropped in the
// Prerouting chain.
IPTablesPreroutingDropped tcpip.MultiCounterStat
// IPTablesInputDropped is the number of IP packets dropped in the
// Input chain.
IPTablesInputDropped tcpip.MultiCounterStat
// IPTablesForwardDropped is the number of IP packets dropped in the
// Forward chain.
IPTablesForwardDropped tcpip.MultiCounterStat
// IPTablesOutputDropped is the number of IP packets dropped in the
// Output chain.
IPTablesOutputDropped tcpip.MultiCounterStat
// IPTablesPostroutingDropped is the number of IP packets dropped in
// the Postrouting chain.
IPTablesPostroutingDropped tcpip.MultiCounterStat
// TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option
// stats out of IPStats.
// OptionTimestampReceived is the number of Timestamp options seen.
OptionTimestampReceived tcpip.MultiCounterStat
// OptionRecordRouteReceived is the number of Record Route options
// seen.
OptionRecordRouteReceived tcpip.MultiCounterStat
// OptionRouterAlertReceived is the number of Router Alert options
// seen.
OptionRouterAlertReceived tcpip.MultiCounterStat
// OptionUnknownReceived is the number of unknown IP options seen.
OptionUnknownReceived tcpip.MultiCounterStat
// Forwarding collects stats related to IP forwarding.
Forwarding MultiCounterIPForwardingStats
}
// Init sets internal counters to track a and b counters.
func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) {
m.PacketsReceived.Init(a.PacketsReceived, b.PacketsReceived)
m.ValidPacketsReceived.Init(a.ValidPacketsReceived, b.ValidPacketsReceived)
m.DisabledPacketsReceived.Init(a.DisabledPacketsReceived, b.DisabledPacketsReceived)
m.InvalidDestinationAddressesReceived.Init(a.InvalidDestinationAddressesReceived, b.InvalidDestinationAddressesReceived)
m.InvalidSourceAddressesReceived.Init(a.InvalidSourceAddressesReceived, b.InvalidSourceAddressesReceived)
m.PacketsDelivered.Init(a.PacketsDelivered, b.PacketsDelivered)
m.PacketsSent.Init(a.PacketsSent, b.PacketsSent)
m.OutgoingPacketErrors.Init(a.OutgoingPacketErrors, b.OutgoingPacketErrors)
m.MalformedPacketsReceived.Init(a.MalformedPacketsReceived, b.MalformedPacketsReceived)
m.MalformedFragmentsReceived.Init(a.MalformedFragmentsReceived, b.MalformedFragmentsReceived)
m.IPTablesPreroutingDropped.Init(a.IPTablesPreroutingDropped, b.IPTablesPreroutingDropped)
m.IPTablesInputDropped.Init(a.IPTablesInputDropped, b.IPTablesInputDropped)
m.IPTablesForwardDropped.Init(a.IPTablesForwardDropped, b.IPTablesForwardDropped)
m.IPTablesOutputDropped.Init(a.IPTablesOutputDropped, b.IPTablesOutputDropped)
m.IPTablesPostroutingDropped.Init(a.IPTablesPostroutingDropped, b.IPTablesPostroutingDropped)
m.OptionTimestampReceived.Init(a.OptionTimestampReceived, b.OptionTimestampReceived)
m.OptionRecordRouteReceived.Init(a.OptionRecordRouteReceived, b.OptionRecordRouteReceived)
m.OptionRouterAlertReceived.Init(a.OptionRouterAlertReceived, b.OptionRouterAlertReceived)
m.OptionUnknownReceived.Init(a.OptionUnknownReceived, b.OptionUnknownReceived)
m.Forwarding.Init(&a.Forwarding, &b.Forwarding)
}
// LINT.ThenChange(:MultiCounterIPStats, ../../../tcpip.go:IPStats)

View File

@@ -0,0 +1,137 @@
// automatically generated by stateify.
package multicast
import (
"context"
"gvisor.dev/gvisor/pkg/state"
)
func (r *RouteTable) StateTypeName() string {
return "pkg/tcpip/network/internal/multicast.RouteTable"
}
func (r *RouteTable) StateFields() []string {
return []string{
"installedRoutes",
"pendingRoutes",
"cleanupPendingRoutesTimer",
"isCleanupRoutineRunning",
"config",
}
}
func (r *RouteTable) beforeSave() {}
// +checklocksignore
func (r *RouteTable) StateSave(stateSinkObject state.Sink) {
r.beforeSave()
stateSinkObject.Save(0, &r.installedRoutes)
stateSinkObject.Save(1, &r.pendingRoutes)
stateSinkObject.Save(2, &r.cleanupPendingRoutesTimer)
stateSinkObject.Save(3, &r.isCleanupRoutineRunning)
stateSinkObject.Save(4, &r.config)
}
func (r *RouteTable) afterLoad(context.Context) {}
// +checklocksignore
func (r *RouteTable) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &r.installedRoutes)
stateSourceObject.Load(1, &r.pendingRoutes)
stateSourceObject.Load(2, &r.cleanupPendingRoutesTimer)
stateSourceObject.Load(3, &r.isCleanupRoutineRunning)
stateSourceObject.Load(4, &r.config)
}
func (r *InstalledRoute) StateTypeName() string {
return "pkg/tcpip/network/internal/multicast.InstalledRoute"
}
func (r *InstalledRoute) StateFields() []string {
return []string{
"MulticastRoute",
"lastUsedTimestamp",
}
}
func (r *InstalledRoute) beforeSave() {}
// +checklocksignore
func (r *InstalledRoute) StateSave(stateSinkObject state.Sink) {
r.beforeSave()
stateSinkObject.Save(0, &r.MulticastRoute)
stateSinkObject.Save(1, &r.lastUsedTimestamp)
}
func (r *InstalledRoute) afterLoad(context.Context) {}
// +checklocksignore
func (r *InstalledRoute) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &r.MulticastRoute)
stateSourceObject.Load(1, &r.lastUsedTimestamp)
}
func (p *PendingRoute) StateTypeName() string {
return "pkg/tcpip/network/internal/multicast.PendingRoute"
}
func (p *PendingRoute) StateFields() []string {
return []string{
"packets",
"expiration",
}
}
func (p *PendingRoute) beforeSave() {}
// +checklocksignore
func (p *PendingRoute) StateSave(stateSinkObject state.Sink) {
p.beforeSave()
stateSinkObject.Save(0, &p.packets)
stateSinkObject.Save(1, &p.expiration)
}
func (p *PendingRoute) afterLoad(context.Context) {}
// +checklocksignore
func (p *PendingRoute) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &p.packets)
stateSourceObject.Load(1, &p.expiration)
}
func (c *Config) StateTypeName() string {
return "pkg/tcpip/network/internal/multicast.Config"
}
func (c *Config) StateFields() []string {
return []string{
"MaxPendingQueueSize",
"Clock",
}
}
func (c *Config) beforeSave() {}
// +checklocksignore
func (c *Config) StateSave(stateSinkObject state.Sink) {
c.beforeSave()
stateSinkObject.Save(0, &c.MaxPendingQueueSize)
stateSinkObject.Save(1, &c.Clock)
}
func (c *Config) afterLoad(context.Context) {}
// +checklocksignore
func (c *Config) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &c.MaxPendingQueueSize)
stateSourceObject.Load(1, &c.Clock)
}
func init() {
state.Register((*RouteTable)(nil))
state.Register((*InstalledRoute)(nil))
state.Register((*PendingRoute)(nil))
state.Register((*Config)(nil))
}

View File

@@ -0,0 +1,446 @@
// Copyright 2022 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package multicast contains utilities for supporting multicast routing.
package multicast
import (
"errors"
"fmt"
"sync"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
// RouteTable represents a multicast routing table.
//
// +stateify savable
type RouteTable struct {
// Internally, installed and pending routes are stored and locked separately
// A couple of reasons for structuring the table this way:
//
// 1. We can avoid write locking installed routes when pending packets are
// being queued. In other words, the happy path of reading installed
// routes doesn't require an exclusive lock.
// 2. The cleanup process for expired routes only needs to operate on pending
// routes. Like above, a write lock on the installed routes can be
// avoided.
// 3. This structure is similar to the Linux implementation:
// https://github.com/torvalds/linux/blob/cffb2b72d3e/include/linux/mroute_base.h#L250
// The installedMu lock should typically be acquired before the pendingMu
// lock. This ensures that installed routes can continue to be read even when
// the pending routes are write locked.
installedMu sync.RWMutex `state:"nosave"`
// Maintaining pointers ensures that the installed routes are exclusively
// locked only when a route is being installed.
// +checklocks:installedMu
installedRoutes map[stack.UnicastSourceAndMulticastDestination]*InstalledRoute
pendingMu sync.RWMutex `state:"nosave"`
// +checklocks:pendingMu
pendingRoutes map[stack.UnicastSourceAndMulticastDestination]PendingRoute
// cleanupPendingRoutesTimer is a timer that triggers a routine to remove
// pending routes that are expired.
// +checklocks:pendingMu
cleanupPendingRoutesTimer tcpip.Timer
// +checklocks:pendingMu
isCleanupRoutineRunning bool
config Config
}
var (
// ErrNoBufferSpace indicates that no buffer space is available in the
// pending route packet queue.
ErrNoBufferSpace = errors.New("unable to queue packet, no buffer space available")
// ErrMissingClock indicates that a clock was not provided as part of the
// Config, but is required.
ErrMissingClock = errors.New("clock must not be nil")
// ErrAlreadyInitialized indicates that RouteTable.Init was already invoked.
ErrAlreadyInitialized = errors.New("table is already initialized")
)
// InstalledRoute represents a route that is in the installed state.
//
// If a route is in the installed state, then it may be used to forward
// multicast packets.
//
// +stateify savable
type InstalledRoute struct {
stack.MulticastRoute
lastUsedTimestampMu sync.RWMutex `state:"nosave"`
// +checklocks:lastUsedTimestampMu
lastUsedTimestamp tcpip.MonotonicTime
}
// LastUsedTimestamp returns a monotonic timestamp that corresponds to the last
// time the route was used or updated.
func (r *InstalledRoute) LastUsedTimestamp() tcpip.MonotonicTime {
r.lastUsedTimestampMu.RLock()
defer r.lastUsedTimestampMu.RUnlock()
return r.lastUsedTimestamp
}
// SetLastUsedTimestamp sets the time that the route was last used.
//
// The timestamp is only updated if it occurs after the currently set
// timestamp. Callers should invoke this anytime the route is used to forward a
// packet.
func (r *InstalledRoute) SetLastUsedTimestamp(monotonicTime tcpip.MonotonicTime) {
r.lastUsedTimestampMu.Lock()
defer r.lastUsedTimestampMu.Unlock()
if monotonicTime.After(r.lastUsedTimestamp) {
r.lastUsedTimestamp = monotonicTime
}
}
// PendingRoute represents a route that is in the "pending" state.
//
// A route is in the pending state if an installed route does not yet exist
// for the entry. For such routes, packets are added to an expiring queue until
// a route is installed.
//
// +stateify savable
type PendingRoute struct {
packets []*stack.PacketBuffer
// expiration is the timestamp at which the pending route should be expired.
//
// If this value is before the current time, then this pending route will
// be dropped.
expiration tcpip.MonotonicTime
}
func (p *PendingRoute) releasePackets() {
for _, pkt := range p.packets {
pkt.DecRef()
}
}
func (p *PendingRoute) isExpired(currentTime tcpip.MonotonicTime) bool {
return currentTime.After(p.expiration)
}
const (
// DefaultMaxPendingQueueSize corresponds to the number of elements that can
// be in the packet queue for a pending route.
//
// Matches the Linux default queue size:
// https://github.com/torvalds/linux/blob/26291c54e11/net/ipv6/ip6mr.c#L1186
DefaultMaxPendingQueueSize uint8 = 3
// DefaultPendingRouteExpiration is the default maximum lifetime of a pending
// route.
//
// Matches the Linux default:
// https://github.com/torvalds/linux/blob/26291c54e11/net/ipv6/ip6mr.c#L991
DefaultPendingRouteExpiration time.Duration = 10 * time.Second
// DefaultCleanupInterval is the default frequency of the routine that
// expires pending routes.
//
// Matches the Linux default:
// https://github.com/torvalds/linux/blob/26291c54e11/net/ipv6/ip6mr.c#L793
DefaultCleanupInterval time.Duration = 10 * time.Second
)
// Config represents the options for configuring a RouteTable.
//
// +stateify savable
type Config struct {
// MaxPendingQueueSize corresponds to the maximum number of queued packets
// for a pending route.
//
// If the caller attempts to queue a packet and the queue already contains
// MaxPendingQueueSize elements, then the packet will be rejected and should
// not be forwarded.
MaxPendingQueueSize uint8
// Clock represents the clock that should be used to obtain the current time.
//
// This field is required and must have a non-nil value.
Clock tcpip.Clock
}
// DefaultConfig returns the default configuration for the table.
func DefaultConfig(clock tcpip.Clock) Config {
return Config{
MaxPendingQueueSize: DefaultMaxPendingQueueSize,
Clock: clock,
}
}
// Init initializes the RouteTable with the provided config.
//
// An error is returned if the config is not valid.
//
// Must be called before any other function on the table.
func (r *RouteTable) Init(config Config) error {
r.installedMu.Lock()
defer r.installedMu.Unlock()
r.pendingMu.Lock()
defer r.pendingMu.Unlock()
if r.installedRoutes != nil {
return ErrAlreadyInitialized
}
if config.Clock == nil {
return ErrMissingClock
}
r.config = config
r.installedRoutes = make(map[stack.UnicastSourceAndMulticastDestination]*InstalledRoute)
r.pendingRoutes = make(map[stack.UnicastSourceAndMulticastDestination]PendingRoute)
return nil
}
// Close cleans up resources held by the table.
//
// Calling this will stop the cleanup routine and release any packets owned by
// the table.
func (r *RouteTable) Close() {
r.pendingMu.Lock()
defer r.pendingMu.Unlock()
if r.cleanupPendingRoutesTimer != nil {
r.cleanupPendingRoutesTimer.Stop()
}
for key, route := range r.pendingRoutes {
delete(r.pendingRoutes, key)
route.releasePackets()
}
}
// maybeStopCleanupRoutine stops the pending routes cleanup routine if no
// pending routes exist.
//
// Returns true if the timer is not running. Otherwise, returns false.
//
// +checklocks:r.pendingMu
func (r *RouteTable) maybeStopCleanupRoutineLocked() bool {
if !r.isCleanupRoutineRunning {
return true
}
if len(r.pendingRoutes) == 0 {
r.cleanupPendingRoutesTimer.Stop()
r.isCleanupRoutineRunning = false
return true
}
return false
}
func (r *RouteTable) cleanupPendingRoutes() {
currentTime := r.config.Clock.NowMonotonic()
r.pendingMu.Lock()
defer r.pendingMu.Unlock()
for key, route := range r.pendingRoutes {
if route.isExpired(currentTime) {
delete(r.pendingRoutes, key)
route.releasePackets()
}
}
if stopped := r.maybeStopCleanupRoutineLocked(); !stopped {
r.cleanupPendingRoutesTimer.Reset(DefaultCleanupInterval)
}
}
func (r *RouteTable) newPendingRoute() PendingRoute {
return PendingRoute{
packets: make([]*stack.PacketBuffer, 0, r.config.MaxPendingQueueSize),
expiration: r.config.Clock.NowMonotonic().Add(DefaultPendingRouteExpiration),
}
}
// NewInstalledRoute instantiates an installed route for the table.
func (r *RouteTable) NewInstalledRoute(route stack.MulticastRoute) *InstalledRoute {
return &InstalledRoute{
MulticastRoute: route,
lastUsedTimestamp: r.config.Clock.NowMonotonic(),
}
}
// GetRouteResult represents the result of calling GetRouteOrInsertPending.
type GetRouteResult struct {
// GetRouteResultState signals the result of calling GetRouteOrInsertPending.
GetRouteResultState GetRouteResultState
// InstalledRoute represents the existing installed route. This field will
// only be populated if the GetRouteResultState is InstalledRouteFound.
InstalledRoute *InstalledRoute
}
// GetRouteResultState signals the result of calling GetRouteOrInsertPending.
type GetRouteResultState uint8
const (
// InstalledRouteFound indicates that an InstalledRoute was found.
InstalledRouteFound GetRouteResultState = iota
// PacketQueuedInPendingRoute indicates that the packet was queued in an
// existing pending route.
PacketQueuedInPendingRoute
// NoRouteFoundAndPendingInserted indicates that no route was found and that
// a pending route was newly inserted into the RouteTable.
NoRouteFoundAndPendingInserted
)
func (e GetRouteResultState) String() string {
switch e {
case InstalledRouteFound:
return "InstalledRouteFound"
case PacketQueuedInPendingRoute:
return "PacketQueuedInPendingRoute"
case NoRouteFoundAndPendingInserted:
return "NoRouteFoundAndPendingInserted"
default:
return fmt.Sprintf("%d", uint8(e))
}
}
// GetRouteOrInsertPending attempts to fetch the installed route that matches
// the provided key.
//
// If no matching installed route is found, then the pkt is cloned and queued
// in a pending route. The GetRouteResult.GetRouteResultState will indicate
// whether the pkt was queued in a new pending route or an existing one.
//
// If the relevant pending route queue is at max capacity, then returns false.
// Otherwise, returns true.
func (r *RouteTable) GetRouteOrInsertPending(key stack.UnicastSourceAndMulticastDestination, pkt *stack.PacketBuffer) (GetRouteResult, bool) {
r.installedMu.RLock()
defer r.installedMu.RUnlock()
if route, ok := r.installedRoutes[key]; ok {
return GetRouteResult{GetRouteResultState: InstalledRouteFound, InstalledRoute: route}, true
}
r.pendingMu.Lock()
defer r.pendingMu.Unlock()
pendingRoute, getRouteResultState := r.getOrCreatePendingRouteRLocked(key)
if len(pendingRoute.packets) >= int(r.config.MaxPendingQueueSize) {
// The incoming packet is rejected if the pending queue is already at max
// capacity. This behavior matches the Linux implementation:
// https://github.com/torvalds/linux/blob/ae085d7f936/net/ipv4/ipmr.c#L1147
return GetRouteResult{}, false
}
pendingRoute.packets = append(pendingRoute.packets, pkt.Clone())
r.pendingRoutes[key] = pendingRoute
if !r.isCleanupRoutineRunning {
// The cleanup routine isn't running, but should be. Start it.
if r.cleanupPendingRoutesTimer == nil {
r.cleanupPendingRoutesTimer = r.config.Clock.AfterFunc(DefaultCleanupInterval, r.cleanupPendingRoutes)
} else {
r.cleanupPendingRoutesTimer.Reset(DefaultCleanupInterval)
}
r.isCleanupRoutineRunning = true
}
return GetRouteResult{GetRouteResultState: getRouteResultState, InstalledRoute: nil}, true
}
// +checklocks:r.pendingMu
func (r *RouteTable) getOrCreatePendingRouteRLocked(key stack.UnicastSourceAndMulticastDestination) (PendingRoute, GetRouteResultState) {
if pendingRoute, ok := r.pendingRoutes[key]; ok {
return pendingRoute, PacketQueuedInPendingRoute
}
return r.newPendingRoute(), NoRouteFoundAndPendingInserted
}
// AddInstalledRoute adds the provided route to the table.
//
// Packets that were queued while the route was in the pending state are
// returned. The caller assumes ownership of these packets and is responsible
// for forwarding and releasing them. If an installed route already exists for
// the provided key, then it is overwritten.
func (r *RouteTable) AddInstalledRoute(key stack.UnicastSourceAndMulticastDestination, route *InstalledRoute) []*stack.PacketBuffer {
r.installedMu.Lock()
defer r.installedMu.Unlock()
r.installedRoutes[key] = route
r.pendingMu.Lock()
pendingRoute, ok := r.pendingRoutes[key]
delete(r.pendingRoutes, key)
// No need to reset the timer here. The cleanup routine is responsible for
// doing so.
_ = r.maybeStopCleanupRoutineLocked()
r.pendingMu.Unlock()
// Ignore the pending route if it is expired. It may be in this state since
// the cleanup process is only run periodically.
if !ok || pendingRoute.isExpired(r.config.Clock.NowMonotonic()) {
pendingRoute.releasePackets()
return nil
}
return pendingRoute.packets
}
// RemoveInstalledRoute deletes any installed route that matches the provided
// key.
//
// Returns true if a route was removed. Otherwise returns false.
func (r *RouteTable) RemoveInstalledRoute(key stack.UnicastSourceAndMulticastDestination) bool {
r.installedMu.Lock()
defer r.installedMu.Unlock()
if _, ok := r.installedRoutes[key]; ok {
delete(r.installedRoutes, key)
return true
}
return false
}
// RemoveAllInstalledRoutes removes all installed routes from the table.
func (r *RouteTable) RemoveAllInstalledRoutes() {
r.installedMu.Lock()
defer r.installedMu.Unlock()
for key := range r.installedRoutes {
delete(r.installedRoutes, key)
}
}
// GetLastUsedTimestamp returns a monotonic timestamp that represents the last
// time the route that matches the provided key was used or updated.
//
// Returns true if a matching route was found. Otherwise returns false.
func (r *RouteTable) GetLastUsedTimestamp(key stack.UnicastSourceAndMulticastDestination) (tcpip.MonotonicTime, bool) {
r.installedMu.RLock()
defer r.installedMu.RUnlock()
if route, ok := r.installedRoutes[key]; ok {
return route.LastUsedTimestamp(), true
}
return tcpip.MonotonicTime{}, false
}

View File

@@ -0,0 +1,821 @@
// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ipv4
import (
"fmt"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/checksum"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
// icmpv4DestinationUnreachableSockError is a general ICMPv4 Destination
// Unreachable error.
//
// +stateify savable
type icmpv4DestinationUnreachableSockError struct{}
// Origin implements tcpip.SockErrorCause.
func (*icmpv4DestinationUnreachableSockError) Origin() tcpip.SockErrOrigin {
return tcpip.SockExtErrorOriginICMP
}
// Type implements tcpip.SockErrorCause.
func (*icmpv4DestinationUnreachableSockError) Type() uint8 {
return uint8(header.ICMPv4DstUnreachable)
}
// Info implements tcpip.SockErrorCause.
func (*icmpv4DestinationUnreachableSockError) Info() uint32 {
return 0
}
var _ stack.TransportError = (*icmpv4DestinationHostUnreachableSockError)(nil)
// icmpv4DestinationHostUnreachableSockError is an ICMPv4 Destination Host
// Unreachable error.
//
// It indicates that a packet was not able to reach the destination host.
//
// +stateify savable
type icmpv4DestinationHostUnreachableSockError struct {
icmpv4DestinationUnreachableSockError
}
// Code implements tcpip.SockErrorCause.
func (*icmpv4DestinationHostUnreachableSockError) Code() uint8 {
return uint8(header.ICMPv4HostUnreachable)
}
// Kind implements stack.TransportError.
func (*icmpv4DestinationHostUnreachableSockError) Kind() stack.TransportErrorKind {
return stack.DestinationHostUnreachableTransportError
}
var _ stack.TransportError = (*icmpv4DestinationNetUnreachableSockError)(nil)
// icmpv4DestinationNetUnreachableSockError is an ICMPv4 Destination Net
// Unreachable error.
//
// It indicates that a packet was not able to reach the destination network.
//
// +stateify savable
type icmpv4DestinationNetUnreachableSockError struct {
icmpv4DestinationUnreachableSockError
}
// Code implements tcpip.SockErrorCause.
func (*icmpv4DestinationNetUnreachableSockError) Code() uint8 {
return uint8(header.ICMPv4NetUnreachable)
}
// Kind implements stack.TransportError.
func (*icmpv4DestinationNetUnreachableSockError) Kind() stack.TransportErrorKind {
return stack.DestinationNetworkUnreachableTransportError
}
var _ stack.TransportError = (*icmpv4DestinationPortUnreachableSockError)(nil)
// icmpv4DestinationPortUnreachableSockError is an ICMPv4 Destination Port
// Unreachable error.
//
// It indicates that a packet reached the destination host, but the transport
// protocol was not active on the destination port.
//
// +stateify savable
type icmpv4DestinationPortUnreachableSockError struct {
icmpv4DestinationUnreachableSockError
}
// Code implements tcpip.SockErrorCause.
func (*icmpv4DestinationPortUnreachableSockError) Code() uint8 {
return uint8(header.ICMPv4PortUnreachable)
}
// Kind implements stack.TransportError.
func (*icmpv4DestinationPortUnreachableSockError) Kind() stack.TransportErrorKind {
return stack.DestinationPortUnreachableTransportError
}
var _ stack.TransportError = (*icmpv4DestinationProtoUnreachableSockError)(nil)
// icmpv4DestinationProtoUnreachableSockError is an ICMPv4 Destination Protocol
// Unreachable error.
//
// It indicates that a packet reached the destination host, but the transport
// protocol was not reachable
//
// +stateify savable
type icmpv4DestinationProtoUnreachableSockError struct {
icmpv4DestinationUnreachableSockError
}
// Code implements tcpip.SockErrorCause.
func (*icmpv4DestinationProtoUnreachableSockError) Code() uint8 {
return uint8(header.ICMPv4ProtoUnreachable)
}
// Kind implements stack.TransportError.
func (*icmpv4DestinationProtoUnreachableSockError) Kind() stack.TransportErrorKind {
return stack.DestinationProtoUnreachableTransportError
}
var _ stack.TransportError = (*icmpv4SourceRouteFailedSockError)(nil)
// icmpv4SourceRouteFailedSockError is an ICMPv4 Destination Unreachable error
// due to source route failed.
//
// +stateify savable
type icmpv4SourceRouteFailedSockError struct {
icmpv4DestinationUnreachableSockError
}
// Code implements tcpip.SockErrorCause.
func (*icmpv4SourceRouteFailedSockError) Code() uint8 {
return uint8(header.ICMPv4SourceRouteFailed)
}
// Kind implements stack.TransportError.
func (*icmpv4SourceRouteFailedSockError) Kind() stack.TransportErrorKind {
return stack.SourceRouteFailedTransportError
}
var _ stack.TransportError = (*icmpv4SourceHostIsolatedSockError)(nil)
// icmpv4SourceHostIsolatedSockError is an ICMPv4 Destination Unreachable error
// due to source host isolated (not on the network).
//
// +stateify savable
type icmpv4SourceHostIsolatedSockError struct {
icmpv4DestinationUnreachableSockError
}
// Code implements tcpip.SockErrorCause.
func (*icmpv4SourceHostIsolatedSockError) Code() uint8 {
return uint8(header.ICMPv4SourceHostIsolated)
}
// Kind implements stack.TransportError.
func (*icmpv4SourceHostIsolatedSockError) Kind() stack.TransportErrorKind {
return stack.SourceHostIsolatedTransportError
}
var _ stack.TransportError = (*icmpv4DestinationHostUnknownSockError)(nil)
// icmpv4DestinationHostUnknownSockError is an ICMPv4 Destination Unreachable
// error due to destination host unknown/down.
//
// +stateify savable
type icmpv4DestinationHostUnknownSockError struct {
icmpv4DestinationUnreachableSockError
}
// Code implements tcpip.SockErrorCause.
func (*icmpv4DestinationHostUnknownSockError) Code() uint8 {
return uint8(header.ICMPv4DestinationHostUnknown)
}
// Kind implements stack.TransportError.
func (*icmpv4DestinationHostUnknownSockError) Kind() stack.TransportErrorKind {
return stack.DestinationHostDownTransportError
}
var _ stack.TransportError = (*icmpv4FragmentationNeededSockError)(nil)
// icmpv4FragmentationNeededSockError is an ICMPv4 Destination Unreachable error
// due to fragmentation being required but the packet was set to not be
// fragmented.
//
// It indicates that a link exists on the path to the destination with an MTU
// that is too small to carry the packet.
//
// +stateify savable
type icmpv4FragmentationNeededSockError struct {
icmpv4DestinationUnreachableSockError
mtu uint32
}
// Code implements tcpip.SockErrorCause.
func (*icmpv4FragmentationNeededSockError) Code() uint8 {
return uint8(header.ICMPv4FragmentationNeeded)
}
// Info implements tcpip.SockErrorCause.
func (e *icmpv4FragmentationNeededSockError) Info() uint32 {
return e.mtu
}
// Kind implements stack.TransportError.
func (*icmpv4FragmentationNeededSockError) Kind() stack.TransportErrorKind {
return stack.PacketTooBigTransportError
}
func (e *endpoint) checkLocalAddress(addr tcpip.Address) bool {
if e.nic.Spoofing() {
return true
}
if addressEndpoint := e.AcquireAssignedAddress(addr, false, stack.NeverPrimaryEndpoint, true /* readOnly */); addressEndpoint != nil {
return true
}
return false
}
// handleControl handles the case when an ICMP error packet contains the headers
// of the original packet that caused the ICMP one to be sent. This information
// is used to find out which transport endpoint must be notified about the ICMP
// packet. We only expect the payload, not the enclosing ICMP packet.
func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.PacketBuffer) {
h, ok := pkt.Data().PullUp(header.IPv4MinimumSize)
if !ok {
return
}
hdr := header.IPv4(h)
// We don't use IsValid() here because ICMP only requires that the IP
// header plus 8 bytes of the transport header be included. So it's
// likely that it is truncated, which would cause IsValid to return
// false.
//
// Drop packet if it doesn't have the basic IPv4 header or if the
// original source address doesn't match an address we own.
srcAddr := hdr.SourceAddress()
if !e.checkLocalAddress(srcAddr) {
return
}
hlen := int(hdr.HeaderLength())
if pkt.Data().Size() < hlen || hdr.FragmentOffset() != 0 {
// We won't be able to handle this if it doesn't contain the
// full IPv4 header, or if it's a fragment not at offset 0
// (because it won't have the transport header).
return
}
// Keep needed information before trimming header.
p := hdr.TransportProtocol()
dstAddr := hdr.DestinationAddress()
// Skip the ip header, then deliver the error.
if _, ok := pkt.Data().Consume(hlen); !ok {
panic(fmt.Sprintf("could not consume the IP header of %d bytes", hlen))
}
e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, errInfo, pkt)
}
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
received := e.stats.icmp.packetsReceived
h := header.ICMPv4(pkt.TransportHeader().Slice())
if len(h) < header.ICMPv4MinimumSize {
received.invalid.Increment()
return
}
// Only do in-stack processing if the checksum is correct.
if checksum.Checksum(h, pkt.Data().Checksum()) != 0xffff {
received.invalid.Increment()
// It's possible that a raw socket expects to receive this regardless
// of checksum errors. If it's an echo request we know it's safe because
// we are the only handler, however other types do not cope well with
// packets with checksum errors.
switch h.Type() {
case header.ICMPv4Echo:
e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt)
}
return
}
iph := header.IPv4(pkt.NetworkHeader().Slice())
var newOptions header.IPv4Options
if opts := iph.Options(); len(opts) != 0 {
// RFC 1122 section 3.2.2.6 (page 43) (and similar for other round trip
// type ICMP packets):
// If a Record Route and/or Time Stamp option is received in an
// ICMP Echo Request, this option (these options) SHOULD be
// updated to include the current host and included in the IP
// header of the Echo Reply message, without "truncation".
// Thus, the recorded route will be for the entire round trip.
//
// So we need to let the option processor know how it should handle them.
var op optionsUsage
if h.Type() == header.ICMPv4Echo {
op = &optionUsageEcho{}
} else {
op = &optionUsageReceive{}
}
var optProblem *header.IPv4OptParameterProblem
newOptions, _, optProblem = e.processIPOptions(pkt, opts, op)
if optProblem != nil {
if optProblem.NeedICMP {
_ = e.protocol.returnError(&icmpReasonParamProblem{
pointer: optProblem.Pointer,
}, pkt, true /* deliveredLocally */)
e.stats.ip.MalformedPacketsReceived.Increment()
}
return
}
copied := copy(opts, newOptions)
if copied != len(newOptions) {
panic(fmt.Sprintf("copied %d bytes of new options, expected %d bytes", copied, len(newOptions)))
}
for i := copied; i < len(opts); i++ {
// Pad with 0 (EOL). RFC 791 page 23 says "The padding is zero".
opts[i] = byte(header.IPv4OptionListEndType)
}
}
// TODO(b/112892170): Meaningfully handle all ICMP types.
switch h.Type() {
case header.ICMPv4Echo:
received.echoRequest.Increment()
// DeliverTransportPacket may modify pkt so don't use it beyond
// this point. Make a deep copy of the data before pkt gets sent as we will
// be modifying fields. Both the ICMP header (with its type modified to
// EchoReply) and payload are reused in the reply packet.
//
// TODO(gvisor.dev/issue/4399): The copy may not be needed if there are no
// waiting endpoints. Consider moving responsibility for doing the copy to
// DeliverTransportPacket so that is is only done when needed.
replyData := stack.PayloadSince(pkt.TransportHeader())
defer replyData.Release()
ipHdr := header.IPv4(pkt.NetworkHeader().Slice())
localAddressBroadcast := pkt.NetworkPacketInfo.LocalAddressBroadcast
// It's possible that a raw socket expects to receive this.
e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt)
pkt = nil
sent := e.stats.icmp.packetsSent
if !e.protocol.allowICMPReply(header.ICMPv4EchoReply, header.ICMPv4UnusedCode) {
sent.rateLimited.Increment()
return
}
// As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP
// source address MUST be one of its own IP addresses (but not a broadcast
// or multicast address).
localAddr := ipHdr.DestinationAddress()
if localAddressBroadcast || header.IsV4MulticastAddress(localAddr) {
localAddr = tcpip.Address{}
}
r, err := e.protocol.stack.FindRoute(e.nic.ID(), localAddr, ipHdr.SourceAddress(), ProtocolNumber, false /* multicastLoop */)
if err != nil {
// If we cannot find a route to the destination, silently drop the packet.
return
}
defer r.Release()
outgoingEP, ok := e.protocol.getEndpointForNIC(r.NICID())
if !ok {
// The outgoing NIC went away.
sent.dropped.Increment()
return
}
// Because IP and ICMP are so closely intertwined, we need to handcraft our
// IP header to be able to follow RFC 792. The wording on page 13 is as
// follows:
// IP Fields:
// Addresses
// The address of the source in an echo message will be the
// destination of the echo reply message. To form an echo reply
// message, the source and destination addresses are simply reversed,
// the type code changed to 0, and the checksum recomputed.
//
// This was interpreted by early implementors to mean that all options must
// be copied from the echo request IP header to the echo reply IP header
// and this behaviour is still relied upon by some applications.
//
// Create a copy of the IP header we received, options and all, and change
// The fields we need to alter.
//
// We need to produce the entire packet in the data segment in order to
// use WriteHeaderIncludedPacket(). WriteHeaderIncludedPacket sets the
// total length and the header checksum so we don't need to set those here.
//
// Take the base of the incoming request IP header but replace the options.
replyHeaderLength := uint8(header.IPv4MinimumSize + len(newOptions))
replyIPHdrView := buffer.NewView(int(replyHeaderLength))
replyIPHdrView.Write(iph[:header.IPv4MinimumSize])
replyIPHdrView.Write(newOptions)
replyIPHdr := header.IPv4(replyIPHdrView.AsSlice())
replyIPHdr.SetHeaderLength(replyHeaderLength)
replyIPHdr.SetSourceAddress(r.LocalAddress())
replyIPHdr.SetDestinationAddress(r.RemoteAddress())
replyIPHdr.SetTTL(r.DefaultTTL())
replyIPHdr.SetTotalLength(uint16(len(replyIPHdr) + len(replyData.AsSlice())))
replyIPHdr.SetChecksum(0)
replyIPHdr.SetChecksum(^replyIPHdr.CalculateChecksum())
replyICMPHdr := header.ICMPv4(replyData.AsSlice())
replyICMPHdr.SetType(header.ICMPv4EchoReply)
replyICMPHdr.SetChecksum(0)
replyICMPHdr.SetChecksum(^checksum.Checksum(replyData.AsSlice(), 0))
replyBuf := buffer.MakeWithView(replyIPHdrView)
replyBuf.Append(replyData.Clone())
replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Payload: replyBuf,
})
defer replyPkt.DecRef()
// Populate the network/transport headers in the packet buffer so the
// ICMP packet goes through IPTables.
if ok := parse.IPv4(replyPkt); !ok {
panic("expected to parse IPv4 header we just created")
}
if ok := parse.ICMPv4(replyPkt); !ok {
panic("expected to parse ICMPv4 header we just created")
}
if err := outgoingEP.writePacket(r, replyPkt); err != nil {
sent.dropped.Increment()
return
}
sent.echoReply.Increment()
case header.ICMPv4EchoReply:
received.echoReply.Increment()
// ICMP sockets expect the ICMP header to be present, so we don't consume
// the ICMP header.
e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt)
case header.ICMPv4DstUnreachable:
received.dstUnreachable.Increment()
mtu := h.MTU()
code := h.Code()
switch code {
case header.ICMPv4NetUnreachable,
header.ICMPv4DestinationNetworkUnknown,
header.ICMPv4NetUnreachableForTos,
header.ICMPv4NetProhibited:
e.handleControl(&icmpv4DestinationNetUnreachableSockError{}, pkt)
case header.ICMPv4HostUnreachable,
header.ICMPv4HostProhibited,
header.ICMPv4AdminProhibited,
header.ICMPv4HostUnreachableForTos,
header.ICMPv4HostPrecedenceViolation,
header.ICMPv4PrecedenceCutInEffect:
e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt)
case header.ICMPv4PortUnreachable:
e.handleControl(&icmpv4DestinationPortUnreachableSockError{}, pkt)
case header.ICMPv4FragmentationNeeded:
networkMTU, err := calculateNetworkMTU(uint32(mtu), header.IPv4MinimumSize)
if err != nil {
networkMTU = 0
}
e.handleControl(&icmpv4FragmentationNeededSockError{mtu: networkMTU}, pkt)
case header.ICMPv4ProtoUnreachable:
e.handleControl(&icmpv4DestinationProtoUnreachableSockError{}, pkt)
case header.ICMPv4SourceRouteFailed:
e.handleControl(&icmpv4SourceRouteFailedSockError{}, pkt)
case header.ICMPv4SourceHostIsolated:
e.handleControl(&icmpv4SourceHostIsolatedSockError{}, pkt)
case header.ICMPv4DestinationHostUnknown:
e.handleControl(&icmpv4DestinationHostUnknownSockError{}, pkt)
}
case header.ICMPv4SrcQuench:
received.srcQuench.Increment()
case header.ICMPv4Redirect:
received.redirect.Increment()
case header.ICMPv4TimeExceeded:
received.timeExceeded.Increment()
case header.ICMPv4ParamProblem:
received.paramProblem.Increment()
case header.ICMPv4Timestamp:
received.timestamp.Increment()
case header.ICMPv4TimestampReply:
received.timestampReply.Increment()
case header.ICMPv4InfoRequest:
received.infoRequest.Increment()
case header.ICMPv4InfoReply:
received.infoReply.Increment()
default:
received.invalid.Increment()
}
}
// ======= ICMP Error packet generation =========
// icmpReason is a marker interface for IPv4 specific ICMP errors.
type icmpReason interface {
isICMPReason()
}
// icmpReasonNetworkProhibited is an error where the destination network is
// prohibited.
type icmpReasonNetworkProhibited struct{}
func (*icmpReasonNetworkProhibited) isICMPReason() {}
// icmpReasonHostProhibited is an error where the destination host is
// prohibited.
type icmpReasonHostProhibited struct{}
func (*icmpReasonHostProhibited) isICMPReason() {}
// icmpReasonAdministrativelyProhibited is an error where the destination is
// administratively prohibited.
type icmpReasonAdministrativelyProhibited struct{}
func (*icmpReasonAdministrativelyProhibited) isICMPReason() {}
// icmpReasonPortUnreachable is an error where the transport protocol has no
// listener and no alternative means to inform the sender.
type icmpReasonPortUnreachable struct{}
func (*icmpReasonPortUnreachable) isICMPReason() {}
// icmpReasonProtoUnreachable is an error where the transport protocol is
// not supported.
type icmpReasonProtoUnreachable struct{}
func (*icmpReasonProtoUnreachable) isICMPReason() {}
// icmpReasonTTLExceeded is an error where a packet's time to live exceeded in
// transit to its final destination, as per RFC 792 page 6, Time Exceeded
// Message.
type icmpReasonTTLExceeded struct{}
func (*icmpReasonTTLExceeded) isICMPReason() {}
// icmpReasonReassemblyTimeout is an error where insufficient fragments are
// received to complete reassembly of a packet within a configured time after
// the reception of the first-arriving fragment of that packet.
type icmpReasonReassemblyTimeout struct{}
func (*icmpReasonReassemblyTimeout) isICMPReason() {}
// icmpReasonParamProblem is an error to use to request a Parameter Problem
// message to be sent.
type icmpReasonParamProblem struct {
pointer byte
}
func (*icmpReasonParamProblem) isICMPReason() {}
// icmpReasonNetworkUnreachable is an error in which the network specified in
// the internet destination field of the datagram is unreachable.
type icmpReasonNetworkUnreachable struct{}
func (*icmpReasonNetworkUnreachable) isICMPReason() {}
// icmpReasonFragmentationNeeded is an error where a packet requires
// fragmentation while also having the Don't Fragment flag set, as per RFC 792
// page 3, Destination Unreachable Message.
type icmpReasonFragmentationNeeded struct{}
func (*icmpReasonFragmentationNeeded) isICMPReason() {}
// icmpReasonHostUnreachable is an error in which the host specified in the
// internet destination field of the datagram is unreachable.
type icmpReasonHostUnreachable struct{}
func (*icmpReasonHostUnreachable) isICMPReason() {}
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv4 and sends it back to the remote device that sent
// the problematic packet. It incorporates as much of that packet as
// possible as well as any error metadata as is available. returnError
// expects pkt to hold a valid IPv4 packet as per the wire format.
func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer, deliveredLocally bool) tcpip.Error {
origIPHdr := header.IPv4(pkt.NetworkHeader().Slice())
origIPHdrSrc := origIPHdr.SourceAddress()
origIPHdrDst := origIPHdr.DestinationAddress()
// We check we are responding only when we are allowed to.
// See RFC 1812 section 4.3.2.7 (shown below).
//
// =========
// 4.3.2.7 When Not to Send ICMP Errors
//
// An ICMP error message MUST NOT be sent as the result of receiving:
//
// o An ICMP error message, or
//
// o A packet which fails the IP header validation tests described in
// Section [5.2.2] (except where that section specifically permits
// the sending of an ICMP error message), or
//
// o A packet destined to an IP broadcast or IP multicast address, or
//
// o A packet sent as a Link Layer broadcast or multicast, or
//
// o Any fragment of a datagram other then the first fragment (i.e., a
// packet for which the fragment offset in the IP header is nonzero).
//
// TODO(gvisor.dev/issues/4058): Make sure we don't send ICMP errors in
// response to a non-initial fragment, but it currently can not happen.
if pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(origIPHdrDst) || origIPHdrSrc == header.IPv4Any {
return nil
}
// If the packet wasn't delivered locally, do not use the packet's destination
// address as the response's source address as we should not not own the
// destination address of a packet we are forwarding.
localAddr := origIPHdrDst
if !deliveredLocally {
localAddr = tcpip.Address{}
}
// Even if we were able to receive a packet from some remote, we may not have
// a route to it - the remote may be blocked via routing rules. We must always
// consult our routing table and find a route to the remote before sending any
// packet.
route, err := p.stack.FindRoute(pkt.NICID, localAddr, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */)
if err != nil {
return err
}
defer route.Release()
p.mu.Lock()
// We retrieve an endpoint using the newly constructed route's NICID rather
// than the packet's NICID. The packet's NICID corresponds to the NIC on
// which it arrived, which isn't necessarily the same as the NIC on which it
// will be transmitted. On the other hand, the route's NIC *is* guaranteed
// to be the NIC on which the packet will be transmitted.
netEP, ok := p.eps[route.NICID()]
p.mu.Unlock()
if !ok {
return &tcpip.ErrNotConnected{}
}
transportHeader := pkt.TransportHeader().Slice()
// Don't respond to icmp error packets.
if origIPHdr.Protocol() == uint8(header.ICMPv4ProtocolNumber) {
// We need to decide to explicitly name the packets we can respond to or
// the ones we can not respond to. The decision is somewhat arbitrary and
// if problems arise this could be reversed. It was judged less of a breach
// of protocol to not respond to unknown non-error packets than to respond
// to unknown error packets so we take the first approach.
if len(transportHeader) < header.ICMPv4MinimumSize {
// The packet is malformed.
return nil
}
switch header.ICMPv4(transportHeader).Type() {
case
header.ICMPv4EchoReply,
header.ICMPv4Echo,
header.ICMPv4Timestamp,
header.ICMPv4TimestampReply,
header.ICMPv4InfoRequest,
header.ICMPv4InfoReply:
default:
// Assume any type we don't know about may be an error type.
return nil
}
}
sent := netEP.stats.icmp.packetsSent
icmpType, icmpCode, counter, pointer := func() (header.ICMPv4Type, header.ICMPv4Code, tcpip.MultiCounterStat, byte) {
switch reason := reason.(type) {
case *icmpReasonNetworkProhibited:
return header.ICMPv4DstUnreachable, header.ICMPv4NetProhibited, sent.dstUnreachable, 0
case *icmpReasonHostProhibited:
return header.ICMPv4DstUnreachable, header.ICMPv4HostProhibited, sent.dstUnreachable, 0
case *icmpReasonAdministrativelyProhibited:
return header.ICMPv4DstUnreachable, header.ICMPv4AdminProhibited, sent.dstUnreachable, 0
case *icmpReasonPortUnreachable:
return header.ICMPv4DstUnreachable, header.ICMPv4PortUnreachable, sent.dstUnreachable, 0
case *icmpReasonProtoUnreachable:
return header.ICMPv4DstUnreachable, header.ICMPv4ProtoUnreachable, sent.dstUnreachable, 0
case *icmpReasonNetworkUnreachable:
return header.ICMPv4DstUnreachable, header.ICMPv4NetUnreachable, sent.dstUnreachable, 0
case *icmpReasonHostUnreachable:
return header.ICMPv4DstUnreachable, header.ICMPv4HostUnreachable, sent.dstUnreachable, 0
case *icmpReasonFragmentationNeeded:
return header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, sent.dstUnreachable, 0
case *icmpReasonTTLExceeded:
return header.ICMPv4TimeExceeded, header.ICMPv4TTLExceeded, sent.timeExceeded, 0
case *icmpReasonReassemblyTimeout:
return header.ICMPv4TimeExceeded, header.ICMPv4ReassemblyTimeout, sent.timeExceeded, 0
case *icmpReasonParamProblem:
return header.ICMPv4ParamProblem, header.ICMPv4UnusedCode, sent.paramProblem, reason.pointer
default:
panic(fmt.Sprintf("unsupported ICMP type %T", reason))
}
}()
if !p.allowICMPReply(icmpType, icmpCode) {
sent.rateLimited.Increment()
return nil
}
// Now work out how much of the triggering packet we should return.
// As per RFC 1812 Section 4.3.2.3
//
// ICMP datagram SHOULD contain as much of the original
// datagram as possible without the length of the ICMP
// datagram exceeding 576 bytes.
//
// NOTE: The above RFC referenced is different from the original
// recommendation in RFC 1122 and RFC 792 where it mentioned that at
// least 8 bytes of the payload must be included. Today linux and other
// systems implement the RFC 1812 definition and not the original
// requirement. We treat 8 bytes as the minimum but will try send more.
mtu := int(route.MTU())
const maxIPData = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize
if mtu > maxIPData {
mtu = maxIPData
}
available := mtu - header.ICMPv4MinimumSize
if available < len(origIPHdr)+header.ICMPv4MinimumErrorPayloadSize {
return nil
}
payloadLen := len(origIPHdr) + len(transportHeader) + pkt.Data().Size()
if payloadLen > available {
payloadLen = available
}
// The buffers used by pkt may be used elsewhere in the system.
// For example, an AF_RAW or AF_PACKET socket may use what the transport
// protocol considers an unreachable destination. Thus we deep copy pkt to
// prevent multiple ownership and SR errors. The new copy is a vectorized
// view with the entire incoming IP packet reassembled and truncated as
// required. This is now the payload of the new ICMP packet and no longer
// considered a packet in its own right.
payload := buffer.MakeWithView(pkt.NetworkHeader().View())
payload.Append(pkt.TransportHeader().View())
if dataCap := payloadLen - int(payload.Size()); dataCap > 0 {
buf := pkt.Data().ToBuffer()
buf.Truncate(int64(dataCap))
payload.Merge(&buf)
} else {
payload.Truncate(int64(payloadLen))
}
icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize,
Payload: payload,
})
defer icmpPkt.DecRef()
icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize))
icmpHdr.SetCode(icmpCode)
icmpHdr.SetType(icmpType)
icmpHdr.SetPointer(pointer)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data().Checksum()))
if err := route.WritePacket(
stack.NetworkHeaderParams{
Protocol: header.ICMPv4ProtocolNumber,
TTL: route.DefaultTTL(),
TOS: stack.DefaultTOS,
},
icmpPkt,
); err != nil {
sent.dropped.Increment()
return err
}
counter.Increment()
return nil
}
// OnReassemblyTimeout implements fragmentation.TimeoutHandler.
func (p *protocol) OnReassemblyTimeout(pkt *stack.PacketBuffer) {
// OnReassemblyTimeout sends a Time Exceeded Message, as per RFC 792:
//
// If a host reassembling a fragmented datagram cannot complete the
// reassembly due to missing fragments within its time limit it discards the
// datagram, and it may send a time exceeded message.
//
// If fragment zero is not available then no time exceeded need be sent at
// all.
if pkt != nil {
p.returnError(&icmpReasonReassemblyTimeout{}, pkt, true /* deliveredLocally */)
}
}

View File

@@ -0,0 +1,654 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ipv4
import (
"fmt"
"math"
"time"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/internal/ip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
const (
// v1RouterPresentTimeout from RFC 2236 Section 8.11, Page 18
// See note on igmpState.igmpV1Present for more detail.
v1RouterPresentTimeout = 400 * time.Second
// v1MaxRespTime from RFC 2236 Section 4, Page 5. "The IGMPv1 router
// will send General Queries with the Max Response Time set to 0. This MUST
// be interpreted as a value of 100 (10 seconds)."
//
// Note that the Max Response Time field is a value in units of deciseconds.
v1MaxRespTime = 10 * time.Second
// UnsolicitedReportIntervalMax is the maximum delay between sending
// unsolicited IGMP reports.
//
// Obtained from RFC 2236 Section 8.10, Page 19.
UnsolicitedReportIntervalMax = 10 * time.Second
)
type protocolMode int
const (
protocolModeV2OrV3 protocolMode = iota
protocolModeV1
// protocolModeV1Compatibility is for maintaining compatibility with IGMPv1
// Routers.
//
// Per RFC 2236 Section 4 Page 6: "The IGMPv1 router expects Version 1
// Membership Reports in response to its Queries, and will not pay
// attention to Version 2 Membership Reports. Therefore, a state variable
// MUST be kept for each interface, describing whether the multicast
// Querier on that interface is running IGMPv1 or IGMPv2. This variable
// MUST be based upon whether or not an IGMPv1 query was heard in the last
// [Version 1 Router Present Timeout] seconds".
protocolModeV1Compatibility
)
// IGMPVersion is the forced version of IGMP.
type IGMPVersion int
const (
_ IGMPVersion = iota
// IGMPVersion1 indicates IGMPv1.
IGMPVersion1
// IGMPVersion2 indicates IGMPv2. Note that IGMP may still fallback to V1
// compatibility mode as required by IGMPv2.
IGMPVersion2
// IGMPVersion3 indicates IGMPv3. Note that IGMP may still fallback to V2
// compatibility mode as required by IGMPv3.
IGMPVersion3
)
// IGMPEndpoint is a network endpoint that supports IGMP.
type IGMPEndpoint interface {
// SetIGMPVersion sets the IGMP version.
//
// Returns the previous IGMP version.
SetIGMPVersion(IGMPVersion) IGMPVersion
// GetIGMPVersion returns the IGMP version.
GetIGMPVersion() IGMPVersion
}
// IGMPOptions holds options for IGMP.
//
// +stateify savable
type IGMPOptions struct {
// Enabled indicates whether IGMP will be performed.
//
// When enabled, IGMP may transmit IGMP report and leave messages when
// joining and leaving multicast groups respectively, and handle incoming
// IGMP packets.
//
// This field is ignored and is always assumed to be false for interfaces
// without neighbouring nodes (e.g. loopback).
Enabled bool
}
var _ ip.MulticastGroupProtocol = (*igmpState)(nil)
// igmpState is the per-interface IGMP state.
//
// igmpState.init() MUST be called after creating an IGMP state.
//
// +stateify savable
type igmpState struct {
// The IPv4 endpoint this igmpState is for.
ep *endpoint
genericMulticastProtocol ip.GenericMulticastProtocolState
// mode is used to configure the version of IGMP to perform.
mode protocolMode
// igmpV1Job is scheduled when this interface receives an IGMPv1 style
// message, upon expiration the igmpV1Present flag is cleared.
// igmpV1Job may not be nil once igmpState is initialized.
igmpV1Job *tcpip.Job
}
// Enabled implements ip.MulticastGroupProtocol.
func (igmp *igmpState) Enabled() bool {
// No need to perform IGMP on loopback interfaces since they don't have
// neighbouring nodes.
return igmp.ep.protocol.options.IGMP.Enabled && !igmp.ep.nic.IsLoopback() && igmp.ep.Enabled()
}
// SendReport implements ip.MulticastGroupProtocol.
//
// +checklocksread:igmp.ep.mu
func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) {
igmpType := header.IGMPv2MembershipReport
switch igmp.mode {
case protocolModeV2OrV3:
case protocolModeV1, protocolModeV1Compatibility:
igmpType = header.IGMPv1MembershipReport
default:
panic(fmt.Sprintf("unrecognized mode = %d", igmp.mode))
}
return igmp.writePacket(groupAddress, groupAddress, igmpType)
}
// SendLeave implements ip.MulticastGroupProtocol.
//
// +checklocksread:igmp.ep.mu
func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) tcpip.Error {
// As per RFC 2236 Section 6, Page 8: "If the interface state says the
// Querier is running IGMPv1, this action SHOULD be skipped. If the flag
// saying we were the last host to report is cleared, this action MAY be
// skipped."
switch igmp.mode {
case protocolModeV2OrV3:
_, err := igmp.writePacket(header.IPv4AllRoutersGroup, groupAddress, header.IGMPLeaveGroup)
return err
case protocolModeV1, protocolModeV1Compatibility:
return nil
default:
panic(fmt.Sprintf("unrecognized mode = %d", igmp.mode))
}
}
// ShouldPerformProtocol implements ip.MulticastGroupProtocol.
func (igmp *igmpState) ShouldPerformProtocol(groupAddress tcpip.Address) bool {
// As per RFC 2236 section 6 page 10,
//
// The all-systems group (address 224.0.0.1) is handled as a special
// case. The host starts in Idle Member state for that group on every
// interface, never transitions to another state, and never sends a
// report for that group.
return groupAddress != header.IPv4AllSystems
}
type igmpv3ReportBuilder struct {
igmp *igmpState
records []header.IGMPv3ReportGroupAddressRecordSerializer
}
// AddRecord implements ip.MulticastGroupProtocolV2ReportBuilder.
func (b *igmpv3ReportBuilder) AddRecord(genericRecordType ip.MulticastGroupProtocolV2ReportRecordType, groupAddress tcpip.Address) {
var recordType header.IGMPv3ReportRecordType
switch genericRecordType {
case ip.MulticastGroupProtocolV2ReportRecordModeIsInclude:
recordType = header.IGMPv3ReportRecordModeIsInclude
case ip.MulticastGroupProtocolV2ReportRecordModeIsExclude:
recordType = header.IGMPv3ReportRecordModeIsExclude
case ip.MulticastGroupProtocolV2ReportRecordChangeToIncludeMode:
recordType = header.IGMPv3ReportRecordChangeToIncludeMode
case ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode:
recordType = header.IGMPv3ReportRecordChangeToExcludeMode
case ip.MulticastGroupProtocolV2ReportRecordAllowNewSources:
recordType = header.IGMPv3ReportRecordAllowNewSources
case ip.MulticastGroupProtocolV2ReportRecordBlockOldSources:
recordType = header.IGMPv3ReportRecordBlockOldSources
default:
panic(fmt.Sprintf("unrecognied genericRecordType = %d", genericRecordType))
}
b.records = append(b.records, header.IGMPv3ReportGroupAddressRecordSerializer{
RecordType: recordType,
GroupAddress: groupAddress,
Sources: nil,
})
}
// Send implements ip.MulticastGroupProtocolV2ReportBuilder.
//
// +checklocksread:b.igmp.ep.mu
func (b *igmpv3ReportBuilder) Send() (sent bool, err tcpip.Error) {
if len(b.records) == 0 {
return false, err
}
options := header.IPv4OptionsSerializer{
&header.IPv4SerializableRouterAlertOption{},
}
mtu := int(b.igmp.ep.MTU()) - int(options.Length())
allSentWithSpecifiedAddress := true
var firstErr tcpip.Error
for records := b.records; len(records) != 0; {
spaceLeft := mtu
maxRecords := 0
for ; maxRecords < len(records); maxRecords++ {
tmp := spaceLeft - records[maxRecords].Length()
if tmp > 0 {
spaceLeft = tmp
} else {
break
}
}
serializer := header.IGMPv3ReportSerializer{Records: records[:maxRecords]}
records = records[maxRecords:]
icmpView := buffer.NewViewSize(serializer.Length())
serializer.SerializeInto(icmpView.AsSlice())
if sentWithSpecifiedAddress, err := b.igmp.writePacketInner(
icmpView,
b.igmp.ep.stats.igmp.packetsSent.v3MembershipReport,
options,
header.IGMPv3RoutersAddress,
); err != nil {
if firstErr != nil {
firstErr = nil
}
allSentWithSpecifiedAddress = false
} else if !sentWithSpecifiedAddress {
allSentWithSpecifiedAddress = false
}
}
return allSentWithSpecifiedAddress, firstErr
}
// NewReportV2Builder implements ip.MulticastGroupProtocol.
func (igmp *igmpState) NewReportV2Builder() ip.MulticastGroupProtocolV2ReportBuilder {
return &igmpv3ReportBuilder{igmp: igmp}
}
// V2QueryMaxRespCodeToV2Delay implements ip.MulticastGroupProtocol.
func (*igmpState) V2QueryMaxRespCodeToV2Delay(code uint16) time.Duration {
if code > math.MaxUint8 {
panic(fmt.Sprintf("got IGMPv3 MaxRespCode = %d, want <= %d", code, math.MaxUint8))
}
return header.IGMPv3MaximumResponseDelay(uint8(code))
}
// V2QueryMaxRespCodeToV1Delay implements ip.MulticastGroupProtocol.
func (*igmpState) V2QueryMaxRespCodeToV1Delay(code uint16) time.Duration {
return time.Duration(code) * time.Millisecond
}
// init sets up an igmpState struct, and is required to be called before using
// a new igmpState.
//
// Must only be called once for the lifetime of igmp.
func (igmp *igmpState) init(ep *endpoint) {
igmp.ep = ep
igmp.genericMulticastProtocol.Init(&ep.mu, ip.GenericMulticastProtocolOptions{
Rand: ep.protocol.stack.InsecureRNG(),
Clock: ep.protocol.stack.Clock(),
Protocol: igmp,
MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax,
})
// As per RFC 2236 Page 9 says "No IGMPv1 Router Present ... is
// the initial state.
igmp.mode = protocolModeV2OrV3
igmp.igmpV1Job = tcpip.NewJob(ep.protocol.stack.Clock(), &ep.mu, func() {
igmp.mode = protocolModeV2OrV3
})
}
// +checklocks:igmp.ep.mu
func (igmp *igmpState) isSourceIPValidLocked(src tcpip.Address, messageType header.IGMPType) bool {
if messageType == header.IGMPMembershipQuery {
// RFC 2236 does not require the IGMP implementation to check the source IP
// for Membership Query messages.
return true
}
// As per RFC 2236 section 10,
//
// Ignore the Report if you cannot identify the source address of the
// packet as belonging to a subnet assigned to the interface on which the
// packet was received.
//
// Ignore the Leave message if you cannot identify the source address of
// the packet as belonging to a subnet assigned to the interface on which
// the packet was received.
//
// Note: this rule applies to both V1 and V2 Membership Reports.
var isSourceIPValid bool
igmp.ep.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
if subnet := addressEndpoint.Subnet(); subnet.Contains(src) {
isSourceIPValid = true
return false
}
return true
})
return isSourceIPValid
}
// +checklocks:igmp.ep.mu
func (igmp *igmpState) isPacketValidLocked(pkt *stack.PacketBuffer, messageType header.IGMPType, hasRouterAlertOption bool) bool {
// We can safely assume that the IP header is valid if we got this far.
iph := header.IPv4(pkt.NetworkHeader().Slice())
// As per RFC 2236 section 2,
//
// All IGMP messages described in this document are sent with IP TTL 1, and
// contain the IP Router Alert option [RFC 2113] in their IP header.
if !hasRouterAlertOption || iph.TTL() != header.IGMPTTL {
return false
}
return igmp.isSourceIPValidLocked(iph.SourceAddress(), messageType)
}
// handleIGMP handles an IGMP packet.
//
// +checklocks:igmp.ep.mu
func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer, hasRouterAlertOption bool) {
received := igmp.ep.stats.igmp.packetsReceived
hdr, ok := pkt.Data().PullUp(pkt.Data().Size())
if !ok {
received.invalid.Increment()
return
}
h := header.IGMP(hdr)
if len(h) < header.IGMPMinimumSize {
received.invalid.Increment()
return
}
// As per RFC 1071 section 1.3,
//
// To check a checksum, the 1's complement sum is computed over the
// same set of octets, including the checksum field. If the result
// is all 1 bits (-0 in 1's complement arithmetic), the check
// succeeds.
if pkt.Data().Checksum() != 0xFFFF {
received.checksumErrors.Increment()
return
}
isValid := func(minimumSize int) bool {
return len(hdr) >= minimumSize && igmp.isPacketValidLocked(pkt, h.Type(), hasRouterAlertOption)
}
switch h.Type() {
case header.IGMPMembershipQuery:
received.membershipQuery.Increment()
if len(h) >= header.IGMPv3QueryMinimumSize {
if isValid(header.IGMPv3QueryMinimumSize) {
igmp.handleMembershipQueryV3(header.IGMPv3Query(h))
} else {
received.invalid.Increment()
}
return
} else if !isValid(header.IGMPQueryMinimumSize) {
received.invalid.Increment()
return
}
igmp.handleMembershipQuery(h.GroupAddress(), h.MaxRespTime())
case header.IGMPv1MembershipReport:
received.v1MembershipReport.Increment()
if !isValid(header.IGMPReportMinimumSize) {
received.invalid.Increment()
return
}
igmp.handleMembershipReport(h.GroupAddress())
case header.IGMPv2MembershipReport:
received.v2MembershipReport.Increment()
if !isValid(header.IGMPReportMinimumSize) {
received.invalid.Increment()
return
}
igmp.handleMembershipReport(h.GroupAddress())
case header.IGMPLeaveGroup:
received.leaveGroup.Increment()
if !isValid(header.IGMPLeaveMessageMinimumSize) {
received.invalid.Increment()
return
}
// As per RFC 2236 Section 6, Page 7: "IGMP messages other than Query or
// Report, are ignored in all states"
default:
// As per RFC 2236 Section 2.1 Page 3: "Unrecognized message types should
// be silently ignored. New message types may be used by newer versions of
// IGMP, by multicast routing protocols, or other uses."
received.unrecognized.Increment()
}
}
func (igmp *igmpState) resetV1Present() {
igmp.igmpV1Job.Cancel()
switch igmp.mode {
case protocolModeV2OrV3, protocolModeV1:
case protocolModeV1Compatibility:
igmp.mode = protocolModeV2OrV3
default:
panic(fmt.Sprintf("unrecognized mode = %d", igmp.mode))
}
}
// handleMembershipQuery handles a membership query.
//
// +checklocks:igmp.ep.mu
func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxRespTime time.Duration) {
// As per RFC 2236 Section 6, Page 10: If the maximum response time is zero
// then change the state to note that an IGMPv1 router is present and
// schedule the query received Job.
if maxRespTime == 0 && igmp.Enabled() {
switch igmp.mode {
case protocolModeV2OrV3, protocolModeV1Compatibility:
igmp.igmpV1Job.Cancel()
igmp.igmpV1Job.Schedule(v1RouterPresentTimeout)
igmp.mode = protocolModeV1Compatibility
case protocolModeV1:
default:
panic(fmt.Sprintf("unrecognized mode = %d", igmp.mode))
}
maxRespTime = v1MaxRespTime
}
igmp.genericMulticastProtocol.HandleQueryLocked(groupAddress, maxRespTime)
}
// handleMembershipQueryV3 handles a membership query.
//
// +checklocks:igmp.ep.mu
func (igmp *igmpState) handleMembershipQueryV3(igmpHdr header.IGMPv3Query) {
sources, ok := igmpHdr.Sources()
if !ok {
return
}
igmp.genericMulticastProtocol.HandleQueryV2Locked(
igmpHdr.GroupAddress(),
uint16(igmpHdr.MaximumResponseCode()),
sources,
igmpHdr.QuerierRobustnessVariable(),
igmpHdr.QuerierQueryInterval(),
)
}
// handleMembershipReport handles a membership report.
//
// +checklocks:igmp.ep.mu
func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) {
igmp.genericMulticastProtocol.HandleReportLocked(groupAddress)
}
// writePacket assembles and sends an IGMP packet.
//
// +checklocksread:igmp.ep.mu
func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) (bool, tcpip.Error) {
igmpView := buffer.NewViewSize(header.IGMPReportMinimumSize)
igmpData := header.IGMP(igmpView.AsSlice())
igmpData.SetType(igmpType)
igmpData.SetGroupAddress(groupAddress)
igmpData.SetChecksum(header.IGMPCalculateChecksum(igmpData))
var reportType tcpip.MultiCounterStat
sentStats := igmp.ep.stats.igmp.packetsSent
switch igmpType {
case header.IGMPv1MembershipReport:
reportType = sentStats.v1MembershipReport
case header.IGMPv2MembershipReport:
reportType = sentStats.v2MembershipReport
case header.IGMPLeaveGroup:
reportType = sentStats.leaveGroup
default:
panic(fmt.Sprintf("unrecognized igmp type = %d", igmpType))
}
return igmp.writePacketInner(
igmpView,
reportType,
header.IPv4OptionsSerializer{
&header.IPv4SerializableRouterAlertOption{},
},
destAddress,
)
}
// +checklocksread:igmp.ep.mu
func (igmp *igmpState) writePacketInner(buf *buffer.View, reportStat tcpip.MultiCounterStat, options header.IPv4OptionsSerializer, destAddress tcpip.Address) (bool, tcpip.Error) {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(igmp.ep.MaxHeaderLength()),
Payload: buffer.MakeWithView(buf),
})
defer pkt.DecRef()
addressEndpoint := igmp.ep.acquireOutgoingPrimaryAddressRLocked(destAddress, tcpip.Address{} /* srcHint */, false /* allowExpired */)
if addressEndpoint == nil {
return false, nil
}
localAddr := addressEndpoint.AddressWithPrefix().Address
addressEndpoint.DecRef()
addressEndpoint = nil
if err := igmp.ep.addIPHeader(localAddr, destAddress, pkt, stack.NetworkHeaderParams{
Protocol: header.IGMPProtocolNumber,
TTL: header.IGMPTTL,
TOS: stack.DefaultTOS,
}, options); err != nil {
panic(fmt.Sprintf("failed to add IP header: %s", err))
}
sentStats := igmp.ep.stats.igmp.packetsSent
if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), pkt); err != nil {
sentStats.dropped.Increment()
return false, err
}
reportStat.Increment()
return true, nil
}
// joinGroup handles adding a new group to the membership map, setting up the
// IGMP state for the group, and sending and scheduling the required
// messages.
//
// If the group already exists in the membership map, returns
// *tcpip.ErrDuplicateAddress.
//
// +checklocks:igmp.ep.mu
func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) {
igmp.genericMulticastProtocol.JoinGroupLocked(groupAddress)
}
// isInGroup returns true if the specified group has been joined locally.
//
// +checklocksread:igmp.ep.mu
func (igmp *igmpState) isInGroup(groupAddress tcpip.Address) bool {
return igmp.genericMulticastProtocol.IsLocallyJoinedRLocked(groupAddress)
}
// leaveGroup handles removing the group from the membership map, cancels any
// delay timers associated with that group, and sends the Leave Group message
// if required.
//
// +checklocks:igmp.ep.mu
func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) tcpip.Error {
// LeaveGroup returns false only if the group was not joined.
if igmp.genericMulticastProtocol.LeaveGroupLocked(groupAddress) {
return nil
}
return &tcpip.ErrBadLocalAddress{}
}
// softLeaveAll leaves all groups from the perspective of IGMP, but remains
// joined locally.
//
// +checklocks:igmp.ep.mu
func (igmp *igmpState) softLeaveAll() {
igmp.genericMulticastProtocol.MakeAllNonMemberLocked()
}
// initializeAll attempts to initialize the IGMP state for each group that has
// been joined locally.
//
// +checklocks:igmp.ep.mu
func (igmp *igmpState) initializeAll() {
igmp.genericMulticastProtocol.InitializeGroupsLocked()
}
// sendQueuedReports attempts to send any reports that are queued for sending.
//
// +checklocks:igmp.ep.mu
func (igmp *igmpState) sendQueuedReports() {
igmp.genericMulticastProtocol.SendQueuedReportsLocked()
}
// setVersion sets the IGMP version.
//
// +checklocks:igmp.ep.mu
func (igmp *igmpState) setVersion(v IGMPVersion) IGMPVersion {
prev := igmp.mode
igmp.igmpV1Job.Cancel()
var prevGenericModeV1 bool
switch v {
case IGMPVersion3:
prevGenericModeV1 = igmp.genericMulticastProtocol.SetV1ModeLocked(false)
igmp.mode = protocolModeV2OrV3
case IGMPVersion2:
// IGMPv1 and IGMPv2 map to V1 of the generic multicast protocol.
prevGenericModeV1 = igmp.genericMulticastProtocol.SetV1ModeLocked(true)
igmp.mode = protocolModeV2OrV3
case IGMPVersion1:
// IGMPv1 and IGMPv2 map to V1 of the generic multicast protocol.
prevGenericModeV1 = igmp.genericMulticastProtocol.SetV1ModeLocked(true)
igmp.mode = protocolModeV1
default:
panic(fmt.Sprintf("unrecognized version = %d", v))
}
return toIGMPVersion(prev, prevGenericModeV1)
}
func toIGMPVersion(mode protocolMode, genericV1 bool) IGMPVersion {
switch mode {
case protocolModeV2OrV3, protocolModeV1Compatibility:
if genericV1 {
return IGMPVersion2
}
return IGMPVersion3
case protocolModeV1:
return IGMPVersion1
default:
panic(fmt.Sprintf("unrecognized mode = %d", mode))
}
}
// getVersion returns the IGMP version.
//
// +checklocksread:igmp.ep.mu
func (igmp *igmpState) getVersion() IGMPVersion {
return toIGMPVersion(igmp.mode, igmp.genericMulticastProtocol.GetV1ModeLocked())
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,785 @@
// automatically generated by stateify.
package ipv4
import (
"context"
"gvisor.dev/gvisor/pkg/state"
)
func (i *icmpv4DestinationUnreachableSockError) StateTypeName() string {
return "pkg/tcpip/network/ipv4.icmpv4DestinationUnreachableSockError"
}
func (i *icmpv4DestinationUnreachableSockError) StateFields() []string {
return []string{}
}
func (i *icmpv4DestinationUnreachableSockError) beforeSave() {}
// +checklocksignore
func (i *icmpv4DestinationUnreachableSockError) StateSave(stateSinkObject state.Sink) {
i.beforeSave()
}
func (i *icmpv4DestinationUnreachableSockError) afterLoad(context.Context) {}
// +checklocksignore
func (i *icmpv4DestinationUnreachableSockError) StateLoad(ctx context.Context, stateSourceObject state.Source) {
}
func (i *icmpv4DestinationHostUnreachableSockError) StateTypeName() string {
return "pkg/tcpip/network/ipv4.icmpv4DestinationHostUnreachableSockError"
}
func (i *icmpv4DestinationHostUnreachableSockError) StateFields() []string {
return []string{
"icmpv4DestinationUnreachableSockError",
}
}
func (i *icmpv4DestinationHostUnreachableSockError) beforeSave() {}
// +checklocksignore
func (i *icmpv4DestinationHostUnreachableSockError) StateSave(stateSinkObject state.Sink) {
i.beforeSave()
stateSinkObject.Save(0, &i.icmpv4DestinationUnreachableSockError)
}
func (i *icmpv4DestinationHostUnreachableSockError) afterLoad(context.Context) {}
// +checklocksignore
func (i *icmpv4DestinationHostUnreachableSockError) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &i.icmpv4DestinationUnreachableSockError)
}
func (i *icmpv4DestinationNetUnreachableSockError) StateTypeName() string {
return "pkg/tcpip/network/ipv4.icmpv4DestinationNetUnreachableSockError"
}
func (i *icmpv4DestinationNetUnreachableSockError) StateFields() []string {
return []string{
"icmpv4DestinationUnreachableSockError",
}
}
func (i *icmpv4DestinationNetUnreachableSockError) beforeSave() {}
// +checklocksignore
func (i *icmpv4DestinationNetUnreachableSockError) StateSave(stateSinkObject state.Sink) {
i.beforeSave()
stateSinkObject.Save(0, &i.icmpv4DestinationUnreachableSockError)
}
func (i *icmpv4DestinationNetUnreachableSockError) afterLoad(context.Context) {}
// +checklocksignore
func (i *icmpv4DestinationNetUnreachableSockError) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &i.icmpv4DestinationUnreachableSockError)
}
func (i *icmpv4DestinationPortUnreachableSockError) StateTypeName() string {
return "pkg/tcpip/network/ipv4.icmpv4DestinationPortUnreachableSockError"
}
func (i *icmpv4DestinationPortUnreachableSockError) StateFields() []string {
return []string{
"icmpv4DestinationUnreachableSockError",
}
}
func (i *icmpv4DestinationPortUnreachableSockError) beforeSave() {}
// +checklocksignore
func (i *icmpv4DestinationPortUnreachableSockError) StateSave(stateSinkObject state.Sink) {
i.beforeSave()
stateSinkObject.Save(0, &i.icmpv4DestinationUnreachableSockError)
}
func (i *icmpv4DestinationPortUnreachableSockError) afterLoad(context.Context) {}
// +checklocksignore
func (i *icmpv4DestinationPortUnreachableSockError) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &i.icmpv4DestinationUnreachableSockError)
}
func (i *icmpv4DestinationProtoUnreachableSockError) StateTypeName() string {
return "pkg/tcpip/network/ipv4.icmpv4DestinationProtoUnreachableSockError"
}
func (i *icmpv4DestinationProtoUnreachableSockError) StateFields() []string {
return []string{
"icmpv4DestinationUnreachableSockError",
}
}
func (i *icmpv4DestinationProtoUnreachableSockError) beforeSave() {}
// +checklocksignore
func (i *icmpv4DestinationProtoUnreachableSockError) StateSave(stateSinkObject state.Sink) {
i.beforeSave()
stateSinkObject.Save(0, &i.icmpv4DestinationUnreachableSockError)
}
func (i *icmpv4DestinationProtoUnreachableSockError) afterLoad(context.Context) {}
// +checklocksignore
func (i *icmpv4DestinationProtoUnreachableSockError) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &i.icmpv4DestinationUnreachableSockError)
}
func (i *icmpv4SourceRouteFailedSockError) StateTypeName() string {
return "pkg/tcpip/network/ipv4.icmpv4SourceRouteFailedSockError"
}
func (i *icmpv4SourceRouteFailedSockError) StateFields() []string {
return []string{
"icmpv4DestinationUnreachableSockError",
}
}
func (i *icmpv4SourceRouteFailedSockError) beforeSave() {}
// +checklocksignore
func (i *icmpv4SourceRouteFailedSockError) StateSave(stateSinkObject state.Sink) {
i.beforeSave()
stateSinkObject.Save(0, &i.icmpv4DestinationUnreachableSockError)
}
func (i *icmpv4SourceRouteFailedSockError) afterLoad(context.Context) {}
// +checklocksignore
func (i *icmpv4SourceRouteFailedSockError) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &i.icmpv4DestinationUnreachableSockError)
}
func (i *icmpv4SourceHostIsolatedSockError) StateTypeName() string {
return "pkg/tcpip/network/ipv4.icmpv4SourceHostIsolatedSockError"
}
func (i *icmpv4SourceHostIsolatedSockError) StateFields() []string {
return []string{
"icmpv4DestinationUnreachableSockError",
}
}
func (i *icmpv4SourceHostIsolatedSockError) beforeSave() {}
// +checklocksignore
func (i *icmpv4SourceHostIsolatedSockError) StateSave(stateSinkObject state.Sink) {
i.beforeSave()
stateSinkObject.Save(0, &i.icmpv4DestinationUnreachableSockError)
}
func (i *icmpv4SourceHostIsolatedSockError) afterLoad(context.Context) {}
// +checklocksignore
func (i *icmpv4SourceHostIsolatedSockError) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &i.icmpv4DestinationUnreachableSockError)
}
func (i *icmpv4DestinationHostUnknownSockError) StateTypeName() string {
return "pkg/tcpip/network/ipv4.icmpv4DestinationHostUnknownSockError"
}
func (i *icmpv4DestinationHostUnknownSockError) StateFields() []string {
return []string{
"icmpv4DestinationUnreachableSockError",
}
}
func (i *icmpv4DestinationHostUnknownSockError) beforeSave() {}
// +checklocksignore
func (i *icmpv4DestinationHostUnknownSockError) StateSave(stateSinkObject state.Sink) {
i.beforeSave()
stateSinkObject.Save(0, &i.icmpv4DestinationUnreachableSockError)
}
func (i *icmpv4DestinationHostUnknownSockError) afterLoad(context.Context) {}
// +checklocksignore
func (i *icmpv4DestinationHostUnknownSockError) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &i.icmpv4DestinationUnreachableSockError)
}
func (e *icmpv4FragmentationNeededSockError) StateTypeName() string {
return "pkg/tcpip/network/ipv4.icmpv4FragmentationNeededSockError"
}
func (e *icmpv4FragmentationNeededSockError) StateFields() []string {
return []string{
"icmpv4DestinationUnreachableSockError",
"mtu",
}
}
func (e *icmpv4FragmentationNeededSockError) beforeSave() {}
// +checklocksignore
func (e *icmpv4FragmentationNeededSockError) StateSave(stateSinkObject state.Sink) {
e.beforeSave()
stateSinkObject.Save(0, &e.icmpv4DestinationUnreachableSockError)
stateSinkObject.Save(1, &e.mtu)
}
func (e *icmpv4FragmentationNeededSockError) afterLoad(context.Context) {}
// +checklocksignore
func (e *icmpv4FragmentationNeededSockError) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &e.icmpv4DestinationUnreachableSockError)
stateSourceObject.Load(1, &e.mtu)
}
func (i *IGMPOptions) StateTypeName() string {
return "pkg/tcpip/network/ipv4.IGMPOptions"
}
func (i *IGMPOptions) StateFields() []string {
return []string{
"Enabled",
}
}
func (i *IGMPOptions) beforeSave() {}
// +checklocksignore
func (i *IGMPOptions) StateSave(stateSinkObject state.Sink) {
i.beforeSave()
stateSinkObject.Save(0, &i.Enabled)
}
func (i *IGMPOptions) afterLoad(context.Context) {}
// +checklocksignore
func (i *IGMPOptions) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &i.Enabled)
}
func (igmp *igmpState) StateTypeName() string {
return "pkg/tcpip/network/ipv4.igmpState"
}
func (igmp *igmpState) StateFields() []string {
return []string{
"ep",
"genericMulticastProtocol",
"mode",
"igmpV1Job",
}
}
func (igmp *igmpState) beforeSave() {}
// +checklocksignore
func (igmp *igmpState) StateSave(stateSinkObject state.Sink) {
igmp.beforeSave()
stateSinkObject.Save(0, &igmp.ep)
stateSinkObject.Save(1, &igmp.genericMulticastProtocol)
stateSinkObject.Save(2, &igmp.mode)
stateSinkObject.Save(3, &igmp.igmpV1Job)
}
func (igmp *igmpState) afterLoad(context.Context) {}
// +checklocksignore
func (igmp *igmpState) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &igmp.ep)
stateSourceObject.Load(1, &igmp.genericMulticastProtocol)
stateSourceObject.Load(2, &igmp.mode)
stateSourceObject.Load(3, &igmp.igmpV1Job)
}
func (e *endpoint) StateTypeName() string {
return "pkg/tcpip/network/ipv4.endpoint"
}
func (e *endpoint) StateFields() []string {
return []string{
"nic",
"dispatcher",
"protocol",
"stats",
"enabled",
"forwarding",
"multicastForwarding",
"addressableEndpointState",
"igmp",
}
}
func (e *endpoint) beforeSave() {}
// +checklocksignore
func (e *endpoint) StateSave(stateSinkObject state.Sink) {
e.beforeSave()
stateSinkObject.Save(0, &e.nic)
stateSinkObject.Save(1, &e.dispatcher)
stateSinkObject.Save(2, &e.protocol)
stateSinkObject.Save(3, &e.stats)
stateSinkObject.Save(4, &e.enabled)
stateSinkObject.Save(5, &e.forwarding)
stateSinkObject.Save(6, &e.multicastForwarding)
stateSinkObject.Save(7, &e.addressableEndpointState)
stateSinkObject.Save(8, &e.igmp)
}
func (e *endpoint) afterLoad(context.Context) {}
// +checklocksignore
func (e *endpoint) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &e.nic)
stateSourceObject.Load(1, &e.dispatcher)
stateSourceObject.Load(2, &e.protocol)
stateSourceObject.Load(3, &e.stats)
stateSourceObject.Load(4, &e.enabled)
stateSourceObject.Load(5, &e.forwarding)
stateSourceObject.Load(6, &e.multicastForwarding)
stateSourceObject.Load(7, &e.addressableEndpointState)
stateSourceObject.Load(8, &e.igmp)
}
func (p *protocol) StateTypeName() string {
return "pkg/tcpip/network/ipv4.protocol"
}
func (p *protocol) StateFields() []string {
return []string{
"stack",
"eps",
"icmpRateLimitedTypes",
"defaultTTL",
"ids",
"hashIV",
"idTS",
"fragmentation",
"options",
"multicastRouteTable",
"multicastForwardingDisp",
}
}
func (p *protocol) beforeSave() {}
// +checklocksignore
func (p *protocol) StateSave(stateSinkObject state.Sink) {
p.beforeSave()
stateSinkObject.Save(0, &p.stack)
stateSinkObject.Save(1, &p.eps)
stateSinkObject.Save(2, &p.icmpRateLimitedTypes)
stateSinkObject.Save(3, &p.defaultTTL)
stateSinkObject.Save(4, &p.ids)
stateSinkObject.Save(5, &p.hashIV)
stateSinkObject.Save(6, &p.idTS)
stateSinkObject.Save(7, &p.fragmentation)
stateSinkObject.Save(8, &p.options)
stateSinkObject.Save(9, &p.multicastRouteTable)
stateSinkObject.Save(10, &p.multicastForwardingDisp)
}
func (p *protocol) afterLoad(context.Context) {}
// +checklocksignore
func (p *protocol) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &p.stack)
stateSourceObject.Load(1, &p.eps)
stateSourceObject.Load(2, &p.icmpRateLimitedTypes)
stateSourceObject.Load(3, &p.defaultTTL)
stateSourceObject.Load(4, &p.ids)
stateSourceObject.Load(5, &p.hashIV)
stateSourceObject.Load(6, &p.idTS)
stateSourceObject.Load(7, &p.fragmentation)
stateSourceObject.Load(8, &p.options)
stateSourceObject.Load(9, &p.multicastRouteTable)
stateSourceObject.Load(10, &p.multicastForwardingDisp)
}
func (o *Options) StateTypeName() string {
return "pkg/tcpip/network/ipv4.Options"
}
func (o *Options) StateFields() []string {
return []string{
"IGMP",
"AllowExternalLoopbackTraffic",
}
}
func (o *Options) beforeSave() {}
// +checklocksignore
func (o *Options) StateSave(stateSinkObject state.Sink) {
o.beforeSave()
stateSinkObject.Save(0, &o.IGMP)
stateSinkObject.Save(1, &o.AllowExternalLoopbackTraffic)
}
func (o *Options) afterLoad(context.Context) {}
// +checklocksignore
func (o *Options) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &o.IGMP)
stateSourceObject.Load(1, &o.AllowExternalLoopbackTraffic)
}
func (s *Stats) StateTypeName() string {
return "pkg/tcpip/network/ipv4.Stats"
}
func (s *Stats) StateFields() []string {
return []string{
"IP",
"IGMP",
"ICMP",
}
}
func (s *Stats) beforeSave() {}
// +checklocksignore
func (s *Stats) StateSave(stateSinkObject state.Sink) {
s.beforeSave()
stateSinkObject.Save(0, &s.IP)
stateSinkObject.Save(1, &s.IGMP)
stateSinkObject.Save(2, &s.ICMP)
}
func (s *Stats) afterLoad(context.Context) {}
// +checklocksignore
func (s *Stats) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &s.IP)
stateSourceObject.Load(1, &s.IGMP)
stateSourceObject.Load(2, &s.ICMP)
}
func (s *sharedStats) StateTypeName() string {
return "pkg/tcpip/network/ipv4.sharedStats"
}
func (s *sharedStats) StateFields() []string {
return []string{
"localStats",
"ip",
"icmp",
"igmp",
}
}
func (s *sharedStats) beforeSave() {}
// +checklocksignore
func (s *sharedStats) StateSave(stateSinkObject state.Sink) {
s.beforeSave()
stateSinkObject.Save(0, &s.localStats)
stateSinkObject.Save(1, &s.ip)
stateSinkObject.Save(2, &s.icmp)
stateSinkObject.Save(3, &s.igmp)
}
func (s *sharedStats) afterLoad(context.Context) {}
// +checklocksignore
func (s *sharedStats) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &s.localStats)
stateSourceObject.Load(1, &s.ip)
stateSourceObject.Load(2, &s.icmp)
stateSourceObject.Load(3, &s.igmp)
}
func (m *multiCounterICMPv4PacketStats) StateTypeName() string {
return "pkg/tcpip/network/ipv4.multiCounterICMPv4PacketStats"
}
func (m *multiCounterICMPv4PacketStats) StateFields() []string {
return []string{
"echoRequest",
"echoReply",
"dstUnreachable",
"srcQuench",
"redirect",
"timeExceeded",
"paramProblem",
"timestamp",
"timestampReply",
"infoRequest",
"infoReply",
}
}
func (m *multiCounterICMPv4PacketStats) beforeSave() {}
// +checklocksignore
func (m *multiCounterICMPv4PacketStats) StateSave(stateSinkObject state.Sink) {
m.beforeSave()
stateSinkObject.Save(0, &m.echoRequest)
stateSinkObject.Save(1, &m.echoReply)
stateSinkObject.Save(2, &m.dstUnreachable)
stateSinkObject.Save(3, &m.srcQuench)
stateSinkObject.Save(4, &m.redirect)
stateSinkObject.Save(5, &m.timeExceeded)
stateSinkObject.Save(6, &m.paramProblem)
stateSinkObject.Save(7, &m.timestamp)
stateSinkObject.Save(8, &m.timestampReply)
stateSinkObject.Save(9, &m.infoRequest)
stateSinkObject.Save(10, &m.infoReply)
}
func (m *multiCounterICMPv4PacketStats) afterLoad(context.Context) {}
// +checklocksignore
func (m *multiCounterICMPv4PacketStats) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &m.echoRequest)
stateSourceObject.Load(1, &m.echoReply)
stateSourceObject.Load(2, &m.dstUnreachable)
stateSourceObject.Load(3, &m.srcQuench)
stateSourceObject.Load(4, &m.redirect)
stateSourceObject.Load(5, &m.timeExceeded)
stateSourceObject.Load(6, &m.paramProblem)
stateSourceObject.Load(7, &m.timestamp)
stateSourceObject.Load(8, &m.timestampReply)
stateSourceObject.Load(9, &m.infoRequest)
stateSourceObject.Load(10, &m.infoReply)
}
func (m *multiCounterICMPv4SentPacketStats) StateTypeName() string {
return "pkg/tcpip/network/ipv4.multiCounterICMPv4SentPacketStats"
}
func (m *multiCounterICMPv4SentPacketStats) StateFields() []string {
return []string{
"multiCounterICMPv4PacketStats",
"dropped",
"rateLimited",
}
}
func (m *multiCounterICMPv4SentPacketStats) beforeSave() {}
// +checklocksignore
func (m *multiCounterICMPv4SentPacketStats) StateSave(stateSinkObject state.Sink) {
m.beforeSave()
stateSinkObject.Save(0, &m.multiCounterICMPv4PacketStats)
stateSinkObject.Save(1, &m.dropped)
stateSinkObject.Save(2, &m.rateLimited)
}
func (m *multiCounterICMPv4SentPacketStats) afterLoad(context.Context) {}
// +checklocksignore
func (m *multiCounterICMPv4SentPacketStats) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &m.multiCounterICMPv4PacketStats)
stateSourceObject.Load(1, &m.dropped)
stateSourceObject.Load(2, &m.rateLimited)
}
func (m *multiCounterICMPv4ReceivedPacketStats) StateTypeName() string {
return "pkg/tcpip/network/ipv4.multiCounterICMPv4ReceivedPacketStats"
}
func (m *multiCounterICMPv4ReceivedPacketStats) StateFields() []string {
return []string{
"multiCounterICMPv4PacketStats",
"invalid",
}
}
func (m *multiCounterICMPv4ReceivedPacketStats) beforeSave() {}
// +checklocksignore
func (m *multiCounterICMPv4ReceivedPacketStats) StateSave(stateSinkObject state.Sink) {
m.beforeSave()
stateSinkObject.Save(0, &m.multiCounterICMPv4PacketStats)
stateSinkObject.Save(1, &m.invalid)
}
func (m *multiCounterICMPv4ReceivedPacketStats) afterLoad(context.Context) {}
// +checklocksignore
func (m *multiCounterICMPv4ReceivedPacketStats) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &m.multiCounterICMPv4PacketStats)
stateSourceObject.Load(1, &m.invalid)
}
func (m *multiCounterICMPv4Stats) StateTypeName() string {
return "pkg/tcpip/network/ipv4.multiCounterICMPv4Stats"
}
func (m *multiCounterICMPv4Stats) StateFields() []string {
return []string{
"packetsSent",
"packetsReceived",
}
}
func (m *multiCounterICMPv4Stats) beforeSave() {}
// +checklocksignore
func (m *multiCounterICMPv4Stats) StateSave(stateSinkObject state.Sink) {
m.beforeSave()
stateSinkObject.Save(0, &m.packetsSent)
stateSinkObject.Save(1, &m.packetsReceived)
}
func (m *multiCounterICMPv4Stats) afterLoad(context.Context) {}
// +checklocksignore
func (m *multiCounterICMPv4Stats) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &m.packetsSent)
stateSourceObject.Load(1, &m.packetsReceived)
}
func (m *multiCounterIGMPPacketStats) StateTypeName() string {
return "pkg/tcpip/network/ipv4.multiCounterIGMPPacketStats"
}
func (m *multiCounterIGMPPacketStats) StateFields() []string {
return []string{
"membershipQuery",
"v1MembershipReport",
"v2MembershipReport",
"v3MembershipReport",
"leaveGroup",
}
}
func (m *multiCounterIGMPPacketStats) beforeSave() {}
// +checklocksignore
func (m *multiCounterIGMPPacketStats) StateSave(stateSinkObject state.Sink) {
m.beforeSave()
stateSinkObject.Save(0, &m.membershipQuery)
stateSinkObject.Save(1, &m.v1MembershipReport)
stateSinkObject.Save(2, &m.v2MembershipReport)
stateSinkObject.Save(3, &m.v3MembershipReport)
stateSinkObject.Save(4, &m.leaveGroup)
}
func (m *multiCounterIGMPPacketStats) afterLoad(context.Context) {}
// +checklocksignore
func (m *multiCounterIGMPPacketStats) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &m.membershipQuery)
stateSourceObject.Load(1, &m.v1MembershipReport)
stateSourceObject.Load(2, &m.v2MembershipReport)
stateSourceObject.Load(3, &m.v3MembershipReport)
stateSourceObject.Load(4, &m.leaveGroup)
}
func (m *multiCounterIGMPSentPacketStats) StateTypeName() string {
return "pkg/tcpip/network/ipv4.multiCounterIGMPSentPacketStats"
}
func (m *multiCounterIGMPSentPacketStats) StateFields() []string {
return []string{
"multiCounterIGMPPacketStats",
"dropped",
}
}
func (m *multiCounterIGMPSentPacketStats) beforeSave() {}
// +checklocksignore
func (m *multiCounterIGMPSentPacketStats) StateSave(stateSinkObject state.Sink) {
m.beforeSave()
stateSinkObject.Save(0, &m.multiCounterIGMPPacketStats)
stateSinkObject.Save(1, &m.dropped)
}
func (m *multiCounterIGMPSentPacketStats) afterLoad(context.Context) {}
// +checklocksignore
func (m *multiCounterIGMPSentPacketStats) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &m.multiCounterIGMPPacketStats)
stateSourceObject.Load(1, &m.dropped)
}
func (m *multiCounterIGMPReceivedPacketStats) StateTypeName() string {
return "pkg/tcpip/network/ipv4.multiCounterIGMPReceivedPacketStats"
}
func (m *multiCounterIGMPReceivedPacketStats) StateFields() []string {
return []string{
"multiCounterIGMPPacketStats",
"invalid",
"checksumErrors",
"unrecognized",
}
}
func (m *multiCounterIGMPReceivedPacketStats) beforeSave() {}
// +checklocksignore
func (m *multiCounterIGMPReceivedPacketStats) StateSave(stateSinkObject state.Sink) {
m.beforeSave()
stateSinkObject.Save(0, &m.multiCounterIGMPPacketStats)
stateSinkObject.Save(1, &m.invalid)
stateSinkObject.Save(2, &m.checksumErrors)
stateSinkObject.Save(3, &m.unrecognized)
}
func (m *multiCounterIGMPReceivedPacketStats) afterLoad(context.Context) {}
// +checklocksignore
func (m *multiCounterIGMPReceivedPacketStats) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &m.multiCounterIGMPPacketStats)
stateSourceObject.Load(1, &m.invalid)
stateSourceObject.Load(2, &m.checksumErrors)
stateSourceObject.Load(3, &m.unrecognized)
}
func (m *multiCounterIGMPStats) StateTypeName() string {
return "pkg/tcpip/network/ipv4.multiCounterIGMPStats"
}
func (m *multiCounterIGMPStats) StateFields() []string {
return []string{
"packetsSent",
"packetsReceived",
}
}
func (m *multiCounterIGMPStats) beforeSave() {}
// +checklocksignore
func (m *multiCounterIGMPStats) StateSave(stateSinkObject state.Sink) {
m.beforeSave()
stateSinkObject.Save(0, &m.packetsSent)
stateSinkObject.Save(1, &m.packetsReceived)
}
func (m *multiCounterIGMPStats) afterLoad(context.Context) {}
// +checklocksignore
func (m *multiCounterIGMPStats) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &m.packetsSent)
stateSourceObject.Load(1, &m.packetsReceived)
}
func init() {
state.Register((*icmpv4DestinationUnreachableSockError)(nil))
state.Register((*icmpv4DestinationHostUnreachableSockError)(nil))
state.Register((*icmpv4DestinationNetUnreachableSockError)(nil))
state.Register((*icmpv4DestinationPortUnreachableSockError)(nil))
state.Register((*icmpv4DestinationProtoUnreachableSockError)(nil))
state.Register((*icmpv4SourceRouteFailedSockError)(nil))
state.Register((*icmpv4SourceHostIsolatedSockError)(nil))
state.Register((*icmpv4DestinationHostUnknownSockError)(nil))
state.Register((*icmpv4FragmentationNeededSockError)(nil))
state.Register((*IGMPOptions)(nil))
state.Register((*igmpState)(nil))
state.Register((*endpoint)(nil))
state.Register((*protocol)(nil))
state.Register((*Options)(nil))
state.Register((*Stats)(nil))
state.Register((*sharedStats)(nil))
state.Register((*multiCounterICMPv4PacketStats)(nil))
state.Register((*multiCounterICMPv4SentPacketStats)(nil))
state.Register((*multiCounterICMPv4ReceivedPacketStats)(nil))
state.Register((*multiCounterICMPv4Stats)(nil))
state.Register((*multiCounterIGMPPacketStats)(nil))
state.Register((*multiCounterIGMPSentPacketStats)(nil))
state.Register((*multiCounterIGMPReceivedPacketStats)(nil))
state.Register((*multiCounterIGMPStats)(nil))
}

View File

@@ -0,0 +1,203 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ipv4
import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/network/internal/ip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
var _ stack.IPNetworkEndpointStats = (*Stats)(nil)
// Stats holds statistics related to the IPv4 protocol family.
//
// +stateify savable
type Stats struct {
// IP holds IPv4 statistics.
IP tcpip.IPStats
// IGMP holds IGMP statistics.
IGMP tcpip.IGMPStats
// ICMP holds ICMPv4 statistics.
ICMP tcpip.ICMPv4Stats
}
// IsNetworkEndpointStats implements stack.NetworkEndpointStats.
func (*Stats) IsNetworkEndpointStats() {}
// IPStats implements stack.IPNetworkEndointStats
func (s *Stats) IPStats() *tcpip.IPStats {
return &s.IP
}
// +stateify savable
type sharedStats struct {
localStats Stats
ip ip.MultiCounterIPStats
icmp multiCounterICMPv4Stats
igmp multiCounterIGMPStats
}
// LINT.IfChange(multiCounterICMPv4PacketStats)
// +stateify savable
type multiCounterICMPv4PacketStats struct {
echoRequest tcpip.MultiCounterStat
echoReply tcpip.MultiCounterStat
dstUnreachable tcpip.MultiCounterStat
srcQuench tcpip.MultiCounterStat
redirect tcpip.MultiCounterStat
timeExceeded tcpip.MultiCounterStat
paramProblem tcpip.MultiCounterStat
timestamp tcpip.MultiCounterStat
timestampReply tcpip.MultiCounterStat
infoRequest tcpip.MultiCounterStat
infoReply tcpip.MultiCounterStat
}
func (m *multiCounterICMPv4PacketStats) init(a, b *tcpip.ICMPv4PacketStats) {
m.echoRequest.Init(a.EchoRequest, b.EchoRequest)
m.echoReply.Init(a.EchoReply, b.EchoReply)
m.dstUnreachable.Init(a.DstUnreachable, b.DstUnreachable)
m.srcQuench.Init(a.SrcQuench, b.SrcQuench)
m.redirect.Init(a.Redirect, b.Redirect)
m.timeExceeded.Init(a.TimeExceeded, b.TimeExceeded)
m.paramProblem.Init(a.ParamProblem, b.ParamProblem)
m.timestamp.Init(a.Timestamp, b.Timestamp)
m.timestampReply.Init(a.TimestampReply, b.TimestampReply)
m.infoRequest.Init(a.InfoRequest, b.InfoRequest)
m.infoReply.Init(a.InfoReply, b.InfoReply)
}
// LINT.ThenChange(../../tcpip.go:ICMPv4PacketStats)
// LINT.IfChange(multiCounterICMPv4SentPacketStats)
// +stateify savable
type multiCounterICMPv4SentPacketStats struct {
multiCounterICMPv4PacketStats
dropped tcpip.MultiCounterStat
rateLimited tcpip.MultiCounterStat
}
func (m *multiCounterICMPv4SentPacketStats) init(a, b *tcpip.ICMPv4SentPacketStats) {
m.multiCounterICMPv4PacketStats.init(&a.ICMPv4PacketStats, &b.ICMPv4PacketStats)
m.dropped.Init(a.Dropped, b.Dropped)
m.rateLimited.Init(a.RateLimited, b.RateLimited)
}
// LINT.ThenChange(../../tcpip.go:ICMPv4SentPacketStats)
// LINT.IfChange(multiCounterICMPv4ReceivedPacketStats)
// +stateify savable
type multiCounterICMPv4ReceivedPacketStats struct {
multiCounterICMPv4PacketStats
invalid tcpip.MultiCounterStat
}
func (m *multiCounterICMPv4ReceivedPacketStats) init(a, b *tcpip.ICMPv4ReceivedPacketStats) {
m.multiCounterICMPv4PacketStats.init(&a.ICMPv4PacketStats, &b.ICMPv4PacketStats)
m.invalid.Init(a.Invalid, b.Invalid)
}
// LINT.ThenChange(../../tcpip.go:ICMPv4ReceivedPacketStats)
// LINT.IfChange(multiCounterICMPv4Stats)
// +stateify savable
type multiCounterICMPv4Stats struct {
packetsSent multiCounterICMPv4SentPacketStats
packetsReceived multiCounterICMPv4ReceivedPacketStats
}
func (m *multiCounterICMPv4Stats) init(a, b *tcpip.ICMPv4Stats) {
m.packetsSent.init(&a.PacketsSent, &b.PacketsSent)
m.packetsReceived.init(&a.PacketsReceived, &b.PacketsReceived)
}
// LINT.ThenChange(../../tcpip.go:ICMPv4Stats)
// LINT.IfChange(multiCounterIGMPPacketStats)
// +stateify savable
type multiCounterIGMPPacketStats struct {
membershipQuery tcpip.MultiCounterStat
v1MembershipReport tcpip.MultiCounterStat
v2MembershipReport tcpip.MultiCounterStat
v3MembershipReport tcpip.MultiCounterStat
leaveGroup tcpip.MultiCounterStat
}
func (m *multiCounterIGMPPacketStats) init(a, b *tcpip.IGMPPacketStats) {
m.membershipQuery.Init(a.MembershipQuery, b.MembershipQuery)
m.v1MembershipReport.Init(a.V1MembershipReport, b.V1MembershipReport)
m.v2MembershipReport.Init(a.V2MembershipReport, b.V2MembershipReport)
m.v3MembershipReport.Init(a.V3MembershipReport, b.V3MembershipReport)
m.leaveGroup.Init(a.LeaveGroup, b.LeaveGroup)
}
// LINT.ThenChange(../../tcpip.go:IGMPPacketStats)
// LINT.IfChange(multiCounterIGMPSentPacketStats)
// +stateify savable
type multiCounterIGMPSentPacketStats struct {
multiCounterIGMPPacketStats
dropped tcpip.MultiCounterStat
}
func (m *multiCounterIGMPSentPacketStats) init(a, b *tcpip.IGMPSentPacketStats) {
m.multiCounterIGMPPacketStats.init(&a.IGMPPacketStats, &b.IGMPPacketStats)
m.dropped.Init(a.Dropped, b.Dropped)
}
// LINT.ThenChange(../../tcpip.go:IGMPSentPacketStats)
// LINT.IfChange(multiCounterIGMPReceivedPacketStats)
// +stateify savable
type multiCounterIGMPReceivedPacketStats struct {
multiCounterIGMPPacketStats
invalid tcpip.MultiCounterStat
checksumErrors tcpip.MultiCounterStat
unrecognized tcpip.MultiCounterStat
}
func (m *multiCounterIGMPReceivedPacketStats) init(a, b *tcpip.IGMPReceivedPacketStats) {
m.multiCounterIGMPPacketStats.init(&a.IGMPPacketStats, &b.IGMPPacketStats)
m.invalid.Init(a.Invalid, b.Invalid)
m.checksumErrors.Init(a.ChecksumErrors, b.ChecksumErrors)
m.unrecognized.Init(a.Unrecognized, b.Unrecognized)
}
// LINT.ThenChange(../../tcpip.go:IGMPReceivedPacketStats)
// LINT.IfChange(multiCounterIGMPStats)
// +stateify savable
type multiCounterIGMPStats struct {
packetsSent multiCounterIGMPSentPacketStats
packetsReceived multiCounterIGMPReceivedPacketStats
}
func (m *multiCounterIGMPStats) init(a, b *tcpip.IGMPStats) {
m.packetsSent.init(&a.PacketsSent, &b.PacketsSent)
m.packetsReceived.init(&a.PacketsReceived, &b.PacketsReceived)
}
// LINT.ThenChange(../../tcpip.go:IGMPStats)

View File

@@ -0,0 +1,40 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Code generated by "stringer -type DHCPv6ConfigurationFromNDPRA"; DO NOT EDIT.
package ipv6
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[DHCPv6NoConfiguration-1]
_ = x[DHCPv6ManagedAddress-2]
_ = x[DHCPv6OtherConfigurations-3]
}
const _DHCPv6ConfigurationFromNDPRA_name = "DHCPv6NoConfigurationDHCPv6ManagedAddressDHCPv6OtherConfigurations"
var _DHCPv6ConfigurationFromNDPRA_index = [...]uint8{0, 21, 41, 66}
func (i DHCPv6ConfigurationFromNDPRA) String() string {
i -= 1
if i < 0 || i >= DHCPv6ConfigurationFromNDPRA(len(_DHCPv6ConfigurationFromNDPRA_index)-1) {
return "DHCPv6ConfigurationFromNDPRA(" + strconv.FormatInt(int64(i+1), 10) + ")"
}
return _DHCPv6ConfigurationFromNDPRA_name[_DHCPv6ConfigurationFromNDPRA_index[i]:_DHCPv6ConfigurationFromNDPRA_index[i+1]]
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,478 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ipv6
import (
"fmt"
"time"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/internal/ip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
const (
// UnsolicitedReportIntervalMax is the maximum delay between sending
// unsolicited MLD reports.
//
// Obtained from RFC 2710 Section 7.10.
UnsolicitedReportIntervalMax = 10 * time.Second
)
// MLDVersion is the forced version of MLD.
type MLDVersion int
const (
_ MLDVersion = iota
// MLDVersion1 indicates MLDv1.
MLDVersion1
// MLDVersion2 indicates MLDv2. Note that MLD may still fallback to V1
// compatibility mode as required by MLDv2.
MLDVersion2
)
// MLDEndpoint is a network endpoint that supports MLD.
type MLDEndpoint interface {
// SetMLDVersions sets the MLD version.
//
// Returns the previous MLD version.
SetMLDVersion(MLDVersion) MLDVersion
// GetMLDVersion returns the MLD version.
GetMLDVersion() MLDVersion
}
// MLDOptions holds options for MLD.
//
// +stateify savable
type MLDOptions struct {
// Enabled indicates whether MLD will be performed.
//
// When enabled, MLD may transmit MLD report and done messages when
// joining and leaving multicast groups respectively, and handle incoming
// MLD packets.
//
// This field is ignored and is always assumed to be false for interfaces
// without neighbouring nodes (e.g. loopback).
Enabled bool
}
var _ ip.MulticastGroupProtocol = (*mldState)(nil)
// mldState is the per-interface MLD state.
//
// mldState.init MUST be called to initialize the MLD state.
//
// +stateify savable
type mldState struct {
// The IPv6 endpoint this mldState is for.
ep *endpoint
genericMulticastProtocol ip.GenericMulticastProtocolState
}
// Enabled implements ip.MulticastGroupProtocol.
func (mld *mldState) Enabled() bool {
// No need to perform MLD on loopback interfaces since they don't have
// neighbouring nodes.
return mld.ep.protocol.options.MLD.Enabled && !mld.ep.nic.IsLoopback() && mld.ep.Enabled()
}
// SendReport implements ip.MulticastGroupProtocol.
//
// Precondition: mld.ep.mu must be read locked.
func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) {
return mld.writePacket(groupAddress, groupAddress, header.ICMPv6MulticastListenerReport)
}
// SendLeave implements ip.MulticastGroupProtocol.
//
// Precondition: mld.ep.mu must be read locked.
func (mld *mldState) SendLeave(groupAddress tcpip.Address) tcpip.Error {
_, err := mld.writePacket(header.IPv6AllRoutersLinkLocalMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone)
return err
}
// ShouldPerformProtocol implements ip.MulticastGroupProtocol.
func (mld *mldState) ShouldPerformProtocol(groupAddress tcpip.Address) bool {
// As per RFC 2710 section 5 page 10,
//
// The link-scope all-nodes address (FF02::1) is handled as a special
// case. The node starts in Idle Listener state for that address on
// every interface, never transitions to another state, and never sends
// a Report or Done for that address.
//
// MLD messages are never sent for multicast addresses whose scope is 0
// (reserved) or 1 (node-local).
if groupAddress == header.IPv6AllNodesMulticastAddress {
return false
}
scope := header.V6MulticastScope(groupAddress)
return scope != header.IPv6Reserved0MulticastScope && scope != header.IPv6InterfaceLocalMulticastScope
}
type mldv2ReportBuilder struct {
mld *mldState
records []header.MLDv2ReportMulticastAddressRecordSerializer
}
// AddRecord implements ip.MulticastGroupProtocolV2ReportBuilder.
func (b *mldv2ReportBuilder) AddRecord(genericRecordType ip.MulticastGroupProtocolV2ReportRecordType, groupAddress tcpip.Address) {
var recordType header.MLDv2ReportRecordType
switch genericRecordType {
case ip.MulticastGroupProtocolV2ReportRecordModeIsInclude:
recordType = header.MLDv2ReportRecordModeIsInclude
case ip.MulticastGroupProtocolV2ReportRecordModeIsExclude:
recordType = header.MLDv2ReportRecordModeIsExclude
case ip.MulticastGroupProtocolV2ReportRecordChangeToIncludeMode:
recordType = header.MLDv2ReportRecordChangeToIncludeMode
case ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode:
recordType = header.MLDv2ReportRecordChangeToExcludeMode
case ip.MulticastGroupProtocolV2ReportRecordAllowNewSources:
recordType = header.MLDv2ReportRecordAllowNewSources
case ip.MulticastGroupProtocolV2ReportRecordBlockOldSources:
recordType = header.MLDv2ReportRecordBlockOldSources
default:
panic(fmt.Sprintf("unrecognied genericRecordType = %d", genericRecordType))
}
b.records = append(b.records, header.MLDv2ReportMulticastAddressRecordSerializer{
RecordType: recordType,
MulticastAddress: groupAddress,
Sources: nil,
})
}
// Send implements ip.MulticastGroupProtocolV2ReportBuilder.
func (b *mldv2ReportBuilder) Send() (sent bool, err tcpip.Error) {
if len(b.records) == 0 {
return false, err
}
extensionHeaders := header.IPv6ExtHdrSerializer{
header.IPv6SerializableHopByHopExtHdr{
&header.IPv6RouterAlertOption{Value: header.IPv6RouterAlertMLD},
},
}
mtu := int(b.mld.ep.MTU()) - extensionHeaders.Length()
allSentWithSpecifiedAddress := true
var firstErr tcpip.Error
for records := b.records; len(records) != 0; {
spaceLeft := mtu
maxRecords := 0
for ; maxRecords < len(records); maxRecords++ {
tmp := spaceLeft - records[maxRecords].Length()
if tmp > 0 {
spaceLeft = tmp
} else {
break
}
}
serializer := header.MLDv2ReportSerializer{Records: records[:maxRecords]}
records = records[maxRecords:]
icmpView := buffer.NewViewSize(header.ICMPv6HeaderSize + serializer.Length())
icmp := header.ICMPv6(icmpView.AsSlice())
serializer.SerializeInto(icmp.MessageBody())
if sentWithSpecifiedAddress, err := b.mld.writePacketInner(
icmpView,
header.ICMPv6MulticastListenerV2Report,
b.mld.ep.stats.icmp.packetsSent.multicastListenerReportV2,
extensionHeaders,
header.MLDv2RoutersAddress,
); err != nil {
if firstErr != nil {
firstErr = nil
}
allSentWithSpecifiedAddress = false
} else if !sentWithSpecifiedAddress {
allSentWithSpecifiedAddress = false
}
}
return allSentWithSpecifiedAddress, firstErr
}
// NewReportV2Builder implements ip.MulticastGroupProtocol.
func (mld *mldState) NewReportV2Builder() ip.MulticastGroupProtocolV2ReportBuilder {
return &mldv2ReportBuilder{mld: mld}
}
// V2QueryMaxRespCodeToV2Delay implements ip.MulticastGroupProtocol.
func (*mldState) V2QueryMaxRespCodeToV2Delay(code uint16) time.Duration {
return header.MLDv2MaximumResponseDelay(code)
}
// V2QueryMaxRespCodeToV1Delay implements ip.MulticastGroupProtocol.
func (*mldState) V2QueryMaxRespCodeToV1Delay(code uint16) time.Duration {
return time.Duration(code) * time.Millisecond
}
// init sets up an mldState struct, and is required to be called before using
// a new mldState.
//
// Must only be called once for the lifetime of mld.
func (mld *mldState) init(ep *endpoint) {
mld.ep = ep
mld.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{
Rand: ep.protocol.stack.InsecureRNG(),
Clock: ep.protocol.stack.Clock(),
Protocol: mld,
MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax,
})
}
// handleMulticastListenerQuery handles a query message.
//
// Precondition: mld.ep.mu must be locked.
func (mld *mldState) handleMulticastListenerQuery(mldHdr header.MLD) {
mld.genericMulticastProtocol.HandleQueryLocked(mldHdr.MulticastAddress(), mldHdr.MaximumResponseDelay())
}
// handleMulticastListenerQueryV2 handles a V2 query message.
//
// Precondition: mld.ep.mu must be locked.
func (mld *mldState) handleMulticastListenerQueryV2(mldHdr header.MLDv2Query) {
sources, ok := mldHdr.Sources()
if !ok {
return
}
mld.genericMulticastProtocol.HandleQueryV2Locked(
mldHdr.MulticastAddress(),
mldHdr.MaximumResponseCode(),
sources,
mldHdr.QuerierRobustnessVariable(),
mldHdr.QuerierQueryInterval(),
)
}
// handleMulticastListenerReport handles a report message.
//
// Precondition: mld.ep.mu must be locked.
func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) {
mld.genericMulticastProtocol.HandleReportLocked(mldHdr.MulticastAddress())
}
// joinGroup handles joining a new group and sending and scheduling the required
// messages.
//
// If the group is already joined, returns *tcpip.ErrDuplicateAddress.
//
// Precondition: mld.ep.mu must be locked.
func (mld *mldState) joinGroup(groupAddress tcpip.Address) {
mld.genericMulticastProtocol.JoinGroupLocked(groupAddress)
}
// isInGroup returns true if the specified group has been joined locally.
//
// Precondition: mld.ep.mu must be read locked.
func (mld *mldState) isInGroup(groupAddress tcpip.Address) bool {
return mld.genericMulticastProtocol.IsLocallyJoinedRLocked(groupAddress)
}
// leaveGroup handles removing the group from the membership map, cancels any
// delay timers associated with that group, and sends the Done message, if
// required.
//
// Precondition: mld.ep.mu must be locked.
func (mld *mldState) leaveGroup(groupAddress tcpip.Address) tcpip.Error {
// LeaveGroup returns false only if the group was not joined.
if mld.genericMulticastProtocol.LeaveGroupLocked(groupAddress) {
return nil
}
return &tcpip.ErrBadLocalAddress{}
}
// softLeaveAll leaves all groups from the perspective of MLD, but remains
// joined locally.
//
// Precondition: mld.ep.mu must be locked.
func (mld *mldState) softLeaveAll() {
mld.genericMulticastProtocol.MakeAllNonMemberLocked()
}
// initializeAll attempts to initialize the MLD state for each group that has
// been joined locally.
//
// Precondition: mld.ep.mu must be locked.
func (mld *mldState) initializeAll() {
mld.genericMulticastProtocol.InitializeGroupsLocked()
}
// sendQueuedReports attempts to send any reports that are queued for sending.
//
// Precondition: mld.ep.mu must be locked.
func (mld *mldState) sendQueuedReports() {
mld.genericMulticastProtocol.SendQueuedReportsLocked()
}
// setVersion sets the MLD version.
//
// Precondition: mld.ep.mu must be locked.
func (mld *mldState) setVersion(v MLDVersion) MLDVersion {
var prev bool
switch v {
case MLDVersion2:
prev = mld.genericMulticastProtocol.SetV1ModeLocked(false)
case MLDVersion1:
prev = mld.genericMulticastProtocol.SetV1ModeLocked(true)
default:
panic(fmt.Sprintf("unrecognized version = %d", v))
}
return toMLDVersion(prev)
}
func toMLDVersion(v1Generic bool) MLDVersion {
if v1Generic {
return MLDVersion1
}
return MLDVersion2
}
// getVersion returns the MLD version.
//
// Precondition: mld.ep.mu must be read locked.
func (mld *mldState) getVersion() MLDVersion {
return toMLDVersion(mld.genericMulticastProtocol.GetV1ModeLocked())
}
// writePacket assembles and sends an MLD packet.
//
// Precondition: mld.ep.mu must be read locked.
func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) (bool, tcpip.Error) {
sentStats := mld.ep.stats.icmp.packetsSent
var mldStat tcpip.MultiCounterStat
switch mldType {
case header.ICMPv6MulticastListenerReport:
mldStat = sentStats.multicastListenerReport
case header.ICMPv6MulticastListenerDone:
mldStat = sentStats.multicastListenerDone
default:
panic(fmt.Sprintf("unrecognized mld type = %d", mldType))
}
icmpView := buffer.NewViewSize(header.ICMPv6HeaderSize + header.MLDMinimumSize)
icmp := header.ICMPv6(icmpView.AsSlice())
header.MLD(icmp.MessageBody()).SetMulticastAddress(groupAddress)
extensionHeaders := header.IPv6ExtHdrSerializer{
header.IPv6SerializableHopByHopExtHdr{
&header.IPv6RouterAlertOption{Value: header.IPv6RouterAlertMLD},
},
}
return mld.writePacketInner(
icmpView,
mldType,
mldStat,
extensionHeaders,
destAddress,
)
}
func (mld *mldState) writePacketInner(buf *buffer.View, mldType header.ICMPv6Type, reportStat tcpip.MultiCounterStat, extensionHeaders header.IPv6ExtHdrSerializer, destAddress tcpip.Address) (bool, tcpip.Error) {
icmp := header.ICMPv6(buf.AsSlice())
icmp.SetType(mldType)
// As per RFC 2710 section 3,
//
// All MLD messages described in this document are sent with a link-local
// IPv6 Source Address, an IPv6 Hop Limit of 1, and an IPv6 Router Alert
// option in a Hop-by-Hop Options header.
//
// However, this would cause problems with Duplicate Address Detection with
// the first address as MLD snooping switches may not send multicast traffic
// that DAD depends on to the node performing DAD without the MLD report, as
// documented in RFC 4816:
//
// Note that when a node joins a multicast address, it typically sends a
// Multicast Listener Discovery (MLD) report message [RFC2710] [RFC3810]
// for the multicast address. In the case of Duplicate Address
// Detection, the MLD report message is required in order to inform MLD-
// snooping switches, rather than routers, to forward multicast packets.
// In the above description, the delay for joining the multicast address
// thus means delaying transmission of the corresponding MLD report
// message. Since the MLD specifications do not request a random delay
// to avoid race conditions, just delaying Neighbor Solicitation would
// cause congestion by the MLD report messages. The congestion would
// then prevent the MLD-snooping switches from working correctly and, as
// a result, prevent Duplicate Address Detection from working. The
// requirement to include the delay for the MLD report in this case
// avoids this scenario. [RFC3590] also talks about some interaction
// issues between Duplicate Address Detection and MLD, and specifies
// which source address should be used for the MLD report in this case.
//
// As per RFC 3590 section 4, we should still send out MLD reports with an
// unspecified source address if we do not have an assigned link-local
// address to use as the source address to ensure DAD works as expected on
// networks with MLD snooping switches:
//
// MLD Report and Done messages are sent with a link-local address as
// the IPv6 source address, if a valid address is available on the
// interface. If a valid link-local address is not available (e.g., one
// has not been configured), the message is sent with the unspecified
// address (::) as the IPv6 source address.
//
// Once a valid link-local address is available, a node SHOULD generate
// new MLD Report messages for all multicast addresses joined on the
// interface.
//
// Routers receiving an MLD Report or Done message with the unspecified
// address as the IPv6 source address MUST silently discard the packet
// without taking any action on the packets contents.
//
// Snooping switches MUST manage multicast forwarding state based on MLD
// Report and Done messages sent with the unspecified address as the
// IPv6 source address.
localAddress := mld.ep.getLinkLocalAddressRLocked()
if localAddress.BitLen() == 0 {
localAddress = header.IPv6Any
}
icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmp,
Src: localAddress,
Dst: destAddress,
}))
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(mld.ep.MaxHeaderLength()) + extensionHeaders.Length(),
Payload: buffer.MakeWithView(buf),
})
defer pkt.DecRef()
if err := addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.MLDHopLimit,
}, extensionHeaders); err != nil {
panic(fmt.Sprintf("failed to add IP header: %s", err))
}
if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), pkt); err != nil {
mld.ep.stats.icmp.packetsSent.dropped.Increment()
return false, err
}
reportStat.Increment()
return localAddress != header.IPv6Any, nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,145 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ipv6
import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/network/internal/ip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
var _ stack.IPNetworkEndpointStats = (*Stats)(nil)
// Stats holds statistics related to the IPv6 protocol family.
//
// +stateify savable
type Stats struct {
// IP holds IPv6 statistics.
IP tcpip.IPStats
// ICMP holds ICMPv6 statistics.
ICMP tcpip.ICMPv6Stats
// UnhandledRouterAdvertisements is the number of Router Advertisements that
// were observed but not handled.
UnhandledRouterAdvertisements *tcpip.StatCounter
}
// IsNetworkEndpointStats implements stack.NetworkEndpointStats.
func (*Stats) IsNetworkEndpointStats() {}
// IPStats implements stack.IPNetworkEndointStats
func (s *Stats) IPStats() *tcpip.IPStats {
return &s.IP
}
// +stateify savable
type sharedStats struct {
localStats Stats
ip ip.MultiCounterIPStats
icmp multiCounterICMPv6Stats
}
// LINT.IfChange(multiCounterICMPv6PacketStats)
// +stateify savable
type multiCounterICMPv6PacketStats struct {
echoRequest tcpip.MultiCounterStat
echoReply tcpip.MultiCounterStat
dstUnreachable tcpip.MultiCounterStat
packetTooBig tcpip.MultiCounterStat
timeExceeded tcpip.MultiCounterStat
paramProblem tcpip.MultiCounterStat
routerSolicit tcpip.MultiCounterStat
routerAdvert tcpip.MultiCounterStat
neighborSolicit tcpip.MultiCounterStat
neighborAdvert tcpip.MultiCounterStat
redirectMsg tcpip.MultiCounterStat
multicastListenerQuery tcpip.MultiCounterStat
multicastListenerReport tcpip.MultiCounterStat
multicastListenerReportV2 tcpip.MultiCounterStat
multicastListenerDone tcpip.MultiCounterStat
}
func (m *multiCounterICMPv6PacketStats) init(a, b *tcpip.ICMPv6PacketStats) {
m.echoRequest.Init(a.EchoRequest, b.EchoRequest)
m.echoReply.Init(a.EchoReply, b.EchoReply)
m.dstUnreachable.Init(a.DstUnreachable, b.DstUnreachable)
m.packetTooBig.Init(a.PacketTooBig, b.PacketTooBig)
m.timeExceeded.Init(a.TimeExceeded, b.TimeExceeded)
m.paramProblem.Init(a.ParamProblem, b.ParamProblem)
m.routerSolicit.Init(a.RouterSolicit, b.RouterSolicit)
m.routerAdvert.Init(a.RouterAdvert, b.RouterAdvert)
m.neighborSolicit.Init(a.NeighborSolicit, b.NeighborSolicit)
m.neighborAdvert.Init(a.NeighborAdvert, b.NeighborAdvert)
m.redirectMsg.Init(a.RedirectMsg, b.RedirectMsg)
m.multicastListenerQuery.Init(a.MulticastListenerQuery, b.MulticastListenerQuery)
m.multicastListenerReport.Init(a.MulticastListenerReport, b.MulticastListenerReport)
m.multicastListenerReportV2.Init(a.MulticastListenerReportV2, b.MulticastListenerReportV2)
m.multicastListenerDone.Init(a.MulticastListenerDone, b.MulticastListenerDone)
}
// LINT.ThenChange(../../tcpip.go:ICMPv6PacketStats)
// LINT.IfChange(multiCounterICMPv6SentPacketStats)
// +stateify savable
type multiCounterICMPv6SentPacketStats struct {
multiCounterICMPv6PacketStats
dropped tcpip.MultiCounterStat
rateLimited tcpip.MultiCounterStat
}
func (m *multiCounterICMPv6SentPacketStats) init(a, b *tcpip.ICMPv6SentPacketStats) {
m.multiCounterICMPv6PacketStats.init(&a.ICMPv6PacketStats, &b.ICMPv6PacketStats)
m.dropped.Init(a.Dropped, b.Dropped)
m.rateLimited.Init(a.RateLimited, b.RateLimited)
}
// LINT.ThenChange(../../tcpip.go:ICMPv6SentPacketStats)
// LINT.IfChange(multiCounterICMPv6ReceivedPacketStats)
// +stateify savable
type multiCounterICMPv6ReceivedPacketStats struct {
multiCounterICMPv6PacketStats
unrecognized tcpip.MultiCounterStat
invalid tcpip.MultiCounterStat
routerOnlyPacketsDroppedByHost tcpip.MultiCounterStat
}
func (m *multiCounterICMPv6ReceivedPacketStats) init(a, b *tcpip.ICMPv6ReceivedPacketStats) {
m.multiCounterICMPv6PacketStats.init(&a.ICMPv6PacketStats, &b.ICMPv6PacketStats)
m.unrecognized.Init(a.Unrecognized, b.Unrecognized)
m.invalid.Init(a.Invalid, b.Invalid)
m.routerOnlyPacketsDroppedByHost.Init(a.RouterOnlyPacketsDroppedByHost, b.RouterOnlyPacketsDroppedByHost)
}
// LINT.ThenChange(../../tcpip.go:ICMPv6ReceivedPacketStats)
// LINT.IfChange(multiCounterICMPv6Stats)
// +stateify savable
type multiCounterICMPv6Stats struct {
packetsSent multiCounterICMPv6SentPacketStats
packetsReceived multiCounterICMPv6ReceivedPacketStats
}
func (m *multiCounterICMPv6Stats) init(a, b *tcpip.ICMPv6Stats) {
m.packetsSent.init(&a.PacketsSent, &b.PacketsSent)
m.packetsReceived.init(&a.PacketsReceived, &b.PacketsReceived)
}
// LINT.ThenChange(../../tcpip.go:ICMPv6Stats)

View File

@@ -0,0 +1,152 @@
// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ports
// Flags represents the type of port reservation.
//
// +stateify savable
type Flags struct {
// MostRecent represents UDP SO_REUSEADDR.
MostRecent bool
// LoadBalanced indicates SO_REUSEPORT.
//
// LoadBalanced takes precedence over MostRecent.
LoadBalanced bool
// TupleOnly represents TCP SO_REUSEADDR.
TupleOnly bool
}
// Bits converts the Flags to their bitset form.
func (f Flags) Bits() BitFlags {
var rf BitFlags
if f.MostRecent {
rf |= MostRecentFlag
}
if f.LoadBalanced {
rf |= LoadBalancedFlag
}
if f.TupleOnly {
rf |= TupleOnlyFlag
}
return rf
}
// Effective returns the effective behavior of a flag config.
func (f Flags) Effective() Flags {
e := f
if e.LoadBalanced && e.MostRecent {
e.MostRecent = false
}
return e
}
// BitFlags is a bitset representation of Flags.
type BitFlags uint32
const (
// MostRecentFlag represents Flags.MostRecent.
MostRecentFlag BitFlags = 1 << iota
// LoadBalancedFlag represents Flags.LoadBalanced.
LoadBalancedFlag
// TupleOnlyFlag represents Flags.TupleOnly.
TupleOnlyFlag
// nextFlag is the value that the next added flag will have.
//
// It is used to calculate FlagMask below. It is also the number of
// valid flag states.
nextFlag
// FlagMask is a bit mask for BitFlags.
FlagMask = nextFlag - 1
// MultiBindFlagMask contains the flags that allow binding the same
// tuple multiple times.
MultiBindFlagMask = MostRecentFlag | LoadBalancedFlag
)
// ToFlags converts the bitset into a Flags struct.
func (f BitFlags) ToFlags() Flags {
return Flags{
MostRecent: f&MostRecentFlag != 0,
LoadBalanced: f&LoadBalancedFlag != 0,
TupleOnly: f&TupleOnlyFlag != 0,
}
}
// FlagCounter counts how many references each flag combination has.
//
// +stateify savable
type FlagCounter struct {
// refs stores the count for each possible flag combination, (0 though
// FlagMask).
refs [nextFlag]int
}
// AddRef increases the reference count for a specific flag combination.
func (c *FlagCounter) AddRef(flags BitFlags) {
c.refs[flags]++
}
// DropRef decreases the reference count for a specific flag combination.
func (c *FlagCounter) DropRef(flags BitFlags) {
c.refs[flags]--
}
// TotalRefs calculates the total number of references for all flag
// combinations.
func (c FlagCounter) TotalRefs() int {
var total int
for _, r := range c.refs {
total += r
}
return total
}
// FlagRefs returns the number of references with all specified flags.
func (c FlagCounter) FlagRefs(flags BitFlags) int {
var total int
for i, r := range c.refs {
if BitFlags(i)&flags == flags {
total += r
}
}
return total
}
// AllRefsHave returns if all references have all specified flags.
func (c FlagCounter) AllRefsHave(flags BitFlags) bool {
for i, r := range c.refs {
if BitFlags(i)&flags != flags && r > 0 {
return false
}
}
return true
}
// SharedFlags returns the set of flags shared by all references.
func (c FlagCounter) SharedFlags() BitFlags {
intersection := FlagMask
for i, r := range c.refs {
if r > 0 {
intersection &= BitFlags(i)
}
}
return intersection
}

View File

@@ -0,0 +1,498 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package ports provides PortManager that manages allocating, reserving and
// releasing ports.
package ports
import (
"math"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
const (
firstEphemeral = 16000
)
var (
anyIPAddress = tcpip.Address{}
)
// Reservation describes a port reservation.
type Reservation struct {
// Networks is a list of network protocols to which the reservation
// applies. Can be IPv4, IPv6, or both.
Networks []tcpip.NetworkProtocolNumber
// Transport is the transport protocol to which the reservation applies.
Transport tcpip.TransportProtocolNumber
// Addr is the address of the local endpoint.
Addr tcpip.Address
// Port is the local port number.
Port uint16
// Flags describe features of the reservation.
Flags Flags
// BindToDevice is the NIC to which the reservation applies.
BindToDevice tcpip.NICID
// Dest is the destination address.
Dest tcpip.FullAddress
}
func (rs Reservation) dst() destination {
return destination{
rs.Dest.Addr,
rs.Dest.Port,
}
}
// +stateify savable
type portDescriptor struct {
network tcpip.NetworkProtocolNumber
transport tcpip.TransportProtocolNumber
port uint16
}
// +stateify savable
type destination struct {
addr tcpip.Address
port uint16
}
// destToCounter maps each destination to the FlagCounter that represents
// endpoints to that destination.
//
// destToCounter is never empty. When it has no elements, it is removed from
// the map that references it.
type destToCounter map[destination]FlagCounter
// intersectionFlags calculates the intersection of flag bit values which affect
// the specified destination.
//
// If no destinations are present, all flag values are returned as there are no
// entries to limit possible flag values of a new entry.
//
// In addition to the intersection, the number of intersecting refs is
// returned.
func (dc destToCounter) intersectionFlags(res Reservation) (BitFlags, int) {
intersection := FlagMask
var count int
for dest, counter := range dc {
if dest == res.dst() {
intersection &= counter.SharedFlags()
count++
continue
}
// Wildcard destinations affect all destinations for TupleOnly.
if dest.addr == anyIPAddress || res.Dest.Addr == anyIPAddress {
// Only bitwise and the TupleOnlyFlag.
intersection &= (^TupleOnlyFlag) | counter.SharedFlags()
count++
}
}
return intersection, count
}
// deviceToDest maps NICs to destinations for which there are port reservations.
//
// deviceToDest is never empty. When it has no elements, it is removed from the
// map that references it.
type deviceToDest map[tcpip.NICID]destToCounter
// isAvailable checks whether binding is possible by device. If not binding to
// a device, check against all FlagCounters. If binding to a specific device,
// check against the unspecified device and the provided device.
//
// If either of the port reuse flags is enabled on any of the nodes, all nodes
// sharing a port must share at least one reuse flag. This matches Linux's
// behavior.
func (dd deviceToDest) isAvailable(res Reservation, portSpecified bool) bool {
flagBits := res.Flags.Bits()
if res.BindToDevice == 0 {
intersection := FlagMask
for _, dest := range dd {
flags, count := dest.intersectionFlags(res)
if count == 0 {
continue
}
intersection &= flags
if intersection&flagBits == 0 {
// Can't bind because the (addr,port) was
// previously bound without reuse.
return false
}
}
if !portSpecified && res.Transport == header.TCPProtocolNumber {
return false
}
return true
}
intersection := FlagMask
if dests, ok := dd[0]; ok {
var count int
intersection, count = dests.intersectionFlags(res)
if count > 0 {
if intersection&flagBits == 0 {
return false
}
if !portSpecified && res.Transport == header.TCPProtocolNumber {
return false
}
}
}
if dests, ok := dd[res.BindToDevice]; ok {
flags, count := dests.intersectionFlags(res)
intersection &= flags
if count > 0 {
if intersection&flagBits == 0 {
return false
}
if !portSpecified && res.Transport == header.TCPProtocolNumber {
return false
}
}
}
return true
}
// addrToDevice maps IP addresses to NICs that have port reservations.
type addrToDevice map[tcpip.Address]deviceToDest
// isAvailable checks whether an IP address is available to bind to. If the
// address is the "any" address, check all other addresses. Otherwise, just
// check against the "any" address and the provided address.
func (ad addrToDevice) isAvailable(res Reservation, portSpecified bool) bool {
if res.Addr == anyIPAddress {
// If binding to the "any" address then check that there are no
// conflicts with all addresses.
for _, devices := range ad {
if !devices.isAvailable(res, portSpecified) {
return false
}
}
return true
}
// Check that there is no conflict with the "any" address.
if devices, ok := ad[anyIPAddress]; ok {
if !devices.isAvailable(res, portSpecified) {
return false
}
}
// Check that this is no conflict with the provided address.
if devices, ok := ad[res.Addr]; ok {
if !devices.isAvailable(res, portSpecified) {
return false
}
}
return true
}
// PortManager manages allocating, reserving and releasing ports.
//
// +stateify savable
type PortManager struct {
// mu protects allocatedPorts.
// LOCK ORDERING: mu > ephemeralMu.
mu sync.RWMutex `state:"nosave"`
// allocatedPorts is a nesting of maps that ultimately map Reservations
// to FlagCounters describing whether the Reservation is valid and can
// be reused.
allocatedPorts map[portDescriptor]addrToDevice
// ephemeralMu protects firstEphemeral and numEphemeral.
ephemeralMu sync.RWMutex `state:"nosave"`
firstEphemeral uint16
numEphemeral uint16
}
// NewPortManager creates new PortManager.
func NewPortManager() *PortManager {
return &PortManager{
allocatedPorts: make(map[portDescriptor]addrToDevice),
firstEphemeral: firstEphemeral,
numEphemeral: math.MaxUint16 - firstEphemeral + 1,
}
}
// PortTester indicates whether the passed in port is suitable. Returning an
// error causes the function to which the PortTester is passed to return that
// error.
type PortTester func(port uint16) (good bool, err tcpip.Error)
// PickEphemeralPort randomly chooses a starting point and iterates over all
// possible ephemeral ports, allowing the caller to decide whether a given port
// is suitable for its needs, and stopping when a port is found or an error
// occurs.
func (pm *PortManager) PickEphemeralPort(rng rand.RNG, testPort PortTester) (port uint16, err tcpip.Error) {
pm.ephemeralMu.RLock()
firstEphemeral := pm.firstEphemeral
numEphemeral := pm.numEphemeral
pm.ephemeralMu.RUnlock()
return pickEphemeralPort(rng.Uint32(), firstEphemeral, numEphemeral, testPort)
}
// pickEphemeralPort starts at the offset specified from the FirstEphemeral port
// and iterates over the number of ports specified by count and allows the
// caller to decide whether a given port is suitable for its needs, and stopping
// when a port is found or an error occurs.
func pickEphemeralPort(offset uint32, first, count uint16, testPort PortTester) (port uint16, err tcpip.Error) {
// This implements Algorithm 1 as per RFC 6056 Section 3.3.1.
for i := uint32(0); i < uint32(count); i++ {
port := uint16(uint32(first) + (offset+i)%uint32(count))
ok, err := testPort(port)
if err != nil {
return 0, err
}
if ok {
return port, nil
}
}
return 0, &tcpip.ErrNoPortAvailable{}
}
// ReservePort marks a port/IP combination as reserved so that it cannot be
// reserved by another endpoint. If port is zero, ReservePort will search for
// an unreserved ephemeral port and reserve it, returning its value in the
// "port" return value.
//
// An optional PortTester can be passed in which if provided will be used to
// test if the picked port can be used. The function should return true if the
// port is safe to use, false otherwise.
func (pm *PortManager) ReservePort(rng rand.RNG, res Reservation, testPort PortTester) (reservedPort uint16, err tcpip.Error) {
pm.mu.Lock()
defer pm.mu.Unlock()
// If a port is specified, just try to reserve it for all network
// protocols.
if res.Port != 0 {
if !pm.reserveSpecificPortLocked(res, true /* portSpecified */) {
return 0, &tcpip.ErrPortInUse{}
}
if testPort != nil {
ok, err := testPort(res.Port)
if err != nil {
pm.releasePortLocked(res)
return 0, err
}
if !ok {
pm.releasePortLocked(res)
return 0, &tcpip.ErrPortInUse{}
}
}
return res.Port, nil
}
// A port wasn't specified, so try to find one.
return pm.PickEphemeralPort(rng, func(p uint16) (bool, tcpip.Error) {
res.Port = p
if !pm.reserveSpecificPortLocked(res, false /* portSpecified */) {
return false, nil
}
if testPort != nil {
ok, err := testPort(p)
if err != nil {
pm.releasePortLocked(res)
return false, err
}
if !ok {
pm.releasePortLocked(res)
return false, nil
}
}
return true, nil
})
}
// reserveSpecificPortLocked tries to reserve the given port on all given
// protocols.
func (pm *PortManager) reserveSpecificPortLocked(res Reservation, portSpecified bool) bool {
// Make sure the port is available.
for _, network := range res.Networks {
desc := portDescriptor{network, res.Transport, res.Port}
if addrs, ok := pm.allocatedPorts[desc]; ok {
if !addrs.isAvailable(res, portSpecified) {
return false
}
}
}
// Reserve port on all network protocols.
flagBits := res.Flags.Bits()
dst := res.dst()
for _, network := range res.Networks {
desc := portDescriptor{network, res.Transport, res.Port}
addrToDev, ok := pm.allocatedPorts[desc]
if !ok {
addrToDev = make(addrToDevice)
pm.allocatedPorts[desc] = addrToDev
}
devToDest, ok := addrToDev[res.Addr]
if !ok {
devToDest = make(deviceToDest)
addrToDev[res.Addr] = devToDest
}
destToCntr := devToDest[res.BindToDevice]
if destToCntr == nil {
destToCntr = make(destToCounter)
}
counter := destToCntr[dst]
counter.AddRef(flagBits)
destToCntr[dst] = counter
devToDest[res.BindToDevice] = destToCntr
}
return true
}
// ReserveTuple adds a port reservation for the tuple on all given protocol.
func (pm *PortManager) ReserveTuple(res Reservation) bool {
flagBits := res.Flags.Bits()
dst := res.dst()
pm.mu.Lock()
defer pm.mu.Unlock()
// It is easier to undo the entire reservation, so if we find that the
// tuple can't be fully added, finish and undo the whole thing.
undo := false
// Reserve port on all network protocols.
for _, network := range res.Networks {
desc := portDescriptor{network, res.Transport, res.Port}
addrToDev, ok := pm.allocatedPorts[desc]
if !ok {
addrToDev = make(addrToDevice)
pm.allocatedPorts[desc] = addrToDev
}
devToDest, ok := addrToDev[res.Addr]
if !ok {
devToDest = make(deviceToDest)
addrToDev[res.Addr] = devToDest
}
destToCntr := devToDest[res.BindToDevice]
if destToCntr == nil {
destToCntr = make(destToCounter)
}
counter := destToCntr[dst]
if counter.TotalRefs() != 0 && counter.SharedFlags()&flagBits == 0 {
// Tuple already exists.
undo = true
}
counter.AddRef(flagBits)
destToCntr[dst] = counter
devToDest[res.BindToDevice] = destToCntr
}
if undo {
// releasePortLocked decrements the counts (rather than setting
// them to zero), so it will undo the incorrect incrementing
// above.
pm.releasePortLocked(res)
return false
}
return true
}
// ReleasePort releases the reservation on a port/IP combination so that it can
// be reserved by other endpoints.
func (pm *PortManager) ReleasePort(res Reservation) {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.releasePortLocked(res)
}
func (pm *PortManager) releasePortLocked(res Reservation) {
dst := res.dst()
for _, network := range res.Networks {
desc := portDescriptor{network, res.Transport, res.Port}
addrToDev, ok := pm.allocatedPorts[desc]
if !ok {
continue
}
devToDest, ok := addrToDev[res.Addr]
if !ok {
continue
}
destToCounter, ok := devToDest[res.BindToDevice]
if !ok {
continue
}
counter, ok := destToCounter[dst]
if !ok {
continue
}
counter.DropRef(res.Flags.Bits())
if counter.TotalRefs() > 0 {
destToCounter[dst] = counter
continue
}
delete(destToCounter, dst)
if len(destToCounter) > 0 {
continue
}
delete(devToDest, res.BindToDevice)
if len(devToDest) > 0 {
continue
}
delete(addrToDev, res.Addr)
if len(addrToDev) > 0 {
continue
}
delete(pm.allocatedPorts, desc)
}
}
// PortRange returns the UDP and TCP inclusive range of ephemeral ports used in
// both IPv4 and IPv6.
func (pm *PortManager) PortRange() (uint16, uint16) {
pm.ephemeralMu.RLock()
defer pm.ephemeralMu.RUnlock()
return pm.firstEphemeral, pm.firstEphemeral + pm.numEphemeral - 1
}
// SetPortRange sets the UDP and TCP IPv4 and IPv6 ephemeral port range
// (inclusive).
func (pm *PortManager) SetPortRange(start uint16, end uint16) tcpip.Error {
if start > end {
return &tcpip.ErrInvalidPortRange{}
}
pm.ephemeralMu.Lock()
defer pm.ephemeralMu.Unlock()
pm.firstEphemeral = start
pm.numEphemeral = end - start + 1
return nil
}

View File

@@ -0,0 +1,163 @@
// automatically generated by stateify.
package ports
import (
"context"
"gvisor.dev/gvisor/pkg/state"
)
func (f *Flags) StateTypeName() string {
return "pkg/tcpip/ports.Flags"
}
func (f *Flags) StateFields() []string {
return []string{
"MostRecent",
"LoadBalanced",
"TupleOnly",
}
}
func (f *Flags) beforeSave() {}
// +checklocksignore
func (f *Flags) StateSave(stateSinkObject state.Sink) {
f.beforeSave()
stateSinkObject.Save(0, &f.MostRecent)
stateSinkObject.Save(1, &f.LoadBalanced)
stateSinkObject.Save(2, &f.TupleOnly)
}
func (f *Flags) afterLoad(context.Context) {}
// +checklocksignore
func (f *Flags) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &f.MostRecent)
stateSourceObject.Load(1, &f.LoadBalanced)
stateSourceObject.Load(2, &f.TupleOnly)
}
func (c *FlagCounter) StateTypeName() string {
return "pkg/tcpip/ports.FlagCounter"
}
func (c *FlagCounter) StateFields() []string {
return []string{
"refs",
}
}
func (c *FlagCounter) beforeSave() {}
// +checklocksignore
func (c *FlagCounter) StateSave(stateSinkObject state.Sink) {
c.beforeSave()
stateSinkObject.Save(0, &c.refs)
}
func (c *FlagCounter) afterLoad(context.Context) {}
// +checklocksignore
func (c *FlagCounter) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &c.refs)
}
func (p *portDescriptor) StateTypeName() string {
return "pkg/tcpip/ports.portDescriptor"
}
func (p *portDescriptor) StateFields() []string {
return []string{
"network",
"transport",
"port",
}
}
func (p *portDescriptor) beforeSave() {}
// +checklocksignore
func (p *portDescriptor) StateSave(stateSinkObject state.Sink) {
p.beforeSave()
stateSinkObject.Save(0, &p.network)
stateSinkObject.Save(1, &p.transport)
stateSinkObject.Save(2, &p.port)
}
func (p *portDescriptor) afterLoad(context.Context) {}
// +checklocksignore
func (p *portDescriptor) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &p.network)
stateSourceObject.Load(1, &p.transport)
stateSourceObject.Load(2, &p.port)
}
func (d *destination) StateTypeName() string {
return "pkg/tcpip/ports.destination"
}
func (d *destination) StateFields() []string {
return []string{
"addr",
"port",
}
}
func (d *destination) beforeSave() {}
// +checklocksignore
func (d *destination) StateSave(stateSinkObject state.Sink) {
d.beforeSave()
stateSinkObject.Save(0, &d.addr)
stateSinkObject.Save(1, &d.port)
}
func (d *destination) afterLoad(context.Context) {}
// +checklocksignore
func (d *destination) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &d.addr)
stateSourceObject.Load(1, &d.port)
}
func (pm *PortManager) StateTypeName() string {
return "pkg/tcpip/ports.PortManager"
}
func (pm *PortManager) StateFields() []string {
return []string{
"allocatedPorts",
"firstEphemeral",
"numEphemeral",
}
}
func (pm *PortManager) beforeSave() {}
// +checklocksignore
func (pm *PortManager) StateSave(stateSinkObject state.Sink) {
pm.beforeSave()
stateSinkObject.Save(0, &pm.allocatedPorts)
stateSinkObject.Save(1, &pm.firstEphemeral)
stateSinkObject.Save(2, &pm.numEphemeral)
}
func (pm *PortManager) afterLoad(context.Context) {}
// +checklocksignore
func (pm *PortManager) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &pm.allocatedPorts)
stateSourceObject.Load(1, &pm.firstEphemeral)
stateSourceObject.Load(2, &pm.numEphemeral)
}
func init() {
state.Register((*Flags)(nil))
state.Register((*FlagCounter)(nil))
state.Register((*portDescriptor)(nil))
state.Register((*destination)(nil))
state.Register((*PortManager)(nil))
}

View File

@@ -0,0 +1,239 @@
package tcpip
// ElementMapper provides an identity mapping by default.
//
// This can be replaced to provide a struct that maps elements to linker
// objects, if they are not the same. An ElementMapper is not typically
// required if: Linker is left as is, Element is left as is, or Linker and
// Element are the same type.
type RouteElementMapper struct{}
// linkerFor maps an Element to a Linker.
//
// This default implementation should be inlined.
//
//go:nosplit
func (RouteElementMapper) linkerFor(elem *Route) *Route { return elem }
// List is an intrusive list. Entries can be added to or removed from the list
// in O(1) time and with no additional memory allocations.
//
// The zero value for List is an empty list ready to use.
//
// To iterate over a list (where l is a List):
//
// for e := l.Front(); e != nil; e = e.Next() {
// // do something with e.
// }
//
// +stateify savable
type RouteList struct {
head *Route
tail *Route
}
// Reset resets list l to the empty state.
func (l *RouteList) Reset() {
l.head = nil
l.tail = nil
}
// Empty returns true iff the list is empty.
//
//go:nosplit
func (l *RouteList) Empty() bool {
return l.head == nil
}
// Front returns the first element of list l or nil.
//
//go:nosplit
func (l *RouteList) Front() *Route {
return l.head
}
// Back returns the last element of list l or nil.
//
//go:nosplit
func (l *RouteList) Back() *Route {
return l.tail
}
// Len returns the number of elements in the list.
//
// NOTE: This is an O(n) operation.
//
//go:nosplit
func (l *RouteList) Len() (count int) {
for e := l.Front(); e != nil; e = (RouteElementMapper{}.linkerFor(e)).Next() {
count++
}
return count
}
// PushFront inserts the element e at the front of list l.
//
//go:nosplit
func (l *RouteList) PushFront(e *Route) {
linker := RouteElementMapper{}.linkerFor(e)
linker.SetNext(l.head)
linker.SetPrev(nil)
if l.head != nil {
RouteElementMapper{}.linkerFor(l.head).SetPrev(e)
} else {
l.tail = e
}
l.head = e
}
// PushFrontList inserts list m at the start of list l, emptying m.
//
//go:nosplit
func (l *RouteList) PushFrontList(m *RouteList) {
if l.head == nil {
l.head = m.head
l.tail = m.tail
} else if m.head != nil {
RouteElementMapper{}.linkerFor(l.head).SetPrev(m.tail)
RouteElementMapper{}.linkerFor(m.tail).SetNext(l.head)
l.head = m.head
}
m.head = nil
m.tail = nil
}
// PushBack inserts the element e at the back of list l.
//
//go:nosplit
func (l *RouteList) PushBack(e *Route) {
linker := RouteElementMapper{}.linkerFor(e)
linker.SetNext(nil)
linker.SetPrev(l.tail)
if l.tail != nil {
RouteElementMapper{}.linkerFor(l.tail).SetNext(e)
} else {
l.head = e
}
l.tail = e
}
// PushBackList inserts list m at the end of list l, emptying m.
//
//go:nosplit
func (l *RouteList) PushBackList(m *RouteList) {
if l.head == nil {
l.head = m.head
l.tail = m.tail
} else if m.head != nil {
RouteElementMapper{}.linkerFor(l.tail).SetNext(m.head)
RouteElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
l.tail = m.tail
}
m.head = nil
m.tail = nil
}
// InsertAfter inserts e after b.
//
//go:nosplit
func (l *RouteList) InsertAfter(b, e *Route) {
bLinker := RouteElementMapper{}.linkerFor(b)
eLinker := RouteElementMapper{}.linkerFor(e)
a := bLinker.Next()
eLinker.SetNext(a)
eLinker.SetPrev(b)
bLinker.SetNext(e)
if a != nil {
RouteElementMapper{}.linkerFor(a).SetPrev(e)
} else {
l.tail = e
}
}
// InsertBefore inserts e before a.
//
//go:nosplit
func (l *RouteList) InsertBefore(a, e *Route) {
aLinker := RouteElementMapper{}.linkerFor(a)
eLinker := RouteElementMapper{}.linkerFor(e)
b := aLinker.Prev()
eLinker.SetNext(a)
eLinker.SetPrev(b)
aLinker.SetPrev(e)
if b != nil {
RouteElementMapper{}.linkerFor(b).SetNext(e)
} else {
l.head = e
}
}
// Remove removes e from l.
//
//go:nosplit
func (l *RouteList) Remove(e *Route) {
linker := RouteElementMapper{}.linkerFor(e)
prev := linker.Prev()
next := linker.Next()
if prev != nil {
RouteElementMapper{}.linkerFor(prev).SetNext(next)
} else if l.head == e {
l.head = next
}
if next != nil {
RouteElementMapper{}.linkerFor(next).SetPrev(prev)
} else if l.tail == e {
l.tail = prev
}
linker.SetNext(nil)
linker.SetPrev(nil)
}
// Entry is a default implementation of Linker. Users can add anonymous fields
// of this type to their structs to make them automatically implement the
// methods needed by List.
//
// +stateify savable
type RouteEntry struct {
next *Route
prev *Route
}
// Next returns the entry that follows e in the list.
//
//go:nosplit
func (e *RouteEntry) Next() *Route {
return e.next
}
// Prev returns the entry that precedes e in the list.
//
//go:nosplit
func (e *RouteEntry) Prev() *Route {
return e.prev
}
// SetNext assigns 'entry' as the entry that follows e in the list.
//
//go:nosplit
func (e *RouteEntry) SetNext(elem *Route) {
e.next = elem
}
// SetPrev assigns 'entry' as the entry that precedes e in the list.
//
//go:nosplit
func (e *RouteEntry) SetPrev(elem *Route) {
e.prev = elem
}

View File

@@ -0,0 +1,62 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package seqnum defines the types and methods for TCP sequence numbers such
// that they fit in 32-bit words and work properly when overflows occur.
package seqnum
// Value represents the value of a sequence number.
type Value uint32
// Size represents the size (length) of a sequence number window.
type Size uint32
// LessThan checks if v is before w, i.e., v < w.
func (v Value) LessThan(w Value) bool {
return int32(v-w) < 0
}
// LessThanEq returns true if v==w or v is before i.e., v < w.
func (v Value) LessThanEq(w Value) bool {
if v == w {
return true
}
return v.LessThan(w)
}
// InRange checks if v is in the range [a,b), i.e., a <= v < b.
func (v Value) InRange(a, b Value) bool {
return v-a < b-a
}
// InWindow checks if v is in the window that starts at 'first' and spans 'size'
// sequence numbers.
func (v Value) InWindow(first Value, size Size) bool {
return v.InRange(first, first.Add(size))
}
// Add calculates the sequence number following the [v, v+s) window.
func (v Value) Add(s Size) Value {
return v + Value(s)
}
// Size calculates the size of the window defined by [v, w).
func (v Value) Size(w Value) Size {
return Size(w - v)
}
// UpdateForward updates v such that it becomes v + s.
func (v *Value) UpdateForward(s Size) {
*v += Value(s)
}

View File

@@ -0,0 +1,3 @@
// automatically generated by stateify.
package seqnum

View File

@@ -0,0 +1,239 @@
package tcpip
// ElementMapper provides an identity mapping by default.
//
// This can be replaced to provide a struct that maps elements to linker
// objects, if they are not the same. An ElementMapper is not typically
// required if: Linker is left as is, Element is left as is, or Linker and
// Element are the same type.
type sockErrorElementMapper struct{}
// linkerFor maps an Element to a Linker.
//
// This default implementation should be inlined.
//
//go:nosplit
func (sockErrorElementMapper) linkerFor(elem *SockError) *SockError { return elem }
// List is an intrusive list. Entries can be added to or removed from the list
// in O(1) time and with no additional memory allocations.
//
// The zero value for List is an empty list ready to use.
//
// To iterate over a list (where l is a List):
//
// for e := l.Front(); e != nil; e = e.Next() {
// // do something with e.
// }
//
// +stateify savable
type sockErrorList struct {
head *SockError
tail *SockError
}
// Reset resets list l to the empty state.
func (l *sockErrorList) Reset() {
l.head = nil
l.tail = nil
}
// Empty returns true iff the list is empty.
//
//go:nosplit
func (l *sockErrorList) Empty() bool {
return l.head == nil
}
// Front returns the first element of list l or nil.
//
//go:nosplit
func (l *sockErrorList) Front() *SockError {
return l.head
}
// Back returns the last element of list l or nil.
//
//go:nosplit
func (l *sockErrorList) Back() *SockError {
return l.tail
}
// Len returns the number of elements in the list.
//
// NOTE: This is an O(n) operation.
//
//go:nosplit
func (l *sockErrorList) Len() (count int) {
for e := l.Front(); e != nil; e = (sockErrorElementMapper{}.linkerFor(e)).Next() {
count++
}
return count
}
// PushFront inserts the element e at the front of list l.
//
//go:nosplit
func (l *sockErrorList) PushFront(e *SockError) {
linker := sockErrorElementMapper{}.linkerFor(e)
linker.SetNext(l.head)
linker.SetPrev(nil)
if l.head != nil {
sockErrorElementMapper{}.linkerFor(l.head).SetPrev(e)
} else {
l.tail = e
}
l.head = e
}
// PushFrontList inserts list m at the start of list l, emptying m.
//
//go:nosplit
func (l *sockErrorList) PushFrontList(m *sockErrorList) {
if l.head == nil {
l.head = m.head
l.tail = m.tail
} else if m.head != nil {
sockErrorElementMapper{}.linkerFor(l.head).SetPrev(m.tail)
sockErrorElementMapper{}.linkerFor(m.tail).SetNext(l.head)
l.head = m.head
}
m.head = nil
m.tail = nil
}
// PushBack inserts the element e at the back of list l.
//
//go:nosplit
func (l *sockErrorList) PushBack(e *SockError) {
linker := sockErrorElementMapper{}.linkerFor(e)
linker.SetNext(nil)
linker.SetPrev(l.tail)
if l.tail != nil {
sockErrorElementMapper{}.linkerFor(l.tail).SetNext(e)
} else {
l.head = e
}
l.tail = e
}
// PushBackList inserts list m at the end of list l, emptying m.
//
//go:nosplit
func (l *sockErrorList) PushBackList(m *sockErrorList) {
if l.head == nil {
l.head = m.head
l.tail = m.tail
} else if m.head != nil {
sockErrorElementMapper{}.linkerFor(l.tail).SetNext(m.head)
sockErrorElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
l.tail = m.tail
}
m.head = nil
m.tail = nil
}
// InsertAfter inserts e after b.
//
//go:nosplit
func (l *sockErrorList) InsertAfter(b, e *SockError) {
bLinker := sockErrorElementMapper{}.linkerFor(b)
eLinker := sockErrorElementMapper{}.linkerFor(e)
a := bLinker.Next()
eLinker.SetNext(a)
eLinker.SetPrev(b)
bLinker.SetNext(e)
if a != nil {
sockErrorElementMapper{}.linkerFor(a).SetPrev(e)
} else {
l.tail = e
}
}
// InsertBefore inserts e before a.
//
//go:nosplit
func (l *sockErrorList) InsertBefore(a, e *SockError) {
aLinker := sockErrorElementMapper{}.linkerFor(a)
eLinker := sockErrorElementMapper{}.linkerFor(e)
b := aLinker.Prev()
eLinker.SetNext(a)
eLinker.SetPrev(b)
aLinker.SetPrev(e)
if b != nil {
sockErrorElementMapper{}.linkerFor(b).SetNext(e)
} else {
l.head = e
}
}
// Remove removes e from l.
//
//go:nosplit
func (l *sockErrorList) Remove(e *SockError) {
linker := sockErrorElementMapper{}.linkerFor(e)
prev := linker.Prev()
next := linker.Next()
if prev != nil {
sockErrorElementMapper{}.linkerFor(prev).SetNext(next)
} else if l.head == e {
l.head = next
}
if next != nil {
sockErrorElementMapper{}.linkerFor(next).SetPrev(prev)
} else if l.tail == e {
l.tail = prev
}
linker.SetNext(nil)
linker.SetPrev(nil)
}
// Entry is a default implementation of Linker. Users can add anonymous fields
// of this type to their structs to make them automatically implement the
// methods needed by List.
//
// +stateify savable
type sockErrorEntry struct {
next *SockError
prev *SockError
}
// Next returns the entry that follows e in the list.
//
//go:nosplit
func (e *sockErrorEntry) Next() *SockError {
return e.next
}
// Prev returns the entry that precedes e in the list.
//
//go:nosplit
func (e *sockErrorEntry) Prev() *SockError {
return e.prev
}
// SetNext assigns 'entry' as the entry that follows e in the list.
//
//go:nosplit
func (e *sockErrorEntry) SetNext(elem *SockError) {
e.next = elem
}
// SetPrev assigns 'entry' as the entry that precedes e in the list.
//
//go:nosplit
func (e *sockErrorEntry) SetPrev(elem *SockError) {
e.prev = elem
}

View File

@@ -0,0 +1,758 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tcpip
import (
"gvisor.dev/gvisor/pkg/atomicbitops"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/sync"
)
// SocketOptionsHandler holds methods that help define endpoint specific
// behavior for socket level socket options. These must be implemented by
// endpoints to get notified when socket level options are set.
type SocketOptionsHandler interface {
// OnReuseAddressSet is invoked when SO_REUSEADDR is set for an endpoint.
OnReuseAddressSet(v bool)
// OnReusePortSet is invoked when SO_REUSEPORT is set for an endpoint.
OnReusePortSet(v bool)
// OnKeepAliveSet is invoked when SO_KEEPALIVE is set for an endpoint.
OnKeepAliveSet(v bool)
// OnDelayOptionSet is invoked when TCP_NODELAY is set for an endpoint.
// Note that v will be the inverse of TCP_NODELAY option.
OnDelayOptionSet(v bool)
// OnCorkOptionSet is invoked when TCP_CORK is set for an endpoint.
OnCorkOptionSet(v bool)
// LastError is invoked when SO_ERROR is read for an endpoint.
LastError() Error
// UpdateLastError updates the endpoint specific last error field.
UpdateLastError(err Error)
// HasNIC is invoked to check if the NIC is valid for SO_BINDTODEVICE.
HasNIC(v int32) bool
// OnSetSendBufferSize is invoked when the send buffer size for an endpoint is
// changed. The handler is invoked with the new value for the socket send
// buffer size. It also returns the newly set value.
OnSetSendBufferSize(v int64) (newSz int64)
// OnSetReceiveBufferSize is invoked by SO_RCVBUF and SO_RCVBUFFORCE. The
// handler can optionally return a callback which will be called after
// the buffer size is updated to newSz.
OnSetReceiveBufferSize(v, oldSz int64) (newSz int64, postSet func())
// WakeupWriters is invoked when the send buffer size for an endpoint is
// changed. The handler notifies the writers if the send buffer size is
// increased with setsockopt(2) for TCP endpoints.
WakeupWriters()
// GetAcceptConn returns true if the socket is a TCP socket and is in
// listening state.
GetAcceptConn() bool
}
// DefaultSocketOptionsHandler is an embeddable type that implements no-op
// implementations for SocketOptionsHandler methods.
type DefaultSocketOptionsHandler struct{}
var _ SocketOptionsHandler = (*DefaultSocketOptionsHandler)(nil)
// OnReuseAddressSet implements SocketOptionsHandler.OnReuseAddressSet.
func (*DefaultSocketOptionsHandler) OnReuseAddressSet(bool) {}
// OnReusePortSet implements SocketOptionsHandler.OnReusePortSet.
func (*DefaultSocketOptionsHandler) OnReusePortSet(bool) {}
// OnKeepAliveSet implements SocketOptionsHandler.OnKeepAliveSet.
func (*DefaultSocketOptionsHandler) OnKeepAliveSet(bool) {}
// OnDelayOptionSet implements SocketOptionsHandler.OnDelayOptionSet.
func (*DefaultSocketOptionsHandler) OnDelayOptionSet(bool) {}
// OnCorkOptionSet implements SocketOptionsHandler.OnCorkOptionSet.
func (*DefaultSocketOptionsHandler) OnCorkOptionSet(bool) {}
// LastError implements SocketOptionsHandler.LastError.
func (*DefaultSocketOptionsHandler) LastError() Error {
return nil
}
// UpdateLastError implements SocketOptionsHandler.UpdateLastError.
func (*DefaultSocketOptionsHandler) UpdateLastError(Error) {}
// HasNIC implements SocketOptionsHandler.HasNIC.
func (*DefaultSocketOptionsHandler) HasNIC(int32) bool {
return false
}
// OnSetSendBufferSize implements SocketOptionsHandler.OnSetSendBufferSize.
func (*DefaultSocketOptionsHandler) OnSetSendBufferSize(v int64) (newSz int64) {
return v
}
// WakeupWriters implements SocketOptionsHandler.WakeupWriters.
func (*DefaultSocketOptionsHandler) WakeupWriters() {}
// OnSetReceiveBufferSize implements SocketOptionsHandler.OnSetReceiveBufferSize.
func (*DefaultSocketOptionsHandler) OnSetReceiveBufferSize(v, oldSz int64) (newSz int64, postSet func()) {
return v, nil
}
// GetAcceptConn implements SocketOptionsHandler.GetAcceptConn.
func (*DefaultSocketOptionsHandler) GetAcceptConn() bool {
return false
}
// StackHandler holds methods to access the stack options. These must be
// implemented by the stack.
type StackHandler interface {
// Option allows retrieving stack wide options.
Option(option any) Error
// TransportProtocolOption allows retrieving individual protocol level
// option values.
TransportProtocolOption(proto TransportProtocolNumber, option GettableTransportProtocolOption) Error
}
// SocketOptions contains all the variables which store values for SOL_SOCKET,
// SOL_IP, SOL_IPV6 and SOL_TCP level options.
//
// +stateify savable
type SocketOptions struct {
handler SocketOptionsHandler
// StackHandler is initialized at the creation time and will not change.
stackHandler StackHandler `state:"manual"`
// These fields are accessed and modified using atomic operations.
// broadcastEnabled determines whether datagram sockets are allowed to
// send packets to a broadcast address.
broadcastEnabled atomicbitops.Uint32
// passCredEnabled determines whether SCM_CREDENTIALS socket control
// messages are enabled.
passCredEnabled atomicbitops.Uint32
// noChecksumEnabled determines whether UDP checksum is disabled while
// transmitting for this socket.
noChecksumEnabled atomicbitops.Uint32
// reuseAddressEnabled determines whether Bind() should allow reuse of
// local address.
reuseAddressEnabled atomicbitops.Uint32
// reusePortEnabled determines whether to permit multiple sockets to be
// bound to an identical socket address.
reusePortEnabled atomicbitops.Uint32
// keepAliveEnabled determines whether TCP keepalive is enabled for this
// socket.
keepAliveEnabled atomicbitops.Uint32
// multicastLoopEnabled determines whether multicast packets sent over a
// non-loopback interface will be looped back.
multicastLoopEnabled atomicbitops.Uint32
// receiveTOSEnabled is used to specify if the TOS ancillary message is
// passed with incoming packets.
receiveTOSEnabled atomicbitops.Uint32
// receiveTTLEnabled is used to specify if the TTL ancillary message is passed
// with incoming packets.
receiveTTLEnabled atomicbitops.Uint32
// receiveHopLimitEnabled is used to specify if the HopLimit ancillary message
// is passed with incoming packets.
receiveHopLimitEnabled atomicbitops.Uint32
// receiveTClassEnabled is used to specify if the IPV6_TCLASS ancillary
// message is passed with incoming packets.
receiveTClassEnabled atomicbitops.Uint32
// receivePacketInfoEnabled is used to specify if more information is
// provided with incoming IPv4 packets.
receivePacketInfoEnabled atomicbitops.Uint32
// receivePacketInfoEnabled is used to specify if more information is
// provided with incoming IPv6 packets.
receiveIPv6PacketInfoEnabled atomicbitops.Uint32
// hdrIncludeEnabled is used to indicate for a raw endpoint that all packets
// being written have an IP header and the endpoint should not attach an IP
// header.
hdrIncludedEnabled atomicbitops.Uint32
// v6OnlyEnabled is used to determine whether an IPv6 socket is to be
// restricted to sending and receiving IPv6 packets only.
v6OnlyEnabled atomicbitops.Uint32
// quickAckEnabled is used to represent the value of TCP_QUICKACK option.
// It currently does not have any effect on the TCP endpoint.
quickAckEnabled atomicbitops.Uint32
// delayOptionEnabled is used to specify if data should be sent out immediately
// by the transport protocol. For TCP, it determines if the Nagle algorithm
// is on or off.
delayOptionEnabled atomicbitops.Uint32
// corkOptionEnabled is used to specify if data should be held until segments
// are full by the TCP transport protocol.
corkOptionEnabled atomicbitops.Uint32
// receiveOriginalDstAddress is used to specify if the original destination of
// the incoming packet should be returned as an ancillary message.
receiveOriginalDstAddress atomicbitops.Uint32
// ipv4RecvErrEnabled determines whether extended reliable error message
// passing is enabled for IPv4.
ipv4RecvErrEnabled atomicbitops.Uint32
// ipv6RecvErrEnabled determines whether extended reliable error message
// passing is enabled for IPv6.
ipv6RecvErrEnabled atomicbitops.Uint32
// errQueue is the per-socket error queue. It is protected by errQueueMu.
errQueueMu sync.Mutex `state:"nosave"`
errQueue sockErrorList
// bindToDevice determines the device to which the socket is bound.
bindToDevice atomicbitops.Int32
// getSendBufferLimits provides the handler to get the min, default and max
// size for send buffer. It is initialized at the creation time and will not
// change.
getSendBufferLimits GetSendBufferLimits `state:"manual"`
// sendBufferSize determines the send buffer size for this socket.
sendBufferSize atomicbitops.Int64
// getReceiveBufferLimits provides the handler to get the min, default and
// max size for receive buffer. It is initialized at the creation time and
// will not change.
getReceiveBufferLimits GetReceiveBufferLimits `state:"manual"`
// receiveBufferSize determines the receive buffer size for this socket.
receiveBufferSize atomicbitops.Int64
// mu protects the access to the below fields.
mu sync.Mutex `state:"nosave"`
// linger determines the amount of time the socket should linger before
// close. We currently implement this option for TCP socket only.
linger LingerOption
// rcvlowat specifies the minimum number of bytes which should be
// received to indicate the socket as readable.
rcvlowat atomicbitops.Int32
}
// InitHandler initializes the handler. This must be called before using the
// socket options utility.
func (so *SocketOptions) InitHandler(handler SocketOptionsHandler, stack StackHandler, getSendBufferLimits GetSendBufferLimits, getReceiveBufferLimits GetReceiveBufferLimits) {
so.handler = handler
so.stackHandler = stack
so.getSendBufferLimits = getSendBufferLimits
so.getReceiveBufferLimits = getReceiveBufferLimits
}
func storeAtomicBool(addr *atomicbitops.Uint32, v bool) {
var val uint32
if v {
val = 1
}
addr.Store(val)
}
// SetLastError sets the last error for a socket.
func (so *SocketOptions) SetLastError(err Error) {
so.handler.UpdateLastError(err)
}
// GetBroadcast gets value for SO_BROADCAST option.
func (so *SocketOptions) GetBroadcast() bool {
return so.broadcastEnabled.Load() != 0
}
// SetBroadcast sets value for SO_BROADCAST option.
func (so *SocketOptions) SetBroadcast(v bool) {
storeAtomicBool(&so.broadcastEnabled, v)
}
// GetPassCred gets value for SO_PASSCRED option.
func (so *SocketOptions) GetPassCred() bool {
return so.passCredEnabled.Load() != 0
}
// SetPassCred sets value for SO_PASSCRED option.
func (so *SocketOptions) SetPassCred(v bool) {
storeAtomicBool(&so.passCredEnabled, v)
}
// GetNoChecksum gets value for SO_NO_CHECK option.
func (so *SocketOptions) GetNoChecksum() bool {
return so.noChecksumEnabled.Load() != 0
}
// SetNoChecksum sets value for SO_NO_CHECK option.
func (so *SocketOptions) SetNoChecksum(v bool) {
storeAtomicBool(&so.noChecksumEnabled, v)
}
// GetReuseAddress gets value for SO_REUSEADDR option.
func (so *SocketOptions) GetReuseAddress() bool {
return so.reuseAddressEnabled.Load() != 0
}
// SetReuseAddress sets value for SO_REUSEADDR option.
func (so *SocketOptions) SetReuseAddress(v bool) {
storeAtomicBool(&so.reuseAddressEnabled, v)
so.handler.OnReuseAddressSet(v)
}
// GetReusePort gets value for SO_REUSEPORT option.
func (so *SocketOptions) GetReusePort() bool {
return so.reusePortEnabled.Load() != 0
}
// SetReusePort sets value for SO_REUSEPORT option.
func (so *SocketOptions) SetReusePort(v bool) {
storeAtomicBool(&so.reusePortEnabled, v)
so.handler.OnReusePortSet(v)
}
// GetKeepAlive gets value for SO_KEEPALIVE option.
func (so *SocketOptions) GetKeepAlive() bool {
return so.keepAliveEnabled.Load() != 0
}
// SetKeepAlive sets value for SO_KEEPALIVE option.
func (so *SocketOptions) SetKeepAlive(v bool) {
storeAtomicBool(&so.keepAliveEnabled, v)
so.handler.OnKeepAliveSet(v)
}
// GetMulticastLoop gets value for IP_MULTICAST_LOOP option.
func (so *SocketOptions) GetMulticastLoop() bool {
return so.multicastLoopEnabled.Load() != 0
}
// SetMulticastLoop sets value for IP_MULTICAST_LOOP option.
func (so *SocketOptions) SetMulticastLoop(v bool) {
storeAtomicBool(&so.multicastLoopEnabled, v)
}
// GetReceiveTOS gets value for IP_RECVTOS option.
func (so *SocketOptions) GetReceiveTOS() bool {
return so.receiveTOSEnabled.Load() != 0
}
// SetReceiveTOS sets value for IP_RECVTOS option.
func (so *SocketOptions) SetReceiveTOS(v bool) {
storeAtomicBool(&so.receiveTOSEnabled, v)
}
// GetReceiveTTL gets value for IP_RECVTTL option.
func (so *SocketOptions) GetReceiveTTL() bool {
return so.receiveTTLEnabled.Load() != 0
}
// SetReceiveTTL sets value for IP_RECVTTL option.
func (so *SocketOptions) SetReceiveTTL(v bool) {
storeAtomicBool(&so.receiveTTLEnabled, v)
}
// GetReceiveHopLimit gets value for IP_RECVHOPLIMIT option.
func (so *SocketOptions) GetReceiveHopLimit() bool {
return so.receiveHopLimitEnabled.Load() != 0
}
// SetReceiveHopLimit sets value for IP_RECVHOPLIMIT option.
func (so *SocketOptions) SetReceiveHopLimit(v bool) {
storeAtomicBool(&so.receiveHopLimitEnabled, v)
}
// GetReceiveTClass gets value for IPV6_RECVTCLASS option.
func (so *SocketOptions) GetReceiveTClass() bool {
return so.receiveTClassEnabled.Load() != 0
}
// SetReceiveTClass sets value for IPV6_RECVTCLASS option.
func (so *SocketOptions) SetReceiveTClass(v bool) {
storeAtomicBool(&so.receiveTClassEnabled, v)
}
// GetReceivePacketInfo gets value for IP_PKTINFO option.
func (so *SocketOptions) GetReceivePacketInfo() bool {
return so.receivePacketInfoEnabled.Load() != 0
}
// SetReceivePacketInfo sets value for IP_PKTINFO option.
func (so *SocketOptions) SetReceivePacketInfo(v bool) {
storeAtomicBool(&so.receivePacketInfoEnabled, v)
}
// GetIPv6ReceivePacketInfo gets value for IPV6_RECVPKTINFO option.
func (so *SocketOptions) GetIPv6ReceivePacketInfo() bool {
return so.receiveIPv6PacketInfoEnabled.Load() != 0
}
// SetIPv6ReceivePacketInfo sets value for IPV6_RECVPKTINFO option.
func (so *SocketOptions) SetIPv6ReceivePacketInfo(v bool) {
storeAtomicBool(&so.receiveIPv6PacketInfoEnabled, v)
}
// GetHeaderIncluded gets value for IP_HDRINCL option.
func (so *SocketOptions) GetHeaderIncluded() bool {
return so.hdrIncludedEnabled.Load() != 0
}
// SetHeaderIncluded sets value for IP_HDRINCL option.
func (so *SocketOptions) SetHeaderIncluded(v bool) {
storeAtomicBool(&so.hdrIncludedEnabled, v)
}
// GetV6Only gets value for IPV6_V6ONLY option.
func (so *SocketOptions) GetV6Only() bool {
return so.v6OnlyEnabled.Load() != 0
}
// SetV6Only sets value for IPV6_V6ONLY option.
//
// Preconditions: the backing TCP or UDP endpoint must be in initial state.
func (so *SocketOptions) SetV6Only(v bool) {
storeAtomicBool(&so.v6OnlyEnabled, v)
}
// GetQuickAck gets value for TCP_QUICKACK option.
func (so *SocketOptions) GetQuickAck() bool {
return so.quickAckEnabled.Load() != 0
}
// SetQuickAck sets value for TCP_QUICKACK option.
func (so *SocketOptions) SetQuickAck(v bool) {
storeAtomicBool(&so.quickAckEnabled, v)
}
// GetDelayOption gets inverted value for TCP_NODELAY option.
func (so *SocketOptions) GetDelayOption() bool {
return so.delayOptionEnabled.Load() != 0
}
// SetDelayOption sets inverted value for TCP_NODELAY option.
func (so *SocketOptions) SetDelayOption(v bool) {
storeAtomicBool(&so.delayOptionEnabled, v)
so.handler.OnDelayOptionSet(v)
}
// GetCorkOption gets value for TCP_CORK option.
func (so *SocketOptions) GetCorkOption() bool {
return so.corkOptionEnabled.Load() != 0
}
// SetCorkOption sets value for TCP_CORK option.
func (so *SocketOptions) SetCorkOption(v bool) {
storeAtomicBool(&so.corkOptionEnabled, v)
so.handler.OnCorkOptionSet(v)
}
// GetReceiveOriginalDstAddress gets value for IP(V6)_RECVORIGDSTADDR option.
func (so *SocketOptions) GetReceiveOriginalDstAddress() bool {
return so.receiveOriginalDstAddress.Load() != 0
}
// SetReceiveOriginalDstAddress sets value for IP(V6)_RECVORIGDSTADDR option.
func (so *SocketOptions) SetReceiveOriginalDstAddress(v bool) {
storeAtomicBool(&so.receiveOriginalDstAddress, v)
}
// GetIPv4RecvError gets value for IP_RECVERR option.
func (so *SocketOptions) GetIPv4RecvError() bool {
return so.ipv4RecvErrEnabled.Load() != 0
}
// SetIPv4RecvError sets value for IP_RECVERR option.
func (so *SocketOptions) SetIPv4RecvError(v bool) {
storeAtomicBool(&so.ipv4RecvErrEnabled, v)
if !v {
so.pruneErrQueue()
}
}
// GetIPv6RecvError gets value for IPV6_RECVERR option.
func (so *SocketOptions) GetIPv6RecvError() bool {
return so.ipv6RecvErrEnabled.Load() != 0
}
// SetIPv6RecvError sets value for IPV6_RECVERR option.
func (so *SocketOptions) SetIPv6RecvError(v bool) {
storeAtomicBool(&so.ipv6RecvErrEnabled, v)
if !v {
so.pruneErrQueue()
}
}
// GetLastError gets value for SO_ERROR option.
func (so *SocketOptions) GetLastError() Error {
return so.handler.LastError()
}
// GetOutOfBandInline gets value for SO_OOBINLINE option.
func (*SocketOptions) GetOutOfBandInline() bool {
return true
}
// SetOutOfBandInline sets value for SO_OOBINLINE option. We currently do not
// support disabling this option.
func (*SocketOptions) SetOutOfBandInline(bool) {}
// GetLinger gets value for SO_LINGER option.
func (so *SocketOptions) GetLinger() LingerOption {
so.mu.Lock()
linger := so.linger
so.mu.Unlock()
return linger
}
// SetLinger sets value for SO_LINGER option.
func (so *SocketOptions) SetLinger(linger LingerOption) {
so.mu.Lock()
so.linger = linger
so.mu.Unlock()
}
// SockErrOrigin represents the constants for error origin.
type SockErrOrigin uint8
const (
// SockExtErrorOriginNone represents an unknown error origin.
SockExtErrorOriginNone SockErrOrigin = iota
// SockExtErrorOriginLocal indicates a local error.
SockExtErrorOriginLocal
// SockExtErrorOriginICMP indicates an IPv4 ICMP error.
SockExtErrorOriginICMP
// SockExtErrorOriginICMP6 indicates an IPv6 ICMP error.
SockExtErrorOriginICMP6
)
// IsICMPErr indicates if the error originated from an ICMP error.
func (origin SockErrOrigin) IsICMPErr() bool {
return origin == SockExtErrorOriginICMP || origin == SockExtErrorOriginICMP6
}
// SockErrorCause is the cause of a socket error.
type SockErrorCause interface {
// Origin is the source of the error.
Origin() SockErrOrigin
// Type is the origin specific type of error.
Type() uint8
// Code is the origin and type specific error code.
Code() uint8
// Info is any extra information about the error.
Info() uint32
}
// LocalSockError is a socket error that originated from the local host.
//
// +stateify savable
type LocalSockError struct {
info uint32
}
// Origin implements SockErrorCause.
func (*LocalSockError) Origin() SockErrOrigin {
return SockExtErrorOriginLocal
}
// Type implements SockErrorCause.
func (*LocalSockError) Type() uint8 {
return 0
}
// Code implements SockErrorCause.
func (*LocalSockError) Code() uint8 {
return 0
}
// Info implements SockErrorCause.
func (l *LocalSockError) Info() uint32 {
return l.info
}
// SockError represents a queue entry in the per-socket error queue.
//
// +stateify savable
type SockError struct {
sockErrorEntry
// Err is the error caused by the errant packet.
Err Error
// Cause is the detailed cause of the error.
Cause SockErrorCause
// Payload is the errant packet's payload.
Payload *buffer.View
// Dst is the original destination address of the errant packet.
Dst FullAddress
// Offender is the original sender address of the errant packet.
Offender FullAddress
// NetProto is the network protocol being used to transmit the packet.
NetProto NetworkProtocolNumber
}
// pruneErrQueue resets the queue.
func (so *SocketOptions) pruneErrQueue() {
so.errQueueMu.Lock()
so.errQueue.Reset()
so.errQueueMu.Unlock()
}
// DequeueErr dequeues a socket extended error from the error queue and returns
// it. Returns nil if queue is empty.
func (so *SocketOptions) DequeueErr() *SockError {
so.errQueueMu.Lock()
defer so.errQueueMu.Unlock()
err := so.errQueue.Front()
if err != nil {
so.errQueue.Remove(err)
}
return err
}
// PeekErr returns the error in the front of the error queue. Returns nil if
// the error queue is empty.
func (so *SocketOptions) PeekErr() *SockError {
so.errQueueMu.Lock()
defer so.errQueueMu.Unlock()
return so.errQueue.Front()
}
// QueueErr inserts the error at the back of the error queue.
//
// Preconditions: so.GetIPv4RecvError() or so.GetIPv6RecvError() is true.
func (so *SocketOptions) QueueErr(err *SockError) {
so.errQueueMu.Lock()
defer so.errQueueMu.Unlock()
so.errQueue.PushBack(err)
}
// QueueLocalErr queues a local error onto the local queue.
func (so *SocketOptions) QueueLocalErr(err Error, net NetworkProtocolNumber, info uint32, dst FullAddress, payload *buffer.View) {
so.QueueErr(&SockError{
Err: err,
Cause: &LocalSockError{info: info},
Payload: payload,
Dst: dst,
NetProto: net,
})
}
// GetBindToDevice gets value for SO_BINDTODEVICE option.
func (so *SocketOptions) GetBindToDevice() int32 {
return so.bindToDevice.Load()
}
// SetBindToDevice sets value for SO_BINDTODEVICE option. If bindToDevice is
// zero, the socket device binding is removed.
func (so *SocketOptions) SetBindToDevice(bindToDevice int32) Error {
if bindToDevice != 0 && !so.handler.HasNIC(bindToDevice) {
return &ErrUnknownDevice{}
}
so.bindToDevice.Store(bindToDevice)
return nil
}
// GetSendBufferSize gets value for SO_SNDBUF option.
func (so *SocketOptions) GetSendBufferSize() int64 {
return so.sendBufferSize.Load()
}
// SendBufferLimits returns the [min, max) range of allowable send buffer
// sizes.
func (so *SocketOptions) SendBufferLimits() (min, max int64) {
limits := so.getSendBufferLimits(so.stackHandler)
return int64(limits.Min), int64(limits.Max)
}
// SetSendBufferSize sets value for SO_SNDBUF option. notify indicates if the
// stack handler should be invoked to set the send buffer size.
func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) {
if notify {
sendBufferSize = so.handler.OnSetSendBufferSize(sendBufferSize)
}
so.sendBufferSize.Store(sendBufferSize)
if notify {
so.handler.WakeupWriters()
}
}
// GetReceiveBufferSize gets value for SO_RCVBUF option.
func (so *SocketOptions) GetReceiveBufferSize() int64 {
return so.receiveBufferSize.Load()
}
// ReceiveBufferLimits returns the [min, max) range of allowable receive buffer
// sizes.
func (so *SocketOptions) ReceiveBufferLimits() (min, max int64) {
limits := so.getReceiveBufferLimits(so.stackHandler)
return int64(limits.Min), int64(limits.Max)
}
// SetReceiveBufferSize sets the value of the SO_RCVBUF option, optionally
// notifying the owning endpoint.
func (so *SocketOptions) SetReceiveBufferSize(receiveBufferSize int64, notify bool) {
var postSet func()
if notify {
oldSz := so.receiveBufferSize.Load()
receiveBufferSize, postSet = so.handler.OnSetReceiveBufferSize(receiveBufferSize, oldSz)
}
so.receiveBufferSize.Store(receiveBufferSize)
if postSet != nil {
postSet()
}
}
// GetRcvlowat gets value for SO_RCVLOWAT option.
func (so *SocketOptions) GetRcvlowat() int32 {
// TODO(b/226603727): Return so.rcvlowat after adding complete support
// for SO_RCVLOWAT option. For now, return the default value of 1.
defaultRcvlowat := int32(1)
return defaultRcvlowat
}
// SetRcvlowat sets value for SO_RCVLOWAT option.
func (so *SocketOptions) SetRcvlowat(rcvlowat int32) Error {
so.rcvlowat.Store(rcvlowat)
return nil
}
// GetAcceptConn gets value for SO_ACCEPTCONN option.
func (so *SocketOptions) GetAcceptConn() bool {
return so.handler.GetAcceptConn()
}

View File

@@ -0,0 +1,96 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type addressStateRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var addressStatelockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type addressStatelockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *addressStateRWMutex) Lock() {
locking.AddGLock(addressStateprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *addressStateRWMutex) NestedLock(i addressStatelockNameIndex) {
locking.AddGLock(addressStateprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *addressStateRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(addressStateprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *addressStateRWMutex) NestedUnlock(i addressStatelockNameIndex) {
m.mu.Unlock()
locking.DelGLock(addressStateprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *addressStateRWMutex) RLock() {
locking.AddGLock(addressStateprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *addressStateRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(addressStateprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *addressStateRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *addressStateRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *addressStateRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var addressStateprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func addressStateinitLockNames() {}
func init() {
addressStateinitLockNames()
addressStateprefixIndex = locking.NewMutexClass(reflect.TypeOf(addressStateRWMutex{}), addressStatelockNames)
}

View File

@@ -0,0 +1,142 @@
package stack
import (
"context"
"fmt"
"gvisor.dev/gvisor/pkg/atomicbitops"
"gvisor.dev/gvisor/pkg/refs"
)
// enableLogging indicates whether reference-related events should be logged (with
// stack traces). This is false by default and should only be set to true for
// debugging purposes, as it can generate an extremely large amount of output
// and drastically degrade performance.
const addressStateenableLogging = false
// obj is used to customize logging. Note that we use a pointer to T so that
// we do not copy the entire object when passed as a format parameter.
var addressStateobj *addressState
// Refs implements refs.RefCounter. It keeps a reference count using atomic
// operations and calls the destructor when the count reaches zero.
//
// NOTE: Do not introduce additional fields to the Refs struct. It is used by
// many filesystem objects, and we want to keep it as small as possible (i.e.,
// the same size as using an int64 directly) to avoid taking up extra cache
// space. In general, this template should not be extended at the cost of
// performance. If it does not offer enough flexibility for a particular object
// (example: b/187877947), we should implement the RefCounter/CheckedObject
// interfaces manually.
//
// +stateify savable
type addressStateRefs struct {
// refCount is composed of two fields:
//
// [32-bit speculative references]:[32-bit real references]
//
// Speculative references are used for TryIncRef, to avoid a CompareAndSwap
// loop. See IncRef, DecRef and TryIncRef for details of how these fields are
// used.
refCount atomicbitops.Int64
}
// InitRefs initializes r with one reference and, if enabled, activates leak
// checking.
func (r *addressStateRefs) InitRefs() {
r.refCount.RacyStore(1)
refs.Register(r)
}
// RefType implements refs.CheckedObject.RefType.
func (r *addressStateRefs) RefType() string {
return fmt.Sprintf("%T", addressStateobj)[1:]
}
// LeakMessage implements refs.CheckedObject.LeakMessage.
func (r *addressStateRefs) LeakMessage() string {
return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs())
}
// LogRefs implements refs.CheckedObject.LogRefs.
func (r *addressStateRefs) LogRefs() bool {
return addressStateenableLogging
}
// ReadRefs returns the current number of references. The returned count is
// inherently racy and is unsafe to use without external synchronization.
func (r *addressStateRefs) ReadRefs() int64 {
return r.refCount.Load()
}
// IncRef implements refs.RefCounter.IncRef.
//
//go:nosplit
func (r *addressStateRefs) IncRef() {
v := r.refCount.Add(1)
if addressStateenableLogging {
refs.LogIncRef(r, v)
}
if v <= 1 {
panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType()))
}
}
// TryIncRef implements refs.TryRefCounter.TryIncRef.
//
// To do this safely without a loop, a speculative reference is first acquired
// on the object. This allows multiple concurrent TryIncRef calls to distinguish
// other TryIncRef calls from genuine references held.
//
//go:nosplit
func (r *addressStateRefs) TryIncRef() bool {
const speculativeRef = 1 << 32
if v := r.refCount.Add(speculativeRef); int32(v) == 0 {
r.refCount.Add(-speculativeRef)
return false
}
v := r.refCount.Add(-speculativeRef + 1)
if addressStateenableLogging {
refs.LogTryIncRef(r, v)
}
return true
}
// DecRef implements refs.RefCounter.DecRef.
//
// Note that speculative references are counted here. Since they were added
// prior to real references reaching zero, they will successfully convert to
// real references. In other words, we see speculative references only in the
// following case:
//
// A: TryIncRef [speculative increase => sees non-negative references]
// B: DecRef [real decrease]
// A: TryIncRef [transform speculative to real]
//
//go:nosplit
func (r *addressStateRefs) DecRef(destroy func()) {
v := r.refCount.Add(-1)
if addressStateenableLogging {
refs.LogDecRef(r, v)
}
switch {
case v < 0:
panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType()))
case v == 0:
refs.Unregister(r)
if destroy != nil {
destroy()
}
}
}
func (r *addressStateRefs) afterLoad(context.Context) {
if r.ReadRefs() > 0 {
refs.Register(r)
}
}

View File

@@ -0,0 +1,950 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
func (lifetimes *AddressLifetimes) sanitize() {
if lifetimes.Deprecated {
lifetimes.PreferredUntil = tcpip.MonotonicTime{}
}
}
var _ AddressableEndpoint = (*AddressableEndpointState)(nil)
// AddressableEndpointState is an implementation of an AddressableEndpoint.
//
// +stateify savable
type AddressableEndpointState struct {
networkEndpoint NetworkEndpoint
options AddressableEndpointStateOptions
// Lock ordering (from outer to inner lock ordering):
//
// AddressableEndpointState.mu
// addressState.mu
mu addressableEndpointStateRWMutex `state:"nosave"`
// +checklocks:mu
endpoints map[tcpip.Address]*addressState
// +checklocks:mu
primary []*addressState
}
// AddressableEndpointStateOptions contains options used to configure an
// AddressableEndpointState.
//
// +stateify savable
type AddressableEndpointStateOptions struct {
// HiddenWhileDisabled determines whether addresses should be returned to
// callers while the NetworkEndpoint this AddressableEndpointState belongs
// to is disabled.
HiddenWhileDisabled bool
}
// Init initializes the AddressableEndpointState with networkEndpoint.
//
// Must be called before calling any other function on m.
func (a *AddressableEndpointState) Init(networkEndpoint NetworkEndpoint, options AddressableEndpointStateOptions) {
a.networkEndpoint = networkEndpoint
a.options = options
a.mu.Lock()
defer a.mu.Unlock()
a.endpoints = make(map[tcpip.Address]*addressState)
}
// OnNetworkEndpointEnabledChanged must be called every time the
// NetworkEndpoint this AddressableEndpointState belongs to is enabled or
// disabled so that any AddressDispatchers can be notified of the NIC enabled
// change.
func (a *AddressableEndpointState) OnNetworkEndpointEnabledChanged() {
a.mu.RLock()
defer a.mu.RUnlock()
for _, ep := range a.endpoints {
ep.mu.Lock()
ep.notifyChangedLocked()
ep.mu.Unlock()
}
}
// GetAddress returns the AddressEndpoint for the passed address.
//
// GetAddress does not increment the address's reference count or check if the
// address is considered bound to the endpoint.
//
// Returns nil if the passed address is not associated with the endpoint.
func (a *AddressableEndpointState) GetAddress(addr tcpip.Address) AddressEndpoint {
a.mu.RLock()
defer a.mu.RUnlock()
ep, ok := a.endpoints[addr]
if !ok {
return nil
}
return ep
}
// ForEachEndpoint calls f for each address.
//
// Once f returns false, f will no longer be called.
func (a *AddressableEndpointState) ForEachEndpoint(f func(AddressEndpoint) bool) {
a.mu.RLock()
defer a.mu.RUnlock()
for _, ep := range a.endpoints {
if !f(ep) {
return
}
}
}
// ForEachPrimaryEndpoint calls f for each primary address.
//
// Once f returns false, f will no longer be called.
func (a *AddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint) bool) {
a.mu.RLock()
defer a.mu.RUnlock()
for _, ep := range a.primary {
if !f(ep) {
return
}
}
}
func (a *AddressableEndpointState) releaseAddressState(addrState *addressState) {
a.mu.Lock()
defer a.mu.Unlock()
a.releaseAddressStateLocked(addrState)
}
// releaseAddressStateLocked removes addrState from a's address state
// (primary and endpoints list).
//
// +checklocks:a.mu
func (a *AddressableEndpointState) releaseAddressStateLocked(addrState *addressState) {
oldPrimary := a.primary
for i, s := range a.primary {
if s == addrState {
a.primary = append(a.primary[:i], a.primary[i+1:]...)
oldPrimary[len(oldPrimary)-1] = nil
break
}
}
delete(a.endpoints, addrState.addr.Address)
}
// AddAndAcquirePermanentAddress implements AddressableEndpoint.
func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties AddressProperties) (AddressEndpoint, tcpip.Error) {
return a.AddAndAcquireAddress(addr, properties, Permanent)
}
// AddAndAcquireTemporaryAddress adds a temporary address.
//
// Returns *tcpip.ErrDuplicateAddress if the address exists.
//
// The temporary address's endpoint is acquired and returned.
func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, tcpip.Error) {
return a.AddAndAcquireAddress(addr, AddressProperties{PEB: peb}, Temporary)
}
// AddAndAcquireAddress adds an address with the specified kind.
//
// Returns *tcpip.ErrDuplicateAddress if the address exists.
func (a *AddressableEndpointState) AddAndAcquireAddress(addr tcpip.AddressWithPrefix, properties AddressProperties, kind AddressKind) (AddressEndpoint, tcpip.Error) {
a.mu.Lock()
defer a.mu.Unlock()
ep, err := a.addAndAcquireAddressLocked(addr, properties, kind)
// From https://golang.org/doc/faq#nil_error:
//
// Under the covers, interfaces are implemented as two elements, a type T and
// a value V.
//
// An interface value is nil only if the V and T are both unset, (T=nil, V is
// not set), In particular, a nil interface will always hold a nil type. If we
// store a nil pointer of type *int inside an interface value, the inner type
// will be *int regardless of the value of the pointer: (T=*int, V=nil). Such
// an interface value will therefore be non-nil even when the pointer value V
// inside is nil.
//
// Since addAndAcquireAddressLocked returns a nil value with a non-nil type,
// we need to explicitly return nil below if ep is (a typed) nil.
if ep == nil {
return nil, err
}
return ep, err
}
// addAndAcquireAddressLocked adds, acquires and returns a permanent or
// temporary address.
//
// If the addressable endpoint already has the address in a non-permanent state,
// and addAndAcquireAddressLocked is adding a permanent address, that address is
// promoted in place and its properties set to the properties provided. If the
// address already exists in any other state, then *tcpip.ErrDuplicateAddress is
// returned, regardless the kind of address that is being added.
//
// +checklocks:a.mu
func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, properties AddressProperties, kind AddressKind) (*addressState, tcpip.Error) {
var permanent bool
switch kind {
case PermanentExpired:
panic(fmt.Sprintf("cannot add address %s in PermanentExpired state", addr))
case Permanent, PermanentTentative:
permanent = true
case Temporary:
default:
panic(fmt.Sprintf("unknown address kind: %d", kind))
}
// attemptAddToPrimary is false when the address is already in the primary
// address list.
attemptAddToPrimary := true
addrState, ok := a.endpoints[addr.Address]
if ok {
if !permanent {
// We are adding a non-permanent address but the address exists. No need
// to go any further since we can only promote existing temporary/expired
// addresses to permanent.
return nil, &tcpip.ErrDuplicateAddress{}
}
addrState.mu.RLock()
if addrState.refs.ReadRefs() == 0 {
panic(fmt.Sprintf("found an address that should have been released (ref count == 0); address = %s", addrState.addr))
}
isPermanent := addrState.kind.IsPermanent()
addrState.mu.RUnlock()
if isPermanent {
// We are adding a permanent address but a permanent address already
// exists.
return nil, &tcpip.ErrDuplicateAddress{}
}
// We now promote the address.
for i, s := range a.primary {
if s == addrState {
switch properties.PEB {
case CanBePrimaryEndpoint:
// The address is already in the primary address list.
attemptAddToPrimary = false
case FirstPrimaryEndpoint:
if i == 0 {
// The address is already first in the primary address list.
attemptAddToPrimary = false
} else {
a.primary = append(a.primary[:i], a.primary[i+1:]...)
}
case NeverPrimaryEndpoint:
a.primary = append(a.primary[:i], a.primary[i+1:]...)
default:
panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", properties.PEB))
}
break
}
}
addrState.refs.IncRef()
} else {
addrState = &addressState{
addressableEndpointState: a,
addr: addr,
temporary: properties.Temporary,
// Cache the subnet in addrState to avoid calls to addr.Subnet() as that
// results in allocations on every call.
subnet: addr.Subnet(),
}
addrState.refs.InitRefs()
a.endpoints[addr.Address] = addrState
// We never promote an address to temporary - it can only be added as such.
// If we are actually adding a permanent address, it is promoted below.
addrState.kind = Temporary
}
// At this point we have an address we are either promoting from an expired or
// temporary address to permanent, promoting an expired address to temporary,
// or we are adding a new temporary or permanent address.
//
// The address MUST be write locked at this point.
addrState.mu.Lock()
defer addrState.mu.Unlock()
if permanent {
if addrState.kind.IsPermanent() {
panic(fmt.Sprintf("only non-permanent addresses should be promoted to permanent; address = %s", addrState.addr))
}
// Primary addresses are biased by 1.
addrState.refs.IncRef()
addrState.kind = kind
}
addrState.configType = properties.ConfigType
lifetimes := properties.Lifetimes
lifetimes.sanitize()
addrState.lifetimes = lifetimes
addrState.disp = properties.Disp
if attemptAddToPrimary {
switch properties.PEB {
case NeverPrimaryEndpoint:
case CanBePrimaryEndpoint:
a.primary = append(a.primary, addrState)
case FirstPrimaryEndpoint:
if cap(a.primary) == len(a.primary) {
a.primary = append([]*addressState{addrState}, a.primary...)
} else {
// Shift all the endpoints by 1 to make room for the new address at the
// front. We could have just created a new slice but this saves
// allocations when the slice has capacity for the new address.
primaryCount := len(a.primary)
a.primary = append(a.primary, nil)
if n := copy(a.primary[1:], a.primary); n != primaryCount {
panic(fmt.Sprintf("copied %d elements; expected = %d elements", n, primaryCount))
}
a.primary[0] = addrState
}
default:
panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", properties.PEB))
}
}
addrState.notifyChangedLocked()
return addrState, nil
}
// RemovePermanentAddress implements AddressableEndpoint.
func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) tcpip.Error {
a.mu.Lock()
defer a.mu.Unlock()
return a.removePermanentAddressLocked(addr)
}
// removePermanentAddressLocked is like RemovePermanentAddress but with locking
// requirements.
//
// +checklocks:a.mu
func (a *AddressableEndpointState) removePermanentAddressLocked(addr tcpip.Address) tcpip.Error {
addrState, ok := a.endpoints[addr]
if !ok {
return &tcpip.ErrBadLocalAddress{}
}
return a.removePermanentEndpointLocked(addrState, AddressRemovalManualAction)
}
// RemovePermanentEndpoint removes the passed endpoint if it is associated with
// a and permanent.
func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint, reason AddressRemovalReason) tcpip.Error {
addrState, ok := ep.(*addressState)
if !ok || addrState.addressableEndpointState != a {
return &tcpip.ErrInvalidEndpointState{}
}
a.mu.Lock()
defer a.mu.Unlock()
return a.removePermanentEndpointLocked(addrState, reason)
}
// removePermanentAddressLocked is like RemovePermanentAddress but with locking
// requirements.
//
// +checklocks:a.mu
func (a *AddressableEndpointState) removePermanentEndpointLocked(addrState *addressState, reason AddressRemovalReason) tcpip.Error {
if !addrState.GetKind().IsPermanent() {
return &tcpip.ErrBadLocalAddress{}
}
addrState.remove(reason)
a.decAddressRefLocked(addrState)
return nil
}
// decAddressRef decrements the address's reference count and releases it once
// the reference count hits 0.
func (a *AddressableEndpointState) decAddressRef(addrState *addressState) {
a.mu.Lock()
defer a.mu.Unlock()
a.decAddressRefLocked(addrState)
}
// decAddressRefLocked is like decAddressRef but with locking requirements.
//
// +checklocks:a.mu
func (a *AddressableEndpointState) decAddressRefLocked(addrState *addressState) {
destroy := false
addrState.refs.DecRef(func() {
destroy = true
})
if !destroy {
return
}
addrState.mu.Lock()
defer addrState.mu.Unlock()
// A non-expired permanent address must not have its reference count dropped
// to 0.
if addrState.kind.IsPermanent() {
panic(fmt.Sprintf("permanent addresses should be removed through the AddressableEndpoint: addr = %s, kind = %d", addrState.addr, addrState.kind))
}
a.releaseAddressStateLocked(addrState)
}
// SetDeprecated implements stack.AddressableEndpoint.
func (a *AddressableEndpointState) SetDeprecated(addr tcpip.Address, deprecated bool) tcpip.Error {
a.mu.RLock()
defer a.mu.RUnlock()
addrState, ok := a.endpoints[addr]
if !ok {
return &tcpip.ErrBadLocalAddress{}
}
addrState.SetDeprecated(deprecated)
return nil
}
// SetLifetimes implements stack.AddressableEndpoint.
func (a *AddressableEndpointState) SetLifetimes(addr tcpip.Address, lifetimes AddressLifetimes) tcpip.Error {
a.mu.RLock()
defer a.mu.RUnlock()
addrState, ok := a.endpoints[addr]
if !ok {
return &tcpip.ErrBadLocalAddress{}
}
addrState.SetLifetimes(lifetimes)
return nil
}
// MainAddress implements AddressableEndpoint.
func (a *AddressableEndpointState) MainAddress() tcpip.AddressWithPrefix {
a.mu.RLock()
defer a.mu.RUnlock()
ep := a.acquirePrimaryAddressRLocked(tcpip.Address{}, tcpip.Address{} /* srcHint */, func(ep *addressState) bool {
switch kind := ep.GetKind(); kind {
case Permanent:
return a.networkEndpoint.Enabled() || !a.options.HiddenWhileDisabled
case PermanentTentative, PermanentExpired, Temporary:
return false
default:
panic(fmt.Sprintf("unknown address kind: %d", kind))
}
})
if ep == nil {
return tcpip.AddressWithPrefix{}
}
addr := ep.AddressWithPrefix()
// Note that when ep must have a ref count >=2, because its ref count
// must be >=1 in order to be found and the ref count was incremented
// when a reference was acquired. The only way for the ref count to
// drop below 2 is for the endpoint to be removed, which requires a
// write lock; so we're guaranteed to be able to decrement the ref
// count and not need to remove the endpoint from a.primary.
ep.decRefMustNotFree()
return addr
}
// acquirePrimaryAddressRLocked returns an acquired primary address that is
// valid according to isValid.
//
// +checklocksread:a.mu
func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(remoteAddr, srcHint tcpip.Address, isValid func(*addressState) bool) *addressState {
// TODO: Move this out into IPv4-specific code.
// IPv6 handles source IP selection elsewhere. We have to do source
// selection only for IPv4, in which case ep is never deprecated. Thus
// we don't have to worry about refcounts.
if remoteAddr.Len() == header.IPv4AddressSize && remoteAddr != (tcpip.Address{}) {
var best *addressState
var bestLen uint8
for _, state := range a.primary {
if !isValid(state) {
continue
}
// Source hint takes precedent over prefix matching.
if state.addr.Address == srcHint && srcHint != (tcpip.Address{}) {
best = state
break
}
stateLen := state.addr.Address.MatchingPrefix(remoteAddr)
if best == nil || bestLen < stateLen {
best = state
bestLen = stateLen
}
}
if best != nil && best.TryIncRef() {
return best
}
}
var deprecatedEndpoint *addressState
for _, ep := range a.primary {
if !isValid(ep) {
continue
}
if !ep.Deprecated() {
if ep.TryIncRef() {
// ep is not deprecated, so return it immediately.
//
// If we kept track of a deprecated endpoint, decrement its reference
// count since it was incremented when we decided to keep track of it.
if deprecatedEndpoint != nil {
// Note that when deprecatedEndpoint was found, its ref count
// must have necessarily been >=1, and after incrementing it
// must be >=2. The only way for the ref count to drop below 2 is
// for the endpoint to be removed, which requires a write lock;
// so we're guaranteed to be able to decrement the ref count
// and not need to remove the endpoint from a.primary.
deprecatedEndpoint.decRefMustNotFree()
}
return ep
}
} else if deprecatedEndpoint == nil && ep.TryIncRef() {
// We prefer an endpoint that is not deprecated, but we keep track of
// ep in case a doesn't have any non-deprecated endpoints.
//
// If we end up finding a more preferred endpoint, ep's reference count
// will be decremented.
deprecatedEndpoint = ep
}
}
return deprecatedEndpoint
}
// AcquireAssignedAddressOrMatching returns an address endpoint that is
// considered assigned to the addressable endpoint.
//
// If the address is an exact match with an existing address, that address is
// returned. Otherwise, if f is provided, f is called with each address and
// the address that f returns true for is returned.
//
// If there is no matching address, a temporary address will be returned if
// allowTemp is true.
//
// If readOnly is true, the address will be returned without an extra reference.
// In this case it is not safe to modify the endpoint, only read attributes like
// subnet.
//
// Regardless how the address was obtained, it will be acquired before it is
// returned.
func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tcpip.Address, f func(AddressEndpoint) bool, allowTemp bool, tempPEB PrimaryEndpointBehavior, readOnly bool) AddressEndpoint {
lookup := func() *addressState {
if addrState, ok := a.endpoints[localAddr]; ok {
if !addrState.IsAssigned(allowTemp) {
return nil
}
if !readOnly && !addrState.TryIncRef() {
panic(fmt.Sprintf("failed to increase the reference count for address = %s", addrState.addr))
}
return addrState
}
if f != nil {
for _, addrState := range a.endpoints {
if addrState.IsAssigned(allowTemp) && f(addrState) {
if !readOnly && !addrState.TryIncRef() {
continue
}
return addrState
}
}
}
return nil
}
// Avoid exclusive lock on mu unless we need to add a new address.
a.mu.RLock()
ep := lookup()
a.mu.RUnlock()
if ep != nil {
return ep
}
if !allowTemp {
return nil
}
// Acquire state lock in exclusive mode as we need to add a new temporary
// endpoint.
a.mu.Lock()
defer a.mu.Unlock()
// Do the lookup again in case another goroutine added the address in the time
// we released and acquired the lock.
ep = lookup()
if ep != nil {
return ep
}
// Proceed to add a new temporary endpoint.
addr := localAddr.WithPrefix()
ep, err := a.addAndAcquireAddressLocked(addr, AddressProperties{PEB: tempPEB}, Temporary)
if err != nil {
// addAndAcquireAddressLocked only returns an error if the address is
// already assigned but we just checked above if the address exists so we
// expect no error.
panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, AddressProperties{PEB: %s}, false): %s", addr, tempPEB, err))
}
// From https://golang.org/doc/faq#nil_error:
//
// Under the covers, interfaces are implemented as two elements, a type T and
// a value V.
//
// An interface value is nil only if the V and T are both unset, (T=nil, V is
// not set), In particular, a nil interface will always hold a nil type. If we
// store a nil pointer of type *int inside an interface value, the inner type
// will be *int regardless of the value of the pointer: (T=*int, V=nil). Such
// an interface value will therefore be non-nil even when the pointer value V
// inside is nil.
//
// Since addAndAcquireAddressLocked returns a nil value with a non-nil type,
// we need to explicitly return nil below if ep is (a typed) nil.
if ep == nil {
return nil
}
if readOnly {
if ep.addressableEndpointState == a {
// Checklocks doesn't understand that we are logically guaranteed to have
// ep.mu locked already. We need to use checklocksignore to appease the
// analyzer.
ep.addressableEndpointState.decAddressRefLocked(ep) // +checklocksignore
} else {
ep.DecRef()
}
}
return ep
}
// AcquireAssignedAddress implements AddressableEndpoint.
func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior, readOnly bool) AddressEndpoint {
return a.AcquireAssignedAddressOrMatching(localAddr, nil, allowTemp, tempPEB, readOnly)
}
// AcquireOutgoingPrimaryAddress implements AddressableEndpoint.
func (a *AddressableEndpointState) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, srcHint tcpip.Address, allowExpired bool) AddressEndpoint {
a.mu.Lock()
defer a.mu.Unlock()
ep := a.acquirePrimaryAddressRLocked(remoteAddr, srcHint, func(ep *addressState) bool {
return ep.IsAssigned(allowExpired)
})
// From https://golang.org/doc/faq#nil_error:
//
// Under the covers, interfaces are implemented as two elements, a type T and
// a value V.
//
// An interface value is nil only if the V and T are both unset, (T=nil, V is
// not set), In particular, a nil interface will always hold a nil type. If we
// store a nil pointer of type *int inside an interface value, the inner type
// will be *int regardless of the value of the pointer: (T=*int, V=nil). Such
// an interface value will therefore be non-nil even when the pointer value V
// inside is nil.
//
// Since acquirePrimaryAddressLocked returns a nil value with a non-nil type,
// we need to explicitly return nil below if ep is (a typed) nil.
if ep == nil {
return nil
}
return ep
}
// PrimaryAddresses implements AddressableEndpoint.
func (a *AddressableEndpointState) PrimaryAddresses() []tcpip.AddressWithPrefix {
a.mu.RLock()
defer a.mu.RUnlock()
var addrs []tcpip.AddressWithPrefix
if a.options.HiddenWhileDisabled && !a.networkEndpoint.Enabled() {
return addrs
}
for _, ep := range a.primary {
switch kind := ep.GetKind(); kind {
// Don't include tentative, expired or temporary endpoints
// to avoid confusion and prevent the caller from using
// those.
case PermanentTentative, PermanentExpired, Temporary:
continue
case Permanent:
default:
panic(fmt.Sprintf("address %s has unknown kind %d", ep.AddressWithPrefix(), kind))
}
addrs = append(addrs, ep.AddressWithPrefix())
}
return addrs
}
// PermanentAddresses implements AddressableEndpoint.
func (a *AddressableEndpointState) PermanentAddresses() []tcpip.AddressWithPrefix {
a.mu.RLock()
defer a.mu.RUnlock()
var addrs []tcpip.AddressWithPrefix
for _, ep := range a.endpoints {
if !ep.GetKind().IsPermanent() {
continue
}
addrs = append(addrs, ep.AddressWithPrefix())
}
return addrs
}
// Cleanup forcefully leaves all groups and removes all permanent addresses.
func (a *AddressableEndpointState) Cleanup() {
a.mu.Lock()
defer a.mu.Unlock()
for _, ep := range a.endpoints {
// removePermanentEndpointLocked returns *tcpip.ErrBadLocalAddress if ep is
// not a permanent address.
switch err := a.removePermanentEndpointLocked(ep, AddressRemovalInterfaceRemoved); err.(type) {
case nil, *tcpip.ErrBadLocalAddress:
default:
panic(fmt.Sprintf("unexpected error from removePermanentEndpointLocked(%s): %s", ep.addr, err))
}
}
}
var _ AddressEndpoint = (*addressState)(nil)
// addressState holds state for an address.
//
// +stateify savable
type addressState struct {
addressableEndpointState *AddressableEndpointState
addr tcpip.AddressWithPrefix
subnet tcpip.Subnet
temporary bool
// Lock ordering (from outer to inner lock ordering):
//
// AddressableEndpointState.mu
// addressState.mu
mu addressStateRWMutex `state:"nosave"`
refs addressStateRefs
// checklocks:mu
kind AddressKind
// checklocks:mu
configType AddressConfigType
// lifetimes holds this address' lifetimes.
//
// Invariant: if lifetimes.deprecated is true, then lifetimes.PreferredUntil
// must be the zero value. Note that the converse does not need to be
// upheld!
//
// checklocks:mu
lifetimes AddressLifetimes
// The enclosing mutex must be write-locked before calling methods on the
// dispatcher.
//
// checklocks:mu
disp AddressDispatcher
}
// AddressWithPrefix implements AddressEndpoint.
func (a *addressState) AddressWithPrefix() tcpip.AddressWithPrefix {
return a.addr
}
// Subnet implements AddressEndpoint.
func (a *addressState) Subnet() tcpip.Subnet {
return a.subnet
}
// GetKind implements AddressEndpoint.
func (a *addressState) GetKind() AddressKind {
a.mu.RLock()
defer a.mu.RUnlock()
return a.kind
}
// SetKind implements AddressEndpoint.
func (a *addressState) SetKind(kind AddressKind) {
a.mu.Lock()
defer a.mu.Unlock()
prevKind := a.kind
a.kind = kind
if kind == PermanentExpired {
a.notifyRemovedLocked(AddressRemovalManualAction)
} else if prevKind != kind && a.addressableEndpointState.networkEndpoint.Enabled() {
a.notifyChangedLocked()
}
}
// notifyRemovedLocked notifies integrators of address removal.
//
// +checklocks:a.mu
func (a *addressState) notifyRemovedLocked(reason AddressRemovalReason) {
if disp := a.disp; disp != nil {
a.disp.OnRemoved(reason)
a.disp = nil
}
}
func (a *addressState) remove(reason AddressRemovalReason) {
a.mu.Lock()
defer a.mu.Unlock()
a.kind = PermanentExpired
a.notifyRemovedLocked(reason)
}
// IsAssigned implements AddressEndpoint.
func (a *addressState) IsAssigned(allowExpired bool) bool {
switch kind := a.GetKind(); kind {
case PermanentTentative:
return false
case PermanentExpired:
return allowExpired
case Permanent, Temporary:
return true
default:
panic(fmt.Sprintf("address %s has unknown kind %d", a.AddressWithPrefix(), kind))
}
}
// IncRef implements AddressEndpoint.
func (a *addressState) TryIncRef() bool {
return a.refs.TryIncRef()
}
// DecRef implements AddressEndpoint.
func (a *addressState) DecRef() {
a.addressableEndpointState.decAddressRef(a)
}
// decRefMustNotFree decreases the reference count with the guarantee that the
// reference count will be greater than 0 after the decrement.
//
// Panics if the ref count is less than 2 after acquiring the lock in this
// function.
func (a *addressState) decRefMustNotFree() {
a.refs.DecRef(func() {
panic(fmt.Sprintf("cannot decrease addressState %s without freeing the endpoint", a.addr))
})
}
// ConfigType implements AddressEndpoint.
func (a *addressState) ConfigType() AddressConfigType {
a.mu.RLock()
defer a.mu.RUnlock()
return a.configType
}
// notifyChangedLocked notifies integrators of address property changes.
//
// +checklocks:a.mu
func (a *addressState) notifyChangedLocked() {
if a.disp == nil {
return
}
state := AddressDisabled
if a.addressableEndpointState.networkEndpoint.Enabled() {
switch a.kind {
case Permanent:
state = AddressAssigned
case PermanentTentative:
state = AddressTentative
case Temporary, PermanentExpired:
return
default:
panic(fmt.Sprintf("unrecognized address kind = %d", a.kind))
}
}
a.disp.OnChanged(a.lifetimes, state)
}
// SetDeprecated implements AddressEndpoint.
func (a *addressState) SetDeprecated(d bool) {
a.mu.Lock()
defer a.mu.Unlock()
var changed bool
if a.lifetimes.Deprecated != d {
a.lifetimes.Deprecated = d
changed = true
}
if d {
a.lifetimes.PreferredUntil = tcpip.MonotonicTime{}
}
if changed {
a.notifyChangedLocked()
}
}
// Deprecated implements AddressEndpoint.
func (a *addressState) Deprecated() bool {
a.mu.RLock()
defer a.mu.RUnlock()
return a.lifetimes.Deprecated
}
// SetLifetimes implements AddressEndpoint.
func (a *addressState) SetLifetimes(lifetimes AddressLifetimes) {
a.mu.Lock()
defer a.mu.Unlock()
lifetimes.sanitize()
var changed bool
if a.lifetimes != lifetimes {
changed = true
}
a.lifetimes = lifetimes
if changed {
a.notifyChangedLocked()
}
}
// Lifetimes implements AddressEndpoint.
func (a *addressState) Lifetimes() AddressLifetimes {
a.mu.RLock()
defer a.mu.RUnlock()
return a.lifetimes
}
// Temporary implements AddressEndpoint.
func (a *addressState) Temporary() bool {
return a.temporary
}
// RegisterDispatcher implements AddressEndpoint.
func (a *addressState) RegisterDispatcher(disp AddressDispatcher) {
a.mu.Lock()
defer a.mu.Unlock()
if disp != nil {
a.disp = disp
a.notifyChangedLocked()
}
}

View File

@@ -0,0 +1,96 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type addressableEndpointStateRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var addressableEndpointStatelockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type addressableEndpointStatelockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *addressableEndpointStateRWMutex) Lock() {
locking.AddGLock(addressableEndpointStateprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *addressableEndpointStateRWMutex) NestedLock(i addressableEndpointStatelockNameIndex) {
locking.AddGLock(addressableEndpointStateprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *addressableEndpointStateRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(addressableEndpointStateprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *addressableEndpointStateRWMutex) NestedUnlock(i addressableEndpointStatelockNameIndex) {
m.mu.Unlock()
locking.DelGLock(addressableEndpointStateprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *addressableEndpointStateRWMutex) RLock() {
locking.AddGLock(addressableEndpointStateprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *addressableEndpointStateRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(addressableEndpointStateprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *addressableEndpointStateRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *addressableEndpointStateRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *addressableEndpointStateRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var addressableEndpointStateprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func addressableEndpointStateinitLockNames() {}
func init() {
addressableEndpointStateinitLockNames()
addressableEndpointStateprefixIndex = locking.NewMutexClass(reflect.TypeOf(addressableEndpointStateRWMutex{}), addressableEndpointStatelockNames)
}

View File

@@ -0,0 +1,229 @@
// Copyright 2024 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"gvisor.dev/gvisor/pkg/atomicbitops"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
var _ NetworkLinkEndpoint = (*BridgeEndpoint)(nil)
type bridgePort struct {
bridge *BridgeEndpoint
nic *nic
}
// ParseHeader implements stack.LinkEndpoint.
func (p *bridgePort) ParseHeader(pkt *PacketBuffer) bool {
_, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
return ok
}
// DeliverNetworkPacket implements stack.NetworkDispatcher.
func (p *bridgePort) DeliverNetworkPacket(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
bridge := p.bridge
bridge.mu.RLock()
// Send the packet to all other ports.
for _, port := range bridge.ports {
if p == port {
continue
}
newPkt := NewPacketBuffer(PacketBufferOptions{
ReserveHeaderBytes: int(port.nic.MaxHeaderLength()),
Payload: pkt.ToBuffer(),
})
port.nic.writeRawPacket(newPkt)
newPkt.DecRef()
}
d := bridge.dispatcher
bridge.mu.RUnlock()
if d != nil {
// The dispatcher may acquire Stack.mu in DeliverNetworkPacket(), which is
// ordered above bridge.mu. So call DeliverNetworkPacket() without holding
// bridge.mu to avoid circular locking.
d.DeliverNetworkPacket(protocol, pkt)
}
}
func (p *bridgePort) DeliverLinkPacket(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
}
// NewBridgeEndpoint creates a new bridge endpoint.
func NewBridgeEndpoint(mtu uint32) *BridgeEndpoint {
b := &BridgeEndpoint{
mtu: mtu,
addr: tcpip.GetRandMacAddr(),
}
b.ports = make(map[tcpip.NICID]*bridgePort)
return b
}
// BridgeEndpoint is a bridge endpoint.
type BridgeEndpoint struct {
mu bridgeRWMutex
// +checklocks:mu
ports map[tcpip.NICID]*bridgePort
// +checklocks:mu
dispatcher NetworkDispatcher
// +checklocks:mu
addr tcpip.LinkAddress
// +checklocks:mu
attached bool
// +checklocks:mu
mtu uint32
maxHeaderLength atomicbitops.Uint32
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
func (b *BridgeEndpoint) WritePackets(pkts PacketBufferList) (int, tcpip.Error) {
b.mu.RLock()
defer b.mu.RUnlock()
pktsSlice := pkts.AsSlice()
n := len(pktsSlice)
for _, p := range b.ports {
for _, pkt := range pktsSlice {
// In order to properly loop back to the inbound side we must create a
// fresh packet that only contains the underlying payload with no headers
// or struct fields set.
newPkt := NewPacketBuffer(PacketBufferOptions{
Payload: pkt.ToBuffer(),
ReserveHeaderBytes: int(p.nic.MaxHeaderLength()),
})
newPkt.EgressRoute = pkt.EgressRoute
newPkt.NetworkProtocolNumber = pkt.NetworkProtocolNumber
p.nic.writePacket(newPkt)
newPkt.DecRef()
}
}
return n, nil
}
// AddNIC adds the specified NIC to the bridge.
func (b *BridgeEndpoint) AddNIC(n *nic) tcpip.Error {
b.mu.Lock()
defer b.mu.Unlock()
port := &bridgePort{
nic: n,
bridge: b,
}
n.NetworkLinkEndpoint.Attach(port)
b.ports[n.id] = port
if b.maxHeaderLength.Load() < uint32(n.MaxHeaderLength()) {
b.maxHeaderLength.Store(uint32(n.MaxHeaderLength()))
}
return nil
}
// DelNIC remove the specified NIC from the bridge.
func (b *BridgeEndpoint) DelNIC(nic *nic) tcpip.Error {
b.mu.Lock()
defer b.mu.Unlock()
delete(b.ports, nic.id)
nic.NetworkLinkEndpoint.Attach(nic)
return nil
}
// MTU implements stack.LinkEndpoint.MTU.
func (b *BridgeEndpoint) MTU() uint32 {
b.mu.RLock()
defer b.mu.RUnlock()
if b.mtu > header.EthernetMinimumSize {
return b.mtu - header.EthernetMinimumSize
}
return 0
}
// SetMTU implements stack.LinkEndpoint.SetMTU.
func (b *BridgeEndpoint) SetMTU(mtu uint32) {
b.mu.Lock()
defer b.mu.Unlock()
b.mtu = mtu
}
// MaxHeaderLength implements stack.LinkEndpoint.
func (b *BridgeEndpoint) MaxHeaderLength() uint16 {
return uint16(b.maxHeaderLength.Load())
}
// LinkAddress implements stack.LinkEndpoint.LinkAddress.
func (b *BridgeEndpoint) LinkAddress() tcpip.LinkAddress {
b.mu.Lock()
defer b.mu.Unlock()
return b.addr
}
// SetLinkAddress implements stack.LinkEndpoint.SetLinkAddress.
func (b *BridgeEndpoint) SetLinkAddress(addr tcpip.LinkAddress) {
b.mu.Lock()
defer b.mu.Unlock()
b.addr = addr
}
// Capabilities implements stack.LinkEndpoint.Capabilities.
func (b *BridgeEndpoint) Capabilities() LinkEndpointCapabilities {
return CapabilityRXChecksumOffload | CapabilitySaveRestore | CapabilityResolutionRequired
}
// Attach implements stack.LinkEndpoint.Attach.
func (b *BridgeEndpoint) Attach(dispatcher NetworkDispatcher) {
b.mu.Lock()
defer b.mu.Unlock()
for _, p := range b.ports {
p.nic.Primary = nil
}
b.dispatcher = dispatcher
b.ports = make(map[tcpip.NICID]*bridgePort)
}
// IsAttached implements stack.LinkEndpoint.IsAttached.
func (b *BridgeEndpoint) IsAttached() bool {
b.mu.RLock()
defer b.mu.RUnlock()
return b.dispatcher != nil
}
// Wait implements stack.LinkEndpoint.Wait.
func (b *BridgeEndpoint) Wait() {
}
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
func (b *BridgeEndpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareEther
}
// AddHeader implements stack.LinkEndpoint.AddHeader.
func (b *BridgeEndpoint) AddHeader(pkt *PacketBuffer) {
}
// ParseHeader implements stack.LinkEndpoint.ParseHeader.
func (b *BridgeEndpoint) ParseHeader(*PacketBuffer) bool {
return true
}
// Close implements stack.LinkEndpoint.Close.
func (b *BridgeEndpoint) Close() {}
// SetOnCloseAction implements stack.LinkEndpoint.Close.
func (b *BridgeEndpoint) SetOnCloseAction(func()) {}

View File

@@ -0,0 +1,96 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type bridgeRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var bridgelockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type bridgelockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *bridgeRWMutex) Lock() {
locking.AddGLock(bridgeprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *bridgeRWMutex) NestedLock(i bridgelockNameIndex) {
locking.AddGLock(bridgeprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *bridgeRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(bridgeprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *bridgeRWMutex) NestedUnlock(i bridgelockNameIndex) {
m.mu.Unlock()
locking.DelGLock(bridgeprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *bridgeRWMutex) RLock() {
locking.AddGLock(bridgeprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *bridgeRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(bridgeprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *bridgeRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *bridgeRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *bridgeRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var bridgeprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func bridgeinitLockNames() {}
func init() {
bridgeinitLockNames()
bridgeprefixIndex = locking.NewMutexClass(reflect.TypeOf(bridgeRWMutex{}), bridgelockNames)
}

View File

@@ -0,0 +1,98 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type bucketRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var bucketlockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type bucketlockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
const (
bucketLockOthertuple = bucketlockNameIndex(0)
)
const ()
// Lock locks m.
// +checklocksignore
func (m *bucketRWMutex) Lock() {
locking.AddGLock(bucketprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *bucketRWMutex) NestedLock(i bucketlockNameIndex) {
locking.AddGLock(bucketprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *bucketRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(bucketprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *bucketRWMutex) NestedUnlock(i bucketlockNameIndex) {
m.mu.Unlock()
locking.DelGLock(bucketprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *bucketRWMutex) RLock() {
locking.AddGLock(bucketprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *bucketRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(bucketprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *bucketRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *bucketRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *bucketRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var bucketprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func bucketinitLockNames() { bucketlockNames = []string{"otherTuple"} }
func init() {
bucketinitLockNames()
bucketprefixIndex = locking.NewMutexClass(reflect.TypeOf(bucketRWMutex{}), bucketlockNames)
}

View File

@@ -0,0 +1,64 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// Mutex is sync.Mutex with the correctness validator.
type cleanupEndpointsMutex struct {
mu sync.Mutex
}
var cleanupEndpointsprefixIndex *locking.MutexClass
// lockNames is a list of user-friendly lock names.
// Populated in init.
var cleanupEndpointslockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type cleanupEndpointslockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *cleanupEndpointsMutex) Lock() {
locking.AddGLock(cleanupEndpointsprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *cleanupEndpointsMutex) NestedLock(i cleanupEndpointslockNameIndex) {
locking.AddGLock(cleanupEndpointsprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *cleanupEndpointsMutex) Unlock() {
locking.DelGLock(cleanupEndpointsprefixIndex, -1)
m.mu.Unlock()
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *cleanupEndpointsMutex) NestedUnlock(i cleanupEndpointslockNameIndex) {
locking.DelGLock(cleanupEndpointsprefixIndex, int(i))
m.mu.Unlock()
}
// DO NOT REMOVE: The following function is automatically replaced.
func cleanupEndpointsinitLockNames() {}
func init() {
cleanupEndpointsinitLockNames()
cleanupEndpointsprefixIndex = locking.NewMutexClass(reflect.TypeOf(cleanupEndpointsMutex{}), cleanupEndpointslockNames)
}

View File

@@ -0,0 +1,96 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type connRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var connlockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type connlockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *connRWMutex) Lock() {
locking.AddGLock(connprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *connRWMutex) NestedLock(i connlockNameIndex) {
locking.AddGLock(connprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *connRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(connprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *connRWMutex) NestedUnlock(i connlockNameIndex) {
m.mu.Unlock()
locking.DelGLock(connprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *connRWMutex) RLock() {
locking.AddGLock(connprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *connRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(connprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *connRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *connRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *connRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var connprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func conninitLockNames() {}
func init() {
conninitLockNames()
connprefixIndex = locking.NewMutexClass(reflect.TypeOf(connRWMutex{}), connlockNames)
}

View File

@@ -0,0 +1,96 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type connTrackRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var connTracklockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type connTracklockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *connTrackRWMutex) Lock() {
locking.AddGLock(connTrackprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *connTrackRWMutex) NestedLock(i connTracklockNameIndex) {
locking.AddGLock(connTrackprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *connTrackRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(connTrackprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *connTrackRWMutex) NestedUnlock(i connTracklockNameIndex) {
m.mu.Unlock()
locking.DelGLock(connTrackprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *connTrackRWMutex) RLock() {
locking.AddGLock(connTrackprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *connTrackRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(connTrackprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *connTrackRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *connTrackRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *connTrackRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var connTrackprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func connTrackinitLockNames() {}
func init() {
connTrackinitLockNames()
connTrackprefixIndex = locking.NewMutexClass(reflect.TypeOf(connTrackRWMutex{}), connTracklockNames)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,96 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type endpointsByNICRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var endpointsByNIClockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type endpointsByNIClockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *endpointsByNICRWMutex) Lock() {
locking.AddGLock(endpointsByNICprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *endpointsByNICRWMutex) NestedLock(i endpointsByNIClockNameIndex) {
locking.AddGLock(endpointsByNICprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *endpointsByNICRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(endpointsByNICprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *endpointsByNICRWMutex) NestedUnlock(i endpointsByNIClockNameIndex) {
m.mu.Unlock()
locking.DelGLock(endpointsByNICprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *endpointsByNICRWMutex) RLock() {
locking.AddGLock(endpointsByNICprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *endpointsByNICRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(endpointsByNICprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *endpointsByNICRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *endpointsByNICRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *endpointsByNICRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var endpointsByNICprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func endpointsByNICinitLockNames() {}
func init() {
endpointsByNICinitLockNames()
endpointsByNICprefixIndex = locking.NewMutexClass(reflect.TypeOf(endpointsByNICRWMutex{}), endpointsByNIClockNames)
}

View File

@@ -0,0 +1,599 @@
// Copyright 2022 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package gro implements generic receive offload.
package gro
import (
"bytes"
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
// TODO(b/256037250): Enable by default.
// TODO(b/256037250): We parse headers here. We should save those headers in
// PacketBuffers so they don't have to be re-parsed later.
// TODO(b/256037250): I still see the occasional SACK block in the zero-loss
// benchmark, which should not happen.
// TODO(b/256037250): Some dispatchers, e.g. XDP and RecvMmsg, can receive
// multiple packets at a time. Even if the GRO interval is 0, there is an
// opportunity for coalescing.
// TODO(b/256037250): We're doing some header parsing here, which presents the
// opportunity to skip it later.
// TODO(b/256037250): Can we pass a packet list up the stack too?
const (
// groNBuckets is the number of GRO buckets.
groNBuckets = 8
groNBucketsMask = groNBuckets - 1
// groBucketSize is the size of each GRO bucket.
groBucketSize = 8
// groMaxPacketSize is the maximum size of a GRO'd packet.
groMaxPacketSize = 1 << 16 // 65KB.
)
// A groBucket holds packets that are undergoing GRO.
type groBucket struct {
// count is the number of packets in the bucket.
count int
// packets is the linked list of packets.
packets groPacketList
// packetsPrealloc and allocIdxs are used to preallocate and reuse
// groPacket structs and avoid allocation.
packetsPrealloc [groBucketSize]groPacket
allocIdxs [groBucketSize]int
}
func (gb *groBucket) full() bool {
return gb.count == groBucketSize
}
// insert inserts pkt into the bucket.
func (gb *groBucket) insert(pkt *stack.PacketBuffer, ipHdr []byte, tcpHdr header.TCP) {
groPkt := &gb.packetsPrealloc[gb.allocIdxs[gb.count]]
*groPkt = groPacket{
pkt: pkt,
ipHdr: ipHdr,
tcpHdr: tcpHdr,
initialLength: pkt.Data().Size(), // pkt.Data() contains network header.
idx: groPkt.idx,
}
gb.count++
gb.packets.PushBack(groPkt)
}
// removeOldest removes the oldest packet from gb and returns the contained
// PacketBuffer. gb must not be empty.
func (gb *groBucket) removeOldest() *stack.PacketBuffer {
pkt := gb.packets.Front()
gb.packets.Remove(pkt)
gb.count--
gb.allocIdxs[gb.count] = pkt.idx
ret := pkt.pkt
pkt.reset()
return ret
}
// removeOne removes a packet from gb. It also resets pkt to its zero value.
func (gb *groBucket) removeOne(pkt *groPacket) {
gb.packets.Remove(pkt)
gb.count--
gb.allocIdxs[gb.count] = pkt.idx
pkt.reset()
}
// findGROPacket4 returns the groPkt that matches ipHdr and tcpHdr, or nil if
// none exists. It also returns whether the groPkt should be flushed based on
// differences between the two headers.
func (gb *groBucket) findGROPacket4(pkt *stack.PacketBuffer, ipHdr header.IPv4, tcpHdr header.TCP) (*groPacket, bool) {
for groPkt := gb.packets.Front(); groPkt != nil; groPkt = groPkt.Next() {
// Do the addresses match?
groIPHdr := header.IPv4(groPkt.ipHdr)
if ipHdr.SourceAddress() != groIPHdr.SourceAddress() || ipHdr.DestinationAddress() != groIPHdr.DestinationAddress() {
continue
}
// Do the ports match?
if tcpHdr.SourcePort() != groPkt.tcpHdr.SourcePort() || tcpHdr.DestinationPort() != groPkt.tcpHdr.DestinationPort() {
continue
}
// We've found a packet of the same flow.
// IP checks.
TOS, _ := ipHdr.TOS()
groTOS, _ := groIPHdr.TOS()
if ipHdr.TTL() != groIPHdr.TTL() || TOS != groTOS {
return groPkt, true
}
// TCP checks.
if shouldFlushTCP(groPkt, tcpHdr) {
return groPkt, true
}
// There's an upper limit on coalesced packet size.
if pkt.Data().Size()-header.IPv4MinimumSize-int(tcpHdr.DataOffset())+groPkt.pkt.Data().Size() >= groMaxPacketSize {
return groPkt, true
}
return groPkt, false
}
return nil, false
}
// findGROPacket6 returns the groPkt that matches ipHdr and tcpHdr, or nil if
// none exists. It also returns whether the groPkt should be flushed based on
// differences between the two headers.
func (gb *groBucket) findGROPacket6(pkt *stack.PacketBuffer, ipHdr header.IPv6, tcpHdr header.TCP) (*groPacket, bool) {
for groPkt := gb.packets.Front(); groPkt != nil; groPkt = groPkt.Next() {
// Do the addresses match?
groIPHdr := header.IPv6(groPkt.ipHdr)
if ipHdr.SourceAddress() != groIPHdr.SourceAddress() || ipHdr.DestinationAddress() != groIPHdr.DestinationAddress() {
continue
}
// Need to check that headers are the same except:
// - Traffic class, a difference of which causes a flush.
// - Hop limit, a difference of which causes a flush.
// - Length, which is checked later.
// - Version, which is checked by an earlier call to IsValid().
trafficClass, flowLabel := ipHdr.TOS()
groTrafficClass, groFlowLabel := groIPHdr.TOS()
if flowLabel != groFlowLabel || ipHdr.NextHeader() != groIPHdr.NextHeader() {
continue
}
// Unlike IPv4, IPv6 packets with extension headers can be coalesced.
if !bytes.Equal(ipHdr[header.IPv6MinimumSize:], groIPHdr[header.IPv6MinimumSize:]) {
continue
}
// Do the ports match?
if tcpHdr.SourcePort() != groPkt.tcpHdr.SourcePort() || tcpHdr.DestinationPort() != groPkt.tcpHdr.DestinationPort() {
continue
}
// We've found a packet of the same flow.
// TCP checks.
if shouldFlushTCP(groPkt, tcpHdr) {
return groPkt, true
}
// Do the traffic class and hop limit match?
if trafficClass != groTrafficClass || ipHdr.HopLimit() != groIPHdr.HopLimit() {
return groPkt, true
}
// This limit is artificial for IPv6 -- we could allow even
// larger packets via jumbograms.
if pkt.Data().Size()-len(ipHdr)-int(tcpHdr.DataOffset())+groPkt.pkt.Data().Size() >= groMaxPacketSize {
return groPkt, true
}
return groPkt, false
}
return nil, false
}
func (gb *groBucket) found(gd *GRO, groPkt *groPacket, flushGROPkt bool, pkt *stack.PacketBuffer, ipHdr []byte, tcpHdr header.TCP, updateIPHdr func([]byte, int)) {
// Flush groPkt or merge the packets.
pktSize := pkt.Data().Size()
flags := tcpHdr.Flags()
dataOff := tcpHdr.DataOffset()
tcpPayloadSize := pkt.Data().Size() - len(ipHdr) - int(dataOff)
if flushGROPkt {
// Flush the existing GRO packet.
pkt := groPkt.pkt
gb.removeOne(groPkt)
gd.handlePacket(pkt)
pkt.DecRef()
groPkt = nil
} else if groPkt != nil {
// Merge pkt in to GRO packet.
pkt.Data().TrimFront(len(ipHdr) + int(dataOff))
groPkt.pkt.Data().Merge(pkt.Data())
// Update the IP total length.
updateIPHdr(groPkt.ipHdr, tcpPayloadSize)
// Add flags from the packet to the GRO packet.
groPkt.tcpHdr.SetFlags(uint8(groPkt.tcpHdr.Flags() | (flags & (header.TCPFlagFin | header.TCPFlagPsh))))
pkt = nil
}
// Flush if the packet isn't the same size as the previous packets or
// if certain flags are set. The reason for checking size equality is:
// - If the packet is smaller than the others, this is likely the end
// of some message. Peers will send MSS-sized packets until they have
// insufficient data to do so.
// - If the packet is larger than the others, this packet is either
// malformed, a local GSO packet, or has already been handled by host
// GRO.
flush := header.TCPFlags(flags)&(header.TCPFlagUrg|header.TCPFlagPsh|header.TCPFlagRst|header.TCPFlagSyn|header.TCPFlagFin) != 0
flush = flush || tcpPayloadSize == 0
if groPkt != nil {
flush = flush || pktSize != groPkt.initialLength
}
switch {
case flush && groPkt != nil:
// A merge occurred and we need to flush groPkt.
pkt := groPkt.pkt
gb.removeOne(groPkt)
gd.handlePacket(pkt)
pkt.DecRef()
case flush && groPkt == nil:
// No merge occurred and the incoming packet needs to be flushed.
gd.handlePacket(pkt)
case !flush && groPkt == nil:
// New flow and we don't need to flush. Insert pkt into GRO.
if gb.full() {
// Head is always the oldest packet
toFlush := gb.removeOldest()
gb.insert(pkt.IncRef(), ipHdr, tcpHdr)
gd.handlePacket(toFlush)
toFlush.DecRef()
} else {
gb.insert(pkt.IncRef(), ipHdr, tcpHdr)
}
default:
// A merge occurred and we don't need to flush anything.
}
}
// A groPacket is packet undergoing GRO. It may be several packets coalesced
// together.
type groPacket struct {
// groPacketEntry is an intrusive list.
groPacketEntry
// pkt is the coalesced packet.
pkt *stack.PacketBuffer
// ipHdr is the IP (v4 or v6) header for the coalesced packet.
ipHdr []byte
// tcpHdr is the TCP header for the coalesced packet.
tcpHdr header.TCP
// initialLength is the length of the first packet in the flow. It is
// used as a best-effort guess at MSS: senders will send MSS-sized
// packets until they run out of data, so we coalesce as long as
// packets are the same size.
initialLength int
// idx is the groPacket's index in its bucket packetsPrealloc. It is
// immutable.
idx int
}
// reset resets all mutable fields of the groPacket.
func (pk *groPacket) reset() {
*pk = groPacket{
idx: pk.idx,
}
}
// payloadSize is the payload size of the coalesced packet, which does not
// include the network or transport headers.
func (pk *groPacket) payloadSize() int {
return pk.pkt.Data().Size() - len(pk.ipHdr) - int(pk.tcpHdr.DataOffset())
}
// GRO coalesces incoming packets to increase throughput.
type GRO struct {
enabled bool
buckets [groNBuckets]groBucket
Dispatcher stack.NetworkDispatcher
}
// Init initializes GRO.
func (gd *GRO) Init(enabled bool) {
gd.enabled = enabled
for i := range gd.buckets {
bucket := &gd.buckets[i]
for j := range bucket.packetsPrealloc {
bucket.allocIdxs[j] = j
bucket.packetsPrealloc[j].idx = j
}
}
}
// Enqueue the packet in GRO. This does not flush packets; Flush() must be
// called explicitly for that.
//
// pkt.NetworkProtocolNumber and pkt.RXChecksumValidated must be set.
func (gd *GRO) Enqueue(pkt *stack.PacketBuffer) {
if !gd.enabled {
gd.handlePacket(pkt)
return
}
switch pkt.NetworkProtocolNumber {
case header.IPv4ProtocolNumber:
gd.dispatch4(pkt)
case header.IPv6ProtocolNumber:
gd.dispatch6(pkt)
default:
gd.handlePacket(pkt)
}
}
func (gd *GRO) dispatch4(pkt *stack.PacketBuffer) {
// Immediately get the IPv4 and TCP headers. We need a way to hash the
// packet into its bucket, which requires addresses and ports. Linux
// simply gets a hash passed by hardware, but we're not so lucky.
// We only GRO TCP packets. The check for the transport protocol number
// is done below so that we can PullUp both the IP and TCP headers
// together.
hdrBytes, ok := pkt.Data().PullUp(header.IPv4MinimumSize + header.TCPMinimumSize)
if !ok {
gd.handlePacket(pkt)
return
}
ipHdr := header.IPv4(hdrBytes)
// We don't handle fragments. That should be the vast majority of
// traffic, and simplifies handling.
if ipHdr.FragmentOffset() != 0 || ipHdr.Flags()&header.IPv4FlagMoreFragments != 0 {
gd.handlePacket(pkt)
return
}
// We only handle TCP packets without IP options.
if ipHdr.HeaderLength() != header.IPv4MinimumSize || tcpip.TransportProtocolNumber(ipHdr.Protocol()) != header.TCPProtocolNumber {
gd.handlePacket(pkt)
return
}
tcpHdr := header.TCP(hdrBytes[header.IPv4MinimumSize:])
ipHdr = ipHdr[:header.IPv4MinimumSize]
dataOff := tcpHdr.DataOffset()
if dataOff < header.TCPMinimumSize {
// Malformed packet: will be handled further up the stack.
gd.handlePacket(pkt)
return
}
hdrBytes, ok = pkt.Data().PullUp(header.IPv4MinimumSize + int(dataOff))
if !ok {
// Malformed packet: will be handled further up the stack.
gd.handlePacket(pkt)
return
}
tcpHdr = header.TCP(hdrBytes[header.IPv4MinimumSize:])
// If either checksum is bad, flush the packet. Since we don't know
// what bits were flipped, we can't identify this packet with a flow.
if !pkt.RXChecksumValidated {
if !ipHdr.IsValid(pkt.Data().Size()) || !ipHdr.IsChecksumValid() {
gd.handlePacket(pkt)
return
}
payloadChecksum := pkt.Data().ChecksumAtOffset(header.IPv4MinimumSize + int(dataOff))
tcpPayloadSize := pkt.Data().Size() - header.IPv4MinimumSize - int(dataOff)
if !tcpHdr.IsChecksumValid(ipHdr.SourceAddress(), ipHdr.DestinationAddress(), payloadChecksum, uint16(tcpPayloadSize)) {
gd.handlePacket(pkt)
return
}
// We've validated the checksum, no reason for others to do it
// again.
pkt.RXChecksumValidated = true
}
// Now we can get the bucket for the packet.
bucket := &gd.buckets[gd.bucketForPacket4(ipHdr, tcpHdr)&groNBucketsMask]
groPkt, flushGROPkt := bucket.findGROPacket4(pkt, ipHdr, tcpHdr)
bucket.found(gd, groPkt, flushGROPkt, pkt, ipHdr, tcpHdr, updateIPv4Hdr)
}
func (gd *GRO) dispatch6(pkt *stack.PacketBuffer) {
// Immediately get the IPv6 and TCP headers. We need a way to hash the
// packet into its bucket, which requires addresses and ports. Linux
// simply gets a hash passed by hardware, but we're not so lucky.
hdrBytes, ok := pkt.Data().PullUp(header.IPv6MinimumSize)
if !ok {
gd.handlePacket(pkt)
return
}
ipHdr := header.IPv6(hdrBytes)
// Getting the IP header (+ extension headers) size is a bit of a pain
// on IPv6.
transProto := tcpip.TransportProtocolNumber(ipHdr.NextHeader())
buf := pkt.Data().ToBuffer()
buf.TrimFront(header.IPv6MinimumSize)
it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(transProto), buf)
ipHdrSize := int(header.IPv6MinimumSize)
for {
transProto = tcpip.TransportProtocolNumber(it.NextHeaderIdentifier())
extHdr, done, err := it.Next()
if err != nil {
gd.handlePacket(pkt)
return
}
if done {
break
}
switch extHdr.(type) {
// We can GRO these, so just skip over them.
case header.IPv6HopByHopOptionsExtHdr:
case header.IPv6RoutingExtHdr:
case header.IPv6DestinationOptionsExtHdr:
default:
// This is either a TCP header or something we can't handle.
ipHdrSize = int(it.HeaderOffset())
done = true
}
extHdr.Release()
if done {
break
}
}
hdrBytes, ok = pkt.Data().PullUp(ipHdrSize + header.TCPMinimumSize)
if !ok {
gd.handlePacket(pkt)
return
}
ipHdr = header.IPv6(hdrBytes[:ipHdrSize])
// We only handle TCP packets.
if transProto != header.TCPProtocolNumber {
gd.handlePacket(pkt)
return
}
tcpHdr := header.TCP(hdrBytes[ipHdrSize:])
dataOff := tcpHdr.DataOffset()
if dataOff < header.TCPMinimumSize {
// Malformed packet: will be handled further up the stack.
gd.handlePacket(pkt)
return
}
hdrBytes, ok = pkt.Data().PullUp(ipHdrSize + int(dataOff))
if !ok {
// Malformed packet: will be handled further up the stack.
gd.handlePacket(pkt)
return
}
tcpHdr = header.TCP(hdrBytes[ipHdrSize:])
// If either checksum is bad, flush the packet. Since we don't know
// what bits were flipped, we can't identify this packet with a flow.
if !pkt.RXChecksumValidated {
if !ipHdr.IsValid(pkt.Data().Size()) {
gd.handlePacket(pkt)
return
}
payloadChecksum := pkt.Data().ChecksumAtOffset(ipHdrSize + int(dataOff))
tcpPayloadSize := pkt.Data().Size() - ipHdrSize - int(dataOff)
if !tcpHdr.IsChecksumValid(ipHdr.SourceAddress(), ipHdr.DestinationAddress(), payloadChecksum, uint16(tcpPayloadSize)) {
gd.handlePacket(pkt)
return
}
// We've validated the checksum, no reason for others to do it
// again.
pkt.RXChecksumValidated = true
}
// Now we can get the bucket for the packet.
bucket := &gd.buckets[gd.bucketForPacket6(ipHdr, tcpHdr)&groNBucketsMask]
groPkt, flushGROPkt := bucket.findGROPacket6(pkt, ipHdr, tcpHdr)
bucket.found(gd, groPkt, flushGROPkt, pkt, ipHdr, tcpHdr, updateIPv6Hdr)
}
func (gd *GRO) bucketForPacket4(ipHdr header.IPv4, tcpHdr header.TCP) int {
// TODO(b/256037250): Use jenkins or checksum. Write a test to print
// distribution.
var sum int
srcAddr := ipHdr.SourceAddress()
for _, val := range srcAddr.AsSlice() {
sum += int(val)
}
dstAddr := ipHdr.DestinationAddress()
for _, val := range dstAddr.AsSlice() {
sum += int(val)
}
sum += int(tcpHdr.SourcePort())
sum += int(tcpHdr.DestinationPort())
return sum
}
func (gd *GRO) bucketForPacket6(ipHdr header.IPv6, tcpHdr header.TCP) int {
// TODO(b/256037250): Use jenkins or checksum. Write a test to print
// distribution.
var sum int
srcAddr := ipHdr.SourceAddress()
for _, val := range srcAddr.AsSlice() {
sum += int(val)
}
dstAddr := ipHdr.DestinationAddress()
for _, val := range dstAddr.AsSlice() {
sum += int(val)
}
sum += int(tcpHdr.SourcePort())
sum += int(tcpHdr.DestinationPort())
return sum
}
// Flush sends all packets up the stack.
func (gd *GRO) Flush() {
for i := range gd.buckets {
for groPkt := gd.buckets[i].packets.Front(); groPkt != nil; groPkt = groPkt.Next() {
pkt := groPkt.pkt
gd.buckets[i].removeOne(groPkt)
gd.handlePacket(pkt)
pkt.DecRef()
}
}
}
func (gd *GRO) handlePacket(pkt *stack.PacketBuffer) {
gd.Dispatcher.DeliverNetworkPacket(pkt.NetworkProtocolNumber, pkt)
}
// String implements fmt.Stringer.
func (gd *GRO) String() string {
ret := "GRO state: \n"
for i := range gd.buckets {
bucket := &gd.buckets[i]
ret += fmt.Sprintf("bucket %d: %d packets: ", i, bucket.count)
for groPkt := bucket.packets.Front(); groPkt != nil; groPkt = groPkt.Next() {
ret += fmt.Sprintf("%d, ", groPkt.pkt.Data().Size())
}
ret += "\n"
}
return ret
}
// shouldFlushTCP returns whether the TCP headers indicate that groPkt should
// be flushed
func shouldFlushTCP(groPkt *groPacket, tcpHdr header.TCP) bool {
flags := tcpHdr.Flags()
groPktFlags := groPkt.tcpHdr.Flags()
dataOff := tcpHdr.DataOffset()
if flags&header.TCPFlagCwr != 0 || // Is congestion control occurring?
(flags^groPktFlags)&^(header.TCPFlagCwr|header.TCPFlagFin|header.TCPFlagPsh) != 0 || // Do the flags differ besides CRW, FIN, and PSH?
tcpHdr.AckNumber() != groPkt.tcpHdr.AckNumber() || // Do the ACKs match?
dataOff != groPkt.tcpHdr.DataOffset() || // Are the TCP headers the same length?
groPkt.tcpHdr.SequenceNumber()+uint32(groPkt.payloadSize()) != tcpHdr.SequenceNumber() { // Does the incoming packet match the expected sequence number?
return true
}
// The options, including timestamps, must be identical.
return !bytes.Equal(tcpHdr[header.TCPMinimumSize:], groPkt.tcpHdr[header.TCPMinimumSize:])
}
func updateIPv4Hdr(ipHdrBytes []byte, newBytes int) {
ipHdr := header.IPv4(ipHdrBytes)
ipHdr.SetTotalLength(ipHdr.TotalLength() + uint16(newBytes))
}
func updateIPv6Hdr(ipHdrBytes []byte, newBytes int) {
ipHdr := header.IPv6(ipHdrBytes)
ipHdr.SetPayloadLength(ipHdr.PayloadLength() + uint16(newBytes))
}

View File

@@ -0,0 +1,239 @@
package gro
// ElementMapper provides an identity mapping by default.
//
// This can be replaced to provide a struct that maps elements to linker
// objects, if they are not the same. An ElementMapper is not typically
// required if: Linker is left as is, Element is left as is, or Linker and
// Element are the same type.
type groPacketElementMapper struct{}
// linkerFor maps an Element to a Linker.
//
// This default implementation should be inlined.
//
//go:nosplit
func (groPacketElementMapper) linkerFor(elem *groPacket) *groPacket { return elem }
// List is an intrusive list. Entries can be added to or removed from the list
// in O(1) time and with no additional memory allocations.
//
// The zero value for List is an empty list ready to use.
//
// To iterate over a list (where l is a List):
//
// for e := l.Front(); e != nil; e = e.Next() {
// // do something with e.
// }
//
// +stateify savable
type groPacketList struct {
head *groPacket
tail *groPacket
}
// Reset resets list l to the empty state.
func (l *groPacketList) Reset() {
l.head = nil
l.tail = nil
}
// Empty returns true iff the list is empty.
//
//go:nosplit
func (l *groPacketList) Empty() bool {
return l.head == nil
}
// Front returns the first element of list l or nil.
//
//go:nosplit
func (l *groPacketList) Front() *groPacket {
return l.head
}
// Back returns the last element of list l or nil.
//
//go:nosplit
func (l *groPacketList) Back() *groPacket {
return l.tail
}
// Len returns the number of elements in the list.
//
// NOTE: This is an O(n) operation.
//
//go:nosplit
func (l *groPacketList) Len() (count int) {
for e := l.Front(); e != nil; e = (groPacketElementMapper{}.linkerFor(e)).Next() {
count++
}
return count
}
// PushFront inserts the element e at the front of list l.
//
//go:nosplit
func (l *groPacketList) PushFront(e *groPacket) {
linker := groPacketElementMapper{}.linkerFor(e)
linker.SetNext(l.head)
linker.SetPrev(nil)
if l.head != nil {
groPacketElementMapper{}.linkerFor(l.head).SetPrev(e)
} else {
l.tail = e
}
l.head = e
}
// PushFrontList inserts list m at the start of list l, emptying m.
//
//go:nosplit
func (l *groPacketList) PushFrontList(m *groPacketList) {
if l.head == nil {
l.head = m.head
l.tail = m.tail
} else if m.head != nil {
groPacketElementMapper{}.linkerFor(l.head).SetPrev(m.tail)
groPacketElementMapper{}.linkerFor(m.tail).SetNext(l.head)
l.head = m.head
}
m.head = nil
m.tail = nil
}
// PushBack inserts the element e at the back of list l.
//
//go:nosplit
func (l *groPacketList) PushBack(e *groPacket) {
linker := groPacketElementMapper{}.linkerFor(e)
linker.SetNext(nil)
linker.SetPrev(l.tail)
if l.tail != nil {
groPacketElementMapper{}.linkerFor(l.tail).SetNext(e)
} else {
l.head = e
}
l.tail = e
}
// PushBackList inserts list m at the end of list l, emptying m.
//
//go:nosplit
func (l *groPacketList) PushBackList(m *groPacketList) {
if l.head == nil {
l.head = m.head
l.tail = m.tail
} else if m.head != nil {
groPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head)
groPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
l.tail = m.tail
}
m.head = nil
m.tail = nil
}
// InsertAfter inserts e after b.
//
//go:nosplit
func (l *groPacketList) InsertAfter(b, e *groPacket) {
bLinker := groPacketElementMapper{}.linkerFor(b)
eLinker := groPacketElementMapper{}.linkerFor(e)
a := bLinker.Next()
eLinker.SetNext(a)
eLinker.SetPrev(b)
bLinker.SetNext(e)
if a != nil {
groPacketElementMapper{}.linkerFor(a).SetPrev(e)
} else {
l.tail = e
}
}
// InsertBefore inserts e before a.
//
//go:nosplit
func (l *groPacketList) InsertBefore(a, e *groPacket) {
aLinker := groPacketElementMapper{}.linkerFor(a)
eLinker := groPacketElementMapper{}.linkerFor(e)
b := aLinker.Prev()
eLinker.SetNext(a)
eLinker.SetPrev(b)
aLinker.SetPrev(e)
if b != nil {
groPacketElementMapper{}.linkerFor(b).SetNext(e)
} else {
l.head = e
}
}
// Remove removes e from l.
//
//go:nosplit
func (l *groPacketList) Remove(e *groPacket) {
linker := groPacketElementMapper{}.linkerFor(e)
prev := linker.Prev()
next := linker.Next()
if prev != nil {
groPacketElementMapper{}.linkerFor(prev).SetNext(next)
} else if l.head == e {
l.head = next
}
if next != nil {
groPacketElementMapper{}.linkerFor(next).SetPrev(prev)
} else if l.tail == e {
l.tail = prev
}
linker.SetNext(nil)
linker.SetPrev(nil)
}
// Entry is a default implementation of Linker. Users can add anonymous fields
// of this type to their structs to make them automatically implement the
// methods needed by List.
//
// +stateify savable
type groPacketEntry struct {
next *groPacket
prev *groPacket
}
// Next returns the entry that follows e in the list.
//
//go:nosplit
func (e *groPacketEntry) Next() *groPacket {
return e.next
}
// Prev returns the entry that precedes e in the list.
//
//go:nosplit
func (e *groPacketEntry) Prev() *groPacket {
return e.prev
}
// SetNext assigns 'entry' as the entry that follows e in the list.
//
//go:nosplit
func (e *groPacketEntry) SetNext(elem *groPacket) {
e.next = elem
}
// SetPrev assigns 'entry' as the entry that precedes e in the list.
//
//go:nosplit
func (e *groPacketEntry) SetPrev(elem *groPacket) {
e.prev = elem
}

View File

@@ -0,0 +1,70 @@
// automatically generated by stateify.
package gro
import (
"context"
"gvisor.dev/gvisor/pkg/state"
)
func (l *groPacketList) StateTypeName() string {
return "pkg/tcpip/stack/gro.groPacketList"
}
func (l *groPacketList) StateFields() []string {
return []string{
"head",
"tail",
}
}
func (l *groPacketList) beforeSave() {}
// +checklocksignore
func (l *groPacketList) StateSave(stateSinkObject state.Sink) {
l.beforeSave()
stateSinkObject.Save(0, &l.head)
stateSinkObject.Save(1, &l.tail)
}
func (l *groPacketList) afterLoad(context.Context) {}
// +checklocksignore
func (l *groPacketList) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &l.head)
stateSourceObject.Load(1, &l.tail)
}
func (e *groPacketEntry) StateTypeName() string {
return "pkg/tcpip/stack/gro.groPacketEntry"
}
func (e *groPacketEntry) StateFields() []string {
return []string{
"next",
"prev",
}
}
func (e *groPacketEntry) beforeSave() {}
// +checklocksignore
func (e *groPacketEntry) StateSave(stateSinkObject state.Sink) {
e.beforeSave()
stateSinkObject.Save(0, &e.next)
stateSinkObject.Save(1, &e.prev)
}
func (e *groPacketEntry) afterLoad(context.Context) {}
// +checklocksignore
func (e *groPacketEntry) StateLoad(ctx context.Context, stateSourceObject state.Source) {
stateSourceObject.Load(0, &e.next)
stateSourceObject.Load(1, &e.prev)
}
func init() {
state.Register((*groPacketList)(nil))
state.Register((*groPacketEntry)(nil))
}

View File

@@ -0,0 +1,40 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at //
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Code generated by "stringer -type headerType ."; DO NOT EDIT.
package stack
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[virtioNetHeader-0]
_ = x[linkHeader-1]
_ = x[networkHeader-2]
_ = x[transportHeader-3]
_ = x[numHeaderType-4]
}
const _headerType_name = "virtioNetHeaderlinkHeadernetworkHeadertransportHeadernumHeaderType"
var _headerType_index = [...]uint8{0, 10, 23, 38, 51}
func (i headerType) String() string {
if i < 0 || i >= headerType(len(_headerType_index)-1) {
return "headerType(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _headerType_name[_headerType_index[i]:_headerType_index[i+1]]
}

View File

@@ -0,0 +1,41 @@
// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at //
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Code generated by "stringer -type Hook ."; DO NOT EDIT.
package stack
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[Prerouting-0]
_ = x[Input-1]
_ = x[Forward-2]
_ = x[Output-3]
_ = x[Postrouting-4]
_ = x[NumHooks-5]
}
const _Hook_name = "PreroutingInputForwardOutputPostroutingNumHooks"
var _Hook_index = [...]uint8{0, 10, 15, 22, 28, 39, 47}
func (i Hook) String() string {
if i >= Hook(len(_Hook_index)-1) {
return "Hook(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _Hook_name[_Hook_index[i]:_Hook_index[i+1]]
}

View File

@@ -0,0 +1,75 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"golang.org/x/time/rate"
"gvisor.dev/gvisor/pkg/tcpip"
)
const (
// icmpLimit is the default maximum number of ICMP messages permitted by this
// rate limiter.
icmpLimit = 1000
// icmpBurst is the default number of ICMP messages that can be sent in a single
// burst.
icmpBurst = 50
)
// ICMPRateLimiter is a global rate limiter that controls the generation of
// ICMP messages generated by the stack.
//
// +stateify savable
type ICMPRateLimiter struct {
// TODO(b/341946753): Restore when netstack is savable.
limiter *rate.Limiter `state:"nosave"`
clock tcpip.Clock
}
// NewICMPRateLimiter returns a global rate limiter for controlling the rate
// at which ICMP messages are generated by the stack. The returned limiter
// does not apply limits to any ICMP types by default.
func NewICMPRateLimiter(clock tcpip.Clock) *ICMPRateLimiter {
return &ICMPRateLimiter{
clock: clock,
limiter: rate.NewLimiter(icmpLimit, icmpBurst),
}
}
// SetLimit sets a new Limit for the limiter.
func (l *ICMPRateLimiter) SetLimit(limit rate.Limit) {
l.limiter.SetLimitAt(l.clock.Now(), limit)
}
// Limit returns the maximum overall event rate.
func (l *ICMPRateLimiter) Limit() rate.Limit {
return l.limiter.Limit()
}
// SetBurst sets a new burst size for the limiter.
func (l *ICMPRateLimiter) SetBurst(burst int) {
l.limiter.SetBurstAt(l.clock.Now(), burst)
}
// Burst returns the maximum burst size.
func (l *ICMPRateLimiter) Burst() int {
return l.limiter.Burst()
}
// Allow reports whether one ICMP message may be sent now.
func (l *ICMPRateLimiter) Allow() bool {
return l.limiter.AllowN(l.clock.Now(), 1)
}

View File

@@ -0,0 +1,717 @@
// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"context"
"fmt"
"math/rand"
"reflect"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
// TableID identifies a specific table.
type TableID int
// Each value identifies a specific table.
const (
NATID TableID = iota
MangleID
FilterID
NumTables
)
// HookUnset indicates that there is no hook set for an entrypoint or
// underflow.
const HookUnset = -1
// reaperDelay is how long to wait before starting to reap connections.
const reaperDelay = 5 * time.Second
// DefaultTables returns a default set of tables. Each chain is set to accept
// all packets.
func DefaultTables(clock tcpip.Clock, rand *rand.Rand) *IPTables {
return &IPTables{
v4Tables: [NumTables]Table{
NATID: {
Rules: []Rule{
{Filter: EmptyFilter4(), Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
{Filter: EmptyFilter4(), Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
{Filter: EmptyFilter4(), Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
{Filter: EmptyFilter4(), Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
{Filter: EmptyFilter4(), Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: 0,
Input: 1,
Forward: HookUnset,
Output: 2,
Postrouting: 3,
},
Underflows: [NumHooks]int{
Prerouting: 0,
Input: 1,
Forward: HookUnset,
Output: 2,
Postrouting: 3,
},
},
MangleID: {
Rules: []Rule{
{Filter: EmptyFilter4(), Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
{Filter: EmptyFilter4(), Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
{Filter: EmptyFilter4(), Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: 0,
Output: 1,
},
Underflows: [NumHooks]int{
Prerouting: 0,
Input: HookUnset,
Forward: HookUnset,
Output: 1,
Postrouting: HookUnset,
},
},
FilterID: {
Rules: []Rule{
{Filter: EmptyFilter4(), Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
{Filter: EmptyFilter4(), Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
{Filter: EmptyFilter4(), Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
{Filter: EmptyFilter4(), Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: HookUnset,
Input: 0,
Forward: 1,
Output: 2,
Postrouting: HookUnset,
},
Underflows: [NumHooks]int{
Prerouting: HookUnset,
Input: 0,
Forward: 1,
Output: 2,
Postrouting: HookUnset,
},
},
},
v6Tables: [NumTables]Table{
NATID: {
Rules: []Rule{
{Filter: EmptyFilter6(), Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
{Filter: EmptyFilter6(), Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
{Filter: EmptyFilter6(), Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
{Filter: EmptyFilter6(), Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
{Filter: EmptyFilter6(), Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: 0,
Input: 1,
Forward: HookUnset,
Output: 2,
Postrouting: 3,
},
Underflows: [NumHooks]int{
Prerouting: 0,
Input: 1,
Forward: HookUnset,
Output: 2,
Postrouting: 3,
},
},
MangleID: {
Rules: []Rule{
{Filter: EmptyFilter6(), Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
{Filter: EmptyFilter6(), Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
{Filter: EmptyFilter6(), Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: 0,
Output: 1,
},
Underflows: [NumHooks]int{
Prerouting: 0,
Input: HookUnset,
Forward: HookUnset,
Output: 1,
Postrouting: HookUnset,
},
},
FilterID: {
Rules: []Rule{
{Filter: EmptyFilter6(), Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
{Filter: EmptyFilter6(), Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
{Filter: EmptyFilter6(), Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
{Filter: EmptyFilter6(), Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: HookUnset,
Input: 0,
Forward: 1,
Output: 2,
Postrouting: HookUnset,
},
Underflows: [NumHooks]int{
Prerouting: HookUnset,
Input: 0,
Forward: 1,
Output: 2,
Postrouting: HookUnset,
},
},
},
connections: ConnTrack{
seed: rand.Uint32(),
clock: clock,
rand: rand,
},
}
}
// EmptyFilterTable returns a Table with no rules and the filter table chains
// mapped to HookUnset.
func EmptyFilterTable() Table {
return Table{
Rules: []Rule{},
BuiltinChains: [NumHooks]int{
Prerouting: HookUnset,
Postrouting: HookUnset,
},
Underflows: [NumHooks]int{
Prerouting: HookUnset,
Postrouting: HookUnset,
},
}
}
// EmptyNATTable returns a Table with no rules and the filter table chains
// mapped to HookUnset.
func EmptyNATTable() Table {
return Table{
Rules: []Rule{},
BuiltinChains: [NumHooks]int{
Forward: HookUnset,
},
Underflows: [NumHooks]int{
Forward: HookUnset,
},
}
}
// GetTable returns a table with the given id and IP version. It panics when an
// invalid id is provided.
func (it *IPTables) GetTable(id TableID, ipv6 bool) Table {
it.mu.RLock()
defer it.mu.RUnlock()
return it.getTableRLocked(id, ipv6)
}
// +checklocksread:it.mu
func (it *IPTables) getTableRLocked(id TableID, ipv6 bool) Table {
if ipv6 {
return it.v6Tables[id]
}
return it.v4Tables[id]
}
// ReplaceTable replaces or inserts table by name. It panics when an invalid id
// is provided.
func (it *IPTables) ReplaceTable(id TableID, table Table, ipv6 bool) {
it.replaceTable(id, table, ipv6, false /* force */)
}
// ForceReplaceTable replaces or inserts table by name. It panics when an invalid id
// is provided. It enables iptables even when the inserted table is all
// conditionless ACCEPT, skipping our optimization that disables iptables until
// they're modified.
func (it *IPTables) ForceReplaceTable(id TableID, table Table, ipv6 bool) {
it.replaceTable(id, table, ipv6, true /* force */)
}
func (it *IPTables) replaceTable(id TableID, table Table, ipv6, force bool) {
it.mu.Lock()
defer it.mu.Unlock()
// If iptables is being enabled, initialize the conntrack table and
// reaper.
if !it.modified {
// Don't do anything if the table is identical.
if ((ipv6 && reflect.DeepEqual(table, it.v6Tables[id])) || (!ipv6 && reflect.DeepEqual(table, it.v4Tables[id]))) && !force {
return
}
it.connections.init()
it.startReaper(reaperDelay)
}
it.modified = true
if ipv6 {
it.v6Tables[id] = table
} else {
it.v4Tables[id] = table
}
}
// A chainVerdict is what a table decides should be done with a packet.
type chainVerdict int
const (
// chainAccept indicates the packet should continue through netstack.
chainAccept chainVerdict = iota
// chainDrop indicates the packet should be dropped.
chainDrop
// chainReturn indicates the packet should return to the calling chain
// or the underflow rule of a builtin chain.
chainReturn
)
type checkTable struct {
fn checkTableFn
tableID TableID
table Table
}
// shouldSkipOrPopulateTables returns true iff IPTables should be skipped.
//
// If IPTables should not be skipped, tables will be updated with the
// specified table.
//
// This is called in the hot path even when iptables are disabled, so we ensure
// it does not allocate. We check recursively for heap allocations, but not for:
// - Stack splitting, which can allocate.
// - Calls to interfaces, which can allocate.
// - Calls to dynamic functions, which can allocate.
//
// +checkescape:hard
func (it *IPTables) shouldSkipOrPopulateTables(tables []checkTable, pkt *PacketBuffer) bool {
switch pkt.NetworkProtocolNumber {
case header.IPv4ProtocolNumber, header.IPv6ProtocolNumber:
default:
// IPTables only supports IPv4/IPv6.
return true
}
it.mu.RLock()
defer it.mu.RUnlock()
if !it.modified {
// Many users never configure iptables. Spare them the cost of rule
// traversal if rules have never been set.
return true
}
for i := range tables {
table := &tables[i]
table.table = it.getTableRLocked(table.tableID, pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber)
}
return false
}
// CheckPrerouting performs the prerouting hook on the packet.
//
// Returns true iff the packet may continue traversing the stack; the packet
// must be dropped if false is returned.
//
// Precondition: The packet's network and transport header must be set.
//
// This is called in the hot path even when iptables are disabled, so we ensure
// that it does not allocate. Note that called functions (e.g.
// getConnAndUpdate) can allocate.
// TODO(b/233951539): checkescape fails on arm sometimes. Fix and re-add.
func (it *IPTables) CheckPrerouting(pkt *PacketBuffer, addressEP AddressableEndpoint, inNicName string) bool {
tables := [...]checkTable{
{
fn: check,
tableID: MangleID,
},
{
fn: checkNAT,
tableID: NATID,
},
}
if it.shouldSkipOrPopulateTables(tables[:], pkt) {
return true
}
pkt.tuple = it.connections.getConnAndUpdate(pkt, false /* skipChecksumValidation */)
for _, table := range tables {
if !table.fn(it, table.table, Prerouting, pkt, nil /* route */, addressEP, inNicName, "" /* outNicName */) {
return false
}
}
return true
}
// CheckInput performs the input hook on the packet.
//
// Returns true iff the packet may continue traversing the stack; the packet
// must be dropped if false is returned.
//
// Precondition: The packet's network and transport header must be set.
//
// This is called in the hot path even when iptables are disabled, so we ensure
// that it does not allocate. Note that called functions (e.g.
// getConnAndUpdate) can allocate.
// TODO(b/233951539): checkescape fails on arm sometimes. Fix and re-add.
func (it *IPTables) CheckInput(pkt *PacketBuffer, inNicName string) bool {
tables := [...]checkTable{
{
fn: checkNAT,
tableID: NATID,
},
{
fn: check,
tableID: FilterID,
},
}
if it.shouldSkipOrPopulateTables(tables[:], pkt) {
return true
}
for _, table := range tables {
if !table.fn(it, table.table, Input, pkt, nil /* route */, nil /* addressEP */, inNicName, "" /* outNicName */) {
return false
}
}
if t := pkt.tuple; t != nil {
pkt.tuple = nil
return t.conn.finalize()
}
return true
}
// CheckForward performs the forward hook on the packet.
//
// Returns true iff the packet may continue traversing the stack; the packet
// must be dropped if false is returned.
//
// Precondition: The packet's network and transport header must be set.
//
// This is called in the hot path even when iptables are disabled, so we ensure
// that it does not allocate. Note that called functions (e.g.
// getConnAndUpdate) can allocate.
// TODO(b/233951539): checkescape fails on arm sometimes. Fix and re-add.
func (it *IPTables) CheckForward(pkt *PacketBuffer, inNicName, outNicName string) bool {
tables := [...]checkTable{
{
fn: check,
tableID: FilterID,
},
}
if it.shouldSkipOrPopulateTables(tables[:], pkt) {
return true
}
for _, table := range tables {
if !table.fn(it, table.table, Forward, pkt, nil /* route */, nil /* addressEP */, inNicName, outNicName) {
return false
}
}
return true
}
// CheckOutput performs the output hook on the packet.
//
// Returns true iff the packet may continue traversing the stack; the packet
// must be dropped if false is returned.
//
// Precondition: The packet's network and transport header must be set.
//
// This is called in the hot path even when iptables are disabled, so we ensure
// that it does not allocate. Note that called functions (e.g.
// getConnAndUpdate) can allocate.
// TODO(b/233951539): checkescape fails on arm sometimes. Fix and re-add.
func (it *IPTables) CheckOutput(pkt *PacketBuffer, r *Route, outNicName string) bool {
tables := [...]checkTable{
{
fn: check,
tableID: MangleID,
},
{
fn: checkNAT,
tableID: NATID,
},
{
fn: check,
tableID: FilterID,
},
}
if it.shouldSkipOrPopulateTables(tables[:], pkt) {
return true
}
// We don't need to validate the checksum in the Output path: we can assume
// we calculate it correctly, plus checksumming may be deferred due to GSO.
pkt.tuple = it.connections.getConnAndUpdate(pkt, true /* skipChecksumValidation */)
for _, table := range tables {
if !table.fn(it, table.table, Output, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName) {
return false
}
}
return true
}
// CheckPostrouting performs the postrouting hook on the packet.
//
// Returns true iff the packet may continue traversing the stack; the packet
// must be dropped if false is returned.
//
// Precondition: The packet's network and transport header must be set.
//
// This is called in the hot path even when iptables are disabled, so we ensure
// that it does not allocate. Note that called functions (e.g.
// getConnAndUpdate) can allocate.
// TODO(b/233951539): checkescape fails on arm sometimes. Fix and re-add.
func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, outNicName string) bool {
tables := [...]checkTable{
{
fn: check,
tableID: MangleID,
},
{
fn: checkNAT,
tableID: NATID,
},
}
if it.shouldSkipOrPopulateTables(tables[:], pkt) {
return true
}
for _, table := range tables {
if !table.fn(it, table.table, Postrouting, pkt, r, addressEP, "" /* inNicName */, outNicName) {
return false
}
}
if t := pkt.tuple; t != nil {
pkt.tuple = nil
return t.conn.finalize()
}
return true
}
// Note: this used to omit the *IPTables parameter, but doing so caused
// unnecessary allocations.
type checkTableFn func(it *IPTables, table Table, hook Hook, pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) bool
func checkNAT(it *IPTables, table Table, hook Hook, pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) bool {
return it.checkNAT(table, hook, pkt, r, addressEP, inNicName, outNicName)
}
// checkNAT runs the packet through the NAT table.
//
// See check.
func (it *IPTables) checkNAT(table Table, hook Hook, pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) bool {
t := pkt.tuple
if t != nil && t.conn.handlePacket(pkt, hook, r) {
return true
}
if !it.check(table, hook, pkt, r, addressEP, inNicName, outNicName) {
return false
}
if t == nil {
return true
}
dnat, natDone := func() (bool, bool) {
switch hook {
case Prerouting, Output:
return true, pkt.dnatDone
case Input, Postrouting:
return false, pkt.snatDone
case Forward:
panic("should not attempt NAT in forwarding")
default:
panic(fmt.Sprintf("unhandled hook = %d", hook))
}
}()
// Make sure the connection is NATed.
//
// If the packet was already NATed, the connection must be NATed.
if !natDone {
t.conn.maybePerformNoopNAT(pkt, hook, r, dnat)
}
return true
}
func check(it *IPTables, table Table, hook Hook, pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) bool {
return it.check(table, hook, pkt, r, addressEP, inNicName, outNicName)
}
// check runs the packet through the rules in the specified table for the
// hook. It returns true if the packet should continue to traverse through the
// network stack or tables, or false when it must be dropped.
//
// Precondition: The packet's network and transport header must be set.
func (it *IPTables) check(table Table, hook Hook, pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) bool {
ruleIdx := table.BuiltinChains[hook]
switch verdict := it.checkChain(hook, pkt, table, ruleIdx, r, addressEP, inNicName, outNicName); verdict {
// If the table returns Accept, move on to the next table.
case chainAccept:
return true
// The Drop verdict is final.
case chainDrop:
return false
case chainReturn:
// Any Return from a built-in chain means we have to
// call the underflow.
underflow := table.Rules[table.Underflows[hook]]
switch v, _ := underflow.Target.Action(pkt, hook, r, addressEP); v {
case RuleAccept:
return true
case RuleDrop:
return false
case RuleJump, RuleReturn:
panic("Underflows should only return RuleAccept or RuleDrop.")
default:
panic(fmt.Sprintf("Unknown verdict: %d", v))
}
default:
panic(fmt.Sprintf("Unknown verdict %v.", verdict))
}
}
// beforeSave is invoked by stateify.
func (it *IPTables) beforeSave() {
// Ensure the reaper exits cleanly.
it.reaper.Stop()
// Prevent others from modifying the connection table.
it.connections.mu.Lock()
}
// afterLoad is invoked by stateify.
func (it *IPTables) afterLoad(context.Context) {
it.startReaper(reaperDelay)
}
// startReaper periodically reaps timed out connections.
func (it *IPTables) startReaper(interval time.Duration) {
bucket := 0
it.reaper = it.connections.clock.AfterFunc(interval, func() {
bucket, interval = it.connections.reapUnused(bucket, interval)
it.reaper.Reset(interval)
})
}
// Preconditions:
// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// - pkt.NetworkHeader is not nil.
func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
for ruleIdx < len(table.Rules) {
switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, r, addressEP, inNicName, outNicName); verdict {
case RuleAccept:
return chainAccept
case RuleDrop:
return chainDrop
case RuleReturn:
return chainReturn
case RuleJump:
// "Jumping" to the next rule just means we're
// continuing on down the list.
if jumpTo == ruleIdx+1 {
ruleIdx++
continue
}
switch verdict := it.checkChain(hook, pkt, table, jumpTo, r, addressEP, inNicName, outNicName); verdict {
case chainAccept:
return chainAccept
case chainDrop:
return chainDrop
case chainReturn:
ruleIdx++
continue
default:
panic(fmt.Sprintf("Unknown verdict: %d", verdict))
}
default:
panic(fmt.Sprintf("Unknown verdict: %d", verdict))
}
}
// We got through the entire table without a decision. Default to DROP
// for safety.
return chainDrop
}
// Preconditions:
// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// - pkt.NetworkHeader is not nil.
//
// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// * pkt.NetworkHeader is not nil.
func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
// Check whether the packet matches the IP header filter.
if !rule.Filter.match(pkt, hook, inNicName, outNicName) {
// Continue on to the next rule.
return RuleJump, ruleIdx + 1
}
// Go through each rule matcher. If they all match, run
// the rule target.
for _, matcher := range rule.Matchers {
matches, hotdrop := matcher.Match(hook, pkt, inNicName, outNicName)
if hotdrop {
return RuleDrop, 0
}
if !matches {
// Continue on to the next rule.
return RuleJump, ruleIdx + 1
}
}
// All the matchers matched, so run the target.
return rule.Target.Action(pkt, hook, r, addressEP)
}
// OriginalDst returns the original destination of redirected connections. It
// returns an error if the connection doesn't exist or isn't redirected.
func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
it.mu.RLock()
defer it.mu.RUnlock()
if !it.modified {
return tcpip.Address{}, 0, &tcpip.ErrNotConnected{}
}
return it.connections.originalDst(epID, netProto, transProto)
}

View File

@@ -0,0 +1,96 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type ipTablesRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var ipTableslockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type ipTableslockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *ipTablesRWMutex) Lock() {
locking.AddGLock(ipTablesprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *ipTablesRWMutex) NestedLock(i ipTableslockNameIndex) {
locking.AddGLock(ipTablesprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *ipTablesRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(ipTablesprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *ipTablesRWMutex) NestedUnlock(i ipTableslockNameIndex) {
m.mu.Unlock()
locking.DelGLock(ipTablesprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *ipTablesRWMutex) RLock() {
locking.AddGLock(ipTablesprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *ipTablesRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(ipTablesprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *ipTablesRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *ipTablesRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *ipTablesRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var ipTablesprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func ipTablesinitLockNames() {}
func init() {
ipTablesinitLockNames()
ipTablesprefixIndex = locking.NewMutexClass(reflect.TypeOf(ipTablesRWMutex{}), ipTableslockNames)
}

View File

@@ -0,0 +1,493 @@
// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"fmt"
"math"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
// AcceptTarget accepts packets.
//
// +stateify savable
type AcceptTarget struct {
// NetworkProtocol is the network protocol the target is used with.
NetworkProtocol tcpip.NetworkProtocolNumber
}
// Action implements Target.Action.
func (*AcceptTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleAccept, 0
}
// DropTarget drops packets.
//
// +stateify savable
type DropTarget struct {
// NetworkProtocol is the network protocol the target is used with.
NetworkProtocol tcpip.NetworkProtocolNumber
}
// Action implements Target.Action.
func (*DropTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleDrop, 0
}
// RejectIPv4WithHandler handles rejecting a packet.
type RejectIPv4WithHandler interface {
// SendRejectionError sends an error packet in response to the packet.
SendRejectionError(pkt *PacketBuffer, rejectWith RejectIPv4WithICMPType, inputHook bool) tcpip.Error
}
// RejectIPv4WithICMPType indicates the type of ICMP error that should be sent.
type RejectIPv4WithICMPType int
// The types of errors that may be returned when rejecting IPv4 packets.
const (
_ RejectIPv4WithICMPType = iota
RejectIPv4WithICMPNetUnreachable
RejectIPv4WithICMPHostUnreachable
RejectIPv4WithICMPPortUnreachable
RejectIPv4WithICMPNetProhibited
RejectIPv4WithICMPHostProhibited
RejectIPv4WithICMPAdminProhibited
)
// RejectIPv4Target drops packets and sends back an error packet in response to the
// matched packet.
//
// +stateify savable
type RejectIPv4Target struct {
Handler RejectIPv4WithHandler
RejectWith RejectIPv4WithICMPType
}
// Action implements Target.Action.
func (rt *RejectIPv4Target) Action(pkt *PacketBuffer, hook Hook, _ *Route, _ AddressableEndpoint) (RuleVerdict, int) {
switch hook {
case Input, Forward, Output:
// There is nothing reasonable for us to do in response to an error here;
// we already drop the packet.
_ = rt.Handler.SendRejectionError(pkt, rt.RejectWith, hook == Input)
return RuleDrop, 0
case Prerouting, Postrouting:
panic(fmt.Sprintf("%s hook not supported for REDIRECT", hook))
default:
panic(fmt.Sprintf("unhandled hook = %s", hook))
}
}
// RejectIPv6WithHandler handles rejecting a packet.
type RejectIPv6WithHandler interface {
// SendRejectionError sends an error packet in response to the packet.
SendRejectionError(pkt *PacketBuffer, rejectWith RejectIPv6WithICMPType, forwardingHook bool) tcpip.Error
}
// RejectIPv6WithICMPType indicates the type of ICMP error that should be sent.
type RejectIPv6WithICMPType int
// The types of errors that may be returned when rejecting IPv6 packets.
const (
_ RejectIPv6WithICMPType = iota
RejectIPv6WithICMPNoRoute
RejectIPv6WithICMPAddrUnreachable
RejectIPv6WithICMPPortUnreachable
RejectIPv6WithICMPAdminProhibited
)
// RejectIPv6Target drops packets and sends back an error packet in response to the
// matched packet.
//
// +stateify savable
type RejectIPv6Target struct {
Handler RejectIPv6WithHandler
RejectWith RejectIPv6WithICMPType
}
// Action implements Target.Action.
func (rt *RejectIPv6Target) Action(pkt *PacketBuffer, hook Hook, _ *Route, _ AddressableEndpoint) (RuleVerdict, int) {
switch hook {
case Input, Forward, Output:
// There is nothing reasonable for us to do in response to an error here;
// we already drop the packet.
_ = rt.Handler.SendRejectionError(pkt, rt.RejectWith, hook == Input)
return RuleDrop, 0
case Prerouting, Postrouting:
panic(fmt.Sprintf("%s hook not supported for REDIRECT", hook))
default:
panic(fmt.Sprintf("unhandled hook = %s", hook))
}
}
// ErrorTarget logs an error and drops the packet. It represents a target that
// should be unreachable.
//
// +stateify savable
type ErrorTarget struct {
// NetworkProtocol is the network protocol the target is used with.
NetworkProtocol tcpip.NetworkProtocolNumber
}
// Action implements Target.Action.
func (*ErrorTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
log.Debugf("ErrorTarget triggered.")
return RuleDrop, 0
}
// UserChainTarget marks a rule as the beginning of a user chain.
//
// +stateify savable
type UserChainTarget struct {
// Name is the chain name.
Name string
// NetworkProtocol is the network protocol the target is used with.
NetworkProtocol tcpip.NetworkProtocolNumber
}
// Action implements Target.Action.
func (*UserChainTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
panic("UserChainTarget should never be called.")
}
// ReturnTarget returns from the current chain. If the chain is a built-in, the
// hook's underflow should be called.
//
// +stateify savable
type ReturnTarget struct {
// NetworkProtocol is the network protocol the target is used with.
NetworkProtocol tcpip.NetworkProtocolNumber
}
// Action implements Target.Action.
func (*ReturnTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleReturn, 0
}
// DNATTarget modifies the destination port/IP of packets.
//
// +stateify savable
type DNATTarget struct {
// The new destination address for packets.
//
// Immutable.
Addr tcpip.Address
// The new destination port for packets.
//
// Immutable.
Port uint16
// NetworkProtocol is the network protocol the target is used with.
//
// Immutable.
NetworkProtocol tcpip.NetworkProtocolNumber
// ChangeAddress indicates whether we should check addresses.
//
// Immutable.
ChangeAddress bool
// ChangePort indicates whether we should check ports.
//
// Immutable.
ChangePort bool
}
// Action implements Target.Action.
func (rt *DNATTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
// Sanity check.
if rt.NetworkProtocol != pkt.NetworkProtocolNumber {
panic(fmt.Sprintf(
"DNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
rt.NetworkProtocol, pkt.NetworkProtocolNumber))
}
switch hook {
case Prerouting, Output:
case Input, Forward, Postrouting:
panic(fmt.Sprintf("%s not supported for DNAT", hook))
default:
panic(fmt.Sprintf("%s unrecognized", hook))
}
return dnatAction(pkt, hook, r, rt.Port, rt.Addr, rt.ChangePort, rt.ChangeAddress)
}
// RedirectTarget redirects the packet to this machine by modifying the
// destination port/IP. Outgoing packets are redirected to the loopback device,
// and incoming packets are redirected to the incoming interface (rather than
// forwarded).
//
// +stateify savable
type RedirectTarget struct {
// Port indicates port used to redirect. It is immutable.
Port uint16
// NetworkProtocol is the network protocol the target is used with. It
// is immutable.
NetworkProtocol tcpip.NetworkProtocolNumber
}
// Action implements Target.Action.
func (rt *RedirectTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
// Sanity check.
if rt.NetworkProtocol != pkt.NetworkProtocolNumber {
panic(fmt.Sprintf(
"RedirectTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
rt.NetworkProtocol, pkt.NetworkProtocolNumber))
}
// Change the address to loopback (127.0.0.1 or ::1) in Output and to
// the primary address of the incoming interface in Prerouting.
var address tcpip.Address
switch hook {
case Output:
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
address = tcpip.AddrFrom4([4]byte{127, 0, 0, 1})
} else {
address = header.IPv6Loopback
}
case Prerouting:
// addressEP is expected to be set for the prerouting hook.
address = addressEP.MainAddress().Address
default:
panic("redirect target is supported only on output and prerouting hooks")
}
return dnatAction(pkt, hook, r, rt.Port, address, true /* changePort */, true /* changeAddress */)
}
// SNATTarget modifies the source port/IP in the outgoing packets.
//
// +stateify savable
type SNATTarget struct {
Addr tcpip.Address
Port uint16
// NetworkProtocol is the network protocol the target is used with. It
// is immutable.
NetworkProtocol tcpip.NetworkProtocolNumber
// ChangeAddress indicates whether we should check addresses.
//
// Immutable.
ChangeAddress bool
// ChangePort indicates whether we should check ports.
//
// Immutable.
ChangePort bool
}
func dnatAction(pkt *PacketBuffer, hook Hook, r *Route, port uint16, address tcpip.Address, changePort, changeAddress bool) (RuleVerdict, int) {
return natAction(pkt, hook, r, portOrIdentRange{start: port, size: 1}, address, true /* dnat */, changePort, changeAddress)
}
func targetPortRangeForTCPAndUDP(originalSrcPort uint16) portOrIdentRange {
// As per iptables(8),
//
// If no port range is specified, then source ports below 512 will be
// mapped to other ports below 512: those between 512 and 1023 inclusive
// will be mapped to ports below 1024, and other ports will be mapped to
// 1024 or above.
switch {
case originalSrcPort < 512:
return portOrIdentRange{start: 1, size: 511}
case originalSrcPort < 1024:
return portOrIdentRange{start: 1, size: 1023}
default:
return portOrIdentRange{start: 1024, size: math.MaxUint16 - 1023}
}
}
func snatAction(pkt *PacketBuffer, hook Hook, r *Route, port uint16, address tcpip.Address, changePort, changeAddress bool) (RuleVerdict, int) {
portsOrIdents := portOrIdentRange{start: port, size: 1}
switch pkt.TransportProtocolNumber {
case header.UDPProtocolNumber:
if port == 0 {
portsOrIdents = targetPortRangeForTCPAndUDP(header.UDP(pkt.TransportHeader().Slice()).SourcePort())
}
case header.TCPProtocolNumber:
if port == 0 {
portsOrIdents = targetPortRangeForTCPAndUDP(header.TCP(pkt.TransportHeader().Slice()).SourcePort())
}
case header.ICMPv4ProtocolNumber, header.ICMPv6ProtocolNumber:
// Allow NAT-ing to any 16-bit value for ICMP's Ident field to match Linux
// behaviour.
//
// https://github.com/torvalds/linux/blob/58e1100fdc5990b0cc0d4beaf2562a92e621ac7d/net/netfilter/nf_nat_core.c#L391
portsOrIdents = portOrIdentRange{start: 0, size: math.MaxUint16 + 1}
}
return natAction(pkt, hook, r, portsOrIdents, address, false /* dnat */, changePort, changeAddress)
}
func natAction(pkt *PacketBuffer, hook Hook, r *Route, portsOrIdents portOrIdentRange, address tcpip.Address, dnat, changePort, changeAddress bool) (RuleVerdict, int) {
// Drop the packet if network and transport header are not set.
if len(pkt.NetworkHeader().Slice()) == 0 || len(pkt.TransportHeader().Slice()) == 0 {
return RuleDrop, 0
}
if t := pkt.tuple; t != nil {
t.conn.performNAT(pkt, hook, r, portsOrIdents, address, dnat, changePort, changeAddress)
return RuleAccept, 0
}
return RuleDrop, 0
}
// Action implements Target.Action.
func (st *SNATTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, _ AddressableEndpoint) (RuleVerdict, int) {
// Sanity check.
if st.NetworkProtocol != pkt.NetworkProtocolNumber {
panic(fmt.Sprintf(
"SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
st.NetworkProtocol, pkt.NetworkProtocolNumber))
}
switch hook {
case Postrouting, Input:
case Prerouting, Output, Forward:
panic(fmt.Sprintf("%s not supported", hook))
default:
panic(fmt.Sprintf("%s unrecognized", hook))
}
return snatAction(pkt, hook, r, st.Port, st.Addr, st.ChangePort, st.ChangeAddress)
}
// MasqueradeTarget modifies the source port/IP in the outgoing packets.
//
// +stateify savable
type MasqueradeTarget struct {
// NetworkProtocol is the network protocol the target is used with. It
// is immutable.
NetworkProtocol tcpip.NetworkProtocolNumber
}
// Action implements Target.Action.
func (mt *MasqueradeTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
// Sanity check.
if mt.NetworkProtocol != pkt.NetworkProtocolNumber {
panic(fmt.Sprintf(
"MasqueradeTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
mt.NetworkProtocol, pkt.NetworkProtocolNumber))
}
switch hook {
case Postrouting:
case Prerouting, Input, Forward, Output:
panic(fmt.Sprintf("masquerade target is supported only on postrouting hook; hook = %d", hook))
default:
panic(fmt.Sprintf("%s unrecognized", hook))
}
// addressEP is expected to be set for the postrouting hook.
ep := addressEP.AcquireOutgoingPrimaryAddress(pkt.Network().DestinationAddress(), tcpip.Address{} /* srcHint */, false /* allowExpired */)
if ep == nil {
// No address exists that we can use as a source address.
return RuleDrop, 0
}
address := ep.AddressWithPrefix().Address
ep.DecRef()
return snatAction(pkt, hook, r, 0 /* port */, address, true /* changePort */, true /* changeAddress */)
}
func rewritePacket(n header.Network, t header.Transport, updateSRCFields, fullChecksum, updatePseudoHeader bool, newPortOrIdent uint16, newAddr tcpip.Address) {
switch t := t.(type) {
case header.ChecksummableTransport:
if updateSRCFields {
if fullChecksum {
t.SetSourcePortWithChecksumUpdate(newPortOrIdent)
} else {
t.SetSourcePort(newPortOrIdent)
}
} else {
if fullChecksum {
t.SetDestinationPortWithChecksumUpdate(newPortOrIdent)
} else {
t.SetDestinationPort(newPortOrIdent)
}
}
if updatePseudoHeader {
var oldAddr tcpip.Address
if updateSRCFields {
oldAddr = n.SourceAddress()
} else {
oldAddr = n.DestinationAddress()
}
t.UpdateChecksumPseudoHeaderAddress(oldAddr, newAddr, fullChecksum)
}
case header.ICMPv4:
switch icmpType := t.Type(); icmpType {
case header.ICMPv4Echo:
if updateSRCFields {
t.SetIdentWithChecksumUpdate(newPortOrIdent)
}
case header.ICMPv4EchoReply:
if !updateSRCFields {
t.SetIdentWithChecksumUpdate(newPortOrIdent)
}
default:
panic(fmt.Sprintf("unexpected ICMPv4 type = %d", icmpType))
}
case header.ICMPv6:
switch icmpType := t.Type(); icmpType {
case header.ICMPv6EchoRequest:
if updateSRCFields {
t.SetIdentWithChecksumUpdate(newPortOrIdent)
}
case header.ICMPv6EchoReply:
if !updateSRCFields {
t.SetIdentWithChecksumUpdate(newPortOrIdent)
}
default:
panic(fmt.Sprintf("unexpected ICMPv4 type = %d", icmpType))
}
var oldAddr tcpip.Address
if updateSRCFields {
oldAddr = n.SourceAddress()
} else {
oldAddr = n.DestinationAddress()
}
t.UpdateChecksumPseudoHeaderAddress(oldAddr, newAddr)
default:
panic(fmt.Sprintf("unhandled transport = %#v", t))
}
if checksummableNetHeader, ok := n.(header.ChecksummableNetwork); ok {
if updateSRCFields {
checksummableNetHeader.SetSourceAddressWithChecksumUpdate(newAddr)
} else {
checksummableNetHeader.SetDestinationAddressWithChecksumUpdate(newAddr)
}
} else if updateSRCFields {
n.SetSourceAddress(newAddr)
} else {
n.SetDestinationAddress(newAddr)
}
}

View File

@@ -0,0 +1,387 @@
// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"fmt"
"strings"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
// A Hook specifies one of the hooks built into the network stack.
//
// Userspace app Userspace app
// ^ |
// | v
// [Input] [Output]
// ^ |
// | v
// | routing
// | |
// | v
// ----->[Prerouting]----->routing----->[Forward]---------[Postrouting]----->
type Hook uint
const (
// Prerouting happens before a packet is routed to applications or to
// be forwarded.
Prerouting Hook = iota
// Input happens before a packet reaches an application.
Input
// Forward happens once it's decided that a packet should be forwarded
// to another host.
Forward
// Output happens after a packet is written by an application to be
// sent out.
Output
// Postrouting happens just before a packet goes out on the wire.
Postrouting
// NumHooks is the total number of hooks.
NumHooks
)
// A RuleVerdict is what a rule decides should be done with a packet.
type RuleVerdict int
const (
// RuleAccept indicates the packet should continue through netstack.
RuleAccept RuleVerdict = iota
// RuleDrop indicates the packet should be dropped.
RuleDrop
// RuleJump indicates the packet should jump to another chain.
RuleJump
// RuleReturn indicates the packet should return to the previous chain.
RuleReturn
)
// IPTables holds all the tables for a netstack.
//
// +stateify savable
type IPTables struct {
connections ConnTrack
reaper tcpip.Timer
mu ipTablesRWMutex `state:"nosave"`
// v4Tables and v6tables map tableIDs to tables. They hold builtin
// tables only, not user tables.
//
// mu protects the array of tables, but not the tables themselves.
// +checklocks:mu
v4Tables [NumTables]Table
//
// mu protects the array of tables, but not the tables themselves.
// +checklocks:mu
v6Tables [NumTables]Table
// modified is whether tables have been modified at least once. It is
// used to elide the iptables performance overhead for workloads that
// don't utilize iptables.
//
// +checklocks:mu
modified bool
}
// Modified returns whether iptables has been modified. It is inherently racy
// and intended for use only in tests.
func (it *IPTables) Modified() bool {
it.mu.Lock()
defer it.mu.Unlock()
return it.modified
}
// VisitTargets traverses all the targets of all tables and replaces each with
// transform(target).
func (it *IPTables) VisitTargets(transform func(Target) Target) {
it.mu.Lock()
defer it.mu.Unlock()
for tid := range it.v4Tables {
for i, rule := range it.v4Tables[tid].Rules {
it.v4Tables[tid].Rules[i].Target = transform(rule.Target)
}
}
for tid := range it.v6Tables {
for i, rule := range it.v6Tables[tid].Rules {
it.v6Tables[tid].Rules[i].Target = transform(rule.Target)
}
}
}
// A Table defines a set of chains and hooks into the network stack.
//
// It is a list of Rules, entry points (BuiltinChains), and error handlers
// (Underflows). As packets traverse netstack, they hit hooks. When a packet
// hits a hook, iptables compares it to Rules starting from that hook's entry
// point. So if a packet hits the Input hook, we look up the corresponding
// entry point in BuiltinChains and jump to that point.
//
// If the Rule doesn't match the packet, iptables continues to the next Rule.
// If a Rule does match, it can issue a verdict on the packet (e.g. RuleAccept
// or RuleDrop) that causes the packet to stop traversing iptables. It can also
// jump to other rules or perform custom actions based on Rule.Target.
//
// Underflow Rules are invoked when a chain returns without reaching a verdict.
//
// +stateify savable
type Table struct {
// Rules holds the rules that make up the table.
Rules []Rule
// BuiltinChains maps builtin chains to their entrypoint rule in Rules.
BuiltinChains [NumHooks]int
// Underflows maps builtin chains to their underflow rule in Rules
// (i.e. the rule to execute if the chain returns without a verdict).
Underflows [NumHooks]int
}
// ValidHooks returns a bitmap of the builtin hooks for the given table.
func (table *Table) ValidHooks() uint32 {
hooks := uint32(0)
for hook, ruleIdx := range table.BuiltinChains {
if ruleIdx != HookUnset {
hooks |= 1 << hook
}
}
return hooks
}
// A Rule is a packet processing rule. It consists of two pieces. First it
// contains zero or more matchers, each of which is a specification of which
// packets this rule applies to. If there are no matchers in the rule, it
// applies to any packet.
//
// +stateify savable
type Rule struct {
// Filter holds basic IP filtering fields common to every rule.
Filter IPHeaderFilter
// Matchers is the list of matchers for this rule.
Matchers []Matcher
// Target is the action to invoke if all the matchers match the packet.
Target Target
}
// IPHeaderFilter performs basic IP header matching common to every rule.
//
// +stateify savable
type IPHeaderFilter struct {
// Protocol matches the transport protocol.
Protocol tcpip.TransportProtocolNumber
// CheckProtocol determines whether the Protocol field should be
// checked during matching.
CheckProtocol bool
// Dst matches the destination IP address.
Dst tcpip.Address
// DstMask masks bits of the destination IP address when comparing with
// Dst.
DstMask tcpip.Address
// DstInvert inverts the meaning of the destination IP check, i.e. when
// true the filter will match packets that fail the destination
// comparison.
DstInvert bool
// Src matches the source IP address.
Src tcpip.Address
// SrcMask masks bits of the source IP address when comparing with Src.
SrcMask tcpip.Address
// SrcInvert inverts the meaning of the source IP check, i.e. when true the
// filter will match packets that fail the source comparison.
SrcInvert bool
// InputInterface matches the name of the incoming interface for the packet.
InputInterface string
// InputInterfaceMask masks the characters of the interface name when
// comparing with InputInterface.
InputInterfaceMask string
// InputInterfaceInvert inverts the meaning of incoming interface check,
// i.e. when true the filter will match packets that fail the incoming
// interface comparison.
InputInterfaceInvert bool
// OutputInterface matches the name of the outgoing interface for the packet.
OutputInterface string
// OutputInterfaceMask masks the characters of the interface name when
// comparing with OutputInterface.
OutputInterfaceMask string
// OutputInterfaceInvert inverts the meaning of outgoing interface check,
// i.e. when true the filter will match packets that fail the outgoing
// interface comparison.
OutputInterfaceInvert bool
}
// EmptyFilter4 returns an initialized IPv4 header filter.
func EmptyFilter4() IPHeaderFilter {
return IPHeaderFilter{
Dst: tcpip.AddrFrom4([4]byte{}),
DstMask: tcpip.AddrFrom4([4]byte{}),
Src: tcpip.AddrFrom4([4]byte{}),
SrcMask: tcpip.AddrFrom4([4]byte{}),
}
}
// EmptyFilter6 returns an initialized IPv6 header filter.
func EmptyFilter6() IPHeaderFilter {
return IPHeaderFilter{
Dst: tcpip.AddrFrom16([16]byte{}),
DstMask: tcpip.AddrFrom16([16]byte{}),
Src: tcpip.AddrFrom16([16]byte{}),
SrcMask: tcpip.AddrFrom16([16]byte{}),
}
}
// match returns whether pkt matches the filter.
//
// Preconditions: pkt.NetworkHeader is set and is at least of the minimal IPv4
// or IPv6 header length.
func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicName string) bool {
// Extract header fields.
var (
transProto tcpip.TransportProtocolNumber
dstAddr tcpip.Address
srcAddr tcpip.Address
)
switch proto := pkt.NetworkProtocolNumber; proto {
case header.IPv4ProtocolNumber:
hdr := header.IPv4(pkt.NetworkHeader().Slice())
transProto = hdr.TransportProtocol()
dstAddr = hdr.DestinationAddress()
srcAddr = hdr.SourceAddress()
case header.IPv6ProtocolNumber:
hdr := header.IPv6(pkt.NetworkHeader().Slice())
transProto = hdr.TransportProtocol()
dstAddr = hdr.DestinationAddress()
srcAddr = hdr.SourceAddress()
default:
panic(fmt.Sprintf("unknown network protocol with EtherType: %d", proto))
}
// Check the transport protocol.
if fl.CheckProtocol && fl.Protocol != transProto {
return false
}
// Check the addresses.
if !filterAddress(dstAddr, fl.DstMask, fl.Dst, fl.DstInvert) ||
!filterAddress(srcAddr, fl.SrcMask, fl.Src, fl.SrcInvert) {
return false
}
switch hook {
case Prerouting, Input:
return matchIfName(inNicName, fl.InputInterface, fl.InputInterfaceInvert)
case Output:
return matchIfName(outNicName, fl.OutputInterface, fl.OutputInterfaceInvert)
case Forward:
if !matchIfName(inNicName, fl.InputInterface, fl.InputInterfaceInvert) {
return false
}
if !matchIfName(outNicName, fl.OutputInterface, fl.OutputInterfaceInvert) {
return false
}
return true
case Postrouting:
return true
default:
panic(fmt.Sprintf("unknown hook: %d", hook))
}
}
func matchIfName(nicName string, ifName string, invert bool) bool {
n := len(ifName)
if n == 0 {
// If the interface name is omitted in the filter, any interface will match.
return true
}
// If the interface name ends with '+', any interface which begins with the
// name should be matched.
var matches bool
if strings.HasSuffix(ifName, "+") {
matches = strings.HasPrefix(nicName, ifName[:n-1])
} else {
matches = nicName == ifName
}
return matches != invert
}
// NetworkProtocol returns the protocol (IPv4 or IPv6) on to which the header
// applies.
func (fl IPHeaderFilter) NetworkProtocol() tcpip.NetworkProtocolNumber {
switch fl.Src.BitLen() {
case header.IPv4AddressSizeBits:
return header.IPv4ProtocolNumber
case header.IPv6AddressSizeBits:
return header.IPv6ProtocolNumber
}
panic(fmt.Sprintf("invalid address in IPHeaderFilter: %s", fl.Src))
}
// filterAddress returns whether addr matches the filter.
func filterAddress(addr, mask, filterAddr tcpip.Address, invert bool) bool {
matches := true
addrBytes := addr.AsSlice()
maskBytes := mask.AsSlice()
filterBytes := filterAddr.AsSlice()
for i := range filterAddr.AsSlice() {
if addrBytes[i]&maskBytes[i] != filterBytes[i] {
matches = false
break
}
}
return matches != invert
}
// A Matcher is the interface for matching packets.
type Matcher interface {
// Match returns whether the packet matches and whether the packet
// should be "hotdropped", i.e. dropped immediately. This is usually
// used for suspicious packets.
//
// Precondition: packet.NetworkHeader is set.
Match(hook Hook, packet *PacketBuffer, inputInterfaceName, outputInterfaceName string) (matches bool, hotdrop bool)
}
// A Target is the interface for taking an action for a packet.
type Target interface {
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
// Jump, it also returns the index of the rule to jump to.
Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int)
}

View File

@@ -0,0 +1,96 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type multiPortEndpointRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var multiPortEndpointlockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type multiPortEndpointlockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *multiPortEndpointRWMutex) Lock() {
locking.AddGLock(multiPortEndpointprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *multiPortEndpointRWMutex) NestedLock(i multiPortEndpointlockNameIndex) {
locking.AddGLock(multiPortEndpointprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *multiPortEndpointRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(multiPortEndpointprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *multiPortEndpointRWMutex) NestedUnlock(i multiPortEndpointlockNameIndex) {
m.mu.Unlock()
locking.DelGLock(multiPortEndpointprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *multiPortEndpointRWMutex) RLock() {
locking.AddGLock(multiPortEndpointprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *multiPortEndpointRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(multiPortEndpointprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *multiPortEndpointRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *multiPortEndpointRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *multiPortEndpointRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var multiPortEndpointprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func multiPortEndpointinitLockNames() {}
func init() {
multiPortEndpointinitLockNames()
multiPortEndpointprefixIndex = locking.NewMutexClass(reflect.TypeOf(multiPortEndpointRWMutex{}), multiPortEndpointlockNames)
}

View File

@@ -0,0 +1,314 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
)
// NeighborCacheSize is the size of the neighborCache. Exceeding this size will
// result in the least recently used entry being evicted.
const NeighborCacheSize = 512 // max entries per interface
// NeighborStats holds metrics for the neighbor table.
type NeighborStats struct {
// UnreachableEntryLookups counts the number of lookups performed on an
// entry in Unreachable state.
UnreachableEntryLookups *tcpip.StatCounter
}
// +stateify savable
type dynamicCacheEntry struct {
lru neighborEntryList
// count tracks the amount of dynamic entries in the cache. This is
// needed since static entries do not count towards the LRU cache
// eviction strategy.
count uint16
}
// +stateify savable
type neighborCacheMu struct {
neighborCacheRWMutex `state:"nosave"`
cache map[tcpip.Address]*neighborEntry
dynamic dynamicCacheEntry
}
// neighborCache maps IP addresses to link addresses. It uses the Least
// Recently Used (LRU) eviction strategy to implement a bounded cache for
// dynamically acquired entries. It contains the state machine and configuration
// for running Neighbor Unreachability Detection (NUD).
//
// There are two types of entries in the neighbor cache:
// 1. Dynamic entries are discovered automatically by neighbor discovery
// protocols (e.g. ARP, NDP). These protocols will attempt to reconfirm
// reachability with the device once the entry's state becomes Stale.
// 2. Static entries are explicitly added by a user and have no expiration.
// Their state is always Static. The amount of static entries stored in the
// cache is unbounded.
//
// +stateify savable
type neighborCache struct {
nic *nic
state *NUDState
linkRes LinkAddressResolver
mu neighborCacheMu
}
// getOrCreateEntry retrieves a cache entry associated with addr. The
// returned entry is always refreshed in the cache (it is reachable via the
// map, and its place is bumped in LRU).
//
// If a matching entry exists in the cache, it is returned. If no matching
// entry exists and the cache is full, an existing entry is evicted via LRU,
// reset to state incomplete, and returned. If no matching entry exists and the
// cache is not full, a new entry with state incomplete is allocated and
// returned.
func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address) *neighborEntry {
n.mu.Lock()
defer n.mu.Unlock()
if entry, ok := n.mu.cache[remoteAddr]; ok {
entry.mu.RLock()
if entry.mu.neigh.State != Static {
n.mu.dynamic.lru.Remove(entry)
n.mu.dynamic.lru.PushFront(entry)
}
entry.mu.RUnlock()
return entry
}
// The entry that needs to be created must be dynamic since all static
// entries are directly added to the cache via addStaticEntry.
entry := newNeighborEntry(n, remoteAddr, n.state)
if n.mu.dynamic.count == NeighborCacheSize {
e := n.mu.dynamic.lru.Back()
e.mu.Lock()
delete(n.mu.cache, e.mu.neigh.Addr)
n.mu.dynamic.lru.Remove(e)
n.mu.dynamic.count--
e.removeLocked()
e.mu.Unlock()
}
n.mu.cache[remoteAddr] = entry
n.mu.dynamic.lru.PushFront(entry)
n.mu.dynamic.count++
return entry
}
// entry looks up neighbor information matching the remote address, and returns
// it if readily available.
//
// Returns ErrWouldBlock if the link address is not readily available, along
// with a notification channel for the caller to block on. Triggers address
// resolution asynchronously.
//
// If onResolve is provided, it will be called either immediately, if resolution
// is not required, or when address resolution is complete, with the resolved
// link address and whether resolution succeeded. After any callbacks have been
// called, the returned notification channel is closed.
//
// NB: if a callback is provided, it should not call into the neighbor cache.
//
// If specified, the local address must be an address local to the interface the
// neighbor cache belongs to. The local address is the source address of a
// packet prompting NUD/link address resolution.
func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (*neighborEntry, <-chan struct{}, tcpip.Error) {
entry := n.getOrCreateEntry(remoteAddr)
entry.mu.Lock()
defer entry.mu.Unlock()
switch s := entry.mu.neigh.State; s {
case Stale:
entry.handlePacketQueuedLocked(localAddr)
fallthrough
case Reachable, Static, Delay, Probe:
// As per RFC 4861 section 7.3.3:
// "Neighbor Unreachability Detection operates in parallel with the sending
// of packets to a neighbor. While reasserting a neighbor's reachability,
// a node continues sending packets to that neighbor using the cached
// link-layer address."
if onResolve != nil {
onResolve(LinkResolutionResult{LinkAddress: entry.mu.neigh.LinkAddr, Err: nil})
}
return entry, nil, nil
case Unknown, Incomplete, Unreachable:
if onResolve != nil {
entry.mu.onResolve = append(entry.mu.onResolve, onResolve)
}
if entry.mu.done == nil {
// Address resolution needs to be initiated.
entry.mu.done = make(chan struct{})
}
entry.handlePacketQueuedLocked(localAddr)
return entry, entry.mu.done, &tcpip.ErrWouldBlock{}
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", s))
}
}
// entries returns all entries in the neighbor cache.
func (n *neighborCache) entries() []NeighborEntry {
n.mu.RLock()
defer n.mu.RUnlock()
entries := make([]NeighborEntry, 0, len(n.mu.cache))
for _, entry := range n.mu.cache {
entry.mu.RLock()
entries = append(entries, entry.mu.neigh)
entry.mu.RUnlock()
}
return entries
}
// addStaticEntry adds a static entry to the neighbor cache, mapping an IP
// address to a link address. If a dynamic entry exists in the neighbor cache
// with the same address, it will be replaced with this static entry. If a
// static entry exists with the same address but different link address, it
// will be updated with the new link address. If a static entry exists with the
// same address and link address, nothing will happen.
func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAddress) {
n.mu.Lock()
defer n.mu.Unlock()
if entry, ok := n.mu.cache[addr]; ok {
entry.mu.Lock()
if entry.mu.neigh.State != Static {
// Dynamic entry found with the same address.
n.mu.dynamic.lru.Remove(entry)
n.mu.dynamic.count--
} else if entry.mu.neigh.LinkAddr == linkAddr {
// Static entry found with the same address and link address.
entry.mu.Unlock()
return
} else {
// Static entry found with the same address but different link address.
entry.mu.neigh.LinkAddr = linkAddr
entry.dispatchChangeEventLocked()
entry.mu.Unlock()
return
}
entry.removeLocked()
entry.mu.Unlock()
}
entry := newStaticNeighborEntry(n, addr, linkAddr, n.state)
n.mu.cache[addr] = entry
entry.mu.Lock()
defer entry.mu.Unlock()
entry.dispatchAddEventLocked()
}
// removeEntry removes a dynamic or static entry by address from the neighbor
// cache. Returns true if the entry was found and deleted.
func (n *neighborCache) removeEntry(addr tcpip.Address) bool {
n.mu.Lock()
defer n.mu.Unlock()
entry, ok := n.mu.cache[addr]
if !ok {
return false
}
entry.mu.Lock()
defer entry.mu.Unlock()
if entry.mu.neigh.State != Static {
n.mu.dynamic.lru.Remove(entry)
n.mu.dynamic.count--
}
entry.removeLocked()
delete(n.mu.cache, entry.mu.neigh.Addr)
return true
}
// clear removes all dynamic and static entries from the neighbor cache.
func (n *neighborCache) clear() {
n.mu.Lock()
defer n.mu.Unlock()
for _, entry := range n.mu.cache {
entry.mu.Lock()
entry.removeLocked()
entry.mu.Unlock()
}
n.mu.dynamic.lru = neighborEntryList{}
clear(n.mu.cache)
n.mu.dynamic.count = 0
}
// config returns the NUD configuration.
func (n *neighborCache) config() NUDConfigurations {
return n.state.Config()
}
// setConfig changes the NUD configuration.
//
// If config contains invalid NUD configuration values, it will be fixed to
// use default values for the erroneous values.
func (n *neighborCache) setConfig(config NUDConfigurations) {
config.resetInvalidFields()
n.state.SetConfig(config)
}
// handleProbe handles a neighbor probe as defined by RFC 4861 section 7.2.3.
//
// Validation of the probe is expected to be handled by the caller.
func (n *neighborCache) handleProbe(remoteAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) {
entry := n.getOrCreateEntry(remoteAddr)
entry.mu.Lock()
entry.handleProbeLocked(remoteLinkAddr)
entry.mu.Unlock()
}
// handleConfirmation handles a neighbor confirmation as defined by
// RFC 4861 section 7.2.5.
//
// Validation of the confirmation is expected to be handled by the caller.
func (n *neighborCache) handleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) {
n.mu.RLock()
entry, ok := n.mu.cache[addr]
n.mu.RUnlock()
if ok {
entry.mu.Lock()
entry.handleConfirmationLocked(linkAddr, flags)
entry.mu.Unlock()
} else {
// The confirmation SHOULD be silently discarded if the recipient did not
// initiate any communication with the target. This is indicated if there is
// no matching entry for the remote address.
n.nic.stats.neighbor.droppedConfirmationForNoninitiatedNeighbor.Increment()
}
}
func (n *neighborCache) init(nic *nic, r LinkAddressResolver) {
*n = neighborCache{
nic: nic,
state: NewNUDState(nic.stack.nudConfigs, nic.stack.clock, nic.stack.insecureRNG),
linkRes: r,
}
n.mu.Lock()
n.mu.cache = make(map[tcpip.Address]*neighborEntry, NeighborCacheSize)
n.mu.Unlock()
}

View File

@@ -0,0 +1,96 @@
package stack
import (
"reflect"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type neighborCacheRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var neighborCachelockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type neighborCachelockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *neighborCacheRWMutex) Lock() {
locking.AddGLock(neighborCacheprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *neighborCacheRWMutex) NestedLock(i neighborCachelockNameIndex) {
locking.AddGLock(neighborCacheprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *neighborCacheRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(neighborCacheprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *neighborCacheRWMutex) NestedUnlock(i neighborCachelockNameIndex) {
m.mu.Unlock()
locking.DelGLock(neighborCacheprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *neighborCacheRWMutex) RLock() {
locking.AddGLock(neighborCacheprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *neighborCacheRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(neighborCacheprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *neighborCacheRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *neighborCacheRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *neighborCacheRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var neighborCacheprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func neighborCacheinitLockNames() {}
func init() {
neighborCacheinitLockNames()
neighborCacheprefixIndex = locking.NewMutexClass(reflect.TypeOf(neighborCacheRWMutex{}), neighborCachelockNames)
}

View File

@@ -0,0 +1,646 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
const (
// immediateDuration is a duration of zero for scheduling work that needs to
// be done immediately but asynchronously to avoid deadlock.
immediateDuration time.Duration = 0
)
// NeighborEntry describes a neighboring device in the local network.
type NeighborEntry struct {
Addr tcpip.Address
LinkAddr tcpip.LinkAddress
State NeighborState
UpdatedAt tcpip.MonotonicTime
}
// NeighborState defines the state of a NeighborEntry within the Neighbor
// Unreachability Detection state machine, as per RFC 4861 section 7.3.2 and
// RFC 7048.
type NeighborState uint8
const (
// Unknown means reachability has not been verified yet. This is the initial
// state of entries that have been created automatically by the Neighbor
// Unreachability Detection state machine.
Unknown NeighborState = iota
// Incomplete means that there is an outstanding request to resolve the
// address.
Incomplete
// Reachable means the path to the neighbor is functioning properly for both
// receive and transmit paths.
Reachable
// Stale means reachability to the neighbor is unknown, but packets are still
// able to be transmitted to the possibly stale link address.
Stale
// Delay means reachability to the neighbor is unknown and pending
// confirmation from an upper-level protocol like TCP, but packets are still
// able to be transmitted to the possibly stale link address.
Delay
// Probe means a reachability confirmation is actively being sought by
// periodically retransmitting reachability probes until a reachability
// confirmation is received, or until the maximum number of probes has been
// sent.
Probe
// Static describes entries that have been explicitly added by the user. They
// do not expire and are not deleted until explicitly removed.
Static
// Unreachable means reachability confirmation failed; the maximum number of
// reachability probes has been sent and no replies have been received.
//
// TODO(gvisor.dev/issue/5472): Add the following sentence when we implement
// RFC 7048: "Packets continue to be sent to the neighbor while
// re-attempting to resolve the address."
Unreachable
)
type timer struct {
// done indicates to the timer that the timer was stopped.
done *bool
timer tcpip.Timer
}
// neighborEntry implements a neighbor entry's individual node behavior, as per
// RFC 4861 section 7.3.3. Neighbor Unreachability Detection operates in
// parallel with the sending of packets to a neighbor, necessitating the
// entry's lock to be acquired for all operations.
type neighborEntry struct {
neighborEntryEntry
cache *neighborCache
// nudState points to the Neighbor Unreachability Detection configuration.
nudState *NUDState
mu struct {
neighborEntryRWMutex
neigh NeighborEntry
// done is closed when address resolution is complete. It is nil iff s is
// incomplete and resolution is not yet in progress.
done chan struct{}
// onResolve is called with the result of address resolution.
onResolve []func(LinkResolutionResult)
isRouter bool
timer timer
}
}
// newNeighborEntry creates a neighbor cache entry starting at the default
// state, Unknown. Transition out of Unknown by calling either
// `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created
// neighborEntry.
func newNeighborEntry(cache *neighborCache, remoteAddr tcpip.Address, nudState *NUDState) *neighborEntry {
n := &neighborEntry{
cache: cache,
nudState: nudState,
}
n.mu.Lock()
n.mu.neigh = NeighborEntry{
Addr: remoteAddr,
State: Unknown,
}
n.mu.Unlock()
return n
}
// newStaticNeighborEntry creates a neighbor cache entry starting at the
// Static state. The entry can only transition out of Static by directly
// calling `setStateLocked`.
func newStaticNeighborEntry(cache *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry {
entry := NeighborEntry{
Addr: addr,
LinkAddr: linkAddr,
State: Static,
UpdatedAt: cache.nic.stack.clock.NowMonotonic(),
}
n := &neighborEntry{
cache: cache,
nudState: state,
}
n.mu.Lock()
n.mu.neigh = entry
n.mu.Unlock()
return n
}
// notifyCompletionLocked notifies those waiting for address resolution, with
// the link address if resolution completed successfully.
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) notifyCompletionLocked(err tcpip.Error) {
res := LinkResolutionResult{LinkAddress: e.mu.neigh.LinkAddr, Err: err}
for _, callback := range e.mu.onResolve {
callback(res)
}
e.mu.onResolve = nil
if ch := e.mu.done; ch != nil {
close(ch)
e.mu.done = nil
// Dequeue the pending packets asynchronously to not hold up the current
// goroutine as writing packets may be a costly operation.
//
// At the time of writing, when writing packets, a neighbor's link address
// is resolved (which ends up obtaining the entry's lock) while holding the
// link resolution queue's lock. Dequeuing packets asynchronously avoids a
// lock ordering violation.
//
// NB: this is equivalent to spawning a goroutine directly using the go
// keyword but allows tests that use manual clocks to deterministically
// wait for this work to complete.
e.cache.nic.stack.clock.AfterFunc(0, func() {
e.cache.nic.linkResQueue.dequeue(ch, e.mu.neigh.LinkAddr, err)
})
}
}
// dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has
// been added.
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) dispatchAddEventLocked() {
if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil {
nudDisp.OnNeighborAdded(e.cache.nic.id, e.mu.neigh)
}
}
// dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry
// has changed state or link-layer address.
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) dispatchChangeEventLocked() {
if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil {
nudDisp.OnNeighborChanged(e.cache.nic.id, e.mu.neigh)
}
}
// dispatchRemoveEventLocked signals to stack's NUD Dispatcher that the entry
// has been removed.
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) dispatchRemoveEventLocked() {
if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil {
nudDisp.OnNeighborRemoved(e.cache.nic.id, e.mu.neigh)
}
}
// cancelTimerLocked cancels the currently scheduled action, if there is one.
// Entries in Unknown, Stale, or Static state do not have a scheduled action.
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) cancelTimerLocked() {
if e.mu.timer.timer != nil {
e.mu.timer.timer.Stop()
*e.mu.timer.done = true
e.mu.timer = timer{}
}
}
// removeLocked prepares the entry for removal.
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) removeLocked() {
e.mu.neigh.UpdatedAt = e.cache.nic.stack.clock.NowMonotonic()
e.dispatchRemoveEventLocked()
// Set state to unknown to invalidate this entry if it's cached in a Route.
e.setStateLocked(Unknown)
e.cancelTimerLocked()
// TODO(https://gvisor.dev/issues/5583): test the case where this function is
// called during resolution; that can happen in at least these scenarios:
//
// - manual address removal during resolution
//
// - neighbor cache eviction during resolution
e.notifyCompletionLocked(&tcpip.ErrAborted{})
}
// setStateLocked transitions the entry to the specified state immediately.
//
// Follows the logic defined in RFC 4861 section 7.3.3.
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) setStateLocked(next NeighborState) {
e.cancelTimerLocked()
prev := e.mu.neigh.State
e.mu.neigh.State = next
e.mu.neigh.UpdatedAt = e.cache.nic.stack.clock.NowMonotonic()
config := e.nudState.Config()
switch next {
case Incomplete:
panic(fmt.Sprintf("should never transition to Incomplete with setStateLocked; neigh = %#v, prev state = %s", e.mu.neigh, prev))
case Reachable:
// Protected by e.mu.
done := false
e.mu.timer = timer{
done: &done,
timer: e.cache.nic.stack.Clock().AfterFunc(e.nudState.ReachableTime(), func() {
e.mu.Lock()
defer e.mu.Unlock()
if done {
// The timer was stopped because the entry changed state.
return
}
e.setStateLocked(Stale)
e.dispatchChangeEventLocked()
}),
}
case Delay:
// Protected by e.mu.
done := false
e.mu.timer = timer{
done: &done,
timer: e.cache.nic.stack.Clock().AfterFunc(config.DelayFirstProbeTime, func() {
e.mu.Lock()
defer e.mu.Unlock()
if done {
// The timer was stopped because the entry changed state.
return
}
e.setStateLocked(Probe)
e.dispatchChangeEventLocked()
}),
}
case Probe:
// Protected by e.mu.
done := false
remaining := config.MaxUnicastProbes
addr := e.mu.neigh.Addr
linkAddr := e.mu.neigh.LinkAddr
// Send a probe in another gorountine to free this thread of execution
// for finishing the state transition. This is necessary to escape the
// currently held lock so we can send the probe message without holding
// a shared lock.
e.mu.timer = timer{
done: &done,
timer: e.cache.nic.stack.Clock().AfterFunc(immediateDuration, func() {
var err tcpip.Error = &tcpip.ErrTimeout{}
if remaining != 0 {
err = e.cache.linkRes.LinkAddressRequest(addr, tcpip.Address{} /* localAddr */, linkAddr)
}
e.mu.Lock()
defer e.mu.Unlock()
if done {
// The timer was stopped because the entry changed state.
return
}
if err != nil {
e.setStateLocked(Unreachable)
e.notifyCompletionLocked(err)
e.dispatchChangeEventLocked()
return
}
remaining--
e.mu.timer.timer.Reset(config.RetransmitTimer)
}),
}
case Unreachable:
case Unknown, Stale, Static:
// Do nothing
default:
panic(fmt.Sprintf("Invalid state transition from %q to %q", prev, next))
}
}
// handlePacketQueuedLocked advances the state machine according to a packet
// being queued for outgoing transmission.
//
// Follows the logic defined in RFC 4861 section 7.3.3.
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
switch e.mu.neigh.State {
case Unknown, Unreachable:
prev := e.mu.neigh.State
e.mu.neigh.State = Incomplete
e.mu.neigh.UpdatedAt = e.cache.nic.stack.clock.NowMonotonic()
switch prev {
case Unknown:
e.dispatchAddEventLocked()
case Unreachable:
e.dispatchChangeEventLocked()
e.cache.nic.stats.neighbor.unreachableEntryLookups.Increment()
}
config := e.nudState.Config()
// Protected by e.mu.
done := false
remaining := config.MaxMulticastProbes
addr := e.mu.neigh.Addr
// Send a probe in another gorountine to free this thread of execution
// for finishing the state transition. This is necessary to escape the
// currently held lock so we can send the probe message without holding
// a shared lock.
e.mu.timer = timer{
done: &done,
timer: e.cache.nic.stack.Clock().AfterFunc(immediateDuration, func() {
var err tcpip.Error = &tcpip.ErrTimeout{}
if remaining != 0 {
// As per RFC 4861 section 7.2.2:
//
// If the source address of the packet prompting the solicitation is
// the same as one of the addresses assigned to the outgoing interface,
// that address SHOULD be placed in the IP Source Address of the
// outgoing solicitation.
//
err = e.cache.linkRes.LinkAddressRequest(addr, localAddr, "" /* linkAddr */)
}
e.mu.Lock()
defer e.mu.Unlock()
if done {
// The timer was stopped because the entry changed state.
return
}
if err != nil {
e.setStateLocked(Unreachable)
e.notifyCompletionLocked(err)
e.dispatchChangeEventLocked()
return
}
remaining--
e.mu.timer.timer.Reset(config.RetransmitTimer)
}),
}
case Stale:
e.setStateLocked(Delay)
e.dispatchChangeEventLocked()
case Incomplete, Reachable, Delay, Probe, Static:
// Do nothing
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", e.mu.neigh.State))
}
}
// handleProbeLocked processes an incoming neighbor probe (e.g. ARP request or
// Neighbor Solicitation for ARP or NDP, respectively).
//
// Follows the logic defined in RFC 4861 section 7.2.3.
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) {
// Probes MUST be silently discarded if the target address is tentative, does
// not exist, or not bound to the NIC as per RFC 4861 section 7.2.3. These
// checks MUST be done by the NetworkEndpoint.
switch e.mu.neigh.State {
case Unknown:
e.mu.neigh.LinkAddr = remoteLinkAddr
e.setStateLocked(Stale)
e.dispatchAddEventLocked()
case Incomplete:
// "If an entry already exists, and the cached link-layer address
// differs from the one in the received Source Link-Layer option, the
// cached address should be replaced by the received address, and the
// entry's reachability state MUST be set to STALE."
// - RFC 4861 section 7.2.3
e.mu.neigh.LinkAddr = remoteLinkAddr
e.setStateLocked(Stale)
e.notifyCompletionLocked(nil)
e.dispatchChangeEventLocked()
case Reachable, Delay, Probe:
if e.mu.neigh.LinkAddr != remoteLinkAddr {
e.mu.neigh.LinkAddr = remoteLinkAddr
e.setStateLocked(Stale)
e.dispatchChangeEventLocked()
}
case Stale:
if e.mu.neigh.LinkAddr != remoteLinkAddr {
e.mu.neigh.LinkAddr = remoteLinkAddr
e.dispatchChangeEventLocked()
}
case Unreachable:
// TODO(gvisor.dev/issue/5472): Do not change the entry if the link
// address is the same, as per RFC 7048.
e.mu.neigh.LinkAddr = remoteLinkAddr
e.setStateLocked(Stale)
e.dispatchChangeEventLocked()
case Static:
// Do nothing
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", e.mu.neigh.State))
}
}
// handleConfirmationLocked processes an incoming neighbor confirmation
// (e.g. ARP reply or Neighbor Advertisement for ARP or NDP, respectively).
//
// Follows the state machine defined by RFC 4861 section 7.2.5.
//
// TODO(gvisor.dev/issue/2277): To protect against ARP poisoning and other
// attacks against NDP functions, Secure Neighbor Discovery (SEND) Protocol
// should be deployed where preventing access to the broadcast segment might
// not be possible. SEND uses RSA key pairs to produce Cryptographically
// Generated Addresses (CGA), as defined in RFC 3972. This ensures that the
// claimed source of an NDP message is the owner of the claimed address.
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) {
switch e.mu.neigh.State {
case Incomplete:
if len(linkAddr) == 0 {
// "If the link layer has addresses and no Target Link-Layer Address
// option is included, the receiving node SHOULD silently discard the
// received advertisement." - RFC 4861 section 7.2.5
e.cache.nic.stats.neighbor.droppedInvalidLinkAddressConfirmations.Increment()
break
}
e.mu.neigh.LinkAddr = linkAddr
if flags.Solicited {
e.setStateLocked(Reachable)
} else {
e.setStateLocked(Stale)
}
e.dispatchChangeEventLocked()
e.mu.isRouter = flags.IsRouter
e.notifyCompletionLocked(nil)
// "Note that the Override flag is ignored if the entry is in the
// INCOMPLETE state." - RFC 4861 section 7.2.5
case Reachable, Stale, Delay, Probe:
isLinkAddrDifferent := len(linkAddr) != 0 && e.mu.neigh.LinkAddr != linkAddr
if isLinkAddrDifferent {
if !flags.Override {
if e.mu.neigh.State == Reachable {
e.setStateLocked(Stale)
e.dispatchChangeEventLocked()
}
break
}
e.mu.neigh.LinkAddr = linkAddr
if !flags.Solicited {
if e.mu.neigh.State != Stale {
e.setStateLocked(Stale)
e.dispatchChangeEventLocked()
} else {
// Notify the LinkAddr change, even though NUD state hasn't changed.
e.dispatchChangeEventLocked()
}
break
}
}
if flags.Solicited && (flags.Override || !isLinkAddrDifferent) {
wasReachable := e.mu.neigh.State == Reachable
// Set state to Reachable again to refresh timers.
e.setStateLocked(Reachable)
e.notifyCompletionLocked(nil)
if !wasReachable {
e.dispatchChangeEventLocked()
}
}
if e.mu.isRouter && !flags.IsRouter && header.IsV6UnicastAddress(e.mu.neigh.Addr) {
// "In those cases where the IsRouter flag changes from TRUE to FALSE as
// a result of this update, the node MUST remove that router from the
// Default Router List and update the Destination Cache entries for all
// destinations using that neighbor as a router as specified in Section
// 7.3.3. This is needed to detect when a node that is used as a router
// stops forwarding packets due to being configured as a host."
// - RFC 4861 section 7.2.5
//
// TODO(gvisor.dev/issue/4085): Remove the special casing we do for IPv6
// here.
ep := e.cache.nic.getNetworkEndpoint(header.IPv6ProtocolNumber)
if ep == nil {
panic(fmt.Sprintf("have a neighbor entry for an IPv6 router but no IPv6 network endpoint"))
}
if ndpEP, ok := ep.(NDPEndpoint); ok {
ndpEP.InvalidateDefaultRouter(e.mu.neigh.Addr)
}
}
e.mu.isRouter = flags.IsRouter
case Unknown, Unreachable, Static:
// Do nothing
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", e.mu.neigh.State))
}
}
// handleUpperLevelConfirmation processes an incoming upper-level protocol
// (e.g. TCP acknowledgements) reachability confirmation.
func (e *neighborEntry) handleUpperLevelConfirmation() {
tryHandleConfirmation := func() bool {
switch e.mu.neigh.State {
case Stale, Delay, Probe:
return true
case Reachable:
// Avoid setStateLocked; Timer.Reset is cheaper.
//
// Note that setting the timer does not need to be protected by the
// entry's write lock since we do not modify the timer pointer, but the
// time the timer should fire. The timer should have internal locks to
// synchronize timer resets changes with the clock.
e.mu.timer.timer.Reset(e.nudState.ReachableTime())
return false
case Unknown, Incomplete, Unreachable, Static:
// Do nothing
return false
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", e.mu.neigh.State))
}
}
e.mu.RLock()
needsTransition := tryHandleConfirmation()
e.mu.RUnlock()
if !needsTransition {
return
}
// We need to transition the neighbor to Reachable so take the write lock and
// perform the transition, but only if we still need the transition since the
// state could have changed since we dropped the read lock above.
e.mu.Lock()
defer e.mu.Unlock()
if needsTransition := tryHandleConfirmation(); needsTransition {
e.setStateLocked(Reachable)
e.dispatchChangeEventLocked()
}
}
// getRemoteLinkAddress returns the entry's link address and whether that link
// address is valid.
func (e *neighborEntry) getRemoteLinkAddress() (tcpip.LinkAddress, bool) {
e.mu.RLock()
defer e.mu.RUnlock()
switch e.mu.neigh.State {
case Reachable, Static, Delay, Probe:
return e.mu.neigh.LinkAddr, true
case Unknown, Incomplete, Unreachable, Stale:
return "", false
default:
panic(fmt.Sprintf("invalid state for neighbor entry %v: %v", e.mu.neigh, e.mu.neigh.State))
}
}

Some files were not shown because too many files have changed in this diff Show More