diff --git a/AGHTechDoc.md b/AGHTechDoc.md index 5d812e6b..c5d09f8a 100644 --- a/AGHTechDoc.md +++ b/AGHTechDoc.md @@ -138,10 +138,13 @@ Request: { "web":{"port":80,"ip":"192.168.11.33"}, "dns":{"port":53,"ip":"127.0.0.1","autofix":false}, + "set_static_ip": true | false } Server should check whether a port is available only in case it itself isn't already listening on that port. +If `set_static_ip` is `true`, Server attempts to set a static IP for the network interface chosen by `dns.ip` setting. If the operation is successful, `static_ip.static` setting will be `yes`. If it fails, `static_ip.static` setting will be set to `error` and `static_ip.error` will contain the error message. + Server replies on success: 200 OK @@ -149,7 +152,14 @@ Server replies on success: { "web":{"status":""}, "dns":{"status":""}, + "static_ip": { + "static": "yes|no|error", + "ip": "", // set if static=no + "error": "..." // set if static=error } + } + +If `static_ip.static` is `no`, Server has detected that the system uses a dynamic address and it can automatically set a static address if `set_static_ip` in request is `true`. See section `Static IP check/set` for detailed process. Server replies on error: @@ -172,7 +182,11 @@ Request: POST /control/install/check_config { - "dns":{"port":53,"ip":"127.0.0.1","autofix":false} + "dns":{ + "port":53, + "ip":"127.0.0.1", + "autofix":false + } } Check if DNSStubListener is enabled: @@ -499,13 +513,7 @@ which will print: default via 192.168.0.1 proto dhcp metric 100 -#### Phase 2 - -This method only works on Raspbian. - -On Ubuntu DHCP for a network interface can't be disabled via `dhcpcd.conf`. This must be configured in `/etc/netplan/01-netcfg.yaml`. - -Fedora doesn't use `dhcpcd.conf` configuration at all. +#### Phase 2 (Raspbian) Step 1. @@ -526,6 +534,44 @@ If we would set a different IP address, we'd need to replace the IP address for ip addr replace dev eth0 192.168.0.1/24 +#### Phase 2 (Ubuntu) + +`/etc/netplan/01-netcfg.yaml` or `/etc/netplan/01-network-manager-all.yaml` + +This configuration example has a static IP set for `enp0s3` interface: + + network: + version: 2 + renderer: networkd + ethernets: + enp0s3: + dhcp4: no + addresses: [192.168.0.2/24] + gateway: 192.168.0.1 + nameservers: + addresses: [192.168.0.1,8.8.8.8] + +For dynamic configuration `dhcp4: yes` is set. + +Make a backup copy to `/etc/netplan/01-netcfg.yaml.backup`. + +Apply: + + netplan apply + +Restart network: + + systemctl restart networking + +or: + + systemctl restart network-manager + +or: + + systemctl restart system-networkd + + ### Add a static lease Request: diff --git a/client/src/__locales/en.json b/client/src/__locales/en.json index 961b47c7..ee17dbc1 100644 --- a/client/src/__locales/en.json +++ b/client/src/__locales/en.json @@ -458,9 +458,16 @@ "check_reason": "Reason: {{reason}}", "check_rule": "Rule: {{rule}}", "check_service": "Service name: {{service}}", - "check_not_found": "Doesn't exist in any filter", + "check_not_found": "Not found in your filter lists", "client_confirm_block": "Are you sure you want to block the client \"{{ip}}\"?", "client_confirm_unblock": "Are you sure you want to unblock the client \"{{ip}}\"?", "client_blocked": "Client \"{{ip}}\" successfully blocked", - "client_unblocked": "Client \"{{ip}}\" successfully unblocked" + "client_unblocked": "Client \"{{ip}}\" successfully unblocked", + "static_ip": "Static IP Address", + "static_ip_desc": "AdGuard Home is a server so it needs a static IP address to function properly. Otherwise, at some point, your router may assign a different IP address to this device.", + "set_static_ip": "Set a static IP address", + "install_static_ok": "Good news! The static IP address is already configured", + "install_static_error": "AdGuard Home cannot configure it automatically for this network interface. Please look for an instruction on how to do this manually.", + "install_static_configure": "We have detected that a dynamic IP address is used — <0>{{ip}}. Do you want to use it as your static address?", + "confirm_static_ip": "AdGuard Home will configure {{ip}} to be your static IP address. Do you want to proceed?" } diff --git a/client/src/helpers/form.js b/client/src/helpers/form.js index 0c617ff5..7aed918c 100644 --- a/client/src/helpers/form.js +++ b/client/src/helpers/form.js @@ -240,6 +240,13 @@ export const port = (value) => { return undefined; }; +export const validInstallPort = (value) => { + if (value < 1 || value > 65535) { + return form_error_port; + } + return undefined; +}; + export const portTLS = (value) => { if (value === 0) { return undefined; diff --git a/client/src/install/Setup/Settings.js b/client/src/install/Setup/Settings.js index 246206ba..876aa05b 100644 --- a/client/src/install/Setup/Settings.js +++ b/client/src/install/Setup/Settings.js @@ -7,26 +7,17 @@ import flow from 'lodash/flow'; import Controls from './Controls'; import AddressList from './AddressList'; + import { getInterfaceIp } from '../../helpers/helpers'; import { ALL_INTERFACES_IP } from '../../helpers/constants'; -import { renderInputField } from '../../helpers/form'; +import { renderInputField, required, validInstallPort, toNumber } from '../../helpers/form'; -const required = (value) => { - if (value || value === 0) { - return false; - } - return form_error_required; +const STATIC_STATUS = { + ENABLED: 'yes', + DISABLED: 'no', + ERROR: 'error', }; -const port = (value) => { - if (value < 1 || value > 65535) { - return form_error_port; - } - return false; -}; - -const toNumber = value => value && parseInt(value, 10); - const renderInterfaces = (interfaces => ( Object.keys(interfaces).map((item) => { const option = interfaces[item]; @@ -79,11 +70,91 @@ class Settings extends Component { }); } + getStaticIpMessage = (staticIp) => { + const { static: status, ip } = staticIp; + + if (!status) { + return ''; + } + + return ( + + {status === STATIC_STATUS.DISABLED && ( + +
+ text]}> + install_static_configure + +
+ +
+ )} + {status === STATIC_STATUS.ERROR && ( +
+ install_static_error +
+ )} + {status === STATIC_STATUS.ENABLED && ( +
+ + install_static_ok + +
+ )} +
+ ); + }; + + handleAutofix = (type) => { + const { + webIp, + webPort, + dnsIp, + dnsPort, + handleFix, + } = this.props; + + const web = { ip: webIp, port: webPort, autofix: false }; + const dns = { ip: dnsIp, port: dnsPort, autofix: false }; + const set_static_ip = false; + + if (type === 'web') { + web.autofix = true; + } else { + dns.autofix = true; + } + + handleFix(web, dns, set_static_ip); + }; + + handleStaticIp = (ip) => { + const { + webIp, + webPort, + dnsIp, + dnsPort, + handleFix, + } = this.props; + + const web = { ip: webIp, port: webPort, autofix: false }; + const dns = { ip: dnsIp, port: dnsPort, autofix: false }; + const set_static_ip = true; + + if (window.confirm(this.props.t('confirm_static_ip', { ip }))) { + handleFix(web, dns, set_static_ip); + } + }; + render() { const { handleSubmit, handleChange, - handleAutofix, webIp, webPort, dnsIp, @@ -100,6 +171,7 @@ class Settings extends Component { status: dnsStatus, can_autofix: isDnsFixAvailable, } = config.dns; + const { staticIp } = config; return (
@@ -137,7 +209,7 @@ class Settings extends Component { type="number" className="form-control" placeholder="80" - validate={[port, required]} + validate={[validInstallPort, required]} normalize={toNumber} onChange={handleChange} /> @@ -151,11 +223,12 @@ class Settings extends Component { } +
} @@ -171,6 +244,7 @@ class Settings extends Component { +
install_settings_dns @@ -205,7 +279,7 @@ class Settings extends Component { type="number" className="form-control" placeholder="80" - validate={[port, required]} + validate={[validInstallPort, required]} normalize={toNumber} onChange={handleChange} /> @@ -220,7 +294,7 @@ class Settings extends Component { @@ -237,6 +311,7 @@ class Settings extends Component { autofix_warning_result

+
}
@@ -253,6 +328,19 @@ class Settings extends Component { + +
+
+ static_ip +
+ +
+ static_ip_desc +
+ + {this.getStaticIpMessage(staticIp)} +
+ ); @@ -262,7 +350,7 @@ class Settings extends Component { Settings.propTypes = { handleSubmit: PropTypes.func.isRequired, handleChange: PropTypes.func, - handleAutofix: PropTypes.func, + handleFix: PropTypes.func.isRequired, validateForm: PropTypes.func, webIp: PropTypes.string.isRequired, dnsIp: PropTypes.string.isRequired, @@ -278,6 +366,7 @@ Settings.propTypes = { interfaces: PropTypes.object.isRequired, invalid: PropTypes.bool.isRequired, initialValues: PropTypes.object, + t: PropTypes.func.isRequired, }; const selector = formValueSelector('install'); diff --git a/client/src/install/Setup/Setup.css b/client/src/install/Setup/Setup.css index 11ee1430..aac7ea0e 100644 --- a/client/src/install/Setup/Setup.css +++ b/client/src/install/Setup/Setup.css @@ -119,3 +119,8 @@ .setup__error { margin: -5px 0 5px; } + +.divider--small { + margin-top: 1rem; + margin-bottom: 1rem; +} diff --git a/client/src/install/Setup/index.js b/client/src/install/Setup/index.js index ca91cd8f..82d8f84b 100644 --- a/client/src/install/Setup/index.js +++ b/client/src/install/Setup/index.js @@ -33,31 +33,19 @@ class Setup extends Component { } handleFormSubmit = (values) => { - this.props.setAllSettings(values); + const { staticIp, ...config } = values; + this.props.setAllSettings(config); }; handleFormChange = debounce((values) => { - if (values && values.web.port && values.dns.port) { - this.props.checkConfig(values); + const { web, dns } = values; + if (values && web.port && dns.port) { + this.props.checkConfig({ web, dns, set_static_ip: false }); } }, DEBOUNCE_TIMEOUT); - handleAutofix = (type, ip, port) => { - const data = { - ip, - port, - autofix: true, - }; - - if (type === 'web') { - this.props.checkConfig({ - web: { ...data }, - }); - } else { - this.props.checkConfig({ - dns: { ...data }, - }); - } + handleFix = (web, dns, set_static_ip) => { + this.props.checkConfig({ web, dns, set_static_ip }); }; openDashboard = (ip, port) => { @@ -95,7 +83,7 @@ class Setup extends Component { onSubmit={this.nextStep} onChange={this.handleFormChange} validateForm={this.handleFormChange} - handleAutofix={this.handleAutofix} + handleFix={this.handleFix} /> ); case 3: @@ -117,6 +105,7 @@ class Setup extends Component { step, web, dns, + staticIp, interfaces, } = this.props.install; @@ -128,7 +117,7 @@ class Setup extends Component {
logo - {this.renderPage(step, { web, dns }, interfaces)} + {this.renderPage(step, { web, dns, staticIp }, interfaces)}
diff --git a/client/src/reducers/install.js b/client/src/reducers/install.js index 3709b0ec..b3f95dfb 100644 --- a/client/src/reducers/install.js +++ b/client/src/reducers/install.js @@ -32,9 +32,10 @@ const install = handleActions({ [actions.checkConfigSuccess]: (state, { payload }) => { const web = { ...state.web, ...payload.web }; const dns = { ...state.dns, ...payload.dns }; + const staticIp = { ...state.staticIp, ...payload.static_ip }; const newState = { - ...state, web, dns, processingCheck: false, + ...state, web, dns, staticIp, processingCheck: false, }; return newState; }, @@ -55,6 +56,11 @@ const install = handleActions({ status: '', can_autofix: false, }, + staticIp: { + static: '', + ip: '', + error: '', + }, interfaces: {}, }); diff --git a/dhcpd/dhcp_http.go b/dhcpd/dhcp_http.go index e1b3d4fb..9105d76b 100644 --- a/dhcpd/dhcp_http.go +++ b/dhcpd/dhcp_http.go @@ -2,18 +2,16 @@ package dhcpd import ( "encoding/json" - "errors" "fmt" "io/ioutil" "net" "net/http" "os" - "os/exec" - "runtime" "strings" "time" - "github.com/AdguardTeam/golibs/file" + "github.com/AdguardTeam/AdGuardHome/util" + "github.com/AdguardTeam/golibs/log" ) @@ -97,10 +95,9 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) { s.conf.ConfigModified() if newconfig.Enabled { - - staticIP, err := hasStaticIP(newconfig.InterfaceName) + staticIP, err := HasStaticIP(newconfig.InterfaceName) if !staticIP && err == nil { - err = setStaticIP(newconfig.InterfaceName) + err = SetStaticIP(newconfig.InterfaceName) if err != nil { httpError(r, w, http.StatusInternalServerError, "Failed to configure static IP: %s", err) return @@ -115,7 +112,7 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) { } } -type netInterface struct { +type netInterfaceJSON struct { Name string `json:"name"` MTU int `json:"mtu"` HardwareAddr string `json:"hardware_address"` @@ -123,33 +120,10 @@ type netInterface struct { Flags string `json:"flags"` } -// getValidNetInterfaces returns interfaces that are eligible for DNS and/or DHCP -// invalid interface is a ppp interface or the one that doesn't allow broadcasts -func getValidNetInterfaces() ([]net.Interface, error) { - ifaces, err := net.Interfaces() - if err != nil { - return nil, fmt.Errorf("Couldn't get list of interfaces: %s", err) - } - - netIfaces := []net.Interface{} - - for i := range ifaces { - if ifaces[i].Flags&net.FlagPointToPoint != 0 { - // this interface is ppp, we're not interested in this one - continue - } - - iface := ifaces[i] - netIfaces = append(netIfaces, iface) - } - - return netIfaces, nil -} - func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) { response := map[string]interface{}{} - ifaces, err := getValidNetInterfaces() + ifaces, err := util.GetValidNetInterfaces() if err != nil { httpError(r, w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err) return @@ -170,7 +144,7 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) { return } - jsonIface := netInterface{ + jsonIface := netInterfaceJSON{ Name: iface.Name, MTU: iface.MTU, HardwareAddr: iface.HardwareAddr.String(), @@ -240,14 +214,14 @@ func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque othSrv["found"] = foundVal staticIP := map[string]interface{}{} - isStaticIP, err := hasStaticIP(interfaceName) + isStaticIP, err := HasStaticIP(interfaceName) staticIPStatus := "yes" if err != nil { staticIPStatus = "error" staticIP["error"] = err.Error() } else if !isStaticIP { staticIPStatus = "no" - staticIP["ip"] = getFullIP(interfaceName) + staticIP["ip"] = util.GetSubnet(interfaceName) } staticIP["static"] = staticIPStatus @@ -263,137 +237,6 @@ func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque } } -// Check if network interface has a static IP configured -func hasStaticIP(ifaceName string) (bool, error) { - if runtime.GOOS == "windows" { - return false, errors.New("Can't detect static IP: not supported on Windows") - } - - body, err := ioutil.ReadFile("/etc/dhcpcd.conf") - if err != nil { - return false, err - } - lines := strings.Split(string(body), "\n") - nameLine := fmt.Sprintf("interface %s", ifaceName) - withinInterfaceCtx := false - - for _, line := range lines { - line = strings.TrimSpace(line) - - if withinInterfaceCtx && len(line) == 0 { - // an empty line resets our state - withinInterfaceCtx = false - } - - if len(line) == 0 || line[0] == '#' { - continue - } - line = strings.TrimSpace(line) - - if !withinInterfaceCtx { - if line == nameLine { - // we found our interface - withinInterfaceCtx = true - } - - } else { - if strings.HasPrefix(line, "interface ") { - // we found another interface - reset our state - withinInterfaceCtx = false - continue - } - if strings.HasPrefix(line, "static ip_address=") { - return true, nil - } - } - } - - return false, nil -} - -// Get IP address with netmask -func getFullIP(ifaceName string) string { - cmd := exec.Command("ip", "-oneline", "-family", "inet", "address", "show", ifaceName) - log.Tracef("executing %s %v", cmd.Path, cmd.Args) - d, err := cmd.Output() - if err != nil || cmd.ProcessState.ExitCode() != 0 { - return "" - } - - fields := strings.Fields(string(d)) - if len(fields) < 4 { - return "" - } - _, _, err = net.ParseCIDR(fields[3]) - if err != nil { - return "" - } - - return fields[3] -} - -// Get gateway IP address -func getGatewayIP(ifaceName string) string { - cmd := exec.Command("ip", "route", "show", "dev", ifaceName) - log.Tracef("executing %s %v", cmd.Path, cmd.Args) - d, err := cmd.Output() - if err != nil || cmd.ProcessState.ExitCode() != 0 { - return "" - } - - fields := strings.Fields(string(d)) - if len(fields) < 3 || fields[0] != "default" { - return "" - } - - ip := net.ParseIP(fields[2]) - if ip == nil { - return "" - } - - return fields[2] -} - -// Set a static IP for network interface -func setStaticIP(ifaceName string) error { - ip := getFullIP(ifaceName) - if len(ip) == 0 { - return errors.New("Can't get IP address") - } - - body, err := ioutil.ReadFile("/etc/dhcpcd.conf") - if err != nil { - return err - } - - ip4, _, err := net.ParseCIDR(ip) - if err != nil { - return err - } - - add := fmt.Sprintf("\ninterface %s\nstatic ip_address=%s\n", - ifaceName, ip) - body = append(body, []byte(add)...) - - gatewayIP := getGatewayIP(ifaceName) - if len(gatewayIP) != 0 { - add = fmt.Sprintf("static routers=%s\n", - gatewayIP) - body = append(body, []byte(add)...) - } - - add = fmt.Sprintf("static domain_name_servers=%s\n\n", - ip4) - body = append(body, []byte(add)...) - - err = file.SafeWrite("/etc/dhcpcd.conf", body) - if err != nil { - return err - } - - return nil -} - func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) { lj := staticLeaseJSON{} diff --git a/dhcpd/network_utils.go b/dhcpd/network_utils.go new file mode 100644 index 00000000..79c13285 --- /dev/null +++ b/dhcpd/network_utils.go @@ -0,0 +1,312 @@ +package dhcpd + +import ( + "errors" + "fmt" + "io/ioutil" + "net" + "os/exec" + "regexp" + "runtime" + "strings" + + "github.com/AdguardTeam/AdGuardHome/util" + + "github.com/AdguardTeam/golibs/file" + + "github.com/AdguardTeam/golibs/log" +) + +// Check if network interface has a static IP configured +// Supports: Raspbian. +func HasStaticIP(ifaceName string) (bool, error) { + if runtime.GOOS == "linux" { + body, err := ioutil.ReadFile("/etc/dhcpcd.conf") + if err != nil { + return false, err + } + + return hasStaticIPDhcpcdConf(string(body), ifaceName), nil + } + + if runtime.GOOS == "darwin" { + return hasStaticIPDarwin(ifaceName) + } + + return false, fmt.Errorf("Cannot check if IP is static: not supported on %s", runtime.GOOS) +} + +// Set a static IP for the specified network interface +func SetStaticIP(ifaceName string) error { + if runtime.GOOS == "linux" { + return setStaticIPDhcpdConf(ifaceName) + } + + if runtime.GOOS == "darwin" { + return setStaticIPDarwin(ifaceName) + } + + return fmt.Errorf("Cannot set static IP on %s", runtime.GOOS) +} + +// for dhcpcd.conf +func hasStaticIPDhcpcdConf(dhcpConf, ifaceName string) bool { + lines := strings.Split(dhcpConf, "\n") + nameLine := fmt.Sprintf("interface %s", ifaceName) + withinInterfaceCtx := false + + for _, line := range lines { + line = strings.TrimSpace(line) + + if withinInterfaceCtx && len(line) == 0 { + // an empty line resets our state + withinInterfaceCtx = false + } + + if len(line) == 0 || line[0] == '#' { + continue + } + line = strings.TrimSpace(line) + + if !withinInterfaceCtx { + if line == nameLine { + // we found our interface + withinInterfaceCtx = true + } + + } else { + if strings.HasPrefix(line, "interface ") { + // we found another interface - reset our state + withinInterfaceCtx = false + continue + } + if strings.HasPrefix(line, "static ip_address=") { + return true + } + } + } + return false +} + +// Get gateway IP address +func getGatewayIP(ifaceName string) string { + cmd := exec.Command("ip", "route", "show", "dev", ifaceName) + log.Tracef("executing %s %v", cmd.Path, cmd.Args) + d, err := cmd.Output() + if err != nil || cmd.ProcessState.ExitCode() != 0 { + return "" + } + + fields := strings.Fields(string(d)) + if len(fields) < 3 || fields[0] != "default" { + return "" + } + + ip := net.ParseIP(fields[2]) + if ip == nil { + return "" + } + + return fields[2] +} + +// setStaticIPDhcpdConf - updates /etc/dhcpd.conf and sets the current IP address to be static +func setStaticIPDhcpdConf(ifaceName string) error { + ip := util.GetSubnet(ifaceName) + if len(ip) == 0 { + return errors.New("Can't get IP address") + } + + ip4, _, err := net.ParseCIDR(ip) + if err != nil { + return err + } + gatewayIP := getGatewayIP(ifaceName) + add := updateStaticIPDhcpcdConf(ifaceName, ip, gatewayIP, ip4.String()) + + body, err := ioutil.ReadFile("/etc/dhcpcd.conf") + if err != nil { + return err + } + + body = append(body, []byte(add)...) + err = file.SafeWrite("/etc/dhcpcd.conf", body) + if err != nil { + return err + } + + return nil +} + +// updates dhcpd.conf content -- sets static IP address there +// for dhcpcd.conf +func updateStaticIPDhcpcdConf(ifaceName, ip, gatewayIP, dnsIP string) string { + var body []byte + + add := fmt.Sprintf("\ninterface %s\nstatic ip_address=%s\n", + ifaceName, ip) + body = append(body, []byte(add)...) + + if len(gatewayIP) != 0 { + add = fmt.Sprintf("static routers=%s\n", + gatewayIP) + body = append(body, []byte(add)...) + } + + add = fmt.Sprintf("static domain_name_servers=%s\n\n", + dnsIP) + body = append(body, []byte(add)...) + + return string(body) +} + +// Check if network interface has a static IP configured +// Supports: MacOS. +func hasStaticIPDarwin(ifaceName string) (bool, error) { + portInfo, err := getCurrentHardwarePortInfo(ifaceName) + if err != nil { + return false, err + } + + return portInfo.static, nil +} + +// setStaticIPDarwin - uses networksetup util to set the current IP address to be static +// Additionally it configures the current DNS servers as well +func setStaticIPDarwin(ifaceName string) error { + portInfo, err := getCurrentHardwarePortInfo(ifaceName) + if err != nil { + return err + } + + if portInfo.static { + return errors.New("IP address is already static") + } + + dnsAddrs, err := getEtcResolvConfServers() + if err != nil { + return err + } + + args := make([]string, 0) + args = append(args, "-setdnsservers", portInfo.name) + args = append(args, dnsAddrs...) + + // Setting DNS servers is necessary when configuring a static IP + code, _, err := util.RunCommand("networksetup", args...) + if err != nil { + return err + } + if code != 0 { + return fmt.Errorf("Failed to set DNS servers, code=%d", code) + } + + // Actually configures hardware port to have static IP + code, _, err = util.RunCommand("networksetup", "-setmanual", + portInfo.name, portInfo.ip, portInfo.subnet, portInfo.gatewayIP) + if err != nil { + return err + } + if code != 0 { + return fmt.Errorf("Failed to set DNS servers, code=%d", code) + } + + return nil +} + +// getCurrentHardwarePortInfo gets information the specified network interface +func getCurrentHardwarePortInfo(ifaceName string) (hardwarePortInfo, error) { + // First of all we should find hardware port name + m := getNetworkSetupHardwareReports() + hardwarePort, ok := m[ifaceName] + if !ok { + return hardwarePortInfo{}, fmt.Errorf("Could not find hardware port for %s", ifaceName) + } + + return getHardwarePortInfo(hardwarePort) +} + +// getNetworkSetupHardwareReports parses the output of the `networksetup -listallhardwareports` command +// it returns a map where the key is the interface name, and the value is the "hardware port" +// returns nil if it fails to parse the output +func getNetworkSetupHardwareReports() map[string]string { + _, out, err := util.RunCommand("networksetup", "-listallhardwareports") + if err != nil { + return nil + } + + re, err := regexp.Compile("Hardware Port: (.*?)\nDevice: (.*?)\n") + if err != nil { + return nil + } + + m := make(map[string]string, 0) + + matches := re.FindAllStringSubmatch(out, -1) + for i := range matches { + port := matches[i][1] + device := matches[i][2] + m[device] = port + } + + return m +} + +// hardwarePortInfo - information obtained using MacOS networksetup +// about the current state of the internet connection +type hardwarePortInfo struct { + name string + ip string + subnet string + gatewayIP string + static bool +} + +func getHardwarePortInfo(hardwarePort string) (hardwarePortInfo, error) { + h := hardwarePortInfo{} + + _, out, err := util.RunCommand("networksetup", "-getinfo", hardwarePort) + if err != nil { + return h, err + } + + re := regexp.MustCompile("IP address: (.*?)\nSubnet mask: (.*?)\nRouter: (.*?)\n") + + match := re.FindStringSubmatch(out) + if len(match) == 0 { + return h, errors.New("Could not find hardware port info") + } + + h.name = hardwarePort + h.ip = match[1] + h.subnet = match[2] + h.gatewayIP = match[3] + + if strings.Index(out, "Manual Configuration") == 0 { + h.static = true + } + + return h, nil +} + +// Gets a list of nameservers currently configured in the /etc/resolv.conf +func getEtcResolvConfServers() ([]string, error) { + body, err := ioutil.ReadFile("/etc/resolv.conf") + if err != nil { + return nil, err + } + + re := regexp.MustCompile("nameserver ([a-zA-Z0-9.:]+)") + + matches := re.FindAllStringSubmatch(string(body), -1) + if len(matches) == 0 { + return nil, errors.New("Found no DNS servers in /etc/resolv.conf") + } + + addrs := make([]string, 0) + for i := range matches { + addrs = append(addrs, matches[i][1]) + } + + return addrs, nil +} diff --git a/dhcpd/network_utils_test.go b/dhcpd/network_utils_test.go new file mode 100644 index 00000000..2957a411 --- /dev/null +++ b/dhcpd/network_utils_test.go @@ -0,0 +1,61 @@ +package dhcpd + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestHasStaticIPDhcpcdConf(t *testing.T) { + dhcpdConf := `#comment +# comment + +interface eth0 +static ip_address=192.168.0.1/24 + +# interface wlan0 +static ip_address=192.168.1.1/24 + +# comment +` + assert.True(t, !hasStaticIPDhcpcdConf(dhcpdConf, "wlan0")) + + dhcpdConf = `#comment +# comment + +interface eth0 +static ip_address=192.168.0.1/24 + +# interface wlan0 +static ip_address=192.168.1.1/24 + +# comment + +interface wlan0 +# comment +static ip_address=192.168.2.1/24 +` + assert.True(t, hasStaticIPDhcpcdConf(dhcpdConf, "wlan0")) +} + +func TestSetStaticIPDhcpcdConf(t *testing.T) { + dhcpcdConf := ` +interface wlan0 +static ip_address=192.168.0.2/24 +static routers=192.168.0.1 +static domain_name_servers=192.168.0.2 + +` + s := updateStaticIPDhcpcdConf("wlan0", "192.168.0.2/24", "192.168.0.1", "192.168.0.2") + assert.Equal(t, dhcpcdConf, s) + + // without gateway + dhcpcdConf = ` +interface wlan0 +static ip_address=192.168.0.2/24 +static domain_name_servers=192.168.0.2 + +` + s = updateStaticIPDhcpcdConf("wlan0", "192.168.0.2/24", "", "192.168.0.2") + assert.Equal(t, dhcpcdConf, s) +} diff --git a/home/auth.go b/home/auth.go index 9afe2c87..e6e4642a 100644 --- a/home/auth.go +++ b/home/auth.go @@ -152,7 +152,7 @@ func (a *Auth) addSession(data []byte, s *session) { a.sessions[name] = s a.lock.Unlock() if a.storeSession(data, s) { - log.Info("Auth: created session %s: expire=%d", name, s.expire) + log.Debug("Auth: created session %s: expire=%d", name, s.expire) } } @@ -307,7 +307,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) { return } - cookie := config.auth.httpCookie(req) + cookie := Context.auth.httpCookie(req) if len(cookie) == 0 { log.Info("Auth: invalid user name or password: name='%s'", req.Name) time.Sleep(1 * time.Second) @@ -328,7 +328,7 @@ func handleLogout(w http.ResponseWriter, r *http.Request) { cookie := r.Header.Get("Cookie") sess := parseCookie(cookie) - config.auth.RemoveSession(sess) + Context.auth.RemoveSession(sess) w.Header().Set("Location", "/login.html") @@ -365,10 +365,10 @@ func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.Re if r.URL.Path == "/login.html" { // redirect to dashboard if already authenticated - authRequired := config.auth != nil && config.auth.AuthRequired() + authRequired := Context.auth != nil && Context.auth.AuthRequired() cookie, err := r.Cookie(sessionCookieName) if authRequired && err == nil { - r := config.auth.CheckSession(cookie.Value) + r := Context.auth.CheckSession(cookie.Value) if r == 0 { w.Header().Set("Location", "/") w.WriteHeader(http.StatusFound) @@ -383,12 +383,12 @@ func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.Re strings.HasPrefix(r.URL.Path, "/__locales/") { // process as usual - } else if config.auth != nil && config.auth.AuthRequired() { + } else if Context.auth != nil && Context.auth.AuthRequired() { // redirect to login page if not authenticated ok := false cookie, err := r.Cookie(sessionCookieName) if err == nil { - r := config.auth.CheckSession(cookie.Value) + r := Context.auth.CheckSession(cookie.Value) if r == 0 { ok = true } else if r < 0 { @@ -398,7 +398,7 @@ func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.Re // there's no Cookie, check Basic authentication user, pass, ok2 := r.BasicAuth() if ok2 { - u := config.auth.UserFind(user, pass) + u := Context.auth.UserFind(user, pass) if len(u.Name) != 0 { ok = true } else { @@ -474,7 +474,7 @@ func (a *Auth) GetCurrentUser(r *http.Request) User { // there's no Cookie, check Basic authentication user, pass, ok := r.BasicAuth() if ok { - u := config.auth.UserFind(user, pass) + u := Context.auth.UserFind(user, pass) return u } return User{} diff --git a/home/auth_test.go b/home/auth_test.go index 19cd5001..38f826ec 100644 --- a/home/auth_test.go +++ b/home/auth_test.go @@ -100,7 +100,7 @@ func TestAuthHTTP(t *testing.T) { users := []User{ User{Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"}, } - config.auth = InitAuth(fn, users, 60) + Context.auth = InitAuth(fn, users, 60) handlerCalled := false handler := func(w http.ResponseWriter, r *http.Request) { @@ -129,7 +129,7 @@ func TestAuthHTTP(t *testing.T) { assert.True(t, handlerCalled) // perform login - cookie := config.auth.httpCookie(loginJSON{Name: "name", Password: "password"}) + cookie := Context.auth.httpCookie(loginJSON{Name: "name", Password: "password"}) assert.True(t, cookie != "") // get / @@ -173,5 +173,5 @@ func TestAuthHTTP(t *testing.T) { assert.True(t, handlerCalled) r.Header.Del("Cookie") - config.auth.Close() + Context.auth.Close() } diff --git a/home/config.go b/home/config.go index 693c4a40..ab875fdc 100644 --- a/home/config.go +++ b/home/config.go @@ -44,19 +44,6 @@ type configuration struct { // It's reset after config is parsed fileData []byte - ourConfigFilename string // Config filename (can be overridden via the command line arguments) - ourWorkingDir string // Location of our directory, used to protect against CWD being somewhere else - firstRun bool // if set to true, don't run any services except HTTP web inteface, and serve only first-run html - pidFileName string // PID file name. Empty if no PID file was created. - // runningAsService flag is set to true when options are passed from the service runner - runningAsService bool - disableUpdate bool // If set, don't check for updates - appSignalChannel chan os.Signal - controlLock sync.Mutex - transport *http.Transport - client *http.Client - auth *Auth // HTTP authentication module - // cached version.json to avoid hammering github.io for each page reload versionCheckJSON []byte versionCheckLastTime time.Time @@ -152,9 +139,8 @@ type tlsConfig struct { // initialize to default values, will be changed later when reading config or parsing command line var config = configuration{ - ourConfigFilename: "AdGuardHome.yaml", - BindPort: 3000, - BindHost: "0.0.0.0", + BindPort: 3000, + BindHost: "0.0.0.0", DNS: dnsConfig{ BindHost: "0.0.0.0", Port: 53, @@ -185,14 +171,6 @@ var config = configuration{ // initConfig initializes default configuration for the current OS&ARCH func initConfig() { - config.transport = &http.Transport{ - DialContext: customDialContext, - } - config.client = &http.Client{ - Timeout: time.Minute * 5, - Transport: config.transport, - } - config.WebSessionTTLHours = 30 * 24 config.DNS.QueryLogEnabled = true @@ -209,24 +187,19 @@ func initConfig() { // getConfigFilename returns path to the current config file func (c *configuration) getConfigFilename() string { - configFile, err := filepath.EvalSymlinks(config.ourConfigFilename) + configFile, err := filepath.EvalSymlinks(Context.configFilename) if err != nil { if !os.IsNotExist(err) { log.Error("unexpected error while config file path evaluation: %s", err) } - configFile = config.ourConfigFilename + configFile = Context.configFilename } if !filepath.IsAbs(configFile) { - configFile = filepath.Join(config.ourWorkingDir, configFile) + configFile = filepath.Join(Context.workDir, configFile) } return configFile } -// getDataDir returns path to the directory where we store databases and filters -func (c *configuration) getDataDir() string { - return filepath.Join(c.ourWorkingDir, dataDir) -} - // getLogSettings reads logging settings from the config file. // we do it in a separate method in order to configure logger before the actual configuration is parsed and applied. func getLogSettings() logSettings { @@ -292,8 +265,8 @@ func (c *configuration) write() error { Context.clients.WriteDiskConfig(&config.Clients) - if config.auth != nil { - config.Users = config.auth.GetUsers() + if Context.auth != nil { + config.Users = Context.auth.GetUsers() } if Context.stats != nil { diff --git a/home/control.go b/home/control.go index 87247190..031f7a36 100644 --- a/home/control.go +++ b/home/control.go @@ -3,7 +3,13 @@ package home import ( "encoding/json" "fmt" + "net" "net/http" + "net/url" + "strconv" + "strings" + + "github.com/AdguardTeam/AdGuardHome/util" "github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/AdguardTeam/golibs/log" @@ -54,8 +60,7 @@ func getDNSAddresses() []string { dnsAddresses := []string{} if config.DNS.BindHost == "0.0.0.0" { - - ifaces, e := getValidNetInterfacesForWeb() + ifaces, e := util.GetValidNetInterfacesForWeb() if e != nil { log.Error("Couldn't get network interfaces: %v", e) return []string{} @@ -66,7 +71,6 @@ func getDNSAddresses() []string { addDNSAddress(&dnsAddresses, addr) } } - } else { addDNSAddress(&dnsAddresses, config.DNS.BindHost) } @@ -129,7 +133,7 @@ type profileJSON struct { func handleGetProfile(w http.ResponseWriter, r *http.Request) { pj := profileJSON{} - u := config.auth.GetCurrentUser(r) + u := Context.auth.GetCurrentUser(r) pj.Name = u.Name data, err := json.Marshal(pj) @@ -180,3 +184,118 @@ func registerControlHandlers() { func httpRegister(method string, url string, handler func(http.ResponseWriter, *http.Request)) { http.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler))))) } + +// ---------------------------------- +// helper functions for HTTP handlers +// ---------------------------------- +func ensure(method string, handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + log.Debug("%s %v", r.Method, r.URL) + + if r.Method != method { + http.Error(w, "This request must be "+method, http.StatusMethodNotAllowed) + return + } + + if method == "POST" || method == "PUT" || method == "DELETE" { + Context.controlLock.Lock() + defer Context.controlLock.Unlock() + } + + handler(w, r) + } +} + +func ensurePOST(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { + return ensure("POST", handler) +} + +func ensureGET(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { + return ensure("GET", handler) +} + +// Bridge between http.Handler object and Go function +type httpHandler struct { + handler func(http.ResponseWriter, *http.Request) +} + +func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + h.handler(w, r) +} + +func ensureHandler(method string, handler func(http.ResponseWriter, *http.Request)) http.Handler { + h := httpHandler{} + h.handler = ensure(method, handler) + return &h +} + +// preInstall lets the handler run only if firstRun is true, no redirects +func preInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + if !Context.firstRun { + // if it's not first run, don't let users access it (for example /install.html when configuration is done) + http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) + return + } + handler(w, r) + } +} + +// preInstallStruct wraps preInstall into a struct that can be returned as an interface where necessary +type preInstallHandlerStruct struct { + handler http.Handler +} + +func (p *preInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) { + preInstall(p.handler.ServeHTTP)(w, r) +} + +// preInstallHandler returns http.Handler interface for preInstall wrapper +func preInstallHandler(handler http.Handler) http.Handler { + return &preInstallHandlerStruct{handler} +} + +// postInstall lets the handler run only if firstRun is false, and redirects to /install.html otherwise +// it also enforces HTTPS if it is enabled and configured +func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + if Context.firstRun && + !strings.HasPrefix(r.URL.Path, "/install.") && + r.URL.Path != "/favicon.png" { + http.Redirect(w, r, "/install.html", http.StatusSeeOther) // should not be cacheable + return + } + // enforce https? + if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && Context.httpsServer.server != nil { + // yes, and we want host from host:port + host, _, err := net.SplitHostPort(r.Host) + if err != nil { + // no port in host + host = r.Host + } + // construct new URL to redirect to + newURL := url.URL{ + Scheme: "https", + Host: net.JoinHostPort(host, strconv.Itoa(config.TLS.PortHTTPS)), + Path: r.URL.Path, + RawQuery: r.URL.RawQuery, + } + http.Redirect(w, r, newURL.String(), http.StatusTemporaryRedirect) + return + } + w.Header().Set("Access-Control-Allow-Origin", "*") + handler(w, r) + } +} + +type postInstallHandlerStruct struct { + handler http.Handler +} + +func (p *postInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) { + postInstall(p.handler.ServeHTTP)(w, r) +} + +func postInstallHandler(handler http.Handler) http.Handler { + return &postInstallHandlerStruct{handler} +} diff --git a/home/control_filtering.go b/home/control_filtering.go index 77c6cafa..8846b980 100644 --- a/home/control_filtering.go +++ b/home/control_filtering.go @@ -210,9 +210,9 @@ func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) { } func handleFilteringRefresh(w http.ResponseWriter, r *http.Request) { - config.controlLock.Unlock() + Context.controlLock.Unlock() nUpdated, err := refreshFilters() - config.controlLock.Lock() + Context.controlLock.Lock() if err != nil { httpError(w, http.StatusInternalServerError, "%s", err) return diff --git a/home/control_install.go b/home/control_install.go index 5311c091..17ffafa9 100644 --- a/home/control_install.go +++ b/home/control_install.go @@ -13,6 +13,10 @@ import ( "runtime" "strconv" + "github.com/AdguardTeam/AdGuardHome/util" + + "github.com/AdguardTeam/AdGuardHome/dhcpd" + "github.com/AdguardTeam/golibs/log" ) @@ -22,13 +26,21 @@ type firstRunData struct { Interfaces map[string]interface{} `json:"interfaces"` } +type netInterfaceJSON struct { + Name string `json:"name"` + MTU int `json:"mtu"` + HardwareAddr string `json:"hardware_address"` + Addresses []string `json:"ip_addresses"` + Flags string `json:"flags"` +} + // Get initial installation settings func handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) { data := firstRunData{} data.WebPort = 80 data.DNSPort = 53 - ifaces, err := getValidNetInterfacesForWeb() + ifaces, err := util.GetValidNetInterfacesForWeb() if err != nil { httpError(w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err) return @@ -36,7 +48,14 @@ func handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) { data.Interfaces = make(map[string]interface{}) for _, iface := range ifaces { - data.Interfaces[iface.Name] = iface + ifaceJSON := netInterfaceJSON{ + Name: iface.Name, + MTU: iface.MTU, + HardwareAddr: iface.HardwareAddr, + Addresses: iface.Addresses, + Flags: iface.Flags, + } + data.Interfaces[iface.Name] = ifaceJSON } w.Header().Set("Content-Type", "application/json") @@ -53,17 +72,24 @@ type checkConfigReqEnt struct { Autofix bool `json:"autofix"` } type checkConfigReq struct { - Web checkConfigReqEnt `json:"web"` - DNS checkConfigReqEnt `json:"dns"` + Web checkConfigReqEnt `json:"web"` + DNS checkConfigReqEnt `json:"dns"` + SetStaticIP bool `json:"set_static_ip"` } type checkConfigRespEnt struct { Status string `json:"status"` CanAutofix bool `json:"can_autofix"` } +type staticIPJSON struct { + Static string `json:"static"` + IP string `json:"ip"` + Error string `json:"error"` +} type checkConfigResp struct { - Web checkConfigRespEnt `json:"web"` - DNS checkConfigRespEnt `json:"dns"` + Web checkConfigRespEnt `json:"web"` + DNS checkConfigRespEnt `json:"dns"` + StaticIP staticIPJSON `json:"static_ip"` } // Check if ports are available, respond with results @@ -77,16 +103,16 @@ func handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) { } if reqData.Web.Port != 0 && reqData.Web.Port != config.BindPort { - err = checkPortAvailable(reqData.Web.IP, reqData.Web.Port) + err = util.CheckPortAvailable(reqData.Web.IP, reqData.Web.Port) if err != nil { respData.Web.Status = fmt.Sprintf("%v", err) } } if reqData.DNS.Port != 0 { - err = checkPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port) + err = util.CheckPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port) - if errorIsAddrInUse(err) { + if util.ErrorIsAddrInUse(err) { canAutofix := checkDNSStubListener() if canAutofix && reqData.DNS.Autofix { @@ -95,7 +121,7 @@ func handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) { log.Error("Couldn't disable DNSStubListener: %s", err) } - err = checkPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port) + err = util.CheckPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port) canAutofix = false } @@ -103,11 +129,13 @@ func handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) { } if err == nil { - err = checkPortAvailable(reqData.DNS.IP, reqData.DNS.Port) + err = util.CheckPortAvailable(reqData.DNS.IP, reqData.DNS.Port) } if err != nil { respData.DNS.Status = fmt.Sprintf("%v", err) + } else { + respData.StaticIP = handleStaticIP(reqData.DNS.IP, reqData.SetStaticIP) } } @@ -119,6 +147,46 @@ func handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) { } } +// handleStaticIP - handles static IP request +// It either checks if we have a static IP +// Or if set=true, it tries to set it +func handleStaticIP(ip string, set bool) staticIPJSON { + resp := staticIPJSON{} + + interfaceName := util.GetInterfaceByIP(ip) + resp.Static = "no" + + if len(interfaceName) == 0 { + resp.Static = "error" + resp.Error = fmt.Sprintf("Couldn't find network interface by IP %s", ip) + return resp + } + + if set { + // Try to set static IP for the specified interface + err := dhcpd.SetStaticIP(interfaceName) + if err != nil { + resp.Static = "error" + resp.Error = err.Error() + return resp + } + } + + // Fallthrough here even if we set static IP + // Check if we have a static IP and return the details + isStaticIP, err := dhcpd.HasStaticIP(interfaceName) + if err != nil { + resp.Static = "error" + resp.Error = err.Error() + } else { + if isStaticIP { + resp.Static = "yes" + } + resp.IP = util.GetSubnet(interfaceName) + } + return resp +} + // Check if DNSStubListener is active func checkDNSStubListener() bool { if runtime.GOOS != "linux" { @@ -129,7 +197,7 @@ func checkDNSStubListener() bool { log.Tracef("executing %s %v", cmd.Path, cmd.Args) _, err := cmd.Output() if err != nil || cmd.ProcessState.ExitCode() != 0 { - log.Error("command %s has failed: %v code:%d", + log.Info("command %s has failed: %v code:%d", cmd.Path, err, cmd.ProcessState.ExitCode()) return false } @@ -138,7 +206,7 @@ func checkDNSStubListener() bool { log.Tracef("executing %s %v", cmd.Path, cmd.Args) _, err = cmd.Output() if err != nil || cmd.ProcessState.ExitCode() != 0 { - log.Error("command %s has failed: %v code:%d", + log.Info("command %s has failed: %v code:%d", cmd.Path, err, cmd.ProcessState.ExitCode()) return false } @@ -228,7 +296,7 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) { // validate that hosts and ports are bindable if restartHTTP { - err = checkPortAvailable(newSettings.Web.IP, newSettings.Web.Port) + err = util.CheckPortAvailable(newSettings.Web.IP, newSettings.Web.Port) if err != nil { httpError(w, http.StatusBadRequest, "Impossible to listen on IP:port %s due to %s", net.JoinHostPort(newSettings.Web.IP, strconv.Itoa(newSettings.Web.Port)), err) @@ -236,13 +304,13 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) { } } - err = checkPacketPortAvailable(newSettings.DNS.IP, newSettings.DNS.Port) + err = util.CheckPacketPortAvailable(newSettings.DNS.IP, newSettings.DNS.Port) if err != nil { httpError(w, http.StatusBadRequest, "%s", err) return } - err = checkPortAvailable(newSettings.DNS.IP, newSettings.DNS.Port) + err = util.CheckPortAvailable(newSettings.DNS.IP, newSettings.DNS.Port) if err != nil { httpError(w, http.StatusBadRequest, "%s", err) return @@ -251,7 +319,7 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) { var curConfig configuration copyInstallSettings(&curConfig, &config) - config.firstRun = false + Context.firstRun = false config.BindHost = newSettings.Web.IP config.BindPort = newSettings.Web.Port config.DNS.BindHost = newSettings.DNS.IP @@ -266,7 +334,7 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) { } } if err != nil || err2 != nil { - config.firstRun = true + Context.firstRun = true copyInstallSettings(&config, &curConfig) if err != nil { httpError(w, http.StatusInternalServerError, "Couldn't initialize DNS server: %s", err) @@ -278,11 +346,11 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) { u := User{} u.Name = newSettings.Username - config.auth.UserAdd(&u, newSettings.Password) + Context.auth.UserAdd(&u, newSettings.Password) err = config.write() if err != nil { - config.firstRun = true + Context.firstRun = true copyInstallSettings(&config, &curConfig) httpError(w, http.StatusInternalServerError, "Couldn't write config: %s", err) return diff --git a/home/control_tls.go b/home/control_tls.go index f0f4c655..0df8b729 100644 --- a/home/control_tls.go +++ b/home/control_tls.go @@ -20,6 +20,8 @@ import ( "strings" "time" + "github.com/AdguardTeam/AdGuardHome/util" + "github.com/AdguardTeam/golibs/log" "github.com/joomcode/errorx" ) @@ -84,7 +86,7 @@ func handleTLSValidate(w http.ResponseWriter, r *http.Request) { alreadyRunning = true } if !alreadyRunning { - err = checkPortAvailable(config.BindHost, data.PortHTTPS) + err = util.CheckPortAvailable(config.BindHost, data.PortHTTPS) if err != nil { httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS) return @@ -114,7 +116,7 @@ func handleTLSConfigure(w http.ResponseWriter, r *http.Request) { alreadyRunning = true } if !alreadyRunning { - err = checkPortAvailable(config.BindHost, data.PortHTTPS) + err = util.CheckPortAvailable(config.BindHost, data.PortHTTPS) if err != nil { httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS) return diff --git a/home/control_update.go b/home/control_update.go index 7864cfbb..87fe4034 100644 --- a/home/control_update.go +++ b/home/control_update.go @@ -17,6 +17,8 @@ import ( "syscall" "time" + "github.com/AdguardTeam/AdGuardHome/util" + "github.com/AdguardTeam/golibs/log" ) @@ -64,7 +66,7 @@ type getVersionJSONRequest struct { // Get the latest available version from the Internet func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { - if config.disableUpdate { + if Context.disableUpdate { return } @@ -77,10 +79,10 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { now := time.Now() if !req.RecheckNow { - config.controlLock.Lock() + Context.controlLock.Lock() cached := now.Sub(config.versionCheckLastTime) <= versionCheckPeriod && len(config.versionCheckJSON) != 0 data := config.versionCheckJSON - config.controlLock.Unlock() + Context.controlLock.Unlock() if cached { log.Tracef("Returning cached data") @@ -93,7 +95,7 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { var resp *http.Response for i := 0; i != 3; i++ { log.Tracef("Downloading data from %s", versionCheckURL) - resp, err = config.client.Get(versionCheckURL) + resp, err = Context.client.Get(versionCheckURL) if err != nil && strings.HasSuffix(err.Error(), "i/o timeout") { // This case may happen while we're restarting DNS server // https://github.com/AdguardTeam/AdGuardHome/issues/934 @@ -116,10 +118,10 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { return } - config.controlLock.Lock() + Context.controlLock.Lock() config.versionCheckLastTime = now config.versionCheckJSON = body - config.controlLock.Unlock() + Context.controlLock.Unlock() w.Header().Set("Content-Type", "application/json") _, err = w.Write(getVersionResp(body)) @@ -158,7 +160,7 @@ type updateInfo struct { func getUpdateInfo(jsonData []byte) (*updateInfo, error) { var u updateInfo - workDir := config.ourWorkingDir + workDir := Context.workDir versionJSON := make(map[string]interface{}) err := json.Unmarshal(jsonData, &versionJSON) @@ -196,7 +198,7 @@ func getUpdateInfo(jsonData []byte) (*updateInfo, error) { binName = "AdGuardHome.exe" } u.curBinName = filepath.Join(workDir, binName) - if !fileExists(u.curBinName) { + if !util.FileExists(u.curBinName) { return nil, fmt.Errorf("Executable file %s doesn't exist", u.curBinName) } u.bkpBinName = filepath.Join(u.backupDir, binName) @@ -365,7 +367,7 @@ func copySupportingFiles(files []string, srcdir, dstdir string, useSrcNameOnly, // Download package file and save it to disk func getPackageFile(u *updateInfo) error { - resp, err := config.client.Get(u.pkgURL) + resp, err := Context.client.Get(u.pkgURL) if err != nil { return fmt.Errorf("HTTP request failed: %s", err) } @@ -436,17 +438,17 @@ func doUpdate(u *updateInfo) error { } // ./README.md -> backup/README.md - err = copySupportingFiles(files, config.ourWorkingDir, u.backupDir, true, true) + err = copySupportingFiles(files, Context.workDir, u.backupDir, true, true) if err != nil { return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s", - config.ourWorkingDir, u.backupDir, err) + Context.workDir, u.backupDir, err) } // update/[AdGuardHome/]README.md -> ./README.md - err = copySupportingFiles(files, u.updateDir, config.ourWorkingDir, false, true) + err = copySupportingFiles(files, u.updateDir, Context.workDir, false, true) if err != nil { return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s", - u.updateDir, config.ourWorkingDir, err) + u.updateDir, Context.workDir, err) } log.Tracef("Renaming: %s -> %s", u.curBinName, u.bkpBinName) @@ -478,8 +480,7 @@ func finishUpdate(u *updateInfo) { cleanupAlways() if runtime.GOOS == "windows" { - - if config.runningAsService { + if Context.runningAsService { // Note: // we can't restart the service via "kardianos/service" package - it kills the process first // we can't start a new instance - Windows doesn't allow it diff --git a/home/control_update_test.go b/home/control_update_test.go index c30a72e4..6ec4a186 100644 --- a/home/control_update_test.go +++ b/home/control_update_test.go @@ -8,9 +8,8 @@ import ( ) func TestDoUpdate(t *testing.T) { - config.DNS.Port = 0 - config.ourWorkingDir = "..." // set absolute path + Context.workDir = "..." // set absolute path newver := "v0.96" data := `{ @@ -35,15 +34,15 @@ func TestDoUpdate(t *testing.T) { u := updateInfo{ pkgURL: "https://github.com/AdguardTeam/AdGuardHome/releases/download/" + newver + "/AdGuardHome_linux_amd64.tar.gz", - pkgName: config.ourWorkingDir + "/agh-update-" + newver + "/AdGuardHome_linux_amd64.tar.gz", + pkgName: Context.workDir + "/agh-update-" + newver + "/AdGuardHome_linux_amd64.tar.gz", newVer: newver, - updateDir: config.ourWorkingDir + "/agh-update-" + newver, - backupDir: config.ourWorkingDir + "/agh-backup", - configName: config.ourWorkingDir + "/AdGuardHome.yaml", - updateConfigName: config.ourWorkingDir + "/agh-update-" + newver + "/AdGuardHome/AdGuardHome.yaml", - curBinName: config.ourWorkingDir + "/AdGuardHome", - bkpBinName: config.ourWorkingDir + "/agh-backup/AdGuardHome", - newBinName: config.ourWorkingDir + "/agh-update-" + newver + "/AdGuardHome/AdGuardHome", + updateDir: Context.workDir + "/agh-update-" + newver, + backupDir: Context.workDir + "/agh-backup", + configName: Context.workDir + "/AdGuardHome.yaml", + updateConfigName: Context.workDir + "/agh-update-" + newver + "/AdGuardHome/AdGuardHome.yaml", + curBinName: Context.workDir + "/AdGuardHome", + bkpBinName: Context.workDir + "/agh-backup/AdGuardHome", + newBinName: Context.workDir + "/agh-update-" + newver + "/AdGuardHome/AdGuardHome", } if uu.pkgURL != u.pkgURL || diff --git a/home/dns.go b/home/dns.go index 4d5dceb3..167662b1 100644 --- a/home/dns.go +++ b/home/dns.go @@ -25,7 +25,7 @@ func onConfigModified() { // Please note that we must do it even if we don't start it // so that we had access to the query log and the stats func initDNSServer() error { - baseDir := config.getDataDir() + baseDir := Context.getDataDir() err := os.MkdirAll(baseDir, 0755) if err != nil { @@ -71,8 +71,8 @@ func initDNSServer() error { } sessFilename := filepath.Join(baseDir, "sessions.db") - config.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60) - if config.auth == nil { + Context.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60) + if Context.auth == nil { closeDNSServer() return fmt.Errorf("Couldn't initialize Auth module") } @@ -294,9 +294,9 @@ func closeDNSServer() { Context.queryLog = nil } - if config.auth != nil { - config.auth.Close() - config.auth = nil + if Context.auth != nil { + Context.auth.Close() + Context.auth = nil } log.Debug("Closed all DNS modules") diff --git a/home/filter.go b/home/filter.go index 6b0a16ef..befdd873 100644 --- a/home/filter.go +++ b/home/filter.go @@ -13,6 +13,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/dnsfilter" + "github.com/AdguardTeam/AdGuardHome/util" "github.com/AdguardTeam/golibs/file" "github.com/AdguardTeam/golibs/log" ) @@ -401,7 +402,7 @@ func parseFilterContents(contents []byte) (int, string) { // Count lines in the filter for len(data) != 0 { - line := SplitNext(&data, '\n') + line := util.SplitNext(&data, '\n') if len(line) == 0 { continue } @@ -424,7 +425,7 @@ func parseFilterContents(contents []byte) (int, string) { func (filter *filter) update() (bool, error) { log.Tracef("Downloading update for filter %d from %s", filter.ID, filter.URL) - resp, err := config.client.Get(filter.URL) + resp, err := Context.client.Get(filter.URL) if resp != nil && resp.Body != nil { defer resp.Body.Close() } @@ -538,7 +539,7 @@ func (filter *filter) unload() { // Path to the filter contents func (filter *filter) Path() string { - return filepath.Join(config.getDataDir(), filterDir, strconv.FormatInt(filter.ID, 10)+".txt") + return filepath.Join(Context.getDataDir(), filterDir, strconv.FormatInt(filter.ID, 10)+".txt") } // LastTimeUpdated returns the time when the filter was last time updated diff --git a/home/filter_test.go b/home/filter_test.go index 63736c38..edda556a 100644 --- a/home/filter_test.go +++ b/home/filter_test.go @@ -10,7 +10,12 @@ import ( ) func TestFilters(t *testing.T) { - config.client = &http.Client{ + dir := prepareTestDir() + defer func() { _ = os.RemoveAll(dir) }() + + Context = homeContext{} + Context.workDir = dir + Context.client = &http.Client{ Timeout: time.Minute * 5, } @@ -33,5 +38,5 @@ func TestFilters(t *testing.T) { assert.True(t, err == nil) f.unload() - os.Remove(f.Path()) + _ = os.Remove(f.Path()) } diff --git a/home/helpers.go b/home/helpers.go deleted file mode 100644 index 982a10ef..00000000 --- a/home/helpers.go +++ /dev/null @@ -1,380 +0,0 @@ -package home - -import ( - "context" - "errors" - "fmt" - "net" - "net/http" - "net/url" - "os" - "os/exec" - "path" - "path/filepath" - "runtime" - "strconv" - "strings" - "syscall" - "time" - - "github.com/AdguardTeam/golibs/log" - "github.com/joomcode/errorx" -) - -// ---------------------------------- -// helper functions for HTTP handlers -// ---------------------------------- -func ensure(method string, handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { - return func(w http.ResponseWriter, r *http.Request) { - log.Debug("%s %v", r.Method, r.URL) - - if r.Method != method { - http.Error(w, "This request must be "+method, http.StatusMethodNotAllowed) - return - } - - if method == "POST" || method == "PUT" || method == "DELETE" { - config.controlLock.Lock() - defer config.controlLock.Unlock() - } - - handler(w, r) - } -} - -func ensurePOST(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { - return ensure("POST", handler) -} - -func ensureGET(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { - return ensure("GET", handler) -} - -// Bridge between http.Handler object and Go function -type httpHandler struct { - handler func(http.ResponseWriter, *http.Request) -} - -func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - h.handler(w, r) -} - -func ensureHandler(method string, handler func(http.ResponseWriter, *http.Request)) http.Handler { - h := httpHandler{} - h.handler = ensure(method, handler) - return &h -} - -// ------------------- -// first run / install -// ------------------- -func detectFirstRun() bool { - configfile := config.ourConfigFilename - if !filepath.IsAbs(configfile) { - configfile = filepath.Join(config.ourWorkingDir, config.ourConfigFilename) - } - _, err := os.Stat(configfile) - if !os.IsNotExist(err) { - // do nothing, file exists - return false - } - return true -} - -// preInstall lets the handler run only if firstRun is true, no redirects -func preInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { - return func(w http.ResponseWriter, r *http.Request) { - if !config.firstRun { - // if it's not first run, don't let users access it (for example /install.html when configuration is done) - http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) - return - } - handler(w, r) - } -} - -// preInstallStruct wraps preInstall into a struct that can be returned as an interface where necessary -type preInstallHandlerStruct struct { - handler http.Handler -} - -func (p *preInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) { - preInstall(p.handler.ServeHTTP)(w, r) -} - -// preInstallHandler returns http.Handler interface for preInstall wrapper -func preInstallHandler(handler http.Handler) http.Handler { - return &preInstallHandlerStruct{handler} -} - -// postInstall lets the handler run only if firstRun is false, and redirects to /install.html otherwise -// it also enforces HTTPS if it is enabled and configured -func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { - return func(w http.ResponseWriter, r *http.Request) { - if config.firstRun && - !(strings.HasPrefix(r.URL.Path, "/install.") || - strings.HasPrefix(r.URL.Path, "/__locales/") || - r.URL.Path == "/favicon.png") { - http.Redirect(w, r, "/install.html", http.StatusSeeOther) // should not be cacheable - return - } - // enforce https? - if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && Context.httpsServer.server != nil { - // yes, and we want host from host:port - host, _, err := net.SplitHostPort(r.Host) - if err != nil { - // no port in host - host = r.Host - } - // construct new URL to redirect to - newURL := url.URL{ - Scheme: "https", - Host: net.JoinHostPort(host, strconv.Itoa(config.TLS.PortHTTPS)), - Path: r.URL.Path, - RawQuery: r.URL.RawQuery, - } - http.Redirect(w, r, newURL.String(), http.StatusTemporaryRedirect) - return - } - w.Header().Set("Access-Control-Allow-Origin", "*") - handler(w, r) - } -} - -type postInstallHandlerStruct struct { - handler http.Handler -} - -func (p *postInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) { - postInstall(p.handler.ServeHTTP)(w, r) -} - -func postInstallHandler(handler http.Handler) http.Handler { - return &postInstallHandlerStruct{handler} -} - -// ------------------ -// network interfaces -// ------------------ -type netInterface struct { - Name string `json:"name"` - MTU int `json:"mtu"` - HardwareAddr string `json:"hardware_address"` - Addresses []string `json:"ip_addresses"` - Flags string `json:"flags"` -} - -// getValidNetInterfaces returns interfaces that are eligible for DNS and/or DHCP -// invalid interface is a ppp interface or the one that doesn't allow broadcasts -func getValidNetInterfaces() ([]net.Interface, error) { - ifaces, err := net.Interfaces() - if err != nil { - return nil, fmt.Errorf("Couldn't get list of interfaces: %s", err) - } - - netIfaces := []net.Interface{} - - for i := range ifaces { - if ifaces[i].Flags&net.FlagPointToPoint != 0 { - // this interface is ppp, we're not interested in this one - continue - } - - iface := ifaces[i] - netIfaces = append(netIfaces, iface) - } - - return netIfaces, nil -} - -// getValidNetInterfacesMap returns interfaces that are eligible for DNS and WEB only -// we do not return link-local addresses here -func getValidNetInterfacesForWeb() ([]netInterface, error) { - ifaces, err := getValidNetInterfaces() - if err != nil { - return nil, errorx.Decorate(err, "Couldn't get interfaces") - } - if len(ifaces) == 0 { - return nil, errors.New("couldn't find any legible interface") - } - - var netInterfaces []netInterface - - for _, iface := range ifaces { - addrs, e := iface.Addrs() - if e != nil { - return nil, errorx.Decorate(e, "Failed to get addresses for interface %s", iface.Name) - } - - netIface := netInterface{ - Name: iface.Name, - MTU: iface.MTU, - HardwareAddr: iface.HardwareAddr.String(), - } - - if iface.Flags != 0 { - netIface.Flags = iface.Flags.String() - } - - // we don't want link-local addresses in json, so skip them - for _, addr := range addrs { - ipnet, ok := addr.(*net.IPNet) - if !ok { - // not an IPNet, should not happen - return nil, fmt.Errorf("SHOULD NOT HAPPEN: got iface.Addrs() element %s that is not net.IPNet, it is %T", addr, addr) - } - // ignore link-local - if ipnet.IP.IsLinkLocalUnicast() { - continue - } - netIface.Addresses = append(netIface.Addresses, ipnet.IP.String()) - } - if len(netIface.Addresses) != 0 { - netInterfaces = append(netInterfaces, netIface) - } - } - - return netInterfaces, nil -} - -// checkPortAvailable is not a cheap test to see if the port is bindable, because it's actually doing the bind momentarily -func checkPortAvailable(host string, port int) error { - ln, err := net.Listen("tcp", net.JoinHostPort(host, strconv.Itoa(port))) - if err != nil { - return err - } - _ = ln.Close() - - // It seems that net.Listener.Close() doesn't close file descriptors right away. - // We wait for some time and hope that this fd will be closed. - time.Sleep(100 * time.Millisecond) - return nil -} - -func checkPacketPortAvailable(host string, port int) error { - ln, err := net.ListenPacket("udp", net.JoinHostPort(host, strconv.Itoa(port))) - if err != nil { - return err - } - _ = ln.Close() - - // It seems that net.Listener.Close() doesn't close file descriptors right away. - // We wait for some time and hope that this fd will be closed. - time.Sleep(100 * time.Millisecond) - return err -} - -// Connect to a remote server resolving hostname using our own DNS server -func customDialContext(ctx context.Context, network, addr string) (net.Conn, error) { - log.Tracef("network:%v addr:%v", network, addr) - - host, port, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - - dialer := &net.Dialer{ - Timeout: time.Minute * 5, - } - - if net.ParseIP(host) != nil || config.DNS.Port == 0 { - con, err := dialer.DialContext(ctx, network, addr) - return con, err - } - - addrs, e := Context.dnsServer.Resolve(host) - log.Debug("dnsServer.Resolve: %s: %v", host, addrs) - if e != nil { - return nil, e - } - - if len(addrs) == 0 { - return nil, fmt.Errorf("couldn't lookup host: %s", host) - } - - var dialErrs []error - for _, a := range addrs { - addr = net.JoinHostPort(a.String(), port) - con, err := dialer.DialContext(ctx, network, addr) - if err != nil { - dialErrs = append(dialErrs, err) - continue - } - return con, err - } - return nil, errorx.DecorateMany(fmt.Sprintf("couldn't dial to %s", addr), dialErrs...) -} - -// check if error is "address already in use" -func errorIsAddrInUse(err error) bool { - errOpError, ok := err.(*net.OpError) - if !ok { - return false - } - - errSyscallError, ok := errOpError.Err.(*os.SyscallError) - if !ok { - return false - } - - errErrno, ok := errSyscallError.Err.(syscall.Errno) - if !ok { - return false - } - - if runtime.GOOS == "windows" { - const WSAEADDRINUSE = 10048 - return errErrno == WSAEADDRINUSE - } - - return errErrno == syscall.EADDRINUSE -} - -// --------------------- -// general helpers -// --------------------- - -// fileExists returns TRUE if file exists -func fileExists(fn string) bool { - _, err := os.Stat(fn) - if err != nil { - return false - } - return true -} - -// runCommand runs shell command -func runCommand(command string, arguments ...string) (int, string, error) { - cmd := exec.Command(command, arguments...) - out, err := cmd.Output() - if err != nil { - return 1, "", fmt.Errorf("exec.Command(%s) failed: %s", command, err) - } - - return cmd.ProcessState.ExitCode(), string(out), nil -} - -// --------------------- -// debug logging helpers -// --------------------- -func _Func() string { - pc := make([]uintptr, 10) // at least 1 entry needed - runtime.Callers(2, pc) - f := runtime.FuncForPC(pc[0]) - return path.Base(f.Name()) -} - -// SplitNext - split string by a byte and return the first chunk -// Whitespace is trimmed -func SplitNext(str *string, splitBy byte) string { - i := strings.IndexByte(*str, splitBy) - s := "" - if i != -1 { - s = (*str)[0:i] - *str = (*str)[i+1:] - } else { - s = *str - *str = "" - } - return strings.TrimSpace(s) -} diff --git a/home/home.go b/home/home.go index b414d671..f7c697f6 100644 --- a/home/home.go +++ b/home/home.go @@ -20,6 +20,10 @@ import ( "syscall" "time" + "github.com/AdguardTeam/AdGuardHome/util" + + "github.com/joomcode/errorx" + "github.com/AdguardTeam/AdGuardHome/isdelve" "github.com/AdguardTeam/AdGuardHome/dhcpd" @@ -49,6 +53,9 @@ const versionCheckPeriod = time.Hour * 8 // Global context type homeContext struct { + // Modules + // -- + clients clientsContainer // per-client-settings module stats stats.Stats // statistics module queryLog querylog.QueryLog // query log module @@ -57,8 +64,29 @@ type homeContext struct { whois *Whois // WHOIS module dnsFilter *dnsfilter.Dnsfilter // DNS filtering module dhcpServer *dhcpd.Server // DHCP module + auth *Auth // HTTP authentication module httpServer *http.Server // HTTP module httpsServer HTTPSServer // HTTPS module + + // Runtime properties + // -- + + configFilename string // Config filename (can be overridden via the command line arguments) + workDir string // Location of our directory, used to protect against CWD being somewhere else + firstRun bool // if set to true, don't run any services except HTTP web inteface, and serve only first-run html + pidFileName string // PID file name. Empty if no PID file was created. + disableUpdate bool // If set, don't check for updates + controlLock sync.Mutex + transport *http.Transport + client *http.Client + appSignalChannel chan os.Signal // Channel for receiving OS signals by the console app + // runningAsService flag is set to true when options are passed from the service runner + runningAsService bool +} + +// getDataDir returns path to the directory where we store databases and filters +func (c *homeContext) getDataDir() string { + return filepath.Join(c.workDir, dataDir) } // Context - a global context object @@ -81,17 +109,38 @@ func Main(version string, channel string, armVer string) { return } + Context.appSignalChannel = make(chan os.Signal) + signal.Notify(Context.appSignalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT) + go func() { + <-Context.appSignalChannel + cleanup() + cleanupAlways() + os.Exit(0) + }() + // run the protection run(args) } // run initializes configuration and runs the AdGuard Home -// run is a blocking method and it won't exit until the service is stopped! +// run is a blocking method! // nolint func run(args options) { // config file path can be overridden by command-line arguments: if args.configFilename != "" { - config.ourConfigFilename = args.configFilename + Context.configFilename = args.configFilename + } else { + // Default config file name + Context.configFilename = "AdGuardHome.yaml" + } + + // Init some of the Context fields right away + Context.transport = &http.Transport{ + DialContext: customDialContext, + } + Context.client = &http.Client{ + Timeout: time.Minute * 5, + Transport: Context.transport, } // configure working dir and config path @@ -106,31 +155,22 @@ func run(args options) { msg = msg + " v" + ARMVersion } log.Printf(msg, versionString, updateChannel, runtime.GOOS, runtime.GOARCH, ARMVersion) - log.Debug("Current working directory is %s", config.ourWorkingDir) + log.Debug("Current working directory is %s", Context.workDir) if args.runningAsService { log.Info("AdGuard Home is running as a service") } - config.runningAsService = args.runningAsService - config.disableUpdate = args.disableUpdate + Context.runningAsService = args.runningAsService + Context.disableUpdate = args.disableUpdate - config.firstRun = detectFirstRun() - if config.firstRun { + Context.firstRun = detectFirstRun() + if Context.firstRun { requireAdminRights() } - config.appSignalChannel = make(chan os.Signal) - signal.Notify(config.appSignalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT) - go func() { - <-config.appSignalChannel - cleanup() - cleanupAlways() - os.Exit(0) - }() - initConfig() initServices() - if !config.firstRun { + if !Context.firstRun { // Do the upgrade if necessary err := upgradeConfig() if err != nil { @@ -148,7 +188,7 @@ func run(args options) { } } - config.DHCP.WorkDir = config.ourWorkingDir + config.DHCP.WorkDir = Context.workDir config.DHCP.HTTPRegister = httpRegister config.DHCP.ConfigModified = onConfigModified Context.dhcpServer = dhcpd.Create(config.DHCP) @@ -157,7 +197,7 @@ func run(args options) { if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") && config.RlimitNoFile != 0 { - setRlimit(config.RlimitNoFile) + util.SetRlimit(config.RlimitNoFile) } // override bind host/port from the console @@ -168,7 +208,7 @@ func run(args options) { config.BindPort = args.bindPort } - if !config.firstRun { + if !Context.firstRun { // Save the updated config err := config.write() if err != nil { @@ -193,7 +233,7 @@ func run(args options) { } if len(args.pidFile) != 0 && writePIDFile(args.pidFile) { - config.pidFileName = args.pidFile + Context.pidFileName = args.pidFile } // Initialize and run the admin Web interface @@ -204,7 +244,7 @@ func run(args options) { registerControlHandlers() // add handlers for /install paths, we only need them when we're not configured yet - if config.firstRun { + if Context.firstRun { log.Info("This is the first launch of AdGuard Home, redirecting everything to /install.html ") http.Handle("/install.html", preInstallHandler(http.FileServer(box))) registerInstallHandlers() @@ -291,7 +331,7 @@ func httpServerLoop() { // Check if the current user has root (administrator) rights // and if not, ask and try to run as root func requireAdminRights() { - admin, _ := haveAdminRights() + admin, _ := util.HaveAdminRights() if //noinspection ALL admin || isdelve.Enabled { return @@ -331,7 +371,7 @@ func writePIDFile(fn string) bool { return true } -// initWorkingDir initializes the ourWorkingDir +// initWorkingDir initializes the workDir // if no command-line arguments specified, we use the directory where our binary file is located func initWorkingDir(args options) { execPath, err := os.Executable() @@ -341,9 +381,9 @@ func initWorkingDir(args options) { if args.workDir != "" { // If there is a custom config file, use it's directory as our working dir - config.ourWorkingDir = args.workDir + Context.workDir = args.workDir } else { - config.ourWorkingDir = filepath.Dir(execPath) + Context.workDir = filepath.Dir(execPath) } } @@ -376,12 +416,12 @@ func configureLogger(args options) { if ls.LogFile == configSyslog { // Use syslog where it is possible and eventlog on Windows - err := configureSyslog() + err := util.ConfigureSyslog(serviceName) if err != nil { log.Fatalf("cannot initialize syslog: %s", err) } } else { - logFilePath := filepath.Join(config.ourWorkingDir, ls.LogFile) + logFilePath := filepath.Join(Context.workDir, ls.LogFile) if filepath.IsAbs(ls.LogFile) { logFilePath = ls.LogFile } @@ -420,8 +460,8 @@ func stopHTTPServer() { // This function is called before application exits func cleanupAlways() { - if len(config.pidFileName) != 0 { - _ = os.Remove(config.pidFileName) + if len(Context.pidFileName) != 0 { + _ = os.Remove(Context.pidFileName) } log.Info("Stopped") } @@ -544,7 +584,7 @@ func printHTTPAddresses(proto string) { } } else if config.BindHost == "0.0.0.0" { log.Println("AdGuard Home is available on the following addresses:") - ifaces, err := getValidNetInterfacesForWeb() + ifaces, err := util.GetValidNetInterfacesForWeb() if err != nil { // That's weird, but we'll ignore it address = net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort)) @@ -561,3 +601,60 @@ func printHTTPAddresses(proto string) { log.Printf("Go to %s://%s", proto, address) } } + +// ------------------- +// first run / install +// ------------------- +func detectFirstRun() bool { + configfile := Context.configFilename + if !filepath.IsAbs(configfile) { + configfile = filepath.Join(Context.workDir, Context.configFilename) + } + _, err := os.Stat(configfile) + if !os.IsNotExist(err) { + // do nothing, file exists + return false + } + return true +} + +// Connect to a remote server resolving hostname using our own DNS server +func customDialContext(ctx context.Context, network, addr string) (net.Conn, error) { + log.Tracef("network:%v addr:%v", network, addr) + + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + dialer := &net.Dialer{ + Timeout: time.Minute * 5, + } + + if net.ParseIP(host) != nil || config.DNS.Port == 0 { + con, err := dialer.DialContext(ctx, network, addr) + return con, err + } + + addrs, e := Context.dnsServer.Resolve(host) + log.Debug("dnsServer.Resolve: %s: %v", host, addrs) + if e != nil { + return nil, e + } + + if len(addrs) == 0 { + return nil, fmt.Errorf("couldn't lookup host: %s", host) + } + + var dialErrs []error + for _, a := range addrs { + addr = net.JoinHostPort(a.String(), port) + con, err := dialer.DialContext(ctx, network, addr) + if err != nil { + dialErrs = append(dialErrs, err) + continue + } + return con, err + } + return nil, errorx.DecorateMany(fmt.Sprintf("couldn't dial to %s", addr), dialErrs...) +} diff --git a/home/home_test.go b/home/home_test.go index 771c74d0..fba9b43d 100644 --- a/home/home_test.go +++ b/home/home_test.go @@ -107,6 +107,9 @@ schema_version: 5 // . Wait until the filters are downloaded // . Stop and cleanup func TestHome(t *testing.T) { + // Reinit context + Context = homeContext{} + dir := prepareTestDir() defer func() { _ = os.RemoveAll(dir) }() fn := filepath.Join(dir, "AdGuardHome.yaml") @@ -123,12 +126,12 @@ func TestHome(t *testing.T) { var err error var resp *http.Response h := http.Client{} - for i := 0; i != 5; i++ { + for i := 0; i != 50; i++ { resp, err = h.Get("http://127.0.0.1:3000/") if err == nil && resp.StatusCode != 404 { break } - time.Sleep(1 * time.Second) + time.Sleep(100 * time.Millisecond) } assert.Truef(t, err == nil, "%s", err) assert.Equal(t, 200, resp.StatusCode) @@ -140,7 +143,7 @@ func TestHome(t *testing.T) { // test DNS over UDP r := upstream.NewResolver("127.0.0.1:5354", 3*time.Second) addrs, err := r.LookupIPAddr(context.TODO(), "static.adguard.com") - assert.Truef(t, err == nil, "%s", err) + assert.Nil(t, err) haveIP := len(addrs) != 0 assert.True(t, haveIP) diff --git a/home/service.go b/home/service.go index edca9244..d066e118 100644 --- a/home/service.go +++ b/home/service.go @@ -7,6 +7,7 @@ import ( "strings" "syscall" + "github.com/AdguardTeam/AdGuardHome/util" "github.com/AdguardTeam/golibs/log" "github.com/kardianos/service" ) @@ -34,10 +35,10 @@ func (p *program) Start(s service.Service) error { // Stop stops the program func (p *program) Stop(s service.Service) error { // Stop should not block. Return with a few seconds. - if config.appSignalChannel == nil { + if Context.appSignalChannel == nil { os.Exit(0) } - config.appSignalChannel <- syscall.SIGINT + Context.appSignalChannel <- syscall.SIGINT return nil } @@ -229,7 +230,7 @@ func configureService(c *service.Config) { // returns command code or error if any func runInitdCommand(action string) (int, error) { confPath := "/etc/init.d/" + serviceName - code, _, err := runCommand("sh", "-c", confPath+" "+action) + code, _, err := util.RunCommand("sh", "-c", confPath+" "+action) return code, err } diff --git a/home/upgrade.go b/home/upgrade.go index 3a703ebc..b2936a3e 100644 --- a/home/upgrade.go +++ b/home/upgrade.go @@ -5,6 +5,8 @@ import ( "os" "path/filepath" + "github.com/AdguardTeam/AdGuardHome/util" + "github.com/AdguardTeam/golibs/file" "github.com/AdguardTeam/golibs/log" "golang.org/x/crypto/bcrypt" @@ -114,9 +116,9 @@ func upgradeConfigSchema(oldVersion int, diskConfig *map[string]interface{}) err // The first schema upgrade: // No more "dnsfilter.txt", filters are now kept in data/filters/ func upgradeSchema0to1(diskConfig *map[string]interface{}) error { - log.Printf("%s(): called", _Func()) + log.Printf("%s(): called", util.FuncName()) - dnsFilterPath := filepath.Join(config.ourWorkingDir, "dnsfilter.txt") + dnsFilterPath := filepath.Join(Context.workDir, "dnsfilter.txt") if _, err := os.Stat(dnsFilterPath); !os.IsNotExist(err) { log.Printf("Deleting %s as we don't need it anymore", dnsFilterPath) err = os.Remove(dnsFilterPath) @@ -135,9 +137,9 @@ func upgradeSchema0to1(diskConfig *map[string]interface{}) error { // coredns is now dns in config // delete 'Corefile', since we don't use that anymore func upgradeSchema1to2(diskConfig *map[string]interface{}) error { - log.Printf("%s(): called", _Func()) + log.Printf("%s(): called", util.FuncName()) - coreFilePath := filepath.Join(config.ourWorkingDir, "Corefile") + coreFilePath := filepath.Join(Context.workDir, "Corefile") if _, err := os.Stat(coreFilePath); !os.IsNotExist(err) { log.Printf("Deleting %s as we don't need it anymore", coreFilePath) err = os.Remove(coreFilePath) @@ -159,7 +161,7 @@ func upgradeSchema1to2(diskConfig *map[string]interface{}) error { // Third schema upgrade: // Bootstrap DNS becomes an array func upgradeSchema2to3(diskConfig *map[string]interface{}) error { - log.Printf("%s(): called", _Func()) + log.Printf("%s(): called", util.FuncName()) // Let's read dns configuration from diskConfig dnsConfig, ok := (*diskConfig)["dns"] @@ -196,7 +198,7 @@ func upgradeSchema2to3(diskConfig *map[string]interface{}) error { // Add use_global_blocked_services=true setting for existing "clients" array func upgradeSchema3to4(diskConfig *map[string]interface{}) error { - log.Printf("%s(): called", _Func()) + log.Printf("%s(): called", util.FuncName()) (*diskConfig)["schema_version"] = 4 @@ -233,7 +235,7 @@ func upgradeSchema3to4(diskConfig *map[string]interface{}) error { // password: "..." // ... func upgradeSchema4to5(diskConfig *map[string]interface{}) error { - log.Printf("%s(): called", _Func()) + log.Printf("%s(): called", util.FuncName()) (*diskConfig)["schema_version"] = 5 @@ -288,7 +290,7 @@ func upgradeSchema4to5(diskConfig *map[string]interface{}) error { // - 127.0.0.1 // - ... func upgradeSchema5to6(diskConfig *map[string]interface{}) error { - log.Printf("%s(): called", _Func()) + log.Printf("%s(): called", util.FuncName()) (*diskConfig)["schema_version"] = 6 diff --git a/home/whois.go b/home/whois.go index 25fe211a..321b4ef2 100644 --- a/home/whois.go +++ b/home/whois.go @@ -8,6 +8,8 @@ import ( "strings" "time" + "github.com/AdguardTeam/AdGuardHome/util" + "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/log" ) @@ -61,7 +63,7 @@ func whoisParse(data string) map[string]string { descr := "" netname := "" for len(data) != 0 { - ln := SplitNext(&data, '\n') + ln := util.SplitNext(&data, '\n') if len(ln) == 0 || ln[0] == '#' || ln[0] == '%' { continue } diff --git a/openapi/openapi.yaml b/openapi/openapi.yaml index 9ad2c5d1..5f3539f0 100644 --- a/openapi/openapi.yaml +++ b/openapi/openapi.yaml @@ -1831,6 +1831,9 @@ definitions: $ref: "#/definitions/CheckConfigRequestInfo" web: $ref: "#/definitions/CheckConfigRequestInfo" + set_static_ip: + type: "boolean" + example: false CheckConfigRequestInfo: type: "object" properties: @@ -1851,6 +1854,8 @@ definitions: $ref: "#/definitions/CheckConfigResponseInfo" web: $ref: "#/definitions/CheckConfigResponseInfo" + static_ip: + $ref: "#/definitions/CheckConfigStaticIpInfo" CheckConfigResponseInfo: type: "object" properties: @@ -1860,6 +1865,23 @@ definitions: can_autofix: type: "boolean" example: false + CheckConfigStaticIpInfo: + type: "object" + properties: + static: + type: "string" + example: "no" + description: "Can be: yes, no, error" + ip: + type: "string" + example: "192.168.1.1" + description: "Current dynamic IP address. Set if static=no" + error: + type: "string" + example: "" + description: "Error text. Set if static=error" + + InitialConfiguration: type: "object" description: "AdGuard Home initial configuration (for the first-install wizard)" diff --git a/util/helpers.go b/util/helpers.go new file mode 100644 index 00000000..c50c940d --- /dev/null +++ b/util/helpers.go @@ -0,0 +1,59 @@ +package util + +import ( + "fmt" + "os" + "os/exec" + "path" + "runtime" + "strings" +) + +// --------------------- +// general helpers +// --------------------- + +// fileExists returns TRUE if file exists +func FileExists(fn string) bool { + _, err := os.Stat(fn) + if err != nil { + return false + } + return true +} + +// runCommand runs shell command +func RunCommand(command string, arguments ...string) (int, string, error) { + cmd := exec.Command(command, arguments...) + out, err := cmd.Output() + if err != nil { + return 1, "", fmt.Errorf("exec.Command(%s) failed: %v: %s", command, err, string(out)) + } + + return cmd.ProcessState.ExitCode(), string(out), nil +} + +// --------------------- +// debug logging helpers +// --------------------- +func FuncName() string { + pc := make([]uintptr, 10) // at least 1 entry needed + runtime.Callers(2, pc) + f := runtime.FuncForPC(pc[0]) + return path.Base(f.Name()) +} + +// SplitNext - split string by a byte and return the first chunk +// Whitespace is trimmed +func SplitNext(str *string, splitBy byte) string { + i := strings.IndexByte(*str, splitBy) + s := "" + if i != -1 { + s = (*str)[0:i] + *str = (*str)[i+1:] + } else { + s = *str + *str = "" + } + return strings.TrimSpace(s) +} diff --git a/util/helpers_test.go b/util/helpers_test.go new file mode 100644 index 00000000..d5e90637 --- /dev/null +++ b/util/helpers_test.go @@ -0,0 +1,14 @@ +package util + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSplitNext(t *testing.T) { + s := " a,b , c " + assert.True(t, SplitNext(&s, ',') == "a") + assert.True(t, SplitNext(&s, ',') == "b") + assert.True(t, SplitNext(&s, ',') == "c" && len(s) == 0) +} diff --git a/util/network_utils.go b/util/network_utils.go new file mode 100644 index 00000000..af410201 --- /dev/null +++ b/util/network_utils.go @@ -0,0 +1,194 @@ +package util + +import ( + "errors" + "fmt" + "net" + "os" + "runtime" + "strconv" + "syscall" + "time" + + "github.com/AdguardTeam/golibs/log" + + "github.com/joomcode/errorx" +) + +// NetInterface represents a list of network interfaces +type NetInterface struct { + Name string // Network interface name + MTU int // MTU + HardwareAddr string // Hardware address + Addresses []string // Array with the network interface addresses + Subnets []string // Array with CIDR addresses of this network interface + Flags string // Network interface flags (up, broadcast, etc) +} + +// GetValidNetInterfaces returns interfaces that are eligible for DNS and/or DHCP +// invalid interface is a ppp interface or the one that doesn't allow broadcasts +func GetValidNetInterfaces() ([]net.Interface, error) { + ifaces, err := net.Interfaces() + if err != nil { + return nil, fmt.Errorf("Couldn't get list of interfaces: %s", err) + } + + netIfaces := []net.Interface{} + + for i := range ifaces { + if ifaces[i].Flags&net.FlagPointToPoint != 0 { + // this interface is ppp, we're not interested in this one + continue + } + + iface := ifaces[i] + netIfaces = append(netIfaces, iface) + } + + return netIfaces, nil +} + +// getValidNetInterfacesMap returns interfaces that are eligible for DNS and WEB only +// we do not return link-local addresses here +func GetValidNetInterfacesForWeb() ([]NetInterface, error) { + ifaces, err := GetValidNetInterfaces() + if err != nil { + return nil, errorx.Decorate(err, "Couldn't get interfaces") + } + if len(ifaces) == 0 { + return nil, errors.New("couldn't find any legible interface") + } + + var netInterfaces []NetInterface + + for _, iface := range ifaces { + addrs, e := iface.Addrs() + if e != nil { + return nil, errorx.Decorate(e, "Failed to get addresses for interface %s", iface.Name) + } + + netIface := NetInterface{ + Name: iface.Name, + MTU: iface.MTU, + HardwareAddr: iface.HardwareAddr.String(), + } + + if iface.Flags != 0 { + netIface.Flags = iface.Flags.String() + } + + // Collect network interface addresses + for _, addr := range addrs { + ipNet, ok := addr.(*net.IPNet) + if !ok { + // not an IPNet, should not happen + return nil, fmt.Errorf("got iface.Addrs() element %s that is not net.IPNet, it is %T", addr, addr) + } + // ignore link-local + if ipNet.IP.IsLinkLocalUnicast() { + continue + } + // ignore IPv6 + if ipNet.IP.To4() == nil { + continue + } + netIface.Addresses = append(netIface.Addresses, ipNet.IP.String()) + netIface.Subnets = append(netIface.Subnets, ipNet.String()) + } + + // Discard interfaces with no addresses + if len(netIface.Addresses) != 0 { + netInterfaces = append(netInterfaces, netIface) + } + } + + return netInterfaces, nil +} + +// Get interface name by its IP address. +func GetInterfaceByIP(ip string) string { + ifaces, err := GetValidNetInterfacesForWeb() + if err != nil { + return "" + } + + for _, iface := range ifaces { + for _, addr := range iface.Addresses { + if ip == addr { + return iface.Name + } + } + } + + return "" +} + +// Get IP address with netmask for the specified interface +// Returns an empty string if it fails to find it +func GetSubnet(ifaceName string) string { + netIfaces, err := GetValidNetInterfacesForWeb() + if err != nil { + log.Error("Could not get network interfaces info: %v", err) + return "" + } + + for _, netIface := range netIfaces { + if netIface.Name == ifaceName && len(netIface.Subnets) > 0 { + return netIface.Subnets[0] + } + } + + return "" +} + +// checkPortAvailable is not a cheap test to see if the port is bindable, because it's actually doing the bind momentarily +func CheckPortAvailable(host string, port int) error { + ln, err := net.Listen("tcp", net.JoinHostPort(host, strconv.Itoa(port))) + if err != nil { + return err + } + _ = ln.Close() + + // It seems that net.Listener.Close() doesn't close file descriptors right away. + // We wait for some time and hope that this fd will be closed. + time.Sleep(100 * time.Millisecond) + return nil +} + +func CheckPacketPortAvailable(host string, port int) error { + ln, err := net.ListenPacket("udp", net.JoinHostPort(host, strconv.Itoa(port))) + if err != nil { + return err + } + _ = ln.Close() + + // It seems that net.Listener.Close() doesn't close file descriptors right away. + // We wait for some time and hope that this fd will be closed. + time.Sleep(100 * time.Millisecond) + return err +} + +// check if error is "address already in use" +func ErrorIsAddrInUse(err error) bool { + errOpError, ok := err.(*net.OpError) + if !ok { + return false + } + + errSyscallError, ok := errOpError.Err.(*os.SyscallError) + if !ok { + return false + } + + errErrno, ok := errSyscallError.Err.(syscall.Errno) + if !ok { + return false + } + + if runtime.GOOS == "windows" { + const WSAEADDRINUSE = 10048 + return errErrno == WSAEADDRINUSE + } + + return errErrno == syscall.EADDRINUSE +} diff --git a/home/helpers_test.go b/util/network_utils_test.go similarity index 52% rename from home/helpers_test.go rename to util/network_utils_test.go index c2095966..7feac0f2 100644 --- a/home/helpers_test.go +++ b/util/network_utils_test.go @@ -1,14 +1,12 @@ -package home +package util import ( + "log" "testing" - - "github.com/AdguardTeam/golibs/log" - "github.com/stretchr/testify/assert" ) func TestGetValidNetInterfacesForWeb(t *testing.T) { - ifaces, err := getValidNetInterfacesForWeb() + ifaces, err := GetValidNetInterfacesForWeb() if err != nil { t.Fatalf("Cannot get net interfaces: %s", err) } @@ -24,10 +22,3 @@ func TestGetValidNetInterfacesForWeb(t *testing.T) { log.Printf("%v", iface) } } - -func TestSplitNext(t *testing.T) { - s := " a,b , c " - assert.True(t, SplitNext(&s, ',') == "a") - assert.True(t, SplitNext(&s, ',') == "b") - assert.True(t, SplitNext(&s, ',') == "c" && len(s) == 0) -} diff --git a/home/os_freebsd.go b/util/os_freebsd.go similarity index 86% rename from home/os_freebsd.go rename to util/os_freebsd.go index 43ee223e..33311e16 100644 --- a/home/os_freebsd.go +++ b/util/os_freebsd.go @@ -1,6 +1,6 @@ // +build freebsd -package home +package util import ( "os" @@ -11,7 +11,7 @@ import ( // Set user-specified limit of how many fd's we can use // https://github.com/AdguardTeam/AdGuardHome/issues/659 -func setRlimit(val uint) { +func SetRlimit(val uint) { var rlim syscall.Rlimit rlim.Max = int64(val) rlim.Cur = int64(val) @@ -22,6 +22,6 @@ func setRlimit(val uint) { } // Check if the current user has root (administrator) rights -func haveAdminRights() (bool, error) { +func HaveAdminRights() (bool, error) { return os.Getuid() == 0, nil } diff --git a/home/os_unix.go b/util/os_unix.go similarity index 87% rename from home/os_unix.go rename to util/os_unix.go index 2623376e..338edfa8 100644 --- a/home/os_unix.go +++ b/util/os_unix.go @@ -1,6 +1,6 @@ // +build aix darwin dragonfly linux netbsd openbsd solaris -package home +package util import ( "os" @@ -11,7 +11,7 @@ import ( // Set user-specified limit of how many fd's we can use // https://github.com/AdguardTeam/AdGuardHome/issues/659 -func setRlimit(val uint) { +func SetRlimit(val uint) { var rlim syscall.Rlimit rlim.Max = uint64(val) rlim.Cur = uint64(val) @@ -22,6 +22,6 @@ func setRlimit(val uint) { } // Check if the current user has root (administrator) rights -func haveAdminRights() (bool, error) { +func HaveAdminRights() (bool, error) { return os.Getuid() == 0, nil } diff --git a/home/os_windows.go b/util/os_windows.go similarity index 87% rename from home/os_windows.go rename to util/os_windows.go index f6949d93..e081f758 100644 --- a/home/os_windows.go +++ b/util/os_windows.go @@ -1,12 +1,12 @@ -package home +package util import "golang.org/x/sys/windows" // Set user-specified limit of how many fd's we can use -func setRlimit(val uint) { +func SetRlimit(val uint) { } -func haveAdminRights() (bool, error) { +func HaveAdminRights() (bool, error) { var token windows.Token h, _ := windows.GetCurrentProcess() err := windows.OpenProcessToken(h, windows.TOKEN_QUERY, &token) diff --git a/home/syslog_others.go b/util/syslog_others.go similarity index 62% rename from home/syslog_others.go rename to util/syslog_others.go index 8aa0f8b0..f4ad9119 100644 --- a/home/syslog_others.go +++ b/util/syslog_others.go @@ -1,14 +1,14 @@ // +build !windows,!nacl,!plan9 -package home +package util import ( "log" "log/syslog" ) -// configureSyslog reroutes standard logger output to syslog -func configureSyslog() error { +// ConfigureSyslog reroutes standard logger output to syslog +func ConfigureSyslog(serviceName string) error { w, err := syslog.New(syslog.LOG_NOTICE|syslog.LOG_USER, serviceName) if err != nil { return err diff --git a/home/syslog_windows.go b/util/syslog_windows.go similarity index 94% rename from home/syslog_windows.go rename to util/syslog_windows.go index a80933bb..30ee7815 100644 --- a/home/syslog_windows.go +++ b/util/syslog_windows.go @@ -1,4 +1,4 @@ -package home +package util import ( "log" @@ -17,7 +17,7 @@ func (w *eventLogWriter) Write(b []byte) (int, error) { return len(b), w.el.Info(1, string(b)) } -func configureSyslog() error { +func ConfigureSyslog(serviceName string) error { // Note that the eventlog src is the same as the service name // Otherwise, we will get "the description for event id cannot be found" warning in every log record