Simple Web Transport

This commit is contained in:
Derek S 2021-04-07 01:07:51 -05:00
parent 99163d7d4e
commit 0ed36a5a2f
21 changed files with 2177 additions and 0 deletions

View file

@ -0,0 +1,4 @@
using System.Runtime.CompilerServices;
[assembly: InternalsVisibleTo("SimpleWebTransport.Tests.Runtime")]
[assembly: InternalsVisibleTo("SimpleWebTransport.Tests.Editor")]

View file

@ -0,0 +1,265 @@
using System;
using System.Collections.Concurrent;
using System.Diagnostics;
using System.Runtime.InteropServices;
using System.Threading;
namespace Mirror.SimpleWeb
{
public interface IBufferOwner
{
void Return(ArrayBuffer buffer);
}
public sealed class ArrayBuffer : IDisposable
{
readonly IBufferOwner owner;
public readonly byte[] array;
/// <summary>
/// number of bytes writen to buffer
/// </summary>
internal int count;
/// <summary>
/// How many times release needs to be called before buffer is returned to pool
/// <para>This allows the buffer to be used in multiple places at the same time</para>
/// </summary>
public void SetReleasesRequired(int required)
{
releasesRequired = required;
}
/// <summary>
/// How many times release needs to be called before buffer is returned to pool
/// <para>This allows the buffer to be used in multiple places at the same time</para>
/// </summary>
/// <remarks>
/// This value is normally 0, but can be changed to require release to be called multiple times
/// </remarks>
int releasesRequired;
public ArrayBuffer(IBufferOwner owner, int size)
{
this.owner = owner;
array = new byte[size];
}
public void Release()
{
int newValue = Interlocked.Decrement(ref releasesRequired);
if (newValue <= 0)
{
count = 0;
owner.Return(this);
}
}
public void Dispose()
{
Release();
}
public void CopyTo(byte[] target, int offset)
{
if (count > (target.Length + offset)) throw new ArgumentException($"{nameof(count)} was greater than {nameof(target)}.length", nameof(target));
Buffer.BlockCopy(array, 0, target, offset, count);
}
public void CopyFrom(ArraySegment<byte> segment)
{
CopyFrom(segment.Array, segment.Offset, segment.Count);
}
public void CopyFrom(byte[] source, int offset, int length)
{
if (length > array.Length) throw new ArgumentException($"{nameof(length)} was greater than {nameof(array)}.length", nameof(length));
count = length;
Buffer.BlockCopy(source, offset, array, 0, length);
}
public void CopyFrom(IntPtr bufferPtr, int length)
{
if (length > array.Length) throw new ArgumentException($"{nameof(length)} was greater than {nameof(array)}.length", nameof(length));
count = length;
Marshal.Copy(bufferPtr, array, 0, length);
}
public ArraySegment<byte> ToSegment()
{
return new ArraySegment<byte>(array, 0, count);
}
[Conditional("UNITY_ASSERTIONS")]
internal void Validate(int arraySize)
{
if (array.Length != arraySize)
{
Log.Error("Buffer that was returned had an array of the wrong size");
}
}
}
internal class BufferBucket : IBufferOwner
{
public readonly int arraySize;
readonly ConcurrentQueue<ArrayBuffer> buffers;
/// <summary>
/// keeps track of how many arrays are taken vs returned
/// </summary>
internal int _current = 0;
public BufferBucket(int arraySize)
{
this.arraySize = arraySize;
buffers = new ConcurrentQueue<ArrayBuffer>();
}
public ArrayBuffer Take()
{
IncrementCreated();
if (buffers.TryDequeue(out ArrayBuffer buffer))
{
return buffer;
}
else
{
Log.Verbose($"BufferBucket({arraySize}) create new");
return new ArrayBuffer(this, arraySize);
}
}
public void Return(ArrayBuffer buffer)
{
DecrementCreated();
buffer.Validate(arraySize);
buffers.Enqueue(buffer);
}
[Conditional("DEBUG")]
void IncrementCreated()
{
int next = Interlocked.Increment(ref _current);
Log.Verbose($"BufferBucket({arraySize}) count:{next}");
}
[Conditional("DEBUG")]
void DecrementCreated()
{
int next = Interlocked.Decrement(ref _current);
Log.Verbose($"BufferBucket({arraySize}) count:{next}");
}
}
/// <summary>
/// Collection of different sized buffers
/// </summary>
/// <remarks>
/// <para>
/// Problem: <br/>
/// * Need to cached byte[] so that new ones aren't created each time <br/>
/// * Arrays sent are multiple different sizes <br/>
/// * Some message might be big so need buffers to cover that size <br/>
/// * Most messages will be small compared to max message size <br/>
/// </para>
/// <br/>
/// <para>
/// Solution: <br/>
/// * Create multiple groups of buffers covering the range of allowed sizes <br/>
/// * Split range exponentially (using math.log) so that there are more groups for small buffers <br/>
/// </para>
/// </remarks>
public class BufferPool
{
internal readonly BufferBucket[] buckets;
readonly int bucketCount;
readonly int smallest;
readonly int largest;
public BufferPool(int bucketCount, int smallest, int largest)
{
if (bucketCount < 2) throw new ArgumentException("Count must be at least 2");
if (smallest < 1) throw new ArgumentException("Smallest must be at least 1");
if (largest < smallest) throw new ArgumentException("Largest must be greater than smallest");
this.bucketCount = bucketCount;
this.smallest = smallest;
this.largest = largest;
// split range over log scale (more buckets for smaller sizes)
double minLog = Math.Log(this.smallest);
double maxLog = Math.Log(this.largest);
double range = maxLog - minLog;
double each = range / (bucketCount - 1);
buckets = new BufferBucket[bucketCount];
for (int i = 0; i < bucketCount; i++)
{
double size = smallest * Math.Pow(Math.E, each * i);
buckets[i] = new BufferBucket((int)Math.Ceiling(size));
}
Validate();
// Example
// 5 count
// 20 smallest
// 16400 largest
// 3.0 log 20
// 9.7 log 16400
// 6.7 range 9.7 - 3
// 1.675 each 6.7 / (5-1)
// 20 e^ (3 + 1.675 * 0)
// 107 e^ (3 + 1.675 * 1)
// 572 e^ (3 + 1.675 * 2)
// 3056 e^ (3 + 1.675 * 3)
// 16,317 e^ (3 + 1.675 * 4)
// perceision wont be lose when using doubles
}
[Conditional("UNITY_ASSERTIONS")]
void Validate()
{
if (buckets[0].arraySize != smallest)
{
Log.Error($"BufferPool Failed to create bucket for smallest. bucket:{buckets[0].arraySize} smallest{smallest}");
}
int largestBucket = buckets[bucketCount - 1].arraySize;
// rounded using Ceiling, so allowed to be 1 more that largest
if (largestBucket != largest && largestBucket != largest + 1)
{
Log.Error($"BufferPool Failed to create bucket for largest. bucket:{largestBucket} smallest{largest}");
}
}
public ArrayBuffer Take(int size)
{
if (size > largest) { throw new ArgumentException($"Size ({size}) is greatest that largest ({largest})"); }
for (int i = 0; i < bucketCount; i++)
{
if (size <= buckets[i].arraySize)
{
return buckets[i].Take();
}
}
throw new ArgumentException($"Size ({size}) is greatest that largest ({largest})");
}
}
}

