Initial implementation of welcome/firstrun/installer page in go backend

This commit is contained in:
Eugene Bujak 2019-01-29 20:41:57 +03:00
parent c494e17df5
commit 302c3a767a
5 changed files with 241 additions and 199 deletions

97
app.go
View file

@ -1,7 +1,6 @@
package main
import (
"bufio"
"fmt"
stdlog "log"
"net"
@ -17,7 +16,6 @@ import (
"github.com/gobuffalo/packr"
"github.com/hmage/golibs/log"
"golang.org/x/crypto/ssh/terminal"
)
// VersionString will be set through ldflags, contains current version
@ -72,13 +70,10 @@ func run(args options) {
log.Printf("AdGuard Home is running as a service")
}
err := askUsernamePasswordIfPossible()
if err != nil {
log.Fatal(err)
}
config.firstRun = detectFirstRun()
// Do the upgrade if necessary
err = upgradeConfig()
err := upgradeConfig()
if err != nil {
log.Fatal(err)
}
@ -145,7 +140,9 @@ func run(args options) {
// Initialize and run the admin Web interface
box := packr.NewBox("build/static")
http.Handle("/", optionalAuthHandler(http.FileServer(box)))
// if not configured, redirect / to /install.html, otherwise redirect /install.html to /
http.Handle("/", postInstallHandler(optionalAuthHandler(http.FileServer(box))))
http.Handle("/install.html", preInstallHandler(http.FileServer(box)))
registerControlHandlers()
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
@ -222,14 +219,6 @@ func cleanup() {
}
}
func getInput() (string, error) {
scanner := bufio.NewScanner(os.Stdin)
scanner.Scan()
text := scanner.Text()
err := scanner.Err()
return text, err
}
// command-line arguments
type options struct {
verbose bool // is verbose logging enabled
@ -318,79 +307,3 @@ func loadOptions() options {
return o
}
func promptAndGet(prompt string) (string, error) {
for {
fmt.Print(prompt)
input, err := getInput()
if err != nil {
log.Printf("Failed to get input, aborting: %s", err)
return "", err
}
if len(input) != 0 {
return input, nil
}
// try again
}
}
func promptAndGetPassword(prompt string) (string, error) {
for {
fmt.Print(prompt)
password, err := terminal.ReadPassword(int(os.Stdin.Fd()))
fmt.Print("\n")
if err != nil {
log.Printf("Failed to get input, aborting: %s", err)
return "", err
}
if len(password) != 0 {
return string(password), nil
}
// try again
}
}
func askUsernamePasswordIfPossible() error {
configFile := config.getConfigFilename()
_, err := os.Stat(configFile)
if !os.IsNotExist(err) {
// do nothing, file exists
return nil
}
if !terminal.IsTerminal(int(os.Stdin.Fd())) {
return nil // do nothing
}
if !terminal.IsTerminal(int(os.Stdout.Fd())) {
return nil // do nothing
}
fmt.Printf("Would you like to set user/password for the web interface authentication (yes/no)?\n")
yesno, err := promptAndGet("Please type 'yes' or 'no': ")
if err != nil {
return err
}
if yesno[0] != 'y' && yesno[0] != 'Y' {
return nil
}
username, err := promptAndGet("Please enter the username: ")
if err != nil {
return err
}
password, err := promptAndGetPassword("Please enter the password: ")
if err != nil {
return err
}
password2, err := promptAndGetPassword("Please enter password again: ")
if err != nil {
return err
}
if password2 != password {
fmt.Printf("Passwords do not match! Aborting\n")
os.Exit(1)
}
config.AuthName = username
config.AuthPass = password
return nil
}

View file

@ -29,6 +29,7 @@ type logSettings struct {
type configuration struct {
ourConfigFilename string // Config filename (can be overridden via the command line arguments)
ourBinaryDir string // Location of our directory, used to protect against CWD being somewhere else
firstRun bool // if set to true, don't run any services except HTTP web inteface, and serve only first-run html
BindHost string `yaml:"bind_host"`
BindPort int `yaml:"bind_port"`
@ -152,6 +153,10 @@ func readConfigFile() ([]byte, error) {
func (c *configuration) write() error {
c.Lock()
defer c.Unlock()
if config.firstRun {
log.Tracef("Silently refusing to write config because first run and not configured yet")
return nil
}
configFile := config.getConfigFilename()
log.Printf("Writing YAML file: %s", configFile)
yamlText, err := yaml.Marshal(&config)

View file

@ -694,24 +694,43 @@ func handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) {
}
}
func handleGetDefaultAddresses(w http.ResponseWriter, r *http.Request) {
type ipport struct {
IP string `json:"ip"`
Port int `json:"port"`
}
data := struct {
type firstRunData struct {
Web ipport `json:"web"`
DNS ipport `json:"dns"`
}{}
Username string `json:"username,omitempty"`
Password string `json:"password,omitempty"`
}
// TODO: replace mockup with actual data
data.Web.IP = "192.168.104.104"
data.Web.Port = 3000
data.DNS.IP = "192.168.104.104"
data.DNS.Port = 53
func handleGetDefaultAddresses(w http.ResponseWriter, r *http.Request) {
data := firstRunData{}
ifaces, err := getValidNetInterfaces()
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
return
}
if len(ifaces) == 0 {
httpError(w, http.StatusServiceUnavailable, "Couldn't find any legible interface, plase try again later")
return
}
// find an interface with an ipv4 address
addr := findIPv4IfaceAddr(ifaces)
if len(addr) == 0 {
httpError(w, http.StatusServiceUnavailable, "Couldn't find any interface with IPv4, plase try again later")
return
}
data.Web.IP = addr
data.DNS.IP = addr
data.Web.Port = 3000 // TODO: find out if port 80 is available -- if not, fall back to 3000
data.DNS.Port = 53 // TODO: find out if port 53 is available -- if not, show a big warning
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(data)
err = json.NewEncoder(w).Encode(data)
if err != nil {
httpError(w, http.StatusInternalServerError, "Unable to marshal default addresses to json: %s", err)
return
@ -719,7 +738,7 @@ func handleGetDefaultAddresses(w http.ResponseWriter, r *http.Request) {
}
func handleSetAllSettings(w http.ResponseWriter, r *http.Request) {
newSettings := map[string]interface{}{}
newSettings := firstRunData{}
err := json.NewDecoder(r.Body).Decode(&newSettings)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to parse new DHCP config json: %s", err)
@ -727,48 +746,57 @@ func handleSetAllSettings(w http.ResponseWriter, r *http.Request) {
}
spew.Dump(newSettings)
config.firstRun = false
config.BindHost = newSettings.Web.IP
config.BindPort = newSettings.Web.Port
config.DNS.BindHost = newSettings.DNS.IP
config.DNS.Port = newSettings.DNS.Port
config.AuthName = newSettings.Username
config.AuthPass = newSettings.Password
httpUpdateConfigReloadDNSReturnOK(w, r)
}
func registerControlHandlers() {
http.HandleFunc("/control/status", optionalAuth(ensureGET(handleStatus)))
http.HandleFunc("/control/enable_protection", optionalAuth(ensurePOST(handleProtectionEnable)))
http.HandleFunc("/control/disable_protection", optionalAuth(ensurePOST(handleProtectionDisable)))
http.HandleFunc("/control/querylog", optionalAuth(ensureGET(dnsforward.HandleQueryLog)))
http.HandleFunc("/control/querylog_enable", optionalAuth(ensurePOST(handleQueryLogEnable)))
http.HandleFunc("/control/querylog_disable", optionalAuth(ensurePOST(handleQueryLogDisable)))
http.HandleFunc("/control/set_upstream_dns", optionalAuth(ensurePOST(handleSetUpstreamDNS)))
http.HandleFunc("/control/test_upstream_dns", optionalAuth(ensurePOST(handleTestUpstreamDNS)))
http.HandleFunc("/control/i18n/change_language", optionalAuth(ensurePOST(handleI18nChangeLanguage)))
http.HandleFunc("/control/i18n/current_language", optionalAuth(ensureGET(handleI18nCurrentLanguage)))
http.HandleFunc("/control/stats_top", optionalAuth(ensureGET(dnsforward.HandleStatsTop)))
http.HandleFunc("/control/stats", optionalAuth(ensureGET(dnsforward.HandleStats)))
http.HandleFunc("/control/stats_history", optionalAuth(ensureGET(dnsforward.HandleStatsHistory)))
http.HandleFunc("/control/stats_reset", optionalAuth(ensurePOST(dnsforward.HandleStatsReset)))
http.HandleFunc("/control/version.json", optionalAuth(handleGetVersionJSON))
http.HandleFunc("/control/filtering/enable", optionalAuth(ensurePOST(handleFilteringEnable)))
http.HandleFunc("/control/filtering/disable", optionalAuth(ensurePOST(handleFilteringDisable)))
http.HandleFunc("/control/filtering/add_url", optionalAuth(ensurePUT(handleFilteringAddURL)))
http.HandleFunc("/control/filtering/remove_url", optionalAuth(ensureDELETE(handleFilteringRemoveURL)))
http.HandleFunc("/control/filtering/enable_url", optionalAuth(ensurePOST(handleFilteringEnableURL)))
http.HandleFunc("/control/filtering/disable_url", optionalAuth(ensurePOST(handleFilteringDisableURL)))
http.HandleFunc("/control/filtering/refresh", optionalAuth(ensurePOST(handleFilteringRefresh)))
http.HandleFunc("/control/filtering/status", optionalAuth(ensureGET(handleFilteringStatus)))
http.HandleFunc("/control/filtering/set_rules", optionalAuth(ensurePUT(handleFilteringSetRules)))
http.HandleFunc("/control/safebrowsing/enable", optionalAuth(ensurePOST(handleSafeBrowsingEnable)))
http.HandleFunc("/control/safebrowsing/disable", optionalAuth(ensurePOST(handleSafeBrowsingDisable)))
http.HandleFunc("/control/safebrowsing/status", optionalAuth(ensureGET(handleSafeBrowsingStatus)))
http.HandleFunc("/control/parental/enable", optionalAuth(ensurePOST(handleParentalEnable)))
http.HandleFunc("/control/parental/disable", optionalAuth(ensurePOST(handleParentalDisable)))
http.HandleFunc("/control/parental/status", optionalAuth(ensureGET(handleParentalStatus)))
http.HandleFunc("/control/safesearch/enable", optionalAuth(ensurePOST(handleSafeSearchEnable)))
http.HandleFunc("/control/safesearch/disable", optionalAuth(ensurePOST(handleSafeSearchDisable)))
http.HandleFunc("/control/safesearch/status", optionalAuth(ensureGET(handleSafeSearchStatus)))
http.HandleFunc("/control/dhcp/status", optionalAuth(ensureGET(handleDHCPStatus)))
http.HandleFunc("/control/dhcp/interfaces", optionalAuth(ensureGET(handleDHCPInterfaces)))
http.HandleFunc("/control/dhcp/set_config", optionalAuth(ensurePOST(handleDHCPSetConfig)))
http.HandleFunc("/control/dhcp/find_active_dhcp", optionalAuth(ensurePOST(handleDHCPFindActiveServer)))
http.HandleFunc("/control/status", postInstall(optionalAuth(ensureGET(handleStatus))))
http.HandleFunc("/control/enable_protection", postInstall(optionalAuth(ensurePOST(handleProtectionEnable))))
http.HandleFunc("/control/disable_protection", postInstall(optionalAuth(ensurePOST(handleProtectionDisable))))
http.HandleFunc("/control/querylog", postInstall(optionalAuth(ensureGET(dnsforward.HandleQueryLog))))
http.HandleFunc("/control/querylog_enable", postInstall(optionalAuth(ensurePOST(handleQueryLogEnable))))
http.HandleFunc("/control/querylog_disable", postInstall(optionalAuth(ensurePOST(handleQueryLogDisable))))
http.HandleFunc("/control/set_upstream_dns", postInstall(optionalAuth(ensurePOST(handleSetUpstreamDNS))))
http.HandleFunc("/control/test_upstream_dns", postInstall(optionalAuth(ensurePOST(handleTestUpstreamDNS))))
http.HandleFunc("/control/i18n/change_language", postInstall(optionalAuth(ensurePOST(handleI18nChangeLanguage))))
http.HandleFunc("/control/i18n/current_language", postInstall(optionalAuth(ensureGET(handleI18nCurrentLanguage))))
http.HandleFunc("/control/stats_top", postInstall(optionalAuth(ensureGET(dnsforward.HandleStatsTop))))
http.HandleFunc("/control/stats", postInstall(optionalAuth(ensureGET(dnsforward.HandleStats))))
http.HandleFunc("/control/stats_history", postInstall(optionalAuth(ensureGET(dnsforward.HandleStatsHistory))))
http.HandleFunc("/control/stats_reset", postInstall(optionalAuth(ensurePOST(dnsforward.HandleStatsReset))))
http.HandleFunc("/control/version.json", postInstall(optionalAuth(handleGetVersionJSON)))
http.HandleFunc("/control/filtering/enable", postInstall(optionalAuth(ensurePOST(handleFilteringEnable))))
http.HandleFunc("/control/filtering/disable", postInstall(optionalAuth(ensurePOST(handleFilteringDisable))))
http.HandleFunc("/control/filtering/add_url", postInstall(optionalAuth(ensurePUT(handleFilteringAddURL))))
http.HandleFunc("/control/filtering/remove_url", postInstall(optionalAuth(ensureDELETE(handleFilteringRemoveURL))))
http.HandleFunc("/control/filtering/enable_url", postInstall(optionalAuth(ensurePOST(handleFilteringEnableURL))))
http.HandleFunc("/control/filtering/disable_url", postInstall(optionalAuth(ensurePOST(handleFilteringDisableURL))))
http.HandleFunc("/control/filtering/refresh", postInstall(optionalAuth(ensurePOST(handleFilteringRefresh))))
http.HandleFunc("/control/filtering/status", postInstall(optionalAuth(ensureGET(handleFilteringStatus))))
http.HandleFunc("/control/filtering/set_rules", postInstall(optionalAuth(ensurePUT(handleFilteringSetRules))))
http.HandleFunc("/control/safebrowsing/enable", postInstall(optionalAuth(ensurePOST(handleSafeBrowsingEnable))))
http.HandleFunc("/control/safebrowsing/disable", postInstall(optionalAuth(ensurePOST(handleSafeBrowsingDisable))))
http.HandleFunc("/control/safebrowsing/status", postInstall(optionalAuth(ensureGET(handleSafeBrowsingStatus))))
http.HandleFunc("/control/parental/enable", postInstall(optionalAuth(ensurePOST(handleParentalEnable))))
http.HandleFunc("/control/parental/disable", postInstall(optionalAuth(ensurePOST(handleParentalDisable))))
http.HandleFunc("/control/parental/status", postInstall(optionalAuth(ensureGET(handleParentalStatus))))
http.HandleFunc("/control/safesearch/enable", postInstall(optionalAuth(ensurePOST(handleSafeSearchEnable))))
http.HandleFunc("/control/safesearch/disable", postInstall(optionalAuth(ensurePOST(handleSafeSearchDisable))))
http.HandleFunc("/control/safesearch/status", postInstall(optionalAuth(ensureGET(handleSafeSearchStatus))))
http.HandleFunc("/control/dhcp/status", postInstall(optionalAuth(ensureGET(handleDHCPStatus))))
http.HandleFunc("/control/dhcp/interfaces", postInstall(optionalAuth(ensureGET(handleDHCPInterfaces))))
http.HandleFunc("/control/dhcp/set_config", postInstall(optionalAuth(ensurePOST(handleDHCPSetConfig))))
http.HandleFunc("/control/dhcp/find_active_dhcp", postInstall(optionalAuth(ensurePOST(handleDHCPFindActiveServer))))
// TODO: move to registerInstallHandlers()
http.HandleFunc("/control/install/get_default_addresses", ensureGET(handleGetDefaultAddresses))
http.HandleFunc("/control/install/set_all_settings", ensurePOST(handleSetAllSettings))
http.HandleFunc("/control/install/get_default_addresses", preInstall(ensureGET(handleGetDefaultAddresses)))
http.HandleFunc("/control/install/set_all_settings", preInstall(ensurePOST(handleSetAllSettings)))
}

43
dhcp.go
View file

@ -4,7 +4,6 @@ import (
"encoding/json"
"fmt"
"io/ioutil"
"net"
"net/http"
"strings"
"time"
@ -70,50 +69,14 @@ func handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
func handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
response := map[string]interface{}{}
ifaces, err := net.Interfaces()
ifaces, err := getValidNetInterfaces()
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't get list of interfaces: %s", err)
httpError(w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
return
}
type responseInterface struct {
Name string `json:"name"`
MTU int `json:"mtu"`
HardwareAddr string `json:"hardware_address"`
Addresses []string `json:"ip_addresses"`
}
for i := range ifaces {
if ifaces[i].Flags&net.FlagLoopback != 0 {
// it's a loopback, skip it
continue
}
if ifaces[i].Flags&net.FlagBroadcast == 0 {
// this interface doesn't support broadcast, skip it
continue
}
if ifaces[i].Flags&net.FlagPointToPoint != 0 {
// this interface is ppp, don't do dhcp over it
continue
}
iface := responseInterface{
Name: ifaces[i].Name,
MTU: ifaces[i].MTU,
HardwareAddr: ifaces[i].HardwareAddr.String(),
}
addrs, errAddrs := ifaces[i].Addrs()
if errAddrs != nil {
httpError(w, http.StatusInternalServerError, "Failed to get addresses for interface %v: %s", ifaces[i].Name, errAddrs)
return
}
for _, addr := range addrs {
iface.Addresses = append(iface.Addresses, addr.String())
}
if len(iface.Addresses) == 0 {
// this interface has no addresses, skip it
continue
}
response[ifaces[i].Name] = iface
response[ifaces[i].Name] = ifaces[i]
}
err = json.NewEncoder(w).Encode(response)

View file

@ -3,14 +3,18 @@ package main
import (
"bufio"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"os"
"path"
"path/filepath"
"runtime"
"strings"
"github.com/hmage/golibs/log"
)
// ----------------------------------
@ -84,24 +88,78 @@ type authHandler struct {
}
func (a *authHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if config.AuthName == "" || config.AuthPass == "" {
a.handler.ServeHTTP(w, r)
return
}
user, pass, ok := r.BasicAuth()
if !ok || user != config.AuthName || pass != config.AuthPass {
w.Header().Set("WWW-Authenticate", `Basic realm="dnsfilter"`)
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("Unauthorised.\n"))
return
}
a.handler.ServeHTTP(w, r)
optionalAuth(a.handler.ServeHTTP)(w, r)
}
func optionalAuthHandler(handler http.Handler) http.Handler {
return &authHandler{handler}
}
// -------------------
// first run / install
// -------------------
func detectFirstRun() bool {
configfile := config.ourConfigFilename
if !filepath.IsAbs(configfile) {
configfile = filepath.Join(config.ourBinaryDir, config.ourConfigFilename)
}
_, err := os.Stat(configfile)
if !os.IsNotExist(err) {
// do nothing, file exists
return false
}
return true
}
// preInstall lets the handler run only if firstRun is true, no redirects
func preInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if !config.firstRun {
// if it's not first run, don't let users access it (for example /install.html when configuration is done)
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
handler(w, r)
}
}
// preInstallStruct wraps preInstall into a struct that can be returned as an interface where neccessary
type preInstallHandlerStruct struct {
handler http.Handler
}
func (p *preInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) {
preInstall(p.handler.ServeHTTP)(w, r)
}
// preInstallHandler returns http.Handler interface for preInstall wrapper
func preInstallHandler(handler http.Handler) http.Handler {
return &preInstallHandlerStruct{handler}
}
// postInstall lets the handler run only if firstRun is false, and redirects to /install.html otherwise
func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if config.firstRun && !strings.HasPrefix(r.URL.Path, "/install.") {
http.Redirect(w, r, "/install.html", http.StatusSeeOther) // should not be cacheable
return
}
handler(w, r)
}
}
type postInstallHandlerStruct struct {
handler http.Handler
}
func (p *postInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) {
postInstall(p.handler.ServeHTTP)(w, r)
}
func postInstallHandler(handler http.Handler) http.Handler {
return &postInstallHandlerStruct{handler}
}
// -------------------------------------------------
// helper functions for parsing parameters from body
// -------------------------------------------------
@ -125,6 +183,81 @@ func parseParametersFromBody(r io.Reader) (map[string]string, error) {
return parameters, nil
}
// ------------------
// network interfaces
// ------------------
type netInterface struct {
Name string `json:"name"`
MTU int `json:"mtu"`
HardwareAddr string `json:"hardware_address"`
Addresses []string `json:"ip_addresses"`
}
// getValidNetInterfaces() returns interfaces that are eligible for DNS and/or DHCP
// invalid interface is either a loopback, ppp interface, or the one that doesn't allow broadcasts
func getValidNetInterfaces() ([]netInterface, error) {
ifaces, err := net.Interfaces()
if err != nil {
return nil, fmt.Errorf("Couldn't get list of interfaces: %s", err)
}
netIfaces := []netInterface{}
for i := range ifaces {
if ifaces[i].Flags&net.FlagLoopback != 0 {
// it's a loopback, skip it
continue
}
if ifaces[i].Flags&net.FlagBroadcast == 0 {
// this interface doesn't support broadcast, skip it
continue
}
if ifaces[i].Flags&net.FlagPointToPoint != 0 {
// this interface is ppp, don't do dhcp over it
continue
}
iface := netInterface{
Name: ifaces[i].Name,
MTU: ifaces[i].MTU,
HardwareAddr: ifaces[i].HardwareAddr.String(),
}
addrs, err := ifaces[i].Addrs()
if err != nil {
return nil, fmt.Errorf("Failed to get addresses for interface %v: %s", ifaces[i].Name, err)
}
for _, addr := range addrs {
iface.Addresses = append(iface.Addresses, addr.String())
}
if len(iface.Addresses) == 0 {
// this interface has no addresses, skip it
continue
}
netIfaces = append(netIfaces, iface)
}
return netIfaces, nil
}
func findIPv4IfaceAddr(ifaces []netInterface) string {
for _, iface := range ifaces {
for _, addr := range iface.Addresses {
ip, _, err := net.ParseCIDR(addr)
if err != nil {
log.Printf("SHOULD NOT HAPPEN: got iface.Addresses element that's not a parseable CIDR: %s", addr)
continue
}
if ip.To4() == nil {
log.Tracef("Ignoring IP that isn't IPv4: %s", ip)
continue
}
return ip.To4().String()
}
}
return ""
}
// ---------------------
// debug logging helpers
// ---------------------