diff --git a/module/CLI/config/dependencies.config.php b/module/CLI/config/dependencies.config.php index 294b7f83..d296d1aa 100644 --- a/module/CLI/config/dependencies.config.php +++ b/module/CLI/config/dependencies.config.php @@ -3,6 +3,8 @@ declare(strict_types=1); namespace Shlinkio\Shlink\CLI; +use GeoIp2\Database\Reader; +use Shlinkio\Shlink\CLI\Util\GeolocationDbUpdater; use Shlinkio\Shlink\Common\IpGeolocation\GeoLite2\DbUpdater; use Shlinkio\Shlink\Common\IpGeolocation\IpLocationResolverInterface; use Shlinkio\Shlink\Common\Service\PreviewGenerator; @@ -19,6 +21,8 @@ return [ 'factories' => [ Application::class => Factory\ApplicationFactory::class, + GeolocationDbUpdater::class => ConfigAbstractFactory::class, + Command\ShortUrl\GenerateShortUrlCommand::class => ConfigAbstractFactory::class, Command\ShortUrl\ResolveUrlCommand::class => ConfigAbstractFactory::class, Command\ShortUrl\ListShortUrlsCommand::class => ConfigAbstractFactory::class, @@ -44,6 +48,8 @@ return [ ], ConfigAbstractFactory::class => [ + GeolocationDbUpdater::class => [DbUpdater::class, Reader::class], + Command\ShortUrl\GenerateShortUrlCommand::class => [Service\UrlShortener::class, 'config.url_shortener.domain'], Command\ShortUrl\ResolveUrlCommand::class => [Service\UrlShortener::class], Command\ShortUrl\ListShortUrlsCommand::class => [Service\ShortUrlService::class, 'config.url_shortener.domain'], diff --git a/module/CLI/src/Util/GeolocationDbUpdater.php b/module/CLI/src/Util/GeolocationDbUpdater.php index c4fb5e7f..1133d5a2 100644 --- a/module/CLI/src/Util/GeolocationDbUpdater.php +++ b/module/CLI/src/Util/GeolocationDbUpdater.php @@ -26,16 +26,16 @@ class GeolocationDbUpdater implements GeolocationDbUpdaterInterface /** * @throws GeolocationDbUpdateFailedException */ - public function checkDbUpdate(callable $handleProgress = null): void + public function checkDbUpdate(callable $mustBeUpdated = null, callable $handleProgress = null): void { try { $meta = $this->geoLiteDbReader->metadata(); if ($this->buildIsOlderThanOneWeek($meta->__get('buildEpoch'))) { - $this->downloadNewDb(true, $handleProgress); + $this->downloadNewDb(true, $mustBeUpdated, $handleProgress); } } catch (InvalidArgumentException $e) { // This is the exception thrown by the reader when the database file does not exist - $this->downloadNewDb(false, $handleProgress); + $this->downloadNewDb(false, $mustBeUpdated, $handleProgress); } } @@ -49,8 +49,15 @@ class GeolocationDbUpdater implements GeolocationDbUpdaterInterface /** * @throws GeolocationDbUpdateFailedException */ - private function downloadNewDb(bool $olderDbExists, callable $handleProgress = null): void - { + private function downloadNewDb( + bool $olderDbExists, + callable $mustBeUpdated = null, + callable $handleProgress = null + ): void { + if ($mustBeUpdated !== null) { + $mustBeUpdated(); + } + try { $this->dbUpdater->downloadFreshCopy($handleProgress); } catch (RuntimeException $e) { diff --git a/module/CLI/src/Util/GeolocationDbUpdaterInterface.php b/module/CLI/src/Util/GeolocationDbUpdaterInterface.php index 9b3d70d0..1d5bcf48 100644 --- a/module/CLI/src/Util/GeolocationDbUpdaterInterface.php +++ b/module/CLI/src/Util/GeolocationDbUpdaterInterface.php @@ -10,5 +10,5 @@ interface GeolocationDbUpdaterInterface /** * @throws GeolocationDbUpdateFailedException */ - public function checkDbUpdate(callable $handleProgress = null): void; + public function checkDbUpdate(callable $mustBeUpdated = null, callable $handleProgress = null): void; } diff --git a/module/CLI/test/Util/GeolocationDbUpdaterTest.php b/module/CLI/test/Util/GeolocationDbUpdaterTest.php index 2e7a7b10..be6709fa 100644 --- a/module/CLI/test/Util/GeolocationDbUpdaterTest.php +++ b/module/CLI/test/Util/GeolocationDbUpdaterTest.php @@ -41,12 +41,15 @@ class GeolocationDbUpdaterTest extends TestCase /** @test */ public function exceptionIsThrownWhenOlderDbDoesNotExistAndDownloadFails(): void { + $mustBeUpdated = function () { + $this->assertTrue(true); + }; $getMeta = $this->geoLiteDbReader->metadata()->willThrow(InvalidArgumentException::class); $prev = new RuntimeException(''); $download = $this->dbUpdater->downloadFreshCopy(null)->willThrow($prev); try { - $this->geolocationDbUpdater->checkDbUpdate(); + $this->geolocationDbUpdater->checkDbUpdate($mustBeUpdated); $this->assertTrue(false); // If this is reached, the test will fail } catch (Throwable $e) { /** @var GeolocationDbUpdateFailedException $e */