Pull request: dnsforward: fix clientid check

Closes #3437.

Squashed commit of the following:

commit fc4207a6ee1a09ade9db5eb5c8b58f88011db2f9
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Thu Aug 12 18:22:31 2021 +0300

    dnsforward: imp code, docs

commit 0c608e0b7ca0b68b7810fc1ca798fb7d80d6ac24
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Thu Aug 12 18:01:22 2021 +0300

    dnsforward: fix clientid check
This commit is contained in:
Ainar Garipov 2021-08-12 18:35:30 +03:00
parent 506b459842
commit e3ad46876f
3 changed files with 32 additions and 3 deletions

View file

@ -64,6 +64,7 @@ and this project adheres to
### Fixed
- Client ID checking ([#3437]).
- Discovering other DHCP servers on `darwin` and `freebsd` ([#3417]).
- Switching listening address to unspecified one when bound to a single
specified IPv4 address on Darwin (macOS) ([#2807]).
@ -122,6 +123,7 @@ and this project adheres to
[#3351]: https://github.com/AdguardTeam/AdGuardHome/issues/3351
[#3372]: https://github.com/AdguardTeam/AdGuardHome/issues/3372
[#3417]: https://github.com/AdguardTeam/AdGuardHome/issues/3417
[#3437]: https://github.com/AdguardTeam/AdGuardHome/issues/3437

View file

@ -24,21 +24,39 @@ func ValidateClientID(clientID string) (err error) {
return nil
}
// hasLabelSuffix returns true if s ends with suffix preceded by a dot. It's
// a helper function to prevent unnecessary allocations in code like:
//
// if strings.HasSuffix(s, "." + suffix) { /* … */ }
//
// s must be longer than suffix.
func hasLabelSuffix(s, suffix string) (ok bool) {
return strings.HasSuffix(s, suffix) && s[len(s)-len(suffix)-1] == '.'
}
// clientIDFromClientServerName extracts and validates a client ID. hostSrvName
// is the server name of the host. cliSrvName is the server name as sent by the
// client. When strict is true, and client and host server name don't match,
// clientIDFromClientServerName will return an error.
func clientIDFromClientServerName(hostSrvName, cliSrvName string, strict bool) (clientID string, err error) {
func clientIDFromClientServerName(
hostSrvName string,
cliSrvName string,
strict bool,
) (clientID string, err error) {
if hostSrvName == cliSrvName {
return "", nil
}
if !strings.HasSuffix(cliSrvName, hostSrvName) {
if !hasLabelSuffix(cliSrvName, hostSrvName) {
if !strict {
return "", nil
}
return "", fmt.Errorf("client server name %q doesn't match host server name %q", cliSrvName, hostSrvName)
return "", fmt.Errorf(
"client server name %q doesn't match host server name %q",
cliSrvName,
hostSrvName,
)
}
clientID = cliSrvName[:len(cliSrvName)-len(hostSrvName)-1]

View file

@ -134,6 +134,15 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
wantClientID: "cli",
wantErrMsg: "",
strictSNI: true,
}, {
name: "tls_client_id_issue3437",
proto: proxy.ProtoTLS,
hostSrvName: "example.com",
cliSrvName: "cli.myexample.com",
wantClientID: "",
wantErrMsg: `client id check: client server name "cli.myexample.com" ` +
`doesn't match host server name "example.com"`,
strictSNI: true,
}}
for _, tc := range testCases {