View file

@ -0,0 +1,90 @@
using System;
using System.Collections.Concurrent;
using System.IO;
using System.Net.Sockets;
using System.Threading;
namespace Mirror.SimpleWeb
{
internal sealed class Connection : IDisposable
{
public const int IdNotSet = -1;
readonly object disposedLock = new object();
public TcpClient client;
public int connId = IdNotSet;
public Stream stream;
public Thread receiveThread;
public Thread sendThread;
public ManualResetEventSlim sendPending = new ManualResetEventSlim(false);
public ConcurrentQueue<ArrayBuffer> sendQueue = new ConcurrentQueue<ArrayBuffer>();
public Action<Connection> onDispose;
volatile bool hasDisposed;
public Connection(TcpClient client, Action<Connection> onDispose)
{
this.client = client ?? throw new ArgumentNullException(nameof(client));
this.onDispose = onDispose;
}
/// <summary>
/// disposes client and stops threads
/// </summary>
public void Dispose()
{
Log.Verbose($"Dispose {ToString()}");
// check hasDisposed first to stop ThreadInterruptedException on lock
if (hasDisposed) { return; }
Log.Info($"Connection Close: {ToString()}");
lock (disposedLock)
{
// check hasDisposed again inside lock to make sure no other object has called this
if (hasDisposed) { return; }
hasDisposed = true;
// stop threads first so they don't try to use disposed objects
receiveThread.Interrupt();
sendThread?.Interrupt();
try
{
// stream
stream?.Dispose();
stream = null;
client.Dispose();
client = null;
}
catch (Exception e)
{
Log.Exception(e);
}
sendPending.Dispose();
// release all buffers in send queue
while (sendQueue.TryDequeue(out ArrayBuffer buffer))
{
buffer.Release();
}
onDispose.Invoke(this);
}
}
public override string ToString()
{
System.Net.EndPoint endpoint = client?.Client?.RemoteEndPoint;
return $"[Conn:{connId}, endPoint:{endpoint}]";
}
}
}

View file

@ -0,0 +1,72 @@
using System.Text;
namespace Mirror.SimpleWeb
{
/// <summary>
/// Constant values that should never change
/// <para>
/// Some values are from https://tools.ietf.org/html/rfc6455
/// </para>
/// </summary>
internal static class Constants
{
/// <summary>
/// Header is at most 4 bytes
/// <para>
/// If message is less than 125 then header is 2 bytes, else header is 4 bytes
/// </para>
/// </summary>
public const int HeaderSize = 4;
/// <summary>
/// Smallest size of header
/// <para>
/// If message is less than 125 then header is 2 bytes, else header is 4 bytes
/// </para>
/// </summary>
public const int HeaderMinSize = 2;
/// <summary>
/// bytes for short length
/// </summary>
public const int ShortLength = 2;
/// <summary>
/// Message mask is always 4 bytes
/// </summary>
public const int MaskSize = 4;
/// <summary>
/// Max size of a message for length to be 1 byte long
/// <para>
/// payload length between 0-125
/// </para>
/// </summary>
public const int BytePayloadLength = 125;
/// <summary>
/// if payload length is 126 when next 2 bytes will be the length
/// </summary>
public const int UshortPayloadLength = 126;
/// <summary>
/// if payload length is 127 when next 8 bytes will be the length
/// </summary>
public const int UlongPayloadLength = 127;
/// <summary>
/// Guid used for WebSocket Protocol
/// </summary>
public const string HandshakeGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
public static readonly int HandshakeGUIDLength = HandshakeGUID.Length;
public static readonly byte[] HandshakeGUIDBytes = Encoding.ASCII.GetBytes(HandshakeGUID);
/// <summary>
/// Handshake messages will end with \r\n\r\n
/// </summary>
public static readonly byte[] endOfHandshake = new byte[4] { (byte)'\r', (byte)'\n', (byte)'\r', (byte)'\n' };
}
}

View file

@ -0,0 +1,10 @@
namespace Mirror.SimpleWeb
{
public enum EventType
{
Connected,
Data,
Disconnected,
Error
}
}

View file

@ -0,0 +1,94 @@
using System;
using Conditional = System.Diagnostics.ConditionalAttribute;
namespace Mirror.SimpleWeb
{
public static class Log
{
// used for Conditional
const string SIMPLEWEB_LOG_ENABLED = nameof(SIMPLEWEB_LOG_ENABLED);
const string DEBUG = nameof(DEBUG);
public enum Levels
{
none = 0,
error = 1,
warn = 2,
info = 3,
verbose = 4,
}
public static Levels level = Levels.none;
public static string BufferToString(byte[] buffer, int offset = 0, int? length = null)
{
return BitConverter.ToString(buffer, offset, length ?? buffer.Length);
}
[Conditional(SIMPLEWEB_LOG_ENABLED)]
public static void DumpBuffer(string label, byte[] buffer, int offset, int length)
{
if (level < Levels.verbose)
return;
}
[Conditional(SIMPLEWEB_LOG_ENABLED)]
public static void DumpBuffer(string label, ArrayBuffer arrayBuffer)
{
if (level < Levels.verbose)
return;
}
[Conditional(SIMPLEWEB_LOG_ENABLED)]
public static void Verbose(string msg, bool showColor = true)
{
if (level < Levels.verbose)
return;
}
[Conditional(SIMPLEWEB_LOG_ENABLED)]
public static void Info(string msg, bool showColor = true)
{
if (level < Levels.info)
return;
}
/// <summary>
/// An expected Exception was caught, useful for debugging but not important
/// </summary>
/// <param name="msg"></param>
/// <param name="showColor"></param>
[Conditional(SIMPLEWEB_LOG_ENABLED)]
public static void InfoException(Exception e)
{
if (level < Levels.info)
return;
}
[Conditional(SIMPLEWEB_LOG_ENABLED), Conditional(DEBUG)]
public static void Warn(string msg, bool showColor = true)
{
if (level < Levels.warn)
return;
}
[Conditional(SIMPLEWEB_LOG_ENABLED), Conditional(DEBUG)]
public static void Error(string msg, bool showColor = true)
{
if (level < Levels.error)
return;
}
public static void Exception(Exception e)
{
// always log Exceptions
Console.WriteLine("SWT Exception: " + e);
}
}
}

View file

