From 7eb3e00b351e41445a97bbf246134fdef84f0ce8 Mon Sep 17 00:00:00 2001
From: jvoisin <julien.voisin@dustri.org>
Date: Mon, 21 Dec 2020 19:39:39 +0100
Subject: [PATCH 1/3] Use a couple of defer in internal/home/auth.go

---
 internal/home/auth.go | 9 ++-------
 1 file changed, 2 insertions(+), 7 deletions(-)

diff --git a/internal/home/auth.go b/internal/home/auth.go
index 00407fa0..941c97ab 100644
--- a/internal/home/auth.go
+++ b/internal/home/auth.go
@@ -230,16 +230,15 @@ func (a *Auth) CheckSession(sess string) int {
 	update := false
 
 	a.lock.Lock()
+	defer a.lock.Unlock()
 	s, ok := a.sessions[sess]
 	if !ok {
-		a.lock.Unlock()
 		return -1
 	}
 	if s.expire <= now {
 		delete(a.sessions, sess)
 		key, _ := hex.DecodeString(sess)
 		a.removeSession(key)
-		a.lock.Unlock()
 		return 1
 	}
 
@@ -250,8 +249,6 @@ func (a *Auth) CheckSession(sess string) int {
 		s.expire = newExpire
 	}
 
-	a.lock.Unlock()
-
 	if update {
 		key, _ := hex.DecodeString(sess)
 		if a.storeSession(key, s) {
@@ -517,18 +514,16 @@ func (a *Auth) GetCurrentUser(r *http.Request) User {
 	}
 
 	a.lock.Lock()
+	defer a.lock.Unlock()
 	s, ok := a.sessions[cookie.Value]
 	if !ok {
-		a.lock.Unlock()
 		return User{}
 	}
 	for _, u := range a.users {
 		if u.Name == s.userName {
-			a.lock.Unlock()
 			return u
 		}
 	}
-	a.lock.Unlock()
 	return User{}
 }
 

From 925c5df801dc940239930034cdee3b683304b860 Mon Sep 17 00:00:00 2001
From: Ainar Garipov <A.Garipov@AdGuard.COM>
Date: Tue, 22 Dec 2020 21:05:12 +0300
Subject: [PATCH 2/3] home: improve checkSession

---
 internal/home/auth.go      | 41 +++++++++++++++++++++++++-------------
 internal/home/auth_test.go | 14 ++++++-------
 2 files changed, 34 insertions(+), 21 deletions(-)

diff --git a/internal/home/auth.go b/internal/home/auth.go
index 941c97ab..01f89a26 100644
--- a/internal/home/auth.go
+++ b/internal/home/auth.go
@@ -59,10 +59,10 @@ func (s *session) deserialize(data []byte) bool {
 // Auth - global object
 type Auth struct {
 	db         *bbolt.DB
-	sessions   map[string]*session // session name -> session data
-	lock       sync.Mutex
+	sessions   map[string]*session
 	users      []User
-	sessionTTL uint32 // in seconds
+	lock       sync.Mutex
+	sessionTTL uint32
 }
 
 // User object
@@ -223,23 +223,35 @@ func (a *Auth) removeSession(sess []byte) {
 	log.Debug("Auth: removed session from DB")
 }
 
-// CheckSession - check if session is valid
-// Return 0 if OK;  -1 if session doesn't exist;  1 if session has expired
-func (a *Auth) CheckSession(sess string) int {
+// checkSessionResult is the result of checking a session.
+type checkSessionResult int
+
+// checkSessionResult constants.
+const (
+	checkSessionOK       checkSessionResult = 0
+	checkSessionNotFound checkSessionResult = -1
+	checkSessionExpired  checkSessionResult = 1
+)
+
+// checkSession checks if the session is valid.
+func (a *Auth) checkSession(sess string) (res checkSessionResult) {
 	now := uint32(time.Now().UTC().Unix())
 	update := false
 
 	a.lock.Lock()
 	defer a.lock.Unlock()
+
 	s, ok := a.sessions[sess]
 	if !ok {
-		return -1
+		return checkSessionNotFound
 	}
+
 	if s.expire <= now {
 		delete(a.sessions, sess)
 		key, _ := hex.DecodeString(sess)
 		a.removeSession(key)
-		return 1
+
+		return checkSessionExpired
 	}
 
 	newExpire := now + a.sessionTTL
@@ -256,7 +268,7 @@ func (a *Auth) CheckSession(sess string) int {
 		}
 	}
 
-	return 0
+	return checkSessionOK
 }
 
 // RemoveSession - remove session
@@ -389,8 +401,8 @@ func optionalAuthThird(w http.ResponseWriter, r *http.Request) (authFirst bool)
 		ok = true
 
 	} else if err == nil {
-		r := Context.auth.CheckSession(cookie.Value)
-		if r == 0 {
+		r := Context.auth.checkSession(cookie.Value)
+		if r == checkSessionOK {
 			ok = true
 		} else if r < 0 {
 			log.Debug("Auth: invalid cookie value: %s", cookie)
@@ -431,12 +443,13 @@ func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.Re
 			authRequired := Context.auth != nil && Context.auth.AuthRequired()
 			cookie, err := r.Cookie(sessionCookieName)
 			if authRequired && err == nil {
-				r := Context.auth.CheckSession(cookie.Value)
-				if r == 0 {
+				r := Context.auth.checkSession(cookie.Value)
+				if r == checkSessionOK {
 					w.Header().Set("Location", "/")
 					w.WriteHeader(http.StatusFound)
+
 					return
-				} else if r < 0 {
+				} else if r == checkSessionNotFound {
 					log.Debug("Auth: invalid cookie value: %s", cookie)
 				}
 			}
diff --git a/internal/home/auth_test.go b/internal/home/auth_test.go
index 25db2dd6..0998a2a6 100644
--- a/internal/home/auth_test.go
+++ b/internal/home/auth_test.go
@@ -38,7 +38,7 @@ func TestAuth(t *testing.T) {
 	user := User{Name: "name"}
 	a.UserAdd(&user, "password")
 
-	assert.True(t, a.CheckSession("notfound") == -1)
+	assert.Equal(t, checkSessionNotFound, a.checkSession("notfound"))
 	a.RemoveSession("notfound")
 
 	sess, err := getSession(&users[0])
@@ -49,13 +49,13 @@ func TestAuth(t *testing.T) {
 	// check expiration
 	s.expire = uint32(now)
 	a.addSession(sess, &s)
-	assert.True(t, a.CheckSession(sessStr) == 1)
+	assert.Equal(t, checkSessionExpired, a.checkSession(sessStr))
 
 	// add session with TTL = 2 sec
 	s = session{}
 	s.expire = uint32(time.Now().UTC().Unix() + 2)
 	a.addSession(sess, &s)
-	assert.True(t, a.CheckSession(sessStr) == 0)
+	assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
 
 	a.Close()
 
@@ -63,8 +63,8 @@ func TestAuth(t *testing.T) {
 	a = InitAuth(fn, users, 60)
 
 	// the session is still alive
-	assert.True(t, a.CheckSession(sessStr) == 0)
-	// reset our expiration time because CheckSession() has just updated it
+	assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
+	// reset our expiration time because checkSession() has just updated it
 	s.expire = uint32(time.Now().UTC().Unix() + 2)
 	a.storeSession(sess, &s)
 	a.Close()
@@ -76,7 +76,7 @@ func TestAuth(t *testing.T) {
 
 	// load and remove expired sessions
 	a = InitAuth(fn, users, 60)
-	assert.True(t, a.CheckSession(sessStr) == -1)
+	assert.Equal(t, checkSessionNotFound, a.checkSession(sessStr))
 
 	a.Close()
 	os.Remove(fn)
@@ -111,7 +111,7 @@ func TestAuthHTTP(t *testing.T) {
 	Context.auth = InitAuth(fn, users, 60)
 
 	handlerCalled := false
-	handler := func(w http.ResponseWriter, r *http.Request) {
+	handler := func(_ http.ResponseWriter, _ *http.Request) {
 		handlerCalled = true
 	}
 	handler2 := optionalAuth(handler)

From 1c754788f9139ed9741cf01c6d94bcced6909b8c Mon Sep 17 00:00:00 2001
From: Ainar Garipov <A.Garipov@AdGuard.COM>
Date: Tue, 22 Dec 2020 21:09:53 +0300
Subject: [PATCH 3/3] home: improve getCurrentUser

---
 internal/home/auth.go    | 14 +++++++++-----
 internal/home/control.go |  2 +-
 2 files changed, 10 insertions(+), 6 deletions(-)

diff --git a/internal/home/auth.go b/internal/home/auth.go
index 01f89a26..dce17241 100644
--- a/internal/home/auth.go
+++ b/internal/home/auth.go
@@ -513,30 +513,34 @@ func (a *Auth) UserFind(login, password string) User {
 	return User{}
 }
 
-// GetCurrentUser - get the current user
-func (a *Auth) GetCurrentUser(r *http.Request) User {
+// getCurrentUser returns the current user.  It returns an empty User if the
+// user is not found.
+func (a *Auth) getCurrentUser(r *http.Request) User {
 	cookie, err := r.Cookie(sessionCookieName)
 	if err != nil {
-		// there's no Cookie, check Basic authentication
+		// There's no Cookie, check Basic authentication.
 		user, pass, ok := r.BasicAuth()
 		if ok {
-			u := Context.auth.UserFind(user, pass)
-			return u
+			return Context.auth.UserFind(user, pass)
 		}
+
 		return User{}
 	}
 
 	a.lock.Lock()
 	defer a.lock.Unlock()
+
 	s, ok := a.sessions[cookie.Value]
 	if !ok {
 		return User{}
 	}
+
 	for _, u := range a.users {
 		if u.Name == s.userName {
 			return u
 		}
 	}
+
 	return User{}
 }
 
diff --git a/internal/home/control.go b/internal/home/control.go
index 3443515a..616557a8 100644
--- a/internal/home/control.go
+++ b/internal/home/control.go
@@ -89,7 +89,7 @@ type profileJSON struct {
 
 func handleGetProfile(w http.ResponseWriter, r *http.Request) {
 	pj := profileJSON{}
-	u := Context.auth.GetCurrentUser(r)
+	u := Context.auth.getCurrentUser(r)
 	pj.Name = u.Name
 
 	data, err := json.Marshal(pj)