all: imp code

This commit is contained in:
Stanislav Chzhen 2024-09-12 14:15:49 +03:00
parent 9feda414b6
commit f63d0e80fb
4 changed files with 21 additions and 23 deletions

View file

@ -28,8 +28,8 @@ func (ri *runtimeIndex) add(rc *Runtime) {
ri.index[ip] = rc ri.index[ip] = rc
} }
// rangeF calls f for each runtime client in an undefined order. // rangeClients calls f for each runtime client in an undefined order.
func (ri *runtimeIndex) rangeF(f func(rc *Runtime) (cont bool)) { func (ri *runtimeIndex) rangeClients(f func(rc *Runtime) (cont bool)) {
for _, rc := range ri.index { for _, rc := range ri.index {
if !f(rc) { if !f(rc) {
return return

View file

@ -125,14 +125,6 @@ func NewStorage(conf *Config) (s *Storage, err error) {
done: make(chan struct{}), done: make(chan struct{}),
} }
// TODO(s.chzhen): Refactor it.
switch v := s.etcHosts.(type) {
case *aghnet.HostsContainer:
if v == nil {
s.etcHosts = nil
}
}
for i, p := range conf.InitialClients { for i, p := range conf.InitialClients {
err = s.Add(p) err = s.Add(p)
if err != nil { if err != nil {
@ -140,6 +132,12 @@ func NewStorage(conf *Config) (s *Storage, err error) {
} }
} }
if hc, ok := s.etcHosts.(*aghnet.HostsContainer); ok && hc == nil {
s.etcHosts = nil
}
s.ReloadARP()
return s, nil return s, nil
} }
@ -163,9 +161,6 @@ func (s *Storage) Shutdown(_ context.Context) (err error) {
func (s *Storage) periodicARPUpdate() { func (s *Storage) periodicARPUpdate() {
defer log.OnPanic("storage") defer log.OnPanic("storage")
// Initial ARP refresh.
s.ReloadARP()
t := time.NewTicker(s.arpClientsUpdatePeriod) t := time.NewTicker(s.arpClientsUpdatePeriod)
for { for {
@ -376,6 +371,9 @@ func (s *Storage) FindByName(name string) (p *Persistent, ok bool) {
// Find finds persistent client by string representation of the client ID, IP // Find finds persistent client by string representation of the client ID, IP
// address, or MAC. And returns its shallow copy. // address, or MAC. And returns its shallow copy.
//
// TODO(s.chzhen): Accept ClientIDData structure instead, which will contain
// the parsed IP address, if any.
func (s *Storage) Find(id string) (p *Persistent, ok bool) { func (s *Storage) Find(id string) (p *Persistent, ok bool) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@ -540,5 +538,5 @@ func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
s.runtimeIndex.rangeF(f) s.runtimeIndex.rangeClients(f)
} }

View file

@ -44,22 +44,22 @@ type Interface interface {
Neighbors() (ns []arpdb.Neighbor) Neighbors() (ns []arpdb.Neighbor)
} }
// testARP is a mock implementation of the [arpdb.Interface]. // testARPDB is a mock implementation of the [arpdb.Interface].
type testARP struct { type testARPDB struct {
onRefresh func() (err error) onRefresh func() (err error)
onNeighbors func() (ns []arpdb.Neighbor) onNeighbors func() (ns []arpdb.Neighbor)
} }
// type check // type check
var _ arpdb.Interface = (*testARP)(nil) var _ arpdb.Interface = (*testARPDB)(nil)
// Refresh implements the [arpdb.Interface] interface for *testARP. // Refresh implements the [arpdb.Interface] interface for *testARP.
func (c *testARP) Refresh() (err error) { func (c *testARPDB) Refresh() (err error) {
return c.onRefresh() return c.onRefresh()
} }
// Neighbors implements the [arpdb.Interface] interface for *testARP. // Neighbors implements the [arpdb.Interface] interface for *testARP.
func (c *testARP) Neighbors() (ns []arpdb.Neighbor) { func (c *testARPDB) Neighbors() (ns []arpdb.Neighbor) {
return c.onNeighbors() return c.onNeighbors()
} }
@ -186,7 +186,7 @@ func TestStorage_Add_arp(t *testing.T) {
cliName2 = "client_two" cliName2 = "client_two"
) )
a := &testARP{ a := &testARPDB{
onRefresh: func() (err error) { return nil }, onRefresh: func() (err error) { return nil },
onNeighbors: func() (ns []arpdb.Neighbor) { onNeighbors: func() (ns []arpdb.Neighbor) {
mu.Lock() mu.Lock()
@ -327,7 +327,7 @@ func TestClientsDHCP(t *testing.T) {
prsCliIP = netip.MustParseAddr("4.3.2.1") prsCliIP = netip.MustParseAddr("4.3.2.1")
prsCliMAC = mustParseMAC("AA:AA:AA:AA:AA:AA") prsCliMAC = mustParseMAC("AA:AA:AA:AA:AA:AA")
prsCliName = "persitent.dhcp" prsCliName = "persistent.dhcp"
) )
ipToHost := map[netip.Addr]string{ ipToHost := map[netip.Addr]string{

View file

@ -460,12 +460,12 @@ func startDNSServer() error {
err := Context.clients.Start(context.TODO()) err := Context.clients.Start(context.TODO())
if err != nil { if err != nil {
return fmt.Errorf("couldn't start clients container: %w", err) return fmt.Errorf("starting clients container: %w", err)
} }
err = Context.dnsServer.Start() err = Context.dnsServer.Start()
if err != nil { if err != nil {
return fmt.Errorf("couldn't start forwarding DNS server: %w", err) return fmt.Errorf("starting dns server: %w", err)
} }
Context.filters.Start() Context.filters.Start()