AdGuardHome/internal/dnsforward/recursiondetector.go
Eugene Burkov 78d47d8884 Pull request: 3185 fix recursion vol.2
Merge in DNS/adguard-home from fix-recursion to master

Closes #3185.

Squashed commit of the following:

commit c78650b762163f39b2eb4b10f76f1845913134b2
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Mon May 31 13:12:46 2021 +0300

    all: fix changelog

commit e43017a4b6d245f14e1a3bdf735a9c9f03501156
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Sun May 30 22:39:05 2021 +0300

    dnsforward: reduce recursion ttl

commit 79208a82bdad8280c439669413aef8dc7df0a85c
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Sun May 30 22:29:27 2021 +0300

    dnsforward: only check recursion for private rdns

commit 1b8075b086f33e58e273dfcf4168a6ba0473ebae
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Sun May 30 22:18:11 2021 +0300

    dnsforward: rm recursion detecting from upstream
2021-05-31 13:25:40 +03:00

115 lines
2.8 KiB
Go

package dnsforward
import (
"bytes"
"encoding/binary"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
// uint* sizes in bytes to improve readability.
//
// TODO(e.burkov): Remove when there will be a more regardful way to define
// those. See https://github.com/golang/go/issues/29982.
const (
uint16sz = 2
uint64sz = 8
)
// recursionDetector detects recursion in DNS forwarding.
type recursionDetector struct {
recentRequests cache.Cache
ttl time.Duration
}
// check checks if the passed req was already sent by the server.
func (rd *recursionDetector) check(msg dns.Msg) (ok bool) {
if len(msg.Question) == 0 {
return false
}
key := msgToSignature(msg)
expireData := rd.recentRequests.Get(key)
if expireData == nil {
return false
}
expire := time.Unix(0, int64(binary.BigEndian.Uint64(expireData)))
return time.Now().Before(expire)
}
// add caches the msg if it has anything in the questions section.
func (rd *recursionDetector) add(msg dns.Msg) {
now := time.Now()
if len(msg.Question) == 0 {
return
}
key := msgToSignature(msg)
expire64 := uint64(now.Add(rd.ttl).UnixNano())
expire := make([]byte, uint64sz)
binary.BigEndian.PutUint64(expire, expire64)
rd.recentRequests.Set(key, expire)
}
// clear clears the recent requests cache.
func (rd *recursionDetector) clear() {
rd.recentRequests.Clear()
}
// newRecursionDetector returns the initialized *recursionDetector.
func newRecursionDetector(ttl time.Duration, suspectsNum uint) (rd *recursionDetector) {
return &recursionDetector{
recentRequests: cache.New(cache.Config{
EnableLRU: true,
MaxCount: suspectsNum,
}),
ttl: ttl,
}
}
// msgToSignature converts msg into it's signature represented in bytes.
func msgToSignature(msg dns.Msg) (sig []byte) {
sig = make([]byte, uint16sz*2+aghnet.MaxDomainNameLen)
// The binary.BigEndian byte order is used everywhere except when the
// real machine's endianess is needed.
byteOrder := binary.BigEndian
byteOrder.PutUint16(sig[0:], msg.Id)
q := msg.Question[0]
byteOrder.PutUint16(sig[uint16sz:], q.Qtype)
copy(sig[2*uint16sz:], []byte(q.Name))
return sig
}
// msgToSignatureSlow converts msg into it's signature represented in bytes in
// the less efficient way.
//
// See BenchmarkMsgToSignature.
func msgToSignatureSlow(msg dns.Msg) (sig []byte) {
type msgSignature struct {
name [aghnet.MaxDomainNameLen]byte
id uint16
qtype uint16
}
b := bytes.NewBuffer(sig)
q := msg.Question[0]
signature := msgSignature{
id: msg.Id,
qtype: q.Qtype,
}
copy(signature.name[:], q.Name)
if err := binary.Write(b, binary.BigEndian, signature); err != nil {
log.Debug("writing message signature: %s", err)
}
return b.Bytes()
}