package dnsforward import ( "encoding/binary" "log" "math" "strings" "sync" "time" "github.com/miekg/dns" ) type item struct { m *dns.Msg when time.Time } type cache struct { items map[string]item sync.RWMutex } func (c *cache) Get(request *dns.Msg) (*dns.Msg, bool) { if request == nil { return nil, false } ok, key := key(request) if !ok { log.Printf("Get(): key returned !ok") return nil, false } c.RLock() item, ok := c.items[key] c.RUnlock() if !ok { return nil, false } // get item's TTL ttl := findLowestTTL(item.m) // zero TTL? delete and don't serve it if ttl == 0 { c.Lock() delete(c.items, key) c.Unlock() return nil, false } // too much time has passed? delete and don't serve it if time.Since(item.when) >= time.Duration(ttl)*time.Second { c.Lock() delete(c.items, key) c.Unlock() return nil, false } response := item.fromItem(request) return response, true } func (c *cache) Set(m *dns.Msg) { if m == nil { return // no-op } if !isRequestCacheable(m) { return } if !isResponseCacheable(m) { return } ok, key := key(m) if !ok { return } i := toItem(m) c.Lock() if c.items == nil { c.items = map[string]item{} } c.items[key] = i c.Unlock() } // check only request fields func isRequestCacheable(m *dns.Msg) bool { // truncated messages aren't valid if m.Truncated { log.Printf("Refusing to cache truncated message") return false } // if has wrong number of questions, also don't cache if len(m.Question) != 1 { log.Printf("Refusing to cache message with wrong number of questions") return false } // only OK or NXdomain replies are cached switch m.Rcode { case dns.RcodeSuccess: case dns.RcodeNameError: // that's an NXDomain case dns.RcodeServerFailure: return false // quietly refuse, don't log default: log.Printf("%s: Refusing to cache message with rcode: %s", m.Question[0].Name, dns.RcodeToString[m.Rcode]) return false } return true } func isResponseCacheable(m *dns.Msg) bool { ttl := findLowestTTL(m) if ttl == 0 { return false } return true } func findLowestTTL(m *dns.Msg) uint32 { var ttl uint32 = math.MaxUint32 found := false if m.Answer != nil { for _, r := range m.Answer { if r.Header().Ttl < ttl { ttl = r.Header().Ttl found = true } } } if m.Ns != nil { for _, r := range m.Ns { if r.Header().Ttl < ttl { ttl = r.Header().Ttl found = true } } } if m.Extra != nil { for _, r := range m.Extra { if r.Header().Rrtype == dns.TypeOPT { continue // OPT records use TTL for other purposes } if r.Header().Ttl < ttl { ttl = r.Header().Ttl found = true } } } if found == false { return 0 } return ttl } // key is binary little endian in sequence: // uint16(qtype) then uint16(qclass) then name func key(m *dns.Msg) (bool, string) { if len(m.Question) != 1 { log.Printf("got msg with len(m.Question) != 1: %d", len(m.Question)) return false, "" } bb := strings.Builder{} b := make([]byte, 2) binary.LittleEndian.PutUint16(b, m.Question[0].Qtype) bb.Write(b) binary.LittleEndian.PutUint16(b, m.Question[0].Qclass) bb.Write(b) name := strings.ToLower(m.Question[0].Name) bb.WriteString(name) return true, bb.String() } func toItem(m *dns.Msg) item { return item{ m: m, when: time.Now(), } } func (i *item) fromItem(request *dns.Msg) *dns.Msg { response := &dns.Msg{} response.SetReply(request) response.Authoritative = false response.AuthenticatedData = i.m.AuthenticatedData response.RecursionAvailable = i.m.RecursionAvailable response.Rcode = i.m.Rcode ttl := findLowestTTL(i.m) timeleft := math.Round(float64(ttl) - time.Since(i.when).Seconds()) var newttl uint32 if timeleft > 0 { newttl = uint32(timeleft) } for _, r := range i.m.Answer { answer := dns.Copy(r) answer.Header().Ttl = newttl response.Answer = append(response.Answer, answer) } for _, r := range i.m.Ns { ns := dns.Copy(r) ns.Header().Ttl = newttl response.Ns = append(response.Ns, ns) } for _, r := range i.m.Extra { // don't return OPT records as these are hop-by-hop if r.Header().Rrtype == dns.TypeOPT { continue } extra := dns.Copy(r) extra.Header().Ttl = newttl response.Extra = append(response.Extra, extra) } return response }