Files
tsnet-proxy/vendor/github.com/mdlayher/netlink/nltest/nltest.go
2024-11-01 17:43:06 +00:00

208 lines
6.0 KiB
Go

// Package nltest provides utilities for netlink testing.
package nltest
import (
"fmt"
"io"
"os"
"github.com/mdlayher/netlink"
"github.com/mdlayher/netlink/nlenc"
)
// PID is the netlink header PID value assigned by nltest.
const PID = 1
// MustMarshalAttributes marshals a slice of netlink.Attributes to their binary
// format, but panics if any errors occur.
func MustMarshalAttributes(attrs []netlink.Attribute) []byte {
b, err := netlink.MarshalAttributes(attrs)
if err != nil {
panic(fmt.Sprintf("failed to marshal attributes to binary: %v", err))
}
return b
}
// Multipart sends a slice of netlink.Messages to the caller as a
// netlink multi-part message. If less than two messages are present,
// the messages are not altered.
func Multipart(msgs []netlink.Message) ([]netlink.Message, error) {
if len(msgs) < 2 {
return msgs, nil
}
for i := range msgs {
// Last message has header type "done" in addition to multi-part flag.
if i == len(msgs)-1 {
msgs[i].Header.Type = netlink.Done
}
msgs[i].Header.Flags |= netlink.Multi
}
return msgs, nil
}
// Error returns a netlink error to the caller with the specified error
// number, in the body of the specified request message.
func Error(number int, reqs []netlink.Message) ([]netlink.Message, error) {
req := reqs[0]
req.Header.Length += 4
req.Header.Type = netlink.Error
errno := -1 * int32(number)
req.Data = append(nlenc.Int32Bytes(errno), req.Data...)
return []netlink.Message{req}, nil
}
// A Func is a function that can be used to test netlink.Conn interactions.
// The function can choose to return zero or more netlink messages, or an
// error if needed.
//
// For a netlink request/response interaction, a request req is populated by
// netlink.Conn.Send and passed to the function.
//
// For multicast interactions, an empty request req is passed to the function
// when netlink.Conn.Receive is called.
//
// If a Func returns an error, the error will be returned as-is to the caller.
// If no messages and io.EOF are returned, no messages and no error will be
// returned to the caller, simulating a multi-part message with no data.
type Func func(req []netlink.Message) ([]netlink.Message, error)
// Dial sets up a netlink.Conn for testing using the specified Func. All requests
// sent from the connection will be passed to the Func. The connection should be
// closed as usual when it is no longer needed.
func Dial(fn Func) *netlink.Conn {
sock := &socket{
fn: fn,
}
return netlink.NewConn(sock, PID)
}
// CheckRequest returns a Func that verifies that each message in an incoming
// request has the specified netlink header type and flags in the same slice
// position index, and then passes the request through to fn.
//
// The length of the types and flags slices must match the number of requests
// passed to the returned Func, or CheckRequest will panic.
//
// As an example:
// - types[0] and flags[0] will be checked against reqs[0]
// - types[1] and flags[1] will be checked against reqs[1]
// - ... and so on
//
// If an element of types or flags is set to the zero value, that check will
// be skipped for the request message that occurs at the same index.
//
// As an example, if types[0] is 0 and reqs[0].Header.Type is 1, the check will
// succeed because types[0] was not specified.
func CheckRequest(types []netlink.HeaderType, flags []netlink.HeaderFlags, fn Func) Func {
if len(types) != len(flags) {
panicf("nltest: CheckRequest called with mismatched types and flags slice lengths: %d != %d",
len(types), len(flags))
}
return func(req []netlink.Message) ([]netlink.Message, error) {
if len(types) != len(req) {
panicf("nltest: CheckRequest function invoked types/flags and request message slice lengths: %d != %d",
len(types), len(req))
}
for i := range req {
if want, got := types[i], req[i].Header.Type; types[i] != 0 && want != got {
return nil, fmt.Errorf("nltest: unexpected netlink header type: %s, want: %s", got, want)
}
if want, got := flags[i], req[i].Header.Flags; flags[i] != 0 && want != got {
return nil, fmt.Errorf("nltest: unexpected netlink header flags: %s, want: %s", got, want)
}
}
return fn(req)
}
}
// A socket is a netlink.Socket used for testing.
type socket struct {
fn Func
msgs []netlink.Message
err error
}
func (c *socket) Close() error { return nil }
func (c *socket) SendMessages(messages []netlink.Message) error {
msgs, err := c.fn(messages)
c.msgs = append(c.msgs, msgs...)
c.err = err
return nil
}
func (c *socket) Send(m netlink.Message) error {
c.msgs, c.err = c.fn([]netlink.Message{m})
return nil
}
func (c *socket) Receive() ([]netlink.Message, error) {
// No messages set by Send means that we are emulating a
// multicast response or an error occurred.
if len(c.msgs) == 0 {
switch c.err {
case nil:
// No error, simulate multicast, but also return EOF to simulate
// no replies if needed.
msgs, err := c.fn(nil)
if err == io.EOF {
err = nil
}
return msgs, err
case io.EOF:
// EOF, simulate no replies in multi-part message.
return nil, nil
}
// If the error is a system call error, wrap it in os.NewSyscallError
// to simulate what the Linux netlink.Conn does.
if isSyscallError(c.err) {
return nil, os.NewSyscallError("recvmsg", c.err)
}
// Some generic error occurred and should be passed to the caller.
return nil, c.err
}
// Detect multi-part messages.
var multi bool
for _, m := range c.msgs {
if m.Header.Flags&netlink.Multi != 0 && m.Header.Type != netlink.Done {
multi = true
}
}
// When a multi-part message is detected, return all messages except for the
// final "multi-part done", so that a second call to Receive from netlink.Conn
// will drain that message.
if multi {
last := c.msgs[len(c.msgs)-1]
ret := c.msgs[:len(c.msgs)-1]
c.msgs = []netlink.Message{last}
return ret, c.err
}
msgs, err := c.msgs, c.err
c.msgs, c.err = nil, nil
return msgs, err
}
func panicf(format string, a ...interface{}) {
panic(fmt.Sprintf(format, a...))
}