diff --git a/sshd.go b/sshd.go index 38163c2..516282a 100644 --- a/sshd.go +++ b/sshd.go @@ -21,7 +21,7 @@ import ( // TODO: Use defer where useful var ( - authorisedKeys map[string]string + authorisedKeys map[string]deviceInfo listenport = flag.Int("listenport", 2200, "Port to listen on for incoming ssh connections") hostkey = flag.String("hostkey", "id_rsa", "Server host key to load") @@ -43,6 +43,12 @@ type bindInfo struct { Addr string } +type deviceInfo struct { + LocalPorts string + RemotePorts string + Comment string +} + /* RFC4254 7.2 */ type directTCPPayload struct { Addr string // To connect to @@ -77,9 +83,11 @@ func main() { config := &ssh.ServerConfig{ PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { - if ports, found := authorisedKeys[string(key.Marshal())]; found { + if deviceinfo, found := authorisedKeys[string(key.Marshal())]; found { return &ssh.Permissions{ - CriticalOptions: map[string]string{"ports": ports}, + CriticalOptions: map[string]string{"name": deviceinfo.Comment, + "localports": deviceinfo.LocalPorts, + "remoteports": deviceinfo.RemotePorts}, }, nil } @@ -111,7 +119,18 @@ func main() { return } - client := sshClient{"TODO FIXME XXX", sshConn, make(map[string]net.Listener), nil, nil} + client := sshClient{sshConn.Permissions.CriticalOptions["name"], sshConn, make(map[string]net.Listener), nil, nil} + allowedLocalPorts := sshConn.Permissions.CriticalOptions["localports"] + allowedRemotePorts := sshConn.Permissions.CriticalOptions["remoteports"] + + if *verbose { + log.Printf("Connection from %s (%s). Allowed local ports: %s remote ports: %s", sshConn.RemoteAddr(), sshConn.ClientVersion(), allowedLocalPorts, allowedRemotePorts) + } + + // Parsing a second time should not error, so we can ignore the error + // safely + client.AllowedLocalPorts, _ = parsePorts(allowedLocalPorts) + client.AllowedRemotePorts, _ = parsePorts(allowedRemotePorts) go func() { err := client.Conn.Wait() @@ -128,19 +147,6 @@ func main() { } }() - allowedPorts := sshConn.Permissions.CriticalOptions["ports"] - - if *verbose { - log.Printf("Connection from %s (%s). Allowed ports: %s", sshConn.RemoteAddr(), sshConn.ClientVersion(), allowedPorts) - } - - // Parsing a second time should not error, so we can ignore the error - // safely - ports, _ := parsePorts(allowedPorts) - // TODO: Don't share same port/host limit - client.AllowedLocalPorts = ports - client.AllowedRemotePorts = ports - go handleRequest(&client, reqs) // Accept all channels @@ -355,18 +361,6 @@ func serve(cssh ssh.Channel, conn net.Conn) { }() } -func parsePorts(portstr string) (p []uint32, err error) { - ports := strings.Split(portstr, ":") - for _, port := range ports { - port, err := strconv.ParseUint(port, 10, 32) - if err != nil { - return p, err - } - p = append(p, uint32(port)) - } - return -} - func loadHostKeys(config *ssh.ServerConfig) { privateBytes, err := ioutil.ReadFile(*hostkey) if err != nil { @@ -382,39 +376,41 @@ func loadHostKeys(config *ssh.ServerConfig) { } func loadAuthorisedKeys(authorisedkeys string) { - authorisedKeys = map[string]string{} + authorisedKeys = map[string]deviceInfo{} authorisedKeysBytes, err := ioutil.ReadFile(authorisedkeys) if err != nil { log.Fatal("Cannot load authorised keys") } for len(authorisedKeysBytes) > 0 { - pubkey, _, options, rest, err := ssh.ParseAuthorizedKey(authorisedKeysBytes) + pubkey, comment, options, rest, err := ssh.ParseAuthorizedKey(authorisedKeysBytes) if err != nil { log.Fatal(err) } log.Println("Options:", options) - if len(options) != 1 { - log.Fatal(fmt.Errorf("Only one option is accepted: \"ports=...\"")) - } - - option := options[0] - if !strings.HasPrefix(option, "ports=") { - log.Fatal(fmt.Errorf("Options does not start with \"ports=\"")) + devinfo := deviceInfo{ + Comment: comment, } - ports := option[len("ports="):] - - _, err = parsePorts(ports) - - if err != nil { - log.Fatal(err) + for _, option := range options { + ports, err := parseOption(option, "local") + if err != nil { + ports, err := parseOption(option, "remote") + if err != nil { + log.Fatal(err) + } else { + devinfo.RemotePorts = ports + } + } else { + devinfo.LocalPorts = ports + } } - authorisedKeys[string(pubkey.Marshal())] = ports + authorisedKeys[string(pubkey.Marshal())] = devinfo + authorisedKeysBytes = rest } } @@ -456,3 +452,32 @@ func portPermitted(port uint32, ports []uint32) bool { return ok } + +func parseOption(option string, prefix string) (string, error) { + str := fmt.Sprintf("%sports=", prefix) + if !strings.HasPrefix(option, str) { + return "", fmt.Errorf("Option does not start with %s", str) + } + + ports := option[len(str):] + + _, err := parsePorts(ports) + + if err != nil { + log.Fatal(err) + } + + return ports, nil +} + +func parsePorts(portstr string) (p []uint32, err error) { + ports := strings.Split(portstr, ":") + for _, port := range ports { + port, err := strconv.ParseUint(port, 10, 32) + if err != nil { + return p, err + } + p = append(p, uint32(port)) + } + return +}