From 03506df25d5c6af6daf1ba19300d2697ad241bda Mon Sep 17 00:00:00 2001
From: David Sheets <sheets@alum.mit.edu>
Date: Mon, 7 Sep 2020 09:10:56 +0100
Subject: [PATCH] cli: factor options struct and parsing into home/options.go

---
 home/home.go         | 109 +++---------------
 home/options.go      | 255 +++++++++++++++++++++++++++++++++++++++++++
 home/options_test.go | 171 +++++++++++++++++++++++++++++
 3 files changed, 439 insertions(+), 96 deletions(-)
 create mode 100644 home/options.go
 create mode 100644 home/options_test.go

diff --git a/home/home.go b/home/home.go
index b9d7a760..806c59e2 100644
--- a/home/home.go
+++ b/home/home.go
@@ -526,108 +526,25 @@ func cleanupAlways() {
 	log.Info("Stopped")
 }
 
-// command-line arguments
-type options struct {
-	verbose        bool   // is verbose logging enabled
-	configFilename string // path to the config file
-	workDir        string // path to the working directory where we will store the filters data and the querylog
-	bindHost       string // host address to bind HTTP server on
-	bindPort       int    // port to serve HTTP pages on
-	logFile        string // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog
-	pidFile        string // File name to save PID to
-	checkConfig    bool   // Check configuration and exit
-	disableUpdate  bool   // If set, don't check for updates
-
-	// service control action (see service.ControlAction array + "status" command)
-	serviceControlAction string
-
-	// runningAsService flag is set to true when options are passed from the service runner
-	runningAsService bool
-
-	glinetMode bool // Activate GL-Inet mode
+func exitWithError() {
+	os.Exit(64)
 }
 
 // loadOptions reads command line arguments and initializes configuration
 func loadOptions() options {
-	o := options{}
+	o, f, err := parse(os.Args[0], os.Args[1:])
 
-	var printHelp func()
-	var opts = []struct {
-		longName          string
-		shortName         string
-		description       string
-		callbackWithValue func(value string)
-		callbackNoValue   func()
-	}{
-		{"config", "c", "Path to the config file", func(value string) { o.configFilename = value }, nil},
-		{"work-dir", "w", "Path to the working directory", func(value string) { o.workDir = value }, nil},
-		{"host", "h", "Host address to bind HTTP server on", func(value string) { o.bindHost = value }, nil},
-		{"port", "p", "Port to serve HTTP pages on", func(value string) {
-			v, err := strconv.Atoi(value)
-			if err != nil {
-				panic("Got port that is not a number")
-			}
-			o.bindPort = v
-		}, nil},
-		{"service", "s", "Service control action: status, install, uninstall, start, stop, restart, reload (configuration)", func(value string) {
-			o.serviceControlAction = value
-		}, nil},
-		{"logfile", "l", "Path to log file. If empty: write to stdout; if 'syslog': write to system log", func(value string) {
-			o.logFile = value
-		}, nil},
-		{"pidfile", "", "Path to a file where PID is stored", func(value string) { o.pidFile = value }, nil},
-		{"check-config", "", "Check configuration and exit", nil, func() { o.checkConfig = true }},
-		{"no-check-update", "", "Don't check for updates", nil, func() { o.disableUpdate = true }},
-		{"verbose", "v", "Enable verbose output", nil, func() { o.verbose = true }},
-		{"glinet", "", "Run in GL-Inet compatibility mode", nil, func() { o.glinetMode = true }},
-		{"version", "", "Show the version and exit", nil, func() {
-			fmt.Println(version())
+	if err != nil {
+		log.Error(err.Error())
+		_ = printHelp(os.Args[0])
+		exitWithError()
+	} else if f != nil {
+		err = f()
+		if err != nil {
+			log.Error(err.Error())
+			exitWithError()
+		} else {
 			os.Exit(0)
-		}},
-		{"help", "", "Print this help", nil, func() {
-			printHelp()
-			os.Exit(64)
-		}},
-	}
-	printHelp = func() {
-		fmt.Printf("Usage:\n\n")
-		fmt.Printf("%s [options]\n\n", os.Args[0])
-		fmt.Printf("Options:\n")
-		for _, opt := range opts {
-			val := ""
-			if opt.callbackWithValue != nil {
-				val = " VALUE"
-			}
-			if opt.shortName != "" {
-				fmt.Printf("  -%s, %-30s %s\n", opt.shortName, "--"+opt.longName+val, opt.description)
-			} else {
-				fmt.Printf("  %-34s %s\n", "--"+opt.longName+val, opt.description)
-			}
-		}
-	}
-	for i := 1; i < len(os.Args); i++ {
-		v := os.Args[i]
-		knownParam := false
-		for _, opt := range opts {
-			if v == "--"+opt.longName || (opt.shortName != "" && v == "-"+opt.shortName) {
-				if opt.callbackWithValue != nil {
-					if i+1 >= len(os.Args) {
-						log.Error("Got %s without argument\n", v)
-						os.Exit(64)
-					}
-					i++
-					opt.callbackWithValue(os.Args[i])
-				} else if opt.callbackNoValue != nil {
-					opt.callbackNoValue()
-				}
-				knownParam = true
-				break
-			}
-		}
-		if !knownParam {
-			log.Error("unknown option %v\n", v)
-			printHelp()
-			os.Exit(64)
 		}
 	}
 
diff --git a/home/options.go b/home/options.go
new file mode 100644
index 00000000..ec789d4f
--- /dev/null
+++ b/home/options.go
@@ -0,0 +1,255 @@
+package home
+
+import (
+	"fmt"
+	"os"
+	"strconv"
+)
+
+// options passed from command-line arguments
+type options struct {
+	verbose        bool   // is verbose logging enabled
+	configFilename string // path to the config file
+	workDir        string // path to the working directory where we will store the filters data and the querylog
+	bindHost       string // host address to bind HTTP server on
+	bindPort       int    // port to serve HTTP pages on
+	logFile        string // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog
+	pidFile        string // File name to save PID to
+	checkConfig    bool   // Check configuration and exit
+	disableUpdate  bool   // If set, don't check for updates
+
+	// service control action (see service.ControlAction array + "status" command)
+	serviceControlAction string
+
+	// runningAsService flag is set to true when options are passed from the service runner
+	runningAsService bool
+
+	glinetMode bool // Activate GL-Inet mode
+}
+
+// functions used for their side-effects
+type effect func() error
+
+type arg struct {
+	description string // a short, English description of the argument
+	longName    string // the name of the argument used after '--'
+	shortName   string // the name of the argument used after '-'
+
+	// only one of updateWithValue, updateNoValue, and effect should be present
+
+	updateWithValue func(o options, v string) (options, error)         // the mutator for arguments with parameters
+	updateNoValue   func(o options) (options, error)                   // the mutator for arguments without parameters
+	effect          func(o options, exec string) (f effect, err error) // the side-effect closure generator
+}
+
+var args []arg
+
+var configArg = arg{
+	"Path to the config file",
+	"config", "c",
+	func(o options, v string) (options, error) { o.configFilename = v; return o, nil },
+	nil,
+	nil,
+}
+
+var workDirArg = arg{
+	"Path to the working directory",
+	"work-dir", "w",
+	func(o options, v string) (options, error) { o.workDir = v; return o, nil }, nil, nil,
+}
+
+var hostArg = arg{
+	"Host address to bind HTTP server on",
+	"host", "h",
+	func(o options, v string) (options, error) { o.bindHost = v; return o, nil }, nil, nil,
+}
+
+var portArg = arg{
+	"Port to serve HTTP pages on",
+	"port", "p",
+	func(o options, v string) (options, error) {
+		var err error
+		var p int
+		minPort, maxPort := 0, 1<<16-1
+		if p, err = strconv.Atoi(v); err != nil {
+			err = fmt.Errorf("port '%s' is not a number", v)
+		} else if p < minPort || p > maxPort {
+			err = fmt.Errorf("port %d not in range %d - %d", p, minPort, maxPort)
+		} else {
+			o.bindPort = p
+		}
+		return o, err
+	}, nil, nil,
+}
+
+var serviceArg = arg{
+	"Service control action: status, install, uninstall, start, stop, restart, reload (configuration)",
+	"service", "s",
+	func(o options, v string) (options, error) {
+		o.serviceControlAction = v
+		return o, nil
+	}, nil, nil,
+}
+
+var logfileArg = arg{
+	"Path to log file. If empty: write to stdout; if 'syslog': write to system log",
+	"logfile", "l",
+	func(o options, v string) (options, error) { o.logFile = v; return o, nil }, nil, nil,
+}
+
+var pidfileArg = arg{
+	"Path to a file where PID is stored",
+	"pidfile", "",
+	func(o options, v string) (options, error) { o.pidFile = v; return o, nil }, nil, nil,
+}
+
+var checkConfigArg = arg{
+	"Check configuration and exit",
+	"check-config", "",
+	nil, func(o options) (options, error) { o.checkConfig = true; return o, nil }, nil,
+}
+
+var noCheckUpdateArg = arg{
+	"Don't check for updates",
+	"no-check-update", "",
+	nil, func(o options) (options, error) { o.disableUpdate = true; return o, nil }, nil,
+}
+
+var verboseArg = arg{
+	"Enable verbose output",
+	"verbose", "v",
+	nil, func(o options) (options, error) { o.verbose = true; return o, nil }, nil,
+}
+
+var glinetArg = arg{
+	"Run in GL-Inet compatibility mode",
+	"glinet", "",
+	nil, func(o options) (options, error) { o.glinetMode = true; return o, nil }, nil,
+}
+
+var versionArg = arg{
+	"Show the version and exit",
+	"version", "",
+	nil, nil, func(o options, exec string) (effect, error) {
+		return func() error { fmt.Println(version()); os.Exit(0); return nil }, nil
+	},
+}
+
+var helpArg = arg{
+	"Print this help",
+	"help", "",
+	nil, nil, func(o options, exec string) (effect, error) {
+		return func() error { _ = printHelp(exec); os.Exit(64); return nil }, nil
+	},
+}
+
+func init() {
+	args = []arg{
+		configArg,
+		workDirArg,
+		hostArg,
+		portArg,
+		serviceArg,
+		logfileArg,
+		pidfileArg,
+		checkConfigArg,
+		noCheckUpdateArg,
+		verboseArg,
+		glinetArg,
+		versionArg,
+		helpArg,
+	}
+}
+
+func getUsageLines(exec string, args []arg) []string {
+	usage := []string{
+		"Usage:",
+		"",
+		fmt.Sprintf("%s [options]", exec),
+		"",
+		"Options:",
+	}
+	for _, arg := range args {
+		val := ""
+		if arg.updateWithValue != nil {
+			val = " VALUE"
+		}
+		if arg.shortName != "" {
+			usage = append(usage, fmt.Sprintf("  -%s, %-30s %s",
+				arg.shortName,
+				"--"+arg.longName+val,
+				arg.description))
+		} else {
+			usage = append(usage, fmt.Sprintf("  %-34s %s",
+				"--"+arg.longName+val,
+				arg.description))
+		}
+	}
+	return usage
+}
+
+func printHelp(exec string) error {
+	for _, line := range getUsageLines(exec, args) {
+		_, err := fmt.Println(line)
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func argMatches(a arg, v string) bool {
+	return v == "--"+a.longName || (a.shortName != "" && v == "-"+a.shortName)
+}
+
+func parse(exec string, ss []string) (o options, f effect, err error) {
+	for i := 0; i < len(ss); i++ {
+		v := ss[i]
+		knownParam := false
+		for _, arg := range args {
+			if argMatches(arg, v) {
+				if arg.updateWithValue != nil {
+					if i+1 >= len(ss) {
+						return o, f, fmt.Errorf("got %s without argument", v)
+					}
+					i++
+					o, err = arg.updateWithValue(o, ss[i])
+					if err != nil {
+						return
+					}
+				} else if arg.updateNoValue != nil {
+					o, err = arg.updateNoValue(o)
+					if err != nil {
+						return
+					}
+				} else if arg.effect != nil {
+					var eff effect
+					eff, err = arg.effect(o, exec)
+					if err != nil {
+						return
+					}
+					if eff != nil {
+						prevf := f
+						f = func() error {
+							var err error
+							if prevf != nil {
+								err = prevf()
+							}
+							if err == nil {
+								err = eff()
+							}
+							return err
+						}
+					}
+				}
+				knownParam = true
+				break
+			}
+		}
+		if !knownParam {
+			return o, f, fmt.Errorf("unknown option %v", v)
+		}
+	}
+
+	return
+}
diff --git a/home/options_test.go b/home/options_test.go
new file mode 100644
index 00000000..750f3ce2
--- /dev/null
+++ b/home/options_test.go
@@ -0,0 +1,171 @@
+package home
+
+import (
+	"fmt"
+	"testing"
+)
+
+func testParseOk(t *testing.T, ss ...string) options {
+	o, _, err := parse("", ss)
+	if err != nil {
+		t.Fatal(err.Error())
+	}
+	return o
+}
+
+func testParseErr(t *testing.T, descr string, ss ...string) {
+	_, _, err := parse("", ss)
+	if err == nil {
+		t.Fatalf("expected an error because %s but no error returned", descr)
+	}
+}
+
+func testParseParamMissing(t *testing.T, param string) {
+	testParseErr(t, fmt.Sprintf("%s parameter missing", param), param)
+}
+
+func TestParseVerbose(t *testing.T) {
+	if testParseOk(t).verbose {
+		t.Fatal("empty is not verbose")
+	}
+	if !testParseOk(t, "-v").verbose {
+		t.Fatal("-v is verbose")
+	}
+	if !testParseOk(t, "--verbose").verbose {
+		t.Fatal("--verbose is verbose")
+	}
+}
+
+func TestParseConfigFilename(t *testing.T) {
+	if testParseOk(t).configFilename != "" {
+		t.Fatal("empty is no config filename")
+	}
+	if testParseOk(t, "-c", "path").configFilename != "path" {
+		t.Fatal("-c is config filename")
+	}
+	testParseParamMissing(t, "-c")
+	if testParseOk(t, "--config", "path").configFilename != "path" {
+		t.Fatal("--configFilename is config filename")
+	}
+	testParseParamMissing(t, "--config")
+}
+
+func TestParseWorkDir(t *testing.T) {
+	if testParseOk(t).workDir != "" {
+		t.Fatal("empty is no work dir")
+	}
+	if testParseOk(t, "-w", "path").workDir != "path" {
+		t.Fatal("-w is work dir")
+	}
+	testParseParamMissing(t, "-w")
+	if testParseOk(t, "--work-dir", "path").workDir != "path" {
+		t.Fatal("--work-dir is work dir")
+	}
+	testParseParamMissing(t, "--work-dir")
+}
+
+func TestParseBindHost(t *testing.T) {
+	if testParseOk(t).bindHost != "" {
+		t.Fatal("empty is no host")
+	}
+	if testParseOk(t, "-h", "addr").bindHost != "addr" {
+		t.Fatal("-h is host")
+	}
+	testParseParamMissing(t, "-h")
+	if testParseOk(t, "--host", "addr").bindHost != "addr" {
+		t.Fatal("--host is host")
+	}
+	testParseParamMissing(t, "--host")
+}
+
+func TestParseBindPort(t *testing.T) {
+	if testParseOk(t).bindPort != 0 {
+		t.Fatal("empty is port 0")
+	}
+	if testParseOk(t, "-p", "65535").bindPort != 65535 {
+		t.Fatal("-p is port")
+	}
+	testParseParamMissing(t, "-p")
+	if testParseOk(t, "--port", "65535").bindPort != 65535 {
+		t.Fatal("--port is port")
+	}
+	testParseParamMissing(t, "--port")
+}
+
+func TestParseBindPortBad(t *testing.T) {
+	testParseErr(t, "not an int", "-p", "x")
+	testParseErr(t, "hex not supported", "-p", "0x100")
+	testParseErr(t, "port negative", "-p", "-1")
+	testParseErr(t, "port too high", "-p", "65536")
+	testParseErr(t, "port too high", "-p", "4294967297")           // 2^32 + 1
+	testParseErr(t, "port too high", "-p", "18446744073709551617") // 2^64 + 1
+}
+
+func TestParseLogfile(t *testing.T) {
+	if testParseOk(t).logFile != "" {
+		t.Fatal("empty is no log file")
+	}
+	if testParseOk(t, "-l", "path").logFile != "path" {
+		t.Fatal("-l is log file")
+	}
+	if testParseOk(t, "--logfile", "path").logFile != "path" {
+		t.Fatal("--logfile is log file")
+	}
+}
+
+func TestParsePidfile(t *testing.T) {
+	if testParseOk(t).pidFile != "" {
+		t.Fatal("empty is no pid file")
+	}
+	if testParseOk(t, "--pidfile", "path").pidFile != "path" {
+		t.Fatal("--pidfile is pid file")
+	}
+}
+
+func TestParseCheckConfig(t *testing.T) {
+	if testParseOk(t).checkConfig {
+		t.Fatal("empty is not check config")
+	}
+	if !testParseOk(t, "--check-config").checkConfig {
+		t.Fatal("--check-config is check config")
+	}
+}
+
+func TestParseDisableUpdate(t *testing.T) {
+	if testParseOk(t).disableUpdate {
+		t.Fatal("empty is not disable update")
+	}
+	if !testParseOk(t, "--no-check-update").disableUpdate {
+		t.Fatal("--no-check-update is disable update")
+	}
+}
+
+func TestParseService(t *testing.T) {
+	if testParseOk(t).serviceControlAction != "" {
+		t.Fatal("empty is no service command")
+	}
+	if testParseOk(t, "-s", "command").serviceControlAction != "command" {
+		t.Fatal("-s is service command")
+	}
+	if testParseOk(t, "--service", "command").serviceControlAction != "command" {
+		t.Fatal("--service is service command")
+	}
+}
+
+func TestParseGLInet(t *testing.T) {
+	if testParseOk(t).glinetMode {
+		t.Fatal("empty is not GL-Inet mode")
+	}
+	if !testParseOk(t, "--glinet").glinetMode {
+		t.Fatal("--glinet is GL-Inet mode")
+	}
+}
+
+func TestParseUnknown(t *testing.T) {
+	testParseErr(t, "unknown word", "x")
+	testParseErr(t, "unknown short", "-x")
+	testParseErr(t, "unknown long", "--x")
+	testParseErr(t, "unknown triple", "---x")
+	testParseErr(t, "unknown plus", "+x")
+	testParseErr(t, "unknown dash", "-")
+}