Merge branch 'master' into 3745-cleanup-lists

This commit is contained in:
Eugene Burkov 2021-10-21 14:30:18 +03:00
commit 8388b3d5bc
23 changed files with 549 additions and 184 deletions

View file

@ -46,6 +46,10 @@ and this project adheres to
### Changed
- `$dnsrewrite` rules and other DNS rewrites will now be applied even when the
protection is disabled ([#1558]).
- DHCP gateway address, subnet mask, IP address range, and leases validations
([#3529]).
- The `systemd` service script will now create the `/var/log` directory when it
doesn't exist ([#3579]).
- Items in allowed clients, disallowed clients, and blocked hosts lists are now
@ -114,6 +118,7 @@ In this release, the schema version has changed from 10 to 12.
### Fixed
- Incorrect assignment of explicitly configured DHCP options ([#3744]).
- Occasional panic during shutdown ([#3655]).
- Addition of IPs into only one as opposed to all matching ipsets on Linux
([#3638]).
@ -152,6 +157,7 @@ In this release, the schema version has changed from 10 to 12.
- Go 1.15 support.
[#1381]: https://github.com/AdguardTeam/AdGuardHome/issues/1381
[#1558]: https://github.com/AdguardTeam/AdGuardHome/issues/1558
[#1691]: https://github.com/AdguardTeam/AdGuardHome/issues/1691
[#1898]: https://github.com/AdguardTeam/AdGuardHome/issues/1898
[#1992]: https://github.com/AdguardTeam/AdGuardHome/issues/1992
@ -195,6 +201,7 @@ In this release, the schema version has changed from 10 to 12.
[#3450]: https://github.com/AdguardTeam/AdGuardHome/issues/3450
[#3457]: https://github.com/AdguardTeam/AdGuardHome/issues/3457
[#3506]: https://github.com/AdguardTeam/AdGuardHome/issues/3506
[#3529]: https://github.com/AdguardTeam/AdGuardHome/issues/3529
[#3538]: https://github.com/AdguardTeam/AdGuardHome/issues/3538
[#3551]: https://github.com/AdguardTeam/AdGuardHome/issues/3551
[#3564]: https://github.com/AdguardTeam/AdGuardHome/issues/3564
@ -204,6 +211,7 @@ In this release, the schema version has changed from 10 to 12.
[#3607]: https://github.com/AdguardTeam/AdGuardHome/issues/3607
[#3638]: https://github.com/AdguardTeam/AdGuardHome/issues/3638
[#3655]: https://github.com/AdguardTeam/AdGuardHome/issues/3655
[#3744]: https://github.com/AdguardTeam/AdGuardHome/issues/3744

View file

@ -37,6 +37,9 @@
"dhcp_ipv6_settings": "DHCP IPv6 Settings",
"form_error_required": "Required field",
"form_error_ip4_format": "Invalid IPv4 format",
"form_error_ip4_range_start_format": "Invalid range start IPv4 format",
"form_error_ip4_range_end_format": "Invalid range end IPv4 format",
"form_error_ip4_gateway_format": "Invalid gateway IPv4 format",
"form_error_ip6_format": "Invalid IPv6 format",
"form_error_ip_format": "Invalid IP format",
"form_error_mac_format": "Invalid MAC format",
@ -45,7 +48,14 @@
"form_error_subnet": "Subnet \"{{cidr}}\" does not contain the IP address \"{{ip}}\"",
"form_error_positive": "Must be greater than 0",
"form_error_negative": "Must be equal to 0 or greater",
"range_end_error": "Must be greater than range start",
"out_of_range_error": "Must be out of range \"{{start}}\"-\"{{end}}\"",
"in_range_error": "Must be in range \"{{start}}\"-\"{{end}}\"",
"lower_range_start_error": "Must be lower than range start",
"lower_range_end_error": "Must be lower than range end",
"greater_range_start_error": "Must be greater than range start",
"greater_range_end_error": "Must be greater than range end",
"subnet_error": "Addresses must be in one subnet",
"gateway_or_subnet_invalid": "Subnet mask invalid",
"dhcp_form_gateway_input": "Gateway IP",
"dhcp_form_subnet_input": "Subnet mask",
"dhcp_form_range_title": "Range of IP addresses",

View file

@ -13,6 +13,9 @@ import {
validateIpv4,
validateRequiredValue,
validateIpv4RangeEnd,
validateGatewaySubnetMask,
validateIpForGatewaySubnetMask,
validateNotInRange,
} from '../../../helpers/validators';
const FormDHCPv4 = ({
@ -54,7 +57,11 @@ const FormDHCPv4 = ({
type="text"
className="form-control"
placeholder={t(ipv4placeholders.gateway_ip)}
validate={[validateIpv4, validateRequired]}
validate={[
validateIpv4,
validateRequired,
validateNotInRange,
]}
disabled={!isInterfaceIncludesIpv4}
/>
</div>
@ -66,7 +73,11 @@ const FormDHCPv4 = ({
type="text"
className="form-control"
placeholder={t(ipv4placeholders.subnet_mask)}
validate={[validateIpv4, validateRequired]}
validate={[
validateIpv4,
validateRequired,
validateGatewaySubnetMask,
]}
disabled={!isInterfaceIncludesIpv4}
/>
</div>
@ -84,7 +95,11 @@ const FormDHCPv4 = ({
type="text"
className="form-control"
placeholder={t(ipv4placeholders.range_start)}
validate={[validateIpv4]}
validate={[
validateIpv4,
validateGatewaySubnetMask,
validateIpForGatewaySubnetMask,
]}
disabled={!isInterfaceIncludesIpv4}
/>
</div>
@ -95,7 +110,12 @@ const FormDHCPv4 = ({
type="text"
className="form-control"
placeholder={t(ipv4placeholders.range_end)}
validate={[validateIpv4, validateIpv4RangeEnd]}
validate={[
validateIpv4,
validateIpv4RangeEnd,
validateGatewaySubnetMask,
validateIpForGatewaySubnetMask,
]}
disabled={!isInterfaceIncludesIpv4}
/>
</div>

View file

@ -10,6 +10,7 @@ import {
validateMac,
validateRequiredValue,
validateIpv4InCidr,
validateInRange,
} from '../../../../helpers/validators';
import { FORM_NAME } from '../../../../helpers/constants';
import { toggleLeaseModal } from '../../../../actions';
@ -53,7 +54,12 @@ const Form = ({
type="text"
className="form-control"
placeholder={t('form_enter_subnet_ip', { cidr })}
validate={[validateRequiredValue, validateIpv4, validateIpv4InCidr]}
validate={[
validateRequiredValue,
validateIpv4,
validateIpv4InCidr,
validateInRange,
]}
/>
</div>
<div className="form__group">

View file

@ -11,6 +11,8 @@ const Modal = ({
handleSubmit,
processingAdding,
cidr,
rangeStart,
rangeEnd,
}) => {
const dispatch = useDispatch();
@ -38,10 +40,14 @@ const Modal = ({
ip: '',
hostname: '',
cidr,
rangeStart,
rangeEnd,
}}
onSubmit={handleSubmit}
processingAdding={processingAdding}
cidr={cidr}
rangeStart={rangeStart}
rangeEnd={rangeEnd}
/>
</div>
</ReactModal>
@ -53,6 +59,8 @@ Modal.propTypes = {
handleSubmit: PropTypes.func.isRequired,
processingAdding: PropTypes.bool.isRequired,
cidr: PropTypes.string.isRequired,
rangeStart: PropTypes.string,
rangeEnd: PropTypes.string,
};
export default withTranslation()(Modal);

View file

@ -22,6 +22,8 @@ const StaticLeases = ({
processingDeleting,
staticLeases,
cidr,
rangeStart,
rangeEnd,
}) => {
const [t] = useTranslation();
const dispatch = useDispatch();
@ -100,6 +102,8 @@ const StaticLeases = ({
handleSubmit={handleSubmit}
processingAdding={processingAdding}
cidr={cidr}
rangeStart={rangeStart}
rangeEnd={rangeEnd}
/>
</>
);
@ -111,6 +115,8 @@ StaticLeases.propTypes = {
processingAdding: PropTypes.bool.isRequired,
processingDeleting: PropTypes.bool.isRequired,
cidr: PropTypes.string.isRequired,
rangeStart: PropTypes.string,
rangeEnd: PropTypes.string,
};
cellWrap.propTypes = {

View file

@ -275,6 +275,8 @@ const Dhcp = () => {
processingAdding={processingAdding}
processingDeleting={processingDeleting}
cidr={cidr}
rangeStart={dhcp?.values?.v4?.range_start}
rangeEnd={dhcp?.values?.v4?.range_end}
/>
<div className="btn-list mt-2">
<button

View file

@ -552,6 +552,20 @@ export const isIpInCidr = (ip, cidr) => {
}
};
/**
*
* @param {string} subnetMask
* @returns {IPv4 | null}
*/
export const parseSubnetMask = (subnetMask) => {
try {
return ipaddr.parse(subnetMask).prefixLengthFromSubnetMask();
} catch (e) {
console.error(e);
return null;
}
};
/**
*
* @param {string} subnetMask

View file

@ -1,4 +1,5 @@
import i18next from 'i18next';
import {
MAX_PORT,
R_CIDR,
@ -14,7 +15,7 @@ import {
R_DOMAIN,
} from './constants';
import { ip4ToInt, isValidAbsolutePath } from './form';
import { isIpInCidr } from './helpers';
import { isIpInCidr, parseSubnetMask } from './helpers';
// Validation functions
// https://redux-form.com/8.3.0/examples/fieldlevelvalidation/
@ -44,7 +45,7 @@ export const validateIpv4RangeEnd = (_, allValues) => {
const { range_end, range_start } = allValues.v4;
if (ip4ToInt(range_end) <= ip4ToInt(range_start)) {
return 'range_end_error';
return 'greater_range_start_error';
}
return undefined;
@ -61,6 +62,114 @@ export const validateIpv4 = (value) => {
return undefined;
};
/**
* @returns {undefined|string}
* @param _
* @param allValues
*/
export const validateNotInRange = (value, allValues) => {
const { range_start, range_end } = allValues.v4;
if (range_start && validateIpv4(range_start)) {
return 'form_error_ip4_range_start_format';
}
if (range_end && validateIpv4(range_end)) {
return 'form_error_ip4_range_end_format';
}
const isAboveMin = range_start && ip4ToInt(value) >= ip4ToInt(range_start);
const isBelowMax = range_end && ip4ToInt(value) <= ip4ToInt(range_end);
if (isAboveMin && isBelowMax) {
return i18next.t('out_of_range_error', {
start: range_start,
end: range_end,
});
}
if (!range_end && isAboveMin) {
return 'lower_range_start_error';
}
if (!range_start && isBelowMax) {
return 'greater_range_end_error';
}
return undefined;
};
/**
* @returns {undefined|string}
* @param _
* @param allValues
*/
export const validateInRange = (value, allValues) => {
const { rangeStart, rangeEnd } = allValues;
if (rangeStart && validateIpv4(rangeStart)) {
return 'form_error_ip4_range_start_format';
}
if (rangeEnd && validateIpv4(rangeEnd)) {
return 'form_error_ip4_range_end_format';
}
const isBelowMin = rangeStart && ip4ToInt(value) < ip4ToInt(rangeStart);
const isAboveMax = rangeEnd && ip4ToInt(value) > ip4ToInt(rangeEnd);
if (isAboveMax || isBelowMin) {
return i18next.t('in_range_error', {
start: rangeStart,
end: rangeEnd,
});
}
return undefined;
};
/**
* @returns {undefined|string}
* @param _
* @param allValues
*/
export const validateGatewaySubnetMask = (_, allValues) => {
if (!allValues || !allValues.v4 || !allValues.v4.subnet_mask || !allValues.v4.gateway_ip) {
return 'gateway_or_subnet_invalid';
}
const { subnet_mask, gateway_ip } = allValues.v4;
if (validateIpv4(gateway_ip)) {
return 'form_error_ip4_gateway_format';
}
return parseSubnetMask(subnet_mask) ? undefined : 'gateway_or_subnet_invalid';
};
/**
* @returns {undefined|string}
* @param value
* @param allValues
*/
export const validateIpForGatewaySubnetMask = (value, allValues) => {
if (!allValues || !allValues.v4 || !value) {
return undefined;
}
const {
gateway_ip, subnet_mask,
} = allValues.v4;
const subnetPrefix = parseSubnetMask(subnet_mask);
if (!isIpInCidr(value, `${gateway_ip}/${subnetPrefix}`)) {
return 'subnet_error';
}
return undefined;
};
/**
* @param value {string}
* @returns {undefined|string}

2
go.mod
View file

@ -4,7 +4,7 @@ go 1.16
require (
github.com/AdguardTeam/dnsproxy v0.39.8
github.com/AdguardTeam/golibs v0.10.0
github.com/AdguardTeam/golibs v0.10.2
github.com/AdguardTeam/urlfilter v0.14.6
github.com/NYTimes/gziphandler v1.1.1
github.com/ameshkov/dnscrypt/v2 v2.2.2

4
go.sum
View file

@ -14,8 +14,8 @@ github.com/AdguardTeam/dnsproxy v0.39.8/go.mod h1:eDpJKAdkHORRwAedjuERv+7SWlcz4c
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.9.2/go.mod h1:fCAMwPBJ8S7YMYbTWvYS+eeTLblP5E04IDtNAo7y7IY=
github.com/AdguardTeam/golibs v0.10.0 h1:A7MXRfZ+ItpOyS9tWKtqrLj3vZtE9FJFC+dOVY/LcWs=
github.com/AdguardTeam/golibs v0.10.0/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw=
github.com/AdguardTeam/golibs v0.10.2 h1:TAwnS4Y49sSUa4UX1yz/MWNGbIlXHqafrWr9MxdIh9A=
github.com/AdguardTeam/golibs v0.10.2/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw=
github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU=
github.com/AdguardTeam/urlfilter v0.14.6 h1:emqoKZElooHACYehRBYENeKVN1a/rspxiqTIMYLuoIo=
github.com/AdguardTeam/urlfilter v0.14.6/go.mod h1:klx4JbOfc4EaNb5lWLqOwfg+pVcyRukmoJRvO55lL5U=

View file

@ -11,6 +11,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -138,6 +139,49 @@ func TestNormalizeLeases(t *testing.T) {
assert.Equal(t, leases[2].HWAddr, dynLeases[1].HWAddr)
}
func TestV4Server_badRange(t *testing.T) {
testCases := []struct {
name string
gatewayIP net.IP
subnetMask net.IP
wantErrMsg string
}{{
name: "gateway_in_range",
gatewayIP: net.IP{192, 168, 10, 120},
subnetMask: net.IP{255, 255, 255, 0},
wantErrMsg: "dhcpv4: gateway ip 192.168.10.120 in the ip range: " +
"192.168.10.20-192.168.10.200",
}, {
name: "outside_range_start",
gatewayIP: net.IP{192, 168, 10, 1},
subnetMask: net.IP{255, 255, 255, 240},
wantErrMsg: "dhcpv4: range start 192.168.10.20 is outside network " +
"192.168.10.1/28",
}, {
name: "outside_range_end",
gatewayIP: net.IP{192, 168, 10, 1},
subnetMask: net.IP{255, 255, 255, 224},
wantErrMsg: "dhcpv4: range end 192.168.10.200 is outside network " +
"192.168.10.1/27",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
conf := V4ServerConf{
Enabled: true,
RangeStart: net.IP{192, 168, 10, 20},
RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: tc.gatewayIP,
SubnetMask: tc.subnetMask,
notify: testNotify,
}
_, err := v4Create(conf)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
}
// cloneUDPAddr returns a deep copy of a.
func cloneUDPAddr(a *net.UDPAddr) (clone *net.UDPAddr) {
return &net.UDPAddr{

View file

@ -293,6 +293,8 @@ func (s *v4Server) addLease(l *Lease) (err error) {
offset, inOffset := r.offset(l.IP)
if l.IsStatic() {
// TODO(a.garipov, d.seregin): Subnet can be nil when dhcp server is
// disabled.
if sn := s.conf.subnet; !sn.Contains(l.IP) {
return fmt.Errorf("subnet %s does not contain the ip %q", sn, l.IP)
}
@ -900,9 +902,10 @@ func (s *v4Server) process(req, resp *dhcpv4.DHCPv4) int {
resp.UpdateOption(dhcpv4.OptGeneric(code, configured.Get(code)))
}
}
// Update the value of Domain Name Server option separately from others
// since its value is set after server's creating.
if requested.Has(dhcpv4.OptionDomainNameServer) {
// Update the value of Domain Name Server option separately from others if
// not assigned yet since its value is set after server's creating.
if requested.Has(dhcpv4.OptionDomainNameServer) &&
!resp.Options.Has(dhcpv4.OptionDomainNameServer) {
resp.UpdateOption(dhcpv4.OptDNS(s.conf.dnsIPAddrs...))
}
@ -1124,6 +1127,29 @@ func v4Create(conf V4ServerConf) (srv DHCPServer, err error) {
return s, fmt.Errorf("dhcpv4: %w", err)
}
if s.conf.ipRange.contains(routerIP) {
return s, fmt.Errorf("dhcpv4: gateway ip %v in the ip range: %v-%v",
routerIP,
conf.RangeStart,
conf.RangeEnd,
)
}
if !s.conf.subnet.Contains(conf.RangeStart) {
return s, fmt.Errorf("dhcpv4: range start %v is outside network %v",
conf.RangeStart,
s.conf.subnet,
)
}
if !s.conf.subnet.Contains(conf.RangeEnd) {
return s, fmt.Errorf("dhcpv4: range end %v is outside network %v",
conf.RangeEnd,
s.conf.subnet,
)
}
// TODO(a.garipov, d.seregin): Check that every lease is inside the IPRange.
s.leasedOffsets = newBitSet()
if conf.LeaseDuration == 0 {

View file

@ -5,8 +5,10 @@ package dhcpd
import (
"net"
"strings"
"testing"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/mdlayher/raw"
"github.com/stretchr/testify/assert"
@ -16,17 +18,34 @@ import (
func notify4(flags uint32) {
}
func TestV4_AddRemove_static(t *testing.T) {
s, err := v4Create(V4ServerConf{
// defaultV4ServerConf returns the default configuration for *v4Server to use in
// tests.
func defaultV4ServerConf() (conf V4ServerConf) {
return V4ServerConf{
Enabled: true,
RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: net.IP{192, 168, 10, 1},
SubnetMask: net.IP{255, 255, 255, 0},
notify: notify4,
})
}
}
// defaultSrv prepares the default DHCPServer to use in tests. The underlying
// type of s is *v4Server.
func defaultSrv(t *testing.T) (s DHCPServer) {
t.Helper()
var err error
s, err = v4Create(defaultV4ServerConf())
require.NoError(t, err)
return s
}
func TestV4_AddRemove_static(t *testing.T) {
s := defaultSrv(t)
ls := s.GetLeases(LeasesStatic)
assert.Empty(t, ls)
@ -37,7 +56,7 @@ func TestV4_AddRemove_static(t *testing.T) {
IP: net.IP{192, 168, 10, 150},
}
err = s.AddStaticLease(l)
err := s.AddStaticLease(l)
require.NoError(t, err)
err = s.AddStaticLease(l)
@ -65,15 +84,7 @@ func TestV4_AddRemove_static(t *testing.T) {
}
func TestV4_AddReplace(t *testing.T) {
sIface, err := v4Create(V4ServerConf{
Enabled: true,
RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: net.IP{192, 168, 10, 1},
SubnetMask: net.IP{255, 255, 255, 0},
notify: notify4,
})
require.NoError(t, err)
sIface := defaultSrv(t)
s, ok := sIface.(*v4Server)
require.True(t, ok)
@ -89,7 +100,7 @@ func TestV4_AddReplace(t *testing.T) {
}}
for i := range dynLeases {
err = s.addLease(&dynLeases[i])
err := s.addLease(&dynLeases[i])
require.NoError(t, err)
}
@ -104,7 +115,7 @@ func TestV4_AddReplace(t *testing.T) {
}}
for _, l := range stLeases {
err = s.AddStaticLease(l)
err := s.AddStaticLease(l)
require.NoError(t, err)
}
@ -118,17 +129,80 @@ func TestV4_AddReplace(t *testing.T) {
}
}
func TestV4StaticLease_Get(t *testing.T) {
var err error
sIface, err := v4Create(V4ServerConf{
Enabled: true,
RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: net.IP{192, 168, 10, 1},
SubnetMask: net.IP{255, 255, 255, 0},
notify: notify4,
func TestV4Server_Process_optionsPriority(t *testing.T) {
defaultIP := net.IP{192, 168, 1, 1}
knownIP := net.IP{1, 2, 3, 4}
// prepareSrv creates a *v4Server and sets the opt6IPs in the initial
// configuration of the server as the value for DHCP option 6.
prepareSrv := func(t *testing.T, opt6IPs []net.IP) (s *v4Server) {
t.Helper()
conf := defaultV4ServerConf()
if len(opt6IPs) > 0 {
b := &strings.Builder{}
stringutil.WriteToBuilder(b, "6 ips ", opt6IPs[0].String())
for _, ip := range opt6IPs[1:] {
stringutil.WriteToBuilder(b, ",", ip.String())
}
conf.Options = []string{b.String()}
}
ss, err := v4Create(conf)
require.NoError(t, err)
var ok bool
s, ok = ss.(*v4Server)
require.True(t, ok)
s.conf.dnsIPAddrs = []net.IP{defaultIP}
return s
}
// checkResp creates a discovery message with DHCP option 6 requested amd
// asserts the response to contain wantIPs in this option.
checkResp := func(t *testing.T, s *v4Server, wantIPs []net.IP) {
t.Helper()
mac := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
req, err := dhcpv4.NewDiscovery(mac, dhcpv4.WithRequestedOptions(
dhcpv4.OptionDomainNameServer,
))
require.NoError(t, err)
var resp *dhcpv4.DHCPv4
resp, err = dhcpv4.NewReplyFromRequest(req)
require.NoError(t, err)
res := s.process(req, resp)
require.Equal(t, 1, res)
o := resp.GetOneOption(dhcpv4.OptionDomainNameServer)
require.NotEmpty(t, o)
wantData := []byte{}
for _, ip := range wantIPs {
wantData = append(wantData, ip...)
}
assert.Equal(t, o, wantData)
}
t.Run("default", func(t *testing.T) {
s := prepareSrv(t, nil)
checkResp(t, s, []net.IP{defaultIP})
})
require.NoError(t, err)
t.Run("explicitly_configured", func(t *testing.T) {
s := prepareSrv(t, []net.IP{knownIP, knownIP})
checkResp(t, s, []net.IP{knownIP, knownIP})
})
}
func TestV4StaticLease_Get(t *testing.T) {
sIface := defaultSrv(t)
s, ok := sIface.(*v4Server)
require.True(t, ok)
@ -140,7 +214,7 @@ func TestV4StaticLease_Get(t *testing.T) {
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: net.IP{192, 168, 10, 150},
}
err = s.AddStaticLease(l)
err := s.AddStaticLease(l)
require.NoError(t, err)
var req, resp *dhcpv4.DHCPv4
@ -208,19 +282,14 @@ func TestV4StaticLease_Get(t *testing.T) {
}
func TestV4DynamicLease_Get(t *testing.T) {
conf := defaultV4ServerConf()
conf.Options = []string{
"81 hex 303132",
"82 ip 1.2.3.4",
}
var err error
sIface, err := v4Create(V4ServerConf{
Enabled: true,
RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: net.IP{192, 168, 10, 1},
SubnetMask: net.IP{255, 255, 255, 0},
notify: notify4,
Options: []string{
"81 hex 303132",
"82 ip 1.2.3.4",
},
})
sIface, err := v4Create(conf)
require.NoError(t, err)
s, ok := sIface.(*v4Server)

View file

@ -90,7 +90,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
s.processRestrictLocal,
s.processInternalIPAddrs,
s.processClientID,
processFilteringBeforeRequest,
s.processFilteringBeforeRequest,
s.processLocalPTR,
s.processUpstream,
processDNSSECAfterResponse,
@ -468,19 +468,18 @@ func (s *Server) processLocalPTR(ctx *dnsContext) (rc resultCode) {
}
// Apply filtering logic
func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
s := ctx.srv
d := ctx.proxyCtx
if d.Res != nil {
return resultCodeSuccess // response is already set - nothing to do
func (s *Server) processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
if ctx.proxyCtx.Res != nil {
// Go on since the response is already set.
return resultCodeSuccess
}
s.serverLock.RLock()
defer s.serverLock.RUnlock()
ctx.protectionEnabled = s.conf.ProtectionEnabled && s.dnsFilter != nil
if !ctx.protectionEnabled {
ctx.protectionEnabled = s.conf.ProtectionEnabled
if s.dnsFilter == nil {
return resultCodeSuccess
}
@ -489,8 +488,7 @@ func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
}
var err error
ctx.result, err = s.filterDNSRequest(ctx)
if err != nil {
if ctx.result, err = s.filterDNSRequest(ctx); err != nil {
ctx.err = err
return resultCodeError
@ -608,48 +606,50 @@ func processDNSSECAfterResponse(ctx *dnsContext) (rc resultCode) {
func processFilteringAfterResponse(ctx *dnsContext) (rc resultCode) {
s := ctx.srv
d := ctx.proxyCtx
res := ctx.result
var err error
switch res.Reason {
case filtering.Rewritten,
switch res := ctx.result; res.Reason {
case filtering.NotFilteredAllowList:
// Go on.
case
filtering.Rewritten,
filtering.RewrittenRule:
if len(ctx.origQuestion.Name) == 0 {
// origQuestion is set in case we get only CNAME without IP from rewrites table
// origQuestion is set in case we get only CNAME without IP from
// rewrites table.
break
}
d.Req.Question[0] = ctx.origQuestion
d.Res.Question[0] = ctx.origQuestion
if len(d.Res.Answer) != 0 {
answer := []dns.RR{}
answer = append(answer, s.genAnswerCNAME(d.Req, res.CanonName))
answer = append(answer, d.Res.Answer...)
d.Req.Question[0], d.Res.Question[0] = ctx.origQuestion, ctx.origQuestion
if len(d.Res.Answer) > 0 {
answer := append([]dns.RR{s.genAnswerCNAME(d.Req, res.CanonName)}, d.Res.Answer...)
d.Res.Answer = answer
}
case filtering.NotFilteredAllowList:
// nothing
default:
if !ctx.protectionEnabled || // filters are disabled: there's nothing to check for
!ctx.responseFromUpstream { // only check response if it's from an upstream server
// Check the response only if the it's from an upstream. Don't check
// the response if the protection is disabled since dnsrewrite rules
// aren't applied to it anyway.
if !ctx.protectionEnabled || !ctx.responseFromUpstream || s.dnsFilter == nil {
break
}
origResp2 := d.Res
ctx.result, err = s.filterDNSResponse(ctx)
origResp := d.Res
result, err := s.filterDNSResponse(ctx)
if err != nil {
ctx.err = err
return resultCodeError
}
if ctx.result != nil {
ctx.origResp = origResp2 // matched by response
} else {
ctx.result = &filtering.Result{}
if result != nil {
ctx.result = result
ctx.origResp = origResp
}
}
if ctx.result == nil {
ctx.result = &filtering.Result{}
}
return resultCodeSuccess
}

View file

@ -909,6 +909,7 @@ func TestRewrite(t *testing.T) {
}},
}
f := filtering.New(c, nil)
f.SetEnabled(true)
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
@ -945,45 +946,56 @@ func TestRewrite(t *testing.T) {
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := createTestMessageWithType("test.com.", dns.TypeA)
reply, err := dns.Exchange(req, addr.String())
require.NoError(t, err)
subTestFunc := func(t *testing.T) {
req := createTestMessageWithType("test.com.", dns.TypeA)
reply, eerr := dns.Exchange(req, addr.String())
require.NoError(t, eerr)
require.Len(t, reply.Answer, 1)
require.Len(t, reply.Answer, 1)
a, ok := reply.Answer[0].(*dns.A)
require.True(t, ok)
a, ok := reply.Answer[0].(*dns.A)
require.True(t, ok)
assert.True(t, net.IP{1, 2, 3, 4}.Equal(a.A))
assert.True(t, net.IP{1, 2, 3, 4}.Equal(a.A))
req = createTestMessageWithType("test.com.", dns.TypeAAAA)
reply, err = dns.Exchange(req, addr.String())
require.NoError(t, err)
req = createTestMessageWithType("test.com.", dns.TypeAAAA)
reply, eerr = dns.Exchange(req, addr.String())
require.NoError(t, eerr)
assert.Empty(t, reply.Answer)
assert.Empty(t, reply.Answer)
req = createTestMessageWithType("alias.test.com.", dns.TypeA)
reply, err = dns.Exchange(req, addr.String())
require.NoError(t, err)
req = createTestMessageWithType("alias.test.com.", dns.TypeA)
reply, eerr = dns.Exchange(req, addr.String())
require.NoError(t, eerr)
require.Len(t, reply.Answer, 2)
require.Len(t, reply.Answer, 2)
assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target)
assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A))
assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target)
assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A))
req = createTestMessageWithType("my.alias.example.org.", dns.TypeA)
reply, err = dns.Exchange(req, addr.String())
require.NoError(t, err)
req = createTestMessageWithType("my.alias.example.org.", dns.TypeA)
reply, eerr = dns.Exchange(req, addr.String())
require.NoError(t, eerr)
// The original question is restored.
require.Len(t, reply.Question, 1)
// The original question is restored.
require.Len(t, reply.Question, 1)
assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name)
assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name)
require.Len(t, reply.Answer, 2)
require.Len(t, reply.Answer, 2)
assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target)
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target)
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
}
for _, protect := range []bool{true, false} {
val := protect
conf := s.getDNSConfig()
conf.ProtectionEnabled = &val
s.setConfig(conf)
t.Run(fmt.Sprintf("protection_is_%t", val), subTestFunc)
}
}
func publicKey(priv interface{}) interface{} {
@ -1092,9 +1104,10 @@ func TestPTRResponseFromHosts(t *testing.T) {
require.ErrorIs(t, hc.Close(), closeCalled)
})
c := filtering.Config{
flt := filtering.New(&filtering.Config{
EtcHosts: hc,
}
}, nil)
flt.SetEnabled(true)
var snd *aghnet.SubnetDetector
snd, err = aghnet.NewSubnetDetector()
@ -1104,7 +1117,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
var s *Server
s, err = NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DNSFilter: filtering.New(&c, nil),
DNSFilter: flt,
SubnetDetector: snd,
})
require.NoError(t, err)
@ -1112,32 +1125,41 @@ func TestPTRResponseFromHosts(t *testing.T) {
s.conf.UDPListenAddrs = []*net.UDPAddr{{}}
s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.FilteringConfig.ProtectionEnabled = true
err = s.Prepare(nil)
require.NoError(t, err)
err = s.Start()
require.NoError(t, err)
t.Cleanup(func() {
s.Close()
})
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR)
subTestFunc := func(t *testing.T) {
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR)
resp, err := dns.Exchange(req, addr.String())
require.NoError(t, err)
resp, eerr := dns.Exchange(req, addr.String())
require.NoError(t, eerr)
require.Lenf(t, resp.Answer, 1, "%#v", resp)
require.Len(t, resp.Answer, 1)
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
ptr, ok := resp.Answer[0].(*dns.PTR)
require.True(t, ok)
assert.Equal(t, "host.", ptr.Ptr)
ptr, ok := resp.Answer[0].(*dns.PTR)
require.True(t, ok)
assert.Equal(t, "host.", ptr.Ptr)
}
for _, protect := range []bool{true, false} {
val := protect
conf := s.getDNSConfig()
conf.ProtectionEnabled = &val
s.setConfig(conf)
t.Run(fmt.Sprintf("protection_is_%t", val), subTestFunc)
}
}
func TestNewServer(t *testing.T) {

View file

@ -52,6 +52,7 @@ func (s *Server) beforeRequestHandler(
// the client's IP address and ID, if any, from ctx.
func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.Settings {
setts := s.dnsFilter.GetConfig()
setts.ProtectionEnabled = ctx.protectionEnabled
if s.conf.FilterHandler != nil {
ip, _ := netutil.IPAndPortFromAddr(ctx.proxyCtx.Addr)
s.conf.FilterHandler(ip, ctx.clientID, &setts)
@ -65,32 +66,31 @@ func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.S
func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) {
d := ctx.proxyCtx
req := d.Req
host := strings.TrimSuffix(req.Question[0].Name, ".")
res, err := s.dnsFilter.CheckHost(host, req.Question[0].Qtype, ctx.setts)
if err != nil {
// Return immediately if there's an error
return nil, fmt.Errorf("filtering failed to check host %q: %w", host, err)
} else if res.IsFiltered {
log.Tracef("Host %s is filtered, reason - %q, matched rule: %q", host, res.Reason, res.Rules[0].Text)
q := req.Question[0]
host := strings.TrimSuffix(q.Name, ".")
res, err := s.dnsFilter.CheckHost(host, q.Qtype, ctx.setts)
switch {
case err != nil:
return nil, fmt.Errorf("failed to check host %q: %w", host, err)
case res.IsFiltered:
log.Tracef("host %q is filtered, reason %q, rule: %q", host, res.Reason, res.Rules[0].Text)
d.Res = s.genDNSFilterMessage(d, &res)
} else if res.Reason.In(filtering.Rewritten, filtering.RewrittenRule) &&
case res.Reason.In(filtering.Rewritten, filtering.RewrittenRule) &&
res.CanonName != "" &&
len(res.IPList) == 0 {
// Resolve the new canonical name, not the original host
// name. The original question is readded in
// processFilteringAfterResponse.
ctx.origQuestion = req.Question[0]
len(res.IPList) == 0:
// Resolve the new canonical name, not the original host name. The
// original question is readded in processFilteringAfterResponse.
ctx.origQuestion = q
req.Question[0].Name = dns.Fqdn(res.CanonName)
} else if res.Reason == filtering.RewrittenAutoHosts && len(res.ReverseHosts) != 0 {
case res.Reason == filtering.RewrittenAutoHosts && len(res.ReverseHosts) != 0:
resp := s.makeResponse(req)
hdr := dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypePTR,
Ttl: s.conf.BlockedResponseTTL,
Class: dns.ClassINET,
}
for _, h := range res.ReverseHosts {
hdr := dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: dns.TypePTR,
Ttl: s.conf.BlockedResponseTTL,
Class: dns.ClassINET,
}
ptr := &dns.PTR{
Hdr: hdr,
Ptr: h,
@ -100,7 +100,7 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) {
}
d.Res = resp
} else if res.Reason.In(filtering.Rewritten, filtering.RewrittenAutoHosts) {
case res.Reason.In(filtering.Rewritten, filtering.RewrittenAutoHosts):
resp := s.makeResponse(req)
name := host
@ -110,11 +110,12 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) {
}
for _, ip := range res.IPList {
if req.Question[0].Qtype == dns.TypeA {
switch q.Qtype {
case dns.TypeA:
a := s.genAnswerA(req, ip.To4())
a.Hdr.Name = dns.Fqdn(name)
resp.Answer = append(resp.Answer, a)
} else if req.Question[0].Qtype == dns.TypeAAAA {
case dns.TypeAAAA:
a := s.genAnswerAAAA(req, ip)
a.Hdr.Name = dns.Fqdn(name)
resp.Answer = append(resp.Answer, a)
@ -122,9 +123,8 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) {
}
d.Res = resp
} else if res.Reason == filtering.RewrittenRule {
err = s.filterDNSRewrite(req, res, d)
if err != nil {
case res.Reason == filtering.RewrittenRule:
if err = s.filterDNSRewrite(req, res, d); err != nil {
return nil, err
}
}
@ -179,6 +179,7 @@ func (s *Server) filterDNSResponse(ctx *dnsContext) (*filtering.Result, error) {
continue
}
host = strings.TrimSuffix(host, ".")
res, err := s.checkHostRules(host, d.Req.Question[0].Qtype, ctx.setts)
if err != nil {
return nil, err

View file

@ -38,6 +38,7 @@ type Settings struct {
ServicesRules []ServiceEntry
ProtectionEnabled bool
FilteringEnabled bool
SafeSearchEnabled bool
SafeBrowsingEnabled bool
@ -221,12 +222,13 @@ func (r Reason) String() string {
}
// In returns true if reasons include r.
func (r Reason) In(reasons ...Reason) bool {
func (r Reason) In(reasons ...Reason) (ok bool) {
for _, reason := range reasons {
if r == reason {
return true
}
}
return false
}
@ -245,7 +247,7 @@ func (d *DNSFilter) GetConfig() (s Settings) {
defer d.confLock.RUnlock()
return Settings{
FilteringEnabled: atomic.LoadUint32(&d.Config.enabled) == 1,
FilteringEnabled: atomic.LoadUint32(&d.Config.enabled) != 0,
SafeSearchEnabled: d.Config.SafeSearchEnabled,
SafeBrowsingEnabled: d.Config.SafeBrowsingEnabled,
ParentalEnabled: d.Config.ParentalEnabled,
@ -421,14 +423,16 @@ func (d *DNSFilter) CheckHost(
// Sometimes clients try to resolve ".", which is a request to get root
// servers.
if host == "" {
return Result{Reason: NotFilteredNotFound}, nil
return Result{}, nil
}
host = strings.ToLower(host)
res = d.processRewrites(host, qtype)
if res.Reason == Rewritten {
return res, nil
if setts.FilteringEnabled {
res = d.processRewrites(host, qtype)
if res.Reason == Rewritten {
return res, nil
}
}
for _, hc := range d.hostCheckers {
@ -448,7 +452,7 @@ func (d *DNSFilter) CheckHost(
// matchSysHosts tries to match the host against the operating system's hosts
// database.
func (d *DNSFilter) matchSysHosts(host string, qtype uint16, setts *Settings) (res Result, err error) {
if d.EtcHosts == nil {
if !setts.FilteringEnabled || d.EtcHosts == nil {
return Result{}, nil
}
@ -468,10 +472,8 @@ func (d *DNSFilter) matchSysHosts(host string, qtype uint16, setts *Settings) (r
var ips []net.IP
var revHosts []string
for _, nr := range dnsr {
dr := nr.DNSRewrite
if dr == nil {
if nr.DNSRewrite == nil {
continue
}
@ -553,6 +555,10 @@ func matchBlockedServicesRules(
_ uint16,
setts *Settings,
) (res Result, err error) {
if !setts.ProtectionEnabled {
return Result{}, nil
}
svcs := setts.ServicesRules
if len(svcs) == 0 {
return Result{}, nil
@ -784,7 +790,7 @@ func (d *DNSFilter) matchHost(
// TODO(e.burkov): Inspect if the above is true.
defer d.engineLock.RUnlock()
if d.filteringEngineAllow != nil {
if setts.ProtectionEnabled && d.filteringEngineAllow != nil {
dnsres, ok := d.filteringEngineAllow.MatchRequest(ureq)
if ok {
return d.matchHostProcessAllowList(host, dnsres)
@ -810,6 +816,11 @@ func (d *DNSFilter) matchHost(
return Result{}, nil
}
if !setts.ProtectionEnabled {
// Don't check non-dnsrewrite filtering results.
return Result{}, nil
}
res = d.matchHostProcessDNSResult(qtype, dnsres)
for _, r := range res.Rules {
log.Debug(

View file

@ -21,7 +21,9 @@ func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m)
}
var setts Settings
var setts = Settings{
ProtectionEnabled: true,
}
// Helpers.
@ -39,9 +41,9 @@ func purgeCaches() {
func newForTest(c *Config, filters []Filter) *DNSFilter {
setts = Settings{
FilteringEnabled: true,
ProtectionEnabled: true,
FilteringEnabled: true,
}
setts.FilteringEnabled = true
if c != nil {
c.SafeBrowsingCacheSize = 10000
c.ParentalCacheSize = 10000
@ -797,7 +799,11 @@ func TestClientSettings(t *testing.T) {
makeTester := func(tc testCase, before bool) func(t *testing.T) {
return func(t *testing.T) {
r, _ := d.CheckHost(tc.host, dns.TypeA, &setts)
t.Helper()
r, err := d.CheckHost(tc.host, dns.TypeA, &setts)
require.NoError(t, err)
if before {
assert.True(t, r.IsFiltered)
assert.Equal(t, tc.wantReason, r.Reason)
@ -808,7 +814,7 @@ func TestClientSettings(t *testing.T) {
}
// Check behaviour without any per-client settings, then apply per-client
// settings and check behaviour once again.
// settings and check behavior once again.
for _, tc := range testCases {
t.Run(tc.name, makeTester(tc, tc.before))
}

View file

@ -306,7 +306,7 @@ func (d *DNSFilter) checkSafeBrowsing(
_ uint16,
setts *Settings,
) (res Result, err error) {
if !setts.SafeBrowsingEnabled {
if !setts.ProtectionEnabled || !setts.SafeBrowsingEnabled {
return Result{}, nil
}
@ -339,7 +339,7 @@ func (d *DNSFilter) checkParental(
_ uint16,
setts *Settings,
) (res Result, err error) {
if !setts.ParentalEnabled {
if !setts.ProtectionEnabled || !setts.ParentalEnabled {
return Result{}, nil
}

View file

@ -117,6 +117,7 @@ func TestSBPC_checkErrorUpstream(t *testing.T) {
d.SetParentalUpstream(ups)
setts := &Settings{
ProtectionEnabled: true,
SafeBrowsingEnabled: true,
ParentalEnabled: true,
}
@ -135,35 +136,36 @@ func TestSBPC(t *testing.T) {
const hostname = "example.org"
setts := &Settings{
ProtectionEnabled: true,
SafeBrowsingEnabled: true,
ParentalEnabled: true,
}
testCases := []struct {
testCache cache.Cache
testFunc func(host string, _ uint16, _ *Settings) (res Result, err error)
name string
block bool
testFunc func(host string, _ uint16, _ *Settings) (res Result, err error)
testCache cache.Cache
}{{
testCache: gctx.safebrowsingCache,
testFunc: d.checkSafeBrowsing,
name: "sb_no_block",
block: false,
testFunc: d.checkSafeBrowsing,
testCache: gctx.safebrowsingCache,
}, {
testCache: gctx.safebrowsingCache,
testFunc: d.checkSafeBrowsing,
name: "sb_block",
block: true,
testFunc: d.checkSafeBrowsing,
testCache: gctx.safebrowsingCache,
}, {
testCache: gctx.parentalCache,
testFunc: d.checkParental,
name: "pc_no_block",
block: false,
testFunc: d.checkParental,
testCache: gctx.parentalCache,
}, {
testCache: gctx.parentalCache,
testFunc: d.checkParental,
name: "pc_block",
block: true,
testFunc: d.checkParental,
testCache: gctx.parentalCache,
}}
for _, tc := range testCases {

View file

@ -74,7 +74,7 @@ func (d *DNSFilter) checkSafeSearch(
_ uint16,
setts *Settings,
) (res Result, err error) {
if !setts.SafeSearchEnabled {
if !setts.ProtectionEnabled || !setts.SafeSearchEnabled {
return Result{}, nil
}

View file

@ -404,6 +404,7 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
setts := Context.dnsFilter.GetConfig()
setts.FilteringEnabled = true
setts.ProtectionEnabled = true
Context.dnsFilter.ApplyBlockedServices(&setts, nil, true)
result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts)
if err != nil {