2023-07-18 17:02:07 +03:00
|
|
|
package dnsforward
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"fmt"
|
|
|
|
"net"
|
2023-12-01 11:12:03 +03:00
|
|
|
"net/netip"
|
|
|
|
"strconv"
|
2023-07-18 17:02:07 +03:00
|
|
|
"time"
|
|
|
|
|
|
|
|
"github.com/AdguardTeam/golibs/errors"
|
|
|
|
"github.com/AdguardTeam/golibs/log"
|
|
|
|
)
|
|
|
|
|
2023-07-20 14:26:35 +03:00
|
|
|
// DialContext is an [aghnet.DialContextFunc] that uses s to resolve hostnames.
|
2023-12-01 11:12:03 +03:00
|
|
|
// addr should be a valid host:port address, where host could be a domain name
|
|
|
|
// or an IP address.
|
2023-07-18 17:02:07 +03:00
|
|
|
func (s *Server) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) {
|
|
|
|
log.Debug("dnsforward: dialing %q for network %q", addr, network)
|
|
|
|
|
2023-12-01 11:12:03 +03:00
|
|
|
host, portStr, err := net.SplitHostPort(addr)
|
2023-07-18 17:02:07 +03:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
dialer := &net.Dialer{
|
|
|
|
// TODO(a.garipov): Consider making configurable.
|
|
|
|
Timeout: time.Minute * 5,
|
|
|
|
}
|
|
|
|
|
|
|
|
if net.ParseIP(host) != nil {
|
|
|
|
return dialer.DialContext(ctx, network, addr)
|
|
|
|
}
|
|
|
|
|
2023-12-01 11:12:03 +03:00
|
|
|
port, err := strconv.Atoi(portStr)
|
2023-07-18 17:02:07 +03:00
|
|
|
if err != nil {
|
2023-12-01 11:12:03 +03:00
|
|
|
return nil, fmt.Errorf("invalid port %s: %w", portStr, err)
|
2023-07-18 17:02:07 +03:00
|
|
|
}
|
|
|
|
|
2023-12-01 11:12:03 +03:00
|
|
|
ips, err := s.Resolve(ctx, network, host)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("resolving %q: %w", host, err)
|
|
|
|
} else if len(ips) == 0 {
|
2023-07-18 17:02:07 +03:00
|
|
|
return nil, fmt.Errorf("no addresses for host %q", host)
|
|
|
|
}
|
|
|
|
|
2023-12-01 11:12:03 +03:00
|
|
|
log.Debug("dnsforward: resolved %q: %v", host, ips)
|
|
|
|
|
2023-07-18 17:02:07 +03:00
|
|
|
var dialErrs []error
|
2023-12-01 11:12:03 +03:00
|
|
|
for _, ip := range ips {
|
|
|
|
addrPort := netip.AddrPortFrom(ip, uint16(port))
|
|
|
|
conn, err = dialer.DialContext(ctx, network, addrPort.String())
|
2023-07-18 17:02:07 +03:00
|
|
|
if err != nil {
|
|
|
|
dialErrs = append(dialErrs, err)
|
|
|
|
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
2023-08-15 15:09:08 +03:00
|
|
|
return conn, nil
|
2023-07-18 17:02:07 +03:00
|
|
|
}
|
|
|
|
|
2023-08-15 15:09:08 +03:00
|
|
|
return nil, errors.Join(dialErrs...)
|
2023-07-18 17:02:07 +03:00
|
|
|
}
|