2019-08-29 12:34:07 +03:00
package home
import (
2020-11-20 17:32:41 +03:00
"crypto/rand"
2019-08-29 12:34:07 +03:00
"encoding/binary"
"encoding/hex"
"fmt"
"net/http"
"sync"
"time"
2024-10-02 21:00:15 +03:00
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
2022-09-29 19:04:26 +03:00
"github.com/AdguardTeam/golibs/errors"
2019-08-29 12:34:07 +03:00
"github.com/AdguardTeam/golibs/log"
2024-03-20 19:25:59 +03:00
"github.com/AdguardTeam/golibs/netutil"
2020-04-05 18:21:26 +03:00
"go.etcd.io/bbolt"
2019-08-29 12:34:07 +03:00
"golang.org/x/crypto/bcrypt"
)
2021-04-06 14:31:20 +03:00
// sessionTokenSize is the length of session token in bytes.
const sessionTokenSize = 16
2019-08-29 12:34:07 +03:00
2019-10-21 17:44:07 +03:00
type session struct {
userName string
2022-09-29 19:04:26 +03:00
// expire is the expiration time, in seconds.
expire uint32
2019-10-21 17:44:07 +03:00
}
func ( s * session ) serialize ( ) [ ] byte {
2020-11-06 12:15:08 +03:00
const (
expireLen = 4
nameLen = 2
)
data := make ( [ ] byte , expireLen + nameLen + len ( s . userName ) )
2019-10-21 17:44:07 +03:00
binary . BigEndian . PutUint32 ( data [ 0 : 4 ] , s . expire )
binary . BigEndian . PutUint16 ( data [ 4 : 6 ] , uint16 ( len ( s . userName ) ) )
copy ( data [ 6 : ] , [ ] byte ( s . userName ) )
return data
}
func ( s * session ) deserialize ( data [ ] byte ) bool {
if len ( data ) < 4 + 2 {
return false
}
s . expire = binary . BigEndian . Uint32 ( data [ 0 : 4 ] )
nameLen := binary . BigEndian . Uint16 ( data [ 4 : 6 ] )
data = data [ 6 : ]
if len ( data ) < int ( nameLen ) {
return false
}
s . userName = string ( data )
return true
}
2024-03-20 19:25:59 +03:00
// Auth is the global authentication object.
2019-08-29 12:34:07 +03:00
type Auth struct {
2024-03-20 19:25:59 +03:00
trustedProxies netutil . SubnetSet
db * bbolt . DB
rateLimiter * authRateLimiter
sessions map [ string ] * session
users [ ] webUser
lock sync . Mutex
sessionTTL uint32
2019-08-29 12:34:07 +03:00
}
2022-09-29 19:04:26 +03:00
// webUser represents a user of the Web UI.
2023-11-03 16:07:15 +03:00
//
// TODO(s.chzhen): Improve naming.
2022-09-29 19:04:26 +03:00
type webUser struct {
2019-08-29 12:34:07 +03:00
Name string ` yaml:"name" `
2022-09-29 19:04:26 +03:00
PasswordHash string ` yaml:"password" `
2019-08-29 12:34:07 +03:00
}
2024-03-20 19:25:59 +03:00
// InitAuth initializes the global authentication object.
func InitAuth (
dbFilename string ,
users [ ] webUser ,
sessionTTL uint32 ,
rateLimiter * authRateLimiter ,
trustedProxies netutil . SubnetSet ,
) ( a * Auth ) {
2020-04-15 15:17:57 +03:00
log . Info ( "Initializing auth module: %s" , dbFilename )
2024-03-20 19:25:59 +03:00
a = & Auth {
sessionTTL : sessionTTL ,
rateLimiter : rateLimiter ,
sessions : make ( map [ string ] * session ) ,
users : users ,
trustedProxies : trustedProxies ,
2021-04-27 18:56:32 +03:00
}
2019-08-29 12:34:07 +03:00
var err error
2024-10-02 21:00:15 +03:00
a . db , err = bbolt . Open ( dbFilename , aghos . DefaultPermFile , nil )
2019-08-29 12:34:07 +03:00
if err != nil {
2021-04-06 14:31:20 +03:00
log . Error ( "auth: open DB: %s: %s" , dbFilename , err )
2020-07-02 16:52:29 +03:00
if err . Error ( ) == "invalid argument" {
2021-04-08 16:44:01 +03:00
log . Error ( "AdGuard Home cannot be initialized due to an incompatible file system.\nPlease read the explanation here: https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#limitations" )
2020-07-02 16:52:29 +03:00
}
2021-04-06 14:31:20 +03:00
2019-08-29 12:34:07 +03:00
return nil
}
a . loadSessions ( )
2021-04-06 14:31:20 +03:00
log . Info ( "auth: initialized. users:%d sessions:%d" , len ( a . users ) , len ( a . sessions ) )
2021-04-27 18:56:32 +03:00
return a
2019-08-29 12:34:07 +03:00
}
2024-03-20 19:25:59 +03:00
// Close closes the authentication database.
2019-08-29 12:34:07 +03:00
func ( a * Auth ) Close ( ) {
_ = a . db . Close ( )
}
2019-10-21 17:44:07 +03:00
func bucketName ( ) [ ] byte {
return [ ] byte ( "sessions-2" )
}
2024-03-20 19:25:59 +03:00
// loadSessions loads sessions from the database file and removes expired
// sessions.
2019-08-29 12:34:07 +03:00
func ( a * Auth ) loadSessions ( ) {
tx , err := a . db . Begin ( true )
if err != nil {
2021-04-06 14:31:20 +03:00
log . Error ( "auth: bbolt.Begin: %s" , err )
2019-08-29 12:34:07 +03:00
return
}
defer func ( ) {
_ = tx . Rollback ( )
} ( )
2019-10-21 17:44:07 +03:00
bkt := tx . Bucket ( bucketName ( ) )
2019-08-29 12:34:07 +03:00
if bkt == nil {
return
}
removed := 0
2019-10-21 17:44:07 +03:00
if tx . Bucket ( [ ] byte ( "sessions" ) ) != nil {
_ = tx . DeleteBucket ( [ ] byte ( "sessions" ) )
removed = 1
}
2019-08-29 12:34:07 +03:00
now := uint32 ( time . Now ( ) . UTC ( ) . Unix ( ) )
forEach := func ( k , v [ ] byte ) error {
2019-10-21 17:44:07 +03:00
s := session { }
if ! s . deserialize ( v ) || s . expire <= now {
2019-08-29 12:34:07 +03:00
err = bkt . Delete ( k )
if err != nil {
2021-04-06 14:31:20 +03:00
log . Error ( "auth: bbolt.Delete: %s" , err )
2019-08-29 12:34:07 +03:00
} else {
removed ++
}
2021-04-06 14:31:20 +03:00
2019-08-29 12:34:07 +03:00
return nil
}
2019-10-21 17:44:07 +03:00
a . sessions [ hex . EncodeToString ( k ) ] = & s
2019-08-29 12:34:07 +03:00
return nil
}
_ = bkt . ForEach ( forEach )
if removed != 0 {
2019-09-18 13:17:35 +03:00
err = tx . Commit ( )
if err != nil {
log . Error ( "bolt.Commit(): %s" , err )
}
2019-08-29 12:34:07 +03:00
}
2021-04-06 14:31:20 +03:00
log . Debug ( "auth: loaded %d sessions from DB (removed %d expired)" , len ( a . sessions ) , removed )
2019-08-29 12:34:07 +03:00
}
2024-03-20 19:25:59 +03:00
// addSession adds a new session to the list of sessions and saves it in the
// database file.
2019-10-21 17:44:07 +03:00
func ( a * Auth ) addSession ( data [ ] byte , s * session ) {
2019-11-12 14:24:27 +03:00
name := hex . EncodeToString ( data )
2019-08-29 12:34:07 +03:00
a . lock . Lock ( )
2019-11-12 14:24:27 +03:00
a . sessions [ name ] = s
2019-08-29 12:34:07 +03:00
a . lock . Unlock ( )
2019-11-12 14:24:27 +03:00
if a . storeSession ( data , s ) {
2021-04-06 14:31:20 +03:00
log . Debug ( "auth: created session %s: expire=%d" , name , s . expire )
2019-11-12 14:24:27 +03:00
}
2019-10-21 17:44:07 +03:00
}
2019-08-29 12:34:07 +03:00
2024-03-20 19:25:59 +03:00
// storeSession saves a session in the database file.
2019-11-12 14:24:27 +03:00
func ( a * Auth ) storeSession ( data [ ] byte , s * session ) bool {
2019-08-29 12:34:07 +03:00
tx , err := a . db . Begin ( true )
if err != nil {
2021-04-06 14:31:20 +03:00
log . Error ( "auth: bbolt.Begin: %s" , err )
2019-11-12 14:24:27 +03:00
return false
2019-08-29 12:34:07 +03:00
}
defer func ( ) {
_ = tx . Rollback ( )
} ( )
2019-10-21 17:44:07 +03:00
bkt , err := tx . CreateBucketIfNotExists ( bucketName ( ) )
2019-08-29 12:34:07 +03:00
if err != nil {
2021-04-06 14:31:20 +03:00
log . Error ( "auth: bbolt.CreateBucketIfNotExists: %s" , err )
2019-11-12 14:24:27 +03:00
return false
2019-08-29 12:34:07 +03:00
}
2021-04-06 14:31:20 +03:00
2019-10-21 17:44:07 +03:00
err = bkt . Put ( data , s . serialize ( ) )
2019-08-29 12:34:07 +03:00
if err != nil {
2021-04-06 14:31:20 +03:00
log . Error ( "auth: bbolt.Put: %s" , err )
2019-11-12 14:24:27 +03:00
return false
2019-08-29 12:34:07 +03:00
}
err = tx . Commit ( )
if err != nil {
2021-04-06 14:31:20 +03:00
log . Error ( "auth: bbolt.Commit: %s" , err )
2019-11-12 14:24:27 +03:00
return false
2019-08-29 12:34:07 +03:00
}
2021-04-06 14:31:20 +03:00
2019-11-12 14:24:27 +03:00
return true
2019-08-29 12:34:07 +03:00
}
2023-11-03 16:07:15 +03:00
// removeSessionFromFile removes a stored session from the DB file on disk.
func ( a * Auth ) removeSessionFromFile ( sess [ ] byte ) {
2019-08-29 12:34:07 +03:00
tx , err := a . db . Begin ( true )
if err != nil {
2021-04-06 14:31:20 +03:00
log . Error ( "auth: bbolt.Begin: %s" , err )
2019-08-29 12:34:07 +03:00
return
}
2021-04-06 14:31:20 +03:00
2019-08-29 12:34:07 +03:00
defer func ( ) {
_ = tx . Rollback ( )
} ( )
2019-10-21 17:44:07 +03:00
bkt := tx . Bucket ( bucketName ( ) )
2019-08-29 12:34:07 +03:00
if bkt == nil {
2021-04-06 14:31:20 +03:00
log . Error ( "auth: bbolt.Bucket" )
2019-08-29 12:34:07 +03:00
return
}
2021-04-06 14:31:20 +03:00
2019-08-29 12:34:07 +03:00
err = bkt . Delete ( sess )
if err != nil {
2021-04-06 14:31:20 +03:00
log . Error ( "auth: bbolt.Put: %s" , err )
2019-08-29 12:34:07 +03:00
return
}
err = tx . Commit ( )
if err != nil {
2021-04-06 14:31:20 +03:00
log . Error ( "auth: bbolt.Commit: %s" , err )
2019-08-29 12:34:07 +03:00
return
}
2021-04-06 14:31:20 +03:00
log . Debug ( "auth: removed session from DB" )
2019-08-29 12:34:07 +03:00
}
2020-12-22 21:05:12 +03:00
// checkSessionResult is the result of checking a session.
type checkSessionResult int
// checkSessionResult constants.
const (
checkSessionOK checkSessionResult = 0
checkSessionNotFound checkSessionResult = - 1
checkSessionExpired checkSessionResult = 1
)
// checkSession checks if the session is valid.
func ( a * Auth ) checkSession ( sess string ) ( res checkSessionResult ) {
2019-08-29 12:34:07 +03:00
now := uint32 ( time . Now ( ) . UTC ( ) . Unix ( ) )
update := false
a . lock . Lock ( )
2020-12-21 21:39:39 +03:00
defer a . lock . Unlock ( )
2020-12-22 21:05:12 +03:00
2019-10-21 17:44:07 +03:00
s , ok := a . sessions [ sess ]
2019-08-29 12:34:07 +03:00
if ! ok {
2020-12-22 21:05:12 +03:00
return checkSessionNotFound
2019-08-29 12:34:07 +03:00
}
2020-12-22 21:05:12 +03:00
2019-10-21 17:44:07 +03:00
if s . expire <= now {
2019-08-29 12:34:07 +03:00
delete ( a . sessions , sess )
key , _ := hex . DecodeString ( sess )
2023-11-03 16:07:15 +03:00
a . removeSessionFromFile ( key )
2020-12-22 21:05:12 +03:00
return checkSessionExpired
2019-08-29 12:34:07 +03:00
}
2019-11-12 14:23:00 +03:00
newExpire := now + a . sessionTTL
2019-10-21 17:44:07 +03:00
if s . expire / ( 24 * 60 * 60 ) != newExpire / ( 24 * 60 * 60 ) {
2019-08-29 12:34:07 +03:00
// update expiration time once a day
update = true
2019-10-21 17:44:07 +03:00
s . expire = newExpire
2019-08-29 12:34:07 +03:00
}
if update {
key , _ := hex . DecodeString ( sess )
2019-11-12 14:24:27 +03:00
if a . storeSession ( key , s ) {
2021-04-06 14:31:20 +03:00
log . Debug ( "auth: updated session %s: expire=%d" , sess , s . expire )
2019-11-12 14:24:27 +03:00
}
2019-08-29 12:34:07 +03:00
}
2020-12-22 21:05:12 +03:00
return checkSessionOK
2019-08-29 12:34:07 +03:00
}
2023-11-03 16:07:15 +03:00
// removeSession removes the session from the active sessions and the disk.
func ( a * Auth ) removeSession ( sess string ) {
2019-08-29 12:34:07 +03:00
key , _ := hex . DecodeString ( sess )
a . lock . Lock ( )
delete ( a . sessions , sess )
a . lock . Unlock ( )
2023-11-03 16:07:15 +03:00
a . removeSessionFromFile ( key )
2019-08-29 12:34:07 +03:00
}
2023-11-03 16:07:15 +03:00
// addUser adds a new user with the given password.
func ( a * Auth ) addUser ( u * webUser , password string ) ( err error ) {
2019-08-29 12:34:07 +03:00
if len ( password ) == 0 {
2023-10-05 16:20:28 +03:00
return errors . Error ( "empty password" )
2019-08-29 12:34:07 +03:00
}
hash , err := bcrypt . GenerateFromPassword ( [ ] byte ( password ) , bcrypt . DefaultCost )
if err != nil {
2023-10-05 16:20:28 +03:00
return fmt . Errorf ( "generating hash: %w" , err )
2019-08-29 12:34:07 +03:00
}
2023-10-05 16:20:28 +03:00
2019-08-29 12:34:07 +03:00
u . PasswordHash = string ( hash )
a . lock . Lock ( )
2023-10-05 16:20:28 +03:00
defer a . lock . Unlock ( )
2019-08-29 12:34:07 +03:00
a . users = append ( a . users , * u )
2023-10-05 16:20:28 +03:00
log . Debug ( "auth: added user with login %q" , u . Name )
return nil
2019-08-29 12:34:07 +03:00
}
2022-09-29 19:04:26 +03:00
// findUser returns a user if there is one.
func ( a * Auth ) findUser ( login , password string ) ( u webUser , ok bool ) {
2019-08-29 12:34:07 +03:00
a . lock . Lock ( )
defer a . lock . Unlock ( )
2022-09-29 19:04:26 +03:00
for _ , u = range a . users {
2019-08-29 12:34:07 +03:00
if u . Name == login &&
bcrypt . CompareHashAndPassword ( [ ] byte ( u . PasswordHash ) , [ ] byte ( password ) ) == nil {
2022-09-29 19:04:26 +03:00
return u , true
2019-08-29 12:34:07 +03:00
}
}
2022-09-29 19:04:26 +03:00
return webUser { } , false
2019-08-29 12:34:07 +03:00
}
2020-12-22 21:09:53 +03:00
// getCurrentUser returns the current user. It returns an empty User if the
// user is not found.
2022-09-29 19:04:26 +03:00
func ( a * Auth ) getCurrentUser ( r * http . Request ) ( u webUser ) {
2019-11-25 15:45:50 +03:00
cookie , err := r . Cookie ( sessionCookieName )
2019-10-21 17:44:07 +03:00
if err != nil {
2020-12-22 21:09:53 +03:00
// There's no Cookie, check Basic authentication.
2019-10-21 17:44:07 +03:00
user , pass , ok := r . BasicAuth ( )
if ok {
2022-09-29 19:04:26 +03:00
u , _ = Context . auth . findUser ( user , pass )
return u
2019-10-21 17:44:07 +03:00
}
2020-12-22 21:09:53 +03:00
2022-09-29 19:04:26 +03:00
return webUser { }
2019-10-21 17:44:07 +03:00
}
a . lock . Lock ( )
2020-12-21 21:39:39 +03:00
defer a . lock . Unlock ( )
2020-12-22 21:09:53 +03:00
2019-10-21 17:44:07 +03:00
s , ok := a . sessions [ cookie . Value ]
if ! ok {
2022-09-29 19:04:26 +03:00
return webUser { }
2019-10-21 17:44:07 +03:00
}
2020-12-22 21:09:53 +03:00
2022-09-29 19:04:26 +03:00
for _ , u = range a . users {
2019-10-21 17:44:07 +03:00
if u . Name == s . userName {
return u
}
}
2020-12-22 21:09:53 +03:00
2022-09-29 19:04:26 +03:00
return webUser { }
2019-10-21 17:44:07 +03:00
}
2023-11-03 16:07:15 +03:00
// usersList returns a copy of a users list.
func ( a * Auth ) usersList ( ) ( users [ ] webUser ) {
2019-08-29 12:34:07 +03:00
a . lock . Lock ( )
2023-11-03 16:07:15 +03:00
defer a . lock . Unlock ( )
users = make ( [ ] webUser , len ( a . users ) )
copy ( users , a . users )
2019-08-29 12:34:07 +03:00
return users
}
2023-11-03 16:07:15 +03:00
// authRequired returns true if a authentication is required.
func ( a * Auth ) authRequired ( ) bool {
2020-07-03 20:34:08 +03:00
if GLMode {
return true
}
2019-08-29 12:34:07 +03:00
a . lock . Lock ( )
2023-11-03 16:07:15 +03:00
defer a . lock . Unlock ( )
return len ( a . users ) != 0
}
// newSessionToken returns cryptographically secure randomly generated slice of
// bytes of sessionTokenSize length.
//
// TODO(e.burkov): Think about using byte array instead of byte slice.
func newSessionToken ( ) ( data [ ] byte , err error ) {
randData := make ( [ ] byte , sessionTokenSize )
_ , err = rand . Read ( randData )
if err != nil {
return nil , err
}
return randData , nil
2019-08-29 12:34:07 +03:00
}