246 lines
5.7 KiB
Go
246 lines
5.7 KiB
Go
package memconn
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// Provider is used to track named MemConn objects.
|
|
type Provider struct {
|
|
nets networkMap
|
|
listeners listenerCache
|
|
}
|
|
|
|
type listenerCache struct {
|
|
sync.RWMutex
|
|
cache map[string]*Listener
|
|
}
|
|
|
|
type networkMap struct {
|
|
sync.RWMutex
|
|
cache map[string]string
|
|
}
|
|
|
|
// MapNetwork enables mapping the network value provided to this Provider's
|
|
// Dial and Listen functions from the specified "from" value to the
|
|
// specified "to" value.
|
|
//
|
|
// For example, calling MapNetwork("tcp", "memu") means a subsequent
|
|
// Dial("tcp", "address") gets translated to Dial("memu", "address").
|
|
//
|
|
// Calling MapNetwork("tcp", "") removes any previous translation for
|
|
// the "tcp" network.
|
|
func (p *Provider) MapNetwork(from, to string) {
|
|
p.nets.Lock()
|
|
defer p.nets.Unlock()
|
|
if p.nets.cache == nil {
|
|
p.nets.cache = map[string]string{}
|
|
}
|
|
if to == "" {
|
|
delete(p.nets.cache, from)
|
|
return
|
|
}
|
|
p.nets.cache[from] = to
|
|
}
|
|
|
|
func (p *Provider) mapNetwork(network string) string {
|
|
p.nets.RLock()
|
|
defer p.nets.RUnlock()
|
|
if to, ok := p.nets.cache[network]; ok {
|
|
return to
|
|
}
|
|
return network
|
|
}
|
|
|
|
// Listen begins listening at address for the specified network.
|
|
//
|
|
// Known networks are "memb" (memconn buffered) and "memu" (memconn unbuffered).
|
|
//
|
|
// When the specified address is already in use on the specified
|
|
// network an error is returned.
|
|
//
|
|
// When the provided network is unknown the operation defers to
|
|
// net.Dial.
|
|
func (p *Provider) Listen(network, address string) (net.Listener, error) {
|
|
switch p.mapNetwork(network) {
|
|
case networkMemb, networkMemu:
|
|
return p.ListenMem(
|
|
network, &Addr{Name: address, network: network})
|
|
default:
|
|
return net.Listen(network, address)
|
|
}
|
|
}
|
|
|
|
// ListenMem begins listening at laddr.
|
|
//
|
|
// Known networks are "memb" (memconn buffered) and "memu" (memconn unbuffered).
|
|
//
|
|
// If laddr is nil then ListenMem listens on "localhost" on the
|
|
// specified network.
|
|
func (p *Provider) ListenMem(network string, laddr *Addr) (*Listener, error) {
|
|
|
|
switch p.mapNetwork(network) {
|
|
case networkMemb, networkMemu:
|
|
// If laddr is not specified then set it to the reserved name
|
|
// "localhost".
|
|
if laddr == nil {
|
|
laddr = &Addr{Name: addrLocalhost, network: network}
|
|
} else {
|
|
laddr.network = network
|
|
}
|
|
default:
|
|
return nil, &net.OpError{
|
|
Addr: laddr,
|
|
Source: laddr,
|
|
Net: network,
|
|
Op: "listen",
|
|
Err: errors.New("unknown network"),
|
|
}
|
|
}
|
|
|
|
p.listeners.Lock()
|
|
defer p.listeners.Unlock()
|
|
|
|
if p.listeners.cache == nil {
|
|
p.listeners.cache = map[string]*Listener{}
|
|
}
|
|
|
|
if _, ok := p.listeners.cache[laddr.Name]; ok {
|
|
return nil, &net.OpError{
|
|
Addr: laddr,
|
|
Source: laddr,
|
|
Net: network,
|
|
Op: "listen",
|
|
Err: errors.New("addr unavailable"),
|
|
}
|
|
}
|
|
|
|
l := &Listener{
|
|
addr: *laddr,
|
|
done: make(chan struct{}),
|
|
rmvd: make(chan struct{}),
|
|
rcvr: make(chan *Conn, 1),
|
|
}
|
|
|
|
// Start a goroutine that removes the listener from
|
|
// the cache once the listener is closed.
|
|
go func() {
|
|
<-l.done
|
|
p.listeners.Lock()
|
|
defer p.listeners.Unlock()
|
|
delete(p.listeners.cache, laddr.Name)
|
|
close(l.rmvd)
|
|
}()
|
|
|
|
p.listeners.cache[laddr.Name] = l
|
|
return l, nil
|
|
}
|
|
|
|
// Dial dials a named connection.
|
|
//
|
|
// Known networks are "memb" (memconn buffered) and "memu" (memconn unbuffered).
|
|
//
|
|
// When the provided network is unknown the operation defers to
|
|
// net.Dial.
|
|
func (p *Provider) Dial(network, address string) (net.Conn, error) {
|
|
return p.DialContext(nil, network, address)
|
|
}
|
|
|
|
// DialMem dials a named connection.
|
|
//
|
|
// Known networks are "memb" (memconn buffered) and "memu" (memconn unbuffered).
|
|
//
|
|
// If laddr is nil then a new address is generated using
|
|
// time.Now().UnixNano(). Please note that client addresses are
|
|
// not required to be unique.
|
|
//
|
|
// If raddr is nil then the "localhost" endpoint is used on the
|
|
// specified network.
|
|
func (p *Provider) DialMem(
|
|
network string, laddr, raddr *Addr) (*Conn, error) {
|
|
|
|
return p.DialMemContext(nil, network, laddr, raddr)
|
|
}
|
|
|
|
// DialContext dials a named connection using a
|
|
// Go context to provide timeout behavior.
|
|
//
|
|
// Please see Dial for more information.
|
|
func (p *Provider) DialContext(
|
|
ctx context.Context,
|
|
network, address string) (net.Conn, error) {
|
|
|
|
switch p.mapNetwork(network) {
|
|
case networkMemb, networkMemu:
|
|
return p.DialMemContext(
|
|
ctx, network, nil, &Addr{
|
|
Name: address,
|
|
network: network,
|
|
})
|
|
default:
|
|
if ctx == nil {
|
|
return net.Dial(network, address)
|
|
}
|
|
return (&net.Dialer{}).DialContext(ctx, network, address)
|
|
}
|
|
}
|
|
|
|
// DialMemContext dials a named connection using a
|
|
// Go context to provide timeout behavior.
|
|
//
|
|
// Please see DialMem for more information.
|
|
func (p *Provider) DialMemContext(
|
|
ctx context.Context,
|
|
network string,
|
|
laddr, raddr *Addr) (*Conn, error) {
|
|
|
|
switch p.mapNetwork(network) {
|
|
case networkMemb, networkMemu:
|
|
// If laddr is not specified then create one with the current
|
|
// epoch in nanoseconds. This value need not be unique.
|
|
if laddr == nil {
|
|
laddr = &Addr{
|
|
Name: fmt.Sprintf("%d", time.Now().UnixNano()),
|
|
network: network,
|
|
}
|
|
} else {
|
|
laddr.network = network
|
|
}
|
|
if raddr == nil {
|
|
raddr = &Addr{Name: addrLocalhost, network: network}
|
|
} else {
|
|
raddr.network = network
|
|
}
|
|
default:
|
|
return nil, &net.OpError{
|
|
Addr: raddr,
|
|
Source: laddr,
|
|
Net: network,
|
|
Op: "dial",
|
|
Err: errors.New("unknown network"),
|
|
}
|
|
}
|
|
|
|
p.listeners.RLock()
|
|
defer p.listeners.RUnlock()
|
|
|
|
if l, ok := p.listeners.cache[raddr.Name]; ok {
|
|
// Update the provided raddr with the actual network type used
|
|
// by the listener.
|
|
raddr.network = l.addr.network
|
|
return l.dial(ctx, network, *laddr, *raddr)
|
|
}
|
|
|
|
return nil, &net.OpError{
|
|
Addr: raddr,
|
|
Source: laddr,
|
|
Net: network,
|
|
Op: "dial",
|
|
Err: errors.New("unknown remote address"),
|
|
}
|
|
}
|