From c0440760ac363d817cbdca87e1ab7eff7e74a025 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Thu, 3 Jul 2025 14:50:37 -0700 Subject: [PATCH 1/7] Bump version to 0.3.0-preview.3 --- src/Directory.Build.props | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Directory.Build.props b/src/Directory.Build.props index b3d159455..b8408bacd 100644 --- a/src/Directory.Build.props +++ b/src/Directory.Build.props @@ -6,7 +6,7 @@ https://github.com/modelcontextprotocol/csharp-sdk git 0.3.0 - preview.2 + preview.3 ModelContextProtocolOfficial © Anthropic and Contributors. ModelContextProtocol;mcp;ai;llm From aa75f6f37edd94afa31ae5a4ac3967ad9d36a949 Mon Sep 17 00:00:00 2001 From: Eirik Tsarpalis Date: Tue, 8 Jul 2025 18:50:37 +0300 Subject: [PATCH 2/7] Enable netfx testing. (#588) * Enable netfx testing. * Skip impacted tests and re-enable netfx testing. * Remove redundant comment. * Remove unnecessary project configuration. * Simplify test server copying. --- Directory.Packages.props | 6 +- .../CancellableStreamReader.cs | 1395 +++++++++++++++++ .../TextReaderExtensions.cs | 15 + .../ValueStringBuilder.cs | 317 ++++ .../Polyfills/System/IO/StreamExtensions.cs | 28 + .../System/IO/TextReaderExtensions.cs | 10 - src/Common/Throw.cs | 9 + .../Client/StdioClientSessionTransport.cs | 2 +- .../Client/StdioClientTransport.cs | 10 +- .../Client/StreamClientSessionTransport.cs | 40 + .../Client/StreamClientTransport.cs | 5 +- .../ModelContextProtocol.Core.csproj | 2 + .../Server/McpServerPrimitiveCollection.cs | 2 +- .../Server/StreamServerTransport.cs | 4 + tests/Common/Utils/MockHttpHandler.cs | 4 +- tests/Common/Utils/ProcessExtensions.cs | 15 + .../ModelContextProtocol.TestServer.csproj | 1 + .../ClientIntegrationTestFixture.cs | 4 +- .../McpServerBuilderExtensionsToolsTests.cs | 4 +- .../EverythingSseServerFixture.cs | 7 +- .../GlobalUsings.cs | 1 + .../ModelContextProtocol.Tests.csproj | 19 +- .../PlatformDetection.cs | 6 + .../Server/McpServerDelegatesTests.cs | 8 + .../Server/McpServerLoggingLevelTests.cs | 8 + .../Server/McpServerPromptTests.cs | 8 + .../Server/McpServerResourceTests.cs | 12 + .../Server/McpServerTests.cs | 17 +- .../Server/McpServerToolTests.cs | 10 +- .../StdioServerIntegrationTests.cs | 6 +- .../SseResponseStreamTransportTests.cs | 12 +- .../Transport/StdioClientTransportTests.cs | 6 +- 32 files changed, 1940 insertions(+), 53 deletions(-) create mode 100644 src/Common/CancellableStreamReader/CancellableStreamReader.cs create mode 100644 src/Common/CancellableStreamReader/TextReaderExtensions.cs create mode 100644 src/Common/CancellableStreamReader/ValueStringBuilder.cs delete mode 100644 src/Common/Polyfills/System/IO/TextReaderExtensions.cs create mode 100644 tests/Common/Utils/ProcessExtensions.cs create mode 100644 tests/ModelContextProtocol.Tests/PlatformDetection.cs diff --git a/Directory.Packages.props b/Directory.Packages.props index 2e377c4ff..70eb82f3a 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -60,7 +60,7 @@ - + @@ -75,8 +75,8 @@ - - + + diff --git a/src/Common/CancellableStreamReader/CancellableStreamReader.cs b/src/Common/CancellableStreamReader/CancellableStreamReader.cs new file mode 100644 index 000000000..f6df72d22 --- /dev/null +++ b/src/Common/CancellableStreamReader/CancellableStreamReader.cs @@ -0,0 +1,1395 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Buffers.Binary; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Text; +using ModelContextProtocol; + +namespace System.IO; + +/// +/// Netfx-compatible polyfill of System.IO.StreamReader that supports cancellation. +/// +internal class CancellableStreamReader : TextReader +{ + // CancellableStreamReader.Null is threadsafe. + public static new readonly CancellableStreamReader Null = new NullCancellableStreamReader(); + + // Using a 1K byte buffer and a 4K FileStream buffer works out pretty well + // perf-wise. On even a 40 MB text file, any perf loss by using a 4K + // buffer is negated by the win of allocating a smaller byte[], which + // saves construction time. This does break adaptive buffering, + // but this is slightly faster. + private const int DefaultBufferSize = 1024; // Byte buffer size + private const int DefaultFileStreamBufferSize = 4096; + private const int MinBufferSize = 128; + + private readonly Stream _stream; + private Encoding _encoding = null!; // only null in NullCancellableStreamReader where this is never used + private readonly byte[] _encodingPreamble = null!; // only null in NullCancellableStreamReader where this is never used + private Decoder _decoder = null!; // only null in NullCancellableStreamReader where this is never used + private readonly byte[] _byteBuffer = null!; // only null in NullCancellableStreamReader where this is never used + private char[] _charBuffer = null!; // only null in NullCancellableStreamReader where this is never used + private int _charPos; + private int _charLen; + // Record the number of valid bytes in the byteBuffer, for a few checks. + private int _byteLen; + // This is used only for preamble detection + private int _bytePos; + + // This is the maximum number of chars we can get from one call to + // ReadBuffer. Used so ReadBuffer can tell when to copy data into + // a user's char[] directly, instead of our internal char[]. + private int _maxCharsPerBuffer; + + /// True if the writer has been disposed; otherwise, false. + private bool _disposed; + + // We will support looking for byte order marks in the stream and trying + // to decide what the encoding might be from the byte order marks, IF they + // exist. But that's all we'll do. + private bool _detectEncoding; + + // Whether we must still check for the encoding's given preamble at the + // beginning of this file. + private bool _checkPreamble; + + // Whether the stream is most likely not going to give us back as much + // data as we want the next time we call it. We must do the computation + // before we do any byte order mark handling and save the result. Note + // that we need this to allow users to handle streams used for an + // interactive protocol, where they block waiting for the remote end + // to send a response, like logging in on a Unix machine. + private bool _isBlocked; + + // The intent of this field is to leave open the underlying stream when + // disposing of this CancellableStreamReader. A name like _leaveOpen is better, + // but this type is serializable, and this field's name was _closable. + private readonly bool _closable; // Whether to close the underlying stream. + + // We don't guarantee thread safety on CancellableStreamReader, but we should at + // least prevent users from trying to read anything while an Async + // read from the same thread is in progress. + private Task _asyncReadTask = Task.CompletedTask; + + private void CheckAsyncTaskInProgress() + { + // We are not locking the access to _asyncReadTask because this is not meant to guarantee thread safety. + // We are simply trying to deter calling any Read APIs while an async Read from the same thread is in progress. + if (!_asyncReadTask.IsCompleted) + { + ThrowAsyncIOInProgress(); + } + } + + [DoesNotReturn] + private static void ThrowAsyncIOInProgress() => + throw new InvalidOperationException("Async IO is in progress"); + + // CancellableStreamReader by default will ignore illegal UTF8 characters. We don't want to + // throw here because we want to be able to read ill-formed data without choking. + // The high level goal is to be tolerant of encoding errors when we read and very strict + // when we write. Hence, default StreamWriter encoding will throw on error. + + private CancellableStreamReader() + { + Debug.Assert(this is NullCancellableStreamReader); + _stream = Stream.Null; + _closable = true; + } + + public CancellableStreamReader(Stream stream) + : this(stream, true) + { + } + + public CancellableStreamReader(Stream stream, bool detectEncodingFromByteOrderMarks) + : this(stream, Encoding.UTF8, detectEncodingFromByteOrderMarks, DefaultBufferSize, false) + { + } + + public CancellableStreamReader(Stream stream, Encoding? encoding) + : this(stream, encoding, true, DefaultBufferSize, false) + { + } + + public CancellableStreamReader(Stream stream, Encoding? encoding, bool detectEncodingFromByteOrderMarks) + : this(stream, encoding, detectEncodingFromByteOrderMarks, DefaultBufferSize, false) + { + } + + // Creates a new CancellableStreamReader for the given stream. The + // character encoding is set by encoding and the buffer size, + // in number of 16-bit characters, is set by bufferSize. + // + // Note that detectEncodingFromByteOrderMarks is a very + // loose attempt at detecting the encoding by looking at the first + // 3 bytes of the stream. It will recognize UTF-8, little endian + // unicode, and big endian unicode text, but that's it. If neither + // of those three match, it will use the Encoding you provided. + // + public CancellableStreamReader(Stream stream, Encoding? encoding, bool detectEncodingFromByteOrderMarks, int bufferSize) + : this(stream, encoding, detectEncodingFromByteOrderMarks, bufferSize, false) + { + } + + public CancellableStreamReader(Stream stream, Encoding? encoding = null, bool detectEncodingFromByteOrderMarks = true, int bufferSize = -1, bool leaveOpen = false) + { + Throw.IfNull(stream); + + if (!stream.CanRead) + { + throw new ArgumentException("Stream not readable."); + } + + if (bufferSize == -1) + { + bufferSize = DefaultBufferSize; + } + + if (bufferSize <= 0) + { + throw new ArgumentOutOfRangeException(nameof(bufferSize), bufferSize, "Buffer size must be greater than zero."); + } + + _stream = stream; + _encoding = encoding ??= Encoding.UTF8; + _decoder = encoding.GetDecoder(); + if (bufferSize < MinBufferSize) + { + bufferSize = MinBufferSize; + } + + _byteBuffer = new byte[bufferSize]; + _maxCharsPerBuffer = encoding.GetMaxCharCount(bufferSize); + _charBuffer = new char[_maxCharsPerBuffer]; + _detectEncoding = detectEncodingFromByteOrderMarks; + _encodingPreamble = encoding.GetPreamble(); + + // If the preamble length is larger than the byte buffer length, + // we'll never match it and will enter an infinite loop. This + // should never happen in practice, but just in case, we'll skip + // the preamble check for absurdly long preambles. + int preambleLength = _encodingPreamble.Length; + _checkPreamble = preambleLength > 0 && preambleLength <= bufferSize; + + _closable = !leaveOpen; + } + + public override void Close() + { + Dispose(true); + } + + protected override void Dispose(bool disposing) + { + if (_disposed) + { + return; + } + _disposed = true; + + // Dispose of our resources if this CancellableStreamReader is closable. + if (_closable) + { + try + { + // Note that Stream.Close() can potentially throw here. So we need to + // ensure cleaning up internal resources, inside the finally block. + if (disposing) + { + _stream.Close(); + } + } + finally + { + _charPos = 0; + _charLen = 0; + base.Dispose(disposing); + } + } + } + + public virtual Encoding CurrentEncoding => _encoding; + + public virtual Stream BaseStream => _stream; + + // DiscardBufferedData tells CancellableStreamReader to throw away its internal + // buffer contents. This is useful if the user needs to seek on the + // underlying stream to a known location then wants the CancellableStreamReader + // to start reading from this new point. This method should be called + // very sparingly, if ever, since it can lead to very poor performance. + // However, it may be the only way of handling some scenarios where + // users need to re-read the contents of a CancellableStreamReader a second time. + public void DiscardBufferedData() + { + CheckAsyncTaskInProgress(); + + _byteLen = 0; + _charLen = 0; + _charPos = 0; + // in general we'd like to have an invariant that encoding isn't null. However, + // for startup improvements for NullCancellableStreamReader, we want to delay load encoding. + if (_encoding != null) + { + _decoder = _encoding.GetDecoder(); + } + _isBlocked = false; + } + + public bool EndOfStream + { + get + { + ThrowIfDisposed(); + CheckAsyncTaskInProgress(); + + if (_charPos < _charLen) + { + return false; + } + + // This may block on pipes! + int numRead = ReadBuffer(); + return numRead == 0; + } + } + + public override int Peek() + { + ThrowIfDisposed(); + CheckAsyncTaskInProgress(); + + if (_charPos == _charLen) + { + if (ReadBuffer() == 0) + { + return -1; + } + } + return _charBuffer[_charPos]; + } + + public override int Read() + { + ThrowIfDisposed(); + CheckAsyncTaskInProgress(); + + if (_charPos == _charLen) + { + if (ReadBuffer() == 0) + { + return -1; + } + } + int result = _charBuffer[_charPos]; + _charPos++; + return result; + } + + public override int Read(char[] buffer, int index, int count) + { + Throw.IfNull(buffer); + Throw.IfNegative(index); + Throw.IfNegative(count); + + return ReadSpan(new Span(buffer, index, count)); + } + + private int ReadSpan(Span buffer) + { + ThrowIfDisposed(); + CheckAsyncTaskInProgress(); + + int charsRead = 0; + // As a perf optimization, if we had exactly one buffer's worth of + // data read in, let's try writing directly to the user's buffer. + bool readToUserBuffer = false; + int count = buffer.Length; + while (count > 0) + { + int n = _charLen - _charPos; + if (n == 0) + { + n = ReadBuffer(buffer.Slice(charsRead), out readToUserBuffer); + } + if (n == 0) + { + break; // We're at EOF + } + if (n > count) + { + n = count; + } + if (!readToUserBuffer) + { + new Span(_charBuffer, _charPos, n).CopyTo(buffer.Slice(charsRead)); + _charPos += n; + } + + charsRead += n; + count -= n; + // This function shouldn't block for an indefinite amount of time, + // or reading from a network stream won't work right. If we got + // fewer bytes than we requested, then we want to break right here. + if (_isBlocked) + { + break; + } + } + + return charsRead; + } + + public override string ReadToEnd() + { + ThrowIfDisposed(); + CheckAsyncTaskInProgress(); + + // Call ReadBuffer, then pull data out of charBuffer. + StringBuilder sb = new StringBuilder(_charLen - _charPos); + do + { + sb.Append(_charBuffer, _charPos, _charLen - _charPos); + _charPos = _charLen; // Note we consumed these characters + ReadBuffer(); + } while (_charLen > 0); + return sb.ToString(); + } + + public override int ReadBlock(char[] buffer, int index, int count) + { + Throw.IfNull(buffer); + + Throw.IfNegative(index); + Throw.IfNegative(count); + if (buffer.Length - index < count) + { + throw new ArgumentException("invalid offset length."); + } + ThrowIfDisposed(); + CheckAsyncTaskInProgress(); + + return base.ReadBlock(buffer, index, count); + } + + // Trims n bytes from the front of the buffer. + private void CompressBuffer(int n) + { + Debug.Assert(_byteLen >= n, "CompressBuffer was called with a number of bytes greater than the current buffer length. Are two threads using this CancellableStreamReader at the same time?"); + byte[] byteBuffer = _byteBuffer; + _ = byteBuffer.Length; // allow JIT to prove object is not null + new ReadOnlySpan(byteBuffer, n, _byteLen - n).CopyTo(byteBuffer); + _byteLen -= n; + } + + private void DetectEncoding() + { + Debug.Assert(_byteLen >= 2, "Caller should've validated that at least 2 bytes were available."); + + byte[] byteBuffer = _byteBuffer; + _detectEncoding = false; + bool changedEncoding = false; + + ushort firstTwoBytes = BinaryPrimitives.ReadUInt16LittleEndian(byteBuffer); + if (firstTwoBytes == 0xFFFE) + { + // Big Endian Unicode + _encoding = Encoding.BigEndianUnicode; + CompressBuffer(2); + changedEncoding = true; + } + else if (firstTwoBytes == 0xFEFF) + { + // Little Endian Unicode, or possibly little endian UTF32 + if (_byteLen < 4 || byteBuffer[2] != 0 || byteBuffer[3] != 0) + { + _encoding = Encoding.Unicode; + CompressBuffer(2); + changedEncoding = true; + } + else + { + _encoding = Encoding.UTF32; + CompressBuffer(4); + changedEncoding = true; + } + } + else if (_byteLen >= 3 && firstTwoBytes == 0xBBEF && byteBuffer[2] == 0xBF) + { + // UTF-8 + _encoding = Encoding.UTF8; + CompressBuffer(3); + changedEncoding = true; + } + else if (_byteLen >= 4 && firstTwoBytes == 0 && byteBuffer[2] == 0xFE && byteBuffer[3] == 0xFF) + { + // Big Endian UTF32 + _encoding = new UTF32Encoding(bigEndian: true, byteOrderMark: true); + CompressBuffer(4); + changedEncoding = true; + } + else if (_byteLen == 2) + { + _detectEncoding = true; + } + // Note: in the future, if we change this algorithm significantly, + // we can support checking for the preamble of the given encoding. + + if (changedEncoding) + { + _decoder = _encoding.GetDecoder(); + int newMaxCharsPerBuffer = _encoding.GetMaxCharCount(byteBuffer.Length); + if (newMaxCharsPerBuffer > _maxCharsPerBuffer) + { + _charBuffer = new char[newMaxCharsPerBuffer]; + } + _maxCharsPerBuffer = newMaxCharsPerBuffer; + } + } + + // Trims the preamble bytes from the byteBuffer. This routine can be called multiple times + // and we will buffer the bytes read until the preamble is matched or we determine that + // there is no match. If there is no match, every byte read previously will be available + // for further consumption. If there is a match, we will compress the buffer for the + // leading preamble bytes + private bool IsPreamble() + { + if (!_checkPreamble) + { + return false; + } + + return IsPreambleWorker(); // move this call out of the hot path + bool IsPreambleWorker() + { + Debug.Assert(_checkPreamble); + ReadOnlySpan preamble = _encodingPreamble; + + Debug.Assert(_bytePos < preamble.Length, "_compressPreamble was called with the current bytePos greater than the preamble buffer length. Are two threads using this CancellableStreamReader at the same time?"); + int len = Math.Min(_byteLen, preamble.Length); + + for (int i = _bytePos; i < len; i++) + { + if (_byteBuffer[i] != preamble[i]) + { + _bytePos = 0; // preamble match failed; back up to beginning of buffer + _checkPreamble = false; + return false; + } + } + _bytePos = len; // we've matched all bytes up to this point + + Debug.Assert(_bytePos <= preamble.Length, "possible bug in _compressPreamble. Are two threads using this CancellableStreamReader at the same time?"); + + if (_bytePos == preamble.Length) + { + // We have a match + CompressBuffer(preamble.Length); + _bytePos = 0; + _checkPreamble = false; + _detectEncoding = false; + } + + return _checkPreamble; + } + } + + internal virtual int ReadBuffer() + { + _charLen = 0; + _charPos = 0; + + if (!_checkPreamble) + { + _byteLen = 0; + } + + bool eofReached = false; + + do + { + if (_checkPreamble) + { + Debug.Assert(_bytePos <= _encodingPreamble.Length, "possible bug in _compressPreamble. Are two threads using this CancellableStreamReader at the same time?"); + int len = _stream.Read(_byteBuffer, _bytePos, _byteBuffer.Length - _bytePos); + Debug.Assert(len >= 0, "Stream.Read returned a negative number! This is a bug in your stream class."); + + if (len == 0) + { + eofReached = true; + break; + } + + _byteLen += len; + } + else + { + Debug.Assert(_bytePos == 0, "bytePos can be non zero only when we are trying to _checkPreamble. Are two threads using this CancellableStreamReader at the same time?"); + _byteLen = _stream.Read(_byteBuffer, 0, _byteBuffer.Length); + Debug.Assert(_byteLen >= 0, "Stream.Read returned a negative number! This is a bug in your stream class."); + + if (_byteLen == 0) + { + eofReached = true; + break; + } + } + + // _isBlocked == whether we read fewer bytes than we asked for. + // Note we must check it here because CompressBuffer or + // DetectEncoding will change byteLen. + _isBlocked = (_byteLen < _byteBuffer.Length); + + // Check for preamble before detect encoding. This is not to override the + // user supplied Encoding for the one we implicitly detect. The user could + // customize the encoding which we will loose, such as ThrowOnError on UTF8 + if (IsPreamble()) + { + continue; + } + + // If we're supposed to detect the encoding and haven't done so yet, + // do it. Note this may need to be called more than once. + if (_detectEncoding && _byteLen >= 2) + { + DetectEncoding(); + } + + Debug.Assert(_charPos == 0 && _charLen == 0, "We shouldn't be trying to decode more data if we made progress in an earlier iteration."); + _charLen = _decoder.GetChars(_byteBuffer, 0, _byteLen, _charBuffer, 0, flush: false); + } while (_charLen == 0); + + if (eofReached) + { + // EOF has been reached - perform final flush. + // We need to reset _bytePos and _byteLen just in case we hadn't + // finished processing the preamble before we reached EOF. + + Debug.Assert(_charPos == 0 && _charLen == 0, "We shouldn't be looking for EOF unless we have an empty char buffer."); + _charLen = _decoder.GetChars(_byteBuffer, 0, _byteLen, _charBuffer, 0, flush: true); + _bytePos = 0; + _byteLen = 0; + } + + return _charLen; + } + + + // This version has a perf optimization to decode data DIRECTLY into the + // user's buffer, bypassing CancellableStreamReader's own buffer. + // This gives a > 20% perf improvement for our encodings across the board, + // but only when asking for at least the number of characters that one + // buffer's worth of bytes could produce. + // This optimization, if run, will break SwitchEncoding, so we must not do + // this on the first call to ReadBuffer. + private int ReadBuffer(Span userBuffer, out bool readToUserBuffer) + { + _charLen = 0; + _charPos = 0; + + if (!_checkPreamble) + { + _byteLen = 0; + } + + bool eofReached = false; + int charsRead = 0; + + // As a perf optimization, we can decode characters DIRECTLY into a + // user's char[]. We absolutely must not write more characters + // into the user's buffer than they asked for. Calculating + // encoding.GetMaxCharCount(byteLen) each time is potentially very + // expensive - instead, cache the number of chars a full buffer's + // worth of data may produce. Yes, this makes the perf optimization + // less aggressive, in that all reads that asked for fewer than AND + // returned fewer than _maxCharsPerBuffer chars won't get the user + // buffer optimization. This affects reads where the end of the + // Stream comes in the middle somewhere, and when you ask for + // fewer chars than your buffer could produce. + readToUserBuffer = userBuffer.Length >= _maxCharsPerBuffer; + + do + { + Debug.Assert(charsRead == 0); + + if (_checkPreamble) + { + Debug.Assert(_bytePos <= _encodingPreamble.Length, "possible bug in _compressPreamble. Are two threads using this CancellableStreamReader at the same time?"); + int len = _stream.Read(_byteBuffer, _bytePos, _byteBuffer.Length - _bytePos); + Debug.Assert(len >= 0, "Stream.Read returned a negative number! This is a bug in your stream class."); + + if (len == 0) + { + eofReached = true; + break; + } + + _byteLen += len; + } + else + { + Debug.Assert(_bytePos == 0, "bytePos can be non zero only when we are trying to _checkPreamble. Are two threads using this CancellableStreamReader at the same time?"); + _byteLen = _stream.Read(_byteBuffer, 0, _byteBuffer.Length); + Debug.Assert(_byteLen >= 0, "Stream.Read returned a negative number! This is a bug in your stream class."); + + if (_byteLen == 0) + { + eofReached = true; + break; + } + } + + // _isBlocked == whether we read fewer bytes than we asked for. + // Note we must check it here because CompressBuffer or + // DetectEncoding will change byteLen. + _isBlocked = (_byteLen < _byteBuffer.Length); + + // Check for preamble before detect encoding. This is not to override the + // user supplied Encoding for the one we implicitly detect. The user could + // customize the encoding which we will loose, such as ThrowOnError on UTF8 + // Note: we don't need to recompute readToUserBuffer optimization as IsPreamble + // doesn't change the encoding or affect _maxCharsPerBuffer + if (IsPreamble()) + { + continue; + } + + // On the first call to ReadBuffer, if we're supposed to detect the encoding, do it. + if (_detectEncoding && _byteLen >= 2) + { + DetectEncoding(); + // DetectEncoding changes some buffer state. Recompute this. + readToUserBuffer = userBuffer.Length >= _maxCharsPerBuffer; + } + + Debug.Assert(charsRead == 0 && _charPos == 0 && _charLen == 0, "We shouldn't be trying to decode more data if we made progress in an earlier iteration."); + if (readToUserBuffer) + { + charsRead = GetChars(_decoder, new ReadOnlySpan(_byteBuffer, 0, _byteLen), userBuffer, flush: false); + } + else + { + charsRead = _decoder.GetChars(_byteBuffer, 0, _byteLen, _charBuffer, 0, flush: false); + _charLen = charsRead; // Number of chars in CancellableStreamReader's buffer. + } + } while (charsRead == 0); + + if (eofReached) + { + // EOF has been reached - perform final flush. + // We need to reset _bytePos and _byteLen just in case we hadn't + // finished processing the preamble before we reached EOF. + + Debug.Assert(charsRead == 0 && _charPos == 0 && _charLen == 0, "We shouldn't be looking for EOF unless we have an empty char buffer."); + + if (readToUserBuffer) + { + charsRead = GetChars(_decoder, new ReadOnlySpan(_byteBuffer, 0, _byteLen), userBuffer, flush: true); + } + else + { + charsRead = _decoder.GetChars(_byteBuffer, 0, _byteLen, _charBuffer, 0, flush: true); + _charLen = charsRead; // Number of chars in CancellableStreamReader's buffer. + } + _bytePos = 0; + _byteLen = 0; + } + + _isBlocked &= charsRead < userBuffer.Length; + + return charsRead; + } + + + // Reads a line. A line is defined as a sequence of characters followed by + // a carriage return ('\r'), a line feed ('\n'), or a carriage return + // immediately followed by a line feed. The resulting string does not + // contain the terminating carriage return and/or line feed. The returned + // value is null if the end of the input stream has been reached. + // + public override string? ReadLine() + { + ThrowIfDisposed(); + CheckAsyncTaskInProgress(); + + if (_charPos == _charLen) + { + if (ReadBuffer() == 0) + { + return null; + } + } + + var vsb = new ValueStringBuilder(stackalloc char[256]); + do + { + // Look for '\r' or \'n'. + ReadOnlySpan charBufferSpan = _charBuffer.AsSpan(_charPos, _charLen - _charPos); + Debug.Assert(!charBufferSpan.IsEmpty, "ReadBuffer returned > 0 but didn't bump _charLen?"); + + int idxOfNewline = charBufferSpan.IndexOfAny('\r', '\n'); + if (idxOfNewline >= 0) + { + string retVal; + if (vsb.Length == 0) + { + retVal = charBufferSpan.Slice(0, idxOfNewline).ToString(); + } + else + { + retVal = string.Concat(vsb.AsSpan().ToString(), charBufferSpan.Slice(0, idxOfNewline).ToString()); + vsb.Dispose(); + } + + char matchedChar = charBufferSpan[idxOfNewline]; + _charPos += idxOfNewline + 1; + + // If we found '\r', consume any immediately following '\n'. + if (matchedChar == '\r') + { + if (_charPos < _charLen || ReadBuffer() > 0) + { + if (_charBuffer[_charPos] == '\n') + { + _charPos++; + } + } + } + + return retVal; + } + + // We didn't find '\r' or '\n'. Add it to the StringBuilder + // and loop until we reach a newline or EOF. + + vsb.Append(charBufferSpan); + } while (ReadBuffer() > 0); + + return vsb.ToString(); + } + + public override Task ReadLineAsync() => + ReadLineAsync(default).AsTask(); + + /// + /// Reads a line of characters asynchronously from the current stream and returns the data as a string. + /// + /// The token to monitor for cancellation requests. + /// A value task that represents the asynchronous read operation. The value of the TResult + /// parameter contains the next line from the stream, or is if all of the characters have been read. + /// The number of characters in the next line is larger than . + /// The stream reader has been disposed. + /// The reader is currently in use by a previous read operation. + /// + /// The following example shows how to read and print all lines from the file until the end of the file is reached or the operation timed out. + /// + /// using CancellationTokenSource tokenSource = new (TimeSpan.FromSeconds(1)); + /// using CancellableStreamReader reader = File.OpenText("existingfile.txt"); + /// + /// string line; + /// while ((line = await reader.ReadLineAsync(tokenSource.Token)) is not null) + /// { + /// Console.WriteLine(line); + /// } + /// + /// + /// + /// If this method is canceled via , some data + /// that has been read from the current but not stored (by the + /// ) or returned (to the caller) may be lost. + /// + public virtual ValueTask ReadLineAsync(CancellationToken cancellationToken) + { + // If we have been inherited into a subclass, the following implementation could be incorrect + // since it does not call through to Read() which a subclass might have overridden. + // To be safe we will only use this implementation in cases where we know it is safe to do so, + // and delegate to our base class (which will call into Read) when we are not sure. + if (GetType() != typeof(CancellableStreamReader)) + { + return new ValueTask(base.ReadLineAsync()!); + } + + ThrowIfDisposed(); + CheckAsyncTaskInProgress(); + + Task task = ReadLineAsyncInternal(cancellationToken); + _asyncReadTask = task; + + return new ValueTask(task); + } + + private async Task ReadLineAsyncInternal(CancellationToken cancellationToken) + { + if (_charPos == _charLen && (await ReadBufferAsync(cancellationToken).ConfigureAwait(false)) == 0) + { + return null; + } + + string retVal; + char[]? arrayPoolBuffer = null; + int arrayPoolBufferPos = 0; + + do + { + char[] charBuffer = _charBuffer; + int charLen = _charLen; + int charPos = _charPos; + + // Look for '\r' or \'n'. + Debug.Assert(charPos < charLen, "ReadBuffer returned > 0 but didn't bump _charLen?"); + + int idxOfNewline = charBuffer.AsSpan(charPos, charLen - charPos).IndexOfAny('\r', '\n'); + if (idxOfNewline >= 0) + { + if (arrayPoolBuffer is null) + { + retVal = new string(charBuffer, charPos, idxOfNewline); + } + else + { + retVal = string.Concat(arrayPoolBuffer.AsSpan(0, arrayPoolBufferPos).ToString(), charBuffer.AsSpan(charPos, idxOfNewline).ToString()); + ArrayPool.Shared.Return(arrayPoolBuffer); + } + + charPos += idxOfNewline; + char matchedChar = charBuffer[charPos++]; + _charPos = charPos; + + // If we found '\r', consume any immediately following '\n'. + if (matchedChar == '\r') + { + if (charPos < charLen || (await ReadBufferAsync(cancellationToken).ConfigureAwait(false)) > 0) + { + if (_charBuffer[_charPos] == '\n') + { + _charPos++; + } + } + } + + return retVal; + } + + // We didn't find '\r' or '\n'. Add the read data to the pooled buffer + // and loop until we reach a newline or EOF. + if (arrayPoolBuffer is null) + { + arrayPoolBuffer = ArrayPool.Shared.Rent(charLen - charPos + 80); + } + else if ((arrayPoolBuffer.Length - arrayPoolBufferPos) < (charLen - charPos)) + { + char[] newBuffer = ArrayPool.Shared.Rent(checked(arrayPoolBufferPos + charLen - charPos)); + arrayPoolBuffer.AsSpan(0, arrayPoolBufferPos).CopyTo(newBuffer); + ArrayPool.Shared.Return(arrayPoolBuffer); + arrayPoolBuffer = newBuffer; + } + charBuffer.AsSpan(charPos, charLen - charPos).CopyTo(arrayPoolBuffer.AsSpan(arrayPoolBufferPos)); + arrayPoolBufferPos += charLen - charPos; + } + while (await ReadBufferAsync(cancellationToken).ConfigureAwait(false) > 0); + + if (arrayPoolBuffer is not null) + { + retVal = new string(arrayPoolBuffer, 0, arrayPoolBufferPos); + ArrayPool.Shared.Return(arrayPoolBuffer); + } + else + { + retVal = string.Empty; + } + + return retVal; + } + + public override Task ReadToEndAsync() => ReadToEndAsync(default); + + /// + /// Reads all characters from the current position to the end of the stream asynchronously and returns them as one string. + /// + /// The token to monitor for cancellation requests. + /// A task that represents the asynchronous read operation. The value of the TResult parameter contains + /// a string with the characters from the current position to the end of the stream. + /// The number of characters is larger than . + /// The stream reader has been disposed. + /// The reader is currently in use by a previous read operation. + /// + /// The following example shows how to read the contents of a file by using the method. + /// + /// using CancellationTokenSource tokenSource = new (TimeSpan.FromSeconds(1)); + /// using CancellableStreamReader reader = File.OpenText("existingfile.txt"); + /// + /// Console.WriteLine(await reader.ReadToEndAsync(tokenSource.Token)); + /// + /// + /// + /// If this method is canceled via , some data + /// that has been read from the current but not stored (by the + /// ) or returned (to the caller) may be lost. + /// + public virtual Task ReadToEndAsync(CancellationToken cancellationToken) + { + // If we have been inherited into a subclass, the following implementation could be incorrect + // since it does not call through to Read() which a subclass might have overridden. + // To be safe we will only use this implementation in cases where we know it is safe to do so, + // and delegate to our base class (which will call into Read) when we are not sure. + if (GetType() != typeof(CancellableStreamReader)) + { + return base.ReadToEndAsync(); + } + + ThrowIfDisposed(); + CheckAsyncTaskInProgress(); + + Task task = ReadToEndAsyncInternal(cancellationToken); + _asyncReadTask = task; + + return task; + } + + private async Task ReadToEndAsyncInternal(CancellationToken cancellationToken) + { + // Call ReadBuffer, then pull data out of charBuffer. + StringBuilder sb = new StringBuilder(_charLen - _charPos); + do + { + int tmpCharPos = _charPos; + sb.Append(_charBuffer, tmpCharPos, _charLen - tmpCharPos); + _charPos = _charLen; // We consumed these characters + await ReadBufferAsync(cancellationToken).ConfigureAwait(false); + } while (_charLen > 0); + + return sb.ToString(); + } + + public override Task ReadAsync(char[] buffer, int index, int count) + { + Throw.IfNull(buffer); + + Throw.IfNegative(index); + Throw.IfNegative(count); + if (buffer.Length - index < count) + { + throw new ArgumentException("invalid offset length."); + } + + // If we have been inherited into a subclass, the following implementation could be incorrect + // since it does not call through to Read() which a subclass might have overridden. + // To be safe we will only use this implementation in cases where we know it is safe to do so, + // and delegate to our base class (which will call into Read) when we are not sure. + if (GetType() != typeof(CancellableStreamReader)) + { + return base.ReadAsync(buffer, index, count); + } + + ThrowIfDisposed(); + CheckAsyncTaskInProgress(); + + Task task = ReadAsyncInternal(new Memory(buffer, index, count), CancellationToken.None).AsTask(); + _asyncReadTask = task; + + return task; + } + + public virtual ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + Debug.Assert(GetType() == typeof(CancellableStreamReader)); + + ThrowIfDisposed(); + CheckAsyncTaskInProgress(); + + if (cancellationToken.IsCancellationRequested) + { + return new ValueTask(Task.FromCanceled(cancellationToken)); + } + + return ReadAsyncInternal(buffer, cancellationToken); + } + + private protected virtual async ValueTask ReadAsyncInternal(Memory buffer, CancellationToken cancellationToken) + { + if (_charPos == _charLen && (await ReadBufferAsync(cancellationToken).ConfigureAwait(false)) == 0) + { + return 0; + } + + int charsRead = 0; + + // As a perf optimization, if we had exactly one buffer's worth of + // data read in, let's try writing directly to the user's buffer. + bool readToUserBuffer = false; + + byte[] tmpByteBuffer = _byteBuffer; + Stream tmpStream = _stream; + + int count = buffer.Length; + while (count > 0) + { + // n is the characters available in _charBuffer + int n = _charLen - _charPos; + + // charBuffer is empty, let's read from the stream + if (n == 0) + { + _charLen = 0; + _charPos = 0; + + if (!_checkPreamble) + { + _byteLen = 0; + } + + readToUserBuffer = count >= _maxCharsPerBuffer; + + // We loop here so that we read in enough bytes to yield at least 1 char. + // We break out of the loop if the stream is blocked (EOF is reached). + do + { + Debug.Assert(n == 0); + + if (_checkPreamble) + { + Debug.Assert(_bytePos <= _encodingPreamble.Length, "possible bug in _compressPreamble. Are two threads using this CancellableStreamReader at the same time?"); + int tmpBytePos = _bytePos; + int len = await tmpStream.ReadAsync(new Memory(tmpByteBuffer, tmpBytePos, tmpByteBuffer.Length - tmpBytePos), cancellationToken).ConfigureAwait(false); + Debug.Assert(len >= 0, "Stream.Read returned a negative number! This is a bug in your stream class."); + + if (len == 0) + { + // EOF but we might have buffered bytes from previous + // attempts to detect preamble that needs to be decoded now + if (_byteLen > 0) + { + if (readToUserBuffer) + { + n = GetChars(_decoder, new ReadOnlySpan(tmpByteBuffer, 0, _byteLen), buffer.Span.Slice(charsRead), flush: false); + _charLen = 0; // CancellableStreamReader's buffer is empty. + } + else + { + n = _decoder.GetChars(tmpByteBuffer, 0, _byteLen, _charBuffer, 0); + _charLen += n; // Number of chars in CancellableStreamReader's buffer. + } + } + + // How can part of the preamble yield any chars? + Debug.Assert(n == 0); + + _isBlocked = true; + break; + } + else + { + _byteLen += len; + } + } + else + { + Debug.Assert(_bytePos == 0, "_bytePos can be non zero only when we are trying to _checkPreamble. Are two threads using this CancellableStreamReader at the same time?"); + + _byteLen = await tmpStream.ReadAsync(new Memory(tmpByteBuffer), cancellationToken).ConfigureAwait(false); + + Debug.Assert(_byteLen >= 0, "Stream.Read returned a negative number! This is a bug in your stream class."); + + if (_byteLen == 0) // EOF + { + _isBlocked = true; + break; + } + } + + // _isBlocked == whether we read fewer bytes than we asked for. + // Note we must check it here because CompressBuffer or + // DetectEncoding will change _byteLen. + _isBlocked = (_byteLen < tmpByteBuffer.Length); + + // Check for preamble before detect encoding. This is not to override the + // user supplied Encoding for the one we implicitly detect. The user could + // customize the encoding which we will loose, such as ThrowOnError on UTF8 + // Note: we don't need to recompute readToUserBuffer optimization as IsPreamble + // doesn't change the encoding or affect _maxCharsPerBuffer + if (IsPreamble()) + { + continue; + } + + // On the first call to ReadBuffer, if we're supposed to detect the encoding, do it. + if (_detectEncoding && _byteLen >= 2) + { + DetectEncoding(); + // DetectEncoding changes some buffer state. Recompute this. + readToUserBuffer = count >= _maxCharsPerBuffer; + } + + Debug.Assert(n == 0); + + _charPos = 0; + if (readToUserBuffer) + { + n = GetChars(_decoder, new ReadOnlySpan(tmpByteBuffer, 0, _byteLen), buffer.Span.Slice(charsRead), flush: false); + _charLen = 0; // CancellableStreamReader's buffer is empty. + } + else + { + n = _decoder.GetChars(tmpByteBuffer, 0, _byteLen, _charBuffer, 0); + _charLen += n; // Number of chars in CancellableStreamReader's buffer. + } + } while (n == 0); + + if (n == 0) + { + break; // We're at EOF + } + } // if (n == 0) + + // Got more chars in charBuffer than the user requested + if (n > count) + { + n = count; + } + + if (!readToUserBuffer) + { + new Span(_charBuffer, _charPos, n).CopyTo(buffer.Span.Slice(charsRead)); + _charPos += n; + } + + charsRead += n; + count -= n; + + // This function shouldn't block for an indefinite amount of time, + // or reading from a network stream won't work right. If we got + // fewer bytes than we requested, then we want to break right here. + if (_isBlocked) + { + break; + } + } // while (count > 0) + + return charsRead; + } + + public override Task ReadBlockAsync(char[] buffer, int index, int count) + { + Throw.IfNull(buffer); + + Throw.IfNegative(index); + Throw.IfNegative(count); + if (buffer.Length - index < count) + { + throw new ArgumentException("invalid offset length."); + } + + // If we have been inherited into a subclass, the following implementation could be incorrect + // since it does not call through to Read() which a subclass might have overridden. + // To be safe we will only use this implementation in cases where we know it is safe to do so, + // and delegate to our base class (which will call into Read) when we are not sure. + if (GetType() != typeof(CancellableStreamReader)) + { + return base.ReadBlockAsync(buffer, index, count); + } + + ThrowIfDisposed(); + CheckAsyncTaskInProgress(); + + Task task = base.ReadBlockAsync(buffer, index, count); + _asyncReadTask = task; + + return task; + } + + public virtual ValueTask ReadBlockAsync(Memory buffer, CancellationToken cancellationToken = default) + { + Debug.Assert(GetType() == typeof(CancellableStreamReader)); + + ThrowIfDisposed(); + CheckAsyncTaskInProgress(); + + if (cancellationToken.IsCancellationRequested) + { + return new ValueTask(Task.FromCanceled(cancellationToken)); + } + + ValueTask vt = ReadBlockAsyncInternal(buffer, cancellationToken); + if (vt.IsCompletedSuccessfully) + { + return vt; + } + + Task t = vt.AsTask(); + _asyncReadTask = t; + return new ValueTask(t); + } + + private async ValueTask ReadBufferAsync(CancellationToken cancellationToken) + { + _charLen = 0; + _charPos = 0; + byte[] tmpByteBuffer = _byteBuffer; + Stream tmpStream = _stream; + + if (!_checkPreamble) + { + _byteLen = 0; + } + + bool eofReached = false; + + do + { + if (_checkPreamble) + { + Debug.Assert(_bytePos <= _encodingPreamble.Length, "possible bug in _compressPreamble. Are two threads using this CancellableStreamReader at the same time?"); + int tmpBytePos = _bytePos; + int len = await tmpStream.ReadAsync(tmpByteBuffer.AsMemory(tmpBytePos), cancellationToken).ConfigureAwait(false); + Debug.Assert(len >= 0, "Stream.Read returned a negative number! This is a bug in your stream class."); + + if (len == 0) + { + eofReached = true; + break; + } + + _byteLen += len; + } + else + { + Debug.Assert(_bytePos == 0, "_bytePos can be non zero only when we are trying to _checkPreamble. Are two threads using this CancellableStreamReader at the same time?"); + _byteLen = await tmpStream.ReadAsync(new Memory(tmpByteBuffer), cancellationToken).ConfigureAwait(false); + Debug.Assert(_byteLen >= 0, "Stream.Read returned a negative number! Bug in stream class."); + + if (_byteLen == 0) + { + eofReached = true; + break; + } + } + + // _isBlocked == whether we read fewer bytes than we asked for. + // Note we must check it here because CompressBuffer or + // DetectEncoding will change _byteLen. + _isBlocked = (_byteLen < tmpByteBuffer.Length); + + // Check for preamble before detect encoding. This is not to override the + // user supplied Encoding for the one we implicitly detect. The user could + // customize the encoding which we will loose, such as ThrowOnError on UTF8 + if (IsPreamble()) + { + continue; + } + + // If we're supposed to detect the encoding and haven't done so yet, + // do it. Note this may need to be called more than once. + if (_detectEncoding && _byteLen >= 2) + { + DetectEncoding(); + } + + Debug.Assert(_charPos == 0 && _charLen == 0, "We shouldn't be trying to decode more data if we made progress in an earlier iteration."); + _charLen = _decoder.GetChars(tmpByteBuffer, 0, _byteLen, _charBuffer, 0, flush: false); + } while (_charLen == 0); + + if (eofReached) + { + // EOF has been reached - perform final flush. + // We need to reset _bytePos and _byteLen just in case we hadn't + // finished processing the preamble before we reached EOF. + + Debug.Assert(_charPos == 0 && _charLen == 0, "We shouldn't be looking for EOF unless we have an empty char buffer."); + _charLen = _decoder.GetChars(_byteBuffer, 0, _byteLen, _charBuffer, 0, flush: true); + _bytePos = 0; + _byteLen = 0; + } + + return _charLen; + } + + private async ValueTask ReadBlockAsyncInternal(Memory buffer, CancellationToken cancellationToken) + { + int n = 0, i; + do + { + i = await ReadAsyncInternal(buffer.Slice(n), cancellationToken).ConfigureAwait(false); + n += i; + } while (i > 0 && n < buffer.Length); + + return n; + } + + private static unsafe int GetChars(Decoder decoder, ReadOnlySpan bytes, Span chars, bool flush = false) + { + Throw.IfNull(decoder); + if (decoder is null || bytes.IsEmpty || chars.IsEmpty) + { + return 0; + } + + fixed (byte* pBytes = bytes) + fixed (char* pChars = chars) + { + return decoder.GetChars(pBytes, bytes.Length, pChars, chars.Length, flush); + } + } + + private void ThrowIfDisposed() + { + if (_disposed) + { + ThrowObjectDisposedException(); + } + + void ThrowObjectDisposedException() => throw new ObjectDisposedException(GetType().Name, "reader has been closed."); + } + + // No data, class doesn't need to be serializable. + // Note this class is threadsafe. + internal sealed class NullCancellableStreamReader : CancellableStreamReader + { + public override Encoding CurrentEncoding => Encoding.Unicode; + + protected override void Dispose(bool disposing) + { + // Do nothing - this is essentially unclosable. + } + + public override int Peek() => -1; + + public override int Read() => -1; + + public override int Read(char[] buffer, int index, int count) => 0; + + public override Task ReadAsync(char[] buffer, int index, int count) => Task.FromResult(0); + + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken) => + cancellationToken.IsCancellationRequested ? new ValueTask(Task.FromCanceled(cancellationToken)) : default; + + public override int ReadBlock(char[] buffer, int index, int count) => 0; + + public override Task ReadBlockAsync(char[] buffer, int index, int count) => Task.FromResult(0); + + public override ValueTask ReadBlockAsync(Memory buffer, CancellationToken cancellationToken) => + cancellationToken.IsCancellationRequested ? new ValueTask(Task.FromCanceled(cancellationToken)) : default; + + public override string? ReadLine() => null; + + public override Task ReadLineAsync() => Task.FromResult(null); + + public override ValueTask ReadLineAsync(CancellationToken cancellationToken) => + cancellationToken.IsCancellationRequested ? new ValueTask(Task.FromCanceled(cancellationToken)) : default; + + public override string ReadToEnd() => ""; + + public override Task ReadToEndAsync() => Task.FromResult(""); + + public override Task ReadToEndAsync(CancellationToken cancellationToken) => + cancellationToken.IsCancellationRequested ? Task.FromCanceled(cancellationToken) : Task.FromResult(""); + + private protected override ValueTask ReadAsyncInternal(Memory buffer, CancellationToken cancellationToken) => + cancellationToken.IsCancellationRequested ? new ValueTask(Task.FromCanceled(cancellationToken)) : default; + + internal override int ReadBuffer() => 0; + } +} \ No newline at end of file diff --git a/src/Common/CancellableStreamReader/TextReaderExtensions.cs b/src/Common/CancellableStreamReader/TextReaderExtensions.cs new file mode 100644 index 000000000..11ba0565e --- /dev/null +++ b/src/Common/CancellableStreamReader/TextReaderExtensions.cs @@ -0,0 +1,15 @@ +namespace System.IO; + +internal static class TextReaderExtensions +{ + public static ValueTask ReadLineAsync(this TextReader reader, CancellationToken cancellationToken) + { + if (reader is CancellableStreamReader cancellableReader) + { + return cancellableReader.ReadLineAsync(cancellationToken)!; + } + + cancellationToken.ThrowIfCancellationRequested(); + return new ValueTask(reader.ReadLineAsync()); + } +} \ No newline at end of file diff --git a/src/Common/CancellableStreamReader/ValueStringBuilder.cs b/src/Common/CancellableStreamReader/ValueStringBuilder.cs new file mode 100644 index 000000000..27bea693e --- /dev/null +++ b/src/Common/CancellableStreamReader/ValueStringBuilder.cs @@ -0,0 +1,317 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +#nullable enable + +namespace System.Text +{ + internal ref partial struct ValueStringBuilder + { + private char[]? _arrayToReturnToPool; + private Span _chars; + private int _pos; + + public ValueStringBuilder(Span initialBuffer) + { + _arrayToReturnToPool = null; + _chars = initialBuffer; + _pos = 0; + } + + public ValueStringBuilder(int initialCapacity) + { + _arrayToReturnToPool = ArrayPool.Shared.Rent(initialCapacity); + _chars = _arrayToReturnToPool; + _pos = 0; + } + + public int Length + { + get => _pos; + set + { + Debug.Assert(value >= 0); + Debug.Assert(value <= _chars.Length); + _pos = value; + } + } + + public int Capacity => _chars.Length; + + public void EnsureCapacity(int capacity) + { + // This is not expected to be called this with negative capacity + Debug.Assert(capacity >= 0); + + // If the caller has a bug and calls this with negative capacity, make sure to call Grow to throw an exception. + if ((uint)capacity > (uint)_chars.Length) + Grow(capacity - _pos); + } + + /// + /// Get a pinnable reference to the builder. + /// Does not ensure there is a null char after + /// This overload is pattern matched in the C# 7.3+ compiler so you can omit + /// the explicit method call, and write eg "fixed (char* c = builder)" + /// + public ref char GetPinnableReference() + { + return ref MemoryMarshal.GetReference(_chars); + } + + /// + /// Get a pinnable reference to the builder. + /// + /// Ensures that the builder has a null char after + public ref char GetPinnableReference(bool terminate) + { + if (terminate) + { + EnsureCapacity(Length + 1); + _chars[Length] = '\0'; + } + return ref MemoryMarshal.GetReference(_chars); + } + + public ref char this[int index] + { + get + { + Debug.Assert(index < _pos); + return ref _chars[index]; + } + } + + public override string ToString() + { + string s = _chars.Slice(0, _pos).ToString(); + Dispose(); + return s; + } + + /// Returns the underlying storage of the builder. + public Span RawChars => _chars; + + /// + /// Returns a span around the contents of the builder. + /// + /// Ensures that the builder has a null char after + public ReadOnlySpan AsSpan(bool terminate) + { + if (terminate) + { + EnsureCapacity(Length + 1); + _chars[Length] = '\0'; + } + return _chars.Slice(0, _pos); + } + + public ReadOnlySpan AsSpan() => _chars.Slice(0, _pos); + public ReadOnlySpan AsSpan(int start) => _chars.Slice(start, _pos - start); + public ReadOnlySpan AsSpan(int start, int length) => _chars.Slice(start, length); + + public bool TryCopyTo(Span destination, out int charsWritten) + { + if (_chars.Slice(0, _pos).TryCopyTo(destination)) + { + charsWritten = _pos; + Dispose(); + return true; + } + else + { + charsWritten = 0; + Dispose(); + return false; + } + } + + public void Insert(int index, char value, int count) + { + if (_pos > _chars.Length - count) + { + Grow(count); + } + + int remaining = _pos - index; + _chars.Slice(index, remaining).CopyTo(_chars.Slice(index + count)); + _chars.Slice(index, count).Fill(value); + _pos += count; + } + + public void Insert(int index, string? s) + { + if (s == null) + { + return; + } + + int count = s.Length; + + if (_pos > (_chars.Length - count)) + { + Grow(count); + } + + int remaining = _pos - index; + _chars.Slice(index, remaining).CopyTo(_chars.Slice(index + count)); + s +#if !NET + .AsSpan() +#endif + .CopyTo(_chars.Slice(index)); + _pos += count; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Append(char c) + { + int pos = _pos; + Span chars = _chars; + if ((uint)pos < (uint)chars.Length) + { + chars[pos] = c; + _pos = pos + 1; + } + else + { + GrowAndAppend(c); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Append(string? s) + { + if (s == null) + { + return; + } + + int pos = _pos; + if (s.Length == 1 && (uint)pos < (uint)_chars.Length) // very common case, e.g. appending strings from NumberFormatInfo like separators, percent symbols, etc. + { + _chars[pos] = s[0]; + _pos = pos + 1; + } + else + { + AppendSlow(s); + } + } + + private void AppendSlow(string s) + { + int pos = _pos; + if (pos > _chars.Length - s.Length) + { + Grow(s.Length); + } + + s +#if !NET + .AsSpan() +#endif + .CopyTo(_chars.Slice(pos)); + _pos += s.Length; + } + + public void Append(char c, int count) + { + if (_pos > _chars.Length - count) + { + Grow(count); + } + + Span dst = _chars.Slice(_pos, count); + for (int i = 0; i < dst.Length; i++) + { + dst[i] = c; + } + _pos += count; + } + + public void Append(scoped ReadOnlySpan value) + { + int pos = _pos; + if (pos > _chars.Length - value.Length) + { + Grow(value.Length); + } + + value.CopyTo(_chars.Slice(_pos)); + _pos += value.Length; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Span AppendSpan(int length) + { + int origPos = _pos; + if (origPos > _chars.Length - length) + { + Grow(length); + } + + _pos = origPos + length; + return _chars.Slice(origPos, length); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void GrowAndAppend(char c) + { + Grow(1); + Append(c); + } + + /// + /// Resize the internal buffer either by doubling current buffer size or + /// by adding to + /// whichever is greater. + /// + /// + /// Number of chars requested beyond current position. + /// + [MethodImpl(MethodImplOptions.NoInlining)] + private void Grow(int additionalCapacityBeyondPos) + { + Debug.Assert(additionalCapacityBeyondPos > 0); + Debug.Assert(_pos > _chars.Length - additionalCapacityBeyondPos, "Grow called incorrectly, no resize is needed."); + + const uint ArrayMaxLength = 0x7FFFFFC7; // same as Array.MaxLength + + // Increase to at least the required size (_pos + additionalCapacityBeyondPos), but try + // to double the size if possible, bounding the doubling to not go beyond the max array length. + int newCapacity = (int)Math.Max( + (uint)(_pos + additionalCapacityBeyondPos), + Math.Min((uint)_chars.Length * 2, ArrayMaxLength)); + + // Make sure to let Rent throw an exception if the caller has a bug and the desired capacity is negative. + // This could also go negative if the actual required length wraps around. + char[] poolArray = ArrayPool.Shared.Rent(newCapacity); + + _chars.Slice(0, _pos).CopyTo(poolArray); + + char[]? toReturn = _arrayToReturnToPool; + _chars = _arrayToReturnToPool = poolArray; + if (toReturn != null) + { + ArrayPool.Shared.Return(toReturn); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Dispose() + { + char[]? toReturn = _arrayToReturnToPool; + this = default; // for safety, to avoid using pooled array if this instance is erroneously appended to again + if (toReturn != null) + { + ArrayPool.Shared.Return(toReturn); + } + } + } +} \ No newline at end of file diff --git a/src/Common/Polyfills/System/IO/StreamExtensions.cs b/src/Common/Polyfills/System/IO/StreamExtensions.cs index d58ffaf3b..4dc8e2a5a 100644 --- a/src/Common/Polyfills/System/IO/StreamExtensions.cs +++ b/src/Common/Polyfills/System/IO/StreamExtensions.cs @@ -1,6 +1,7 @@ using ModelContextProtocol; using System.Buffers; using System.Runtime.InteropServices; +using System.Text; namespace System.IO; @@ -33,4 +34,31 @@ static async ValueTask WriteAsyncCore(Stream stream, ReadOnlyMemory buffer } } } + + public static ValueTask ReadAsync(this Stream stream, Memory buffer, CancellationToken cancellationToken) + { + Throw.IfNull(stream); + if (MemoryMarshal.TryGetArray(buffer, out ArraySegment segment)) + { + return new ValueTask(stream.ReadAsync(segment.Array, segment.Offset, segment.Count, cancellationToken)); + } + else + { + return ReadAsyncCore(stream, buffer, cancellationToken); + static async ValueTask ReadAsyncCore(Stream stream, Memory buffer, CancellationToken cancellationToken) + { + byte[] array = ArrayPool.Shared.Rent(buffer.Length); + try + { + int bytesRead = await stream.ReadAsync(array, 0, buffer.Length, cancellationToken).ConfigureAwait(false); + array.AsSpan(0, bytesRead).CopyTo(buffer.Span); + return bytesRead; + } + finally + { + ArrayPool.Shared.Return(array); + } + } + } + } } \ No newline at end of file diff --git a/src/Common/Polyfills/System/IO/TextReaderExtensions.cs b/src/Common/Polyfills/System/IO/TextReaderExtensions.cs deleted file mode 100644 index 63b3db25e..000000000 --- a/src/Common/Polyfills/System/IO/TextReaderExtensions.cs +++ /dev/null @@ -1,10 +0,0 @@ -namespace System.IO; - -internal static class TextReaderExtensions -{ - public static Task ReadLineAsync(this TextReader reader, CancellationToken cancellationToken) - { - cancellationToken.ThrowIfCancellationRequested(); - return reader.ReadLineAsync(); - } -} \ No newline at end of file diff --git a/src/Common/Throw.cs b/src/Common/Throw.cs index 0c6927e5f..ed80036f7 100644 --- a/src/Common/Throw.cs +++ b/src/Common/Throw.cs @@ -25,6 +25,15 @@ public static void IfNullOrWhiteSpace([NotNull] string? arg, [CallerArgumentExpr } } + public static void IfNegative(int arg, [CallerArgumentExpression(nameof(arg))] string? parameterName = null) + { + if (arg < 0) + { + Throw(parameterName); + static void Throw(string? parameterName) => throw new ArgumentOutOfRangeException(parameterName, "must not be negative."); + } + } + [DoesNotReturn] private static void ThrowArgumentNullOrWhiteSpaceException(string? parameterName) { diff --git a/src/ModelContextProtocol.Core/Client/StdioClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StdioClientSessionTransport.cs index 3c7210ecb..2ce32cb72 100644 --- a/src/ModelContextProtocol.Core/Client/StdioClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StdioClientSessionTransport.cs @@ -13,7 +13,7 @@ internal sealed class StdioClientSessionTransport : StreamClientSessionTransport private int _cleanedUp = 0; public StdioClientSessionTransport(StdioClientTransportOptions options, Process process, string endpointName, Queue stderrRollingLog, ILoggerFactory? loggerFactory) - : base(process.StandardInput, process.StandardOutput, endpointName, loggerFactory) + : base(process.StandardInput.BaseStream, process.StandardOutput.BaseStream, encoding: null, endpointName, loggerFactory) { _process = process; _options = options; diff --git a/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs b/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs index e00ddbabe..c026acb93 100644 --- a/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs @@ -69,8 +69,6 @@ public async Task ConnectAsync(CancellationToken cancellationToken = { LogTransportConnecting(logger, endpointName); - UTF8Encoding noBomUTF8 = new(encoderShouldEmitUTF8Identifier: false); - ProcessStartInfo startInfo = new() { FileName = command, @@ -80,10 +78,10 @@ public async Task ConnectAsync(CancellationToken cancellationToken = UseShellExecute = false, CreateNoWindow = true, WorkingDirectory = _options.WorkingDirectory ?? Environment.CurrentDirectory, - StandardOutputEncoding = noBomUTF8, - StandardErrorEncoding = noBomUTF8, + StandardOutputEncoding = StreamClientSessionTransport.NoBomUtf8Encoding, + StandardErrorEncoding = StreamClientSessionTransport.NoBomUtf8Encoding, #if NET - StandardInputEncoding = noBomUTF8, + StandardInputEncoding = StreamClientSessionTransport.NoBomUtf8Encoding, #endif }; @@ -164,7 +162,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = Encoding originalInputEncoding = Console.InputEncoding; try { - Console.InputEncoding = noBomUTF8; + Console.InputEncoding = StreamClientSessionTransport.NoBomUtf8Encoding; processStarted = process.Start(); } finally diff --git a/src/ModelContextProtocol.Core/Client/StreamClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StreamClientSessionTransport.cs index e35e2b18e..dfcccf61b 100644 --- a/src/ModelContextProtocol.Core/Client/StreamClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StreamClientSessionTransport.cs @@ -1,5 +1,6 @@ using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; +using System.Text; using System.Text.Json; namespace ModelContextProtocol.Client; @@ -7,6 +8,8 @@ namespace ModelContextProtocol.Client; /// Provides the client side of a stream-based session transport. internal class StreamClientSessionTransport : TransportBase { + internal static UTF8Encoding NoBomUtf8Encoding { get; } = new(encoderShouldEmitUTF8Identifier: false); + private readonly TextReader _serverOutput; private readonly TextWriter _serverInput; private readonly SemaphoreSlim _sendLock = new(1, 1); @@ -54,6 +57,43 @@ public StreamClientSessionTransport( readTask.Start(); } + /// + /// Initializes a new instance of the class. + /// + /// + /// The server's input stream. Messages written to this stream will be sent to the server. + /// + /// + /// The server's output stream. Messages read from this stream will be received from the server. + /// + /// + /// The encoding used for reading and writing messages from the input and output streams. Defaults to UTF-8 without BOM if null. + /// + /// + /// A name that identifies this transport endpoint in logs. + /// + /// + /// Optional factory for creating loggers. If null, a NullLogger will be used. + /// + /// + /// This constructor starts a background task to read messages from the server output stream. + /// The transport will be marked as connected once initialized. + /// + public StreamClientSessionTransport(Stream serverInput, Stream serverOutput, Encoding? encoding, string endpointName, ILoggerFactory? loggerFactory) + : this( + new StreamWriter(serverInput, encoding ?? NoBomUtf8Encoding), +#if NET + new StreamReader(serverOutput, encoding ?? NoBomUtf8Encoding), +#else + new CancellableStreamReader(serverOutput, encoding ?? NoBomUtf8Encoding), +#endif + endpointName, + loggerFactory) + { + Throw.IfNull(serverInput); + Throw.IfNull(serverOutput); + } + /// public override async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { diff --git a/src/ModelContextProtocol.Core/Client/StreamClientTransport.cs b/src/ModelContextProtocol.Core/Client/StreamClientTransport.cs index 30607e574..a0e335be6 100644 --- a/src/ModelContextProtocol.Core/Client/StreamClientTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StreamClientTransport.cs @@ -47,8 +47,9 @@ public StreamClientTransport( public Task ConnectAsync(CancellationToken cancellationToken = default) { return Task.FromResult(new StreamClientSessionTransport( - new StreamWriter(_serverInput), - new StreamReader(_serverOutput), + _serverInput, + _serverOutput, + encoding: null, "Client (stream)", _loggerFactory)); } diff --git a/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj b/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj index 928c76d97..f3ab71819 100644 --- a/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj +++ b/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj @@ -7,6 +7,7 @@ ModelContextProtocol.Core Core .NET SDK for the Model Context Protocol (MCP) README.md + True @@ -20,6 +21,7 @@ + diff --git a/src/ModelContextProtocol.Core/Server/McpServerPrimitiveCollection.cs b/src/ModelContextProtocol.Core/Server/McpServerPrimitiveCollection.cs index f891858eb..a6188773d 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerPrimitiveCollection.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerPrimitiveCollection.cs @@ -17,7 +17,7 @@ public class McpServerPrimitiveCollection : ICollection, IReadOnlyCollecti /// public McpServerPrimitiveCollection(IEqualityComparer? keyComparer = null) { - _primitives = new(keyComparer); + _primitives = new(keyComparer ?? EqualityComparer.Default); } /// Occurs when the collection is changed. diff --git a/src/ModelContextProtocol.Core/Server/StreamServerTransport.cs b/src/ModelContextProtocol.Core/Server/StreamServerTransport.cs index 9528e4f42..915d7813c 100644 --- a/src/ModelContextProtocol.Core/Server/StreamServerTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamServerTransport.cs @@ -46,7 +46,11 @@ public StreamServerTransport(Stream inputStream, Stream outputStream, string? se _logger = loggerFactory?.CreateLogger(GetType()) ?? NullLogger.Instance; +#if NET _inputReader = new StreamReader(inputStream, Encoding.UTF8); +#else + _inputReader = new CancellableStreamReader(inputStream, Encoding.UTF8); +#endif _outputStream = outputStream; SetConnected(); diff --git a/tests/Common/Utils/MockHttpHandler.cs b/tests/Common/Utils/MockHttpHandler.cs index 5e58a6cd5..d15ec3dc0 100644 --- a/tests/Common/Utils/MockHttpHandler.cs +++ b/tests/Common/Utils/MockHttpHandler.cs @@ -1,4 +1,6 @@ -namespace ModelContextProtocol.Tests.Utils; +using System.Net.Http; + +namespace ModelContextProtocol.Tests.Utils; public class MockHttpHandler : HttpMessageHandler { diff --git a/tests/Common/Utils/ProcessExtensions.cs b/tests/Common/Utils/ProcessExtensions.cs new file mode 100644 index 000000000..186ecc9c6 --- /dev/null +++ b/tests/Common/Utils/ProcessExtensions.cs @@ -0,0 +1,15 @@ +namespace System.Diagnostics; + +public static class ProcessExtensions +{ + public static async Task WaitForExitAsync(this Process process, TimeSpan timeout) + { +#if NET + using var shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); + shutdownCts.CancelAfter(timeout); + await process.WaitForExitAsync(shutdownCts.Token); +#else + process.WaitForExit(milliseconds: (int)timeout.TotalMilliseconds); +#endif + } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.TestServer/ModelContextProtocol.TestServer.csproj b/tests/ModelContextProtocol.TestServer/ModelContextProtocol.TestServer.csproj index 75de837aa..f38a35859 100644 --- a/tests/ModelContextProtocol.TestServer/ModelContextProtocol.TestServer.csproj +++ b/tests/ModelContextProtocol.TestServer/ModelContextProtocol.TestServer.csproj @@ -15,6 +15,7 @@ + diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs index dec5ad057..ebc7171e2 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs @@ -25,14 +25,14 @@ public ClientIntegrationTestFixture() TestServerTransportOptions = new() { - Command = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? "TestServer.exe" : "dotnet", + Command = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? "TestServer.exe" : PlatformDetection.IsMonoRuntime ? "mono" : "dotnet", Name = "TestServer", }; if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { // Change to Arguments to "mcp-server-everything" if you want to run the server locally after creating a symlink - TestServerTransportOptions.Arguments = ["TestServer.dll"]; + TestServerTransportOptions.Arguments = [PlatformDetection.IsMonoRuntime ? "TestServer.exe" : "TestServer.dll"]; } } diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 38c688cce..d2080e1fc 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -606,7 +606,7 @@ public async Task HandlesIProgressParameter() McpClientTool progressTool = tools.First(t => t.Name == "sends_progress_notifications"); - TaskCompletionSource tcs = new(TaskCreationOptions.RunContinuationsAsynchronously); + TaskCompletionSource tcs = new(TaskCreationOptions.RunContinuationsAsynchronously); int remainingNotifications = 10; ConcurrentQueue notifications = new(); @@ -618,7 +618,7 @@ public async Task HandlesIProgressParameter() notifications.Enqueue(pn); if (Interlocked.Decrement(ref remainingNotifications) == 0) { - tcs.SetResult(); + tcs.SetResult(true); } } diff --git a/tests/ModelContextProtocol.Tests/EverythingSseServerFixture.cs b/tests/ModelContextProtocol.Tests/EverythingSseServerFixture.cs index ffd8859a0..7a019c896 100644 --- a/tests/ModelContextProtocol.Tests/EverythingSseServerFixture.cs +++ b/tests/ModelContextProtocol.Tests/EverythingSseServerFixture.cs @@ -49,7 +49,7 @@ public async ValueTask DisposeAsync() using var stopProcess = Process.Start(stopInfo) ?? throw new InvalidOperationException($"Could not stop process for {stopInfo.FileName} with '{stopInfo.Arguments}'."); - await stopProcess.WaitForExitAsync(); + await stopProcess.WaitForExitAsync(TimeSpan.FromSeconds(10)); } catch (Exception ex) { @@ -60,6 +60,7 @@ public async ValueTask DisposeAsync() private static bool CheckIsDockerAvailable() { +#if NET try { ProcessStartInfo processStartInfo = new() @@ -78,5 +79,9 @@ private static bool CheckIsDockerAvailable() { return false; } +#else + // Do not run docker tests using .NET framework. + return false; +#endif } } \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/GlobalUsings.cs b/tests/ModelContextProtocol.Tests/GlobalUsings.cs index c802f4480..6d129626d 100644 --- a/tests/ModelContextProtocol.Tests/GlobalUsings.cs +++ b/tests/ModelContextProtocol.Tests/GlobalUsings.cs @@ -1 +1,2 @@ global using Xunit; +global using System.Net.Http; \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj b/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj index 6ddad70ff..993564bf0 100644 --- a/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj +++ b/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj @@ -2,7 +2,7 @@ Exe - net9.0;net8.0 + net9.0;net8.0;net472 enable enable @@ -27,6 +27,11 @@ + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive @@ -44,7 +49,9 @@ + + @@ -60,16 +67,10 @@ - - PreserveNewest - - - PreserveNewest - - + PreserveNewest - + PreserveNewest diff --git a/tests/ModelContextProtocol.Tests/PlatformDetection.cs b/tests/ModelContextProtocol.Tests/PlatformDetection.cs new file mode 100644 index 000000000..1eef99420 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/PlatformDetection.cs @@ -0,0 +1,6 @@ +namespace ModelContextProtocol.Tests; + +internal static class PlatformDetection +{ + public static bool IsMonoRuntime { get; } = Type.GetType("Mono.Runtime") is not null; +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerDelegatesTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerDelegatesTests.cs index 97b63157f..30675f7b8 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerDelegatesTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerDelegatesTests.cs @@ -1,10 +1,18 @@ using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; +using System.Runtime.InteropServices; namespace ModelContextProtocol.Tests.Server; public class McpServerHandlerTests { + public McpServerHandlerTests() + { +#if !NET + Assert.SkipWhen(RuntimeInformation.IsOSPlatform(OSPlatform.Windows), "https://github.com/modelcontextprotocol/csharp-sdk/issues/587"); +#endif + } + [Fact] public void AllPropertiesAreSettable() { diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerLoggingLevelTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerLoggingLevelTests.cs index 7cdbdb5b1..b2e748730 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerLoggingLevelTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerLoggingLevelTests.cs @@ -1,11 +1,19 @@ using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; +using System.Runtime.InteropServices; namespace ModelContextProtocol.Tests.Server; public class McpServerLoggingLevelTests { + public McpServerLoggingLevelTests() + { +#if !NET + Assert.SkipWhen(RuntimeInformation.IsOSPlatform(OSPlatform.Windows), "https://github.com/modelcontextprotocol/csharp-sdk/issues/587"); +#endif + } + [Fact] public void CanCreateServerWithLoggingLevelHandler() { diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs index 90998e24b..39e9b72ff 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs @@ -7,6 +7,7 @@ using System.ComponentModel; using System.Diagnostics; using System.Reflection; +using System.Runtime.InteropServices; using System.Text.Json; using System.Text.Json.Nodes; @@ -14,6 +15,13 @@ namespace ModelContextProtocol.Tests.Server; public class McpServerPromptTests { + public McpServerPromptTests() + { +#if !NET + Assert.SkipWhen(RuntimeInformation.IsOSPlatform(OSPlatform.Windows), "https://github.com/modelcontextprotocol/csharp-sdk/issues/587"); +#endif + } + [Fact] public void Create_InvalidArgs_Throws() { diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs index fb0772d04..011c4f2b6 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs @@ -4,12 +4,20 @@ using ModelContextProtocol.Server; using Moq; using System.Reflection; +using System.Runtime.InteropServices; using System.Text.Json.Serialization; namespace ModelContextProtocol.Tests.Server; public partial class McpServerResourceTests { + public McpServerResourceTests() + { +#if !NET + Assert.SkipWhen(RuntimeInformation.IsOSPlatform(OSPlatform.Windows), "https://github.com/modelcontextprotocol/csharp-sdk/issues/587"); +#endif + } + [Fact] public void CanCreateServerWithResource() { @@ -191,6 +199,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() Assert.NotNull(result); Assert.Equal("14e5f43d0d4147d682078249cf669e411.2.3.4", ((TextResourceContents)result.Contents[0]).Text); +#if NET t = McpServerResource.Create((Half a2, Int128 a3, UInt128 a4, IntPtr a5) => (a3 + (Int128)a4 + a5).ToString(), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( @@ -206,6 +215,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("123", ((TextResourceContents)result.Contents[0]).Text); +#endif t = McpServerResource.Create((bool? a2, char? a3, byte? a4, sbyte? a5) => a2?.ToString() + a3 + a4 + a5, new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); @@ -239,6 +249,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() Assert.NotNull(result); Assert.Equal("14e5f43d0d4147d682078249cf669e41", ((TextResourceContents)result.Contents[0]).Text); +#if NET t = McpServerResource.Create((Half? a2, Int128? a3, UInt128? a4, IntPtr? a5) => (a3 + (Int128?)a4 + a5).ToString(), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( @@ -254,6 +265,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("123", ((TextResourceContents)result.Contents[0]).Text); +#endif } [Theory] diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index 260b9bdd7..6750b2cad 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -3,6 +3,7 @@ using ModelContextProtocol.Server; using ModelContextProtocol.Tests.Utils; using System.Reflection; +using System.Runtime.InteropServices; using System.Text.Json; using System.Text.Json.Nodes; @@ -15,6 +16,9 @@ public class McpServerTests : LoggedTest public McpServerTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper) { +#if !NET + Assert.SkipWhen(RuntimeInformation.IsOSPlatform(OSPlatform.Windows), "https://github.com/modelcontextprotocol/csharp-sdk/issues/587"); +#endif _options = CreateOptions(); } @@ -212,6 +216,7 @@ await Can_Handle_Requests( [Fact] public async Task Can_Handle_Initialize_Requests() { + AssemblyName expectedAssemblyName = (Assembly.GetEntryAssembly() ?? typeof(IMcpServer).Assembly).GetName(); await Can_Handle_Requests( serverCapabilities: null, method: RequestMethods.Initialize, @@ -220,8 +225,8 @@ await Can_Handle_Requests( { var result = JsonSerializer.Deserialize(response, McpJsonUtilities.DefaultOptions); Assert.NotNull(result); - Assert.Equal("ModelContextProtocol.Tests", result.ServerInfo.Name); - Assert.Equal("1.0.0.0", result.ServerInfo.Version); + Assert.Equal(expectedAssemblyName.Name, result.ServerInfo.Name); + Assert.Equal(expectedAssemblyName.Version?.ToString() ?? "1.0.0", result.ServerInfo.Version); Assert.Equal("2024", result.ProtocolVersion); }); } @@ -518,10 +523,10 @@ private async Task Can_Handle_Requests(ServerCapabilities? serverCapabilities, s }; await transport.SendMessageAsync( - new JsonRpcRequest - { - Method = method, - Id = new RequestId(55) + new JsonRpcRequest + { + Method = method, + Id = new RequestId(55) } ); diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs index 0f67f2a58..5cc6fa78a 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs @@ -7,6 +7,7 @@ using ModelContextProtocol.Tests.Utils; using Moq; using System.Reflection; +using System.Runtime.InteropServices; using System.Text.Json; using System.Text.Json.Nodes; using System.Text.Json.Serialization; @@ -17,6 +18,13 @@ namespace ModelContextProtocol.Tests.Server; public partial class McpServerToolTests { + public McpServerToolTests() + { +#if !NET + Assert.SkipWhen(RuntimeInformation.IsOSPlatform(OSPlatform.Windows), "https://github.com/modelcontextprotocol/csharp-sdk/issues/587"); +#endif + } + [Fact] public void Create_InvalidArgs_Throws() { @@ -525,7 +533,7 @@ public async Task StructuredOutput_Enabled_VoidReturningTools_ReturnsExpectedSch Assert.Null(tool.ProtocolTool.OutputSchema); Assert.Null(result.StructuredContent); - tool = McpServerTool.Create(() => ValueTask.CompletedTask); + tool = McpServerTool.Create(() => default(ValueTask)); request = new RequestContext(mockServer.Object) { Params = new CallToolRequestParams { Name = "tool" }, diff --git a/tests/ModelContextProtocol.Tests/StdioServerIntegrationTests.cs b/tests/ModelContextProtocol.Tests/StdioServerIntegrationTests.cs index db22ec244..f3927be62 100644 --- a/tests/ModelContextProtocol.Tests/StdioServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/StdioServerIntegrationTests.cs @@ -9,7 +9,7 @@ namespace ModelContextProtocol.Tests; public class StdioServerIntegrationTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper) { - public static bool CanSendSigInt { get; } = RuntimeInformation.IsOSPlatform(OSPlatform.Linux) || RuntimeInformation.IsOSPlatform(OSPlatform.OSX); + public static bool CanSendSigInt { get; } = (RuntimeInformation.IsOSPlatform(OSPlatform.Linux) || RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) && !PlatformDetection.IsMonoRuntime; private const int SIGINT = 2; [Fact(Skip = "Platform not supported by this test.", SkipUnless = nameof(CanSendSigInt))] @@ -46,9 +46,7 @@ public async Task SigInt_DisposesTestServerWithHosting_Gracefully() // https://github.com/dotnet/runtime/issues/109432, https://github.com/dotnet/runtime/issues/44944 Assert.Equal(0, kill(process.Id, SIGINT)); - using var shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); - shutdownCts.CancelAfter(TimeSpan.FromSeconds(10)); - await process.WaitForExitAsync(shutdownCts.Token); + await process.WaitForExitAsync(TimeSpan.FromSeconds(10)); Assert.True(process.HasExited); Assert.Equal(0, process.ExitCode); diff --git a/tests/ModelContextProtocol.Tests/Transport/SseResponseStreamTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseResponseStreamTransportTests.cs index 416d17193..b49542784 100644 --- a/tests/ModelContextProtocol.Tests/Transport/SseResponseStreamTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/SseResponseStreamTransportTests.cs @@ -15,10 +15,18 @@ public async Task Can_Customize_MessageEndpoint() var transportRunTask = transport.RunAsync(TestContext.Current.CancellationToken); using var responseStreamReader = new StreamReader(responsePipe.Reader.AsStream()); - var firstLine = await responseStreamReader.ReadLineAsync(TestContext.Current.CancellationToken); + var firstLine = await responseStreamReader.ReadLineAsync( +#if NET + TestContext.Current.CancellationToken +#endif + ); Assert.Equal("event: endpoint", firstLine); - var secondLine = await responseStreamReader.ReadLineAsync(TestContext.Current.CancellationToken); + var secondLine = await responseStreamReader.ReadLineAsync( +#if NET + TestContext.Current.CancellationToken +#endif + ); Assert.Equal("data: /my-message-endpoint", secondLine); responsePipe.Reader.Complete(); diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs index 40602a9ed..93cbcec82 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs @@ -7,6 +7,8 @@ namespace ModelContextProtocol.Tests.Transport; public class StdioClientTransportTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper) { + public static bool IsStdErrCallbackSupported => !PlatformDetection.IsMonoRuntime; + [Fact] public async Task CreateAsync_ValidProcessInvalidServer_Throws() { @@ -19,8 +21,8 @@ public async Task CreateAsync_ValidProcessInvalidServer_Throws() IOException e = await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); Assert.Contains(id, e.ToString()); } - - [Fact] + + [Fact(Skip = "Platform not supported by this test.", SkipUnless = nameof(IsStdErrCallbackSupported))] public async Task CreateAsync_ValidProcessInvalidServer_StdErrCallbackInvoked() { string id = Guid.NewGuid().ToString("N"); From 984aa9ae1b03a04833955d92423db3b6d2e04e41 Mon Sep 17 00:00:00 2001 From: David Parks Date: Wed, 9 Jul 2025 12:05:00 -0700 Subject: [PATCH 3/7] fix: Prevent crash when Options.ResourceMetadata is null but handled by event (#603) --- .../McpAuthenticationHandler.cs | 12 +- .../AuthEventTests.cs | 310 ++++++++++++++++++ 2 files changed, 316 insertions(+), 6 deletions(-) create mode 100644 tests/ModelContextProtocol.AspNetCore.Tests/AuthEventTests.cs diff --git a/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationHandler.cs b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationHandler.cs index 942db1b65..f8c6f41cd 100644 --- a/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationHandler.cs @@ -43,8 +43,7 @@ public async Task HandleRequestAsync() return false; } - var cancellationToken = Request.HttpContext.RequestAborted; - await HandleResourceMetadataRequestAsync(cancellationToken); + await HandleResourceMetadataRequestAsync(); return true; } @@ -82,8 +81,7 @@ private string GetAbsoluteResourceMetadataUri() /// /// Handles the resource metadata request. /// - /// A token to cancel the operation. - private async Task HandleResourceMetadataRequestAsync(CancellationToken cancellationToken = default) + private async Task HandleResourceMetadataRequestAsync() { var resourceMetadata = Options.ResourceMetadata; @@ -95,12 +93,14 @@ private async Task HandleResourceMetadataRequestAsync(CancellationToken cancella }; await Options.Events.OnResourceMetadataRequest(context); + resourceMetadata = context.ResourceMetadata; } - if (resourceMetadata == null) { - throw new InvalidOperationException("ResourceMetadata has not been configured. Please set McpAuthenticationOptions.ResourceMetadata."); + throw new InvalidOperationException( + "ResourceMetadata has not been configured. Please set McpAuthenticationOptions.ResourceMetadata or ensure context.ResourceMetadata is set inside McpAuthenticationOptions.Events.OnResourceMetadataRequest." + ); } await Results.Json(resourceMetadata, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ProtectedResourceMetadata))).ExecuteAsync(Context); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/AuthEventTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/AuthEventTests.cs new file mode 100644 index 000000000..6a48c21d2 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/AuthEventTests.cs @@ -0,0 +1,310 @@ +using System.Net; +using System.Net.Http.Json; +using System.Text.Json; +using Microsoft.AspNetCore.Authentication.JwtBearer; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.WebUtilities; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.IdentityModel.Tokens; +using ModelContextProtocol.AspNetCore.Authentication; +using ModelContextProtocol.AspNetCore.Tests.Utils; +using ModelContextProtocol.Authentication; +using ModelContextProtocol.Client; + +namespace ModelContextProtocol.AspNetCore.Tests; + +/// +/// Tests for MCP authentication when resource metadata is provided via events rather than static configuration. +/// +public class AuthEventTests : KestrelInMemoryTest, IAsyncDisposable +{ + private const string McpServerUrl = "http://localhost:5000/"; + private const string OAuthServerUrl = "https://localhost:7029/"; + + private readonly CancellationTokenSource _testCts = new(); + private readonly TestOAuthServer.Program _testOAuthServer; + private readonly Task _testOAuthRunTask; + + public AuthEventTests(ITestOutputHelper outputHelper) + : base(outputHelper) + { + // Let the HandleAuthorizationUrlAsync take a look at the Location header + SocketsHttpHandler.AllowAutoRedirect = false; + // The dev cert may not be installed on the CI, but AddJwtBearer requires an HTTPS backchannel by default. + // The easiest workaround is to disable cert validation for testing purposes. + SocketsHttpHandler.SslOptions.RemoteCertificateValidationCallback = (_, _, _, _) => true; + + _testOAuthServer = new TestOAuthServer.Program( + XunitLoggerProvider, + KestrelInMemoryTransport + ); + _testOAuthRunTask = _testOAuthServer.RunServerAsync(cancellationToken: _testCts.Token); + + Builder + .Services.AddAuthentication(options => + { + options.DefaultChallengeScheme = McpAuthenticationDefaults.AuthenticationScheme; + options.DefaultAuthenticateScheme = JwtBearerDefaults.AuthenticationScheme; + }) + .AddJwtBearer(options => + { + options.Backchannel = HttpClient; + options.Authority = OAuthServerUrl; + options.TokenValidationParameters = new TokenValidationParameters + { + ValidateIssuer = true, + ValidateAudience = true, + ValidateLifetime = true, + ValidateIssuerSigningKey = true, + ValidAudience = McpServerUrl, + ValidIssuer = OAuthServerUrl, + NameClaimType = "name", + RoleClaimType = "roles", + }; + }) + .AddMcp(options => + { + // Note: ResourceMetadata is NOT set here - it will be provided via events + options.Events.OnResourceMetadataRequest = async context => + { + // Dynamically provide the resource metadata + context.ResourceMetadata = new ProtectedResourceMetadata + { + Resource = new Uri(McpServerUrl), + AuthorizationServers = { new Uri(OAuthServerUrl) }, + ScopesSupported = ["mcp:tools"], + }; + await Task.CompletedTask; + }; + }); + + Builder.Services.AddAuthorization(); + } + + public async ValueTask DisposeAsync() + { + _testCts.Cancel(); + try + { + await _testOAuthRunTask; + } + catch (OperationCanceledException) { } + finally + { + _testCts.Dispose(); + } + } + + [Fact] + public async Task CanAuthenticate_WithResourceMetadataFromEvent() + { + Builder.Services.AddMcpServer().WithHttpTransport(); + + await using var app = Builder.Build(); + + app.MapMcp().RequireAuthorization(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var transport = new SseClientTransport( + new() + { + Endpoint = new(McpServerUrl), + OAuth = new() + { + ClientId = "demo-client", + ClientSecret = "demo-secret", + RedirectUri = new Uri("http://localhost:1179/callback"), + AuthorizationRedirectDelegate = HandleAuthorizationUrlAsync, + }, + }, + HttpClient, + LoggerFactory + ); + + await using var client = await McpClientFactory.CreateAsync( + transport, + loggerFactory: LoggerFactory, + cancellationToken: TestContext.Current.CancellationToken + ); + } + + [Fact] + public async Task CanAuthenticate_WithDynamicClientRegistration_FromEvent() + { + Builder.Services.AddMcpServer().WithHttpTransport(); + + await using var app = Builder.Build(); + + app.MapMcp().RequireAuthorization(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var transport = new SseClientTransport( + new() + { + Endpoint = new(McpServerUrl), + OAuth = new ClientOAuthOptions() + { + RedirectUri = new Uri("http://localhost:1179/callback"), + AuthorizationRedirectDelegate = HandleAuthorizationUrlAsync, + ClientName = "Test MCP Client", + ClientUri = new Uri("https://example.com/"), + Scopes = ["mcp:tools"], + }, + }, + HttpClient, + LoggerFactory + ); + + await using var client = await McpClientFactory.CreateAsync( + transport, + loggerFactory: LoggerFactory, + cancellationToken: TestContext.Current.CancellationToken + ); + } + + [Fact] + public async Task ResourceMetadataEndpoint_ReturnsCorrectMetadata_FromEvent() + { + Builder.Services.AddMcpServer().WithHttpTransport(); + + await using var app = Builder.Build(); + + app.MapMcp().RequireAuthorization(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + // Make a direct request to the resource metadata endpoint + using var response = await HttpClient.GetAsync( + "/.well-known/oauth-protected-resource", + TestContext.Current.CancellationToken + ); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + var metadata = await response.Content.ReadFromJsonAsync( + McpJsonUtilities.DefaultOptions, + TestContext.Current.CancellationToken + ); + + Assert.NotNull(metadata); + Assert.Equal(new Uri(McpServerUrl), metadata.Resource); + Assert.Contains(new Uri(OAuthServerUrl), metadata.AuthorizationServers); + Assert.Contains("mcp:tools", metadata.ScopesSupported); + } + + [Fact] + public async Task ResourceMetadataEndpoint_CanModifyExistingMetadata_InEvent() + { + Builder.Services.AddMcpServer().WithHttpTransport(); + + // Override the configuration to test modification of existing metadata + Builder.Services.Configure( + McpAuthenticationDefaults.AuthenticationScheme, + options => + { + // Set initial metadata + options.ResourceMetadata = new ProtectedResourceMetadata + { + Resource = new Uri(McpServerUrl), + AuthorizationServers = { new Uri(OAuthServerUrl) }, + ScopesSupported = ["mcp:basic"], + }; + + // Override the event to modify the metadata + options.Events.OnResourceMetadataRequest = async context => + { + // Start with the existing metadata and modify it + if (context.ResourceMetadata != null) + { + context.ResourceMetadata.ScopesSupported.Add("mcp:tools"); + context.ResourceMetadata.ResourceName = "Dynamic Test Resource"; + } + await Task.CompletedTask; + }; + } + ); + + await using var app = Builder.Build(); + + app.MapMcp().RequireAuthorization(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + // Make a direct request to the resource metadata endpoint + using var response = await HttpClient.GetAsync( + "/.well-known/oauth-protected-resource", + TestContext.Current.CancellationToken + ); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + var metadata = await response.Content.ReadFromJsonAsync( + McpJsonUtilities.DefaultOptions, + TestContext.Current.CancellationToken + ); + + Assert.NotNull(metadata); + Assert.Equal(new Uri(McpServerUrl), metadata.Resource); + Assert.Contains(new Uri(OAuthServerUrl), metadata.AuthorizationServers); + Assert.Contains("mcp:basic", metadata.ScopesSupported); + Assert.Contains("mcp:tools", metadata.ScopesSupported); + Assert.Equal("Dynamic Test Resource", metadata.ResourceName); + } + + [Fact] + public async Task ResourceMetadataEndpoint_ThrowsException_WhenNoMetadataProvided() + { + Builder.Services.AddMcpServer().WithHttpTransport(); + + // Override the configuration to test the error case where no metadata is provided + Builder.Services.Configure( + McpAuthenticationDefaults.AuthenticationScheme, + options => + { + // Don't set ResourceMetadata and provide an event that doesn't set it either + options.ResourceMetadata = null; + options.Events.OnResourceMetadataRequest = async context => + { + // Intentionally don't set context.ResourceMetadata to test error handling + await Task.CompletedTask; + }; + } + ); + + await using var app = Builder.Build(); + + app.MapMcp().RequireAuthorization(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + // Make a direct request to the resource metadata endpoint - this should fail + using var response = await HttpClient.GetAsync( + "/.well-known/oauth-protected-resource", + TestContext.Current.CancellationToken + ); + + // The request should fail with an internal server error due to the InvalidOperationException + Assert.Equal(HttpStatusCode.InternalServerError, response.StatusCode); + } + + private async Task HandleAuthorizationUrlAsync( + Uri authorizationUri, + Uri redirectUri, + CancellationToken cancellationToken + ) + { + using var redirectResponse = await HttpClient.GetAsync(authorizationUri, cancellationToken); + Assert.Equal(HttpStatusCode.Redirect, redirectResponse.StatusCode); + var location = redirectResponse.Headers.Location; + + if (location is not null && !string.IsNullOrEmpty(location.Query)) + { + var queryParams = QueryHelpers.ParseQuery(location.Query); + return queryParams["code"]; + } + + return null; + } +} From 12324561536c86e5429b18aafb57248ac5f4bacf Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Wed, 9 Jul 2025 23:41:57 -0400 Subject: [PATCH 4/7] Update to M.E.AI 9.7.0 (#602) --- Directory.Packages.props | 6 +++--- samples/AspNetCoreSseServer/Tools/SampleLlmTool.cs | 9 ++------- .../Server/McpServerExtensions.cs | 5 +++++ 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/Directory.Packages.props b/Directory.Packages.props index 70eb82f3a..4df4ea73a 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -3,7 +3,7 @@ true 9.0.5 10.0.0-preview.4.25258.110 - 9.6.0 + 9.7.0 @@ -13,7 +13,7 @@ - + @@ -53,7 +53,7 @@ all - + diff --git a/samples/AspNetCoreSseServer/Tools/SampleLlmTool.cs b/samples/AspNetCoreSseServer/Tools/SampleLlmTool.cs index 4fbca594a..247619dbb 100644 --- a/samples/AspNetCoreSseServer/Tools/SampleLlmTool.cs +++ b/samples/AspNetCoreSseServer/Tools/SampleLlmTool.cs @@ -17,19 +17,14 @@ public static async Task SampleLLM( [Description("Maximum number of tokens to generate")] int maxTokens, CancellationToken cancellationToken) { - ChatMessage[] messages = - [ - new(ChatRole.System, "You are a helpful test server."), - new(ChatRole.User, prompt), - ]; - ChatOptions options = new() { + Instructions = "You are a helpful test server.", MaxOutputTokens = maxTokens, Temperature = 0.7f, }; - var samplingResponse = await thisServer.AsSamplingChatClient().GetResponseAsync(messages, options, cancellationToken); + var samplingResponse = await thisServer.AsSamplingChatClient().GetResponseAsync(prompt, options, cancellationToken); return $"LLM sampling result: {samplingResponse}"; } diff --git a/src/ModelContextProtocol.Core/Server/McpServerExtensions.cs b/src/ModelContextProtocol.Core/Server/McpServerExtensions.cs index d00c41a6b..277ed737b 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerExtensions.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerExtensions.cs @@ -64,6 +64,11 @@ public static async Task SampleAsync( StringBuilder? systemPrompt = null; + if (options?.Instructions is { } instructions) + { + (systemPrompt ??= new()).Append(instructions); + } + List samplingMessages = []; foreach (var message in messages) { From 5f1c74f1ffa8eb7d1f68029691427556d03727ee Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Mon, 14 Jul 2025 10:30:48 -0400 Subject: [PATCH 5/7] Ensure IsExternalInit is type forwarded on NET builds (#619) It's included as internal on netstandard/net472, and the C# compiler may bake a reference to that into a consumer. If that consumer is then used with a net8.0+ build, the IsExternalInit needs to be there and forwarded to the real one. This switches our polyfill files to always be included but the actual contents ifdef'd out on TFMs that already have the contents. That then makes it easier to do specialized ifdef'ing in the future, as this does for IsExternalInit. --- .../System/Collections/Generic/CollectionExtensions.cs | 4 +++- .../CodeAnalysis/DynamicallyAccessedMemberTypes.cs | 2 ++ .../CodeAnalysis/DynamicallyAccessedMembersAttribute.cs | 2 ++ .../System/Diagnostics/CodeAnalysis/NullableAttributes.cs | 2 ++ .../CodeAnalysis/RequiresDynamicCodeAttribute.cs | 2 ++ .../Diagnostics/CodeAnalysis/RequiresUnreferencedCode.cs | 2 ++ .../CodeAnalysis/SetsRequiredMembersAttribute.cs | 2 ++ .../Diagnostics/CodeAnalysis/StringSyntaxAttribute.cs | 4 +++- .../CodeAnalysis/UnconditionalSuppressMessageAttribute.cs | 4 +++- src/Common/Polyfills/System/IO/StreamExtensions.cs | 4 +++- src/Common/Polyfills/System/IO/TextWriterExtensions.cs | 4 +++- .../Polyfills/System/Net/Http/HttpClientExtensions.cs | 4 +++- src/Common/Polyfills/System/PasteArguments.cs | 4 +++- .../CompilerServices/CallerArgumentExpressionAttribute.cs | 2 ++ .../CompilerServices/CompilerFeatureRequiredAttribute.cs | 2 ++ .../CompilerServices/DefaultInterpolatedStringHandler.cs | 4 +++- .../System/Runtime/CompilerServices/IsExternalInit.cs | 6 ++++++ .../Runtime/CompilerServices/RequiredMemberAttribute.cs | 2 ++ .../System/Threading/CancellationTokenSourceExtensions.cs | 4 +++- .../System/Threading/Channels/ChannelExtensions.cs | 4 +++- src/Common/Polyfills/System/Threading/ForceYielding.cs | 4 +++- .../Polyfills/System/Threading/Tasks/TaskExtensions.cs | 4 +++- src/Directory.Build.props | 4 ++++ .../ModelContextProtocol.Core.csproj | 1 - src/ModelContextProtocol/ModelContextProtocol.csproj | 4 ---- 25 files changed, 64 insertions(+), 17 deletions(-) diff --git a/src/Common/Polyfills/System/Collections/Generic/CollectionExtensions.cs b/src/Common/Polyfills/System/Collections/Generic/CollectionExtensions.cs index ae4f697bc..fe5e09931 100644 --- a/src/Common/Polyfills/System/Collections/Generic/CollectionExtensions.cs +++ b/src/Common/Polyfills/System/Collections/Generic/CollectionExtensions.cs @@ -1,3 +1,4 @@ +#if !NET using ModelContextProtocol; namespace System.Collections.Generic; @@ -18,4 +19,5 @@ public static TValue GetValueOrDefault(this IReadOnlyDictionary ToDictionary(this IEnumerable> source) => source.ToDictionary(kv => kv.Key, kv => kv.Value); -} \ No newline at end of file +} +#endif \ No newline at end of file diff --git a/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/DynamicallyAccessedMemberTypes.cs b/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/DynamicallyAccessedMemberTypes.cs index ee6fa51a3..fcb09e4ff 100644 --- a/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/DynamicallyAccessedMemberTypes.cs +++ b/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/DynamicallyAccessedMemberTypes.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +#if !NET namespace System.Diagnostics.CodeAnalysis; /// @@ -162,3 +163,4 @@ internal enum DynamicallyAccessedMemberTypes /// All = ~None } +#endif \ No newline at end of file diff --git a/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/DynamicallyAccessedMembersAttribute.cs b/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/DynamicallyAccessedMembersAttribute.cs index 2d0140477..c99bb8e0a 100644 --- a/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/DynamicallyAccessedMembersAttribute.cs +++ b/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/DynamicallyAccessedMembersAttribute.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +#if !NET namespace System.Diagnostics.CodeAnalysis; /// @@ -48,3 +49,4 @@ public DynamicallyAccessedMembersAttribute(DynamicallyAccessedMemberTypes member /// public DynamicallyAccessedMemberTypes MemberTypes { get; } } +#endif \ No newline at end of file diff --git a/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/NullableAttributes.cs b/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/NullableAttributes.cs index ef577a9ca..0e7425e01 100644 --- a/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/NullableAttributes.cs +++ b/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/NullableAttributes.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +#if !NET namespace System.Diagnostics.CodeAnalysis { /// Specifies that null is allowed as an input even if the corresponding type disallows it. @@ -137,3 +138,4 @@ public MemberNotNullWhenAttribute(bool returnValue, params string[] members) public string[] Members { get; } } } +#endif \ No newline at end of file diff --git a/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/RequiresDynamicCodeAttribute.cs b/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/RequiresDynamicCodeAttribute.cs index 817ec6eaa..554a699a4 100644 --- a/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/RequiresDynamicCodeAttribute.cs +++ b/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/RequiresDynamicCodeAttribute.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +#if !NET namespace System.Diagnostics.CodeAnalysis; /// @@ -36,3 +37,4 @@ public RequiresDynamicCodeAttribute(string message) /// public string? Url { get; set; } } +#endif \ No newline at end of file diff --git a/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/RequiresUnreferencedCode.cs b/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/RequiresUnreferencedCode.cs index 3e845a534..eb91908ec 100644 --- a/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/RequiresUnreferencedCode.cs +++ b/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/RequiresUnreferencedCode.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +#if !NET namespace System.Diagnostics.CodeAnalysis; /// @@ -37,3 +38,4 @@ public RequiresUnreferencedCodeAttribute(string message) /// public string? Url { get; set; } } +#endif \ No newline at end of file diff --git a/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/SetsRequiredMembersAttribute.cs b/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/SetsRequiredMembersAttribute.cs index 83d6793b3..b778c95fc 100644 --- a/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/SetsRequiredMembersAttribute.cs +++ b/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/SetsRequiredMembersAttribute.cs @@ -1,5 +1,7 @@ +#if !NET namespace System.Diagnostics.CodeAnalysis { [AttributeUsage(AttributeTargets.Constructor, AllowMultiple = false, Inherited = false)] internal sealed class SetsRequiredMembersAttribute : Attribute; } +#endif \ No newline at end of file diff --git a/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/StringSyntaxAttribute.cs b/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/StringSyntaxAttribute.cs index a8ab9bd28..1b884fa14 100644 --- a/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/StringSyntaxAttribute.cs +++ b/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/StringSyntaxAttribute.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +#if !NET namespace System.Diagnostics.CodeAnalysis; /// Specifies the syntax used in a string. @@ -65,4 +66,5 @@ public StringSyntaxAttribute(string syntax, params object?[] arguments) /// The syntax identifier for strings containing XML. public const string Xml = nameof(Xml); -} \ No newline at end of file +} +#endif \ No newline at end of file diff --git a/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/UnconditionalSuppressMessageAttribute.cs b/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/UnconditionalSuppressMessageAttribute.cs index b06d9ed1a..db80655a7 100644 --- a/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/UnconditionalSuppressMessageAttribute.cs +++ b/src/Common/Polyfills/System/Diagnostics/CodeAnalysis/UnconditionalSuppressMessageAttribute.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +#if !NET namespace System.Diagnostics.CodeAnalysis; /// @@ -81,4 +82,5 @@ public UnconditionalSuppressMessageAttribute(string category, string checkId) /// Gets or sets the justification for suppressing the code analysis message. /// public string? Justification { get; set; } -} \ No newline at end of file +} +#endif \ No newline at end of file diff --git a/src/Common/Polyfills/System/IO/StreamExtensions.cs b/src/Common/Polyfills/System/IO/StreamExtensions.cs index 4dc8e2a5a..452b80321 100644 --- a/src/Common/Polyfills/System/IO/StreamExtensions.cs +++ b/src/Common/Polyfills/System/IO/StreamExtensions.cs @@ -3,6 +3,7 @@ using System.Runtime.InteropServices; using System.Text; +#if !NET namespace System.IO; internal static class StreamExtensions @@ -61,4 +62,5 @@ static async ValueTask ReadAsyncCore(Stream stream, Memory buffer, Ca } } } -} \ No newline at end of file +} +#endif \ No newline at end of file diff --git a/src/Common/Polyfills/System/IO/TextWriterExtensions.cs b/src/Common/Polyfills/System/IO/TextWriterExtensions.cs index 637cc09b0..a8dabd1fc 100644 --- a/src/Common/Polyfills/System/IO/TextWriterExtensions.cs +++ b/src/Common/Polyfills/System/IO/TextWriterExtensions.cs @@ -1,3 +1,4 @@ +#if !NET namespace System.IO; internal static class TextWriterExtensions @@ -7,4 +8,5 @@ public static async Task FlushAsync(this TextWriter writer, CancellationToken ca cancellationToken.ThrowIfCancellationRequested(); await writer.FlushAsync(); } -} \ No newline at end of file +} +#endif \ No newline at end of file diff --git a/src/Common/Polyfills/System/Net/Http/HttpClientExtensions.cs b/src/Common/Polyfills/System/Net/Http/HttpClientExtensions.cs index 85612b1d5..96a2948fb 100644 --- a/src/Common/Polyfills/System/Net/Http/HttpClientExtensions.cs +++ b/src/Common/Polyfills/System/Net/Http/HttpClientExtensions.cs @@ -1,3 +1,4 @@ +#if !NET using ModelContextProtocol; namespace System.Net.Http; @@ -19,4 +20,5 @@ public static async Task ReadAsStringAsync(this HttpContent content, Can cancellationToken.ThrowIfCancellationRequested(); return await content.ReadAsStringAsync(); } -} \ No newline at end of file +} +#endif \ No newline at end of file diff --git a/src/Common/Polyfills/System/PasteArguments.cs b/src/Common/Polyfills/System/PasteArguments.cs index 32eb4c69f..d838ec023 100644 --- a/src/Common/Polyfills/System/PasteArguments.cs +++ b/src/Common/Polyfills/System/PasteArguments.cs @@ -5,6 +5,7 @@ // https://github.com/dotnet/runtime/blob/d2650b6ae7023a2d9d2c74c56116f1f18472ab04/src/libraries/System.Private.CoreLib/src/System/PasteArguments.cs // and changed from using ValueStringBuilder to StringBuilder. +#if !NET using System.Text; namespace System; @@ -98,4 +99,5 @@ private static bool ContainsNoWhitespaceOrQuotes(string s) private const char Quote = '\"'; private const char Backslash = '\\'; -} \ No newline at end of file +} +#endif \ No newline at end of file diff --git a/src/Common/Polyfills/System/Runtime/CompilerServices/CallerArgumentExpressionAttribute.cs b/src/Common/Polyfills/System/Runtime/CompilerServices/CallerArgumentExpressionAttribute.cs index 968c31e8a..553afbea1 100644 --- a/src/Common/Polyfills/System/Runtime/CompilerServices/CallerArgumentExpressionAttribute.cs +++ b/src/Common/Polyfills/System/Runtime/CompilerServices/CallerArgumentExpressionAttribute.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +#if !NET namespace System.Runtime.CompilerServices; [AttributeUsage(AttributeTargets.Parameter, AllowMultiple = false, Inherited = false)] @@ -13,3 +14,4 @@ public CallerArgumentExpressionAttribute(string parameterName) public string ParameterName { get; } } +#endif \ No newline at end of file diff --git a/src/Common/Polyfills/System/Runtime/CompilerServices/CompilerFeatureRequiredAttribute.cs b/src/Common/Polyfills/System/Runtime/CompilerServices/CompilerFeatureRequiredAttribute.cs index 12f3e5d2d..1df9c1c21 100644 --- a/src/Common/Polyfills/System/Runtime/CompilerServices/CompilerFeatureRequiredAttribute.cs +++ b/src/Common/Polyfills/System/Runtime/CompilerServices/CompilerFeatureRequiredAttribute.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +#if !NET namespace System.Runtime.CompilerServices { /// @@ -30,3 +31,4 @@ public CompilerFeatureRequiredAttribute(string featureName) public const string RequiredMembers = nameof(RequiredMembers); } } +#endif \ No newline at end of file diff --git a/src/Common/Polyfills/System/Runtime/CompilerServices/DefaultInterpolatedStringHandler.cs b/src/Common/Polyfills/System/Runtime/CompilerServices/DefaultInterpolatedStringHandler.cs index 244f0875e..24622096f 100644 --- a/src/Common/Polyfills/System/Runtime/CompilerServices/DefaultInterpolatedStringHandler.cs +++ b/src/Common/Polyfills/System/Runtime/CompilerServices/DefaultInterpolatedStringHandler.cs @@ -5,6 +5,7 @@ // https://github.com/dotnet/runtime/blob/dd75c45c123055baacd7aa4418f425f412797a29/src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/DefaultInterpolatedStringHandler.cs // and then modified to build on netstandard2.0. +#if !NET using System.Buffers; using System.Diagnostics; using System.Globalization; @@ -614,4 +615,5 @@ private static uint Clamp(uint value, uint min, uint max) return value; } } -} \ No newline at end of file +} +#endif \ No newline at end of file diff --git a/src/Common/Polyfills/System/Runtime/CompilerServices/IsExternalInit.cs b/src/Common/Polyfills/System/Runtime/CompilerServices/IsExternalInit.cs index 70443090c..9ae535381 100644 --- a/src/Common/Polyfills/System/Runtime/CompilerServices/IsExternalInit.cs +++ b/src/Common/Polyfills/System/Runtime/CompilerServices/IsExternalInit.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +#if !NET using System.ComponentModel; namespace System.Runtime.CompilerServices @@ -12,3 +13,8 @@ namespace System.Runtime.CompilerServices [EditorBrowsable(EditorBrowsableState.Never)] internal static class IsExternalInit; } +#else +// The compiler emits a reference to the internal copy of this type in the non-.NET builds, +// so we must include a forward to be compatible. +[assembly: System.Runtime.CompilerServices.TypeForwardedTo(typeof(System.Runtime.CompilerServices.IsExternalInit))] +#endif \ No newline at end of file diff --git a/src/Common/Polyfills/System/Runtime/CompilerServices/RequiredMemberAttribute.cs b/src/Common/Polyfills/System/Runtime/CompilerServices/RequiredMemberAttribute.cs index 6930dc4f1..35b6948a7 100644 --- a/src/Common/Polyfills/System/Runtime/CompilerServices/RequiredMemberAttribute.cs +++ b/src/Common/Polyfills/System/Runtime/CompilerServices/RequiredMemberAttribute.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +#if !NET using System.ComponentModel; namespace System.Runtime.CompilerServices @@ -10,3 +11,4 @@ namespace System.Runtime.CompilerServices [EditorBrowsable(EditorBrowsableState.Never)] internal sealed class RequiredMemberAttribute : Attribute; } +#endif \ No newline at end of file diff --git a/src/Common/Polyfills/System/Threading/CancellationTokenSourceExtensions.cs b/src/Common/Polyfills/System/Threading/CancellationTokenSourceExtensions.cs index 95acac96c..d5508153f 100644 --- a/src/Common/Polyfills/System/Threading/CancellationTokenSourceExtensions.cs +++ b/src/Common/Polyfills/System/Threading/CancellationTokenSourceExtensions.cs @@ -1,3 +1,4 @@ +#if !NET using ModelContextProtocol; namespace System.Threading.Tasks; @@ -11,4 +12,5 @@ public static Task CancelAsync(this CancellationTokenSource cancellationTokenSou cancellationTokenSource.Cancel(); return Task.CompletedTask; } -} \ No newline at end of file +} +#endif \ No newline at end of file diff --git a/src/Common/Polyfills/System/Threading/Channels/ChannelExtensions.cs b/src/Common/Polyfills/System/Threading/Channels/ChannelExtensions.cs index 89822eff1..6cda43e94 100644 --- a/src/Common/Polyfills/System/Threading/Channels/ChannelExtensions.cs +++ b/src/Common/Polyfills/System/Threading/Channels/ChannelExtensions.cs @@ -1,3 +1,4 @@ +#if !NET using System.Runtime.CompilerServices; namespace System.Threading.Channels; @@ -14,4 +15,5 @@ public static async IAsyncEnumerable ReadAllAsync(this ChannelReader re } } } -} \ No newline at end of file +} +#endif \ No newline at end of file diff --git a/src/Common/Polyfills/System/Threading/ForceYielding.cs b/src/Common/Polyfills/System/Threading/ForceYielding.cs index a25baa977..4ce99c87c 100644 --- a/src/Common/Polyfills/System/Threading/ForceYielding.cs +++ b/src/Common/Polyfills/System/Threading/ForceYielding.cs @@ -1,3 +1,4 @@ +#if !NET using System.Runtime.CompilerServices; namespace System.Threading; @@ -14,4 +15,5 @@ namespace System.Threading; public void OnCompleted(Action continuation) => ThreadPool.QueueUserWorkItem(a => ((Action)a!)(), continuation); public void UnsafeOnCompleted(Action continuation) => ThreadPool.UnsafeQueueUserWorkItem(a => ((Action)a!)(), continuation); public void GetResult() { } -} \ No newline at end of file +} +#endif \ No newline at end of file diff --git a/src/Common/Polyfills/System/Threading/Tasks/TaskExtensions.cs b/src/Common/Polyfills/System/Threading/Tasks/TaskExtensions.cs index bee89a25d..68eb073d7 100644 --- a/src/Common/Polyfills/System/Threading/Tasks/TaskExtensions.cs +++ b/src/Common/Polyfills/System/Threading/Tasks/TaskExtensions.cs @@ -1,3 +1,4 @@ +#if !NET using ModelContextProtocol; namespace System.Threading.Tasks; @@ -49,4 +50,5 @@ public static async Task WaitAsync(this Task task, TimeSpan timeout, Cancellatio await task.ConfigureAwait(false); } -} \ No newline at end of file +} +#endif \ No newline at end of file diff --git a/src/Directory.Build.props b/src/Directory.Build.props index b8408bacd..7859ba39a 100644 --- a/src/Directory.Build.props +++ b/src/Directory.Build.props @@ -23,6 +23,10 @@ + + + + diff --git a/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj b/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj index f3ab71819..07a5ec1b0 100644 --- a/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj +++ b/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj @@ -20,7 +20,6 @@ - diff --git a/src/ModelContextProtocol/ModelContextProtocol.csproj b/src/ModelContextProtocol/ModelContextProtocol.csproj index 994f3dcc5..963ba0fed 100644 --- a/src/ModelContextProtocol/ModelContextProtocol.csproj +++ b/src/ModelContextProtocol/ModelContextProtocol.csproj @@ -17,10 +17,6 @@ - - - - From a88ef0deb712ea86f0bea1c5680fa6a227242c78 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Tue, 15 Jul 2025 13:58:43 -0700 Subject: [PATCH 6/7] Flow ExecutionContext with JsonRpcMessage (#616) The primary goal of this change is to support IHttpContextAccessor in tool calls when the Streamable HTTP is in its default non-Stateless mode. --- .../HttpServerTransportOptions.cs | 14 ++++++++++++++ .../StreamableHttpHandler.cs | 1 + src/ModelContextProtocol.Core/McpSession.cs | 15 ++++++++++++--- .../Protocol/JsonRpcMessage.cs | 14 ++++++++++++++ .../Server/StreamableHttpPostTransport.cs | 5 +++++ .../Server/StreamableHttpServerTransport.cs | 17 +++++++++++++---- .../MapMcpTests.cs | 6 ------ .../StreamableHttpServerConformanceTests.cs | 3 ++- 8 files changed, 61 insertions(+), 14 deletions(-) diff --git a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs index 677606eb6..2a34a17a1 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs @@ -35,6 +35,20 @@ public class HttpServerTransportOptions /// public bool Stateless { get; set; } + /// + /// Gets or sets whether the server should use a single execution context for the entire session. + /// If , handlers like tools get called with the + /// belonging to the corresponding HTTP request which can change throughout the MCP session. + /// If , handlers will get called with the same + /// used to call and . + /// + /// + /// Enabling a per-session can be useful for setting variables + /// that persist for the entire session, but it prevents you from using IHttpContextAccessor in handlers. + /// Defaults to . + /// + public bool PerSessionExecutionContext { get; set; } + /// /// Gets or sets the duration of time the server will wait between any active requests before timing out an MCP session. /// diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index aeac38bf0..6dac1c3e4 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -188,6 +188,7 @@ private async ValueTask> StartNewS transport = new() { SessionId = sessionId, + FlowExecutionContextFromRequests = !HttpServerTransportOptions.PerSessionExecutionContext, }; context.Response.Headers[McpSessionIdHeaderName] = sessionId; } diff --git a/src/ModelContextProtocol.Core/McpSession.cs b/src/ModelContextProtocol.Core/McpSession.cs index 47c4d212b..06b2894b0 100644 --- a/src/ModelContextProtocol.Core/McpSession.cs +++ b/src/ModelContextProtocol.Core/McpSession.cs @@ -115,7 +115,16 @@ public async Task ProcessMessagesAsync(CancellationToken cancellationToken) LogMessageRead(EndpointName, message.GetType().Name); // Fire and forget the message handling to avoid blocking the transport. - _ = ProcessMessageAsync(); + if (message.ExecutionContext is null) + { + _ = ProcessMessageAsync(); + } + else + { + // Flow the execution context from the HTTP request corresponding to this message if provided. + ExecutionContext.Run(message.ExecutionContext, _ => _ = ProcessMessageAsync(), null); + } + async Task ProcessMessageAsync() { JsonRpcMessageWithId? messageWithId = message as JsonRpcMessageWithId; @@ -609,9 +618,9 @@ private static void AddExceptionTags(ref TagList tags, Activity? activity, Excep e = ae.InnerException; } - int? intErrorCode = + int? intErrorCode = (int?)((e as McpException)?.ErrorCode) is int errorCode ? errorCode : - e is JsonException ? (int)McpErrorCode.ParseError : + e is JsonException ? (int)McpErrorCode.ParseError : null; string? errorType = intErrorCode?.ToString() ?? e.GetType().FullName; diff --git a/src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs index 77866add2..b3176937c 100644 --- a/src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs +++ b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs @@ -1,3 +1,4 @@ +using ModelContextProtocol.Server; using System.ComponentModel; using System.Text.Json; using System.Text.Json.Serialization; @@ -38,6 +39,19 @@ private protected JsonRpcMessage() [JsonIgnore] public ITransport? RelatedTransport { get; set; } + /// + /// Gets or sets the that should be used to run any handlers + /// + /// + /// This is used to support the Streamable HTTP transport in its default stateful mode. In this mode, + /// the outlives the initial HTTP request context it was created on, and new + /// JSON-RPC messages can originate from future HTTP requests. This allows the transport to flow the + /// context with the JSON-RPC message. This is particularly useful for enabling IHttpContextAccessor + /// in tool calls. + /// + [JsonIgnore] + public ExecutionContext? ExecutionContext { get; set; } + /// /// Provides a for messages, /// handling polymorphic deserialization of different message types. diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs index 343b57485..9d225caa8 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs @@ -91,6 +91,11 @@ private async ValueTask OnMessageReceivedAsync(JsonRpcMessage? message, Cancella message.RelatedTransport = this; + if (parentTransport.FlowExecutionContextFromRequests) + { + message.ExecutionContext = ExecutionContext.Capture(); + } + await parentTransport.MessageWriter.WriteAsync(message, cancellationToken).ConfigureAwait(false); } } diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs index 1f5775e66..b63c8a651 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs @@ -10,7 +10,7 @@ namespace ModelContextProtocol.Server; /// /// /// This transport provides one-way communication from server to client using the SSE protocol over HTTP, -/// while receiving client messages through a separate mechanism. It writes messages as +/// while receiving client messages through a separate mechanism. It writes messages as /// SSE events to a response stream, typically associated with an HTTP response. /// /// @@ -36,6 +36,9 @@ public sealed class StreamableHttpServerTransport : ITransport private int _getRequestStarted; + /// + public string? SessionId { get; set; } + /// /// Configures whether the transport should be in stateless mode that does not require all requests for a given session /// to arrive to the same ASP.NET Core application process. Unsolicited server-to-client messages are not supported in this mode, @@ -45,6 +48,15 @@ public sealed class StreamableHttpServerTransport : ITransport /// public bool Stateless { get; init; } + /// + /// Gets a value indicating whether the execution context should flow from the calls to + /// to the corresponding emitted by the . + /// + /// + /// Defaults to . + /// + public bool FlowExecutionContextFromRequests { get; init; } + /// /// Gets or sets a callback to be invoked before handling the initialize request. /// @@ -55,9 +67,6 @@ public sealed class StreamableHttpServerTransport : ITransport internal ChannelWriter MessageWriter => _incomingChannel.Writer; - /// - public string? SessionId { get; set; } - /// /// Handles an optional SSE GET request a client using the Streamable HTTP transport might make by /// writing any unsolicited JSON-RPC messages sent via diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs index cf54e7774..4d0d73562 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs @@ -52,12 +52,6 @@ public async Task MapMcp_ThrowsInvalidOperationException_IfWithHttpTransportIsNo [Fact] public async Task Can_UseIHttpContextAccessor_InTool() { - Assert.SkipWhen(UseStreamableHttp && !Stateless, - """ - IHttpContextAccessor is not currently supported with non-stateless Streamable HTTP. - TODO: Support it in stateless mode by manually capturing and flowing execution context. - """); - Builder.Services.AddMcpServer().WithHttpTransport(ConfigureStateless).WithTools(); Builder.Services.AddHttpContextAccessor(); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs index 8c7f736db..0b3ae4c2a 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs @@ -387,7 +387,7 @@ public async Task Progress_IsReported_InSameSseResponseAsRpcResponse() } [Fact] - public async Task AsyncLocalSetInRunSessionHandlerCallback_Flows_ToAllToolCalls() + public async Task AsyncLocalSetInRunSessionHandlerCallback_Flows_ToAllToolCalls_IfPerSessionExecutionContextEnabled() { var asyncLocal = new AsyncLocal(); var totalSessionCount = 0; @@ -395,6 +395,7 @@ public async Task AsyncLocalSetInRunSessionHandlerCallback_Flows_ToAllToolCalls( Builder.Services.AddMcpServer() .WithHttpTransport(options => { + options.PerSessionExecutionContext = true; options.RunSessionHandler = async (httpContext, mcpServer, cancellationToken) => { asyncLocal.Value = $"RunSessionHandler ({totalSessionCount++})"; From 3156818d663fa67c164679c355c3880cce75449b Mon Sep 17 00:00:00 2001 From: Eirik Tsarpalis Date: Wed, 16 Jul 2025 16:17:49 +0300 Subject: [PATCH 7/7] Update MEAI version and add regression test for #601. (#628) --- Directory.Packages.props | 2 +- .../Server/McpServerToolTests.cs | 23 +++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/Directory.Packages.props b/Directory.Packages.props index 4df4ea73a..6da9521f7 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -3,7 +3,7 @@ true 9.0.5 10.0.0-preview.4.25258.110 - 9.7.0 + 9.7.1 diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs index 5cc6fa78a..bd0ca5ef9 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs @@ -563,6 +563,27 @@ public async Task StructuredOutput_Disabled_ReturnsExpectedSchema(T value) Assert.Null(result.StructuredContent); } + [Theory] + [InlineData(JsonNumberHandling.Strict)] + [InlineData(JsonNumberHandling.AllowReadingFromString)] + public async Task ToolWithNullableParameters_ReturnsExpectedSchema(JsonNumberHandling nunmberHandling) + { + JsonSerializerOptions options = new(JsonContext2.Default.Options) { NumberHandling = nunmberHandling }; + McpServerTool tool = McpServerTool.Create((int? x = 42, DateTimeOffset? y = null) => { }, new() { SerializerOptions = options }); + + JsonElement expectedSchema = JsonDocument.Parse(""" + { + "type": "object", + "properties": { + "x": { "type": ["integer", "null"], "default": 42 }, + "y": { "type": ["string", "null"], "format": "date-time", "default": null } + } + } + """).RootElement; + + Assert.True(JsonElement.DeepEquals(expectedSchema, tool.ProtocolTool.InputSchema)); + } + public static IEnumerable StructuredOutput_ReturnsExpectedSchema_Inputs() { yield return new object[] { "string" }; @@ -695,5 +716,7 @@ record Person(string Name, int Age); [JsonSerializable(typeof(JsonSchema))] [JsonSerializable(typeof(List))] [JsonSerializable(typeof(List))] + [JsonSerializable(typeof(int?))] + [JsonSerializable(typeof(DateTimeOffset?))] partial class JsonContext2 : JsonSerializerContext; }