395 lines
8.7 KiB
Go
395 lines
8.7 KiB
Go
package rtnetlink
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"errors"
|
|
"net"
|
|
|
|
"github.com/jsimonetti/rtnetlink/internal/unix"
|
|
"github.com/mdlayher/netlink"
|
|
)
|
|
|
|
var (
|
|
// errInvalidRuleMessage is returned when a RuleMessage is malformed.
|
|
errInvalidRuleMessage = errors.New("rtnetlink RuleMessage is invalid or too short")
|
|
|
|
// errInvalidRuleAttribute is returned when a RuleMessage contains an unknown attribute.
|
|
errInvalidRuleAttribute = errors.New("rtnetlink RuleMessage contains an unknown Attribute")
|
|
)
|
|
|
|
var _ Message = &RuleMessage{}
|
|
|
|
// A RuleMessage is a route netlink link message.
|
|
type RuleMessage struct {
|
|
// Address family
|
|
Family uint8
|
|
|
|
// Length of destination prefix
|
|
DstLength uint8
|
|
|
|
// Length of source prefix
|
|
SrcLength uint8
|
|
|
|
// Rule TOS
|
|
TOS uint8
|
|
|
|
// Routing table identifier
|
|
Table uint8
|
|
|
|
// Rule action
|
|
Action uint8
|
|
|
|
// Rule flags
|
|
Flags uint32
|
|
|
|
// Attributes List
|
|
Attributes *RuleAttributes
|
|
}
|
|
|
|
// MarshalBinary marshals a LinkMessage into a byte slice.
|
|
func (m *RuleMessage) MarshalBinary() ([]byte, error) {
|
|
b := make([]byte, 12)
|
|
|
|
// fib_rule_hdr
|
|
b[0] = m.Family
|
|
b[1] = m.DstLength
|
|
b[2] = m.SrcLength
|
|
b[3] = m.TOS
|
|
b[4] = m.Table
|
|
b[7] = m.Action
|
|
nativeEndian.PutUint32(b[8:12], m.Flags)
|
|
|
|
if m.Attributes != nil {
|
|
ae := netlink.NewAttributeEncoder()
|
|
ae.ByteOrder = nativeEndian
|
|
err := m.Attributes.encode(ae)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
a, err := ae.Encode()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return append(b, a...), nil
|
|
}
|
|
|
|
return b, nil
|
|
}
|
|
|
|
// UnmarshalBinary unmarshals the contents of a byte slice into a LinkMessage.
|
|
func (m *RuleMessage) UnmarshalBinary(b []byte) error {
|
|
l := len(b)
|
|
if l < 12 {
|
|
return errInvalidRuleMessage
|
|
}
|
|
m.Family = b[0]
|
|
m.DstLength = b[1]
|
|
m.SrcLength = b[2]
|
|
m.TOS = b[3]
|
|
m.Table = b[4]
|
|
// b[5] and b[6] are reserved fields
|
|
m.Action = b[7]
|
|
m.Flags = nativeEndian.Uint32(b[8:12])
|
|
|
|
if l > 12 {
|
|
m.Attributes = &RuleAttributes{}
|
|
ad, err := netlink.NewAttributeDecoder(b[12:])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ad.ByteOrder = nativeEndian
|
|
return m.Attributes.decode(ad)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// rtMessage is an empty method to sattisfy the Message interface.
|
|
func (*RuleMessage) rtMessage() {}
|
|
|
|
// RuleService is used to retrieve rtnetlink family information.
|
|
type RuleService struct {
|
|
c *Conn
|
|
}
|
|
|
|
func (r *RuleService) execute(m Message, family uint16, flags netlink.HeaderFlags) ([]RuleMessage, error) {
|
|
msgs, err := r.c.Execute(m, family, flags)
|
|
|
|
rules := make([]RuleMessage, len(msgs))
|
|
for i := range msgs {
|
|
rules[i] = *msgs[i].(*RuleMessage)
|
|
}
|
|
|
|
return rules, err
|
|
}
|
|
|
|
// Add new rule
|
|
func (r *RuleService) Add(req *RuleMessage) error {
|
|
flags := netlink.Request | netlink.Create | netlink.Acknowledge | netlink.Excl
|
|
_, err := r.c.Execute(req, unix.RTM_NEWRULE, flags)
|
|
|
|
return err
|
|
}
|
|
|
|
// Replace or add new rule
|
|
func (r *RuleService) Replace(req *RuleMessage) error {
|
|
flags := netlink.Request | netlink.Create | netlink.Replace | netlink.Acknowledge
|
|
_, err := r.c.Execute(req, unix.RTM_NEWRULE, flags)
|
|
|
|
return err
|
|
}
|
|
|
|
// Delete existing rule
|
|
func (r *RuleService) Delete(req *RuleMessage) error {
|
|
flags := netlink.Request | netlink.Acknowledge
|
|
_, err := r.c.Execute(req, unix.RTM_DELRULE, flags)
|
|
|
|
return err
|
|
}
|
|
|
|
// Get Rule(s)
|
|
func (r *RuleService) Get(req *RuleMessage) ([]RuleMessage, error) {
|
|
flags := netlink.Request | netlink.DumpFiltered
|
|
return r.execute(req, unix.RTM_GETRULE, flags)
|
|
}
|
|
|
|
// List all rules
|
|
func (r *RuleService) List() ([]RuleMessage, error) {
|
|
flags := netlink.Request | netlink.Dump
|
|
return r.execute(&RuleMessage{}, unix.RTM_GETRULE, flags)
|
|
}
|
|
|
|
// RuleAttributes contains all attributes for a rule.
|
|
type RuleAttributes struct {
|
|
Src, Dst *net.IP
|
|
IIFName, OIFName *string
|
|
Goto *uint32
|
|
Priority *uint32
|
|
FwMark, FwMask *uint32
|
|
SrcRealm *uint16
|
|
DstRealm *uint16
|
|
TunID *uint64
|
|
Table *uint32
|
|
L3MDev *uint8
|
|
Protocol *uint8
|
|
IPProto *uint8
|
|
SuppressPrefixLen *uint32
|
|
SuppressIFGroup *uint32
|
|
UIDRange *RuleUIDRange
|
|
SPortRange *RulePortRange
|
|
DPortRange *RulePortRange
|
|
}
|
|
|
|
// unmarshalBinary unmarshals the contents of a byte slice into a RuleMessage.
|
|
func (r *RuleAttributes) decode(ad *netlink.AttributeDecoder) error {
|
|
for ad.Next() {
|
|
switch ad.Type() {
|
|
case unix.FRA_UNSPEC:
|
|
// unused
|
|
continue
|
|
case unix.FRA_DST:
|
|
r.Dst = &net.IP{}
|
|
ad.Do(decodeIP(r.Dst))
|
|
case unix.FRA_SRC:
|
|
r.Src = &net.IP{}
|
|
ad.Do(decodeIP(r.Src))
|
|
case unix.FRA_IIFNAME:
|
|
v := ad.String()
|
|
r.IIFName = &v
|
|
case unix.FRA_GOTO:
|
|
v := ad.Uint32()
|
|
r.Goto = &v
|
|
case unix.FRA_UNUSED2:
|
|
// unused
|
|
continue
|
|
case unix.FRA_PRIORITY:
|
|
v := ad.Uint32()
|
|
r.Priority = &v
|
|
case unix.FRA_UNUSED3:
|
|
// unused
|
|
continue
|
|
case unix.FRA_UNUSED4:
|
|
// unused
|
|
continue
|
|
case unix.FRA_UNUSED5:
|
|
// unused
|
|
continue
|
|
case unix.FRA_FWMARK:
|
|
v := ad.Uint32()
|
|
r.FwMark = &v
|
|
case unix.FRA_FLOW:
|
|
dst32 := ad.Uint32()
|
|
src32 := uint32(dst32 >> 16)
|
|
src32 &= 0xFFFF
|
|
dst32 &= 0xFFFF
|
|
src16 := uint16(src32)
|
|
dst16 := uint16(dst32)
|
|
r.SrcRealm = &src16
|
|
r.DstRealm = &dst16
|
|
case unix.FRA_TUN_ID:
|
|
v := ad.Uint64()
|
|
r.TunID = &v
|
|
case unix.FRA_SUPPRESS_IFGROUP:
|
|
v := ad.Uint32()
|
|
r.SuppressIFGroup = &v
|
|
case unix.FRA_SUPPRESS_PREFIXLEN:
|
|
v := ad.Uint32()
|
|
r.SuppressPrefixLen = &v
|
|
case unix.FRA_TABLE:
|
|
v := ad.Uint32()
|
|
r.Table = &v
|
|
case unix.FRA_FWMASK:
|
|
v := ad.Uint32()
|
|
r.FwMask = &v
|
|
case unix.FRA_OIFNAME:
|
|
v := ad.String()
|
|
r.OIFName = &v
|
|
case unix.FRA_PAD:
|
|
// unused
|
|
continue
|
|
case unix.FRA_L3MDEV:
|
|
v := ad.Uint8()
|
|
r.L3MDev = &v
|
|
case unix.FRA_UID_RANGE:
|
|
r.UIDRange = &RuleUIDRange{}
|
|
err := r.UIDRange.unmarshalBinary(ad.Bytes())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
case unix.FRA_PROTOCOL:
|
|
v := ad.Uint8()
|
|
r.Protocol = &v
|
|
case unix.FRA_IP_PROTO:
|
|
v := ad.Uint8()
|
|
r.IPProto = &v
|
|
case unix.FRA_SPORT_RANGE:
|
|
r.SPortRange = &RulePortRange{}
|
|
err := r.SPortRange.unmarshalBinary(ad.Bytes())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
case unix.FRA_DPORT_RANGE:
|
|
r.DPortRange = &RulePortRange{}
|
|
err := r.DPortRange.unmarshalBinary(ad.Bytes())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
default:
|
|
return errInvalidRuleAttribute
|
|
}
|
|
}
|
|
return ad.Err()
|
|
}
|
|
|
|
// MarshalBinary marshals a RuleAttributes into a byte slice.
|
|
func (r *RuleAttributes) encode(ae *netlink.AttributeEncoder) error {
|
|
if r.Table != nil {
|
|
ae.Uint32(unix.FRA_TABLE, *r.Table)
|
|
}
|
|
if r.Protocol != nil {
|
|
ae.Uint8(unix.FRA_PROTOCOL, *r.Protocol)
|
|
}
|
|
if r.Src != nil {
|
|
ae.Do(unix.FRA_SRC, encodeIP(*r.Src))
|
|
}
|
|
if r.Dst != nil {
|
|
ae.Do(unix.FRA_DST, encodeIP(*r.Dst))
|
|
}
|
|
if r.IIFName != nil {
|
|
ae.String(unix.FRA_IIFNAME, *r.IIFName)
|
|
}
|
|
if r.OIFName != nil {
|
|
ae.String(unix.FRA_OIFNAME, *r.OIFName)
|
|
}
|
|
if r.Goto != nil {
|
|
ae.Uint32(unix.FRA_GOTO, *r.Goto)
|
|
}
|
|
if r.Priority != nil {
|
|
ae.Uint32(unix.FRA_PRIORITY, *r.Priority)
|
|
}
|
|
if r.FwMark != nil {
|
|
ae.Uint32(unix.FRA_FWMARK, *r.FwMark)
|
|
}
|
|
if r.FwMask != nil {
|
|
ae.Uint32(unix.FRA_FWMASK, *r.FwMask)
|
|
}
|
|
if r.DstRealm != nil {
|
|
value := uint32(*r.DstRealm)
|
|
if r.SrcRealm != nil {
|
|
value |= (uint32(*r.SrcRealm&0xFFFF) << 16)
|
|
}
|
|
ae.Uint32(unix.FRA_FLOW, value)
|
|
}
|
|
if r.TunID != nil {
|
|
ae.Uint64(unix.FRA_TUN_ID, *r.TunID)
|
|
}
|
|
if r.L3MDev != nil {
|
|
ae.Uint8(unix.FRA_L3MDEV, *r.L3MDev)
|
|
}
|
|
if r.IPProto != nil {
|
|
ae.Uint8(unix.FRA_IP_PROTO, *r.IPProto)
|
|
}
|
|
if r.SuppressIFGroup != nil {
|
|
ae.Uint32(unix.FRA_SUPPRESS_IFGROUP, *r.SuppressIFGroup)
|
|
}
|
|
if r.SuppressPrefixLen != nil {
|
|
ae.Uint32(unix.FRA_SUPPRESS_PREFIXLEN, *r.SuppressPrefixLen)
|
|
}
|
|
if r.UIDRange != nil {
|
|
data, err := marshalRuleUIDRange(*r.UIDRange)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ae.Bytes(unix.FRA_UID_RANGE, data)
|
|
}
|
|
if r.SPortRange != nil {
|
|
data, err := marshalRulePortRange(*r.SPortRange)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ae.Bytes(unix.FRA_SPORT_RANGE, data)
|
|
}
|
|
if r.DPortRange != nil {
|
|
data, err := marshalRulePortRange(*r.DPortRange)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ae.Bytes(unix.FRA_DPORT_RANGE, data)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// RulePortRange defines start and end ports for a rule
|
|
type RulePortRange struct {
|
|
Start, End uint16
|
|
}
|
|
|
|
func (r *RulePortRange) unmarshalBinary(data []byte) error {
|
|
b := bytes.NewReader(data)
|
|
return binary.Read(b, nativeEndian, r)
|
|
}
|
|
|
|
func marshalRulePortRange(s RulePortRange) ([]byte, error) {
|
|
var buf bytes.Buffer
|
|
err := binary.Write(&buf, nativeEndian, s)
|
|
return buf.Bytes(), err
|
|
}
|
|
|
|
// RuleUIDRange defines the start and end for UID matches
|
|
type RuleUIDRange struct {
|
|
Start, End uint16
|
|
}
|
|
|
|
func (r *RuleUIDRange) unmarshalBinary(data []byte) error {
|
|
b := bytes.NewReader(data)
|
|
return binary.Read(b, nativeEndian, r)
|
|
}
|
|
|
|
func marshalRuleUIDRange(s RuleUIDRange) ([]byte, error) {
|
|
var buf bytes.Buffer
|
|
err := binary.Write(&buf, nativeEndian, s)
|
|
return buf.Bytes(), err
|
|
}
|