From 34eb54b8dae797251684dbee6518e58fd9b8e097 Mon Sep 17 00:00:00 2001 From: Joel Christner Date: Wed, 4 Sep 2019 17:41:44 -0700 Subject: [PATCH] NuGet v2.0, breaking changes, thanks to @MrMikeJJ for his extensive commits and pull requests! --- README.md | 45 +++-- TestClient/Program.cs | 116 ++++++++----- TestClientStream/Program.cs | 87 +++++----- TestDebug.bat | 8 +- TestMultiClient/ConcurrentList.cs | 24 +-- TestMultiClient/Program.cs | 85 +++++----- TestMultiThread/Program.cs | 68 +++++--- TestParallel/Program.cs | 60 ++++--- TestServer/Program.cs | 55 +++--- TestServerStream/Program.cs | 67 ++++---- WatsonTcp/ClientMetadata.cs | 18 +- WatsonTcp/Common.cs | 128 +++----------- WatsonTcp/Message/FieldType.cs | 10 +- WatsonTcp/Message/MessageField.cs | 21 +-- WatsonTcp/Message/MessageStatus.cs | 11 +- WatsonTcp/Message/WatsonMessage.cs | 162 +++++++++--------- WatsonTcp/Mode.cs | 8 +- WatsonTcp/WatsonTcp.csproj | 4 +- WatsonTcp/WatsonTcpClient.cs | 260 ++++++++++++++++------------- WatsonTcp/WatsonTcpServer.cs | 156 ++++++++--------- 20 files changed, 681 insertions(+), 712 deletions(-) diff --git a/README.md b/README.md index 817b630..7dbf489 100644 --- a/README.md +++ b/README.md @@ -9,15 +9,12 @@ A simple C# async TCP server and client with integrated framing for reliable transmission and receipt of data. -## New in v1.3.x +## New in v2.x -- Numerous fixes to authentication using preshared keys -- Authentication callbacks in the client to handle authentication events - - ```AuthenticationRequested``` - authentication requested by the server, return the preshared key string (16 bytes) - - ```AuthenticationSucceeded``` - authentication has succeeded, return true - - ```AuthenticationFailure``` - authentication has failed, return true -- Support for sending and receiving larger messages by using streams instead of byte arrays -- Refer to ```TestServerStream``` and ```TestClientStream``` for a reference implementation. You must set ```client.ReadDataStream = false``` and ```server.ReadDataStream = false``` and use the ```StreamReceived``` callback instead of ```MessageReceived``` +- Async Task-based callbacks +- Configurable connect timeout in WatsonTcpClient +- Clients can now connect via SSL without a certificate +- Big thanks to @MrMikeJJ for his extensive commits and pull requests ## Test Applications @@ -99,24 +96,21 @@ static void Main(string[] args) } } -static bool ClientConnected(string ipPort) +static async Task ClientConnected(string ipPort) { Console.WriteLine("Client connected: " + ipPort); - return true; } -static bool ClientDisconnected(string ipPort) +static async Task ClientDisconnected(string ipPort) { Console.WriteLine("Client disconnected: " + ipPort); - return true; } -static bool MessageReceived(string ipPort, byte[] data) +static async Task MessageReceived(string ipPort, byte[] data) { string msg = ""; if (data != null && data.Length > 0) msg = Encoding.UTF8.GetString(data); Console.WriteLine("Message received from " + ipPort + ": " + msg); - return true; } ``` @@ -164,22 +158,19 @@ static void Main(string[] args) } } -static bool MessageReceived(byte[] data) +static async Task MessageReceived(byte[] data) { Console.WriteLine("Message from server: " + Encoding.UTF8.GetString(data)); - return true; } -static bool ServerConnected() +static async Task ServerConnected() { Console.WriteLine("Server connected"); - return true; } -static bool ServerDisconnected() +static async Task ServerDisconnected() { Console.WriteLine("Server disconnected"); - return true; } ``` @@ -218,7 +209,7 @@ server.StreamReceived = StreamReceived; server.ReadDataStream = false; server.Start(); -static bool StreamReceived(string ipPort, long contentLength, Stream stream) +static async Task StreamReceived(string ipPort, long contentLength, Stream stream) { // read contentLength bytes from the stream from client ipPort and process return true; @@ -232,15 +223,23 @@ client.StreamReceived = StreamReceived; client.ReadDataStream = false; client.Start(); -static bool StreamReceived(long contentLength, Stream stream) +static async Task StreamReceived(long contentLength, Stream stream) { // read contentLength bytes from the stream and process - return true; } ``` ## Version History +v1.3.x +- Numerous fixes to authentication using preshared keys +- Authentication callbacks in the client to handle authentication events + - ```AuthenticationRequested``` - authentication requested by the server, return the preshared key string (16 bytes) + - ```AuthenticationSucceeded``` - authentication has succeeded, return true + - ```AuthenticationFailure``` - authentication has failed, return true +- Support for sending and receiving larger messages by using streams instead of byte arrays +- Refer to ```TestServerStream``` and ```TestClientStream``` for a reference implementation. You must set ```client.ReadDataStream = false``` and ```server.ReadDataStream = false``` and use the ```StreamReceived``` callback instead of ```MessageReceived``` + v1.2.x - Breaking changes for assigning callbacks, various server/client class variables, and starting them - Consolidated SSL and non-SSL clients and servers into single classes for each diff --git a/TestClient/Program.cs b/TestClient/Program.cs index d591f58..a570a33 100644 --- a/TestClient/Program.cs +++ b/TestClient/Program.cs @@ -1,27 +1,24 @@ using System; using System.Text; +using System.Threading.Tasks; using WatsonTcp; namespace TestClient { - class TestClient + internal class TestClient { - static string serverIp = ""; - static int serverPort = 0; - static bool useSsl = false; - static string certFile = ""; - static string certPass = ""; - static bool acceptInvalidCerts = true; - static bool mutualAuthentication = true; - static WatsonTcpClient client = null; - static string presharedKey = null; - - static void Main(string[] args) + private static string serverIp = ""; + private static int serverPort = 0; + private static bool useSsl = false; + private static string certFile = ""; + private static string certPass = ""; + private static bool acceptInvalidCerts = true; + private static bool mutualAuthentication = true; + private static WatsonTcpClient client = null; + private static string presharedKey = null; + + private static void Main(string[] args) { - serverIp = Common.InputString("Server IP:", "127.0.0.1", false); - serverPort = Common.InputInteger("Server port:", 9000, true, false); - useSsl = Common.InputBoolean("Use SSL:", false); - InitializeClient(); bool runForever = true; @@ -68,7 +65,7 @@ static void Main(string[] args) break; } - client.Send(Encoding.UTF8.GetBytes(userInput)); + if (!client.Send(Encoding.UTF8.GetBytes(userInput))) Console.WriteLine("Failed"); break; case "sendasync": @@ -79,7 +76,7 @@ static void Main(string[] args) break; } - bool success = client.SendAsync(Encoding.UTF8.GetBytes(userInput)).Result; + if (!client.SendAsync(Encoding.UTF8.GetBytes(userInput)).Result) Console.WriteLine("Failed"); break; case "status": @@ -109,17 +106,12 @@ static void Main(string[] args) client.ServerConnected = ServerConnected; client.ServerDisconnected = ServerDisconnected; client.MessageReceived = MessageReceived; - client.Start(); + client.Start(); } break; case "reconnect": - if (client != null) client.Dispose(); - client = new WatsonTcpClient(serverIp, serverPort); - client.ServerConnected = ServerConnected; - client.ServerDisconnected = ServerDisconnected; - client.MessageReceived = MessageReceived; - client.Start(); + ConnectClient(); break; case "psk": @@ -141,19 +133,47 @@ static void Main(string[] args) } } - static void InitializeClient() - { + private static void InitializeClient() + { + serverIp = Common.InputString("Server IP:", "127.0.0.1", false); + serverPort = Common.InputInteger("Server port:", 9000, true, false); + useSsl = Common.InputBoolean("Use SSL:", false); + if (!useSsl) { client = new WatsonTcpClient(serverIp, serverPort); } else { - certFile = Common.InputString("Certificate file:", "test.pfx", false); - certPass = Common.InputString("Certificate password:", "password", false); + bool supplyCert = Common.InputBoolean("Supply SSL certificate:", false); + + if (supplyCert) + { + certFile = Common.InputString("Certificate file:", "test.pfx", false); + certPass = Common.InputString("Certificate password:", "password", false); + } + acceptInvalidCerts = Common.InputBoolean("Accept Invalid Certs:", true); - mutualAuthentication = Common.InputBoolean("Mutually authenticate:", true); + mutualAuthentication = Common.InputBoolean("Mutually authenticate:", false); + + client = new WatsonTcpClient(serverIp, serverPort, certFile, certPass); + client.AcceptInvalidCertificates = acceptInvalidCerts; + client.MutuallyAuthenticate = mutualAuthentication; + } + ConnectClient(); + } + + private static void ConnectClient() + { + if (client != null) client.Dispose(); + + if (!useSsl) + { + client = new WatsonTcpClient(serverIp, serverPort); + } + else + { client = new WatsonTcpClient(serverIp, serverPort, certFile, certPass); client.AcceptInvalidCertificates = acceptInvalidCerts; client.MutuallyAuthenticate = mutualAuthentication; @@ -168,10 +188,10 @@ static void InitializeClient() client.ReadDataStream = true; client.ReadStreamBufferSize = 65536; // client.Debug = true; - client.Start(); + client.Start(); } - static string AuthenticationRequested() + private static string AuthenticationRequested() { Console.WriteLine(""); Console.WriteLine(""); @@ -181,34 +201,44 @@ static string AuthenticationRequested() return presharedKey; } - static bool AuthenticationSucceeded() +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task AuthenticationSucceeded() +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { Console.WriteLine("Authentication succeeded"); - return true; } - static bool AuthenticationFailure() +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task AuthenticationFailure() +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { Console.WriteLine("Authentication failed"); - return true; } - static bool MessageReceived(byte[] data) +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task MessageReceived(byte[] data) +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { Console.WriteLine("Message from server: " + Encoding.UTF8.GetString(data)); - return true; } - static bool ServerConnected() +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task ServerConnected() +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { Console.WriteLine("Server connected"); - return true; } - static bool ServerDisconnected() +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task ServerDisconnected() +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { Console.WriteLine("Server disconnected"); - return true; } } -} +} \ No newline at end of file diff --git a/TestClientStream/Program.cs b/TestClientStream/Program.cs index b018bc9..75b4000 100644 --- a/TestClientStream/Program.cs +++ b/TestClientStream/Program.cs @@ -1,23 +1,24 @@ using System; using System.IO; using System.Text; +using System.Threading.Tasks; using WatsonTcp; namespace TestClientStream { - class TestClientStream + internal class TestClientStream { - static string serverIp = ""; - static int serverPort = 0; - static bool useSsl = false; - static string certFile = ""; - static string certPass = ""; - static bool acceptInvalidCerts = true; - static bool mutualAuthentication = true; - static WatsonTcpClient client = null; - static string presharedKey = null; - - static void Main(string[] args) + private static string serverIp = ""; + private static int serverPort = 0; + private static bool useSsl = false; + private static string certFile = ""; + private static string certPass = ""; + private static bool acceptInvalidCerts = true; + private static bool mutualAuthentication = true; + private static WatsonTcpClient client = null; + private static string presharedKey = null; + + private static void Main(string[] args) { serverIp = Common.InputString("Server IP:", "127.0.0.1", false); serverPort = Common.InputInteger("Server port:", 9000, true, false); @@ -144,7 +145,10 @@ static void Main(string[] args) } } - static bool StreamReceived(long contentLength, Stream stream) +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task StreamReceived(long contentLength, Stream stream) +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { try { @@ -156,9 +160,9 @@ static bool StreamReceived(long contentLength, Stream stream) long bytesRemaining = contentLength; if (stream != null && stream.CanRead) - { + { while (bytesRemaining > 0) - { + { bytesRead = stream.Read(buffer, 0, buffer.Length); Console.WriteLine("Read " + bytesRead); @@ -178,17 +182,14 @@ static bool StreamReceived(long contentLength, Stream stream) { Console.WriteLine("[null]"); } - - return true; - } + } catch (Exception e) { - LogException(e); - return false; + Common.LogException("StreamReceived", e); } } - static void InitializeClient() + private static void InitializeClient() { if (!useSsl) { @@ -214,10 +215,10 @@ static void InitializeClient() client.StreamReceived = StreamReceived; client.ReadDataStream = false; // client.Debug = true; - client.Start(); + client.Start(); } - static string AuthenticationRequested() + private static string AuthenticationRequested() { Console.WriteLine(""); Console.WriteLine(""); @@ -227,40 +228,36 @@ static string AuthenticationRequested() return presharedKey; } - static bool AuthenticationSucceeded() +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task AuthenticationSucceeded() +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { Console.WriteLine("Authentication succeeded"); - return true; } - static bool AuthenticationFailure() +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task AuthenticationFailure() +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { Console.WriteLine("Authentication failed"); - return true; } - static bool ServerConnected() +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task ServerConnected() +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { Console.WriteLine("Server connected"); - return true; } - static bool ServerDisconnected() - { - Console.WriteLine("Server disconnected"); - return true; - } +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously - static void LogException(Exception e) + private static async Task ServerDisconnected() +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { - Console.WriteLine("================================================================================"); - Console.WriteLine(" = Exception Type: " + e.GetType().ToString()); - Console.WriteLine(" = Exception Data: " + e.Data); - Console.WriteLine(" = Inner Exception: " + e.InnerException); - Console.WriteLine(" = Exception Message: " + e.Message); - Console.WriteLine(" = Exception Source: " + e.Source); - Console.WriteLine(" = Exception StackTrace: " + e.StackTrace); - Console.WriteLine("================================================================================"); - } + Console.WriteLine("Server disconnected"); + } } -} +} \ No newline at end of file diff --git a/TestDebug.bat b/TestDebug.bat index 82f03e8..d2b3a4f 100644 --- a/TestDebug.bat +++ b/TestDebug.bat @@ -1,17 +1,17 @@ @echo off IF [%1] == [] GOTO Usage -cd TestServer\bin\debug +cd TestServer\bin\debug\net452 start TestServer.exe TIMEOUT 3 > NUL -cd ..\..\.. +cd ..\..\..\.. -cd TestClient\bin\debug +cd TestClient\bin\debug\net452 FOR /L %%i IN (1,1,%1) DO ( ECHO Starting client %%i start TestClient.exe TIMEOUT 1 > NUL ) -cd ..\..\.. +cd ..\..\..\.. @echo on EXIT /b diff --git a/TestMultiClient/ConcurrentList.cs b/TestMultiClient/ConcurrentList.cs index 3ef7ff9..e30bc5f 100644 --- a/TestMultiClient/ConcurrentList.cs +++ b/TestMultiClient/ConcurrentList.cs @@ -7,8 +7,8 @@ namespace ConcurrentList { public sealed class ConcurrentList : ThreadSafeList { - static readonly int[] Sizes; - static readonly int[] Counts; + private static readonly int[] Sizes; + private static readonly int[] Counts; static ConcurrentList() { @@ -30,10 +30,10 @@ static ConcurrentList() } } - int _index; - int _fuzzyCount; - int _count; - T[][] _array; + private int _index; + private int _fuzzyCount; + private int _count; + private T[][] _array; public ConcurrentList() { @@ -46,7 +46,7 @@ public override T this[int index] { if (index < 0 || index >= _count) { - throw new ArgumentOutOfRangeException("index"); + throw new ArgumentOutOfRangeException(nameof(index)); } int arrayIndex = GetArrayIndex(index + 1); @@ -98,7 +98,7 @@ public override void CopyTo(T[] array, int index) { if (array == null) { - throw new ArgumentNullException("array"); + throw new ArgumentNullException(nameof(array)); } int count = _count; @@ -165,7 +165,7 @@ protected override bool IsSynchronizedBase get { return false; } } - #endregion + #endregion "Protected methods" } public abstract class ThreadSafeList : IList, IList @@ -226,7 +226,7 @@ protected virtual int AddBase(object value) return Count - 1; } - #endregion + #endregion "Protected methods" #region "Explicit interface implementations" @@ -332,6 +332,6 @@ System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() return GetEnumerator(); } - #endregion + #endregion "Explicit interface implementations" } -} +} \ No newline at end of file diff --git a/TestMultiClient/Program.cs b/TestMultiClient/Program.cs index 991540a..02fb360 100644 --- a/TestMultiClient/Program.cs +++ b/TestMultiClient/Program.cs @@ -1,29 +1,26 @@ -using System; -using System.Collections.Generic; -using System.Linq; +using ConcurrentList; +using System; using System.Security.Cryptography; -using System.Text; using System.Threading; using System.Threading.Tasks; using WatsonTcp; -using ConcurrentList; namespace TestMultiClient { - class Program + internal class Program { - static int serverPort = 9000; - static WatsonTcpServer server = null; - static int clientThreads = 16; - static int numIterations = 1000; - static int connectionCount = 0; - static ConcurrentList connections = new ConcurrentList(); - static bool clientsStarted = false; - - static Random rng; - static byte[] data; - - static void Main(string[] args) + private static int serverPort = 9000; + private static WatsonTcpServer server = null; + private static int clientThreads = 16; + private static int numIterations = 1000; + private static int connectionCount = 0; + private static ConcurrentList connections = new ConcurrentList(); + private static bool clientsStarted = false; + + private static Random rng; + private static byte[] data; + + private static void Main(string[] args) { rng = new Random((int)DateTime.Now.Ticks); data = InitByteArray(65536, 0x00); @@ -44,12 +41,12 @@ static void Main(string[] args) Console.WriteLine("Starting client " + i); Task.Run(() => ClientTask()); } - + Console.WriteLine("Press ENTER to exit"); Console.ReadLine(); } - static void ClientTask() + private static void ClientTask() { Console.WriteLine("ClientTask entering"); using (WatsonTcpClient client = new WatsonTcpClient("localhost", serverPort)) @@ -74,7 +71,10 @@ static void ClientTask() Console.WriteLine("[client] finished"); } - static bool ServerClientConnected(string ipPort) +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task ServerClientConnected(string ipPort) +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { connectionCount++; Console.WriteLine("[server] connection from " + ipPort + " (now " + connectionCount + ")"); @@ -85,36 +85,43 @@ static bool ServerClientConnected(string ipPort) } connections.Add(ipPort); - return true; } - static bool ServerClientDisconnected(string ipPort) +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task ServerClientDisconnected(string ipPort) +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { connectionCount--; Console.WriteLine("[server] disconnection from " + ipPort + " (now " + connectionCount + ")"); - return true; } - static bool ServerMsgReceived(string ipPort, byte[] data) +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task ServerMsgReceived(string ipPort, byte[] data) +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { - // Console.WriteLine("[server] msg from " + ipPort + ": " + BytesToHex(Md5(data)) + " (" + data.Length + " bytes)"); - return true; } - static bool ClientServerConnected() +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task ClientServerConnected() +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { - return true; } - static bool ClientServerDisconnected() +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task ClientServerDisconnected() +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { - return true; } - static bool ClientMsgReceived(byte[] data) +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task ClientMsgReceived(byte[] data) +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { - // Console.WriteLine("[server] msg from server: " + BytesToHex(Md5(data)) + " (" + data.Length + " bytes)"); - return true; } public static byte[] InitByteArray(int count, byte val) @@ -127,15 +134,17 @@ public static byte[] InitByteArray(int count, byte val) return ret; } - static byte[] Md5(byte[] data) + private static byte[] Md5(byte[] data) { if (data == null || data.Length < 1) { return null; } - MD5 m = MD5.Create(); - return m.ComputeHash(data); + using (MD5 m = MD5.Create()) + { + return m.ComputeHash(data); + } } public static string BytesToHex(byte[] bytes) @@ -153,4 +162,4 @@ public static string BytesToHex(byte[] bytes) return BitConverter.ToString(bytes).Replace("-", ""); } } -} +} \ No newline at end of file diff --git a/TestMultiThread/Program.cs b/TestMultiThread/Program.cs index 9ad4457..795bdb6 100644 --- a/TestMultiThread/Program.cs +++ b/TestMultiThread/Program.cs @@ -6,18 +6,18 @@ namespace TestMultiThread { - class Program + internal class Program { - static int serverPort = 8000; - static int clientThreads = 128; - static int numIterations = 10000; - static Random rng; - static byte[] data; + private static int serverPort = 8000; + private static int clientThreads = 128; + private static int numIterations = 10000; + private static Random rng; + private static byte[] data; - static WatsonTcpServer server; - static WatsonTcpClient c; + private static WatsonTcpServer server; + private static WatsonTcpClient c; - static void Main(string[] args) + private static void Main(string[] args) { rng = new Random((int)DateTime.Now.Ticks); data = InitByteArray(262144, 0x00); @@ -36,7 +36,7 @@ static void Main(string[] args) c.ServerConnected = ClientServerConnected; c.ServerDisconnected = ClientServerDisconnected; c.MessageReceived = ClientMsgReceived; - c.Start(); + c.Start(); Console.WriteLine("Press ENTER to exit"); @@ -48,7 +48,7 @@ static void Main(string[] args) Console.ReadLine(); } - static void ClientTask() + private static void ClientTask() { for (int i = 0; i < numIterations; i++) { @@ -59,38 +59,50 @@ static void ClientTask() Console.WriteLine("[client] finished"); } - static bool ServerClientConnected(string ipPort) +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task ServerClientConnected(string ipPort) +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { Console.WriteLine("[server] connection from " + ipPort); - return true; } - static bool ServerClientDisconnected(string ipPort) +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task ServerClientDisconnected(string ipPort) +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { Console.WriteLine("[server] disconnection from " + ipPort); - return true; } - static bool ServerMsgReceived(string ipPort, byte[] data) +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task ServerMsgReceived(string ipPort, byte[] data) +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { Console.WriteLine("[server] msg from " + ipPort + ": " + BytesToHex(Md5(data)) + " (" + data.Length + " bytes)"); - return true; } - static bool ClientServerConnected() +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task ClientServerConnected() +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { - return true; } - static bool ClientServerDisconnected() +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task ClientServerDisconnected() +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { - return true; } - static bool ClientMsgReceived(byte[] data) +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task ClientMsgReceived(byte[] data) +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { Console.WriteLine("[server] msg from server: " + BytesToHex(Md5(data)) + " (" + data.Length + " bytes)"); - return true; } public static byte[] InitByteArray(int count, byte val) @@ -103,15 +115,17 @@ public static byte[] InitByteArray(int count, byte val) return ret; } - static byte[] Md5(byte[] data) + private static byte[] Md5(byte[] data) { if (data == null || data.Length < 1) { return null; } - MD5 m = MD5.Create(); - return m.ComputeHash(data); + using (MD5 m = MD5.Create()) + { + return m.ComputeHash(data); + } } public static string BytesToHex(byte[] bytes) @@ -129,4 +143,4 @@ public static string BytesToHex(byte[] bytes) return BitConverter.ToString(bytes).Replace("-", ""); } } -} +} \ No newline at end of file diff --git a/TestParallel/Program.cs b/TestParallel/Program.cs index f439289..848b590 100644 --- a/TestParallel/Program.cs +++ b/TestParallel/Program.cs @@ -6,16 +6,16 @@ namespace TestParallel { - class Program + internal class Program { - static int serverPort = 8000; - static int clientThreads = 8; - static int numIterations = 10000; - static Random rng; - static byte[] data; - static WatsonTcpServer server; - - static void Main(string[] args) + private static int serverPort = 8000; + private static int clientThreads = 8; + private static int numIterations = 10000; + private static Random rng; + private static byte[] data; + private static WatsonTcpServer server; + + private static void Main(string[] args) { rng = new Random((int)DateTime.Now.Ticks); data = InitByteArray(262144, 0x00); @@ -40,7 +40,7 @@ static void Main(string[] args) Console.ReadLine(); } - static void ClientTask() + private static void ClientTask() { using (WatsonTcpClient client = new WatsonTcpClient("localhost", serverPort)) { @@ -59,38 +59,50 @@ static void ClientTask() Console.WriteLine("[client] finished"); } - static bool ServerClientConnected(string ipPort) +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task ServerClientConnected(string ipPort) +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { Console.WriteLine("[server] connection from " + ipPort); - return true; } - static bool ServerClientDisconnected(string ipPort) +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task ServerClientDisconnected(string ipPort) +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { Console.WriteLine("[server] disconnection from " + ipPort); - return true; } - static bool ServerMsgReceived(string ipPort, byte[] data) +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task ServerMsgReceived(string ipPort, byte[] data) +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { Console.WriteLine("[server] msg from " + ipPort + ": " + BytesToHex(Md5(data)) + " (" + data.Length + " bytes)"); - return true; } - static bool ClientServerConnected() +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task ClientServerConnected() +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { - return true; } - static bool ClientServerDisconnected() +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task ClientServerDisconnected() +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { - return true; } - static bool ClientMsgReceived(byte[] data) +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task ClientMsgReceived(byte[] data) +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { Console.WriteLine("[server] msg from server: " + BytesToHex(Md5(data)) + " (" + data.Length + " bytes)"); - return true; } public static byte[] InitByteArray(int count, byte val) @@ -103,7 +115,7 @@ public static byte[] InitByteArray(int count, byte val) return ret; } - static byte[] Md5(byte[] data) + private static byte[] Md5(byte[] data) { if (data == null || data.Length < 1) { @@ -129,4 +141,4 @@ public static string BytesToHex(byte[] bytes) return BitConverter.ToString(bytes).Replace("-", ""); } } -} +} \ No newline at end of file diff --git a/TestServer/Program.cs b/TestServer/Program.cs index f223628..7e1fbf6 100644 --- a/TestServer/Program.cs +++ b/TestServer/Program.cs @@ -1,22 +1,23 @@ using System; using System.Collections.Generic; using System.Text; +using System.Threading.Tasks; using WatsonTcp; namespace TestServer { - class TestServer + internal class TestServer { - static string serverIp = ""; - static int serverPort = 0; - static bool useSsl = false; - static WatsonTcpServer server = null; - static string certFile = ""; - static string certPass = ""; - static bool acceptInvalidCerts = true; - static bool mutualAuthentication = true; - - static void Main(string[] args) + private static string serverIp = ""; + private static int serverPort = 0; + private static bool useSsl = false; + private static WatsonTcpServer server = null; + private static string certFile = ""; + private static string certPass = ""; + private static bool acceptInvalidCerts = true; + private static bool mutualAuthentication = true; + + private static void Main(string[] args) { serverIp = Common.InputString("Server IP:", "127.0.0.1", false); serverPort = Common.InputInteger("Server port:", 9000, true, false); @@ -24,14 +25,14 @@ static void Main(string[] args) if (!useSsl) { - server = new WatsonTcpServer(serverIp, serverPort); + server = new WatsonTcpServer(serverIp, serverPort); } else - { + { certFile = Common.InputString("Certificate file:", "test.pfx", false); - certPass = Common.InputString("Certificate password:", "password", false); + certPass = Common.InputString("Certificate password:", "password", false); acceptInvalidCerts = Common.InputBoolean("Accept Invalid Certs:", true); - mutualAuthentication = Common.InputBoolean("Mutually authenticate:", true); + mutualAuthentication = Common.InputBoolean("Mutually authenticate:", false); server = new WatsonTcpServer(serverIp, serverPort, certFile, certPass); server.AcceptInvalidCertificates = acceptInvalidCerts; @@ -101,7 +102,7 @@ static void Main(string[] args) if (String.IsNullOrEmpty(ipPort)) break; Console.Write("Data: "); userInput = Console.ReadLine(); - if (String.IsNullOrEmpty(userInput)) break; + if (String.IsNullOrEmpty(userInput)) break; success = server.Send(ipPort, Encoding.UTF8.GetBytes(userInput)); Console.WriteLine(success); break; @@ -135,22 +136,29 @@ static void Main(string[] args) default: break; } - } + } } - static bool ClientConnected(string ipPort) +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task ClientConnected(string ipPort) +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { Console.WriteLine("Client connected: " + ipPort); - return true; } - static bool ClientDisconnected(string ipPort) +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task ClientDisconnected(string ipPort) +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { Console.WriteLine("Client disconnected: " + ipPort); - return true; } - static bool MessageReceived(string ipPort, byte[] data) +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task MessageReceived(string ipPort, byte[] data) +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { string msg = ""; if (data != null && data.Length > 0) @@ -159,7 +167,6 @@ static bool MessageReceived(string ipPort, byte[] data) } Console.WriteLine("Message received from " + ipPort + ": " + msg); - return true; } } -} +} \ No newline at end of file diff --git a/TestServerStream/Program.cs b/TestServerStream/Program.cs index 47409e0..b241faa 100644 --- a/TestServerStream/Program.cs +++ b/TestServerStream/Program.cs @@ -2,22 +2,23 @@ using System.Collections.Generic; using System.IO; using System.Text; +using System.Threading.Tasks; using WatsonTcp; namespace TestServerStream { - class TestServerStream + internal class TestServerStream { - static string serverIp = ""; - static int serverPort = 0; - static bool useSsl = false; - static WatsonTcpServer server = null; - static string certFile = ""; - static string certPass = ""; - static bool acceptInvalidCerts = true; - static bool mutualAuthentication = true; - - static void Main(string[] args) + private static string serverIp = ""; + private static int serverPort = 0; + private static bool useSsl = false; + private static WatsonTcpServer server = null; + private static string certFile = ""; + private static string certPass = ""; + private static bool acceptInvalidCerts = true; + private static bool mutualAuthentication = true; + + private static void Main(string[] args) { serverIp = Common.InputString("Server IP:", "127.0.0.1", false); serverPort = Common.InputInteger("Server port:", 9000, true, false); @@ -42,7 +43,7 @@ static void Main(string[] args) server.ClientConnected = ClientConnected; server.ClientDisconnected = ClientDisconnected; server.StreamReceived = StreamReceived; - server.ReadDataStream = false; + server.ReadDataStream = false; // server.Debug = true; server.Start(); @@ -110,7 +111,7 @@ static void Main(string[] args) data = Encoding.UTF8.GetBytes(userInput); ms = new MemoryStream(data); success = server.Send(ipPort, data.Length, ms); - Console.WriteLine(success); + Console.WriteLine(success); break; case "sendasync": @@ -147,19 +148,26 @@ static void Main(string[] args) } } - static bool ClientConnected(string ipPort) +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task ClientConnected(string ipPort) +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { Console.WriteLine("Client connected: " + ipPort); - return true; } - static bool ClientDisconnected(string ipPort) +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task ClientDisconnected(string ipPort) +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { Console.WriteLine("Client disconnected: " + ipPort); - return true; } - static bool StreamReceived(string ipPort, long contentLength, Stream stream) +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + + private static async Task StreamReceived(string ipPort, long contentLength, Stream stream) +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { try { @@ -171,9 +179,9 @@ static bool StreamReceived(string ipPort, long contentLength, Stream stream) long bytesRemaining = contentLength; if (stream != null && stream.CanRead) - { + { while (bytesRemaining > 0) - { + { bytesRead = stream.Read(buffer, 0, buffer.Length); Console.WriteLine("Read " + bytesRead); @@ -193,26 +201,11 @@ static bool StreamReceived(string ipPort, long contentLength, Stream stream) { Console.WriteLine("[null]"); } - - return true; } catch (Exception e) { - LogException(e); - return false; + Common.LogException("StreamReceived", e); } - } - - static void LogException(Exception e) - { - Console.WriteLine("================================================================================"); - Console.WriteLine(" = Exception Type: " + e.GetType().ToString()); - Console.WriteLine(" = Exception Data: " + e.Data); - Console.WriteLine(" = Inner Exception: " + e.InnerException); - Console.WriteLine(" = Exception Message: " + e.Message); - Console.WriteLine(" = Exception Source: " + e.Source); - Console.WriteLine(" = Exception StackTrace: " + e.StackTrace); - Console.WriteLine("================================================================================"); } } -} +} \ No newline at end of file diff --git a/WatsonTcp/ClientMetadata.cs b/WatsonTcp/ClientMetadata.cs index ec93a9f..4ca2916 100644 --- a/WatsonTcp/ClientMetadata.cs +++ b/WatsonTcp/ClientMetadata.cs @@ -1,9 +1,7 @@ using System; -using System.Net; using System.Net.Security; using System.Net.Sockets; using System.Threading; -using System.Threading.Tasks; namespace WatsonTcp { @@ -36,7 +34,7 @@ public string IpPort public SemaphoreSlim WriteLock { get; set; } - #endregion + #endregion Public-Members #region Private-Members @@ -47,21 +45,21 @@ public string IpPort private SslStream _SslStream; private string _IpPort; - #endregion + #endregion Private-Members #region Constructors-and-Factories public ClientMetadata(TcpClient tcp) { - _TcpClient = tcp ?? throw new ArgumentNullException(nameof(tcp)); - _NetworkStream = tcp.GetStream(); + _TcpClient = tcp ?? throw new ArgumentNullException(nameof(tcp)); + _NetworkStream = tcp.GetStream(); _IpPort = tcp.Client.RemoteEndPoint.ToString(); ReadLock = new SemaphoreSlim(1); WriteLock = new SemaphoreSlim(1); } - #endregion + #endregion Constructors-and-Factories #region Public-Methods @@ -71,7 +69,7 @@ public void Dispose() GC.SuppressFinalize(this); } - #endregion + #endregion Public-Methods #region Private-Methods @@ -106,6 +104,6 @@ protected virtual void Dispose(bool disposing) _Disposed = true; } - #endregion + #endregion Private-Methods } -} +} \ No newline at end of file diff --git a/WatsonTcp/Common.cs b/WatsonTcp/Common.cs index be17de0..4c29162 100644 --- a/WatsonTcp/Common.cs +++ b/WatsonTcp/Common.cs @@ -1,19 +1,12 @@ -using System; +using Newtonsoft.Json; +using System; using System.Collections; -using System.Collections.Generic; using System.Text; -using Newtonsoft.Json; - namespace WatsonTcp { public static class Common - { - /// - /// Serialize an object to JSON. - /// - /// Object to serialize. - /// JSON string. + { public static string SerializeJson(object obj) { if (obj == null) return null; @@ -28,13 +21,7 @@ public static string SerializeJson(object obj) return json; } - - /// - /// Deserialize JSON string to an object using Newtonsoft JSON.NET. - /// - /// The type of object. - /// JSON string. - /// An object of the specified type. + public static T DeserializeJson(string json) { if (String.IsNullOrEmpty(json)) throw new ArgumentNullException(nameof(json)); @@ -43,22 +30,16 @@ public static T DeserializeJson(string json) { return JsonConvert.DeserializeObject(json); } - catch (Exception e) + catch (Exception) { Console.WriteLine(""); Console.WriteLine("Exception while deserializing:"); Console.WriteLine(json); Console.WriteLine(""); - throw e; + throw; } } - - /// - /// Deserialize JSON string to an object using Newtonsoft JSON.NET. - /// - /// The type of object. - /// Byte array containing the JSON string. - /// An object of the specified type. + public static T DeserializeJson(byte[] data) { if (data == null || data.Length < 1) throw new ArgumentNullException(nameof(data)); @@ -175,91 +156,24 @@ public static int InputInteger(string question, int defaultAnswer, bool positive return ret; } } - - public static void InitByteArray(byte[] data) - { - if (data == null || data.Length < 1) throw new ArgumentNullException(nameof(data)); - for (int i = 0; i < data.Length; i++) - { - data[i] = 0x00; - } - } - - public static void InitBitArray(BitArray data) - { - if (data == null || data.Length < 1) throw new ArgumentNullException(nameof(data)); - for (int i = 0; i < data.Length; i++) - { - data[i] = false; - } - } - - public static byte[] AppendBytes(byte[] head, byte[] tail) - { - byte[] arrayCombined = new byte[head.Length + tail.Length]; - Array.Copy(head, 0, arrayCombined, 0, head.Length); - Array.Copy(tail, 0, arrayCombined, head.Length, tail.Length); - return arrayCombined; - } - - public static string ByteArrayToHex(byte[] data) - { - StringBuilder hex = new StringBuilder(data.Length * 2); - foreach (byte b in data) hex.AppendFormat("{0:x2}", b); - return hex.ToString(); - } - - public static void ReverseBitArray(BitArray array) - { - int length = array.Length; - int mid = (length / 2); - - for (int i = 0; i < mid; i++) - { - bool bit = array[i]; - array[i] = array[length - i - 1]; - array[length - i - 1] = bit; - } - } - - public static byte[] ReverseByteArray(byte[] bytes) - { - if (bytes == null || bytes.Length < 1) throw new ArgumentNullException(nameof(bytes)); - - byte[] ret = new byte[bytes.Length]; - for (int i = 0; i < bytes.Length; i++) - { - ret[i] = ReverseByte(bytes[i]); - } - - return ret; - } - + public static byte ReverseByte(byte b) { return (byte)(((b * 0x0802u & 0x22110u) | (b * 0x8020u & 0x88440u)) * 0x10101u >> 16); } - - public static byte[] BitArrayToBytes(BitArray bits) - { - if (bits == null || bits.Length < 1) throw new ArgumentNullException(nameof(bits)); - if (bits.Length % 8 != 0) throw new ArgumentException("BitArray length must be divisible by 8."); - - byte[] ret = new byte[(bits.Length - 1) / 8 + 1]; - bits.CopyTo(ret, 0); - return ret; - } - - public static void LogException(Exception e) + + public static void LogException(string method, Exception e) { - Console.WriteLine("================================================================================"); - Console.WriteLine(" = Exception Type: " + e.GetType().ToString()); - Console.WriteLine(" = Exception Data: " + e.Data); - Console.WriteLine(" = Inner Exception: " + e.InnerException); - Console.WriteLine(" = Exception Message: " + e.Message); - Console.WriteLine(" = Exception Source: " + e.Source); - Console.WriteLine(" = Exception StackTrace: " + e.StackTrace); - Console.WriteLine("================================================================================"); + Console.WriteLine(""); + Console.WriteLine("An exception was encountered."); + Console.WriteLine(" Method : " + method); + Console.WriteLine(" Type : " + e.GetType().ToString()); + Console.WriteLine(" Data : " + e.Data); + Console.WriteLine(" Inner : " + e.InnerException); + Console.WriteLine(" Message : " + e.Message); + Console.WriteLine(" Source : " + e.Source); + Console.WriteLine(" StackTrace : " + e.StackTrace); + Console.WriteLine(""); } } -} +} \ No newline at end of file diff --git a/WatsonTcp/Message/FieldType.cs b/WatsonTcp/Message/FieldType.cs index f6f06ea..b6ef866 100644 --- a/WatsonTcp/Message/FieldType.cs +++ b/WatsonTcp/Message/FieldType.cs @@ -1,11 +1,7 @@ -using System; -using System.Collections.Generic; -using System.Text; - -namespace WatsonTcp.Message +namespace WatsonTcp.Message { public enum FieldType - { + { Int32, Int64, Bits, @@ -13,4 +9,4 @@ public enum FieldType DateTime, String } -} +} \ No newline at end of file diff --git a/WatsonTcp/Message/MessageField.cs b/WatsonTcp/Message/MessageField.cs index 0c21cec..a658d6e 100644 --- a/WatsonTcp/Message/MessageField.cs +++ b/WatsonTcp/Message/MessageField.cs @@ -1,7 +1,4 @@ using System; -using System.Collections; -using System.Collections.Generic; -using System.Text; namespace WatsonTcp.Message { @@ -14,11 +11,9 @@ public class MessageField public FieldType Type { get; set; } public int Length { get; set; } - #endregion + #endregion Public-Members - #region Private-Members - #endregion #region Constructors-and-Factories @@ -39,14 +34,6 @@ public MessageField(int bitNumber, string name, FieldType fieldType, int length) Length = length; } - #endregion - - #region Public-Methods - - #endregion - - #region Private-Methods - - #endregion - } -} + #endregion Constructors-and-Factories + } +} \ No newline at end of file diff --git a/WatsonTcp/Message/MessageStatus.cs b/WatsonTcp/Message/MessageStatus.cs index 8bdafa4..b8e6a23 100644 --- a/WatsonTcp/Message/MessageStatus.cs +++ b/WatsonTcp/Message/MessageStatus.cs @@ -1,8 +1,4 @@ -using System; -using System.Collections.Generic; -using System.Text; - -namespace WatsonTcp.Message +namespace WatsonTcp.Message { public enum MessageStatus { @@ -12,6 +8,7 @@ public enum MessageStatus AuthRequired, AuthRequested, AuthSuccess, - AuthFailure + AuthFailure, + Removed } -} +} \ No newline at end of file diff --git a/WatsonTcp/Message/WatsonMessage.cs b/WatsonTcp/Message/WatsonMessage.cs index 341eca6..32c3aaf 100644 --- a/WatsonTcp/Message/WatsonMessage.cs +++ b/WatsonTcp/Message/WatsonMessage.cs @@ -5,11 +5,9 @@ using System.Globalization; using System.IO; using System.Linq; -using System.Net; using System.Net.Security; using System.Net.Sockets; using System.Text; -using System.Threading; using System.Threading.Tasks; namespace WatsonTcp.Message @@ -36,7 +34,7 @@ internal class WatsonMessage /// /// Preshared key for connection authentication. HeaderFields[0], 16 bytes. /// - internal byte[] PresharedKey + internal byte[] PresharedKey { get { @@ -44,12 +42,13 @@ internal byte[] PresharedKey } set { + if (value == null) throw new ArgumentNullException(nameof(PresharedKey)); if (value != null && value.Length != 16) throw new ArgumentException("PresharedKey must be 16 bytes."); _PresharedKey = new byte[16]; Buffer.BlockCopy(value, 0, _PresharedKey, 0, 16); HeaderFields[0] = true; } - } + } /// /// Status of the message. HeaderFields[1], 4 bytes. @@ -93,7 +92,7 @@ internal int ReadStreamBuffer } } - #endregion + #endregion Public-Members #region Private-Members @@ -108,7 +107,7 @@ internal int ReadStreamBuffer private byte[] _PresharedKey; private MessageStatus _Status; - #endregion + #endregion Private-Members #region Constructors-and-Factories @@ -139,7 +138,7 @@ internal WatsonMessage(byte[] data, bool debug) ContentLength = data.Length; Data = new byte[data.Length]; Buffer.BlockCopy(data, 0, Data, 0, data.Length); - DataStream = null; + DataStream = null; _Debug = debug; } @@ -160,7 +159,7 @@ internal WatsonMessage(long contentLength, Stream stream, bool debug) throw new ArgumentException("Cannot read from supplied stream."); } } - + HeaderFields = new BitArray(64); InitBitArray(HeaderFields); @@ -168,8 +167,8 @@ internal WatsonMessage(long contentLength, Stream stream, bool debug) ContentLength = contentLength; Data = null; - DataStream = stream; - + DataStream = stream; + _Debug = debug; } @@ -209,7 +208,7 @@ internal WatsonMessage(SslStream stream, bool debug) _Debug = debug; } - #endregion + #endregion Constructors-and-Factories #region Public-Methods @@ -220,9 +219,9 @@ internal WatsonMessage(SslStream stream, bool debug) internal async Task Build() { try - { + { #region Read-Message-Length - + using (MemoryStream msgLengthMs = new MemoryStream()) { while (true) @@ -230,19 +229,19 @@ internal async Task Build() byte[] data = await ReadFromNetwork(1, "MessageLength"); await msgLengthMs.WriteAsync(data, 0, 1); if (data[0] == 58) break; - } + } byte[] msgLengthBytes = msgLengthMs.ToArray(); if (msgLengthBytes == null || msgLengthBytes.Length < 1) return false; - string msgLengthString = Encoding.UTF8.GetString(msgLengthBytes).Replace(":", ""); + string msgLengthString = Encoding.UTF8.GetString(msgLengthBytes).Replace(":", ""); long length; Int64.TryParse(msgLengthString, out length); Length = length; - if (_Debug) Console.WriteLine("Message payload length: " + Length + " bytes"); + Log("Message payload length: " + Length + " bytes"); } - #endregion + #endregion Read-Message-Length #region Process-Header-Fields @@ -256,7 +255,7 @@ internal async Task Build() { if (HeaderFields[i]) { - MessageField field = GetMessageField(i); + MessageField field = GetMessageField(i); object val = await ReadField(field.Type, field.Length, field.Name); SetMessageValue(field, val); payloadLength -= field.Length; @@ -267,39 +266,32 @@ internal async Task Build() DataStream = null; Data = await ReadFromNetwork(ContentLength, "Payload"); - #endregion + #endregion Process-Header-Fields return true; } catch (Exception e) { - if (_Debug) - { - Console.WriteLine(Common.SerializeJson(e)); - } - + Log(Common.SerializeJson(e)); throw; } finally { - if (_Debug) - { - Console.WriteLine("Message build completed:"); - Console.WriteLine(this.ToString()); - } + Log("Message build completed:"); + Log(this.ToString()); } } /// - /// Awaitable async method to build the Message object from data that awaits in a NetworkStream or SslStream, returning the stream itself. + /// Awaitable async method to build the Message object from data that awaits in a NetworkStream or SslStream, returning the stream itself. /// /// Always returns true (void cannot be a return parameter). internal async Task BuildStream() { try - { + { #region Read-Message-Length - + using (MemoryStream msgLengthMs = new MemoryStream()) { while (true) @@ -317,10 +309,10 @@ internal async Task BuildStream() Int64.TryParse(msgLengthString, out length); Length = length; - if (_Debug) Console.WriteLine("Message payload length: " + Length + " bytes"); + Log("Message payload length: " + Length + " bytes"); } - #endregion + #endregion Read-Message-Length #region Process-Header-Fields @@ -335,7 +327,7 @@ internal async Task BuildStream() if (HeaderFields[i]) { MessageField field = GetMessageField(i); - if (_Debug) Console.WriteLine("Reading header field " + i + " " + field.Name + " " + field.Type.ToString() + " " + field.Length + " bytes"); + Log("Reading header field " + i + " " + field.Name + " " + field.Type.ToString() + " " + field.Length + " bytes"); object val = await ReadField(field.Type, field.Length, field.Name); SetMessageValue(field, val); payloadLength -= field.Length; @@ -344,7 +336,7 @@ internal async Task BuildStream() ContentLength = payloadLength; Data = null; - + if (_NetworkStream != null) { DataStream = _NetworkStream; @@ -358,25 +350,19 @@ internal async Task BuildStream() throw new IOException("No suitable input stream found."); } - #endregion + #endregion Process-Header-Fields return true; } catch (Exception e) { - if (_Debug) - { - Console.WriteLine(Common.SerializeJson(e)); - } + Log(Common.SerializeJson(e)); throw e; } finally { - if (_Debug) - { - Console.WriteLine("Message build completed:"); - Console.WriteLine(this.ToString()); - } + Log("Message build completed:"); + Log(this.ToString()); } } @@ -394,46 +380,48 @@ internal byte[] ToHeaderBytes(long contentLength) byte[] ret = new byte[headerFieldsBytes.Length]; Buffer.BlockCopy(headerFieldsBytes, 0, ret, 0, headerFieldsBytes.Length); - + #region Header-Fields - + for (int i = 0; i < HeaderFields.Length; i++) - { + { if (HeaderFields[i]) { - if (_Debug) Console.WriteLine("Header field " + i + " is set"); + Log("Header field " + i + " is set"); MessageField field = GetMessageField(i); switch (i) { case 0: // preshared key - if (_Debug) Console.WriteLine("PresharedKey: " + Encoding.UTF8.GetString(PresharedKey)); + Log("PresharedKey: " + Encoding.UTF8.GetString(PresharedKey)); ret = AppendBytes(ret, PresharedKey); break; + case 1: // status - if (_Debug) Console.WriteLine("Status: " + Status.ToString() + " " + (int)Status); + Log("Status: " + Status.ToString() + " " + (int)Status); ret = AppendBytes(ret, IntegerToBytes((int)Status)); break; + default: throw new ArgumentException("Unknown bit number."); } } } - #endregion + #endregion Header-Fields #region Prepend-Message-Length long finalLen = ret.Length + contentLength; - if (_Debug) Console.WriteLine("Content length: " + finalLen + " (" + ret.Length + " + " + contentLength + ")"); + Log("Content length: " + finalLen + " (" + ret.Length + " + " + contentLength + ")"); byte[] lengthHeader = Encoding.UTF8.GetBytes(finalLen.ToString() + ":"); byte[] final = new byte[(lengthHeader.Length + ret.Length)]; Buffer.BlockCopy(lengthHeader, 0, final, 0, lengthHeader.Length); Buffer.BlockCopy(ret, 0, final, lengthHeader.Length, ret.Length); - #endregion + #endregion Prepend-Message-Length - if (_Debug) Console.WriteLine("ToHeaderBytes returning: " + Encoding.UTF8.GetString(final)); + Log("ToHeaderBytes returning: " + Encoding.UTF8.GetString(final)); return final; } @@ -463,7 +451,7 @@ public override string ToString() return ret; } - #endregion + #endregion Public-Methods #region Private-Methods @@ -481,7 +469,7 @@ private async Task ReadField(FieldType fieldType, int maxLength, string string logMessage = "ReadField " + fieldType.ToString() + " " + maxLength + " " + name; try - { + { byte[] data = null; int headerLength = 0; @@ -505,16 +493,16 @@ private async Task ReadField(FieldType fieldType, int maxLength, string { data = await ReadFromNetwork(maxLength, name + " String (" + maxLength + ")"); logMessage += " " + ByteArrayToHex(data); - ret = Encoding.UTF8.GetString(data); + ret = Encoding.UTF8.GetString(data); logMessage += ": " + headerLength + " " + ret; - } + } else if (fieldType == FieldType.DateTime) { - data = await ReadFromNetwork(22, name + " DateTime"); + data = await ReadFromNetwork(_DateTimeFormat.Length, name + " DateTime"); logMessage += " " + ByteArrayToHex(data); ret = DateTime.ParseExact(Encoding.UTF8.GetString(data), _DateTimeFormat, CultureInfo.InvariantCulture); logMessage += ": " + headerLength + " " + ret.ToString(); - } + } else if (fieldType == FieldType.ByteArray) { ret = await ReadFromNetwork(maxLength, name + " ByteArray (" + maxLength + ")"); @@ -551,7 +539,7 @@ private byte[] FieldToBytes(FieldType fieldType, object data, int maxLength) string lengthVar = ""; for (int i = 0; i < maxLength; i++) lengthVar += "0"; return Encoding.UTF8.GetBytes(longVar.ToString(lengthVar)); - } + } else if (fieldType == FieldType.String) { string dataStr = data.ToString().ToUpper(); @@ -583,7 +571,7 @@ private byte[] FieldToBytes(FieldType fieldType, object data, int maxLength) InitByteArray(ret); Buffer.BlockCopy((byte[])data, 0, ret, 0, maxLength); return ret; - } + } else { throw new ArgumentException("Unknown field type: " + fieldType.ToString()); @@ -605,11 +593,11 @@ private string FieldToString(FieldType fieldType, object data) else if (fieldType == FieldType.String) { return "[s]" + data.ToString(); - } + } else if (fieldType == FieldType.DateTime) { return "[d]" + Convert.ToDateTime(data).ToString(_DateTimeFormat); - } + } else if (fieldType == FieldType.ByteArray) { return "[b]" + ByteArrayToHex((byte[])data); @@ -639,7 +627,7 @@ private string FieldToString(FieldType fieldType, object data) private async Task ReadFromNetwork(long count, string field) { - if (_Debug) Console.WriteLine("ReadFromNetwork " + count + " " + field); + Log("ReadFromNetwork " + count + " " + field); string logMessage = null; try @@ -652,7 +640,7 @@ private async Task ReadFromNetwork(long count, string field) InitByteArray(buffer); if (_NetworkStream != null) - { + { while (true) { read = await _NetworkStream.ReadAsync(buffer, 0, buffer.Length); @@ -662,10 +650,10 @@ private async Task ReadFromNetwork(long count, string field) Buffer.BlockCopy(buffer, 0, ret, 0, read); break; } - } + } } else if (_SslStream != null) - { + { while (true) { read = await _SslStream.ReadAsync(buffer, 0, buffer.Length); @@ -675,8 +663,8 @@ private async Task ReadFromNetwork(long count, string field) Buffer.BlockCopy(buffer, 0, ret, 0, read); break; } - } - } + } + } else { throw new IOException("No suitable input stream found."); @@ -689,7 +677,7 @@ private async Task ReadFromNetwork(long count, string field) } finally { - if (_Debug) Console.WriteLine("- Result: " + field + " " + count + ": " + logMessage); + Log("- Result: " + field + " " + count + ": " + logMessage); } } @@ -824,11 +812,13 @@ private MessageField GetMessageField(int bitNumber) switch (bitNumber) { case 0: - if (_Debug) Console.WriteLine("Returning field PresharedKey"); + Log("Returning field PresharedKey"); return new MessageField(0, "PresharedKey", FieldType.ByteArray, 16); + case 1: - if (_Debug) Console.WriteLine("Returning field Status"); + Log("Returning field Status"); return new MessageField(1, "Status", FieldType.Int32, 4); + default: throw new KeyNotFoundException(); } @@ -843,17 +833,27 @@ private void SetMessageValue(MessageField field, object val) { case 0: PresharedKey = (byte[])val; - if (_Debug) Console.WriteLine("PresharedKey set: " + Encoding.UTF8.GetString(PresharedKey)); + Log("PresharedKey set: " + Encoding.UTF8.GetString(PresharedKey)); return; + case 1: Status = (MessageStatus)((int)val); - if (_Debug) Console.WriteLine("Status set: " + Status.ToString()); + Log("Status set: " + Status.ToString()); return; + default: throw new ArgumentException("Unknown bit number."); - } + } } - - #endregion + + private void Log(string msg) + { + if (_Debug) + { + Console.WriteLine(msg); + } + } + + #endregion Private-Methods } -} +} \ No newline at end of file diff --git a/WatsonTcp/Mode.cs b/WatsonTcp/Mode.cs index 7215463..86b01ab 100644 --- a/WatsonTcp/Mode.cs +++ b/WatsonTcp/Mode.cs @@ -1,12 +1,8 @@ -using System; -using System.Collections.Generic; -using System.Text; - -namespace WatsonTcp +namespace WatsonTcp { internal enum Mode { Tcp, Ssl } -} +} \ No newline at end of file diff --git a/WatsonTcp/WatsonTcp.csproj b/WatsonTcp/WatsonTcp.csproj index 852cdf9..a02d241 100644 --- a/WatsonTcp/WatsonTcp.csproj +++ b/WatsonTcp/WatsonTcp.csproj @@ -3,7 +3,7 @@ netstandard2.0;net452 true - 1.3.12 + 2.0.0 Joel Christner Joel Christner A simple C# async TCP server and client with integrated framing for reliable transmission and receipt of data @@ -13,7 +13,7 @@ https://github.com/jchristn/WatsonTcp Github https://github.com/jchristn/WatsonTcp/blob/master/LICENSE.TXT - Reduce instances of calls to GetStream + Breaking changes, configurable connect timeout, Task-based async callbacks https://raw.githubusercontent.com/jchristn/watsontcp/master/assets/watson.ico diff --git a/WatsonTcp/WatsonTcpClient.cs b/WatsonTcp/WatsonTcpClient.cs index a5a6766..2492d0e 100644 --- a/WatsonTcp/WatsonTcpClient.cs +++ b/WatsonTcp/WatsonTcpClient.cs @@ -52,40 +52,36 @@ public int ReadStreamBufferSize public Func AuthenticationRequested = null; /// - /// Function called when authentication has succeeded. Expects a response of 'true'. + /// Function called when authentication has succeeded. /// - public Func AuthenticationSucceeded = null; + public Func AuthenticationSucceeded = null; /// - /// Function called when authentication has failed. Expects a response of 'true'. + /// Function called when authentication has failed. /// - public Func AuthenticationFailure = null; + public Func AuthenticationFailure = null; /// - /// Function called when a message is received. + /// Function called when a message is received. /// A byte array containing the message data is passed to this function. - /// It is expected that 'true' will be returned. /// - public Func MessageReceived = null; + public Func MessageReceived = null; /// /// Method to call when a message is received from a client. - /// The IP:port is passed to this method as a string, along with a long indicating the number of bytes to read from the stream. - /// It is expected that the method will return true; + /// The number of bytes (long) and the stream containing the data are passed to this function. /// - public Func StreamReceived = null; + public Func StreamReceived = null; /// /// Function called when the client successfully connects to the server. - /// It is expected that 'true' will be returned. /// - public Func ServerConnected = null; + public Func ServerConnected = null; /// /// Function called when the client disconnects from the server. - /// It is expected that 'true' will be returned. /// - public Func ServerDisconnected = null; + public Func ServerDisconnected = null; /// /// Enable acceptance of SSL certificates from the server that cannot be validated. @@ -95,38 +91,72 @@ public int ReadStreamBufferSize /// /// Require mutual authentication between the server and this client. /// - public bool MutuallyAuthenticate = false; - + public bool MutuallyAuthenticate + { + get + { + return _MutuallyAuthenticate; + } + set + { + if (value) + { + if (_Mode == Mode.Tcp) throw new ArgumentException("Mutual authentication only supported with SSL."); + if (_SslCertificate == null) throw new ArgumentException("Mutual authentication requires a certificate."); + } + + _MutuallyAuthenticate = value; + } + } + /// /// Indicates whether or not the client is connected to the server. /// public bool Connected { get; private set; } - #endregion + /// + /// The number of seconds to wait before timing out a connection attempt. Default is 5 seconds. + /// + public int ConnectTimeoutSeconds + { + get + { + return _ConnectTimeoutSeconds; + } + set + { + if (value < 1) throw new ArgumentException("ConnectTimeoutSeconds must be greater than zero."); + _ConnectTimeoutSeconds = value; + } + } + + #endregion Public-Members #region Private-Members private bool _Disposed = false; private int _ReadStreamBufferSize = 65536; - private Mode _Mode; + private int _ConnectTimeoutSeconds = 5; + private Mode _Mode; private string _SourceIp; private int _SourcePort; private string _ServerIp; - private int _ServerPort; - private TcpClient _Client; - private NetworkStream _TcpStream; - private SslStream _SslStream; + private int _ServerPort; + private bool _MutuallyAuthenticate = false; + private TcpClient _Client = null; + private NetworkStream _TcpStream = null; + private SslStream _SslStream = null; - private X509Certificate2 _SslCertificate; - private X509Certificate2Collection _SslCertificateCollection; + private X509Certificate2 _SslCertificate = null; + private X509Certificate2Collection _SslCertificateCollection = null; - private SemaphoreSlim _WriteLock; - private SemaphoreSlim _ReadLock; + private SemaphoreSlim _WriteLock = new SemaphoreSlim(1); + private SemaphoreSlim _ReadLock = new SemaphoreSlim(1); - private CancellationTokenSource _TokenSource; + private CancellationTokenSource _TokenSource = new CancellationTokenSource(); private CancellationToken _Token; - #endregion + #endregion Private-Members #region Constructors-and-Factories @@ -139,17 +169,15 @@ public WatsonTcpClient( string serverIp, int serverPort) { - if (String.IsNullOrEmpty(serverIp)) throw new ArgumentNullException(nameof(serverIp)); + if (String.IsNullOrEmpty(serverIp)) throw new ArgumentNullException(nameof(serverIp)); if (serverPort < 1) throw new ArgumentOutOfRangeException(nameof(serverPort)); + _Token = _TokenSource.Token; _Mode = Mode.Tcp; _ServerIp = serverIp; - _ServerPort = serverPort; - _WriteLock = new SemaphoreSlim(1); - _ReadLock = new SemaphoreSlim(1); - _SslStream = null; + _ServerPort = serverPort; } - + /// /// Initialize the Watson TCP client with SSL. Call Start() afterward to connect to the server. /// @@ -163,26 +191,37 @@ public WatsonTcpClient( string pfxCertFile, string pfxCertPass) { - if (String.IsNullOrEmpty(serverIp)) throw new ArgumentNullException(nameof(serverIp)); + if (String.IsNullOrEmpty(serverIp)) throw new ArgumentNullException(nameof(serverIp)); if (serverPort < 1) throw new ArgumentOutOfRangeException(nameof(serverPort)); + _Token = _TokenSource.Token; _Mode = Mode.Ssl; _ServerIp = serverIp; _ServerPort = serverPort; - _WriteLock = new SemaphoreSlim(1); - _ReadLock = new SemaphoreSlim(1); - _TcpStream = null; - _SslCertificate = null; - if (String.IsNullOrEmpty(pfxCertPass)) _SslCertificate = new X509Certificate2(pfxCertFile); - else _SslCertificate = new X509Certificate2(pfxCertFile, pfxCertPass); - _SslCertificateCollection = new X509Certificate2Collection + if (!String.IsNullOrEmpty(pfxCertFile)) { - _SslCertificate - }; + if (String.IsNullOrEmpty(pfxCertPass)) + { + _SslCertificate = new X509Certificate2(pfxCertFile); + } + else + { + _SslCertificate = new X509Certificate2(pfxCertFile, pfxCertPass); + } + + _SslCertificateCollection = new X509Certificate2Collection + { + _SslCertificate + }; + } + else + { + _SslCertificateCollection = new X509Certificate2Collection(); + } } - #endregion + #endregion Constructors-and-Factories #region Public-Methods @@ -203,6 +242,7 @@ public void Start() _Client = new TcpClient(); IAsyncResult asyncResult = null; WaitHandle waitHandle = null; + bool connectSuccess = false; if (_Mode == Mode.Tcp) { @@ -216,7 +256,8 @@ public void Start() try { - if (!asyncResult.AsyncWaitHandle.WaitOne(TimeSpan.FromSeconds(5), false)) + connectSuccess = waitHandle.WaitOne(TimeSpan.FromSeconds(_ConnectTimeoutSeconds), false); + if (!connectSuccess) { _Client.Close(); throw new TimeoutException("Timeout connecting to " + _ServerIp + ":" + _ServerPort); @@ -240,20 +281,22 @@ public void Start() waitHandle.Close(); } - #endregion + #endregion TCP } else if (_Mode == Mode.Ssl) { #region SSL Log("Watson TCP client connecting with SSL to " + _ServerIp + ":" + _ServerPort); - + + _Client.LingerState = new LingerOption(true, 0); asyncResult = _Client.BeginConnect(_ServerIp, _ServerPort, null, null); waitHandle = asyncResult.AsyncWaitHandle; try { - if (!asyncResult.AsyncWaitHandle.WaitOne(TimeSpan.FromSeconds(5), false)) + connectSuccess = waitHandle.WaitOne(TimeSpan.FromSeconds(_ConnectTimeoutSeconds), false); + if (!connectSuccess) { _Client.Close(); throw new TimeoutException("Timeout connecting to " + _ServerIp + ":" + _ServerPort); @@ -292,7 +335,6 @@ public void Start() throw new AuthenticationException("Mutual authentication failed"); } - Connected = true; } catch (Exception) @@ -302,22 +344,20 @@ public void Start() finally { waitHandle.Close(); - } + } - #endregion + #endregion SSL } else { throw new ArgumentException("Unknown mode: " + _Mode.ToString()); } - + if (ServerConnected != null) { Task.Run(() => ServerConnected()); } - _TokenSource = new CancellationTokenSource(); - _Token = _TokenSource.Token; Task.Run(async () => await DataReceiver(_Token), _Token); } @@ -326,7 +366,7 @@ public void Start() /// /// Up to 16-character string. public void Authenticate(string presharedKey) - { + { if (String.IsNullOrEmpty(presharedKey)) throw new ArgumentNullException(nameof(presharedKey)); if (presharedKey.Length != 16) throw new ArgumentException("Preshared key length must be 16 bytes."); @@ -359,7 +399,7 @@ public bool Send(long contentLength, Stream stream) { return MessageWrite(contentLength, stream); } - + /// /// Send data to the server asynchronously /// @@ -380,8 +420,8 @@ public async Task SendAsync(long contentLength, Stream stream) { return await MessageWriteAsync(contentLength, stream); } - - #endregion + + #endregion Public-Methods #region Private-Methods @@ -404,7 +444,6 @@ protected virtual void Dispose(bool disposing) } catch (Exception) { - } finally { @@ -414,17 +453,16 @@ protected virtual void Dispose(bool disposing) } if (_TcpStream != null) - { + { try { _WriteLock.Wait(1); _ReadLock.Wait(1); - if (_TcpStream != null) _TcpStream.Close(); + if (_TcpStream != null) _TcpStream.Close(); } catch (Exception) { - - } + } try { @@ -432,7 +470,6 @@ protected virtual void Dispose(bool disposing) } catch (Exception) { - } finally { @@ -442,14 +479,17 @@ protected virtual void Dispose(bool disposing) } _TokenSource.Cancel(); - _TokenSource.Dispose(); + _TokenSource.Dispose(); + + if (_WriteLock != null) _WriteLock.Dispose(); + if (_ReadLock != null) _ReadLock.Dispose(); Connected = false; } _Disposed = true; } - + private bool AcceptCertificate(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors) { // return true; // Allow untrusted certificates. @@ -463,21 +503,8 @@ private void Log(string msg) Console.WriteLine(msg); } } - - private void LogException(string method, Exception e) - { - Log("================================================================================"); - Log(" = Method: " + method); - Log(" = Exception Type: " + e.GetType().ToString()); - Log(" = Exception Data: " + e.Data); - Log(" = Inner Exception: " + e.InnerException); - Log(" = Exception Message: " + e.Message); - Log(" = Exception Source: " + e.Source); - Log(" = Exception StackTrace: " + e.StackTrace); - Log("================================================================================"); - } - - private async Task DataReceiver(CancellationToken? cancelToken=null) + + private async Task DataReceiver(CancellationToken? cancelToken = null) { try { @@ -500,14 +527,14 @@ private async Task DataReceiver(CancellationToken? cancelToken=null) Log("*** DataReceiver server disconnected"); break; } - + if (_SslStream != null && !_SslStream.CanRead) { Log("*** DataReceiver cannot read from SSL stream"); break; } - #endregion + #endregion Check-Connection #region Read-Message-and-Handle @@ -555,7 +582,12 @@ private async Task DataReceiver(CancellationToken? cancelToken=null) continue; } - if (msg.Status == MessageStatus.AuthSuccess) + if (msg.Status == MessageStatus.Removed) + { + Log("*** DataReceiver removed from the server"); + break; + } + else if (msg.Status == MessageStatus.AuthSuccess) { Log("DataReceiver successfully authenticated"); AuthenticationSucceeded?.Invoke(); @@ -586,7 +618,7 @@ private async Task DataReceiver(CancellationToken? cancelToken=null) { if (MessageReceived != null) { - Task unawaited = Task.Run(() => MessageReceived(msg.Data)); + Task unawaited = Task.Run(() => MessageReceived(msg.Data)); } } else @@ -594,22 +626,19 @@ private async Task DataReceiver(CancellationToken? cancelToken=null) StreamReceived?.Invoke(msg.ContentLength, msg.DataStream); } - #endregion + #endregion Read-Message-and-Handle } - #endregion + #endregion Wait-for-Data } catch (OperationCanceledException) - { - + { } catch (ObjectDisposedException) - { - + { } catch (IOException) - { - + { } catch (Exception e) { @@ -622,7 +651,7 @@ private async Task DataReceiver(CancellationToken? cancelToken=null) finally { Connected = false; - ServerDisconnected?.Invoke(); + ServerDisconnected?.Invoke(); } } @@ -633,22 +662,22 @@ private bool MessageWrite(WatsonMessage msg) if (msg.Data != null) dataLen = msg.Data.Length; try - { + { if (_Client == null) { Log("MessageWrite client is null"); disconnectDetected = true; return false; - } + } byte[] headerBytes = msg.ToHeaderBytes(dataLen); _WriteLock.Wait(1); try - { + { if (_Mode == Mode.Tcp) - { + { _TcpStream.Write(headerBytes, 0, headerBytes.Length); if (msg.Data != null && msg.Data.Length > 0) _TcpStream.Write(msg.Data, 0, msg.Data.Length); _TcpStream.Flush(); @@ -662,16 +691,16 @@ private bool MessageWrite(WatsonMessage msg) else { throw new ArgumentException("Unknown mode: " + _Mode.ToString()); - } + } } finally { _WriteLock.Release(); } - string logMessage = "MessageWrite sent " + Encoding.UTF8.GetString(headerBytes); + string logMessage = "MessageWrite sent " + Encoding.UTF8.GetString(headerBytes); Log(logMessage); - return true; + return true; } catch (ObjectDisposedException ObjDispInner) { @@ -724,7 +753,7 @@ private bool MessageWrite(byte[] data) ms.Seek(0, SeekOrigin.Begin); } - return MessageWrite(dataLen, ms); + return MessageWrite(dataLen, ms); } private bool MessageWrite(long contentLength, Stream stream) @@ -741,14 +770,14 @@ private bool MessageWrite(long contentLength, Stream stream) bool disconnectDetected = false; try - { + { if (_Client == null) { Log("MessageWrite client is null"); disconnectDetected = true; return false; } - + WatsonMessage msg = new WatsonMessage(contentLength, stream, Debug); byte[] headerBytes = msg.ToHeaderBytes(contentLength); @@ -761,7 +790,7 @@ private bool MessageWrite(long contentLength, Stream stream) try { if (_Mode == Mode.Tcp) - { + { _TcpStream.Write(headerBytes, 0, headerBytes.Length); if (contentLength > 0) @@ -838,7 +867,7 @@ private bool MessageWrite(long contentLength, Stream stream) } catch (Exception e) { - LogException("MessageWrite", e); + Common.LogException("MessageWrite", e); disconnectDetected = true; return false; } @@ -863,11 +892,12 @@ private async Task MessageWriteAsync(byte[] data) ms.Seek(0, SeekOrigin.Begin); } - return await MessageWriteAsync(dataLen, ms); + return await MessageWriteAsync(dataLen, ms); } private async Task MessageWriteAsync(long contentLength, Stream stream) { + if (!Connected) return false; if (contentLength < 0) throw new ArgumentException("Content length must be zero or greater bytes."); if (contentLength > 0) { @@ -875,19 +905,19 @@ private async Task MessageWriteAsync(long contentLength, Stream stream) { throw new ArgumentException("Cannot read from supplied stream."); } - } + } bool disconnectDetected = false; try - { + { if (_Client == null) { Log("MessageWriteAsync client is null"); disconnectDetected = true; return false; } - + WatsonMessage msg = new WatsonMessage(contentLength, stream, Debug); byte[] headerBytes = msg.ToHeaderBytes(contentLength); @@ -900,7 +930,7 @@ private async Task MessageWriteAsync(long contentLength, Stream stream) try { if (_Mode == Mode.Tcp) - { + { await _TcpStream.WriteAsync(headerBytes, 0, headerBytes.Length); if (contentLength > 0) @@ -977,7 +1007,7 @@ private async Task MessageWriteAsync(long contentLength, Stream stream) } catch (Exception e) { - LogException("MessageWriteAsync", e); + Common.LogException("MessageWriteAsync", e); disconnectDetected = true; return false; } @@ -991,6 +1021,6 @@ private async Task MessageWriteAsync(long contentLength, Stream stream) } } - #endregion + #endregion Private-Methods } -} +} \ No newline at end of file diff --git a/WatsonTcp/WatsonTcpServer.cs b/WatsonTcp/WatsonTcpServer.cs index fb9da1b..8b3b597 100644 --- a/WatsonTcp/WatsonTcpServer.cs +++ b/WatsonTcp/WatsonTcpServer.cs @@ -55,30 +55,28 @@ public int ReadStreamBufferSize public List PermittedIPs = null; /// - /// Method to call when a client connects to the server. - /// The IP:port is passed to this method as a string, and it is expected that the method will return true. + /// Method to call when a client connects to the server. + /// The IP:port is passed to this method as a string. /// - public Func ClientConnected = null; + public Func ClientConnected = null; /// - /// Method to call when a client disconnects from the server. - /// The IP:port is passed to this method as a string, and it is expected that the method will return true. + /// Method to call when a client disconnects from the server. + /// The IP:port is passed to this method as a string. /// - public Func ClientDisconnected = null; + public Func ClientDisconnected = null; /// - /// Method to call when a message is received from a client. - /// The IP:port is passed to this method as a string, along with a byte array containing the message data. - /// It is expected that the method will return true. + /// Method to call when a message is received from a client. + /// The IP:port is passed to this method as a string, along with a byte array containing the message data. /// - public Func MessageReceived = null; + public Func MessageReceived = null; /// /// Method to call when a message is received from a client. /// The IP:port is passed to this method as a string, along with a long indicating the number of bytes to read from the stream. - /// It is expected that the method will return true; /// - public Func StreamReceived = null; + public Func StreamReceived = null; /// /// Enable acceptance of SSL certificates from clients that cannot be validated. @@ -95,15 +93,15 @@ public int ReadStreamBufferSize /// public string PresharedKey = null; - #endregion + #endregion Public-Members #region Private-Members private bool _Disposed = false; private int _ReadStreamBufferSize = 65536; - private Mode _Mode; + private Mode _Mode; private string _ListenerIp; - private int _ListenerPort; + private int _ListenerPort; private IPAddress _ListenerIpAddress; private TcpListener _Listener; @@ -112,14 +110,14 @@ public int ReadStreamBufferSize private int _ActiveClients; private ConcurrentDictionary _Clients; private ConcurrentDictionary _UnauthenticatedClients; - + private CancellationTokenSource _TokenSource; private CancellationToken _Token; - #endregion + #endregion Private-Members #region Constructors-and-Factories - + /// /// Initialize the Watson TCP server without SSL. Call Start() afterward to start Watson. /// @@ -145,7 +143,7 @@ public WatsonTcpServer( } _ListenerPort = listenerPort; - + _Listener = new TcpListener(_ListenerIpAddress, _ListenerPort); _TokenSource = new CancellationTokenSource(); @@ -153,9 +151,9 @@ public WatsonTcpServer( _ActiveClients = 0; _Clients = new ConcurrentDictionary(); - _UnauthenticatedClients = new ConcurrentDictionary(); + _UnauthenticatedClients = new ConcurrentDictionary(); } - + /// /// Initialize the Watson TCP server with SSL. Call Start() afterward to start Watson. /// @@ -170,6 +168,7 @@ public WatsonTcpServer( string pfxCertPass) { if (listenerPort < 1) throw new ArgumentOutOfRangeException(nameof(listenerPort)); + if (String.IsNullOrEmpty(pfxCertFile)) throw new ArgumentNullException(nameof(pfxCertFile)); _Mode = Mode.Ssl; @@ -186,7 +185,7 @@ public WatsonTcpServer( _ListenerPort = listenerPort; - _SslCertificate = null; + _SslCertificate = null; if (String.IsNullOrEmpty(pfxCertPass)) { _SslCertificate = new X509Certificate2(pfxCertFile); @@ -194,17 +193,17 @@ public WatsonTcpServer( else { _SslCertificate = new X509Certificate2(pfxCertFile, pfxCertPass); - } + } _Listener = new TcpListener(_ListenerIpAddress, _ListenerPort); _TokenSource = new CancellationTokenSource(); _Token = _TokenSource.Token; _ActiveClients = 0; _Clients = new ConcurrentDictionary(); - _UnauthenticatedClients = new ConcurrentDictionary(); + _UnauthenticatedClients = new ConcurrentDictionary(); } - - #endregion + + #endregion Constructors-and-Factories #region Public-Methods @@ -218,7 +217,7 @@ public void Dispose() } /// - /// Start the server. + /// Start the server. /// public void Start() { @@ -347,11 +346,18 @@ public void DisconnectClient(string ipPort) } else { + byte[] data = Encoding.UTF8.GetBytes("Removed from server"); + WatsonMessage removeMsg = new WatsonMessage(); + removeMsg.Status = MessageStatus.Removed; + removeMsg.Data = null; + removeMsg.ContentLength = 0; + MessageWrite(client, removeMsg, null); client.Dispose(); + _Clients.TryRemove(ipPort, out ClientMetadata removed); } } - #endregion + #endregion Public-Methods #region Private-Methods @@ -381,7 +387,7 @@ protected virtual void Dispose(bool disposing) } } } - + _Disposed = true; } @@ -389,20 +395,7 @@ private void Log(string msg) { if (Debug) Console.WriteLine(msg); } - - private void LogException(string method, Exception e) - { - Log("================================================================================"); - Log(" = Method: " + method); - Log(" = Exception Type: " + e.GetType().ToString()); - Log(" = Exception Data: " + e.Data); - Log(" = Inner Exception: " + e.InnerException); - Log(" = Exception Message: " + e.Message); - Log(" = Exception Source: " + e.Source); - Log(" = Exception StackTrace: " + e.StackTrace); - Log("================================================================================"); - } - + private bool AcceptCertificate(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors) { // return true; // Allow untrusted certificates. @@ -422,7 +415,7 @@ private async Task AcceptConnections() TcpClient tcpClient = await _Listener.AcceptTcpClientAsync(); tcpClient.LingerState.Enabled = false; - + string clientIp = ((IPEndPoint)tcpClient.Client.RemoteEndPoint).Address.ToString(); if (PermittedIPs != null && PermittedIPs.Count > 0) { @@ -437,18 +430,15 @@ private async Task AcceptConnections() ClientMetadata client = new ClientMetadata(tcpClient); clientIpPort = client.IpPort; - #endregion + #endregion Accept-Connection-and-Validate-IP if (_Mode == Mode.Tcp) { #region Tcp - Task unawaited = Task.Run(() => - { - FinalizeConnection(client); - }, _Token); + Task unawaited = Task.Run(() => FinalizeConnection(client), _Token); - #endregion + #endregion Tcp } else if (_Mode == Mode.Ssl) { @@ -463,7 +453,8 @@ private async Task AcceptConnections() client.SslStream = new SslStream(client.NetworkStream, false); } - Task unawaited = Task.Run(() => { + Task unawaited = Task.Run(() => + { Task success = StartTls(client); if (success.Result) { @@ -471,13 +462,13 @@ private async Task AcceptConnections() } }, _Token); - #endregion + #endregion SSL } else { throw new ArgumentException("Unknown mode: " + _Mode.ToString()); } - + Log("*** AcceptConnections accepted connection from " + client.IpPort); } catch (Exception e) @@ -490,10 +481,8 @@ private async Task AcceptConnections() private async Task StartTls(ClientMetadata client) { try - { - // the two bools in this should really be contruction paramaters - // maybe re-use mutualAuthentication and acceptInvalidCerts ? - await client.SslStream.AuthenticateAsServerAsync(_SslCertificate, true, SslProtocols.Tls12, false); + { + await client.SslStream.AuthenticateAsServerAsync(_SslCertificate, true, SslProtocols.Tls12, !AcceptInvalidCertificates); if (!client.SslStream.IsEncrypted) { @@ -525,9 +514,11 @@ private async Task StartTls(ClientMetadata client) case "Unable to read data from the transport connection: An existing connection was forcibly closed by the remote host.": Log("*** StartTls IOException " + client.IpPort + " closed the connection."); break; + case "The handshake failed due to an unexpected packet format.": Log("*** StartTls IOException " + client.IpPort + " disconnected, invalid handshake."); break; + default: Log("*** StartTls IOException from " + client.IpPort + Environment.NewLine + ex.ToString()); break; @@ -560,7 +551,7 @@ private void FinalizeConnection(ClientMetadata client) // Do not decrement in this block, decrement is done by the connection reader int activeCount = Interlocked.Increment(ref _ActiveClients); - #endregion + #endregion Add-to-Client-List #region Request-Authentication @@ -577,7 +568,7 @@ private void FinalizeConnection(ClientMetadata client) MessageWrite(client, authMsg, null); } - #endregion + #endregion Request-Authentication #region Start-Data-Receiver @@ -589,7 +580,8 @@ private void FinalizeConnection(ClientMetadata client) Task.Run(async () => await DataReceiver(client)); - #endregion + #endregion Start-Data-Receiver + } private bool IsConnected(ClientMetadata client) @@ -610,16 +602,14 @@ private bool IsConnected(ClientMetadata client) } catch (ObjectDisposedException) { - } catch (IOException) { - } catch (SocketException se) { if (se.NativeErrorCode.Equals(10035)) success = true; - } + } catch (Exception e) { Log("*** IsConnected " + client.IpPort + " exception using send: " + e.Message); @@ -659,7 +649,7 @@ private bool IsConnected(ClientMetadata client) { Log("*** IsConnected " + client.IpPort + " exception using poll/peek: " + e.Message); return false; - } + } finally { if (readLocked) client.ReadLock.Release(); @@ -697,12 +687,12 @@ private async Task DataReceiver(ClientMetadata client) if (!String.IsNullOrEmpty(PresharedKey)) { if (_UnauthenticatedClients.ContainsKey(client.IpPort)) - { + { Log("*** DataReceiver message received from unauthenticated endpoint: " + client.IpPort); if (msg.Status == MessageStatus.AuthRequested) { - // check preshared key + // check preshared key if (msg.PresharedKey != null && msg.PresharedKey.Length > 0) { string clientPsk = Encoding.UTF8.GetString(msg.PresharedKey).Trim(); @@ -753,14 +743,14 @@ private async Task DataReceiver(ClientMetadata client) { if (MessageReceived != null) { - Task unawaited = Task.Run(() => MessageReceived(client.IpPort, msg.Data)); + Task unawaited = Task.Run(() => MessageReceived(client.IpPort, msg.Data)); } } else { if (StreamReceived != null) - { - StreamReceived(client.IpPort, msg.ContentLength, msg.DataStream); + { + Task unawaited = Task.Run(() => StreamReceived(client.IpPort, msg.ContentLength, msg.DataStream)); } } } @@ -770,7 +760,7 @@ private async Task DataReceiver(ClientMetadata client) } } - #endregion + #endregion Wait-for-Data } finally { @@ -779,7 +769,7 @@ private async Task DataReceiver(ClientMetadata client) if (ClientDisconnected != null) { - Task unawaited = Task.Run(() => ClientDisconnected(client.IpPort)); + Task unawaited = Task.Run(() => ClientDisconnected(client.IpPort)); } Log("*** DataReceiver client " + client.IpPort + " disconnected (now " + activeCount + " clients active)"); @@ -804,7 +794,7 @@ private bool RemoveClient(ClientMetadata client) Log("*** RemoveClient removed client " + client.IpPort); return true; } - + private async Task MessageReadAsync(ClientMetadata client) { /* @@ -847,7 +837,7 @@ private async Task MessageReadAsync(ClientMetadata client) throw new ArgumentException("Unknown mode: " + _Mode.ToString()); } - return msg; + return msg; } private bool MessageWrite(ClientMetadata client, WatsonMessage msg, byte[] data) @@ -861,7 +851,7 @@ private bool MessageWrite(ClientMetadata client, WatsonMessage msg, byte[] data) ms.Seek(0, SeekOrigin.Begin); } - return MessageWrite(client, msg, dataLen, ms); + return MessageWrite(client, msg, dataLen, ms); } private bool MessageWrite(ClientMetadata client, WatsonMessage msg, long contentLength, Stream stream) @@ -876,13 +866,13 @@ private bool MessageWrite(ClientMetadata client, WatsonMessage msg, long content throw new ArgumentException("Cannot read from supplied stream."); } } - + byte[] headerBytes = msg.ToHeaderBytes(contentLength); int bytesRead = 0; long bytesRemaining = contentLength; byte[] buffer = new byte[_ReadStreamBufferSize]; - + client.WriteLock.Wait(1); try @@ -890,7 +880,7 @@ private bool MessageWrite(ClientMetadata client, WatsonMessage msg, long content if (_Mode == Mode.Tcp) { client.NetworkStream.Write(headerBytes, 0, headerBytes.Length); - + if (contentLength > 0) { while (bytesRemaining > 0) @@ -944,7 +934,7 @@ private bool MessageWrite(ClientMetadata client, WatsonMessage msg, long content } private async Task MessageWriteAsync(ClientMetadata client, WatsonMessage msg, byte[] data) - { + { int dataLen = 0; MemoryStream ms = new MemoryStream(); if (data != null && data.Length > 0) @@ -954,7 +944,7 @@ private async Task MessageWriteAsync(ClientMetadata client, WatsonMessage ms.Seek(0, SeekOrigin.Begin); } - return await MessageWriteAsync(client, msg, dataLen, ms); + return await MessageWriteAsync(client, msg, dataLen, ms); } private async Task MessageWriteAsync(ClientMetadata client, WatsonMessage msg, long contentLength, Stream stream) @@ -969,7 +959,7 @@ private async Task MessageWriteAsync(ClientMetadata client, WatsonMessage throw new ArgumentException("Cannot read from supplied stream."); } } - + byte[] headerBytes = msg.ToHeaderBytes(contentLength); int bytesRead = 0; @@ -983,7 +973,7 @@ private async Task MessageWriteAsync(ClientMetadata client, WatsonMessage if (_Mode == Mode.Tcp) { await client.NetworkStream.WriteAsync(headerBytes, 0, headerBytes.Length); - + if (contentLength > 0) { while (bytesRemaining > 0) @@ -1036,6 +1026,6 @@ private async Task MessageWriteAsync(ClientMetadata client, WatsonMessage } } - #endregion + #endregion Private-Methods } -} +} \ No newline at end of file