fix task register
This commit is contained in:
@@ -1,37 +1,81 @@
|
||||
using System.Collections.Concurrent;
|
||||
|
||||
|
||||
namespace LightlessSync.Utils;
|
||||
|
||||
public sealed class TaskRegistry<HandleType> where HandleType : notnull
|
||||
{
|
||||
private readonly ConcurrentDictionary<HandleType, ActiveTask> _activeTasks = new();
|
||||
private readonly ConcurrentDictionary<HandleType, Lazy<Task>> _activeTasks = new();
|
||||
|
||||
public Task GetOrStart(HandleType handle, Func<Task> taskFactory)
|
||||
{
|
||||
ActiveTask entry = _activeTasks.GetOrAdd(handle, i => new ActiveTask(() => ExecuteAndRemove(i, taskFactory)));
|
||||
return entry.EnsureStarted();
|
||||
}
|
||||
=> GetOrStartInternal(handle, taskFactory);
|
||||
|
||||
public Task<T> GetOrStart<T>(HandleType handle, Func<Task<T>> taskFactory)
|
||||
{
|
||||
ActiveTask entry = _activeTasks.GetOrAdd(handle, i => new ActiveTask(() => ExecuteAndRemove(i, taskFactory)));
|
||||
return (Task<T>)entry.EnsureStarted();
|
||||
}
|
||||
=> GetOrStartInternal(handle, taskFactory);
|
||||
|
||||
public bool TryGetExisting(HandleType handle, out Task task)
|
||||
{
|
||||
if (_activeTasks.TryGetValue(handle, out ActiveTask? entry))
|
||||
if (_activeTasks.TryGetValue(handle, out Lazy<Task>? entry))
|
||||
{
|
||||
task = entry.EnsureStarted();
|
||||
return true;
|
||||
task = entry.Value;
|
||||
if (!task.IsCompleted)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
_activeTasks.TryRemove(new KeyValuePair<HandleType, Lazy<Task>>(handle, entry));
|
||||
}
|
||||
|
||||
task = Task.CompletedTask;
|
||||
return false;
|
||||
}
|
||||
|
||||
private async Task ExecuteAndRemove(HandleType handle, Func<Task> taskFactory)
|
||||
private Task GetOrStartInternal(HandleType handle, Func<Task> taskFactory)
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
Lazy<Task> entry = _activeTasks.GetOrAdd(handle, _ => CreateEntry(handle, taskFactory));
|
||||
Task task = entry.Value;
|
||||
|
||||
if (!task.IsCompleted)
|
||||
{
|
||||
return task;
|
||||
}
|
||||
|
||||
_activeTasks.TryRemove(new KeyValuePair<HandleType, Lazy<Task>>(handle, entry));
|
||||
}
|
||||
}
|
||||
|
||||
private Task<T> GetOrStartInternal<T>(HandleType handle, Func<Task<T>> taskFactory)
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
Lazy<Task> entry = _activeTasks.GetOrAdd(handle, _ => CreateEntry(handle, taskFactory));
|
||||
Task task = entry.Value;
|
||||
|
||||
if (!task.IsCompleted)
|
||||
{
|
||||
return (Task<T>)task;
|
||||
}
|
||||
|
||||
_activeTasks.TryRemove(new KeyValuePair<HandleType, Lazy<Task>>(handle, entry));
|
||||
}
|
||||
}
|
||||
|
||||
private Lazy<Task> CreateEntry(HandleType handle, Func<Task> taskFactory)
|
||||
{
|
||||
Lazy<Task> entry = null!;
|
||||
entry = new Lazy<Task>(() => ExecuteAndRemove(handle, entry, taskFactory), LazyThreadSafetyMode.ExecutionAndPublication);
|
||||
return entry;
|
||||
}
|
||||
|
||||
private Lazy<Task> CreateEntry<T>(HandleType handle, Func<Task<T>> taskFactory)
|
||||
{
|
||||
Lazy<Task> entry = null!;
|
||||
entry = new Lazy<Task>(() => ExecuteAndRemove(handle, entry, taskFactory), LazyThreadSafetyMode.ExecutionAndPublication);
|
||||
return entry;
|
||||
}
|
||||
|
||||
private async Task ExecuteAndRemove(HandleType handle, Lazy<Task> entry, Func<Task> taskFactory)
|
||||
{
|
||||
try
|
||||
{
|
||||
@@ -39,11 +83,11 @@ public sealed class TaskRegistry<HandleType> where HandleType : notnull
|
||||
}
|
||||
finally
|
||||
{
|
||||
_activeTasks.TryRemove(handle, out _);
|
||||
_activeTasks.TryRemove(new KeyValuePair<HandleType, Lazy<Task>>(handle, entry));
|
||||
}
|
||||
}
|
||||
|
||||
private async Task<T> ExecuteAndRemove<T>(HandleType handle, Func<Task<T>> taskFactory)
|
||||
private async Task<T> ExecuteAndRemove<T>(HandleType handle, Lazy<Task> entry, Func<Task<T>> taskFactory)
|
||||
{
|
||||
try
|
||||
{
|
||||
@@ -51,31 +95,7 @@ public sealed class TaskRegistry<HandleType> where HandleType : notnull
|
||||
}
|
||||
finally
|
||||
{
|
||||
_activeTasks.TryRemove(handle, out _);
|
||||
}
|
||||
}
|
||||
|
||||
private sealed class ActiveTask
|
||||
{
|
||||
private readonly object _gate = new();
|
||||
private readonly Func<Task> _starter;
|
||||
private Task? _cached;
|
||||
|
||||
public ActiveTask(Func<Task> starter)
|
||||
{
|
||||
_starter = starter;
|
||||
}
|
||||
|
||||
public Task EnsureStarted()
|
||||
{
|
||||
lock (_gate)
|
||||
{
|
||||
if (_cached == null || _cached.IsCompleted)
|
||||
{
|
||||
_cached = _starter();
|
||||
}
|
||||
return _cached;
|
||||
}
|
||||
_activeTasks.TryRemove(new KeyValuePair<HandleType, Lazy<Task>>(handle, entry));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user