Skip to content

Commit 6d4ba04

Browse files
committed
Fix Rpc field management to match Go
1 parent d095896 commit 6d4ba04

File tree

6 files changed

+84
-45
lines changed

6 files changed

+84
-45
lines changed

Coder.Desktop.sln.DotSettings

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,4 +252,5 @@
252252
</TypePattern>
253253
</Patterns>
254254
</s:String>
255+
<s:Boolean x:Key="/Default/Environment/SettingsMigration/IsMigratorApplied/=JetBrains_002EReSharper_002EPsi_002ECSharp_002ECodeStyle_002EMemberReordering_002EMigrations_002ECSharpFileLayoutPatternRemoveIsAttributeUpgrade/@EntryIndexedValue">True</s:Boolean>
255256
<s:Boolean x:Key="/Default/UserDictionary/Words/=serdes/@EntryIndexedValue">True</s:Boolean></wpf:ResourceDictionary>

Rpc.Proto/RpcMessage.cs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,20 @@ public abstract class RpcMessage<T> where T : IMessage<T>
1919
/// The inner RPC component of the message. This is a separate field as the C# compiler does not allow the existing Rpc
2020
/// field to be overridden or implement this abstract property.
2121
/// </summary>
22-
public abstract RPC RpcField { get; set; }
22+
public abstract RPC? RpcField { get; set; }
2323

2424
/// <summary>
2525
/// The inner message component of the message. This exists so values of type RpcMessage can easily get message
2626
/// contents.
2727
/// </summary>
2828
public abstract T Message { get; }
2929

30+
/// <summary>
31+
/// Check if the message is valid. Checks for empty <c>oneof</c> of fields.
32+
/// </summary>
33+
/// <exception cref="ArgumentException">Invalid message</exception>
34+
public abstract void Validate();
35+
3036
/// <summary>
3137
/// Gets the RpcRole of the message type from it's RpcRole attribute.
3238
/// </summary>
@@ -44,23 +50,33 @@ public static RpcRole GetRole()
4450
[RpcRole(RpcRole.Manager)]
4551
public partial class ManagerMessage : RpcMessage<ManagerMessage>
4652
{
47-
public override RPC RpcField
53+
public override RPC? RpcField
4854
{
4955
get => Rpc;
5056
set => Rpc = value;
5157
}
5258

5359
public override ManagerMessage Message => this;
60+
61+
public override void Validate()
62+
{
63+
if (MsgCase == MsgOneofCase.None) throw new ArgumentException("Message does not contain inner message type");
64+
}
5465
}
5566

5667
[RpcRole(RpcRole.Tunnel)]
5768
public partial class TunnelMessage : RpcMessage<TunnelMessage>
5869
{
59-
public override RPC RpcField
70+
public override RPC? RpcField
6071
{
6172
get => Rpc;
6273
set => Rpc = value;
6374
}
6475

6576
public override TunnelMessage Message => this;
77+
78+
public override void Validate()
79+
{
80+
if (MsgCase == MsgOneofCase.None) throw new ArgumentException("Message does not contain inner message type");
81+
}
6682
}

