Remove global state, refactoring

master
Merlijn Wajer 7 years ago
parent 8d13c9501c
commit ebbf5692fa
  1. 2
      README.rst
  2. 180
      sshd.go

@ -2,7 +2,7 @@ Motivation
========== ==========
sshd implementation in Go, for the sole purpose of restricting the ports that sshd implementation in Go, for the sole purpose of restricting the ports that
clients can request using direct-tcpip. clients can request using direct-tcpip and tcpip-forward / forwarded-tcpip.
OpenSSH refuses to merge patches to support this, but there is a fork of OpenSSH OpenSSH refuses to merge patches to support this, but there is a fork of OpenSSH
with patches that achieve something similar to this. [1] with patches that achieve something similar to this. [1]

@ -18,18 +18,31 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
// TODO: Use defer where useful
var ( var (
authorisedKeys map[string]string authorisedKeys map[string]string
/* Global listeners, we keep a global state for cancel-tcpip-forward */
globalListens map[string]net.Listener
listenport = flag.Int("listenport", 2200, "Port to listen on for incoming ssh connections") listenport = flag.Int("listenport", 2200, "Port to listen on for incoming ssh connections")
hostkey = flag.String("hostkey", "id_rsa", "Server host key to load") hostkey = flag.String("hostkey", "id_rsa", "Server host key to load")
authorisedkeys = flag.String("authorisedkeys", "authorized_keys", "Authorised keys") authorisedkeys = flag.String("authorisedkeys", "authorized_keys", "Authorised keys")
verbose = flag.Bool("verbose", false, "Enable verbose mode") verbose = flag.Bool("verbose", false, "Enable verbose mode")
) )
type sshClient struct {
Name string
Conn *ssh.ServerConn
Listeners map[string]net.Listener
AllowedLocalPorts []uint32
AllowedRemotePorts []uint32
}
type bindInfo struct {
Bound string
Port uint32
Addr string
}
/* RFC4254 7.2 */ /* RFC4254 7.2 */
type directTCPPayload struct { type directTCPPayload struct {
Addr string // To connect to Addr string // To connect to
@ -62,8 +75,6 @@ type tcpIpForwardCancelPayload struct {
func main() { func main() {
flag.Parse() flag.Parse()
globalListens = make(map[string]net.Listener)
config := &ssh.ServerConfig{ config := &ssh.ServerConfig{
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
if ports, found := authorisedKeys[string(key.Marshal())]; found { if ports, found := authorisedKeys[string(key.Marshal())]; found {
@ -101,6 +112,23 @@ func main() {
return return
} }
client := sshClient{"TODO FIXME XXX", sshConn, make(map[string]net.Listener), nil, nil}
go func() {
err := client.Conn.Wait()
if *verbose {
log.Printf("SSH connection closed for client %s: %s", client.Name, err)
}
// TODO: Make this safe? Is it impossible for cancel code to be
// running at this point?
for bind, listener := range client.Listeners {
if *verbose {
log.Printf("Closing listener bound to %s", bind)
}
listener.Close()
}
}()
allowedPorts := sshConn.Permissions.CriticalOptions["ports"] allowedPorts := sshConn.Permissions.CriticalOptions["ports"]
if *verbose { if *verbose {
@ -110,27 +138,30 @@ func main() {
// Parsing a second time should not error, so we can ignore the error // Parsing a second time should not error, so we can ignore the error
// safely // safely
ports, _ := parsePorts(allowedPorts) ports, _ := parsePorts(allowedPorts)
// TODO: Don't share same port/host limit
client.AllowedLocalPorts = ports
client.AllowedRemotePorts = ports
go handleRequest(sshConn, reqs) go handleRequest(&client, reqs)
// Accept all channels // Accept all channels (TODO: Pass client)
go handleChannels(chans, ports) go handleChannels(&client, chans)
}() }()
} }
} }
func handleChannels(chans <-chan ssh.NewChannel, ports []uint32) { func handleChannels(client *sshClient, chans <-chan ssh.NewChannel) {
for c := range chans { for c := range chans {
go handleChannel(c, ports) go handleChannel(client, c)
} }
} }
func handleChannel(newChannel ssh.NewChannel, ports []uint32) { func handleChannel(client *sshClient, newChannel ssh.NewChannel) {
if *verbose { if *verbose {
log.Println("Channel type:", newChannel.ChannelType()) log.Println("Channel type:", newChannel.ChannelType())
} }
if t := newChannel.ChannelType(); t == "direct-tcpip" { if t := newChannel.ChannelType(); t == "direct-tcpip" {
handleDirect(newChannel, ports) handleDirect(client, newChannel)
return return
} }
@ -145,12 +176,12 @@ func handleChannel(newChannel ssh.NewChannel, ports []uint32) {
d := make([]byte, 4096) d := make([]byte, 4096)
c.Read(d) c.Read(d)
}() }()
return
*/ */
return
} }
func handleDirect(newChannel ssh.NewChannel, ports []uint32) { func handleDirect(client *sshClient, newChannel ssh.NewChannel) {
var payload directTCPPayload var payload directTCPPayload
if err := ssh.Unmarshal(newChannel.ExtraData(), &payload); err != nil { if err := ssh.Unmarshal(newChannel.ExtraData(), &payload); err != nil {
log.Printf("Could not unmarshal extra data: %s\n", err) log.Printf("Could not unmarshal extra data: %s\n", err)
@ -166,7 +197,7 @@ func handleDirect(newChannel ssh.NewChannel, ports []uint32) {
} }
ok := false ok := false
for _, port := range ports { for _, port := range client.AllowedLocalPorts {
if payload.Port == port { if payload.Port == port {
ok = true ok = true
break break
@ -203,11 +234,12 @@ func handleDirect(newChannel ssh.NewChannel, ports []uint32) {
serve(connection, rconn) serve(connection, rconn)
} }
func handleTcpIpForward(conn *ssh.ServerConn, req *ssh.Request) { func handleTcpIpForward(client *sshClient, req *ssh.Request) (net.Listener, *bindInfo, error) {
var payload tcpIpForwardPayload var payload tcpIpForwardPayload
if err := ssh.Unmarshal(req.Payload, &payload); err != nil { if err := ssh.Unmarshal(req.Payload, &payload); err != nil {
log.Println("Unable to unmarshal payload") log.Println("Unable to unmarshal payload")
req.Reply(false, []byte{}) req.Reply(false, []byte{})
return nil, nil, fmt.Errorf("Unable to parse payload")
} }
log.Println("Request:", req.Type, req.WantReply, payload) log.Println("Request:", req.Type, req.WantReply, payload)
@ -217,7 +249,7 @@ func handleTcpIpForward(conn *ssh.ServerConn, req *ssh.Request) {
if payload.Addr != "localhost" { if payload.Addr != "localhost" {
log.Printf("Payload address is not \"localhost\"") log.Printf("Payload address is not \"localhost\"")
req.Reply(false, []byte{}) req.Reply(false, []byte{})
return return nil, nil, fmt.Errorf("Address is not permitted")
} }
// TODO: Check port // TODO: Check port
@ -233,84 +265,76 @@ func handleTcpIpForward(conn *ssh.ServerConn, req *ssh.Request) {
if err != nil { if err != nil {
log.Printf("Listen failed for %s", bind) log.Printf("Listen failed for %s", bind)
req.Reply(false, []byte{}) req.Reply(false, []byte{})
return return nil, nil, err
} }
globalListens[bind] = ln
// Tell client everything is OK // Tell client everything is OK
reply := tcpIpForwardPayloadReply{lport} reply := tcpIpForwardPayloadReply{lport}
req.Reply(true, ssh.Marshal(&reply)) req.Reply(true, ssh.Marshal(&reply))
// Ensure that we get notified when the client connection is (unexpectedly) return ln, &bindInfo{bind, lport, laddr}, nil
// closed
go func() {
err := conn.Wait()
if *verbose {
log.Printf("SSH connection closed: %s. Stopping listen", err)
}
ln.Close()
delete(globalListens, bind)
// We don't close existing connections }
}()
func handleListener(client *sshClient, bindinfo *bindInfo, listener net.Listener) {
// Start listening for connections // Start listening for connections
go func() { for {
for { lconn, err := listener.Accept()
lconn, err := ln.Accept() if err != nil {
if err != nil { neterr := err.(net.Error)
neterr := err.(net.Error) if neterr.Timeout() {
if neterr.Timeout() { log.Println("Accept failed with timeout:", err)
log.Println("Accept failed with timeout:", err) continue
continue }
} if neterr.Temporary() {
if neterr.Temporary() { log.Println("Accept failed with temporary:", err)
log.Println("Accept failed with temporary:", err) continue
continue
}
break
} }
// TODO: Sep function? break
go func() {
remotetcpaddr := lconn.RemoteAddr().(*net.TCPAddr)
raddr := remotetcpaddr.IP.String()
rport := uint32(remotetcpaddr.Port)
payload := forwardedTCPPayload{laddr, lport, raddr, uint32(rport)}
mpayload := ssh.Marshal(&payload)
// Open channel with client
c, requests, err := conn.OpenChannel("forwarded-tcpip", mpayload)
if err != nil {
log.Printf("Error: %s", err)
log.Println("Unable to get channel. Hanging up requesting party!")
lconn.Close()
return
}
go ssh.DiscardRequests(requests)
serve(c, lconn)
}()
} }
}()
// TODO: I don't think a goroutine is required here
go handleForwardTcpIp(client, bindinfo, lconn)
}
} }
func handleTcpIPForwardCancel(req *ssh.Request) { func handleForwardTcpIp(client *sshClient, bindinfo *bindInfo, lconn net.Conn) {
remotetcpaddr := lconn.RemoteAddr().(*net.TCPAddr)
raddr := remotetcpaddr.IP.String()
rport := uint32(remotetcpaddr.Port)
payload := forwardedTCPPayload{bindinfo.Addr, bindinfo.Port, raddr, uint32(rport)}
mpayload := ssh.Marshal(&payload)
// Open channel with client
c, requests, err := client.Conn.OpenChannel("forwarded-tcpip", mpayload)
if err != nil {
log.Printf("Error: %s", err)
log.Println("Unable to get channel. Hanging up requesting party!")
lconn.Close()
return
}
go ssh.DiscardRequests(requests)
serve(c, lconn)
}
func handleTcpIPForwardCancel(client *sshClient, req *ssh.Request) {
if *verbose {
log.Println("Cancel called by client", client)
}
var payload tcpIpForwardCancelPayload var payload tcpIpForwardCancelPayload
if err := ssh.Unmarshal(req.Payload, &payload); err != nil { if err := ssh.Unmarshal(req.Payload, &payload); err != nil {
log.Println("Unable to unmarshal cancel payload") log.Println("Unable to unmarshal cancel payload")
req.Reply(false, []byte{}) req.Reply(false, []byte{})
} }
//bound := fmt.Sprintf(":%d", payload.Port)
bound := fmt.Sprintf("%s:%d", payload.Addr, payload.Port) bound := fmt.Sprintf("%s:%d", payload.Addr, payload.Port)
if listener, found := globalListens[bound]; found { if listener, found := client.Listeners[bound]; found {
listener.Close() listener.Close()
delete(globalListens, bound) delete(client.Listeners, bound)
req.Reply(true, []byte{}) req.Reply(true, []byte{})
} }
@ -401,7 +425,7 @@ func loadAuthorisedKeys(authorisedkeys string) {
} }
} }
func handleRequest(sshConn *ssh.ServerConn, reqs <-chan *ssh.Request) { func handleRequest(client *sshClient, reqs <-chan *ssh.Request) {
for req := range reqs { for req := range reqs {
if *verbose { if *verbose {
log.Println("Out of band request:", req.Type, req.WantReply) log.Println("Out of band request:", req.Type, req.WantReply)
@ -409,10 +433,16 @@ func handleRequest(sshConn *ssh.ServerConn, reqs <-chan *ssh.Request) {
// RFC4254: 7.1 for forwarding // RFC4254: 7.1 for forwarding
if req.Type == "tcpip-forward" { if req.Type == "tcpip-forward" {
handleTcpIpForward(sshConn, req) listener, bindinfo, err := handleTcpIpForward(client, req)
if err != nil {
continue
}
client.Listeners[bindinfo.Bound] = listener
go handleListener(client, bindinfo, listener)
continue continue
} else if req.Type == "cancel-tcpip-forward" { } else if req.Type == "cancel-tcpip-forward" {
handleTcpIPForwardCancel(req) handleTcpIPForwardCancel(client, req)
continue continue
} else { } else {
// Discard everything else // Discard everything else

Loading…
Cancel
Save