// Package websvc contains the AdGuard Home HTTP API service.
//
// NOTE: Packages other than cmd must not import this package, as it imports
// most other packages.
//
// TODO(a.garipov): Add tests.
package websvc

import (
	"context"
	"crypto/tls"
	"fmt"
	"io/fs"
	"log/slog"
	"net/http"
	"net/netip"
	"runtime"
	"time"

	"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
	"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/netutil/httputil"
)

// ConfigManager is the configuration manager interface.
type ConfigManager interface {
	DNS() (svc agh.ServiceWithConfig[*dnssvc.Config])
	Web() (svc agh.ServiceWithConfig[*Config])

	UpdateDNS(ctx context.Context, c *dnssvc.Config) (err error)
	UpdateWeb(ctx context.Context, c *Config) (err error)
}

// Service is the AdGuard Home web service.  A nil *Service is a valid
// [agh.Service] that does nothing.
type Service struct {
	logger       *slog.Logger
	confMgr      ConfigManager
	frontend     fs.FS
	tls          *tls.Config
	pprof        *server
	start        time.Time
	overrideAddr netip.AddrPort
	servers      []*server
	timeout      time.Duration
	pprofPort    uint16
	forceHTTPS   bool
}

// New returns a new properly initialized *Service.  If c is nil, svc is a nil
// *Service that does nothing.  The fields of c must not be modified after
// calling New.
//
// TODO(a.garipov): Get rid of this special handling of nil or explain it
// better.
func New(c *Config) (svc *Service, err error) {
	if c == nil {
		return nil, nil
	}

	svc = &Service{
		logger:       c.Logger,
		confMgr:      c.ConfigManager,
		frontend:     c.Frontend,
		tls:          c.TLS,
		start:        c.Start,
		overrideAddr: c.OverrideAddress,
		timeout:      c.Timeout,
		forceHTTPS:   c.ForceHTTPS,
	}

	mux := http.NewServeMux()
	svc.route(mux)

	if svc.overrideAddr != (netip.AddrPort{}) {
		svc.servers = []*server{newServer(svc.logger, svc.overrideAddr, nil, mux, c.Timeout)}
	} else {
		for _, a := range c.Addresses {
			svc.servers = append(svc.servers, newServer(svc.logger, a, nil, mux, c.Timeout))
		}

		for _, a := range c.SecureAddresses {
			svc.servers = append(svc.servers, newServer(svc.logger, a, c.TLS, mux, c.Timeout))
		}
	}

	svc.setupPprof(c.Pprof)

	return svc, nil
}

// setupPprof sets the pprof properties of svc.
func (svc *Service) setupPprof(c *PprofConfig) {
	if !c.Enabled {
		// Set to zero explicitly in case pprof used to be enabled before a
		// reconfiguration took place.
		runtime.SetBlockProfileRate(0)
		runtime.SetMutexProfileFraction(0)

		return
	}

	runtime.SetBlockProfileRate(1)
	runtime.SetMutexProfileFraction(1)

	pprofMux := http.NewServeMux()
	httputil.RoutePprof(pprofMux)

	svc.pprofPort = c.Port
	addr := netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), c.Port)

	svc.pprof = newServer(svc.logger, addr, nil, pprofMux, 10*time.Minute)
}

// addrs returns all addresses on which this server serves the HTTP API.  addrs
// must not be called simultaneously with Start.  If svc was initialized with
// ":0" addresses, addrs will not return the actual bound ports until Start is
// finished.
func (svc *Service) addrs() (addrs, secureAddrs []netip.AddrPort) {
	if svc.overrideAddr != (netip.AddrPort{}) {
		return []netip.AddrPort{svc.overrideAddr}, nil
	}

	for _, srv := range svc.servers {
		addrPort := netutil.NetAddrToAddrPort(srv.localAddr())
		if addrPort == (netip.AddrPort{}) {
			continue
		}

		if srv.tlsConf == nil {
			addrs = append(addrs, addrPort)
		} else {
			secureAddrs = append(secureAddrs, addrPort)
		}
	}

	return addrs, secureAddrs
}

// type check
var _ agh.ServiceWithConfig[*Config] = (*Service)(nil)

// Start implements the [agh.Service] interface for *Service.  svc may be nil.
// After Start exits, all HTTP servers have tried to start, possibly failing and
// writing error messages to the log.
//
// TODO(a.garipov):  Use the context for cancelation as well.
func (svc *Service) Start(ctx context.Context) (err error) {
	if svc == nil {
		return nil
	}

	svc.logger.InfoContext(ctx, "starting")
	defer svc.logger.InfoContext(ctx, "started")

	for _, srv := range svc.servers {
		go srv.serve(ctx, svc.logger)
	}

	if svc.pprof != nil {
		go svc.pprof.serve(ctx, svc.logger)
	}

	return svc.wait(ctx)
}

// wait waits until either the context is canceled or all servers have started.
func (svc *Service) wait(ctx context.Context) (err error) {
	for !svc.serversHaveStarted() {
		select {
		case <-ctx.Done():
			return ctx.Err()
		default:
			// Wait and let the other goroutines do their job.
			runtime.Gosched()
		}
	}

	return nil
}

// serversHaveStarted returns true if all servers have started serving.
func (svc *Service) serversHaveStarted() (started bool) {
	started = len(svc.servers) != 0
	for _, srv := range svc.servers {
		started = started && srv.localAddr() != nil
	}

	if svc.pprof != nil {
		started = started && svc.pprof.localAddr() != nil
	}

	return started
}

// Shutdown implements the [agh.Service] interface for *Service.  svc may be
// nil.
func (svc *Service) Shutdown(ctx context.Context) (err error) {
	if svc == nil {
		return nil
	}

	svc.logger.InfoContext(ctx, "shutting down")
	defer svc.logger.InfoContext(ctx, "shut down")

	defer func() { err = errors.Annotate(err, "shutting down: %w") }()

	var errs []error
	for _, srv := range svc.servers {
		shutdownErr := srv.shutdown(ctx)
		if shutdownErr != nil {
			// Don't wrap the error, because it's informative enough as is.
			errs = append(errs, err)
		}
	}

	if svc.pprof != nil {
		shutdownErr := svc.pprof.shutdown(ctx)
		if shutdownErr != nil {
			errs = append(errs, fmt.Errorf("pprof: %w", shutdownErr))
		}
	}

	return errors.Join(errs...)
}