diff --git a/TODO b/TODO index 95dd7d1..f3489c2 100644 --- a/TODO +++ b/TODO @@ -1,4 +1,9 @@ * Make sure to not run this as root (setuid doesn't work well), so use NET capabilities * Check assertions and TODOs. -* Put ports to be allowed in options field, not in comments. +* Look if/where we want to set deadlines on open sockets +* Go through all log.Println calls, and make sure they are unique(?) and + sensible, and are not too verbose, and/or hidden behind *verbose +* FILTER for forwarded ports +* Change format of authorized_keys to allow for both forwarded and direct filtering * Put device identifier in comments. +* Add some client identifier to log messages diff --git a/sshd.go b/sshd.go index 07016cd..bfbe4f9 100644 --- a/sshd.go +++ b/sshd.go @@ -4,10 +4,6 @@ package main // Merlijn B. W. Wajer // (C) 2017 -// Trivial parts taken from: -// * https://blog.gopheracademy.com/advent-2015/ssh-server-in-go/ -// * https://github.com/tg123/sshpiper/commit/9db468b52dfc2cbe936efb7bef0fd5b88e0c1649 - import ( "flag" "fmt" @@ -25,15 +21,49 @@ import ( var ( authorisedKeys map[string]string + /* Global listeners, we keep a global state for cancel-tcpip-forward */ + globalListens map[string]net.Listener + listenport = flag.Int("listenport", 2200, "Port to listen on for incoming ssh connections") hostkey = flag.String("hostkey", "id_rsa", "Server host key to load") authorisedkeys = flag.String("authorisedkeys", "authorized_keys", "Authorised keys") verbose = flag.Bool("verbose", false, "Enable verbose mode") ) +/* RFC4254 7.2 */ +type directTCPPayload struct { + Addr string // To connect to + Port uint32 + OriginAddr string + OriginPort uint32 +} + +type forwardedTCPPayload struct { + Addr string // Is connected to + Port uint32 + OriginAddr string + OriginPort uint32 +} + +type tcpIpForwardPayload struct { + Addr string + Port uint32 +} + +type tcpIpForwardPayloadReply struct { + Port uint32 +} + +type tcpIpForwardCancelPayload struct { + Addr string + Port uint32 +} + func main() { flag.Parse() + globalListens = make(map[string]net.Listener) + config := &ssh.ServerConfig{ PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { if ports, found := authorisedKeys[string(key.Marshal())]; found { @@ -46,17 +76,7 @@ func main() { }, } - privateBytes, err := ioutil.ReadFile(*hostkey) - if err != nil { - log.Fatal("Failed to load private key (./id_rsa)") - } - - private, err := ssh.ParsePrivateKey(privateBytes) - if err != nil { - log.Fatal("Failed to parse private key") - } - - config.AddHostKey(private) + loadHostKeys(config) loadAuthorisedKeys(*authorisedkeys) listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", *listenport)) @@ -85,11 +105,31 @@ func main() { log.Printf("Connection from %s (%s). Allowed ports: %s", sshConn.RemoteAddr(), sshConn.ClientVersion(), allowedPorts) } - // Parsing a second time should not error + // Parsing a second time should not error, so we can ignore the error + // safely ports, _ := parsePorts(allowedPorts) - // Discard all global out-of-band Requests - go ssh.DiscardRequests(reqs) + // Handle global out-of-band Requests + go func() { + for req := range reqs { + if *verbose { + log.Println("Out of band request:", req.Type, req.WantReply) + } + + // RFC4254: 7.1 for forwarding + if req.Type == "tcpip-forward" { + handleTcpIpForward(sshConn, req) + continue + } else if req.Type == "cancel-tcpip-forward" { + handleTcpIPForwardCancel(req) + continue + } else { + // Discard everything else + req.Reply(false, []byte{}) + } + } + }() + // Accept all channels go handleChannels(chans, ports) } @@ -101,23 +141,32 @@ func handleChannels(chans <-chan ssh.NewChannel, ports []uint32) { } } -/* RFC4254 7.2 */ -type directTCPPayload struct { - Addr string - Port uint32 - OriginAddr string - OriginPort uint32 -} - func handleChannel(newChannel ssh.NewChannel, ports []uint32) { if *verbose { log.Println("Channel type:", newChannel.ChannelType()) } - if t := newChannel.ChannelType(); t != "direct-tcpip" { - newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Only \"direct-tcpip\" is accepted")) + if t := newChannel.ChannelType(); t == "direct-tcpip" { + handleDirect(newChannel, ports) return } + newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Only \"direct-tcpip\" is accepted")) + /* + // TODO: USE THIS ONLY FOR USING SSH ESCAPE SEQUENCES + c, _, err := newChannel.Accept() + if err != nil { + log.Fatal(err) + } + go func() { + d := make([]byte, 4096) + c.Read(d) + }() + return + */ + +} + +func handleDirect(newChannel ssh.NewChannel, ports []uint32) { var payload directTCPPayload if err := ssh.Unmarshal(newChannel.ExtraData(), &payload); err != nil { log.Printf("Could not unmarshal extra data: %s\n", err) @@ -127,6 +176,7 @@ func handleChannel(newChannel ssh.NewChannel, ports []uint32) { } if payload.Addr != "localhost" { + log.Printf("Tried to connect to prohibited host: %s", payload.Addr) newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Bad addr")) return } @@ -141,22 +191,22 @@ func handleChannel(newChannel ssh.NewChannel, ports []uint32) { if !ok { newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Bad port")) - log.Printf("Tried to forward prohibited port: %d", payload.Port) + log.Printf("Tried to connect to prohibited port: %d", payload.Port) return } - // At this point, we have the opportunity to reject the client's + // At this point, we have the opportunity to reject the clients // request for another logical connection connection, requests, err := newChannel.Accept() - _ = requests // TODO: Think we can just ignore these if err != nil { log.Printf("Could not accept channel (%s)", err) return } + go ssh.DiscardRequests(requests) addr := fmt.Sprintf("%s:%d", payload.Addr, payload.Port) if *verbose { - log.Println("Going to dial:", addr) + log.Println("Dialing:", addr) } rconn, err := net.Dial("tcp", addr) @@ -166,21 +216,134 @@ func handleChannel(newChannel ssh.NewChannel, ports []uint32) { return } + serve(connection, rconn) +} + +func handleTcpIpForward(conn *ssh.ServerConn, req *ssh.Request) { + var payload tcpIpForwardPayload + if err := ssh.Unmarshal(req.Payload, &payload); err != nil { + log.Println("Unable to unmarshal payload") + req.Reply(false, []byte{}) + } + + log.Println("Request:", req.Type, req.WantReply, payload) + + log.Printf("Request to listen on %s:%d", payload.Addr, payload.Port) + + if payload.Addr != "localhost" { + log.Printf("Payload address is not \"localhost\"") + req.Reply(false, []byte{}) + return + } + + // TODO: Check port + + laddr := payload.Addr + lport := payload.Port + + // TODO: We currently bind to localhost:port, and not to :port + // Need to figure out what we want - perhaps just part of policy + bind := fmt.Sprintf("%s:%d", laddr, lport) + ln, err := net.Listen("tcp", bind) + if err != nil { + log.Printf("Listen failed for %s", bind) + req.Reply(false, []byte{}) + return + } + + globalListens[bind] = ln + + // Tell client everything is OK + reply := tcpIpForwardPayloadReply{lport} + req.Reply(true, ssh.Marshal(&reply)) + + // Ensure that we get notified when the client connection is (unexpectedly) + // closed + go func() { + err := conn.Wait() + if *verbose { + log.Printf("SSH connection closed: %s. Stopping listen", err) + } + ln.Close() + delete(globalListens, bind) + + // We don't close existing connections + }() + + // Start listening for connections + go func() { + for { + lconn, err := ln.Accept() + if err != nil { + log.Println("Accept failed") + break + } + + go func() { + remoteaddr := lconn.RemoteAddr().String() + + p_index := strings.LastIndex(remoteaddr, ":") + raddr := remoteaddr[:p_index] + rport, err := strconv.ParseUint(remoteaddr[p_index+1:], 10, 32) + if err != nil { + log.Printf("Unable to parse RemoteAddr! (%s)", err) + lconn.Close() + return + } + + payload := forwardedTCPPayload{laddr, lport, raddr, uint32(rport)} + mpayload := ssh.Marshal(&payload) + + // Open channel with client + c, requests, err := conn.OpenChannel("forwarded-tcpip", mpayload) + if err != nil { + log.Printf("Error: %s", err) + log.Println("Unable to get channel. Hanging up requesting party!") + lconn.Close() + return + } + go ssh.DiscardRequests(requests) + + serve(c, lconn) + }() + } + }() +} + +func handleTcpIPForwardCancel(req *ssh.Request) { + var payload tcpIpForwardCancelPayload + if err := ssh.Unmarshal(req.Payload, &payload); err != nil { + log.Println("Unable to unmarshal cancel payload") + req.Reply(false, []byte{}) + } + + bound := fmt.Sprintf("%s:%d", payload.Addr, payload.Port) + + if listener, found := globalListens[bound]; found { + listener.Close() + delete(globalListens, bound) + req.Reply(true, []byte{}) + } + + req.Reply(false, []byte{}) +} + +func serve(cssh ssh.Channel, conn net.Conn) { close := func() { - connection.Close() - rconn.Close() + cssh.Close() + conn.Close() if *verbose { - log.Printf("Session closed") + log.Printf("Channel closed") } } var once sync.Once go func() { - io.Copy(connection, rconn) + io.Copy(cssh, conn) once.Do(close) }() go func() { - io.Copy(rconn, connection) + io.Copy(conn, cssh) once.Do(close) }() } @@ -197,6 +360,20 @@ func parsePorts(portstr string) (p []uint32, err error) { return } +func loadHostKeys(config *ssh.ServerConfig) { + privateBytes, err := ioutil.ReadFile(*hostkey) + if err != nil { + log.Fatal("Failed to load private key (./id_rsa)") + } + + private, err := ssh.ParsePrivateKey(privateBytes) + if err != nil { + log.Fatal("Failed to parse private key") + } + + config.AddHostKey(private) +} + func loadAuthorisedKeys(authorisedkeys string) { authorisedKeys = map[string]string{} authorisedKeysBytes, err := ioutil.ReadFile(authorisedkeys) @@ -213,13 +390,13 @@ func loadAuthorisedKeys(authorisedkeys string) { log.Println("Options:", options) if len(options) != 1 { - log.Fatal(fmt.Errorf("Only one option is accepted: 'ports=...'")) + log.Fatal(fmt.Errorf("Only one option is accepted: \"ports=...\"")) } option := options[0] if !strings.HasPrefix(option, "ports=") { - log.Fatal(fmt.Errorf("Options does not start with 'ports='")) + log.Fatal(fmt.Errorf("Options does not start with \"ports=\"")) } ports := option[len("ports="):]