Compare commits

...

3 Commits

Author SHA1 Message Date
Merlijn Wajer 9f8a9042fc Clean up code and more comments 7 years ago
Merlijn Wajer 1f1df93791 Add "-debug" flag, and make use of it. 7 years ago
Merlijn Wajer 1216e33e23 Add basic support for timeouts 7 years ago
  1. 26
      signal_unix.go
  2. 5
      signal_windows.go
  3. 234
      sshd.go

@ -0,0 +1,26 @@
package main
import (
"log"
"os"
"os/signal"
"syscall"
)
// +build !windows
func registerReloadSignal() {
c := make(chan os.Signal)
signal.Notify(c, syscall.SIGUSR1)
go func() {
for sig := range c {
if sig == syscall.SIGUSR1 {
log.Printf("Received signal: SIGUSR1. Reloading authorised keys.")
loadAuthorisedKeys(*authorisedkeys)
} else {
log.Printf("Received unexpected signal: \"%s\".", sig.String())
}
}
}()
}

@ -0,0 +1,5 @@
package main
// +build windows
func registerReloadSignal() {
}

@ -11,17 +11,16 @@ import (
"io/ioutil"
"log"
"net"
"os"
"os/signal"
"strconv"
"strings"
"sync"
"syscall"
"time"
"golang.org/x/crypto/ssh"
)
var (
// Contains a mapping of authorised keys to permissions for said key
authorisedKeys map[string]deviceInfo
listenaddr = flag.String("listenaddr", "0.0.0.0", "Addr to listen on for incoming ssh connections")
@ -29,26 +28,50 @@ var (
hostkey = flag.String("hostkey", "id_rsa", "Server host key to load")
authorisedkeys = flag.String("authorisedkeys", "authorized_keys", "Authorised keys")
verbose = flag.Bool("verbose", false, "Enable verbose mode")
debug = flag.Bool("debug", false, "Enable debug mode")
// Currently the timeouts are not separate for read and write deadlines.
// This could be done, but I currently don't really see a reason for this.
maintimeout = flag.Duration("main-timeout", time.Duration(2)*time.Minute, "Client socket timeout")
directtimeout = flag.Duration("direct-timeout", time.Duration(2)*time.Minute, "direct-tcpip timeout")
forwardedtimeout = flag.Duration("forwarded-timeout", time.Duration(2)*time.Minute, "forwarded-tcpip timeout")
// Mutex protecting 'authorisedKeys' map
authmutex sync.Mutex
)
// Structure that holds all information for each connection/client
type sshClient struct {
Name string
Conn *ssh.ServerConn
Listeners map[string]net.Listener
Name string
// We keep track of the normal Conn as well so that we have access to the
// SetDeadline() methods
Conn net.Conn
SshConn *ssh.ServerConn
// Listener sockets opened by the client
Listeners map[string]net.Listener
AllowedLocalPorts []uint32
AllowedRemotePorts []uint32
Stopping bool
ListenMutex sync.Mutex
// This indicates that a client is shutting down. When a client is stopping,
// we do not allow new listening requests, to prevent a listener connection
// being opened just after we closed all of them.
Stopping bool
ListenMutex sync.Mutex
}
// Structure containing what address/port we should bind on, for forwarded-tcpip
// connections
type bindInfo struct {
Bound string
Port uint32
Addr string
}
// Information parsed from the authorized_keys file
type deviceInfo struct {
LocalPorts string
RemotePorts string
@ -84,6 +107,10 @@ type tcpIpForwardCancelPayload struct {
Port uint32
}
// Function that can be used to implement calls to SetDeadline() after
// read/writes in copyTimeout()
type TimeoutFunc func()
func main() {
flag.Parse()
@ -122,6 +149,10 @@ func main() {
continue
}
tcpConn.SetDeadline(time.Now().Add(*maintimeout))
// We perform the ssh handshake in a goroutine so the handshake cannot
// block incoming connections.
go func() {
sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, config)
if err != nil {
@ -129,7 +160,7 @@ func main() {
return
}
client := sshClient{sshConn.Permissions.CriticalOptions["name"], sshConn, make(map[string]net.Listener), nil, nil, false, sync.Mutex{}}
client := sshClient{sshConn.Permissions.CriticalOptions["name"], tcpConn, sshConn, make(map[string]net.Listener), nil, nil, false, sync.Mutex{}}
allowedLocalPorts := sshConn.Permissions.CriticalOptions["localports"]
allowedRemotePorts := sshConn.Permissions.CriticalOptions["remoteports"]
@ -143,7 +174,7 @@ func main() {
client.AllowedRemotePorts, _ = parsePorts(allowedRemotePorts)
go func() {
err := client.Conn.Wait()
err := client.SshConn.Wait()
client.ListenMutex.Lock()
client.Stopping = true
@ -159,8 +190,8 @@ func main() {
client.ListenMutex.Unlock()
}()
// Accept requests & channels
go handleRequest(&client, reqs)
// Accept all channels
go handleChannels(&client, chans)
}()
}
@ -173,7 +204,7 @@ func handleChannels(client *sshClient, chans <-chan ssh.NewChannel) {
}
func handleChannel(client *sshClient, newChannel ssh.NewChannel) {
if *verbose {
if *debug {
log.Printf("[%s] Channel type: %v", client.Name, newChannel.ChannelType())
}
if t := newChannel.ChannelType(); t == "direct-tcpip" {
@ -181,9 +212,11 @@ func handleChannel(client *sshClient, newChannel ssh.NewChannel) {
return
}
newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Only \"direct-tcpip\", \"forwarded-tcpip\" and \"cancel-tcpip-forward\" are accepted"))
newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Only \"direct-tcpip\" is accepted"))
/*
// TODO: USE THIS ONLY FOR USING SSH ESCAPE SEQUENCES
// XXX: Use this only for testing purposes -- I add this in if/when I
// want to use the ssh escape sequences from ssh (those only work in an
// interactive session)
c, _, err := newChannel.Accept()
if err != nil {
log.Fatal(err)
@ -240,7 +273,7 @@ func handleDirect(client *sshClient, newChannel ssh.NewChannel) {
return
}
serve(connection, rconn, client)
serve(connection, rconn, client, *directtimeout)
}
func handleTcpIpForward(client *sshClient, req *ssh.Request) (net.Listener, *bindInfo, error) {
@ -318,7 +351,7 @@ func handleForwardTcpIp(client *sshClient, bindinfo *bindInfo, lconn net.Conn) {
mpayload := ssh.Marshal(&payload)
// Open channel with client
c, requests, err := client.Conn.OpenChannel("forwarded-tcpip", mpayload)
c, requests, err := client.SshConn.OpenChannel("forwarded-tcpip", mpayload)
if err != nil {
log.Printf("[%s] Unable to get channel: %s. Hanging up requesting party!", client.Name, err)
lconn.Close()
@ -329,7 +362,7 @@ func handleForwardTcpIp(client *sshClient, bindinfo *bindInfo, lconn net.Conn) {
}
go ssh.DiscardRequests(requests)
serve(c, lconn, client)
serve(c, lconn, client, *forwardedtimeout)
}
func handleTcpIPForwardCancel(client *sshClient, req *ssh.Request) {
@ -353,7 +386,48 @@ func handleTcpIPForwardCancel(client *sshClient, req *ssh.Request) {
req.Reply(false, []byte{})
}
func serve(cssh ssh.Channel, conn net.Conn, client *sshClient) {
func handleRequest(client *sshClient, reqs <-chan *ssh.Request) {
for req := range reqs {
client.Conn.SetDeadline(time.Now().Add(*maintimeout))
if *debug {
log.Printf("[%s] Out of band request: %v %v", client.Name, req.Type, req.WantReply)
}
// RFC4254: 7.1 for forwarding
if req.Type == "tcpip-forward" {
client.ListenMutex.Lock()
/* If we are closing, do not set up a new listener */
if client.Stopping {
client.ListenMutex.Unlock()
req.Reply(false, []byte{})
continue
}
listener, bindinfo, err := handleTcpIpForward(client, req)
if err != nil {
client.ListenMutex.Unlock()
continue
}
client.Listeners[bindinfo.Bound] = listener
client.ListenMutex.Unlock()
go handleListener(client, bindinfo, listener)
continue
} else if req.Type == "cancel-tcpip-forward" {
client.ListenMutex.Lock()
handleTcpIPForwardCancel(client, req)
client.ListenMutex.Unlock()
continue
} else {
// Discard everything else
req.Reply(false, []byte{})
}
}
}
func serve(cssh ssh.Channel, conn net.Conn, client *sshClient, timeout time.Duration) {
// TODO: Maybe just do this with defer instead? (And only one copy in a
// goroutine)
close := func() {
@ -364,17 +438,81 @@ func serve(cssh ssh.Channel, conn net.Conn, client *sshClient) {
}
}
// TODO: Share timeout between both the read-conn and the write-conn
var once sync.Once
go func() {
io.Copy(cssh, conn)
//io.Copy(cssh, conn)
bytes_written, err := copyTimeout(cssh, conn, func() {
if *debug {
log.Printf("[%s] Updating deadline for direct|forwarded socket and main socket (sending data)", client.Name)
}
conn.SetDeadline(time.Now().Add(timeout))
client.Conn.SetDeadline(time.Now().Add(*maintimeout))
})
if err != nil {
if *debug {
log.Printf("[%s] copyTimeout failed with: %s", client.Name, err)
}
}
if *verbose {
log.Printf("[%s] Connection closed, bytes written: %d", client.Name, bytes_written)
}
once.Do(close)
}()
go func() {
io.Copy(conn, cssh)
//io.Copy(conn, cssh)
bytes_written, err := copyTimeout(conn, cssh, func() {
if *debug {
log.Printf("[%s] Updating deadline for direct|forwarded socket and main socket (received data)", client.Name)
}
conn.SetDeadline(time.Now().Add(timeout))
client.Conn.SetDeadline(time.Now().Add(*maintimeout))
})
if err != nil {
if *debug {
log.Printf("[%s] copyTimeout failed with: %s", client.Name, err)
}
}
if *verbose {
log.Printf("[%s] Connection closed, bytes written: %d", client.Name, bytes_written)
}
once.Do(close)
}()
}
// Changed from pkg/io/io.go copyBuffer
func copyTimeout(dst io.Writer, src io.Reader, timeout TimeoutFunc) (written int64, err error) {
buf := make([]byte, 32*1024)
for {
nr, er := src.Read(buf)
if nr > 0 {
timeout()
nw, ew := dst.Write(buf[0:nr])
if nw > 0 {
written += int64(nw)
}
if ew != nil {
err = ew
break
}
if nr != nw {
err = io.ErrShortWrite
break
}
timeout()
}
if er != nil {
if er != io.EOF {
err = er
}
break
}
}
return written, err
}
func loadHostKeys(config *ssh.ServerConfig) {
privateBytes, err := ioutil.ReadFile(*hostkey)
if err != nil {
@ -435,62 +573,6 @@ func loadAuthorisedKeys(authorisedkeys string) {
authorisedKeys = authKeys
}
func registerReloadSignal() {
c := make(chan os.Signal)
signal.Notify(c, syscall.SIGUSR1)
go func() {
for sig := range c {
if sig == syscall.SIGUSR1 {
log.Printf("Received signal: SIGUSR1. Reloading authorised keys.")
loadAuthorisedKeys(*authorisedkeys)
} else {
log.Printf("Received unexpected signal: \"%s\".", sig.String())
}
}
}()
}
func handleRequest(client *sshClient, reqs <-chan *ssh.Request) {
for req := range reqs {
if *verbose {
log.Printf("[%s] Out of band request: %v %v", client.Name, req.Type, req.WantReply)
}
// RFC4254: 7.1 for forwarding
if req.Type == "tcpip-forward" {
client.ListenMutex.Lock()
/* If we are closing, do not set up a new listener */
if client.Stopping {
client.ListenMutex.Unlock()
req.Reply(false, []byte{})
continue
}
listener, bindinfo, err := handleTcpIpForward(client, req)
if err != nil {
client.ListenMutex.Unlock()
continue
}
client.Listeners[bindinfo.Bound] = listener
client.ListenMutex.Unlock()
go handleListener(client, bindinfo, listener)
continue
} else if req.Type == "cancel-tcpip-forward" {
client.ListenMutex.Lock()
handleTcpIPForwardCancel(client, req)
client.ListenMutex.Unlock()
continue
} else {
// Discard everything else
req.Reply(false, []byte{})
}
}
}
func portPermitted(port uint32, ports []uint32) bool {
ok := false
for _, p := range ports {

Loading…
Cancel
Save