diff --git a/flake.nix b/flake.nix index 7e4ab44..35ce91c 100644 --- a/flake.nix +++ b/flake.nix @@ -48,12 +48,16 @@ type = bool; default = true; }; - proto = mkOption {type = proto;}; - dest = mkOption {type = str;}; - hostProto = mkOption {type = proto;}; - port = mkOption {type = port;}; + forwards = mkOption { + type = listOf (submodule { + options = { + proto = mkOption {type = proto;}; + port = mkOption {type = port;}; + dest = mkOption {type = str;}; + }; + }); + }; }; - config.hostProto = mkDefault config.proto; })); }; }; @@ -65,10 +69,7 @@ }; in lib.mapAttrs' (hostname: { - proto, - hostProto, - port, - dest, + forwards, enable, ... }: let @@ -78,7 +79,7 @@ value = { inherit enable; script = '' - TS_AUTHKEY=$(cat $RUNTIME_DIRECTORY/authkey) ${lib.getExe self.packages.${pkgs.system}.default} ${hostProto} ${hostname} ${toString port} ${proto} ${dest} + TS_AUTHKEY=$(cat $RUNTIME_DIRECTORY/authkey) ${lib.getExe self.packages.${pkgs.system}.default} ${hostname} ${lib.concatMapStringsSep " " ({proto, port, dest}: "${proto}:${proto}:${dest}") forwards} ''; wantedBy = ["multi-user.target"]; serviceConfig = { diff --git a/main.go b/main.go index b3bae3d..627654b 100644 --- a/main.go +++ b/main.go @@ -5,30 +5,45 @@ import ( "log" "net" "os" + "strings" "ben.soroos.net/tsnet-proxy/netforward" "tailscale.com/tsnet" ) -var ( - host_proto = os.Args[1] - hostname = os.Args[2] - host_port = os.Args[3] - dst_proto = os.Args[4] - dst_addr = os.Args[5] -) +var hostname = os.Args[1] -type Dialer struct{} + +type Dialer struct{ + proto string; + addr string; +} + +type Forward struct{ + proto string; + port string; + dst string; +} func (dialer Dialer) Dial() (net.Conn, error) { - return net.Dial(dst_proto, dst_addr) + return net.Dial(dialer.proto, dialer.addr) +} + +func (forward Forward) Run(server *tsnet.Server, finish chan error) { + ln, err := server.Listen(forward.proto, fmt.Sprint(":", forward.port)) + defer ln.Close() + if err != nil { + finish <- err + } + err = netforward.Forward(Dialer { proto: forward.proto, addr: forward.dst }, ln) + if err != nil { + finish <- err + } } func main() { - if !(host_proto != "" && hostname != "" && host_port != "" && dst_proto != "" && dst_addr != "") { - fmt.Println("Usage: tsnet-proxy PROTO HOSTNAME PORT DST-PROTO DST") - fmt.Println("Where PROTO is one of tcp, udp, unix") - fmt.Println("Got: ", host_proto, hostname, host_port, dst_proto, dst_addr) + if !(hostname != "") { + fmt.Println("Usage: tsnet-proxy HOSTNAME FORWARDS...") os.Exit(1) } @@ -42,11 +57,14 @@ func main() { } defer s.Close() - ln, err := s.Listen(host_proto, fmt.Sprint(":", host_port)) - if err != nil { - log.Fatal(err) + err_chan := make(chan error) + for _, arg := range os.Args[2:] { + args := strings.SplitN(arg, ":", 3) + proto, port, dst := args[0], args[1], args[2] + go func() { + Forward { proto, port, dst }.Run(s, err_chan); + }(); } - defer ln.Close() - - netforward.Forward(Dialer {}, ln) + err := <- err_chan + log.Fatal(err) }