Clean up code and more comments

master
Merlijn Wajer 7 years ago
parent 1f1df93791
commit 9f8a9042fc
  1. 26
      signal_unix.go
  2. 5
      signal_windows.go
  3. 154
      sshd.go

@ -0,0 +1,26 @@
package main
import (
"log"
"os"
"os/signal"
"syscall"
)
// +build !windows
func registerReloadSignal() {
c := make(chan os.Signal)
signal.Notify(c, syscall.SIGUSR1)
go func() {
for sig := range c {
if sig == syscall.SIGUSR1 {
log.Printf("Received signal: SIGUSR1. Reloading authorised keys.")
loadAuthorisedKeys(*authorisedkeys)
} else {
log.Printf("Received unexpected signal: \"%s\".", sig.String())
}
}
}()
}

@ -0,0 +1,5 @@
package main
// +build windows
func registerReloadSignal() {
}

@ -11,18 +11,16 @@ import (
"io/ioutil"
"log"
"net"
"os"
"os/signal"
"strconv"
"strings"
"sync"
"syscall"
"time"
"golang.org/x/crypto/ssh"
)
var (
// Contains a mapping of authorised keys to permissions for said key
authorisedKeys map[string]deviceInfo
listenaddr = flag.String("listenaddr", "0.0.0.0", "Addr to listen on for incoming ssh connections")
@ -32,34 +30,48 @@ var (
verbose = flag.Bool("verbose", false, "Enable verbose mode")
debug = flag.Bool("debug", false, "Enable debug mode")
// TODO: Separate for read/write? (Right now assume that on either
// read/write, reset deadline)
// Currently the timeouts are not separate for read and write deadlines.
// This could be done, but I currently don't really see a reason for this.
maintimeout = flag.Duration("main-timeout", time.Duration(2)*time.Minute, "Client socket timeout")
directtimeout = flag.Duration("direct-timeout", time.Duration(2)*time.Minute, "direct-tcpip timeout")
forwardedtimeout = flag.Duration("forwarded-timeout", time.Duration(2)*time.Minute, "forwarded-tcpip timeout")
// Mutex protecting 'authorisedKeys' map
authmutex sync.Mutex
)
type TimeoutFunc func()
// Structure that holds all information for each connection/client
type sshClient struct {
Name string
Conn net.Conn
SshConn *ssh.ServerConn
Listeners map[string]net.Listener
Name string
// We keep track of the normal Conn as well so that we have access to the
// SetDeadline() methods
Conn net.Conn
SshConn *ssh.ServerConn
// Listener sockets opened by the client
Listeners map[string]net.Listener
AllowedLocalPorts []uint32
AllowedRemotePorts []uint32
Stopping bool
ListenMutex sync.Mutex
// This indicates that a client is shutting down. When a client is stopping,
// we do not allow new listening requests, to prevent a listener connection
// being opened just after we closed all of them.
Stopping bool
ListenMutex sync.Mutex
}
// Structure containing what address/port we should bind on, for forwarded-tcpip
// connections
type bindInfo struct {
Bound string
Port uint32
Addr string
}
// Information parsed from the authorized_keys file
type deviceInfo struct {
LocalPorts string
RemotePorts string
@ -95,6 +107,10 @@ type tcpIpForwardCancelPayload struct {
Port uint32
}
// Function that can be used to implement calls to SetDeadline() after
// read/writes in copyTimeout()
type TimeoutFunc func()
func main() {
flag.Parse()
@ -134,6 +150,9 @@ func main() {
}
tcpConn.SetDeadline(time.Now().Add(*maintimeout))
// We perform the ssh handshake in a goroutine so the handshake cannot
// block incoming connections.
go func() {
sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, config)
if err != nil {
@ -171,8 +190,8 @@ func main() {
client.ListenMutex.Unlock()
}()
// Accept requests & channels
go handleRequest(&client, reqs)
// Accept all channels
go handleChannels(&client, chans)
}()
}
@ -185,7 +204,7 @@ func handleChannels(client *sshClient, chans <-chan ssh.NewChannel) {
}
func handleChannel(client *sshClient, newChannel ssh.NewChannel) {
if *verbose {
if *debug {
log.Printf("[%s] Channel type: %v", client.Name, newChannel.ChannelType())
}
if t := newChannel.ChannelType(); t == "direct-tcpip" {
@ -193,9 +212,11 @@ func handleChannel(client *sshClient, newChannel ssh.NewChannel) {
return
}
newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Only \"direct-tcpip\", \"forwarded-tcpip\" and \"cancel-tcpip-forward\" are accepted"))
newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Only \"direct-tcpip\" is accepted"))
/*
// TODO: USE THIS ONLY FOR USING SSH ESCAPE SEQUENCES
// XXX: Use this only for testing purposes -- I add this in if/when I
// want to use the ssh escape sequences from ssh (those only work in an
// interactive session)
c, _, err := newChannel.Accept()
if err != nil {
log.Fatal(err)
@ -365,6 +386,47 @@ func handleTcpIPForwardCancel(client *sshClient, req *ssh.Request) {
req.Reply(false, []byte{})
}
func handleRequest(client *sshClient, reqs <-chan *ssh.Request) {
for req := range reqs {
client.Conn.SetDeadline(time.Now().Add(*maintimeout))
if *debug {
log.Printf("[%s] Out of band request: %v %v", client.Name, req.Type, req.WantReply)
}
// RFC4254: 7.1 for forwarding
if req.Type == "tcpip-forward" {
client.ListenMutex.Lock()
/* If we are closing, do not set up a new listener */
if client.Stopping {
client.ListenMutex.Unlock()
req.Reply(false, []byte{})
continue
}
listener, bindinfo, err := handleTcpIpForward(client, req)
if err != nil {
client.ListenMutex.Unlock()
continue
}
client.Listeners[bindinfo.Bound] = listener
client.ListenMutex.Unlock()
go handleListener(client, bindinfo, listener)
continue
} else if req.Type == "cancel-tcpip-forward" {
client.ListenMutex.Lock()
handleTcpIPForwardCancel(client, req)
client.ListenMutex.Unlock()
continue
} else {
// Discard everything else
req.Reply(false, []byte{})
}
}
}
func serve(cssh ssh.Channel, conn net.Conn, client *sshClient, timeout time.Duration) {
// TODO: Maybe just do this with defer instead? (And only one copy in a
// goroutine)
@ -511,64 +573,6 @@ func loadAuthorisedKeys(authorisedkeys string) {
authorisedKeys = authKeys
}
func registerReloadSignal() {
c := make(chan os.Signal)
signal.Notify(c, syscall.SIGUSR1)
go func() {
for sig := range c {
if sig == syscall.SIGUSR1 {
log.Printf("Received signal: SIGUSR1. Reloading authorised keys.")
loadAuthorisedKeys(*authorisedkeys)
} else {
log.Printf("Received unexpected signal: \"%s\".", sig.String())
}
}
}()
}
func handleRequest(client *sshClient, reqs <-chan *ssh.Request) {
for req := range reqs {
client.Conn.SetDeadline(time.Now().Add(*maintimeout))
if *verbose {
log.Printf("[%s] Out of band request: %v %v", client.Name, req.Type, req.WantReply)
}
// RFC4254: 7.1 for forwarding
if req.Type == "tcpip-forward" {
client.ListenMutex.Lock()
/* If we are closing, do not set up a new listener */
if client.Stopping {
client.ListenMutex.Unlock()
req.Reply(false, []byte{})
continue
}
listener, bindinfo, err := handleTcpIpForward(client, req)
if err != nil {
client.ListenMutex.Unlock()
continue
}
client.Listeners[bindinfo.Bound] = listener
client.ListenMutex.Unlock()
go handleListener(client, bindinfo, listener)
continue
} else if req.Type == "cancel-tcpip-forward" {
client.ListenMutex.Lock()
handleTcpIPForwardCancel(client, req)
client.ListenMutex.Unlock()
continue
} else {
// Discard everything else
req.Reply(false, []byte{})
}
}
}
func portPermitted(port uint32, ports []uint32) bool {
ok := false
for _, p := range ports {

Loading…
Cancel
Save