@ -0,0 +1,49 @@
using System;
namespace Mirror.SimpleWeb
{
public struct Message
{
public readonly int connId;
public readonly EventType type;
public readonly ArrayBuffer data;
public readonly Exception exception;
public Message(EventType type) : this()
{
this.type = type;
}
public Message(ArrayBuffer data) : this()
{
type = EventType.Data;
this.data = data;
}
public Message(Exception exception) : this()
{
type = EventType.Error;
this.exception = exception;
}
public Message(int connId, EventType type) : this()
{
this.connId = connId;
this.type = type;
}
public Message(int connId, ArrayBuffer data) : this()
{
this.connId = connId;
type = EventType.Data;
this.data = data;
}
public Message(int connId, Exception exception) : this()
{
this.connId = connId;
type = EventType.Error;
this.exception = exception;
}
}
}

View file

@ -0,0 +1,140 @@
using System.IO;
using System.Runtime.CompilerServices;
namespace Mirror.SimpleWeb
{
public static class MessageProcessor
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
static byte FirstLengthByte(byte[] buffer) => (byte)(buffer[1] & 0b0111_1111);
public static bool NeedToReadShortLength(byte[] buffer)
{
byte lenByte = FirstLengthByte(buffer);
return lenByte >= Constants.UshortPayloadLength;
}
public static int GetOpcode(byte[] buffer)
{
return buffer[0] & 0b0000_1111;
}
public static int GetPayloadLength(byte[] buffer)
{
byte lenByte = FirstLengthByte(buffer);
return GetMessageLength(buffer, 0, lenByte);
}
public static void ValidateHeader(byte[] buffer, int maxLength, bool expectMask)
{
bool finished = (buffer[0] & 0b1000_0000) != 0; // has full message been sent
bool hasMask = (buffer[1] & 0b1000_0000) != 0; // true from clients, false from server, "All messages from the client to the server have this bit set"
int opcode = buffer[0] & 0b0000_1111; // expecting 1 - text message
byte lenByte = FirstLengthByte(buffer);
ThrowIfNotFinished(finished);
ThrowIfMaskNotExpected(hasMask, expectMask);
ThrowIfBadOpCode(opcode);
int msglen = GetMessageLength(buffer, 0, lenByte);
ThrowIfLengthZero(msglen);
ThrowIfMsgLengthTooLong(msglen, maxLength);
}
public static void ToggleMask(byte[] src, int sourceOffset, int messageLength, byte[] maskBuffer, int maskOffset)
{
ToggleMask(src, sourceOffset, src, sourceOffset, messageLength, maskBuffer, maskOffset);
}
public static void ToggleMask(byte[] src, int sourceOffset, ArrayBuffer dst, int messageLength, byte[] maskBuffer, int maskOffset)
{
ToggleMask(src, sourceOffset, dst.array, 0, messageLength, maskBuffer, maskOffset);
dst.count = messageLength;
}
public static void ToggleMask(byte[] src, int srcOffset, byte[] dst, int dstOffset, int messageLength, byte[] maskBuffer, int maskOffset)
{
for (int i = 0; i < messageLength; i++)
{
byte maskByte = maskBuffer[maskOffset + i % Constants.MaskSize];
dst[dstOffset + i] = (byte)(src[srcOffset + i] ^ maskByte);
}
}
/// <exception cref="InvalidDataException"></exception>
static int GetMessageLength(byte[] buffer, int offset, byte lenByte)
{
if (lenByte == Constants.UshortPayloadLength)
{
// header is 4 bytes long
ushort value = 0;
value |= (ushort)(buffer[offset + 2] << 8);
value |= buffer[offset + 3];
return value;
}
else if (lenByte == Constants.UlongPayloadLength)
{
throw new InvalidDataException("Max length is longer than allowed in a single message");
}
else // is less than 126
{
// header is 2 bytes long
return lenByte;
}
}
/// <exception cref="InvalidDataException"></exception>
static void ThrowIfNotFinished(bool finished)
{
if (!finished)
{
throw new InvalidDataException("Full message should have been sent, if the full message wasn't sent it wasn't sent from this trasnport");
}
}
/// <exception cref="InvalidDataException"></exception>
static void ThrowIfMaskNotExpected(bool hasMask, bool expectMask)
{
if (hasMask != expectMask)
{
throw new InvalidDataException($"Message expected mask to be {expectMask} but was {hasMask}");
}
}
/// <exception cref="InvalidDataException"></exception>
static void ThrowIfBadOpCode(int opcode)
{
// 2 = binary
// 8 = close
if (opcode != 2 && opcode != 8)
{
throw new InvalidDataException("Expected opcode to be binary or close");
}
}
/// <exception cref="InvalidDataException"></exception>
static void ThrowIfLengthZero(int msglen)
{
if (msglen == 0)
{
throw new InvalidDataException("Message length was zero");
}
}
/// <summary>
/// need to check this so that data from previous buffer isn't used
/// </summary>
/// <exception cref="InvalidDataException"></exception>
static void ThrowIfMsgLengthTooLong(int msglen, int maxLength)
{
if (msglen > maxLength)
{
throw new InvalidDataException("Message length is greater than max length");
}
}
}
}

View file

@ -0,0 +1,132 @@
using System;
using System.IO;
using System.Runtime.Serialization;
namespace Mirror.SimpleWeb
{
public static class ReadHelper
{
/// <summary>
/// Reads exactly length from stream
/// </summary>
/// <returns>outOffset + length</returns>
/// <exception cref="ReadHelperException"></exception>
public static int Read(Stream stream, byte[] outBuffer, int outOffset, int length)
{
int received = 0;
try
{
while (received < length)
{
int read = stream.Read(outBuffer, outOffset + received, length - received);
if (read == 0)
{
throw new ReadHelperException("returned 0");
}
received += read;
}
}
catch (AggregateException ae)
{
// if interrupt is called we don't care about Exceptions
Utils.CheckForInterupt();
// rethrow
ae.Handle(e => false);
}
if (received != length)
{
throw new ReadHelperException("returned not equal to length");
}
return outOffset + received;
}
/// <summary>
/// Reads and returns results. This should never throw an exception
/// </summary>
public static bool TryRead(Stream stream, byte[] outBuffer, int outOffset, int length)
{
try
{
Read(stream, outBuffer, outOffset, length);
return true;
}
catch (ReadHelperException)
{
return false;
}
catch (IOException)
{
return false;
}
catch (Exception e)
{
Log.Exception(e);
return false;
}
}
public static int? SafeReadTillMatch(Stream stream, byte[] outBuffer, int outOffset, int maxLength, byte[] endOfHeader)
{
try
{
int read = 0;
int endIndex = 0;
int endLength = endOfHeader.Length;
while (true)
{
int next = stream.ReadByte();
if (next == -1) // closed
return null;
if (read >= maxLength)
{
Log.Error("SafeReadTillMatch exceeded maxLength");
return null;
}
outBuffer[outOffset + read] = (byte)next;
read++;
// if n is match, check n+1 next
if (endOfHeader[endIndex] == next)
{
endIndex++;
// when all is match return with read length
if (endIndex >= endLength)
{
return read;
}
}
// if n not match reset to 0
else
{
endIndex = 0;
}
}
}
catch (IOException e)
{
Log.InfoException(e);
return null;
}
catch (Exception e)
{
Log.Exception(e);
return null;
}
}
}
[Serializable]
public class ReadHelperException : Exception
{
public ReadHelperException(string message) : base(message) {}
protected ReadHelperException(SerializationInfo info, StreamingContext context) : base(info, context)
{
}
}
}

