diff --git a/LightlessSync/PlayerData/Factories/PlayerDataFactory.cs b/LightlessSync/PlayerData/Factories/PlayerDataFactory.cs index 51dba8f..2a45fbf 100644 --- a/LightlessSync/PlayerData/Factories/PlayerDataFactory.cs +++ b/LightlessSync/PlayerData/Factories/PlayerDataFactory.cs @@ -9,10 +9,10 @@ using LightlessSync.PlayerData.Data; using LightlessSync.PlayerData.Handlers; using LightlessSync.Services; using LightlessSync.Services.Mediator; +using LightlessSync.Utils; using Microsoft.Extensions.Logging; using System.Collections.Concurrent; using System.Diagnostics; -using System.Runtime.ExceptionServices; using System.Runtime.InteropServices; namespace LightlessSync.PlayerData.Factories; @@ -34,7 +34,7 @@ public class PlayerDataFactory private const int _maxTransientResolvedEntries = 1000; // Character build caches - private readonly ConcurrentDictionary> _characterBuildInflight = new(); + private readonly TaskRegistry _characterBuildInflight = new(); private readonly ConcurrentDictionary _characterBuildCache = new(); // Time out thresholds @@ -170,10 +170,10 @@ public class PlayerDataFactory { var key = obj.Address; - if (_characterBuildCache.TryGetValue(key, out var cached) && IsCacheFresh(cached) && !_characterBuildInflight.ContainsKey(key)) + if (_characterBuildCache.TryGetValue(key, out CacheEntry cached) && IsCacheFresh(cached) && !_characterBuildInflight.TryGetExisting(key, out _)) return cached.Fragment; - var buildTask = _characterBuildInflight.GetOrAdd(key, _ => BuildAndCacheAsync(obj, key)); + Task buildTask = _characterBuildInflight.GetOrStart(key, () => BuildAndCacheAsync(obj, key)); if (_characterBuildCache.TryGetValue(key, out cached)) { @@ -189,20 +189,13 @@ public class PlayerDataFactory private async Task BuildAndCacheAsync(GameObjectHandler obj, nint key) { - try - { - using var cts = new CancellationTokenSource(_hardBuildTimeout); - var fragment = await CreateCharacterDataInternal(obj, cts.Token).ConfigureAwait(false); + using var cts = new CancellationTokenSource(_hardBuildTimeout); + CharacterDataFragment fragment = await CreateCharacterDataInternal(obj, cts.Token).ConfigureAwait(false); - _characterBuildCache[key] = new CacheEntry(fragment, DateTime.UtcNow); - PruneCharacterCacheIfNeeded(); + _characterBuildCache[key] = new CacheEntry(fragment, DateTime.UtcNow); + PruneCharacterCacheIfNeeded(); - return fragment; - } - finally - { - _characterBuildInflight.TryRemove(key, out _); - } + return fragment; } private void PruneCharacterCacheIfNeeded() diff --git a/LightlessSync/Utils/TaskRegistry.cs b/LightlessSync/Utils/TaskRegistry.cs index d888bbd..90b6fcf 100644 --- a/LightlessSync/Utils/TaskRegistry.cs +++ b/LightlessSync/Utils/TaskRegistry.cs @@ -1,37 +1,81 @@ using System.Collections.Concurrent; - namespace LightlessSync.Utils; public sealed class TaskRegistry where HandleType : notnull { - private readonly ConcurrentDictionary _activeTasks = new(); + private readonly ConcurrentDictionary> _activeTasks = new(); public Task GetOrStart(HandleType handle, Func taskFactory) - { - ActiveTask entry = _activeTasks.GetOrAdd(handle, i => new ActiveTask(() => ExecuteAndRemove(i, taskFactory))); - return entry.EnsureStarted(); - } + => GetOrStartInternal(handle, taskFactory); public Task GetOrStart(HandleType handle, Func> taskFactory) - { - ActiveTask entry = _activeTasks.GetOrAdd(handle, i => new ActiveTask(() => ExecuteAndRemove(i, taskFactory))); - return (Task)entry.EnsureStarted(); - } + => GetOrStartInternal(handle, taskFactory); public bool TryGetExisting(HandleType handle, out Task task) { - if (_activeTasks.TryGetValue(handle, out ActiveTask? entry)) + if (_activeTasks.TryGetValue(handle, out Lazy? entry)) { - task = entry.EnsureStarted(); - return true; + task = entry.Value; + if (!task.IsCompleted) + { + return true; + } + + _activeTasks.TryRemove(new KeyValuePair>(handle, entry)); } task = Task.CompletedTask; return false; } - private async Task ExecuteAndRemove(HandleType handle, Func taskFactory) + private Task GetOrStartInternal(HandleType handle, Func taskFactory) + { + while (true) + { + Lazy entry = _activeTasks.GetOrAdd(handle, _ => CreateEntry(handle, taskFactory)); + Task task = entry.Value; + + if (!task.IsCompleted) + { + return task; + } + + _activeTasks.TryRemove(new KeyValuePair>(handle, entry)); + } + } + + private Task GetOrStartInternal(HandleType handle, Func> taskFactory) + { + while (true) + { + Lazy entry = _activeTasks.GetOrAdd(handle, _ => CreateEntry(handle, taskFactory)); + Task task = entry.Value; + + if (!task.IsCompleted) + { + return (Task)task; + } + + _activeTasks.TryRemove(new KeyValuePair>(handle, entry)); + } + } + + private Lazy CreateEntry(HandleType handle, Func taskFactory) + { + Lazy entry = null!; + entry = new Lazy(() => ExecuteAndRemove(handle, entry, taskFactory), LazyThreadSafetyMode.ExecutionAndPublication); + return entry; + } + + private Lazy CreateEntry(HandleType handle, Func> taskFactory) + { + Lazy entry = null!; + entry = new Lazy(() => ExecuteAndRemove(handle, entry, taskFactory), LazyThreadSafetyMode.ExecutionAndPublication); + return entry; + } + + private async Task ExecuteAndRemove(HandleType handle, Lazy entry, Func taskFactory) { try { @@ -39,11 +83,11 @@ public sealed class TaskRegistry where HandleType : notnull } finally { - _activeTasks.TryRemove(handle, out _); + _activeTasks.TryRemove(new KeyValuePair>(handle, entry)); } } - private async Task ExecuteAndRemove(HandleType handle, Func> taskFactory) + private async Task ExecuteAndRemove(HandleType handle, Lazy entry, Func> taskFactory) { try { @@ -51,31 +95,7 @@ public sealed class TaskRegistry where HandleType : notnull } finally { - _activeTasks.TryRemove(handle, out _); - } - } - - private sealed class ActiveTask - { - private readonly object _gate = new(); - private readonly Func _starter; - private Task? _cached; - - public ActiveTask(Func starter) - { - _starter = starter; - } - - public Task EnsureStarted() - { - lock (_gate) - { - if (_cached == null || _cached.IsCompleted) - { - _cached = _starter(); - } - return _cached; - } + _activeTasks.TryRemove(new KeyValuePair>(handle, entry)); } } }