Skip to content

Commit 20be4cb

Browse files
authored
Add AwaitCompletion for await “tasks”
Merge of PR #479
1 parent 06b7bdb commit 20be4cb

File tree

3 files changed

+126
-34
lines changed

3 files changed

+126
-34
lines changed

MoreLinq/Experimental/Await.cs

Lines changed: 118 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,87 @@ public static IAwaitQuery<T> Await<T>(this IEnumerable<Task<T>> source) =>
329329
/// </remarks>
330330

331331
public static IAwaitQuery<TResult> Await<T, TResult>(
332-
this IEnumerable<T> source, Func<T, CancellationToken, Task<TResult>> evaluator)
332+
this IEnumerable<T> source, Func<T, CancellationToken, Task<TResult>> evaluator) =>
333+
AwaitQuery.Create(options =>
334+
from t in source.AwaitCompletion(evaluator, (_, t) => t)
335+
.WithOptions(options)
336+
select t.GetAwaiter().GetResult());
337+
338+
/*
339+
/// <summary>
340+
/// Awaits completion of all asynchronous evaluations.
341+
/// </summary>
342+
343+
public static IAwaitQuery<TResult> AwaitCompletion<T, TT, TResult>(
344+
this IEnumerable<T> source,
345+
Func<T, CancellationToken, Task<TT>> evaluator,
346+
Func<T, TT, TResult> resultSelector,
347+
Func<T, Exception, TResult> errorSelector,
348+
Func<T, TResult> cancellationSelector) =>
349+
AwaitQuery.Create(options =>
350+
from e in source.AwaitCompletion(evaluator, (item, task) => (Item: item, Task: task))
351+
.WithOptions(options)
352+
select e.Task.IsFaulted
353+
? errorSelector(e.Item, e.Task.Exception)
354+
: e.Task.IsCanceled
355+
? cancellationSelector(e.Item)
356+
: resultSelector(e.Item, e.Task.Result));
357+
*/
358+
359+
/// <summary>
360+
/// Awaits completion of all asynchronous evaluations irrespective of
361+
/// whether they succeed or fail. An additional argument specifies a
362+
/// function that projects the final result given the source item and
363+
/// completed task.
364+
/// </summary>
365+
/// <typeparam name="T">The type of the source elements.</typeparam>
366+
/// <typeparam name="TTaskResult"> The type of the tasks's result.</typeparam>
367+
/// <typeparam name="TResult">The type of the result elements.</typeparam>
368+
/// <param name="source">The source sequence.</param>
369+
/// <param name="evaluator">A function to begin the asynchronous
370+
/// evaluation of each element, the second parameter of which is a
371+
/// <see cref="CancellationToken"/> that can be used to abort
372+
/// asynchronous operations.</param>
373+
/// <param name="resultSelector">A fucntion that projects the final
374+
/// result given the source item and its asynchronous completion
375+
/// result.</param>
376+
/// <returns>
377+
/// A sequence query that stream its results as they are
378+
/// evaluated asynchronously.
379+
/// </returns>
380+
/// <remarks>
381+
/// <para>
382+
/// This method uses deferred execution semantics. The results are
383+
/// yielded as each asynchronous evaluation completes and, by default,
384+
/// not guaranteed to be based on the source sequence order. If order
385+
/// is important, compose further with
386+
/// <see cref="AsOrdered{T}"/>.</para>
387+
/// <para>
388+
/// This method starts a new task where the asynchronous evaluations
389+
/// take place and awaited. If the resulting sequence is partially
390+
/// consumed then there's a good chance that some projection work will
391+
/// be wasted and a cooperative effort is done that depends on the
392+
/// <paramref name="evaluator"/> function (via a
393+
/// <see cref="CancellationToken"/> as its second argument) to cancel
394+
/// those in flight.</para>
395+
/// <para>
396+
/// The <paramref name="evaluator"/> function should be designed to be
397+
/// thread-agnostic.</para>
398+
/// <para>
399+
/// The task returned by <paramref name="evaluator"/> should be started
400+
/// when the function is called (and not just a mere projection)
401+
/// otherwise changing concurrency options via
402+
/// <see cref="AsSequential{T}"/>, <see cref="MaxConcurrency{T}"/> or
403+
/// <see cref="UnboundedConcurrency{T}"/> will only change how many
404+
/// tasks are awaited at any given moment, not how many will be
405+
/// kept in flight.
406+
/// </para>
407+
/// </remarks>
408+
409+
public static IAwaitQuery<TResult> AwaitCompletion<T, TTaskResult, TResult>(
410+
this IEnumerable<T> source,
411+
Func<T, CancellationToken, Task<TTaskResult>> evaluator,
412+
Func<T, Task<TTaskResult>, TResult> resultSelector)
333413
{
334414
if (source == null) throw new ArgumentNullException(nameof(source));
335415
if (evaluator == null) throw new ArgumentNullException(nameof(evaluator));
@@ -342,14 +422,14 @@ public static IAwaitQuery<TResult> Await<T, TResult>(
342422

343423
IEnumerable<TResult> _(int maxConcurrency, TaskScheduler scheduler, bool ordered)
344424
{
345-
var notices = new BlockingCollection<(Notice, (int, TResult), ExceptionDispatchInfo)>();
425+
var notices = new BlockingCollection<(Notice, (int, T, Task<TTaskResult>), ExceptionDispatchInfo)>();
346426
var cancellationTokenSource = new CancellationTokenSource();
347427
var cancellationToken = cancellationTokenSource.Token;
348428
var completed = false;
349429

350430
var enumerator =
351431
source.Index()
352-
.Select(e => (e.Key, Task: evaluator(e.Value, cancellationToken)))
432+
.Select(e => (e.Key, Item: e.Value, Task: evaluator(e.Value, cancellationToken)))
353433
.GetEnumerator();
354434

355435
IDisposable disposable = enumerator; // disables AccessToDisposedClosure warnings
@@ -362,7 +442,7 @@ IEnumerable<TResult> _(int maxConcurrency, TaskScheduler scheduler, bool ordered
362442
enumerator,
363443
e => e.Task,
364444
notices,
365-
(e, r) => (Notice.Result, (e.Key, r), default),
445+
(e, r) => (Notice.Result, (e.Key, e.Item, e.Task), default),
366446
ex => (Notice.Error, default, ExceptionDispatchInfo.Capture(ex)),
367447
(Notice.End, default, default),
368448
maxConcurrency, cancellationTokenSource),
@@ -371,7 +451,7 @@ IEnumerable<TResult> _(int maxConcurrency, TaskScheduler scheduler, bool ordered
371451
scheduler);
372452

373453
var nextKey = 0;
374-
var holds = ordered ? new List<(int, TResult)>() : null;
454+
var holds = ordered ? new List<(int, T, Task<TTaskResult>)>() : null;
375455

376456
foreach (var (kind, result, error) in notices.GetConsumingEnumerable())
377457
{
@@ -383,14 +463,14 @@ IEnumerable<TResult> _(int maxConcurrency, TaskScheduler scheduler, bool ordered
383463

384464
Debug.Assert(kind == Notice.Result);
385465

386-
var (key, value) = result;
466+
var (key, inp, value) = result;
387467
if (holds == null || key == nextKey)
388468
{
389469
// If order does not need to be preserved or the key
390470
// is the next that should be yielded then yield
391471
// the result.
392472

393-
yield return value;
473+
yield return resultSelector(inp, value);
394474

395475
if (holds != null) // preserve order?
396476
{
@@ -401,12 +481,12 @@ IEnumerable<TResult> _(int maxConcurrency, TaskScheduler scheduler, bool ordered
401481

402482
for (nextKey++; holds.Count > 0; nextKey++)
403483
{
404-
var (candidateKey, candidate) = holds[0];
484+
var (candidateKey, ic, candidate) = holds[0];
405485
if (candidateKey != nextKey)
406486
break;
407487

408488
releaseCount++;
409-
yield return candidate;
489+
yield return resultSelector(ic, candidate);
410490
}
411491

412492
holds.RemoveRange(0, releaseCount);
@@ -419,18 +499,18 @@ IEnumerable<TResult> _(int maxConcurrency, TaskScheduler scheduler, bool ordered
419499
// where it belongs in the order of results withheld
420500
// so far and insert it in the list.
421501

422-
var i = holds.BinarySearch(result, TupleComparer<int, TResult>.Item1);
502+
var i = holds.BinarySearch(result, TupleComparer<int, T, Task<TTaskResult>>.Item1);
423503
Debug.Assert(i < 0);
424504
holds.Insert(~i, result);
425505
}
426506
}
427507

428508
if (holds?.Count > 0) // yield any withheld, which should be in order...
429509
{
430-
foreach (var (key, value) in holds)
510+
foreach (var (key, x, value) in holds)
431511
{
432512
Debug.Assert(nextKey++ == key); //...assert so!
433-
yield return value;
513+
yield return resultSelector(x, value);
434514
}
435515
}
436516

@@ -453,14 +533,11 @@ IEnumerable<TResult> _(int maxConcurrency, TaskScheduler scheduler, bool ordered
453533

454534
enum Notice { Result, Error, End }
455535

456-
static async Task<TResult> Select<T, TResult>(this Task<T> task, Func<T, TResult> selector) =>
457-
selector(await task.ConfigureAwait(continueOnCapturedContext: false));
458-
459536
static async Task CollectToAsync<T, TResult, TNotice>(
460537
this IEnumerator<T> e,
461538
Func<T, Task<TResult>> taskSelector,
462539
BlockingCollection<TNotice> collection,
463-
Func<T, TResult, TNotice> resultNoticeSelector,
540+
Func<T, Task<TResult>, TNotice> completionNoticeSelector,
464541
Func<Exception, TNotice> errorNoticeSelector,
465542
TNotice endNotice,
466543
int maxConcurrency,
@@ -476,13 +553,13 @@ static async Task CollectToAsync<T, TResult, TNotice>(
476553
var cancellationTaskSource = new TaskCompletionSource<bool>();
477554
cancellationToken.Register(() => cancellationTaskSource.TrySetResult(true));
478555

479-
var tasks = new List<Task<(T, TResult)>>();
556+
var tasks = new List<(T Item, Task<TResult> Task)>();
480557

481558
for (var i = 0; i < maxConcurrency; i++)
482559
{
483560
if (!reader.TryRead(out var item))
484561
break;
485-
tasks.Add(taskSelector(item).Select(r => (item, r)));
562+
tasks.Add((item, taskSelector(item)));
486563
}
487564

488565
while (tasks.Count > 0)
@@ -518,7 +595,7 @@ static async Task CollectToAsync<T, TResult, TNotice>(
518595
// a consequence generate new and unique task objects.
519596

520597
var completedTask = await
521-
Task.WhenAny(tasks.Cast<Task>().Concat(cancellationTaskSource.Task))
598+
Task.WhenAny(tasks.Select(it => (Task) it.Task).Concat(cancellationTaskSource.Task))
522599
.ConfigureAwait(continueOnCapturedContext: false);
523600

524601
if (completedTask == cancellationTaskSource.Task)
@@ -534,18 +611,23 @@ static async Task CollectToAsync<T, TResult, TNotice>(
534611
return;
535612
}
536613

537-
var task = (Task<(T Input, TResult Result)>) completedTask;
538-
tasks.Remove(task);
614+
var i = tasks.FindIndex(it => it.Task.Equals(completedTask));
539615

540-
// Await the task rather than using its result directly
541-
// to avoid having the task's exception bubble up as
542-
// AggregateException if the task failed.
616+
{
617+
var (item, task) = tasks[i];
618+
tasks.RemoveAt(i);
619+
620+
// Await the task rather than using its result directly
621+
// to avoid having the task's exception bubble up as
622+
// AggregateException if the task failed.
543623

544-
var eval = await task;
545-
collection.Add(resultNoticeSelector(eval.Input, eval.Result));
624+
collection.Add(completionNoticeSelector(item, task));
625+
}
546626

547-
if (reader.TryRead(out var item))
548-
tasks.Add(taskSelector(item).Select(r => (item, r)));
627+
{
628+
if (reader.TryRead(out var item))
629+
tasks.Add((item, taskSelector(item)));
630+
}
549631
}
550632

551633
collection.Add(endNotice);
@@ -627,13 +709,16 @@ public IAwaitQuery<T> WithOptions(AwaitQueryOptions options)
627709
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
628710
}
629711

630-
static class TupleComparer<T1, T2>
712+
static class TupleComparer<T1, T2, T3>
631713
{
632-
public static readonly IComparer<(T1, T2)> Item1 =
633-
Comparer<(T1, T2)>.Create((x, y) => Comparer<T1>.Default.Compare(x.Item1, y.Item1));
714+
public static readonly IComparer<(T1, T2, T3)> Item1 =
715+
Comparer<(T1, T2, T3)>.Create((x, y) => Comparer<T1>.Default.Compare(x.Item1, y.Item1));
716+
717+
public static readonly IComparer<(T1, T2, T3)> Item2 =
718+
Comparer<(T1, T2, T3)>.Create((x, y) => Comparer<T2>.Default.Compare(x.Item2, y.Item2));
634719

635-
public static readonly IComparer<(T1, T2)> Item2 =
636-
Comparer<(T1, T2)>.Create((x, y) => Comparer<T2>.Default.Compare(x.Item2, y.Item2));
720+
public static readonly IComparer<(T1, T2, T3)> Item3 =
721+
Comparer<(T1, T2, T3)>.Create((x, y) => Comparer<T3>.Default.Compare(x.Item3, y.Item3));
637722
}
638723
}
639724
}

MoreLinq/MoreLinq.csproj

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
- AtLeast
1616
- AtMost
1717
- Await (EXPERIMENTAL)
18+
- AwaitCompletion (EXPERIMENTAL)
1819
- Batch
1920
- Cartesian
2021
- Choose
@@ -121,7 +122,7 @@
121122
<PublicSign Condition=" '$(OS)' != 'Windows_NT' ">true</PublicSign>
122123
<PackageId>morelinq</PackageId>
123124
<PackageTags>linq;extensions</PackageTags>
124-
<PackageReleaseNotes>Adds new operators: Await (EXPERIMENTAL), CompareCount, Choose, CountDown, Memoize (EXPERIMENTAL), Transpose. See also https://github.com/morelinq/MoreLINQ/wiki/API-Changes.</PackageReleaseNotes>
125+
<PackageReleaseNotes>Adds new operators: Await (EXPERIMENTAL), AwaitCompletion (EXPERIMENTAL), CompareCount, Choose, CountDown, Memoize (EXPERIMENTAL), Transpose. See also https://github.com/morelinq/MoreLINQ/wiki/API-Changes.</PackageReleaseNotes>
125126
<PackageProjectUrl>https://morelinq.github.io/</PackageProjectUrl>
126127
<PackageLicenseUrl>http://www.apache.org/licenses/LICENSE-2.0</PackageLicenseUrl>
127128
<GenerateAssemblyTitleAttribute>false</GenerateAssemblyTitleAttribute>

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,12 @@ sequence as it completes asynchronously.
622622

623623
This method has 2 overloads.
624624

625+
### AwaitCompletion
626+
627+
Awaits completion of all asynchronous evaluations irrespective of whether they
628+
succeed or fail. An additional argument specifies a function that projects the
629+
final result given the source item and completed task.
630+
625631
### Memoize
626632

627633
Creates a sequence that lazily caches the source as it is iterated for the

0 commit comments

Comments
 (0)