diff --git a/Coder.Desktop.sln b/Coder.Desktop.sln
index 342963b..c5fc598 100644
--- a/Coder.Desktop.sln
+++ b/Coder.Desktop.sln
@@ -5,7 +5,15 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.Vpn", "Vpn\Vp
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.Vpn.Proto", "Vpn.Proto\Vpn.Proto.csproj", "{318E78BB-E6AD-410F-8F3F-B680F6880293}"
EndProject
-Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.Tests", "Tests\Tests.csproj", "{D247B2E7-38A0-4A69-A710-7E8FAA7B807E}"
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.Vpn.Service", "Vpn.Service\Vpn.Service.csproj", "{51B91794-0A2A-4F84-9935-8E17DD2AB260}"
+EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.Tests.Vpn", "Tests.Vpn\Tests.Vpn.csproj", "{D247B2E7-38A0-4A69-A710-7E8FAA7B807E}"
+EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.Tests.Vpn.Proto", "Tests.Vpn.Proto\Tests.Vpn.Proto.csproj", "{AA3EEFF4-414B-4A83-8ACF-188C3C61CCE1}"
+EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.Tests.Vpn.Service", "Tests.Vpn.Service\Tests.Vpn.Service.csproj", "{D32E5FE1-C251-4A08-8EBE-B8D4F18A36F1}"
+EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.CoderSdk", "CoderSdk\CoderSdk.csproj", "{A3D2B2B3-A051-46BD-A190-5487A9F24C28}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
@@ -21,9 +29,25 @@ Global
{318E78BB-E6AD-410F-8F3F-B680F6880293}.Debug|Any CPU.Build.0 = Debug|Any CPU
{318E78BB-E6AD-410F-8F3F-B680F6880293}.Release|Any CPU.ActiveCfg = Release|Any CPU
{318E78BB-E6AD-410F-8F3F-B680F6880293}.Release|Any CPU.Build.0 = Release|Any CPU
+ {51B91794-0A2A-4F84-9935-8E17DD2AB260}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {51B91794-0A2A-4F84-9935-8E17DD2AB260}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {51B91794-0A2A-4F84-9935-8E17DD2AB260}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {51B91794-0A2A-4F84-9935-8E17DD2AB260}.Release|Any CPU.Build.0 = Release|Any CPU
{D247B2E7-38A0-4A69-A710-7E8FAA7B807E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{D247B2E7-38A0-4A69-A710-7E8FAA7B807E}.Debug|Any CPU.Build.0 = Debug|Any CPU
{D247B2E7-38A0-4A69-A710-7E8FAA7B807E}.Release|Any CPU.ActiveCfg = Release|Any CPU
{D247B2E7-38A0-4A69-A710-7E8FAA7B807E}.Release|Any CPU.Build.0 = Release|Any CPU
+ {AA3EEFF4-414B-4A83-8ACF-188C3C61CCE1}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {AA3EEFF4-414B-4A83-8ACF-188C3C61CCE1}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {AA3EEFF4-414B-4A83-8ACF-188C3C61CCE1}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {AA3EEFF4-414B-4A83-8ACF-188C3C61CCE1}.Release|Any CPU.Build.0 = Release|Any CPU
+ {D32E5FE1-C251-4A08-8EBE-B8D4F18A36F1}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {D32E5FE1-C251-4A08-8EBE-B8D4F18A36F1}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {D32E5FE1-C251-4A08-8EBE-B8D4F18A36F1}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {D32E5FE1-C251-4A08-8EBE-B8D4F18A36F1}.Release|Any CPU.Build.0 = Release|Any CPU
+ {A3D2B2B3-A051-46BD-A190-5487A9F24C28}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {A3D2B2B3-A051-46BD-A190-5487A9F24C28}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {A3D2B2B3-A051-46BD-A190-5487A9F24C28}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {A3D2B2B3-A051-46BD-A190-5487A9F24C28}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
EndGlobal
diff --git a/Coder.Desktop.sln.DotSettings b/Coder.Desktop.sln.DotSettings
index 636b95d..176e490 100644
--- a/Coder.Desktop.sln.DotSettings
+++ b/Coder.Desktop.sln.DotSettings
@@ -253,4 +253,7 @@
</Patterns>
True
+ True
+ True
+ True
True
\ No newline at end of file
diff --git a/CoderSdk/CoderApiClient.cs b/CoderSdk/CoderApiClient.cs
new file mode 100644
index 0000000..90343f3
--- /dev/null
+++ b/CoderSdk/CoderApiClient.cs
@@ -0,0 +1,81 @@
+using System.Text;
+using System.Text.Json;
+using System.Text.Json.Serialization;
+
+namespace CoderSdk;
+
+///
+/// Changes names from PascalCase to snake_case.
+///
+internal class SnakeCaseNamingPolicy : JsonNamingPolicy
+{
+ public override string ConvertName(string name)
+ {
+ return string.Concat(
+ name.Select((x, i) => i > 0 && char.IsUpper(x) ? "_" + char.ToLower(x) : char.ToLower(x).ToString())
+ );
+ }
+}
+
+///
+/// Provides a limited selection of API methods for a Coder instance.
+///
+public partial class CoderApiClient
+{
+ // TODO: allow adding headers
+ private readonly HttpClient _httpClient = new();
+ private readonly JsonSerializerOptions _jsonOptions;
+
+ public CoderApiClient(string baseUrl)
+ {
+ var url = new Uri(baseUrl, UriKind.Absolute);
+ if (url.PathAndQuery != "/")
+ throw new ArgumentException($"Base URL '{baseUrl}' must not contain a path", nameof(baseUrl));
+ _httpClient.BaseAddress = url;
+ _jsonOptions = new JsonSerializerOptions
+ {
+ PropertyNameCaseInsensitive = true,
+ PropertyNamingPolicy = new SnakeCaseNamingPolicy(),
+ DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
+ };
+ }
+
+ public CoderApiClient(string baseUrl, string token) : this(baseUrl)
+ {
+ SetSessionToken(token);
+ }
+
+ public void SetSessionToken(string token)
+ {
+ _httpClient.DefaultRequestHeaders.Remove("Coder-Session-Token");
+ _httpClient.DefaultRequestHeaders.Add("Coder-Session-Token", token);
+ }
+
+ private async Task SendRequestAsync(HttpMethod method, string path,
+ object? payload, CancellationToken ct = default)
+ {
+ try
+ {
+ var request = new HttpRequestMessage(method, path);
+
+ if (payload is not null)
+ {
+ var json = JsonSerializer.Serialize(payload, _jsonOptions);
+ request.Content = new StringContent(json, Encoding.UTF8, "application/json");
+ }
+
+ var res = await _httpClient.SendAsync(request, ct);
+ // TODO: this should be improved to try and parse a codersdk.Error response
+ res.EnsureSuccessStatusCode();
+
+ var content = await res.Content.ReadAsStringAsync(ct);
+ var data = JsonSerializer.Deserialize(content, _jsonOptions);
+ if (data is null) throw new JsonException("Deserialized response is null");
+ return data;
+ }
+ catch (Exception e)
+ {
+ throw new Exception($"API Request: {method} {path} (req body: {payload is not null})", e);
+ }
+ }
+}
diff --git a/CoderSdk/CoderSdk.csproj b/CoderSdk/CoderSdk.csproj
new file mode 100644
index 0000000..3a63532
--- /dev/null
+++ b/CoderSdk/CoderSdk.csproj
@@ -0,0 +1,9 @@
+
+
+
+ net8.0
+ enable
+ enable
+
+
+
diff --git a/CoderSdk/Deployment.cs b/CoderSdk/Deployment.cs
new file mode 100644
index 0000000..b00d49f
--- /dev/null
+++ b/CoderSdk/Deployment.cs
@@ -0,0 +1,22 @@
+namespace CoderSdk;
+
+public class BuildInfo
+{
+ public string ExternalUrl { get; set; } = "";
+ public string Version { get; set; } = "";
+ public string DashboardUrl { get; set; } = "";
+ public bool Telemetry { get; set; } = false;
+ public bool WorkspaceProxy { get; set; } = false;
+ public string AgentApiVersion { get; set; } = "";
+ public string ProvisionerApiVersion { get; set; } = "";
+ public string UpgradeMessage { get; set; } = "";
+ public string DeploymentId { get; set; } = "";
+}
+
+public partial class CoderApiClient
+{
+ public Task GetBuildInfo(CancellationToken ct = default)
+ {
+ return SendRequestAsync(HttpMethod.Get, "/api/v2/buildinfo", null, ct);
+ }
+}
diff --git a/CoderSdk/Users.cs b/CoderSdk/Users.cs
new file mode 100644
index 0000000..58ff474
--- /dev/null
+++ b/CoderSdk/Users.cs
@@ -0,0 +1,17 @@
+namespace CoderSdk;
+
+public class User
+{
+ public const string Me = "me";
+
+ // TODO: fill out more fields
+ public string Username { get; set; } = "";
+}
+
+public partial class CoderApiClient
+{
+ public Task GetUser(string user, CancellationToken ct = default)
+ {
+ return SendRequestAsync(HttpMethod.Get, $"/api/v2/users/{user}", null, ct);
+ }
+}
diff --git a/Tests/Vpn.Proto/RpcHeaderTest.cs b/Tests.Vpn.Proto/RpcHeaderTest.cs
similarity index 85%
rename from Tests/Vpn.Proto/RpcHeaderTest.cs
rename to Tests.Vpn.Proto/RpcHeaderTest.cs
index 8e19d0e..55edeea 100644
--- a/Tests/Vpn.Proto/RpcHeaderTest.cs
+++ b/Tests.Vpn.Proto/RpcHeaderTest.cs
@@ -11,14 +11,14 @@ public void Valid()
{
var headerStr = "codervpn manager 1.3,2.1";
var header = RpcHeader.Parse(headerStr);
- Assert.That(header.Role.ToString(), Is.EqualTo(RpcRole.Manager));
+ Assert.That(header.Role, Is.EqualTo("manager"));
Assert.That(header.VersionList, Is.EqualTo(new RpcVersionList(new RpcVersion(1, 3), new RpcVersion(2, 1))));
Assert.That(header.ToString(), Is.EqualTo(headerStr + "\n"));
Assert.That(header.ToBytes().ToArray(), Is.EqualTo(Encoding.UTF8.GetBytes(headerStr + "\n")));
headerStr = "codervpn tunnel 1.0";
header = RpcHeader.Parse(headerStr);
- Assert.That(header.Role.ToString(), Is.EqualTo(RpcRole.Tunnel));
+ Assert.That(header.Role, Is.EqualTo("tunnel"));
Assert.That(header.VersionList, Is.EqualTo(new RpcVersionList(new RpcVersion(1, 0))));
Assert.That(header.ToString(), Is.EqualTo(headerStr + "\n"));
Assert.That(header.ToBytes().ToArray(), Is.EqualTo(Encoding.UTF8.GetBytes(headerStr + "\n")));
@@ -35,7 +35,8 @@ public void ParseInvalid()
Assert.That(ex.Message, Does.Contain("Wrong number of parts"));
ex = Assert.Throws(() => RpcHeader.Parse("cats manager 1.0"));
Assert.That(ex.Message, Does.Contain("Invalid preamble"));
- ex = Assert.Throws(() => RpcHeader.Parse("codervpn cats 1.0"));
- Assert.That(ex.Message, Does.Contain("Unknown role 'cats'"));
+ // RpcHeader doesn't care about the role string as long as it isn't empty.
+ ex = Assert.Throws(() => RpcHeader.Parse("codervpn 1.0"));
+ Assert.That(ex.Message, Does.Contain("Invalid role in header string"));
}
}
diff --git a/Tests.Vpn.Proto/RpcMessageTest.cs b/Tests.Vpn.Proto/RpcMessageTest.cs
new file mode 100644
index 0000000..e254120
--- /dev/null
+++ b/Tests.Vpn.Proto/RpcMessageTest.cs
@@ -0,0 +1,39 @@
+using Coder.Desktop.Vpn.Proto;
+
+namespace Coder.Desktop.Tests.Vpn.Proto;
+
+[TestFixture]
+public class RpcRoleAttributeTest
+{
+ [Test]
+ public void Ok()
+ {
+ var role = new RpcRoleAttribute("manager");
+ Assert.That(role.Role, Is.EqualTo("manager"));
+ role = new RpcRoleAttribute("tunnel");
+ Assert.That(role.Role, Is.EqualTo("tunnel"));
+ role = new RpcRoleAttribute("service");
+ Assert.That(role.Role, Is.EqualTo("service"));
+ role = new RpcRoleAttribute("client");
+ Assert.That(role.Role, Is.EqualTo("client"));
+ }
+}
+
+[TestFixture]
+public class RpcMessageTest
+{
+ [Test]
+ public void GetRole()
+ {
+ // RpcMessage is not a supported message type and doesn't have an
+ // RpcRoleAttribute
+ var ex = Assert.Throws(() => _ = RpcMessage.GetRole());
+ Assert.That(ex.Message,
+ Does.Contain("Message type 'Coder.Desktop.Vpn.Proto.RPC' does not have a RpcRoleAttribute"));
+
+ Assert.That(ManagerMessage.GetRole(), Is.EqualTo("manager"));
+ Assert.That(TunnelMessage.GetRole(), Is.EqualTo("tunnel"));
+ Assert.That(ServiceMessage.GetRole(), Is.EqualTo("service"));
+ Assert.That(ClientMessage.GetRole(), Is.EqualTo("client"));
+ }
+}
diff --git a/Tests/Vpn.Proto/RpcVersionTest.cs b/Tests.Vpn.Proto/RpcVersionTest.cs
similarity index 100%
rename from Tests/Vpn.Proto/RpcVersionTest.cs
rename to Tests.Vpn.Proto/RpcVersionTest.cs
diff --git a/Tests.Vpn.Proto/Tests.Vpn.Proto.csproj b/Tests.Vpn.Proto/Tests.Vpn.Proto.csproj
new file mode 100644
index 0000000..54b7b33
--- /dev/null
+++ b/Tests.Vpn.Proto/Tests.Vpn.Proto.csproj
@@ -0,0 +1,35 @@
+
+
+
+ Coder.Desktop.Tests.Vpn.Proto
+ net8.0
+ enable
+ enable
+
+ false
+ true
+
+
+
+
+ all
+ runtime; build; native; contentfiles; analyzers; buildtransitive
+
+
+
+
+ all
+ runtime; build; native; contentfiles; analyzers; buildtransitive
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/Tests.Vpn.Service/DownloaderTest.cs b/Tests.Vpn.Service/DownloaderTest.cs
new file mode 100644
index 0000000..ae3a0a0
--- /dev/null
+++ b/Tests.Vpn.Service/DownloaderTest.cs
@@ -0,0 +1,362 @@
+using System.Security.Cryptography;
+using System.Text;
+using Coder.Desktop.Vpn.Service;
+using Microsoft.Extensions.Logging.Abstractions;
+
+namespace Coder.Desktop.Tests.Vpn.Service;
+
+public class TestDownloadValidator : IDownloadValidator
+{
+ private readonly Exception _e;
+
+ public TestDownloadValidator(Exception e)
+ {
+ _e = e;
+ }
+
+ public Task ValidateAsync(string path, CancellationToken ct = default)
+ {
+ throw _e;
+ }
+}
+
+[TestFixture]
+public class AuthenticodeDownloadValidatorTest
+{
+ [Test(Description = "Test an unsigned binary")]
+ [CancelAfter(30_000)]
+ public void Unsigned(CancellationToken ct)
+ {
+ // TODO: this
+ }
+
+ [Test(Description = "Test an untrusted binary")]
+ [CancelAfter(30_000)]
+ public void Untrusted(CancellationToken ct)
+ {
+ // TODO: this
+ }
+
+ [Test(Description = "Test an binary with a detached signature (catalog file)")]
+ [CancelAfter(30_000)]
+ public void DifferentCertTrusted(CancellationToken ct)
+ {
+ // notepad.exe uses a catalog file for its signature.
+ var ex = Assert.ThrowsAsync(() =>
+ AuthenticodeDownloadValidator.Coder.ValidateAsync(@"C:\Windows\System32\notepad.exe", ct));
+ Assert.That(ex.Message,
+ Does.Contain("File is not signed with an embedded Authenticode signature: Kind=Catalog"));
+ }
+
+ [Test(Description = "Test a binary signed by a different certificate")]
+ [CancelAfter(30_000)]
+ public void DifferentCertUntrusted(CancellationToken ct)
+ {
+ // TODO: this
+ }
+
+ [Test(Description = "Test a binary signed by Coder's certificate")]
+ [CancelAfter(30_000)]
+ public async Task CoderSigned(CancellationToken ct)
+ {
+ // TODO: this
+ await Task.CompletedTask;
+ }
+}
+
+[TestFixture]
+public class AssemblyVersionDownloadValidatorTest
+{
+ [Test(Description = "No version on binary")]
+ [CancelAfter(30_000)]
+ public void NoVersion(CancellationToken ct)
+ {
+ // TODO: this
+ }
+
+ [Test(Description = "Version mismatch")]
+ [CancelAfter(30_000)]
+ public void VersionMismatch(CancellationToken ct)
+ {
+ // TODO: this
+ }
+
+ [Test(Description = "Version match")]
+ [CancelAfter(30_000)]
+ public async Task VersionMatch(CancellationToken ct)
+ {
+ // TODO: this
+ await Task.CompletedTask;
+ }
+}
+
+[TestFixture]
+public class CombinationDownloadValidatorTest
+{
+ [Test(Description = "All validators pass")]
+ [CancelAfter(30_000)]
+ public async Task AllPass(CancellationToken ct)
+ {
+ var validator = new CombinationDownloadValidator(
+ NullDownloadValidator.Instance,
+ NullDownloadValidator.Instance
+ );
+ await validator.ValidateAsync("test", ct);
+ }
+
+ [Test(Description = "A validator fails")]
+ [CancelAfter(30_000)]
+ public void Fail(CancellationToken ct)
+ {
+ var validator = new CombinationDownloadValidator(
+ NullDownloadValidator.Instance,
+ new TestDownloadValidator(new Exception("test exception"))
+ );
+ var ex = Assert.ThrowsAsync(() => validator.ValidateAsync("test", ct));
+ Assert.That(ex.Message, Is.EqualTo("test exception"));
+ }
+}
+
+[TestFixture]
+public class DownloaderTest
+{
+ // FYI, SetUp and TearDown get called before and after each test.
+ [SetUp]
+ public void Setup()
+ {
+ _tempDir = Path.Combine(Path.GetTempPath(), "Coder.Desktop.Tests.Vpn.Service_" + Path.GetRandomFileName());
+ Directory.CreateDirectory(_tempDir);
+ }
+
+ [TearDown]
+ public void TearDown()
+ {
+ Directory.Delete(_tempDir, true);
+ }
+
+ private string _tempDir;
+
+ private static TestHttpServer EchoServer()
+ {
+ // Create webserver that replies to `/xyz` with a test file containing
+ // `xyz`.
+ return new TestHttpServer(async ctx =>
+ {
+ // Get the path without the leading slash.
+ var path = ctx.Request.Url!.AbsolutePath[1..];
+ var pathBytes = Encoding.UTF8.GetBytes(path);
+
+ // If the client sends an If-None-Match header with the correct ETag,
+ // return 304 Not Modified.
+ var etag = "\"" + Convert.ToHexString(SHA1.HashData(pathBytes)).ToLower() + "\"";
+ if (ctx.Request.Headers["If-None-Match"] == etag)
+ {
+ ctx.Response.StatusCode = 304;
+ return;
+ }
+
+ ctx.Response.StatusCode = 200;
+ ctx.Response.Headers.Add("ETag", etag);
+ ctx.Response.ContentType = "text/plain";
+ ctx.Response.ContentLength64 = pathBytes.Length;
+ await ctx.Response.OutputStream.WriteAsync(pathBytes);
+ });
+ }
+
+ [Test(Description = "Perform a download")]
+ [CancelAfter(30_000)]
+ public async Task Download(CancellationToken ct)
+ {
+ using var httpServer = EchoServer();
+ var url = new Uri(httpServer.BaseUrl + "/test");
+ var destPath = Path.Combine(_tempDir, "test");
+
+ var manager = new Downloader(NullLogger.Instance);
+ var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath,
+ NullDownloadValidator.Instance, ct);
+ await dlTask.Task;
+ Assert.That(dlTask.TotalBytes, Is.EqualTo(4));
+ Assert.That(dlTask.BytesRead, Is.EqualTo(4));
+ Assert.That(dlTask.Progress, Is.EqualTo(1));
+ Assert.That(dlTask.IsCompleted, Is.True);
+ Assert.That(await File.ReadAllTextAsync(destPath, ct), Is.EqualTo("test"));
+ }
+
+ [Test(Description = "Download with custom headers")]
+ [CancelAfter(30_000)]
+ public async Task WithHeaders(CancellationToken ct)
+ {
+ using var httpServer = new TestHttpServer(ctx =>
+ {
+ Assert.That(ctx.Request.Headers["X-Custom-Header"], Is.EqualTo("custom-value"));
+ ctx.Response.StatusCode = 200;
+ });
+ var url = new Uri(httpServer.BaseUrl + "/test");
+ var destPath = Path.Combine(_tempDir, "test");
+
+ var manager = new Downloader(NullLogger.Instance);
+ var req = new HttpRequestMessage(HttpMethod.Get, url);
+ req.Headers.Add("X-Custom-Header", "custom-value");
+ var dlTask = await manager.StartDownloadAsync(req, destPath, NullDownloadValidator.Instance, ct);
+ await dlTask.Task;
+ }
+
+ [Test(Description = "Perform a download against an existing identical file")]
+ [CancelAfter(30_000)]
+ public async Task DownloadExisting(CancellationToken ct)
+ {
+ using var httpServer = EchoServer();
+ var url = new Uri(httpServer.BaseUrl + "/test");
+ var destPath = Path.Combine(_tempDir, "test");
+
+ // Create the destination file with a very old timestamp.
+ await File.WriteAllTextAsync(destPath, "test", ct);
+ File.SetLastWriteTime(destPath, DateTime.Now - TimeSpan.FromDays(365));
+
+ var manager = new Downloader(NullLogger.Instance);
+ var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath,
+ NullDownloadValidator.Instance, ct);
+ await dlTask.Task;
+ Assert.That(dlTask.BytesRead, Is.Zero);
+ Assert.That(await File.ReadAllTextAsync(destPath, ct), Is.EqualTo("test"));
+ Assert.That(File.GetLastWriteTime(destPath), Is.LessThan(DateTime.Now - TimeSpan.FromDays(1)));
+ }
+
+ [Test(Description = "Perform a download against an existing file with different content")]
+ [CancelAfter(30_000)]
+ public async Task DownloadExistingDifferentContent(CancellationToken ct)
+ {
+ using var httpServer = EchoServer();
+ var url = new Uri(httpServer.BaseUrl + "/test");
+ var destPath = Path.Combine(_tempDir, "test");
+
+ // Create the destination file with a very old timestamp.
+ await File.WriteAllTextAsync(destPath, "TEST", ct);
+ File.SetLastWriteTime(destPath, DateTime.Now - TimeSpan.FromDays(365));
+
+ var manager = new Downloader(NullLogger.Instance);
+ var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath,
+ NullDownloadValidator.Instance, ct);
+ await dlTask.Task;
+ Assert.That(dlTask.BytesRead, Is.EqualTo(4));
+ Assert.That(await File.ReadAllTextAsync(destPath, ct), Is.EqualTo("test"));
+ Assert.That(File.GetLastWriteTime(destPath), Is.GreaterThan(DateTime.Now - TimeSpan.FromDays(1)));
+ }
+
+ [Test(Description = "Unexpected response code from server")]
+ [CancelAfter(30_000)]
+ public void UnexpectedResponseCode(CancellationToken ct)
+ {
+ using var httpServer = new TestHttpServer(ctx => { ctx.Response.StatusCode = 404; });
+ var url = new Uri(httpServer.BaseUrl + "/test");
+ var destPath = Path.Combine(_tempDir, "test");
+
+ var manager = new Downloader(NullLogger.Instance);
+ // The "outer" Task should fail.
+ var ex = Assert.ThrowsAsync(async () =>
+ await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath,
+ NullDownloadValidator.Instance, ct));
+ Assert.That(ex.Message, Does.Contain("404"));
+ }
+
+ // TODO: It would be nice to have a test that tests mismatched
+ // Content-Length, but it seems HttpListener doesn't allow that.
+
+ [Test(Description = "Mismatched ETag")]
+ [CancelAfter(30_000)]
+ public async Task MismatchedETag(CancellationToken ct)
+ {
+ using var httpServer = new TestHttpServer(ctx =>
+ {
+ ctx.Response.StatusCode = 200;
+ ctx.Response.Headers.Add("ETag", "\"beef\"");
+ });
+ var url = new Uri(httpServer.BaseUrl + "/test");
+ var destPath = Path.Combine(_tempDir, "test");
+
+ var manager = new Downloader(NullLogger.Instance);
+ // The "inner" Task should fail.
+ var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath,
+ NullDownloadValidator.Instance, ct);
+ var ex = Assert.ThrowsAsync(async () => await dlTask.Task);
+ Assert.That(ex.Message, Does.Contain("ETag does not match SHA1 hash of downloaded file").And.Contains("beef"));
+ }
+
+ [Test(Description = "Timeout on response headers")]
+ [CancelAfter(30_000)]
+ public void CancelledOuter(CancellationToken ct)
+ {
+ using var httpServer = new TestHttpServer(async _ => { await Task.Delay(TimeSpan.FromSeconds(5), ct); });
+ var url = new Uri(httpServer.BaseUrl + "/test");
+ var destPath = Path.Combine(_tempDir, "test");
+
+ var manager = new Downloader(NullLogger.Instance);
+ // The "outer" Task should fail.
+ var smallerCt = new CancellationTokenSource(TimeSpan.FromSeconds(1)).Token;
+ Assert.ThrowsAsync(
+ async () => await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath,
+ NullDownloadValidator.Instance, smallerCt));
+ }
+
+ [Test(Description = "Timeout on response body")]
+ [CancelAfter(30_000)]
+ public async Task CancelledInner(CancellationToken ct)
+ {
+ using var httpServer = new TestHttpServer(async ctx =>
+ {
+ ctx.Response.StatusCode = 200;
+ await ctx.Response.OutputStream.WriteAsync("test"u8.ToArray(), ct);
+ await ctx.Response.OutputStream.FlushAsync(ct);
+ await Task.Delay(TimeSpan.FromSeconds(5), ct);
+ });
+ var url = new Uri(httpServer.BaseUrl + "/test");
+ var destPath = Path.Combine(_tempDir, "test");
+
+ var manager = new Downloader(NullLogger.Instance);
+ // The "inner" Task should fail.
+ var smallerCt = new CancellationTokenSource(TimeSpan.FromSeconds(1)).Token;
+ var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath,
+ NullDownloadValidator.Instance, smallerCt);
+ var ex = Assert.ThrowsAsync(async () => await dlTask.Task);
+ Assert.That(ex.CancellationToken, Is.EqualTo(smallerCt));
+ }
+
+ [Test(Description = "Validation failure")]
+ [CancelAfter(30_000)]
+ public async Task ValidationFailure(CancellationToken ct)
+ {
+ using var httpServer = EchoServer();
+ var url = new Uri(httpServer.BaseUrl + "/test");
+ var destPath = Path.Combine(_tempDir, "test");
+
+ var manager = new Downloader(NullLogger.Instance);
+ var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath,
+ new TestDownloadValidator(new Exception("test exception")), ct);
+
+ var ex = Assert.ThrowsAsync(async () => await dlTask.Task);
+ Assert.That(ex.Message, Does.Contain("Downloaded file failed validation"));
+ Assert.That(ex.InnerException, Is.Not.Null);
+ Assert.That(ex.InnerException!.Message, Is.EqualTo("test exception"));
+ }
+
+ [Test(Description = "Validation failure on existing file")]
+ [CancelAfter(30_000)]
+ public async Task ValidationFailureExistingFile(CancellationToken ct)
+ {
+ using var httpServer = EchoServer();
+ var url = new Uri(httpServer.BaseUrl + "/test");
+ var destPath = Path.Combine(_tempDir, "test");
+ await File.WriteAllTextAsync(destPath, "test", ct);
+
+ var manager = new Downloader(NullLogger.Instance);
+ // The "outer" Task should fail because the inner task never starts.
+ var ex = Assert.ThrowsAsync(async () =>
+ {
+ await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath,
+ new TestDownloadValidator(new Exception("test exception")), ct);
+ });
+ Assert.That(ex.Message, Does.Contain("Existing file failed validation"));
+ Assert.That(ex.InnerException, Is.Not.Null);
+ Assert.That(ex.InnerException!.Message, Is.EqualTo("test exception"));
+ }
+}
diff --git a/Tests.Vpn.Service/TestHttpServer.cs b/Tests.Vpn.Service/TestHttpServer.cs
new file mode 100644
index 0000000..d33697f
--- /dev/null
+++ b/Tests.Vpn.Service/TestHttpServer.cs
@@ -0,0 +1,106 @@
+using System.Net;
+using System.Text;
+
+namespace Coder.Desktop.Tests.Vpn.Service;
+
+public class TestHttpServer : IDisposable
+{
+ // IANA suggested range for dynamic or private ports
+ private const int MinPort = 49215;
+ private const int MaxPort = 65535;
+ private const int PortRangeSize = MaxPort - MinPort + 1;
+
+ private readonly CancellationTokenSource _cts = new();
+ private readonly Func _handler;
+ private readonly HttpListener _listener;
+ private readonly Thread _listenerThread;
+
+ public string BaseUrl { get; private set; }
+
+ public TestHttpServer(Action handler) : this(ctx =>
+ {
+ handler(ctx);
+ return Task.CompletedTask;
+ })
+ {
+ }
+
+ public TestHttpServer(Func handler)
+ {
+ _handler = handler;
+
+ // Yes, this is the best way to get an unused port using HttpListener.
+ // It sucks.
+ //
+ // This implementation picks a random start point between MinPort and
+ // MaxPort, then iterates through the entire range (wrapping around at
+ // the end) until it finds a free port.
+ var port = 0;
+ var random = new Random();
+ var startPort = random.Next(MinPort, MaxPort + 1);
+ for (var i = 0; i < PortRangeSize; i++)
+ {
+ port = MinPort + (startPort - MinPort + i) % PortRangeSize;
+
+ var attempt = new HttpListener();
+ attempt.Prefixes.Add($"http://localhost:{port}/");
+ try
+ {
+ attempt.Start();
+ _listener = attempt;
+ break;
+ }
+ catch
+ {
+ // Listener disposes itself on failure
+ }
+ }
+
+ if (_listener == null || port == 0)
+ throw new InvalidOperationException("Could not find a free port to listen on");
+ BaseUrl = $"http://localhost:{port}";
+
+ _listenerThread = new Thread(() =>
+ {
+ while (!_cts.Token.IsCancellationRequested)
+ try
+ {
+ var context = _listener.GetContext();
+ Task.Run(() => HandleRequest(context));
+ }
+ catch (HttpListenerException) when (_cts.Token.IsCancellationRequested)
+ {
+ break;
+ }
+ });
+
+ _listenerThread.Start();
+ }
+
+ public void Dispose()
+ {
+ _cts.Cancel();
+ _listener.Stop();
+ _listenerThread.Join();
+ GC.SuppressFinalize(this);
+ }
+
+ private async Task HandleRequest(HttpListenerContext context)
+ {
+ try
+ {
+ await _handler(context);
+ }
+ catch (Exception e)
+ {
+ await Console.Error.WriteLineAsync($"Exception while serving HTTP request: {e}");
+ context.Response.StatusCode = 500;
+ var response = Encoding.UTF8.GetBytes($"Internal Server Error: {e.Message}");
+ await context.Response.OutputStream.WriteAsync(response);
+ }
+ finally
+ {
+ context.Response.Close();
+ }
+ }
+}
diff --git a/Tests.Vpn.Service/Tests.Vpn.Service.csproj b/Tests.Vpn.Service/Tests.Vpn.Service.csproj
new file mode 100644
index 0000000..2fdfa76
--- /dev/null
+++ b/Tests.Vpn.Service/Tests.Vpn.Service.csproj
@@ -0,0 +1,35 @@
+
+
+
+ Coder.Desktop.Tests.Vpn.Service
+ net8.0-windows
+ enable
+ enable
+
+ false
+ true
+
+
+
+
+ all
+ runtime; build; native; contentfiles; analyzers; buildtransitive
+
+
+
+
+ all
+ runtime; build; native; contentfiles; analyzers; buildtransitive
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/Tests/Vpn/SerdesTest.cs b/Tests.Vpn/SerdesTest.cs
similarity index 65%
rename from Tests/Vpn/SerdesTest.cs
rename to Tests.Vpn/SerdesTest.cs
index 7673d6a..3266f14 100644
--- a/Tests/Vpn/SerdesTest.cs
+++ b/Tests.Vpn/SerdesTest.cs
@@ -1,6 +1,7 @@
using System.Buffers.Binary;
using Coder.Desktop.Vpn;
using Coder.Desktop.Vpn.Proto;
+using Coder.Desktop.Vpn.Utilities;
using Google.Protobuf;
namespace Coder.Desktop.Tests.Vpn;
@@ -9,26 +10,26 @@ namespace Coder.Desktop.Tests.Vpn;
public class SerdesTest
{
[Test(Description = "Tests that writing and reading a message works")]
- [Timeout(5_000)]
- public async Task WriteReadMessage()
+ [CancelAfter(30_000)]
+ public async Task WriteReadMessage(CancellationToken ct)
{
- var (stream1, stream2) = BidirectionalPipe.New();
+ var (stream1, stream2) = BidirectionalPipe.NewInMemory();
var serdes = new Serdes();
var msg = new ManagerMessage
{
Start = new StartRequest(),
};
- await serdes.WriteMessage(stream1, msg);
- var got = await serdes.ReadMessage(stream2);
+ await serdes.WriteMessage(stream1, msg, ct);
+ var got = await serdes.ReadMessage(stream2, ct);
Assert.That(msg, Is.EqualTo(got));
}
[Test(Description = "Tests that writing a message larger than 16 MiB throws an exception")]
- [Timeout(5_000)]
- public void WriteMessageTooLarge()
+ [CancelAfter(30_000)]
+ public void WriteMessageTooLarge(CancellationToken ct)
{
- var (stream1, _) = BidirectionalPipe.New();
+ var (stream1, _) = BidirectionalPipe.NewInMemory();
var serdes = new Serdes();
var msg = new ManagerMessage
@@ -39,51 +40,51 @@ public void WriteMessageTooLarge()
CoderUrl = "test",
},
};
- Assert.ThrowsAsync(() => serdes.WriteMessage(stream1, msg));
+ Assert.ThrowsAsync(() => serdes.WriteMessage(stream1, msg, ct));
}
[Test(Description = "Tests that attempting to read a message larger than 16 MiB throws an exception")]
- [Timeout(5_000)]
- public async Task ReadMessageTooLarge()
+ [CancelAfter(30_000)]
+ public async Task ReadMessageTooLarge(CancellationToken ct)
{
- var (stream1, stream2) = BidirectionalPipe.New();
+ var (stream1, stream2) = BidirectionalPipe.NewInMemory();
var serdes = new Serdes();
// In this test we don't actually write a message as the parser should
// bail out immediately after reading the message length
var lenBytes = new byte[4];
BinaryPrimitives.WriteUInt32BigEndian(lenBytes, 0x1000001);
- await stream1.WriteAsync(lenBytes);
- Assert.ThrowsAsync(() => serdes.ReadMessage(stream2));
+ await stream1.WriteAsync(lenBytes, ct);
+ Assert.ThrowsAsync(() => serdes.ReadMessage(stream2, ct));
}
[Test(Description = "Read an empty (size 0) message from the stream")]
- [Timeout(5_000)]
- public async Task ReadEmptyMessage()
+ [CancelAfter(30_000)]
+ public async Task ReadEmptyMessage(CancellationToken ct)
{
- var (stream1, stream2) = BidirectionalPipe.New();
+ var (stream1, stream2) = BidirectionalPipe.NewInMemory();
var serdes = new Serdes();
// Write an empty message.
var lenBytes = new byte[4];
BinaryPrimitives.WriteUInt32BigEndian(lenBytes, 0);
- await stream1.WriteAsync(lenBytes);
- var ex = Assert.ThrowsAsync(() => serdes.ReadMessage(stream2));
+ await stream1.WriteAsync(lenBytes, ct);
+ var ex = Assert.ThrowsAsync(() => serdes.ReadMessage(stream2, ct));
Assert.That(ex.Message, Does.Contain("Received message size 0"));
}
[Test(Description = "Read an invalid/corrupt message from the stream")]
- [Timeout(5_000)]
- public async Task ReadInvalidMessage()
+ [CancelAfter(30_000)]
+ public async Task ReadInvalidMessage(CancellationToken ct)
{
- var (stream1, stream2) = BidirectionalPipe.New();
+ var (stream1, stream2) = BidirectionalPipe.NewInMemory();
var serdes = new Serdes();
var lenBytes = new byte[4];
BinaryPrimitives.WriteUInt32BigEndian(lenBytes, 1);
- await stream1.WriteAsync(lenBytes);
- await stream1.WriteAsync(new byte[1]);
- var ex = Assert.ThrowsAsync(() => serdes.ReadMessage(stream2));
+ await stream1.WriteAsync(lenBytes, ct);
+ await stream1.WriteAsync(new byte[1], ct);
+ var ex = Assert.ThrowsAsync(() => serdes.ReadMessage(stream2, ct));
Assert.That(ex.InnerException, Is.TypeOf(typeof(InvalidProtocolBufferException)));
}
}
diff --git a/Tests/Vpn/SpeakerTest.cs b/Tests.Vpn/SpeakerTest.cs
similarity index 74%
rename from Tests/Vpn/SpeakerTest.cs
rename to Tests.Vpn/SpeakerTest.cs
index f06c62f..51950f7 100644
--- a/Tests/Vpn/SpeakerTest.cs
+++ b/Tests.Vpn/SpeakerTest.cs
@@ -1,84 +1,12 @@
-using System.Buffers;
-using System.IO.Pipelines;
using System.Reflection;
using System.Text;
using System.Threading.Channels;
using Coder.Desktop.Vpn;
using Coder.Desktop.Vpn.Proto;
+using Coder.Desktop.Vpn.Utilities;
namespace Coder.Desktop.Tests.Vpn;
-#region BidrectionalPipe
-
-internal class BidirectionalPipe(PipeReader reader, PipeWriter writer) : Stream
-{
- public override bool CanRead => true;
- public override bool CanSeek => false;
- public override bool CanWrite => true;
- public override long Length => -1;
-
- public override long Position
- {
- get => -1;
- set => throw new NotImplementedException("BidirectionalPipe does not support setting position");
- }
-
- public static (BidirectionalPipe, BidirectionalPipe) New()
- {
- var pipe1 = new Pipe();
- var pipe2 = new Pipe();
- return (new BidirectionalPipe(pipe1.Reader, pipe2.Writer), new BidirectionalPipe(pipe2.Reader, pipe1.Writer));
- }
-
- public override void Flush()
- {
- }
-
- public override int Read(byte[] buffer, int offset, int count)
- {
- return ReadAsync(buffer, offset, count).GetAwaiter().GetResult();
- }
-
- public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken ct)
- {
- var result = await reader.ReadAtLeastAsync(1, ct);
- var n = Math.Min((int)result.Buffer.Length, count);
- // Copy result.Buffer[0:n] to buffer[offset:offset+n]
- result.Buffer.Slice(0, n).CopyTo(buffer.AsMemory(offset, n).Span);
- if (!result.IsCompleted) reader.AdvanceTo(result.Buffer.GetPosition(n));
- return n;
- }
-
- public override long Seek(long offset, SeekOrigin origin)
- {
- throw new NotImplementedException("BidirectionalPipe does not support seeking");
- }
-
- public override void SetLength(long value)
- {
- throw new NotImplementedException("BidirectionalPipe does not support setting length");
- }
-
- public override void Write(byte[] buffer, int offset, int count)
- {
- WriteAsync(buffer, offset, count).GetAwaiter().GetResult();
- }
-
- public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken ct)
- {
- await writer.WriteAsync(buffer.AsMemory(offset, count), ct);
- }
-
- protected override void Dispose(bool disposing)
- {
- base.Dispose(disposing);
- writer.Complete();
- reader.Complete();
- }
-}
-
-#endregion
-
#region FailableStream
internal class FailableStream : Stream
@@ -88,13 +16,6 @@ internal class FailableStream : Stream
private readonly TaskCompletionSource _writeTcs = new();
- public FailableStream(Stream inner, Exception? writeException, Exception? readException)
- {
- _inner = inner;
- if (writeException != null) _writeTcs.SetException(writeException);
- if (readException != null) _readTcs.SetException(readException);
- }
-
public override bool CanRead => _inner.CanRead;
public override bool CanSeek => _inner.CanSeek;
public override bool CanWrite => _inner.CanWrite;
@@ -106,6 +27,13 @@ public override long Position
set => _inner.Position = value;
}
+ public FailableStream(Stream inner, Exception? writeException, Exception? readException)
+ {
+ _inner = inner;
+ if (writeException != null) _writeTcs.SetException(writeException);
+ if (readException != null) _readTcs.SetException(readException);
+ }
+
public void SetWriteException(Exception ex)
{
_writeTcs.SetException(ex);
@@ -172,10 +100,10 @@ public override async ValueTask WriteAsync(ReadOnlyMemory buffer,
public class SpeakerTest
{
[Test(Description = "Send a message from speaker1 to speaker2, receive it, and send a reply back")]
- [Timeout(30_000)]
- public async Task SendReceiveReplyReceive()
+ [CancelAfter(30_000)]
+ public async Task SendReceiveReplyReceive(CancellationToken ct)
{
- var (stream1, stream2) = BidirectionalPipe.New();
+ var (stream1, stream2) = BidirectionalPipe.NewInMemory();
await using var speaker1 = new Speaker(stream1);
var speaker1Ch = Channel
@@ -190,14 +118,14 @@ public async Task SendReceiveReplyReceive()
speaker2.Error += ex => { Assert.Fail($"speaker2 error: {ex}"); };
// Start both speakers simultaneously
- Task.WaitAll(speaker1.StartAsync(), speaker2.StartAsync());
+ await Task.WhenAll(speaker1.StartAsync(ct), speaker2.StartAsync(ct));
// Send a normal message from speaker2 to speaker1
await speaker2.SendMessage(new TunnelMessage
{
PeerUpdate = new PeerUpdate(),
- });
- var receivedMessage = await speaker1Ch.Reader.ReadAsync();
+ }, ct);
+ var receivedMessage = await speaker1Ch.Reader.ReadAsync(ct);
Assert.That(receivedMessage.RpcField, Is.Null); // not a request
Assert.That(receivedMessage.Message.PeerUpdate, Is.Not.Null);
@@ -209,10 +137,10 @@ await speaker2.SendMessage(new TunnelMessage
ApiToken = "test",
CoderUrl = "test",
},
- });
+ }, ct);
// Receive the message in speaker2
- var message = await speaker2Ch.Reader.ReadAsync();
+ var message = await speaker2Ch.Reader.ReadAsync(ct);
Assert.That(message.RpcField, Is.Not.Null);
Assert.That(message.RpcField!.MsgId, Is.Not.EqualTo(0));
Assert.That(message.RpcField!.ResponseTo, Is.EqualTo(0));
@@ -225,7 +153,7 @@ await message.SendReply(new TunnelMessage
{
Success = true,
},
- });
+ }, ct);
// Receive the reply in speaker1 by awaiting sendTask
var reply = await sendTask;
@@ -236,57 +164,58 @@ await message.SendReply(new TunnelMessage
}
[Test(Description = "Encounter a write error during handshake")]
- [Timeout(30_000)]
- public async Task WriteError()
+ [CancelAfter(30_000)]
+ public async Task WriteError(CancellationToken ct)
{
- var (stream1, _) = BidirectionalPipe.New();
+ var (stream1, _) = BidirectionalPipe.NewInMemory();
var writeEx = new IOException("Test write error");
var failStream = new FailableStream(stream1, writeEx, null);
await using var speaker = new Speaker(failStream);
- var gotEx = Assert.ThrowsAsync(() => speaker.StartAsync());
+ var gotEx = Assert.ThrowsAsync(() => speaker.StartAsync(ct));
Assert.That(gotEx, Is.EqualTo(writeEx));
}
[Test(Description = "Encounter a read error during handshake")]
- [Timeout(30_000)]
- public async Task ReadError()
+ [CancelAfter(30_000)]
+ public async Task ReadError(CancellationToken ct)
{
- var (stream1, _) = BidirectionalPipe.New();
+ var (stream1, _) = BidirectionalPipe.NewInMemory();
var readEx = new IOException("Test read error");
var failStream = new FailableStream(stream1, null, readEx);
await using var speaker = new Speaker(failStream);
- var gotEx = Assert.ThrowsAsync(() => speaker.StartAsync());
+ var gotEx = Assert.ThrowsAsync(() => speaker.StartAsync(ct));
Assert.That(gotEx, Is.EqualTo(readEx));
}
[Test(Description = "Receive a header that exceeds 256 bytes")]
- [Timeout(30_000)]
- public async Task ReadLargeHeader()
+ [CancelAfter(30_000)]
+ public async Task ReadLargeHeader(CancellationToken ct)
{
- var (stream1, stream2) = BidirectionalPipe.New();
+ var (stream1, stream2) = BidirectionalPipe.NewInMemory();
await using var speaker1 = new Speaker(stream1);
var header = new byte[257];
for (var i = 0; i < header.Length; i++) header[i] = (byte)'a';
- await stream2.WriteAsync(header);
+ await stream2.WriteAsync(header, ct);
- var gotEx = Assert.ThrowsAsync(() => speaker1.StartAsync());
+ var gotEx = Assert.ThrowsAsync(() => speaker1.StartAsync(ct));
Assert.That(gotEx.Message, Does.Contain("Header malformed or too large"));
}
[Test(Description = "Receive an invalid header")]
- [Timeout(30_000)]
- public async Task ReceiveInvalidHeader()
+ [CancelAfter(30_000)]
+ public async Task ReceiveInvalidHeader(CancellationToken ct)
{
var cases = new Dictionary
{
{ "invalid\n", ("Failed to parse peer header", "Wrong number of parts in header string") },
{ "cats tunnel 1.0\n", ("Failed to parse peer header", "Invalid preamble in header string") },
- { "codervpn cats 1.0\n", ("Failed to parse peer header", "Unknown role 'cats'") },
+ { "codervpn 1.0\n", ("Failed to parse peer header", "Invalid role in header string") },
+ { "codervpn cats 1.0\n", ("Expected peer role 'tunnel' but got 'cats'", null) },
{ "codervpn manager 1.0\n", ("Expected peer role 'tunnel' but got 'manager'", null) },
{
"codervpn tunnel 1000.1\n",
@@ -299,12 +228,12 @@ public async Task ReceiveInvalidHeader()
foreach (var (header, (expectedOuter, expectedInner)) in cases)
{
- var (stream1, stream2) = BidirectionalPipe.New();
+ var (stream1, stream2) = BidirectionalPipe.NewInMemory();
await using var speaker1 = new Speaker(stream1);
- await stream2.WriteAsync(Encoding.UTF8.GetBytes(header));
+ await stream2.WriteAsync(Encoding.UTF8.GetBytes(header), ct);
- var gotEx = Assert.CatchAsync(() => speaker1.StartAsync(), $"header: '{header}'");
+ var gotEx = Assert.CatchAsync(() => speaker1.StartAsync(ct), $"header: '{header}'");
Assert.That(gotEx.Message, Does.Contain(expectedOuter), $"header: '{header}'");
if (expectedInner is null)
{
@@ -318,10 +247,10 @@ public async Task ReceiveInvalidHeader()
}
[Test(Description = "Encounter a write error during message send")]
- [Timeout(30_000)]
- public async Task SendMessageWriteError()
+ [CancelAfter(30_000)]
+ public async Task SendMessageWriteError(CancellationToken ct)
{
- var (stream1, stream2) = BidirectionalPipe.New();
+ var (stream1, stream2) = BidirectionalPipe.NewInMemory();
var failStream = new FailableStream(stream1, null, null);
await using var speaker1 = new Speaker(failStream);
@@ -330,7 +259,7 @@ public async Task SendMessageWriteError()
await using var speaker2 = new Speaker(stream2);
speaker2.Receive += msg => Assert.Fail($"speaker2 received message: {msg}");
speaker2.Error += ex => Assert.Fail($"speaker2 error: {ex}");
- await Task.WhenAll(speaker1.StartAsync(), speaker2.StartAsync());
+ await Task.WhenAll(speaker1.StartAsync(ct), speaker2.StartAsync(ct));
var writeEx = new IOException("Test write error");
failStream.SetWriteException(writeEx);
@@ -338,15 +267,15 @@ public async Task SendMessageWriteError()
var gotEx = Assert.ThrowsAsync(() => speaker1.SendMessage(new ManagerMessage
{
Start = new StartRequest(),
- }));
+ }, ct));
Assert.That(gotEx, Is.EqualTo(writeEx));
}
[Test(Description = "Encounter a read error during message receive")]
- [Timeout(30_000)]
- public async Task ReceiveMessageReadError()
+ [CancelAfter(30_000)]
+ public async Task ReceiveMessageReadError(CancellationToken ct)
{
- var (stream1, stream2) = BidirectionalPipe.New();
+ var (stream1, stream2) = BidirectionalPipe.NewInMemory();
var failStream = new FailableStream(stream1, null, null);
// Speaker1 is bound to failStream and will write an error to errorCh
@@ -359,13 +288,13 @@ public async Task ReceiveMessageReadError()
await using var speaker2 = new Speaker(stream2);
speaker2.Receive += msg => Assert.Fail($"speaker2 received message: {msg}");
speaker2.Error += ex => Assert.Fail($"speaker2 error: {ex}");
- await Task.WhenAll(speaker1.StartAsync(), speaker2.StartAsync());
+ await Task.WhenAll(speaker1.StartAsync(ct), speaker2.StartAsync(ct));
// Now the handshake is complete, cause all reads to fail
var readEx = new IOException("Test write error");
failStream.SetReadException(readEx);
- var gotEx = await errorCh.Reader.ReadAsync();
+ var gotEx = await errorCh.Reader.ReadAsync(ct);
Assert.That(gotEx, Is.EqualTo(readEx));
// The receive loop should be stopped within a timely fashion.
@@ -377,24 +306,24 @@ public async Task ReceiveMessageReadError()
}
else
{
- var delayTask = Task.Delay(TimeSpan.FromSeconds(5));
+ var delayTask = Task.Delay(TimeSpan.FromSeconds(5), ct);
await Task.WhenAny(receiveLoopTask, delayTask);
Assert.That(receiveLoopTask.IsCompleted, Is.True);
}
}
[Test(Description = "Handle dispose while receive loop is running")]
- [Timeout(30_000)]
- public async Task DisposeWhileReceiveLoopRunning()
+ [CancelAfter(30_000)]
+ public async Task DisposeWhileReceiveLoopRunning(CancellationToken ct)
{
- var (stream1, stream2) = BidirectionalPipe.New();
+ var (stream1, stream2) = BidirectionalPipe.NewInMemory();
var speaker1 = new Speaker(stream1);
await using var speaker2 = new Speaker(stream2);
- await Task.WhenAll(speaker1.StartAsync(), speaker2.StartAsync());
+ await Task.WhenAll(speaker1.StartAsync(ct), speaker2.StartAsync(ct));
// Dispose should happen in a timely fashion
var disposeTask = speaker1.DisposeAsync();
- var delayTask = Task.Delay(TimeSpan.FromSeconds(5));
+ var delayTask = Task.Delay(TimeSpan.FromSeconds(5), ct);
await Task.WhenAny(disposeTask.AsTask(), delayTask);
Assert.That(disposeTask.IsCompleted, Is.True);
@@ -408,19 +337,19 @@ public async Task DisposeWhileReceiveLoopRunning()
}
[Test(Description = "Handle dispose while a message is awaiting a reply")]
- [Timeout(30_000)]
- public async Task DisposeWhileAwaitingReply()
+ [CancelAfter(30_000)]
+ public async Task DisposeWhileAwaitingReply(CancellationToken ct)
{
- var (stream1, stream2) = BidirectionalPipe.New();
+ var (stream1, stream2) = BidirectionalPipe.NewInMemory();
var speaker1 = new Speaker(stream1);
await using var speaker2 = new Speaker(stream2);
- await Task.WhenAll(speaker1.StartAsync(), speaker2.StartAsync());
+ await Task.WhenAll(speaker1.StartAsync(ct), speaker2.StartAsync(ct));
// Send a message from speaker1 to speaker2
var sendTask = speaker1.SendRequestAwaitReply(new ManagerMessage
{
Start = new StartRequest(),
- });
+ }, ct);
// Dispose speaker1
await speaker1.DisposeAsync();
diff --git a/Tests/Tests.csproj b/Tests.Vpn/Tests.Vpn.csproj
similarity index 54%
rename from Tests/Tests.csproj
rename to Tests.Vpn/Tests.Vpn.csproj
index cccd5dc..df00e81 100644
--- a/Tests/Tests.csproj
+++ b/Tests.Vpn/Tests.Vpn.csproj
@@ -1,7 +1,7 @@
- Coder.Desktop.Tests
+ Coder.Desktop.Tests.Vpn
net8.0
enable
enable
@@ -11,12 +11,17 @@
-
-
-
-
-
-
+
+ all
+ runtime; build; native; contentfiles; analyzers; buildtransitive
+
+
+
+
+ all
+ runtime; build; native; contentfiles; analyzers; buildtransitive
+
+
diff --git a/Tests.Vpn/Utilities/TaskUtilitiesTest.cs b/Tests.Vpn/Utilities/TaskUtilitiesTest.cs
new file mode 100644
index 0000000..a6a4583
--- /dev/null
+++ b/Tests.Vpn/Utilities/TaskUtilitiesTest.cs
@@ -0,0 +1,141 @@
+using Coder.Desktop.Vpn.Utilities;
+
+namespace Coder.Desktop.Tests.Vpn.Utilities;
+
+[TestFixture]
+public class TaskUtilitiesTest
+{
+ [Test(Description = "CancellableWhenAll with no tasks should complete immediately")]
+ [Timeout(30_000)]
+ public void CancellableWhenAll_NoTasks()
+ {
+ var task = TaskUtilities.CancellableWhenAll(new CancellationTokenSource());
+ Assert.That(task.IsCompleted, Is.True);
+ }
+
+ [Test(Description = "CancellableWhenAll with a single task should complete")]
+ [Timeout(30_000)]
+ public async Task CancellableWhenAll_SingleTask()
+ {
+ var innerTask = new TaskCompletionSource();
+ var task = TaskUtilities.CancellableWhenAll(new CancellationTokenSource(), innerTask.Task);
+ Assert.That(task.IsCompleted, Is.False);
+ innerTask.SetResult();
+ await task;
+ }
+
+ [Test(Description = "CancellableWhenAll with a single task that faults should propagate the exception")]
+ [Timeout(30_000)]
+ public void CancellableWhenAll_SingleTaskFault()
+ {
+ var cts = new CancellationTokenSource();
+ var innerTask = new TaskCompletionSource();
+ var task = TaskUtilities.CancellableWhenAll(cts, innerTask.Task);
+ Assert.That(task.IsCompleted, Is.False);
+ innerTask.SetException(new InvalidOperationException("Test"));
+ Assert.ThrowsAsync(async () => await task);
+ Assert.That(cts.IsCancellationRequested, Is.True);
+ }
+
+ [Test(Description = "CancellableWhenAll with a single task that is canceled should propagate the cancellation")]
+ [Timeout(30_000)]
+ public void CancellableWhenAll_SingleTaskCanceled()
+ {
+ var cts = new CancellationTokenSource();
+ var innerTask = new TaskCompletionSource();
+ var task = TaskUtilities.CancellableWhenAll(cts, innerTask.Task);
+ Assert.That(task.IsCompleted, Is.False);
+ innerTask.SetCanceled();
+ Assert.ThrowsAsync(async () => await task);
+ Assert.That(cts.IsCancellationRequested, Is.True);
+ }
+
+ [Test(Description = "CancellableWhenAll with multiple tasks should complete when all tasks are completed")]
+ [Timeout(30_000)]
+ public async Task CancellableWhenAll_MultipleTasks()
+ {
+ var cts = new CancellationTokenSource();
+ var innerTask1 = new TaskCompletionSource();
+ var innerTask2 = new TaskCompletionSource();
+
+ var task = TaskUtilities.CancellableWhenAll(cts, innerTask1.Task, innerTask2.Task);
+ Assert.That(task.IsCompleted, Is.False);
+ // This dance of awaiting a newly added continuation task before
+ // completing the TCS is to ensure that the original continuation task
+ // finished since it's inlinable.
+ var task1ContinueTask = innerTask1.Task.ContinueWith(_ => { });
+ innerTask1.SetResult();
+ await task1ContinueTask;
+ Assert.That(task.IsCompleted, Is.False);
+ var task2ContinueTask = innerTask2.Task.ContinueWith(_ => { });
+ innerTask2.SetResult();
+ await task2ContinueTask;
+ await task;
+ }
+
+ [Test(Description = "CancellableWhenAll with multiple tasks that fault should propagate the first exception only")]
+ [Timeout(30_000)]
+ public async Task CancellableWhenAll_MultipleTasksFault()
+ {
+ var cts = new CancellationTokenSource();
+ var innerTask1 = new TaskCompletionSource();
+ var innerTask2 = new TaskCompletionSource();
+
+ var task = TaskUtilities.CancellableWhenAll(cts, innerTask1.Task, innerTask2.Task);
+ Assert.That(task.IsCompleted, Is.False);
+ var task1ContinueTask = innerTask1.Task.ContinueWith(_ => { });
+ innerTask1.SetException(new Exception("Test1"));
+ await task1ContinueTask;
+ Assert.That(task.IsCompleted, Is.False);
+ var task2ContinueTask = innerTask2.Task.ContinueWith(_ => { });
+ innerTask2.SetException(new Exception("Test2"));
+ await task2ContinueTask;
+ var ex = Assert.ThrowsAsync(async () => await task);
+ Assert.That(ex.Message, Is.EqualTo("Test1"));
+ }
+
+ [Test(Description = "CancellableWhenAll with an exception and a cancellation should propagate the first thing")]
+ [Timeout(30_000)]
+ public async Task CancellableWhenAll_MultipleTasksFaultAndCanceled()
+ {
+ var cts = new CancellationTokenSource();
+ var innerTask1 = new TaskCompletionSource();
+ var innerTask2 = new TaskCompletionSource();
+ var innertask3 = Task.CompletedTask;
+
+ var task = TaskUtilities.CancellableWhenAll(cts, innerTask1.Task, innerTask2.Task, innertask3);
+ Assert.That(task.IsCompleted, Is.False);
+ var task1ContinueTask = innerTask1.Task.ContinueWith(_ => { });
+ innerTask1.SetException(new Exception("Test1"));
+ await task1ContinueTask;
+ Assert.That(task.IsCompleted, Is.False);
+ Assert.That(cts.IsCancellationRequested, Is.True);
+ var task2ContinueTask = innerTask2.Task.ContinueWith(_ => { });
+ innerTask2.SetCanceled();
+ await task2ContinueTask;
+ var ex = Assert.ThrowsAsync(async () => await task);
+ Assert.That(ex.Message, Is.EqualTo("Test1"));
+ }
+
+ [Test(Description = "CancellableWhenAll with a cancellation and an exception should propagate the first thing")]
+ [Timeout(30_000)]
+ public async Task CancellableWhenAll_MultipleTasksCanceledAndFault()
+ {
+ var cts = new CancellationTokenSource();
+ var innerTask1 = new TaskCompletionSource();
+ var innerTask2 = new TaskCompletionSource();
+ var innertask3 = Task.CompletedTask;
+
+ var task = TaskUtilities.CancellableWhenAll(cts, innerTask1.Task, innerTask2.Task, innertask3);
+ Assert.That(task.IsCompleted, Is.False);
+ var task1ContinueTask = innerTask1.Task.ContinueWith(_ => { });
+ innerTask1.SetCanceled();
+ await task1ContinueTask;
+ Assert.That(task.IsCompleted, Is.False);
+ Assert.That(cts.IsCancellationRequested, Is.True);
+ var task2ContinueTask = innerTask2.Task.ContinueWith(_ => { });
+ innerTask2.SetException(new Exception("Test2"));
+ await task2ContinueTask;
+ Assert.ThrowsAsync(async () => await task);
+ }
+}
diff --git a/Tests/Vpn.Proto/RpcMessageTest.cs b/Tests/Vpn.Proto/RpcMessageTest.cs
deleted file mode 100644
index 36de12d..0000000
--- a/Tests/Vpn.Proto/RpcMessageTest.cs
+++ /dev/null
@@ -1,39 +0,0 @@
-using Coder.Desktop.Vpn.Proto;
-
-namespace Coder.Desktop.Tests.Vpn.Proto;
-
-[TestFixture]
-public class RpcRoleAttributeTest
-{
- [Test]
- public void Valid()
- {
- var role = new RpcRoleAttribute(RpcRole.Manager);
- Assert.That(role.Role.ToString(), Is.EqualTo(RpcRole.Manager));
- role = new RpcRoleAttribute(RpcRole.Tunnel);
- Assert.That(role.Role.ToString(), Is.EqualTo(RpcRole.Tunnel));
- }
-
- [Test]
- public void Invalid()
- {
- Assert.Throws(() => _ = new RpcRoleAttribute("cats"));
- }
-}
-
-[TestFixture]
-public class RpcMessageTest
-{
- [Test]
- public void GetRole()
- {
- // RpcMessage is not a supported message type and doesn't have an
- // RpcRoleAttribute
- var ex = Assert.Throws(() => _ = RpcMessage.GetRole());
- Assert.That(ex.Message,
- Does.Contain("Message type 'Coder.Desktop.Vpn.Proto.RPC' does not have a RpcRoleAttribute"));
-
- Assert.That(ManagerMessage.GetRole().ToString(), Is.EqualTo(RpcRole.Manager));
- Assert.That(TunnelMessage.GetRole().ToString(), Is.EqualTo(RpcRole.Tunnel));
- }
-}
diff --git a/Tests/Vpn.Proto/RpcRoleTest.cs b/Tests/Vpn.Proto/RpcRoleTest.cs
deleted file mode 100644
index f39d5cb..0000000
--- a/Tests/Vpn.Proto/RpcRoleTest.cs
+++ /dev/null
@@ -1,22 +0,0 @@
-using Coder.Desktop.Vpn.Proto;
-
-namespace Coder.Desktop.Tests.Vpn.Proto;
-
-[TestFixture]
-public class RpcRoleTest
-{
- [Test(Description = "Instantiate a RpcRole with a valid name")]
- public void ValidRole()
- {
- var role = new RpcRole(RpcRole.Manager);
- Assert.That(role.ToString(), Is.EqualTo(RpcRole.Manager));
- role = new RpcRole(RpcRole.Tunnel);
- Assert.That(role.ToString(), Is.EqualTo(RpcRole.Tunnel));
- }
-
- [Test(Description = "Try to instantiate a RpcRole with an invalid name")]
- public void InvalidRole()
- {
- Assert.Throws(() => _ = new RpcRole("cats"));
- }
-}
diff --git a/Vpn.Proto/RpcHeader.cs b/Vpn.Proto/RpcHeader.cs
index 0b840db..cf7ffcc 100644
--- a/Vpn.Proto/RpcHeader.cs
+++ b/Vpn.Proto/RpcHeader.cs
@@ -3,16 +3,22 @@
namespace Coder.Desktop.Vpn.Proto;
///
-/// A header to write or read from a stream to identify the speaker's role and version.
+/// A header to write or read from a stream to identify the peer role and version.
///
-/// Role of the speaker
-/// Version of the speaker
-public class RpcHeader(RpcRole role, RpcVersionList versionList)
+public class RpcHeader
{
private const string Preamble = "codervpn";
- public RpcRole Role { get; } = role;
- public RpcVersionList VersionList { get; } = versionList;
+ public string Role { get; }
+ public RpcVersionList VersionList { get; }
+
+ /// Role of the peer
+ /// Version of the peer
+ public RpcHeader(string role, RpcVersionList versionList)
+ {
+ Role = role;
+ VersionList = versionList;
+ }
///
/// Parse a header string into a SpeakerHeader.
@@ -25,10 +31,10 @@ public static RpcHeader Parse(string header)
var parts = header.Split(' ');
if (parts.Length != 3) throw new ArgumentException($"Wrong number of parts in header string '{header}'");
if (parts[0] != Preamble) throw new ArgumentException($"Invalid preamble in header string '{header}'");
+ if (string.IsNullOrEmpty(parts[1])) throw new ArgumentException($"Invalid role in header string '{header}'");
- var role = new RpcRole(parts[1]);
var versionList = RpcVersionList.Parse(parts[2]);
- return new RpcHeader(role, versionList);
+ return new RpcHeader(parts[1], versionList);
}
///
diff --git a/Vpn.Proto/RpcMessage.cs b/Vpn.Proto/RpcMessage.cs
index c44168c..bfe4d82 100644
--- a/Vpn.Proto/RpcMessage.cs
+++ b/Vpn.Proto/RpcMessage.cs
@@ -4,11 +4,23 @@
namespace Coder.Desktop.Vpn.Proto;
[AttributeUsage(AttributeTargets.Class, Inherited = false)]
-public class RpcRoleAttribute(string role) : Attribute
+public class RpcRoleAttribute : Attribute
{
- public RpcRole Role { get; } = new(role);
+ public string Role { get; }
+
+ public RpcRoleAttribute(string role)
+ {
+ Role = role;
+ }
}
+///
+/// IRpcMessageCompatibleWith is a marker interface that indicates that a
+/// message type can be used to peer with another message type.
+///
+///
+public interface IRpcMessageCompatibleWith;
+
///
/// Represents an actual over-the-wire message type.
///
@@ -36,9 +48,9 @@ public abstract class RpcMessage where T : IMessage
///
/// Gets the RpcRole of the message type from it's RpcRole attribute.
///
- ///
+ /// The role string
/// The message type does not have an RpcRoleAttribute
- public static RpcRole GetRole()
+ public static string GetRole()
{
var type = typeof(T);
var attr = type.GetCustomAttribute();
@@ -47,8 +59,8 @@ public static RpcRole GetRole()
}
}
-[RpcRole(RpcRole.Manager)]
-public partial class ManagerMessage : RpcMessage
+[RpcRole("manager")]
+public partial class ManagerMessage : RpcMessage, IRpcMessageCompatibleWith
{
public override RPC? RpcField
{
@@ -64,8 +76,8 @@ public override void Validate()
}
}
-[RpcRole(RpcRole.Tunnel)]
-public partial class TunnelMessage : RpcMessage
+[RpcRole("tunnel")]
+public partial class TunnelMessage : RpcMessage, IRpcMessageCompatibleWith
{
public override RPC? RpcField
{
@@ -80,3 +92,37 @@ public override void Validate()
if (MsgCase == MsgOneofCase.None) throw new ArgumentException("Message does not contain inner message type");
}
}
+
+[RpcRole("service")]
+public partial class ServiceMessage : RpcMessage, IRpcMessageCompatibleWith
+{
+ public override RPC? RpcField
+ {
+ get => Rpc;
+ set => Rpc = value;
+ }
+
+ public override ServiceMessage Message => this;
+
+ public override void Validate()
+ {
+ if (MsgCase == MsgOneofCase.None) throw new ArgumentException("Message does not contain inner message type");
+ }
+}
+
+[RpcRole("client")]
+public partial class ClientMessage : RpcMessage, IRpcMessageCompatibleWith
+{
+ public override RPC? RpcField
+ {
+ get => Rpc;
+ set => Rpc = value;
+ }
+
+ public override ClientMessage Message => this;
+
+ public override void Validate()
+ {
+ if (MsgCase == MsgOneofCase.None) throw new ArgumentException("Message does not contain inner message type");
+ }
+}
diff --git a/Vpn.Proto/RpcRole.cs b/Vpn.Proto/RpcRole.cs
deleted file mode 100644
index 9190281..0000000
--- a/Vpn.Proto/RpcRole.cs
+++ /dev/null
@@ -1,56 +0,0 @@
-namespace Coder.Desktop.Vpn.Proto;
-
-///
-/// Represents a role that either side of the connection can fulfil.
-///
-public sealed class RpcRole
-{
- public const string Manager = "manager";
- public const string Tunnel = "tunnel";
-
- public RpcRole(string role)
- {
- if (role != Manager && role != Tunnel) throw new ArgumentException($"Unknown role '{role}'");
-
- Role = role;
- }
-
- private string Role { get; }
-
- public override string ToString()
- {
- return Role;
- }
-
- #region SpeakerRole equality
-
- public static bool operator ==(RpcRole a, RpcRole b)
- {
- return a.Equals(b);
- }
-
- public static bool operator !=(RpcRole a, RpcRole b)
- {
- return !a.Equals(b);
- }
-
- private bool Equals(RpcRole other)
- {
- return Role == other.Role;
- }
-
- public override bool Equals(object? obj)
- {
- if (obj is null) return false;
- if (ReferenceEquals(this, obj)) return true;
- if (obj.GetType() != GetType()) return false;
- return Equals((RpcRole)obj);
- }
-
- public override int GetHashCode()
- {
- return Role.GetHashCode();
- }
-
- #endregion
-}
diff --git a/Vpn.Proto/RpcVersion.cs b/Vpn.Proto/RpcVersion.cs
index a9b1914..574768d 100644
--- a/Vpn.Proto/RpcVersion.cs
+++ b/Vpn.Proto/RpcVersion.cs
@@ -3,14 +3,20 @@
///
/// A version of the RPC API. Can be compared other versions to determine compatibility between two peers.
///
-/// The major version of the peer
-/// The minor version of the peer
-public class RpcVersion(ulong major, ulong minor)
+public class RpcVersion
{
public static readonly RpcVersion Current = new(1, 0);
- public ulong Major { get; } = major;
- public ulong Minor { get; } = minor;
+ public ulong Major { get; }
+ public ulong Minor { get; }
+
+ /// The major version of the peer
+ /// The minor version of the peer
+ public RpcVersion(ulong major, ulong minor)
+ {
+ Major = major;
+ Minor = minor;
+ }
///
/// Parse a string in the format "major.minor" into an ApiVersion.
diff --git a/Vpn.Proto/Vpn.Proto.csproj b/Vpn.Proto/Vpn.Proto.csproj
index 5380bd4..6acb12e 100644
--- a/Vpn.Proto/Vpn.Proto.csproj
+++ b/Vpn.Proto/Vpn.Proto.csproj
@@ -12,8 +12,8 @@
-
-
+
+
all
runtime; build; native; contentfiles; analyzers; buildtransitive
diff --git a/Vpn.Proto/vpn.proto b/Vpn.Proto/vpn.proto
index 33a3ff4..a03978a 100644
--- a/Vpn.Proto/vpn.proto
+++ b/Vpn.Proto/vpn.proto
@@ -44,6 +44,24 @@ message TunnelMessage {
}
}
+// ClientMessage is a message from the client (to the service).
+message ClientMessage {
+ RPC rpc = 1;
+ oneof msg {
+ StartRequest start = 2;
+ StopRequest stop = 3;
+ }
+}
+
+// ServiceMessage is a message from the service (to the client).
+message ServiceMessage {
+ RPC rpc = 1;
+ oneof msg {
+ StartResponse start = 2;
+ StopResponse stop = 3;
+ }
+}
+
// Log is a log message generated by the tunnel. The manager should log it to the system log. It is
// one-way tunnel -> manager with no response.
message Log {
@@ -105,7 +123,7 @@ message Agent {
bytes id = 1; // UUID
string name = 2;
bytes workspace_id = 3; // UUID
- string fqdn = 4;
+ repeated string fqdn = 4;
repeated string ip_addrs = 5;
// last_handshake is the primary indicator of whether we are connected to a peer. Zero value or
// anything longer than 5 minutes ago means there is a problem.
@@ -179,6 +197,12 @@ message StartRequest {
int32 tunnel_file_descriptor = 1;
string coder_url = 2;
string api_token = 3;
+ // Additional HTTP headers added to all requests
+ message Header {
+ string name = 1;
+ string value = 2;
+ }
+ repeated Header headers = 4;
}
message StartResponse {
diff --git a/Vpn.Service/Downloader.cs b/Vpn.Service/Downloader.cs
new file mode 100644
index 0000000..83eda24
--- /dev/null
+++ b/Vpn.Service/Downloader.cs
@@ -0,0 +1,355 @@
+using System.Collections.Concurrent;
+using System.Diagnostics;
+using System.Net;
+using System.Security.Cryptography;
+using System.Security.Cryptography.X509Certificates;
+using Coder.Desktop.Vpn.Utilities;
+using Microsoft.Extensions.Logging;
+using Microsoft.Security.Extensions;
+
+namespace Coder.Desktop.Vpn.Service;
+
+public interface IDownloader
+{
+ Task StartDownloadAsync(HttpRequestMessage req, string destinationPath, IDownloadValidator validator,
+ CancellationToken ct = default);
+}
+
+public interface IDownloadValidator
+{
+ ///
+ /// Validates the downloaded file at the given path. This method should throw an exception if the file is invalid.
+ ///
+ /// The path of the file
+ /// Cancellation token
+ Task ValidateAsync(string path, CancellationToken ct = default);
+}
+
+public class NullDownloadValidator : IDownloadValidator
+{
+ public static NullDownloadValidator Instance => new();
+
+ public Task ValidateAsync(string path, CancellationToken ct = default)
+ {
+ return Task.CompletedTask;
+ }
+}
+
+///
+/// Ensures the downloaded binary is signed by the expected authenticode organization.
+///
+public class AuthenticodeDownloadValidator : IDownloadValidator
+{
+ private readonly string _expectedName;
+
+ public static AuthenticodeDownloadValidator Coder => new("Coder Technologies Inc.");
+
+ public AuthenticodeDownloadValidator(string expectedName)
+ {
+ _expectedName = expectedName;
+ }
+
+ public async Task ValidateAsync(string path, CancellationToken ct = default)
+ {
+ FileSignatureInfo fileSigInfo;
+ await using (var fileStream = File.OpenRead(path))
+ {
+ fileSigInfo = FileSignatureInfo.GetFromFileStream(fileStream);
+ }
+
+ if (fileSigInfo.State != SignatureState.SignedAndTrusted)
+ throw new Exception(
+ $"File is not signed and trusted with an Authenticode signature: State={fileSigInfo.State}");
+
+ // Coder will only use embedded signatures because we are downloading
+ // individual binaries and not installers which can ship catalog files.
+ if (fileSigInfo.Kind != SignatureKind.Embedded)
+ throw new Exception($"File is not signed with an embedded Authenticode signature: Kind={fileSigInfo.Kind}");
+
+ // TODO: check that it's an extended validation certificate
+
+ var actualName = fileSigInfo.SigningCertificate.GetNameInfo(X509NameType.SimpleName, false);
+ if (actualName != _expectedName)
+ throw new Exception(
+ $"File is signed by an unexpected certificate: ExpectedName='{_expectedName}', ActualName='{actualName}'");
+ }
+}
+
+public class AssemblyVersionDownloadValidator : IDownloadValidator
+{
+ private readonly string _expectedAssemblyVersion;
+
+ public AssemblyVersionDownloadValidator(string expectedAssemblyVersion)
+ {
+ _expectedAssemblyVersion = expectedAssemblyVersion;
+ }
+
+ public Task ValidateAsync(string path, CancellationToken ct = default)
+ {
+ var info = FileVersionInfo.GetVersionInfo(path);
+ if (string.IsNullOrEmpty(info.ProductVersion))
+ throw new Exception("File ProductVersion is empty or null, was the binary compiled correctly?");
+ if (info.ProductVersion != _expectedAssemblyVersion)
+ throw new Exception(
+ $"File ProductVersion is '{info.ProductVersion}', but expected '{_expectedAssemblyVersion}'");
+ return Task.CompletedTask;
+ }
+}
+
+///
+/// Combines multiple download validators into a single validator. All validators will be run in order.
+///
+public class CombinationDownloadValidator : IDownloadValidator
+{
+ private readonly IDownloadValidator[] _validators;
+
+ /// Validators to run
+ public CombinationDownloadValidator(params IDownloadValidator[] validators)
+ {
+ _validators = validators;
+ }
+
+ public async Task ValidateAsync(string path, CancellationToken ct = default)
+ {
+ foreach (var validator in _validators)
+ await validator.ValidateAsync(path, ct);
+ }
+}
+
+///
+/// Handles downloading files from the internet. Downloads are performed asynchronously using DownloadTask.
+/// Single-flight is provided to avoid performing the same download multiple times.
+///
+public class Downloader : IDownloader
+{
+ private readonly ConcurrentDictionary _downloads = new();
+ private readonly ILogger _logger;
+
+ // ReSharper disable once ConvertToPrimaryConstructor
+ public Downloader(ILogger logger)
+ {
+ _logger = logger;
+ }
+
+ ///
+ /// Starts a download with the given request. The If-None-Match header will be set to the SHA1 ETag of any existing
+ /// file in the destination location.
+ ///
+ /// Request message
+ /// Path to write file to (will be overwritten)
+ /// Validator for the downloaded file
+ /// Cancellation token
+ /// A DownloadTask representing the ongoing download operation after it starts
+ public async Task StartDownloadAsync(HttpRequestMessage req, string destinationPath,
+ IDownloadValidator validator, CancellationToken ct = default)
+ {
+ while (true)
+ {
+ var task = _downloads.GetOrAdd(destinationPath,
+ _ => new DownloadTask(_logger, req, destinationPath, validator));
+ await task.EnsureStartedAsync(ct);
+
+ // If the existing (or new) task is for the same URL, return it.
+ if (task.Request.RequestUri == req.RequestUri)
+ return task;
+
+ // If the existing task is for a different URL, await its completion
+ // then retry the loop to create a new task. This could potentially
+ // get stuck if there are a lot of download operations for different
+ // URLs and the same destination path, but in our use case this
+ // shouldn't happen unless the user keeps changing the access URL.
+ _logger.LogWarning(
+ "Download for '{DestinationPath}' is already in progress, but is for a different Url - awaiting completion",
+ destinationPath);
+ await task.Task;
+ }
+ }
+}
+
+///
+/// Downloads an Url to a file on disk. The download will be written to a temporary file first, then moved to the final
+/// destination. The SHA1 of any existing file will be calculated and used as an ETag to avoid downloading the file if
+/// it hasn't changed.
+///
+public class DownloadTask
+{
+ private const int BufferSize = 4096;
+
+ private static readonly HttpClient HttpClient = new();
+ private readonly string _destinationDirectory;
+
+ private readonly ILogger _logger;
+
+ private readonly RaiiSemaphoreSlim _semaphore = new(1, 1);
+ private readonly IDownloadValidator _validator;
+ public readonly string DestinationPath;
+
+ public readonly HttpRequestMessage Request;
+ public readonly string TempDestinationPath;
+
+ public ulong? TotalBytes { get; private set; }
+ public ulong BytesRead { get; private set; }
+ public Task Task { get; private set; } = null!; // Set in EnsureStartedAsync
+
+ public double? Progress => TotalBytes == null ? null : (double)BytesRead / TotalBytes.Value;
+ public bool IsCompleted => Task.IsCompleted;
+
+ internal DownloadTask(ILogger logger, HttpRequestMessage req, string destinationPath, IDownloadValidator validator)
+ {
+ _logger = logger;
+ Request = req;
+ _validator = validator;
+
+ if (string.IsNullOrWhiteSpace(destinationPath))
+ throw new ArgumentException("Destination path must not be empty", nameof(destinationPath));
+ DestinationPath = Path.GetFullPath(destinationPath);
+ if (Path.EndsInDirectorySeparator(DestinationPath))
+ throw new ArgumentException($"Destination path '{DestinationPath}' must not end in a directory separator",
+ nameof(destinationPath));
+
+ _destinationDirectory = Path.GetDirectoryName(DestinationPath)
+ ?? throw new ArgumentException(
+ $"Destination path '{DestinationPath}' must have a parent directory",
+ nameof(destinationPath));
+
+ TempDestinationPath = Path.Combine(_destinationDirectory, "." + Path.GetFileName(DestinationPath) +
+ ".download-" + Path.GetRandomFileName());
+ }
+
+ internal async Task EnsureStartedAsync(CancellationToken ct = default)
+ {
+ using var _ = await _semaphore.LockAsync(ct);
+ if (Task == null!)
+ Task = await StartDownloadAsync(ct);
+
+ return Task;
+ }
+
+ ///
+ /// Starts downloading the file. The request will be performed in this task, but once started, the task will complete
+ /// and the download will continue in the background. The provided CancellationToken can be used to cancel the
+ /// download.
+ ///
+ private async Task StartDownloadAsync(CancellationToken ct = default)
+ {
+ Directory.CreateDirectory(_destinationDirectory);
+
+ // If the destination path exists, generate a Coder SHA1 ETag and send
+ // it in the If-None-Match header to the server.
+ if (File.Exists(DestinationPath))
+ {
+ await using var stream = File.OpenRead(DestinationPath);
+ var etag = Convert.ToHexString(await SHA1.HashDataAsync(stream, ct)).ToLower();
+ Request.Headers.Add("If-None-Match", "\"" + etag + "\"");
+ }
+
+ var res = await HttpClient.SendAsync(Request, HttpCompletionOption.ResponseHeadersRead, ct);
+ if (res.StatusCode == HttpStatusCode.NotModified)
+ {
+ _logger.LogInformation("File has not been modified, skipping download");
+ try
+ {
+ await _validator.ValidateAsync(DestinationPath, ct);
+ }
+ catch (Exception e)
+ {
+ _logger.LogWarning(e, "Existing file '{DestinationPath}' failed custom validation", DestinationPath);
+ throw new Exception("Existing file failed validation after 304 Not Modified", e);
+ }
+
+ Task = Task.CompletedTask;
+ return Task;
+ }
+
+ if (res.StatusCode != HttpStatusCode.OK)
+ {
+ _logger.LogWarning("Failed to download file '{Request.RequestUri}': {StatusCode} {ReasonPhrase}",
+ Request.RequestUri, res.StatusCode,
+ res.ReasonPhrase);
+ throw new HttpRequestException(
+ $"Failed to download file '{Request.RequestUri}': {(int)res.StatusCode} {res.ReasonPhrase}");
+ }
+
+ if (res.Content == null)
+ {
+ _logger.LogWarning("File {Request.RequestUri} has no content", Request.RequestUri);
+ throw new HttpRequestException("Response has no content");
+ }
+
+ if (res.Content.Headers.ContentLength >= 0)
+ TotalBytes = (ulong)res.Content.Headers.ContentLength;
+
+ FileStream tempFile;
+ try
+ {
+ tempFile = File.Create(TempDestinationPath, BufferSize,
+ FileOptions.Asynchronous | FileOptions.SequentialScan);
+ }
+ catch (Exception e)
+ {
+ _logger.LogError(e, "Failed to create temporary file '{TempDestinationPath}'", TempDestinationPath);
+ throw;
+ }
+
+ Task = DownloadAsync(res, tempFile, ct);
+ return Task;
+ }
+
+ private async Task DownloadAsync(HttpResponseMessage res, FileStream tempFile, CancellationToken ct)
+ {
+ try
+ {
+ var sha1 = res.Headers.Contains("ETag") ? SHA1.Create() : null;
+ await using (tempFile)
+ {
+ var stream = await res.Content.ReadAsStreamAsync(ct);
+ var buffer = new byte[BufferSize];
+ int n;
+ while ((n = await stream.ReadAsync(buffer, ct)) > 0)
+ {
+ await tempFile.WriteAsync(buffer.AsMemory(0, n), ct);
+ sha1?.TransformBlock(buffer, 0, n, null, 0);
+ BytesRead += (ulong)n;
+ }
+ }
+
+ if (TotalBytes != null && BytesRead != TotalBytes)
+ throw new IOException(
+ $"Downloaded file size does not match response Content-Length: Content-Length={TotalBytes}, BytesRead={BytesRead}");
+
+ // Verify the ETag if it was sent by the server.
+ if (res.Headers.Contains("ETag") && sha1 != null)
+ {
+ var etag = res.Headers.ETag!.Tag.Trim('"');
+ _ = sha1.TransformFinalBlock([], 0, 0);
+ var hashStr = Convert.ToHexString(sha1.Hash!).ToLower();
+ if (etag != hashStr)
+ throw new HttpRequestException(
+ $"ETag does not match SHA1 hash of downloaded file: ETag='{etag}', Local='{hashStr}'");
+ }
+
+ try
+ {
+ await _validator.ValidateAsync(TempDestinationPath, ct);
+ }
+ catch (Exception e)
+ {
+ _logger.LogWarning(e, "Downloaded file '{TempDestinationPath}' failed custom validation",
+ TempDestinationPath);
+ throw new HttpRequestException("Downloaded file failed validation", e);
+ }
+
+ File.Move(TempDestinationPath, DestinationPath, true);
+ }
+ finally
+ {
+#if DEBUG
+ _logger.LogWarning("Not deleting temporary file '{TempDestinationPath}' in debug mode",
+ TempDestinationPath);
+#else
+ if (File.Exists(TempDestinationPath))
+ File.Delete(TempDestinationPath);
+#endif
+ }
+ }
+}
diff --git a/Vpn.Service/Manager.cs b/Vpn.Service/Manager.cs
new file mode 100644
index 0000000..0f11f34
--- /dev/null
+++ b/Vpn.Service/Manager.cs
@@ -0,0 +1,215 @@
+using System.Runtime.InteropServices;
+using Coder.Desktop.Vpn.Proto;
+using CoderSdk;
+using Microsoft.Extensions.Logging;
+using Microsoft.Extensions.Options;
+using Semver;
+
+namespace Coder.Desktop.Vpn.Service;
+
+public interface IManager : IDisposable
+{
+ public Task HandleClientRpcMessage(ReplyableRpcMessage message,
+ CancellationToken ct = default);
+
+ public Task StopAsync(CancellationToken ct = default);
+}
+
+///
+/// Manager provides handling for RPC requests from the client and from the tunnel.
+///
+public class Manager : IManager
+{
+ // TODO: determine a suitable value for this
+ private const string ServerVersionRange = ">=0.0.0";
+
+ private readonly ManagerConfig _config;
+ private readonly IDownloader _downloader;
+ private readonly ILogger _logger;
+ private readonly ITunnelSupervisor _tunnelSupervisor;
+
+ // ReSharper disable once ConvertToPrimaryConstructor
+ public Manager(IOptions config, ILogger logger, IDownloader downloader,
+ ITunnelSupervisor tunnelSupervisor)
+ {
+ _config = config.Value;
+ _logger = logger;
+ _downloader = downloader;
+ _tunnelSupervisor = tunnelSupervisor;
+ }
+
+ public void Dispose()
+ {
+ GC.SuppressFinalize(this);
+ }
+
+ ///
+ /// Processes a message sent from a Client to the ManagerRpcService over the codervpn RPC protocol.
+ ///
+ /// Client message
+ /// Cancellation token
+ public async Task HandleClientRpcMessage(ReplyableRpcMessage message,
+ CancellationToken ct = default)
+ {
+ _logger.LogInformation("ClientMessage: {MessageType}", message.Message.MsgCase);
+ // TODO: break out each into it's own method?
+ switch (message.Message.MsgCase)
+ {
+ case ClientMessage.MsgOneofCase.Start:
+ // TODO: these sub-methods should be managed by some Task list and cancelled/awaited on stop
+ await HandleClientMessageStart(message, ct);
+ break;
+ case ClientMessage.MsgOneofCase.Stop:
+ await HandleClientMessageStop(message, ct);
+ break;
+ case ClientMessage.MsgOneofCase.None:
+ default:
+ _logger.LogWarning("Received unknown message type {MessageType}", message.Message.MsgCase);
+ break;
+ }
+ }
+
+ public async Task StopAsync(CancellationToken ct = default)
+ {
+ await _tunnelSupervisor.StopAsync(ct);
+ }
+
+ private async Task HandleClientMessageStart(ReplyableRpcMessage message,
+ CancellationToken ct)
+ {
+ try
+ {
+ // TODO: if the credentials and URL are identical and the server
+ // version hasn't changed we should not do anything
+ // TODO: this should be broken out into it's own method
+ _logger.LogInformation("ClientMessage.Start: testing server '{ServerUrl}'", message.Message.Start.CoderUrl);
+ var client = new CoderApiClient(message.Message.Start.CoderUrl, message.Message.Start.ApiToken);
+ var buildInfo = await client.GetBuildInfo(ct);
+ _logger.LogInformation("ClientMessage.Start: server version '{ServerVersion}'", buildInfo.Version);
+ var serverVersion = SemVersion.Parse(buildInfo.Version);
+ if (!serverVersion.Satisfies(ServerVersionRange))
+ throw new InvalidOperationException(
+ $"Server version '{serverVersion}' is not within required server version range '{ServerVersionRange}'");
+ var user = await client.GetUser(User.Me, ct);
+ _logger.LogInformation("ClientMessage.Start: authenticated as '{Username}'", user.Username);
+
+ await DownloadTunnelBinaryAsync(message.Message.Start.CoderUrl, serverVersion, ct);
+ await _tunnelSupervisor.StartAsync(_config.TunnelBinaryPath, HandleTunnelRpcMessage,
+ HandleTunnelRpcError,
+ ct);
+ }
+ catch (Exception e)
+ {
+ _logger.LogWarning(e, "ClientMessage.Start: Failed to start VPN client");
+ await message.SendReply(new ServiceMessage
+ {
+ Start = new StartResponse
+ {
+ Success = false,
+ ErrorMessage = e.Message,
+ },
+ }, ct);
+ }
+ }
+
+ private async Task HandleClientMessageStop(ReplyableRpcMessage message,
+ CancellationToken ct)
+ {
+ try
+ {
+ // This will handle sending the Stop message for us.
+ await _tunnelSupervisor.StopAsync(ct);
+ }
+ catch (Exception e)
+ {
+ _logger.LogWarning(e, "ClientMessage.Stop: Failed to stop VPN client");
+ await message.SendReply(new ServiceMessage
+ {
+ Stop = new StopResponse
+ {
+ Success = false,
+ ErrorMessage = e.Message,
+ },
+ }, ct);
+ }
+ }
+
+ private void HandleTunnelRpcMessage(ReplyableRpcMessage message)
+ {
+ // TODO: this
+ }
+
+ private void HandleTunnelRpcError(Exception e)
+ {
+ // TODO: this probably happens during an ongoing start or stop operation, and we should definitely ignore those
+ _logger.LogError(e, "Manager<->Tunnel RPC error");
+ try
+ {
+ _tunnelSupervisor.StopAsync();
+ }
+ catch (Exception e2)
+ {
+ _logger.LogError(e2, "Failed to stop tunnel supervisor after RPC error");
+ }
+ }
+
+ ///
+ /// Returns the architecture of the current system.
+ ///
+ /// A golang architecture string for the binary
+ /// Unsupported architecture
+ private static string SystemArchitecture()
+ {
+ // ReSharper disable once SwitchExpressionHandlesSomeKnownEnumValuesWithExceptionInDefault
+ return RuntimeInformation.ProcessArchitecture switch
+ {
+ Architecture.X64 => "amd64",
+ Architecture.Arm64 => "arm64",
+ // We only support amd64 and arm64 on Windows currently.
+ _ => throw new PlatformNotSupportedException(
+ $"Unsupported architecture '{RuntimeInformation.ProcessArchitecture}'. Coder only supports amd64 and arm64."),
+ };
+ }
+
+ ///
+ /// Fetches the "/bin/coder-windows-{architecture}.exe" binary from the given base URL and writes it to the
+ /// destination path after validating the signature and checksum.
+ ///
+ /// Server base URL to download the binary from
+ /// The version of the server to expect in the binary
+ /// Cancellation token
+ /// If the base URL is invalid
+ private async Task DownloadTunnelBinaryAsync(string baseUrl, SemVersion expectedVersion,
+ CancellationToken ct = default)
+ {
+ var architecture = SystemArchitecture();
+ Uri url;
+ try
+ {
+ url = new Uri(baseUrl, UriKind.Absolute);
+ if (url.PathAndQuery != "/")
+ throw new ArgumentException("Base URL must not contain a path", nameof(baseUrl));
+ url = new Uri(url, $"/bin/coder-windows-{architecture}.exe");
+ }
+ catch (Exception e)
+ {
+ throw new ArgumentException($"Invalid base URL '{baseUrl}'", e);
+ }
+
+ _logger.LogInformation("Downloading VPN binary from '{url}' to '{DestinationPath}'", url,
+ _config.TunnelBinaryPath);
+ var req = new HttpRequestMessage(HttpMethod.Get, url);
+ var validators = new CombinationDownloadValidator(
+ AuthenticodeDownloadValidator.Coder,
+ new AssemblyVersionDownloadValidator(
+ $"{expectedVersion.Major}.{expectedVersion.Minor}.{expectedVersion.Patch}.0")
+ );
+ var downloadTask = await _downloader.StartDownloadAsync(req, _config.TunnelBinaryPath, validators, ct);
+
+ // TODO: monitor and report progress when we have a mechanism to do so
+
+ // Awaiting this will check the checksum (via the ETag) if provided,
+ // and will also validate the signature using the validator we supplied.
+ await downloadTask.Task;
+ }
+}
diff --git a/Vpn.Service/ManagerConfig.cs b/Vpn.Service/ManagerConfig.cs
new file mode 100644
index 0000000..906a0b8
--- /dev/null
+++ b/Vpn.Service/ManagerConfig.cs
@@ -0,0 +1,16 @@
+using System.ComponentModel.DataAnnotations;
+using System.Diagnostics.CodeAnalysis;
+
+namespace Coder.Desktop.Vpn.Service;
+
+[SuppressMessage("ReSharper", "AutoPropertyCanBeMadeGetOnly.Global")]
+public class ManagerConfig
+{
+ [Required]
+ [RegularExpression(@"^([a-zA-Z0-9_-]+\.)*[a-zA-Z0-9_-]+$")]
+ public string ServiceRpcPipeName { get; set; } = "Coder.Desktop.Vpn";
+
+ // TODO: pick a better default path
+ [Required]
+ public string TunnelBinaryPath { get; set; } = @"C:\coder-vpn.exe";
+}
diff --git a/Vpn.Service/ManagerRpcService.cs b/Vpn.Service/ManagerRpcService.cs
new file mode 100644
index 0000000..ce2b17e
--- /dev/null
+++ b/Vpn.Service/ManagerRpcService.cs
@@ -0,0 +1,128 @@
+using System.Collections.Concurrent;
+using System.IO.Pipes;
+using Coder.Desktop.Vpn.Proto;
+using Microsoft.Extensions.Hosting;
+using Microsoft.Extensions.Logging;
+using Microsoft.Extensions.Options;
+
+namespace Coder.Desktop.Vpn.Service;
+
+///
+/// Provides a named pipe server for communication between multiple RpcRole.Client and RpcRole.Manager.
+///
+public class ManagerRpcService : BackgroundService, IAsyncDisposable
+{
+ private readonly ConcurrentDictionary _activeClientTasks = new();
+ private readonly ManagerConfig _config;
+ private readonly CancellationTokenSource _cts = new();
+ private readonly ILogger _logger;
+ private readonly IManager _manager;
+
+ public ManagerRpcService(IOptions config, ILogger logger, IManager manager)
+ {
+ _logger = logger;
+ _manager = manager;
+ _config = config.Value;
+ }
+
+ public async ValueTask DisposeAsync()
+ {
+ await _cts.CancelAsync();
+ while (!_activeClientTasks.IsEmpty) await Task.WhenAny(_activeClientTasks.Values);
+ _cts.Dispose();
+ GC.SuppressFinalize(this);
+ }
+
+ public override async Task StopAsync(CancellationToken cancellationToken)
+ {
+ await _cts.CancelAsync();
+ while (!_activeClientTasks.IsEmpty) await Task.WhenAny(_activeClientTasks.Values);
+ }
+
+ ///
+ /// Starts the named pipe server, listens for incoming connections and starts handling them asynchronously.
+ ///
+ protected override async Task ExecuteAsync(CancellationToken stoppingToken)
+ {
+ _logger.LogInformation(@"Starting continuous named pipe RPC server at \\.\pipe\{PipeName}",
+ _config.ServiceRpcPipeName);
+ using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(stoppingToken, _cts.Token);
+ while (!linkedCts.IsCancellationRequested)
+ {
+ var pipeServer = new NamedPipeServerStream(_config.ServiceRpcPipeName, PipeDirection.InOut,
+ NamedPipeServerStream.MaxAllowedServerInstances, PipeTransmissionMode.Byte, PipeOptions.Asynchronous);
+
+ try
+ {
+ try
+ {
+ _logger.LogDebug("Waiting for new named pipe client connection");
+ await pipeServer.WaitForConnectionAsync(linkedCts.Token);
+ }
+ finally
+ {
+ await pipeServer.DisposeAsync();
+ }
+
+ _logger.LogInformation("Handling named pipe client connection");
+ var clientTask = HandleRpcClientAsync(pipeServer, linkedCts.Token);
+ _activeClientTasks.TryAdd(clientTask.Id, clientTask);
+ _ = clientTask.ContinueWith(RpcClientContinuation, CancellationToken.None);
+ }
+ catch (OperationCanceledException)
+ {
+ throw;
+ }
+ catch (Exception e)
+ {
+ _logger.LogWarning(e, "Failed to accept named pipe client");
+ }
+ }
+ }
+
+ private async Task HandleRpcClientAsync(NamedPipeServerStream pipeServer, CancellationToken ct)
+ {
+ var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(ct, _cts.Token);
+ await using (pipeServer)
+ {
+ await using var speaker = new Speaker(pipeServer);
+
+ var tcs = new TaskCompletionSource();
+ var activeTasks = new ConcurrentDictionary();
+ speaker.Receive += msg =>
+ {
+ var task = HandleRpcMessageAsync(msg, linkedCts.Token);
+ activeTasks.TryAdd(task.Id, task);
+ task.ContinueWith(t =>
+ {
+ if (t.IsFaulted)
+ _logger.LogWarning(t.Exception, "Client RPC message handler task faulted");
+ activeTasks.TryRemove(t.Id, out _);
+ }, CancellationToken.None);
+ };
+ speaker.Error += tcs.SetException;
+ await using (ct.Register(() => tcs.SetCanceled(ct)))
+ {
+ await speaker.StartAsync(ct);
+ await tcs.Task;
+ await linkedCts.CancelAsync();
+ while (!activeTasks.IsEmpty)
+ await Task.WhenAny(activeTasks.Values);
+ }
+ }
+ }
+
+ private void RpcClientContinuation(Task task)
+ {
+ if (task.IsFaulted)
+ _logger.LogWarning(task.Exception, "Client RPC task faulted");
+ _activeClientTasks.TryRemove(task.Id, out _);
+ }
+
+ private async Task HandleRpcMessageAsync(ReplyableRpcMessage message,
+ CancellationToken ct)
+ {
+ _logger.LogInformation("Received RPC message: {Message}", message.Message);
+ await _manager.HandleClientRpcMessage(message, ct);
+ }
+}
diff --git a/Vpn.Service/ManagerService.cs b/Vpn.Service/ManagerService.cs
new file mode 100644
index 0000000..b7b2e34
--- /dev/null
+++ b/Vpn.Service/ManagerService.cs
@@ -0,0 +1,33 @@
+using Microsoft.Extensions.Hosting;
+using Microsoft.Extensions.Logging;
+
+namespace Coder.Desktop.Vpn.Service;
+
+///
+/// Wraps Manager to provide a BackgroundService that informs the singleton Manager to shut down when stop is
+/// requested.
+///
+public class ManagerService : BackgroundService
+{
+ private readonly ILogger _logger;
+ private readonly IManager _manager;
+
+ // ReSharper disable once ConvertToPrimaryConstructor
+ public ManagerService(ILogger logger, IManager manager)
+ {
+ _logger = logger;
+ _manager = manager;
+ }
+
+ protected override async Task ExecuteAsync(CancellationToken stoppingToken)
+ {
+ // Block until the service is stopped.
+ await Task.Delay(-1, stoppingToken);
+ }
+
+ public override async Task StopAsync(CancellationToken cancellationToken)
+ {
+ _logger.LogInformation("Informing Manager to stop");
+ await _manager.StopAsync(cancellationToken);
+ }
+}
diff --git a/Vpn.Service/Program.cs b/Vpn.Service/Program.cs
new file mode 100644
index 0000000..78fbff2
--- /dev/null
+++ b/Vpn.Service/Program.cs
@@ -0,0 +1,30 @@
+using Coder.Desktop.Vpn.Service;
+using Microsoft.Extensions.Configuration;
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.Extensions.Hosting;
+using Microsoft.Win32;
+
+var builder = Host.CreateApplicationBuilder(args);
+
+// Configuration sources
+builder.Configuration.Sources.Clear();
+(builder.Configuration as IConfigurationBuilder).Add(
+ new RegistryConfigurationSource(Registry.LocalMachine, @"SOFTWARE\Coder\Coder VPN"));
+builder.Configuration.AddEnvironmentVariables("CODER_MANAGER_");
+builder.Configuration.AddCommandLine(args);
+
+// Options types (these get registered as IOptions singletons)
+builder.Services.AddOptions()
+ .Bind(builder.Configuration.GetSection("Manager"))
+ .ValidateDataAnnotations();
+
+// Singletons
+builder.Services.AddSingleton();
+builder.Services.AddSingleton();
+builder.Services.AddSingleton();
+
+// Services
+builder.Services.AddHostedService();
+builder.Services.AddHostedService();
+
+builder.Build().Run();
diff --git a/Vpn.Service/RegistryConfigurationSource.cs b/Vpn.Service/RegistryConfigurationSource.cs
new file mode 100644
index 0000000..7ac2764
--- /dev/null
+++ b/Vpn.Service/RegistryConfigurationSource.cs
@@ -0,0 +1,41 @@
+using Microsoft.Extensions.Configuration;
+using Microsoft.Win32;
+
+namespace Coder.Desktop.Vpn.Service;
+
+public class RegistryConfigurationSource : IConfigurationSource
+{
+ private readonly RegistryKey _root;
+ private readonly string _subKeyName;
+
+ public RegistryConfigurationSource(RegistryKey root, string subKeyName)
+ {
+ _root = root;
+ _subKeyName = subKeyName;
+ }
+
+ public IConfigurationProvider Build(IConfigurationBuilder builder)
+ {
+ return new RegistryConfigurationProvider(_root, _subKeyName);
+ }
+}
+
+public class RegistryConfigurationProvider : ConfigurationProvider
+{
+ private readonly RegistryKey _root;
+ private readonly string _subKeyName;
+
+ public RegistryConfigurationProvider(RegistryKey root, string subKeyName)
+ {
+ _root = root;
+ _subKeyName = subKeyName;
+ }
+
+ public override void Load()
+ {
+ using var key = _root.OpenSubKey(_subKeyName);
+ if (key == null) return;
+
+ foreach (var valueName in key.GetValueNames()) Data[valueName] = key.GetValue(valueName)?.ToString();
+ }
+}
diff --git a/Vpn.Service/TunnelSupervisor.cs b/Vpn.Service/TunnelSupervisor.cs
new file mode 100644
index 0000000..9ea5b05
--- /dev/null
+++ b/Vpn.Service/TunnelSupervisor.cs
@@ -0,0 +1,271 @@
+using System.Diagnostics;
+using System.IO.Pipes;
+using Coder.Desktop.Vpn.Proto;
+using Coder.Desktop.Vpn.Utilities;
+using Microsoft.Extensions.Logging;
+
+namespace Coder.Desktop.Vpn.Service;
+
+public interface ITunnelSupervisor : IAsyncDisposable
+{
+ ///
+ /// Starts the tunnel subprocess with the given executable path. If the subprocess is already running, this method will
+ /// kill it first.
+ ///
+ /// Path to the executable
+ /// Handler to call with each RPC message
+ ///
+ /// Handler for permanent errors from the RPC Speaker. The recipient should call StopAsync after
+ /// receiving this.
+ ///
+ /// Cancellation token
+ public Task StartAsync(string binPath,
+ Speaker.OnReceiveDelegate messageHandler,
+ Speaker.OnErrorDelegate errorHandler,
+ CancellationToken ct = default);
+
+ ///
+ /// Stops the tunnel subprocess. If the subprocess is not running, this method does nothing.
+ ///
+ ///
+ ///
+ public Task StopAsync(CancellationToken ct = default);
+}
+
+///
+/// Launches and supervises the tunnel subprocess. Provides RPC communication with the subprocess.
+///
+public class TunnelSupervisor : ITunnelSupervisor
+{
+ private readonly CancellationTokenSource _cts = new();
+ private readonly ILogger _logger;
+ private readonly SemaphoreSlim _operationLock = new(1, 1);
+ private AnonymousPipeServerStream? _inPipe;
+ private AnonymousPipeServerStream? _outPipe;
+ private Speaker? _speaker;
+
+ private Process? _subprocess;
+
+ // ReSharper disable once ConvertToPrimaryConstructor
+ public TunnelSupervisor(ILogger logger)
+ {
+ _logger = logger;
+ }
+
+ public async Task StartAsync(string binPath,
+ Speaker.OnReceiveDelegate messageHandler,
+ Speaker.OnErrorDelegate errorHandler,
+ CancellationToken ct = default)
+ {
+ _logger.LogInformation("StartAsync(\"{binPath}\")", binPath);
+ if (!await _operationLock.WaitAsync(0, ct))
+ throw new InvalidOperationException(
+ "Another TunnelSupervisor Start or Stop operation is already in progress");
+
+ try
+ {
+ await CleanupAsync(ct);
+
+ _outPipe = new AnonymousPipeServerStream(PipeDirection.Out, HandleInheritability.Inheritable);
+ _inPipe = new AnonymousPipeServerStream(PipeDirection.In, HandleInheritability.Inheritable);
+ _subprocess = new Process
+ {
+ StartInfo = new ProcessStartInfo
+ {
+ FileName = binPath,
+ ArgumentList = { "vpn-daemon", "run" },
+ UseShellExecute = false,
+ CreateNoWindow = true,
+ },
+ };
+
+ // Pass the other end of the pipes to the subprocess and dispose
+ // the local copies.
+ _subprocess.StartInfo.Environment.Add("CODER_VPN_DAEMON_RPC_READ_HANDLE",
+ _outPipe.GetClientHandleAsString());
+ _subprocess.StartInfo.Environment.Add("CODER_VPN_DAEMON_RPC_WRITE_HANDLE",
+ _inPipe.GetClientHandleAsString());
+ _outPipe.DisposeLocalCopyOfClientHandle();
+ _inPipe.DisposeLocalCopyOfClientHandle();
+
+ _logger.LogInformation("StartAsync: starting subprocess");
+ _subprocess.Start();
+ _logger.LogInformation("StartAsync: subprocess started");
+
+ // We don't use the supplied CancellationToken here because we want it to only apply to the startup
+ // procedure.
+ _ = _subprocess.WaitForExitAsync(_cts.Token).ContinueWith(OnProcessExited, CancellationToken.None);
+
+ // Start the RPC Speaker.
+ try
+ {
+ var stream = new BidirectionalPipe(_inPipe, _outPipe);
+ _speaker = new Speaker(stream);
+ _speaker.Receive += messageHandler;
+ _speaker.Error += errorHandler;
+ // Handshakes already have a 5-second timeout.
+ await _speaker.StartAsync(ct);
+ }
+ catch (Exception e)
+ {
+ throw new Exception("Failed to start RPC Speaker on pipes to subprocess", e);
+ }
+ }
+ catch (Exception e)
+ {
+ _logger.LogError(e, "StartAsync: failed to start or connect to subprocess");
+ await CleanupAsync(ct);
+ throw;
+ }
+ finally
+ {
+ _operationLock.Release();
+ }
+ }
+
+ public async Task StopAsync(CancellationToken ct = default)
+ {
+ _logger.LogInformation("StopAsync()");
+ if (!await _operationLock.WaitAsync(0, ct))
+ throw new InvalidOperationException(
+ "Another TunnelSupervisor Start or Stop operation is already in progress");
+
+ try
+ {
+ await CleanupAsync(ct);
+ }
+ finally
+ {
+ _operationLock.Release();
+ }
+ }
+
+ public async ValueTask DisposeAsync()
+ {
+ _cts.Dispose();
+ await CleanupAsync();
+ GC.SuppressFinalize(this);
+ }
+
+ private async Task OnProcessExited(Task task)
+ {
+ if (task.IsFaulted)
+ {
+ _logger.LogError(task.Exception, "OnProcessExited: subprocess exited with an exception");
+ return;
+ }
+
+ if (!await _operationLock.WaitAsync(0)) _logger.LogInformation("OnProcessExited: subprocess exited");
+
+ try
+ {
+ await CleanupAsync();
+ _logger.LogInformation("OnProcessExited: subprocess exited with code {ExitCode}",
+ _subprocess?.ExitCode ?? -1);
+ }
+ finally
+ {
+ _operationLock.Release();
+ }
+ }
+
+ ///
+ /// Cleans up the pipes and the subprocess if it's still running. This method should not be called without holding the
+ /// semaphore.
+ ///
+ private async Task CleanupAsync(CancellationToken ct = default)
+ {
+ if (_speaker != null)
+ {
+ try
+ {
+ _logger.LogInformation("CleanupAsync: Sending stop message to subprocess");
+ var stopCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
+ stopCts.CancelAfter(5000);
+ await _speaker.SendRequestAwaitReply(new ManagerMessage
+ {
+ Stop = new StopRequest(),
+ }, stopCts.Token);
+ }
+ catch (Exception e)
+ {
+ _logger.LogError(e, "CleanupAsync: Failed to send stop message to subprocess");
+ }
+
+ try
+ {
+ _logger.LogInformation("CleanupAsync: Disposing _speaker");
+ await _speaker.DisposeAsync();
+ }
+ catch (Exception e)
+ {
+ _logger.LogError(e, "CleanupAsync: Failed to stop/dispose _speaker");
+ }
+ finally
+ {
+ _speaker = null;
+ }
+ }
+
+ if (_outPipe != null)
+ {
+ _logger.LogInformation("CleanupAsync: Disposing _outPipe");
+ try
+ {
+ await _outPipe.DisposeAsync();
+ }
+ catch (Exception e)
+ {
+ _logger.LogError(e, "CleanupAsync: Failed to dispose _outPipe");
+ }
+ finally
+ {
+ _outPipe = null;
+ }
+ }
+
+ if (_inPipe != null)
+ {
+ _logger.LogInformation("CleanupAsync: Disposing _inPipe");
+ try
+ {
+ await _inPipe.DisposeAsync();
+ }
+ catch (Exception e)
+ {
+ _logger.LogError(e, "CleanupAsync: Failed to dispose _inPipe");
+ }
+ finally
+ {
+ _inPipe = null;
+ }
+ }
+
+ if (_subprocess != null)
+ try
+ {
+ if (!_subprocess.HasExited)
+ {
+ // TODO: is there a nicer way we can do this?
+ _logger.LogInformation("CleanupAsync: Killing un-exited _subprocess");
+ _subprocess.Kill();
+ // Since we just killed the process ideally it should exit
+ // immediately.
+ var exitCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
+ exitCts.CancelAfter(5000);
+ await _subprocess.WaitForExitAsync(exitCts.Token);
+ }
+
+ _logger.LogInformation("CleanupAsync: Disposing _subprocess");
+ _subprocess.Dispose();
+ }
+ catch (Exception e)
+ {
+ _logger.LogError(e, "CleanupAsync: Failed to kill/dispose _subprocess");
+ }
+ finally
+ {
+ _subprocess = null;
+ }
+ }
+}
diff --git a/Vpn.Service/Vpn.Service.csproj b/Vpn.Service/Vpn.Service.csproj
new file mode 100644
index 0000000..e6da70d
--- /dev/null
+++ b/Vpn.Service/Vpn.Service.csproj
@@ -0,0 +1,24 @@
+
+
+
+ Coder.Desktop.Vpn.Service
+ Exe
+ net8.0-windows
+ enable
+ enable
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/Vpn/Serdes.cs b/Vpn/Serdes.cs
index 317417b..00837b7 100644
--- a/Vpn/Serdes.cs
+++ b/Vpn/Serdes.cs
@@ -1,32 +1,10 @@
using System.Buffers.Binary;
using Coder.Desktop.Vpn.Proto;
+using Coder.Desktop.Vpn.Utilities;
using Google.Protobuf;
namespace Coder.Desktop.Vpn;
-///
-/// RaiiSemaphoreSlim is a wrapper around SemaphoreSlim that provides RAII-style locking.
-///
-internal class RaiiSemaphoreSlim(int initialCount, int maxCount)
-{
- private readonly SemaphoreSlim _semaphore = new(initialCount, maxCount);
-
- public async ValueTask LockAsync(CancellationToken ct = default)
- {
- await _semaphore.WaitAsync(ct);
- return new Lock(_semaphore);
- }
-
- private class Lock(SemaphoreSlim semaphore) : IDisposable
- {
- public void Dispose()
- {
- semaphore.Release();
- GC.SuppressFinalize(this);
- }
- }
-}
-
///
/// Serdes provides serialization and deserialization of messages read from a Stream.
///
diff --git a/Vpn/Speaker.cs b/Vpn/Speaker.cs
index 5bccbe4..4c6ef3c 100644
--- a/Vpn/Speaker.cs
+++ b/Vpn/Speaker.cs
@@ -9,29 +9,43 @@ namespace Coder.Desktop.Vpn;
///
/// Thrown when the two peers are incompatible with each other.
///
-public class RpcVersionCompatibilityException(RpcVersionList localVersion, RpcVersionList remoteVersion)
- : Exception($"No RPC versions are compatible: local={localVersion}, remote={remoteVersion}");
+public class RpcVersionCompatibilityException : Exception
+{
+ public RpcVersionCompatibilityException(RpcVersionList localVersion, RpcVersionList remoteVersion) : base(
+ $"No RPC versions are compatible: local={localVersion}, remote={remoteVersion}")
+ {
+ }
+}
///
/// Wraps a RpcMessage to allow easily sending a reply via the Speaker.
///
-/// Speaker to use for sending reply
-/// Original received message
-public class ReplyableRpcMessage(Speaker speaker, TR message) : RpcMessage
- where TS : RpcMessage, IMessage
- where TR : RpcMessage, IMessage
, new()
+public class ReplyableRpcMessage : RpcMessage
+ where TS : RpcMessage, IRpcMessageCompatibleWith, IMessage
+ where TR : RpcMessage, IRpcMessageCompatibleWith, IMessage, new()
{
+ private readonly TR _message;
+ private readonly Speaker _speaker;
+
public override RPC? RpcField
{
- get => message.RpcField;
- set => message.RpcField = value;
+ get => _message.RpcField;
+ set => _message.RpcField = value;
}
- public override TR Message => message;
+ public override TR Message => _message;
+
+ /// Speaker to use for sending reply
+ /// Original received message
+ public ReplyableRpcMessage(Speaker speaker, TR message)
+ {
+ _speaker = speaker;
+ _message = message;
+ }
public override void Validate()
{
- message.Validate();
+ _message.Validate();
}
///
@@ -41,7 +55,7 @@ public override void Validate()
/// Optional cancellation token
public async Task SendReply(TS reply, CancellationToken ct = default)
{
- await speaker.SendReply(message, reply, ct);
+ await _speaker.SendReply(_message, reply, ct);
}
}
@@ -51,8 +65,8 @@ public async Task SendReply(TS reply, CancellationToken ct = default)
/// The message type for sent messages
/// The message type for received messages
public class Speaker : IAsyncDisposable
- where TS : RpcMessage, IMessage
- where TR : RpcMessage, IMessage
, new()
+ where TS : RpcMessage, IRpcMessageCompatibleWith, IMessage
+ where TR : RpcMessage, IRpcMessageCompatibleWith, IMessage, new()
{
public delegate void OnErrorDelegate(Exception e);
diff --git a/Vpn/Utilities/BidirectionalPipe.cs b/Vpn/Utilities/BidirectionalPipe.cs
new file mode 100644
index 0000000..72e633b
--- /dev/null
+++ b/Vpn/Utilities/BidirectionalPipe.cs
@@ -0,0 +1,101 @@
+using System.IO.Pipelines;
+
+namespace Coder.Desktop.Vpn.Utilities;
+
+///
+/// BidirectionalPipe implements Stream using a read-only Stream and a write-only Stream.
+///
+public class BidirectionalPipe : Stream
+{
+ private readonly Stream _reader;
+ private readonly Stream _writer;
+
+ public override bool CanRead => true;
+ public override bool CanSeek => false;
+ public override bool CanWrite => true;
+ public override long Length => -1;
+
+ public override long Position
+ {
+ get => -1;
+ set => throw new NotImplementedException("BidirectionalPipe does not support setting position");
+ }
+
+ /// The stream to perform reads from
+ /// The stream to write data to
+ public BidirectionalPipe(Stream reader, Stream writer)
+ {
+ _reader = reader;
+ _writer = writer;
+ }
+
+ ///
+ /// Creates a new pair of BidirectionalPipes that are connected to each other using buffered in-memory pipes.
+ ///
+ /// Two pipes connected to each other
+ public static (BidirectionalPipe, BidirectionalPipe) NewInMemory()
+ {
+ var pipe1 = new Pipe();
+ var pipe2 = new Pipe();
+ return (
+ new BidirectionalPipe(pipe1.Reader.AsStream(), pipe2.Writer.AsStream()),
+ new BidirectionalPipe(pipe2.Reader.AsStream(), pipe1.Writer.AsStream())
+ );
+ }
+
+ public override void Flush()
+ {
+ _writer.Flush();
+ }
+
+ public override int Read(byte[] buffer, int offset, int count)
+ {
+ return _reader.Read(buffer, offset, count);
+ }
+
+ public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken ct)
+ {
+#pragma warning disable CA1835
+ return await _reader.ReadAsync(buffer, offset, count, ct);
+#pragma warning restore CA1835
+ }
+
+ public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default)
+ {
+ return _reader.ReadAsync(buffer, cancellationToken);
+ }
+
+ public override long Seek(long offset, SeekOrigin origin)
+ {
+ throw new NotImplementedException("BidirectionalPipe does not support seeking");
+ }
+
+ public override void SetLength(long value)
+ {
+ throw new NotImplementedException("BidirectionalPipe does not support setting length");
+ }
+
+ public override void Write(byte[] buffer, int offset, int count)
+ {
+ _writer.Write(buffer, offset, count);
+ }
+
+ public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken ct)
+ {
+#pragma warning disable CA1835
+ await _writer.WriteAsync(buffer, offset, count, ct);
+#pragma warning restore CA1835
+ }
+
+ public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default)
+ {
+ return _writer.WriteAsync(buffer, cancellationToken);
+ }
+
+ protected override void Dispose(bool disposing)
+ {
+ base.Dispose(disposing);
+ _writer.Dispose();
+ _reader.Dispose();
+ }
+}
diff --git a/Vpn/Utilities/RaiiSemaphoreSlim.cs b/Vpn/Utilities/RaiiSemaphoreSlim.cs
new file mode 100644
index 0000000..f4ecee6
--- /dev/null
+++ b/Vpn/Utilities/RaiiSemaphoreSlim.cs
@@ -0,0 +1,42 @@
+namespace Coder.Desktop.Vpn.Utilities;
+
+///
+/// RaiiSemaphoreSlim is a wrapper around SemaphoreSlim that provides RAII-style locking.
+///
+public class RaiiSemaphoreSlim : IDisposable
+{
+ private readonly SemaphoreSlim _semaphore;
+
+ public RaiiSemaphoreSlim(int initialCount, int maxCount)
+ {
+ _semaphore = new SemaphoreSlim(initialCount, maxCount);
+ }
+
+ public void Dispose()
+ {
+ _semaphore.Dispose();
+ GC.SuppressFinalize(this);
+ }
+
+ public async ValueTask LockAsync(CancellationToken ct = default)
+ {
+ await _semaphore.WaitAsync(ct);
+ return new Lock(_semaphore);
+ }
+
+ private class Lock : IDisposable
+ {
+ private readonly SemaphoreSlim _semaphore1;
+
+ public Lock(SemaphoreSlim semaphore)
+ {
+ _semaphore1 = semaphore;
+ }
+
+ public void Dispose()
+ {
+ _semaphore1.Release();
+ GC.SuppressFinalize(this);
+ }
+ }
+}
diff --git a/Vpn/Utilities/TaskUtilities.cs b/Vpn/Utilities/TaskUtilities.cs
index 8a2bfdb..4105c9e 100644
--- a/Vpn/Utilities/TaskUtilities.cs
+++ b/Vpn/Utilities/TaskUtilities.cs
@@ -1,6 +1,6 @@
namespace Coder.Desktop.Vpn.Utilities;
-internal static class TaskUtilities
+public static class TaskUtilities
{
///
/// Waits for all tasks to complete, but cancels the provided CancellationTokenSource if any task is canceled or
diff --git a/Vpn/Vpn.csproj b/Vpn/Vpn.csproj
index bcef1b5..22b585f 100644
--- a/Vpn/Vpn.csproj
+++ b/Vpn/Vpn.csproj
@@ -11,4 +11,8 @@
+
+
+
+