Add flag parsing
This commit is contained in:
parent
bbbc8fad88
commit
a89882af11
1 changed files with 23 additions and 10 deletions
33
sshd.go
33
sshd.go
|
@ -9,6 +9,7 @@ package main
|
|||
// * https://github.com/tg123/sshpiper/commit/9db468b52dfc2cbe936efb7bef0fd5b88e0c1649
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
|
@ -23,9 +24,15 @@ import (
|
|||
|
||||
var (
|
||||
authorisedKeys map[string]string
|
||||
|
||||
hostkey = flag.String("hostkey", "id_rsa", "Server host key to load")
|
||||
authorisedkeys = flag.String("authorisedkeys", "authorized_keys", "Authorised keys")
|
||||
verbose = flag.Bool("verbose", false, "Enable verbose mode")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
config := &ssh.ServerConfig{
|
||||
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
||||
if ports, found := authorisedKeys[string(key.Marshal())]; found {
|
||||
|
@ -38,8 +45,7 @@ func main() {
|
|||
},
|
||||
}
|
||||
|
||||
privateBytes, err := ioutil.ReadFile("id_ed25519")
|
||||
//privateBytes, err := ioutil.ReadFile("id_rsa")
|
||||
privateBytes, err := ioutil.ReadFile(*hostkey)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to load private key (./id_rsa)")
|
||||
}
|
||||
|
@ -50,7 +56,7 @@ func main() {
|
|||
}
|
||||
|
||||
config.AddHostKey(private)
|
||||
loadKeys()
|
||||
loadAuthorisedKeys(*authorisedkeys)
|
||||
|
||||
listener, err := net.Listen("tcp", "0.0.0.0:2200")
|
||||
if err != nil {
|
||||
|
@ -74,7 +80,9 @@ func main() {
|
|||
|
||||
allowedPorts := sshConn.Permissions.CriticalOptions["ports"]
|
||||
|
||||
log.Printf("Connection from %s (%s). Allowed ports: %s", sshConn.RemoteAddr(), sshConn.ClientVersion(), allowedPorts)
|
||||
if *verbose {
|
||||
log.Printf("Connection from %s (%s). Allowed ports: %s", sshConn.RemoteAddr(), sshConn.ClientVersion(), allowedPorts)
|
||||
}
|
||||
|
||||
// Parsing a second time should not error
|
||||
ports, _ := parsePorts(allowedPorts)
|
||||
|
@ -101,7 +109,9 @@ type directTCPPayload struct {
|
|||
}
|
||||
|
||||
func handleChannel(newChannel ssh.NewChannel, ports []uint32) {
|
||||
log.Println("Channel type:", newChannel.ChannelType())
|
||||
if *verbose {
|
||||
log.Println("Channel type:", newChannel.ChannelType())
|
||||
}
|
||||
if t := newChannel.ChannelType(); t != "direct-tcpip" {
|
||||
newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Only \"direct-tcpip\" is accepted"))
|
||||
return
|
||||
|
@ -115,7 +125,6 @@ func handleChannel(newChannel ssh.NewChannel, ports []uint32) {
|
|||
return
|
||||
}
|
||||
|
||||
//log.Println("Got payload: %v", payload)
|
||||
if payload.Addr != "localhost" {
|
||||
newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Bad addr"))
|
||||
return
|
||||
|
@ -144,7 +153,9 @@ func handleChannel(newChannel ssh.NewChannel, ports []uint32) {
|
|||
}
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", payload.Addr, payload.Port)
|
||||
log.Println("Going to dial:", addr)
|
||||
if *verbose {
|
||||
log.Println("Going to dial:", addr)
|
||||
}
|
||||
|
||||
rconn, err := net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
|
@ -156,7 +167,9 @@ func handleChannel(newChannel ssh.NewChannel, ports []uint32) {
|
|||
close := func() {
|
||||
connection.Close()
|
||||
rconn.Close()
|
||||
log.Printf("Session closed")
|
||||
if *verbose {
|
||||
log.Printf("Session closed")
|
||||
}
|
||||
}
|
||||
|
||||
var once sync.Once
|
||||
|
@ -182,9 +195,9 @@ func parsePorts(portstr string) (p []uint32, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func loadKeys() {
|
||||
func loadAuthorisedKeys(authorisedkeys string) {
|
||||
authorisedKeys = map[string]string{}
|
||||
authorisedKeysBytes, err := ioutil.ReadFile("authorized_keys")
|
||||
authorisedKeysBytes, err := ioutil.ReadFile(authorisedkeys)
|
||||
if err != nil {
|
||||
log.Fatal("Cannot load authorised keys")
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue