mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-11-21 20:45:33 +03:00
Pull request: 2471 defer
Merge in DNS/adguard-home from 2471-defer to master Updates #2471. * commit '1c754788f9139ed9741cf01c6d94bcced6909b8c': home: improve getCurrentUser home: improve checkSession Use a couple of defer in internal/home/auth.go
This commit is contained in:
commit
e829e7a064
3 changed files with 46 additions and 34 deletions
|
@ -59,10 +59,10 @@ func (s *session) deserialize(data []byte) bool {
|
|||
// Auth - global object
|
||||
type Auth struct {
|
||||
db *bbolt.DB
|
||||
sessions map[string]*session // session name -> session data
|
||||
lock sync.Mutex
|
||||
sessions map[string]*session
|
||||
users []User
|
||||
sessionTTL uint32 // in seconds
|
||||
lock sync.Mutex
|
||||
sessionTTL uint32
|
||||
}
|
||||
|
||||
// User object
|
||||
|
@ -223,24 +223,35 @@ func (a *Auth) removeSession(sess []byte) {
|
|||
log.Debug("Auth: removed session from DB")
|
||||
}
|
||||
|
||||
// CheckSession - check if session is valid
|
||||
// Return 0 if OK; -1 if session doesn't exist; 1 if session has expired
|
||||
func (a *Auth) CheckSession(sess string) int {
|
||||
// checkSessionResult is the result of checking a session.
|
||||
type checkSessionResult int
|
||||
|
||||
// checkSessionResult constants.
|
||||
const (
|
||||
checkSessionOK checkSessionResult = 0
|
||||
checkSessionNotFound checkSessionResult = -1
|
||||
checkSessionExpired checkSessionResult = 1
|
||||
)
|
||||
|
||||
// checkSession checks if the session is valid.
|
||||
func (a *Auth) checkSession(sess string) (res checkSessionResult) {
|
||||
now := uint32(time.Now().UTC().Unix())
|
||||
update := false
|
||||
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
|
||||
s, ok := a.sessions[sess]
|
||||
if !ok {
|
||||
a.lock.Unlock()
|
||||
return -1
|
||||
return checkSessionNotFound
|
||||
}
|
||||
|
||||
if s.expire <= now {
|
||||
delete(a.sessions, sess)
|
||||
key, _ := hex.DecodeString(sess)
|
||||
a.removeSession(key)
|
||||
a.lock.Unlock()
|
||||
return 1
|
||||
|
||||
return checkSessionExpired
|
||||
}
|
||||
|
||||
newExpire := now + a.sessionTTL
|
||||
|
@ -250,8 +261,6 @@ func (a *Auth) CheckSession(sess string) int {
|
|||
s.expire = newExpire
|
||||
}
|
||||
|
||||
a.lock.Unlock()
|
||||
|
||||
if update {
|
||||
key, _ := hex.DecodeString(sess)
|
||||
if a.storeSession(key, s) {
|
||||
|
@ -259,7 +268,7 @@ func (a *Auth) CheckSession(sess string) int {
|
|||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
return checkSessionOK
|
||||
}
|
||||
|
||||
// RemoveSession - remove session
|
||||
|
@ -392,8 +401,8 @@ func optionalAuthThird(w http.ResponseWriter, r *http.Request) (authFirst bool)
|
|||
ok = true
|
||||
|
||||
} else if err == nil {
|
||||
r := Context.auth.CheckSession(cookie.Value)
|
||||
if r == 0 {
|
||||
r := Context.auth.checkSession(cookie.Value)
|
||||
if r == checkSessionOK {
|
||||
ok = true
|
||||
} else if r < 0 {
|
||||
log.Debug("Auth: invalid cookie value: %s", cookie)
|
||||
|
@ -434,12 +443,13 @@ func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.Re
|
|||
authRequired := Context.auth != nil && Context.auth.AuthRequired()
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if authRequired && err == nil {
|
||||
r := Context.auth.CheckSession(cookie.Value)
|
||||
if r == 0 {
|
||||
r := Context.auth.checkSession(cookie.Value)
|
||||
if r == checkSessionOK {
|
||||
w.Header().Set("Location", "/")
|
||||
w.WriteHeader(http.StatusFound)
|
||||
|
||||
return
|
||||
} else if r < 0 {
|
||||
} else if r == checkSessionNotFound {
|
||||
log.Debug("Auth: invalid cookie value: %s", cookie)
|
||||
}
|
||||
}
|
||||
|
@ -503,32 +513,34 @@ func (a *Auth) UserFind(login, password string) User {
|
|||
return User{}
|
||||
}
|
||||
|
||||
// GetCurrentUser - get the current user
|
||||
func (a *Auth) GetCurrentUser(r *http.Request) User {
|
||||
// getCurrentUser returns the current user. It returns an empty User if the
|
||||
// user is not found.
|
||||
func (a *Auth) getCurrentUser(r *http.Request) User {
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if err != nil {
|
||||
// there's no Cookie, check Basic authentication
|
||||
// There's no Cookie, check Basic authentication.
|
||||
user, pass, ok := r.BasicAuth()
|
||||
if ok {
|
||||
u := Context.auth.UserFind(user, pass)
|
||||
return u
|
||||
return Context.auth.UserFind(user, pass)
|
||||
}
|
||||
|
||||
return User{}
|
||||
}
|
||||
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
|
||||
s, ok := a.sessions[cookie.Value]
|
||||
if !ok {
|
||||
a.lock.Unlock()
|
||||
return User{}
|
||||
}
|
||||
|
||||
for _, u := range a.users {
|
||||
if u.Name == s.userName {
|
||||
a.lock.Unlock()
|
||||
return u
|
||||
}
|
||||
}
|
||||
a.lock.Unlock()
|
||||
|
||||
return User{}
|
||||
}
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ func TestAuth(t *testing.T) {
|
|||
user := User{Name: "name"}
|
||||
a.UserAdd(&user, "password")
|
||||
|
||||
assert.True(t, a.CheckSession("notfound") == -1)
|
||||
assert.Equal(t, checkSessionNotFound, a.checkSession("notfound"))
|
||||
a.RemoveSession("notfound")
|
||||
|
||||
sess, err := getSession(&users[0])
|
||||
|
@ -49,13 +49,13 @@ func TestAuth(t *testing.T) {
|
|||
// check expiration
|
||||
s.expire = uint32(now)
|
||||
a.addSession(sess, &s)
|
||||
assert.True(t, a.CheckSession(sessStr) == 1)
|
||||
assert.Equal(t, checkSessionExpired, a.checkSession(sessStr))
|
||||
|
||||
// add session with TTL = 2 sec
|
||||
s = session{}
|
||||
s.expire = uint32(time.Now().UTC().Unix() + 2)
|
||||
a.addSession(sess, &s)
|
||||
assert.True(t, a.CheckSession(sessStr) == 0)
|
||||
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
|
||||
|
||||
a.Close()
|
||||
|
||||
|
@ -63,8 +63,8 @@ func TestAuth(t *testing.T) {
|
|||
a = InitAuth(fn, users, 60)
|
||||
|
||||
// the session is still alive
|
||||
assert.True(t, a.CheckSession(sessStr) == 0)
|
||||
// reset our expiration time because CheckSession() has just updated it
|
||||
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
|
||||
// reset our expiration time because checkSession() has just updated it
|
||||
s.expire = uint32(time.Now().UTC().Unix() + 2)
|
||||
a.storeSession(sess, &s)
|
||||
a.Close()
|
||||
|
@ -76,7 +76,7 @@ func TestAuth(t *testing.T) {
|
|||
|
||||
// load and remove expired sessions
|
||||
a = InitAuth(fn, users, 60)
|
||||
assert.True(t, a.CheckSession(sessStr) == -1)
|
||||
assert.Equal(t, checkSessionNotFound, a.checkSession(sessStr))
|
||||
|
||||
a.Close()
|
||||
os.Remove(fn)
|
||||
|
@ -111,7 +111,7 @@ func TestAuthHTTP(t *testing.T) {
|
|||
Context.auth = InitAuth(fn, users, 60)
|
||||
|
||||
handlerCalled := false
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
handler := func(_ http.ResponseWriter, _ *http.Request) {
|
||||
handlerCalled = true
|
||||
}
|
||||
handler2 := optionalAuth(handler)
|
||||
|
|
|
@ -89,7 +89,7 @@ type profileJSON struct {
|
|||
|
||||
func handleGetProfile(w http.ResponseWriter, r *http.Request) {
|
||||
pj := profileJSON{}
|
||||
u := Context.auth.GetCurrentUser(r)
|
||||
u := Context.auth.getCurrentUser(r)
|
||||
pj.Name = u.Name
|
||||
|
||||
data, err := json.Marshal(pj)
|
||||
|
|
Loading…
Reference in a new issue