diff --git a/README.rst b/README.rst index b21156c..b41c240 100644 --- a/README.rst +++ b/README.rst @@ -2,7 +2,7 @@ Motivation ========== sshd implementation in Go, for the sole purpose of restricting the ports that -clients can request using direct-tcpip. +clients can request using direct-tcpip and tcpip-forward / forwarded-tcpip. OpenSSH refuses to merge patches to support this, but there is a fork of OpenSSH with patches that achieve something similar to this. [1] diff --git a/sshd.go b/sshd.go index dd39beb..b0ef979 100644 --- a/sshd.go +++ b/sshd.go @@ -18,18 +18,31 @@ import ( "golang.org/x/crypto/ssh" ) +// TODO: Use defer where useful + var ( authorisedKeys map[string]string - /* Global listeners, we keep a global state for cancel-tcpip-forward */ - globalListens map[string]net.Listener - listenport = flag.Int("listenport", 2200, "Port to listen on for incoming ssh connections") hostkey = flag.String("hostkey", "id_rsa", "Server host key to load") authorisedkeys = flag.String("authorisedkeys", "authorized_keys", "Authorised keys") verbose = flag.Bool("verbose", false, "Enable verbose mode") ) +type sshClient struct { + Name string + Conn *ssh.ServerConn + Listeners map[string]net.Listener + AllowedLocalPorts []uint32 + AllowedRemotePorts []uint32 +} + +type bindInfo struct { + Bound string + Port uint32 + Addr string +} + /* RFC4254 7.2 */ type directTCPPayload struct { Addr string // To connect to @@ -62,8 +75,6 @@ type tcpIpForwardCancelPayload struct { func main() { flag.Parse() - globalListens = make(map[string]net.Listener) - config := &ssh.ServerConfig{ PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { if ports, found := authorisedKeys[string(key.Marshal())]; found { @@ -101,6 +112,23 @@ func main() { return } + client := sshClient{"TODO FIXME XXX", sshConn, make(map[string]net.Listener), nil, nil} + + go func() { + err := client.Conn.Wait() + if *verbose { + log.Printf("SSH connection closed for client %s: %s", client.Name, err) + } + // TODO: Make this safe? Is it impossible for cancel code to be + // running at this point? + for bind, listener := range client.Listeners { + if *verbose { + log.Printf("Closing listener bound to %s", bind) + } + listener.Close() + } + }() + allowedPorts := sshConn.Permissions.CriticalOptions["ports"] if *verbose { @@ -110,27 +138,30 @@ func main() { // Parsing a second time should not error, so we can ignore the error // safely ports, _ := parsePorts(allowedPorts) + // TODO: Don't share same port/host limit + client.AllowedLocalPorts = ports + client.AllowedRemotePorts = ports - go handleRequest(sshConn, reqs) + go handleRequest(&client, reqs) - // Accept all channels - go handleChannels(chans, ports) + // Accept all channels (TODO: Pass client) + go handleChannels(&client, chans) }() } } -func handleChannels(chans <-chan ssh.NewChannel, ports []uint32) { +func handleChannels(client *sshClient, chans <-chan ssh.NewChannel) { for c := range chans { - go handleChannel(c, ports) + go handleChannel(client, c) } } -func handleChannel(newChannel ssh.NewChannel, ports []uint32) { +func handleChannel(client *sshClient, newChannel ssh.NewChannel) { if *verbose { log.Println("Channel type:", newChannel.ChannelType()) } if t := newChannel.ChannelType(); t == "direct-tcpip" { - handleDirect(newChannel, ports) + handleDirect(client, newChannel) return } @@ -145,12 +176,12 @@ func handleChannel(newChannel ssh.NewChannel, ports []uint32) { d := make([]byte, 4096) c.Read(d) }() - return */ + return } -func handleDirect(newChannel ssh.NewChannel, ports []uint32) { +func handleDirect(client *sshClient, newChannel ssh.NewChannel) { var payload directTCPPayload if err := ssh.Unmarshal(newChannel.ExtraData(), &payload); err != nil { log.Printf("Could not unmarshal extra data: %s\n", err) @@ -166,7 +197,7 @@ func handleDirect(newChannel ssh.NewChannel, ports []uint32) { } ok := false - for _, port := range ports { + for _, port := range client.AllowedLocalPorts { if payload.Port == port { ok = true break @@ -203,11 +234,12 @@ func handleDirect(newChannel ssh.NewChannel, ports []uint32) { serve(connection, rconn) } -func handleTcpIpForward(conn *ssh.ServerConn, req *ssh.Request) { +func handleTcpIpForward(client *sshClient, req *ssh.Request) (net.Listener, *bindInfo, error) { var payload tcpIpForwardPayload if err := ssh.Unmarshal(req.Payload, &payload); err != nil { log.Println("Unable to unmarshal payload") req.Reply(false, []byte{}) + return nil, nil, fmt.Errorf("Unable to parse payload") } log.Println("Request:", req.Type, req.WantReply, payload) @@ -217,7 +249,7 @@ func handleTcpIpForward(conn *ssh.ServerConn, req *ssh.Request) { if payload.Addr != "localhost" { log.Printf("Payload address is not \"localhost\"") req.Reply(false, []byte{}) - return + return nil, nil, fmt.Errorf("Address is not permitted") } // TODO: Check port @@ -233,84 +265,76 @@ func handleTcpIpForward(conn *ssh.ServerConn, req *ssh.Request) { if err != nil { log.Printf("Listen failed for %s", bind) req.Reply(false, []byte{}) - return + return nil, nil, err } - globalListens[bind] = ln - // Tell client everything is OK reply := tcpIpForwardPayloadReply{lport} req.Reply(true, ssh.Marshal(&reply)) - // Ensure that we get notified when the client connection is (unexpectedly) - // closed - go func() { - err := conn.Wait() - if *verbose { - log.Printf("SSH connection closed: %s. Stopping listen", err) - } - ln.Close() - delete(globalListens, bind) + return ln, &bindInfo{bind, lport, laddr}, nil - // We don't close existing connections - }() +} +func handleListener(client *sshClient, bindinfo *bindInfo, listener net.Listener) { // Start listening for connections - go func() { - for { - lconn, err := ln.Accept() - if err != nil { - neterr := err.(net.Error) - if neterr.Timeout() { - log.Println("Accept failed with timeout:", err) - continue - } - if neterr.Temporary() { - log.Println("Accept failed with temporary:", err) - continue - } - - break + for { + lconn, err := listener.Accept() + if err != nil { + neterr := err.(net.Error) + if neterr.Timeout() { + log.Println("Accept failed with timeout:", err) + continue + } + if neterr.Temporary() { + log.Println("Accept failed with temporary:", err) + continue } - // TODO: Sep function? - go func() { - remotetcpaddr := lconn.RemoteAddr().(*net.TCPAddr) - raddr := remotetcpaddr.IP.String() - rport := uint32(remotetcpaddr.Port) - - payload := forwardedTCPPayload{laddr, lport, raddr, uint32(rport)} - mpayload := ssh.Marshal(&payload) - - // Open channel with client - c, requests, err := conn.OpenChannel("forwarded-tcpip", mpayload) - if err != nil { - log.Printf("Error: %s", err) - log.Println("Unable to get channel. Hanging up requesting party!") - lconn.Close() - return - } - go ssh.DiscardRequests(requests) - - serve(c, lconn) - }() + break } - }() + + // TODO: I don't think a goroutine is required here + go handleForwardTcpIp(client, bindinfo, lconn) + } } -func handleTcpIPForwardCancel(req *ssh.Request) { +func handleForwardTcpIp(client *sshClient, bindinfo *bindInfo, lconn net.Conn) { + remotetcpaddr := lconn.RemoteAddr().(*net.TCPAddr) + raddr := remotetcpaddr.IP.String() + rport := uint32(remotetcpaddr.Port) + + payload := forwardedTCPPayload{bindinfo.Addr, bindinfo.Port, raddr, uint32(rport)} + mpayload := ssh.Marshal(&payload) + + // Open channel with client + c, requests, err := client.Conn.OpenChannel("forwarded-tcpip", mpayload) + if err != nil { + log.Printf("Error: %s", err) + log.Println("Unable to get channel. Hanging up requesting party!") + lconn.Close() + return + } + go ssh.DiscardRequests(requests) + + serve(c, lconn) +} + +func handleTcpIPForwardCancel(client *sshClient, req *ssh.Request) { + if *verbose { + log.Println("Cancel called by client", client) + } var payload tcpIpForwardCancelPayload if err := ssh.Unmarshal(req.Payload, &payload); err != nil { log.Println("Unable to unmarshal cancel payload") req.Reply(false, []byte{}) } - //bound := fmt.Sprintf(":%d", payload.Port) bound := fmt.Sprintf("%s:%d", payload.Addr, payload.Port) - if listener, found := globalListens[bound]; found { + if listener, found := client.Listeners[bound]; found { listener.Close() - delete(globalListens, bound) + delete(client.Listeners, bound) req.Reply(true, []byte{}) } @@ -401,7 +425,7 @@ func loadAuthorisedKeys(authorisedkeys string) { } } -func handleRequest(sshConn *ssh.ServerConn, reqs <-chan *ssh.Request) { +func handleRequest(client *sshClient, reqs <-chan *ssh.Request) { for req := range reqs { if *verbose { log.Println("Out of band request:", req.Type, req.WantReply) @@ -409,10 +433,16 @@ func handleRequest(sshConn *ssh.ServerConn, reqs <-chan *ssh.Request) { // RFC4254: 7.1 for forwarding if req.Type == "tcpip-forward" { - handleTcpIpForward(sshConn, req) + listener, bindinfo, err := handleTcpIpForward(client, req) + if err != nil { + continue + } + + client.Listeners[bindinfo.Bound] = listener + go handleListener(client, bindinfo, listener) continue } else if req.Type == "cancel-tcpip-forward" { - handleTcpIPForwardCancel(req) + handleTcpIPForwardCancel(client, req) continue } else { // Discard everything else