package main // Copyright: // Merlijn B. W. Wajer // (C) 2017 import ( "flag" "fmt" "io" "io/ioutil" "log" "net" "strconv" "strings" "sync" "time" "golang.org/x/crypto/ssh" ) var ( // Contains a mapping of authorised keys to permissions for said key authorisedKeys map[string]deviceInfo listenaddr = flag.String("listenaddr", "0.0.0.0", "Addr to listen on for incoming ssh connections") listenport = flag.Int("listenport", 2200, "Port to listen on for incoming ssh connections") hostkey = flag.String("hostkey", "id_rsa", "Server host key to load") authorisedkeys = flag.String("authorisedkeys", "authorized_keys", "Authorised keys") verbose = flag.Bool("verbose", false, "Enable verbose mode") debug = flag.Bool("debug", false, "Enable debug mode") // Currently the timeouts are not separate for read and write deadlines. // This could be done, but I currently don't really see a reason for this. maintimeout = flag.Duration("main-timeout", time.Duration(2)*time.Minute, "Client socket timeout") directtimeout = flag.Duration("direct-timeout", time.Duration(2)*time.Minute, "direct-tcpip timeout") forwardedtimeout = flag.Duration("forwarded-timeout", time.Duration(2)*time.Minute, "forwarded-tcpip timeout") // Mutex protecting 'authorisedKeys' map authmutex sync.Mutex ) // Structure that holds all information for each connection/client type sshClient struct { Name string // We keep track of the normal Conn as well so that we have access to the // SetDeadline() methods Conn net.Conn SshConn *ssh.ServerConn // Listener sockets opened by the client Listeners map[string]net.Listener AllowedLocalPorts []uint32 AllowedRemotePorts []uint32 // This indicates that a client is shutting down. When a client is stopping, // we do not allow new listening requests, to prevent a listener connection // being opened just after we closed all of them. Stopping bool ListenMutex sync.Mutex } // Structure containing what address/port we should bind on, for forwarded-tcpip // connections type bindInfo struct { Bound string Port uint32 Addr string } // Information parsed from the authorized_keys file type deviceInfo struct { LocalPorts string RemotePorts string Comment string } /* RFC4254 7.2 */ type directTCPPayload struct { Addr string // To connect to Port uint32 OriginAddr string OriginPort uint32 } type forwardedTCPPayload struct { Addr string // Is connected to Port uint32 OriginAddr string OriginPort uint32 } type tcpIpForwardPayload struct { Addr string Port uint32 } type tcpIpForwardPayloadReply struct { Port uint32 } type tcpIpForwardCancelPayload struct { Addr string Port uint32 } // Function that can be used to implement calls to SetDeadline() after // read/writes in copyTimeout() type TimeoutFunc func() func main() { flag.Parse() config := &ssh.ServerConfig{ PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { authmutex.Lock() defer authmutex.Unlock() if clientinfo, found := authorisedKeys[string(key.Marshal())]; found { return &ssh.Permissions{ CriticalOptions: map[string]string{"name": clientinfo.Comment, "localports": clientinfo.LocalPorts, "remoteports": clientinfo.RemotePorts}, }, nil } return nil, fmt.Errorf("Unknown public key\n") }, } loadHostKeys(config) loadAuthorisedKeys(*authorisedkeys) registerReloadSignal() bind := fmt.Sprintf("[%s]:%d", *listenaddr, *listenport) listener, err := net.Listen("tcp", bind) if err != nil { log.Fatalf("Failed to listen on %s (%s)", listenport, err) } // Accept all connections log.Printf("Listening on %d...", *listenport) for { tcpConn, err := listener.Accept() if err != nil { log.Printf("Failed to accept incoming connection (%s)", err) continue } tcpConn.SetDeadline(time.Now().Add(*maintimeout)) // We perform the ssh handshake in a goroutine so the handshake cannot // block incoming connections. go func() { sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, config) if err != nil { log.Printf("Failed to handshake: %s (rip: %v)", err, tcpConn.RemoteAddr()) return } client := sshClient{sshConn.Permissions.CriticalOptions["name"], tcpConn, sshConn, make(map[string]net.Listener), nil, nil, false, sync.Mutex{}} allowedLocalPorts := sshConn.Permissions.CriticalOptions["localports"] allowedRemotePorts := sshConn.Permissions.CriticalOptions["remoteports"] if *verbose { log.Printf("[%s] Connection from %s (%s). Allowed local ports: %s remote ports: %s", client.Name, sshConn.RemoteAddr(), sshConn.ClientVersion(), allowedLocalPorts, allowedRemotePorts) } // Parsing a second time should not error, so we can ignore the error // safely client.AllowedLocalPorts, _ = parsePorts(allowedLocalPorts) client.AllowedRemotePorts, _ = parsePorts(allowedRemotePorts) // Start the clean-up function: will wait for the socket to be // closed (either by remote, protocol or deadline/timeout) // and close any listeners if any go func() { err := client.SshConn.Wait() client.ListenMutex.Lock() defer client.ListenMutex.Unlock() client.Stopping = true if *verbose { log.Printf("[%s] SSH connection closed: %s", client.Name, err) } for bind, listener := range client.Listeners { if *verbose { log.Printf("[%s] Closing listener bound to %s", client.Name, bind) } listener.Close() } }() // Accept requests & channels go handleRequest(&client, reqs) go handleChannels(&client, chans) }() } } func handleChannels(client *sshClient, chans <-chan ssh.NewChannel) { for c := range chans { go handleChannel(client, c) } } func handleChannel(client *sshClient, newChannel ssh.NewChannel) { if *debug { log.Printf("[%s] Channel type: %v", client.Name, newChannel.ChannelType()) } if t := newChannel.ChannelType(); t == "direct-tcpip" { handleDirect(client, newChannel) return } newChannel.Reject(ssh.Prohibited, "Only \"direct-tcpip\" is accepted") /* // XXX: Use this only for testing purposes -- I add this in if/when I // want to use the ssh escape sequences from ssh (those only work in an // interactive session) c, _, err := newChannel.Accept() if err != nil { log.Fatal(err) } go func() { d := make([]byte, 4096) c.Read(d) }() */ return } func handleDirect(client *sshClient, newChannel ssh.NewChannel) { var payload directTCPPayload if err := ssh.Unmarshal(newChannel.ExtraData(), &payload); err != nil { log.Printf("[%s] Could not unmarshal extra data: %s", client.Name, err) newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Bad payload")) return } /* // XXX: Is this sensible? if payload.Addr != "localhost" && payload.Addr != "::1" && payload.Addr != "127.0.0.1" { log.Printf("[%s] Tried to connect to prohibited host: %s", client.Name, payload.Addr) newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Bad addr")) return } */ if !portPermitted(payload.Port, client.AllowedLocalPorts) { newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Bad port")) log.Printf("[%s] Tried to connect to prohibited port: %d", client.Name, payload.Port) return } connection, requests, err := newChannel.Accept() if err != nil { log.Printf("[%s] Could not accept channel (%s)", client.Name, err) return } go ssh.DiscardRequests(requests) addr := fmt.Sprintf("[%s]:%d", payload.Addr, payload.Port) if *verbose { log.Printf("[%s] Dialing: %s", client.Name, addr) } rconn, err := net.Dial("tcp", addr) if err != nil { log.Printf("[%s] Could not dial remote (%s)", client.Name, err) connection.Close() return } serve(connection, rconn, client, *directtimeout) } func handleTcpIpForward(client *sshClient, req *ssh.Request) (net.Listener, *bindInfo, error) { var payload tcpIpForwardPayload if err := ssh.Unmarshal(req.Payload, &payload); err != nil { log.Printf("[%s] Unable to unmarshal payload", client.Name) req.Reply(false, []byte{}) return nil, nil, fmt.Errorf("Unable to parse payload") } if *verbose { log.Printf("[%s] Request: %s %v %v", client.Name, req.Type, req.WantReply, payload) log.Printf("[%s] Request to listen on %s:%d", client.Name, payload.Addr, payload.Port) } if payload.Addr != "localhost" && payload.Addr != "" { log.Printf("[%s] Payload address is not \"localhost\" or empty: %s", client.Name, payload.Addr) req.Reply(false, []byte{}) return nil, nil, fmt.Errorf("Address is not permitted") } if !portPermitted(payload.Port, client.AllowedRemotePorts) { log.Printf("[%s] Port is not permitted: %d", client.Name, payload.Port) req.Reply(false, []byte{}) return nil, nil, fmt.Errorf("Port is not permitted") } laddr := payload.Addr lport := payload.Port bind := fmt.Sprintf("[%s]:%d", laddr, lport) ln, err := net.Listen("tcp", bind) if err != nil { log.Printf("[%s] Listen failed for %s", client.Name, bind) req.Reply(false, []byte{}) return nil, nil, err } // Tell client everything is OK reply := tcpIpForwardPayloadReply{lport} req.Reply(true, ssh.Marshal(&reply)) return ln, &bindInfo{bind, lport, laddr}, nil } func handleListener(client *sshClient, bindinfo *bindInfo, listener net.Listener) { // Start listening for connections for { lconn, err := listener.Accept() if err != nil { neterr := err.(net.Error) if neterr.Timeout() { log.Printf("[%s] Accept failed with timeout: %s", client.Name, err) continue } if neterr.Temporary() { log.Printf("[%s] Accept failed with temporary: %s", client.Name, err) continue } break } go handleForwardTcpIp(client, bindinfo, lconn) } } func handleForwardTcpIp(client *sshClient, bindinfo *bindInfo, lconn net.Conn) { remotetcpaddr := lconn.RemoteAddr().(*net.TCPAddr) raddr := remotetcpaddr.IP.String() rport := uint32(remotetcpaddr.Port) payload := forwardedTCPPayload{bindinfo.Addr, bindinfo.Port, raddr, uint32(rport)} mpayload := ssh.Marshal(&payload) // Open channel with client c, requests, err := client.SshConn.OpenChannel("forwarded-tcpip", mpayload) if err != nil { log.Printf("[%s] Unable to get channel: %s. Hanging up requesting party!", client.Name, err) lconn.Close() return } if *verbose { log.Printf("[%s] Channel opened for client", client.Name) } go ssh.DiscardRequests(requests) serve(c, lconn, client, *forwardedtimeout) } func handleTcpIPForwardCancel(client *sshClient, req *ssh.Request) { if *verbose { log.Printf("[%s] \"cancel-tcpip-forward\" called by client", client.Name) } var payload tcpIpForwardCancelPayload if err := ssh.Unmarshal(req.Payload, &payload); err != nil { log.Printf("[%s] Unable to unmarshal cancel payload", client.Name) req.Reply(false, []byte{}) } bound := fmt.Sprintf("%s:%d", payload.Addr, payload.Port) if listener, found := client.Listeners[bound]; found { listener.Close() delete(client.Listeners, bound) req.Reply(true, []byte{}) } req.Reply(false, []byte{}) } func handleRequest(client *sshClient, reqs <-chan *ssh.Request) { for req := range reqs { client.Conn.SetDeadline(time.Now().Add(*maintimeout)) if *debug { log.Printf("[%s] Out of band request: %v %v", client.Name, req.Type, req.WantReply) } // RFC4254: 7.1 for forwarding if req.Type == "tcpip-forward" { client.ListenMutex.Lock() /* If we are closing, do not set up a new listener */ if client.Stopping { client.ListenMutex.Unlock() req.Reply(false, []byte{}) continue } listener, bindinfo, err := handleTcpIpForward(client, req) if err != nil { client.ListenMutex.Unlock() continue } client.Listeners[bindinfo.Bound] = listener client.ListenMutex.Unlock() go handleListener(client, bindinfo, listener) continue } else if req.Type == "cancel-tcpip-forward" { client.ListenMutex.Lock() handleTcpIPForwardCancel(client, req) client.ListenMutex.Unlock() continue } else { // Discard everything else req.Reply(false, []byte{}) } } } func serve(cssh ssh.Channel, conn net.Conn, client *sshClient, timeout time.Duration) { close := func() { cssh.Close() conn.Close() if *verbose { log.Printf("[%s] Channel closed.", client.Name) } } var once sync.Once go func() { //io.Copy(cssh, conn) bytes_written, err := copyTimeout(cssh, conn, func() { if *debug { log.Printf("[%s] Updating deadline for direct|forwarded socket and main socket (sending data)", client.Name) } conn.SetDeadline(time.Now().Add(timeout)) client.Conn.SetDeadline(time.Now().Add(*maintimeout)) }) if err != nil { if *debug { log.Printf("[%s] copyTimeout failed with: %s", client.Name, err) } } if *verbose { log.Printf("[%s] Connection closed, bytes written: %d", client.Name, bytes_written) } once.Do(close) }() go func() { //io.Copy(conn, cssh) bytes_written, err := copyTimeout(conn, cssh, func() { if *debug { log.Printf("[%s] Updating deadline for direct|forwarded socket and main socket (received data)", client.Name) } conn.SetDeadline(time.Now().Add(timeout)) client.Conn.SetDeadline(time.Now().Add(*maintimeout)) }) if err != nil { if *debug { log.Printf("[%s] copyTimeout failed with: %s", client.Name, err) } } if *verbose { log.Printf("[%s] Connection closed, bytes written: %d", client.Name, bytes_written) } once.Do(close) }() } // Changed from pkg/io/io.go copyBuffer func copyTimeout(dst io.Writer, src io.Reader, timeout TimeoutFunc) (written int64, err error) { buf := make([]byte, 32*1024) for { nr, er := src.Read(buf) if nr > 0 { timeout() nw, ew := dst.Write(buf[0:nr]) if nw > 0 { written += int64(nw) } if ew != nil { err = ew break } if nr != nw { err = io.ErrShortWrite break } timeout() } if er != nil { if er != io.EOF { err = er } break } } return written, err } func loadHostKeys(config *ssh.ServerConfig) { privateBytes, err := ioutil.ReadFile(*hostkey) if err != nil { log.Fatal(fmt.Sprintf("Failed to load private key (%s)", *hostkey)) } private, err := ssh.ParsePrivateKey(privateBytes) if err != nil { log.Fatal("Failed to parse private key") } config.AddHostKey(private) } func loadAuthorisedKeys(authorisedkeys string) { authKeys := map[string]deviceInfo{} authKeysBytes, err := ioutil.ReadFile(authorisedkeys) if err != nil { log.Fatal("Cannot load authorised keys") } for len(authKeysBytes) > 0 { pubkey, comment, options, rest, err := ssh.ParseAuthorizedKey(authKeysBytes) if err != nil { log.Printf("Error parsing line: %s", err) authKeysBytes = rest continue } devinfo := deviceInfo{Comment: comment} // TODO: Compatibility with permitopen=foo,permitopen=bar, // permitremoteopen=quux,permitremoteopen=wobble for _, option := range options { ports, err := parseOption(option, "localports") if err == nil { devinfo.LocalPorts = ports continue } ports, err = parseOption(option, "remoteports") if err == nil { devinfo.RemotePorts = ports continue } if *verbose { log.Println("Unknown option:", option) } } authKeys[string(pubkey.Marshal())] = devinfo authKeysBytes = rest } authmutex.Lock() defer authmutex.Unlock() authorisedKeys = authKeys } func portPermitted(port uint32, ports []uint32) bool { ok := false for _, p := range ports { if port == p { ok = true break } } return ok } func parseOption(option string, prefix string) (string, error) { str := fmt.Sprintf("%s=", prefix) if !strings.HasPrefix(option, str) { return "", fmt.Errorf("Option does not start with %s", str) } ports := option[len(str):] if _, err := parsePorts(ports); err != nil { log.Fatal(err) } return ports, nil } func parsePorts(portstr string) (p []uint32, err error) { ports := strings.Split(portstr, ":") for _, port := range ports { port, err := strconv.ParseUint(port, 10, 32) if err != nil { return p, err } p = append(p, uint32(port)) } return }