diff --git a/src/Microsoft.ML.Tokenizers/Tokenizer.cs b/src/Microsoft.ML.Tokenizers/Tokenizer.cs index f7682b012b..1340a3a946 100644 --- a/src/Microsoft.ML.Tokenizers/Tokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Tokenizer.cs @@ -189,9 +189,10 @@ protected virtual int CountTokens(string? text, ReadOnlySpan textSpan, Enc /// The text to encode. /// Indicate whether to consider pre-tokenization before tokenization. /// Indicate whether to consider normalization before tokenization. + /// Indicate whether to consider a max token count for counting tokens. /// The number of token Ids that the input text will be encoded to. - public int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true) - => CountTokens(text, text.AsSpan(), new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization }); + public int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true, int maxTokenCount = int.MaxValue) + => CountTokens(text, text.AsSpan(), new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization, MaxTokenCount = maxTokenCount }); /// /// Get the number of tokens that the input text will be encoded to. @@ -199,9 +200,10 @@ public int CountTokens(string text, bool considerPreTokenization = true, bool co /// The text to encode. /// Indicate whether to consider pre-tokenization before tokenization. /// Indicate whether to consider normalization before tokenization. + /// Indicate whether to consider a max token count for counting tokens. /// The number of token Ids that the input text will be encoded to. - public int CountTokens(ReadOnlySpan text, bool considerPreTokenization = true, bool considerNormalization = true) - => CountTokens(null, text, new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization }); + public int CountTokens(ReadOnlySpan text, bool considerPreTokenization = true, bool considerNormalization = true, int maxTokenCount = int.MaxValue) + => CountTokens(null, text, new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization, MaxTokenCount = maxTokenCount }); /// /// Find the index of the maximum encoding capacity without surpassing the token limit. diff --git a/test/Microsoft.ML.Tokenizers.Tests/TiktokenTests.cs b/test/Microsoft.ML.Tokenizers.Tests/TiktokenTests.cs index 1e7cad6890..6946cf4c56 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/TiktokenTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/TiktokenTests.cs @@ -391,6 +391,20 @@ public void TestEncodeR50kBase() TestDecodingWithSpan((R50kBase as TiktokenTokenizer)!, encoded.ToArray(), text); } + [Fact] + public void TestCountingTokens() + { + string text = ReadAndSanitizeFile("./Data/lib.rs.txt"); + IReadOnlyList encoded = R50kBase.EncodeToIds(text); + int idsCount = R50kBase.CountTokens(text); + Assert.Equal(11378, encoded.Count); + Assert.Equal(encoded.Count, idsCount); + + // count with max tokens to encode + int idsCountMax1000 = R50kBase.CountTokens(text, maxTokenCount: 1000); + Assert.Equal(1000, idsCountMax1000); + } + [Theory] [InlineData("o1")] [InlineData("o1-")] diff --git a/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs b/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs index 7d18ecb1be..386e0c03c2 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs @@ -48,6 +48,14 @@ public void CountTokens_DefaultImplementation() Assert.Equal(5, tokenizer.CountTokens("hello")); } + [Fact] + public void CountTokens_WithMaxTokenCount() + { + var tokenizer = new EnglishAlphabetTokenizer(); + + Assert.Equal(3, tokenizer.CountTokens("hello", maxTokenCount: 3)); + } + [Fact] public void GetIndexByTokenCount_DefaultImplementation() {