Skip to content

Commit c4ef279

Browse files
committed
Merge branch 'main' into dean/app-buttons
2 parents b3c2eaa + 9b8408d commit c4ef279

12 files changed

+619
-88
lines changed

App/App.xaml.cs

+16-7
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ public partial class App : Application
4242
#endif
4343

4444
private readonly ILogger<App> _logger;
45+
private readonly IUriHandler _uriHandler;
4546

4647
public App()
4748
{
@@ -76,6 +77,8 @@ public App()
7677
.Bind(builder.Configuration.GetSection(MutagenControllerConfigSection));
7778
services.AddSingleton<ISyncSessionController, MutagenController>();
7879
services.AddSingleton<IUserNotifier, UserNotifier>();
80+
services.AddSingleton<IRdpConnector, RdpConnector>();
81+
services.AddSingleton<IUriHandler, UriHandler>();
7982

8083
// SignInWindow views and view models
8184
services.AddTransient<SignInViewModel>();
@@ -104,6 +107,7 @@ public App()
104107

105108
_services = services.BuildServiceProvider();
106109
_logger = (ILogger<App>)_services.GetService(typeof(ILogger<App>))!;
110+
_uriHandler = (IUriHandler)_services.GetService(typeof(IUriHandler))!;
107111

108112
InitializeComponent();
109113
}
@@ -197,7 +201,18 @@ public void OnActivated(object? sender, AppActivationArguments args)
197201
return;
198202
}
199203

200-
HandleURIActivation(protoArgs.Uri);
204+
// don't need to wait for it to complete.
205+
_uriHandler.HandleUri(protoArgs.Uri).ContinueWith(t =>
206+
{
207+
if (t.Exception != null)
208+
{
209+
// don't log query params, as they contain secrets.
210+
_logger.LogError(t.Exception,
211+
"unhandled exception while processing URI coder://{authority}{path}",
212+
protoArgs.Uri.Authority, protoArgs.Uri.AbsolutePath);
213+
}
214+
});
215+
201216
break;
202217

203218
case ExtendedActivationKind.AppNotification:
@@ -211,12 +226,6 @@ public void OnActivated(object? sender, AppActivationArguments args)
211226
}
212227
}
213228

