diff --git a/Coder.Desktop.sln b/Coder.Desktop.sln index 342963b..c5fc598 100644 --- a/Coder.Desktop.sln +++ b/Coder.Desktop.sln @@ -5,7 +5,15 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.Vpn", "Vpn\Vp EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.Vpn.Proto", "Vpn.Proto\Vpn.Proto.csproj", "{318E78BB-E6AD-410F-8F3F-B680F6880293}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.Tests", "Tests\Tests.csproj", "{D247B2E7-38A0-4A69-A710-7E8FAA7B807E}" +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.Vpn.Service", "Vpn.Service\Vpn.Service.csproj", "{51B91794-0A2A-4F84-9935-8E17DD2AB260}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.Tests.Vpn", "Tests.Vpn\Tests.Vpn.csproj", "{D247B2E7-38A0-4A69-A710-7E8FAA7B807E}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.Tests.Vpn.Proto", "Tests.Vpn.Proto\Tests.Vpn.Proto.csproj", "{AA3EEFF4-414B-4A83-8ACF-188C3C61CCE1}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.Tests.Vpn.Service", "Tests.Vpn.Service\Tests.Vpn.Service.csproj", "{D32E5FE1-C251-4A08-8EBE-B8D4F18A36F1}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.CoderSdk", "CoderSdk\CoderSdk.csproj", "{A3D2B2B3-A051-46BD-A190-5487A9F24C28}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -21,9 +29,25 @@ Global {318E78BB-E6AD-410F-8F3F-B680F6880293}.Debug|Any CPU.Build.0 = Debug|Any CPU {318E78BB-E6AD-410F-8F3F-B680F6880293}.Release|Any CPU.ActiveCfg = Release|Any CPU {318E78BB-E6AD-410F-8F3F-B680F6880293}.Release|Any CPU.Build.0 = Release|Any CPU + {51B91794-0A2A-4F84-9935-8E17DD2AB260}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {51B91794-0A2A-4F84-9935-8E17DD2AB260}.Debug|Any CPU.Build.0 = Debug|Any CPU + {51B91794-0A2A-4F84-9935-8E17DD2AB260}.Release|Any CPU.ActiveCfg = Release|Any CPU + {51B91794-0A2A-4F84-9935-8E17DD2AB260}.Release|Any CPU.Build.0 = Release|Any CPU {D247B2E7-38A0-4A69-A710-7E8FAA7B807E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {D247B2E7-38A0-4A69-A710-7E8FAA7B807E}.Debug|Any CPU.Build.0 = Debug|Any CPU {D247B2E7-38A0-4A69-A710-7E8FAA7B807E}.Release|Any CPU.ActiveCfg = Release|Any CPU {D247B2E7-38A0-4A69-A710-7E8FAA7B807E}.Release|Any CPU.Build.0 = Release|Any CPU + {AA3EEFF4-414B-4A83-8ACF-188C3C61CCE1}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {AA3EEFF4-414B-4A83-8ACF-188C3C61CCE1}.Debug|Any CPU.Build.0 = Debug|Any CPU + {AA3EEFF4-414B-4A83-8ACF-188C3C61CCE1}.Release|Any CPU.ActiveCfg = Release|Any CPU + {AA3EEFF4-414B-4A83-8ACF-188C3C61CCE1}.Release|Any CPU.Build.0 = Release|Any CPU + {D32E5FE1-C251-4A08-8EBE-B8D4F18A36F1}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {D32E5FE1-C251-4A08-8EBE-B8D4F18A36F1}.Debug|Any CPU.Build.0 = Debug|Any CPU + {D32E5FE1-C251-4A08-8EBE-B8D4F18A36F1}.Release|Any CPU.ActiveCfg = Release|Any CPU + {D32E5FE1-C251-4A08-8EBE-B8D4F18A36F1}.Release|Any CPU.Build.0 = Release|Any CPU + {A3D2B2B3-A051-46BD-A190-5487A9F24C28}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {A3D2B2B3-A051-46BD-A190-5487A9F24C28}.Debug|Any CPU.Build.0 = Debug|Any CPU + {A3D2B2B3-A051-46BD-A190-5487A9F24C28}.Release|Any CPU.ActiveCfg = Release|Any CPU + {A3D2B2B3-A051-46BD-A190-5487A9F24C28}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection EndGlobal diff --git a/Coder.Desktop.sln.DotSettings b/Coder.Desktop.sln.DotSettings index 636b95d..176e490 100644 --- a/Coder.Desktop.sln.DotSettings +++ b/Coder.Desktop.sln.DotSettings @@ -253,4 +253,7 @@ </Patterns> True + True + True + True True \ No newline at end of file diff --git a/CoderSdk/CoderApiClient.cs b/CoderSdk/CoderApiClient.cs new file mode 100644 index 0000000..90343f3 --- /dev/null +++ b/CoderSdk/CoderApiClient.cs @@ -0,0 +1,81 @@ +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace CoderSdk; + +/// +/// Changes names from PascalCase to snake_case. +/// +internal class SnakeCaseNamingPolicy : JsonNamingPolicy +{ + public override string ConvertName(string name) + { + return string.Concat( + name.Select((x, i) => i > 0 && char.IsUpper(x) ? "_" + char.ToLower(x) : char.ToLower(x).ToString()) + ); + } +} + +/// +/// Provides a limited selection of API methods for a Coder instance. +/// +public partial class CoderApiClient +{ + // TODO: allow adding headers + private readonly HttpClient _httpClient = new(); + private readonly JsonSerializerOptions _jsonOptions; + + public CoderApiClient(string baseUrl) + { + var url = new Uri(baseUrl, UriKind.Absolute); + if (url.PathAndQuery != "/") + throw new ArgumentException($"Base URL '{baseUrl}' must not contain a path", nameof(baseUrl)); + _httpClient.BaseAddress = url; + _jsonOptions = new JsonSerializerOptions + { + PropertyNameCaseInsensitive = true, + PropertyNamingPolicy = new SnakeCaseNamingPolicy(), + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + }; + } + + public CoderApiClient(string baseUrl, string token) : this(baseUrl) + { + SetSessionToken(token); + } + + public void SetSessionToken(string token) + { + _httpClient.DefaultRequestHeaders.Remove("Coder-Session-Token"); + _httpClient.DefaultRequestHeaders.Add("Coder-Session-Token", token); + } + + private async Task SendRequestAsync(HttpMethod method, string path, + object? payload, CancellationToken ct = default) + { + try + { + var request = new HttpRequestMessage(method, path); + + if (payload is not null) + { + var json = JsonSerializer.Serialize(payload, _jsonOptions); + request.Content = new StringContent(json, Encoding.UTF8, "application/json"); + } + + var res = await _httpClient.SendAsync(request, ct); + // TODO: this should be improved to try and parse a codersdk.Error response + res.EnsureSuccessStatusCode(); + + var content = await res.Content.ReadAsStringAsync(ct); + var data = JsonSerializer.Deserialize(content, _jsonOptions); + if (data is null) throw new JsonException("Deserialized response is null"); + return data; + } + catch (Exception e) + { + throw new Exception($"API Request: {method} {path} (req body: {payload is not null})", e); + } + } +} diff --git a/CoderSdk/CoderSdk.csproj b/CoderSdk/CoderSdk.csproj new file mode 100644 index 0000000..3a63532 --- /dev/null +++ b/CoderSdk/CoderSdk.csproj @@ -0,0 +1,9 @@ + + + + net8.0 + enable + enable + + + diff --git a/CoderSdk/Deployment.cs b/CoderSdk/Deployment.cs new file mode 100644 index 0000000..b00d49f --- /dev/null +++ b/CoderSdk/Deployment.cs @@ -0,0 +1,22 @@ +namespace CoderSdk; + +public class BuildInfo +{ + public string ExternalUrl { get; set; } = ""; + public string Version { get; set; } = ""; + public string DashboardUrl { get; set; } = ""; + public bool Telemetry { get; set; } = false; + public bool WorkspaceProxy { get; set; } = false; + public string AgentApiVersion { get; set; } = ""; + public string ProvisionerApiVersion { get; set; } = ""; + public string UpgradeMessage { get; set; } = ""; + public string DeploymentId { get; set; } = ""; +} + +public partial class CoderApiClient +{ + public Task GetBuildInfo(CancellationToken ct = default) + { + return SendRequestAsync(HttpMethod.Get, "/api/v2/buildinfo", null, ct); + } +} diff --git a/CoderSdk/Users.cs b/CoderSdk/Users.cs new file mode 100644 index 0000000..58ff474 --- /dev/null +++ b/CoderSdk/Users.cs @@ -0,0 +1,17 @@ +namespace CoderSdk; + +public class User +{ + public const string Me = "me"; + + // TODO: fill out more fields + public string Username { get; set; } = ""; +} + +public partial class CoderApiClient +{ + public Task GetUser(string user, CancellationToken ct = default) + { + return SendRequestAsync(HttpMethod.Get, $"/api/v2/users/{user}", null, ct); + } +} diff --git a/Tests/Vpn.Proto/RpcHeaderTest.cs b/Tests.Vpn.Proto/RpcHeaderTest.cs similarity index 85% rename from Tests/Vpn.Proto/RpcHeaderTest.cs rename to Tests.Vpn.Proto/RpcHeaderTest.cs index 8e19d0e..55edeea 100644 --- a/Tests/Vpn.Proto/RpcHeaderTest.cs +++ b/Tests.Vpn.Proto/RpcHeaderTest.cs @@ -11,14 +11,14 @@ public void Valid() { var headerStr = "codervpn manager 1.3,2.1"; var header = RpcHeader.Parse(headerStr); - Assert.That(header.Role.ToString(), Is.EqualTo(RpcRole.Manager)); + Assert.That(header.Role, Is.EqualTo("manager")); Assert.That(header.VersionList, Is.EqualTo(new RpcVersionList(new RpcVersion(1, 3), new RpcVersion(2, 1)))); Assert.That(header.ToString(), Is.EqualTo(headerStr + "\n")); Assert.That(header.ToBytes().ToArray(), Is.EqualTo(Encoding.UTF8.GetBytes(headerStr + "\n"))); headerStr = "codervpn tunnel 1.0"; header = RpcHeader.Parse(headerStr); - Assert.That(header.Role.ToString(), Is.EqualTo(RpcRole.Tunnel)); + Assert.That(header.Role, Is.EqualTo("tunnel")); Assert.That(header.VersionList, Is.EqualTo(new RpcVersionList(new RpcVersion(1, 0)))); Assert.That(header.ToString(), Is.EqualTo(headerStr + "\n")); Assert.That(header.ToBytes().ToArray(), Is.EqualTo(Encoding.UTF8.GetBytes(headerStr + "\n"))); @@ -35,7 +35,8 @@ public void ParseInvalid() Assert.That(ex.Message, Does.Contain("Wrong number of parts")); ex = Assert.Throws(() => RpcHeader.Parse("cats manager 1.0")); Assert.That(ex.Message, Does.Contain("Invalid preamble")); - ex = Assert.Throws(() => RpcHeader.Parse("codervpn cats 1.0")); - Assert.That(ex.Message, Does.Contain("Unknown role 'cats'")); + // RpcHeader doesn't care about the role string as long as it isn't empty. + ex = Assert.Throws(() => RpcHeader.Parse("codervpn 1.0")); + Assert.That(ex.Message, Does.Contain("Invalid role in header string")); } } diff --git a/Tests.Vpn.Proto/RpcMessageTest.cs b/Tests.Vpn.Proto/RpcMessageTest.cs new file mode 100644 index 0000000..e254120 --- /dev/null +++ b/Tests.Vpn.Proto/RpcMessageTest.cs @@ -0,0 +1,39 @@ +using Coder.Desktop.Vpn.Proto; + +namespace Coder.Desktop.Tests.Vpn.Proto; + +[TestFixture] +public class RpcRoleAttributeTest +{ + [Test] + public void Ok() + { + var role = new RpcRoleAttribute("manager"); + Assert.That(role.Role, Is.EqualTo("manager")); + role = new RpcRoleAttribute("tunnel"); + Assert.That(role.Role, Is.EqualTo("tunnel")); + role = new RpcRoleAttribute("service"); + Assert.That(role.Role, Is.EqualTo("service")); + role = new RpcRoleAttribute("client"); + Assert.That(role.Role, Is.EqualTo("client")); + } +} + +[TestFixture] +public class RpcMessageTest +{ + [Test] + public void GetRole() + { + // RpcMessage is not a supported message type and doesn't have an + // RpcRoleAttribute + var ex = Assert.Throws(() => _ = RpcMessage.GetRole()); + Assert.That(ex.Message, + Does.Contain("Message type 'Coder.Desktop.Vpn.Proto.RPC' does not have a RpcRoleAttribute")); + + Assert.That(ManagerMessage.GetRole(), Is.EqualTo("manager")); + Assert.That(TunnelMessage.GetRole(), Is.EqualTo("tunnel")); + Assert.That(ServiceMessage.GetRole(), Is.EqualTo("service")); + Assert.That(ClientMessage.GetRole(), Is.EqualTo("client")); + } +} diff --git a/Tests/Vpn.Proto/RpcVersionTest.cs b/Tests.Vpn.Proto/RpcVersionTest.cs similarity index 100% rename from Tests/Vpn.Proto/RpcVersionTest.cs rename to Tests.Vpn.Proto/RpcVersionTest.cs diff --git a/Tests.Vpn.Proto/Tests.Vpn.Proto.csproj b/Tests.Vpn.Proto/Tests.Vpn.Proto.csproj new file mode 100644 index 0000000..54b7b33 --- /dev/null +++ b/Tests.Vpn.Proto/Tests.Vpn.Proto.csproj @@ -0,0 +1,35 @@ + + + + Coder.Desktop.Tests.Vpn.Proto + net8.0 + enable + enable + + false + true + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + + + + + + + diff --git a/Tests.Vpn.Service/DownloaderTest.cs b/Tests.Vpn.Service/DownloaderTest.cs new file mode 100644 index 0000000..ae3a0a0 --- /dev/null +++ b/Tests.Vpn.Service/DownloaderTest.cs @@ -0,0 +1,362 @@ +using System.Security.Cryptography; +using System.Text; +using Coder.Desktop.Vpn.Service; +using Microsoft.Extensions.Logging.Abstractions; + +namespace Coder.Desktop.Tests.Vpn.Service; + +public class TestDownloadValidator : IDownloadValidator +{ + private readonly Exception _e; + + public TestDownloadValidator(Exception e) + { + _e = e; + } + + public Task ValidateAsync(string path, CancellationToken ct = default) + { + throw _e; + } +} + +[TestFixture] +public class AuthenticodeDownloadValidatorTest +{ + [Test(Description = "Test an unsigned binary")] + [CancelAfter(30_000)] + public void Unsigned(CancellationToken ct) + { + // TODO: this + } + + [Test(Description = "Test an untrusted binary")] + [CancelAfter(30_000)] + public void Untrusted(CancellationToken ct) + { + // TODO: this + } + + [Test(Description = "Test an binary with a detached signature (catalog file)")] + [CancelAfter(30_000)] + public void DifferentCertTrusted(CancellationToken ct) + { + // notepad.exe uses a catalog file for its signature. + var ex = Assert.ThrowsAsync(() => + AuthenticodeDownloadValidator.Coder.ValidateAsync(@"C:\Windows\System32\notepad.exe", ct)); + Assert.That(ex.Message, + Does.Contain("File is not signed with an embedded Authenticode signature: Kind=Catalog")); + } + + [Test(Description = "Test a binary signed by a different certificate")] + [CancelAfter(30_000)] + public void DifferentCertUntrusted(CancellationToken ct) + { + // TODO: this + } + + [Test(Description = "Test a binary signed by Coder's certificate")] + [CancelAfter(30_000)] + public async Task CoderSigned(CancellationToken ct) + { + // TODO: this + await Task.CompletedTask; + } +} + +[TestFixture] +public class AssemblyVersionDownloadValidatorTest +{ + [Test(Description = "No version on binary")] + [CancelAfter(30_000)] + public void NoVersion(CancellationToken ct) + { + // TODO: this + } + + [Test(Description = "Version mismatch")] + [CancelAfter(30_000)] + public void VersionMismatch(CancellationToken ct) + { + // TODO: this + } + + [Test(Description = "Version match")] + [CancelAfter(30_000)] + public async Task VersionMatch(CancellationToken ct) + { + // TODO: this + await Task.CompletedTask; + } +} + +[TestFixture] +public class CombinationDownloadValidatorTest +{ + [Test(Description = "All validators pass")] + [CancelAfter(30_000)] + public async Task AllPass(CancellationToken ct) + { + var validator = new CombinationDownloadValidator( + NullDownloadValidator.Instance, + NullDownloadValidator.Instance + ); + await validator.ValidateAsync("test", ct); + } + + [Test(Description = "A validator fails")] + [CancelAfter(30_000)] + public void Fail(CancellationToken ct) + { + var validator = new CombinationDownloadValidator( + NullDownloadValidator.Instance, + new TestDownloadValidator(new Exception("test exception")) + ); + var ex = Assert.ThrowsAsync(() => validator.ValidateAsync("test", ct)); + Assert.That(ex.Message, Is.EqualTo("test exception")); + } +} + +[TestFixture] +public class DownloaderTest +{ + // FYI, SetUp and TearDown get called before and after each test. + [SetUp] + public void Setup() + { + _tempDir = Path.Combine(Path.GetTempPath(), "Coder.Desktop.Tests.Vpn.Service_" + Path.GetRandomFileName()); + Directory.CreateDirectory(_tempDir); + } + + [TearDown] + public void TearDown() + { + Directory.Delete(_tempDir, true); + } + + private string _tempDir; + + private static TestHttpServer EchoServer() + { + // Create webserver that replies to `/xyz` with a test file containing + // `xyz`. + return new TestHttpServer(async ctx => + { + // Get the path without the leading slash. + var path = ctx.Request.Url!.AbsolutePath[1..]; + var pathBytes = Encoding.UTF8.GetBytes(path); + + // If the client sends an If-None-Match header with the correct ETag, + // return 304 Not Modified. + var etag = "\"" + Convert.ToHexString(SHA1.HashData(pathBytes)).ToLower() + "\""; + if (ctx.Request.Headers["If-None-Match"] == etag) + { + ctx.Response.StatusCode = 304; + return; + } + + ctx.Response.StatusCode = 200; + ctx.Response.Headers.Add("ETag", etag); + ctx.Response.ContentType = "text/plain"; + ctx.Response.ContentLength64 = pathBytes.Length; + await ctx.Response.OutputStream.WriteAsync(pathBytes); + }); + } + + [Test(Description = "Perform a download")] + [CancelAfter(30_000)] + public async Task Download(CancellationToken ct) + { + using var httpServer = EchoServer(); + var url = new Uri(httpServer.BaseUrl + "/test"); + var destPath = Path.Combine(_tempDir, "test"); + + var manager = new Downloader(NullLogger.Instance); + var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, + NullDownloadValidator.Instance, ct); + await dlTask.Task; + Assert.That(dlTask.TotalBytes, Is.EqualTo(4)); + Assert.That(dlTask.BytesRead, Is.EqualTo(4)); + Assert.That(dlTask.Progress, Is.EqualTo(1)); + Assert.That(dlTask.IsCompleted, Is.True); + Assert.That(await File.ReadAllTextAsync(destPath, ct), Is.EqualTo("test")); + } + + [Test(Description = "Download with custom headers")] + [CancelAfter(30_000)] + public async Task WithHeaders(CancellationToken ct) + { + using var httpServer = new TestHttpServer(ctx => + { + Assert.That(ctx.Request.Headers["X-Custom-Header"], Is.EqualTo("custom-value")); + ctx.Response.StatusCode = 200; + }); + var url = new Uri(httpServer.BaseUrl + "/test"); + var destPath = Path.Combine(_tempDir, "test"); + + var manager = new Downloader(NullLogger.Instance); + var req = new HttpRequestMessage(HttpMethod.Get, url); + req.Headers.Add("X-Custom-Header", "custom-value"); + var dlTask = await manager.StartDownloadAsync(req, destPath, NullDownloadValidator.Instance, ct); + await dlTask.Task; + } + + [Test(Description = "Perform a download against an existing identical file")] + [CancelAfter(30_000)] + public async Task DownloadExisting(CancellationToken ct) + { + using var httpServer = EchoServer(); + var url = new Uri(httpServer.BaseUrl + "/test"); + var destPath = Path.Combine(_tempDir, "test"); + + // Create the destination file with a very old timestamp. + await File.WriteAllTextAsync(destPath, "test", ct); + File.SetLastWriteTime(destPath, DateTime.Now - TimeSpan.FromDays(365)); + + var manager = new Downloader(NullLogger.Instance); + var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, + NullDownloadValidator.Instance, ct); + await dlTask.Task; + Assert.That(dlTask.BytesRead, Is.Zero); + Assert.That(await File.ReadAllTextAsync(destPath, ct), Is.EqualTo("test")); + Assert.That(File.GetLastWriteTime(destPath), Is.LessThan(DateTime.Now - TimeSpan.FromDays(1))); + } + + [Test(Description = "Perform a download against an existing file with different content")] + [CancelAfter(30_000)] + public async Task DownloadExistingDifferentContent(CancellationToken ct) + { + using var httpServer = EchoServer(); + var url = new Uri(httpServer.BaseUrl + "/test"); + var destPath = Path.Combine(_tempDir, "test"); + + // Create the destination file with a very old timestamp. + await File.WriteAllTextAsync(destPath, "TEST", ct); + File.SetLastWriteTime(destPath, DateTime.Now - TimeSpan.FromDays(365)); + + var manager = new Downloader(NullLogger.Instance); + var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, + NullDownloadValidator.Instance, ct); + await dlTask.Task; + Assert.That(dlTask.BytesRead, Is.EqualTo(4)); + Assert.That(await File.ReadAllTextAsync(destPath, ct), Is.EqualTo("test")); + Assert.That(File.GetLastWriteTime(destPath), Is.GreaterThan(DateTime.Now - TimeSpan.FromDays(1))); + } + + [Test(Description = "Unexpected response code from server")] + [CancelAfter(30_000)] + public void UnexpectedResponseCode(CancellationToken ct) + { + using var httpServer = new TestHttpServer(ctx => { ctx.Response.StatusCode = 404; }); + var url = new Uri(httpServer.BaseUrl + "/test"); + var destPath = Path.Combine(_tempDir, "test"); + + var manager = new Downloader(NullLogger.Instance); + // The "outer" Task should fail. + var ex = Assert.ThrowsAsync(async () => + await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, + NullDownloadValidator.Instance, ct)); + Assert.That(ex.Message, Does.Contain("404")); + } + + // TODO: It would be nice to have a test that tests mismatched + // Content-Length, but it seems HttpListener doesn't allow that. + + [Test(Description = "Mismatched ETag")] + [CancelAfter(30_000)] + public async Task MismatchedETag(CancellationToken ct) + { + using var httpServer = new TestHttpServer(ctx => + { + ctx.Response.StatusCode = 200; + ctx.Response.Headers.Add("ETag", "\"beef\""); + }); + var url = new Uri(httpServer.BaseUrl + "/test"); + var destPath = Path.Combine(_tempDir, "test"); + + var manager = new Downloader(NullLogger.Instance); + // The "inner" Task should fail. + var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, + NullDownloadValidator.Instance, ct); + var ex = Assert.ThrowsAsync(async () => await dlTask.Task); + Assert.That(ex.Message, Does.Contain("ETag does not match SHA1 hash of downloaded file").And.Contains("beef")); + } + + [Test(Description = "Timeout on response headers")] + [CancelAfter(30_000)] + public void CancelledOuter(CancellationToken ct) + { + using var httpServer = new TestHttpServer(async _ => { await Task.Delay(TimeSpan.FromSeconds(5), ct); }); + var url = new Uri(httpServer.BaseUrl + "/test"); + var destPath = Path.Combine(_tempDir, "test"); + + var manager = new Downloader(NullLogger.Instance); + // The "outer" Task should fail. + var smallerCt = new CancellationTokenSource(TimeSpan.FromSeconds(1)).Token; + Assert.ThrowsAsync( + async () => await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, + NullDownloadValidator.Instance, smallerCt)); + } + + [Test(Description = "Timeout on response body")] + [CancelAfter(30_000)] + public async Task CancelledInner(CancellationToken ct) + { + using var httpServer = new TestHttpServer(async ctx => + { + ctx.Response.StatusCode = 200; + await ctx.Response.OutputStream.WriteAsync("test"u8.ToArray(), ct); + await ctx.Response.OutputStream.FlushAsync(ct); + await Task.Delay(TimeSpan.FromSeconds(5), ct); + }); + var url = new Uri(httpServer.BaseUrl + "/test"); + var destPath = Path.Combine(_tempDir, "test"); + + var manager = new Downloader(NullLogger.Instance); + // The "inner" Task should fail. + var smallerCt = new CancellationTokenSource(TimeSpan.FromSeconds(1)).Token; + var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, + NullDownloadValidator.Instance, smallerCt); + var ex = Assert.ThrowsAsync(async () => await dlTask.Task); + Assert.That(ex.CancellationToken, Is.EqualTo(smallerCt)); + } + + [Test(Description = "Validation failure")] + [CancelAfter(30_000)] + public async Task ValidationFailure(CancellationToken ct) + { + using var httpServer = EchoServer(); + var url = new Uri(httpServer.BaseUrl + "/test"); + var destPath = Path.Combine(_tempDir, "test"); + + var manager = new Downloader(NullLogger.Instance); + var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, + new TestDownloadValidator(new Exception("test exception")), ct); + + var ex = Assert.ThrowsAsync(async () => await dlTask.Task); + Assert.That(ex.Message, Does.Contain("Downloaded file failed validation")); + Assert.That(ex.InnerException, Is.Not.Null); + Assert.That(ex.InnerException!.Message, Is.EqualTo("test exception")); + } + + [Test(Description = "Validation failure on existing file")] + [CancelAfter(30_000)] + public async Task ValidationFailureExistingFile(CancellationToken ct) + { + using var httpServer = EchoServer(); + var url = new Uri(httpServer.BaseUrl + "/test"); + var destPath = Path.Combine(_tempDir, "test"); + await File.WriteAllTextAsync(destPath, "test", ct); + + var manager = new Downloader(NullLogger.Instance); + // The "outer" Task should fail because the inner task never starts. + var ex = Assert.ThrowsAsync(async () => + { + await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, + new TestDownloadValidator(new Exception("test exception")), ct); + }); + Assert.That(ex.Message, Does.Contain("Existing file failed validation")); + Assert.That(ex.InnerException, Is.Not.Null); + Assert.That(ex.InnerException!.Message, Is.EqualTo("test exception")); + } +} diff --git a/Tests.Vpn.Service/TestHttpServer.cs b/Tests.Vpn.Service/TestHttpServer.cs new file mode 100644 index 0000000..d33697f --- /dev/null +++ b/Tests.Vpn.Service/TestHttpServer.cs @@ -0,0 +1,106 @@ +using System.Net; +using System.Text; + +namespace Coder.Desktop.Tests.Vpn.Service; + +public class TestHttpServer : IDisposable +{ + // IANA suggested range for dynamic or private ports + private const int MinPort = 49215; + private const int MaxPort = 65535; + private const int PortRangeSize = MaxPort - MinPort + 1; + + private readonly CancellationTokenSource _cts = new(); + private readonly Func _handler; + private readonly HttpListener _listener; + private readonly Thread _listenerThread; + + public string BaseUrl { get; private set; } + + public TestHttpServer(Action handler) : this(ctx => + { + handler(ctx); + return Task.CompletedTask; + }) + { + } + + public TestHttpServer(Func handler) + { + _handler = handler; + + // Yes, this is the best way to get an unused port using HttpListener. + // It sucks. + // + // This implementation picks a random start point between MinPort and + // MaxPort, then iterates through the entire range (wrapping around at + // the end) until it finds a free port. + var port = 0; + var random = new Random(); + var startPort = random.Next(MinPort, MaxPort + 1); + for (var i = 0; i < PortRangeSize; i++) + { + port = MinPort + (startPort - MinPort + i) % PortRangeSize; + + var attempt = new HttpListener(); + attempt.Prefixes.Add($"http://localhost:{port}/"); + try + { + attempt.Start(); + _listener = attempt; + break; + } + catch + { + // Listener disposes itself on failure + } + } + + if (_listener == null || port == 0) + throw new InvalidOperationException("Could not find a free port to listen on"); + BaseUrl = $"http://localhost:{port}"; + + _listenerThread = new Thread(() => + { + while (!_cts.Token.IsCancellationRequested) + try + { + var context = _listener.GetContext(); + Task.Run(() => HandleRequest(context)); + } + catch (HttpListenerException) when (_cts.Token.IsCancellationRequested) + { + break; + } + }); + + _listenerThread.Start(); + } + + public void Dispose() + { + _cts.Cancel(); + _listener.Stop(); + _listenerThread.Join(); + GC.SuppressFinalize(this); + } + + private async Task HandleRequest(HttpListenerContext context) + { + try + { + await _handler(context); + } + catch (Exception e) + { + await Console.Error.WriteLineAsync($"Exception while serving HTTP request: {e}"); + context.Response.StatusCode = 500; + var response = Encoding.UTF8.GetBytes($"Internal Server Error: {e.Message}"); + await context.Response.OutputStream.WriteAsync(response); + } + finally + { + context.Response.Close(); + } + } +} diff --git a/Tests.Vpn.Service/Tests.Vpn.Service.csproj b/Tests.Vpn.Service/Tests.Vpn.Service.csproj new file mode 100644 index 0000000..2fdfa76 --- /dev/null +++ b/Tests.Vpn.Service/Tests.Vpn.Service.csproj @@ -0,0 +1,35 @@ + + + + Coder.Desktop.Tests.Vpn.Service + net8.0-windows + enable + enable + + false + true + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + + + + + + + diff --git a/Tests/Vpn/SerdesTest.cs b/Tests.Vpn/SerdesTest.cs similarity index 65% rename from Tests/Vpn/SerdesTest.cs rename to Tests.Vpn/SerdesTest.cs index 7673d6a..3266f14 100644 --- a/Tests/Vpn/SerdesTest.cs +++ b/Tests.Vpn/SerdesTest.cs @@ -1,6 +1,7 @@ using System.Buffers.Binary; using Coder.Desktop.Vpn; using Coder.Desktop.Vpn.Proto; +using Coder.Desktop.Vpn.Utilities; using Google.Protobuf; namespace Coder.Desktop.Tests.Vpn; @@ -9,26 +10,26 @@ namespace Coder.Desktop.Tests.Vpn; public class SerdesTest { [Test(Description = "Tests that writing and reading a message works")] - [Timeout(5_000)] - public async Task WriteReadMessage() + [CancelAfter(30_000)] + public async Task WriteReadMessage(CancellationToken ct) { - var (stream1, stream2) = BidirectionalPipe.New(); + var (stream1, stream2) = BidirectionalPipe.NewInMemory(); var serdes = new Serdes(); var msg = new ManagerMessage { Start = new StartRequest(), }; - await serdes.WriteMessage(stream1, msg); - var got = await serdes.ReadMessage(stream2); + await serdes.WriteMessage(stream1, msg, ct); + var got = await serdes.ReadMessage(stream2, ct); Assert.That(msg, Is.EqualTo(got)); } [Test(Description = "Tests that writing a message larger than 16 MiB throws an exception")] - [Timeout(5_000)] - public void WriteMessageTooLarge() + [CancelAfter(30_000)] + public void WriteMessageTooLarge(CancellationToken ct) { - var (stream1, _) = BidirectionalPipe.New(); + var (stream1, _) = BidirectionalPipe.NewInMemory(); var serdes = new Serdes(); var msg = new ManagerMessage @@ -39,51 +40,51 @@ public void WriteMessageTooLarge() CoderUrl = "test", }, }; - Assert.ThrowsAsync(() => serdes.WriteMessage(stream1, msg)); + Assert.ThrowsAsync(() => serdes.WriteMessage(stream1, msg, ct)); } [Test(Description = "Tests that attempting to read a message larger than 16 MiB throws an exception")] - [Timeout(5_000)] - public async Task ReadMessageTooLarge() + [CancelAfter(30_000)] + public async Task ReadMessageTooLarge(CancellationToken ct) { - var (stream1, stream2) = BidirectionalPipe.New(); + var (stream1, stream2) = BidirectionalPipe.NewInMemory(); var serdes = new Serdes(); // In this test we don't actually write a message as the parser should // bail out immediately after reading the message length var lenBytes = new byte[4]; BinaryPrimitives.WriteUInt32BigEndian(lenBytes, 0x1000001); - await stream1.WriteAsync(lenBytes); - Assert.ThrowsAsync(() => serdes.ReadMessage(stream2)); + await stream1.WriteAsync(lenBytes, ct); + Assert.ThrowsAsync(() => serdes.ReadMessage(stream2, ct)); } [Test(Description = "Read an empty (size 0) message from the stream")] - [Timeout(5_000)] - public async Task ReadEmptyMessage() + [CancelAfter(30_000)] + public async Task ReadEmptyMessage(CancellationToken ct) { - var (stream1, stream2) = BidirectionalPipe.New(); + var (stream1, stream2) = BidirectionalPipe.NewInMemory(); var serdes = new Serdes(); // Write an empty message. var lenBytes = new byte[4]; BinaryPrimitives.WriteUInt32BigEndian(lenBytes, 0); - await stream1.WriteAsync(lenBytes); - var ex = Assert.ThrowsAsync(() => serdes.ReadMessage(stream2)); + await stream1.WriteAsync(lenBytes, ct); + var ex = Assert.ThrowsAsync(() => serdes.ReadMessage(stream2, ct)); Assert.That(ex.Message, Does.Contain("Received message size 0")); } [Test(Description = "Read an invalid/corrupt message from the stream")] - [Timeout(5_000)] - public async Task ReadInvalidMessage() + [CancelAfter(30_000)] + public async Task ReadInvalidMessage(CancellationToken ct) { - var (stream1, stream2) = BidirectionalPipe.New(); + var (stream1, stream2) = BidirectionalPipe.NewInMemory(); var serdes = new Serdes(); var lenBytes = new byte[4]; BinaryPrimitives.WriteUInt32BigEndian(lenBytes, 1); - await stream1.WriteAsync(lenBytes); - await stream1.WriteAsync(new byte[1]); - var ex = Assert.ThrowsAsync(() => serdes.ReadMessage(stream2)); + await stream1.WriteAsync(lenBytes, ct); + await stream1.WriteAsync(new byte[1], ct); + var ex = Assert.ThrowsAsync(() => serdes.ReadMessage(stream2, ct)); Assert.That(ex.InnerException, Is.TypeOf(typeof(InvalidProtocolBufferException))); } } diff --git a/Tests/Vpn/SpeakerTest.cs b/Tests.Vpn/SpeakerTest.cs similarity index 74% rename from Tests/Vpn/SpeakerTest.cs rename to Tests.Vpn/SpeakerTest.cs index f06c62f..51950f7 100644 --- a/Tests/Vpn/SpeakerTest.cs +++ b/Tests.Vpn/SpeakerTest.cs @@ -1,84 +1,12 @@ -using System.Buffers; -using System.IO.Pipelines; using System.Reflection; using System.Text; using System.Threading.Channels; using Coder.Desktop.Vpn; using Coder.Desktop.Vpn.Proto; +using Coder.Desktop.Vpn.Utilities; namespace Coder.Desktop.Tests.Vpn; -#region BidrectionalPipe - -internal class BidirectionalPipe(PipeReader reader, PipeWriter writer) : Stream -{ - public override bool CanRead => true; - public override bool CanSeek => false; - public override bool CanWrite => true; - public override long Length => -1; - - public override long Position - { - get => -1; - set => throw new NotImplementedException("BidirectionalPipe does not support setting position"); - } - - public static (BidirectionalPipe, BidirectionalPipe) New() - { - var pipe1 = new Pipe(); - var pipe2 = new Pipe(); - return (new BidirectionalPipe(pipe1.Reader, pipe2.Writer), new BidirectionalPipe(pipe2.Reader, pipe1.Writer)); - } - - public override void Flush() - { - } - - public override int Read(byte[] buffer, int offset, int count) - { - return ReadAsync(buffer, offset, count).GetAwaiter().GetResult(); - } - - public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken ct) - { - var result = await reader.ReadAtLeastAsync(1, ct); - var n = Math.Min((int)result.Buffer.Length, count); - // Copy result.Buffer[0:n] to buffer[offset:offset+n] - result.Buffer.Slice(0, n).CopyTo(buffer.AsMemory(offset, n).Span); - if (!result.IsCompleted) reader.AdvanceTo(result.Buffer.GetPosition(n)); - return n; - } - - public override long Seek(long offset, SeekOrigin origin) - { - throw new NotImplementedException("BidirectionalPipe does not support seeking"); - } - - public override void SetLength(long value) - { - throw new NotImplementedException("BidirectionalPipe does not support setting length"); - } - - public override void Write(byte[] buffer, int offset, int count) - { - WriteAsync(buffer, offset, count).GetAwaiter().GetResult(); - } - - public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken ct) - { - await writer.WriteAsync(buffer.AsMemory(offset, count), ct); - } - - protected override void Dispose(bool disposing) - { - base.Dispose(disposing); - writer.Complete(); - reader.Complete(); - } -} - -#endregion - #region FailableStream internal class FailableStream : Stream @@ -88,13 +16,6 @@ internal class FailableStream : Stream private readonly TaskCompletionSource _writeTcs = new(); - public FailableStream(Stream inner, Exception? writeException, Exception? readException) - { - _inner = inner; - if (writeException != null) _writeTcs.SetException(writeException); - if (readException != null) _readTcs.SetException(readException); - } - public override bool CanRead => _inner.CanRead; public override bool CanSeek => _inner.CanSeek; public override bool CanWrite => _inner.CanWrite; @@ -106,6 +27,13 @@ public override long Position set => _inner.Position = value; } + public FailableStream(Stream inner, Exception? writeException, Exception? readException) + { + _inner = inner; + if (writeException != null) _writeTcs.SetException(writeException); + if (readException != null) _readTcs.SetException(readException); + } + public void SetWriteException(Exception ex) { _writeTcs.SetException(ex); @@ -172,10 +100,10 @@ public override async ValueTask WriteAsync(ReadOnlyMemory buffer, public class SpeakerTest { [Test(Description = "Send a message from speaker1 to speaker2, receive it, and send a reply back")] - [Timeout(30_000)] - public async Task SendReceiveReplyReceive() + [CancelAfter(30_000)] + public async Task SendReceiveReplyReceive(CancellationToken ct) { - var (stream1, stream2) = BidirectionalPipe.New(); + var (stream1, stream2) = BidirectionalPipe.NewInMemory(); await using var speaker1 = new Speaker(stream1); var speaker1Ch = Channel @@ -190,14 +118,14 @@ public async Task SendReceiveReplyReceive() speaker2.Error += ex => { Assert.Fail($"speaker2 error: {ex}"); }; // Start both speakers simultaneously - Task.WaitAll(speaker1.StartAsync(), speaker2.StartAsync()); + await Task.WhenAll(speaker1.StartAsync(ct), speaker2.StartAsync(ct)); // Send a normal message from speaker2 to speaker1 await speaker2.SendMessage(new TunnelMessage { PeerUpdate = new PeerUpdate(), - }); - var receivedMessage = await speaker1Ch.Reader.ReadAsync(); + }, ct); + var receivedMessage = await speaker1Ch.Reader.ReadAsync(ct); Assert.That(receivedMessage.RpcField, Is.Null); // not a request Assert.That(receivedMessage.Message.PeerUpdate, Is.Not.Null); @@ -209,10 +137,10 @@ await speaker2.SendMessage(new TunnelMessage ApiToken = "test", CoderUrl = "test", }, - }); + }, ct); // Receive the message in speaker2 - var message = await speaker2Ch.Reader.ReadAsync(); + var message = await speaker2Ch.Reader.ReadAsync(ct); Assert.That(message.RpcField, Is.Not.Null); Assert.That(message.RpcField!.MsgId, Is.Not.EqualTo(0)); Assert.That(message.RpcField!.ResponseTo, Is.EqualTo(0)); @@ -225,7 +153,7 @@ await message.SendReply(new TunnelMessage { Success = true, }, - }); + }, ct); // Receive the reply in speaker1 by awaiting sendTask var reply = await sendTask; @@ -236,57 +164,58 @@ await message.SendReply(new TunnelMessage } [Test(Description = "Encounter a write error during handshake")] - [Timeout(30_000)] - public async Task WriteError() + [CancelAfter(30_000)] + public async Task WriteError(CancellationToken ct) { - var (stream1, _) = BidirectionalPipe.New(); + var (stream1, _) = BidirectionalPipe.NewInMemory(); var writeEx = new IOException("Test write error"); var failStream = new FailableStream(stream1, writeEx, null); await using var speaker = new Speaker(failStream); - var gotEx = Assert.ThrowsAsync(() => speaker.StartAsync()); + var gotEx = Assert.ThrowsAsync(() => speaker.StartAsync(ct)); Assert.That(gotEx, Is.EqualTo(writeEx)); } [Test(Description = "Encounter a read error during handshake")] - [Timeout(30_000)] - public async Task ReadError() + [CancelAfter(30_000)] + public async Task ReadError(CancellationToken ct) { - var (stream1, _) = BidirectionalPipe.New(); + var (stream1, _) = BidirectionalPipe.NewInMemory(); var readEx = new IOException("Test read error"); var failStream = new FailableStream(stream1, null, readEx); await using var speaker = new Speaker(failStream); - var gotEx = Assert.ThrowsAsync(() => speaker.StartAsync()); + var gotEx = Assert.ThrowsAsync(() => speaker.StartAsync(ct)); Assert.That(gotEx, Is.EqualTo(readEx)); } [Test(Description = "Receive a header that exceeds 256 bytes")] - [Timeout(30_000)] - public async Task ReadLargeHeader() + [CancelAfter(30_000)] + public async Task ReadLargeHeader(CancellationToken ct) { - var (stream1, stream2) = BidirectionalPipe.New(); + var (stream1, stream2) = BidirectionalPipe.NewInMemory(); await using var speaker1 = new Speaker(stream1); var header = new byte[257]; for (var i = 0; i < header.Length; i++) header[i] = (byte)'a'; - await stream2.WriteAsync(header); + await stream2.WriteAsync(header, ct); - var gotEx = Assert.ThrowsAsync(() => speaker1.StartAsync()); + var gotEx = Assert.ThrowsAsync(() => speaker1.StartAsync(ct)); Assert.That(gotEx.Message, Does.Contain("Header malformed or too large")); } [Test(Description = "Receive an invalid header")] - [Timeout(30_000)] - public async Task ReceiveInvalidHeader() + [CancelAfter(30_000)] + public async Task ReceiveInvalidHeader(CancellationToken ct) { var cases = new Dictionary { { "invalid\n", ("Failed to parse peer header", "Wrong number of parts in header string") }, { "cats tunnel 1.0\n", ("Failed to parse peer header", "Invalid preamble in header string") }, - { "codervpn cats 1.0\n", ("Failed to parse peer header", "Unknown role 'cats'") }, + { "codervpn 1.0\n", ("Failed to parse peer header", "Invalid role in header string") }, + { "codervpn cats 1.0\n", ("Expected peer role 'tunnel' but got 'cats'", null) }, { "codervpn manager 1.0\n", ("Expected peer role 'tunnel' but got 'manager'", null) }, { "codervpn tunnel 1000.1\n", @@ -299,12 +228,12 @@ public async Task ReceiveInvalidHeader() foreach (var (header, (expectedOuter, expectedInner)) in cases) { - var (stream1, stream2) = BidirectionalPipe.New(); + var (stream1, stream2) = BidirectionalPipe.NewInMemory(); await using var speaker1 = new Speaker(stream1); - await stream2.WriteAsync(Encoding.UTF8.GetBytes(header)); + await stream2.WriteAsync(Encoding.UTF8.GetBytes(header), ct); - var gotEx = Assert.CatchAsync(() => speaker1.StartAsync(), $"header: '{header}'"); + var gotEx = Assert.CatchAsync(() => speaker1.StartAsync(ct), $"header: '{header}'"); Assert.That(gotEx.Message, Does.Contain(expectedOuter), $"header: '{header}'"); if (expectedInner is null) { @@ -318,10 +247,10 @@ public async Task ReceiveInvalidHeader() } [Test(Description = "Encounter a write error during message send")] - [Timeout(30_000)] - public async Task SendMessageWriteError() + [CancelAfter(30_000)] + public async Task SendMessageWriteError(CancellationToken ct) { - var (stream1, stream2) = BidirectionalPipe.New(); + var (stream1, stream2) = BidirectionalPipe.NewInMemory(); var failStream = new FailableStream(stream1, null, null); await using var speaker1 = new Speaker(failStream); @@ -330,7 +259,7 @@ public async Task SendMessageWriteError() await using var speaker2 = new Speaker(stream2); speaker2.Receive += msg => Assert.Fail($"speaker2 received message: {msg}"); speaker2.Error += ex => Assert.Fail($"speaker2 error: {ex}"); - await Task.WhenAll(speaker1.StartAsync(), speaker2.StartAsync()); + await Task.WhenAll(speaker1.StartAsync(ct), speaker2.StartAsync(ct)); var writeEx = new IOException("Test write error"); failStream.SetWriteException(writeEx); @@ -338,15 +267,15 @@ public async Task SendMessageWriteError() var gotEx = Assert.ThrowsAsync(() => speaker1.SendMessage(new ManagerMessage { Start = new StartRequest(), - })); + }, ct)); Assert.That(gotEx, Is.EqualTo(writeEx)); } [Test(Description = "Encounter a read error during message receive")] - [Timeout(30_000)] - public async Task ReceiveMessageReadError() + [CancelAfter(30_000)] + public async Task ReceiveMessageReadError(CancellationToken ct) { - var (stream1, stream2) = BidirectionalPipe.New(); + var (stream1, stream2) = BidirectionalPipe.NewInMemory(); var failStream = new FailableStream(stream1, null, null); // Speaker1 is bound to failStream and will write an error to errorCh @@ -359,13 +288,13 @@ public async Task ReceiveMessageReadError() await using var speaker2 = new Speaker(stream2); speaker2.Receive += msg => Assert.Fail($"speaker2 received message: {msg}"); speaker2.Error += ex => Assert.Fail($"speaker2 error: {ex}"); - await Task.WhenAll(speaker1.StartAsync(), speaker2.StartAsync()); + await Task.WhenAll(speaker1.StartAsync(ct), speaker2.StartAsync(ct)); // Now the handshake is complete, cause all reads to fail var readEx = new IOException("Test write error"); failStream.SetReadException(readEx); - var gotEx = await errorCh.Reader.ReadAsync(); + var gotEx = await errorCh.Reader.ReadAsync(ct); Assert.That(gotEx, Is.EqualTo(readEx)); // The receive loop should be stopped within a timely fashion. @@ -377,24 +306,24 @@ public async Task ReceiveMessageReadError() } else { - var delayTask = Task.Delay(TimeSpan.FromSeconds(5)); + var delayTask = Task.Delay(TimeSpan.FromSeconds(5), ct); await Task.WhenAny(receiveLoopTask, delayTask); Assert.That(receiveLoopTask.IsCompleted, Is.True); } } [Test(Description = "Handle dispose while receive loop is running")] - [Timeout(30_000)] - public async Task DisposeWhileReceiveLoopRunning() + [CancelAfter(30_000)] + public async Task DisposeWhileReceiveLoopRunning(CancellationToken ct) { - var (stream1, stream2) = BidirectionalPipe.New(); + var (stream1, stream2) = BidirectionalPipe.NewInMemory(); var speaker1 = new Speaker(stream1); await using var speaker2 = new Speaker(stream2); - await Task.WhenAll(speaker1.StartAsync(), speaker2.StartAsync()); + await Task.WhenAll(speaker1.StartAsync(ct), speaker2.StartAsync(ct)); // Dispose should happen in a timely fashion var disposeTask = speaker1.DisposeAsync(); - var delayTask = Task.Delay(TimeSpan.FromSeconds(5)); + var delayTask = Task.Delay(TimeSpan.FromSeconds(5), ct); await Task.WhenAny(disposeTask.AsTask(), delayTask); Assert.That(disposeTask.IsCompleted, Is.True); @@ -408,19 +337,19 @@ public async Task DisposeWhileReceiveLoopRunning() } [Test(Description = "Handle dispose while a message is awaiting a reply")] - [Timeout(30_000)] - public async Task DisposeWhileAwaitingReply() + [CancelAfter(30_000)] + public async Task DisposeWhileAwaitingReply(CancellationToken ct) { - var (stream1, stream2) = BidirectionalPipe.New(); + var (stream1, stream2) = BidirectionalPipe.NewInMemory(); var speaker1 = new Speaker(stream1); await using var speaker2 = new Speaker(stream2); - await Task.WhenAll(speaker1.StartAsync(), speaker2.StartAsync()); + await Task.WhenAll(speaker1.StartAsync(ct), speaker2.StartAsync(ct)); // Send a message from speaker1 to speaker2 var sendTask = speaker1.SendRequestAwaitReply(new ManagerMessage { Start = new StartRequest(), - }); + }, ct); // Dispose speaker1 await speaker1.DisposeAsync(); diff --git a/Tests/Tests.csproj b/Tests.Vpn/Tests.Vpn.csproj similarity index 54% rename from Tests/Tests.csproj rename to Tests.Vpn/Tests.Vpn.csproj index cccd5dc..df00e81 100644 --- a/Tests/Tests.csproj +++ b/Tests.Vpn/Tests.Vpn.csproj @@ -1,7 +1,7 @@ - Coder.Desktop.Tests + Coder.Desktop.Tests.Vpn net8.0 enable enable @@ -11,12 +11,17 @@ - - - - - - + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + diff --git a/Tests.Vpn/Utilities/TaskUtilitiesTest.cs b/Tests.Vpn/Utilities/TaskUtilitiesTest.cs new file mode 100644 index 0000000..a6a4583 --- /dev/null +++ b/Tests.Vpn/Utilities/TaskUtilitiesTest.cs @@ -0,0 +1,141 @@ +using Coder.Desktop.Vpn.Utilities; + +namespace Coder.Desktop.Tests.Vpn.Utilities; + +[TestFixture] +public class TaskUtilitiesTest +{ + [Test(Description = "CancellableWhenAll with no tasks should complete immediately")] + [Timeout(30_000)] + public void CancellableWhenAll_NoTasks() + { + var task = TaskUtilities.CancellableWhenAll(new CancellationTokenSource()); + Assert.That(task.IsCompleted, Is.True); + } + + [Test(Description = "CancellableWhenAll with a single task should complete")] + [Timeout(30_000)] + public async Task CancellableWhenAll_SingleTask() + { + var innerTask = new TaskCompletionSource(); + var task = TaskUtilities.CancellableWhenAll(new CancellationTokenSource(), innerTask.Task); + Assert.That(task.IsCompleted, Is.False); + innerTask.SetResult(); + await task; + } + + [Test(Description = "CancellableWhenAll with a single task that faults should propagate the exception")] + [Timeout(30_000)] + public void CancellableWhenAll_SingleTaskFault() + { + var cts = new CancellationTokenSource(); + var innerTask = new TaskCompletionSource(); + var task = TaskUtilities.CancellableWhenAll(cts, innerTask.Task); + Assert.That(task.IsCompleted, Is.False); + innerTask.SetException(new InvalidOperationException("Test")); + Assert.ThrowsAsync(async () => await task); + Assert.That(cts.IsCancellationRequested, Is.True); + } + + [Test(Description = "CancellableWhenAll with a single task that is canceled should propagate the cancellation")] + [Timeout(30_000)] + public void CancellableWhenAll_SingleTaskCanceled() + { + var cts = new CancellationTokenSource(); + var innerTask = new TaskCompletionSource(); + var task = TaskUtilities.CancellableWhenAll(cts, innerTask.Task); + Assert.That(task.IsCompleted, Is.False); + innerTask.SetCanceled(); + Assert.ThrowsAsync(async () => await task); + Assert.That(cts.IsCancellationRequested, Is.True); + } + + [Test(Description = "CancellableWhenAll with multiple tasks should complete when all tasks are completed")] + [Timeout(30_000)] + public async Task CancellableWhenAll_MultipleTasks() + { + var cts = new CancellationTokenSource(); + var innerTask1 = new TaskCompletionSource(); + var innerTask2 = new TaskCompletionSource(); + + var task = TaskUtilities.CancellableWhenAll(cts, innerTask1.Task, innerTask2.Task); + Assert.That(task.IsCompleted, Is.False); + // This dance of awaiting a newly added continuation task before + // completing the TCS is to ensure that the original continuation task + // finished since it's inlinable. + var task1ContinueTask = innerTask1.Task.ContinueWith(_ => { }); + innerTask1.SetResult(); + await task1ContinueTask; + Assert.That(task.IsCompleted, Is.False); + var task2ContinueTask = innerTask2.Task.ContinueWith(_ => { }); + innerTask2.SetResult(); + await task2ContinueTask; + await task; + } + + [Test(Description = "CancellableWhenAll with multiple tasks that fault should propagate the first exception only")] + [Timeout(30_000)] + public async Task CancellableWhenAll_MultipleTasksFault() + { + var cts = new CancellationTokenSource(); + var innerTask1 = new TaskCompletionSource(); + var innerTask2 = new TaskCompletionSource(); + + var task = TaskUtilities.CancellableWhenAll(cts, innerTask1.Task, innerTask2.Task); + Assert.That(task.IsCompleted, Is.False); + var task1ContinueTask = innerTask1.Task.ContinueWith(_ => { }); + innerTask1.SetException(new Exception("Test1")); + await task1ContinueTask; + Assert.That(task.IsCompleted, Is.False); + var task2ContinueTask = innerTask2.Task.ContinueWith(_ => { }); + innerTask2.SetException(new Exception("Test2")); + await task2ContinueTask; + var ex = Assert.ThrowsAsync(async () => await task); + Assert.That(ex.Message, Is.EqualTo("Test1")); + } + + [Test(Description = "CancellableWhenAll with an exception and a cancellation should propagate the first thing")] + [Timeout(30_000)] + public async Task CancellableWhenAll_MultipleTasksFaultAndCanceled() + { + var cts = new CancellationTokenSource(); + var innerTask1 = new TaskCompletionSource(); + var innerTask2 = new TaskCompletionSource(); + var innertask3 = Task.CompletedTask; + + var task = TaskUtilities.CancellableWhenAll(cts, innerTask1.Task, innerTask2.Task, innertask3); + Assert.That(task.IsCompleted, Is.False); + var task1ContinueTask = innerTask1.Task.ContinueWith(_ => { }); + innerTask1.SetException(new Exception("Test1")); + await task1ContinueTask; + Assert.That(task.IsCompleted, Is.False); + Assert.That(cts.IsCancellationRequested, Is.True); + var task2ContinueTask = innerTask2.Task.ContinueWith(_ => { }); + innerTask2.SetCanceled(); + await task2ContinueTask; + var ex = Assert.ThrowsAsync(async () => await task); + Assert.That(ex.Message, Is.EqualTo("Test1")); + } + + [Test(Description = "CancellableWhenAll with a cancellation and an exception should propagate the first thing")] + [Timeout(30_000)] + public async Task CancellableWhenAll_MultipleTasksCanceledAndFault() + { + var cts = new CancellationTokenSource(); + var innerTask1 = new TaskCompletionSource(); + var innerTask2 = new TaskCompletionSource(); + var innertask3 = Task.CompletedTask; + + var task = TaskUtilities.CancellableWhenAll(cts, innerTask1.Task, innerTask2.Task, innertask3); + Assert.That(task.IsCompleted, Is.False); + var task1ContinueTask = innerTask1.Task.ContinueWith(_ => { }); + innerTask1.SetCanceled(); + await task1ContinueTask; + Assert.That(task.IsCompleted, Is.False); + Assert.That(cts.IsCancellationRequested, Is.True); + var task2ContinueTask = innerTask2.Task.ContinueWith(_ => { }); + innerTask2.SetException(new Exception("Test2")); + await task2ContinueTask; + Assert.ThrowsAsync(async () => await task); + } +} diff --git a/Tests/Vpn.Proto/RpcMessageTest.cs b/Tests/Vpn.Proto/RpcMessageTest.cs deleted file mode 100644 index 36de12d..0000000 --- a/Tests/Vpn.Proto/RpcMessageTest.cs +++ /dev/null @@ -1,39 +0,0 @@ -using Coder.Desktop.Vpn.Proto; - -namespace Coder.Desktop.Tests.Vpn.Proto; - -[TestFixture] -public class RpcRoleAttributeTest -{ - [Test] - public void Valid() - { - var role = new RpcRoleAttribute(RpcRole.Manager); - Assert.That(role.Role.ToString(), Is.EqualTo(RpcRole.Manager)); - role = new RpcRoleAttribute(RpcRole.Tunnel); - Assert.That(role.Role.ToString(), Is.EqualTo(RpcRole.Tunnel)); - } - - [Test] - public void Invalid() - { - Assert.Throws(() => _ = new RpcRoleAttribute("cats")); - } -} - -[TestFixture] -public class RpcMessageTest -{ - [Test] - public void GetRole() - { - // RpcMessage is not a supported message type and doesn't have an - // RpcRoleAttribute - var ex = Assert.Throws(() => _ = RpcMessage.GetRole()); - Assert.That(ex.Message, - Does.Contain("Message type 'Coder.Desktop.Vpn.Proto.RPC' does not have a RpcRoleAttribute")); - - Assert.That(ManagerMessage.GetRole().ToString(), Is.EqualTo(RpcRole.Manager)); - Assert.That(TunnelMessage.GetRole().ToString(), Is.EqualTo(RpcRole.Tunnel)); - } -} diff --git a/Tests/Vpn.Proto/RpcRoleTest.cs b/Tests/Vpn.Proto/RpcRoleTest.cs deleted file mode 100644 index f39d5cb..0000000 --- a/Tests/Vpn.Proto/RpcRoleTest.cs +++ /dev/null @@ -1,22 +0,0 @@ -using Coder.Desktop.Vpn.Proto; - -namespace Coder.Desktop.Tests.Vpn.Proto; - -[TestFixture] -public class RpcRoleTest -{ - [Test(Description = "Instantiate a RpcRole with a valid name")] - public void ValidRole() - { - var role = new RpcRole(RpcRole.Manager); - Assert.That(role.ToString(), Is.EqualTo(RpcRole.Manager)); - role = new RpcRole(RpcRole.Tunnel); - Assert.That(role.ToString(), Is.EqualTo(RpcRole.Tunnel)); - } - - [Test(Description = "Try to instantiate a RpcRole with an invalid name")] - public void InvalidRole() - { - Assert.Throws(() => _ = new RpcRole("cats")); - } -} diff --git a/Vpn.Proto/RpcHeader.cs b/Vpn.Proto/RpcHeader.cs index 0b840db..cf7ffcc 100644 --- a/Vpn.Proto/RpcHeader.cs +++ b/Vpn.Proto/RpcHeader.cs @@ -3,16 +3,22 @@ namespace Coder.Desktop.Vpn.Proto; /// -/// A header to write or read from a stream to identify the speaker's role and version. +/// A header to write or read from a stream to identify the peer role and version. /// -/// Role of the speaker -/// Version of the speaker -public class RpcHeader(RpcRole role, RpcVersionList versionList) +public class RpcHeader { private const string Preamble = "codervpn"; - public RpcRole Role { get; } = role; - public RpcVersionList VersionList { get; } = versionList; + public string Role { get; } + public RpcVersionList VersionList { get; } + + /// Role of the peer + /// Version of the peer + public RpcHeader(string role, RpcVersionList versionList) + { + Role = role; + VersionList = versionList; + } /// /// Parse a header string into a SpeakerHeader. @@ -25,10 +31,10 @@ public static RpcHeader Parse(string header) var parts = header.Split(' '); if (parts.Length != 3) throw new ArgumentException($"Wrong number of parts in header string '{header}'"); if (parts[0] != Preamble) throw new ArgumentException($"Invalid preamble in header string '{header}'"); + if (string.IsNullOrEmpty(parts[1])) throw new ArgumentException($"Invalid role in header string '{header}'"); - var role = new RpcRole(parts[1]); var versionList = RpcVersionList.Parse(parts[2]); - return new RpcHeader(role, versionList); + return new RpcHeader(parts[1], versionList); } /// diff --git a/Vpn.Proto/RpcMessage.cs b/Vpn.Proto/RpcMessage.cs index c44168c..bfe4d82 100644 --- a/Vpn.Proto/RpcMessage.cs +++ b/Vpn.Proto/RpcMessage.cs @@ -4,11 +4,23 @@ namespace Coder.Desktop.Vpn.Proto; [AttributeUsage(AttributeTargets.Class, Inherited = false)] -public class RpcRoleAttribute(string role) : Attribute +public class RpcRoleAttribute : Attribute { - public RpcRole Role { get; } = new(role); + public string Role { get; } + + public RpcRoleAttribute(string role) + { + Role = role; + } } +/// +/// IRpcMessageCompatibleWith is a marker interface that indicates that a +/// message type can be used to peer with another message type. +/// +/// +public interface IRpcMessageCompatibleWith; + /// /// Represents an actual over-the-wire message type. /// @@ -36,9 +48,9 @@ public abstract class RpcMessage where T : IMessage /// /// Gets the RpcRole of the message type from it's RpcRole attribute. /// - /// + /// The role string /// The message type does not have an RpcRoleAttribute - public static RpcRole GetRole() + public static string GetRole() { var type = typeof(T); var attr = type.GetCustomAttribute(); @@ -47,8 +59,8 @@ public static RpcRole GetRole() } } -[RpcRole(RpcRole.Manager)] -public partial class ManagerMessage : RpcMessage +[RpcRole("manager")] +public partial class ManagerMessage : RpcMessage, IRpcMessageCompatibleWith { public override RPC? RpcField { @@ -64,8 +76,8 @@ public override void Validate() } } -[RpcRole(RpcRole.Tunnel)] -public partial class TunnelMessage : RpcMessage +[RpcRole("tunnel")] +public partial class TunnelMessage : RpcMessage, IRpcMessageCompatibleWith { public override RPC? RpcField { @@ -80,3 +92,37 @@ public override void Validate() if (MsgCase == MsgOneofCase.None) throw new ArgumentException("Message does not contain inner message type"); } } + +[RpcRole("service")] +public partial class ServiceMessage : RpcMessage, IRpcMessageCompatibleWith +{ + public override RPC? RpcField + { + get => Rpc; + set => Rpc = value; + } + + public override ServiceMessage Message => this; + + public override void Validate() + { + if (MsgCase == MsgOneofCase.None) throw new ArgumentException("Message does not contain inner message type"); + } +} + +[RpcRole("client")] +public partial class ClientMessage : RpcMessage, IRpcMessageCompatibleWith +{ + public override RPC? RpcField + { + get => Rpc; + set => Rpc = value; + } + + public override ClientMessage Message => this; + + public override void Validate() + { + if (MsgCase == MsgOneofCase.None) throw new ArgumentException("Message does not contain inner message type"); + } +} diff --git a/Vpn.Proto/RpcRole.cs b/Vpn.Proto/RpcRole.cs deleted file mode 100644 index 9190281..0000000 --- a/Vpn.Proto/RpcRole.cs +++ /dev/null @@ -1,56 +0,0 @@ -namespace Coder.Desktop.Vpn.Proto; - -/// -/// Represents a role that either side of the connection can fulfil. -/// -public sealed class RpcRole -{ - public const string Manager = "manager"; - public const string Tunnel = "tunnel"; - - public RpcRole(string role) - { - if (role != Manager && role != Tunnel) throw new ArgumentException($"Unknown role '{role}'"); - - Role = role; - } - - private string Role { get; } - - public override string ToString() - { - return Role; - } - - #region SpeakerRole equality - - public static bool operator ==(RpcRole a, RpcRole b) - { - return a.Equals(b); - } - - public static bool operator !=(RpcRole a, RpcRole b) - { - return !a.Equals(b); - } - - private bool Equals(RpcRole other) - { - return Role == other.Role; - } - - public override bool Equals(object? obj) - { - if (obj is null) return false; - if (ReferenceEquals(this, obj)) return true; - if (obj.GetType() != GetType()) return false; - return Equals((RpcRole)obj); - } - - public override int GetHashCode() - { - return Role.GetHashCode(); - } - - #endregion -} diff --git a/Vpn.Proto/RpcVersion.cs b/Vpn.Proto/RpcVersion.cs index a9b1914..574768d 100644 --- a/Vpn.Proto/RpcVersion.cs +++ b/Vpn.Proto/RpcVersion.cs @@ -3,14 +3,20 @@ /// /// A version of the RPC API. Can be compared other versions to determine compatibility between two peers. /// -/// The major version of the peer -/// The minor version of the peer -public class RpcVersion(ulong major, ulong minor) +public class RpcVersion { public static readonly RpcVersion Current = new(1, 0); - public ulong Major { get; } = major; - public ulong Minor { get; } = minor; + public ulong Major { get; } + public ulong Minor { get; } + + /// The major version of the peer + /// The minor version of the peer + public RpcVersion(ulong major, ulong minor) + { + Major = major; + Minor = minor; + } /// /// Parse a string in the format "major.minor" into an ApiVersion. diff --git a/Vpn.Proto/Vpn.Proto.csproj b/Vpn.Proto/Vpn.Proto.csproj index 5380bd4..6acb12e 100644 --- a/Vpn.Proto/Vpn.Proto.csproj +++ b/Vpn.Proto/Vpn.Proto.csproj @@ -12,8 +12,8 @@ - - + + all runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/Vpn.Proto/vpn.proto b/Vpn.Proto/vpn.proto index 33a3ff4..a03978a 100644 --- a/Vpn.Proto/vpn.proto +++ b/Vpn.Proto/vpn.proto @@ -44,6 +44,24 @@ message TunnelMessage { } } +// ClientMessage is a message from the client (to the service). +message ClientMessage { + RPC rpc = 1; + oneof msg { + StartRequest start = 2; + StopRequest stop = 3; + } +} + +// ServiceMessage is a message from the service (to the client). +message ServiceMessage { + RPC rpc = 1; + oneof msg { + StartResponse start = 2; + StopResponse stop = 3; + } +} + // Log is a log message generated by the tunnel. The manager should log it to the system log. It is // one-way tunnel -> manager with no response. message Log { @@ -105,7 +123,7 @@ message Agent { bytes id = 1; // UUID string name = 2; bytes workspace_id = 3; // UUID - string fqdn = 4; + repeated string fqdn = 4; repeated string ip_addrs = 5; // last_handshake is the primary indicator of whether we are connected to a peer. Zero value or // anything longer than 5 minutes ago means there is a problem. @@ -179,6 +197,12 @@ message StartRequest { int32 tunnel_file_descriptor = 1; string coder_url = 2; string api_token = 3; + // Additional HTTP headers added to all requests + message Header { + string name = 1; + string value = 2; + } + repeated Header headers = 4; } message StartResponse { diff --git a/Vpn.Service/Downloader.cs b/Vpn.Service/Downloader.cs new file mode 100644 index 0000000..83eda24 --- /dev/null +++ b/Vpn.Service/Downloader.cs @@ -0,0 +1,355 @@ +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Net; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using Coder.Desktop.Vpn.Utilities; +using Microsoft.Extensions.Logging; +using Microsoft.Security.Extensions; + +namespace Coder.Desktop.Vpn.Service; + +public interface IDownloader +{ + Task StartDownloadAsync(HttpRequestMessage req, string destinationPath, IDownloadValidator validator, + CancellationToken ct = default); +} + +public interface IDownloadValidator +{ + /// + /// Validates the downloaded file at the given path. This method should throw an exception if the file is invalid. + /// + /// The path of the file + /// Cancellation token + Task ValidateAsync(string path, CancellationToken ct = default); +} + +public class NullDownloadValidator : IDownloadValidator +{ + public static NullDownloadValidator Instance => new(); + + public Task ValidateAsync(string path, CancellationToken ct = default) + { + return Task.CompletedTask; + } +} + +/// +/// Ensures the downloaded binary is signed by the expected authenticode organization. +/// +public class AuthenticodeDownloadValidator : IDownloadValidator +{ + private readonly string _expectedName; + + public static AuthenticodeDownloadValidator Coder => new("Coder Technologies Inc."); + + public AuthenticodeDownloadValidator(string expectedName) + { + _expectedName = expectedName; + } + + public async Task ValidateAsync(string path, CancellationToken ct = default) + { + FileSignatureInfo fileSigInfo; + await using (var fileStream = File.OpenRead(path)) + { + fileSigInfo = FileSignatureInfo.GetFromFileStream(fileStream); + } + + if (fileSigInfo.State != SignatureState.SignedAndTrusted) + throw new Exception( + $"File is not signed and trusted with an Authenticode signature: State={fileSigInfo.State}"); + + // Coder will only use embedded signatures because we are downloading + // individual binaries and not installers which can ship catalog files. + if (fileSigInfo.Kind != SignatureKind.Embedded) + throw new Exception($"File is not signed with an embedded Authenticode signature: Kind={fileSigInfo.Kind}"); + + // TODO: check that it's an extended validation certificate + + var actualName = fileSigInfo.SigningCertificate.GetNameInfo(X509NameType.SimpleName, false); + if (actualName != _expectedName) + throw new Exception( + $"File is signed by an unexpected certificate: ExpectedName='{_expectedName}', ActualName='{actualName}'"); + } +} + +public class AssemblyVersionDownloadValidator : IDownloadValidator +{ + private readonly string _expectedAssemblyVersion; + + public AssemblyVersionDownloadValidator(string expectedAssemblyVersion) + { + _expectedAssemblyVersion = expectedAssemblyVersion; + } + + public Task ValidateAsync(string path, CancellationToken ct = default) + { + var info = FileVersionInfo.GetVersionInfo(path); + if (string.IsNullOrEmpty(info.ProductVersion)) + throw new Exception("File ProductVersion is empty or null, was the binary compiled correctly?"); + if (info.ProductVersion != _expectedAssemblyVersion) + throw new Exception( + $"File ProductVersion is '{info.ProductVersion}', but expected '{_expectedAssemblyVersion}'"); + return Task.CompletedTask; + } +} + +/// +/// Combines multiple download validators into a single validator. All validators will be run in order. +/// +public class CombinationDownloadValidator : IDownloadValidator +{ + private readonly IDownloadValidator[] _validators; + + /// Validators to run + public CombinationDownloadValidator(params IDownloadValidator[] validators) + { + _validators = validators; + } + + public async Task ValidateAsync(string path, CancellationToken ct = default) + { + foreach (var validator in _validators) + await validator.ValidateAsync(path, ct); + } +} + +/// +/// Handles downloading files from the internet. Downloads are performed asynchronously using DownloadTask. +/// Single-flight is provided to avoid performing the same download multiple times. +/// +public class Downloader : IDownloader +{ + private readonly ConcurrentDictionary _downloads = new(); + private readonly ILogger _logger; + + // ReSharper disable once ConvertToPrimaryConstructor + public Downloader(ILogger logger) + { + _logger = logger; + } + + /// + /// Starts a download with the given request. The If-None-Match header will be set to the SHA1 ETag of any existing + /// file in the destination location. + /// + /// Request message + /// Path to write file to (will be overwritten) + /// Validator for the downloaded file + /// Cancellation token + /// A DownloadTask representing the ongoing download operation after it starts + public async Task StartDownloadAsync(HttpRequestMessage req, string destinationPath, + IDownloadValidator validator, CancellationToken ct = default) + { + while (true) + { + var task = _downloads.GetOrAdd(destinationPath, + _ => new DownloadTask(_logger, req, destinationPath, validator)); + await task.EnsureStartedAsync(ct); + + // If the existing (or new) task is for the same URL, return it. + if (task.Request.RequestUri == req.RequestUri) + return task; + + // If the existing task is for a different URL, await its completion + // then retry the loop to create a new task. This could potentially + // get stuck if there are a lot of download operations for different + // URLs and the same destination path, but in our use case this + // shouldn't happen unless the user keeps changing the access URL. + _logger.LogWarning( + "Download for '{DestinationPath}' is already in progress, but is for a different Url - awaiting completion", + destinationPath); + await task.Task; + } + } +} + +/// +/// Downloads an Url to a file on disk. The download will be written to a temporary file first, then moved to the final +/// destination. The SHA1 of any existing file will be calculated and used as an ETag to avoid downloading the file if +/// it hasn't changed. +/// +public class DownloadTask +{ + private const int BufferSize = 4096; + + private static readonly HttpClient HttpClient = new(); + private readonly string _destinationDirectory; + + private readonly ILogger _logger; + + private readonly RaiiSemaphoreSlim _semaphore = new(1, 1); + private readonly IDownloadValidator _validator; + public readonly string DestinationPath; + + public readonly HttpRequestMessage Request; + public readonly string TempDestinationPath; + + public ulong? TotalBytes { get; private set; } + public ulong BytesRead { get; private set; } + public Task Task { get; private set; } = null!; // Set in EnsureStartedAsync + + public double? Progress => TotalBytes == null ? null : (double)BytesRead / TotalBytes.Value; + public bool IsCompleted => Task.IsCompleted; + + internal DownloadTask(ILogger logger, HttpRequestMessage req, string destinationPath, IDownloadValidator validator) + { + _logger = logger; + Request = req; + _validator = validator; + + if (string.IsNullOrWhiteSpace(destinationPath)) + throw new ArgumentException("Destination path must not be empty", nameof(destinationPath)); + DestinationPath = Path.GetFullPath(destinationPath); + if (Path.EndsInDirectorySeparator(DestinationPath)) + throw new ArgumentException($"Destination path '{DestinationPath}' must not end in a directory separator", + nameof(destinationPath)); + + _destinationDirectory = Path.GetDirectoryName(DestinationPath) + ?? throw new ArgumentException( + $"Destination path '{DestinationPath}' must have a parent directory", + nameof(destinationPath)); + + TempDestinationPath = Path.Combine(_destinationDirectory, "." + Path.GetFileName(DestinationPath) + + ".download-" + Path.GetRandomFileName()); + } + + internal async Task EnsureStartedAsync(CancellationToken ct = default) + { + using var _ = await _semaphore.LockAsync(ct); + if (Task == null!) + Task = await StartDownloadAsync(ct); + + return Task; + } + + /// + /// Starts downloading the file. The request will be performed in this task, but once started, the task will complete + /// and the download will continue in the background. The provided CancellationToken can be used to cancel the + /// download. + /// + private async Task StartDownloadAsync(CancellationToken ct = default) + { + Directory.CreateDirectory(_destinationDirectory); + + // If the destination path exists, generate a Coder SHA1 ETag and send + // it in the If-None-Match header to the server. + if (File.Exists(DestinationPath)) + { + await using var stream = File.OpenRead(DestinationPath); + var etag = Convert.ToHexString(await SHA1.HashDataAsync(stream, ct)).ToLower(); + Request.Headers.Add("If-None-Match", "\"" + etag + "\""); + } + + var res = await HttpClient.SendAsync(Request, HttpCompletionOption.ResponseHeadersRead, ct); + if (res.StatusCode == HttpStatusCode.NotModified) + { + _logger.LogInformation("File has not been modified, skipping download"); + try + { + await _validator.ValidateAsync(DestinationPath, ct); + } + catch (Exception e) + { + _logger.LogWarning(e, "Existing file '{DestinationPath}' failed custom validation", DestinationPath); + throw new Exception("Existing file failed validation after 304 Not Modified", e); + } + + Task = Task.CompletedTask; + return Task; + } + + if (res.StatusCode != HttpStatusCode.OK) + { + _logger.LogWarning("Failed to download file '{Request.RequestUri}': {StatusCode} {ReasonPhrase}", + Request.RequestUri, res.StatusCode, + res.ReasonPhrase); + throw new HttpRequestException( + $"Failed to download file '{Request.RequestUri}': {(int)res.StatusCode} {res.ReasonPhrase}"); + } + + if (res.Content == null) + { + _logger.LogWarning("File {Request.RequestUri} has no content", Request.RequestUri); + throw new HttpRequestException("Response has no content"); + } + + if (res.Content.Headers.ContentLength >= 0) + TotalBytes = (ulong)res.Content.Headers.ContentLength; + + FileStream tempFile; + try + { + tempFile = File.Create(TempDestinationPath, BufferSize, + FileOptions.Asynchronous | FileOptions.SequentialScan); + } + catch (Exception e) + { + _logger.LogError(e, "Failed to create temporary file '{TempDestinationPath}'", TempDestinationPath); + throw; + } + + Task = DownloadAsync(res, tempFile, ct); + return Task; + } + + private async Task DownloadAsync(HttpResponseMessage res, FileStream tempFile, CancellationToken ct) + { + try + { + var sha1 = res.Headers.Contains("ETag") ? SHA1.Create() : null; + await using (tempFile) + { + var stream = await res.Content.ReadAsStreamAsync(ct); + var buffer = new byte[BufferSize]; + int n; + while ((n = await stream.ReadAsync(buffer, ct)) > 0) + { + await tempFile.WriteAsync(buffer.AsMemory(0, n), ct); + sha1?.TransformBlock(buffer, 0, n, null, 0); + BytesRead += (ulong)n; + } + } + + if (TotalBytes != null && BytesRead != TotalBytes) + throw new IOException( + $"Downloaded file size does not match response Content-Length: Content-Length={TotalBytes}, BytesRead={BytesRead}"); + + // Verify the ETag if it was sent by the server. + if (res.Headers.Contains("ETag") && sha1 != null) + { + var etag = res.Headers.ETag!.Tag.Trim('"'); + _ = sha1.TransformFinalBlock([], 0, 0); + var hashStr = Convert.ToHexString(sha1.Hash!).ToLower(); + if (etag != hashStr) + throw new HttpRequestException( + $"ETag does not match SHA1 hash of downloaded file: ETag='{etag}', Local='{hashStr}'"); + } + + try + { + await _validator.ValidateAsync(TempDestinationPath, ct); + } + catch (Exception e) + { + _logger.LogWarning(e, "Downloaded file '{TempDestinationPath}' failed custom validation", + TempDestinationPath); + throw new HttpRequestException("Downloaded file failed validation", e); + } + + File.Move(TempDestinationPath, DestinationPath, true); + } + finally + { +#if DEBUG + _logger.LogWarning("Not deleting temporary file '{TempDestinationPath}' in debug mode", + TempDestinationPath); +#else + if (File.Exists(TempDestinationPath)) + File.Delete(TempDestinationPath); +#endif + } + } +} diff --git a/Vpn.Service/Manager.cs b/Vpn.Service/Manager.cs new file mode 100644 index 0000000..0f11f34 --- /dev/null +++ b/Vpn.Service/Manager.cs @@ -0,0 +1,215 @@ +using System.Runtime.InteropServices; +using Coder.Desktop.Vpn.Proto; +using CoderSdk; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Semver; + +namespace Coder.Desktop.Vpn.Service; + +public interface IManager : IDisposable +{ + public Task HandleClientRpcMessage(ReplyableRpcMessage message, + CancellationToken ct = default); + + public Task StopAsync(CancellationToken ct = default); +} + +/// +/// Manager provides handling for RPC requests from the client and from the tunnel. +/// +public class Manager : IManager +{ + // TODO: determine a suitable value for this + private const string ServerVersionRange = ">=0.0.0"; + + private readonly ManagerConfig _config; + private readonly IDownloader _downloader; + private readonly ILogger _logger; + private readonly ITunnelSupervisor _tunnelSupervisor; + + // ReSharper disable once ConvertToPrimaryConstructor + public Manager(IOptions config, ILogger logger, IDownloader downloader, + ITunnelSupervisor tunnelSupervisor) + { + _config = config.Value; + _logger = logger; + _downloader = downloader; + _tunnelSupervisor = tunnelSupervisor; + } + + public void Dispose() + { + GC.SuppressFinalize(this); + } + + /// + /// Processes a message sent from a Client to the ManagerRpcService over the codervpn RPC protocol. + /// + /// Client message + /// Cancellation token + public async Task HandleClientRpcMessage(ReplyableRpcMessage message, + CancellationToken ct = default) + { + _logger.LogInformation("ClientMessage: {MessageType}", message.Message.MsgCase); + // TODO: break out each into it's own method? + switch (message.Message.MsgCase) + { + case ClientMessage.MsgOneofCase.Start: + // TODO: these sub-methods should be managed by some Task list and cancelled/awaited on stop + await HandleClientMessageStart(message, ct); + break; + case ClientMessage.MsgOneofCase.Stop: + await HandleClientMessageStop(message, ct); + break; + case ClientMessage.MsgOneofCase.None: + default: + _logger.LogWarning("Received unknown message type {MessageType}", message.Message.MsgCase); + break; + } + } + + public async Task StopAsync(CancellationToken ct = default) + { + await _tunnelSupervisor.StopAsync(ct); + } + + private async Task HandleClientMessageStart(ReplyableRpcMessage message, + CancellationToken ct) + { + try + { + // TODO: if the credentials and URL are identical and the server + // version hasn't changed we should not do anything + // TODO: this should be broken out into it's own method + _logger.LogInformation("ClientMessage.Start: testing server '{ServerUrl}'", message.Message.Start.CoderUrl); + var client = new CoderApiClient(message.Message.Start.CoderUrl, message.Message.Start.ApiToken); + var buildInfo = await client.GetBuildInfo(ct); + _logger.LogInformation("ClientMessage.Start: server version '{ServerVersion}'", buildInfo.Version); + var serverVersion = SemVersion.Parse(buildInfo.Version); + if (!serverVersion.Satisfies(ServerVersionRange)) + throw new InvalidOperationException( + $"Server version '{serverVersion}' is not within required server version range '{ServerVersionRange}'"); + var user = await client.GetUser(User.Me, ct); + _logger.LogInformation("ClientMessage.Start: authenticated as '{Username}'", user.Username); + + await DownloadTunnelBinaryAsync(message.Message.Start.CoderUrl, serverVersion, ct); + await _tunnelSupervisor.StartAsync(_config.TunnelBinaryPath, HandleTunnelRpcMessage, + HandleTunnelRpcError, + ct); + } + catch (Exception e) + { + _logger.LogWarning(e, "ClientMessage.Start: Failed to start VPN client"); + await message.SendReply(new ServiceMessage + { + Start = new StartResponse + { + Success = false, + ErrorMessage = e.Message, + }, + }, ct); + } + } + + private async Task HandleClientMessageStop(ReplyableRpcMessage message, + CancellationToken ct) + { + try + { + // This will handle sending the Stop message for us. + await _tunnelSupervisor.StopAsync(ct); + } + catch (Exception e) + { + _logger.LogWarning(e, "ClientMessage.Stop: Failed to stop VPN client"); + await message.SendReply(new ServiceMessage + { + Stop = new StopResponse + { + Success = false, + ErrorMessage = e.Message, + }, + }, ct); + } + } + + private void HandleTunnelRpcMessage(ReplyableRpcMessage message) + { + // TODO: this + } + + private void HandleTunnelRpcError(Exception e) + { + // TODO: this probably happens during an ongoing start or stop operation, and we should definitely ignore those + _logger.LogError(e, "Manager<->Tunnel RPC error"); + try + { + _tunnelSupervisor.StopAsync(); + } + catch (Exception e2) + { + _logger.LogError(e2, "Failed to stop tunnel supervisor after RPC error"); + } + } + + /// + /// Returns the architecture of the current system. + /// + /// A golang architecture string for the binary + /// Unsupported architecture + private static string SystemArchitecture() + { + // ReSharper disable once SwitchExpressionHandlesSomeKnownEnumValuesWithExceptionInDefault + return RuntimeInformation.ProcessArchitecture switch + { + Architecture.X64 => "amd64", + Architecture.Arm64 => "arm64", + // We only support amd64 and arm64 on Windows currently. + _ => throw new PlatformNotSupportedException( + $"Unsupported architecture '{RuntimeInformation.ProcessArchitecture}'. Coder only supports amd64 and arm64."), + }; + } + + /// + /// Fetches the "/bin/coder-windows-{architecture}.exe" binary from the given base URL and writes it to the + /// destination path after validating the signature and checksum. + /// + /// Server base URL to download the binary from + /// The version of the server to expect in the binary + /// Cancellation token + /// If the base URL is invalid + private async Task DownloadTunnelBinaryAsync(string baseUrl, SemVersion expectedVersion, + CancellationToken ct = default) + { + var architecture = SystemArchitecture(); + Uri url; + try + { + url = new Uri(baseUrl, UriKind.Absolute); + if (url.PathAndQuery != "/") + throw new ArgumentException("Base URL must not contain a path", nameof(baseUrl)); + url = new Uri(url, $"/bin/coder-windows-{architecture}.exe"); + } + catch (Exception e) + { + throw new ArgumentException($"Invalid base URL '{baseUrl}'", e); + } + + _logger.LogInformation("Downloading VPN binary from '{url}' to '{DestinationPath}'", url, + _config.TunnelBinaryPath); + var req = new HttpRequestMessage(HttpMethod.Get, url); + var validators = new CombinationDownloadValidator( + AuthenticodeDownloadValidator.Coder, + new AssemblyVersionDownloadValidator( + $"{expectedVersion.Major}.{expectedVersion.Minor}.{expectedVersion.Patch}.0") + ); + var downloadTask = await _downloader.StartDownloadAsync(req, _config.TunnelBinaryPath, validators, ct); + + // TODO: monitor and report progress when we have a mechanism to do so + + // Awaiting this will check the checksum (via the ETag) if provided, + // and will also validate the signature using the validator we supplied. + await downloadTask.Task; + } +} diff --git a/Vpn.Service/ManagerConfig.cs b/Vpn.Service/ManagerConfig.cs new file mode 100644 index 0000000..906a0b8 --- /dev/null +++ b/Vpn.Service/ManagerConfig.cs @@ -0,0 +1,16 @@ +using System.ComponentModel.DataAnnotations; +using System.Diagnostics.CodeAnalysis; + +namespace Coder.Desktop.Vpn.Service; + +[SuppressMessage("ReSharper", "AutoPropertyCanBeMadeGetOnly.Global")] +public class ManagerConfig +{ + [Required] + [RegularExpression(@"^([a-zA-Z0-9_-]+\.)*[a-zA-Z0-9_-]+$")] + public string ServiceRpcPipeName { get; set; } = "Coder.Desktop.Vpn"; + + // TODO: pick a better default path + [Required] + public string TunnelBinaryPath { get; set; } = @"C:\coder-vpn.exe"; +} diff --git a/Vpn.Service/ManagerRpcService.cs b/Vpn.Service/ManagerRpcService.cs new file mode 100644 index 0000000..ce2b17e --- /dev/null +++ b/Vpn.Service/ManagerRpcService.cs @@ -0,0 +1,128 @@ +using System.Collections.Concurrent; +using System.IO.Pipes; +using Coder.Desktop.Vpn.Proto; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; + +namespace Coder.Desktop.Vpn.Service; + +/// +/// Provides a named pipe server for communication between multiple RpcRole.Client and RpcRole.Manager. +/// +public class ManagerRpcService : BackgroundService, IAsyncDisposable +{ + private readonly ConcurrentDictionary _activeClientTasks = new(); + private readonly ManagerConfig _config; + private readonly CancellationTokenSource _cts = new(); + private readonly ILogger _logger; + private readonly IManager _manager; + + public ManagerRpcService(IOptions config, ILogger logger, IManager manager) + { + _logger = logger; + _manager = manager; + _config = config.Value; + } + + public async ValueTask DisposeAsync() + { + await _cts.CancelAsync(); + while (!_activeClientTasks.IsEmpty) await Task.WhenAny(_activeClientTasks.Values); + _cts.Dispose(); + GC.SuppressFinalize(this); + } + + public override async Task StopAsync(CancellationToken cancellationToken) + { + await _cts.CancelAsync(); + while (!_activeClientTasks.IsEmpty) await Task.WhenAny(_activeClientTasks.Values); + } + + /// + /// Starts the named pipe server, listens for incoming connections and starts handling them asynchronously. + /// + protected override async Task ExecuteAsync(CancellationToken stoppingToken) + { + _logger.LogInformation(@"Starting continuous named pipe RPC server at \\.\pipe\{PipeName}", + _config.ServiceRpcPipeName); + using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(stoppingToken, _cts.Token); + while (!linkedCts.IsCancellationRequested) + { + var pipeServer = new NamedPipeServerStream(_config.ServiceRpcPipeName, PipeDirection.InOut, + NamedPipeServerStream.MaxAllowedServerInstances, PipeTransmissionMode.Byte, PipeOptions.Asynchronous); + + try + { + try + { + _logger.LogDebug("Waiting for new named pipe client connection"); + await pipeServer.WaitForConnectionAsync(linkedCts.Token); + } + finally + { + await pipeServer.DisposeAsync(); + } + + _logger.LogInformation("Handling named pipe client connection"); + var clientTask = HandleRpcClientAsync(pipeServer, linkedCts.Token); + _activeClientTasks.TryAdd(clientTask.Id, clientTask); + _ = clientTask.ContinueWith(RpcClientContinuation, CancellationToken.None); + } + catch (OperationCanceledException) + { + throw; + } + catch (Exception e) + { + _logger.LogWarning(e, "Failed to accept named pipe client"); + } + } + } + + private async Task HandleRpcClientAsync(NamedPipeServerStream pipeServer, CancellationToken ct) + { + var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(ct, _cts.Token); + await using (pipeServer) + { + await using var speaker = new Speaker(pipeServer); + + var tcs = new TaskCompletionSource(); + var activeTasks = new ConcurrentDictionary(); + speaker.Receive += msg => + { + var task = HandleRpcMessageAsync(msg, linkedCts.Token); + activeTasks.TryAdd(task.Id, task); + task.ContinueWith(t => + { + if (t.IsFaulted) + _logger.LogWarning(t.Exception, "Client RPC message handler task faulted"); + activeTasks.TryRemove(t.Id, out _); + }, CancellationToken.None); + }; + speaker.Error += tcs.SetException; + await using (ct.Register(() => tcs.SetCanceled(ct))) + { + await speaker.StartAsync(ct); + await tcs.Task; + await linkedCts.CancelAsync(); + while (!activeTasks.IsEmpty) + await Task.WhenAny(activeTasks.Values); + } + } + } + + private void RpcClientContinuation(Task task) + { + if (task.IsFaulted) + _logger.LogWarning(task.Exception, "Client RPC task faulted"); + _activeClientTasks.TryRemove(task.Id, out _); + } + + private async Task HandleRpcMessageAsync(ReplyableRpcMessage message, + CancellationToken ct) + { + _logger.LogInformation("Received RPC message: {Message}", message.Message); + await _manager.HandleClientRpcMessage(message, ct); + } +} diff --git a/Vpn.Service/ManagerService.cs b/Vpn.Service/ManagerService.cs new file mode 100644 index 0000000..b7b2e34 --- /dev/null +++ b/Vpn.Service/ManagerService.cs @@ -0,0 +1,33 @@ +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; + +namespace Coder.Desktop.Vpn.Service; + +/// +/// Wraps Manager to provide a BackgroundService that informs the singleton Manager to shut down when stop is +/// requested. +/// +public class ManagerService : BackgroundService +{ + private readonly ILogger _logger; + private readonly IManager _manager; + + // ReSharper disable once ConvertToPrimaryConstructor + public ManagerService(ILogger logger, IManager manager) + { + _logger = logger; + _manager = manager; + } + + protected override async Task ExecuteAsync(CancellationToken stoppingToken) + { + // Block until the service is stopped. + await Task.Delay(-1, stoppingToken); + } + + public override async Task StopAsync(CancellationToken cancellationToken) + { + _logger.LogInformation("Informing Manager to stop"); + await _manager.StopAsync(cancellationToken); + } +} diff --git a/Vpn.Service/Program.cs b/Vpn.Service/Program.cs new file mode 100644 index 0000000..78fbff2 --- /dev/null +++ b/Vpn.Service/Program.cs @@ -0,0 +1,30 @@ +using Coder.Desktop.Vpn.Service; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.Win32; + +var builder = Host.CreateApplicationBuilder(args); + +// Configuration sources +builder.Configuration.Sources.Clear(); +(builder.Configuration as IConfigurationBuilder).Add( + new RegistryConfigurationSource(Registry.LocalMachine, @"SOFTWARE\Coder\Coder VPN")); +builder.Configuration.AddEnvironmentVariables("CODER_MANAGER_"); +builder.Configuration.AddCommandLine(args); + +// Options types (these get registered as IOptions singletons) +builder.Services.AddOptions() + .Bind(builder.Configuration.GetSection("Manager")) + .ValidateDataAnnotations(); + +// Singletons +builder.Services.AddSingleton(); +builder.Services.AddSingleton(); +builder.Services.AddSingleton(); + +// Services +builder.Services.AddHostedService(); +builder.Services.AddHostedService(); + +builder.Build().Run(); diff --git a/Vpn.Service/RegistryConfigurationSource.cs b/Vpn.Service/RegistryConfigurationSource.cs new file mode 100644 index 0000000..7ac2764 --- /dev/null +++ b/Vpn.Service/RegistryConfigurationSource.cs @@ -0,0 +1,41 @@ +using Microsoft.Extensions.Configuration; +using Microsoft.Win32; + +namespace Coder.Desktop.Vpn.Service; + +public class RegistryConfigurationSource : IConfigurationSource +{ + private readonly RegistryKey _root; + private readonly string _subKeyName; + + public RegistryConfigurationSource(RegistryKey root, string subKeyName) + { + _root = root; + _subKeyName = subKeyName; + } + + public IConfigurationProvider Build(IConfigurationBuilder builder) + { + return new RegistryConfigurationProvider(_root, _subKeyName); + } +} + +public class RegistryConfigurationProvider : ConfigurationProvider +{ + private readonly RegistryKey _root; + private readonly string _subKeyName; + + public RegistryConfigurationProvider(RegistryKey root, string subKeyName) + { + _root = root; + _subKeyName = subKeyName; + } + + public override void Load() + { + using var key = _root.OpenSubKey(_subKeyName); + if (key == null) return; + + foreach (var valueName in key.GetValueNames()) Data[valueName] = key.GetValue(valueName)?.ToString(); + } +} diff --git a/Vpn.Service/TunnelSupervisor.cs b/Vpn.Service/TunnelSupervisor.cs new file mode 100644 index 0000000..9ea5b05 --- /dev/null +++ b/Vpn.Service/TunnelSupervisor.cs @@ -0,0 +1,271 @@ +using System.Diagnostics; +using System.IO.Pipes; +using Coder.Desktop.Vpn.Proto; +using Coder.Desktop.Vpn.Utilities; +using Microsoft.Extensions.Logging; + +namespace Coder.Desktop.Vpn.Service; + +public interface ITunnelSupervisor : IAsyncDisposable +{ + /// + /// Starts the tunnel subprocess with the given executable path. If the subprocess is already running, this method will + /// kill it first. + /// + /// Path to the executable + /// Handler to call with each RPC message + /// + /// Handler for permanent errors from the RPC Speaker. The recipient should call StopAsync after + /// receiving this. + /// + /// Cancellation token + public Task StartAsync(string binPath, + Speaker.OnReceiveDelegate messageHandler, + Speaker.OnErrorDelegate errorHandler, + CancellationToken ct = default); + + /// + /// Stops the tunnel subprocess. If the subprocess is not running, this method does nothing. + /// + /// + /// + public Task StopAsync(CancellationToken ct = default); +} + +/// +/// Launches and supervises the tunnel subprocess. Provides RPC communication with the subprocess. +/// +public class TunnelSupervisor : ITunnelSupervisor +{ + private readonly CancellationTokenSource _cts = new(); + private readonly ILogger _logger; + private readonly SemaphoreSlim _operationLock = new(1, 1); + private AnonymousPipeServerStream? _inPipe; + private AnonymousPipeServerStream? _outPipe; + private Speaker? _speaker; + + private Process? _subprocess; + + // ReSharper disable once ConvertToPrimaryConstructor + public TunnelSupervisor(ILogger logger) + { + _logger = logger; + } + + public async Task StartAsync(string binPath, + Speaker.OnReceiveDelegate messageHandler, + Speaker.OnErrorDelegate errorHandler, + CancellationToken ct = default) + { + _logger.LogInformation("StartAsync(\"{binPath}\")", binPath); + if (!await _operationLock.WaitAsync(0, ct)) + throw new InvalidOperationException( + "Another TunnelSupervisor Start or Stop operation is already in progress"); + + try + { + await CleanupAsync(ct); + + _outPipe = new AnonymousPipeServerStream(PipeDirection.Out, HandleInheritability.Inheritable); + _inPipe = new AnonymousPipeServerStream(PipeDirection.In, HandleInheritability.Inheritable); + _subprocess = new Process + { + StartInfo = new ProcessStartInfo + { + FileName = binPath, + ArgumentList = { "vpn-daemon", "run" }, + UseShellExecute = false, + CreateNoWindow = true, + }, + }; + + // Pass the other end of the pipes to the subprocess and dispose + // the local copies. + _subprocess.StartInfo.Environment.Add("CODER_VPN_DAEMON_RPC_READ_HANDLE", + _outPipe.GetClientHandleAsString()); + _subprocess.StartInfo.Environment.Add("CODER_VPN_DAEMON_RPC_WRITE_HANDLE", + _inPipe.GetClientHandleAsString()); + _outPipe.DisposeLocalCopyOfClientHandle(); + _inPipe.DisposeLocalCopyOfClientHandle(); + + _logger.LogInformation("StartAsync: starting subprocess"); + _subprocess.Start(); + _logger.LogInformation("StartAsync: subprocess started"); + + // We don't use the supplied CancellationToken here because we want it to only apply to the startup + // procedure. + _ = _subprocess.WaitForExitAsync(_cts.Token).ContinueWith(OnProcessExited, CancellationToken.None); + + // Start the RPC Speaker. + try + { + var stream = new BidirectionalPipe(_inPipe, _outPipe); + _speaker = new Speaker(stream); + _speaker.Receive += messageHandler; + _speaker.Error += errorHandler; + // Handshakes already have a 5-second timeout. + await _speaker.StartAsync(ct); + } + catch (Exception e) + { + throw new Exception("Failed to start RPC Speaker on pipes to subprocess", e); + } + } + catch (Exception e) + { + _logger.LogError(e, "StartAsync: failed to start or connect to subprocess"); + await CleanupAsync(ct); + throw; + } + finally + { + _operationLock.Release(); + } + } + + public async Task StopAsync(CancellationToken ct = default) + { + _logger.LogInformation("StopAsync()"); + if (!await _operationLock.WaitAsync(0, ct)) + throw new InvalidOperationException( + "Another TunnelSupervisor Start or Stop operation is already in progress"); + + try + { + await CleanupAsync(ct); + } + finally + { + _operationLock.Release(); + } + } + + public async ValueTask DisposeAsync() + { + _cts.Dispose(); + await CleanupAsync(); + GC.SuppressFinalize(this); + } + + private async Task OnProcessExited(Task task) + { + if (task.IsFaulted) + { + _logger.LogError(task.Exception, "OnProcessExited: subprocess exited with an exception"); + return; + } + + if (!await _operationLock.WaitAsync(0)) _logger.LogInformation("OnProcessExited: subprocess exited"); + + try + { + await CleanupAsync(); + _logger.LogInformation("OnProcessExited: subprocess exited with code {ExitCode}", + _subprocess?.ExitCode ?? -1); + } + finally + { + _operationLock.Release(); + } + } + + /// + /// Cleans up the pipes and the subprocess if it's still running. This method should not be called without holding the + /// semaphore. + /// + private async Task CleanupAsync(CancellationToken ct = default) + { + if (_speaker != null) + { + try + { + _logger.LogInformation("CleanupAsync: Sending stop message to subprocess"); + var stopCts = CancellationTokenSource.CreateLinkedTokenSource(ct); + stopCts.CancelAfter(5000); + await _speaker.SendRequestAwaitReply(new ManagerMessage + { + Stop = new StopRequest(), + }, stopCts.Token); + } + catch (Exception e) + { + _logger.LogError(e, "CleanupAsync: Failed to send stop message to subprocess"); + } + + try + { + _logger.LogInformation("CleanupAsync: Disposing _speaker"); + await _speaker.DisposeAsync(); + } + catch (Exception e) + { + _logger.LogError(e, "CleanupAsync: Failed to stop/dispose _speaker"); + } + finally + { + _speaker = null; + } + } + + if (_outPipe != null) + { + _logger.LogInformation("CleanupAsync: Disposing _outPipe"); + try + { + await _outPipe.DisposeAsync(); + } + catch (Exception e) + { + _logger.LogError(e, "CleanupAsync: Failed to dispose _outPipe"); + } + finally + { + _outPipe = null; + } + } + + if (_inPipe != null) + { + _logger.LogInformation("CleanupAsync: Disposing _inPipe"); + try + { + await _inPipe.DisposeAsync(); + } + catch (Exception e) + { + _logger.LogError(e, "CleanupAsync: Failed to dispose _inPipe"); + } + finally + { + _inPipe = null; + } + } + + if (_subprocess != null) + try + { + if (!_subprocess.HasExited) + { + // TODO: is there a nicer way we can do this? + _logger.LogInformation("CleanupAsync: Killing un-exited _subprocess"); + _subprocess.Kill(); + // Since we just killed the process ideally it should exit + // immediately. + var exitCts = CancellationTokenSource.CreateLinkedTokenSource(ct); + exitCts.CancelAfter(5000); + await _subprocess.WaitForExitAsync(exitCts.Token); + } + + _logger.LogInformation("CleanupAsync: Disposing _subprocess"); + _subprocess.Dispose(); + } + catch (Exception e) + { + _logger.LogError(e, "CleanupAsync: Failed to kill/dispose _subprocess"); + } + finally + { + _subprocess = null; + } + } +} diff --git a/Vpn.Service/Vpn.Service.csproj b/Vpn.Service/Vpn.Service.csproj new file mode 100644 index 0000000..e6da70d --- /dev/null +++ b/Vpn.Service/Vpn.Service.csproj @@ -0,0 +1,24 @@ + + + + Coder.Desktop.Vpn.Service + Exe + net8.0-windows + enable + enable + + + + + + + + + + + + + + + + diff --git a/Vpn/Serdes.cs b/Vpn/Serdes.cs index 317417b..00837b7 100644 --- a/Vpn/Serdes.cs +++ b/Vpn/Serdes.cs @@ -1,32 +1,10 @@ using System.Buffers.Binary; using Coder.Desktop.Vpn.Proto; +using Coder.Desktop.Vpn.Utilities; using Google.Protobuf; namespace Coder.Desktop.Vpn; -/// -/// RaiiSemaphoreSlim is a wrapper around SemaphoreSlim that provides RAII-style locking. -/// -internal class RaiiSemaphoreSlim(int initialCount, int maxCount) -{ - private readonly SemaphoreSlim _semaphore = new(initialCount, maxCount); - - public async ValueTask LockAsync(CancellationToken ct = default) - { - await _semaphore.WaitAsync(ct); - return new Lock(_semaphore); - } - - private class Lock(SemaphoreSlim semaphore) : IDisposable - { - public void Dispose() - { - semaphore.Release(); - GC.SuppressFinalize(this); - } - } -} - /// /// Serdes provides serialization and deserialization of messages read from a Stream. /// diff --git a/Vpn/Speaker.cs b/Vpn/Speaker.cs index 5bccbe4..4c6ef3c 100644 --- a/Vpn/Speaker.cs +++ b/Vpn/Speaker.cs @@ -9,29 +9,43 @@ namespace Coder.Desktop.Vpn; /// /// Thrown when the two peers are incompatible with each other. /// -public class RpcVersionCompatibilityException(RpcVersionList localVersion, RpcVersionList remoteVersion) - : Exception($"No RPC versions are compatible: local={localVersion}, remote={remoteVersion}"); +public class RpcVersionCompatibilityException : Exception +{ + public RpcVersionCompatibilityException(RpcVersionList localVersion, RpcVersionList remoteVersion) : base( + $"No RPC versions are compatible: local={localVersion}, remote={remoteVersion}") + { + } +} /// /// Wraps a RpcMessage to allow easily sending a reply via the Speaker. /// -/// Speaker to use for sending reply -/// Original received message -public class ReplyableRpcMessage(Speaker speaker, TR message) : RpcMessage - where TS : RpcMessage, IMessage - where TR : RpcMessage, IMessage, new() +public class ReplyableRpcMessage : RpcMessage + where TS : RpcMessage, IRpcMessageCompatibleWith, IMessage + where TR : RpcMessage, IRpcMessageCompatibleWith, IMessage, new() { + private readonly TR _message; + private readonly Speaker _speaker; + public override RPC? RpcField { - get => message.RpcField; - set => message.RpcField = value; + get => _message.RpcField; + set => _message.RpcField = value; } - public override TR Message => message; + public override TR Message => _message; + + /// Speaker to use for sending reply + /// Original received message + public ReplyableRpcMessage(Speaker speaker, TR message) + { + _speaker = speaker; + _message = message; + } public override void Validate() { - message.Validate(); + _message.Validate(); } /// @@ -41,7 +55,7 @@ public override void Validate() /// Optional cancellation token public async Task SendReply(TS reply, CancellationToken ct = default) { - await speaker.SendReply(message, reply, ct); + await _speaker.SendReply(_message, reply, ct); } } @@ -51,8 +65,8 @@ public async Task SendReply(TS reply, CancellationToken ct = default) /// The message type for sent messages /// The message type for received messages public class Speaker : IAsyncDisposable - where TS : RpcMessage, IMessage - where TR : RpcMessage, IMessage, new() + where TS : RpcMessage, IRpcMessageCompatibleWith, IMessage + where TR : RpcMessage, IRpcMessageCompatibleWith, IMessage, new() { public delegate void OnErrorDelegate(Exception e); diff --git a/Vpn/Utilities/BidirectionalPipe.cs b/Vpn/Utilities/BidirectionalPipe.cs new file mode 100644 index 0000000..72e633b --- /dev/null +++ b/Vpn/Utilities/BidirectionalPipe.cs @@ -0,0 +1,101 @@ +using System.IO.Pipelines; + +namespace Coder.Desktop.Vpn.Utilities; + +/// +/// BidirectionalPipe implements Stream using a read-only Stream and a write-only Stream. +/// +public class BidirectionalPipe : Stream +{ + private readonly Stream _reader; + private readonly Stream _writer; + + public override bool CanRead => true; + public override bool CanSeek => false; + public override bool CanWrite => true; + public override long Length => -1; + + public override long Position + { + get => -1; + set => throw new NotImplementedException("BidirectionalPipe does not support setting position"); + } + + /// The stream to perform reads from + /// The stream to write data to + public BidirectionalPipe(Stream reader, Stream writer) + { + _reader = reader; + _writer = writer; + } + + /// + /// Creates a new pair of BidirectionalPipes that are connected to each other using buffered in-memory pipes. + /// + /// Two pipes connected to each other + public static (BidirectionalPipe, BidirectionalPipe) NewInMemory() + { + var pipe1 = new Pipe(); + var pipe2 = new Pipe(); + return ( + new BidirectionalPipe(pipe1.Reader.AsStream(), pipe2.Writer.AsStream()), + new BidirectionalPipe(pipe2.Reader.AsStream(), pipe1.Writer.AsStream()) + ); + } + + public override void Flush() + { + _writer.Flush(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + return _reader.Read(buffer, offset, count); + } + + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken ct) + { +#pragma warning disable CA1835 + return await _reader.ReadAsync(buffer, offset, count, ct); +#pragma warning restore CA1835 + } + + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + return _reader.ReadAsync(buffer, cancellationToken); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotImplementedException("BidirectionalPipe does not support seeking"); + } + + public override void SetLength(long value) + { + throw new NotImplementedException("BidirectionalPipe does not support setting length"); + } + + public override void Write(byte[] buffer, int offset, int count) + { + _writer.Write(buffer, offset, count); + } + + public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken ct) + { +#pragma warning disable CA1835 + await _writer.WriteAsync(buffer, offset, count, ct); +#pragma warning restore CA1835 + } + + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + return _writer.WriteAsync(buffer, cancellationToken); + } + + protected override void Dispose(bool disposing) + { + base.Dispose(disposing); + _writer.Dispose(); + _reader.Dispose(); + } +} diff --git a/Vpn/Utilities/RaiiSemaphoreSlim.cs b/Vpn/Utilities/RaiiSemaphoreSlim.cs new file mode 100644 index 0000000..f4ecee6 --- /dev/null +++ b/Vpn/Utilities/RaiiSemaphoreSlim.cs @@ -0,0 +1,42 @@ +namespace Coder.Desktop.Vpn.Utilities; + +/// +/// RaiiSemaphoreSlim is a wrapper around SemaphoreSlim that provides RAII-style locking. +/// +public class RaiiSemaphoreSlim : IDisposable +{ + private readonly SemaphoreSlim _semaphore; + + public RaiiSemaphoreSlim(int initialCount, int maxCount) + { + _semaphore = new SemaphoreSlim(initialCount, maxCount); + } + + public void Dispose() + { + _semaphore.Dispose(); + GC.SuppressFinalize(this); + } + + public async ValueTask LockAsync(CancellationToken ct = default) + { + await _semaphore.WaitAsync(ct); + return new Lock(_semaphore); + } + + private class Lock : IDisposable + { + private readonly SemaphoreSlim _semaphore1; + + public Lock(SemaphoreSlim semaphore) + { + _semaphore1 = semaphore; + } + + public void Dispose() + { + _semaphore1.Release(); + GC.SuppressFinalize(this); + } + } +} diff --git a/Vpn/Utilities/TaskUtilities.cs b/Vpn/Utilities/TaskUtilities.cs index 8a2bfdb..4105c9e 100644 --- a/Vpn/Utilities/TaskUtilities.cs +++ b/Vpn/Utilities/TaskUtilities.cs @@ -1,6 +1,6 @@ namespace Coder.Desktop.Vpn.Utilities; -internal static class TaskUtilities +public static class TaskUtilities { /// /// Waits for all tasks to complete, but cancels the provided CancellationTokenSource if any task is canceled or diff --git a/Vpn/Vpn.csproj b/Vpn/Vpn.csproj index bcef1b5..22b585f 100644 --- a/Vpn/Vpn.csproj +++ b/Vpn/Vpn.csproj @@ -11,4 +11,8 @@ + + + +