mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2025-05-10 09:40:11 +03:00
home: improve checkSession
This commit is contained in:
parent
7eb3e00b35
commit
925c5df801
2 changed files with 34 additions and 21 deletions
internal/home
|
@ -59,10 +59,10 @@ func (s *session) deserialize(data []byte) bool {
|
|||
// Auth - global object
|
||||
type Auth struct {
|
||||
db *bbolt.DB
|
||||
sessions map[string]*session // session name -> session data
|
||||
lock sync.Mutex
|
||||
sessions map[string]*session
|
||||
users []User
|
||||
sessionTTL uint32 // in seconds
|
||||
lock sync.Mutex
|
||||
sessionTTL uint32
|
||||
}
|
||||
|
||||
// User object
|
||||
|
@ -223,23 +223,35 @@ func (a *Auth) removeSession(sess []byte) {
|
|||
log.Debug("Auth: removed session from DB")
|
||||
}
|
||||
|
||||
// CheckSession - check if session is valid
|
||||
// Return 0 if OK; -1 if session doesn't exist; 1 if session has expired
|
||||
func (a *Auth) CheckSession(sess string) int {
|
||||
// checkSessionResult is the result of checking a session.
|
||||
type checkSessionResult int
|
||||
|
||||
// checkSessionResult constants.
|
||||
const (
|
||||
checkSessionOK checkSessionResult = 0
|
||||
checkSessionNotFound checkSessionResult = -1
|
||||
checkSessionExpired checkSessionResult = 1
|
||||
)
|
||||
|
||||
// checkSession checks if the session is valid.
|
||||
func (a *Auth) checkSession(sess string) (res checkSessionResult) {
|
||||
now := uint32(time.Now().UTC().Unix())
|
||||
update := false
|
||||
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
|
||||
s, ok := a.sessions[sess]
|
||||
if !ok {
|
||||
return -1
|
||||
return checkSessionNotFound
|
||||
}
|
||||
|
||||
if s.expire <= now {
|
||||
delete(a.sessions, sess)
|
||||
key, _ := hex.DecodeString(sess)
|
||||
a.removeSession(key)
|
||||
return 1
|
||||
|
||||
return checkSessionExpired
|
||||
}
|
||||
|
||||
newExpire := now + a.sessionTTL
|
||||
|
@ -256,7 +268,7 @@ func (a *Auth) CheckSession(sess string) int {
|
|||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
return checkSessionOK
|
||||
}
|
||||
|
||||
// RemoveSession - remove session
|
||||
|
@ -389,8 +401,8 @@ func optionalAuthThird(w http.ResponseWriter, r *http.Request) (authFirst bool)
|
|||
ok = true
|
||||
|
||||
} else if err == nil {
|
||||
r := Context.auth.CheckSession(cookie.Value)
|
||||
if r == 0 {
|
||||
r := Context.auth.checkSession(cookie.Value)
|
||||
if r == checkSessionOK {
|
||||
ok = true
|
||||
} else if r < 0 {
|
||||
log.Debug("Auth: invalid cookie value: %s", cookie)
|
||||
|
@ -431,12 +443,13 @@ func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.Re
|
|||
authRequired := Context.auth != nil && Context.auth.AuthRequired()
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if authRequired && err == nil {
|
||||
r := Context.auth.CheckSession(cookie.Value)
|
||||
if r == 0 {
|
||||
r := Context.auth.checkSession(cookie.Value)
|
||||
if r == checkSessionOK {
|
||||
w.Header().Set("Location", "/")
|
||||
w.WriteHeader(http.StatusFound)
|
||||
|
||||
return
|
||||
} else if r < 0 {
|
||||
} else if r == checkSessionNotFound {
|
||||
log.Debug("Auth: invalid cookie value: %s", cookie)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue