AdGuardHome/internal/dnsforward/dialcontext.go

64 lines
1.4 KiB
Go
Raw Normal View History

2023-07-26 13:18:44 +03:00
package dnsforward
import (
"context"
"fmt"
"net"
"net/netip"
"strconv"
2023-07-26 13:18:44 +03:00
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
)
// DialContext is an [aghnet.DialContextFunc] that uses s to resolve hostnames.
// addr should be a valid host:port address, where host could be a domain name
// or an IP address.
2023-07-26 13:18:44 +03:00
func (s *Server) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) {
log.Debug("dnsforward: dialing %q for network %q", addr, network)
host, portStr, err := net.SplitHostPort(addr)
2023-07-26 13:18:44 +03:00
if err != nil {
return nil, err
}
dialer := &net.Dialer{
// TODO(a.garipov): Consider making configurable.
Timeout: time.Minute * 5,
}
if net.ParseIP(host) != nil {
return dialer.DialContext(ctx, network, addr)
}
port, err := strconv.Atoi(portStr)
2023-07-26 13:18:44 +03:00
if err != nil {
return nil, fmt.Errorf("invalid port %s: %w", portStr, err)
2023-07-26 13:18:44 +03:00
}
ips, err := s.Resolve(ctx, network, host)
if err != nil {
return nil, fmt.Errorf("resolving %q: %w", host, err)
} else if len(ips) == 0 {
2023-07-26 13:18:44 +03:00
return nil, fmt.Errorf("no addresses for host %q", host)
}
log.Debug("dnsforward: resolved %q: %v", host, ips)
2023-07-26 13:18:44 +03:00
var dialErrs []error
for _, ip := range ips {
addrPort := netip.AddrPortFrom(ip, uint16(port))
conn, err = dialer.DialContext(ctx, network, addrPort.String())
2023-07-26 13:18:44 +03:00
if err != nil {
dialErrs = append(dialErrs, err)
continue
}
2023-09-07 17:13:48 +03:00
return conn, nil
2023-07-26 13:18:44 +03:00
}
2023-09-07 17:13:48 +03:00
return nil, errors.Join(dialErrs...)
2023-07-26 13:18:44 +03:00
}