Skip to content

Commit 634764c

Browse files
committed
Most tests
1 parent bb668c7 commit 634764c

File tree

13 files changed

+321
-27
lines changed

13 files changed

+321
-27
lines changed

Rpc.Proto/ApiVersion.cs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public class ApiVersion(int major, int minor, params int[] additionalMajors)
2626
/// <param name="versionString">Version string to parse</param>
2727
/// <returns>Parsed ApiVersion</returns>
2828
/// <exception cref="ArgumentException">The version string is invalid</exception>
29-
public static ApiVersion ParseString(string versionString)
29+
public static ApiVersion Parse(string versionString)
3030
{
3131
var parts = versionString.Split('.');
3232
if (parts.Length != 2) throw new ArgumentException($"Invalid version string '{versionString}'");
@@ -68,4 +68,36 @@ public void Validate(ApiVersion other)
6868
if (AdditionalMajors.Any(major => other.Major == major)) return;
6969
throw new ApiCompatibilityException(this, other, "Version is no longer supported");
7070
}
71+
72+
#region ApiVersion Equality
73+
74+
public static bool operator ==(ApiVersion a, ApiVersion b)
75+
{
76+
return a.Equals(b);
77+
}
78+
79+
public static bool operator !=(ApiVersion a, ApiVersion b)
80+
{
81+
return !a.Equals(b);
82+
}
83+
84+
private bool Equals(ApiVersion other)
85+
{
86+
return Major == other.Major && Minor == other.Minor && AdditionalMajors.SequenceEqual(other.AdditionalMajors);
87+
}
88+
89+
public override bool Equals(object? obj)
90+
{
91+
if (obj is null) return false;
92+
if (ReferenceEquals(this, obj)) return true;
93+
if (obj.GetType() != GetType()) return false;
94+
return Equals((ApiVersion)obj);
95+
}
96+
97+
public override int GetHashCode()
98+
{
99+
return HashCode.Combine(Major, Minor, AdditionalMajors);
100+
}
101+
102+
#endregion
71103
}

Rpc.Proto/RpcHeader.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public static RpcHeader Parse(string header)
2626
if (parts.Length != 3) throw new ArgumentException($"Wrong number of parts in header string '{header}'");
2727
if (parts[0] != Preamble) throw new ArgumentException($"Invalid preamble in header string '{header}'");
2828

29-
var version = ApiVersion.ParseString(parts[1]);
29+
var version = ApiVersion.Parse(parts[1]);
3030
var role = new RpcRole(parts[2]);
3131
return new RpcHeader(role, version);
3232
}

Rpc.Proto/RpcMessage.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public static RpcRole GetRole()
3636
{
3737
var type = typeof(T);
3838
var attr = type.GetCustomAttribute<RpcRoleAttribute>();
39-
if (attr is null) throw new ArgumentException($"Message type {type} does not have a RpcRoleAttribute");
39+
if (attr is null) throw new ArgumentException($"Message type '{type}' does not have a RpcRoleAttribute");
4040
return attr.Role;
4141
}
4242
}

