AdGuardHome/internal/home/auth_test.go
Eugene Burkov f603c21b55 Pull request: 2826 auth block
Merge in DNS/adguard-home from 2826-auth-block to master

Updates #2826.

Squashed commit of the following:

commit ae87360379270012869ad2bc4e528e07eb9af91e
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Tue Apr 27 15:35:49 2021 +0300

    home: fix mistake

commit dfa2ab05e9a8e70ac1bec36c4eb8ef3b02283b92
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Tue Apr 27 15:31:53 2021 +0300

    home: imp code

commit ff4220d3c3d92ae604e92a0c5c274d9527350d99
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Tue Apr 27 15:14:20 2021 +0300

    home: imp authratelimiter

commit c73a407d8652d64957e35046dbae7168aa45202f
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Tue Apr 27 14:20:17 2021 +0300

    home: fix authratelimiter

commit 724db4380928b055f9995006cf421cbe9c5363e7
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Apr 23 12:15:48 2021 +0300

    home: introduce auth blocker
2021-04-27 18:56:32 +03:00

262 lines
5.9 KiB
Go

package home
import (
"bytes"
"crypto/rand"
"encoding/hex"
"net"
"net/http"
"net/textproto"
"net/url"
"path/filepath"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m)
}
func TestNewSessionToken(t *testing.T) {
// Successful case.
token, err := newSessionToken()
require.Nil(t, err)
assert.Len(t, token, sessionTokenSize)
// Break the rand.Reader.
prevReader := rand.Reader
t.Cleanup(func() {
rand.Reader = prevReader
})
rand.Reader = &bytes.Buffer{}
// Unsuccessful case.
token, err = newSessionToken()
require.NotNil(t, err)
assert.Empty(t, token)
}
func TestAuth(t *testing.T) {
dir := t.TempDir()
fn := filepath.Join(dir, "sessions.db")
users := []User{{
Name: "name",
PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2",
}}
a := InitAuth(fn, nil, 60, nil)
s := session{}
user := User{Name: "name"}
a.UserAdd(&user, "password")
assert.Equal(t, checkSessionNotFound, a.checkSession("notfound"))
a.RemoveSession("notfound")
sess, err := newSessionToken()
assert.Nil(t, err)
sessStr := hex.EncodeToString(sess)
now := time.Now().UTC().Unix()
// check expiration
s.expire = uint32(now)
a.addSession(sess, &s)
assert.Equal(t, checkSessionExpired, a.checkSession(sessStr))
// add session with TTL = 2 sec
s = session{}
s.expire = uint32(time.Now().UTC().Unix() + 2)
a.addSession(sess, &s)
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
a.Close()
// load saved session
a = InitAuth(fn, users, 60, nil)
// the session is still alive
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
// reset our expiration time because checkSession() has just updated it
s.expire = uint32(time.Now().UTC().Unix() + 2)
a.storeSession(sess, &s)
a.Close()
u := a.UserFind("name", "password")
assert.NotEmpty(t, u.Name)
time.Sleep(3 * time.Second)
// load and remove expired sessions
a = InitAuth(fn, users, 60, nil)
assert.Equal(t, checkSessionNotFound, a.checkSession(sessStr))
a.Close()
}
// implements http.ResponseWriter
type testResponseWriter struct {
hdr http.Header
statusCode int
}
func (w *testResponseWriter) Header() http.Header {
return w.hdr
}
func (w *testResponseWriter) Write([]byte) (int, error) {
return 0, nil
}
func (w *testResponseWriter) WriteHeader(statusCode int) {
w.statusCode = statusCode
}
func TestAuthHTTP(t *testing.T) {
dir := t.TempDir()
fn := filepath.Join(dir, "sessions.db")
users := []User{
{Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"},
}
Context.auth = InitAuth(fn, users, 60, nil)
handlerCalled := false
handler := func(_ http.ResponseWriter, _ *http.Request) {
handlerCalled = true
}
handler2 := optionalAuth(handler)
w := testResponseWriter{}
w.hdr = make(http.Header)
r := http.Request{}
r.Header = make(http.Header)
r.Method = http.MethodGet
// get / - we're redirected to login page
r.URL = &url.URL{Path: "/"}
handlerCalled = false
handler2(&w, &r)
assert.Equal(t, http.StatusFound, w.statusCode)
assert.NotEmpty(t, w.hdr.Get("Location"))
assert.False(t, handlerCalled)
// go to login page
loginURL := w.hdr.Get("Location")
r.URL = &url.URL{Path: loginURL}
handlerCalled = false
handler2(&w, &r)
assert.True(t, handlerCalled)
// perform login
cookie, err := Context.auth.httpCookie(loginJSON{Name: "name", Password: "password"}, "")
assert.Nil(t, err)
assert.NotEmpty(t, cookie)
// get /
handler2 = optionalAuth(handler)
w.hdr = make(http.Header)
r.Header.Set("Cookie", cookie)
r.URL = &url.URL{Path: "/"}
handlerCalled = false
handler2(&w, &r)
assert.True(t, handlerCalled)
r.Header.Del("Cookie")
// get / with basic auth
handler2 = optionalAuth(handler)
w.hdr = make(http.Header)
r.URL = &url.URL{Path: "/"}
r.SetBasicAuth("name", "password")
handlerCalled = false
handler2(&w, &r)
assert.True(t, handlerCalled)
r.Header.Del("Authorization")
// get login page with a valid cookie - we're redirected to /
handler2 = optionalAuth(handler)
w.hdr = make(http.Header)
r.Header.Set("Cookie", cookie)
r.URL = &url.URL{Path: loginURL}
handlerCalled = false
handler2(&w, &r)
assert.NotEmpty(t, w.hdr.Get("Location"))
assert.False(t, handlerCalled)
r.Header.Del("Cookie")
// get login page with an invalid cookie
handler2 = optionalAuth(handler)
w.hdr = make(http.Header)
r.Header.Set("Cookie", "bad")
r.URL = &url.URL{Path: loginURL}
handlerCalled = false
handler2(&w, &r)
assert.True(t, handlerCalled)
r.Header.Del("Cookie")
Context.auth.Close()
}
func TestRealIP(t *testing.T) {
const remoteAddr = "1.2.3.4:5678"
testCases := []struct {
name string
header http.Header
remoteAddr string
wantErrMsg string
wantIP net.IP
}{{
name: "success_no_proxy",
header: nil,
remoteAddr: remoteAddr,
wantErrMsg: "",
wantIP: net.IPv4(1, 2, 3, 4),
}, {
name: "success_proxy",
header: http.Header{
textproto.CanonicalMIMEHeaderKey("X-Real-IP"): []string{"1.2.3.5"},
},
remoteAddr: remoteAddr,
wantErrMsg: "",
wantIP: net.IPv4(1, 2, 3, 5),
}, {
name: "success_proxy_multiple",
header: http.Header{
textproto.CanonicalMIMEHeaderKey("X-Forwarded-For"): []string{
"1.2.3.6, 1.2.3.5",
},
},
remoteAddr: remoteAddr,
wantErrMsg: "",
wantIP: net.IPv4(1, 2, 3, 6),
}, {
name: "error_no_proxy",
header: nil,
remoteAddr: "1:::2",
wantErrMsg: `getting ip from client addr: address 1:::2: ` +
`too many colons in address`,
wantIP: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
r := &http.Request{
Header: tc.header,
RemoteAddr: tc.remoteAddr,
}
ip, err := realIP(r)
assert.Equal(t, tc.wantIP, ip)
if tc.wantErrMsg == "" {
assert.NoError(t, err)
} else {
require.Error(t, err)
assert.Equal(t, tc.wantErrMsg, err.Error())
}
})
}
}