View file

@ -0,0 +1,194 @@
using System;
using System.Collections.Concurrent;
using System.IO;
using System.Net.Sockets;
using System.Text;
using System.Threading;
namespace Mirror.SimpleWeb
{
internal static class ReceiveLoop
{
public struct Config
{
public readonly Connection conn;
public readonly int maxMessageSize;
public readonly bool expectMask;
public readonly ConcurrentQueue<Message> queue;
public readonly BufferPool bufferPool;
public Config(Connection conn, int maxMessageSize, bool expectMask, ConcurrentQueue<Message> queue, BufferPool bufferPool)
{
this.conn = conn ?? throw new ArgumentNullException(nameof(conn));
this.maxMessageSize = maxMessageSize;
this.expectMask = expectMask;
this.queue = queue ?? throw new ArgumentNullException(nameof(queue));
this.bufferPool = bufferPool ?? throw new ArgumentNullException(nameof(bufferPool));
}
public void Deconstruct(out Connection conn, out int maxMessageSize, out bool expectMask, out ConcurrentQueue<Message> queue, out BufferPool bufferPool)
{
conn = this.conn;
maxMessageSize = this.maxMessageSize;
expectMask = this.expectMask;
queue = this.queue;
bufferPool = this.bufferPool;
}
}
public static void Loop(Config config)
{
(Connection conn, int maxMessageSize, bool expectMask, ConcurrentQueue<Message> queue, BufferPool _) = config;
byte[] readBuffer = new byte[Constants.HeaderSize + (expectMask ? Constants.MaskSize : 0) + maxMessageSize];
try
{
try
{
TcpClient client = conn.client;
while (client.Connected)
{
ReadOneMessage(config, readBuffer);
}
Log.Info($"{conn} Not Connected");
}
catch (Exception)
{
// if interrupted we don't care about other exceptions
Utils.CheckForInterupt();
throw;
}
}
catch (ThreadInterruptedException e) { Log.InfoException(e); }
catch (ThreadAbortException e) { Log.InfoException(e); }
catch (ObjectDisposedException e) { Log.InfoException(e); }
catch (ReadHelperException e)
{
// log as info only
Log.InfoException(e);
}
catch (SocketException e)
{
// this could happen if wss client closes stream
Log.Warn($"ReceiveLoop SocketException\n{e.Message}", false);
queue.Enqueue(new Message(conn.connId, e));
}
catch (IOException e)
{
// this could happen if client disconnects
Log.Warn($"ReceiveLoop IOException\n{e.Message}", false);
queue.Enqueue(new Message(conn.connId, e));
}
catch (InvalidDataException e)
{
Log.Error($"Invalid data from {conn}: {e.Message}");
queue.Enqueue(new Message(conn.connId, e));
}
catch (Exception e)
{
Log.Exception(e);
queue.Enqueue(new Message(conn.connId, e));
}
finally
{
conn.Dispose();
}
}
static void ReadOneMessage(Config config, byte[] buffer)
{
(Connection conn, int maxMessageSize, bool expectMask, ConcurrentQueue<Message> queue, BufferPool bufferPool) = config;
Stream stream = conn.stream;
int offset = 0;
// read 2
offset = ReadHelper.Read(stream, buffer, offset, Constants.HeaderMinSize);
// log after first blocking call
Log.Verbose($"Message From {conn}");
if (MessageProcessor.NeedToReadShortLength(buffer))
{
offset = ReadHelper.Read(stream, buffer, offset, Constants.ShortLength);
}
MessageProcessor.ValidateHeader(buffer, maxMessageSize, expectMask);
if (expectMask)
{
offset = ReadHelper.Read(stream, buffer, offset, Constants.MaskSize);
}
int opcode = MessageProcessor.GetOpcode(buffer);
int payloadLength = MessageProcessor.GetPayloadLength(buffer);
Log.Verbose($"Header ln:{payloadLength} op:{opcode} mask:{expectMask}");
Log.DumpBuffer($"Raw Header", buffer, 0, offset);
int msgOffset = offset;
offset = ReadHelper.Read(stream, buffer, offset, payloadLength);
switch (opcode)
{
case 2:
HandleArrayMessage(config, buffer, msgOffset, payloadLength);
break;
case 8:
HandleCloseMessage(config, buffer, msgOffset, payloadLength);
break;
}
}
static void HandleArrayMessage(Config config, byte[] buffer, int msgOffset, int payloadLength)
{
(Connection conn, int _, bool expectMask, ConcurrentQueue<Message> queue, BufferPool bufferPool) = config;
ArrayBuffer arrayBuffer = bufferPool.Take(payloadLength);
if (expectMask)
{
int maskOffset = msgOffset - Constants.MaskSize;
// write the result of toggle directly into arrayBuffer to avoid 2nd copy call
MessageProcessor.ToggleMask(buffer, msgOffset, arrayBuffer, payloadLength, buffer, maskOffset);
}
else
{
arrayBuffer.CopyFrom(buffer, msgOffset, payloadLength);
}
// dump after mask off
Log.DumpBuffer($"Message", arrayBuffer);
queue.Enqueue(new Message(conn.connId, arrayBuffer));
}
static void HandleCloseMessage(Config config, byte[] buffer, int msgOffset, int payloadLength)
{
(Connection conn, int _, bool expectMask, ConcurrentQueue<Message> _, BufferPool _) = config;
if (expectMask)
{
int maskOffset = msgOffset - Constants.MaskSize;
MessageProcessor.ToggleMask(buffer, msgOffset, payloadLength, buffer, maskOffset);
}
// dump after mask off
Log.DumpBuffer($"Message", buffer, msgOffset, payloadLength);
Log.Info($"Close: {GetCloseCode(buffer, msgOffset)} message:{GetCloseMessage(buffer, msgOffset, payloadLength)}");
conn.Dispose();
}
static string GetCloseMessage(byte[] buffer, int msgOffset, int payloadLength)
{
return Encoding.UTF8.GetString(buffer, msgOffset + 2, payloadLength - 2);
}
static int GetCloseCode(byte[] buffer, int msgOffset)
{
return buffer[msgOffset + 0] << 8 | buffer[msgOffset + 1];
}
}
}

View file

