Clean up code and more comments
This commit is contained in:
parent
1f1df93791
commit
9f8a9042fc
3 changed files with 110 additions and 75 deletions
26
signal_unix.go
Normal file
26
signal_unix.go
Normal file
|
@ -0,0 +1,26 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// +build !windows
|
||||
func registerReloadSignal() {
|
||||
c := make(chan os.Signal)
|
||||
signal.Notify(c, syscall.SIGUSR1)
|
||||
|
||||
go func() {
|
||||
for sig := range c {
|
||||
if sig == syscall.SIGUSR1 {
|
||||
log.Printf("Received signal: SIGUSR1. Reloading authorised keys.")
|
||||
loadAuthorisedKeys(*authorisedkeys)
|
||||
} else {
|
||||
log.Printf("Received unexpected signal: \"%s\".", sig.String())
|
||||
}
|
||||
}
|
||||
|
||||
}()
|
||||
}
|
5
signal_windows.go
Normal file
5
signal_windows.go
Normal file
|
@ -0,0 +1,5 @@
|
|||
package main
|
||||
|
||||
// +build windows
|
||||
func registerReloadSignal() {
|
||||
}
|
154
sshd.go
154
sshd.go
|
@ -11,18 +11,16 @@ import (
|
|||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var (
|
||||
// Contains a mapping of authorised keys to permissions for said key
|
||||
authorisedKeys map[string]deviceInfo
|
||||
|
||||
listenaddr = flag.String("listenaddr", "0.0.0.0", "Addr to listen on for incoming ssh connections")
|
||||
|
@ -32,34 +30,48 @@ var (
|
|||
verbose = flag.Bool("verbose", false, "Enable verbose mode")
|
||||
debug = flag.Bool("debug", false, "Enable debug mode")
|
||||
|
||||
// TODO: Separate for read/write? (Right now assume that on either
|
||||
// read/write, reset deadline)
|
||||
// Currently the timeouts are not separate for read and write deadlines.
|
||||
// This could be done, but I currently don't really see a reason for this.
|
||||
maintimeout = flag.Duration("main-timeout", time.Duration(2)*time.Minute, "Client socket timeout")
|
||||
directtimeout = flag.Duration("direct-timeout", time.Duration(2)*time.Minute, "direct-tcpip timeout")
|
||||
forwardedtimeout = flag.Duration("forwarded-timeout", time.Duration(2)*time.Minute, "forwarded-tcpip timeout")
|
||||
|
||||
// Mutex protecting 'authorisedKeys' map
|
||||
authmutex sync.Mutex
|
||||
)
|
||||
|
||||
type TimeoutFunc func()
|
||||
|
||||
// Structure that holds all information for each connection/client
|
||||
type sshClient struct {
|
||||
Name string
|
||||
Conn net.Conn
|
||||
SshConn *ssh.ServerConn
|
||||
Listeners map[string]net.Listener
|
||||
Name string
|
||||
|
||||
// We keep track of the normal Conn as well so that we have access to the
|
||||
// SetDeadline() methods
|
||||
Conn net.Conn
|
||||
|
||||
SshConn *ssh.ServerConn
|
||||
|
||||
// Listener sockets opened by the client
|
||||
Listeners map[string]net.Listener
|
||||
|
||||
AllowedLocalPorts []uint32
|
||||
AllowedRemotePorts []uint32
|
||||
Stopping bool
|
||||
ListenMutex sync.Mutex
|
||||
|
||||
// This indicates that a client is shutting down. When a client is stopping,
|
||||
// we do not allow new listening requests, to prevent a listener connection
|
||||
// being opened just after we closed all of them.
|
||||
Stopping bool
|
||||
ListenMutex sync.Mutex
|
||||
}
|
||||
|
||||
// Structure containing what address/port we should bind on, for forwarded-tcpip
|
||||
// connections
|
||||
type bindInfo struct {
|
||||
Bound string
|
||||
Port uint32
|
||||
Addr string
|
||||
}
|
||||
|
||||
// Information parsed from the authorized_keys file
|
||||
type deviceInfo struct {
|
||||
LocalPorts string
|
||||
RemotePorts string
|
||||
|
@ -95,6 +107,10 @@ type tcpIpForwardCancelPayload struct {
|
|||
Port uint32
|
||||
}
|
||||
|
||||
// Function that can be used to implement calls to SetDeadline() after
|
||||
// read/writes in copyTimeout()
|
||||
type TimeoutFunc func()
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
|
@ -134,6 +150,9 @@ func main() {
|
|||
}
|
||||
|
||||
tcpConn.SetDeadline(time.Now().Add(*maintimeout))
|
||||
|
||||
// We perform the ssh handshake in a goroutine so the handshake cannot
|
||||
// block incoming connections.
|
||||
go func() {
|
||||
sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, config)
|
||||
if err != nil {
|
||||
|
@ -171,8 +190,8 @@ func main() {
|
|||
client.ListenMutex.Unlock()
|
||||
}()
|
||||
|
||||
// Accept requests & channels
|
||||
go handleRequest(&client, reqs)
|
||||
// Accept all channels
|
||||
go handleChannels(&client, chans)
|
||||
}()
|
||||
}
|
||||
|
@ -185,7 +204,7 @@ func handleChannels(client *sshClient, chans <-chan ssh.NewChannel) {
|
|||
}
|
||||
|
||||
func handleChannel(client *sshClient, newChannel ssh.NewChannel) {
|
||||
if *verbose {
|
||||
if *debug {
|
||||
log.Printf("[%s] Channel type: %v", client.Name, newChannel.ChannelType())
|
||||
}
|
||||
if t := newChannel.ChannelType(); t == "direct-tcpip" {
|
||||
|
@ -193,9 +212,11 @@ func handleChannel(client *sshClient, newChannel ssh.NewChannel) {
|
|||
return
|
||||
}
|
||||
|
||||
newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Only \"direct-tcpip\", \"forwarded-tcpip\" and \"cancel-tcpip-forward\" are accepted"))
|
||||
newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Only \"direct-tcpip\" is accepted"))
|
||||
/*
|
||||
// TODO: USE THIS ONLY FOR USING SSH ESCAPE SEQUENCES
|
||||
// XXX: Use this only for testing purposes -- I add this in if/when I
|
||||
// want to use the ssh escape sequences from ssh (those only work in an
|
||||
// interactive session)
|
||||
c, _, err := newChannel.Accept()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
@ -365,6 +386,47 @@ func handleTcpIPForwardCancel(client *sshClient, req *ssh.Request) {
|
|||
req.Reply(false, []byte{})
|
||||
}
|
||||
|
||||
func handleRequest(client *sshClient, reqs <-chan *ssh.Request) {
|
||||
for req := range reqs {
|
||||
client.Conn.SetDeadline(time.Now().Add(*maintimeout))
|
||||
|
||||
if *debug {
|
||||
log.Printf("[%s] Out of band request: %v %v", client.Name, req.Type, req.WantReply)
|
||||
}
|
||||
|
||||
// RFC4254: 7.1 for forwarding
|
||||
if req.Type == "tcpip-forward" {
|
||||
client.ListenMutex.Lock()
|
||||
/* If we are closing, do not set up a new listener */
|
||||
if client.Stopping {
|
||||
client.ListenMutex.Unlock()
|
||||
req.Reply(false, []byte{})
|
||||
continue
|
||||
}
|
||||
|
||||
listener, bindinfo, err := handleTcpIpForward(client, req)
|
||||
if err != nil {
|
||||
client.ListenMutex.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
client.Listeners[bindinfo.Bound] = listener
|
||||
client.ListenMutex.Unlock()
|
||||
|
||||
go handleListener(client, bindinfo, listener)
|
||||
continue
|
||||
} else if req.Type == "cancel-tcpip-forward" {
|
||||
client.ListenMutex.Lock()
|
||||
handleTcpIPForwardCancel(client, req)
|
||||
client.ListenMutex.Unlock()
|
||||
continue
|
||||
} else {
|
||||
// Discard everything else
|
||||
req.Reply(false, []byte{})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func serve(cssh ssh.Channel, conn net.Conn, client *sshClient, timeout time.Duration) {
|
||||
// TODO: Maybe just do this with defer instead? (And only one copy in a
|
||||
// goroutine)
|
||||
|
@ -511,64 +573,6 @@ func loadAuthorisedKeys(authorisedkeys string) {
|
|||
authorisedKeys = authKeys
|
||||
}
|
||||
|
||||
func registerReloadSignal() {
|
||||
c := make(chan os.Signal)
|
||||
signal.Notify(c, syscall.SIGUSR1)
|
||||
|
||||
go func() {
|
||||
for sig := range c {
|
||||
if sig == syscall.SIGUSR1 {
|
||||
log.Printf("Received signal: SIGUSR1. Reloading authorised keys.")
|
||||
loadAuthorisedKeys(*authorisedkeys)
|
||||
} else {
|
||||
log.Printf("Received unexpected signal: \"%s\".", sig.String())
|
||||
}
|
||||
}
|
||||
|
||||
}()
|
||||
}
|
||||
|
||||
func handleRequest(client *sshClient, reqs <-chan *ssh.Request) {
|
||||
for req := range reqs {
|
||||
client.Conn.SetDeadline(time.Now().Add(*maintimeout))
|
||||
|
||||
if *verbose {
|
||||
log.Printf("[%s] Out of band request: %v %v", client.Name, req.Type, req.WantReply)
|
||||
}
|
||||
|
||||
// RFC4254: 7.1 for forwarding
|
||||
if req.Type == "tcpip-forward" {
|
||||
client.ListenMutex.Lock()
|
||||
/* If we are closing, do not set up a new listener */
|
||||
if client.Stopping {
|
||||
client.ListenMutex.Unlock()
|
||||
req.Reply(false, []byte{})
|
||||
continue
|
||||
}
|
||||
|
||||
listener, bindinfo, err := handleTcpIpForward(client, req)
|
||||
if err != nil {
|
||||
client.ListenMutex.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
client.Listeners[bindinfo.Bound] = listener
|
||||
client.ListenMutex.Unlock()
|
||||
|
||||
go handleListener(client, bindinfo, listener)
|
||||
continue
|
||||
} else if req.Type == "cancel-tcpip-forward" {
|
||||
client.ListenMutex.Lock()
|
||||
handleTcpIPForwardCancel(client, req)
|
||||
client.ListenMutex.Unlock()
|
||||
continue
|
||||
} else {
|
||||
// Discard everything else
|
||||
req.Reply(false, []byte{})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func portPermitted(port uint32, ports []uint32) bool {
|
||||
ok := false
|
||||
for _, p := range ports {
|
||||
|
|
Loading…
Add table
Reference in a new issue