Add flag parsing
This commit is contained in:
parent
bbbc8fad88
commit
a89882af11
1 changed files with 23 additions and 10 deletions
25
sshd.go
25
sshd.go
|
@ -9,6 +9,7 @@ package main
|
||||||
// * https://github.com/tg123/sshpiper/commit/9db468b52dfc2cbe936efb7bef0fd5b88e0c1649
|
// * https://github.com/tg123/sshpiper/commit/9db468b52dfc2cbe936efb7bef0fd5b88e0c1649
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
@ -23,9 +24,15 @@ import (
|
||||||
|
|
||||||
var (
|
var (
|
||||||
authorisedKeys map[string]string
|
authorisedKeys map[string]string
|
||||||
|
|
||||||
|
hostkey = flag.String("hostkey", "id_rsa", "Server host key to load")
|
||||||
|
authorisedkeys = flag.String("authorisedkeys", "authorized_keys", "Authorised keys")
|
||||||
|
verbose = flag.Bool("verbose", false, "Enable verbose mode")
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
config := &ssh.ServerConfig{
|
config := &ssh.ServerConfig{
|
||||||
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
||||||
if ports, found := authorisedKeys[string(key.Marshal())]; found {
|
if ports, found := authorisedKeys[string(key.Marshal())]; found {
|
||||||
|
@ -38,8 +45,7 @@ func main() {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
privateBytes, err := ioutil.ReadFile("id_ed25519")
|
privateBytes, err := ioutil.ReadFile(*hostkey)
|
||||||
//privateBytes, err := ioutil.ReadFile("id_rsa")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("Failed to load private key (./id_rsa)")
|
log.Fatal("Failed to load private key (./id_rsa)")
|
||||||
}
|
}
|
||||||
|
@ -50,7 +56,7 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
config.AddHostKey(private)
|
config.AddHostKey(private)
|
||||||
loadKeys()
|
loadAuthorisedKeys(*authorisedkeys)
|
||||||
|
|
||||||
listener, err := net.Listen("tcp", "0.0.0.0:2200")
|
listener, err := net.Listen("tcp", "0.0.0.0:2200")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -74,7 +80,9 @@ func main() {
|
||||||
|
|
||||||
allowedPorts := sshConn.Permissions.CriticalOptions["ports"]
|
allowedPorts := sshConn.Permissions.CriticalOptions["ports"]
|
||||||
|
|
||||||
|
if *verbose {
|
||||||
log.Printf("Connection from %s (%s). Allowed ports: %s", sshConn.RemoteAddr(), sshConn.ClientVersion(), allowedPorts)
|
log.Printf("Connection from %s (%s). Allowed ports: %s", sshConn.RemoteAddr(), sshConn.ClientVersion(), allowedPorts)
|
||||||
|
}
|
||||||
|
|
||||||
// Parsing a second time should not error
|
// Parsing a second time should not error
|
||||||
ports, _ := parsePorts(allowedPorts)
|
ports, _ := parsePorts(allowedPorts)
|
||||||
|
@ -101,7 +109,9 @@ type directTCPPayload struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleChannel(newChannel ssh.NewChannel, ports []uint32) {
|
func handleChannel(newChannel ssh.NewChannel, ports []uint32) {
|
||||||
|
if *verbose {
|
||||||
log.Println("Channel type:", newChannel.ChannelType())
|
log.Println("Channel type:", newChannel.ChannelType())
|
||||||
|
}
|
||||||
if t := newChannel.ChannelType(); t != "direct-tcpip" {
|
if t := newChannel.ChannelType(); t != "direct-tcpip" {
|
||||||
newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Only \"direct-tcpip\" is accepted"))
|
newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Only \"direct-tcpip\" is accepted"))
|
||||||
return
|
return
|
||||||
|
@ -115,7 +125,6 @@ func handleChannel(newChannel ssh.NewChannel, ports []uint32) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//log.Println("Got payload: %v", payload)
|
|
||||||
if payload.Addr != "localhost" {
|
if payload.Addr != "localhost" {
|
||||||
newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Bad addr"))
|
newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Bad addr"))
|
||||||
return
|
return
|
||||||
|
@ -144,7 +153,9 @@ func handleChannel(newChannel ssh.NewChannel, ports []uint32) {
|
||||||
}
|
}
|
||||||
|
|
||||||
addr := fmt.Sprintf("%s:%d", payload.Addr, payload.Port)
|
addr := fmt.Sprintf("%s:%d", payload.Addr, payload.Port)
|
||||||
|
if *verbose {
|
||||||
log.Println("Going to dial:", addr)
|
log.Println("Going to dial:", addr)
|
||||||
|
}
|
||||||
|
|
||||||
rconn, err := net.Dial("tcp", addr)
|
rconn, err := net.Dial("tcp", addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -156,8 +167,10 @@ func handleChannel(newChannel ssh.NewChannel, ports []uint32) {
|
||||||
close := func() {
|
close := func() {
|
||||||
connection.Close()
|
connection.Close()
|
||||||
rconn.Close()
|
rconn.Close()
|
||||||
|
if *verbose {
|
||||||
log.Printf("Session closed")
|
log.Printf("Session closed")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var once sync.Once
|
var once sync.Once
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -182,9 +195,9 @@ func parsePorts(portstr string) (p []uint32, err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadKeys() {
|
func loadAuthorisedKeys(authorisedkeys string) {
|
||||||
authorisedKeys = map[string]string{}
|
authorisedKeys = map[string]string{}
|
||||||
authorisedKeysBytes, err := ioutil.ReadFile("authorized_keys")
|
authorisedKeysBytes, err := ioutil.ReadFile(authorisedkeys)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("Cannot load authorised keys")
|
log.Fatal("Cannot load authorised keys")
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue