|
|
|
@ -17,6 +17,7 @@ import ( |
|
|
|
|
"strings" |
|
|
|
|
"sync" |
|
|
|
|
"syscall" |
|
|
|
|
"time" |
|
|
|
|
|
|
|
|
|
"golang.org/x/crypto/ssh" |
|
|
|
|
) |
|
|
|
@ -30,12 +31,21 @@ var ( |
|
|
|
|
authorisedkeys = flag.String("authorisedkeys", "authorized_keys", "Authorised keys") |
|
|
|
|
verbose = flag.Bool("verbose", false, "Enable verbose mode") |
|
|
|
|
|
|
|
|
|
// TODO: Separate for read/write? (Right now assume that on either
|
|
|
|
|
// read/write, reset deadline)
|
|
|
|
|
maintimeout = flag.Duration("main-timeout", time.Duration(2)*time.Minute, "Client socket timeout") |
|
|
|
|
directtimeout = flag.Duration("direct-timeout", time.Duration(2)*time.Minute, "direct-tcpip timeout") |
|
|
|
|
forwardedtimeout = flag.Duration("forwarded-timeout", time.Duration(2)*time.Minute, "forwarded-tcpip timeout") |
|
|
|
|
|
|
|
|
|
authmutex sync.Mutex |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
type TimeoutFunc func() |
|
|
|
|
|
|
|
|
|
type sshClient struct { |
|
|
|
|
Name string |
|
|
|
|
Conn *ssh.ServerConn |
|
|
|
|
Conn net.Conn |
|
|
|
|
SshConn *ssh.ServerConn |
|
|
|
|
Listeners map[string]net.Listener |
|
|
|
|
AllowedLocalPorts []uint32 |
|
|
|
|
AllowedRemotePorts []uint32 |
|
|
|
@ -122,6 +132,7 @@ func main() { |
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
tcpConn.SetDeadline(time.Now().Add(*maintimeout)) |
|
|
|
|
go func() { |
|
|
|
|
sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, config) |
|
|
|
|
if err != nil { |
|
|
|
@ -129,7 +140,7 @@ func main() { |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
client := sshClient{sshConn.Permissions.CriticalOptions["name"], sshConn, make(map[string]net.Listener), nil, nil, false, sync.Mutex{}} |
|
|
|
|
client := sshClient{sshConn.Permissions.CriticalOptions["name"], tcpConn, sshConn, make(map[string]net.Listener), nil, nil, false, sync.Mutex{}} |
|
|
|
|
allowedLocalPorts := sshConn.Permissions.CriticalOptions["localports"] |
|
|
|
|
allowedRemotePorts := sshConn.Permissions.CriticalOptions["remoteports"] |
|
|
|
|
|
|
|
|
@ -143,7 +154,7 @@ func main() { |
|
|
|
|
client.AllowedRemotePorts, _ = parsePorts(allowedRemotePorts) |
|
|
|
|
|
|
|
|
|
go func() { |
|
|
|
|
err := client.Conn.Wait() |
|
|
|
|
err := client.SshConn.Wait() |
|
|
|
|
client.ListenMutex.Lock() |
|
|
|
|
client.Stopping = true |
|
|
|
|
|
|
|
|
@ -240,7 +251,7 @@ func handleDirect(client *sshClient, newChannel ssh.NewChannel) { |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
serve(connection, rconn, client) |
|
|
|
|
serve(connection, rconn, client, *directtimeout) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func handleTcpIpForward(client *sshClient, req *ssh.Request) (net.Listener, *bindInfo, error) { |
|
|
|
@ -318,7 +329,7 @@ func handleForwardTcpIp(client *sshClient, bindinfo *bindInfo, lconn net.Conn) { |
|
|
|
|
mpayload := ssh.Marshal(&payload) |
|
|
|
|
|
|
|
|
|
// Open channel with client
|
|
|
|
|
c, requests, err := client.Conn.OpenChannel("forwarded-tcpip", mpayload) |
|
|
|
|
c, requests, err := client.SshConn.OpenChannel("forwarded-tcpip", mpayload) |
|
|
|
|
if err != nil { |
|
|
|
|
log.Printf("[%s] Unable to get channel: %s. Hanging up requesting party!", client.Name, err) |
|
|
|
|
lconn.Close() |
|
|
|
@ -329,7 +340,7 @@ func handleForwardTcpIp(client *sshClient, bindinfo *bindInfo, lconn net.Conn) { |
|
|
|
|
} |
|
|
|
|
go ssh.DiscardRequests(requests) |
|
|
|
|
|
|
|
|
|
serve(c, lconn, client) |
|
|
|
|
serve(c, lconn, client, *forwardedtimeout) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func handleTcpIPForwardCancel(client *sshClient, req *ssh.Request) { |
|
|
|
@ -353,7 +364,7 @@ func handleTcpIPForwardCancel(client *sshClient, req *ssh.Request) { |
|
|
|
|
req.Reply(false, []byte{}) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func serve(cssh ssh.Channel, conn net.Conn, client *sshClient) { |
|
|
|
|
func serve(cssh ssh.Channel, conn net.Conn, client *sshClient, timeout time.Duration) { |
|
|
|
|
// TODO: Maybe just do this with defer instead? (And only one copy in a
|
|
|
|
|
// goroutine)
|
|
|
|
|
close := func() { |
|
|
|
@ -364,17 +375,59 @@ func serve(cssh ssh.Channel, conn net.Conn, client *sshClient) { |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// TODO: Share timeout between both the read-conn and the write-conn
|
|
|
|
|
var once sync.Once |
|
|
|
|
go func() { |
|
|
|
|
io.Copy(cssh, conn) |
|
|
|
|
//io.Copy(cssh, conn)
|
|
|
|
|
_, _ = copyTimeout(cssh, conn, func() { |
|
|
|
|
conn.SetDeadline(time.Now().Add(timeout)) |
|
|
|
|
client.Conn.SetDeadline(time.Now().Add(*maintimeout)) |
|
|
|
|
}) |
|
|
|
|
once.Do(close) |
|
|
|
|
}() |
|
|
|
|
go func() { |
|
|
|
|
io.Copy(conn, cssh) |
|
|
|
|
//io.Copy(conn, cssh)
|
|
|
|
|
_, _ = copyTimeout(conn, cssh, func() { |
|
|
|
|
conn.SetDeadline(time.Now().Add(timeout)) |
|
|
|
|
client.Conn.SetDeadline(time.Now().Add(*maintimeout)) |
|
|
|
|
}) |
|
|
|
|
once.Do(close) |
|
|
|
|
}() |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Changed from pkg/io/io.go copyBuffer
|
|
|
|
|
func copyTimeout(dst io.Writer, src io.Reader, timeout TimeoutFunc) (written int64, err error) { |
|
|
|
|
buf := make([]byte, 32*1024) |
|
|
|
|
|
|
|
|
|
for { |
|
|
|
|
nr, er := src.Read(buf) |
|
|
|
|
if nr > 0 { |
|
|
|
|
timeout() |
|
|
|
|
|
|
|
|
|
nw, ew := dst.Write(buf[0:nr]) |
|
|
|
|
if nw > 0 { |
|
|
|
|
written += int64(nw) |
|
|
|
|
} |
|
|
|
|
if ew != nil { |
|
|
|
|
err = ew |
|
|
|
|
break |
|
|
|
|
} |
|
|
|
|
if nr != nw { |
|
|
|
|
err = io.ErrShortWrite |
|
|
|
|
break |
|
|
|
|
} |
|
|
|
|
timeout() |
|
|
|
|
} |
|
|
|
|
if er != nil { |
|
|
|
|
if er != io.EOF { |
|
|
|
|
err = er |
|
|
|
|
} |
|
|
|
|
break |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
return written, err |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func loadHostKeys(config *ssh.ServerConfig) { |
|
|
|
|
privateBytes, err := ioutil.ReadFile(*hostkey) |
|
|
|
|
if err != nil { |
|
|
|
@ -454,6 +507,8 @@ func registerReloadSignal() { |
|
|
|
|
|
|
|
|
|
func handleRequest(client *sshClient, reqs <-chan *ssh.Request) { |
|
|
|
|
for req := range reqs { |
|
|
|
|
client.Conn.SetDeadline(time.Now().Add(*maintimeout)) |
|
|
|
|
|
|
|
|
|
if *verbose { |
|
|
|
|
log.Printf("[%s] Out of band request: %v %v", client.Name, req.Type, req.WantReply) |
|
|
|
|
} |
|
|
|
|