@ -0,0 +1,203 @@
using System;
using System.IO;
using System.Net.Sockets;
using System.Security.Cryptography;
using System.Threading;
namespace Mirror.SimpleWeb
{
public static class SendLoopConfig
{
public static volatile bool batchSend = false;
public static volatile bool sleepBeforeSend = false;
}
internal static class SendLoop
{
public struct Config
{
public readonly Connection conn;
public readonly int bufferSize;
public readonly bool setMask;
public Config(Connection conn, int bufferSize, bool setMask)
{
this.conn = conn ?? throw new ArgumentNullException(nameof(conn));
this.bufferSize = bufferSize;
this.setMask = setMask;
}
public void Deconstruct(out Connection conn, out int bufferSize, out bool setMask)
{
conn = this.conn;
bufferSize = this.bufferSize;
setMask = this.setMask;
}
}
public static void Loop(Config config)
{
(Connection conn, int bufferSize, bool setMask) = config;
// create write buffer for this thread
byte[] writeBuffer = new byte[bufferSize];
MaskHelper maskHelper = setMask ? new MaskHelper() : null;
try
{
TcpClient client = conn.client;
Stream stream = conn.stream;
// null check in case disconnect while send thread is starting
if (client == null)
return;
while (client.Connected)
{
// wait for message
conn.sendPending.Wait();
// wait for 1ms for mirror to send other messages
if (SendLoopConfig.sleepBeforeSend)
{
Thread.Sleep(1);
}
conn.sendPending.Reset();
if (SendLoopConfig.batchSend)
{
int offset = 0;
while (conn.sendQueue.TryDequeue(out ArrayBuffer msg))
{
// check if connected before sending message
if (!client.Connected) { Log.Info($"SendLoop {conn} not connected"); return; }
int maxLength = msg.count + Constants.HeaderSize + Constants.MaskSize;
// if next writer could overflow, write to stream and clear buffer
if (offset + maxLength > bufferSize)
{
stream.Write(writeBuffer, 0, offset);
offset = 0;
}
offset = SendMessage(writeBuffer, offset, msg, setMask, maskHelper);
msg.Release();
}
// after no message in queue, send remaining messages
// don't need to check offset > 0 because last message in queue will always be sent here
stream.Write(writeBuffer, 0, offset);
}
else
{
while (conn.sendQueue.TryDequeue(out ArrayBuffer msg))
{
// check if connected before sending message
if (!client.Connected) { Log.Info($"SendLoop {conn} not connected"); return; }
int length = SendMessage(writeBuffer, 0, msg, setMask, maskHelper);
stream.Write(writeBuffer, 0, length);
msg.Release();
}
}
}
Log.Info($"{conn} Not Connected");
}
catch (ThreadInterruptedException e) { Log.InfoException(e); }
catch (ThreadAbortException e) { Log.InfoException(e); }
catch (Exception e)
{
Log.Exception(e);
}
finally
{
conn.Dispose();
maskHelper?.Dispose();
}
}
/// <returns>new offset in buffer</returns>
static int SendMessage(byte[] buffer, int startOffset, ArrayBuffer msg, bool setMask, MaskHelper maskHelper)
{
int msgLength = msg.count;
int offset = WriteHeader(buffer, startOffset, msgLength, setMask);
if (setMask)
{
offset = maskHelper.WriteMask(buffer, offset);
}
msg.CopyTo(buffer, offset);
offset += msgLength;
// dump before mask on
Log.DumpBuffer("Send", buffer, startOffset, offset);
if (setMask)
{
int messageOffset = offset - msgLength;
MessageProcessor.ToggleMask(buffer, messageOffset, msgLength, buffer, messageOffset - Constants.MaskSize);
}
return offset;
}
static int WriteHeader(byte[] buffer, int startOffset, int msgLength, bool setMask)
{
int sendLength = 0;
const byte finished = 128;
const byte byteOpCode = 2;
buffer[startOffset + 0] = finished | byteOpCode;
sendLength++;
if (msgLength <= Constants.BytePayloadLength)
{
buffer[startOffset + 1] = (byte)msgLength;
sendLength++;
}
else if (msgLength <= ushort.MaxValue)
{
buffer[startOffset + 1] = 126;
buffer[startOffset + 2] = (byte)(msgLength >> 8);
buffer[startOffset + 3] = (byte)msgLength;
sendLength += 3;
}
else
{
throw new InvalidDataException($"Trying to send a message larger than {ushort.MaxValue} bytes");
}
if (setMask)
{
buffer[startOffset + 1] |= 0b1000_0000;
}
return sendLength + startOffset;
}
sealed class MaskHelper : IDisposable
{
readonly byte[] maskBuffer;
readonly RNGCryptoServiceProvider random;
public MaskHelper()
{
maskBuffer = new byte[4];
random = new RNGCryptoServiceProvider();
}
public void Dispose()
{
random.Dispose();
}
public int WriteMask(byte[] buffer, int offset)
{
random.GetBytes(maskBuffer);
Buffer.BlockCopy(maskBuffer, 0, buffer, offset, 4);
return offset + 4;
}
}
}
}

View file

@ -0,0 +1,25 @@
using System.Net.Sockets;
namespace Mirror.SimpleWeb
{
public struct TcpConfig
{
public readonly bool noDelay;
public readonly int sendTimeout;
public readonly int receiveTimeout;
public TcpConfig(bool noDelay, int sendTimeout, int receiveTimeout)
{
this.noDelay = noDelay;
this.sendTimeout = sendTimeout;
this.receiveTimeout = receiveTimeout;
}
public void ApplyTo(TcpClient client)
{
client.SendTimeout = sendTimeout;
client.ReceiveTimeout = receiveTimeout;
client.NoDelay = noDelay;
}
}
}

View file

@ -0,0 +1,13 @@
using System.Threading;
namespace Mirror.SimpleWeb
{
internal static class Utils
{
public static void CheckForInterupt()
{
// sleep in order to check for ThreadInterruptedException
Thread.Sleep(1);
}
}
}

View file

@ -0,0 +1,22 @@
SimpleWebTransport is a Transport that implements websocket for Webgl builds of
mirror. This transport can also work on standalone builds and has support for
encryption with websocket secure.
How to use:
Replace your existing Transport with SimpleWebTransport on your NetworkManager
Requirements:
Unity 2018.4 LTS
Mirror v18.0.0
Documentation:
https://mirror-networking.com/docs/
https://github.com/MirrorNetworking/SimpleWebTransport/blob/master/README.md
Support:
Discord: https://discordapp.com/invite/N9QVxbM
Bug Reports: https://github.com/MirrorNetworking/SimpleWebTransport/issues
**To get most recent updates and fixes download from github**
https://github.com/MirrorNetworking/SimpleWebTransport/releases

View file

@ -0,0 +1,34 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Security.Authentication;
using System.Text;
using System.Threading.Tasks;
namespace Mirror
{
class SWTConfig
{
public int maxMessageSize = 16 * 1024;
public int handshakeMaxSize = 3000;
public bool noDelay = true;
public int sendTimeout = 5000;
public int receiveTimeout = 20000;
public int serverMaxMessagesPerTick = 10000;
public bool waitBeforeSend = false;
public bool clientUseWss;
public bool sslEnabled;
public string sslCertJson = "./cert.json";
public SslProtocols sslProtocols = SslProtocols.Tls12;
}
}