Rpc.Proto/RpcRole.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ namespace Coder.Desktop.Rpc.Proto;
55
/// </summary>
66
public sealed class RpcRole
77
{
8-
internal const string Manager = "manager";
9-
internal const string Tunnel = "tunnel";
8+
public const string Manager = "manager";
9+
public const string Tunnel = "tunnel";
1010

1111
public RpcRole(string role)
1212
{

Rpc.Proto/vpn.proto

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
syntax = "proto3";
22
option go_package = "github.com/coder/coder/v2/vpn";
3-
// TODO: add this upstream
43
option csharp_namespace = "Coder.Desktop.Rpc.Proto";
54

65
import "google/protobuf/timestamp.proto";

Rpc/Serdes.cs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ public async Task WriteMessage(Stream conn, TS message, CancellationToken ct = d
6969
/// <param name="ct">Optional cancellation token</param>
7070
/// <returns>Decoded message</returns>
7171
/// <exception cref="IOException">Could not decode the message</exception>
72-
/// <exception cref="InvalidOperationException">Could not cast the received message to the expected type</exception>
7372
public async Task<TR> ReadMessage(Stream conn, CancellationToken ct = default)
7473
{
7574
using var _ = await _readLock.LockAsync(ct);
@@ -83,8 +82,16 @@ public async Task<TR> ReadMessage(Stream conn, CancellationToken ct = default)
8382
var msgBytes = new byte[len];
8483
await conn.ReadExactlyAsync(msgBytes, ct);
8584

86-
var msg = _parser.ParseFrom(msgBytes);
87-
if (msg == null) throw new IOException("Failed to parse message");
88-
return msg;
85+
try
86+
{
87+
var msg = _parser.ParseFrom(msgBytes);
88+
if (msg?.RpcField is null)
89+
throw new IOException("Parsed message is empty or invalid");
90+
return msg;
91+
}
92+
catch (Exception e)
93+
{
94+
throw new IOException("Failed to parse message", e);
95+
}
8996
}
9097
}

Rpc/Speaker.cs

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,17 @@ public class Speaker<TS, TR> : IAsyncDisposable
4646

4747
public delegate void OnReceiveDelegate(ReplyableRpcMessage<TS, TR> message);
4848

49+
/// <summary>
50+
/// Event that is triggered when a message is received.
51+
/// </summary>
52+
public event OnReceiveDelegate? Receive;
53+
54+
/// <summary>
55+
/// Event that is triggered when an error occurs. The handling code should dispose the Speaker after this event is
56+
/// triggered.
57+
/// </summary>
58+
public event OnErrorDelegate? Error;
59+
4960
private readonly Stream _conn;
5061

5162
// _cts is cancelled when Dispose is called and will cause all ongoing I/O
@@ -70,24 +81,21 @@ public Speaker(Stream conn)
7081

7182
public async ValueTask DisposeAsync()
7283
{
84+
Error = null;
7385
await _cts.CancelAsync();
7486
if (_receiveTask is not null) await _receiveTask.WaitAsync(TimeSpan.FromSeconds(5));
7587
await _conn.DisposeAsync();
7688
GC.SuppressFinalize(this);
7789
}
7890

79-
// TODO: do we want to do events API or channels API?
80-
public event OnReceiveDelegate? Receive;
81-
public event OnErrorDelegate? Error;
82-
8391
/// <summary>
8492
/// Performs a handshake with the peer and starts the async receive loop. The caller should attach it's Receive and
8593
/// Error event handlers before calling this method.
8694
/// </summary>
8795
public async Task StartAsync(CancellationToken ct = default)
8896
{
8997
// Handshakes should always finish quickly, so enforce a 5s timeout.
90-
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
98+
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct, _cts.Token);
9199
cts.CancelAfter(TimeSpan.FromSeconds(5));
92100
await PerformHandshake(ct);
93101

@@ -174,23 +182,25 @@ private async Task ReceiveLoop(CancellationToken ct = default)
174182
/// <param name="ct">Optional cancellation token</param>
175183
public async Task SendMessage(TS message, CancellationToken ct = default)
176184
{
185+
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct, _cts.Token);
177186
message.RpcField = new RPC
178187
{
179188
MsgId = Interlocked.Add(ref _lastMessageId, 1),
180189
ResponseTo = 0,
181190
};
182-
await _serdes.WriteMessage(_conn, message, ct);
191+
await _serdes.WriteMessage(_conn, message, cts.Token);
183192
}
184193

185194
/// <summary>
186195
/// Send a message and wait for a reply. The reply will be returned and the callback will not be invoked as long as the
187196
/// reply is received before cancellation.
188197
/// </summary>
189-
/// <param name="message">Message to send</param>
198+
/// <param name="message">Message to send - the Rpc field will be overwritten</param>
190199
/// <param name="ct">Optional cancellation token</param>
191200
/// <returns>Received reply</returns>
192201
public async ValueTask<TR> SendMessageAwaitReply(TS message, CancellationToken ct = default)
193202
{
203+
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct, _cts.Token);
194204
message.RpcField = new RPC
195205
{
196206
MsgId = Interlocked.Add(ref _lastMessageId, 1),
@@ -203,31 +213,32 @@ public async ValueTask<TR> SendMessageAwaitReply(TS message, CancellationToken c
203213
_pendingReplies[message.RpcField.MsgId] = tcs;
204214
try
205215
{
206-
await _serdes.WriteMessage(_conn, message, ct);
216+
await _serdes.WriteMessage(_conn, message, cts.Token);
207217
// Wait for the reply to be received.
208-
return await tcs.Task.WaitAsync(ct);
218+
return await tcs.Task.WaitAsync(cts.Token);
209219
}
210220
finally
211221
{
212222
// Clean up the pending reply if it was not received before
213-
// cancellation.
223+
// cancellation or another exception occurred.
214224
_pendingReplies.TryRemove(message.RpcField.MsgId, out _);
215225
}
216226
}
217227

218228
/// <summary>
219-
/// Sends a reply to a received request.
229+
/// Sends a reply to a received message.
220230
/// </summary>
221-
/// <param name="originalMessage">Message to reply to</param>
231+
/// <param name="originalMessage">Message to reply to - the Rpc field will be overwritten</param>
222232
/// <param name="reply">Reply message</param>
223233
/// <param name="ct">Optional cancellation token</param>
224234
public async Task SendReply(TR originalMessage, TS reply, CancellationToken ct = default)
225235
{
236+
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct, _cts.Token);
226237
reply.RpcField = new RPC
227238
{
228239
MsgId = Interlocked.Add(ref _lastMessageId, 1),
229240
ResponseTo = originalMessage.RpcField.MsgId,
230241
};
231-
await _serdes.WriteMessage(_conn, reply, ct);
242+
await _serdes.WriteMessage(_conn, reply, cts.Token);
232243
}
233244
}

Tests/Rpc.Proto/ApiVersionTest.cs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
using Coder.Desktop.Rpc.Proto;
2+
3+
namespace Coder.Desktop.Tests.Rpc.Proto;
4+
5+
[TestFixture]
6+
public class ApiVersionTest
7+
{
8+
[Test(Description = "Parse a variety of version strings")]
9+
public void Parse()
10+
{
11+
Assert.That(ApiVersion.Parse("2.1"), Is.EqualTo(new ApiVersion(2, 1)));
12+
Assert.That(ApiVersion.Parse("1.0"), Is.EqualTo(new ApiVersion(1, 0)));
13+
14+
Assert.Throws<ArgumentException>(() => ApiVersion.Parse("cats"));
15+
Assert.Throws<ArgumentException>(() => ApiVersion.Parse("cats.dogs"));
16+
Assert.Throws<ArgumentException>(() => ApiVersion.Parse("1.dogs"));
17+
Assert.Throws<ArgumentException>(() => ApiVersion.Parse("1.0.1"));
18+
Assert.Throws<ArgumentException>(() => ApiVersion.Parse("11"));
19+
}
20+
21+
[Test(Description = "Test that versions are compatible")]
22+
public void Validate()
23+
{
24+
var twoOne = new ApiVersion(2, 1, 1);
25+
Assert.DoesNotThrow(() => twoOne.Validate(twoOne));
26+
Assert.DoesNotThrow(() => twoOne.Validate(new ApiVersion(2, 0)));
27+
Assert.DoesNotThrow(() => twoOne.Validate(new ApiVersion(1, 0)));
28+
29+
var ex = Assert.Throws<ApiCompatibilityException>(() => twoOne.Validate(new ApiVersion(2, 2)));
30+
Assert.That(ex.Message, Does.Contain("Peer supports newer minor version"));
31+
ex = Assert.Throws<ApiCompatibilityException>(() => twoOne.Validate(new ApiVersion(3, 1)));
32+
Assert.That(ex.Message, Does.Contain("Peer supports newer major version"));
33+
ex = Assert.Throws<ApiCompatibilityException>(() => twoOne.Validate(new ApiVersion(0, 8)));
34+
Assert.That(ex.Message, Does.Contain("Version is no longer supported"));
35+
}
36+
}

