Skip to content

Commit fa4b6d3

Browse files
committed
pr
1 parent b819cc7 commit fa4b6d3

File tree

2 files changed

+134
-5
lines changed

2 files changed

+134
-5
lines changed

src/MongoDB.Driver/OperationContext.cs

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
using System.Diagnostics;
1818
using System.Threading;
1919
using System.Threading.Tasks;
20+
#if !NET6_0_OR_GREATER
21+
using MongoDB.Driver.Core.Misc;
22+
#endif
2023

2124
namespace MongoDB.Driver
2225
{
@@ -91,22 +94,47 @@ public OperationContext WithTimeout(TimeSpan timeout)
9194

9295
public void WaitTask(Task task)
9396
{
97+
if (task.IsCompleted)
98+
{
99+
task.GetAwaiter().GetResult(); // re-throws exception if any
100+
return;
101+
}
102+
94103
var timeout = RemainingTimeout;
95-
if (timeout < TimeSpan.Zero)
104+
if (timeout != System.Threading.Timeout.InfiniteTimeSpan && timeout < TimeSpan.Zero)
96105
{
97106
throw new TimeoutException();
98107
}
99108

100-
if (!task.Wait((int)timeout.TotalMilliseconds, CancellationToken))
109+
try
101110
{
102-
throw new TimeoutException();
111+
if (!task.Wait((int)timeout.TotalMilliseconds, CancellationToken))
112+
{
113+
CancellationToken.ThrowIfCancellationRequested();
114+
throw new TimeoutException();
115+
}
116+
}
117+
catch (AggregateException e)
118+
{
119+
if (e.InnerExceptions.Count == 1)
120+
{
121+
throw e.InnerExceptions[0];
122+
}
123+
124+
throw;
103125
}
104126
}
105127

106128
public async Task WaitTaskAsync(Task task)
107129
{
130+
if (task.IsCompleted)
131+
{
132+
await task.ConfigureAwait(false); // re-throws exception if any
133+
return;
134+
}
135+
108136
var timeout = RemainingTimeout;
109-
if (timeout < TimeSpan.Zero)
137+
if (timeout != System.Threading.Timeout.InfiniteTimeSpan && timeout < TimeSpan.Zero)
110138
{
111139
throw new TimeoutException();
112140
}

tests/MongoDB.Driver.Tests/OperationContextTests.cs

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
using System.Collections.Generic;
1818
using System.Diagnostics;
1919
using System.Threading;
20+
using System.Threading.Tasks;
2021
using FluentAssertions;
22+
using MongoDB.TestHelpers.XunitExtensions;
2123
using Xunit;
2224

2325
namespace MongoDB.Driver.Tests
@@ -137,7 +139,106 @@ public void WithTimeout_should_set_ParentContext()
137139
resultContext.ParentContext.Should().Be(operationContext);
138140
}
139141

140-
// TODO: Add tests for WaitTask and WaitTaskAsync.
142+
[Theory]
143+
[ParameterAttributeData]
144+
public async Task Wait_should_throw_if_context_is_timedout([Values(true, false)] bool async)
145+
{
146+
var taskCompletionSource = new TaskCompletionSource<bool>();
147+
var operationContext = new OperationContext(TimeSpan.FromMilliseconds(10), CancellationToken.None);
148+
Thread.Sleep(20);
149+
150+
var exception = async ?
151+
await Record.ExceptionAsync(() => operationContext.WaitTaskAsync(taskCompletionSource.Task)) :
152+
Record.Exception(() => operationContext.WaitTask(taskCompletionSource.Task));
153+
154+
exception.Should().BeOfType<TimeoutException>();
155+
}
156+
157+
[Theory]
158+
[ParameterAttributeData]
159+
public async Task Wait_should_throw_if_context_is_cancelled([Values(true, false)] bool async)
160+
{
161+
var taskCompletionSource = new TaskCompletionSource<bool>();
162+
var cancellationTokenSource = new CancellationTokenSource();
163+
cancellationTokenSource.Cancel();
164+
var operationContext = new OperationContext(Timeout.InfiniteTimeSpan, cancellationTokenSource.Token);
165+
166+
var exception = async ?
167+
await Record.ExceptionAsync(() => operationContext.WaitTaskAsync(taskCompletionSource.Task)) :
168+
Record.Exception(() => operationContext.WaitTask(taskCompletionSource.Task));
169+
170+
exception.Should().BeOfType<OperationCanceledException>();
171+
}
172+
173+
[Theory]
174+
[ParameterAttributeData]
175+
public async Task Wait_should_rethrow_on_failed_task([Values(true, false)] bool async)
176+
{
177+
var ex = new InvalidOperationException();
178+
var task = Task.FromException(ex);
179+
var operationContext = new OperationContext(Timeout.InfiniteTimeSpan, CancellationToken.None);
180+
181+
var exception = async ?
182+
await Record.ExceptionAsync(() => operationContext.WaitTaskAsync(task)) :
183+
Record.Exception(() => operationContext.WaitTask(task));
184+
185+
exception.Should().Be(ex);
186+
}
187+
188+
[Theory]
189+
[ParameterAttributeData]
190+
public async Task Wait_should_rethrow_on_failed_promise_task([Values(true, false)] bool async)
191+
{
192+
var ex = new InvalidOperationException("Ups!");
193+
var taskCompletionSource = new TaskCompletionSource<bool>();
194+
var operationContext = new OperationContext(Timeout.InfiniteTimeSpan, CancellationToken.None);
195+
196+
var task = Task.Run(async () =>
197+
{
198+
if (async)
199+
{
200+
await operationContext.WaitTaskAsync(taskCompletionSource.Task);
201+
}
202+
else
203+
{
204+
operationContext.WaitTask(taskCompletionSource.Task);
205+
}
206+
});
207+
Thread.Sleep(20);
208+
taskCompletionSource.SetException(ex);
209+
210+
var exception = await Record.ExceptionAsync(() => task);
211+
exception.Should().Be(ex);
212+
}
213+
214+
[Theory]
215+
[ParameterAttributeData]
216+
public async Task Wait_should_throw_on_timeout([Values(true, false)] bool async)
217+
{
218+
var taskCompletionSource = new TaskCompletionSource<bool>();
219+
var operationContext = new OperationContext(TimeSpan.FromMilliseconds(20), CancellationToken.None);
220+
221+
var exception = async ?
222+
await Record.ExceptionAsync(() => operationContext.WaitTaskAsync(taskCompletionSource.Task)) :
223+
Record.Exception(() => operationContext.WaitTask(taskCompletionSource.Task));
224+
225+
exception.Should().BeOfType<TimeoutException>();
226+
}
227+
228+
[Theory]
229+
[ParameterAttributeData]
230+
public async Task Wait_should_not_throw_on_resolved_task_with_timedout_context([Values(true, false)] bool async)
231+
{
232+
var task = Task.FromResult(42);
233+
var operationContext = new OperationContext(TimeSpan.FromMilliseconds(10), CancellationToken.None);
234+
Thread.Sleep(20);
235+
236+
var exception = async ?
237+
await Record.ExceptionAsync(() => operationContext.WaitTaskAsync(task)) :
238+
Record.Exception(() => operationContext.WaitTask(task));
239+
240+
exception.Should().BeNull();
241+
}
141242
}
142243
}
143244

0 commit comments

Comments
 (0)