diff --git a/sshd.go b/sshd.go index f2a67ee..19cc790 100644 --- a/sshd.go +++ b/sshd.go @@ -9,6 +9,7 @@ package main // * https://github.com/tg123/sshpiper/commit/9db468b52dfc2cbe936efb7bef0fd5b88e0c1649 import ( + "flag" "fmt" "io" "io/ioutil" @@ -23,9 +24,15 @@ import ( var ( authorisedKeys map[string]string + + hostkey = flag.String("hostkey", "id_rsa", "Server host key to load") + authorisedkeys = flag.String("authorisedkeys", "authorized_keys", "Authorised keys") + verbose = flag.Bool("verbose", false, "Enable verbose mode") ) func main() { + flag.Parse() + config := &ssh.ServerConfig{ PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { if ports, found := authorisedKeys[string(key.Marshal())]; found { @@ -38,8 +45,7 @@ func main() { }, } - privateBytes, err := ioutil.ReadFile("id_ed25519") - //privateBytes, err := ioutil.ReadFile("id_rsa") + privateBytes, err := ioutil.ReadFile(*hostkey) if err != nil { log.Fatal("Failed to load private key (./id_rsa)") } @@ -50,7 +56,7 @@ func main() { } config.AddHostKey(private) - loadKeys() + loadAuthorisedKeys(*authorisedkeys) listener, err := net.Listen("tcp", "0.0.0.0:2200") if err != nil { @@ -74,7 +80,9 @@ func main() { allowedPorts := sshConn.Permissions.CriticalOptions["ports"] - log.Printf("Connection from %s (%s). Allowed ports: %s", sshConn.RemoteAddr(), sshConn.ClientVersion(), allowedPorts) + if *verbose { + log.Printf("Connection from %s (%s). Allowed ports: %s", sshConn.RemoteAddr(), sshConn.ClientVersion(), allowedPorts) + } // Parsing a second time should not error ports, _ := parsePorts(allowedPorts) @@ -101,7 +109,9 @@ type directTCPPayload struct { } func handleChannel(newChannel ssh.NewChannel, ports []uint32) { - log.Println("Channel type:", newChannel.ChannelType()) + if *verbose { + log.Println("Channel type:", newChannel.ChannelType()) + } if t := newChannel.ChannelType(); t != "direct-tcpip" { newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Only \"direct-tcpip\" is accepted")) return @@ -115,7 +125,6 @@ func handleChannel(newChannel ssh.NewChannel, ports []uint32) { return } - //log.Println("Got payload: %v", payload) if payload.Addr != "localhost" { newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Bad addr")) return @@ -144,7 +153,9 @@ func handleChannel(newChannel ssh.NewChannel, ports []uint32) { } addr := fmt.Sprintf("%s:%d", payload.Addr, payload.Port) - log.Println("Going to dial:", addr) + if *verbose { + log.Println("Going to dial:", addr) + } rconn, err := net.Dial("tcp", addr) if err != nil { @@ -156,7 +167,9 @@ func handleChannel(newChannel ssh.NewChannel, ports []uint32) { close := func() { connection.Close() rconn.Close() - log.Printf("Session closed") + if *verbose { + log.Printf("Session closed") + } } var once sync.Once @@ -182,9 +195,9 @@ func parsePorts(portstr string) (p []uint32, err error) { return } -func loadKeys() { +func loadAuthorisedKeys(authorisedkeys string) { authorisedKeys = map[string]string{} - authorisedKeysBytes, err := ioutil.ReadFile("authorized_keys") + authorisedKeysBytes, err := ioutil.ReadFile(authorisedkeys) if err != nil { log.Fatal("Cannot load authorised keys") }