Add basic support for timeouts

master
Merlijn Wajer 7 years ago
parent 6d97caadfd
commit 1216e33e23
  1. 73
      sshd.go

@ -17,6 +17,7 @@ import (
"strings"
"sync"
"syscall"
"time"
"golang.org/x/crypto/ssh"
)
@ -30,12 +31,21 @@ var (
authorisedkeys = flag.String("authorisedkeys", "authorized_keys", "Authorised keys")
verbose = flag.Bool("verbose", false, "Enable verbose mode")
// TODO: Separate for read/write? (Right now assume that on either
// read/write, reset deadline)
maintimeout = flag.Duration("main-timeout", time.Duration(2)*time.Minute, "Client socket timeout")
directtimeout = flag.Duration("direct-timeout", time.Duration(2)*time.Minute, "direct-tcpip timeout")
forwardedtimeout = flag.Duration("forwarded-timeout", time.Duration(2)*time.Minute, "forwarded-tcpip timeout")
authmutex sync.Mutex
)
type TimeoutFunc func()
type sshClient struct {
Name string
Conn *ssh.ServerConn
Conn net.Conn
SshConn *ssh.ServerConn
Listeners map[string]net.Listener
AllowedLocalPorts []uint32
AllowedRemotePorts []uint32
@ -122,6 +132,7 @@ func main() {
continue
}
tcpConn.SetDeadline(time.Now().Add(*maintimeout))
go func() {
sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, config)
if err != nil {
@ -129,7 +140,7 @@ func main() {
return
}
client := sshClient{sshConn.Permissions.CriticalOptions["name"], sshConn, make(map[string]net.Listener), nil, nil, false, sync.Mutex{}}
client := sshClient{sshConn.Permissions.CriticalOptions["name"], tcpConn, sshConn, make(map[string]net.Listener), nil, nil, false, sync.Mutex{}}
allowedLocalPorts := sshConn.Permissions.CriticalOptions["localports"]
allowedRemotePorts := sshConn.Permissions.CriticalOptions["remoteports"]
@ -143,7 +154,7 @@ func main() {
client.AllowedRemotePorts, _ = parsePorts(allowedRemotePorts)
go func() {
err := client.Conn.Wait()
err := client.SshConn.Wait()
client.ListenMutex.Lock()
client.Stopping = true
@ -240,7 +251,7 @@ func handleDirect(client *sshClient, newChannel ssh.NewChannel) {
return
}
serve(connection, rconn, client)
serve(connection, rconn, client, *directtimeout)
}
func handleTcpIpForward(client *sshClient, req *ssh.Request) (net.Listener, *bindInfo, error) {
@ -318,7 +329,7 @@ func handleForwardTcpIp(client *sshClient, bindinfo *bindInfo, lconn net.Conn) {
mpayload := ssh.Marshal(&payload)
// Open channel with client
c, requests, err := client.Conn.OpenChannel("forwarded-tcpip", mpayload)
c, requests, err := client.SshConn.OpenChannel("forwarded-tcpip", mpayload)
if err != nil {
log.Printf("[%s] Unable to get channel: %s. Hanging up requesting party!", client.Name, err)
lconn.Close()
@ -329,7 +340,7 @@ func handleForwardTcpIp(client *sshClient, bindinfo *bindInfo, lconn net.Conn) {
}
go ssh.DiscardRequests(requests)
serve(c, lconn, client)
serve(c, lconn, client, *forwardedtimeout)
}
func handleTcpIPForwardCancel(client *sshClient, req *ssh.Request) {
@ -353,7 +364,7 @@ func handleTcpIPForwardCancel(client *sshClient, req *ssh.Request) {
req.Reply(false, []byte{})
}
func serve(cssh ssh.Channel, conn net.Conn, client *sshClient) {
func serve(cssh ssh.Channel, conn net.Conn, client *sshClient, timeout time.Duration) {
// TODO: Maybe just do this with defer instead? (And only one copy in a
// goroutine)
close := func() {
@ -364,17 +375,59 @@ func serve(cssh ssh.Channel, conn net.Conn, client *sshClient) {
}
}
// TODO: Share timeout between both the read-conn and the write-conn
var once sync.Once
go func() {
io.Copy(cssh, conn)
//io.Copy(cssh, conn)
_, _ = copyTimeout(cssh, conn, func() {
conn.SetDeadline(time.Now().Add(timeout))
client.Conn.SetDeadline(time.Now().Add(*maintimeout))
})
once.Do(close)
}()
go func() {
io.Copy(conn, cssh)
//io.Copy(conn, cssh)
_, _ = copyTimeout(conn, cssh, func() {
conn.SetDeadline(time.Now().Add(timeout))
client.Conn.SetDeadline(time.Now().Add(*maintimeout))
})
once.Do(close)
}()
}
// Changed from pkg/io/io.go copyBuffer
func copyTimeout(dst io.Writer, src io.Reader, timeout TimeoutFunc) (written int64, err error) {
buf := make([]byte, 32*1024)
for {
nr, er := src.Read(buf)
if nr > 0 {
timeout()
nw, ew := dst.Write(buf[0:nr])
if nw > 0 {
written += int64(nw)
}
if ew != nil {
err = ew
break
}
if nr != nw {
err = io.ErrShortWrite
break
}
timeout()
}
if er != nil {
if er != io.EOF {
err = er
}
break
}
}
return written, err
}
func loadHostKeys(config *ssh.ServerConfig) {
privateBytes, err := ioutil.ReadFile(*hostkey)
if err != nil {
@ -454,6 +507,8 @@ func registerReloadSignal() {
func handleRequest(client *sshClient, reqs <-chan *ssh.Request) {
for req := range reqs {
client.Conn.SetDeadline(time.Now().Add(*maintimeout))
if *verbose {
log.Printf("[%s] Out of band request: %v %v", client.Name, req.Type, req.WantReply)
}

Loading…
Cancel
Save