sshfwd/server.go

380 lines
8.9 KiB
Go

package main
import (
"context"
"crypto/sha256"
"embed"
"fmt"
"io/fs"
"log"
"net"
"net/http"
"net/http/httputil"
"strings"
"sync"
"text/template"
"time"
"github.com/gliderlabs/ssh"
"github.com/wolfeidau/humanhash"
)
var (
//go:embed pages/* layouts/* assets/*
files embed.FS
templates map[string]*template.Template
)
type user struct {
Name string
Pubkey ssh.PublicKey
BindHost string
BindPort uint32
ctx ssh.Context
proxy http.Handler
LastLogin time.Time
}
func (u *user) Active() bool { return u.ctx != nil }
func (u *user) String() string {
var b strings.Builder
fmt.Fprintln(&b, "User: ", u.Name)
fmt.Fprintf(&b, " Ptr: %p\n", u)
fmt.Fprintf(&b, " Pubkey: %x\n", u.Pubkey)
fmt.Fprintln(&b, " Host: ", u.BindHost)
fmt.Fprintln(&b, " Port: ", u.BindPort)
fmt.Fprintf(&b, " Active: %t\n", u.ctx != nil)
fmt.Fprintln(&b, " LastLog:", u.LastLogin)
return b.String()
}
type server struct {
listenPort uint32
domainName string
domainSuffix string
bindHost string
portStart uint32
portEnd uint32
portNext uint32
ports sync.Map
users sync.Map
}
func (s *server) String() string {
var b strings.Builder
fmt.Fprintln(&b, "Server: ", s.domainName)
fmt.Fprintln(&b, " Port: ", s.listenPort)
fmt.Fprintln(&b, " Suffix: ", s.domainSuffix)
fmt.Fprintln(&b, " BindHost: ", s.bindHost)
fmt.Fprintf(&b, " PortRange: %d-%d\n", s.portStart, s.portEnd)
fmt.Fprintln(&b, " NextPort: ", s.portNext)
return b.String()
}
// User Operations
func (srv *server) addUser(pubkey ssh.PublicKey) *user {
u := &user{}
u.LastLogin = time.Now()
u.Name = fingerprintHuman(pubkey)
u.Name = strings.ToLower(u.Name)
u.Name = filterName.ReplaceAllString(u.Name, "")
if g, ok := srv.users.LoadOrStore(u.Name, u); ok {
u = g.(*user)
return u
}
u.Pubkey = pubkey
u.BindPort = srv.nextPort()
u.BindHost = srv.bindHost
return u
}
func (srv *server) disconnectUser(name string) {
if u, ok := srv.getUserByName(name); ok {
u.ctx = nil
u.proxy = nil
srv.ports.Delete(u.BindPort)
}
}
func (srv *server) getUserByPort(port uint32) (*user, bool) {
if u, ok := srv.ports.Load(port); ok {
log.Printf("%d %T %s", port, u, u)
if u, ok := u.(*user); ok {
return u, true
} else {
log.Println("port not found", port, ok)
}
}
return nil, false
}
func (srv *server) getUserByName(name string) (*user, bool) {
if u, ok := srv.users.Load(name); ok {
if u, ok := u.(*user); ok {
return u, true
} else {
log.Println("user not found", name, ok)
}
}
return nil, false
}
func (srv *server) listUsers() []*user {
var lis []*user
srv.users.Range(func(key, value interface{}) bool {
if u, ok := value.(*user); ok {
lis = append(lis, u)
return true
} else {
fmt.Println(key, value)
}
return false
})
return lis
}
func (srv *server) listPorts() map[uint32]*user {
lis := make(map[uint32]*user)
srv.ports.Range(func(key, value interface{}) bool {
if u, ok := value.(*user); ok {
lis[key.(uint32)] = u
return true
} else {
fmt.Println(key, value)
}
return false
})
return lis
}
func (srv *server) nextPort() uint32 {
if srv.portNext < srv.portStart || srv.portNext > srv.portEnd {
srv.portNext = srv.portStart
}
defer func() { srv.portNext++ }()
return srv.portNext
}
// SSH Operations
func (srv *server) serveSSH(ctx context.Context, opts ...ssh.Option) func(l net.Listener) error {
return func(l net.Listener) error {
return ssh.Serve(
l,
srv.newSession(ctx),
opts...,
)
}
}
func (srv *server) newSession(ctx context.Context) func(ssh.Session) {
return func(s ssh.Session) {
if _, err := fmt.Fprintf(s, "Hello %s\n", s.User()); err != nil {
return
}
if u, ok := srv.getUserByName(s.User()); ok {
host := fmt.Sprintf("%v:%v", "localhost", u.BindPort)
director := func(req *http.Request) {
if h := req.Header.Get("X-Forwarded-Host"); h == "" {
req.Header.Set("X-Forwarded-Host", req.Host)
}
req.Header.Set("X-Origin-Host", host)
req.URL.Scheme = "http"
req.URL.Host = host
requestDump, err := httputil.DumpRequest(req, req.Method == http.MethodPost || req.Method == http.MethodPut)
if err != nil {
fmt.Println(err)
}
fmt.Fprintln(s, string(requestDump))
}
u.proxy = &httputil.ReverseProxy{Director: director}
fmt.Fprintf(s, "Created HTTP listener at: %v%v\n\n", u.Name, srv.domainSuffix)
}
select {
case <-ctx.Done():
log.Println("server shutting down")
case <-s.Context().Done():
log.Println("user", s.User(), "disconnected")
}
srv.disconnectUser(s.User())
if _, err := fmt.Fprintf(s, "Goodbye! %s\n", s.User()); err != nil {
return
}
}
}
func (srv *server) optAuthUser() []ssh.Option {
return []ssh.Option{
ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
u, ok := srv.getUserByName(ctx.User())
if !ok {
log.Println("user not found", ctx.User())
return false
}
if ssh.KeysEqual(key, u.Pubkey) {
log.Println("User:", ctx.User(), "Authorized:", u.BindHost, u.BindPort, ctx.ClientVersion(), ctx.SessionID(), ctx.LocalAddr(), ctx.RemoteAddr())
u.ctx = ctx
u.LastLogin = time.Now()
if _, loaded := srv.ports.LoadOrStore(u.BindPort, u); loaded {
log.Println("User:", ctx.User(), "already connected!")
return false
}
return true
}
return false
}),
func(cfg *ssh.Server) error {
hdlr := ssh.ForwardedTCPHandler{}
if cfg.RequestHandlers == nil {
cfg.RequestHandlers = make(map[string]ssh.RequestHandler, 2)
}
cfg.RequestHandlers["tcpip-forward"] = hdlr.HandleSSHRequest
cfg.RequestHandlers["cancel-tcpip-forward"] = hdlr.HandleSSHRequest
cfg.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
u, ok := srv.getUserByPort(bindPort)
if !ok {
log.Println("User port", bindPort, "not authorized.")
return false
}
if u.ctx.SessionID() != ctx.SessionID() {
log.Println("Port", bindPort, "in use by", u.Name, u.ctx.SessionID())
return false
}
if bindHost != "localhost" || bindPort != u.BindPort {
log.Println("User", ctx.User(), "Not Allowed: ", bindHost, bindPort, ctx.ClientVersion(), ctx.SessionID(), ctx.LocalAddr(), ctx.RemoteAddr())
return false
}
log.Println("User", ctx.User(), "Allow Remote:", bindHost, bindPort, ctx.ClientVersion(), ctx.SessionID(), ctx.LocalAddr(), ctx.RemoteAddr())
return true
}
return nil
},
}
}
// HTTP Operations
func (srv *server) serveHTTP(ctx context.Context) func(net.Listener) error {
s := &http.Server{
ReadTimeout: 2500 * time.Millisecond,
WriteTimeout: 5 * time.Second,
Handler: http.DefaultServeMux,
BaseContext: func(net.Listener) context.Context { return ctx },
}
go func(ctx context.Context) {
<-ctx.Done()
s.Shutdown(context.Background())
}(ctx)
return s.Serve
}
func (srv *server) handleHTTP(rw http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.Host, srv.domainSuffix) {
name := strings.TrimSuffix(r.Host, srv.domainSuffix)
u, ok := srv.getUserByName(name)
if !ok || u.proxy == nil {
fmt.Fprintln(rw, "NOT FOUND", name)
}
u.proxy.ServeHTTP(rw, r)
return
}
if r.Method == http.MethodPost {
pubkey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(r.FormValue("pub")))
if err != nil {
rw.WriteHeader(400)
fmt.Fprintln(rw, "ERR READING KEY", err)
return
}
u := srv.addUser(pubkey)
rw.Header().Set("Location", "/")
rw.WriteHeader(http.StatusFound)
fmt.Fprintf(rw, `ssh -T -p %v %v@%v -R "%v:%v:localhost:$LOCAL_PORT" -i $PRIV_KEY`+"\n", srv.listenPort, u.Name, srv.domainName, u.BindHost, u.BindPort)
return
}
// fmt.Fprintln(rw, "Hello!")
// fmt.Fprintln(rw, srv)
// fmt.Fprintln(rw, "Registered Users")
// for _, u := range srv.listUsers() {
// fmt.Fprintln(rw, u)
// }
// fmt.Fprintln(rw, "Connected Users")
// for _, u := range srv.listConnectedUsers() {
// fmt.Fprintln(rw, u)
// }
a, _ := fs.Sub(files, "assets")
assets := http.StripPrefix("/assets/", http.FileServer(http.FS(a)))
if strings.HasPrefix(r.URL.Path, "/assets/") {
assets.ServeHTTP(rw, r)
return
}
t := templates["home.go.tpl"]
err := t.Execute(rw, map[string]any{
"Users": srv.listUsers(),
"Ports": srv.listPorts(),
"ListenPort": srv.listenPort,
"DomainName": srv.domainName,
})
if err != nil {
log.Println(err)
}
}
func fingerprintHuman(pubKey ssh.PublicKey) string {
sha256sum := sha256.Sum256(pubKey.Marshal())
h, _ := humanhash.Humanize(sha256sum[:], 3)
return h
}
var funcMap = map[string]any{}
func loadTemplates() error {
if templates != nil {
return nil
}
templates = make(map[string]*template.Template)
tmplFiles, err := fs.ReadDir(files, "pages")
if err != nil {
return err
}
for _, tmpl := range tmplFiles {
if tmpl.IsDir() {
continue
}
pt := template.New(tmpl.Name())
pt.Funcs(funcMap)
pt, err = pt.ParseFS(files, "pages/"+tmpl.Name(), "layouts/*.go.tpl")
if err != nil {
log.Println(err)
return err
}
templates[tmpl.Name()] = pt
}
return nil
}