diff --git a/App/App.xaml.cs b/App/App.xaml.cs index e756efd..5b82ced 100644 --- a/App/App.xaml.cs +++ b/App/App.xaml.cs @@ -72,6 +72,7 @@ public App() new WindowsCredentialBackend(WindowsCredentialBackend.CoderCredentialsTargetName)); services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); services.AddOptions() .Bind(builder.Configuration.GetSection(MutagenControllerConfigSection)); diff --git a/App/Models/CredentialModel.cs b/App/Models/CredentialModel.cs index d30f894..b38bbba 100644 --- a/App/Models/CredentialModel.cs +++ b/App/Models/CredentialModel.cs @@ -1,4 +1,5 @@ using System; +using Coder.Desktop.CoderSdk.Coder; namespace Coder.Desktop.App.Models; @@ -14,7 +15,7 @@ public enum CredentialState Valid, } -public class CredentialModel +public class CredentialModel : ICoderApiClientCredentialProvider { public CredentialState State { get; init; } = CredentialState.Unknown; @@ -33,4 +34,14 @@ public CredentialModel Clone() Username = Username, }; } + + public CoderApiClientCredential? GetCoderApiClientCredential() + { + if (State != CredentialState.Valid) return null; + return new CoderApiClientCredential + { + ApiToken = ApiToken!, + CoderUrl = CoderUrl!, + }; + } } diff --git a/App/Services/HostnameSuffixGetter.cs b/App/Services/HostnameSuffixGetter.cs new file mode 100644 index 0000000..3816623 --- /dev/null +++ b/App/Services/HostnameSuffixGetter.cs @@ -0,0 +1,144 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Coder.Desktop.App.Models; +using Coder.Desktop.CoderSdk.Coder; +using Coder.Desktop.Vpn.Utilities; +using Microsoft.Extensions.Logging; + +namespace Coder.Desktop.App.Services; + +public interface IHostnameSuffixGetter +{ + public event EventHandler SuffixChanged; + + public string GetCachedSuffix(); +} + +public class HostnameSuffixGetter : IHostnameSuffixGetter +{ + private const string DefaultSuffix = ".coder"; + + private readonly ICredentialManager _credentialManager; + private readonly ICoderApiClientFactory _clientFactory; + private readonly ILogger _logger; + + // _lock protects all private (non-readonly) values + private readonly RaiiSemaphoreSlim _lock = new(1, 1); + private string _domainSuffix = DefaultSuffix; + private bool _dirty = false; + private bool _getInProgress = false; + private CredentialModel _credentialModel = new() { State = CredentialState.Invalid }; + + public event EventHandler? SuffixChanged; + + public HostnameSuffixGetter(ICredentialManager credentialManager, ICoderApiClientFactory apiClientFactory, + ILogger logger) + { + _credentialManager = credentialManager; + _clientFactory = apiClientFactory; + _logger = logger; + credentialManager.CredentialsChanged += HandleCredentialsChanged; + HandleCredentialsChanged(this, _credentialManager.GetCachedCredentials()); + } + + ~HostnameSuffixGetter() + { + _credentialManager.CredentialsChanged -= HandleCredentialsChanged; + } + + private void HandleCredentialsChanged(object? sender, CredentialModel credentials) + { + using var _ = _lock.Lock(); + _logger.LogDebug("credentials updated with state {state}", credentials.State); + _credentialModel = credentials; + if (credentials.State != CredentialState.Valid) return; + + _dirty = true; + if (!_getInProgress) + { + _getInProgress = true; + Task.Run(Refresh).ContinueWith(MaybeRefreshAgain); + } + } + + private async Task Refresh() + { + _logger.LogDebug("refreshing domain suffix"); + CredentialModel credentials; + using (_ = await _lock.LockAsync()) + { + credentials = _credentialModel; + if (credentials.State != CredentialState.Valid) + { + _logger.LogDebug("abandoning refresh because credentials are now invalid"); + return; + } + + _dirty = false; + } + + var client = _clientFactory.Create(credentials); + using var timeoutSrc = new CancellationTokenSource(TimeSpan.FromSeconds(10)); + var connInfo = await client.GetAgentConnectionInfoGeneric(timeoutSrc.Token); + + // older versions of Coder might not set this + var suffix = string.IsNullOrEmpty(connInfo.HostnameSuffix) + ? DefaultSuffix + // and, it doesn't include the leading dot. + : "." + connInfo.HostnameSuffix; + + var changed = false; + using (_ = await _lock.LockAsync(CancellationToken.None)) + { + if (_domainSuffix != suffix) changed = true; + _domainSuffix = suffix; + } + + if (changed) + { + _logger.LogInformation("got new domain suffix '{suffix}'", suffix); + // grab a local copy of the EventHandler to avoid TOCTOU race on the `?.` null-check + var del = SuffixChanged; + del?.Invoke(this, suffix); + } + else + { + _logger.LogDebug("domain suffix unchanged '{suffix}'", suffix); + } + } + + private async Task MaybeRefreshAgain(Task prev) + { + if (prev.IsFaulted) + { + _logger.LogError(prev.Exception, "failed to query domain suffix"); + // back off here before retrying. We're just going to use a fixed, long + // delay since this just affects UI stuff; we're not in a huge rush as + // long as we eventually get the right value. + await Task.Delay(TimeSpan.FromSeconds(10)); + } + + using var l = await _lock.LockAsync(CancellationToken.None); + if ((_dirty || prev.IsFaulted) && _credentialModel.State == CredentialState.Valid) + { + // we still have valid credentials and we're either dirty or the last Get failed. + _logger.LogDebug("retrying domain suffix query"); + _ = Task.Run(Refresh).ContinueWith(MaybeRefreshAgain); + return; + } + + // Getting here means either the credentials are not valid or we don't need to + // refresh anyway. + // The next time we get new, valid credentials, HandleCredentialsChanged will kick off + // a new Refresh + _getInProgress = false; + return; + } + + public string GetCachedSuffix() + { + using var _ = _lock.Lock(); + return _domainSuffix; + } +} diff --git a/App/ViewModels/TrayWindowViewModel.cs b/App/ViewModels/TrayWindowViewModel.cs index 1dccab0..cfa5163 100644 --- a/App/ViewModels/TrayWindowViewModel.cs +++ b/App/ViewModels/TrayWindowViewModel.cs @@ -35,6 +35,7 @@ public partial class TrayWindowViewModel : ObservableObject, IAgentExpanderHost private readonly IRpcController _rpcController; private readonly ICredentialManager _credentialManager; private readonly IAgentViewModelFactory _agentViewModelFactory; + private readonly IHostnameSuffixGetter _hostnameSuffixGetter; private FileSyncListWindow? _fileSyncListWindow; @@ -91,15 +92,14 @@ public partial class TrayWindowViewModel : ObservableObject, IAgentExpanderHost [ObservableProperty] public partial string DashboardUrl { get; set; } = DefaultDashboardUrl; - private string _hostnameSuffix = DefaultHostnameSuffix; - public TrayWindowViewModel(IServiceProvider services, IRpcController rpcController, - ICredentialManager credentialManager, IAgentViewModelFactory agentViewModelFactory) + ICredentialManager credentialManager, IAgentViewModelFactory agentViewModelFactory, IHostnameSuffixGetter hostnameSuffixGetter) { _services = services; _rpcController = rpcController; _credentialManager = credentialManager; _agentViewModelFactory = agentViewModelFactory; + _hostnameSuffixGetter = hostnameSuffixGetter; // Since the property value itself never changes, we add event // listeners for the underlying collection changing instead. @@ -139,6 +139,9 @@ public void Initialize(DispatcherQueue dispatcherQueue) _credentialManager.CredentialsChanged += (_, credentialModel) => UpdateFromCredentialModel(credentialModel); UpdateFromCredentialModel(_credentialManager.GetCachedCredentials()); + + _hostnameSuffixGetter.SuffixChanged += (_, suffix) => HandleHostnameSuffixChanged(suffix); + HandleHostnameSuffixChanged(_hostnameSuffixGetter.GetCachedSuffix()); } private void UpdateFromRpcModel(RpcModel rpcModel) @@ -195,7 +198,7 @@ private void UpdateFromRpcModel(RpcModel rpcModel) this, uuid, fqdn, - _hostnameSuffix, + _hostnameSuffixGetter.GetCachedSuffix(), connectionStatus, credentialModel.CoderUrl, workspace?.Name)); @@ -214,7 +217,7 @@ private void UpdateFromRpcModel(RpcModel rpcModel) // Workspace ID is fine as a stand-in here, it shouldn't // conflict with any agent IDs. uuid, - _hostnameSuffix, + _hostnameSuffixGetter.GetCachedSuffix(), AgentConnectionStatus.Gray, credentialModel.CoderUrl, workspace.Name)); @@ -273,6 +276,22 @@ private void UpdateFromCredentialModel(CredentialModel credentialModel) DashboardUrl = credentialModel.CoderUrl?.ToString() ?? DefaultDashboardUrl; } + private void HandleHostnameSuffixChanged(string suffix) + { + // Ensure we're on the UI thread. + if (_dispatcherQueue == null) return; + if (!_dispatcherQueue.HasThreadAccess) + { + _dispatcherQueue.TryEnqueue(() => HandleHostnameSuffixChanged(suffix)); + return; + } + + foreach (var agent in Agents) + { + agent.ConfiguredHostnameSuffix = suffix; + } + } + public void VpnSwitch_Toggled(object sender, RoutedEventArgs e) { if (sender is not ToggleSwitch toggleSwitch) return; diff --git a/CoderSdk/Coder/CoderApiClient.cs b/CoderSdk/Coder/CoderApiClient.cs index 15845bb..a24f364 100644 --- a/CoderSdk/Coder/CoderApiClient.cs +++ b/CoderSdk/Coder/CoderApiClient.cs @@ -49,6 +49,7 @@ public partial interface ICoderApiClient public void SetSessionToken(string token); } +[JsonSerializable(typeof(AgentConnectionInfo))] [JsonSerializable(typeof(BuildInfo))] [JsonSerializable(typeof(Response))] [JsonSerializable(typeof(User))] diff --git a/CoderSdk/Coder/WorkspaceAgents.cs b/CoderSdk/Coder/WorkspaceAgents.cs index d566286..9a7e6ff 100644 --- a/CoderSdk/Coder/WorkspaceAgents.cs +++ b/CoderSdk/Coder/WorkspaceAgents.cs @@ -3,6 +3,14 @@ namespace Coder.Desktop.CoderSdk.Coder; public partial interface ICoderApiClient { public Task GetWorkspaceAgent(string id, CancellationToken ct = default); + public Task GetAgentConnectionInfoGeneric(CancellationToken ct = default); +} + +public class AgentConnectionInfo +{ + public string HostnameSuffix { get; set; } = string.Empty; + // note that we're leaving out several fields including the DERP Map because + // we don't use that information, and it's a complex object to define. } public class WorkspaceAgent @@ -35,4 +43,9 @@ public Task GetWorkspaceAgent(string id, CancellationToken ct = { return SendRequestNoBodyAsync(HttpMethod.Get, "/api/v2/workspaceagents/" + id, ct); } + + public Task GetAgentConnectionInfoGeneric(CancellationToken ct = default) + { + return SendRequestNoBodyAsync(HttpMethod.Get, "/api/v2/workspaceagents/connection", ct); + } } diff --git a/Tests.App/Services/HostnameSuffixGetterTest.cs b/Tests.App/Services/HostnameSuffixGetterTest.cs new file mode 100644 index 0000000..9897d98 --- /dev/null +++ b/Tests.App/Services/HostnameSuffixGetterTest.cs @@ -0,0 +1,121 @@ +using System.ComponentModel.DataAnnotations; +using Coder.Desktop.App.Models; +using Coder.Desktop.App.Services; +using Coder.Desktop.CoderSdk.Coder; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Moq; +using Serilog; + +namespace Coder.Desktop.Tests.App.Services; + +[TestFixture] +public class HostnameSuffixGetterTest +{ + const string coderUrl = "https://coder.test/"; + + [SetUp] + public void SetupMocks() + { + Log.Logger = new LoggerConfiguration().MinimumLevel.Debug().WriteTo.NUnitOutput().CreateLogger(); + var builder = Host.CreateApplicationBuilder(); + builder.Services.AddSerilog(); + _logger = (ILogger)builder.Build().Services + .GetService(typeof(ILogger))!; + + _mCoderApiClientFactory = new Mock(MockBehavior.Strict); + _mCredentialManager = new Mock(MockBehavior.Strict); + _mCoderApiClient = new Mock(MockBehavior.Strict); + _mCoderApiClientFactory.Setup(m => m.Create(It.IsAny())) + .Returns(_mCoderApiClient.Object); + } + + private Mock _mCoderApiClientFactory; + private Mock _mCredentialManager; + private Mock _mCoderApiClient; + private ILogger _logger; + + [Test(Description = "Mainline no errors")] + [CancelAfter(10_000)] + public async Task Mainline(CancellationToken ct) + { + _mCredentialManager.Setup(m => m.GetCachedCredentials()) + .Returns(new CredentialModel() { State = CredentialState.Invalid }); + var hostnameSuffixGetter = + new HostnameSuffixGetter(_mCredentialManager.Object, _mCoderApiClientFactory.Object, _logger); + + // initially, we return the default + Assert.That(hostnameSuffixGetter.GetCachedSuffix(), Is.EqualTo(".coder")); + + // subscribed to suffix changes + var suffixCompletion = new TaskCompletionSource(); + hostnameSuffixGetter.SuffixChanged += (_, suffix) => suffixCompletion.SetResult(suffix); + + // set the client to return "test" as the suffix + _mCoderApiClient.Setup(m => m.SetSessionToken("test-token")); + _mCoderApiClient.Setup(m => m.GetAgentConnectionInfoGeneric(It.IsAny())) + .Returns(Task.FromResult(new AgentConnectionInfo() { HostnameSuffix = "test" })); + + _mCredentialManager.Raise(m => m.CredentialsChanged += null, _mCredentialManager.Object, new CredentialModel + { + State = CredentialState.Valid, + CoderUrl = new Uri(coderUrl), + ApiToken = "test-token", + }); + var gotSuffix = await TaskOrCancellation(suffixCompletion.Task, ct); + Assert.That(gotSuffix, Is.EqualTo(".test")); + + // now, we should return the .test domain going forward + Assert.That(hostnameSuffixGetter.GetCachedSuffix(), Is.EqualTo(".test")); + } + + [Test(Description = "Retries if error")] + [CancelAfter(30_000)] + // TODO: make this test not have to actually wait for the retry. + public async Task RetryError(CancellationToken ct) + { + _mCredentialManager.Setup(m => m.GetCachedCredentials()) + .Returns(new CredentialModel() { State = CredentialState.Invalid }); + var hostnameSuffixGetter = + new HostnameSuffixGetter(_mCredentialManager.Object, _mCoderApiClientFactory.Object, _logger); + + // subscribed to suffix changes + var suffixCompletion = new TaskCompletionSource(); + hostnameSuffixGetter.SuffixChanged += (_, suffix) => suffixCompletion.SetResult(suffix); + + // set the client to fail once, then return successfully + _mCoderApiClient.Setup(m => m.SetSessionToken("test-token")); + var connectionInfoCompletion = new TaskCompletionSource(); + _mCoderApiClient.SetupSequence(m => m.GetAgentConnectionInfoGeneric(It.IsAny())) + .Returns(Task.FromException(new Exception("a bad thing happened"))) + .Returns(Task.FromResult(new AgentConnectionInfo() { HostnameSuffix = "test" })); + + _mCredentialManager.Raise(m => m.CredentialsChanged += null, _mCredentialManager.Object, new CredentialModel + { + State = CredentialState.Valid, + CoderUrl = new Uri(coderUrl), + ApiToken = "test-token", + }); + var gotSuffix = await TaskOrCancellation(suffixCompletion.Task, ct); + Assert.That(gotSuffix, Is.EqualTo(".test")); + + // now, we should return the .test domain going forward + Assert.That(hostnameSuffixGetter.GetCachedSuffix(), Is.EqualTo(".test")); + } + + /// + /// TaskOrCancellation waits for either the task to complete, or the given token to be canceled. + /// + internal static async Task TaskOrCancellation(Task task, + CancellationToken cancellationToken) + { + var cancellationTask = new TaskCompletionSource(); + await using (cancellationToken.Register(() => cancellationTask.TrySetCanceled())) + { + // Wait for either the task or the cancellation + var completedTask = await Task.WhenAny(task, cancellationTask.Task); + // Await to propagate exceptions, if any + return await completedTask; + } + } +}