214-
public void HandleURIActivation(Uri uri)
215-
{
216-
// don't log the query string as that's where we include some sensitive information like passwords
217-
_logger.LogInformation("handling URI activation for {path}", uri.AbsolutePath);
218-
}
219-
220229
public void HandleNotification(AppNotificationManager? sender, AppNotificationActivatedEventArgs args)
221230
{
222231
// right now, we don't do anything other than log

App/Services/CredentialManager.cs

+141-77
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ public WindowsCredentialBackend(string credentialsTargetName)
322322

323323
public Task<RawCredentials?> ReadCredentials(CancellationToken ct = default)
324324
{
325-
var raw = NativeApi.ReadCredentials(_credentialsTargetName);
325+
var raw = Wincred.ReadCredentials(_credentialsTargetName);
326326
if (raw == null) return Task.FromResult<RawCredentials?>(null);
327327

328328
RawCredentials? credentials;
@@ -341,115 +341,179 @@ public WindowsCredentialBackend(string credentialsTargetName)
341341
public Task WriteCredentials(RawCredentials credentials, CancellationToken ct = default)
342342
{
343343
var raw = JsonSerializer.Serialize(credentials, RawCredentialsJsonContext.Default.RawCredentials);
344-
NativeApi.WriteCredentials(_credentialsTargetName, raw);
344+
Wincred.WriteCredentials(_credentialsTargetName, raw);
345345
return Task.CompletedTask;
346346
}
347347

348348
public Task DeleteCredentials(CancellationToken ct = default)
349349
{
350-
NativeApi.DeleteCredentials(_credentialsTargetName);
350+
Wincred.DeleteCredentials(_credentialsTargetName);
351351
return Task.CompletedTask;
352352
}
353353

354-
private static class NativeApi
354+
}
355+
356+
/// <summary>
357+
/// Wincred provides relatively low level wrapped calls to the Wincred.h native API.
358+
/// </summary>
359+
internal static class Wincred
360+
{
361+
private const int CredentialTypeGeneric = 1;
362+
private const int CredentialTypeDomainPassword = 2;
363+
private const int PersistenceTypeLocalComputer = 2;
364+
private const int ErrorNotFound = 1168;
365+
private const int CredMaxCredentialBlobSize = 5 * 512;
366+
private const string PackageNTLM = "NTLM";
367+
368+
public static string? ReadCredentials(string targetName)
355369
{
356-
private const int CredentialTypeGeneric = 1;
357-
private const int PersistenceTypeLocalComputer = 2;
358-
private const int ErrorNotFound = 1168;
359-
private const int CredMaxCredentialBlobSize = 5 * 512;
370+
if (!CredReadW(targetName, CredentialTypeGeneric, 0, out var credentialPtr))
371+
{
372+
var error = Marshal.GetLastWin32Error();
373+
if (error == ErrorNotFound) return null;
374+
throw new InvalidOperationException($"Failed to read credentials (Error {error})");
375+
}
360376

361-
public static string? ReadCredentials(string targetName)
377+
try
362378
{
363-
if (!CredReadW(targetName, CredentialTypeGeneric, 0, out var credentialPtr))
364-
{
365-
var error = Marshal.GetLastWin32Error();
366-
if (error == ErrorNotFound) return null;
367-
throw new InvalidOperationException($"Failed to read credentials (Error {error})");
368-
}
379+
var cred = Marshal.PtrToStructure<CREDENTIALW>(credentialPtr);
380+
return Marshal.PtrToStringUni(cred.CredentialBlob, cred.CredentialBlobSize / sizeof(char));
381+
}
382+
finally
383+
{
384+
CredFree(credentialPtr);
385+
}
386+
}
369387

370-
try
371-
{
372-
var cred = Marshal.PtrToStructure<CREDENTIAL>(credentialPtr);
373-
return Marshal.PtrToStringUni(cred.CredentialBlob, cred.CredentialBlobSize / sizeof(char));
374-
}
375-
finally
388+
public static void WriteCredentials(string targetName, string secret)
389+
{
390+
var byteCount = Encoding.Unicode.GetByteCount(secret);
391+
if (byteCount > CredMaxCredentialBlobSize)
392+
throw new ArgumentOutOfRangeException(nameof(secret),
393+
$"The secret is greater than {CredMaxCredentialBlobSize} bytes");
394+
395+
var credentialBlob = Marshal.StringToHGlobalUni(secret);
396+
var cred = new CREDENTIALW
397+
{
398+
Type = CredentialTypeGeneric,
399+
TargetName = targetName,
400+
CredentialBlobSize = byteCount,
401+
CredentialBlob = credentialBlob,
402+
Persist = PersistenceTypeLocalComputer,
403+
};
404+
try
405+
{
406+
if (!CredWriteW(ref cred, 0))
376407
{
377-
CredFree(credentialPtr);
408+
var error = Marshal.GetLastWin32Error();
409+
throw new InvalidOperationException($"Failed to write credentials (Error {error})");
378410
}
379411
}
380-
381-
public static void WriteCredentials(string targetName, string secret)
412+
finally
382413
{
383-
var byteCount = Encoding.Unicode.GetByteCount(secret);
384-
if (byteCount > CredMaxCredentialBlobSize)
385-
throw new ArgumentOutOfRangeException(nameof(secret),
386-
$"The secret is greater than {CredMaxCredentialBlobSize} bytes");
414+
Marshal.FreeHGlobal(credentialBlob);
415+
}
416+
}
387417

388-
var credentialBlob = Marshal.StringToHGlobalUni(secret);
389-
var cred = new CREDENTIAL
390-
{
391-
Type = CredentialTypeGeneric,
392-
TargetName = targetName,
393-
CredentialBlobSize = byteCount,
394-
CredentialBlob = credentialBlob,
395-
Persist = PersistenceTypeLocalComputer,
396-
};
397-
try
398-
{
399-
if (!CredWriteW(ref cred, 0))
400-
{
401-
var error = Marshal.GetLastWin32Error();
402-
throw new InvalidOperationException($"Failed to write credentials (Error {error})");
403-
}
404-
}
405-
finally
406-
{
407-
Marshal.FreeHGlobal(credentialBlob);
408-
}
418+
public static void DeleteCredentials(string targetName)
419+
{
420+
if (!CredDeleteW(targetName, CredentialTypeGeneric, 0))
421+
{
422+
var error = Marshal.GetLastWin32Error();
423+
if (error == ErrorNotFound) return;
424+
throw new InvalidOperationException($"Failed to delete credentials (Error {error})");
409425
}
426+
}
427+
428+
public static void WriteDomainCredentials(string domainName, string serverName, string username, string password)
429+
{
430+
var targetName = $"{domainName}/{serverName}";
431+
var targetInfo = new CREDENTIAL_TARGET_INFORMATIONW
432+
{
433+
TargetName = targetName,
434+
DnsServerName = serverName,
435+
DnsDomainName = domainName,
436+
PackageName = PackageNTLM,
437+
};
438+
var byteCount = Encoding.Unicode.GetByteCount(password);
439+
if (byteCount > CredMaxCredentialBlobSize)
440+
throw new ArgumentOutOfRangeException(nameof(password),
441+
$"The secret is greater than {CredMaxCredentialBlobSize} bytes");
410442

411-
public static void DeleteCredentials(string targetName)
443+
var credentialBlob = Marshal.StringToHGlobalUni(password);
444+
var cred = new CREDENTIALW
412445
{
413-
if (!CredDeleteW(targetName, CredentialTypeGeneric, 0))
446+
Type = CredentialTypeDomainPassword,
447+
TargetName = targetName,
448+
CredentialBlobSize = byteCount,
449+
CredentialBlob = credentialBlob,
450+
Persist = PersistenceTypeLocalComputer,
451+
UserName = username,
452+
};
453+
try
454+
{
455+
if (!CredWriteDomainCredentialsW(ref targetInfo, ref cred, 0))
414456
{
415457
var error = Marshal.GetLastWin32Error();
416-
if (error == ErrorNotFound) return;
417-
throw new InvalidOperationException($"Failed to delete credentials (Error {error})");
458+
throw new InvalidOperationException($"Failed to write credentials (Error {error})");
418459
}
419460
}
461+
finally
462+
{
463+
Marshal.FreeHGlobal(credentialBlob);
464+
}
465+
}
420466

421-
[DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
422-
private static extern bool CredReadW(string target, int type, int reservedFlag, out IntPtr credentialPtr);
467+
[DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
468+
private static extern bool CredReadW(string target, int type, int reservedFlag, out IntPtr credentialPtr);
423469

424-
[DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
425-
private static extern bool CredWriteW([In] ref CREDENTIAL userCredential, [In] uint flags);
470+
[DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
471+
private static extern bool CredWriteW([In] ref CREDENTIALW userCredential, [In] uint flags);
426472

427-
[DllImport("Advapi32.dll", SetLastError = true)]
428-
private static extern void CredFree([In] IntPtr cred);
473+
[DllImport("Advapi32.dll", SetLastError = true)]
474+
private static extern void CredFree([In] IntPtr cred);
429475

430-
[DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
431-
private static extern bool CredDeleteW(string target, int type, int flags);
476+
[DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
477+
private static extern bool CredDeleteW(string target, int type, int flags);
432478

433-
[StructLayout(LayoutKind.Sequential)]
434-
private struct CREDENTIAL
435-
{
436-
public int Flags;
437-
public int Type;
479+
[DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
480+
private static extern bool CredWriteDomainCredentialsW([In] ref CREDENTIAL_TARGET_INFORMATIONW target, [In] ref CREDENTIALW userCredential, [In] uint flags);
438481

439-
[MarshalAs(UnmanagedType.LPWStr)] public string TargetName;
482+
[StructLayout(LayoutKind.Sequential)]
483+
private struct CREDENTIALW
484+
{
485+
public int Flags;
486+
public int Type;
440487

441-
[MarshalAs(UnmanagedType.LPWStr)] public string Comment;
488+
[MarshalAs(UnmanagedType.LPWStr)] public string TargetName;
442489

443-
public long LastWritten;
444-
public int CredentialBlobSize;
445-
public IntPtr CredentialBlob;
446-
public int Persist;
447-
public int AttributeCount;
448-
public IntPtr Attributes;
490+
[MarshalAs(UnmanagedType.LPWStr)] public string Comment;
449491

450-
[MarshalAs(UnmanagedType.LPWStr)] public string TargetAlias;
492+
public long LastWritten;
493+
public int CredentialBlobSize;
494+
public IntPtr CredentialBlob;
495+
public int Persist;
496+
public int AttributeCount;
497+
public IntPtr Attributes;
451498

452-
[MarshalAs(UnmanagedType.LPWStr)] public string UserName;
453-
}
499+
[MarshalAs(UnmanagedType.LPWStr)] public string TargetAlias;
500+
501+
[MarshalAs(UnmanagedType.LPWStr)] public string UserName;
502+
}
503+
504+
[StructLayout(LayoutKind.Sequential)]
505+
private struct CREDENTIAL_TARGET_INFORMATIONW
506+
{
507+
[MarshalAs(UnmanagedType.LPWStr)] public string TargetName;
508+
[MarshalAs(UnmanagedType.LPWStr)] public string NetbiosServerName;
509+
[MarshalAs(UnmanagedType.LPWStr)] public string DnsServerName;
510+
[MarshalAs(UnmanagedType.LPWStr)] public string NetbiosDomainName;
511+
[MarshalAs(UnmanagedType.LPWStr)] public string DnsDomainName;
512+
[MarshalAs(UnmanagedType.LPWStr)] public string DnsTreeName;
513+
[MarshalAs(UnmanagedType.LPWStr)] public string PackageName;
514+
515+
public uint Flags;
516+
public uint CredTypeCount;
517+
public IntPtr CredTypes;
454518
}
455519
}

App/Services/RdpConnector.cs

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
using System;
2+
using System.Diagnostics;
3+
using System.Threading;
4+
using System.Threading.Tasks;
5+
using Microsoft.Extensions.Logging;
6+
7+
namespace Coder.Desktop.App.Services;
8+
9+
public struct RdpCredentials(string username, string password)
10+
{
11+
public readonly string Username = username;
12+
public readonly string Password = password;
13+
}
14+
15+
public interface IRdpConnector
16+
{
17+
public const int DefaultPort = 3389;
18+
19+
public void WriteCredentials(string fqdn, RdpCredentials credentials);
20+
21+
public Task Connect(string fqdn, int port = DefaultPort, CancellationToken ct = default);
22+
}
23+
24+
public class RdpConnector(ILogger<RdpConnector> logger) : IRdpConnector
25+
{
26+
// Remote Desktop always uses TERMSRV as the domain; RDP is a part of Windows "Terminal Services".
27+
private const string RdpDomain = "TERMSRV";
28+
29+
public void WriteCredentials(string fqdn, RdpCredentials credentials)
30+
{
31+
// writing credentials is idempotent for the same domain and server name.
32+
Wincred.WriteDomainCredentials(RdpDomain, fqdn, credentials.Username, credentials.Password);
33+
logger.LogDebug("wrote domain credential for {serverName} with username {username}", fqdn,
34+
credentials.Username);
35+
return;
36+
}
37+
38+
public Task Connect(string fqdn, int port = IRdpConnector.DefaultPort, CancellationToken ct = default)
39+
{
40+
// use mstsc to launch the RDP connection
41+
var mstscProc = new Process();
42+
mstscProc.StartInfo.FileName = "mstsc";
43+
var args = $"/v {fqdn}";
44+
if (port != IRdpConnector.DefaultPort)
45+
{
46+
args = $"/v {fqdn}:{port}";
47+
}
48+
49+
mstscProc.StartInfo.Arguments = args;
50+
mstscProc.StartInfo.CreateNoWindow = true;
51+
mstscProc.StartInfo.UseShellExecute = false;
52+
try
53+
{
54+
if (!mstscProc.Start())
55+
throw new InvalidOperationException("Failed to start mstsc, Start returned false");
56+
}
57+
catch (Exception e)
58+
{
59+
logger.LogWarning(e, "mstsc failed to start");
60+
61+
try
62+
{
63+
mstscProc.Kill();
64+
}
65+
catch
66+
{
67+
// ignored, the process likely doesn't exist
68+
}
69+
70+
mstscProc.Dispose();
71+
throw;
72+
}
73+
74+
return mstscProc.WaitForExitAsync(ct);
75+
}
76+
}

0 commit comments

Comments
 (0)