View file

@ -0,0 +1,149 @@
using System;
using System.IO;
using System.Security.Cryptography;
using System.Text;
namespace Mirror.SimpleWeb
{
/// <summary>
/// Handles Handshakes from new clients on the server
/// <para>The server handshake has buffers to reduce allocations when clients connect</para>
/// </summary>
internal class ServerHandshake
{
const int GetSize = 3;
const int ResponseLength = 129;
const int KeyLength = 24;
const int MergedKeyLength = 60;
const string KeyHeaderString = "Sec-WebSocket-Key: ";
// this isn't an official max, just a reasonable size for a websocket handshake
readonly int maxHttpHeaderSize = 3000;
readonly SHA1 sha1 = SHA1.Create();
readonly BufferPool bufferPool;
public ServerHandshake(BufferPool bufferPool, int handshakeMaxSize)
{
this.bufferPool = bufferPool;
this.maxHttpHeaderSize = handshakeMaxSize;
}
~ServerHandshake()
{
sha1.Dispose();
}
public bool TryHandshake(Connection conn)
{
Stream stream = conn.stream;
using (ArrayBuffer getHeader = bufferPool.Take(GetSize))
{
if (!ReadHelper.TryRead(stream, getHeader.array, 0, GetSize))
return false;
getHeader.count = GetSize;
if (!IsGet(getHeader.array))
{
Log.Warn($"First bytes from client was not 'GET' for handshake, instead was {Log.BufferToString(getHeader.array, 0, GetSize)}");
return false;
}
}
string msg = ReadToEndForHandshake(stream);
if (string.IsNullOrEmpty(msg))
return false;
try
{
AcceptHandshake(stream, msg);
return true;
}
catch (ArgumentException e)
{
Log.InfoException(e);
return false;
}
}
string ReadToEndForHandshake(Stream stream)
{
using (ArrayBuffer readBuffer = bufferPool.Take(maxHttpHeaderSize))
{
int? readCountOrFail = ReadHelper.SafeReadTillMatch(stream, readBuffer.array, 0, maxHttpHeaderSize, Constants.endOfHandshake);
if (!readCountOrFail.HasValue)
return null;
int readCount = readCountOrFail.Value;
string msg = Encoding.ASCII.GetString(readBuffer.array, 0, readCount);
Log.Verbose(msg);
return msg;
}
}
static bool IsGet(byte[] getHeader)
{
// just check bytes here instead of using Encoding.ASCII
return getHeader[0] == 71 && // G
getHeader[1] == 69 && // E
getHeader[2] == 84; // T
}
void AcceptHandshake(Stream stream, string msg)
{
using (
ArrayBuffer keyBuffer = bufferPool.Take(KeyLength),
responseBuffer = bufferPool.Take(ResponseLength))
{
GetKey(msg, keyBuffer.array);
AppendGuid(keyBuffer.array);
byte[] keyHash = CreateHash(keyBuffer.array);
CreateResponse(keyHash, responseBuffer.array);
stream.Write(responseBuffer.array, 0, ResponseLength);
}
}
static void GetKey(string msg, byte[] keyBuffer)
{
int start = msg.IndexOf(KeyHeaderString) + KeyHeaderString.Length;
Log.Verbose($"Handshake Key: {msg.Substring(start, KeyLength)}");
Encoding.ASCII.GetBytes(msg, start, KeyLength, keyBuffer, 0);
}
static void AppendGuid(byte[] keyBuffer)
{
Buffer.BlockCopy(Constants.HandshakeGUIDBytes, 0, keyBuffer, KeyLength, Constants.HandshakeGUID.Length);
}
byte[] CreateHash(byte[] keyBuffer)
{
Log.Verbose($"Handshake Hashing {Encoding.ASCII.GetString(keyBuffer, 0, MergedKeyLength)}");
return sha1.ComputeHash(keyBuffer, 0, MergedKeyLength);
}
static void CreateResponse(byte[] keyHash, byte[] responseBuffer)
{
string keyHashString = Convert.ToBase64String(keyHash);
// compiler should merge these strings into 1 string before format
string message = string.Format(
"HTTP/1.1 101 Switching Protocols\r\n" +
"Connection: Upgrade\r\n" +
"Upgrade: websocket\r\n" +
"Sec-WebSocket-Accept: {0}\r\n\r\n",
keyHashString);
Log.Verbose($"Handshake Response length {message.Length}, IsExpected {message.Length == ResponseLength}");
Encoding.ASCII.GetBytes(message, 0, ResponseLength, responseBuffer, 0);
}
}
}

View file

@ -0,0 +1,74 @@
using System;
using System.IO;
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
namespace Mirror.SimpleWeb
{
public struct SslConfig
{
public readonly bool enabled;
public readonly string certPath;
public readonly string certPassword;
public readonly SslProtocols sslProtocols;
public SslConfig(bool enabled, string certPath, string certPassword, SslProtocols sslProtocols)
{
this.enabled = enabled;
this.certPath = certPath;
this.certPassword = certPassword;
this.sslProtocols = sslProtocols;
}
}
internal class ServerSslHelper
{
readonly SslConfig config;
readonly X509Certificate2 certificate;
public ServerSslHelper(SslConfig sslConfig)
{
config = sslConfig;
if (config.enabled)
certificate = new X509Certificate2(config.certPath, config.certPassword);
}
internal bool TryCreateStream(Connection conn)
{
NetworkStream stream = conn.client.GetStream();
if (config.enabled)
{
try
{
conn.stream = CreateStream(stream);
return true;
}
catch (Exception e)
{
Log.Error($"Create SSLStream Failed: {e}", false);
return false;
}
}
else
{
conn.stream = stream;
return true;
}
}
Stream CreateStream(NetworkStream stream)
{
SslStream sslStream = new SslStream(stream, true, acceptClient);
sslStream.AuthenticateAsServer(certificate, false, config.sslProtocols, false);
return sslStream;
}
bool acceptClient(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors)
{
// always accept client
return true;
}
}
}

View file

