Files
tsnet-proxy/vendor/github.com/akutz/memconn/memconn_provider.go
2024-11-01 17:43:06 +00:00

246 lines
5.7 KiB
Go

package memconn
import (
"context"
"errors"
"fmt"
"net"
"sync"
"time"
)
// Provider is used to track named MemConn objects.
type Provider struct {
nets networkMap
listeners listenerCache
}
type listenerCache struct {
sync.RWMutex
cache map[string]*Listener
}
type networkMap struct {
sync.RWMutex
cache map[string]string
}
// MapNetwork enables mapping the network value provided to this Provider's
// Dial and Listen functions from the specified "from" value to the
// specified "to" value.
//
// For example, calling MapNetwork("tcp", "memu") means a subsequent
// Dial("tcp", "address") gets translated to Dial("memu", "address").
//
// Calling MapNetwork("tcp", "") removes any previous translation for
// the "tcp" network.
func (p *Provider) MapNetwork(from, to string) {
p.nets.Lock()
defer p.nets.Unlock()
if p.nets.cache == nil {
p.nets.cache = map[string]string{}
}
if to == "" {
delete(p.nets.cache, from)
return
}
p.nets.cache[from] = to
}
func (p *Provider) mapNetwork(network string) string {
p.nets.RLock()
defer p.nets.RUnlock()
if to, ok := p.nets.cache[network]; ok {
return to
}
return network
}
// Listen begins listening at address for the specified network.
//
// Known networks are "memb" (memconn buffered) and "memu" (memconn unbuffered).
//
// When the specified address is already in use on the specified
// network an error is returned.
//
// When the provided network is unknown the operation defers to
// net.Dial.
func (p *Provider) Listen(network, address string) (net.Listener, error) {
switch p.mapNetwork(network) {
case networkMemb, networkMemu:
return p.ListenMem(
network, &Addr{Name: address, network: network})
default:
return net.Listen(network, address)
}
}
// ListenMem begins listening at laddr.
//
// Known networks are "memb" (memconn buffered) and "memu" (memconn unbuffered).
//
// If laddr is nil then ListenMem listens on "localhost" on the
// specified network.
func (p *Provider) ListenMem(network string, laddr *Addr) (*Listener, error) {
switch p.mapNetwork(network) {
case networkMemb, networkMemu:
// If laddr is not specified then set it to the reserved name
// "localhost".
if laddr == nil {
laddr = &Addr{Name: addrLocalhost, network: network}
} else {
laddr.network = network
}
default:
return nil, &net.OpError{
Addr: laddr,
Source: laddr,
Net: network,
Op: "listen",
Err: errors.New("unknown network"),
}
}
p.listeners.Lock()
defer p.listeners.Unlock()
if p.listeners.cache == nil {
p.listeners.cache = map[string]*Listener{}
}
if _, ok := p.listeners.cache[laddr.Name]; ok {
return nil, &net.OpError{
Addr: laddr,
Source: laddr,
Net: network,
Op: "listen",
Err: errors.New("addr unavailable"),
}
}
l := &Listener{
addr: *laddr,
done: make(chan struct{}),
rmvd: make(chan struct{}),
rcvr: make(chan *Conn, 1),
}
// Start a goroutine that removes the listener from
// the cache once the listener is closed.
go func() {
<-l.done
p.listeners.Lock()
defer p.listeners.Unlock()
delete(p.listeners.cache, laddr.Name)
close(l.rmvd)
}()
p.listeners.cache[laddr.Name] = l
return l, nil
}
// Dial dials a named connection.
//
// Known networks are "memb" (memconn buffered) and "memu" (memconn unbuffered).
//
// When the provided network is unknown the operation defers to
// net.Dial.
func (p *Provider) Dial(network, address string) (net.Conn, error) {
return p.DialContext(nil, network, address)
}
// DialMem dials a named connection.
//
// Known networks are "memb" (memconn buffered) and "memu" (memconn unbuffered).
//
// If laddr is nil then a new address is generated using
// time.Now().UnixNano(). Please note that client addresses are
// not required to be unique.
//
// If raddr is nil then the "localhost" endpoint is used on the
// specified network.
func (p *Provider) DialMem(
network string, laddr, raddr *Addr) (*Conn, error) {
return p.DialMemContext(nil, network, laddr, raddr)
}
// DialContext dials a named connection using a
// Go context to provide timeout behavior.
//
// Please see Dial for more information.
func (p *Provider) DialContext(
ctx context.Context,
network, address string) (net.Conn, error) {
switch p.mapNetwork(network) {
case networkMemb, networkMemu:
return p.DialMemContext(
ctx, network, nil, &Addr{
Name: address,
network: network,
})
default:
if ctx == nil {
return net.Dial(network, address)
}
return (&net.Dialer{}).DialContext(ctx, network, address)
}
}
// DialMemContext dials a named connection using a
// Go context to provide timeout behavior.
//
// Please see DialMem for more information.
func (p *Provider) DialMemContext(
ctx context.Context,
network string,
laddr, raddr *Addr) (*Conn, error) {
switch p.mapNetwork(network) {
case networkMemb, networkMemu:
// If laddr is not specified then create one with the current
// epoch in nanoseconds. This value need not be unique.
if laddr == nil {
laddr = &Addr{
Name: fmt.Sprintf("%d", time.Now().UnixNano()),
network: network,
}
} else {
laddr.network = network
}
if raddr == nil {
raddr = &Addr{Name: addrLocalhost, network: network}
} else {
raddr.network = network
}
default:
return nil, &net.OpError{
Addr: raddr,
Source: laddr,
Net: network,
Op: "dial",
Err: errors.New("unknown network"),
}
}
p.listeners.RLock()
defer p.listeners.RUnlock()
if l, ok := p.listeners.cache[raddr.Name]; ok {
// Update the provided raddr with the actual network type used
// by the listener.
raddr.network = l.addr.network
return l.dial(ctx, network, *laddr, *raddr)
}
return nil, &net.OpError{
Addr: raddr,
Source: laddr,
Net: network,
Op: "dial",
Err: errors.New("unknown remote address"),
}
}