From 8cdccd48e4c3719a5f8fdfdd29c975e8f352582c Mon Sep 17 00:00:00 2001 From: "Merlijn B. W. Wajer" Date: Fri, 10 Mar 2017 22:52:14 +0100 Subject: [PATCH] Allow authkeys reloading with SIGUSR1 --- sshd.go | 43 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 35 insertions(+), 8 deletions(-) diff --git a/sshd.go b/sshd.go index dc8036c..3d612c6 100644 --- a/sshd.go +++ b/sshd.go @@ -11,9 +11,12 @@ import ( "io/ioutil" "log" "net" + "os" + "os/signal" "strconv" "strings" "sync" + "syscall" "golang.org/x/crypto/ssh" ) @@ -28,6 +31,8 @@ var ( hostkey = flag.String("hostkey", "id_rsa", "Server host key to load") authorisedkeys = flag.String("authorisedkeys", "authorized_keys", "Authorised keys") verbose = flag.Bool("verbose", false, "Enable verbose mode") + + authmutex sync.Mutex ) type sshClient struct { @@ -85,6 +90,8 @@ func main() { config := &ssh.ServerConfig{ PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { if deviceinfo, found := authorisedKeys[string(key.Marshal())]; found { + authmutex.Lock() + defer authmutex.Unlock() return &ssh.Permissions{ CriticalOptions: map[string]string{"name": deviceinfo.Comment, "localports": deviceinfo.LocalPorts, @@ -99,6 +106,8 @@ func main() { loadHostKeys(config) loadAuthorisedKeys(*authorisedkeys) + registerReloadSignal() + listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", *listenaddr, *listenport)) if err != nil { log.Fatalf("Failed to listen on %s (%s)", listenport, err) @@ -125,7 +134,7 @@ func main() { allowedRemotePorts := sshConn.Permissions.CriticalOptions["remoteports"] if *verbose { - log.Printf("Connection from %s (%s). Allowed local ports: %s remote ports: %s", sshConn.RemoteAddr(), sshConn.ClientVersion(), allowedLocalPorts, allowedRemotePorts) + log.Printf("Connection from %s, %s (%s). Allowed local ports: %s remote ports: %s", client.Name, sshConn.RemoteAddr(), sshConn.ClientVersion(), allowedLocalPorts, allowedRemotePorts) } // Parsing a second time should not error, so we can ignore the error @@ -377,18 +386,18 @@ func loadHostKeys(config *ssh.ServerConfig) { } func loadAuthorisedKeys(authorisedkeys string) { - authorisedKeys = map[string]deviceInfo{} - authorisedKeysBytes, err := ioutil.ReadFile(authorisedkeys) + authKeys := map[string]deviceInfo{} + authKeysBytes, err := ioutil.ReadFile(authorisedkeys) if err != nil { log.Fatal("Cannot load authorised keys") } - for len(authorisedKeysBytes) > 0 { - pubkey, comment, options, rest, err := ssh.ParseAuthorizedKey(authorisedKeysBytes) + for len(authKeysBytes) > 0 { + pubkey, comment, options, rest, err := ssh.ParseAuthorizedKey(authKeysBytes) if err != nil { log.Printf("Error parsing line: %s", err) - authorisedKeysBytes = rest + authKeysBytes = rest continue } @@ -412,10 +421,28 @@ func loadAuthorisedKeys(authorisedkeys string) { } } - authorisedKeys[string(pubkey.Marshal())] = devinfo + authKeys[string(pubkey.Marshal())] = devinfo - authorisedKeysBytes = rest + authKeysBytes = rest } + + authmutex.Lock() + defer authmutex.Unlock() + authorisedKeys = authKeys +} + +func registerReloadSignal() { + c := make(chan os.Signal) + signal.Notify(c, syscall.SIGUSR1) + + go func() { + for sig := range c { + _ = sig + log.Printf("Received signal: \"%s\". Reloading authorised keys.", sig.String()) + loadAuthorisedKeys(*authorisedkeys) + } + + }() } func handleRequest(client *sshClient, reqs <-chan *ssh.Request) {