Skip to content

feat: add support for RDP URIs #87

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions App/App.xaml.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public partial class App : Application
#endif

private readonly ILogger<App> _logger;
private readonly IUriHandler _uriHandler;

public App()
{
Expand Down Expand Up @@ -72,6 +73,8 @@ public App()
.Bind(builder.Configuration.GetSection(MutagenControllerConfigSection));
services.AddSingleton<ISyncSessionController, MutagenController>();
services.AddSingleton<IUserNotifier, UserNotifier>();
services.AddSingleton<IRdpConnector, RdpConnector>();
services.AddSingleton<IUriHandler, UriHandler>();

// SignInWindow views and view models
services.AddTransient<SignInViewModel>();
Expand All @@ -98,6 +101,7 @@ public App()

_services = services.BuildServiceProvider();
_logger = (ILogger<App>)_services.GetService(typeof(ILogger<App>))!;
_uriHandler = (IUriHandler)_services.GetService(typeof(IUriHandler))!;

InitializeComponent();
}
Expand Down Expand Up @@ -190,7 +194,19 @@ public void OnActivated(object? sender, AppActivationArguments args)
_logger.LogWarning("URI activation with null data");
return;
}
HandleURIActivation(protoArgs.Uri);

// don't need to wait for it to complete.
_uriHandler.HandleUri(protoArgs.Uri).ContinueWith(t =>
{
if (t.Exception != null)
{
// don't log query params, as they contain secrets.
_logger.LogError(t.Exception,
"unhandled exception while processing URI coder://{authority}{path}",
protoArgs.Uri.Authority, protoArgs.Uri.AbsolutePath);
}
});

break;

case ExtendedActivationKind.AppNotification:
Expand All @@ -204,12 +220,6 @@ public void OnActivated(object? sender, AppActivationArguments args)
}
}

public void HandleURIActivation(Uri uri)
{
// don't log the query string as that's where we include some sensitive information like passwords
_logger.LogInformation("handling URI activation for {path}", uri.AbsolutePath);
}

