From e4b53db55846dd74bbcfe7f3abc66f7481a1c577 Mon Sep 17 00:00:00 2001
From: Simon Zolin <s.zolin@adguard.com>
Date: Mon, 1 Apr 2019 12:22:54 +0300
Subject: [PATCH] + app: unix, windows: require root user on first launch

---
 app.go        | 56 ++++++++++++++++++++++++++++++++++++++++++---------
 os_unix.go    |  6 ++++++
 os_windows.go | 23 +++++++++++++++++++++
 3 files changed, 76 insertions(+), 9 deletions(-)

diff --git a/app.go b/app.go
index dacf1034..acb463de 100644
--- a/app.go
+++ b/app.go
@@ -1,16 +1,20 @@
 package main
 
 import (
+	"bufio"
 	"crypto/tls"
 	"fmt"
+	"io"
 	"io/ioutil"
 	"net"
 	"net/http"
 	"os"
+	"os/exec"
 	"os/signal"
 	"path/filepath"
 	"runtime"
 	"strconv"
+	"strings"
 	"sync"
 	"syscall"
 
@@ -45,15 +49,6 @@ func main() {
 		return
 	}
 
-	signalChannel := make(chan os.Signal)
-	signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
-	go func() {
-		<-signalChannel
-		cleanup()
-		cleanupAlways()
-		os.Exit(0)
-	}()
-
 	// run the protection
 	run(args)
 }
@@ -83,6 +78,18 @@ func run(args options) {
 	}
 
 	config.firstRun = detectFirstRun()
+	if config.firstRun {
+		requireAdminRights()
+	}
+
+	signalChannel := make(chan os.Signal)
+	signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
+	go func() {
+		<-signalChannel
+		cleanup()
+		cleanupAlways()
+		os.Exit(0)
+	}()
 
 	// Do the upgrade if necessary
 	err := upgradeConfig()
@@ -228,6 +235,37 @@ func run(args options) {
 	}
 }
 
+// Check if the current user has root (administrator) rights
+//  and if not, ask and try to run as root
+func requireAdminRights() {
+	admin, _ := haveAdminRights()
+	if admin {
+		return
+	}
+
+	if runtime.GOOS == "windows" {
+		log.Fatal("This is the first launch of AdGuard Home. You must run it as Administrator.")
+
+	} else {
+		log.Error("This is the first launch of AdGuard Home. You must run it as root.")
+
+		_, _ = io.WriteString(os.Stdout, "Do you want to start AdGuard Home as root user? [y/n] ")
+		stdin := bufio.NewReader(os.Stdin)
+		buf, _ := stdin.ReadString('\n')
+		buf = strings.TrimSpace(buf)
+		if buf != "y" {
+			os.Exit(1)
+		}
+
+		cmd := exec.Command("sudo", os.Args...)
+		cmd.Stdin = os.Stdin
+		cmd.Stdout = os.Stdout
+		cmd.Stderr = os.Stderr
+		_ = cmd.Run()
+		os.Exit(1)
+	}
+}
+
 // Write PID to a file
 func writePIDFile(fn string) bool {
 	data := fmt.Sprintf("%d", os.Getpid())
diff --git a/os_unix.go b/os_unix.go
index 12a918c8..9baa357d 100644
--- a/os_unix.go
+++ b/os_unix.go
@@ -3,6 +3,7 @@
 package main
 
 import (
+	"os"
 	"syscall"
 
 	"github.com/AdguardTeam/golibs/log"
@@ -19,3 +20,8 @@ func setRlimit(val uint) {
 		log.Error("Setrlimit() failed: %v", err)
 	}
 }
+
+// Check if the current user has root (administrator) rights
+func haveAdminRights() (bool, error) {
+	return os.Getuid() == 0, nil
+}
diff --git a/os_windows.go b/os_windows.go
index 1155e04b..e847ccce 100644
--- a/os_windows.go
+++ b/os_windows.go
@@ -1,5 +1,28 @@
 package main
 
+import "golang.org/x/sys/windows"
+
 // Set user-specified limit of how many fd's we can use
 func setRlimit(val uint) {
 }
+
+func haveAdminRights() (bool, error) {
+	var token windows.Token
+	h, _ := windows.GetCurrentProcess()
+	err := windows.OpenProcessToken(h, windows.TOKEN_QUERY, &token)
+	if err != nil {
+		return false, err
+	}
+
+	info := make([]byte, 4)
+	var returnedLen uint32
+	err = windows.GetTokenInformation(token, windows.TokenElevation, &info[0], uint32(len(info)), &returnedLen)
+	token.Close()
+	if err != nil {
+		return false, err
+	}
+	if info[0] == 0 {
+		return false, nil
+	}
+	return true, nil
+}