Rpc/Serdes.cs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,15 @@ public class Serdes<TS, TR>
4747
/// <param name="conn">Stream to write the encoded message to</param>
4848
/// <param name="message">Message to encode and write</param>
4949
/// <param name="ct">Optional cancellation token</param>
50-
/// <exception cref="ArgumentException">If the message exceeds the maximum message size of 16 MiB</exception>
50+
/// <exception cref="ArgumentException">If the message is invalid</exception>
5151
public async Task WriteMessage(Stream conn, TS message, CancellationToken ct = default)
5252
{
53+
message.Validate(); // throws ArgumentException if invalid
5354
using var _ = await _writeLock.LockAsync(ct);
5455

5556
var mb = message.ToByteArray();
57+
if (mb.Length == 0)
58+
throw new ArgumentException("Marshalled message is empty");
5659
if (mb.Length > MaxMessageSize)
5760
throw new ArgumentException($"Marshalled message size {mb.Length} exceeds maximum {MaxMessageSize}");
5861

@@ -69,13 +72,16 @@ public async Task WriteMessage(Stream conn, TS message, CancellationToken ct = d
6972
/// <param name="ct">Optional cancellation token</param>
7073
/// <returns>Decoded message</returns>
7174
/// <exception cref="IOException">Could not decode the message</exception>
75+
/// <exception cref="ArgumentException">The message is invalid</exception>
7276
public async Task<TR> ReadMessage(Stream conn, CancellationToken ct = default)
7377
{
7478
using var _ = await _readLock.LockAsync(ct);
7579

7680
var lenBytes = new byte[sizeof(uint)];
7781
await conn.ReadExactlyAsync(lenBytes, ct);
7882
var len = BinaryPrimitives.ReadUInt32BigEndian(lenBytes);
83+
if (len == 0)
84+
throw new IOException("Received message size 0");
7985
if (len > MaxMessageSize)
8086
throw new IOException($"Received message size {len} exceeds maximum {MaxMessageSize}");
8187

@@ -85,8 +91,9 @@ public async Task<TR> ReadMessage(Stream conn, CancellationToken ct = default)
8591
try
8692
{
8793
var msg = _parser.ParseFrom(msgBytes);
88-
if (msg?.RpcField is null)
89-
throw new IOException("Parsed message is empty or invalid");
94+
if (msg is null)
95+
throw new IOException("Parsed message is null");
96+
msg.Validate(); // throws ArgumentException if invalid
9097
return msg;
9198
}
9299
catch (Exception e)

Rpc/Speaker.cs

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,19 @@ public class ReplyableRpcMessage<TS, TR>(Speaker<TS, TR> speaker, TR message) :
1515
where TS : RpcMessage<TS>, IMessage<TS>
1616
where TR : RpcMessage<TR>, IMessage<TR>, new()
1717
{
18-
public override RPC RpcField
18+
public override RPC? RpcField
1919
{
2020
get => message.RpcField;
2121
set => message.RpcField = value;
2222
}
2323

2424
public override TR Message => message;
2525

26+
public override void Validate()
27+
{
28+
message.Validate();
29+
}
30+
2631
/// <summary>
2732
/// Sends a reply to the original message.
2833
/// </summary>
@@ -55,9 +60,9 @@ public class Speaker<TS, TR> : IAsyncDisposable
5560
private readonly ConcurrentDictionary<ulong, TaskCompletionSource<TR>> _pendingReplies = new();
5661
private readonly Serdes<TS, TR> _serdes = new();
5762

58-
// _lastMessageId is incremented using an atomic operation, and as such the
59-
// first message ID will actually be 1.
60-
private ulong _lastMessageId;
63+
// _lastRequestId is incremented using an atomic operation, and as such the
64+
// first request ID will actually be 1.
65+
private ulong _lastRequestId;
6166
private Task? _receiveTask;
6267

6368
/// <summary>
@@ -156,13 +161,17 @@ private async Task ReceiveLoop(CancellationToken ct = default)
156161
while (!ct.IsCancellationRequested)
157162
{
158163
var message = await _serdes.ReadMessage(_conn, ct);
159-
if (message.RpcField.ResponseTo != 0)
164+
if (message is { RpcField.ResponseTo : not 0 })
165+
{
160166
// Look up the TaskCompletionSource for the message ID and
161167
// complete it with the message.
162168
if (_pendingReplies.TryRemove(message.RpcField.ResponseTo, out var tcs))
163169
tcs.SetResult(message);
170+
else
171+
// TODO: we should log unknown replies
172+
continue;
173+
}
164174

165-
// TODO: we should log unknown replies
166175
// Start a new task in the background to handle the message.
167176
_ = Task.Run(() => Receive?.Invoke(new ReplyableRpcMessage<TS, TR>(this, message)), ct);
168177
}
@@ -178,18 +187,14 @@ private async Task ReceiveLoop(CancellationToken ct = default)
178187
}
179188

180189
/// <summary>
181-
/// Send a message without waiting for a reply. If a reply is received it will be handled by the callback.
190+
/// Send a message that does not expect a reply.
182191
/// </summary>
183192
/// <param name="message">Message to send</param>
184193
/// <param name="ct">Optional cancellation token</param>
185194
public async Task SendMessage(TS message, CancellationToken ct = default)
186195
{
187196
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct, _cts.Token);
188-
message.RpcField = new RPC
189-
{
190-
MsgId = Interlocked.Add(ref _lastMessageId, 1),
191-
ResponseTo = 0,
192-
};
197+
message.RpcField = null;
193198
await _serdes.WriteMessage(_conn, message, cts.Token);
194199
}
195200

@@ -200,12 +205,12 @@ public async Task SendMessage(TS message, CancellationToken ct = default)
200205
/// <param name="message">Message to send - the Rpc field will be overwritten</param>
201206
/// <param name="ct">Optional cancellation token</param>
202207
/// <returns>Received reply</returns>
203-
public async ValueTask<TR> SendMessageAwaitReply(TS message, CancellationToken ct = default)
208+
public async ValueTask<TR> SendRequestAwaitReply(TS message, CancellationToken ct = default)
204209
{
205210
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct, _cts.Token);
206211
message.RpcField = new RPC
207212
{
208-
MsgId = Interlocked.Add(ref _lastMessageId, 1),
213+
MsgId = Interlocked.Add(ref _lastRequestId, 1),
209214
ResponseTo = 0,
210215
};
211216

@@ -233,12 +238,16 @@ public async ValueTask<TR> SendMessageAwaitReply(TS message, CancellationToken c
233238
/// <param name="originalMessage">Message to reply to - the Rpc field will be overwritten</param>
234239
/// <param name="reply">Reply message</param>
235240
/// <param name="ct">Optional cancellation token</param>
241+
/// <exception cref="ArgumentException">The original message is not a request and cannot be replied to</exception>
236242
public async Task SendReply(TR originalMessage, TS reply, CancellationToken ct = default)
237243
{
244+
if (originalMessage.RpcField == null || originalMessage.RpcField.MsgId == 0)
245+
throw new ArgumentException("Original message is not a request");
246+
238247
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct, _cts.Token);
239248
reply.RpcField = new RPC
240249
{
241-
MsgId = Interlocked.Add(ref _lastMessageId, 1),
250+
MsgId = 0,
242251
ResponseTo = originalMessage.RpcField.MsgId,
243252
};
244253
await _serdes.WriteMessage(_conn, reply, cts.Token);

Tests/Rpc/SerdesTest.cs

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System.Buffers.Binary;
22
using Coder.Desktop.Rpc;
33
using Coder.Desktop.Rpc.Proto;
4+
using Google.Protobuf;
45

56
namespace Coder.Desktop.Tests.Rpc;
67

@@ -16,10 +17,7 @@ public async Task WriteReadMessage()
1617

1718
var msg = new ManagerMessage
1819
{
19-
Rpc = new RPC
20-
{
21-
MsgId = 1,
22-
},
20+
Start = new StartRequest(),
2321
};
2422
await serdes.WriteMessage(stream1, msg);
2523
var got = await serdes.ReadMessage(stream2);
@@ -35,10 +33,6 @@ public void WriteMessageTooLarge()
3533

3634
var msg = new ManagerMessage
3735
{
38-
Rpc = new RPC
39-
{
40-
MsgId = 1,
41-
},
4236
Start = new StartRequest
4337
{
4438
ApiToken = new string('a', 0x1000001),
@@ -75,8 +69,7 @@ public async Task ReadEmptyMessage()
7569
BinaryPrimitives.WriteUInt32BigEndian(lenBytes, 0);
7670
await stream1.WriteAsync(lenBytes);
7771
var ex = Assert.ThrowsAsync<IOException>(() => serdes.ReadMessage(stream2));
78-
Assert.That(ex.InnerException, Is.Not.Null);
79-
Assert.That(ex.InnerException?.Message, Does.Contain("Parsed message is empty or invalid"));
72+
Assert.That(ex.Message, Does.Contain("Received message size 0"));
8073
}
8174

8275
[Test(Description = "Read an invalid/corrupt message from the stream")]
@@ -91,6 +84,6 @@ public async Task ReadInvalidMessage()
9184
await stream1.WriteAsync(lenBytes);
9285
await stream1.WriteAsync(new byte[1]);
9386
var ex = Assert.ThrowsAsync<IOException>(() => serdes.ReadMessage(stream2));
94-
Assert.That(ex.Message, Does.Not.Contain("Parsed message is empty or invalid"));
87+
Assert.That(ex.InnerException, Is.TypeOf(typeof(InvalidProtocolBufferException)));
9588
}
9689
}

Tests/Rpc/SpeakerTest.cs

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -179,28 +179,29 @@ public async Task SendReceiveReplyReceive()
179179
await using var speaker1 = new Speaker<ManagerMessage, TunnelMessage>(stream1);
180180
var speaker1Ch = Channel
181181
.CreateUnbounded<ReplyableRpcMessage<ManagerMessage, TunnelMessage>>();
182-
speaker1.Receive += msg =>
183-
{
184-
Console.WriteLine($"speaker1 received message: {msg.RpcField.MsgId}");
185-
Assert.That(speaker1Ch.Writer.TryWrite(msg), Is.True);
186-
};
182+
speaker1.Receive += msg => { Assert.That(speaker1Ch.Writer.TryWrite(msg), Is.True); };
187183
speaker1.Error += ex => { Assert.Fail($"speaker1 error: {ex}"); };
188184

189185
await using var speaker2 = new Speaker<TunnelMessage, ManagerMessage>(stream2);
190186
var speaker2Ch = Channel
191187
.CreateUnbounded<ReplyableRpcMessage<TunnelMessage, ManagerMessage>>();
192-
speaker2.Receive += msg =>
193-
{
194-
Console.WriteLine($"speaker2 received message: {msg.RpcField.MsgId}");
195-
Assert.That(speaker2Ch.Writer.TryWrite(msg), Is.True);
196-
};
188+
speaker2.Receive += msg => { Assert.That(speaker2Ch.Writer.TryWrite(msg), Is.True); };
197189
speaker2.Error += ex => { Assert.Fail($"speaker2 error: {ex}"); };
198190

199191
// Start both speakers simultaneously
200192
Task.WaitAll(speaker1.StartAsync(), speaker2.StartAsync());
201193

194+
// Send a normal message from speaker2 to speaker1
195+
await speaker2.SendMessage(new TunnelMessage
196+
{
197+
PeerUpdate = new PeerUpdate(),
198+
});
199+
var receivedMessage = await speaker1Ch.Reader.ReadAsync();
200+
Assert.That(receivedMessage.RpcField, Is.Null); // not a request
201+
Assert.That(receivedMessage.Message.PeerUpdate, Is.Not.Null);
202+
202203
// Send a message from speaker1 to speaker2 in the background
203-
var sendTask = speaker1.SendMessageAwaitReply(new ManagerMessage
204+
var sendTask = speaker1.SendRequestAwaitReply(new ManagerMessage
204205
{
205206
Start = new StartRequest
206207
{
@@ -211,6 +212,9 @@ public async Task SendReceiveReplyReceive()
211212

212213
// Receive the message in speaker2
213214
var message = await speaker2Ch.Reader.ReadAsync();
215+
Assert.That(message.RpcField, Is.Not.Null);
216+
Assert.That(message.RpcField!.MsgId, Is.Not.EqualTo(0));
217+
Assert.That(message.RpcField!.ResponseTo, Is.EqualTo(0));
214218
Assert.That(message.Message.Start.ApiToken, Is.EqualTo("test"));
215219

216220
// Send a reply back to speaker1
@@ -224,6 +228,9 @@ await message.SendReply(new TunnelMessage
224228

225229
// Receive the reply in speaker1 by awaiting sendTask
226230
var reply = await sendTask;
231+
Assert.That(message.RpcField, Is.Not.Null);
232+
Assert.That(reply.RpcField!.MsgId, Is.EqualTo(0));
233+
Assert.That(reply.RpcField!.ResponseTo, Is.EqualTo(message.RpcField!.MsgId));
227234
Assert.That(reply.Message.Start.Success, Is.True);
228235
}
229236

@@ -288,7 +295,10 @@ public async Task SendMessageWriteError()
288295
var writeEx = new IOException("Test write error");
289296
failStream.SetWriteException(writeEx);
290297

291-
var gotEx = Assert.ThrowsAsync<IOException>(() => speaker1.SendMessage(new ManagerMessage()));
298+
var gotEx = Assert.ThrowsAsync<IOException>(() => speaker1.SendMessage(new ManagerMessage
299+
{
300+
Start = new StartRequest(),
301+
}));
292302
Assert.That(gotEx, Is.EqualTo(writeEx));
293303
}
294304

@@ -367,7 +377,10 @@ public async Task DisposeWhileAwaitingReply()
367377
await Task.WhenAll(speaker1.StartAsync(), speaker2.StartAsync());
368378

369379
// Send a message from speaker1 to speaker2
370-
var sendTask = speaker1.SendMessageAwaitReply(new ManagerMessage());
380+
var sendTask = speaker1.SendRequestAwaitReply(new ManagerMessage
381+
{
382+
Start = new StartRequest(),
383+
});
371384

372385
// Dispose speaker1
373386
await speaker1.DisposeAsync();

0 commit comments

Comments
 (0)