|
|
|
@ -18,18 +18,31 @@ import ( |
|
|
|
|
"golang.org/x/crypto/ssh" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
// TODO: Use defer where useful
|
|
|
|
|
|
|
|
|
|
var ( |
|
|
|
|
authorisedKeys map[string]string |
|
|
|
|
|
|
|
|
|
/* Global listeners, we keep a global state for cancel-tcpip-forward */ |
|
|
|
|
globalListens map[string]net.Listener |
|
|
|
|
|
|
|
|
|
listenport = flag.Int("listenport", 2200, "Port to listen on for incoming ssh connections") |
|
|
|
|
hostkey = flag.String("hostkey", "id_rsa", "Server host key to load") |
|
|
|
|
authorisedkeys = flag.String("authorisedkeys", "authorized_keys", "Authorised keys") |
|
|
|
|
verbose = flag.Bool("verbose", false, "Enable verbose mode") |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
type sshClient struct { |
|
|
|
|
Name string |
|
|
|
|
Conn *ssh.ServerConn |
|
|
|
|
Listeners map[string]net.Listener |
|
|
|
|
AllowedLocalPorts []uint32 |
|
|
|
|
AllowedRemotePorts []uint32 |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
type bindInfo struct { |
|
|
|
|
Bound string |
|
|
|
|
Port uint32 |
|
|
|
|
Addr string |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/* RFC4254 7.2 */ |
|
|
|
|
type directTCPPayload struct { |
|
|
|
|
Addr string // To connect to
|
|
|
|
@ -62,8 +75,6 @@ type tcpIpForwardCancelPayload struct { |
|
|
|
|
func main() { |
|
|
|
|
flag.Parse() |
|
|
|
|
|
|
|
|
|
globalListens = make(map[string]net.Listener) |
|
|
|
|
|
|
|
|
|
config := &ssh.ServerConfig{ |
|
|
|
|
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { |
|
|
|
|
if ports, found := authorisedKeys[string(key.Marshal())]; found { |
|
|
|
@ -101,6 +112,23 @@ func main() { |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
client := sshClient{"TODO FIXME XXX", sshConn, make(map[string]net.Listener), nil, nil} |
|
|
|
|
|
|
|
|
|
go func() { |
|
|
|
|
err := client.Conn.Wait() |
|
|
|
|
if *verbose { |
|
|
|
|
log.Printf("SSH connection closed for client %s: %s", client.Name, err) |
|
|
|
|
} |
|
|
|
|
// TODO: Make this safe? Is it impossible for cancel code to be
|
|
|
|
|
// running at this point?
|
|
|
|
|
for bind, listener := range client.Listeners { |
|
|
|
|
if *verbose { |
|
|
|
|
log.Printf("Closing listener bound to %s", bind) |
|
|
|
|
} |
|
|
|
|
listener.Close() |
|
|
|
|
} |
|
|
|
|
}() |
|
|
|
|
|
|
|
|
|
allowedPorts := sshConn.Permissions.CriticalOptions["ports"] |
|
|
|
|
|
|
|
|
|
if *verbose { |
|
|
|
@ -110,27 +138,30 @@ func main() { |
|
|
|
|
// Parsing a second time should not error, so we can ignore the error
|
|
|
|
|
// safely
|
|
|
|
|
ports, _ := parsePorts(allowedPorts) |
|
|
|
|
// TODO: Don't share same port/host limit
|
|
|
|
|
client.AllowedLocalPorts = ports |
|
|
|
|
client.AllowedRemotePorts = ports |
|
|
|
|
|
|
|
|
|
go handleRequest(sshConn, reqs) |
|
|
|
|
go handleRequest(&client, reqs) |
|
|
|
|
|
|
|
|
|
// Accept all channels
|
|
|
|
|
go handleChannels(chans, ports) |
|
|
|
|
// Accept all channels (TODO: Pass client)
|
|
|
|
|
go handleChannels(&client, chans) |
|
|
|
|
}() |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func handleChannels(chans <-chan ssh.NewChannel, ports []uint32) { |
|
|
|
|
func handleChannels(client *sshClient, chans <-chan ssh.NewChannel) { |
|
|
|
|
for c := range chans { |
|
|
|
|
go handleChannel(c, ports) |
|
|
|
|
go handleChannel(client, c) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func handleChannel(newChannel ssh.NewChannel, ports []uint32) { |
|
|
|
|
func handleChannel(client *sshClient, newChannel ssh.NewChannel) { |
|
|
|
|
if *verbose { |
|
|
|
|
log.Println("Channel type:", newChannel.ChannelType()) |
|
|
|
|
} |
|
|
|
|
if t := newChannel.ChannelType(); t == "direct-tcpip" { |
|
|
|
|
handleDirect(newChannel, ports) |
|
|
|
|
handleDirect(client, newChannel) |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
|
|
|
|
@ -145,12 +176,12 @@ func handleChannel(newChannel ssh.NewChannel, ports []uint32) { |
|
|
|
|
d := make([]byte, 4096) |
|
|
|
|
c.Read(d) |
|
|
|
|
}() |
|
|
|
|
return |
|
|
|
|
*/ |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func handleDirect(newChannel ssh.NewChannel, ports []uint32) { |
|
|
|
|
func handleDirect(client *sshClient, newChannel ssh.NewChannel) { |
|
|
|
|
var payload directTCPPayload |
|
|
|
|
if err := ssh.Unmarshal(newChannel.ExtraData(), &payload); err != nil { |
|
|
|
|
log.Printf("Could not unmarshal extra data: %s\n", err) |
|
|
|
@ -166,7 +197,7 @@ func handleDirect(newChannel ssh.NewChannel, ports []uint32) { |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
ok := false |
|
|
|
|
for _, port := range ports { |
|
|
|
|
for _, port := range client.AllowedLocalPorts { |
|
|
|
|
if payload.Port == port { |
|
|
|
|
ok = true |
|
|
|
|
break |
|
|
|
@ -203,11 +234,12 @@ func handleDirect(newChannel ssh.NewChannel, ports []uint32) { |
|
|
|
|
serve(connection, rconn) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func handleTcpIpForward(conn *ssh.ServerConn, req *ssh.Request) { |
|
|
|
|
func handleTcpIpForward(client *sshClient, req *ssh.Request) (net.Listener, *bindInfo, error) { |
|
|
|
|
var payload tcpIpForwardPayload |
|
|
|
|
if err := ssh.Unmarshal(req.Payload, &payload); err != nil { |
|
|
|
|
log.Println("Unable to unmarshal payload") |
|
|
|
|
req.Reply(false, []byte{}) |
|
|
|
|
return nil, nil, fmt.Errorf("Unable to parse payload") |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
log.Println("Request:", req.Type, req.WantReply, payload) |
|
|
|
@ -217,7 +249,7 @@ func handleTcpIpForward(conn *ssh.ServerConn, req *ssh.Request) { |
|
|
|
|
if payload.Addr != "localhost" { |
|
|
|
|
log.Printf("Payload address is not \"localhost\"") |
|
|
|
|
req.Reply(false, []byte{}) |
|
|
|
|
return |
|
|
|
|
return nil, nil, fmt.Errorf("Address is not permitted") |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// TODO: Check port
|
|
|
|
@ -233,84 +265,76 @@ func handleTcpIpForward(conn *ssh.ServerConn, req *ssh.Request) { |
|
|
|
|
if err != nil { |
|
|
|
|
log.Printf("Listen failed for %s", bind) |
|
|
|
|
req.Reply(false, []byte{}) |
|
|
|
|
return |
|
|
|
|
return nil, nil, err |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
globalListens[bind] = ln |
|
|
|
|
|
|
|
|
|
// Tell client everything is OK
|
|
|
|
|
reply := tcpIpForwardPayloadReply{lport} |
|
|
|
|
req.Reply(true, ssh.Marshal(&reply)) |
|
|
|
|
|
|
|
|
|
// Ensure that we get notified when the client connection is (unexpectedly)
|
|
|
|
|
// closed
|
|
|
|
|
go func() { |
|
|
|
|
err := conn.Wait() |
|
|
|
|
if *verbose { |
|
|
|
|
log.Printf("SSH connection closed: %s. Stopping listen", err) |
|
|
|
|
} |
|
|
|
|
ln.Close() |
|
|
|
|
delete(globalListens, bind) |
|
|
|
|
return ln, &bindInfo{bind, lport, laddr}, nil |
|
|
|
|
|
|
|
|
|
// We don't close existing connections
|
|
|
|
|
}() |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func handleListener(client *sshClient, bindinfo *bindInfo, listener net.Listener) { |
|
|
|
|
// Start listening for connections
|
|
|
|
|
go func() { |
|
|
|
|
for { |
|
|
|
|
lconn, err := ln.Accept() |
|
|
|
|
if err != nil { |
|
|
|
|
neterr := err.(net.Error) |
|
|
|
|
if neterr.Timeout() { |
|
|
|
|
log.Println("Accept failed with timeout:", err) |
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
if neterr.Temporary() { |
|
|
|
|
log.Println("Accept failed with temporary:", err) |
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
break |
|
|
|
|
for { |
|
|
|
|
lconn, err := listener.Accept() |
|
|
|
|
if err != nil { |
|
|
|
|
neterr := err.(net.Error) |
|
|
|
|
if neterr.Timeout() { |
|
|
|
|
log.Println("Accept failed with timeout:", err) |
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
if neterr.Temporary() { |
|
|
|
|
log.Println("Accept failed with temporary:", err) |
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// TODO: Sep function?
|
|
|
|
|
go func() { |
|
|
|
|
remotetcpaddr := lconn.RemoteAddr().(*net.TCPAddr) |
|
|
|
|
raddr := remotetcpaddr.IP.String() |
|
|
|
|
rport := uint32(remotetcpaddr.Port) |
|
|
|
|
|
|
|
|
|
payload := forwardedTCPPayload{laddr, lport, raddr, uint32(rport)} |
|
|
|
|
mpayload := ssh.Marshal(&payload) |
|
|
|
|
|
|
|
|
|
// Open channel with client
|
|
|
|
|
c, requests, err := conn.OpenChannel("forwarded-tcpip", mpayload) |
|
|
|
|
if err != nil { |
|
|
|
|
log.Printf("Error: %s", err) |
|
|
|
|
log.Println("Unable to get channel. Hanging up requesting party!") |
|
|
|
|
lconn.Close() |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
go ssh.DiscardRequests(requests) |
|
|
|
|
|
|
|
|
|
serve(c, lconn) |
|
|
|
|
}() |
|
|
|
|
break |
|
|
|
|
} |
|
|
|
|
}() |
|
|
|
|
|
|
|
|
|
// TODO: I don't think a goroutine is required here
|
|
|
|
|
go handleForwardTcpIp(client, bindinfo, lconn) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func handleTcpIPForwardCancel(req *ssh.Request) { |
|
|
|
|
func handleForwardTcpIp(client *sshClient, bindinfo *bindInfo, lconn net.Conn) { |
|
|
|
|
remotetcpaddr := lconn.RemoteAddr().(*net.TCPAddr) |
|
|
|
|
raddr := remotetcpaddr.IP.String() |
|
|
|
|
rport := uint32(remotetcpaddr.Port) |
|
|
|
|
|
|
|
|
|
payload := forwardedTCPPayload{bindinfo.Addr, bindinfo.Port, raddr, uint32(rport)} |
|
|
|
|
mpayload := ssh.Marshal(&payload) |
|
|
|
|
|
|
|
|
|
// Open channel with client
|
|
|
|
|
c, requests, err := client.Conn.OpenChannel("forwarded-tcpip", mpayload) |
|
|
|
|
if err != nil { |
|
|
|
|
log.Printf("Error: %s", err) |
|
|
|
|
log.Println("Unable to get channel. Hanging up requesting party!") |
|
|
|
|
lconn.Close() |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
go ssh.DiscardRequests(requests) |
|
|
|
|
|
|
|
|
|
serve(c, lconn) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func handleTcpIPForwardCancel(client *sshClient, req *ssh.Request) { |
|
|
|
|
if *verbose { |
|
|
|
|
log.Println("Cancel called by client", client) |
|
|
|
|
} |
|
|
|
|
var payload tcpIpForwardCancelPayload |
|
|
|
|
if err := ssh.Unmarshal(req.Payload, &payload); err != nil { |
|
|
|
|
log.Println("Unable to unmarshal cancel payload") |
|
|
|
|
req.Reply(false, []byte{}) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
//bound := fmt.Sprintf(":%d", payload.Port)
|
|
|
|
|
bound := fmt.Sprintf("%s:%d", payload.Addr, payload.Port) |
|
|
|
|
|
|
|
|
|
if listener, found := globalListens[bound]; found { |
|
|
|
|
if listener, found := client.Listeners[bound]; found { |
|
|
|
|
listener.Close() |
|
|
|
|
delete(globalListens, bound) |
|
|
|
|
delete(client.Listeners, bound) |
|
|
|
|
req.Reply(true, []byte{}) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
@ -401,7 +425,7 @@ func loadAuthorisedKeys(authorisedkeys string) { |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func handleRequest(sshConn *ssh.ServerConn, reqs <-chan *ssh.Request) { |
|
|
|
|
func handleRequest(client *sshClient, reqs <-chan *ssh.Request) { |
|
|
|
|
for req := range reqs { |
|
|
|
|
if *verbose { |
|
|
|
|
log.Println("Out of band request:", req.Type, req.WantReply) |
|
|
|
@ -409,10 +433,16 @@ func handleRequest(sshConn *ssh.ServerConn, reqs <-chan *ssh.Request) { |
|
|
|
|
|
|
|
|
|
// RFC4254: 7.1 for forwarding
|
|
|
|
|
if req.Type == "tcpip-forward" { |
|
|
|
|
handleTcpIpForward(sshConn, req) |
|
|
|
|
listener, bindinfo, err := handleTcpIpForward(client, req) |
|
|
|
|
if err != nil { |
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
client.Listeners[bindinfo.Bound] = listener |
|
|
|
|
go handleListener(client, bindinfo, listener) |
|
|
|
|
continue |
|
|
|
|
} else if req.Type == "cancel-tcpip-forward" { |
|
|
|
|
handleTcpIPForwardCancel(req) |
|
|
|
|
handleTcpIPForwardCancel(client, req) |
|
|
|
|
continue |
|
|
|
|
} else { |
|
|
|
|
// Discard everything else
|
|
|
|
|