diff --git a/src/AdsClient/Ams.cs b/src/AdsClient/Ams.cs index af19237..08e41aa 100644 --- a/src/AdsClient/Ams.cs +++ b/src/AdsClient/Ams.cs @@ -40,9 +40,9 @@ public Ams(IAmsSocket amsSocket) /// public ushort AmsPortSource { get; set; } = 32905; - public Task ConnectAsync() + public Task ConnectAsync(CancellationToken cancellationToken = default) { - return AmsSocket.ConnectAsync(new MessageHandler(this)); + return AmsSocket.ConnectAsync(new MessageHandler(this), cancellationToken); } public void Dispose() diff --git a/src/AdsClient/AmsSocket.cs b/src/AdsClient/AmsSocket.cs index 8456714..5f632e0 100644 --- a/src/AdsClient/AmsSocket.cs +++ b/src/AdsClient/AmsSocket.cs @@ -1,5 +1,6 @@ using System; using System.Net.Sockets; +using System.Threading; using System.Threading.Tasks; using Viscon.Communication.Ads.Internal; @@ -31,11 +32,25 @@ public void Close() connection?.Close(); } - public async Task ConnectAsync(IIncomingMessageHandler messageHandler) + public async Task ConnectAsync(IIncomingMessageHandler messageHandler, CancellationToken cancellationToken = default) { if (connection is not null) throw new InvalidOperationException("Connection was already established."); - await TcpClient.ConnectAsync(Host, Port).ConfigureAwait(false); + using (cancellationToken.Register(state => ((Socket)state).Close(), Socket)) + { + try + { + await TcpClient.ConnectAsync(Host, Port).ConfigureAwait(false); + } + catch (Exception) when (cancellationToken.IsCancellationRequested) + { + // The exception handling is quite generic, but exceptions thrown differ across target frameworks. + // (See https://stackoverflow.com/a/66656805/1085457) + // This is probably not something to worry about, since apparently cancellation was requested anyway. + cancellationToken.ThrowIfCancellationRequested(); + } + } + connection = new AmsSocketConnection(TcpClient.Client, messageHandler); } diff --git a/src/AdsClient/IAmsSocket.cs b/src/AdsClient/IAmsSocket.cs index 061000b..f74d602 100644 --- a/src/AdsClient/IAmsSocket.cs +++ b/src/AdsClient/IAmsSocket.cs @@ -1,4 +1,5 @@ using System; +using System.Threading; using System.Threading.Tasks; namespace Viscon.Communication.Ads @@ -9,7 +10,7 @@ public interface IAmsSocket void Close(); - Task ConnectAsync(IIncomingMessageHandler messageHandler); + Task ConnectAsync(IIncomingMessageHandler messageHandler, CancellationToken cancellationToken = default); Task SendAsync(byte[] message); diff --git a/test/AdsClient/AdsCommandsAsyncTest.cs b/test/AdsClient/AdsCommandsAsyncTest.cs index 215114b..86a8912 100644 --- a/test/AdsClient/AdsCommandsAsyncTest.cs +++ b/test/AdsClient/AdsCommandsAsyncTest.cs @@ -1,5 +1,6 @@ using System; using System.Linq; +using System.Threading; using System.Threading.Tasks; using FakeItEasy; using Shouldly; @@ -18,12 +19,13 @@ public class AdsCommandsAsyncTest : IDisposable public AdsCommandsAsyncTest() { - A.CallTo(() => amsSocket.ConnectAsync(A.Ignored)).ReturnsLazily(call => - { - messageHandler = call.GetArgument(0); - connected = true; - return Task.CompletedTask; - }); + A.CallTo(() => amsSocket.ConnectAsync(A.Ignored, A.Ignored)) + .ReturnsLazily(call => + { + messageHandler = call.GetArgument(0); + connected = true; + return Task.CompletedTask; + }); A.CallTo(() => amsSocket.Connected).ReturnsLazily(() => connected); client = new AdsClient(amsNetIdSource: "10.0.0.120.1.1", amsSocket: amsSocket,