Files
LightlessClient/LightlessSync/Utils/TaskRegistry.cs
2026-01-16 11:00:58 +09:00

82 lines
2.1 KiB
C#

using System.Collections.Concurrent;
namespace LightlessSync.Utils;
public sealed class TaskRegistry<HandleType> where HandleType : notnull
{
private readonly ConcurrentDictionary<HandleType, ActiveTask> _activeTasks = new();
public Task GetOrStart(HandleType handle, Func<Task> taskFactory)
{
ActiveTask entry = _activeTasks.GetOrAdd(handle, i => new ActiveTask(() => ExecuteAndRemove(i, taskFactory)));
return entry.EnsureStarted();
}
public Task<T> GetOrStart<T>(HandleType handle, Func<Task<T>> taskFactory)
{
ActiveTask entry = _activeTasks.GetOrAdd(handle, i => new ActiveTask(() => ExecuteAndRemove(i, taskFactory)));
return (Task<T>)entry.EnsureStarted();
}
public bool TryGetExisting(HandleType handle, out Task task)
{
if (_activeTasks.TryGetValue(handle, out ActiveTask? entry))
{
task = entry.EnsureStarted();
return true;
}
task = Task.CompletedTask;
return false;
}
private async Task ExecuteAndRemove(HandleType handle, Func<Task> taskFactory)
{
try
{
await taskFactory().ConfigureAwait(false);
}
finally
{
_activeTasks.TryRemove(handle, out _);
}
}
private async Task<T> ExecuteAndRemove<T>(HandleType handle, Func<Task<T>> taskFactory)
{
try
{
return await taskFactory().ConfigureAwait(false);
}
finally
{
_activeTasks.TryRemove(handle, out _);
}
}
private sealed class ActiveTask
{
private readonly object _gate = new();
private readonly Func<Task> _starter;
private Task? _cached;
public ActiveTask(Func<Task> starter)
{
_starter = starter;
}
public Task EnsureStarted()
{
lock (_gate)
{
if (_cached == null || _cached.IsCompleted)
{
_cached = _starter();
}
return _cached;
}
}
}
}