Skip to content

Commit ab7de4b

Browse files
committed
Add maxTokenCount parameter to CountTokens method
Updated the CountTokens method in Tokenizer.cs to include a maxTokenCount parameter for limiting token counts. Added tests in TiktokenTests.cs and TokenizerTests.cs to verify the new functionality and ensure correct behavior with the maximum token count.
1 parent e3219a9 commit ab7de4b

File tree

3 files changed

+28
-4
lines changed

3 files changed

+28
-4
lines changed

src/Microsoft.ML.Tokenizers/Tokenizer.cs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,19 +189,21 @@ protected virtual int CountTokens(string? text, ReadOnlySpan<char> textSpan, Enc
189189
/// <param name="text">The text to encode.</param>
190190
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
191191
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
192+
/// <param name="maxTokenCount">Indicate whether to consider a max token count for counting tokens.</param>
192193
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
193-
public int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true)
194-
=> CountTokens(text, text.AsSpan(), new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization });
194+
public int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true, int maxTokenCount = int.MaxValue)
195+
=> CountTokens(text, text.AsSpan(), new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization, MaxTokenCount = maxTokenCount });
195196

196197
/// <summary>
197198
/// Get the number of tokens that the input text will be encoded to.
198199
/// </summary>
199200
/// <param name="text">The text to encode.</param>
200201
/// <param name="considerPreTokenization">Indicate whether to consider pre-tokenization before tokenization.</param>
201202
/// <param name="considerNormalization">Indicate whether to consider normalization before tokenization.</param>
203+
/// <param name="maxTokenCount">Indicate whether to consider a max token count for counting tokens.</param>
202204
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
203-
public int CountTokens(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true)
204-
=> CountTokens(null, text, new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization });
205+
public int CountTokens(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true, int maxTokenCount = int.MaxValue)
206+
=> CountTokens(null, text, new EncodeSettings { ConsiderPreTokenization = considerPreTokenization, ConsiderNormalization = considerNormalization, MaxTokenCount = maxTokenCount });
205207

206208
/// <summary>
207209
/// Find the index of the maximum encoding capacity without surpassing the token limit.

test/Microsoft.ML.Tokenizers.Tests/TiktokenTests.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,20 @@ public void TestEncodeR50kBase()
391391
TestDecodingWithSpan((R50kBase as TiktokenTokenizer)!, encoded.ToArray(), text);
392392
}
393393

394+
[Fact]
395+
public void TestCountingTokens()
396+
{
397+
string text = ReadAndSanitizeFile("./Data/lib.rs.txt");
398+
IReadOnlyList<int> encoded = R50kBase.EncodeToIds(text);
399+
int idsCount = R50kBase.CountTokens(text);
400+
Assert.Equal(11378, encoded.Count);
401+
Assert.Equal(encoded.Count, idsCount);
402+
403+
// count with max tokens to encode
404+
int idsCountMax1000 = R50kBase.CountTokens(text, maxTokenCount: 1000);
405+
Assert.Equal(1000, idsCountMax1000);
406+
}
407+
394408
[Theory]
395409
[InlineData("o1")]
396410
[InlineData("o1-")]

test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,14 @@ public void CountTokens_DefaultImplementation()
4848
Assert.Equal(5, tokenizer.CountTokens("hello"));
4949
}
5050

51+
[Fact]
52+
public void CountTokens_WithMaxTokenCount()
53+
{
54+
var tokenizer = new EnglishAlphabetTokenizer();
55+
56+
Assert.Equal(3, tokenizer.CountTokens("hello", maxTokenCount: 3));
57+
}
58+
5159
[Fact]
5260
public void GetIndexByTokenCount_DefaultImplementation()
5361
{

0 commit comments

Comments
 (0)