package home import ( "bytes" "cmp" "encoding/json" "io" "net/http" "net/http/httptest" "net/netip" "net/url" "slices" "testing" "time" "github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/schedule" "github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // testTimeout is the common timeout for tests and contexts. const testTimeout = 1 * time.Second const ( testClientIP1 = "1.1.1.1" testClientIP2 = "2.2.2.2" ) // testBlockedClientChecker is a mock implementation of the // [BlockedClientChecker] interface. type testBlockedClientChecker struct { onIsBlockedClient func(ip netip.Addr, clientiD string) (blocked bool, rule string) } // type check var _ BlockedClientChecker = (*testBlockedClientChecker)(nil) // IsBlockedClient implements the [BlockedClientChecker] interface for // *testBlockedClientChecker. func (c *testBlockedClientChecker) IsBlockedClient( ip netip.Addr, clientID string, ) (blocked bool, rule string) { return c.onIsBlockedClient(ip, clientID) } // newPersistentClient is a helper function that returns a persistent client // with the specified name and newly generated UID. func newPersistentClient(name string) (c *client.Persistent) { return &client.Persistent{ Name: name, UID: client.MustNewUID(), BlockedServices: &filtering.BlockedServices{ Schedule: schedule.EmptyWeekly(), }, } } // newPersistentClientWithIDs is a helper function that returns a persistent // client with the specified name and ids. func newPersistentClientWithIDs(tb testing.TB, name string, ids []string) (c *client.Persistent) { tb.Helper() c = newPersistentClient(name) err := c.SetIDs(ids) require.NoError(tb, err) return c } // assertClients is a helper function that compares lists of persistent clients. func assertClients(tb testing.TB, want, got []*client.Persistent) { tb.Helper() require.Len(tb, got, len(want)) sortFunc := func(a, b *client.Persistent) (n int) { return cmp.Compare(a.Name, b.Name) } slices.SortFunc(want, sortFunc) slices.SortFunc(got, sortFunc) slices.CompareFunc(want, got, func(a, b *client.Persistent) (n int) { assert.True(tb, a.EqualIDs(b), "%q doesn't have the same ids as %q", a.Name, b.Name) return 0 }) } // assertPersistentClients is a helper function that uses HTTP API to check // whether want persistent clients are the same as the persistent clients stored // in the clients container. func assertPersistentClients(tb testing.TB, clients *clientsContainer, want []*client.Persistent) { tb.Helper() rw := httptest.NewRecorder() clients.handleGetClients(rw, &http.Request{}) body, err := io.ReadAll(rw.Body) require.NoError(tb, err) clientList := &clientListJSON{} err = json.Unmarshal(body, clientList) require.NoError(tb, err) var got []*client.Persistent ctx := testutil.ContextWithTimeout(tb, testTimeout) for _, cj := range clientList.Clients { var c *client.Persistent c, err = clients.jsonToClient(ctx, *cj, nil) require.NoError(tb, err) got = append(got, c) } assertClients(tb, want, got) } // assertPersistentClientsData is a helper function that checks whether want // persistent clients are the same as the persistent clients stored in data. func assertPersistentClientsData( tb testing.TB, clients *clientsContainer, data []map[string]*clientJSON, want []*client.Persistent, ) { tb.Helper() var got []*client.Persistent ctx := testutil.ContextWithTimeout(tb, testTimeout) for _, cm := range data { for _, cj := range cm { var c *client.Persistent c, err := clients.jsonToClient(ctx, *cj, nil) require.NoError(tb, err) got = append(got, c) } } assertClients(tb, want, got) } func TestClientsContainer_HandleAddClient(t *testing.T) { clients := newClientsContainer(t) clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1}) clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2}) clientEmptyID := newPersistentClient("empty_client_id") clientEmptyID.ClientIDs = []string{""} testCases := []struct { name string client *client.Persistent wantCode int wantClient []*client.Persistent }{{ name: "add_one", client: clientOne, wantCode: http.StatusOK, wantClient: []*client.Persistent{clientOne}, }, { name: "add_two", client: clientTwo, wantCode: http.StatusOK, wantClient: []*client.Persistent{clientOne, clientTwo}, }, { name: "duplicate_client", client: clientTwo, wantCode: http.StatusBadRequest, wantClient: []*client.Persistent{clientOne, clientTwo}, }, { name: "empty_client_id", client: clientEmptyID, wantCode: http.StatusBadRequest, wantClient: []*client.Persistent{clientOne, clientTwo}, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { cj := clientToJSON(tc.client) body, err := json.Marshal(cj) require.NoError(t, err) r, err := http.NewRequest(http.MethodPost, "", bytes.NewReader(body)) require.NoError(t, err) rw := httptest.NewRecorder() clients.handleAddClient(rw, r) require.NoError(t, err) require.Equal(t, tc.wantCode, rw.Code) assertPersistentClients(t, clients, tc.wantClient) }) } } func TestClientsContainer_HandleDelClient(t *testing.T) { clients := newClientsContainer(t) ctx := testutil.ContextWithTimeout(t, testTimeout) clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1}) err := clients.storage.Add(ctx, clientOne) require.NoError(t, err) clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2}) err = clients.storage.Add(ctx, clientTwo) require.NoError(t, err) assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo}) testCases := []struct { name string client *client.Persistent wantCode int wantClient []*client.Persistent }{{ name: "remove_one", client: clientOne, wantCode: http.StatusOK, wantClient: []*client.Persistent{clientTwo}, }, { name: "duplicate_client", client: clientOne, wantCode: http.StatusBadRequest, wantClient: []*client.Persistent{clientTwo}, }, { name: "empty_client_name", client: newPersistentClient(""), wantCode: http.StatusBadRequest, wantClient: []*client.Persistent{clientTwo}, }, { name: "remove_two", client: clientTwo, wantCode: http.StatusOK, wantClient: []*client.Persistent{}, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { cj := clientToJSON(tc.client) var body []byte body, err = json.Marshal(cj) require.NoError(t, err) var r *http.Request r, err = http.NewRequest(http.MethodPost, "", bytes.NewReader(body)) require.NoError(t, err) rw := httptest.NewRecorder() clients.handleDelClient(rw, r) require.NoError(t, err) require.Equal(t, tc.wantCode, rw.Code) assertPersistentClients(t, clients, tc.wantClient) }) } } func TestClientsContainer_HandleUpdateClient(t *testing.T) { clients := newClientsContainer(t) ctx := testutil.ContextWithTimeout(t, testTimeout) clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1}) err := clients.storage.Add(ctx, clientOne) require.NoError(t, err) assertPersistentClients(t, clients, []*client.Persistent{clientOne}) clientModified := newPersistentClientWithIDs(t, "client2", []string{testClientIP2}) clientEmptyID := newPersistentClient("empty_client_id") clientEmptyID.ClientIDs = []string{""} testCases := []struct { name string clientName string modified *client.Persistent wantCode int wantClient []*client.Persistent }{{ name: "update_one", clientName: clientOne.Name, modified: clientModified, wantCode: http.StatusOK, wantClient: []*client.Persistent{clientModified}, }, { name: "empty_name", clientName: "", modified: clientOne, wantCode: http.StatusBadRequest, wantClient: []*client.Persistent{clientModified}, }, { name: "client_not_found", clientName: "client_not_found", modified: clientOne, wantCode: http.StatusBadRequest, wantClient: []*client.Persistent{clientModified}, }, { name: "empty_client_id", clientName: clientModified.Name, modified: clientEmptyID, wantCode: http.StatusBadRequest, wantClient: []*client.Persistent{clientModified}, }, { name: "no_ids", clientName: clientModified.Name, modified: newPersistentClient("no_ids"), wantCode: http.StatusBadRequest, wantClient: []*client.Persistent{clientModified}, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { uj := updateJSON{ Name: tc.clientName, Data: *clientToJSON(tc.modified), } var body []byte body, err = json.Marshal(uj) require.NoError(t, err) var r *http.Request r, err = http.NewRequest(http.MethodPost, "", bytes.NewReader(body)) require.NoError(t, err) rw := httptest.NewRecorder() clients.handleUpdateClient(rw, r) require.NoError(t, err) require.Equal(t, tc.wantCode, rw.Code) assertPersistentClients(t, clients, tc.wantClient) }) } } func TestClientsContainer_HandleFindClient(t *testing.T) { clients := newClientsContainer(t) clients.clientChecker = &testBlockedClientChecker{ onIsBlockedClient: func(ip netip.Addr, clientID string) (ok bool, rule string) { return false, "" }, } ctx := testutil.ContextWithTimeout(t, testTimeout) clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1}) err := clients.storage.Add(ctx, clientOne) require.NoError(t, err) clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2}) err = clients.storage.Add(ctx, clientTwo) require.NoError(t, err) assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo}) testCases := []struct { name string query url.Values wantCode int wantClient []*client.Persistent }{{ name: "single", query: url.Values{ "ip0": []string{testClientIP1}, }, wantCode: http.StatusOK, wantClient: []*client.Persistent{clientOne}, }, { name: "multiple", query: url.Values{ "ip0": []string{testClientIP1}, "ip1": []string{testClientIP2}, }, wantCode: http.StatusOK, wantClient: []*client.Persistent{clientOne, clientTwo}, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { var r *http.Request r, err = http.NewRequest(http.MethodGet, "", nil) require.NoError(t, err) r.URL.RawQuery = tc.query.Encode() rw := httptest.NewRecorder() clients.handleFindClient(rw, r) require.NoError(t, err) require.Equal(t, tc.wantCode, rw.Code) var body []byte body, err = io.ReadAll(rw.Body) require.NoError(t, err) clientData := []map[string]*clientJSON{} err = json.Unmarshal(body, &clientData) require.NoError(t, err) assertPersistentClientsData(t, clients, clientData, tc.wantClient) }) } } func TestClientsContainer_HandleSearchClient(t *testing.T) { var ( runtimeCli = "runtime_client1" runtimeCliIP = "3.3.3.3" blockedCliIP = "4.4.4.4" nonExistentCliIP = "5.5.5.5" allowed = false dissallowed = true emptyRule = "" disallowedRule = "disallowed_rule" ) clients := newClientsContainer(t) clients.clientChecker = &testBlockedClientChecker{ onIsBlockedClient: func(ip netip.Addr, _ string) (ok bool, rule string) { if ip == netip.MustParseAddr(blockedCliIP) { return true, disallowedRule } return false, emptyRule }, } ctx := testutil.ContextWithTimeout(t, testTimeout) clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1}) err := clients.storage.Add(ctx, clientOne) require.NoError(t, err) clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2}) err = clients.storage.Add(ctx, clientTwo) require.NoError(t, err) assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo}) clients.UpdateAddress(ctx, netip.MustParseAddr(runtimeCliIP), runtimeCli, nil) testCases := []struct { name string query *searchQueryJSON wantPersistent []*client.Persistent wantRuntime *clientJSON }{{ name: "single", query: &searchQueryJSON{ Clients: []searchClientJSON{{ ID: testClientIP1, }}, }, wantPersistent: []*client.Persistent{clientOne}, }, { name: "multiple", query: &searchQueryJSON{ Clients: []searchClientJSON{{ ID: testClientIP1, }, { ID: testClientIP2, }}, }, wantPersistent: []*client.Persistent{clientOne, clientTwo}, }, { name: "runtime", query: &searchQueryJSON{ Clients: []searchClientJSON{{ ID: runtimeCliIP, }}, }, wantRuntime: &clientJSON{ Name: runtimeCli, IDs: []string{runtimeCliIP}, Disallowed: &allowed, DisallowedRule: &emptyRule, WHOIS: &whois.Info{}, }, }, { name: "blocked_access", query: &searchQueryJSON{ Clients: []searchClientJSON{{ ID: blockedCliIP, }}, }, wantRuntime: &clientJSON{ IDs: []string{blockedCliIP}, Disallowed: &dissallowed, DisallowedRule: &disallowedRule, WHOIS: &whois.Info{}, }, }, { name: "non_existing_client", query: &searchQueryJSON{ Clients: []searchClientJSON{{ ID: nonExistentCliIP, }}, }, wantRuntime: &clientJSON{ IDs: []string{nonExistentCliIP}, Disallowed: &allowed, DisallowedRule: &emptyRule, WHOIS: &whois.Info{}, }, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { var body []byte body, err = json.Marshal(tc.query) require.NoError(t, err) var r *http.Request r, err = http.NewRequest(http.MethodPost, "", bytes.NewReader(body)) require.NoError(t, err) rw := httptest.NewRecorder() clients.handleSearchClient(rw, r) require.NoError(t, err) require.Equal(t, http.StatusOK, rw.Code) body, err = io.ReadAll(rw.Body) require.NoError(t, err) clientData := []map[string]*clientJSON{} err = json.Unmarshal(body, &clientData) require.NoError(t, err) if tc.wantPersistent != nil { assertPersistentClientsData(t, clients, clientData, tc.wantPersistent) return } require.Len(t, clientData, 1) require.Len(t, clientData[0], 1) rc := clientData[0][tc.wantRuntime.IDs[0]] assert.Equal(t, tc.wantRuntime, rc) }) } }