diff --git a/internal/client/client.go b/internal/client/client.go index 780415e6..9e76f01e 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -121,6 +121,7 @@ func (r *Runtime) Info() (cs Source, host string) { // SetInfo sets a host as a client information from the cs. func (r *Runtime) SetInfo(cs Source, hosts []string) { + // TODO(s.chzhen): Use contract where hosts must contain non-empty host. if len(hosts) == 1 && hosts[0] == "" { hosts = []string{} } diff --git a/internal/client/storage.go b/internal/client/storage.go index 66601812..23bb6ca8 100644 --- a/internal/client/storage.go +++ b/internal/client/storage.go @@ -237,21 +237,21 @@ func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) { // UpdateRuntime updates the stored runtime client with information from rc. If // no such client exists, saves the copy of rc in storage. rc must not be nil. -func (s *Storage) UpdateRuntime(rc *Runtime) { +func (s *Storage) UpdateRuntime(rc *Runtime) (added bool) { s.mu.Lock() defer s.mu.Unlock() - s.updateRuntimeLocked(rc) + return s.updateRuntimeLocked(rc) } // updateRuntimeLocked updates the stored runtime client with information from // rc. rc must not be nil. Storage.mu is expected to be locked. -func (s *Storage) updateRuntimeLocked(rc *Runtime) { +func (s *Storage) updateRuntimeLocked(rc *Runtime) (added bool) { stored := s.runtimeIndex.Client(rc.ip) if stored == nil { s.runtimeIndex.Add(rc.Clone()) - return + return true } if rc.whois != nil { @@ -273,11 +273,13 @@ func (s *Storage) updateRuntimeLocked(rc *Runtime) { if rc.hostsFile != nil { stored.hostsFile = slices.Clone(rc.hostsFile) } + + return false } // BatchUpdateBySource updates the stored runtime clients information from the -// specified source. -func (s *Storage) BatchUpdateBySource(src Source, rcs []*Runtime) { +// specified source and returns the number of added and removed clients. +func (s *Storage) BatchUpdateBySource(src Source, rcs []*Runtime) (added, removed int) { s.mu.Lock() defer s.mu.Unlock() @@ -286,14 +288,19 @@ func (s *Storage) BatchUpdateBySource(src Source, rcs []*Runtime) { } for _, rc := range rcs { - s.updateRuntimeLocked(rc) + if s.updateRuntimeLocked(rc) { + added++ + } } for ip, rc := range s.runtimeIndex.index { if rc.isEmpty() { delete(s.runtimeIndex.index, ip) + removed++ } } + + return added, removed } // SizeRuntime returns the number of the runtime clients. diff --git a/internal/client/storage_test.go b/internal/client/storage_test.go index 4a534580..5ac02747 100644 --- a/internal/client/storage_test.go +++ b/internal/client/storage_test.go @@ -659,7 +659,7 @@ func TestStorage_UpdateRuntime(t *testing.T) { added := client.NewRuntime(ip) added.SetInfo(client.SourceARP, []string{addedARP}) - s.UpdateRuntime(added) + require.True(t, s.UpdateRuntime(added)) require.Equal(t, 1, s.SizeRuntime()) got := s.ClientRuntime(ip) @@ -672,7 +672,7 @@ func TestStorage_UpdateRuntime(t *testing.T) { added := client.NewRuntime(ip2) added.SetInfo(client.SourceARP, []string{addedSecondARP}) - s.UpdateRuntime(added) + require.True(t, s.UpdateRuntime(added)) require.Equal(t, 2, s.SizeRuntime()) got := s.ClientRuntime(ip2) @@ -682,7 +682,7 @@ func TestStorage_UpdateRuntime(t *testing.T) { }) t.Run("update_first_client", func(t *testing.T) { - s.UpdateRuntime(updated) + require.False(t, s.UpdateRuntime(updated)) got := s.ClientRuntime(ip) require.Equal(t, 2, s.SizeRuntime()) @@ -744,7 +744,9 @@ func TestStorage_BatchUpdateBySource(t *testing.T) { }) t.Run("populate_storage_with_first_clients", func(t *testing.T) { - s.BatchUpdateBySource(defSrc, firstClients) + added, removed := s.BatchUpdateBySource(defSrc, firstClients) + require.Equal(t, len(firstClients), added) + require.Equal(t, 0, removed) require.Equal(t, len(firstClients), s.SizeRuntime()) rc := s.ClientRuntime(cliFirstIP1) @@ -754,7 +756,9 @@ func TestStorage_BatchUpdateBySource(t *testing.T) { }) t.Run("update_storage", func(t *testing.T) { - s.BatchUpdateBySource(defSrc, updatedClients) + added, removed := s.BatchUpdateBySource(defSrc, updatedClients) + require.Equal(t, len(updatedClients), added) + require.Equal(t, len(firstClients), removed) require.Equal(t, len(updatedClients), s.SizeRuntime()) rc := s.ClientRuntime(cliUpdatedIP3) @@ -767,7 +771,9 @@ func TestStorage_BatchUpdateBySource(t *testing.T) { }) t.Run("remove_all", func(t *testing.T) { - s.BatchUpdateBySource(defSrc, []*client.Runtime{}) + added, removed := s.BatchUpdateBySource(defSrc, []*client.Runtime{}) + require.Equal(t, 0, added) + require.Equal(t, len(updatedClients), removed) require.Equal(t, 0, s.SizeRuntime()) }) } diff --git a/internal/home/clients.go b/internal/home/clients.go index e64f5d15..819564bc 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -642,7 +642,6 @@ func (clients *clientsContainer) addFromHostsFile(hosts *hostsfile.DefaultStorag clients.lock.Lock() defer clients.lock.Unlock() - added := 0 var rcs []*client.Runtime hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) { // Only the first name of the first record is considered a canonical @@ -652,15 +651,13 @@ func (clients *clientsContainer) addFromHostsFile(hosts *hostsfile.DefaultStorag rc := client.NewRuntime(addr) rc.SetInfo(client.SourceHostsFile, []string{names[0]}) - added++ rcs = append(rcs, rc) return true }) - clients.storage.BatchUpdateBySource(client.SourceHostsFile, rcs) - - log.Debug("clients: added %d client aliases from system hosts file", added) + added, removed := clients.storage.BatchUpdateBySource(client.SourceHostsFile, rcs) + log.Debug("clients: added %d, removed %d client aliases from system hosts file", added, removed) } // addFromSystemARP adds the IP-hostname pairings from the output of the arp -a @@ -684,19 +681,16 @@ func (clients *clientsContainer) addFromSystemARP() { clients.lock.Lock() defer clients.lock.Unlock() - added := 0 var rcs []*client.Runtime for _, n := range ns { rc := client.NewRuntime(n.IP) rc.SetInfo(client.SourceARP, []string{n.Name}) - added++ rcs = append(rcs, rc) } - clients.storage.BatchUpdateBySource(client.SourceARP, rcs) - - log.Debug("clients: added %d client aliases from arp neighborhood", added) + added, removed := clients.storage.BatchUpdateBySource(client.SourceARP, rcs) + log.Debug("clients: added %d, removed %d client aliases from arp neighborhood", added, removed) } // close gracefully closes all the client-specific upstream configurations of