@ -0,0 +1,103 @@
using System;
using System.Collections.Generic;
namespace Mirror.SimpleWeb
{
public class SimpleWebServer
{
readonly int maxMessagesPerTick;
readonly WebSocketServer server;
readonly BufferPool bufferPool;
public SimpleWebServer(int maxMessagesPerTick, TcpConfig tcpConfig, int maxMessageSize, int handshakeMaxSize, SslConfig sslConfig)
{
this.maxMessagesPerTick = maxMessagesPerTick;
// use max because bufferpool is used for both messages and handshake
int max = Math.Max(maxMessageSize, handshakeMaxSize);
bufferPool = new BufferPool(5, 20, max);
server = new WebSocketServer(tcpConfig, maxMessageSize, handshakeMaxSize, sslConfig, bufferPool);
}
public bool Active { get; private set; }
public event Action<int> onConnect;
public event Action<int> onDisconnect;
public event Action<int, ArraySegment<byte>> onData;
public event Action<int, Exception> onError;
public void Start(ushort port)
{
server.Listen(port);
Active = true;
}
public void Stop()
{
server.Stop();
Active = false;
}
public void SendAll(List<int> connectionIds, ArraySegment<byte> source)
{
ArrayBuffer buffer = bufferPool.Take(source.Count);
buffer.CopyFrom(source);
buffer.SetReleasesRequired(connectionIds.Count);
// make copy of array before for each, data sent to each client is the same
foreach (int id in connectionIds)
{
server.Send(id, buffer);
}
}
public void SendOne(int connectionId, ArraySegment<byte> source)
{
ArrayBuffer buffer = bufferPool.Take(source.Count);
buffer.CopyFrom(source);
server.Send(connectionId, buffer);
}
public bool KickClient(int connectionId)
{
return server.CloseConnection(connectionId);
}
public string GetClientAddress(int connectionId)
{
return server.GetClientAddress(connectionId);
}
public void ProcessMessageQueue()
{
int processedCount = 0;
// check enabled every time in case behaviour was disabled after data
while (
processedCount < maxMessagesPerTick &&
// Dequeue last
server.receiveQueue.TryDequeue(out Message next)
)
{
processedCount++;
switch (next.type)
{
case EventType.Connected:
onConnect?.Invoke(next.connId);
break;
case EventType.Data:
onData?.Invoke(next.connId, next.data.ToSegment());
next.data.Release();
break;
case EventType.Disconnected:
onDisconnect?.Invoke(next.connId);
break;
case EventType.Error:
onError?.Invoke(next.connId, next.exception);
break;
}
}
}
}
}

View file

@ -0,0 +1,230 @@
using System;
using System.Collections.Concurrent;
using System.Linq;
using System.Net.Sockets;
using System.Threading;
namespace Mirror.SimpleWeb
{
public class WebSocketServer
{
public readonly ConcurrentQueue<Message> receiveQueue = new ConcurrentQueue<Message>();
readonly TcpConfig tcpConfig;
readonly int maxMessageSize;
TcpListener listener;
Thread acceptThread;
bool serverStopped;
readonly ServerHandshake handShake;
readonly ServerSslHelper sslHelper;
readonly BufferPool bufferPool;
readonly ConcurrentDictionary<int, Connection> connections = new ConcurrentDictionary<int, Connection>();
int _idCounter = 0;
public WebSocketServer(TcpConfig tcpConfig, int maxMessageSize, int handshakeMaxSize, SslConfig sslConfig, BufferPool bufferPool)
{
this.tcpConfig = tcpConfig;
this.maxMessageSize = maxMessageSize;
sslHelper = new ServerSslHelper(sslConfig);
this.bufferPool = bufferPool;
handShake = new ServerHandshake(this.bufferPool, handshakeMaxSize);
}
public void Listen(int port)
{
listener = TcpListener.Create(port);
listener.Start();
Log.Info($"Server has started on port {port}");
acceptThread = new Thread(acceptLoop);
acceptThread.IsBackground = true;
acceptThread.Start();
}
public void Stop()
{
serverStopped = true;
// Interrupt then stop so that Exception is handled correctly
acceptThread?.Interrupt();
listener?.Stop();
acceptThread = null;
Log.Info("Server stopped, Closing all connections...");
// make copy so that foreach doesn't break if values are removed
Connection[] connectionsCopy = connections.Values.ToArray();
foreach (Connection conn in connectionsCopy)
{
conn.Dispose();
}
connections.Clear();
}
void acceptLoop()
{
try
{
try
{
while (true)
{
TcpClient client = listener.AcceptTcpClient();
tcpConfig.ApplyTo(client);
// TODO keep track of connections before they are in connections dictionary
// this might not be a problem as HandshakeAndReceiveLoop checks for stop
// and returns/disposes before sending message to queue
Connection conn = new Connection(client, AfterConnectionDisposed);
Log.Info($"A client connected {conn}");
// handshake needs its own thread as it needs to wait for message from client
Thread receiveThread = new Thread(() => HandshakeAndReceiveLoop(conn));
conn.receiveThread = receiveThread;
receiveThread.IsBackground = true;
receiveThread.Start();
}
}
catch (SocketException)
{
// check for Interrupted/Abort
Utils.CheckForInterupt();
throw;
}
}
catch (ThreadInterruptedException e) { Log.InfoException(e); }
catch (ThreadAbortException e) { Log.InfoException(e); }
catch (Exception e) { Log.Exception(e); }
}
void HandshakeAndReceiveLoop(Connection conn)
{
try
{
bool success = sslHelper.TryCreateStream(conn);
if (!success)
{
Log.Error($"Failed to create SSL Stream {conn}");
conn.Dispose();
return;
}
success = handShake.TryHandshake(conn);
if (success)
{
Log.Info($"Sent Handshake {conn}");
}
else
{
Log.Error($"Handshake Failed {conn}");
conn.Dispose();
return;
}
// check if Stop has been called since accepting this client
if (serverStopped)
{
Log.Info("Server stops after successful handshake");
return;
}
conn.connId = Interlocked.Increment(ref _idCounter);
connections.TryAdd(conn.connId, conn);
receiveQueue.Enqueue(new Message(conn.connId, EventType.Connected));
Thread sendThread = new Thread(() =>
{
SendLoop.Config sendConfig = new SendLoop.Config(
conn,
bufferSize: Constants.HeaderSize + maxMessageSize,
setMask: false);
SendLoop.Loop(sendConfig);
});
conn.sendThread = sendThread;
sendThread.IsBackground = true;
sendThread.Name = $"SendLoop {conn.connId}";
sendThread.Start();
ReceiveLoop.Config receiveConfig = new ReceiveLoop.Config(
conn,
maxMessageSize,
expectMask: true,
receiveQueue,
bufferPool);
ReceiveLoop.Loop(receiveConfig);
}
catch (ThreadInterruptedException e) { Log.InfoException(e); }
catch (ThreadAbortException e) { Log.InfoException(e); }
catch (Exception e) { Log.Exception(e); }
finally
{
// close here in case connect fails
conn.Dispose();
}
}
void AfterConnectionDisposed(Connection conn)
{
if (conn.connId != Connection.IdNotSet)
{
receiveQueue.Enqueue(new Message(conn.connId, EventType.Disconnected));
connections.TryRemove(conn.connId, out Connection _);
}
}
public void Send(int id, ArrayBuffer buffer)
{
if (connections.TryGetValue(id, out Connection conn))
{
conn.sendQueue.Enqueue(buffer);
conn.sendPending.Set();
}
else
{
Log.Warn($"Cant send message to {id} because connection was not found in dictionary. Maybe it disconnected.");
}
}
public bool CloseConnection(int id)
{
if (connections.TryGetValue(id, out Connection conn))
{
Log.Info($"Kicking connection {id}");
conn.Dispose();
return true;
}
else
{
Log.Warn($"Failed to kick {id} because id not found");
return false;
}
}
public string GetClientAddress(int id)
{
if (connections.TryGetValue(id, out Connection conn))
{
return conn.client.Client.RemoteEndPoint.ToString();
}
else
{
Log.Error($"Cant close connection to {id} because connection was not found in dictionary");
return null;
}
}
}
}

