Add flag parsing

master
Merlijn B. W. Wajer 7 years ago
parent bbbc8fad88
commit a89882af11
  1. 33
      sshd.go

@ -9,6 +9,7 @@ package main
// * https://github.com/tg123/sshpiper/commit/9db468b52dfc2cbe936efb7bef0fd5b88e0c1649
import (
"flag"
"fmt"
"io"
"io/ioutil"
@ -23,9 +24,15 @@ import (
var (
authorisedKeys map[string]string
hostkey = flag.String("hostkey", "id_rsa", "Server host key to load")
authorisedkeys = flag.String("authorisedkeys", "authorized_keys", "Authorised keys")
verbose = flag.Bool("verbose", false, "Enable verbose mode")
)
func main() {
flag.Parse()
config := &ssh.ServerConfig{
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
if ports, found := authorisedKeys[string(key.Marshal())]; found {
@ -38,8 +45,7 @@ func main() {
},
}
privateBytes, err := ioutil.ReadFile("id_ed25519")
//privateBytes, err := ioutil.ReadFile("id_rsa")
privateBytes, err := ioutil.ReadFile(*hostkey)
if err != nil {
log.Fatal("Failed to load private key (./id_rsa)")
}
@ -50,7 +56,7 @@ func main() {
}
config.AddHostKey(private)
loadKeys()
loadAuthorisedKeys(*authorisedkeys)
listener, err := net.Listen("tcp", "0.0.0.0:2200")
if err != nil {
@ -74,7 +80,9 @@ func main() {
allowedPorts := sshConn.Permissions.CriticalOptions["ports"]
log.Printf("Connection from %s (%s). Allowed ports: %s", sshConn.RemoteAddr(), sshConn.ClientVersion(), allowedPorts)
if *verbose {
log.Printf("Connection from %s (%s). Allowed ports: %s", sshConn.RemoteAddr(), sshConn.ClientVersion(), allowedPorts)
}
// Parsing a second time should not error
ports, _ := parsePorts(allowedPorts)
@ -101,7 +109,9 @@ type directTCPPayload struct {
}
func handleChannel(newChannel ssh.NewChannel, ports []uint32) {
log.Println("Channel type:", newChannel.ChannelType())
if *verbose {
log.Println("Channel type:", newChannel.ChannelType())
}
if t := newChannel.ChannelType(); t != "direct-tcpip" {
newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Only \"direct-tcpip\" is accepted"))
return
@ -115,7 +125,6 @@ func handleChannel(newChannel ssh.NewChannel, ports []uint32) {
return
}
//log.Println("Got payload: %v", payload)
if payload.Addr != "localhost" {
newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Bad addr"))
return
@ -144,7 +153,9 @@ func handleChannel(newChannel ssh.NewChannel, ports []uint32) {
}
addr := fmt.Sprintf("%s:%d", payload.Addr, payload.Port)
log.Println("Going to dial:", addr)
if *verbose {
log.Println("Going to dial:", addr)
}
rconn, err := net.Dial("tcp", addr)
if err != nil {
@ -156,7 +167,9 @@ func handleChannel(newChannel ssh.NewChannel, ports []uint32) {
close := func() {
connection.Close()
rconn.Close()
log.Printf("Session closed")
if *verbose {
log.Printf("Session closed")
}
}
var once sync.Once
@ -182,9 +195,9 @@ func parsePorts(portstr string) (p []uint32, err error) {
return
}
func loadKeys() {
func loadAuthorisedKeys(authorisedkeys string) {
authorisedKeys = map[string]string{}
authorisedKeysBytes, err := ioutil.ReadFile("authorized_keys")
authorisedKeysBytes, err := ioutil.ReadFile(authorisedkeys)
if err != nil {
log.Fatal("Cannot load authorised keys")
}

Loading…
Cancel
Save