diff --git a/signal_unix.go b/signal_unix.go new file mode 100644 index 0000000..5f54945 --- /dev/null +++ b/signal_unix.go @@ -0,0 +1,26 @@ +package main + +import ( + "log" + "os" + "os/signal" + "syscall" +) + +// +build !windows +func registerReloadSignal() { + c := make(chan os.Signal) + signal.Notify(c, syscall.SIGUSR1) + + go func() { + for sig := range c { + if sig == syscall.SIGUSR1 { + log.Printf("Received signal: SIGUSR1. Reloading authorised keys.") + loadAuthorisedKeys(*authorisedkeys) + } else { + log.Printf("Received unexpected signal: \"%s\".", sig.String()) + } + } + + }() +} diff --git a/signal_windows.go b/signal_windows.go new file mode 100644 index 0000000..87b19f4 --- /dev/null +++ b/signal_windows.go @@ -0,0 +1,5 @@ +package main + +// +build windows +func registerReloadSignal() { +} diff --git a/sshd.go b/sshd.go index 34a96cc..ac28ac5 100644 --- a/sshd.go +++ b/sshd.go @@ -11,18 +11,16 @@ import ( "io/ioutil" "log" "net" - "os" - "os/signal" "strconv" "strings" "sync" - "syscall" "time" "golang.org/x/crypto/ssh" ) var ( + // Contains a mapping of authorised keys to permissions for said key authorisedKeys map[string]deviceInfo listenaddr = flag.String("listenaddr", "0.0.0.0", "Addr to listen on for incoming ssh connections") @@ -32,34 +30,48 @@ var ( verbose = flag.Bool("verbose", false, "Enable verbose mode") debug = flag.Bool("debug", false, "Enable debug mode") - // TODO: Separate for read/write? (Right now assume that on either - // read/write, reset deadline) + // Currently the timeouts are not separate for read and write deadlines. + // This could be done, but I currently don't really see a reason for this. maintimeout = flag.Duration("main-timeout", time.Duration(2)*time.Minute, "Client socket timeout") directtimeout = flag.Duration("direct-timeout", time.Duration(2)*time.Minute, "direct-tcpip timeout") forwardedtimeout = flag.Duration("forwarded-timeout", time.Duration(2)*time.Minute, "forwarded-tcpip timeout") + // Mutex protecting 'authorisedKeys' map authmutex sync.Mutex ) -type TimeoutFunc func() - +// Structure that holds all information for each connection/client type sshClient struct { - Name string - Conn net.Conn - SshConn *ssh.ServerConn - Listeners map[string]net.Listener + Name string + + // We keep track of the normal Conn as well so that we have access to the + // SetDeadline() methods + Conn net.Conn + + SshConn *ssh.ServerConn + + // Listener sockets opened by the client + Listeners map[string]net.Listener + AllowedLocalPorts []uint32 AllowedRemotePorts []uint32 - Stopping bool - ListenMutex sync.Mutex + + // This indicates that a client is shutting down. When a client is stopping, + // we do not allow new listening requests, to prevent a listener connection + // being opened just after we closed all of them. + Stopping bool + ListenMutex sync.Mutex } +// Structure containing what address/port we should bind on, for forwarded-tcpip +// connections type bindInfo struct { Bound string Port uint32 Addr string } +// Information parsed from the authorized_keys file type deviceInfo struct { LocalPorts string RemotePorts string @@ -95,6 +107,10 @@ type tcpIpForwardCancelPayload struct { Port uint32 } +// Function that can be used to implement calls to SetDeadline() after +// read/writes in copyTimeout() +type TimeoutFunc func() + func main() { flag.Parse() @@ -134,6 +150,9 @@ func main() { } tcpConn.SetDeadline(time.Now().Add(*maintimeout)) + + // We perform the ssh handshake in a goroutine so the handshake cannot + // block incoming connections. go func() { sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, config) if err != nil { @@ -171,8 +190,8 @@ func main() { client.ListenMutex.Unlock() }() + // Accept requests & channels go handleRequest(&client, reqs) - // Accept all channels go handleChannels(&client, chans) }() } @@ -185,7 +204,7 @@ func handleChannels(client *sshClient, chans <-chan ssh.NewChannel) { } func handleChannel(client *sshClient, newChannel ssh.NewChannel) { - if *verbose { + if *debug { log.Printf("[%s] Channel type: %v", client.Name, newChannel.ChannelType()) } if t := newChannel.ChannelType(); t == "direct-tcpip" { @@ -193,9 +212,11 @@ func handleChannel(client *sshClient, newChannel ssh.NewChannel) { return } - newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Only \"direct-tcpip\", \"forwarded-tcpip\" and \"cancel-tcpip-forward\" are accepted")) + newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Only \"direct-tcpip\" is accepted")) /* - // TODO: USE THIS ONLY FOR USING SSH ESCAPE SEQUENCES + // XXX: Use this only for testing purposes -- I add this in if/when I + // want to use the ssh escape sequences from ssh (those only work in an + // interactive session) c, _, err := newChannel.Accept() if err != nil { log.Fatal(err) @@ -365,6 +386,47 @@ func handleTcpIPForwardCancel(client *sshClient, req *ssh.Request) { req.Reply(false, []byte{}) } +func handleRequest(client *sshClient, reqs <-chan *ssh.Request) { + for req := range reqs { + client.Conn.SetDeadline(time.Now().Add(*maintimeout)) + + if *debug { + log.Printf("[%s] Out of band request: %v %v", client.Name, req.Type, req.WantReply) + } + + // RFC4254: 7.1 for forwarding + if req.Type == "tcpip-forward" { + client.ListenMutex.Lock() + /* If we are closing, do not set up a new listener */ + if client.Stopping { + client.ListenMutex.Unlock() + req.Reply(false, []byte{}) + continue + } + + listener, bindinfo, err := handleTcpIpForward(client, req) + if err != nil { + client.ListenMutex.Unlock() + continue + } + + client.Listeners[bindinfo.Bound] = listener + client.ListenMutex.Unlock() + + go handleListener(client, bindinfo, listener) + continue + } else if req.Type == "cancel-tcpip-forward" { + client.ListenMutex.Lock() + handleTcpIPForwardCancel(client, req) + client.ListenMutex.Unlock() + continue + } else { + // Discard everything else + req.Reply(false, []byte{}) + } + } +} + func serve(cssh ssh.Channel, conn net.Conn, client *sshClient, timeout time.Duration) { // TODO: Maybe just do this with defer instead? (And only one copy in a // goroutine) @@ -511,64 +573,6 @@ func loadAuthorisedKeys(authorisedkeys string) { authorisedKeys = authKeys } -func registerReloadSignal() { - c := make(chan os.Signal) - signal.Notify(c, syscall.SIGUSR1) - - go func() { - for sig := range c { - if sig == syscall.SIGUSR1 { - log.Printf("Received signal: SIGUSR1. Reloading authorised keys.") - loadAuthorisedKeys(*authorisedkeys) - } else { - log.Printf("Received unexpected signal: \"%s\".", sig.String()) - } - } - - }() -} - -func handleRequest(client *sshClient, reqs <-chan *ssh.Request) { - for req := range reqs { - client.Conn.SetDeadline(time.Now().Add(*maintimeout)) - - if *verbose { - log.Printf("[%s] Out of band request: %v %v", client.Name, req.Type, req.WantReply) - } - - // RFC4254: 7.1 for forwarding - if req.Type == "tcpip-forward" { - client.ListenMutex.Lock() - /* If we are closing, do not set up a new listener */ - if client.Stopping { - client.ListenMutex.Unlock() - req.Reply(false, []byte{}) - continue - } - - listener, bindinfo, err := handleTcpIpForward(client, req) - if err != nil { - client.ListenMutex.Unlock() - continue - } - - client.Listeners[bindinfo.Bound] = listener - client.ListenMutex.Unlock() - - go handleListener(client, bindinfo, listener) - continue - } else if req.Type == "cancel-tcpip-forward" { - client.ListenMutex.Lock() - handleTcpIPForwardCancel(client, req) - client.ListenMutex.Unlock() - continue - } else { - // Discard everything else - req.Reply(false, []byte{}) - } - } -} - func portPermitted(port uint32, ports []uint32) bool { ok := false for _, p := range ports {