View file

@ -0,0 +1,225 @@
using Mirror.SimpleWeb;
using Newtonsoft.Json;
using System;
using System.IO;
using System.Net;
using System.Security.Authentication;
namespace Mirror
{
public class SimpleWebTransport : Transport
{
public const string NormalScheme = "ws";
public const string SecureScheme = "wss";
public int maxMessageSize = 16 * 1024;
public int handshakeMaxSize = 3000;
public bool noDelay = true;
public int sendTimeout = 5000;
public int receiveTimeout = 20000;
public int serverMaxMessagesPerTick = 10000;
public int clientMaxMessagesPerTick = 1000;
public bool batchSend = true;
public bool waitBeforeSend = false;
public bool clientUseWss;
public bool sslEnabled;
public string sslCertJson = "./cert.json";
public SslProtocols sslProtocols = SslProtocols.Tls12;
Log.Levels _logLevels = Log.Levels.none;
/// <summary>
/// <para>Gets _logLevels field</para>
/// <para>Sets _logLevels and Log.level fields</para>
/// </summary>
public Log.Levels LogLevels
{
get => _logLevels;
set
{
_logLevels = value;
Log.level = _logLevels;
}
}
void OnValidate()
{
if (maxMessageSize > ushort.MaxValue)
{
Console.WriteLine($"max supported value for maxMessageSize is {ushort.MaxValue}");
maxMessageSize = ushort.MaxValue;
}
Log.level = _logLevels;
}
SimpleWebServer server;
TcpConfig TcpConfig => new TcpConfig(noDelay, sendTimeout, receiveTimeout);
public override bool Available()
{
return true;
}
public override int GetMaxPacketSize(int channelId = 0)
{
return maxMessageSize;
}
void Awake()
{
Log.level = _logLevels;
SWTConfig conf = new SWTConfig();
if (!File.Exists("SWTConfig.json"))
{
File.WriteAllText("SWTConfig.json", JsonConvert.SerializeObject(conf, Formatting.Indented));
}
else
{
conf = JsonConvert.DeserializeObject<SWTConfig>(File.ReadAllText("SWTConfig.json"));
}
maxMessageSize = conf.maxMessageSize;
handshakeMaxSize = conf.handshakeMaxSize;
noDelay = conf.noDelay;
sendTimeout = conf.sendTimeout;
receiveTimeout = conf.receiveTimeout;
serverMaxMessagesPerTick = conf.serverMaxMessagesPerTick;
waitBeforeSend = conf.waitBeforeSend;
clientUseWss = conf.clientUseWss;
sslEnabled = conf.sslEnabled;
sslCertJson = conf.sslCertJson;
sslProtocols = conf.sslProtocols;
}
public override void Shutdown()
{
server?.Stop();
server = null;
}
#region Client
string GetClientScheme() => (sslEnabled || clientUseWss) ? SecureScheme : NormalScheme;
string GetServerScheme() => sslEnabled ? SecureScheme : NormalScheme;
public override bool ClientConnected()
{
// not null and not NotConnected (we want to return true if connecting or disconnecting)
return false;
}
public override void ClientConnect(string hostname) { }
public override void ClientDisconnect() { }
public override void ClientSend(int channelId, ArraySegment<byte> segment) { }
#endregion
#region Server
public override bool ServerActive()
{
return server != null && server.Active;
}
public override void ServerStart(ushort requestedPort)
{
if (ServerActive())
{
Console.WriteLine("SimpleWebServer Already Started");
}
SslConfig config = SslConfigLoader.Load(this);
server = new SimpleWebServer(serverMaxMessagesPerTick, TcpConfig, maxMessageSize, handshakeMaxSize, config);
server.onConnect += OnServerConnected.Invoke;
server.onDisconnect += OnServerDisconnected.Invoke;
server.onData += (int connId, ArraySegment<byte> data) => OnServerDataReceived.Invoke(connId, data, 0);
server.onError += OnServerError.Invoke;
SendLoopConfig.batchSend = batchSend || waitBeforeSend;
SendLoopConfig.sleepBeforeSend = waitBeforeSend;
server.Start(requestedPort);
}
public override void ServerStop()
{
if (!ServerActive())
{
Console.WriteLine("SimpleWebServer Not Active");
}
server.Stop();
server = null;
}
public override bool ServerDisconnect(int connectionId)
{
if (!ServerActive())
{
Console.WriteLine("SimpleWebServer Not Active");
return false;
}
return server.KickClient(connectionId);
}
public override void ServerSend(int connectionId, int channelId, ArraySegment<byte> segment)
{
if (!ServerActive())
{
Console.WriteLine("SimpleWebServer Not Active");
return;
}
if (segment.Count > maxMessageSize)
{
Console.WriteLine("Message greater than max size");
return;
}
if (segment.Count == 0)
{
Console.WriteLine("Message count was zero");
return;
}
server.SendOne(connectionId, segment);
return;
}
public override string ServerGetClientAddress(int connectionId)
{
return server.GetClientAddress(connectionId);
}
public override Uri ServerUri()
{
UriBuilder builder = new UriBuilder
{
Scheme = GetServerScheme(),
Host = Dns.GetHostName()
};
return builder.Uri;
}
public void Update()
{
server?.ProcessMessageQueue();
}
#endregion
}
}

View file

@ -0,0 +1,49 @@
using Newtonsoft.Json;
using System.IO;
namespace Mirror.SimpleWeb
{
internal class SslConfigLoader
{
internal struct Cert
{
public string path;
public string password;
}
internal static SslConfig Load(SimpleWebTransport transport)
{
// don't need to load anything if ssl is not enabled
if (!transport.sslEnabled)
return default;
string certJsonPath = transport.sslCertJson;
Cert cert = LoadCertJson(certJsonPath);
return new SslConfig(
enabled: transport.sslEnabled,
sslProtocols: transport.sslProtocols,
certPath: cert.path,
certPassword: cert.password
);
}
internal static Cert LoadCertJson(string certJsonPath)
{
string json = File.ReadAllText(certJsonPath);
Cert cert = JsonConvert.DeserializeObject<Cert>(json);
if (string.IsNullOrEmpty(cert.path))
{
throw new InvalidDataException("Cert Json didn't not contain \"path\"");
}
if (string.IsNullOrEmpty(cert.password))
{
// password can be empty
cert.password = string.Empty;
}
return cert;
}
}
}