package dnsforward import ( "crypto/tls" "fmt" "path" "strings" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/golibs/errors" "github.com/lucas-clemente/quic-go" ) // ValidateClientID returns an error if clientID is not a valid client ID. func ValidateClientID(clientID string) (err error) { err = aghnet.ValidateDomainNameLabel(clientID) if err != nil { // Replace the domain name label wrapper with our own. return fmt.Errorf("invalid client id %q: %w", clientID, errors.Unwrap(err)) } return nil } // clientIDFromClientServerName extracts and validates a client ID. hostSrvName // is the server name of the host. cliSrvName is the server name as sent by the // client. When strict is true, and client and host server name don't match, // clientIDFromClientServerName will return an error. func clientIDFromClientServerName(hostSrvName, cliSrvName string, strict bool) (clientID string, err error) { if hostSrvName == cliSrvName { return "", nil } if !strings.HasSuffix(cliSrvName, hostSrvName) { if !strict { return "", nil } return "", fmt.Errorf("client server name %q doesn't match host server name %q", cliSrvName, hostSrvName) } clientID = cliSrvName[:len(cliSrvName)-len(hostSrvName)-1] err = ValidateClientID(clientID) if err != nil { // Don't wrap the error, because it's informative enough as is. return "", err } return clientID, nil } // processClientIDHTTPS extracts the client's ID from the path of the // client's DNS-over-HTTPS request. func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) { pctx := ctx.proxyCtx r := pctx.HTTPRequest if r == nil { ctx.err = fmt.Errorf("proxy ctx http request of proto %s is nil", pctx.Proto) return resultCodeError } origPath := r.URL.Path parts := strings.Split(path.Clean(origPath), "/") if parts[0] == "" { parts = parts[1:] } if len(parts) == 0 || parts[0] != "dns-query" { ctx.err = fmt.Errorf("client id check: invalid path %q", origPath) return resultCodeError } clientID := "" switch len(parts) { case 1: // Just /dns-query, no client ID. return resultCodeSuccess case 2: clientID = parts[1] default: ctx.err = fmt.Errorf("client id check: invalid path %q: extra parts", origPath) return resultCodeError } err := ValidateClientID(clientID) if err != nil { ctx.err = fmt.Errorf("client id check: %w", err) return resultCodeError } ctx.clientID = clientID return resultCodeSuccess } // tlsConn is a narrow interface for *tls.Conn to simplify testing. type tlsConn interface { ConnectionState() (cs tls.ConnectionState) } // quicSession is a narrow interface for quic.Session to simplify testing. type quicSession interface { ConnectionState() (cs quic.ConnectionState) } // processClientID extracts the client's ID from the server name of the client's // DoT or DoQ request or the path of the client's DoH. func processClientID(dctx *dnsContext) (rc resultCode) { pctx := dctx.proxyCtx proto := pctx.Proto if proto == proxy.ProtoHTTPS { return processClientIDHTTPS(dctx) } else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC { return resultCodeSuccess } srvConf := dctx.srv.conf hostSrvName := srvConf.TLSConfig.ServerName if hostSrvName == "" { return resultCodeSuccess } cliSrvName := "" if proto == proxy.ProtoTLS { conn := pctx.Conn tc, ok := conn.(tlsConn) if !ok { dctx.err = fmt.Errorf("proxy ctx conn of proto %s is %T, want *tls.Conn", proto, conn) return resultCodeError } cliSrvName = tc.ConnectionState().ServerName } else if proto == proxy.ProtoQUIC { qs, ok := pctx.QUICSession.(quicSession) if !ok { dctx.err = fmt.Errorf("proxy ctx quic session of proto %s is %T, want quic.Session", proto, pctx.QUICSession) return resultCodeError } cliSrvName = qs.ConnectionState().TLS.ServerName } clientID, err := clientIDFromClientServerName(hostSrvName, cliSrvName, srvConf.StrictSNICheck) if err != nil { dctx.err = fmt.Errorf("client id check: %w", err) return resultCodeError } dctx.clientID = clientID return resultCodeSuccess }