Remove global state, refactoring

master
Merlijn Wajer 7 years ago
parent 8d13c9501c
commit ebbf5692fa
  1. 2
      README.rst
  2. 180
      sshd.go

@ -2,7 +2,7 @@ Motivation
==========
sshd implementation in Go, for the sole purpose of restricting the ports that
clients can request using direct-tcpip.
clients can request using direct-tcpip and tcpip-forward / forwarded-tcpip.
OpenSSH refuses to merge patches to support this, but there is a fork of OpenSSH
with patches that achieve something similar to this. [1]

@ -18,18 +18,31 @@ import (
"golang.org/x/crypto/ssh"
)
// TODO: Use defer where useful
var (
authorisedKeys map[string]string
/* Global listeners, we keep a global state for cancel-tcpip-forward */
globalListens map[string]net.Listener
listenport = flag.Int("listenport", 2200, "Port to listen on for incoming ssh connections")
hostkey = flag.String("hostkey", "id_rsa", "Server host key to load")
authorisedkeys = flag.String("authorisedkeys", "authorized_keys", "Authorised keys")
verbose = flag.Bool("verbose", false, "Enable verbose mode")
)
type sshClient struct {
Name string
Conn *ssh.ServerConn
Listeners map[string]net.Listener
AllowedLocalPorts []uint32
AllowedRemotePorts []uint32
}
type bindInfo struct {
Bound string
Port uint32
Addr string
}
/* RFC4254 7.2 */
type directTCPPayload struct {
Addr string // To connect to
@ -62,8 +75,6 @@ type tcpIpForwardCancelPayload struct {
func main() {
flag.Parse()
globalListens = make(map[string]net.Listener)
config := &ssh.ServerConfig{
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
if ports, found := authorisedKeys[string(key.Marshal())]; found {
@ -101,6 +112,23 @@ func main() {
return
}
client := sshClient{"TODO FIXME XXX", sshConn, make(map[string]net.Listener), nil, nil}
go func() {
err := client.Conn.Wait()
if *verbose {
log.Printf("SSH connection closed for client %s: %s", client.Name, err)
}
// TODO: Make this safe? Is it impossible for cancel code to be
// running at this point?
for bind, listener := range client.Listeners {
if *verbose {
log.Printf("Closing listener bound to %s", bind)
}
listener.Close()
}
}()
allowedPorts := sshConn.Permissions.CriticalOptions["ports"]
if *verbose {
@ -110,27 +138,30 @@ func main() {
// Parsing a second time should not error, so we can ignore the error
// safely
ports, _ := parsePorts(allowedPorts)
// TODO: Don't share same port/host limit
client.AllowedLocalPorts = ports
client.AllowedRemotePorts = ports
go handleRequest(sshConn, reqs)
go handleRequest(&client, reqs)
// Accept all channels
go handleChannels(chans, ports)
// Accept all channels (TODO: Pass client)
go handleChannels(&client, chans)
}()
}
}
func handleChannels(chans <-chan ssh.NewChannel, ports []uint32) {
func handleChannels(client *sshClient, chans <-chan ssh.NewChannel) {
for c := range chans {
go handleChannel(c, ports)
go handleChannel(client, c)
}
}
func handleChannel(newChannel ssh.NewChannel, ports []uint32) {
func handleChannel(client *sshClient, newChannel ssh.NewChannel) {
if *verbose {
log.Println("Channel type:", newChannel.ChannelType())
}
if t := newChannel.ChannelType(); t == "direct-tcpip" {
handleDirect(newChannel, ports)
handleDirect(client, newChannel)
return
}
@ -145,12 +176,12 @@ func handleChannel(newChannel ssh.NewChannel, ports []uint32) {
d := make([]byte, 4096)
c.Read(d)
}()
return
*/
return
}
func handleDirect(newChannel ssh.NewChannel, ports []uint32) {
func handleDirect(client *sshClient, newChannel ssh.NewChannel) {
var payload directTCPPayload
if err := ssh.Unmarshal(newChannel.ExtraData(), &payload); err != nil {
log.Printf("Could not unmarshal extra data: %s\n", err)
@ -166,7 +197,7 @@ func handleDirect(newChannel ssh.NewChannel, ports []uint32) {
}
ok := false
for _, port := range ports {
for _, port := range client.AllowedLocalPorts {
if payload.Port == port {
ok = true
break
@ -203,11 +234,12 @@ func handleDirect(newChannel ssh.NewChannel, ports []uint32) {
serve(connection, rconn)
}
func handleTcpIpForward(conn *ssh.ServerConn, req *ssh.Request) {
func handleTcpIpForward(client *sshClient, req *ssh.Request) (net.Listener, *bindInfo, error) {
var payload tcpIpForwardPayload
if err := ssh.Unmarshal(req.Payload, &payload); err != nil {
log.Println("Unable to unmarshal payload")
req.Reply(false, []byte{})
return nil, nil, fmt.Errorf("Unable to parse payload")
}
log.Println("Request:", req.Type, req.WantReply, payload)
@ -217,7 +249,7 @@ func handleTcpIpForward(conn *ssh.ServerConn, req *ssh.Request) {
if payload.Addr != "localhost" {
log.Printf("Payload address is not \"localhost\"")
req.Reply(false, []byte{})
return
return nil, nil, fmt.Errorf("Address is not permitted")
}
// TODO: Check port
@ -233,84 +265,76 @@ func handleTcpIpForward(conn *ssh.ServerConn, req *ssh.Request) {
if err != nil {
log.Printf("Listen failed for %s", bind)
req.Reply(false, []byte{})
return
return nil, nil, err
}
globalListens[bind] = ln
// Tell client everything is OK
reply := tcpIpForwardPayloadReply{lport}
req.Reply(true, ssh.Marshal(&reply))
// Ensure that we get notified when the client connection is (unexpectedly)
// closed
go func() {
err := conn.Wait()
if *verbose {
log.Printf("SSH connection closed: %s. Stopping listen", err)
}
ln.Close()
delete(globalListens, bind)
return ln, &bindInfo{bind, lport, laddr}, nil
// We don't close existing connections
}()
}
func handleListener(client *sshClient, bindinfo *bindInfo, listener net.Listener) {
// Start listening for connections
go func() {
for {
lconn, err := ln.Accept()
if err != nil {
neterr := err.(net.Error)
if neterr.Timeout() {
log.Println("Accept failed with timeout:", err)
continue
}
if neterr.Temporary() {
log.Println("Accept failed with temporary:", err)
continue
}
break
for {
lconn, err := listener.Accept()
if err != nil {
neterr := err.(net.Error)
if neterr.Timeout() {
log.Println("Accept failed with timeout:", err)
continue
}
if neterr.Temporary() {
log.Println("Accept failed with temporary:", err)
continue
}
// TODO: Sep function?
go func() {
remotetcpaddr := lconn.RemoteAddr().(*net.TCPAddr)
raddr := remotetcpaddr.IP.String()
rport := uint32(remotetcpaddr.Port)
payload := forwardedTCPPayload{laddr, lport, raddr, uint32(rport)}
mpayload := ssh.Marshal(&payload)
// Open channel with client
c, requests, err := conn.OpenChannel("forwarded-tcpip", mpayload)
if err != nil {
log.Printf("Error: %s", err)
log.Println("Unable to get channel. Hanging up requesting party!")
lconn.Close()
return
}
go ssh.DiscardRequests(requests)
serve(c, lconn)
}()
break
}
}()
// TODO: I don't think a goroutine is required here
go handleForwardTcpIp(client, bindinfo, lconn)
}
}
func handleTcpIPForwardCancel(req *ssh.Request) {
func handleForwardTcpIp(client *sshClient, bindinfo *bindInfo, lconn net.Conn) {
remotetcpaddr := lconn.RemoteAddr().(*net.TCPAddr)
raddr := remotetcpaddr.IP.String()
rport := uint32(remotetcpaddr.Port)
payload := forwardedTCPPayload{bindinfo.Addr, bindinfo.Port, raddr, uint32(rport)}
mpayload := ssh.Marshal(&payload)
// Open channel with client
c, requests, err := client.Conn.OpenChannel("forwarded-tcpip", mpayload)
if err != nil {
log.Printf("Error: %s", err)
log.Println("Unable to get channel. Hanging up requesting party!")
lconn.Close()
return
}
go ssh.DiscardRequests(requests)
serve(c, lconn)
}
func handleTcpIPForwardCancel(client *sshClient, req *ssh.Request) {
if *verbose {
log.Println("Cancel called by client", client)
}
var payload tcpIpForwardCancelPayload
if err := ssh.Unmarshal(req.Payload, &payload); err != nil {
log.Println("Unable to unmarshal cancel payload")
req.Reply(false, []byte{})
}
//bound := fmt.Sprintf(":%d", payload.Port)
bound := fmt.Sprintf("%s:%d", payload.Addr, payload.Port)
if listener, found := globalListens[bound]; found {
if listener, found := client.Listeners[bound]; found {
listener.Close()
delete(globalListens, bound)
delete(client.Listeners, bound)
req.Reply(true, []byte{})
}
@ -401,7 +425,7 @@ func loadAuthorisedKeys(authorisedkeys string) {
}
}
func handleRequest(sshConn *ssh.ServerConn, reqs <-chan *ssh.Request) {
func handleRequest(client *sshClient, reqs <-chan *ssh.Request) {
for req := range reqs {
if *verbose {
log.Println("Out of band request:", req.Type, req.WantReply)
@ -409,10 +433,16 @@ func handleRequest(sshConn *ssh.ServerConn, reqs <-chan *ssh.Request) {
// RFC4254: 7.1 for forwarding
if req.Type == "tcpip-forward" {
handleTcpIpForward(sshConn, req)
listener, bindinfo, err := handleTcpIpForward(client, req)
if err != nil {
continue
}
client.Listeners[bindinfo.Bound] = listener
go handleListener(client, bindinfo, listener)
continue
} else if req.Type == "cancel-tcpip-forward" {
handleTcpIPForwardCancel(req)
handleTcpIPForwardCancel(client, req)
continue
} else {
// Discard everything else

Loading…
Cancel
Save