diff --git a/sshd.go b/sshd.go index d71609f..01f71bb 100644 --- a/sshd.go +++ b/sshd.go @@ -17,6 +17,7 @@ import ( "strings" "sync" "syscall" + "time" "golang.org/x/crypto/ssh" ) @@ -30,12 +31,21 @@ var ( authorisedkeys = flag.String("authorisedkeys", "authorized_keys", "Authorised keys") verbose = flag.Bool("verbose", false, "Enable verbose mode") + // TODO: Separate for read/write? (Right now assume that on either + // read/write, reset deadline) + maintimeout = flag.Duration("main-timeout", time.Duration(2)*time.Minute, "Client socket timeout") + directtimeout = flag.Duration("direct-timeout", time.Duration(2)*time.Minute, "direct-tcpip timeout") + forwardedtimeout = flag.Duration("forwarded-timeout", time.Duration(2)*time.Minute, "forwarded-tcpip timeout") + authmutex sync.Mutex ) +type TimeoutFunc func() + type sshClient struct { Name string - Conn *ssh.ServerConn + Conn net.Conn + SshConn *ssh.ServerConn Listeners map[string]net.Listener AllowedLocalPorts []uint32 AllowedRemotePorts []uint32 @@ -122,6 +132,7 @@ func main() { continue } + tcpConn.SetDeadline(time.Now().Add(*maintimeout)) go func() { sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, config) if err != nil { @@ -129,7 +140,7 @@ func main() { return } - client := sshClient{sshConn.Permissions.CriticalOptions["name"], sshConn, make(map[string]net.Listener), nil, nil, false, sync.Mutex{}} + client := sshClient{sshConn.Permissions.CriticalOptions["name"], tcpConn, sshConn, make(map[string]net.Listener), nil, nil, false, sync.Mutex{}} allowedLocalPorts := sshConn.Permissions.CriticalOptions["localports"] allowedRemotePorts := sshConn.Permissions.CriticalOptions["remoteports"] @@ -143,7 +154,7 @@ func main() { client.AllowedRemotePorts, _ = parsePorts(allowedRemotePorts) go func() { - err := client.Conn.Wait() + err := client.SshConn.Wait() client.ListenMutex.Lock() client.Stopping = true @@ -240,7 +251,7 @@ func handleDirect(client *sshClient, newChannel ssh.NewChannel) { return } - serve(connection, rconn, client) + serve(connection, rconn, client, *directtimeout) } func handleTcpIpForward(client *sshClient, req *ssh.Request) (net.Listener, *bindInfo, error) { @@ -318,7 +329,7 @@ func handleForwardTcpIp(client *sshClient, bindinfo *bindInfo, lconn net.Conn) { mpayload := ssh.Marshal(&payload) // Open channel with client - c, requests, err := client.Conn.OpenChannel("forwarded-tcpip", mpayload) + c, requests, err := client.SshConn.OpenChannel("forwarded-tcpip", mpayload) if err != nil { log.Printf("[%s] Unable to get channel: %s. Hanging up requesting party!", client.Name, err) lconn.Close() @@ -329,7 +340,7 @@ func handleForwardTcpIp(client *sshClient, bindinfo *bindInfo, lconn net.Conn) { } go ssh.DiscardRequests(requests) - serve(c, lconn, client) + serve(c, lconn, client, *forwardedtimeout) } func handleTcpIPForwardCancel(client *sshClient, req *ssh.Request) { @@ -353,7 +364,7 @@ func handleTcpIPForwardCancel(client *sshClient, req *ssh.Request) { req.Reply(false, []byte{}) } -func serve(cssh ssh.Channel, conn net.Conn, client *sshClient) { +func serve(cssh ssh.Channel, conn net.Conn, client *sshClient, timeout time.Duration) { // TODO: Maybe just do this with defer instead? (And only one copy in a // goroutine) close := func() { @@ -364,17 +375,59 @@ func serve(cssh ssh.Channel, conn net.Conn, client *sshClient) { } } + // TODO: Share timeout between both the read-conn and the write-conn var once sync.Once go func() { - io.Copy(cssh, conn) + //io.Copy(cssh, conn) + _, _ = copyTimeout(cssh, conn, func() { + conn.SetDeadline(time.Now().Add(timeout)) + client.Conn.SetDeadline(time.Now().Add(*maintimeout)) + }) once.Do(close) }() go func() { - io.Copy(conn, cssh) + //io.Copy(conn, cssh) + _, _ = copyTimeout(conn, cssh, func() { + conn.SetDeadline(time.Now().Add(timeout)) + client.Conn.SetDeadline(time.Now().Add(*maintimeout)) + }) once.Do(close) }() } +// Changed from pkg/io/io.go copyBuffer +func copyTimeout(dst io.Writer, src io.Reader, timeout TimeoutFunc) (written int64, err error) { + buf := make([]byte, 32*1024) + + for { + nr, er := src.Read(buf) + if nr > 0 { + timeout() + + nw, ew := dst.Write(buf[0:nr]) + if nw > 0 { + written += int64(nw) + } + if ew != nil { + err = ew + break + } + if nr != nw { + err = io.ErrShortWrite + break + } + timeout() + } + if er != nil { + if er != io.EOF { + err = er + } + break + } + } + return written, err +} + func loadHostKeys(config *ssh.ServerConfig) { privateBytes, err := ioutil.ReadFile(*hostkey) if err != nil { @@ -454,6 +507,8 @@ func registerReloadSignal() { func handleRequest(client *sshClient, reqs <-chan *ssh.Request) { for req := range reqs { + client.Conn.SetDeadline(time.Now().Add(*maintimeout)) + if *verbose { log.Printf("[%s] Out of band request: %v %v", client.Name, req.Type, req.WantReply) }