Extract logic to match IP address against list of groups

This commit is contained in:
Alejandro Celaya 2024-07-06 10:12:05 +02:00
parent b6b2530cb6
commit 8d90661d0a
4 changed files with 103 additions and 43 deletions

View file

@ -0,0 +1,15 @@
<?php
declare(strict_types=1);
namespace Shlinkio\Shlink\Core\Exception;
use function sprintf;
class InvalidIpFormatException extends RuntimeException implements ExceptionInterface
{
public static function fromInvalidIp(string $ipAddress): self
{
return new self(sprintf('Provided IP %s does not have the right format. Expected X.X.X.X', $ipAddress));
}
}

View file

@ -0,0 +1,65 @@
<?php
declare(strict_types=1);
namespace Shlinkio\Shlink\Core\Util;
use IPLib\Address\IPv4;
use IPLib\Factory;
use IPLib\Range\RangeInterface;
use Shlinkio\Shlink\Core\Exception\InvalidIpFormatException;
use function array_keys;
use function array_map;
use function explode;
use function implode;
use function Shlinkio\Shlink\Core\ArrayUtils\some;
class IpAddressUtils
{
/**
* Checks if an IP address matches any of provided groups.
* Every group can be a static IP address (100.200.80.40), a CIDR block (192.168.10.0/24) or a wildcard pattern
* (11.22.*.*).
*
* Matching will happen as follows:
* * Static IP address -> strict equality with provided IP address.
* * CIDR block -> provided IP address is part of that block.
* * Wildcard -> static parts match the corresponding ones in provided IP address.
*
* @param string[] $groups
* @throws InvalidIpFormatException
*/
public static function ipAddressMatchesGroups(string $ipAddress, array $groups): bool
{
$ip = IPv4::parseString($ipAddress);
if ($ip === null) {
throw InvalidIpFormatException::fromInvalidIp($ipAddress);
}
$ipAddressParts = explode('.', $ipAddress);
return some($groups, function (string $value) use ($ip, $ipAddressParts): bool {
$range = str_contains($value, '*')
? self::parseValueWithWildcards($value, $ipAddressParts)
: Factory::parseRangeString($value);
return $range !== null && $ip->matches($range);
});
}
private static function parseValueWithWildcards(string $value, array $ipAddressParts): ?RangeInterface
{
$octets = explode('.', $value);
$keys = array_keys($octets);
// Replace wildcard parts with the corresponding ones from the remote address
return Factory::parseRangeString(
implode('.', array_map(
fn (string $part, int $index) => $part === '*' ? $ipAddressParts[$index] : $part,
$octets,
$keys,
)),
);
}
}

View file

@ -5,30 +5,20 @@ declare(strict_types=1);
namespace Shlinkio\Shlink\Core\Visit; namespace Shlinkio\Shlink\Core\Visit;
use Fig\Http\Message\RequestMethodInterface; use Fig\Http\Message\RequestMethodInterface;
use IPLib\Address\IPv4;
use IPLib\Factory;
use IPLib\Range\RangeInterface;
use Mezzio\Router\Middleware\ImplicitHeadMiddleware; use Mezzio\Router\Middleware\ImplicitHeadMiddleware;
use Psr\Http\Message\ServerRequestInterface; use Psr\Http\Message\ServerRequestInterface;
use Shlinkio\Shlink\Common\Middleware\IpAddressMiddlewareFactory; use Shlinkio\Shlink\Common\Middleware\IpAddressMiddlewareFactory;
use Shlinkio\Shlink\Core\ErrorHandler\Model\NotFoundType; use Shlinkio\Shlink\Core\ErrorHandler\Model\NotFoundType;
use Shlinkio\Shlink\Core\Exception\InvalidIpFormatException;
use Shlinkio\Shlink\Core\Options\TrackingOptions; use Shlinkio\Shlink\Core\Options\TrackingOptions;
use Shlinkio\Shlink\Core\ShortUrl\Entity\ShortUrl; use Shlinkio\Shlink\Core\ShortUrl\Entity\ShortUrl;
use Shlinkio\Shlink\Core\Util\IpAddressUtils;
use Shlinkio\Shlink\Core\Visit\Model\Visitor; use Shlinkio\Shlink\Core\Visit\Model\Visitor;
use function array_keys; readonly class RequestTracker implements RequestTrackerInterface, RequestMethodInterface
use function array_map; {
use function explode; public function __construct(private VisitsTrackerInterface $visitsTracker, private TrackingOptions $trackingOptions)
use function implode;
use function Shlinkio\Shlink\Core\ArrayUtils\some;
use function str_contains;
class RequestTracker implements RequestTrackerInterface, RequestMethodInterface
{ {
public function __construct(
private readonly VisitsTrackerInterface $visitsTracker,
private readonly TrackingOptions $trackingOptions,
) {
} }
public function trackIfApplicable(ShortUrl $shortUrl, ServerRequestInterface $request): void public function trackIfApplicable(ShortUrl $shortUrl, ServerRequestInterface $request): void
@ -78,35 +68,10 @@ class RequestTracker implements RequestTrackerInterface, RequestMethodInterface
return false; return false;
} }
$ip = IPv4::parseString($remoteAddr); try {
if ($ip === null) { return IpAddressUtils::ipAddressMatchesGroups($remoteAddr, $this->trackingOptions->disableTrackingFrom);
} catch (InvalidIpFormatException) {
return false; return false;
} }
$remoteAddrParts = explode('.', $remoteAddr);
$disableTrackingFrom = $this->trackingOptions->disableTrackingFrom;
return some($disableTrackingFrom, function (string $value) use ($ip, $remoteAddrParts): bool {
$range = str_contains($value, '*')
? $this->parseValueWithWildcards($value, $remoteAddrParts)
: Factory::parseRangeString($value);
return $range !== null && $ip->matches($range);
});
}
private function parseValueWithWildcards(string $value, array $remoteAddrParts): ?RangeInterface
{
$octets = explode('.', $value);
$keys = array_keys($octets);
// Replace wildcard parts with the corresponding ones from the remote address
return Factory::parseRangeString(
implode('.', array_map(
fn (string $part, int $index) => $part === '*' ? $remoteAddrParts[$index] : $part,
$octets,
$keys,
)),
);
} }
} }

View file

@ -92,6 +92,21 @@ class RequestTrackerTest extends TestCase
$this->requestTracker->trackIfApplicable($shortUrl, $this->request); $this->requestTracker->trackIfApplicable($shortUrl, $this->request);
} }
#[Test]
public function trackingHappensOverShortUrlsWhenRemoteAddressIsInvalid(): void
{
$shortUrl = ShortUrl::withLongUrl(self::LONG_URL);
$this->visitsTracker->expects($this->once())->method('track')->with(
$shortUrl,
$this->isInstanceOf(Visitor::class),
);
$this->requestTracker->trackIfApplicable($shortUrl, ServerRequestFactory::fromGlobals()->withAttribute(
IpAddressMiddlewareFactory::REQUEST_ATTR,
'invalid',
));
}
#[Test] #[Test]
public function baseUrlErrorIsTracked(): void public function baseUrlErrorIsTracked(): void
{ {