|
|
|
@ -4,10 +4,6 @@ package main |
|
|
|
|
// Merlijn B. W. Wajer <merlijn@wizzup.org>
|
|
|
|
|
// (C) 2017
|
|
|
|
|
|
|
|
|
|
// Trivial parts taken from:
|
|
|
|
|
// * https://blog.gopheracademy.com/advent-2015/ssh-server-in-go/
|
|
|
|
|
// * https://github.com/tg123/sshpiper/commit/9db468b52dfc2cbe936efb7bef0fd5b88e0c1649
|
|
|
|
|
|
|
|
|
|
import ( |
|
|
|
|
"flag" |
|
|
|
|
"fmt" |
|
|
|
@ -25,15 +21,49 @@ import ( |
|
|
|
|
var ( |
|
|
|
|
authorisedKeys map[string]string |
|
|
|
|
|
|
|
|
|
/* Global listeners, we keep a global state for cancel-tcpip-forward */ |
|
|
|
|
globalListens map[string]net.Listener |
|
|
|
|
|
|
|
|
|
listenport = flag.Int("listenport", 2200, "Port to listen on for incoming ssh connections") |
|
|
|
|
hostkey = flag.String("hostkey", "id_rsa", "Server host key to load") |
|
|
|
|
authorisedkeys = flag.String("authorisedkeys", "authorized_keys", "Authorised keys") |
|
|
|
|
verbose = flag.Bool("verbose", false, "Enable verbose mode") |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
/* RFC4254 7.2 */ |
|
|
|
|
type directTCPPayload struct { |
|
|
|
|
Addr string // To connect to
|
|
|
|
|
Port uint32 |
|
|
|
|
OriginAddr string |
|
|
|
|
OriginPort uint32 |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
type forwardedTCPPayload struct { |
|
|
|
|
Addr string // Is connected to
|
|
|
|
|
Port uint32 |
|
|
|
|
OriginAddr string |
|
|
|
|
OriginPort uint32 |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
type tcpIpForwardPayload struct { |
|
|
|
|
Addr string |
|
|
|
|
Port uint32 |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
type tcpIpForwardPayloadReply struct { |
|
|
|
|
Port uint32 |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
type tcpIpForwardCancelPayload struct { |
|
|
|
|
Addr string |
|
|
|
|
Port uint32 |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func main() { |
|
|
|
|
flag.Parse() |
|
|
|
|
|
|
|
|
|
globalListens = make(map[string]net.Listener) |
|
|
|
|
|
|
|
|
|
config := &ssh.ServerConfig{ |
|
|
|
|
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { |
|
|
|
|
if ports, found := authorisedKeys[string(key.Marshal())]; found { |
|
|
|
@ -46,17 +76,7 @@ func main() { |
|
|
|
|
}, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
privateBytes, err := ioutil.ReadFile(*hostkey) |
|
|
|
|
if err != nil { |
|
|
|
|
log.Fatal("Failed to load private key (./id_rsa)") |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
private, err := ssh.ParsePrivateKey(privateBytes) |
|
|
|
|
if err != nil { |
|
|
|
|
log.Fatal("Failed to parse private key") |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
config.AddHostKey(private) |
|
|
|
|
loadHostKeys(config) |
|
|
|
|
loadAuthorisedKeys(*authorisedkeys) |
|
|
|
|
|
|
|
|
|
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", *listenport)) |
|
|
|
@ -85,11 +105,31 @@ func main() { |
|
|
|
|
log.Printf("Connection from %s (%s). Allowed ports: %s", sshConn.RemoteAddr(), sshConn.ClientVersion(), allowedPorts) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Parsing a second time should not error
|
|
|
|
|
// Parsing a second time should not error, so we can ignore the error
|
|
|
|
|
// safely
|
|
|
|
|
ports, _ := parsePorts(allowedPorts) |
|
|
|
|
|
|
|
|
|
// Discard all global out-of-band Requests
|
|
|
|
|
go ssh.DiscardRequests(reqs) |
|
|
|
|
// Handle global out-of-band Requests
|
|
|
|
|
go func() { |
|
|
|
|
for req := range reqs { |
|
|
|
|
if *verbose { |
|
|
|
|
log.Println("Out of band request:", req.Type, req.WantReply) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// RFC4254: 7.1 for forwarding
|
|
|
|
|
if req.Type == "tcpip-forward" { |
|
|
|
|
handleTcpIpForward(sshConn, req) |
|
|
|
|
continue |
|
|
|
|
} else if req.Type == "cancel-tcpip-forward" { |
|
|
|
|
handleTcpIPForwardCancel(req) |
|
|
|
|
continue |
|
|
|
|
} else { |
|
|
|
|
// Discard everything else
|
|
|
|
|
req.Reply(false, []byte{}) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
}() |
|
|
|
|
|
|
|
|
|
// Accept all channels
|
|
|
|
|
go handleChannels(chans, ports) |
|
|
|
|
} |
|
|
|
@ -101,23 +141,32 @@ func handleChannels(chans <-chan ssh.NewChannel, ports []uint32) { |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/* RFC4254 7.2 */ |
|
|
|
|
type directTCPPayload struct { |
|
|
|
|
Addr string |
|
|
|
|
Port uint32 |
|
|
|
|
OriginAddr string |
|
|
|
|
OriginPort uint32 |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func handleChannel(newChannel ssh.NewChannel, ports []uint32) { |
|
|
|
|
if *verbose { |
|
|
|
|
log.Println("Channel type:", newChannel.ChannelType()) |
|
|
|
|
} |
|
|
|
|
if t := newChannel.ChannelType(); t != "direct-tcpip" { |
|
|
|
|
newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Only \"direct-tcpip\" is accepted")) |
|
|
|
|
if t := newChannel.ChannelType(); t == "direct-tcpip" { |
|
|
|
|
handleDirect(newChannel, ports) |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Only \"direct-tcpip\" is accepted")) |
|
|
|
|
/* |
|
|
|
|
// TODO: USE THIS ONLY FOR USING SSH ESCAPE SEQUENCES
|
|
|
|
|
c, _, err := newChannel.Accept() |
|
|
|
|
if err != nil { |
|
|
|
|
log.Fatal(err) |
|
|
|
|
} |
|
|
|
|
go func() { |
|
|
|
|
d := make([]byte, 4096) |
|
|
|
|
c.Read(d) |
|
|
|
|
}() |
|
|
|
|
return |
|
|
|
|
*/ |
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func handleDirect(newChannel ssh.NewChannel, ports []uint32) { |
|
|
|
|
var payload directTCPPayload |
|
|
|
|
if err := ssh.Unmarshal(newChannel.ExtraData(), &payload); err != nil { |
|
|
|
|
log.Printf("Could not unmarshal extra data: %s\n", err) |
|
|
|
@ -127,6 +176,7 @@ func handleChannel(newChannel ssh.NewChannel, ports []uint32) { |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if payload.Addr != "localhost" { |
|
|
|
|
log.Printf("Tried to connect to prohibited host: %s", payload.Addr) |
|
|
|
|
newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Bad addr")) |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
@ -141,22 +191,22 @@ func handleChannel(newChannel ssh.NewChannel, ports []uint32) { |
|
|
|
|
|
|
|
|
|
if !ok { |
|
|
|
|
newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Bad port")) |
|
|
|
|
log.Printf("Tried to forward prohibited port: %d", payload.Port) |
|
|
|
|
log.Printf("Tried to connect to prohibited port: %d", payload.Port) |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// At this point, we have the opportunity to reject the client's
|
|
|
|
|
// At this point, we have the opportunity to reject the clients
|
|
|
|
|
// request for another logical connection
|
|
|
|
|
connection, requests, err := newChannel.Accept() |
|
|
|
|
_ = requests // TODO: Think we can just ignore these
|
|
|
|
|
if err != nil { |
|
|
|
|
log.Printf("Could not accept channel (%s)", err) |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
go ssh.DiscardRequests(requests) |
|
|
|
|
|
|
|
|
|
addr := fmt.Sprintf("%s:%d", payload.Addr, payload.Port) |
|
|
|
|
if *verbose { |
|
|
|
|
log.Println("Going to dial:", addr) |
|
|
|
|
log.Println("Dialing:", addr) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
rconn, err := net.Dial("tcp", addr) |
|
|
|
@ -166,21 +216,134 @@ func handleChannel(newChannel ssh.NewChannel, ports []uint32) { |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
serve(connection, rconn) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func handleTcpIpForward(conn *ssh.ServerConn, req *ssh.Request) { |
|
|
|
|
var payload tcpIpForwardPayload |
|
|
|
|
if err := ssh.Unmarshal(req.Payload, &payload); err != nil { |
|
|
|
|
log.Println("Unable to unmarshal payload") |
|
|
|
|
req.Reply(false, []byte{}) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
log.Println("Request:", req.Type, req.WantReply, payload) |
|
|
|
|
|
|
|
|
|
log.Printf("Request to listen on %s:%d", payload.Addr, payload.Port) |
|
|
|
|
|
|
|
|
|
if payload.Addr != "localhost" { |
|
|
|
|
log.Printf("Payload address is not \"localhost\"") |
|
|
|
|
req.Reply(false, []byte{}) |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// TODO: Check port
|
|
|
|
|
|
|
|
|
|
laddr := payload.Addr |
|
|
|
|
lport := payload.Port |
|
|
|
|
|
|
|
|
|
// TODO: We currently bind to localhost:port, and not to :port
|
|
|
|
|
// Need to figure out what we want - perhaps just part of policy
|
|
|
|
|
bind := fmt.Sprintf("%s:%d", laddr, lport) |
|
|
|
|
ln, err := net.Listen("tcp", bind) |
|
|
|
|
if err != nil { |
|
|
|
|
log.Printf("Listen failed for %s", bind) |
|
|
|
|
req.Reply(false, []byte{}) |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
globalListens[bind] = ln |
|
|
|
|
|
|
|
|
|
// Tell client everything is OK
|
|
|
|
|
reply := tcpIpForwardPayloadReply{lport} |
|
|
|
|
req.Reply(true, ssh.Marshal(&reply)) |
|
|
|
|
|
|
|
|
|
// Ensure that we get notified when the client connection is (unexpectedly)
|
|
|
|
|
// closed
|
|
|
|
|
go func() { |
|
|
|
|
err := conn.Wait() |
|
|
|
|
if *verbose { |
|
|
|
|
log.Printf("SSH connection closed: %s. Stopping listen", err) |
|
|
|
|
} |
|
|
|
|
ln.Close() |
|
|
|
|
delete(globalListens, bind) |
|
|
|
|
|
|
|
|
|
// We don't close existing connections
|
|
|
|
|
}() |
|
|
|
|
|
|
|
|
|
// Start listening for connections
|
|
|
|
|
go func() { |
|
|
|
|
for { |
|
|
|
|
lconn, err := ln.Accept() |
|
|
|
|
if err != nil { |
|
|
|
|
log.Println("Accept failed") |
|
|
|
|
break |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
go func() { |
|
|
|
|
remoteaddr := lconn.RemoteAddr().String() |
|
|
|
|
|
|
|
|
|
p_index := strings.LastIndex(remoteaddr, ":") |
|
|
|
|
raddr := remoteaddr[:p_index] |
|
|
|
|
rport, err := strconv.ParseUint(remoteaddr[p_index+1:], 10, 32) |
|
|
|
|
if err != nil { |
|
|
|
|
log.Printf("Unable to parse RemoteAddr! (%s)", err) |
|
|
|
|
lconn.Close() |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
payload := forwardedTCPPayload{laddr, lport, raddr, uint32(rport)} |
|
|
|
|
mpayload := ssh.Marshal(&payload) |
|
|
|
|
|
|
|
|
|
// Open channel with client
|
|
|
|
|
c, requests, err := conn.OpenChannel("forwarded-tcpip", mpayload) |
|
|
|
|
if err != nil { |
|
|
|
|
log.Printf("Error: %s", err) |
|
|
|
|
log.Println("Unable to get channel. Hanging up requesting party!") |
|
|
|
|
lconn.Close() |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
go ssh.DiscardRequests(requests) |
|
|
|
|
|
|
|
|
|
serve(c, lconn) |
|
|
|
|
}() |
|
|
|
|
} |
|
|
|
|
}() |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func handleTcpIPForwardCancel(req *ssh.Request) { |
|
|
|
|
var payload tcpIpForwardCancelPayload |
|
|
|
|
if err := ssh.Unmarshal(req.Payload, &payload); err != nil { |
|
|
|
|
log.Println("Unable to unmarshal cancel payload") |
|
|
|
|
req.Reply(false, []byte{}) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
bound := fmt.Sprintf("%s:%d", payload.Addr, payload.Port) |
|
|
|
|
|
|
|
|
|
if listener, found := globalListens[bound]; found { |
|
|
|
|
listener.Close() |
|
|
|
|
delete(globalListens, bound) |
|
|
|
|
req.Reply(true, []byte{}) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
req.Reply(false, []byte{}) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func serve(cssh ssh.Channel, conn net.Conn) { |
|
|
|
|
close := func() { |
|
|
|
|
connection.Close() |
|
|
|
|
rconn.Close() |
|
|
|
|
cssh.Close() |
|
|
|
|
conn.Close() |
|
|
|
|
if *verbose { |
|
|
|
|
log.Printf("Session closed") |
|
|
|
|
log.Printf("Channel closed") |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
var once sync.Once |
|
|
|
|
go func() { |
|
|
|
|
io.Copy(connection, rconn) |
|
|
|
|
io.Copy(cssh, conn) |
|
|
|
|
once.Do(close) |
|
|
|
|
}() |
|
|
|
|
go func() { |
|
|
|
|
io.Copy(rconn, connection) |
|
|
|
|
io.Copy(conn, cssh) |
|
|
|
|
once.Do(close) |
|
|
|
|
}() |
|
|
|
|
} |
|
|
|
@ -197,6 +360,20 @@ func parsePorts(portstr string) (p []uint32, err error) { |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func loadHostKeys(config *ssh.ServerConfig) { |
|
|
|
|
privateBytes, err := ioutil.ReadFile(*hostkey) |
|
|
|
|
if err != nil { |
|
|
|
|
log.Fatal("Failed to load private key (./id_rsa)") |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
private, err := ssh.ParsePrivateKey(privateBytes) |
|
|
|
|
if err != nil { |
|
|
|
|
log.Fatal("Failed to parse private key") |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
config.AddHostKey(private) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func loadAuthorisedKeys(authorisedkeys string) { |
|
|
|
|
authorisedKeys = map[string]string{} |
|
|
|
|
authorisedKeysBytes, err := ioutil.ReadFile(authorisedkeys) |
|
|
|
@ -213,13 +390,13 @@ func loadAuthorisedKeys(authorisedkeys string) { |
|
|
|
|
|
|
|
|
|
log.Println("Options:", options) |
|
|
|
|
if len(options) != 1 { |
|
|
|
|
log.Fatal(fmt.Errorf("Only one option is accepted: 'ports=...'")) |
|
|
|
|
log.Fatal(fmt.Errorf("Only one option is accepted: \"ports=...\"")) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
option := options[0] |
|
|
|
|
|
|
|
|
|
if !strings.HasPrefix(option, "ports=") { |
|
|
|
|
log.Fatal(fmt.Errorf("Options does not start with 'ports='")) |
|
|
|
|
log.Fatal(fmt.Errorf("Options does not start with \"ports=\"")) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
ports := option[len("ports="):] |
|
|
|
|