diff --git a/src/Core/Utilities/ServiceContainer.cs b/src/Core/Utilities/ServiceContainer.cs index 19450173b..84dda33bd 100644 --- a/src/Core/Utilities/ServiceContainer.cs +++ b/src/Core/Utilities/ServiceContainer.cs @@ -1,5 +1,7 @@ using System; +using System.Collections.Concurrent; using System.Collections.Generic; +using System.Text; using System.Threading.Tasks; using Bit.Core.Abstractions; using Bit.Core.Services; @@ -8,7 +10,7 @@ namespace Bit.Core.Utilities { public static class ServiceContainer { - public static Dictionary RegisteredServices { get; set; } = new Dictionary(); + public static ConcurrentDictionary RegisteredServices { get; set; } = new ConcurrentDictionary(); public static bool Inited { get; set; } public static void Init(string customUserAgent = null, string clearCipherCacheKey = null, @@ -109,18 +111,17 @@ namespace Bit.Core.Utilities public static void Register(string serviceName, T obj) { - if (RegisteredServices.ContainsKey(serviceName)) + if (!RegisteredServices.TryAdd(serviceName, obj)) { throw new Exception($"Service {serviceName} has already been registered."); } - RegisteredServices.Add(serviceName, obj); } public static T Resolve(string serviceName, bool dontThrow = false) { - if (RegisteredServices.ContainsKey(serviceName)) + if (RegisteredServices.TryGetValue(serviceName, out var service)) { - return (T)RegisteredServices[serviceName]; + return (T)service; } if (dontThrow) { @@ -129,6 +130,59 @@ namespace Bit.Core.Utilities throw new Exception($"Service {serviceName} is not registered."); } + public static void Register(T obj) + where T : class + { + Register(typeof(T), obj); + } + + public static void Register(Type type, object obj) + { + var serviceName = GetServiceRegistrationName(type); + if (!RegisteredServices.TryAdd(serviceName, obj)) + { + throw new Exception($"Service {serviceName} has already been registered."); + } + } + + public static T Resolve() + where T : class + { + return (T)Resolve(typeof(T)); + } + + public static object Resolve(Type type) + { + var serviceName = GetServiceRegistrationName(type); + if (RegisteredServices.TryGetValue(serviceName, out var service)) + { + return service; + } + throw new Exception($"Service {serviceName} is not registered."); + } + + public static bool TryResolve(out T service) + where T : class + { + try + { + var toReturn = TryResolve(typeof(T), out var serviceObj); + service = (T)serviceObj; + return toReturn; + } + catch (Exception) + { + service = null; + return false; + } + } + + public static bool TryResolve(Type type, out object service) + { + var serviceName = GetServiceRegistrationName(type); + return RegisteredServices.TryGetValue(serviceName, out service); + } + public static void Reset() { foreach (var service in RegisteredServices) @@ -140,7 +194,33 @@ namespace Bit.Core.Utilities } Inited = false; RegisteredServices.Clear(); - RegisteredServices = new Dictionary(); + RegisteredServices = new ConcurrentDictionary(); + } + + /// + /// Gets the service registration name + /// + /// Type of the service + /// + /// In order to work with already register/resolve we need to maintain the naming convention + /// of camelCase without the first "I" on the services interfaces + /// e.g. "ITokenService" -> "tokenService" + /// + static string GetServiceRegistrationName(Type type) + { + var typeName = type.Name; + var sb = new StringBuilder(); + + var indexToLowerCase = 0; + if (typeName[0] == 'I' && char.IsUpper(typeName[1])) + { + // if it's an interface then we ignore the first char + // and lower case the 2nd one (index 1) + indexToLowerCase = 1; + } + sb.Append(char.ToLower(typeName[indexToLowerCase])); + sb.Append(typeName.Substring(++indexToLowerCase)); + return sb.ToString(); } } }