sshd implementation in Go, for the sole purpose of restricting the ports that clients can request using direct-tcpip and tcpip-forward / forwarded-tcpip.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
go-sshd/sshd.go

614 lines
16 KiB

package main
// Copyright:
// Merlijn B. W. Wajer <merlijn@wizzup.org>
// (C) 2017
import (
"flag"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"strconv"
"strings"
"sync"
"time"
"golang.org/x/crypto/ssh"
)
var (
// Contains a mapping of authorised keys to permissions for said key
authorisedKeys map[string]deviceInfo
listenaddr = flag.String("listenaddr", "0.0.0.0", "Addr to listen on for incoming ssh connections")
listenport = flag.Int("listenport", 2200, "Port to listen on for incoming ssh connections")
hostkey = flag.String("hostkey", "id_rsa", "Server host key to load")
authorisedkeys = flag.String("authorisedkeys", "authorized_keys", "Authorised keys")
verbose = flag.Bool("verbose", false, "Enable verbose mode")
debug = flag.Bool("debug", false, "Enable debug mode")
// Currently the timeouts are not separate for read and write deadlines.
// This could be done, but I currently don't really see a reason for this.
maintimeout = flag.Duration("main-timeout", time.Duration(2)*time.Minute, "Client socket timeout")
directtimeout = flag.Duration("direct-timeout", time.Duration(2)*time.Minute, "direct-tcpip timeout")
forwardedtimeout = flag.Duration("forwarded-timeout", time.Duration(2)*time.Minute, "forwarded-tcpip timeout")
// Mutex protecting 'authorisedKeys' map
authmutex sync.Mutex
)
// Structure that holds all information for each connection/client
type sshClient struct {
Name string
// We keep track of the normal Conn as well so that we have access to the
// SetDeadline() methods
Conn net.Conn
SshConn *ssh.ServerConn
// Listener sockets opened by the client
Listeners map[string]net.Listener
AllowedLocalPorts []uint32
AllowedRemotePorts []uint32
// This indicates that a client is shutting down. When a client is stopping,
// we do not allow new listening requests, to prevent a listener connection
// being opened just after we closed all of them.
Stopping bool
ListenMutex sync.Mutex
}
// Structure containing what address/port we should bind on, for forwarded-tcpip
// connections
type bindInfo struct {
Bound string
Port uint32
Addr string
}
// Information parsed from the authorized_keys file
type deviceInfo struct {
LocalPorts string
RemotePorts string
Comment string
}
/* RFC4254 7.2 */
type directTCPPayload struct {
Addr string // To connect to
Port uint32
OriginAddr string
OriginPort uint32
}
type forwardedTCPPayload struct {
Addr string // Is connected to
Port uint32
OriginAddr string
OriginPort uint32
}
type tcpIpForwardPayload struct {
Addr string
Port uint32
}
type tcpIpForwardPayloadReply struct {
Port uint32
}
type tcpIpForwardCancelPayload struct {
Addr string
Port uint32
}
// Function that can be used to implement calls to SetDeadline() after
// read/writes in copyTimeout()
type TimeoutFunc func()
func main() {
flag.Parse()
config := &ssh.ServerConfig{
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
authmutex.Lock()
defer authmutex.Unlock()
if clientinfo, found := authorisedKeys[string(key.Marshal())]; found {
return &ssh.Permissions{
CriticalOptions: map[string]string{"name": clientinfo.Comment,
"localports": clientinfo.LocalPorts,
"remoteports": clientinfo.RemotePorts},
}, nil
}
return nil, fmt.Errorf("Unknown public key\n")
},
}
loadHostKeys(config)
loadAuthorisedKeys(*authorisedkeys)
registerReloadSignal()
bind := fmt.Sprintf("[%s]:%d", *listenaddr, *listenport)
listener, err := net.Listen("tcp", bind)
if err != nil {
log.Fatalf("Failed to listen on %s (%s)", listenport, err)
}
// Accept all connections
log.Printf("Listening on %d...", *listenport)
for {
tcpConn, err := listener.Accept()
if err != nil {
log.Printf("Failed to accept incoming connection (%s)", err)
continue
}
tcpConn.SetDeadline(time.Now().Add(*maintimeout))
// We perform the ssh handshake in a goroutine so the handshake cannot
// block incoming connections.
go func() {
sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, config)
if err != nil {
log.Printf("Failed to handshake: %s (rip: %v)", err, tcpConn.RemoteAddr())
return
}
client := sshClient{sshConn.Permissions.CriticalOptions["name"], tcpConn, sshConn, make(map[string]net.Listener), nil, nil, false, sync.Mutex{}}
allowedLocalPorts := sshConn.Permissions.CriticalOptions["localports"]
allowedRemotePorts := sshConn.Permissions.CriticalOptions["remoteports"]
if *verbose {
log.Printf("[%s] Connection from %s (%s). Allowed local ports: %s remote ports: %s", client.Name, sshConn.RemoteAddr(), sshConn.ClientVersion(), allowedLocalPorts, allowedRemotePorts)
}
// Parsing a second time should not error, so we can ignore the error
// safely
client.AllowedLocalPorts, _ = parsePorts(allowedLocalPorts)
client.AllowedRemotePorts, _ = parsePorts(allowedRemotePorts)
// Start the clean-up function: will wait for the socket to be
// closed (either by remote, protocol or deadline/timeout)
// and close any listeners if any
go func() {
err := client.SshConn.Wait()
client.ListenMutex.Lock()
defer client.ListenMutex.Unlock()
client.Stopping = true
if *verbose {
log.Printf("[%s] SSH connection closed: %s", client.Name, err)
}
for bind, listener := range client.Listeners {
if *verbose {
log.Printf("[%s] Closing listener bound to %s", client.Name, bind)
}
listener.Close()
}
}()
// Accept requests & channels
go handleRequest(&client, reqs)
go handleChannels(&client, chans)
}()
}
}
func handleChannels(client *sshClient, chans <-chan ssh.NewChannel) {
for c := range chans {
go handleChannel(client, c)
}
}
func handleChannel(client *sshClient, newChannel ssh.NewChannel) {
if *debug {
log.Printf("[%s] Channel type: %v", client.Name, newChannel.ChannelType())
}
if t := newChannel.ChannelType(); t == "direct-tcpip" {
handleDirect(client, newChannel)
return
}
newChannel.Reject(ssh.Prohibited, "Only \"direct-tcpip\" is accepted")
/*
// XXX: Use this only for testing purposes -- I add this in if/when I
// want to use the ssh escape sequences from ssh (those only work in an
// interactive session)
c, _, err := newChannel.Accept()
if err != nil {
log.Fatal(err)
}
go func() {
d := make([]byte, 4096)
c.Read(d)
}()
*/
return
}
func handleDirect(client *sshClient, newChannel ssh.NewChannel) {
var payload directTCPPayload
if err := ssh.Unmarshal(newChannel.ExtraData(), &payload); err != nil {
log.Printf("[%s] Could not unmarshal extra data: %s", client.Name, err)
newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Bad payload"))
return
}
/*
// XXX: Is this sensible?
if payload.Addr != "localhost" && payload.Addr != "::1" && payload.Addr != "127.0.0.1" {
log.Printf("[%s] Tried to connect to prohibited host: %s", client.Name, payload.Addr)
newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Bad addr"))
return
}
*/
if !portPermitted(payload.Port, client.AllowedLocalPorts) {
newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Bad port"))
log.Printf("[%s] Tried to connect to prohibited port: %d", client.Name, payload.Port)
return
}
connection, requests, err := newChannel.Accept()
if err != nil {
log.Printf("[%s] Could not accept channel (%s)", client.Name, err)
return
}
go ssh.DiscardRequests(requests)
addr := fmt.Sprintf("[%s]:%d", payload.Addr, payload.Port)
if *verbose {
log.Printf("[%s] Dialing: %s", client.Name, addr)
}
rconn, err := net.Dial("tcp", addr)
if err != nil {
log.Printf("[%s] Could not dial remote (%s)", client.Name, err)
connection.Close()
return
}
serve(connection, rconn, client, *directtimeout)
}
func handleTcpIpForward(client *sshClient, req *ssh.Request) (net.Listener, *bindInfo, error) {
var payload tcpIpForwardPayload
if err := ssh.Unmarshal(req.Payload, &payload); err != nil {
log.Printf("[%s] Unable to unmarshal payload", client.Name)
req.Reply(false, []byte{})
return nil, nil, fmt.Errorf("Unable to parse payload")
}
if *verbose {
log.Printf("[%s] Request: %s %v %v", client.Name, req.Type, req.WantReply, payload)
log.Printf("[%s] Request to listen on %s:%d", client.Name, payload.Addr, payload.Port)
}
if payload.Addr != "localhost" && payload.Addr != "" {
log.Printf("[%s] Payload address is not \"localhost\" or empty: %s", client.Name, payload.Addr)
req.Reply(false, []byte{})
return nil, nil, fmt.Errorf("Address is not permitted")
}
if !portPermitted(payload.Port, client.AllowedRemotePorts) {
log.Printf("[%s] Port is not permitted: %d", client.Name, payload.Port)
req.Reply(false, []byte{})
return nil, nil, fmt.Errorf("Port is not permitted")
}
laddr := payload.Addr
lport := payload.Port
bind := fmt.Sprintf("[%s]:%d", laddr, lport)
ln, err := net.Listen("tcp", bind)
if err != nil {
log.Printf("[%s] Listen failed for %s", client.Name, bind)
req.Reply(false, []byte{})
return nil, nil, err
}
// Tell client everything is OK
reply := tcpIpForwardPayloadReply{lport}
req.Reply(true, ssh.Marshal(&reply))
return ln, &bindInfo{bind, lport, laddr}, nil
}
func handleListener(client *sshClient, bindinfo *bindInfo, listener net.Listener) {
// Start listening for connections
for {
lconn, err := listener.Accept()
if err != nil {
neterr := err.(net.Error)
if neterr.Timeout() {
log.Printf("[%s] Accept failed with timeout: %s", client.Name, err)
continue
}
if neterr.Temporary() {
log.Printf("[%s] Accept failed with temporary: %s", client.Name, err)
continue
}
break
}
go handleForwardTcpIp(client, bindinfo, lconn)
}
}
func handleForwardTcpIp(client *sshClient, bindinfo *bindInfo, lconn net.Conn) {
remotetcpaddr := lconn.RemoteAddr().(*net.TCPAddr)
raddr := remotetcpaddr.IP.String()
rport := uint32(remotetcpaddr.Port)
payload := forwardedTCPPayload{bindinfo.Addr, bindinfo.Port, raddr, uint32(rport)}
mpayload := ssh.Marshal(&payload)
// Open channel with client
c, requests, err := client.SshConn.OpenChannel("forwarded-tcpip", mpayload)
if err != nil {
log.Printf("[%s] Unable to get channel: %s. Hanging up requesting party!", client.Name, err)
lconn.Close()
return
}
if *verbose {
log.Printf("[%s] Channel opened for client", client.Name)
}
go ssh.DiscardRequests(requests)
serve(c, lconn, client, *forwardedtimeout)
}
func handleTcpIPForwardCancel(client *sshClient, req *ssh.Request) {
if *verbose {
log.Printf("[%s] \"cancel-tcpip-forward\" called by client", client.Name)
}
var payload tcpIpForwardCancelPayload
if err := ssh.Unmarshal(req.Payload, &payload); err != nil {
log.Printf("[%s] Unable to unmarshal cancel payload", client.Name)
req.Reply(false, []byte{})
}
bound := fmt.Sprintf("%s:%d", payload.Addr, payload.Port)
if listener, found := client.Listeners[bound]; found {
listener.Close()
delete(client.Listeners, bound)
req.Reply(true, []byte{})
}
req.Reply(false, []byte{})
}
func handleRequest(client *sshClient, reqs <-chan *ssh.Request) {
for req := range reqs {
client.Conn.SetDeadline(time.Now().Add(*maintimeout))
if *debug {
log.Printf("[%s] Out of band request: %v %v", client.Name, req.Type, req.WantReply)
}
// RFC4254: 7.1 for forwarding
if req.Type == "tcpip-forward" {
client.ListenMutex.Lock()
/* If we are closing, do not set up a new listener */
if client.Stopping {
client.ListenMutex.Unlock()
req.Reply(false, []byte{})
continue
}
listener, bindinfo, err := handleTcpIpForward(client, req)
if err != nil {
client.ListenMutex.Unlock()
continue
}
client.Listeners[bindinfo.Bound] = listener
client.ListenMutex.Unlock()
go handleListener(client, bindinfo, listener)
continue
} else if req.Type == "cancel-tcpip-forward" {
client.ListenMutex.Lock()
handleTcpIPForwardCancel(client, req)
client.ListenMutex.Unlock()
continue
} else {
// Discard everything else
req.Reply(false, []byte{})
}
}
}
func serve(cssh ssh.Channel, conn net.Conn, client *sshClient, timeout time.Duration) {
close := func() {
cssh.Close()
conn.Close()
if *verbose {
log.Printf("[%s] Channel closed.", client.Name)
}
}
var once sync.Once
go func() {
//io.Copy(cssh, conn)
bytes_written, err := copyTimeout(cssh, conn, func() {
if *debug {
log.Printf("[%s] Updating deadline for direct|forwarded socket and main socket (sending data)", client.Name)
}
conn.SetDeadline(time.Now().Add(timeout))
client.Conn.SetDeadline(time.Now().Add(*maintimeout))
})
if err != nil {
if *debug {
log.Printf("[%s] copyTimeout failed with: %s", client.Name, err)
}
}
if *verbose {
log.Printf("[%s] Connection closed, bytes written: %d", client.Name, bytes_written)
}
once.Do(close)
}()
go func() {
//io.Copy(conn, cssh)
bytes_written, err := copyTimeout(conn, cssh, func() {
if *debug {
log.Printf("[%s] Updating deadline for direct|forwarded socket and main socket (received data)", client.Name)
}
conn.SetDeadline(time.Now().Add(timeout))
client.Conn.SetDeadline(time.Now().Add(*maintimeout))
})
if err != nil {
if *debug {
log.Printf("[%s] copyTimeout failed with: %s", client.Name, err)
}
}
if *verbose {
log.Printf("[%s] Connection closed, bytes written: %d", client.Name, bytes_written)
}
once.Do(close)
}()
}
// Changed from pkg/io/io.go copyBuffer
func copyTimeout(dst io.Writer, src io.Reader, timeout TimeoutFunc) (written int64, err error) {
buf := make([]byte, 32*1024)
for {
nr, er := src.Read(buf)
if nr > 0 {
timeout()
nw, ew := dst.Write(buf[0:nr])
if nw > 0 {
written += int64(nw)
}
if ew != nil {
err = ew
break
}
if nr != nw {
err = io.ErrShortWrite
break
}
timeout()
}
if er != nil {
if er != io.EOF {
err = er
}
break
}
}
return written, err
}
func loadHostKeys(config *ssh.ServerConfig) {
privateBytes, err := ioutil.ReadFile(*hostkey)
if err != nil {
log.Fatal(fmt.Sprintf("Failed to load private key (%s)", *hostkey))
}
private, err := ssh.ParsePrivateKey(privateBytes)
if err != nil {
log.Fatal("Failed to parse private key")
}
config.AddHostKey(private)
}
func loadAuthorisedKeys(authorisedkeys string) {
authKeys := map[string]deviceInfo{}
authKeysBytes, err := ioutil.ReadFile(authorisedkeys)
if err != nil {
log.Fatal("Cannot load authorised keys")
}
for len(authKeysBytes) > 0 {
pubkey, comment, options, rest, err := ssh.ParseAuthorizedKey(authKeysBytes)
if err != nil {
log.Printf("Error parsing line: %s", err)
authKeysBytes = rest
continue
}
devinfo := deviceInfo{Comment: comment}
// TODO: Compatibility with permitopen=foo,permitopen=bar,
// permitremoteopen=quux,permitremoteopen=wobble
for _, option := range options {
ports, err := parseOption(option, "localports")
if err == nil {
devinfo.LocalPorts = ports
continue
}
ports, err = parseOption(option, "remoteports")
if err == nil {
devinfo.RemotePorts = ports
continue
}
if *verbose {
log.Println("Unknown option:", option)
}
}
authKeys[string(pubkey.Marshal())] = devinfo
authKeysBytes = rest
}
authmutex.Lock()
defer authmutex.Unlock()
authorisedKeys = authKeys
}
func portPermitted(port uint32, ports []uint32) bool {
ok := false
for _, p := range ports {
if port == p {
ok = true
break
}
}
return ok
}
func parseOption(option string, prefix string) (string, error) {
str := fmt.Sprintf("%s=", prefix)
if !strings.HasPrefix(option, str) {
return "", fmt.Errorf("Option does not start with %s", str)
}
ports := option[len(str):]
if _, err := parsePorts(ports); err != nil {
log.Fatal(err)
}
return ports, nil
}
func parsePorts(portstr string) (p []uint32, err error) {
ports := strings.Split(portstr, ":")
for _, port := range ports {
port, err := strconv.ParseUint(port, 10, 32)
if err != nil {
return p, err
}
p = append(p, uint32(port))
}
return
}