|
|
|
@ -11,17 +11,16 @@ import ( |
|
|
|
|
"io/ioutil" |
|
|
|
|
"log" |
|
|
|
|
"net" |
|
|
|
|
"os" |
|
|
|
|
"os/signal" |
|
|
|
|
"strconv" |
|
|
|
|
"strings" |
|
|
|
|
"sync" |
|
|
|
|
"syscall" |
|
|
|
|
"time" |
|
|
|
|
|
|
|
|
|
"golang.org/x/crypto/ssh" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
var ( |
|
|
|
|
// Contains a mapping of authorised keys to permissions for said key
|
|
|
|
|
authorisedKeys map[string]deviceInfo |
|
|
|
|
|
|
|
|
|
listenaddr = flag.String("listenaddr", "0.0.0.0", "Addr to listen on for incoming ssh connections") |
|
|
|
@ -29,26 +28,50 @@ var ( |
|
|
|
|
hostkey = flag.String("hostkey", "id_rsa", "Server host key to load") |
|
|
|
|
authorisedkeys = flag.String("authorisedkeys", "authorized_keys", "Authorised keys") |
|
|
|
|
verbose = flag.Bool("verbose", false, "Enable verbose mode") |
|
|
|
|
debug = flag.Bool("debug", false, "Enable debug mode") |
|
|
|
|
|
|
|
|
|
// Currently the timeouts are not separate for read and write deadlines.
|
|
|
|
|
// This could be done, but I currently don't really see a reason for this.
|
|
|
|
|
maintimeout = flag.Duration("main-timeout", time.Duration(2)*time.Minute, "Client socket timeout") |
|
|
|
|
directtimeout = flag.Duration("direct-timeout", time.Duration(2)*time.Minute, "direct-tcpip timeout") |
|
|
|
|
forwardedtimeout = flag.Duration("forwarded-timeout", time.Duration(2)*time.Minute, "forwarded-tcpip timeout") |
|
|
|
|
|
|
|
|
|
// Mutex protecting 'authorisedKeys' map
|
|
|
|
|
authmutex sync.Mutex |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
// Structure that holds all information for each connection/client
|
|
|
|
|
type sshClient struct { |
|
|
|
|
Name string |
|
|
|
|
Conn *ssh.ServerConn |
|
|
|
|
Listeners map[string]net.Listener |
|
|
|
|
Name string |
|
|
|
|
|
|
|
|
|
// We keep track of the normal Conn as well so that we have access to the
|
|
|
|
|
// SetDeadline() methods
|
|
|
|
|
Conn net.Conn |
|
|
|
|
|
|
|
|
|
SshConn *ssh.ServerConn |
|
|
|
|
|
|
|
|
|
// Listener sockets opened by the client
|
|
|
|
|
Listeners map[string]net.Listener |
|
|
|
|
|
|
|
|
|
AllowedLocalPorts []uint32 |
|
|
|
|
AllowedRemotePorts []uint32 |
|
|
|
|
Stopping bool |
|
|
|
|
ListenMutex sync.Mutex |
|
|
|
|
|
|
|
|
|
// This indicates that a client is shutting down. When a client is stopping,
|
|
|
|
|
// we do not allow new listening requests, to prevent a listener connection
|
|
|
|
|
// being opened just after we closed all of them.
|
|
|
|
|
Stopping bool |
|
|
|
|
ListenMutex sync.Mutex |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Structure containing what address/port we should bind on, for forwarded-tcpip
|
|
|
|
|
// connections
|
|
|
|
|
type bindInfo struct { |
|
|
|
|
Bound string |
|
|
|
|
Port uint32 |
|
|
|
|
Addr string |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Information parsed from the authorized_keys file
|
|
|
|
|
type deviceInfo struct { |
|
|
|
|
LocalPorts string |
|
|
|
|
RemotePorts string |
|
|
|
@ -84,6 +107,10 @@ type tcpIpForwardCancelPayload struct { |
|
|
|
|
Port uint32 |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Function that can be used to implement calls to SetDeadline() after
|
|
|
|
|
// read/writes in copyTimeout()
|
|
|
|
|
type TimeoutFunc func() |
|
|
|
|
|
|
|
|
|
func main() { |
|
|
|
|
flag.Parse() |
|
|
|
|
|
|
|
|
@ -122,6 +149,10 @@ func main() { |
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
tcpConn.SetDeadline(time.Now().Add(*maintimeout)) |
|
|
|
|
|
|
|
|
|
// We perform the ssh handshake in a goroutine so the handshake cannot
|
|
|
|
|
// block incoming connections.
|
|
|
|
|
go func() { |
|
|
|
|
sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, config) |
|
|
|
|
if err != nil { |
|
|
|
@ -129,7 +160,7 @@ func main() { |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
client := sshClient{sshConn.Permissions.CriticalOptions["name"], sshConn, make(map[string]net.Listener), nil, nil, false, sync.Mutex{}} |
|
|
|
|
client := sshClient{sshConn.Permissions.CriticalOptions["name"], tcpConn, sshConn, make(map[string]net.Listener), nil, nil, false, sync.Mutex{}} |
|
|
|
|
allowedLocalPorts := sshConn.Permissions.CriticalOptions["localports"] |
|
|
|
|
allowedRemotePorts := sshConn.Permissions.CriticalOptions["remoteports"] |
|
|
|
|
|
|
|
|
@ -143,7 +174,7 @@ func main() { |
|
|
|
|
client.AllowedRemotePorts, _ = parsePorts(allowedRemotePorts) |
|
|
|
|
|
|
|
|
|
go func() { |
|
|
|
|
err := client.Conn.Wait() |
|
|
|
|
err := client.SshConn.Wait() |
|
|
|
|
client.ListenMutex.Lock() |
|
|
|
|
client.Stopping = true |
|
|
|
|
|
|
|
|
@ -159,8 +190,8 @@ func main() { |
|
|
|
|
client.ListenMutex.Unlock() |
|
|
|
|
}() |
|
|
|
|
|
|
|
|
|
// Accept requests & channels
|
|
|
|
|
go handleRequest(&client, reqs) |
|
|
|
|
// Accept all channels
|
|
|
|
|
go handleChannels(&client, chans) |
|
|
|
|
}() |
|
|
|
|
} |
|
|
|
@ -173,7 +204,7 @@ func handleChannels(client *sshClient, chans <-chan ssh.NewChannel) { |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func handleChannel(client *sshClient, newChannel ssh.NewChannel) { |
|
|
|
|
if *verbose { |
|
|
|
|
if *debug { |
|
|
|
|
log.Printf("[%s] Channel type: %v", client.Name, newChannel.ChannelType()) |
|
|
|
|
} |
|
|
|
|
if t := newChannel.ChannelType(); t == "direct-tcpip" { |
|
|
|
@ -181,9 +212,11 @@ func handleChannel(client *sshClient, newChannel ssh.NewChannel) { |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Only \"direct-tcpip\", \"forwarded-tcpip\" and \"cancel-tcpip-forward\" are accepted")) |
|
|
|
|
newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Only \"direct-tcpip\" is accepted")) |
|
|
|
|
/* |
|
|
|
|
// TODO: USE THIS ONLY FOR USING SSH ESCAPE SEQUENCES
|
|
|
|
|
// XXX: Use this only for testing purposes -- I add this in if/when I
|
|
|
|
|
// want to use the ssh escape sequences from ssh (those only work in an
|
|
|
|
|
// interactive session)
|
|
|
|
|
c, _, err := newChannel.Accept() |
|
|
|
|
if err != nil { |
|
|
|
|
log.Fatal(err) |
|
|
|
@ -240,7 +273,7 @@ func handleDirect(client *sshClient, newChannel ssh.NewChannel) { |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
serve(connection, rconn, client) |
|
|
|
|
serve(connection, rconn, client, *directtimeout) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func handleTcpIpForward(client *sshClient, req *ssh.Request) (net.Listener, *bindInfo, error) { |
|
|
|
@ -318,7 +351,7 @@ func handleForwardTcpIp(client *sshClient, bindinfo *bindInfo, lconn net.Conn) { |
|
|
|
|
mpayload := ssh.Marshal(&payload) |
|
|
|
|
|
|
|
|
|
// Open channel with client
|
|
|
|
|
c, requests, err := client.Conn.OpenChannel("forwarded-tcpip", mpayload) |
|
|
|
|
c, requests, err := client.SshConn.OpenChannel("forwarded-tcpip", mpayload) |
|
|
|
|
if err != nil { |
|
|
|
|
log.Printf("[%s] Unable to get channel: %s. Hanging up requesting party!", client.Name, err) |
|
|
|
|
lconn.Close() |
|
|
|
@ -329,7 +362,7 @@ func handleForwardTcpIp(client *sshClient, bindinfo *bindInfo, lconn net.Conn) { |
|
|
|
|
} |
|
|
|
|
go ssh.DiscardRequests(requests) |
|
|
|
|
|
|
|
|
|
serve(c, lconn, client) |
|
|
|
|
serve(c, lconn, client, *forwardedtimeout) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func handleTcpIPForwardCancel(client *sshClient, req *ssh.Request) { |
|
|
|
@ -353,7 +386,48 @@ func handleTcpIPForwardCancel(client *sshClient, req *ssh.Request) { |
|
|
|
|
req.Reply(false, []byte{}) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func serve(cssh ssh.Channel, conn net.Conn, client *sshClient) { |
|
|
|
|
func handleRequest(client *sshClient, reqs <-chan *ssh.Request) { |
|
|
|
|
for req := range reqs { |
|
|
|
|
client.Conn.SetDeadline(time.Now().Add(*maintimeout)) |
|
|
|
|
|
|
|
|
|
if *debug { |
|
|
|
|
log.Printf("[%s] Out of band request: %v %v", client.Name, req.Type, req.WantReply) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// RFC4254: 7.1 for forwarding
|
|
|
|
|
if req.Type == "tcpip-forward" { |
|
|
|
|
client.ListenMutex.Lock() |
|
|
|
|
/* If we are closing, do not set up a new listener */ |
|
|
|
|
if client.Stopping { |
|
|
|
|
client.ListenMutex.Unlock() |
|
|
|
|
req.Reply(false, []byte{}) |
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
listener, bindinfo, err := handleTcpIpForward(client, req) |
|
|
|
|
if err != nil { |
|
|
|
|
client.ListenMutex.Unlock() |
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
client.Listeners[bindinfo.Bound] = listener |
|
|
|
|
client.ListenMutex.Unlock() |
|
|
|
|
|
|
|
|
|
go handleListener(client, bindinfo, listener) |
|
|
|
|
continue |
|
|
|
|
} else if req.Type == "cancel-tcpip-forward" { |
|
|
|
|
client.ListenMutex.Lock() |
|
|
|
|
handleTcpIPForwardCancel(client, req) |
|
|
|
|
client.ListenMutex.Unlock() |
|
|
|
|
continue |
|
|
|
|
} else { |
|
|
|
|
// Discard everything else
|
|
|
|
|
req.Reply(false, []byte{}) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func serve(cssh ssh.Channel, conn net.Conn, client *sshClient, timeout time.Duration) { |
|
|
|
|
// TODO: Maybe just do this with defer instead? (And only one copy in a
|
|
|
|
|
// goroutine)
|
|
|
|
|
close := func() { |
|
|
|
@ -364,17 +438,81 @@ func serve(cssh ssh.Channel, conn net.Conn, client *sshClient) { |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// TODO: Share timeout between both the read-conn and the write-conn
|
|
|
|
|
var once sync.Once |
|
|
|
|
go func() { |
|
|
|
|
io.Copy(cssh, conn) |
|
|
|
|
//io.Copy(cssh, conn)
|
|
|
|
|
bytes_written, err := copyTimeout(cssh, conn, func() { |
|
|
|
|
if *debug { |
|
|
|
|
log.Printf("[%s] Updating deadline for direct|forwarded socket and main socket (sending data)", client.Name) |
|
|
|
|
} |
|
|
|
|
conn.SetDeadline(time.Now().Add(timeout)) |
|
|
|
|
client.Conn.SetDeadline(time.Now().Add(*maintimeout)) |
|
|
|
|
}) |
|
|
|
|
if err != nil { |
|
|
|
|
if *debug { |
|
|
|
|
log.Printf("[%s] copyTimeout failed with: %s", client.Name, err) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
if *verbose { |
|
|
|
|
log.Printf("[%s] Connection closed, bytes written: %d", client.Name, bytes_written) |
|
|
|
|
} |
|
|
|
|
once.Do(close) |
|
|
|
|
}() |
|
|
|
|
go func() { |
|
|
|
|
io.Copy(conn, cssh) |
|
|
|
|
//io.Copy(conn, cssh)
|
|
|
|
|
bytes_written, err := copyTimeout(conn, cssh, func() { |
|
|
|
|
if *debug { |
|
|
|
|
log.Printf("[%s] Updating deadline for direct|forwarded socket and main socket (received data)", client.Name) |
|
|
|
|
} |
|
|
|
|
conn.SetDeadline(time.Now().Add(timeout)) |
|
|
|
|
client.Conn.SetDeadline(time.Now().Add(*maintimeout)) |
|
|
|
|
}) |
|
|
|
|
if err != nil { |
|
|
|
|
if *debug { |
|
|
|
|
log.Printf("[%s] copyTimeout failed with: %s", client.Name, err) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
if *verbose { |
|
|
|
|
log.Printf("[%s] Connection closed, bytes written: %d", client.Name, bytes_written) |
|
|
|
|
} |
|
|
|
|
once.Do(close) |
|
|
|
|
}() |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Changed from pkg/io/io.go copyBuffer
|
|
|
|
|
func copyTimeout(dst io.Writer, src io.Reader, timeout TimeoutFunc) (written int64, err error) { |
|
|
|
|
buf := make([]byte, 32*1024) |
|
|
|
|
|
|
|
|
|
for { |
|
|
|
|
nr, er := src.Read(buf) |
|
|
|
|
if nr > 0 { |
|
|
|
|
timeout() |
|
|
|
|
|
|
|
|
|
nw, ew := dst.Write(buf[0:nr]) |
|
|
|
|
if nw > 0 { |
|
|
|
|
written += int64(nw) |
|
|
|
|
} |
|
|
|
|
if ew != nil { |
|
|
|
|
err = ew |
|
|
|
|
break |
|
|
|
|
} |
|
|
|
|
if nr != nw { |
|
|
|
|
err = io.ErrShortWrite |
|
|
|
|
break |
|
|
|
|
} |
|
|
|
|
timeout() |
|
|
|
|
} |
|
|
|
|
if er != nil { |
|
|
|
|
if er != io.EOF { |
|
|
|
|
err = er |
|
|
|
|
} |
|
|
|
|
break |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
return written, err |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func loadHostKeys(config *ssh.ServerConfig) { |
|
|
|
|
privateBytes, err := ioutil.ReadFile(*hostkey) |
|
|
|
|
if err != nil { |
|
|
|
@ -435,62 +573,6 @@ func loadAuthorisedKeys(authorisedkeys string) { |
|
|
|
|
authorisedKeys = authKeys |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func registerReloadSignal() { |
|
|
|
|
c := make(chan os.Signal) |
|
|
|
|
signal.Notify(c, syscall.SIGUSR1) |
|
|
|
|
|
|
|
|
|
go func() { |
|
|
|
|
for sig := range c { |
|
|
|
|
if sig == syscall.SIGUSR1 { |
|
|
|
|
log.Printf("Received signal: SIGUSR1. Reloading authorised keys.") |
|
|
|
|
loadAuthorisedKeys(*authorisedkeys) |
|
|
|
|
} else { |
|
|
|
|
log.Printf("Received unexpected signal: \"%s\".", sig.String()) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
}() |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func handleRequest(client *sshClient, reqs <-chan *ssh.Request) { |
|
|
|
|
for req := range reqs { |
|
|
|
|
if *verbose { |
|
|
|
|
log.Printf("[%s] Out of band request: %v %v", client.Name, req.Type, req.WantReply) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// RFC4254: 7.1 for forwarding
|
|
|
|
|
if req.Type == "tcpip-forward" { |
|
|
|
|
client.ListenMutex.Lock() |
|
|
|
|
/* If we are closing, do not set up a new listener */ |
|
|
|
|
if client.Stopping { |
|
|
|
|
client.ListenMutex.Unlock() |
|
|
|
|
req.Reply(false, []byte{}) |
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
listener, bindinfo, err := handleTcpIpForward(client, req) |
|
|
|
|
if err != nil { |
|
|
|
|
client.ListenMutex.Unlock() |
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
client.Listeners[bindinfo.Bound] = listener |
|
|
|
|
client.ListenMutex.Unlock() |
|
|
|
|
|
|
|
|
|
go handleListener(client, bindinfo, listener) |
|
|
|
|
continue |
|
|
|
|
} else if req.Type == "cancel-tcpip-forward" { |
|
|
|
|
client.ListenMutex.Lock() |
|
|
|
|
handleTcpIPForwardCancel(client, req) |
|
|
|
|
client.ListenMutex.Unlock() |
|
|
|
|
continue |
|
|
|
|
} else { |
|
|
|
|
// Discard everything else
|
|
|
|
|
req.Reply(false, []byte{}) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func portPermitted(port uint32, ports []uint32) bool { |
|
|
|
|
ok := false |
|
|
|
|
for _, p := range ports { |
|
|
|
|