diff --git a/.gitignore b/.gitignore index c88dabf..ae59a6e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ hostkeys/ .DS_Store +Makefile.local.mk \ No newline at end of file diff --git a/Makefile b/Makefile index b68e66d..c4d3bdc 100644 --- a/Makefile +++ b/Makefile @@ -1,10 +1,9 @@ +-include Makefile.local.mk + export SSH_LISTEN?=:2222 export SSH_HOSTKEYS?=hostkeys export SSH_AUTHKEYS?=authkeys -export SSH_HOST?=localhost -export SSH_PORT?=2222 -export SSH_OPTS?=-R 0.0.0.0:1234:localhost:3000 run: go run . @@ -15,6 +14,3 @@ genkeys: ssh-keygen -q -N "" -t ecdsa -f $(SSH_HOSTKEYS)/ecdsa ssh-keygen -q -N "" -t ed25519 -f $(SSH_HOSTKEYS)/ed25519 rm -f $(SSH_HOSTKEYS)/*.pub - -forward: - ssh -T $(SSH_HOST) -p $(SSH_PORT) $(SSH_OPTS) diff --git a/go.mod b/go.mod index 628211f..6f073f2 100644 --- a/go.mod +++ b/go.mod @@ -5,5 +5,8 @@ go 1.15 require ( github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect github.com/gliderlabs/ssh v0.3.2 + github.com/soheilhy/cmux v0.1.5 // indirect + github.com/tjarratt/babble v0.0.0-20210505082055-cbca2a4833c1 // indirect + go.uber.org/multierr v1.7.0 // indirect golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf // indirect ) diff --git a/go.sum b/go.sum index 9f3ff6f..e0f98f2 100644 --- a/go.sum +++ b/go.sum @@ -1,13 +1,40 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/gliderlabs/ssh v0.3.2 h1:gcfd1Aj/9RQxvygu4l3sak711f/5+VOwBw9C/7+N4EI= github.com/gliderlabs/ssh v0.3.2/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/soheilhy/cmux v0.1.5 h1:jjzc5WVemNEDTLwv9tlmemhC73tI08BNOIGwBOo10Js= +github.com/soheilhy/cmux v0.1.5/go.mod h1:T7TcVDs9LWfQgPlPsdngu6I6QIoyIFZDDC6sNE1GqG0= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/tjarratt/babble v0.0.0-20210505082055-cbca2a4833c1 h1:j8whCiEmvLCXI3scVn+YnklCU8mwJ9ZJ4/DGAKqQbRE= +github.com/tjarratt/babble v0.0.0-20210505082055-cbca2a4833c1/go.mod h1:O5hBrCGqzfb+8WyY8ico2AyQau7XQwAfEQeEQ5/5V9E= +go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= +go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/multierr v1.7.0 h1:zaiO/rmgFjbmCXdSYJWQcdvOCsthmdaHfr3Gm2Kx4Ec= +go.uber.org/multierr v1.7.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf h1:B2n+Zi5QeYRDAEodEu72OS36gmTWjgpXr2+cWcBW90o= golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuFiOQWj1Fs7T3VrH4Pjw= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go index 0d6f75b..3ff24bf 100644 --- a/main.go +++ b/main.go @@ -5,22 +5,48 @@ import ( "fmt" "io/ioutil" "log" + "net" + "net/http" + "net/http/httputil" "os" + "os/signal" "path/filepath" + "strconv" + "strings" + "sync" + "time" "github.com/gliderlabs/ssh" + "github.com/soheilhy/cmux" + "github.com/tjarratt/babble" + "go.uber.org/multierr" +) + +const ( + domainName = "prox.int" + domainSuffix = ".prox.int" + portStart uint32 = 4000 + portEnd uint32 = 4999 + bindHost = "[::1]" ) func main() { - run() + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill) + defer stop() + + run(ctx) } -func run() { +func run(ctx context.Context) { + lis, err := net.Listen("tcp", envMust("SSH_LISTEN")) + if err != nil { + log.Fatal(err.Error()) + } + var opts []ssh.Option opts = append( opts, ssh.NoPty(), - optRemoteAllow(), ) files, err := ioutil.ReadDir(envMust("SSH_HOSTKEYS")) @@ -28,19 +54,30 @@ func run() { log.Fatal(err) } + srv := &server{ + bindHost: envDefault("SSH_HOST", bindHost), + portStart: portStart, + portEnd: portEnd, + domainName: envDefault("SSH_DOMAIN", domainName), + domainSuffix: envDefault("SSH_DOMAIN_SUFFIX", domainSuffix), + } + for _, f := range files { opts = append(opts, ssh.HostKeyFile(filepath.Join(envMust("SSH_HOSTKEYS"), f.Name()))) } + opts = append(opts, srv.optAuthUser()...) - if authKeys := os.Getenv("SSH_AUTHKEYS"); authKeys != "" { - opts = append(opts, optPubkeyAllow(authKeys)) + http.HandleFunc("/", srv.handleHTTP) + mux := New(lis, srv.serveHTTP(ctx), srv.serveSSH(ctx, opts...)) + + listen := mux.Listener.Addr().String() + if idx := strings.LastIndex(listen, ":"); idx >= 0 { + if i, err := strconv.Atoi(listen[idx+1:]); err == nil { + srv.listenPort = uint32(i) + } } - log.Fatal(ssh.ListenAndServe( - envMust("SSH_LISTEN"), - newSession(context.Background()), - opts..., - )) + log.Fatal(mux.Serve(ctx)) } func envMust(s string) string { @@ -48,57 +85,336 @@ func envMust(s string) string { if v == "" { log.Fatal("missing env ", s) } + log.Println("env", s, "==", v) + return v +} +func envDefault(s, d string) string { + v := os.Getenv(s) + if v == "" { + return d + } + log.Println("env", s, "==", v) return v } -func newSession(ctx context.Context) func(ssh.Session) { +func (srv *server) newSession(ctx context.Context) func(ssh.Session) { return func(s ssh.Session) { if _, err := fmt.Fprintf(s, "Hello %s\n", s.User()); err != nil { return } - <-ctx.Done() - } -} - -func optRemoteAllow() ssh.Option { - return func(cfg *ssh.Server) error { - hdlr := ssh.ForwardedTCPHandler{} - - if cfg.RequestHandlers == nil { - cfg.RequestHandlers = make(map[string]ssh.RequestHandler, 2) - } - - cfg.RequestHandlers["tcpip-forward"] = hdlr.HandleSSHRequest - cfg.RequestHandlers["cancel-tcpip-forward"] = hdlr.HandleSSHRequest - - cfg.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool { - log.Println("Allow Remote:", bindHost, bindPort) - - return true - } - - return nil - } -} - -func optPubkeyAllow(path string) ssh.Option { - return ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool { - files, err := ioutil.ReadDir(path) - if err != nil { - log.Fatal(err) - } - - for _, f := range files { - fname := filepath.Join(path, f.Name()) - data, _ := ioutil.ReadFile(fname) - allowed, _, _, _, _ := ssh.ParseAuthorizedKey(data) - if ssh.KeysEqual(key, allowed) { - log.Println("Authorized:", fname) - return true + if u, ok := srv.GetUserByName(s.User()); ok { + host := fmt.Sprintf("%v:%v", u.bindHost, u.bindPort) + director := func(req *http.Request) { + if h := req.Header.Get("X-Forwarded-Host"); h == "" { + req.Header.Set("X-Forwarded-Host", req.Host) + } + req.Header.Set("X-Origin-Host", host) + req.URL.Scheme = "http" + req.URL.Host = host } + u.proxy = &httputil.ReverseProxy{Director: director} + fmt.Fprintf(s, "Created HTTP listener at: %v%v\n", u.name, srv.domainSuffix) } + <-ctx.Done() + + if _, err := fmt.Fprintf(s, "Goodbye! %s\n", s.User()); err != nil { + return + } + } +} + +type server struct { + listenPort uint32 + domainName string + domainSuffix string + bindHost string + + portStart uint32 + portEnd uint32 + portNext uint32 + + ports sync.Map + users sync.Map +} + +type user struct { + name string + pubkey ssh.PublicKey + bindHost string + bindPort uint32 + ctx ssh.Context + proxy http.Handler +} + +func (srv *server) AddUser(pubkey ssh.PublicKey) *user { + u := &user{} + u.pubkey = pubkey + + babbler := babble.NewBabbler() + u.name = strings.ToLower(babbler.Babble()) + + u.bindPort = srv.nextPort() + u.bindHost = srv.bindHost + + srv.users.Store(u.name, u) + + return u +} +func (srv *server) nextPort() uint32 { + if srv.portNext < srv.portStart || srv.portNext > srv.portEnd { + srv.portNext = srv.portStart + } + + defer func() { srv.portNext++ }() + + return srv.portNext +} + +func (srv *server) GetUserByPort(port uint32) (*user, bool) { + if u, ok := srv.ports.Load(port); ok { + if u, ok := u.(*user); ok { + return u, true + } + } + return nil, false +} +func (srv *server) GetUserByName(name string) (*user, bool) { + if u, ok := srv.users.Load(name); ok { + if u, ok := u.(*user); ok { + return u, true + } + } + return nil, false +} +func (srv *server) ListUsers() []*user { + var lis []*user + srv.users.Range(func(key, value interface{}) bool { + if u, ok := value.(*user); ok { + lis = append(lis, u) + return true + } else { + fmt.Println(key, value) + + } return false }) + + return lis +} +func (srv *server) ListConnectedUsers() []*user { + var lis []*user + srv.ports.Range(func(key, value interface{}) bool { + if u, ok := value.(*user); ok { + lis = append(lis, u) + return true + } + return false + }) + + return lis +} +func (srv *server) optAuthUser() []ssh.Option { + return []ssh.Option{ + ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool { + u, ok := srv.GetUserByName(ctx.User()) + if !ok { + log.Println("user not found", ctx.User()) + return false + } + + if ssh.KeysEqual(key, u.pubkey) { + log.Println("User: ", ctx.User(), "Authorized:", u.bindHost, u.bindPort, ctx.ClientVersion(), ctx.SessionID(), ctx.LocalAddr(), ctx.RemoteAddr()) + u.ctx = ctx + srv.ports.Store(u.bindPort, u) + return true + } + + return false + }), + func(cfg *ssh.Server) error { + hdlr := ssh.ForwardedTCPHandler{} + + if cfg.RequestHandlers == nil { + cfg.RequestHandlers = make(map[string]ssh.RequestHandler, 2) + } + + cfg.RequestHandlers["tcpip-forward"] = hdlr.HandleSSHRequest + cfg.RequestHandlers["cancel-tcpip-forward"] = hdlr.HandleSSHRequest + + cfg.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool { + u, ok := srv.GetUserByPort(bindPort) + if !ok { + log.Println("User port", bindPort, "not authorized.") + return false + } + + if u.ctx.SessionID() != ctx.SessionID() { + log.Println("Port", bindPort, "in use by", u.name, u.ctx.SessionID()) + return false + } + + if bindHost != strings.Trim(u.bindHost, "[]") || bindPort != u.bindPort { + log.Println("User", ctx.User(), "Not Allowed: ", bindHost, bindPort, ctx.ClientVersion(), ctx.SessionID(), ctx.LocalAddr(), ctx.RemoteAddr()) + return false + } + + log.Println("User", ctx.User(), "Allow Remote:", bindHost, bindPort, ctx.ClientVersion(), ctx.SessionID(), ctx.LocalAddr(), ctx.RemoteAddr()) + return true + } + + return nil + }, + } +} + +func (srv *server) serveSSH(ctx context.Context, opts ...ssh.Option) func(l net.Listener) error { + return func(l net.Listener) error { + return ssh.Serve( + l, + srv.newSession(ctx), + opts..., + ) + } +} +func (srv *server) serveHTTP(ctx context.Context) func(net.Listener) error { + s := &http.Server{ + ReadTimeout: 2500 * time.Millisecond, + WriteTimeout: 5 * time.Second, + Handler: http.DefaultServeMux, + } + + go func(ctx context.Context) { + <-ctx.Done() + s.Shutdown(context.Background()) + }(ctx) + + return s.Serve +} + +func (srv *server) handleHTTP(rw http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.Host, srv.domainSuffix) { + name := strings.TrimSuffix(r.Host, srv.domainSuffix) + u, ok := srv.GetUserByName(name) + if !ok || u.proxy == nil { + fmt.Fprintln(rw, "NOT FOUND", name) + } + u.proxy.ServeHTTP(rw, r) + return + } + + if r.Method == http.MethodPost { + pubkey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(r.FormValue("pub"))) + if err != nil { + rw.WriteHeader(400) + fmt.Fprintln(rw, "ERR READING KEY") + return + } + u := srv.AddUser(pubkey) + rw.WriteHeader(201) + fmt.Fprintf(rw, `ssh -T -p %v %v@%v -R "%v:%v:localhost:$LOCAL_PORT" -i $PRIV_KEY`, srv.listenPort, u.name, srv.domainName, u.bindHost, u.bindPort) + return + } + + fmt.Fprintln(rw, "Hello!", r.Host, r.URL) + for _, u := range srv.ListUsers() { + fmt.Fprintln(rw, "User:", u.name) + } + for _, u := range srv.ListConnectedUsers() { + fmt.Fprintln(rw, "Conn:", u.name) + } +} + +// serverMux is mux server which will multiplex a listener to serve an http +// server using the http.DefaultServeMux handler, as well as a grpc server +// to serve a protobuf generated grpc.serverMux +type serverMux struct { + Listener net.Listener + HTTP func(net.Listener) error + SSH func(net.Listener) error + ServeMux *http.ServeMux +} + +func New(lis net.Listener, http, ssh func(net.Listener) error) *serverMux { + return &serverMux{ + Listener: lis, + HTTP: http, + SSH: ssh, + } +} + +// Serve begins serving a multiplexed server. Any errors returned before the stop +// signal is given indicate a failure of the server to start or an unexpected shutdown. +// Serve closes the listener once Shutdown has been triggered +func (m *serverMux) Serve(ctx context.Context) error { + errChanSSH := make(chan error) + errChanHTTP := make(chan error) + + defer func() { + err := multierr.Combine(m.Listener.Close(), <-errChanSSH, <-errChanHTTP) + if err != nil { + log.Println(err) + } + }() + + mux := cmux.New(m.Listener) + httpL := mux.Match(cmux.HTTP1Fast()) + sshL := mux.Match(cmux.Any()) + + go func() { + defer close(errChanSSH) + if err := m.SSH(sshL); err != nil { + switch err { + case cmux.ErrServerClosed: + log.Println("shutting down SSH Server") + default: + errChanSSH <- fmt.Errorf("failed to start SSH: %w", err) + } + } + }() + + go func() { + defer close(errChanHTTP) + if err := m.HTTP(httpL); err != nil { + switch err { + case cmux.ErrServerClosed: + log.Println("shutting down HTTP Server") + default: + errChanHTTP <- fmt.Errorf("failed to start HTTP: %w", err) + } + } + }() + + errChan := make(chan error) + go func() { + defer close(errChan) + err := mux.Serve() + if err != nil { + if strings.Contains(err.Error(), "use of closed network connection") { + log.Println("shutting down mux server") + } else { + errChan <- fmt.Errorf("failed to start server multiplexing: %w", err) + } + } + }() + + log.Println("server started: multiplexed http/1, http/2", + "address", m.Listener.Addr().String(), + "multiplexed", "true", + ) + + defer mux.Close() + + select { + case <-ctx.Done(): + log.Println("stopping multiplexed server gracefully") + return nil + case err := <-errChanSSH: + return err + case err := <-errChanHTTP: + return err + case err := <-errChan: + return err + } }