You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
614 lines
16 KiB
614 lines
16 KiB
package main
|
|
|
|
// Copyright:
|
|
// Merlijn B. W. Wajer <merlijn@wizzup.org>
|
|
// (C) 2017
|
|
|
|
import (
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"log"
|
|
"net"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
var (
|
|
// Contains a mapping of authorised keys to permissions for said key
|
|
authorisedKeys map[string]deviceInfo
|
|
|
|
listenaddr = flag.String("listenaddr", "0.0.0.0", "Addr to listen on for incoming ssh connections")
|
|
listenport = flag.Int("listenport", 2200, "Port to listen on for incoming ssh connections")
|
|
hostkey = flag.String("hostkey", "id_rsa", "Server host key to load")
|
|
authorisedkeys = flag.String("authorisedkeys", "authorized_keys", "Authorised keys")
|
|
verbose = flag.Bool("verbose", false, "Enable verbose mode")
|
|
debug = flag.Bool("debug", false, "Enable debug mode")
|
|
|
|
// Currently the timeouts are not separate for read and write deadlines.
|
|
// This could be done, but I currently don't really see a reason for this.
|
|
maintimeout = flag.Duration("main-timeout", time.Duration(2)*time.Minute, "Client socket timeout")
|
|
directtimeout = flag.Duration("direct-timeout", time.Duration(2)*time.Minute, "direct-tcpip timeout")
|
|
forwardedtimeout = flag.Duration("forwarded-timeout", time.Duration(2)*time.Minute, "forwarded-tcpip timeout")
|
|
|
|
// Mutex protecting 'authorisedKeys' map
|
|
authmutex sync.Mutex
|
|
)
|
|
|
|
// Structure that holds all information for each connection/client
|
|
type sshClient struct {
|
|
Name string
|
|
|
|
// We keep track of the normal Conn as well so that we have access to the
|
|
// SetDeadline() methods
|
|
Conn net.Conn
|
|
|
|
SshConn *ssh.ServerConn
|
|
|
|
// Listener sockets opened by the client
|
|
Listeners map[string]net.Listener
|
|
|
|
AllowedLocalPorts []uint32
|
|
AllowedRemotePorts []uint32
|
|
|
|
// This indicates that a client is shutting down. When a client is stopping,
|
|
// we do not allow new listening requests, to prevent a listener connection
|
|
// being opened just after we closed all of them.
|
|
Stopping bool
|
|
ListenMutex sync.Mutex
|
|
}
|
|
|
|
// Structure containing what address/port we should bind on, for forwarded-tcpip
|
|
// connections
|
|
type bindInfo struct {
|
|
Bound string
|
|
Port uint32
|
|
Addr string
|
|
}
|
|
|
|
// Information parsed from the authorized_keys file
|
|
type deviceInfo struct {
|
|
LocalPorts string
|
|
RemotePorts string
|
|
Comment string
|
|
}
|
|
|
|
/* RFC4254 7.2 */
|
|
type directTCPPayload struct {
|
|
Addr string // To connect to
|
|
Port uint32
|
|
OriginAddr string
|
|
OriginPort uint32
|
|
}
|
|
|
|
type forwardedTCPPayload struct {
|
|
Addr string // Is connected to
|
|
Port uint32
|
|
OriginAddr string
|
|
OriginPort uint32
|
|
}
|
|
|
|
type tcpIpForwardPayload struct {
|
|
Addr string
|
|
Port uint32
|
|
}
|
|
|
|
type tcpIpForwardPayloadReply struct {
|
|
Port uint32
|
|
}
|
|
|
|
type tcpIpForwardCancelPayload struct {
|
|
Addr string
|
|
Port uint32
|
|
}
|
|
|
|
// Function that can be used to implement calls to SetDeadline() after
|
|
// read/writes in copyTimeout()
|
|
type TimeoutFunc func()
|
|
|
|
func main() {
|
|
flag.Parse()
|
|
|
|
config := &ssh.ServerConfig{
|
|
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
|
authmutex.Lock()
|
|
defer authmutex.Unlock()
|
|
if clientinfo, found := authorisedKeys[string(key.Marshal())]; found {
|
|
return &ssh.Permissions{
|
|
CriticalOptions: map[string]string{"name": clientinfo.Comment,
|
|
"localports": clientinfo.LocalPorts,
|
|
"remoteports": clientinfo.RemotePorts},
|
|
}, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("Unknown public key\n")
|
|
},
|
|
}
|
|
|
|
loadHostKeys(config)
|
|
loadAuthorisedKeys(*authorisedkeys)
|
|
|
|
registerReloadSignal()
|
|
|
|
bind := fmt.Sprintf("[%s]:%d", *listenaddr, *listenport)
|
|
listener, err := net.Listen("tcp", bind)
|
|
if err != nil {
|
|
log.Fatalf("Failed to listen on %s (%s)", listenport, err)
|
|
}
|
|
|
|
// Accept all connections
|
|
log.Printf("Listening on %d...", *listenport)
|
|
for {
|
|
tcpConn, err := listener.Accept()
|
|
if err != nil {
|
|
log.Printf("Failed to accept incoming connection (%s)", err)
|
|
continue
|
|
}
|
|
|
|
tcpConn.SetDeadline(time.Now().Add(*maintimeout))
|
|
|
|
// We perform the ssh handshake in a goroutine so the handshake cannot
|
|
// block incoming connections.
|
|
go func() {
|
|
sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, config)
|
|
if err != nil {
|
|
log.Printf("Failed to handshake: %s (rip: %v)", err, tcpConn.RemoteAddr())
|
|
return
|
|
}
|
|
|
|
client := sshClient{sshConn.Permissions.CriticalOptions["name"], tcpConn, sshConn, make(map[string]net.Listener), nil, nil, false, sync.Mutex{}}
|
|
allowedLocalPorts := sshConn.Permissions.CriticalOptions["localports"]
|
|
allowedRemotePorts := sshConn.Permissions.CriticalOptions["remoteports"]
|
|
|
|
if *verbose {
|
|
log.Printf("[%s] Connection from %s (%s). Allowed local ports: %s remote ports: %s", client.Name, sshConn.RemoteAddr(), sshConn.ClientVersion(), allowedLocalPorts, allowedRemotePorts)
|
|
}
|
|
|
|
// Parsing a second time should not error, so we can ignore the error
|
|
// safely
|
|
client.AllowedLocalPorts, _ = parsePorts(allowedLocalPorts)
|
|
client.AllowedRemotePorts, _ = parsePorts(allowedRemotePorts)
|
|
|
|
// Start the clean-up function: will wait for the socket to be
|
|
// closed (either by remote, protocol or deadline/timeout)
|
|
// and close any listeners if any
|
|
go func() {
|
|
err := client.SshConn.Wait()
|
|
client.ListenMutex.Lock()
|
|
defer client.ListenMutex.Unlock()
|
|
client.Stopping = true
|
|
|
|
if *verbose {
|
|
log.Printf("[%s] SSH connection closed: %s", client.Name, err)
|
|
}
|
|
|
|
for bind, listener := range client.Listeners {
|
|
if *verbose {
|
|
log.Printf("[%s] Closing listener bound to %s", client.Name, bind)
|
|
}
|
|
listener.Close()
|
|
}
|
|
}()
|
|
|
|
// Accept requests & channels
|
|
go handleRequest(&client, reqs)
|
|
go handleChannels(&client, chans)
|
|
}()
|
|
}
|
|
}
|
|
|
|
func handleChannels(client *sshClient, chans <-chan ssh.NewChannel) {
|
|
for c := range chans {
|
|
go handleChannel(client, c)
|
|
}
|
|
}
|
|
|
|
func handleChannel(client *sshClient, newChannel ssh.NewChannel) {
|
|
if *debug {
|
|
log.Printf("[%s] Channel type: %v", client.Name, newChannel.ChannelType())
|
|
}
|
|
if t := newChannel.ChannelType(); t == "direct-tcpip" {
|
|
handleDirect(client, newChannel)
|
|
return
|
|
}
|
|
|
|
newChannel.Reject(ssh.Prohibited, "Only \"direct-tcpip\" is accepted")
|
|
/*
|
|
// XXX: Use this only for testing purposes -- I add this in if/when I
|
|
// want to use the ssh escape sequences from ssh (those only work in an
|
|
// interactive session)
|
|
c, _, err := newChannel.Accept()
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
go func() {
|
|
d := make([]byte, 4096)
|
|
c.Read(d)
|
|
}()
|
|
*/
|
|
return
|
|
|
|
}
|
|
|
|
func handleDirect(client *sshClient, newChannel ssh.NewChannel) {
|
|
var payload directTCPPayload
|
|
if err := ssh.Unmarshal(newChannel.ExtraData(), &payload); err != nil {
|
|
log.Printf("[%s] Could not unmarshal extra data: %s", client.Name, err)
|
|
|
|
newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Bad payload"))
|
|
return
|
|
}
|
|
|
|
/*
|
|
// XXX: Is this sensible?
|
|
if payload.Addr != "localhost" && payload.Addr != "::1" && payload.Addr != "127.0.0.1" {
|
|
log.Printf("[%s] Tried to connect to prohibited host: %s", client.Name, payload.Addr)
|
|
newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Bad addr"))
|
|
return
|
|
}
|
|
*/
|
|
|
|
if !portPermitted(payload.Port, client.AllowedLocalPorts) {
|
|
newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Bad port"))
|
|
log.Printf("[%s] Tried to connect to prohibited port: %d", client.Name, payload.Port)
|
|
return
|
|
}
|
|
|
|
connection, requests, err := newChannel.Accept()
|
|
if err != nil {
|
|
log.Printf("[%s] Could not accept channel (%s)", client.Name, err)
|
|
return
|
|
}
|
|
go ssh.DiscardRequests(requests)
|
|
|
|
addr := fmt.Sprintf("[%s]:%d", payload.Addr, payload.Port)
|
|
if *verbose {
|
|
log.Printf("[%s] Dialing: %s", client.Name, addr)
|
|
}
|
|
|
|
rconn, err := net.Dial("tcp", addr)
|
|
if err != nil {
|
|
log.Printf("[%s] Could not dial remote (%s)", client.Name, err)
|
|
connection.Close()
|
|
return
|
|
}
|
|
|
|
serve(connection, rconn, client, *directtimeout)
|
|
}
|
|
|
|
func handleTcpIpForward(client *sshClient, req *ssh.Request) (net.Listener, *bindInfo, error) {
|
|
var payload tcpIpForwardPayload
|
|
if err := ssh.Unmarshal(req.Payload, &payload); err != nil {
|
|
log.Printf("[%s] Unable to unmarshal payload", client.Name)
|
|
req.Reply(false, []byte{})
|
|
return nil, nil, fmt.Errorf("Unable to parse payload")
|
|
}
|
|
|
|
if *verbose {
|
|
log.Printf("[%s] Request: %s %v %v", client.Name, req.Type, req.WantReply, payload)
|
|
log.Printf("[%s] Request to listen on %s:%d", client.Name, payload.Addr, payload.Port)
|
|
}
|
|
|
|
if payload.Addr != "localhost" && payload.Addr != "" {
|
|
log.Printf("[%s] Payload address is not \"localhost\" or empty: %s", client.Name, payload.Addr)
|
|
req.Reply(false, []byte{})
|
|
return nil, nil, fmt.Errorf("Address is not permitted")
|
|
}
|
|
|
|
if !portPermitted(payload.Port, client.AllowedRemotePorts) {
|
|
log.Printf("[%s] Port is not permitted: %d", client.Name, payload.Port)
|
|
req.Reply(false, []byte{})
|
|
return nil, nil, fmt.Errorf("Port is not permitted")
|
|
}
|
|
|
|
laddr := payload.Addr
|
|
lport := payload.Port
|
|
|
|
bind := fmt.Sprintf("[%s]:%d", laddr, lport)
|
|
ln, err := net.Listen("tcp", bind)
|
|
if err != nil {
|
|
log.Printf("[%s] Listen failed for %s", client.Name, bind)
|
|
req.Reply(false, []byte{})
|
|
return nil, nil, err
|
|
}
|
|
|
|
// Tell client everything is OK
|
|
reply := tcpIpForwardPayloadReply{lport}
|
|
req.Reply(true, ssh.Marshal(&reply))
|
|
|
|
return ln, &bindInfo{bind, lport, laddr}, nil
|
|
|
|
}
|
|
|
|
func handleListener(client *sshClient, bindinfo *bindInfo, listener net.Listener) {
|
|
// Start listening for connections
|
|
for {
|
|
lconn, err := listener.Accept()
|
|
if err != nil {
|
|
neterr := err.(net.Error)
|
|
if neterr.Timeout() {
|
|
log.Printf("[%s] Accept failed with timeout: %s", client.Name, err)
|
|
continue
|
|
}
|
|
if neterr.Temporary() {
|
|
log.Printf("[%s] Accept failed with temporary: %s", client.Name, err)
|
|
continue
|
|
}
|
|
|
|
break
|
|
}
|
|
|
|
go handleForwardTcpIp(client, bindinfo, lconn)
|
|
}
|
|
}
|
|
|
|
func handleForwardTcpIp(client *sshClient, bindinfo *bindInfo, lconn net.Conn) {
|
|
remotetcpaddr := lconn.RemoteAddr().(*net.TCPAddr)
|
|
raddr := remotetcpaddr.IP.String()
|
|
rport := uint32(remotetcpaddr.Port)
|
|
|
|
payload := forwardedTCPPayload{bindinfo.Addr, bindinfo.Port, raddr, uint32(rport)}
|
|
mpayload := ssh.Marshal(&payload)
|
|
|
|
// Open channel with client
|
|
c, requests, err := client.SshConn.OpenChannel("forwarded-tcpip", mpayload)
|
|
if err != nil {
|
|
log.Printf("[%s] Unable to get channel: %s. Hanging up requesting party!", client.Name, err)
|
|
lconn.Close()
|
|
return
|
|
}
|
|
if *verbose {
|
|
log.Printf("[%s] Channel opened for client", client.Name)
|
|
}
|
|
go ssh.DiscardRequests(requests)
|
|
|
|
serve(c, lconn, client, *forwardedtimeout)
|
|
}
|
|
|
|
func handleTcpIPForwardCancel(client *sshClient, req *ssh.Request) {
|
|
if *verbose {
|
|
log.Printf("[%s] \"cancel-tcpip-forward\" called by client", client.Name)
|
|
}
|
|
var payload tcpIpForwardCancelPayload
|
|
if err := ssh.Unmarshal(req.Payload, &payload); err != nil {
|
|
log.Printf("[%s] Unable to unmarshal cancel payload", client.Name)
|
|
req.Reply(false, []byte{})
|
|
}
|
|
|
|
bound := fmt.Sprintf("%s:%d", payload.Addr, payload.Port)
|
|
|
|
if listener, found := client.Listeners[bound]; found {
|
|
listener.Close()
|
|
delete(client.Listeners, bound)
|
|
req.Reply(true, []byte{})
|
|
}
|
|
|
|
req.Reply(false, []byte{})
|
|
}
|
|
|
|
func handleRequest(client *sshClient, reqs <-chan *ssh.Request) {
|
|
for req := range reqs {
|
|
client.Conn.SetDeadline(time.Now().Add(*maintimeout))
|
|
|
|
if *debug {
|
|
log.Printf("[%s] Out of band request: %v %v", client.Name, req.Type, req.WantReply)
|
|
}
|
|
|
|
// RFC4254: 7.1 for forwarding
|
|
if req.Type == "tcpip-forward" {
|
|
client.ListenMutex.Lock()
|
|
/* If we are closing, do not set up a new listener */
|
|
if client.Stopping {
|
|
client.ListenMutex.Unlock()
|
|
req.Reply(false, []byte{})
|
|
continue
|
|
}
|
|
|
|
listener, bindinfo, err := handleTcpIpForward(client, req)
|
|
if err != nil {
|
|
client.ListenMutex.Unlock()
|
|
continue
|
|
}
|
|
|
|
client.Listeners[bindinfo.Bound] = listener
|
|
client.ListenMutex.Unlock()
|
|
|
|
go handleListener(client, bindinfo, listener)
|
|
continue
|
|
} else if req.Type == "cancel-tcpip-forward" {
|
|
client.ListenMutex.Lock()
|
|
handleTcpIPForwardCancel(client, req)
|
|
client.ListenMutex.Unlock()
|
|
continue
|
|
} else {
|
|
// Discard everything else
|
|
req.Reply(false, []byte{})
|
|
}
|
|
}
|
|
}
|
|
|
|
func serve(cssh ssh.Channel, conn net.Conn, client *sshClient, timeout time.Duration) {
|
|
close := func() {
|
|
cssh.Close()
|
|
conn.Close()
|
|
if *verbose {
|
|
log.Printf("[%s] Channel closed.", client.Name)
|
|
}
|
|
}
|
|
|
|
var once sync.Once
|
|
go func() {
|
|
//io.Copy(cssh, conn)
|
|
bytes_written, err := copyTimeout(cssh, conn, func() {
|
|
if *debug {
|
|
log.Printf("[%s] Updating deadline for direct|forwarded socket and main socket (sending data)", client.Name)
|
|
}
|
|
conn.SetDeadline(time.Now().Add(timeout))
|
|
client.Conn.SetDeadline(time.Now().Add(*maintimeout))
|
|
})
|
|
if err != nil {
|
|
if *debug {
|
|
log.Printf("[%s] copyTimeout failed with: %s", client.Name, err)
|
|
}
|
|
}
|
|
if *verbose {
|
|
log.Printf("[%s] Connection closed, bytes written: %d", client.Name, bytes_written)
|
|
}
|
|
once.Do(close)
|
|
}()
|
|
go func() {
|
|
//io.Copy(conn, cssh)
|
|
bytes_written, err := copyTimeout(conn, cssh, func() {
|
|
if *debug {
|
|
log.Printf("[%s] Updating deadline for direct|forwarded socket and main socket (received data)", client.Name)
|
|
}
|
|
conn.SetDeadline(time.Now().Add(timeout))
|
|
client.Conn.SetDeadline(time.Now().Add(*maintimeout))
|
|
})
|
|
if err != nil {
|
|
if *debug {
|
|
log.Printf("[%s] copyTimeout failed with: %s", client.Name, err)
|
|
}
|
|
}
|
|
if *verbose {
|
|
log.Printf("[%s] Connection closed, bytes written: %d", client.Name, bytes_written)
|
|
}
|
|
once.Do(close)
|
|
}()
|
|
}
|
|
|
|
// Changed from pkg/io/io.go copyBuffer
|
|
func copyTimeout(dst io.Writer, src io.Reader, timeout TimeoutFunc) (written int64, err error) {
|
|
buf := make([]byte, 32*1024)
|
|
|
|
for {
|
|
nr, er := src.Read(buf)
|
|
if nr > 0 {
|
|
timeout()
|
|
|
|
nw, ew := dst.Write(buf[0:nr])
|
|
if nw > 0 {
|
|
written += int64(nw)
|
|
}
|
|
if ew != nil {
|
|
err = ew
|
|
break
|
|
}
|
|
if nr != nw {
|
|
err = io.ErrShortWrite
|
|
break
|
|
}
|
|
timeout()
|
|
}
|
|
if er != nil {
|
|
if er != io.EOF {
|
|
err = er
|
|
}
|
|
break
|
|
}
|
|
}
|
|
return written, err
|
|
}
|
|
|
|
func loadHostKeys(config *ssh.ServerConfig) {
|
|
privateBytes, err := ioutil.ReadFile(*hostkey)
|
|
if err != nil {
|
|
log.Fatal(fmt.Sprintf("Failed to load private key (%s)", *hostkey))
|
|
}
|
|
|
|
private, err := ssh.ParsePrivateKey(privateBytes)
|
|
if err != nil {
|
|
log.Fatal("Failed to parse private key")
|
|
}
|
|
|
|
config.AddHostKey(private)
|
|
}
|
|
|
|
func loadAuthorisedKeys(authorisedkeys string) {
|
|
authKeys := map[string]deviceInfo{}
|
|
authKeysBytes, err := ioutil.ReadFile(authorisedkeys)
|
|
if err != nil {
|
|
log.Fatal("Cannot load authorised keys")
|
|
}
|
|
|
|
for len(authKeysBytes) > 0 {
|
|
pubkey, comment, options, rest, err := ssh.ParseAuthorizedKey(authKeysBytes)
|
|
|
|
if err != nil {
|
|
log.Printf("Error parsing line: %s", err)
|
|
authKeysBytes = rest
|
|
continue
|
|
}
|
|
|
|
devinfo := deviceInfo{Comment: comment}
|
|
|
|
// TODO: Compatibility with permitopen=foo,permitopen=bar,
|
|
// permitremoteopen=quux,permitremoteopen=wobble
|
|
for _, option := range options {
|
|
ports, err := parseOption(option, "localports")
|
|
if err == nil {
|
|
devinfo.LocalPorts = ports
|
|
continue
|
|
}
|
|
ports, err = parseOption(option, "remoteports")
|
|
if err == nil {
|
|
devinfo.RemotePorts = ports
|
|
continue
|
|
}
|
|
if *verbose {
|
|
log.Println("Unknown option:", option)
|
|
}
|
|
}
|
|
|
|
authKeys[string(pubkey.Marshal())] = devinfo
|
|
|
|
authKeysBytes = rest
|
|
}
|
|
|
|
authmutex.Lock()
|
|
defer authmutex.Unlock()
|
|
authorisedKeys = authKeys
|
|
}
|
|
|
|
func portPermitted(port uint32, ports []uint32) bool {
|
|
ok := false
|
|
for _, p := range ports {
|
|
if port == p {
|
|
ok = true
|
|
break
|
|
}
|
|
}
|
|
|
|
return ok
|
|
}
|
|
|
|
func parseOption(option string, prefix string) (string, error) {
|
|
str := fmt.Sprintf("%s=", prefix)
|
|
if !strings.HasPrefix(option, str) {
|
|
return "", fmt.Errorf("Option does not start with %s", str)
|
|
}
|
|
ports := option[len(str):]
|
|
|
|
if _, err := parsePorts(ports); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
return ports, nil
|
|
}
|
|
|
|
func parsePorts(portstr string) (p []uint32, err error) {
|
|
ports := strings.Split(portstr, ":")
|
|
for _, port := range ports {
|
|
port, err := strconv.ParseUint(port, 10, 32)
|
|
if err != nil {
|
|
return p, err
|
|
}
|
|
p = append(p, uint32(port))
|
|
}
|
|
return
|
|
}
|
|
|