public void HandleNotification(AppNotificationManager? sender, AppNotificationActivatedEventArgs args)
{
// right now, we don't do anything other than log
Expand Down
218 changes: 141 additions & 77 deletions App/Services/CredentialManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ public WindowsCredentialBackend(string credentialsTargetName)

public Task<RawCredentials?> ReadCredentials(CancellationToken ct = default)
{
var raw = NativeApi.ReadCredentials(_credentialsTargetName);
var raw = Wincred.ReadCredentials(_credentialsTargetName);
if (raw == null) return Task.FromResult<RawCredentials?>(null);

RawCredentials? credentials;
Expand All @@ -326,115 +326,179 @@ public WindowsCredentialBackend(string credentialsTargetName)
public Task WriteCredentials(RawCredentials credentials, CancellationToken ct = default)
{
var raw = JsonSerializer.Serialize(credentials, RawCredentialsJsonContext.Default.RawCredentials);
NativeApi.WriteCredentials(_credentialsTargetName, raw);
Wincred.WriteCredentials(_credentialsTargetName, raw);
return Task.CompletedTask;
}

public Task DeleteCredentials(CancellationToken ct = default)
{
NativeApi.DeleteCredentials(_credentialsTargetName);
Wincred.DeleteCredentials(_credentialsTargetName);
return Task.CompletedTask;
}

private static class NativeApi
}

/// <summary>
/// Wincred provides relatively low level wrapped calls to the Wincred.h native API.
/// </summary>
internal static class Wincred
{
private const int CredentialTypeGeneric = 1;
private const int CredentialTypeDomainPassword = 2;
private const int PersistenceTypeLocalComputer = 2;
private const int ErrorNotFound = 1168;
private const int CredMaxCredentialBlobSize = 5 * 512;
private const string PackageNTLM = "NTLM";

public static string? ReadCredentials(string targetName)
{
private const int CredentialTypeGeneric = 1;
private const int PersistenceTypeLocalComputer = 2;
private const int ErrorNotFound = 1168;
private const int CredMaxCredentialBlobSize = 5 * 512;
if (!CredReadW(targetName, CredentialTypeGeneric, 0, out var credentialPtr))
{
var error = Marshal.GetLastWin32Error();
if (error == ErrorNotFound) return null;
throw new InvalidOperationException($"Failed to read credentials (Error {error})");
}

public static string? ReadCredentials(string targetName)
try
{
if (!CredReadW(targetName, CredentialTypeGeneric, 0, out var credentialPtr))
{
var error = Marshal.GetLastWin32Error();
if (error == ErrorNotFound) return null;
throw new InvalidOperationException($"Failed to read credentials (Error {error})");
}
var cred = Marshal.PtrToStructure<CREDENTIALW>(credentialPtr);
return Marshal.PtrToStringUni(cred.CredentialBlob, cred.CredentialBlobSize / sizeof(char));
}
finally
{
CredFree(credentialPtr);
}
}

try
{
var cred = Marshal.PtrToStructure<CREDENTIAL>(credentialPtr);
return Marshal.PtrToStringUni(cred.CredentialBlob, cred.CredentialBlobSize / sizeof(char));
}
finally
public static void WriteCredentials(string targetName, string secret)
{
var byteCount = Encoding.Unicode.GetByteCount(secret);
if (byteCount > CredMaxCredentialBlobSize)
throw new ArgumentOutOfRangeException(nameof(secret),
$"The secret is greater than {CredMaxCredentialBlobSize} bytes");

var credentialBlob = Marshal.StringToHGlobalUni(secret);
var cred = new CREDENTIALW
{
Type = CredentialTypeGeneric,
TargetName = targetName,
CredentialBlobSize = byteCount,
CredentialBlob = credentialBlob,
Persist = PersistenceTypeLocalComputer,
};
try
{
if (!CredWriteW(ref cred, 0))
{
CredFree(credentialPtr);
var error = Marshal.GetLastWin32Error();
throw new InvalidOperationException($"Failed to write credentials (Error {error})");
}
}

public static void WriteCredentials(string targetName, string secret)
finally
{
var byteCount = Encoding.Unicode.GetByteCount(secret);
if (byteCount > CredMaxCredentialBlobSize)
throw new ArgumentOutOfRangeException(nameof(secret),
$"The secret is greater than {CredMaxCredentialBlobSize} bytes");
Marshal.FreeHGlobal(credentialBlob);
}
}

var credentialBlob = Marshal.StringToHGlobalUni(secret);
var cred = new CREDENTIAL
{
Type = CredentialTypeGeneric,
TargetName = targetName,
CredentialBlobSize = byteCount,
CredentialBlob = credentialBlob,
Persist = PersistenceTypeLocalComputer,
};
try
{
if (!CredWriteW(ref cred, 0))
{
var error = Marshal.GetLastWin32Error();
throw new InvalidOperationException($"Failed to write credentials (Error {error})");
}
}
finally
{
Marshal.FreeHGlobal(credentialBlob);
}
public static void DeleteCredentials(string targetName)
{
if (!CredDeleteW(targetName, CredentialTypeGeneric, 0))
{
var error = Marshal.GetLastWin32Error();
if (error == ErrorNotFound) return;
throw new InvalidOperationException($"Failed to delete credentials (Error {error})");
}
}

public static void WriteDomainCredentials(string domainName, string serverName, string username, string password)
{
var targetName = $"{domainName}/{serverName}";
var targetInfo = new CREDENTIAL_TARGET_INFORMATIONW
{
TargetName = targetName,
DnsServerName = serverName,
DnsDomainName = domainName,
PackageName = PackageNTLM,
};
var byteCount = Encoding.Unicode.GetByteCount(password);
if (byteCount > CredMaxCredentialBlobSize)
throw new ArgumentOutOfRangeException(nameof(password),
$"The secret is greater than {CredMaxCredentialBlobSize} bytes");

public static void DeleteCredentials(string targetName)
var credentialBlob = Marshal.StringToHGlobalUni(password);
var cred = new CREDENTIALW
{
if (!CredDeleteW(targetName, CredentialTypeGeneric, 0))
Type = CredentialTypeDomainPassword,
TargetName = targetName,
CredentialBlobSize = byteCount,
CredentialBlob = credentialBlob,
Persist = PersistenceTypeLocalComputer,
UserName = username,
};
try
{
if (!CredWriteDomainCredentialsW(ref targetInfo, ref cred, 0))
{
var error = Marshal.GetLastWin32Error();
if (error == ErrorNotFound) return;
throw new InvalidOperationException($"Failed to delete credentials (Error {error})");
throw new InvalidOperationException($"Failed to write credentials (Error {error})");
}
}
finally
{
Marshal.FreeHGlobal(credentialBlob);
}
}

[DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
private static extern bool CredReadW(string target, int type, int reservedFlag, out IntPtr credentialPtr);
[DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
private static extern bool CredReadW(string target, int type, int reservedFlag, out IntPtr credentialPtr);

[DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
private static extern bool CredWriteW([In] ref CREDENTIAL userCredential, [In] uint flags);
[DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
private static extern bool CredWriteW([In] ref CREDENTIALW userCredential, [In] uint flags);

[DllImport("Advapi32.dll", SetLastError = true)]
private static extern void CredFree([In] IntPtr cred);
[DllImport("Advapi32.dll", SetLastError = true)]
private static extern void CredFree([In] IntPtr cred);

[DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
private static extern bool CredDeleteW(string target, int type, int flags);
[DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
private static extern bool CredDeleteW(string target, int type, int flags);

[StructLayout(LayoutKind.Sequential)]
private struct CREDENTIAL
{
public int Flags;
public int Type;
[DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
private static extern bool CredWriteDomainCredentialsW([In] ref CREDENTIAL_TARGET_INFORMATIONW target, [In] ref CREDENTIALW userCredential, [In] uint flags);

[MarshalAs(UnmanagedType.LPWStr)] public string TargetName;
[StructLayout(LayoutKind.Sequential)]
private struct CREDENTIALW
{
public int Flags;
public int Type;

[MarshalAs(UnmanagedType.LPWStr)] public string Comment;
[MarshalAs(UnmanagedType.LPWStr)] public string TargetName;

public long LastWritten;
public int CredentialBlobSize;
public IntPtr CredentialBlob;
public int Persist;
public int AttributeCount;
public IntPtr Attributes;
[MarshalAs(UnmanagedType.LPWStr)] public string Comment;

[MarshalAs(UnmanagedType.LPWStr)] public string TargetAlias;
public long LastWritten;
public int CredentialBlobSize;
public IntPtr CredentialBlob;
public int Persist;
public int AttributeCount;
public IntPtr Attributes;

[MarshalAs(UnmanagedType.LPWStr)] public string UserName;
}
[MarshalAs(UnmanagedType.LPWStr)] public string TargetAlias;

[MarshalAs(UnmanagedType.LPWStr)] public string UserName;
}

[StructLayout(LayoutKind.Sequential)]
private struct CREDENTIAL_TARGET_INFORMATIONW
{
[MarshalAs(UnmanagedType.LPWStr)] public string TargetName;
[MarshalAs(UnmanagedType.LPWStr)] public string NetbiosServerName;
[MarshalAs(UnmanagedType.LPWStr)] public string DnsServerName;
[MarshalAs(UnmanagedType.LPWStr)] public string NetbiosDomainName;
[MarshalAs(UnmanagedType.LPWStr)] public string DnsDomainName;
[MarshalAs(UnmanagedType.LPWStr)] public string DnsTreeName;
[MarshalAs(UnmanagedType.LPWStr)] public string PackageName;

public uint Flags;
public uint CredTypeCount;
public IntPtr CredTypes;
}
}
Loading
Loading