Tests/Rpc.Proto/RpcHeaderTest.cs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
using System.Text;
2+
using Coder.Desktop.Rpc.Proto;
3+
4+
namespace Coder.Desktop.Tests.Rpc.Proto;
5+
6+
[TestFixture]
7+
public class RpcHeaderTest
8+
{
9+
[Test(Description = "Parse and use some valid header strings")]
10+
public void Valid()
11+
{
12+
var headerStr = "codervpn 2.1 manager";
13+
var header = RpcHeader.Parse(headerStr);
14+
Assert.That(header.Role.ToString(), Is.EqualTo(RpcRole.Manager));
15+
Assert.That(header.Version, Is.EqualTo(new ApiVersion(2, 1)));
16+
Assert.That(header.ToString(), Is.EqualTo(headerStr + "\n"));
17+
Assert.That(header.ToBytes().ToArray(), Is.EqualTo(Encoding.UTF8.GetBytes(headerStr + "\n")));
18+
19+
headerStr = "codervpn 1.0 tunnel";
20+
header = RpcHeader.Parse(headerStr);
21+
Assert.That(header.Role.ToString(), Is.EqualTo(RpcRole.Tunnel));
22+
Assert.That(header.Version, Is.EqualTo(new ApiVersion(1, 0)));
23+
Assert.That(header.ToString(), Is.EqualTo(headerStr + "\n"));
24+
Assert.That(header.ToBytes().ToArray(), Is.EqualTo(Encoding.UTF8.GetBytes(headerStr + "\n")));
25+
}
26+
27+
[Test(Description = "Try to parse some invalid header strings")]
28+
public void ParseInvalid()
29+
{
30+
var ex = Assert.Throws<ArgumentException>(() => RpcHeader.Parse("codervpn"));
31+
Assert.That(ex.Message, Does.Contain("Wrong number of parts"));
32+
ex = Assert.Throws<ArgumentException>(() => RpcHeader.Parse("codervpn 1.0 manager cats"));
33+
Assert.That(ex.Message, Does.Contain("Wrong number of parts"));
34+
ex = Assert.Throws<ArgumentException>(() => RpcHeader.Parse("codervpn 1.0"));
35+
Assert.That(ex.Message, Does.Contain("Wrong number of parts"));
36+
ex = Assert.Throws<ArgumentException>(() => RpcHeader.Parse("cats 1.0 manager"));
37+
Assert.That(ex.Message, Does.Contain("Invalid preamble"));
38+
ex = Assert.Throws<ArgumentException>(() => RpcHeader.Parse("codervpn 1.0 cats"));
39+
Assert.That(ex.Message, Does.Contain("Unknown role 'cats'"));
40+
}
41+
}

Tests/Rpc.Proto/RpcMessageTest.cs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
using Coder.Desktop.Rpc.Proto;
2+
3+
namespace Coder.Desktop.Tests.Rpc.Proto;
4+
5+
[TestFixture]
6+
public class RpcRoleAttributeTest
7+
{
8+
[Test]
9+
public void Valid()
10+
{
11+
var role = new RpcRoleAttribute(RpcRole.Manager);
12+
Assert.That(role.Role.ToString(), Is.EqualTo(RpcRole.Manager));
13+
role = new RpcRoleAttribute(RpcRole.Tunnel);
14+
Assert.That(role.Role.ToString(), Is.EqualTo(RpcRole.Tunnel));
15+
}
16+
17+
[Test]
18+
public void Invalid()
19+
{
20+
Assert.Throws<ArgumentException>(() => _ = new RpcRoleAttribute("cats"));
21+
}
22+
}
23+
24+
[TestFixture]
25+
public class RpcMessageTest
26+
{
27+
[Test]
28+
public void GetRole()
29+
{
30+
// RpcMessage<RPC> is not a supported message type and doesn't have an
31+
// RpcRoleAttribute
32+
var ex = Assert.Throws<ArgumentException>(() => _ = RpcMessage<RPC>.GetRole());
33+
Assert.That(ex.Message,
34+
Does.Contain("Message type 'Coder.Desktop.Rpc.Proto.RPC' does not have a RpcRoleAttribute"));
35+
36+
Assert.That(ManagerMessage.GetRole().ToString(), Is.EqualTo(RpcRole.Manager));
37+
Assert.That(TunnelMessage.GetRole().ToString(), Is.EqualTo(RpcRole.Tunnel));
38+
}
39+
}

Tests/Rpc.Proto/RpcRoleTest.cs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using Coder.Desktop.Rpc.Proto;
2+
3+
namespace Coder.Desktop.Tests.Rpc.Proto;
4+
5+
[TestFixture]
6+
public class RpcRoleTest
7+
{
8+
[Test(Description = "Instantiate a RpcRole with a valid name")]
9+
public void ValidRole()
10+
{
11+
var role = new RpcRole(RpcRole.Manager);
12+
Assert.That(role.ToString(), Is.EqualTo(RpcRole.Manager));
13+
role = new RpcRole(RpcRole.Tunnel);
14+
Assert.That(role.ToString(), Is.EqualTo(RpcRole.Tunnel));
15+
}
16+
17+
[Test(Description = "Try to instantiate a RpcRole with an invalid name")]
18+
public void InvalidRole()
19+
{
20+
Assert.Throws<ArgumentException>(() => _ = new RpcRole("cats"));
21+
}
22+
}

0 commit comments

Comments
 (0)