From 99b46e8767edf9f2a230a707d6ce194d01ea3ecd Mon Sep 17 00:00:00 2001 From: Merlijn Wajer Date: Wed, 8 Mar 2017 02:12:54 +0100 Subject: [PATCH] Merge port filtering code --- sshd.go | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/sshd.go b/sshd.go index f339332..adedd08 100644 --- a/sshd.go +++ b/sshd.go @@ -196,15 +196,7 @@ func handleDirect(client *sshClient, newChannel ssh.NewChannel) { return } - ok := false - for _, port := range client.AllowedLocalPorts { - if payload.Port == port { - ok = true - break - } - } - - if !ok { + if !portPermitted(payload.Port, client.AllowedLocalPorts) { newChannel.Reject(ssh.Prohibited, fmt.Sprintf("Bad port")) log.Printf("Tried to connect to prohibited port: %d", payload.Port) return @@ -252,15 +244,7 @@ func handleTcpIpForward(client *sshClient, req *ssh.Request) (net.Listener, *bin return nil, nil, fmt.Errorf("Address is not permitted") } - ok := false - for _, port := range client.AllowedRemotePorts { - if payload.Port == port { - ok = true - break - } - } - - if !ok { + if !portPermitted(payload.Port, client.AllowedRemotePorts) { log.Printf("Port is not permitted.") req.Reply(false, []byte{}) return nil, nil, fmt.Errorf("Port is not permitted") @@ -461,3 +445,15 @@ func handleRequest(client *sshClient, reqs <-chan *ssh.Request) { } } } + +func portPermitted(port uint32, ports []uint32) bool { + ok := false + for _, p := range ports { + if port == p { + ok = true + break + } + } + + return ok +}