From d68a6838156049ada8c25d3f4b8fa3befb3b4d1b Mon Sep 17 00:00:00 2001 From: Takahiro Ikeuchi Date: Thu, 24 Apr 2025 06:50:47 +0900 Subject: [PATCH 01/41] feat: add new GPT-4.1 model variants to completion.go (#966) * feat: add new GPT-4.1 model variants to completion.go * feat: add tests for unsupported models in completion endpoint * fix: add missing periods to test function comments in completion_test.go --- completion.go | 13 ++++++++ completion_test.go | 83 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+) diff --git a/completion.go b/completion.go index 015fa2a9f..0d0c1a8f4 100644 --- a/completion.go +++ b/completion.go @@ -37,6 +37,12 @@ const ( GPT4TurboPreview = "gpt-4-turbo-preview" GPT4VisionPreview = "gpt-4-vision-preview" GPT4 = "gpt-4" + GPT4Dot1 = "gpt-4.1" + GPT4Dot120250414 = "gpt-4.1-2025-04-14" + GPT4Dot1Mini = "gpt-4.1-mini" + GPT4Dot1Mini20250414 = "gpt-4.1-mini-2025-04-14" + GPT4Dot1Nano = "gpt-4.1-nano" + GPT4Dot1Nano20250414 = "gpt-4.1-nano-2025-04-14" GPT4Dot5Preview = "gpt-4.5-preview" GPT4Dot5Preview20250227 = "gpt-4.5-preview-2025-02-27" GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125" @@ -121,6 +127,13 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT432K: true, GPT432K0314: true, GPT432K0613: true, + O1: true, + GPT4Dot1: true, + GPT4Dot120250414: true, + GPT4Dot1Mini: true, + GPT4Dot1Mini20250414: true, + GPT4Dot1Nano: true, + GPT4Dot1Nano20250414: true, }, chatCompletionsSuffix: { CodexCodeDavinci002: true, diff --git a/completion_test.go b/completion_test.go index 935bbe864..83bd899a1 100644 --- a/completion_test.go +++ b/completion_test.go @@ -181,3 +181,86 @@ func getCompletionBody(r *http.Request) (openai.CompletionRequest, error) { } return completion, nil } + +// TestCompletionWithO1Model Tests that O1 model is not supported for completion endpoint. +func TestCompletionWithO1Model(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: openai.O1, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O1 model, but returned: %v", err) + } +} + +// TestCompletionWithGPT4DotModels Tests that newer GPT4 models are not supported for completion endpoint. +func TestCompletionWithGPT4DotModels(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + models := []string{ + openai.GPT4Dot1, + openai.GPT4Dot120250414, + openai.GPT4Dot1Mini, + openai.GPT4Dot1Mini20250414, + openai.GPT4Dot1Nano, + openai.GPT4Dot1Nano20250414, + openai.GPT4Dot5Preview, + openai.GPT4Dot5Preview20250227, + } + + for _, model := range models { + t.Run(model, func(t *testing.T) { + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: model, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err) + } + }) + } +} + +// TestCompletionWithGPT4oModels Tests that GPT4o models are not supported for completion endpoint. +func TestCompletionWithGPT4oModels(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + models := []string{ + openai.GPT4o, + openai.GPT4o20240513, + openai.GPT4o20240806, + openai.GPT4o20241120, + openai.GPT4oLatest, + openai.GPT4oMini, + openai.GPT4oMini20240718, + } + + for _, model := range models { + t.Run(model, func(t *testing.T) { + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: model, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err) + } + }) + } +} From 658beda2ba8be4d155bc62208224a5766e0640c0 Mon Sep 17 00:00:00 2001 From: netr Date: Sat, 26 Apr 2025 03:13:43 -0700 Subject: [PATCH 02/41] feat: Add missing TTS models and voices (#958) * feat: Add missing TTS models and voices * feat: Add new instruction field to create speech request - From docs: Control the voice of your generated audio with additional instructions. Does not work with tts-1 or tts-1-hd. * fix: add canary-tts back to SpeechModel --- speech.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/speech.go b/speech.go index 20b52e334..60e7694fd 100644 --- a/speech.go +++ b/speech.go @@ -8,20 +8,25 @@ import ( type SpeechModel string const ( - TTSModel1 SpeechModel = "tts-1" - TTSModel1HD SpeechModel = "tts-1-hd" - TTSModelCanary SpeechModel = "canary-tts" + TTSModel1 SpeechModel = "tts-1" + TTSModel1HD SpeechModel = "tts-1-hd" + TTSModelCanary SpeechModel = "canary-tts" + TTSModelGPT4oMini SpeechModel = "gpt-4o-mini-tts" ) type SpeechVoice string const ( VoiceAlloy SpeechVoice = "alloy" + VoiceAsh SpeechVoice = "ash" + VoiceBallad SpeechVoice = "ballad" + VoiceCoral SpeechVoice = "coral" VoiceEcho SpeechVoice = "echo" VoiceFable SpeechVoice = "fable" VoiceOnyx SpeechVoice = "onyx" VoiceNova SpeechVoice = "nova" VoiceShimmer SpeechVoice = "shimmer" + VoiceVerse SpeechVoice = "verse" ) type SpeechResponseFormat string @@ -39,6 +44,7 @@ type CreateSpeechRequest struct { Model SpeechModel `json:"model"` Input string `json:"input"` Voice SpeechVoice `json:"voice"` + Instructions string `json:"instructions,omitempty"` // Optional, Doesnt work with tts-1 or tts-1-hd. ResponseFormat SpeechResponseFormat `json:"response_format,omitempty"` // Optional, default to mp3 Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0 } From 306fbbbe6f09ff7bd718e11cd322e88b442f4496 Mon Sep 17 00:00:00 2001 From: goodenough Date: Tue, 29 Apr 2025 21:24:45 +0800 Subject: [PATCH 03/41] Add support for reasoning_content field in chat completion messages for DeepSeek R1 (#925) * support deepseek field "reasoning_content" * support deepseek field "reasoning_content" * Comment ends in a period (godot) * add comment on field reasoning_content * fix go lint error * chore: trigger CI * make field "content" in MarshalJSON function omitempty * remove reasoning_content in TestO1ModelChatCompletions func * feat: Add test and handler for deepseek-reasoner chat model completions, including support for reasoning content in responses. * feat: Add test and handler for deepseek-reasoner chat model completions, including support for reasoning content in responses. * feat: Add test and handler for deepseek-reasoner chat model completions, including support for reasoning content in responses. --- chat.go | 74 ++++++++++++++++++++++++++-------------------- chat_stream.go | 6 ++++ chat_test.go | 79 ++++++++++++++++++++++++++++++++++++++++++++++++++ openai_test.go | 2 +- 4 files changed, 128 insertions(+), 33 deletions(-) diff --git a/chat.go b/chat.go index 995860c40..7112dc7b7 100644 --- a/chat.go +++ b/chat.go @@ -104,6 +104,12 @@ type ChatCompletionMessage struct { // - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb Name string `json:"name,omitempty"` + // This property is used for the "reasoning" feature supported by deepseek-reasoner + // which is not in the official documentation. + // the doc from deepseek: + // - https://api-docs.deepseek.com/api/create-chat-completion#responses + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` // For Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls. @@ -119,41 +125,44 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { } if len(m.MultiContent) > 0 { msg := struct { - Role string `json:"role"` - Content string `json:"-"` - Refusal string `json:"refusal,omitempty"` - MultiContent []ChatMessagePart `json:"content,omitempty"` - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"-"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"content,omitempty"` + Name string `json:"name,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }(m) return json.Marshal(msg) } msg := struct { - Role string `json:"role"` - Content string `json:"content,omitempty"` - Refusal string `json:"refusal,omitempty"` - MultiContent []ChatMessagePart `json:"-"` - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"content,omitempty"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"-"` + Name string `json:"name,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }(m) return json.Marshal(msg) } func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { msg := struct { - Role string `json:"role"` - Content string `json:"content,omitempty"` - Refusal string `json:"refusal,omitempty"` - MultiContent []ChatMessagePart - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"content"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart + Name string `json:"name,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }{} if err := json.Unmarshal(bs, &msg); err == nil { @@ -161,14 +170,15 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { return nil } multiMsg := struct { - Role string `json:"role"` - Content string - Refusal string `json:"refusal,omitempty"` - MultiContent []ChatMessagePart `json:"content"` - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"content"` + Name string `json:"name,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }{} if err := json.Unmarshal(bs, &multiMsg); err != nil { return err diff --git a/chat_stream.go b/chat_stream.go index 525b4457a..80d16cc63 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -11,6 +11,12 @@ type ChatCompletionStreamChoiceDelta struct { FunctionCall *FunctionCall `json:"function_call,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"` Refusal string `json:"refusal,omitempty"` + + // This property is used for the "reasoning" feature supported by deepseek-reasoner + // which is not in the official documentation. + // the doc from deepseek: + // - https://api-docs.deepseek.com/api/create-chat-completion#responses + ReasoningContent string `json:"reasoning_content,omitempty"` } type ChatCompletionStreamChoiceLogprobs struct { diff --git a/chat_test.go b/chat_test.go index e90142da6..514706c96 100644 --- a/chat_test.go +++ b/chat_test.go @@ -411,6 +411,23 @@ func TestO3ModelChatCompletions(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion error") } +func TestDeepseekR1ModelChatCompletions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleDeepseekR1ChatCompletionEndpoint) + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + Model: "deepseek-reasoner", + MaxCompletionTokens: 100, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") +} + // TestCompletions Tests the completions endpoint of the API using the mocked server. func TestChatCompletionsWithHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() @@ -822,6 +839,68 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, string(resBytes)) } +func handleDeepseekR1ChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { + var err error + var resBytes []byte + + // completions only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + var completionReq openai.ChatCompletionRequest + if completionReq, err = getChatCompletionBody(r); err != nil { + http.Error(w, "could not read request", http.StatusInternalServerError) + return + } + res := openai.ChatCompletionResponse{ + ID: strconv.Itoa(int(time.Now().Unix())), + Object: "test-object", + Created: time.Now().Unix(), + // would be nice to validate Model during testing, but + // this may not be possible with how much upkeep + // would be required / wouldn't make much sense + Model: completionReq.Model, + } + // create completions + n := completionReq.N + if n == 0 { + n = 1 + } + if completionReq.MaxCompletionTokens == 0 { + completionReq.MaxCompletionTokens = 1000 + } + for i := 0; i < n; i++ { + reasoningContent := "User says hello! And I need to reply" + completionStr := strings.Repeat("a", completionReq.MaxCompletionTokens-numTokens(reasoningContent)) + res.Choices = append(res.Choices, openai.ChatCompletionChoice{ + Message: openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, + ReasoningContent: reasoningContent, + Content: completionStr, + }, + Index: i, + }) + } + inputTokens := numTokens(completionReq.Messages[0].Content) * n + completionTokens := completionReq.MaxTokens * n + res.Usage = openai.Usage{ + PromptTokens: inputTokens, + CompletionTokens: completionTokens, + TotalTokens: inputTokens + completionTokens, + } + resBytes, _ = json.Marshal(res) + w.Header().Set(xCustomHeader, xCustomHeaderValue) + for k, v := range rateLimitHeaders { + switch val := v.(type) { + case int: + w.Header().Set(k, strconv.Itoa(val)) + default: + w.Header().Set(k, fmt.Sprintf("%s", v)) + } + } + fmt.Fprintln(w, string(resBytes)) +} + // getChatCompletionBody Returns the body of the request to create a completion. func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error) { completion := openai.ChatCompletionRequest{} diff --git a/openai_test.go b/openai_test.go index 6c26eebd1..a55f3a858 100644 --- a/openai_test.go +++ b/openai_test.go @@ -29,7 +29,7 @@ func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, tea // numTokens Returns the number of GPT-3 encoded tokens in the given text. // This function approximates based on the rule of thumb stated by OpenAI: -// https://beta.openai.com/tokenizer/ +// https://beta.openai.com/tokenizer. // // TODO: implement an actual tokenizer for GPT-3 and Codex (once available). func numTokens(s string) int { From 4cccc6c93455d2ff2cf52660c916cfa5907ddbd3 Mon Sep 17 00:00:00 2001 From: Zhongxian Pan Date: Tue, 29 Apr 2025 21:29:15 +0800 Subject: [PATCH 04/41] Adapt different stream data prefix, with or without space (#945) --- stream_reader.go | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/stream_reader.go b/stream_reader.go index ecfa26807..6faefe0a7 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -6,13 +6,14 @@ import ( "fmt" "io" "net/http" + "regexp" utils "github.com/sashabaranov/go-openai/internal" ) var ( - headerData = []byte("data: ") - errorPrefix = []byte(`data: {"error":`) + headerData = regexp.MustCompile(`^data:\s*`) + errorPrefix = regexp.MustCompile(`^data:\s*{"error":`) ) type streamable interface { @@ -70,12 +71,12 @@ func (stream *streamReader[T]) processLines() ([]byte, error) { } noSpaceLine := bytes.TrimSpace(rawLine) - if bytes.HasPrefix(noSpaceLine, errorPrefix) { + if errorPrefix.Match(noSpaceLine) { hasErrorPrefix = true } - if !bytes.HasPrefix(noSpaceLine, headerData) || hasErrorPrefix { + if !headerData.Match(noSpaceLine) || hasErrorPrefix { if hasErrorPrefix { - noSpaceLine = bytes.TrimPrefix(noSpaceLine, headerData) + noSpaceLine = headerData.ReplaceAll(noSpaceLine, nil) } writeErr := stream.errAccumulator.Write(noSpaceLine) if writeErr != nil { @@ -89,7 +90,7 @@ func (stream *streamReader[T]) processLines() ([]byte, error) { continue } - noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData) + noPrefixLine := headerData.ReplaceAll(noSpaceLine, nil) if string(noPrefixLine) == "[DONE]" { stream.isFinished = true return nil, io.EOF From bb5bc275678767d410d25d307959e1c45bc89c90 Mon Sep 17 00:00:00 2001 From: rory malcolm Date: Tue, 29 Apr 2025 14:34:33 +0100 Subject: [PATCH 05/41] Add support for `4o-mini` and `3o` (#968) - This adds supports, and tests, for the 3o and 4o-mini class of models --- chat_stream_test.go | 50 ++++++++++++++++++++++++++++++++++++++++++ completion.go | 8 +++++++ completion_test.go | 36 ++++++++++++++++++++++++++++++ models_test.go | 18 +++++++++++++++ reasoning_validator.go | 3 ++- 5 files changed, 114 insertions(+), 1 deletion(-) diff --git a/chat_stream_test.go b/chat_stream_test.go index 4d992e4d1..eabb0f3a2 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -959,6 +959,56 @@ func TestCreateChatCompletionStreamReasoningValidatorFails(t *testing.T) { } } +func TestCreateChatCompletionStreamO3ReasoningValidatorFails(t *testing.T) { + client, _, _ := setupOpenAITestServer() + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 100, // This will trigger the validator to fail + Model: openai.O3, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + + if stream != nil { + t.Error("Expected nil stream when validation fails") + stream.Close() + } + + if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) { + t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated for O3, got: %v", err) + } +} + +func TestCreateChatCompletionStreamO4MiniReasoningValidatorFails(t *testing.T) { + client, _, _ := setupOpenAITestServer() + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 100, // This will trigger the validator to fail + Model: openai.O4Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + + if stream != nil { + t.Error("Expected nil stream when validation fails") + stream.Close() + } + + if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) { + t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated for O4Mini, got: %v", err) + } +} + func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool { if c1.Index != c2.Index { return false diff --git a/completion.go b/completion.go index 0d0c1a8f4..7a6de3033 100644 --- a/completion.go +++ b/completion.go @@ -16,8 +16,12 @@ const ( O1Preview20240912 = "o1-preview-2024-09-12" O1 = "o1" O120241217 = "o1-2024-12-17" + O3 = "o3" + O320250416 = "o3-2025-04-16" O3Mini = "o3-mini" O3Mini20250131 = "o3-mini-2025-01-31" + O4Mini = "o4-mini" + O4Mini2020416 = "o4-mini-2025-04-16" GPT432K0613 = "gpt-4-32k-0613" GPT432K0314 = "gpt-4-32k-0314" GPT432K = "gpt-4-32k" @@ -99,6 +103,10 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ O1Preview20240912: true, O3Mini: true, O3Mini20250131: true, + O4Mini: true, + O4Mini2020416: true, + O3: true, + O320250416: true, GPT3Dot5Turbo: true, GPT3Dot5Turbo0301: true, GPT3Dot5Turbo0613: true, diff --git a/completion_test.go b/completion_test.go index 83bd899a1..27e2d150e 100644 --- a/completion_test.go +++ b/completion_test.go @@ -33,6 +33,42 @@ func TestCompletionsWrongModel(t *testing.T) { } } +// TestCompletionsWrongModelO3 Tests the completions endpoint with O3 model which is not supported. +func TestCompletionsWrongModelO3(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: openai.O3, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O3, but returned: %v", err) + } +} + +// TestCompletionsWrongModelO4Mini Tests the completions endpoint with O4Mini model which is not supported. +func TestCompletionsWrongModelO4Mini(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: openai.O4Mini, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O4Mini, but returned: %v", err) + } +} + func TestCompletionWithStream(t *testing.T) { config := openai.DefaultConfig("whatever") client := openai.NewClientWithConfig(config) diff --git a/models_test.go b/models_test.go index 24a28ed23..7fd010c34 100644 --- a/models_test.go +++ b/models_test.go @@ -47,6 +47,24 @@ func TestGetModel(t *testing.T) { checks.NoError(t, err, "GetModel error") } +// TestGetModelO3 Tests the retrieve O3 model endpoint of the API using the mocked server. +func TestGetModelO3(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/models/o3", handleGetModelEndpoint) + _, err := client.GetModel(context.Background(), "o3") + checks.NoError(t, err, "GetModel error for O3") +} + +// TestGetModelO4Mini Tests the retrieve O4Mini model endpoint of the API using the mocked server. +func TestGetModelO4Mini(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/models/o4-mini", handleGetModelEndpoint) + _, err := client.GetModel(context.Background(), "o4-mini") + checks.NoError(t, err, "GetModel error for O4Mini") +} + func TestAzureGetModel(t *testing.T) { client, server, teardown := setupAzureTestServer() defer teardown() diff --git a/reasoning_validator.go b/reasoning_validator.go index 040d6b495..2910b1395 100644 --- a/reasoning_validator.go +++ b/reasoning_validator.go @@ -40,8 +40,9 @@ func NewReasoningValidator() *ReasoningValidator { func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error { o1Series := strings.HasPrefix(request.Model, "o1") o3Series := strings.HasPrefix(request.Model, "o3") + o4Series := strings.HasPrefix(request.Model, "o4") - if !o1Series && !o3Series { + if !o1Series && !o3Series && !o4Series { return nil } From da5f9bc9bc40537a0c2b451fffa9364efa94dbe1 Mon Sep 17 00:00:00 2001 From: Sean McGinnis Date: Tue, 29 Apr 2025 08:35:26 -0500 Subject: [PATCH 06/41] Add CompletionRequest.StreamOptions (#959) The legacy completion API supports a `stream_options` object when `stream` is set to true [0]. This adds a StreamOptions property to the CompletionRequest struct to support this setting. [0] https://platform.openai.com/docs/api-reference/completions/create#completions-create-stream_options Signed-off-by: Sean McGinnis --- completion.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/completion.go b/completion.go index 7a6de3033..9c3a64dd5 100644 --- a/completion.go +++ b/completion.go @@ -215,6 +215,8 @@ type CompletionRequest struct { Temperature float32 `json:"temperature,omitempty"` TopP float32 `json:"top_p,omitempty"` User string `json:"user,omitempty"` + // Options for streaming response. Only set this when you set stream: true. + StreamOptions *StreamOptions `json:"stream_options,omitempty"` } // CompletionChoice represents one of possible completions. From 6836cf6a6fd0027ea21f8d31bff5d023040d9db4 Mon Sep 17 00:00:00 2001 From: Oleksandr Redko Date: Tue, 29 Apr 2025 16:36:38 +0300 Subject: [PATCH 07/41] Remove redundant typecheck linter (#955) --- .golangci.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.golangci.yml b/.golangci.yml index 9f2ba52e0..a5988825b 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -149,7 +149,6 @@ linters: - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string - ineffassign # Detects when assignments to existing variables are not used - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks - - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code - unused # Checks Go code for unused constants, variables, functions and types ## disabled by default # - asasalint # Check for pass []any as any in variadic func(...any) From 93a611cf4f1d227963a4f28a2c4a3422a0d37bfd Mon Sep 17 00:00:00 2001 From: Daniel Peng Date: Tue, 29 Apr 2025 06:38:27 -0700 Subject: [PATCH 08/41] Add Prediction field (#970) * Add Prediction field to ChatCompletionRequest * Include prediction tokens in response --- chat.go | 7 +++++++ common.go | 6 ++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/chat.go b/chat.go index 7112dc7b7..0f91d481c 100644 --- a/chat.go +++ b/chat.go @@ -273,6 +273,8 @@ type ChatCompletionRequest struct { ReasoningEffort string `json:"reasoning_effort,omitempty"` // Metadata to store with the completion. Metadata map[string]string `json:"metadata,omitempty"` + // Configuration for a predicted output. + Prediction *Prediction `json:"prediction,omitempty"` } type StreamOptions struct { @@ -340,6 +342,11 @@ type LogProbs struct { Content []LogProb `json:"content"` } +type Prediction struct { + Content string `json:"content"` + Type string `json:"type"` +} + type FinishReason string const ( diff --git a/common.go b/common.go index 8cc7289c0..d1936d656 100644 --- a/common.go +++ b/common.go @@ -13,8 +13,10 @@ type Usage struct { // CompletionTokensDetails Breakdown of tokens used in a completion. type CompletionTokensDetails struct { - AudioTokens int `json:"audio_tokens"` - ReasoningTokens int `json:"reasoning_tokens"` + AudioTokens int `json:"audio_tokens"` + ReasoningTokens int `json:"reasoning_tokens"` + AcceptedPredictionTokens int `json:"accepted_prediction_tokens"` + RejectedPredictionTokens int `json:"rejected_prediction_tokens"` } // PromptTokensDetails Breakdown of tokens used in the prompt. From d65f0cb54e8c9c91f4340fad14243beeb38f5a08 Mon Sep 17 00:00:00 2001 From: Ben Katz Date: Sun, 4 May 2025 03:44:48 +0700 Subject: [PATCH 09/41] Fix: Corrected typo in O4Mini20250416 model name and endpoint map. (#981) --- completion.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/completion.go b/completion.go index 9c3a64dd5..21d4897c4 100644 --- a/completion.go +++ b/completion.go @@ -21,7 +21,7 @@ const ( O3Mini = "o3-mini" O3Mini20250131 = "o3-mini-2025-01-31" O4Mini = "o4-mini" - O4Mini2020416 = "o4-mini-2025-04-16" + O4Mini20250416 = "o4-mini-2025-04-16" GPT432K0613 = "gpt-4-32k-0613" GPT432K0314 = "gpt-4-32k-0314" GPT432K = "gpt-4-32k" @@ -104,7 +104,7 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ O3Mini: true, O3Mini20250131: true, O4Mini: true, - O4Mini2020416: true, + O4Mini20250416: true, O3: true, O320250416: true, GPT3Dot5Turbo: true, From 5ea214a188a7751bddb9f6c899632509d388f643 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Sat, 3 May 2025 22:25:14 +0100 Subject: [PATCH 10/41] Improve unit test coverage (#984) * add tests for config * add audio tests * lint * lint * lint --- audio_test.go | 132 +++++++++++++++++++++++++++++++++++++++++++++++++ config_test.go | 21 ++++++++ 2 files changed, 153 insertions(+) diff --git a/audio_test.go b/audio_test.go index 9f32d5468..51b3f465d 100644 --- a/audio_test.go +++ b/audio_test.go @@ -2,12 +2,16 @@ package openai //nolint:testpackage // testing private field import ( "bytes" + "context" + "errors" "fmt" "io" + "net/http" "os" "path/filepath" "testing" + utils "github.com/sashabaranov/go-openai/internal" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" ) @@ -107,3 +111,131 @@ func TestCreateFileField(t *testing.T) { checks.HasError(t, err, "createFileField using file should return error when open file fails") }) } + +// failingFormBuilder always returns an error when creating form files. +type failingFormBuilder struct{ err error } + +func (f *failingFormBuilder) CreateFormFile(_ string, _ *os.File) error { + return f.err +} + +func (f *failingFormBuilder) CreateFormFileReader(_ string, _ io.Reader, _ string) error { + return f.err +} + +func (f *failingFormBuilder) WriteField(_, _ string) error { + return nil +} + +func (f *failingFormBuilder) Close() error { + return nil +} + +func (f *failingFormBuilder) FormDataContentType() string { + return "multipart/form-data" +} + +// failingAudioRequestBuilder simulates an error during HTTP request construction. +type failingAudioRequestBuilder struct{ err error } + +func (f *failingAudioRequestBuilder) Build( + _ context.Context, + _, _ string, + _ any, + _ http.Header, +) (*http.Request, error) { + return nil, f.err +} + +// errorHTTPClient always returns an error when making HTTP calls. +type errorHTTPClient struct{ err error } + +func (e *errorHTTPClient) Do(_ *http.Request) (*http.Response, error) { + return nil, e.err +} + +func TestCallAudioAPIMultipartFormError(t *testing.T) { + client := NewClient("test-token") + errForm := errors.New("mock create form file failure") + // Override form builder to force an error during multipart form creation. + client.createFormBuilder = func(_ io.Writer) utils.FormBuilder { + return &failingFormBuilder{err: errForm} + } + + // Provide a reader so createFileField uses the reader path (no file open). + req := AudioRequest{FilePath: "fake.mp3", Reader: bytes.NewBuffer([]byte("dummy")), Model: Whisper1} + _, err := client.callAudioAPI(context.Background(), req, "transcriptions") + if err == nil { + t.Fatal("expected error but got none") + } + if !errors.Is(err, errForm) { + t.Errorf("expected error %v, got %v", errForm, err) + } +} + +func TestCallAudioAPINewRequestError(t *testing.T) { + client := NewClient("test-token") + // Create a real temp file so multipart form succeeds. + tmp := t.TempDir() + path := filepath.Join(tmp, "file.mp3") + if err := os.WriteFile(path, []byte("content"), 0644); err != nil { + t.Fatalf("failed to write temp file: %v", err) + } + + errBuild := errors.New("mock build failure") + client.requestBuilder = &failingAudioRequestBuilder{err: errBuild} + + req := AudioRequest{FilePath: path, Model: Whisper1} + _, err := client.callAudioAPI(context.Background(), req, "translations") + if err == nil { + t.Fatal("expected error but got none") + } + if !errors.Is(err, errBuild) { + t.Errorf("expected error %v, got %v", errBuild, err) + } +} + +func TestCallAudioAPISendRequestErrorJSON(t *testing.T) { + client := NewClient("test-token") + // Create a real temp file so multipart form succeeds. + tmp := t.TempDir() + path := filepath.Join(tmp, "file.mp3") + if err := os.WriteFile(path, []byte("content"), 0644); err != nil { + t.Fatalf("failed to write temp file: %v", err) + } + + errHTTP := errors.New("mock HTTPClient failure") + // Override HTTP client to simulate a network error. + client.config.HTTPClient = &errorHTTPClient{err: errHTTP} + + req := AudioRequest{FilePath: path, Model: Whisper1} + _, err := client.callAudioAPI(context.Background(), req, "transcriptions") + if err == nil { + t.Fatal("expected error but got none") + } + if !errors.Is(err, errHTTP) { + t.Errorf("expected error %v, got %v", errHTTP, err) + } +} + +func TestCallAudioAPISendRequestErrorText(t *testing.T) { + client := NewClient("test-token") + tmp := t.TempDir() + path := filepath.Join(tmp, "file.mp3") + if err := os.WriteFile(path, []byte("content"), 0644); err != nil { + t.Fatalf("failed to write temp file: %v", err) + } + + errHTTP := errors.New("mock HTTPClient failure") + client.config.HTTPClient = &errorHTTPClient{err: errHTTP} + + // Use a non-JSON response format to exercise the text path. + req := AudioRequest{FilePath: path, Model: Whisper1, Format: AudioResponseFormatText} + _, err := client.callAudioAPI(context.Background(), req, "translations") + if err == nil { + t.Fatal("expected error but got none") + } + if !errors.Is(err, errHTTP) { + t.Errorf("expected error %v, got %v", errHTTP, err) + } +} diff --git a/config_test.go b/config_test.go index 145c26066..960230804 100644 --- a/config_test.go +++ b/config_test.go @@ -100,3 +100,24 @@ func TestDefaultAnthropicConfigWithEmptyValues(t *testing.T) { t.Errorf("Expected BaseURL to be %v, got %v", expectedBaseURL, config.BaseURL) } } + +func TestClientConfigString(t *testing.T) { + // String() should always return the constant value + conf := openai.DefaultConfig("dummy-token") + expected := "" + got := conf.String() + if got != expected { + t.Errorf("ClientConfig.String() = %q; want %q", got, expected) + } +} + +func TestGetAzureDeploymentByModel_NoMapper(t *testing.T) { + // On a zero-value or DefaultConfig, AzureModelMapperFunc is nil, + // so GetAzureDeploymentByModel should just return the input model. + conf := openai.DefaultConfig("dummy-token") + model := "some-model" + got := conf.GetAzureDeploymentByModel(model) + if got != model { + t.Errorf("GetAzureDeploymentByModel(%q) = %q; want %q", model, got, model) + } +} From 77ccac8d342f7704b020e63b23cec2d6f009bff9 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Sat, 3 May 2025 22:39:47 +0100 Subject: [PATCH 11/41] Upgrade golangci-lint to 2.1.5 (#986) * Upgrade golangci-lint to 2.1.5 * update action --- .github/workflows/pr.yml | 4 +- .golangci.yml | 418 +++++++++++++++------------------------ 2 files changed, 166 insertions(+), 256 deletions(-) diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index f4cbe7c8b..268e3259b 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -18,9 +18,9 @@ jobs: run: | go vet . - name: Run golangci-lint - uses: golangci/golangci-lint-action@v6 + uses: golangci/golangci-lint-action@v7 with: - version: v1.64.5 + version: v2.1.5 - name: Run tests run: go test -race -covermode=atomic -coverprofile=coverage.out -v ./... - name: Upload coverage reports to Codecov diff --git a/.golangci.yml b/.golangci.yml index a5988825b..6391ad76f 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,258 +1,168 @@ -## Golden config for golangci-lint v1.47.3 -# -# This is the best config for golangci-lint based on my experience and opinion. -# It is very strict, but not extremely strict. -# Feel free to adopt and change it for your needs. - -run: - # Timeout for analysis, e.g. 30s, 5m. - # Default: 1m - timeout: 3m - - -# This file contains only configs which differ from defaults. -# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml -linters-settings: - cyclop: - # The maximal code complexity to report. - # Default: 10 - max-complexity: 30 - # The maximal average package complexity. - # If it's higher than 0.0 (float) the check is enabled - # Default: 0.0 - package-average: 10.0 - - errcheck: - # Report about not checking of errors in type assertions: `a := b.(MyStruct)`. - # Such cases aren't reported by default. - # Default: false - check-type-assertions: true - - funlen: - # Checks the number of lines in a function. - # If lower than 0, disable the check. - # Default: 60 - lines: 100 - # Checks the number of statements in a function. - # If lower than 0, disable the check. - # Default: 40 - statements: 50 - - gocognit: - # Minimal code complexity to report - # Default: 30 (but we recommend 10-20) - min-complexity: 20 - - gocritic: - # Settings passed to gocritic. - # The settings key is the name of a supported gocritic checker. - # The list of supported checkers can be find in https://go-critic.github.io/overview. - settings: - captLocal: - # Whether to restrict checker to params only. - # Default: true - paramsOnly: false - underef: - # Whether to skip (*x).method() calls where x is a pointer receiver. - # Default: true - skipRecvDeref: false - - mnd: - # List of function patterns to exclude from analysis. - # Values always ignored: `time.Date` - # Default: [] - ignored-functions: - - os.Chmod - - os.Mkdir - - os.MkdirAll - - os.OpenFile - - os.WriteFile - - prometheus.ExponentialBuckets - - prometheus.ExponentialBucketsRange - - prometheus.LinearBuckets - - strconv.FormatFloat - - strconv.FormatInt - - strconv.FormatUint - - strconv.ParseFloat - - strconv.ParseInt - - strconv.ParseUint - - gomodguard: - blocked: - # List of blocked modules. - # Default: [] - modules: - - github.com/golang/protobuf: - recommendations: - - google.golang.org/protobuf - reason: "see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules" - - github.com/satori/go.uuid: - recommendations: - - github.com/google/uuid - reason: "satori's package is not maintained" - - github.com/gofrs/uuid: - recommendations: - - github.com/google/uuid - reason: "see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw" - - govet: - # Enable all analyzers. - # Default: false - enable-all: true - # Disable analyzers by name. - # Run `go tool vet help` to see all analyzers. - # Default: [] - disable: - - fieldalignment # too strict - # Settings per analyzer. - settings: - shadow: - # Whether to be strict about shadowing; can be noisy. - # Default: false - strict: true - - nakedret: - # Make an issue if func has more lines of code than this setting, and it has naked returns. - # Default: 30 - max-func-lines: 0 - - nolintlint: - # Exclude following linters from requiring an explanation. - # Default: [] - allow-no-explanation: [ funlen, gocognit, lll ] - # Enable to require an explanation of nonzero length after each nolint directive. - # Default: false - require-explanation: true - # Enable to require nolint directives to mention the specific linter being suppressed. - # Default: false - require-specific: true - - rowserrcheck: - # database/sql is always checked - # Default: [] - packages: - - github.com/jmoiron/sqlx - - tenv: - # The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures. - # Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked. - # Default: false - all: true - - +version: "2" linters: - disable-all: true + default: none enable: - ## enabled by default - - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases - - gosimple # Linter for Go source code that specializes in simplifying a code - - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string - - ineffassign # Detects when assignments to existing variables are not used - - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks - - unused # Checks Go code for unused constants, variables, functions and types - ## disabled by default - # - asasalint # Check for pass []any as any in variadic func(...any) - - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers - - bidichk # Checks for dangerous unicode character sequences - - bodyclose # checks whether HTTP response body is closed successfully - - contextcheck # check the function whether use a non-inherited context - - cyclop # checks function and package cyclomatic complexity - - dupl # Tool for code clone detection - - durationcheck # check for two durations multiplied together - - errname # Checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error. - - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. - - exhaustive # check exhaustiveness of enum switch statements - - forbidigo # Forbids identifiers - - funlen # Tool for detection of long functions - # - gochecknoglobals # check that no global variables exist - - gochecknoinits # Checks that no init functions are present in Go code - - gocognit # Computes and checks the cognitive complexity of functions - - goconst # Finds repeated strings that could be replaced by a constant - - gocritic # Provides diagnostics that check for bugs, performance and style issues. - - gocyclo # Computes and checks the cyclomatic complexity of functions - - godot # Check if comments end in a period - - goimports # In addition to fixing imports, goimports also formats your code in the same style as gofmt. - - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - - goprintffuncname # Checks that printf-like functions are named with f at the end - - gosec # Inspects source code for security problems - - lll # Reports long lines - - makezero # Finds slice declarations with non-zero initial length - # - nakedret # Finds naked returns in functions greater than a specified function length - - mnd # An analyzer to detect magic numbers. - - nestif # Reports deeply nested if statements - - nilerr # Finds the code that returns nil even if it checks that the error is not nil. - - nilnil # Checks that there is no simultaneous return of nil error and an invalid value. - # - noctx # noctx finds sending http request without context.Context - - nolintlint # Reports ill-formed or insufficient nolint directives - # - nonamedreturns # Reports all named returns - - nosprintfhostport # Checks for misuse of Sprintf to construct a host with port in a URL. - - predeclared # find code that shadows one of Go's predeclared identifiers - - promlinter # Check Prometheus metrics naming via promlint - - revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint. - - rowserrcheck # checks whether Err of rows is checked successfully - - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - - stylecheck # Stylecheck is a replacement for golint - - testpackage # linter that makes you use a separate _test package - - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - - unconvert # Remove unnecessary type conversions - - unparam # Reports unused function parameters - - usetesting # Reports uses of functions with replacement inside the testing package - - wastedassign # wastedassign finds wasted assignment statements. - - whitespace # Tool for detection of leading and trailing whitespace - ## you may want to enable - #- decorder # check declaration order and count of types, constants, variables and functions - #- exhaustruct # Checks if all structure fields are initialized - #- goheader # Checks is file header matches to pattern - #- ireturn # Accept Interfaces, Return Concrete Types - #- prealloc # [premature optimization, but can be used in some cases] Finds slice declarations that could potentially be preallocated - #- varnamelen # [great idea, but too many false positives] checks that the length of a variable's name matches its scope - #- wrapcheck # Checks that errors returned from external packages are wrapped - ## disabled - #- containedctx # containedctx is a linter that detects struct contained context.Context field - #- depguard # [replaced by gomodguard] Go linter that checks if package imports are in a list of acceptable packages - #- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) - #- errchkjson # [don't see profit + I'm against of omitting errors like in the first example https://github.com/breml/errchkjson] Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occasions, where the check for the returned error can be omitted. - #- forcetypeassert # [replaced by errcheck] finds forced type assertions - #- gci # Gci controls golang package import order and makes it always deterministic. - #- godox # Tool for detection of FIXME, TODO and other comment keywords - #- goerr113 # [too strict] Golang linter to check the errors handling expressions - #- gofmt # [replaced by goimports] Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification - #- gofumpt # [replaced by goimports, gofumports is not available yet] Gofumpt checks whether code was gofumpt-ed. - #- grouper # An analyzer to analyze expression groups. - #- ifshort # Checks that your code uses short syntax for if-statements whenever possible - #- importas # Enforces consistent import aliases - #- maintidx # maintidx measures the maintainability index of each function. - #- misspell # [useless] Finds commonly misspelled English words in comments - #- nlreturn # [too strict and mostly code is not more readable] nlreturn checks for a new line before return and branch statements to increase code clarity - #- nosnakecase # Detects snake case of variable naming and function name. # TODO: maybe enable after https://github.com/sivchari/nosnakecase/issues/14 - #- paralleltest # [too many false positives] paralleltest detects missing usage of t.Parallel() method in your Go test - #- tagliatelle # Checks the struct tags. - #- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - #- wsl # [too strict and mostly code is not more readable] Whitespace Linter - Forces you to use empty lines! - - + - asciicheck + - bidichk + - bodyclose + - contextcheck + - cyclop + - dupl + - durationcheck + - errcheck + - errname + - errorlint + - exhaustive + - forbidigo + - funlen + - gochecknoinits + - gocognit + - goconst + - gocritic + - gocyclo + - godot + - gomoddirectives + - gomodguard + - goprintffuncname + - gosec + - govet + - ineffassign + - lll + - makezero + - mnd + - nestif + - nilerr + - nilnil + - nolintlint + - nosprintfhostport + - predeclared + - promlinter + - revive + - rowserrcheck + - sqlclosecheck + - staticcheck + - testpackage + - tparallel + - unconvert + - unparam + - unused + - usetesting + - wastedassign + - whitespace + settings: + cyclop: + max-complexity: 30 + package-average: 10 + errcheck: + check-type-assertions: true + funlen: + lines: 100 + statements: 50 + gocognit: + min-complexity: 20 + gocritic: + settings: + captLocal: + paramsOnly: false + underef: + skipRecvDeref: false + gomodguard: + blocked: + modules: + - github.com/golang/protobuf: + recommendations: + - google.golang.org/protobuf + reason: see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules + - github.com/satori/go.uuid: + recommendations: + - github.com/google/uuid + reason: satori's package is not maintained + - github.com/gofrs/uuid: + recommendations: + - github.com/google/uuid + reason: 'see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw' + govet: + disable: + - fieldalignment + enable-all: true + settings: + shadow: + strict: true + mnd: + ignored-functions: + - os.Chmod + - os.Mkdir + - os.MkdirAll + - os.OpenFile + - os.WriteFile + - prometheus.ExponentialBuckets + - prometheus.ExponentialBucketsRange + - prometheus.LinearBuckets + - strconv.FormatFloat + - strconv.FormatInt + - strconv.FormatUint + - strconv.ParseFloat + - strconv.ParseInt + - strconv.ParseUint + nakedret: + max-func-lines: 0 + nolintlint: + require-explanation: true + require-specific: true + allow-no-explanation: + - funlen + - gocognit + - lll + rowserrcheck: + packages: + - github.com/jmoiron/sqlx + exclusions: + generated: lax + presets: + - comments + - common-false-positives + - legacy + - std-error-handling + rules: + - linters: + - forbidigo + - mnd + - revive + path : ^examples/.*\.go$ + - linters: + - lll + source: ^//\s*go:generate\s + - linters: + - godot + source: (noinspection|TODO) + - linters: + - gocritic + source: //noinspection + - linters: + - errorlint + source: ^\s+if _, ok := err\.\([^.]+\.InternalError\); ok { + - linters: + - bodyclose + - dupl + - funlen + - goconst + - gosec + - noctx + - wrapcheck + - staticcheck + path: _test\.go + paths: + - third_party$ + - builtin$ + - examples$ issues: - # Maximum count of issues with the same text. - # Set to 0 to disable. - # Default: 3 max-same-issues: 50 - - exclude-rules: - - source: "^//\\s*go:generate\\s" - linters: [ lll ] - - source: "(noinspection|TODO)" - linters: [ godot ] - - source: "//noinspection" - linters: [ gocritic ] - - source: "^\\s+if _, ok := err\\.\\([^.]+\\.InternalError\\); ok {" - linters: [ errorlint ] - - path: "_test\\.go" - linters: - - bodyclose - - dupl - - funlen - - goconst - - gosec - - noctx - - wrapcheck +formatters: + enable: + - goimports + exclusions: + generated: lax + paths: + - third_party$ + - builtin$ + - examples$ From 6181facea7e6e5525b6b8da42205d7cce822c249 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Sun, 4 May 2025 15:45:40 +0100 Subject: [PATCH 12/41] update codecov action, pass token (#987) --- .github/workflows/pr.yml | 4 +- .golangci.bck.yml | 258 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 261 insertions(+), 1 deletion(-) create mode 100644 .golangci.bck.yml diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 268e3259b..18c720f03 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -24,4 +24,6 @@ jobs: - name: Run tests run: go test -race -covermode=atomic -coverprofile=coverage.out -v ./... - name: Upload coverage reports to Codecov - uses: codecov/codecov-action@v4 + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.golangci.bck.yml b/.golangci.bck.yml new file mode 100644 index 000000000..a5988825b --- /dev/null +++ b/.golangci.bck.yml @@ -0,0 +1,258 @@ +## Golden config for golangci-lint v1.47.3 +# +# This is the best config for golangci-lint based on my experience and opinion. +# It is very strict, but not extremely strict. +# Feel free to adopt and change it for your needs. + +run: + # Timeout for analysis, e.g. 30s, 5m. + # Default: 1m + timeout: 3m + + +# This file contains only configs which differ from defaults. +# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml +linters-settings: + cyclop: + # The maximal code complexity to report. + # Default: 10 + max-complexity: 30 + # The maximal average package complexity. + # If it's higher than 0.0 (float) the check is enabled + # Default: 0.0 + package-average: 10.0 + + errcheck: + # Report about not checking of errors in type assertions: `a := b.(MyStruct)`. + # Such cases aren't reported by default. + # Default: false + check-type-assertions: true + + funlen: + # Checks the number of lines in a function. + # If lower than 0, disable the check. + # Default: 60 + lines: 100 + # Checks the number of statements in a function. + # If lower than 0, disable the check. + # Default: 40 + statements: 50 + + gocognit: + # Minimal code complexity to report + # Default: 30 (but we recommend 10-20) + min-complexity: 20 + + gocritic: + # Settings passed to gocritic. + # The settings key is the name of a supported gocritic checker. + # The list of supported checkers can be find in https://go-critic.github.io/overview. + settings: + captLocal: + # Whether to restrict checker to params only. + # Default: true + paramsOnly: false + underef: + # Whether to skip (*x).method() calls where x is a pointer receiver. + # Default: true + skipRecvDeref: false + + mnd: + # List of function patterns to exclude from analysis. + # Values always ignored: `time.Date` + # Default: [] + ignored-functions: + - os.Chmod + - os.Mkdir + - os.MkdirAll + - os.OpenFile + - os.WriteFile + - prometheus.ExponentialBuckets + - prometheus.ExponentialBucketsRange + - prometheus.LinearBuckets + - strconv.FormatFloat + - strconv.FormatInt + - strconv.FormatUint + - strconv.ParseFloat + - strconv.ParseInt + - strconv.ParseUint + + gomodguard: + blocked: + # List of blocked modules. + # Default: [] + modules: + - github.com/golang/protobuf: + recommendations: + - google.golang.org/protobuf + reason: "see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules" + - github.com/satori/go.uuid: + recommendations: + - github.com/google/uuid + reason: "satori's package is not maintained" + - github.com/gofrs/uuid: + recommendations: + - github.com/google/uuid + reason: "see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw" + + govet: + # Enable all analyzers. + # Default: false + enable-all: true + # Disable analyzers by name. + # Run `go tool vet help` to see all analyzers. + # Default: [] + disable: + - fieldalignment # too strict + # Settings per analyzer. + settings: + shadow: + # Whether to be strict about shadowing; can be noisy. + # Default: false + strict: true + + nakedret: + # Make an issue if func has more lines of code than this setting, and it has naked returns. + # Default: 30 + max-func-lines: 0 + + nolintlint: + # Exclude following linters from requiring an explanation. + # Default: [] + allow-no-explanation: [ funlen, gocognit, lll ] + # Enable to require an explanation of nonzero length after each nolint directive. + # Default: false + require-explanation: true + # Enable to require nolint directives to mention the specific linter being suppressed. + # Default: false + require-specific: true + + rowserrcheck: + # database/sql is always checked + # Default: [] + packages: + - github.com/jmoiron/sqlx + + tenv: + # The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures. + # Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked. + # Default: false + all: true + + +linters: + disable-all: true + enable: + ## enabled by default + - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases + - gosimple # Linter for Go source code that specializes in simplifying a code + - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string + - ineffassign # Detects when assignments to existing variables are not used + - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks + - unused # Checks Go code for unused constants, variables, functions and types + ## disabled by default + # - asasalint # Check for pass []any as any in variadic func(...any) + - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers + - bidichk # Checks for dangerous unicode character sequences + - bodyclose # checks whether HTTP response body is closed successfully + - contextcheck # check the function whether use a non-inherited context + - cyclop # checks function and package cyclomatic complexity + - dupl # Tool for code clone detection + - durationcheck # check for two durations multiplied together + - errname # Checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error. + - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. + - exhaustive # check exhaustiveness of enum switch statements + - forbidigo # Forbids identifiers + - funlen # Tool for detection of long functions + # - gochecknoglobals # check that no global variables exist + - gochecknoinits # Checks that no init functions are present in Go code + - gocognit # Computes and checks the cognitive complexity of functions + - goconst # Finds repeated strings that could be replaced by a constant + - gocritic # Provides diagnostics that check for bugs, performance and style issues. + - gocyclo # Computes and checks the cyclomatic complexity of functions + - godot # Check if comments end in a period + - goimports # In addition to fixing imports, goimports also formats your code in the same style as gofmt. + - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. + - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. + - goprintffuncname # Checks that printf-like functions are named with f at the end + - gosec # Inspects source code for security problems + - lll # Reports long lines + - makezero # Finds slice declarations with non-zero initial length + # - nakedret # Finds naked returns in functions greater than a specified function length + - mnd # An analyzer to detect magic numbers. + - nestif # Reports deeply nested if statements + - nilerr # Finds the code that returns nil even if it checks that the error is not nil. + - nilnil # Checks that there is no simultaneous return of nil error and an invalid value. + # - noctx # noctx finds sending http request without context.Context + - nolintlint # Reports ill-formed or insufficient nolint directives + # - nonamedreturns # Reports all named returns + - nosprintfhostport # Checks for misuse of Sprintf to construct a host with port in a URL. + - predeclared # find code that shadows one of Go's predeclared identifiers + - promlinter # Check Prometheus metrics naming via promlint + - revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint. + - rowserrcheck # checks whether Err of rows is checked successfully + - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. + - stylecheck # Stylecheck is a replacement for golint + - testpackage # linter that makes you use a separate _test package + - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes + - unconvert # Remove unnecessary type conversions + - unparam # Reports unused function parameters + - usetesting # Reports uses of functions with replacement inside the testing package + - wastedassign # wastedassign finds wasted assignment statements. + - whitespace # Tool for detection of leading and trailing whitespace + ## you may want to enable + #- decorder # check declaration order and count of types, constants, variables and functions + #- exhaustruct # Checks if all structure fields are initialized + #- goheader # Checks is file header matches to pattern + #- ireturn # Accept Interfaces, Return Concrete Types + #- prealloc # [premature optimization, but can be used in some cases] Finds slice declarations that could potentially be preallocated + #- varnamelen # [great idea, but too many false positives] checks that the length of a variable's name matches its scope + #- wrapcheck # Checks that errors returned from external packages are wrapped + ## disabled + #- containedctx # containedctx is a linter that detects struct contained context.Context field + #- depguard # [replaced by gomodguard] Go linter that checks if package imports are in a list of acceptable packages + #- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) + #- errchkjson # [don't see profit + I'm against of omitting errors like in the first example https://github.com/breml/errchkjson] Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occasions, where the check for the returned error can be omitted. + #- forcetypeassert # [replaced by errcheck] finds forced type assertions + #- gci # Gci controls golang package import order and makes it always deterministic. + #- godox # Tool for detection of FIXME, TODO and other comment keywords + #- goerr113 # [too strict] Golang linter to check the errors handling expressions + #- gofmt # [replaced by goimports] Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification + #- gofumpt # [replaced by goimports, gofumports is not available yet] Gofumpt checks whether code was gofumpt-ed. + #- grouper # An analyzer to analyze expression groups. + #- ifshort # Checks that your code uses short syntax for if-statements whenever possible + #- importas # Enforces consistent import aliases + #- maintidx # maintidx measures the maintainability index of each function. + #- misspell # [useless] Finds commonly misspelled English words in comments + #- nlreturn # [too strict and mostly code is not more readable] nlreturn checks for a new line before return and branch statements to increase code clarity + #- nosnakecase # Detects snake case of variable naming and function name. # TODO: maybe enable after https://github.com/sivchari/nosnakecase/issues/14 + #- paralleltest # [too many false positives] paralleltest detects missing usage of t.Parallel() method in your Go test + #- tagliatelle # Checks the struct tags. + #- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers + #- wsl # [too strict and mostly code is not more readable] Whitespace Linter - Forces you to use empty lines! + + +issues: + # Maximum count of issues with the same text. + # Set to 0 to disable. + # Default: 3 + max-same-issues: 50 + + exclude-rules: + - source: "^//\\s*go:generate\\s" + linters: [ lll ] + - source: "(noinspection|TODO)" + linters: [ godot ] + - source: "//noinspection" + linters: [ gocritic ] + - source: "^\\s+if _, ok := err\\.\\([^.]+\\.InternalError\\); ok {" + linters: [ errorlint ] + - path: "_test\\.go" + linters: + - bodyclose + - dupl + - funlen + - goconst + - gosec + - noctx + - wrapcheck From 8ba38f6ba16264760d3fd88892d02b25ef742c24 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Tue, 13 May 2025 12:44:16 +0100 Subject: [PATCH 13/41] remove backup file (#996) --- .golangci.bck.yml | 258 ---------------------------------------------- 1 file changed, 258 deletions(-) delete mode 100644 .golangci.bck.yml diff --git a/.golangci.bck.yml b/.golangci.bck.yml deleted file mode 100644 index a5988825b..000000000 --- a/.golangci.bck.yml +++ /dev/null @@ -1,258 +0,0 @@ -## Golden config for golangci-lint v1.47.3 -# -# This is the best config for golangci-lint based on my experience and opinion. -# It is very strict, but not extremely strict. -# Feel free to adopt and change it for your needs. - -run: - # Timeout for analysis, e.g. 30s, 5m. - # Default: 1m - timeout: 3m - - -# This file contains only configs which differ from defaults. -# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml -linters-settings: - cyclop: - # The maximal code complexity to report. - # Default: 10 - max-complexity: 30 - # The maximal average package complexity. - # If it's higher than 0.0 (float) the check is enabled - # Default: 0.0 - package-average: 10.0 - - errcheck: - # Report about not checking of errors in type assertions: `a := b.(MyStruct)`. - # Such cases aren't reported by default. - # Default: false - check-type-assertions: true - - funlen: - # Checks the number of lines in a function. - # If lower than 0, disable the check. - # Default: 60 - lines: 100 - # Checks the number of statements in a function. - # If lower than 0, disable the check. - # Default: 40 - statements: 50 - - gocognit: - # Minimal code complexity to report - # Default: 30 (but we recommend 10-20) - min-complexity: 20 - - gocritic: - # Settings passed to gocritic. - # The settings key is the name of a supported gocritic checker. - # The list of supported checkers can be find in https://go-critic.github.io/overview. - settings: - captLocal: - # Whether to restrict checker to params only. - # Default: true - paramsOnly: false - underef: - # Whether to skip (*x).method() calls where x is a pointer receiver. - # Default: true - skipRecvDeref: false - - mnd: - # List of function patterns to exclude from analysis. - # Values always ignored: `time.Date` - # Default: [] - ignored-functions: - - os.Chmod - - os.Mkdir - - os.MkdirAll - - os.OpenFile - - os.WriteFile - - prometheus.ExponentialBuckets - - prometheus.ExponentialBucketsRange - - prometheus.LinearBuckets - - strconv.FormatFloat - - strconv.FormatInt - - strconv.FormatUint - - strconv.ParseFloat - - strconv.ParseInt - - strconv.ParseUint - - gomodguard: - blocked: - # List of blocked modules. - # Default: [] - modules: - - github.com/golang/protobuf: - recommendations: - - google.golang.org/protobuf - reason: "see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules" - - github.com/satori/go.uuid: - recommendations: - - github.com/google/uuid - reason: "satori's package is not maintained" - - github.com/gofrs/uuid: - recommendations: - - github.com/google/uuid - reason: "see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw" - - govet: - # Enable all analyzers. - # Default: false - enable-all: true - # Disable analyzers by name. - # Run `go tool vet help` to see all analyzers. - # Default: [] - disable: - - fieldalignment # too strict - # Settings per analyzer. - settings: - shadow: - # Whether to be strict about shadowing; can be noisy. - # Default: false - strict: true - - nakedret: - # Make an issue if func has more lines of code than this setting, and it has naked returns. - # Default: 30 - max-func-lines: 0 - - nolintlint: - # Exclude following linters from requiring an explanation. - # Default: [] - allow-no-explanation: [ funlen, gocognit, lll ] - # Enable to require an explanation of nonzero length after each nolint directive. - # Default: false - require-explanation: true - # Enable to require nolint directives to mention the specific linter being suppressed. - # Default: false - require-specific: true - - rowserrcheck: - # database/sql is always checked - # Default: [] - packages: - - github.com/jmoiron/sqlx - - tenv: - # The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures. - # Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked. - # Default: false - all: true - - -linters: - disable-all: true - enable: - ## enabled by default - - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases - - gosimple # Linter for Go source code that specializes in simplifying a code - - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string - - ineffassign # Detects when assignments to existing variables are not used - - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks - - unused # Checks Go code for unused constants, variables, functions and types - ## disabled by default - # - asasalint # Check for pass []any as any in variadic func(...any) - - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers - - bidichk # Checks for dangerous unicode character sequences - - bodyclose # checks whether HTTP response body is closed successfully - - contextcheck # check the function whether use a non-inherited context - - cyclop # checks function and package cyclomatic complexity - - dupl # Tool for code clone detection - - durationcheck # check for two durations multiplied together - - errname # Checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error. - - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. - - exhaustive # check exhaustiveness of enum switch statements - - forbidigo # Forbids identifiers - - funlen # Tool for detection of long functions - # - gochecknoglobals # check that no global variables exist - - gochecknoinits # Checks that no init functions are present in Go code - - gocognit # Computes and checks the cognitive complexity of functions - - goconst # Finds repeated strings that could be replaced by a constant - - gocritic # Provides diagnostics that check for bugs, performance and style issues. - - gocyclo # Computes and checks the cyclomatic complexity of functions - - godot # Check if comments end in a period - - goimports # In addition to fixing imports, goimports also formats your code in the same style as gofmt. - - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - - goprintffuncname # Checks that printf-like functions are named with f at the end - - gosec # Inspects source code for security problems - - lll # Reports long lines - - makezero # Finds slice declarations with non-zero initial length - # - nakedret # Finds naked returns in functions greater than a specified function length - - mnd # An analyzer to detect magic numbers. - - nestif # Reports deeply nested if statements - - nilerr # Finds the code that returns nil even if it checks that the error is not nil. - - nilnil # Checks that there is no simultaneous return of nil error and an invalid value. - # - noctx # noctx finds sending http request without context.Context - - nolintlint # Reports ill-formed or insufficient nolint directives - # - nonamedreturns # Reports all named returns - - nosprintfhostport # Checks for misuse of Sprintf to construct a host with port in a URL. - - predeclared # find code that shadows one of Go's predeclared identifiers - - promlinter # Check Prometheus metrics naming via promlint - - revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint. - - rowserrcheck # checks whether Err of rows is checked successfully - - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - - stylecheck # Stylecheck is a replacement for golint - - testpackage # linter that makes you use a separate _test package - - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - - unconvert # Remove unnecessary type conversions - - unparam # Reports unused function parameters - - usetesting # Reports uses of functions with replacement inside the testing package - - wastedassign # wastedassign finds wasted assignment statements. - - whitespace # Tool for detection of leading and trailing whitespace - ## you may want to enable - #- decorder # check declaration order and count of types, constants, variables and functions - #- exhaustruct # Checks if all structure fields are initialized - #- goheader # Checks is file header matches to pattern - #- ireturn # Accept Interfaces, Return Concrete Types - #- prealloc # [premature optimization, but can be used in some cases] Finds slice declarations that could potentially be preallocated - #- varnamelen # [great idea, but too many false positives] checks that the length of a variable's name matches its scope - #- wrapcheck # Checks that errors returned from external packages are wrapped - ## disabled - #- containedctx # containedctx is a linter that detects struct contained context.Context field - #- depguard # [replaced by gomodguard] Go linter that checks if package imports are in a list of acceptable packages - #- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) - #- errchkjson # [don't see profit + I'm against of omitting errors like in the first example https://github.com/breml/errchkjson] Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occasions, where the check for the returned error can be omitted. - #- forcetypeassert # [replaced by errcheck] finds forced type assertions - #- gci # Gci controls golang package import order and makes it always deterministic. - #- godox # Tool for detection of FIXME, TODO and other comment keywords - #- goerr113 # [too strict] Golang linter to check the errors handling expressions - #- gofmt # [replaced by goimports] Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification - #- gofumpt # [replaced by goimports, gofumports is not available yet] Gofumpt checks whether code was gofumpt-ed. - #- grouper # An analyzer to analyze expression groups. - #- ifshort # Checks that your code uses short syntax for if-statements whenever possible - #- importas # Enforces consistent import aliases - #- maintidx # maintidx measures the maintainability index of each function. - #- misspell # [useless] Finds commonly misspelled English words in comments - #- nlreturn # [too strict and mostly code is not more readable] nlreturn checks for a new line before return and branch statements to increase code clarity - #- nosnakecase # Detects snake case of variable naming and function name. # TODO: maybe enable after https://github.com/sivchari/nosnakecase/issues/14 - #- paralleltest # [too many false positives] paralleltest detects missing usage of t.Parallel() method in your Go test - #- tagliatelle # Checks the struct tags. - #- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - #- wsl # [too strict and mostly code is not more readable] Whitespace Linter - Forces you to use empty lines! - - -issues: - # Maximum count of issues with the same text. - # Set to 0 to disable. - # Default: 3 - max-same-issues: 50 - - exclude-rules: - - source: "^//\\s*go:generate\\s" - linters: [ lll ] - - source: "(noinspection|TODO)" - linters: [ godot ] - - source: "//noinspection" - linters: [ gocritic ] - - source: "^\\s+if _, ok := err\\.\\([^.]+\\.InternalError\\); ok {" - linters: [ errorlint ] - - path: "_test\\.go" - linters: - - bodyclose - - dupl - - funlen - - goconst - - gosec - - noctx - - wrapcheck From 0116f2994de0960ab65598f60e55d98b43e5fc2d Mon Sep 17 00:00:00 2001 From: Pedro Chaparro <94259578+PChaparro@users.noreply.github.com> Date: Tue, 13 May 2025 06:51:08 -0500 Subject: [PATCH 14/41] feat: add support for image generation using `gpt-image-1` (#971) * feat: add gpt-image-1 support * feat: add example to generate image using gpt-image-1 model * style: missing period in comments * feat: add missing fields to example * docs: add GPT Image 1 to README * revert: keep `examples/images/main.go` unchanged * docs: remove unnecessary newline from example in README file --- README.md | 62 +++++++++++++++++++++++++++++++++- examples/images/main.go | 2 +- image.go | 75 +++++++++++++++++++++++++++++++++++------ 3 files changed, 126 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 57d1d35bf..77b85e519 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ This library provides unofficial Go clients for [OpenAI API](https://platform.op * ChatGPT 4o, o1 * GPT-3, GPT-4 -* DALL·E 2, DALL·E 3 +* DALL·E 2, DALL·E 3, GPT Image 1 * Whisper ## Installation @@ -357,6 +357,66 @@ func main() { ``` +
+GPT Image 1 image generation + +```go +package main + +import ( + "context" + "encoding/base64" + "fmt" + "os" + + openai "github.com/sashabaranov/go-openai" +) + +func main() { + c := openai.NewClient("your token") + ctx := context.Background() + + req := openai.ImageRequest{ + Prompt: "Parrot on a skateboard performing a trick. Large bold text \"SKATE MASTER\" banner at the bottom of the image. Cartoon style, natural light, high detail, 1:1 aspect ratio.", + Background: openai.CreateImageBackgroundOpaque, + Model: openai.CreateImageModelGptImage1, + Size: openai.CreateImageSize1024x1024, + N: 1, + Quality: openai.CreateImageQualityLow, + OutputCompression: 100, + OutputFormat: openai.CreateImageOutputFormatJPEG, + // Moderation: openai.CreateImageModerationLow, + // User: "", + } + + resp, err := c.CreateImage(ctx, req) + if err != nil { + fmt.Printf("Image creation Image generation with GPT Image 1error: %v\n", err) + return + } + + fmt.Println("Image Base64:", resp.Data[0].B64JSON) + + // Decode the base64 data + imgBytes, err := base64.StdEncoding.DecodeString(resp.Data[0].B64JSON) + if err != nil { + fmt.Printf("Base64 decode error: %v\n", err) + return + } + + // Write image to file + outputPath := "generated_image.jpg" + err = os.WriteFile(outputPath, imgBytes, 0644) + if err != nil { + fmt.Printf("Failed to write image file: %v\n", err) + return + } + + fmt.Printf("The image was saved as %s\n", outputPath) +} +``` +
+
Configuring proxy diff --git a/examples/images/main.go b/examples/images/main.go index 5ee649d22..2bfeb7973 100644 --- a/examples/images/main.go +++ b/examples/images/main.go @@ -25,4 +25,4 @@ func main() { return } fmt.Println(respUrl.Data[0].URL) -} +} \ No newline at end of file diff --git a/image.go b/image.go index 577d7db95..d62622a35 100644 --- a/image.go +++ b/image.go @@ -13,51 +13,101 @@ const ( CreateImageSize256x256 = "256x256" CreateImageSize512x512 = "512x512" CreateImageSize1024x1024 = "1024x1024" + // dall-e-3 supported only. CreateImageSize1792x1024 = "1792x1024" CreateImageSize1024x1792 = "1024x1792" + + // gpt-image-1 supported only. + CreateImageSize1536x1024 = "1536x1024" // Landscape + CreateImageSize1024x1536 = "1024x1536" // Portrait ) const ( - CreateImageResponseFormatURL = "url" + // dall-e-2 and dall-e-3 only. CreateImageResponseFormatB64JSON = "b64_json" + CreateImageResponseFormatURL = "url" ) const ( - CreateImageModelDallE2 = "dall-e-2" - CreateImageModelDallE3 = "dall-e-3" + CreateImageModelDallE2 = "dall-e-2" + CreateImageModelDallE3 = "dall-e-3" + CreateImageModelGptImage1 = "gpt-image-1" ) const ( CreateImageQualityHD = "hd" CreateImageQualityStandard = "standard" + + // gpt-image-1 only. + CreateImageQualityHigh = "high" + CreateImageQualityMedium = "medium" + CreateImageQualityLow = "low" ) const ( + // dall-e-3 only. CreateImageStyleVivid = "vivid" CreateImageStyleNatural = "natural" ) +const ( + // gpt-image-1 only. + CreateImageBackgroundTransparent = "transparent" + CreateImageBackgroundOpaque = "opaque" +) + +const ( + // gpt-image-1 only. + CreateImageModerationLow = "low" +) + +const ( + // gpt-image-1 only. + CreateImageOutputFormatPNG = "png" + CreateImageOutputFormatJPEG = "jpeg" + CreateImageOutputFormatWEBP = "webp" +) + // ImageRequest represents the request structure for the image API. type ImageRequest struct { - Prompt string `json:"prompt,omitempty"` - Model string `json:"model,omitempty"` - N int `json:"n,omitempty"` - Quality string `json:"quality,omitempty"` - Size string `json:"size,omitempty"` - Style string `json:"style,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` - User string `json:"user,omitempty"` + Prompt string `json:"prompt,omitempty"` + Model string `json:"model,omitempty"` + N int `json:"n,omitempty"` + Quality string `json:"quality,omitempty"` + Size string `json:"size,omitempty"` + Style string `json:"style,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + User string `json:"user,omitempty"` + Background string `json:"background,omitempty"` + Moderation string `json:"moderation,omitempty"` + OutputCompression int `json:"output_compression,omitempty"` + OutputFormat string `json:"output_format,omitempty"` } // ImageResponse represents a response structure for image API. type ImageResponse struct { Created int64 `json:"created,omitempty"` Data []ImageResponseDataInner `json:"data,omitempty"` + Usage ImageResponseUsage `json:"usage,omitempty"` httpHeader } +// ImageResponseInputTokensDetails represents the token breakdown for input tokens. +type ImageResponseInputTokensDetails struct { + TextTokens int `json:"text_tokens,omitempty"` + ImageTokens int `json:"image_tokens,omitempty"` +} + +// ImageResponseUsage represents the token usage information for image API. +type ImageResponseUsage struct { + TotalTokens int `json:"total_tokens,omitempty"` + InputTokens int `json:"input_tokens,omitempty"` + OutputTokens int `json:"output_tokens,omitempty"` + InputTokensDetails ImageResponseInputTokensDetails `json:"input_tokens_details,omitempty"` +} + // ImageResponseDataInner represents a response data structure for image API. type ImageResponseDataInner struct { URL string `json:"url,omitempty"` @@ -91,6 +141,8 @@ type ImageEditRequest struct { N int `json:"n,omitempty"` Size string `json:"size,omitempty"` ResponseFormat string `json:"response_format,omitempty"` + Quality string `json:"quality,omitempty"` + User string `json:"user,omitempty"` } // CreateEditImage - API call to create an image. This is the main endpoint of the DALL-E API. @@ -159,6 +211,7 @@ type ImageVariRequest struct { N int `json:"n,omitempty"` Size string `json:"size,omitempty"` ResponseFormat string `json:"response_format,omitempty"` + User string `json:"user,omitempty"` } // CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API. From 6aaa7322960741a84da11ac360516e4ec813dfff Mon Sep 17 00:00:00 2001 From: Justa Date: Tue, 13 May 2025 19:52:44 +0800 Subject: [PATCH 15/41] add ChatTemplateKwargs to ChatCompletionRequest (#980) Co-authored-by: Justa --- chat.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/chat.go b/chat.go index 0f91d481c..c8a3e81b3 100644 --- a/chat.go +++ b/chat.go @@ -275,6 +275,11 @@ type ChatCompletionRequest struct { Metadata map[string]string `json:"metadata,omitempty"` // Configuration for a predicted output. Prediction *Prediction `json:"prediction,omitempty"` + // ChatTemplateKwargs provides a way to add non-standard parameters to the request body. + // Additional kwargs to pass to the template renderer. Will be accessible by the chat template. + // Such as think mode for qwen3. "chat_template_kwargs": {"enable_thinking": false} + // https://qwen.readthedocs.io/en/latest/deployment/vllm.html#thinking-non-thinking-modes + ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"` } type StreamOptions struct { From 4d2e7ab29d7bf853c740e9e63187fed960e6b0b4 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Tue, 13 May 2025 12:59:06 +0100 Subject: [PATCH 16/41] fix lint (#998) --- examples/images/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/images/main.go b/examples/images/main.go index 2bfeb7973..5ee649d22 100644 --- a/examples/images/main.go +++ b/examples/images/main.go @@ -25,4 +25,4 @@ func main() { return } fmt.Println(respUrl.Data[0].URL) -} \ No newline at end of file +} From 8c65b35c57ad4e9ba408def9bf9ff97817aab932 Mon Sep 17 00:00:00 2001 From: Axb12 <67110563+Axb12@users.noreply.github.com> Date: Tue, 20 May 2025 21:45:40 +0800 Subject: [PATCH 17/41] update image api *os.File to io.Reader (#994) * update image api *os.File to io.Reader * update code style * add reader test * supplementary reader test * update the reader in the form builder test * add commnet * update comment * update code style --- image.go | 43 ++++++++++++++++++----------------- image_test.go | 8 +++---- internal/form_builder.go | 35 ++++++++++++++++++++++++++-- internal/form_builder_test.go | 29 +++++++++++++++++++++++ 4 files changed, 88 insertions(+), 27 deletions(-) diff --git a/image.go b/image.go index d62622a35..72077ce41 100644 --- a/image.go +++ b/image.go @@ -3,8 +3,8 @@ package openai import ( "bytes" "context" + "io" "net/http" - "os" "strconv" ) @@ -134,15 +134,15 @@ func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (respons // ImageEditRequest represents the request structure for the image API. type ImageEditRequest struct { - Image *os.File `json:"image,omitempty"` - Mask *os.File `json:"mask,omitempty"` - Prompt string `json:"prompt,omitempty"` - Model string `json:"model,omitempty"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` - Quality string `json:"quality,omitempty"` - User string `json:"user,omitempty"` + Image io.Reader `json:"image,omitempty"` + Mask io.Reader `json:"mask,omitempty"` + Prompt string `json:"prompt,omitempty"` + Model string `json:"model,omitempty"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + Quality string `json:"quality,omitempty"` + User string `json:"user,omitempty"` } // CreateEditImage - API call to create an image. This is the main endpoint of the DALL-E API. @@ -150,15 +150,16 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) body := &bytes.Buffer{} builder := c.createFormBuilder(body) - // image - err = builder.CreateFormFile("image", request.Image) + // image, filename is not required + err = builder.CreateFormFileReader("image", request.Image, "") if err != nil { return } // mask, it is optional if request.Mask != nil { - err = builder.CreateFormFile("mask", request.Mask) + // mask, filename is not required + err = builder.CreateFormFileReader("mask", request.Mask, "") if err != nil { return } @@ -206,12 +207,12 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) // ImageVariRequest represents the request structure for the image API. type ImageVariRequest struct { - Image *os.File `json:"image,omitempty"` - Model string `json:"model,omitempty"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` - User string `json:"user,omitempty"` + Image io.Reader `json:"image,omitempty"` + Model string `json:"model,omitempty"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + User string `json:"user,omitempty"` } // CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API. @@ -220,8 +221,8 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) body := &bytes.Buffer{} builder := c.createFormBuilder(body) - // image - err = builder.CreateFormFile("image", request.Image) + // image, filename is not required + err = builder.CreateFormFileReader("image", request.Image, "") if err != nil { return } diff --git a/image_test.go b/image_test.go index 9332dd5cd..644005515 100644 --- a/image_test.go +++ b/image_test.go @@ -54,13 +54,13 @@ func TestImageFormBuilderFailures(t *testing.T) { } mockFailedErr := fmt.Errorf("mock form builder fail") - mockBuilder.mockCreateFormFile = func(string, *os.File) error { + mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { return mockFailedErr } _, err := client.CreateEditImage(ctx, req) checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - mockBuilder.mockCreateFormFile = func(name string, _ *os.File) error { + mockBuilder.mockCreateFormFileReader = func(name string, _ io.Reader, _ string) error { if name == "mask" { return mockFailedErr } @@ -119,13 +119,13 @@ func TestVariImageFormBuilderFailures(t *testing.T) { req := ImageVariRequest{} mockFailedErr := fmt.Errorf("mock form builder fail") - mockBuilder.mockCreateFormFile = func(string, *os.File) error { + mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { return mockFailedErr } _, err := client.CreateVariImage(ctx, req) checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") - mockBuilder.mockCreateFormFile = func(string, *os.File) error { + mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } diff --git a/internal/form_builder.go b/internal/form_builder.go index 2224fad45..1c2513dd9 100644 --- a/internal/form_builder.go +++ b/internal/form_builder.go @@ -4,8 +4,10 @@ import ( "fmt" "io" "mime/multipart" + "net/textproto" "os" - "path" + "path/filepath" + "strings" ) type FormBuilder interface { @@ -30,8 +32,37 @@ func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) er return fb.createFormFile(fieldname, file, file.Name()) } +var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") + +func escapeQuotes(s string) string { + return quoteEscaper.Replace(s) +} + +// CreateFormFileReader creates a form field with a file reader. +// The filename in parameters can be an empty string. +// The filename in Content-Disposition is required, But it can be an empty string. func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error { - return fb.createFormFile(fieldname, r, path.Base(filename)) + h := make(textproto.MIMEHeader) + h.Set( + "Content-Disposition", + fmt.Sprintf( + `form-data; name="%s"; filename="%s"`, + escapeQuotes(fieldname), + escapeQuotes(filepath.Base(filename)), + ), + ) + + fieldWriter, err := fb.writer.CreatePart(h) + if err != nil { + return err + } + + _, err = io.Copy(fieldWriter, r) + if err != nil { + return err + } + + return nil } func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error { diff --git a/internal/form_builder_test.go b/internal/form_builder_test.go index 8df989e3b..76922c1ba 100644 --- a/internal/form_builder_test.go +++ b/internal/form_builder_test.go @@ -43,3 +43,32 @@ func TestFormBuilderWithClosedFile(t *testing.T) { checks.HasError(t, err, "formbuilder should return error if file is closed") checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed") } + +type failingReader struct { +} + +var errMockFailingReaderError = errors.New("mock reader failed") + +func (*failingReader) Read([]byte) (int, error) { + return 0, errMockFailingReaderError +} + +func TestFormBuilderWithReader(t *testing.T) { + file, err := os.CreateTemp(t.TempDir(), "") + if err != nil { + t.Fatalf("Error creating tmp file: %v", err) + } + defer file.Close() + builder := NewFormBuilder(&failingWriter{}) + err = builder.CreateFormFileReader("file", file, file.Name()) + checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails") + + builder = NewFormBuilder(&bytes.Buffer{}) + reader := &failingReader{} + err = builder.CreateFormFileReader("file", reader, "") + checks.ErrorIs(t, err, errMockFailingReaderError, "formbuilder should return error if copy reader fails") + + successReader := &bytes.Buffer{} + err = builder.CreateFormFileReader("file", successReader, "") + checks.NoError(t, err, "formbuilder should not return error") +} From ff9d83a4854790ecbd16e6328415a32c7497efaf Mon Sep 17 00:00:00 2001 From: "JT A." Date: Thu, 29 May 2025 04:31:35 -0600 Subject: [PATCH 18/41] skip json field (#1009) * skip json field * backfill some coverage and tests --- jsonschema/json.go | 7 ++++-- jsonschema/json_test.go | 47 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/jsonschema/json.go b/jsonschema/json.go index d458418f3..03bb68891 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -126,9 +126,12 @@ func reflectSchemaObject(t reflect.Type) (*Definition, error) { } jsonTag := field.Tag.Get("json") var required = true - if jsonTag == "" { + switch { + case jsonTag == "-": + continue + case jsonTag == "": jsonTag = field.Name - } else if strings.HasSuffix(jsonTag, ",omitempty") { + case strings.HasSuffix(jsonTag, ",omitempty"): jsonTag = strings.TrimSuffix(jsonTag, ",omitempty") required = false } diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go index 17f0aba8a..84f25fa85 100644 --- a/jsonschema/json_test.go +++ b/jsonschema/json_test.go @@ -329,6 +329,53 @@ func TestStructToSchema(t *testing.T) { "additionalProperties":false }`, }, + { + name: "Test with exclude mark", + in: struct { + Name string `json:"-"` + }{ + Name: "Name", + }, + want: `{ + "type":"object", + "additionalProperties":false + }`, + }, + { + name: "Test with no json tag", + in: struct { + Name string + }{ + Name: "", + }, + want: `{ + "type":"object", + "properties":{ + "Name":{ + "type":"string" + } + }, + "required":["Name"], + "additionalProperties":false + }`, + }, + { + name: "Test with omitempty tag", + in: struct { + Name string `json:"name,omitempty"` + }{ + Name: "", + }, + want: `{ + "type":"object", + "properties":{ + "name":{ + "type":"string" + } + }, + "additionalProperties":false + }`, + }, } for _, tt := range tests { From d7dca83beda99528c974392a8295630775c1c197 Mon Sep 17 00:00:00 2001 From: Axb12 <67110563+Axb12@users.noreply.github.com> Date: Tue, 17 Jun 2025 03:08:14 +0800 Subject: [PATCH 19/41] fix image api missing filename bug (#1017) * fix image api missing filename bug * add test * add test * update test --- image.go | 32 +++++++++++++++++++++++++++++--- internal/form_builder.go | 17 +++++++++++++++-- internal/form_builder_test.go | 18 ++++++++++++++++++ 3 files changed, 62 insertions(+), 5 deletions(-) diff --git a/image.go b/image.go index 72077ce41..84b9daf02 100644 --- a/image.go +++ b/image.go @@ -132,7 +132,32 @@ func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (respons return } +// WrapReader wraps an io.Reader with filename and Content-type. +func WrapReader(rdr io.Reader, filename string, contentType string) io.Reader { + return file{rdr, filename, contentType} +} + +type file struct { + io.Reader + name string + contentType string +} + +func (f file) Name() string { + if f.name != "" { + return f.name + } else if named, ok := f.Reader.(interface{ Name() string }); ok { + return named.Name() + } + return "" +} + +func (f file) ContentType() string { + return f.contentType +} + // ImageEditRequest represents the request structure for the image API. +// Use WrapReader to wrap an io.Reader with filename and Content-type. type ImageEditRequest struct { Image io.Reader `json:"image,omitempty"` Mask io.Reader `json:"mask,omitempty"` @@ -150,7 +175,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) body := &bytes.Buffer{} builder := c.createFormBuilder(body) - // image, filename is not required + // image, filename verification can be postponed err = builder.CreateFormFileReader("image", request.Image, "") if err != nil { return @@ -158,7 +183,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) // mask, it is optional if request.Mask != nil { - // mask, filename is not required + // filename verification can be postponed err = builder.CreateFormFileReader("mask", request.Mask, "") if err != nil { return @@ -206,6 +231,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) } // ImageVariRequest represents the request structure for the image API. +// Use WrapReader to wrap an io.Reader with filename and Content-type. type ImageVariRequest struct { Image io.Reader `json:"image,omitempty"` Model string `json:"model,omitempty"` @@ -221,7 +247,7 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) body := &bytes.Buffer{} builder := c.createFormBuilder(body) - // image, filename is not required + // image, filename verification can be postponed err = builder.CreateFormFileReader("image", request.Image, "") if err != nil { return diff --git a/internal/form_builder.go b/internal/form_builder.go index 1c2513dd9..5b382df20 100644 --- a/internal/form_builder.go +++ b/internal/form_builder.go @@ -39,9 +39,18 @@ func escapeQuotes(s string) string { } // CreateFormFileReader creates a form field with a file reader. -// The filename in parameters can be an empty string. -// The filename in Content-Disposition is required, But it can be an empty string. +// The filename in Content-Disposition is required. func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error { + if filename == "" { + if f, ok := r.(interface{ Name() string }); ok { + filename = f.Name() + } + } + var contentType string + if f, ok := r.(interface{ ContentType() string }); ok { + contentType = f.ContentType() + } + h := make(textproto.MIMEHeader) h.Set( "Content-Disposition", @@ -51,6 +60,10 @@ func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader escapeQuotes(filepath.Base(filename)), ), ) + // content type is optional, but it can be set + if contentType != "" { + h.Set("Content-Type", contentType) + } fieldWriter, err := fb.writer.CreatePart(h) if err != nil { diff --git a/internal/form_builder_test.go b/internal/form_builder_test.go index 76922c1ba..f4958ad5e 100644 --- a/internal/form_builder_test.go +++ b/internal/form_builder_test.go @@ -1,6 +1,8 @@ package openai //nolint:testpackage // testing private field import ( + "io" + "github.com/sashabaranov/go-openai/internal/test/checks" "bytes" @@ -53,6 +55,18 @@ func (*failingReader) Read([]byte) (int, error) { return 0, errMockFailingReaderError } +type readerWithNameAndContentType struct { + io.Reader +} + +func (*readerWithNameAndContentType) Name() string { + return "" +} + +func (*readerWithNameAndContentType) ContentType() string { + return "image/png" +} + func TestFormBuilderWithReader(t *testing.T) { file, err := os.CreateTemp(t.TempDir(), "") if err != nil { @@ -71,4 +85,8 @@ func TestFormBuilderWithReader(t *testing.T) { successReader := &bytes.Buffer{} err = builder.CreateFormFileReader("file", successReader, "") checks.NoError(t, err, "formbuilder should not return error") + + rnc := &readerWithNameAndContentType{Reader: &bytes.Buffer{}} + err = builder.CreateFormFileReader("file", rnc, "") + checks.NoError(t, err, "formbuilder should not return error") } From a931bf7e85af9d39af414de21d907b6835281519 Mon Sep 17 00:00:00 2001 From: Whitea <145814986+Whitea029@users.noreply.github.com> Date: Tue, 17 Jun 2025 18:00:15 +0800 Subject: [PATCH 20/41] test: enhance error accumulator and form builder tests, add marshaller tests (#999) * test: enhance error accumulator and form builder tests, add marshaller tests * test: fix some issue form golangci-lint * test: gofmt form builder test * fix * fix * fix lint --- internal/error_accumulator_test.go | 43 +++++++----------- internal/form_builder.go | 3 ++ internal/form_builder_test.go | 73 +++++++++++++++++++++++++++++- internal/marshaller_test.go | 34 ++++++++++++++ internal/unmarshaler_test.go | 37 +++++++++++++++ 5 files changed, 163 insertions(+), 27 deletions(-) create mode 100644 internal/marshaller_test.go create mode 100644 internal/unmarshaler_test.go diff --git a/internal/error_accumulator_test.go b/internal/error_accumulator_test.go index d48f28177..3fa9d7714 100644 --- a/internal/error_accumulator_test.go +++ b/internal/error_accumulator_test.go @@ -1,41 +1,32 @@ package openai_test import ( - "bytes" - "errors" "testing" - utils "github.com/sashabaranov/go-openai/internal" - "github.com/sashabaranov/go-openai/internal/test" + openai "github.com/sashabaranov/go-openai/internal" + "github.com/sashabaranov/go-openai/internal/test/checks" ) -func TestErrorAccumulatorBytes(t *testing.T) { - accumulator := &utils.DefaultErrorAccumulator{ - Buffer: &bytes.Buffer{}, +func TestDefaultErrorAccumulator_WriteMultiple(t *testing.T) { + ea, ok := openai.NewErrorAccumulator().(*openai.DefaultErrorAccumulator) + if !ok { + t.Fatal("type assertion to *DefaultErrorAccumulator failed") } + checks.NoError(t, ea.Write([]byte("{\"error\": \"test1\"}"))) + checks.NoError(t, ea.Write([]byte("{\"error\": \"test2\"}"))) - errBytes := accumulator.Bytes() - if len(errBytes) != 0 { - t.Fatalf("Did not return nil with empty bytes: %s", string(errBytes)) - } - - err := accumulator.Write([]byte("{}")) - if err != nil { - t.Fatalf("%+v", err) - } - - errBytes = accumulator.Bytes() - if len(errBytes) == 0 { - t.Fatalf("Did not return error bytes when has error: %s", string(errBytes)) + expected := "{\"error\": \"test1\"}{\"error\": \"test2\"}" + if string(ea.Bytes()) != expected { + t.Fatalf("Expected %q, got %q", expected, ea.Bytes()) } } -func TestErrorByteWriteErrors(t *testing.T) { - accumulator := &utils.DefaultErrorAccumulator{ - Buffer: &test.FailingErrorBuffer{}, +func TestDefaultErrorAccumulator_EmptyBuffer(t *testing.T) { + ea, ok := openai.NewErrorAccumulator().(*openai.DefaultErrorAccumulator) + if !ok { + t.Fatal("type assertion to *DefaultErrorAccumulator failed") } - err := accumulator.Write([]byte("{")) - if !errors.Is(err, test.ErrTestErrorAccumulatorWriteFailed) { - t.Fatalf("Did not return error when write failed: %v", err) + if len(ea.Bytes()) != 0 { + t.Fatal("Buffer should be empty initially") } } diff --git a/internal/form_builder.go b/internal/form_builder.go index 5b382df20..a17e820c0 100644 --- a/internal/form_builder.go +++ b/internal/form_builder.go @@ -97,6 +97,9 @@ func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, file } func (fb *DefaultFormBuilder) WriteField(fieldname, value string) error { + if fieldname == "" { + return fmt.Errorf("fieldname cannot be empty") + } return fb.writer.WriteField(fieldname, value) } diff --git a/internal/form_builder_test.go b/internal/form_builder_test.go index f4958ad5e..ddd6b8448 100644 --- a/internal/form_builder_test.go +++ b/internal/form_builder_test.go @@ -1,16 +1,57 @@ package openai //nolint:testpackage // testing private field import ( + "errors" "io" "github.com/sashabaranov/go-openai/internal/test/checks" "bytes" - "errors" "os" "testing" ) +type mockFormBuilder struct { + mockCreateFormFile func(string, *os.File) error + mockWriteField func(string, string) error + mockClose func() error +} + +func (m *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error { + return m.mockCreateFormFile(fieldname, file) +} + +func (m *mockFormBuilder) WriteField(fieldname, value string) error { + return m.mockWriteField(fieldname, value) +} + +func (m *mockFormBuilder) Close() error { + return m.mockClose() +} + +func (m *mockFormBuilder) FormDataContentType() string { + return "" +} + +func TestCloseMethod(t *testing.T) { + t.Run("NormalClose", func(t *testing.T) { + body := &bytes.Buffer{} + builder := NewFormBuilder(body) + checks.NoError(t, builder.Close(), "正常关闭应成功") + }) + + t.Run("ErrorPropagation", func(t *testing.T) { + errorMock := errors.New("mock close error") + mockBuilder := &mockFormBuilder{ + mockClose: func() error { + return errorMock + }, + } + err := mockBuilder.Close() + checks.ErrorIs(t, err, errorMock, "应传递关闭错误") + }) +} + type failingWriter struct { } @@ -90,3 +131,33 @@ func TestFormBuilderWithReader(t *testing.T) { err = builder.CreateFormFileReader("file", rnc, "") checks.NoError(t, err, "formbuilder should not return error") } + +func TestFormDataContentType(t *testing.T) { + t.Run("ReturnsUnderlyingWriterContentType", func(t *testing.T) { + buf := &bytes.Buffer{} + builder := NewFormBuilder(buf) + + contentType := builder.FormDataContentType() + if contentType == "" { + t.Errorf("expected non-empty content type, got empty string") + } + }) +} + +func TestWriteField(t *testing.T) { + t.Run("EmptyFieldNameShouldReturnError", func(t *testing.T) { + buf := &bytes.Buffer{} + builder := NewFormBuilder(buf) + + err := builder.WriteField("", "some value") + checks.HasError(t, err, "fieldname is required") + }) + + t.Run("ValidFieldNameShouldSucceed", func(t *testing.T) { + buf := &bytes.Buffer{} + builder := NewFormBuilder(buf) + + err := builder.WriteField("key", "value") + checks.NoError(t, err, "should write field without error") + }) +} diff --git a/internal/marshaller_test.go b/internal/marshaller_test.go new file mode 100644 index 000000000..70694faed --- /dev/null +++ b/internal/marshaller_test.go @@ -0,0 +1,34 @@ +package openai_test + +import ( + "testing" + + openai "github.com/sashabaranov/go-openai/internal" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestJSONMarshaller_Normal(t *testing.T) { + jm := &openai.JSONMarshaller{} + data := map[string]string{"key": "value"} + + b, err := jm.Marshal(data) + checks.NoError(t, err) + if len(b) == 0 { + t.Fatal("should return non-empty bytes") + } +} + +func TestJSONMarshaller_InvalidInput(t *testing.T) { + jm := &openai.JSONMarshaller{} + _, err := jm.Marshal(make(chan int)) + checks.HasError(t, err, "should return error for unsupported type") +} + +func TestJSONMarshaller_EmptyValue(t *testing.T) { + jm := &openai.JSONMarshaller{} + b, err := jm.Marshal(nil) + checks.NoError(t, err) + if string(b) != "null" { + t.Fatalf("unexpected marshaled value: %s", string(b)) + } +} diff --git a/internal/unmarshaler_test.go b/internal/unmarshaler_test.go new file mode 100644 index 000000000..d63eac779 --- /dev/null +++ b/internal/unmarshaler_test.go @@ -0,0 +1,37 @@ +package openai_test + +import ( + "testing" + + openai "github.com/sashabaranov/go-openai/internal" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestJSONUnmarshaler_Normal(t *testing.T) { + jm := &openai.JSONUnmarshaler{} + data := []byte(`{"key":"value"}`) + var v map[string]string + + err := jm.Unmarshal(data, &v) + checks.NoError(t, err) + if v["key"] != "value" { + t.Fatal("unmarshal result mismatch") + } +} + +func TestJSONUnmarshaler_InvalidJSON(t *testing.T) { + jm := &openai.JSONUnmarshaler{} + data := []byte(`{invalid}`) + var v map[string]interface{} + + err := jm.Unmarshal(data, &v) + checks.HasError(t, err, "should return error for invalid JSON") +} + +func TestJSONUnmarshaler_EmptyInput(t *testing.T) { + jm := &openai.JSONUnmarshaler{} + var v interface{} + + err := jm.Unmarshal(nil, &v) + checks.HasError(t, err, "should return error for nil input") +} From c125ae2ad7b239355161c7f4260a3743ec0c182b Mon Sep 17 00:00:00 2001 From: Hritik Raj Date: Wed, 25 Jun 2025 15:34:45 +0530 Subject: [PATCH 21/41] Fix for removing usage in every stream chunk response. (#1022) * Fix for https://github.com/sashabaranov/go-openai/issues/1021: 1. Make Usage field in completions Response to pointer. * Fix for https://github.com/sashabaranov/go-openai/issues/1021: 1. Make Usage field in completions Response to pointer. 2. Add omitempty to json tag Signed-off-by: Hritik003 --------- Signed-off-by: Hritik003 --- completion.go | 2 +- completion_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/completion.go b/completion.go index 21d4897c4..02ce7b016 100644 --- a/completion.go +++ b/completion.go @@ -242,7 +242,7 @@ type CompletionResponse struct { Created int64 `json:"created"` Model string `json:"model"` Choices []CompletionChoice `json:"choices"` - Usage Usage `json:"usage"` + Usage *Usage `json:"usage,omitempty"` httpHeader } diff --git a/completion_test.go b/completion_test.go index 27e2d150e..f0ead0d63 100644 --- a/completion_test.go +++ b/completion_test.go @@ -192,7 +192,7 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { } inputTokens *= n completionTokens := completionReq.MaxTokens * len(prompts) * n - res.Usage = openai.Usage{ + res.Usage = &openai.Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, From 8e9b2ac83ab2d15a34982d689d97ec5db34cbb06 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Tue, 8 Jul 2025 19:19:49 +0800 Subject: [PATCH 22/41] fix: properly unmarshal JSON schema in ChatCompletionResponseFormatJSONSchema.schema (#1028) * feat: #1027 * add tests * feat: #1027 * feat: #1027 * feat: #1027 * update chat_test.go * feat: #1027 * add test cases --- chat.go | 27 ++++++++++ chat_test.go | 139 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 166 insertions(+) diff --git a/chat.go b/chat.go index c8a3e81b3..b4a0ad90f 100644 --- a/chat.go +++ b/chat.go @@ -5,6 +5,8 @@ import ( "encoding/json" "errors" "net/http" + + "github.com/sashabaranov/go-openai/jsonschema" ) // Chat message role defined by the OpenAI API. @@ -221,6 +223,31 @@ type ChatCompletionResponseFormatJSONSchema struct { Strict bool `json:"strict"` } +func (r *ChatCompletionResponseFormatJSONSchema) UnmarshalJSON(data []byte) error { + type rawJSONSchema struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Schema json.RawMessage `json:"schema"` + Strict bool `json:"strict"` + } + var raw rawJSONSchema + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + r.Name = raw.Name + r.Description = raw.Description + r.Strict = raw.Strict + if len(raw.Schema) > 0 && string(raw.Schema) != "null" { + var d jsonschema.Definition + err := json.Unmarshal(raw.Schema, &d) + if err != nil { + return err + } + r.Schema = &d + } + return nil +} + // ChatCompletionRequest represents a request structure for chat completion API. type ChatCompletionRequest struct { Model string `json:"model"` diff --git a/chat_test.go b/chat_test.go index 514706c96..172ce0740 100644 --- a/chat_test.go +++ b/chat_test.go @@ -946,3 +946,142 @@ func TestFinishReason(t *testing.T) { } } } + +func TestChatCompletionResponseFormatJSONSchema_UnmarshalJSON(t *testing.T) { + type args struct { + data []byte + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + "", + args{ + data: []byte(`{ + "name": "math_response", + "strict": true, + "schema": { + "type": "object", + "properties": { + "steps": { + "type": "array", + "items": { + "type": "object", + "properties": { + "explanation": { "type": "string" }, + "output": { "type": "string" } + }, + "required": ["explanation","output"], + "additionalProperties": false + } + }, + "final_answer": { "type": "string" } + }, + "required": ["steps","final_answer"], + "additionalProperties": false + } + }`), + }, + false, + }, + { + "", + args{ + data: []byte(`{ + "name": "math_response", + "strict": true, + "schema": null + }`), + }, + false, + }, + { + "", + args{ + data: []byte(`[123,456]`), + }, + true, + }, + { + "", + args{ + data: []byte(`{ + "name": "math_response", + "strict": true, + "schema": 123456 + }`), + }, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var r openai.ChatCompletionResponseFormatJSONSchema + err := r.UnmarshalJSON(tt.args.data) + if (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestChatCompletionRequest_UnmarshalJSON(t *testing.T) { + type args struct { + bs []byte + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + "", + args{bs: []byte(`{ + "model": "llama3-1b", + "messages": [ + { "role": "system", "content": "You are a helpful math tutor." }, + { "role": "user", "content": "solve 8x + 31 = 2" } + ], + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "math_response", + "strict": true, + "schema": { + "type": "object", + "properties": { + "steps": { + "type": "array", + "items": { + "type": "object", + "properties": { + "explanation": { "type": "string" }, + "output": { "type": "string" } + }, + "required": ["explanation","output"], + "additionalProperties": false + } + }, + "final_answer": { "type": "string" } + }, + "required": ["steps","final_answer"], + "additionalProperties": false + } + } + } +}`)}, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var m openai.ChatCompletionRequest + err := json.Unmarshal(tt.args.bs, &m) + if err != nil { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} From bd612cebceb5d84c423bece9e2f2766c7567f8ed Mon Sep 17 00:00:00 2001 From: Matt Tinsley <68241446+mathewtinsley@users.noreply.github.com> Date: Tue, 8 Jul 2025 04:23:41 -0700 Subject: [PATCH 23/41] Add support for Chat Completion Service Tier (#1023) * Add support for Chat Completion Service Tier * Add priority service tier --- chat.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/chat.go b/chat.go index b4a0ad90f..0f0c5b5d5 100644 --- a/chat.go +++ b/chat.go @@ -307,6 +307,8 @@ type ChatCompletionRequest struct { // Such as think mode for qwen3. "chat_template_kwargs": {"enable_thinking": false} // https://qwen.readthedocs.io/en/latest/deployment/vllm.html#thinking-non-thinking-modes ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"` + // Specifies the latency tier to use for processing the request. + ServiceTier ServiceTier `json:"service_tier,omitempty"` } type StreamOptions struct { @@ -390,6 +392,15 @@ const ( FinishReasonNull FinishReason = "null" ) +type ServiceTier string + +const ( + ServiceTierAuto ServiceTier = "auto" + ServiceTierDefault ServiceTier = "default" + ServiceTierFlex ServiceTier = "flex" + ServiceTierPriority ServiceTier = "priority" +) + func (r FinishReason) MarshalJSON() ([]byte, error) { if r == FinishReasonNull || r == "" { return []byte("null"), nil @@ -422,6 +433,7 @@ type ChatCompletionResponse struct { Usage Usage `json:"usage"` SystemFingerprint string `json:"system_fingerprint"` PromptFilterResults []PromptFilterResult `json:"prompt_filter_results,omitempty"` + ServiceTier ServiceTier `json:"service_tier,omitempty"` httpHeader } From c650976e492731e3fa758a2dbff2f743ae0e1c98 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Thu, 10 Jul 2025 18:35:16 +0800 Subject: [PATCH 24/41] Support $ref and $defs in JSON Schema (#1030) * support $ref and $defs in JSON Schema * update --- jsonschema/json.go | 36 +++++-- jsonschema/json_test.go | 70 ++++++++++++++ jsonschema/validate.go | 65 +++++++++++-- jsonschema/validate_test.go | 187 +++++++++++++++++++++++++++++++++++- 4 files changed, 340 insertions(+), 18 deletions(-) diff --git a/jsonschema/json.go b/jsonschema/json.go index 03bb68891..29d15b409 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -48,6 +48,11 @@ type Definition struct { AdditionalProperties any `json:"additionalProperties,omitempty"` // Whether the schema is nullable or not. Nullable bool `json:"nullable,omitempty"` + + // Ref Reference to a definition in $defs or external schema. + Ref string `json:"$ref,omitempty"` + // Defs A map of reusable schema definitions. + Defs map[string]Definition `json:"$defs,omitempty"` } func (d *Definition) MarshalJSON() ([]byte, error) { @@ -67,10 +72,16 @@ func (d *Definition) Unmarshal(content string, v any) error { } func GenerateSchemaForType(v any) (*Definition, error) { - return reflectSchema(reflect.TypeOf(v)) + var defs = make(map[string]Definition) + def, err := reflectSchema(reflect.TypeOf(v), defs) + if err != nil { + return nil, err + } + def.Defs = defs + return def, nil } -func reflectSchema(t reflect.Type) (*Definition, error) { +func reflectSchema(t reflect.Type, defs map[string]Definition) (*Definition, error) { var d Definition switch t.Kind() { case reflect.String: @@ -84,21 +95,32 @@ func reflectSchema(t reflect.Type) (*Definition, error) { d.Type = Boolean case reflect.Slice, reflect.Array: d.Type = Array - items, err := reflectSchema(t.Elem()) + items, err := reflectSchema(t.Elem(), defs) if err != nil { return nil, err } d.Items = items case reflect.Struct: + if t.Name() != "" { + if _, ok := defs[t.Name()]; !ok { + defs[t.Name()] = Definition{} + object, err := reflectSchemaObject(t, defs) + if err != nil { + return nil, err + } + defs[t.Name()] = *object + } + return &Definition{Ref: "#/$defs/" + t.Name()}, nil + } d.Type = Object d.AdditionalProperties = false - object, err := reflectSchemaObject(t) + object, err := reflectSchemaObject(t, defs) if err != nil { return nil, err } d = *object case reflect.Ptr: - definition, err := reflectSchema(t.Elem()) + definition, err := reflectSchema(t.Elem(), defs) if err != nil { return nil, err } @@ -112,7 +134,7 @@ func reflectSchema(t reflect.Type) (*Definition, error) { return &d, nil } -func reflectSchemaObject(t reflect.Type) (*Definition, error) { +func reflectSchemaObject(t reflect.Type, defs map[string]Definition) (*Definition, error) { var d = Definition{ Type: Object, AdditionalProperties: false, @@ -136,7 +158,7 @@ func reflectSchemaObject(t reflect.Type) (*Definition, error) { required = false } - item, err := reflectSchema(field.Type) + item, err := reflectSchema(field.Type, defs) if err != nil { return nil, err } diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go index 84f25fa85..31b54ed1a 100644 --- a/jsonschema/json_test.go +++ b/jsonschema/json_test.go @@ -183,6 +183,17 @@ func TestDefinition_MarshalJSON(t *testing.T) { } func TestStructToSchema(t *testing.T) { + type Tweet struct { + Text string `json:"text"` + } + + type Person struct { + Name string `json:"name,omitempty"` + Age int `json:"age,omitempty"` + Friends []Person `json:"friends,omitempty"` + Tweets []Tweet `json:"tweets,omitempty"` + } + tests := []struct { name string in any @@ -376,6 +387,65 @@ func TestStructToSchema(t *testing.T) { "additionalProperties":false }`, }, + { + name: "Test with $ref and $defs", + in: struct { + Person Person `json:"person"` + Tweets []Tweet `json:"tweets"` + }{}, + want: `{ + "type" : "object", + "properties" : { + "person" : { + "$ref" : "#/$defs/Person" + }, + "tweets" : { + "type" : "array", + "items" : { + "$ref" : "#/$defs/Tweet" + } + } + }, + "required" : [ "person", "tweets" ], + "additionalProperties" : false, + "$defs" : { + "Person" : { + "type" : "object", + "properties" : { + "age" : { + "type" : "integer" + }, + "friends" : { + "type" : "array", + "items" : { + "$ref" : "#/$defs/Person" + } + }, + "name" : { + "type" : "string" + }, + "tweets" : { + "type" : "array", + "items" : { + "$ref" : "#/$defs/Tweet" + } + } + }, + "additionalProperties" : false + }, + "Tweet" : { + "type" : "object", + "properties" : { + "text" : { + "type" : "string" + } + }, + "required" : [ "text" ], + "additionalProperties" : false + } + } +}`, + }, } for _, tt := range tests { diff --git a/jsonschema/validate.go b/jsonschema/validate.go index 49f9b8859..1bd2f809c 100644 --- a/jsonschema/validate.go +++ b/jsonschema/validate.go @@ -5,26 +5,68 @@ import ( "errors" ) +func CollectDefs(def Definition) map[string]Definition { + result := make(map[string]Definition) + collectDefsRecursive(def, result, "#") + return result +} + +func collectDefsRecursive(def Definition, result map[string]Definition, prefix string) { + for k, v := range def.Defs { + path := prefix + "/$defs/" + k + result[path] = v + collectDefsRecursive(v, result, path) + } + for k, sub := range def.Properties { + collectDefsRecursive(sub, result, prefix+"/properties/"+k) + } + if def.Items != nil { + collectDefsRecursive(*def.Items, result, prefix) + } +} + func VerifySchemaAndUnmarshal(schema Definition, content []byte, v any) error { var data any err := json.Unmarshal(content, &data) if err != nil { return err } - if !Validate(schema, data) { + if !Validate(schema, data, WithDefs(CollectDefs(schema))) { return errors.New("data validation failed against the provided schema") } return json.Unmarshal(content, &v) } -func Validate(schema Definition, data any) bool { +type validateArgs struct { + Defs map[string]Definition +} + +type ValidateOption func(*validateArgs) + +func WithDefs(defs map[string]Definition) ValidateOption { + return func(option *validateArgs) { + option.Defs = defs + } +} + +func Validate(schema Definition, data any, opts ...ValidateOption) bool { + args := validateArgs{} + for _, opt := range opts { + opt(&args) + } + if len(opts) == 0 { + args.Defs = CollectDefs(schema) + } switch schema.Type { case Object: - return validateObject(schema, data) + return validateObject(schema, data, args.Defs) case Array: - return validateArray(schema, data) + return validateArray(schema, data, args.Defs) case String: - _, ok := data.(string) + v, ok := data.(string) + if ok && len(schema.Enum) > 0 { + return contains(schema.Enum, v) + } return ok case Number: // float64 and int _, ok := data.(float64) @@ -45,11 +87,16 @@ func Validate(schema Definition, data any) bool { case Null: return data == nil default: + if schema.Ref != "" && args.Defs != nil { + if v, ok := args.Defs[schema.Ref]; ok { + return Validate(v, data, WithDefs(args.Defs)) + } + } return false } } -func validateObject(schema Definition, data any) bool { +func validateObject(schema Definition, data any, defs map[string]Definition) bool { dataMap, ok := data.(map[string]any) if !ok { return false @@ -61,7 +108,7 @@ func validateObject(schema Definition, data any) bool { } for key, valueSchema := range schema.Properties { value, exists := dataMap[key] - if exists && !Validate(valueSchema, value) { + if exists && !Validate(valueSchema, value, WithDefs(defs)) { return false } else if !exists && contains(schema.Required, key) { return false @@ -70,13 +117,13 @@ func validateObject(schema Definition, data any) bool { return true } -func validateArray(schema Definition, data any) bool { +func validateArray(schema Definition, data any, defs map[string]Definition) bool { dataArray, ok := data.([]any) if !ok { return false } for _, item := range dataArray { - if !Validate(*schema.Items, item) { + if !Validate(*schema.Items, item, WithDefs(defs)) { return false } } diff --git a/jsonschema/validate_test.go b/jsonschema/validate_test.go index 6fa30ab0c..aefdf4069 100644 --- a/jsonschema/validate_test.go +++ b/jsonschema/validate_test.go @@ -1,6 +1,7 @@ package jsonschema_test import ( + "reflect" "testing" "github.com/sashabaranov/go-openai/jsonschema" @@ -70,6 +71,96 @@ func Test_Validate(t *testing.T) { }, Required: []string{"string"}, }}, false}, + { + "test schema with ref and defs", args{data: map[string]any{ + "person": map[string]any{ + "name": "John", + "gender": "male", + "age": 28, + "profile": map[string]any{ + "full_name": "John Doe", + }, + }, + }, schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "person": {Ref: "#/$defs/Person"}, + }, + Required: []string{"person"}, + Defs: map[string]jsonschema.Definition{ + "Person": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": {Type: jsonschema.String}, + "gender": {Type: jsonschema.String, Enum: []string{"male", "female", "unknown"}}, + "age": {Type: jsonschema.Integer}, + "profile": {Ref: "#/$defs/Person/$defs/Profile"}, + "tweets": {Type: jsonschema.Array, Items: &jsonschema.Definition{Ref: "#/$defs/Tweet"}}, + }, + Required: []string{"name", "gender", "age", "profile"}, + Defs: map[string]jsonschema.Definition{ + "Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + }, + }, + "Tweet": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "text": {Type: jsonschema.String}, + "person": {Ref: "#/$defs/Person"}, + }, + }, + }, + }}, true}, + { + "test enum invalid value", args{data: map[string]any{ + "person": map[string]any{ + "name": "John", + "gender": "other", + "age": 28, + "profile": map[string]any{ + "full_name": "John Doe", + }, + }, + }, schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "person": {Ref: "#/$defs/Person"}, + }, + Required: []string{"person"}, + Defs: map[string]jsonschema.Definition{ + "Person": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": {Type: jsonschema.String}, + "gender": {Type: jsonschema.String, Enum: []string{"male", "female", "unknown"}}, + "age": {Type: jsonschema.Integer}, + "profile": {Ref: "#/$defs/Person/$defs/Profile"}, + "tweets": {Type: jsonschema.Array, Items: &jsonschema.Definition{Ref: "#/$defs/Tweet"}}, + }, + Required: []string{"name", "gender", "age", "profile"}, + Defs: map[string]jsonschema.Definition{ + "Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + }, + }, + "Tweet": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "text": {Type: jsonschema.String}, + "person": {Ref: "#/$defs/Person"}, + }, + }, + }, + }}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -156,8 +247,100 @@ func TestUnmarshal(t *testing.T) { err := jsonschema.VerifySchemaAndUnmarshal(tt.args.schema, tt.args.content, tt.args.v) if (err != nil) != tt.wantErr { t.Errorf("Unmarshal() error = %v, wantErr %v", err, tt.wantErr) - } else if err == nil { - t.Logf("Unmarshal() v = %+v\n", tt.args.v) + } + }) + } +} + +func TestCollectDefs(t *testing.T) { + type args struct { + schema jsonschema.Definition + } + tests := []struct { + name string + args args + want map[string]jsonschema.Definition + }{ + { + "test collect defs", + args{ + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "person": {Ref: "#/$defs/Person"}, + }, + Required: []string{"person"}, + Defs: map[string]jsonschema.Definition{ + "Person": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": {Type: jsonschema.String}, + "gender": {Type: jsonschema.String, Enum: []string{"male", "female", "unknown"}}, + "age": {Type: jsonschema.Integer}, + "profile": {Ref: "#/$defs/Person/$defs/Profile"}, + "tweets": {Type: jsonschema.Array, Items: &jsonschema.Definition{Ref: "#/$defs/Tweet"}}, + }, + Required: []string{"name", "gender", "age", "profile"}, + Defs: map[string]jsonschema.Definition{ + "Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + }, + }, + "Tweet": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "text": {Type: jsonschema.String}, + "person": {Ref: "#/$defs/Person"}, + }, + }, + }, + }, + }, + map[string]jsonschema.Definition{ + "#/$defs/Person": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": {Type: jsonschema.String}, + "gender": {Type: jsonschema.String, Enum: []string{"male", "female", "unknown"}}, + "age": {Type: jsonschema.Integer}, + "profile": {Ref: "#/$defs/Person/$defs/Profile"}, + "tweets": {Type: jsonschema.Array, Items: &jsonschema.Definition{Ref: "#/$defs/Tweet"}}, + }, + Required: []string{"name", "gender", "age", "profile"}, + Defs: map[string]jsonschema.Definition{ + "Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + }, + }, + "#/$defs/Person/$defs/Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + "#/$defs/Tweet": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "text": {Type: jsonschema.String}, + "person": {Ref: "#/$defs/Person"}, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := jsonschema.CollectDefs(tt.args.schema) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("CollectDefs() = %v, want %v", got, tt.want) } }) } From 1bf77f6fd6f10b9cd80e1dfdedd616f3f1712e20 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Thu, 10 Jul 2025 21:49:20 +0100 Subject: [PATCH 25/41] Improve unit test coverage (#1032) * Add unit tests to improve coverage * Fix type assertion checks in tests --- image_test.go | 50 ++++++++++++++++++++++++++++++ internal/error_accumulator_test.go | 7 +++++ internal/form_builder_test.go | 14 +++++++++ internal/request_builder_test.go | 27 ++++++++++++++++ 4 files changed, 98 insertions(+) diff --git a/image_test.go b/image_test.go index 644005515..bb9a086fd 100644 --- a/image_test.go +++ b/image_test.go @@ -4,6 +4,7 @@ import ( utils "github.com/sashabaranov/go-openai/internal" "github.com/sashabaranov/go-openai/internal/test/checks" + "bytes" "context" "fmt" "io" @@ -156,3 +157,52 @@ func TestVariImageFormBuilderFailures(t *testing.T) { _, err = client.CreateVariImage(ctx, req) checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") } + +type testNamedReader struct{ io.Reader } + +func (testNamedReader) Name() string { return "named.txt" } + +func TestWrapReader(t *testing.T) { + r := bytes.NewBufferString("data") + wrapped := WrapReader(r, "file.png", "image/png") + f, ok := wrapped.(interface { + Name() string + ContentType() string + }) + if !ok { + t.Fatal("wrapped reader missing Name or ContentType") + } + if f.Name() != "file.png" { + t.Fatalf("expected name file.png, got %s", f.Name()) + } + if f.ContentType() != "image/png" { + t.Fatalf("expected content type image/png, got %s", f.ContentType()) + } + + // test name from underlying reader + nr := testNamedReader{Reader: bytes.NewBufferString("d")} + wrapped = WrapReader(nr, "", "text/plain") + f, ok = wrapped.(interface { + Name() string + ContentType() string + }) + if !ok { + t.Fatal("wrapped named reader missing Name or ContentType") + } + if f.Name() != "named.txt" { + t.Fatalf("expected name named.txt, got %s", f.Name()) + } + if f.ContentType() != "text/plain" { + t.Fatalf("expected content type text/plain, got %s", f.ContentType()) + } + + // no name provided + wrapped = WrapReader(bytes.NewBuffer(nil), "", "") + f2, ok := wrapped.(interface{ Name() string }) + if !ok { + t.Fatal("wrapped anonymous reader missing Name") + } + if f2.Name() != "" { + t.Fatalf("expected empty name, got %s", f2.Name()) + } +} diff --git a/internal/error_accumulator_test.go b/internal/error_accumulator_test.go index 3fa9d7714..f6c226c5e 100644 --- a/internal/error_accumulator_test.go +++ b/internal/error_accumulator_test.go @@ -4,6 +4,7 @@ import ( "testing" openai "github.com/sashabaranov/go-openai/internal" + "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" ) @@ -30,3 +31,9 @@ func TestDefaultErrorAccumulator_EmptyBuffer(t *testing.T) { t.Fatal("Buffer should be empty initially") } } + +func TestDefaultErrorAccumulator_WriteError(t *testing.T) { + ea := &openai.DefaultErrorAccumulator{Buffer: &test.FailingErrorBuffer{}} + err := ea.Write([]byte("fail")) + checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Write should propagate buffer errors") +} diff --git a/internal/form_builder_test.go b/internal/form_builder_test.go index ddd6b8448..1cc82ab8a 100644 --- a/internal/form_builder_test.go +++ b/internal/form_builder_test.go @@ -161,3 +161,17 @@ func TestWriteField(t *testing.T) { checks.NoError(t, err, "should write field without error") }) } + +func TestCreateFormFile(t *testing.T) { + buf := &bytes.Buffer{} + builder := NewFormBuilder(buf) + + err := builder.createFormFile("file", bytes.NewBufferString("data"), "") + if err == nil { + t.Fatal("expected error for empty filename") + } + + builder = NewFormBuilder(&failingWriter{}) + err = builder.createFormFile("file", bytes.NewBufferString("data"), "name") + checks.ErrorIs(t, err, errMockFailingWriterError, "should propagate writer error") +} diff --git a/internal/request_builder_test.go b/internal/request_builder_test.go index e26022a6b..1561b87fb 100644 --- a/internal/request_builder_test.go +++ b/internal/request_builder_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "errors" + "io" "net/http" "reflect" "testing" @@ -59,3 +60,29 @@ func TestRequestBuilderReturnsRequestWhenRequestOfArgsIsNil(t *testing.T) { t.Errorf("Build() got = %v, want %v", got, want) } } + +func TestRequestBuilderWithReaderBodyAndHeader(t *testing.T) { + b := NewRequestBuilder() + ctx := context.Background() + method := http.MethodPost + url := "/reader" + bodyContent := "hello" + body := bytes.NewBufferString(bodyContent) + header := http.Header{"X-Test": []string{"val"}} + + req, err := b.Build(ctx, method, url, body, header) + if err != nil { + t.Fatalf("Build returned error: %v", err) + } + + gotBody, err := io.ReadAll(req.Body) + if err != nil { + t.Fatalf("cannot read body: %v", err) + } + if string(gotBody) != bodyContent { + t.Fatalf("expected body %q, got %q", bodyContent, string(gotBody)) + } + if req.Header.Get("X-Test") != "val" { + t.Fatalf("expected header set to val, got %q", req.Header.Get("X-Test")) + } +} From a0c185f3628cd616fb9d1648700fc51c40c42dda Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Fri, 11 Jul 2025 19:17:33 +0800 Subject: [PATCH 26/41] Removed root $ref from GenerateSchemaForType (#1033) * support $ref and $defs in JSON Schema * update * removed root $ref from JSON Schema * Update json.go * Update json_test.go * Update jsonschema/json.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update jsonschema/json.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- jsonschema/json.go | 44 ++++++++++ jsonschema/json_test.go | 174 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 218 insertions(+) diff --git a/jsonschema/json.go b/jsonschema/json.go index 29d15b409..75e3b5173 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -77,6 +77,27 @@ func GenerateSchemaForType(v any) (*Definition, error) { if err != nil { return nil, err } + // If the schema has a root $ref, resolve it by: + // 1. Extracting the key from the $ref. + // 2. Detaching the referenced definition from $defs. + // 3. Checking for self-references in the detached definition. + // - If a self-reference is found, restore the original $defs structure. + // 4. Flattening the referenced definition into the root schema. + // 5. Clearing the $ref field in the root schema. + if def.Ref != "" { + origRef := def.Ref + key := strings.TrimPrefix(origRef, "#/$defs/") + if root, ok := defs[key]; ok { + delete(defs, key) + root.Defs = defs + if containsRef(root, origRef) { + root.Defs = nil + defs[key] = root + } + *def = root + } + def.Ref = "" + } def.Defs = defs return def, nil } @@ -189,3 +210,26 @@ func reflectSchemaObject(t reflect.Type, defs map[string]Definition) (*Definitio d.Properties = properties return &d, nil } + +func containsRef(def Definition, targetRef string) bool { + if def.Ref == targetRef { + return true + } + + for _, d := range def.Defs { + if containsRef(d, targetRef) { + return true + } + } + + for _, prop := range def.Properties { + if containsRef(prop, targetRef) { + return true + } + } + + if def.Items != nil && containsRef(*def.Items, targetRef) { + return true + } + return false +} diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go index 31b54ed1a..34f5d88eb 100644 --- a/jsonschema/json_test.go +++ b/jsonschema/json_test.go @@ -182,6 +182,18 @@ func TestDefinition_MarshalJSON(t *testing.T) { } } +type User struct { + ID int `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Orders []*Order `json:"orders,omitempty"` +} + +type Order struct { + ID int `json:"id,omitempty"` + Amount float64 `json:"amount,omitempty"` + Buyer *User `json:"buyer,omitempty"` +} + func TestStructToSchema(t *testing.T) { type Tweet struct { Text string `json:"text"` @@ -194,6 +206,13 @@ func TestStructToSchema(t *testing.T) { Tweets []Tweet `json:"tweets,omitempty"` } + type MyStructuredResponse struct { + PascalCase string `json:"pascal_case" required:"true" description:"PascalCase"` + CamelCase string `json:"camel_case" required:"true" description:"CamelCase"` + KebabCase string `json:"kebab_case" required:"true" description:"KebabCase"` + SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"` + } + tests := []struct { name string in any @@ -444,6 +463,161 @@ func TestStructToSchema(t *testing.T) { "additionalProperties" : false } } +}`, + }, + { + name: "Test Person", + in: Person{}, + want: `{ + "type": "object", + "properties": { + "age": { + "type": "integer" + }, + "friends": { + "type": "array", + "items": { + "$ref": "#/$defs/Person" + } + }, + "name": { + "type": "string" + }, + "tweets": { + "type": "array", + "items": { + "$ref": "#/$defs/Tweet" + } + } + }, + "additionalProperties": false, + "$defs": { + "Person": { + "type": "object", + "properties": { + "age": { + "type": "integer" + }, + "friends": { + "type": "array", + "items": { + "$ref": "#/$defs/Person" + } + }, + "name": { + "type": "string" + }, + "tweets": { + "type": "array", + "items": { + "$ref": "#/$defs/Tweet" + } + } + }, + "additionalProperties": false + }, + "Tweet": { + "type": "object", + "properties": { + "text": { + "type": "string" + } + }, + "required": [ + "text" + ], + "additionalProperties": false + } + } +}`, + }, + { + name: "Test MyStructuredResponse", + in: MyStructuredResponse{}, + want: `{ + "type": "object", + "properties": { + "camel_case": { + "type": "string", + "description": "CamelCase" + }, + "kebab_case": { + "type": "string", + "description": "KebabCase" + }, + "pascal_case": { + "type": "string", + "description": "PascalCase" + }, + "snake_case": { + "type": "string", + "description": "SnakeCase" + } + }, + "required": [ + "pascal_case", + "camel_case", + "kebab_case", + "snake_case" + ], + "additionalProperties": false +}`, + }, + { + name: "Test User", + in: User{}, + want: `{ + "type": "object", + "properties": { + "id": { + "type": "integer" + }, + "name": { + "type": "string" + }, + "orders": { + "type": "array", + "items": { + "$ref": "#/$defs/Order" + } + } + }, + "additionalProperties": false, + "$defs": { + "Order": { + "type": "object", + "properties": { + "amount": { + "type": "number" + }, + "buyer": { + "$ref": "#/$defs/User" + }, + "id": { + "type": "integer" + } + }, + "additionalProperties": false + }, + "User": { + "type": "object", + "properties": { + "id": { + "type": "integer" + }, + "name": { + "type": "string" + }, + "orders": { + "type": "array", + "items": { + "$ref": "#/$defs/Order" + } + } + }, + "additionalProperties": false + } + } }`, }, } From 575aff439644c3b6460f2617b81b119044bfc35c Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 11 Jul 2025 13:15:15 +0100 Subject: [PATCH 27/41] Add tests for form and request builders (#1036) --- internal/form_builder_test.go | 13 +++++++++++++ internal/request_builder_test.go | 8 ++++++++ 2 files changed, 21 insertions(+) diff --git a/internal/form_builder_test.go b/internal/form_builder_test.go index 1cc82ab8a..53ef11d23 100644 --- a/internal/form_builder_test.go +++ b/internal/form_builder_test.go @@ -8,6 +8,7 @@ import ( "bytes" "os" + "strings" "testing" ) @@ -175,3 +176,15 @@ func TestCreateFormFile(t *testing.T) { err = builder.createFormFile("file", bytes.NewBufferString("data"), "name") checks.ErrorIs(t, err, errMockFailingWriterError, "should propagate writer error") } + +func TestCreateFormFileSuccess(t *testing.T) { + buf := &bytes.Buffer{} + builder := NewFormBuilder(buf) + + err := builder.createFormFile("file", bytes.NewBufferString("data"), "foo.txt") + checks.NoError(t, err, "createFormFile should succeed") + + if !strings.Contains(buf.String(), "filename=\"foo.txt\"") { + t.Fatalf("expected filename header, got %q", buf.String()) + } +} diff --git a/internal/request_builder_test.go b/internal/request_builder_test.go index 1561b87fb..adccb158e 100644 --- a/internal/request_builder_test.go +++ b/internal/request_builder_test.go @@ -86,3 +86,11 @@ func TestRequestBuilderWithReaderBodyAndHeader(t *testing.T) { t.Fatalf("expected header set to val, got %q", req.Header.Get("X-Test")) } } + +func TestRequestBuilderInvalidURL(t *testing.T) { + b := NewRequestBuilder() + _, err := b.Build(context.Background(), http.MethodGet, ":", nil, nil) + if err == nil { + t.Fatal("expected error for invalid URL") + } +} From 88eb1df90bb6af6a1f4d471a08d9f8a13dd5a8e5 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 11 Jul 2025 15:58:52 +0100 Subject: [PATCH 28/41] Ignore codecov coverage for examples and internal/test (#1038) --- .codecov.yml | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 .codecov.yml diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 000000000..81773666c --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,4 @@ +coverage: + ignore: + - "examples/**" + - "internal/test/**" From 8d681e7f9a8f172168199f29e8e1f16701d6817a Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 11 Jul 2025 16:18:23 +0100 Subject: [PATCH 29/41] Increase image.go test coverage to 100% (#1039) --- image_test.go | 303 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 209 insertions(+), 94 deletions(-) diff --git a/image_test.go b/image_test.go index bb9a086fd..c2c8f42dc 100644 --- a/image_test.go +++ b/image_test.go @@ -40,122 +40,237 @@ func (fb *mockFormBuilder) FormDataContentType() string { } func TestImageFormBuilderFailures(t *testing.T) { - config := DefaultConfig("") - config.BaseURL = "" - client := NewClientWithConfig(config) - - mockBuilder := &mockFormBuilder{} - client.createFormBuilder = func(io.Writer) utils.FormBuilder { - return mockBuilder - } ctx := context.Background() - - req := ImageEditRequest{ - Mask: &os.File{}, - } - mockFailedErr := fmt.Errorf("mock form builder fail") - mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { - return mockFailedErr - } - _, err := client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - mockBuilder.mockCreateFormFileReader = func(name string, _ io.Reader, _ string) error { - if name == "mask" { - return mockFailedErr - } - return nil + newClient := func(fb *mockFormBuilder) *Client { + cfg := DefaultConfig("") + cfg.BaseURL = "" + c := NewClientWithConfig(cfg) + c.createFormBuilder = func(io.Writer) utils.FormBuilder { return fb } + return c } - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - mockBuilder.mockCreateFormFile = func(string, *os.File) error { - return nil + tests := []struct { + name string + setup func(*mockFormBuilder) + req ImageEditRequest + }{ + { + name: "image", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return mockFailedErr } + fb.mockWriteField = func(string, string) error { return nil } + fb.mockClose = func() error { return nil } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "mask", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(name string, _ io.Reader, _ string) error { + if name == "mask" { + return mockFailedErr + } + return nil + } + fb.mockWriteField = func(string, string) error { return nil } + fb.mockClose = func() error { return nil } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "prompt", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field, _ string) error { + if field == "prompt" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "n", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field, _ string) error { + if field == "n" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "size", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field, _ string) error { + if field == "size" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "response_format", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field, _ string) error { + if field == "response_format" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "close", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(string, string) error { return nil } + fb.mockClose = func() error { return mockFailedErr } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, } - var failForField string - mockBuilder.mockWriteField = func(fieldname, _ string) error { - if fieldname == failForField { - return mockFailedErr - } - return nil + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + fb := &mockFormBuilder{} + tc.setup(fb) + client := newClient(fb) + _, err := client.CreateEditImage(ctx, tc.req) + checks.ErrorIs(t, err, mockFailedErr, "CreateEditImage should return error if form builder fails") + }) } - failForField = "prompt" - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - - failForField = "n" - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - - failForField = "size" - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - - failForField = "response_format" - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") + t.Run("new request", func(t *testing.T) { + fb := &mockFormBuilder{ + mockCreateFormFileReader: func(string, io.Reader, string) error { return nil }, + mockWriteField: func(string, string) error { return nil }, + mockClose: func() error { return nil }, + } + client := newClient(fb) + client.requestBuilder = &failingRequestBuilder{} - failForField = "" - mockBuilder.mockClose = func() error { - return mockFailedErr - } - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") + _, err := client.CreateEditImage(ctx, ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}) + checks.ErrorIs(t, err, errTestRequestBuilderFailed, "CreateEditImage should return error if request builder fails") + }) } func TestVariImageFormBuilderFailures(t *testing.T) { - config := DefaultConfig("") - config.BaseURL = "" - client := NewClientWithConfig(config) - - mockBuilder := &mockFormBuilder{} - client.createFormBuilder = func(io.Writer) utils.FormBuilder { - return mockBuilder - } ctx := context.Background() - - req := ImageVariRequest{} - mockFailedErr := fmt.Errorf("mock form builder fail") - mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { - return mockFailedErr - } - _, err := client.CreateVariImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") - mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { - return nil + newClient := func(fb *mockFormBuilder) *Client { + cfg := DefaultConfig("") + cfg.BaseURL = "" + c := NewClientWithConfig(cfg) + c.createFormBuilder = func(io.Writer) utils.FormBuilder { return fb } + return c } - var failForField string - mockBuilder.mockWriteField = func(fieldname, _ string) error { - if fieldname == failForField { - return mockFailedErr - } - return nil + tests := []struct { + name string + setup func(*mockFormBuilder) + req ImageVariRequest + }{ + { + name: "image", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return mockFailedErr } + fb.mockWriteField = func(string, string) error { return nil } + fb.mockClose = func() error { return nil } + }, + req: ImageVariRequest{Image: bytes.NewBuffer(nil)}, + }, + { + name: "n", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field string, _ string) error { + if field == "n" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageVariRequest{Image: bytes.NewBuffer(nil)}, + }, + { + name: "size", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field string, _ string) error { + if field == "size" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageVariRequest{Image: bytes.NewBuffer(nil)}, + }, + { + name: "response_format", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field string, _ string) error { + if field == "response_format" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageVariRequest{Image: bytes.NewBuffer(nil)}, + }, + { + name: "close", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(string, string) error { return nil } + fb.mockClose = func() error { return mockFailedErr } + }, + req: ImageVariRequest{Image: bytes.NewBuffer(nil)}, + }, } - failForField = "n" - _, err = client.CreateVariImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") - - failForField = "size" - _, err = client.CreateVariImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + fb := &mockFormBuilder{} + tc.setup(fb) + client := newClient(fb) + _, err := client.CreateVariImage(ctx, tc.req) + checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") + }) + } - failForField = "response_format" - _, err = client.CreateVariImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") + t.Run("new request", func(t *testing.T) { + fb := &mockFormBuilder{ + mockCreateFormFileReader: func(string, io.Reader, string) error { return nil }, + mockWriteField: func(string, string) error { return nil }, + mockClose: func() error { return nil }, + } + client := newClient(fb) + client.requestBuilder = &failingRequestBuilder{} - failForField = "" - mockBuilder.mockClose = func() error { - return mockFailedErr - } - _, err = client.CreateVariImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") + _, err := client.CreateVariImage(ctx, ImageVariRequest{Image: bytes.NewBuffer(nil)}) + checks.ErrorIs(t, err, errTestRequestBuilderFailed, "CreateVariImage should return error if request builder fails") + }) } type testNamedReader struct{ io.Reader } From 8665ad7264b64cd2046c2a2d4addf8617c6fffea Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 11 Jul 2025 16:31:09 +0100 Subject: [PATCH 30/41] Increase jsonschema test coverage (#1040) --- jsonschema/json_additional_test.go | 73 ++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 jsonschema/json_additional_test.go diff --git a/jsonschema/json_additional_test.go b/jsonschema/json_additional_test.go new file mode 100644 index 000000000..70cf37490 --- /dev/null +++ b/jsonschema/json_additional_test.go @@ -0,0 +1,73 @@ +package jsonschema_test + +import ( + "testing" + + "github.com/sashabaranov/go-openai/jsonschema" +) + +// Test Definition.Unmarshal, including success path, validation error, +// JSON syntax error and type mismatch during unmarshalling. +func TestDefinitionUnmarshal(t *testing.T) { + schema := jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": {Type: jsonschema.String}, + }, + } + + var dst struct { + Name string `json:"name"` + } + if err := schema.Unmarshal(`{"name":"foo"}`, &dst); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if dst.Name != "foo" { + t.Errorf("expected name to be foo, got %q", dst.Name) + } + + if err := schema.Unmarshal(`{`, &dst); err == nil { + t.Error("expected error for malformed json") + } + + if err := schema.Unmarshal(`{"name":1}`, &dst); err == nil { + t.Error("expected validation error") + } + + numSchema := jsonschema.Definition{Type: jsonschema.Number} + var s string + if err := numSchema.Unmarshal(`123`, &s); err == nil { + t.Error("expected unmarshal type error") + } +} + +// Ensure GenerateSchemaForType returns an error when encountering unsupported types. +func TestGenerateSchemaForTypeUnsupported(t *testing.T) { + type Bad struct { + Ch chan int `json:"ch"` + } + if _, err := jsonschema.GenerateSchemaForType(Bad{}); err == nil { + t.Fatal("expected error for unsupported type") + } +} + +// Validate should fail when provided data does not match the expected container types. +func TestValidateInvalidContainers(t *testing.T) { + objSchema := jsonschema.Definition{Type: jsonschema.Object} + if jsonschema.Validate(objSchema, 1) { + t.Error("expected object validation to fail for non-map input") + } + + arrSchema := jsonschema.Definition{Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.String}} + if jsonschema.Validate(arrSchema, 1) { + t.Error("expected array validation to fail for non-slice input") + } +} + +// Validate should return false when $ref cannot be resolved. +func TestValidateRefNotFound(t *testing.T) { + refSchema := jsonschema.Definition{Ref: "#/$defs/Missing"} + if jsonschema.Validate(refSchema, "data", jsonschema.WithDefs(map[string]jsonschema.Definition{})) { + t.Error("expected validation to fail when reference is missing") + } +} From 1e912177886f37ae0c4159982a962d9ef8ff921b Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 11 Jul 2025 16:48:49 +0100 Subject: [PATCH 31/41] test: cover CreateFile request builder failure (#1041) --- files_test.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/files_test.go b/files_test.go index 3c1b99fb4..486ef892e 100644 --- a/files_test.go +++ b/files_test.go @@ -121,3 +121,20 @@ func TestFileUploadWithNonExistentPath(t *testing.T) { _, err := client.CreateFile(ctx, req) checks.ErrorIs(t, err, os.ErrNotExist, "CreateFile should return error if file does not exist") } +func TestCreateFileRequestBuilderFailure(t *testing.T) { + config := DefaultConfig("") + config.BaseURL = "" + client := NewClientWithConfig(config) + client.requestBuilder = &failingRequestBuilder{} + + client.createFormBuilder = func(io.Writer) utils.FormBuilder { + return &mockFormBuilder{ + mockWriteField: func(string, string) error { return nil }, + mockCreateFormFile: func(string, *os.File) error { return nil }, + mockClose: func() error { return nil }, + } + } + + _, err := client.CreateFile(context.Background(), FileRequest{FilePath: "client.go"}) + checks.ErrorIs(t, err, errTestRequestBuilderFailed, "CreateFile should return error if request builder fails") +} From 3bb1014fa7e8d09fa6ff31bc6bce6172a3bf59ae Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 11 Jul 2025 17:11:07 +0100 Subject: [PATCH 32/41] ci: enable version compatibility vet (#1042) --- .github/workflows/pr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 18c720f03..2c9730656 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -16,7 +16,7 @@ jobs: go-version: '1.24' - name: Run vet run: | - go vet . + go vet -stdversion ./... - name: Run golangci-lint uses: golangci/golangci-lint-action@v7 with: From bd36c45dc505d3592ddfc0d53c06561fe8a3dacb Mon Sep 17 00:00:00 2001 From: Ayush Sawant Date: Fri, 11 Jul 2025 21:45:53 +0530 Subject: [PATCH 33/41] Support for extra_body parameter for embeddings API (#906) * support for extra_body parameter for embeddings API * done linting * added unit tests * improved code coverage and removed unnecessary checks * test cleanup * updated body map creation code * code coverage * minor change * updated testcase comment --- client.go | 14 ++++++++++++++ embeddings.go | 32 +++++++++++++++++++++++++++++++- embeddings_test.go | 46 +++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 90 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index cef375348..413b8db03 100644 --- a/client.go +++ b/client.go @@ -84,6 +84,20 @@ func withBody(body any) requestOption { } } +func withExtraBody(extraBody map[string]any) requestOption { + return func(args *requestOptions) { + // Assert that args.body is a map[string]any. + bodyMap, ok := args.body.(map[string]any) + if ok { + // If it's a map[string]any then only add extraBody + // fields to args.body otherwise keep only fields in request struct. + for key, value := range extraBody { + bodyMap[key] = value + } + } + } +} + func withContentType(contentType string) requestOption { return func(args *requestOptions) { args.header.Set("Content-Type", contentType) diff --git a/embeddings.go b/embeddings.go index 4a0e682da..8593f8b5b 100644 --- a/embeddings.go +++ b/embeddings.go @@ -4,6 +4,7 @@ import ( "context" "encoding/base64" "encoding/binary" + "encoding/json" "errors" "math" "net/http" @@ -160,6 +161,9 @@ type EmbeddingRequest struct { // Dimensions The number of dimensions the resulting output embeddings should have. // Only supported in text-embedding-3 and later models. Dimensions int `json:"dimensions,omitempty"` + // The ExtraBody field allows for the inclusion of arbitrary key-value pairs + // in the request body that may not be explicitly defined in this struct. + ExtraBody map[string]any `json:"extra_body,omitempty"` } func (r EmbeddingRequest) Convert() EmbeddingRequest { @@ -187,6 +191,9 @@ type EmbeddingRequestStrings struct { // Dimensions The number of dimensions the resulting output embeddings should have. // Only supported in text-embedding-3 and later models. Dimensions int `json:"dimensions,omitempty"` + // The ExtraBody field allows for the inclusion of arbitrary key-value pairs + // in the request body that may not be explicitly defined in this struct. + ExtraBody map[string]any `json:"extra_body,omitempty"` } func (r EmbeddingRequestStrings) Convert() EmbeddingRequest { @@ -196,6 +203,7 @@ func (r EmbeddingRequestStrings) Convert() EmbeddingRequest { User: r.User, EncodingFormat: r.EncodingFormat, Dimensions: r.Dimensions, + ExtraBody: r.ExtraBody, } } @@ -219,6 +227,9 @@ type EmbeddingRequestTokens struct { // Dimensions The number of dimensions the resulting output embeddings should have. // Only supported in text-embedding-3 and later models. Dimensions int `json:"dimensions,omitempty"` + // The ExtraBody field allows for the inclusion of arbitrary key-value pairs + // in the request body that may not be explicitly defined in this struct. + ExtraBody map[string]any `json:"extra_body,omitempty"` } func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { @@ -228,6 +239,7 @@ func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { User: r.User, EncodingFormat: r.EncodingFormat, Dimensions: r.Dimensions, + ExtraBody: r.ExtraBody, } } @@ -241,11 +253,29 @@ func (c *Client) CreateEmbeddings( conv EmbeddingRequestConverter, ) (res EmbeddingResponse, err error) { baseReq := conv.Convert() + + // The body map is used to dynamically construct the request payload for the embedding API. + // Instead of relying on a fixed struct, the body map allows for flexible inclusion of fields + // based on their presence, avoiding unnecessary or empty fields in the request. + extraBody := baseReq.ExtraBody + baseReq.ExtraBody = nil + + // Serialize baseReq to JSON + jsonData, err := json.Marshal(baseReq) + if err != nil { + return + } + + // Deserialize JSON to map[string]any + var body map[string]any + _ = json.Unmarshal(jsonData, &body) + req, err := c.newRequest( ctx, http.MethodPost, c.fullURL("/embeddings", withModel(string(baseReq.Model))), - withBody(baseReq), + withBody(body), // Main request body. + withExtraBody(extraBody), // Merge ExtraBody fields. ) if err != nil { return diff --git a/embeddings_test.go b/embeddings_test.go index 438978169..07f1262cb 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -51,6 +51,24 @@ func TestEmbedding(t *testing.T) { t.Fatalf("Expected embedding request to contain model field") } + // test embedding request with strings and extra_body param + embeddingReqWithExtraBody := openai.EmbeddingRequest{ + Input: []string{ + "The food was delicious and the waiter", + "Other examples of embedding request", + }, + Model: model, + ExtraBody: map[string]any{ + "input_type": "query", + "truncate": "NONE", + }, + } + marshaled, err = json.Marshal(embeddingReqWithExtraBody) + checks.NoError(t, err, "Could not marshal embedding request") + if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) { + t.Fatalf("Expected embedding request to contain model field") + } + // test embedding request with strings embeddingReqStrings := openai.EmbeddingRequestStrings{ Input: []string{ @@ -124,7 +142,33 @@ func TestEmbeddingEndpoint(t *testing.T) { t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) } - // test create embeddings with strings (simple embedding request) + // test create embeddings with strings (ExtraBody in request) + res, err = client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + ExtraBody: map[string]any{ + "input_type": "query", + "truncate": "NONE", + }, + Dimensions: 1, + }, + ) + checks.NoError(t, err, "CreateEmbeddings error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } + + // test create embeddings with strings (ExtraBody in request and ) + _, err = client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + Input: make(chan int), // Channels are not serializable + Model: "example_model", + }, + ) + checks.HasError(t, err, "CreateEmbeddings error") + + // test failed (Serialize JSON error) res, err = client.CreateEmbeddings( context.Background(), openai.EmbeddingRequest{ From e6c1d3e5bde0bae5966070ab4edee874b6c8c73f Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 11 Jul 2025 17:39:39 +0100 Subject: [PATCH 34/41] Increase jsonschema test coverage (#1043) * test: expand jsonschema coverage * test: fix package name for containsref tests --- jsonschema/containsref_test.go | 48 ++++++++++++++++++++++++++++++++++ jsonschema/json_errors_test.go | 27 +++++++++++++++++++ 2 files changed, 75 insertions(+) create mode 100644 jsonschema/containsref_test.go create mode 100644 jsonschema/json_errors_test.go diff --git a/jsonschema/containsref_test.go b/jsonschema/containsref_test.go new file mode 100644 index 000000000..dc1842775 --- /dev/null +++ b/jsonschema/containsref_test.go @@ -0,0 +1,48 @@ +package jsonschema_test + +import ( + "testing" + + "github.com/sashabaranov/go-openai/jsonschema" +) + +// SelfRef struct used to produce a self-referential schema. +type SelfRef struct { + Friends []SelfRef `json:"friends"` +} + +// Address struct referenced by Person without self-reference. +type Address struct { + Street string `json:"street"` +} + +type Person struct { + Address Address `json:"address"` +} + +// TestGenerateSchemaForType_SelfRef ensures that self-referential types are not +// flattened during schema generation. +func TestGenerateSchemaForType_SelfRef(t *testing.T) { + schema, err := jsonschema.GenerateSchemaForType(SelfRef{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := schema.Defs["SelfRef"]; !ok { + t.Fatal("expected defs to contain SelfRef for self reference") + } +} + +// TestGenerateSchemaForType_NoSelfRef ensures that non-self-referential types +// are flattened and do not reappear in $defs. +func TestGenerateSchemaForType_NoSelfRef(t *testing.T) { + schema, err := jsonschema.GenerateSchemaForType(Person{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := schema.Defs["Person"]; ok { + t.Fatal("unexpected Person definition in defs") + } + if _, ok := schema.Defs["Address"]; !ok { + t.Fatal("expected Address definition in defs") + } +} diff --git a/jsonschema/json_errors_test.go b/jsonschema/json_errors_test.go new file mode 100644 index 000000000..3b864fc21 --- /dev/null +++ b/jsonschema/json_errors_test.go @@ -0,0 +1,27 @@ +package jsonschema_test + +import ( + "testing" + + "github.com/sashabaranov/go-openai/jsonschema" +) + +// TestGenerateSchemaForType_ErrorPaths verifies error handling for unsupported types. +func TestGenerateSchemaForType_ErrorPaths(t *testing.T) { + type anon struct{ Ch chan int } + tests := []struct { + name string + v any + }{ + {"slice", []chan int{}}, + {"anon struct", anon{}}, + {"pointer", (*chan int)(nil)}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if _, err := jsonschema.GenerateSchemaForType(tt.v); err == nil { + t.Errorf("expected error for %s", tt.name) + } + }) + } +} From 181c0e8fd7358b1f37f902262b246d513acd1e29 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 11 Jul 2025 18:29:23 +0100 Subject: [PATCH 35/41] Add tests for internal utilities (#1044) * Add unit tests for internal test utilities * Fix lint issues in internal tests --- internal/test/checks/checks_test.go | 19 +++++++ internal/test/failer_test.go | 24 +++++++++ internal/test/helpers_test.go | 54 +++++++++++++++++++ internal/test/server.go | 12 +++++ internal/test/server_test.go | 80 +++++++++++++++++++++++++++++ 5 files changed, 189 insertions(+) create mode 100644 internal/test/checks/checks_test.go create mode 100644 internal/test/failer_test.go create mode 100644 internal/test/helpers_test.go create mode 100644 internal/test/server_test.go diff --git a/internal/test/checks/checks_test.go b/internal/test/checks/checks_test.go new file mode 100644 index 000000000..0677054df --- /dev/null +++ b/internal/test/checks/checks_test.go @@ -0,0 +1,19 @@ +package checks_test + +import ( + "errors" + "testing" + + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestChecksSuccessPaths(t *testing.T) { + checks.NoError(t, nil) + checks.NoErrorF(t, nil) + checks.HasError(t, errors.New("err")) + target := errors.New("x") + checks.ErrorIs(t, target, target) + checks.ErrorIsF(t, target, target, "msg") + checks.ErrorIsNot(t, errors.New("y"), target) + checks.ErrorIsNotf(t, errors.New("y"), target, "msg") +} diff --git a/internal/test/failer_test.go b/internal/test/failer_test.go new file mode 100644 index 000000000..fb1f4bf06 --- /dev/null +++ b/internal/test/failer_test.go @@ -0,0 +1,24 @@ +//nolint:testpackage // need access to unexported fields and types for testing +package test + +import ( + "errors" + "testing" +) + +func TestFailingErrorBuffer(t *testing.T) { + buf := &FailingErrorBuffer{} + n, err := buf.Write([]byte("test")) + if !errors.Is(err, ErrTestErrorAccumulatorWriteFailed) { + t.Fatalf("expected %v, got %v", ErrTestErrorAccumulatorWriteFailed, err) + } + if n != 0 { + t.Fatalf("expected n=0, got %d", n) + } + if buf.Len() != 0 { + t.Fatalf("expected Len 0, got %d", buf.Len()) + } + if len(buf.Bytes()) != 0 { + t.Fatalf("expected empty bytes") + } +} diff --git a/internal/test/helpers_test.go b/internal/test/helpers_test.go new file mode 100644 index 000000000..aa177679b --- /dev/null +++ b/internal/test/helpers_test.go @@ -0,0 +1,54 @@ +package test_test + +import ( + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + internaltest "github.com/sashabaranov/go-openai/internal/test" +) + +func TestCreateTestFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "file.txt") + internaltest.CreateTestFile(t, path) + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("failed to read created file: %v", err) + } + if string(data) != "hello" { + t.Fatalf("unexpected file contents: %q", string(data)) + } +} + +func TestTokenRoundTripperAddsHeader(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "Bearer "+internaltest.GetTestToken() { + t.Fatalf("authorization header not set") + } + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + client := srv.Client() + client.Transport = &internaltest.TokenRoundTripper{Token: internaltest.GetTestToken(), Fallback: client.Transport} + + req, err := http.NewRequest(http.MethodGet, srv.URL, nil) + if err != nil { + t.Fatalf("request error: %v", err) + } + resp, err := client.Do(req) + if err != nil { + t.Fatalf("client request error: %v", err) + } + if _, err = io.Copy(io.Discard, resp.Body); err != nil { + t.Fatalf("read body: %v", err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } +} diff --git a/internal/test/server.go b/internal/test/server.go index 127d4c16f..d32c3e4cb 100644 --- a/internal/test/server.go +++ b/internal/test/server.go @@ -23,6 +23,18 @@ func NewTestServer() *ServerTest { return &ServerTest{handlers: make(map[string]handler)} } +// HandlerCount returns the number of registered handlers. +func (ts *ServerTest) HandlerCount() int { + return len(ts.handlers) +} + +// HasHandler checks if a handler was registered for the given path. +func (ts *ServerTest) HasHandler(path string) bool { + path = strings.ReplaceAll(path, "*", ".*") + _, ok := ts.handlers[path] + return ok +} + func (ts *ServerTest) RegisterHandler(path string, handler handler) { // to make the registered paths friendlier to a regex match in the route handler // in OpenAITestServer diff --git a/internal/test/server_test.go b/internal/test/server_test.go new file mode 100644 index 000000000..f8ce731d1 --- /dev/null +++ b/internal/test/server_test.go @@ -0,0 +1,80 @@ +package test_test + +import ( + "io" + "net/http" + "testing" + + internaltest "github.com/sashabaranov/go-openai/internal/test" +) + +func TestGetTestToken(t *testing.T) { + if internaltest.GetTestToken() != "this-is-my-secure-token-do-not-steal!!" { + t.Fatalf("unexpected token") + } +} + +func TestNewTestServer(t *testing.T) { + ts := internaltest.NewTestServer() + if ts == nil { + t.Fatalf("server not properly initialized") + } + if ts.HandlerCount() != 0 { + t.Fatalf("expected no handlers initially") + } +} + +func TestRegisterHandlerTransformsPath(t *testing.T) { + ts := internaltest.NewTestServer() + h := func(_ http.ResponseWriter, _ *http.Request) {} + ts.RegisterHandler("/foo/*", h) + if !ts.HasHandler("/foo/*") { + t.Fatalf("handler not registered with transformed path") + } +} + +func TestOpenAITestServer(t *testing.T) { + ts := internaltest.NewTestServer() + ts.RegisterHandler("/v1/test/*", func(w http.ResponseWriter, _ *http.Request) { + if _, err := io.WriteString(w, "ok"); err != nil { + t.Fatalf("write: %v", err) + } + }) + srv := ts.OpenAITestServer() + srv.Start() + defer srv.Close() + + base := srv.Client().Transport + client := &http.Client{Transport: &internaltest.TokenRoundTripper{Token: internaltest.GetTestToken(), Fallback: base}} + resp, err := client.Get(srv.URL + "/v1/test/123") + if err != nil { + t.Fatalf("request error: %v", err) + } + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + t.Fatalf("read response body: %v", err) + } + if resp.StatusCode != http.StatusOK || string(body) != "ok" { + t.Fatalf("unexpected response: %d %q", resp.StatusCode, string(body)) + } + + // unregistered path + resp, err = client.Get(srv.URL + "/unknown") + if err != nil { + t.Fatalf("request error: %v", err) + } + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("expected 404, got %d", resp.StatusCode) + } + + // missing token should return unauthorized + clientNoToken := srv.Client() + resp, err = clientNoToken.Get(srv.URL + "/v1/test/123") + if err != nil { + t.Fatalf("request error: %v", err) + } + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", resp.StatusCode) + } +} From 4f87294cebdd14457cc2e1013cdf5acf5db3d27d Mon Sep 17 00:00:00 2001 From: Ayush Sawant Date: Sun, 20 Jul 2025 01:29:02 +0530 Subject: [PATCH 36/41] Add GuidedChoice to ChatCompletionRequest (#1034) * Add GuidedChoice to ChatCompletionRequest * made separate NonOpenAIExtensions * fixed lint issue * renamed struct and removed inline json tag * Update chat.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update chat.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- chat.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/chat.go b/chat.go index 0f0c5b5d5..e14acd9d9 100644 --- a/chat.go +++ b/chat.go @@ -248,6 +248,16 @@ func (r *ChatCompletionResponseFormatJSONSchema) UnmarshalJSON(data []byte) erro return nil } +// ChatCompletionRequestExtensions contains third-party OpenAI API extensions (e.g., vendor-specific implementations like vLLM). +type ChatCompletionRequestExtensions struct { + // GuidedChoice is a vLLM-specific extension that restricts the model's output + // to one of the predefined string choices provided in this field. This feature + // is used to constrain the model's responses to a controlled set of options, + // ensuring predictable and consistent outputs in scenarios where specific + // choices are required. + GuidedChoice []string `json:"guided_choice,omitempty"` +} + // ChatCompletionRequest represents a request structure for chat completion API. type ChatCompletionRequest struct { Model string `json:"model"` @@ -309,6 +319,8 @@ type ChatCompletionRequest struct { ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"` // Specifies the latency tier to use for processing the request. ServiceTier ServiceTier `json:"service_tier,omitempty"` + // Embedded struct for non-OpenAI extensions + ChatCompletionRequestExtensions } type StreamOptions struct { From c4273cb5f46031ee478601f4edd82d2fad401e77 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Sat, 19 Jul 2025 21:07:34 +0100 Subject: [PATCH 37/41] fix(chat): shorten comment to pass linter (#1050) --- chat.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/chat.go b/chat.go index e14acd9d9..0bb2e98ee 100644 --- a/chat.go +++ b/chat.go @@ -248,7 +248,8 @@ func (r *ChatCompletionResponseFormatJSONSchema) UnmarshalJSON(data []byte) erro return nil } -// ChatCompletionRequestExtensions contains third-party OpenAI API extensions (e.g., vendor-specific implementations like vLLM). +// ChatCompletionRequestExtensions contains third-party OpenAI API extensions +// (e.g., vendor-specific implementations like vLLM). type ChatCompletionRequestExtensions struct { // GuidedChoice is a vLLM-specific extension that restricts the model's output // to one of the predefined string choices provided in this field. This feature @@ -264,7 +265,7 @@ type ChatCompletionRequest struct { Messages []ChatCompletionMessage `json:"messages"` // MaxTokens The maximum number of tokens that can be generated in the chat completion. // This value can be used to control costs for text generated via API. - // This value is now deprecated in favor of max_completion_tokens, and is not compatible with o1 series models. + // Deprecated: use MaxCompletionTokens. Not compatible with o1-series models. // refs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_tokens MaxTokens int `json:"max_tokens,omitempty"` // MaxCompletionTokens An upper bound for the number of tokens that can be generated for a completion, From f7d6ece81065bde1e93576c89f7506d1b35cb205 Mon Sep 17 00:00:00 2001 From: Behzad Soltanpour Date: Mon, 11 Aug 2025 12:45:50 +0330 Subject: [PATCH 38/41] add GPT-5 model constants and reasoning validation (#1062) --- chat_test.go | 120 +++++++++++++++++++++++++++++++++++++++++ completion.go | 8 +++ completion_test.go | 29 ++++++++++ reasoning_validator.go | 9 ++-- 4 files changed, 162 insertions(+), 4 deletions(-) diff --git a/chat_test.go b/chat_test.go index 172ce0740..236cff736 100644 --- a/chat_test.go +++ b/chat_test.go @@ -331,6 +331,126 @@ func TestO3ModelsChatCompletionsBetaLimitations(t *testing.T) { } } +func TestGPT5ModelsChatCompletionsBetaLimitations(t *testing.T) { + tests := []struct { + name string + in openai.ChatCompletionRequest + expectedError error + }{ + { + name: "log_probs_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + LogProbs: true, + Model: openai.GPT5, + }, + expectedError: openai.ErrReasoningModelLimitationsLogprobs, + }, + { + name: "set_temperature_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.GPT5Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(2), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_top_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.GPT5Nano, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(0.1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_n_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.GPT5ChatLatest, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(1), + N: 2, + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_presence_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.GPT5, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + PresencePenalty: float32(0.1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_frequency_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.GPT5Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + FrequencyPenalty: float32(0.1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + ctx := context.Background() + + _, err := client.CreateChatCompletion(ctx, tt.in) + checks.HasError(t, err) + msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err) + checks.ErrorIs(t, err, tt.expectedError, msg) + }) + } +} + func TestChatRequestOmitEmpty(t *testing.T) { data, err := json.Marshal(openai.ChatCompletionRequest{ // We set model b/c it's required, so omitempty doesn't make sense diff --git a/completion.go b/completion.go index 02ce7b016..27d69f587 100644 --- a/completion.go +++ b/completion.go @@ -49,6 +49,10 @@ const ( GPT4Dot1Nano20250414 = "gpt-4.1-nano-2025-04-14" GPT4Dot5Preview = "gpt-4.5-preview" GPT4Dot5Preview20250227 = "gpt-4.5-preview-2025-02-27" + GPT5 = "gpt-5" + GPT5Mini = "gpt-5-mini" + GPT5Nano = "gpt-5-nano" + GPT5ChatLatest = "gpt-5-chat-latest" GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125" GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106" GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613" @@ -142,6 +146,10 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT4Dot1Mini20250414: true, GPT4Dot1Nano: true, GPT4Dot1Nano20250414: true, + GPT5: true, + GPT5Mini: true, + GPT5Nano: true, + GPT5ChatLatest: true, }, chatCompletionsSuffix: { CodexCodeDavinci002: true, diff --git a/completion_test.go b/completion_test.go index f0ead0d63..abfc3007e 100644 --- a/completion_test.go +++ b/completion_test.go @@ -300,3 +300,32 @@ func TestCompletionWithGPT4oModels(t *testing.T) { }) } } + +// TestCompletionWithGPT5Models Tests that GPT5 models are not supported for completion endpoint. +func TestCompletionWithGPT5Models(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + models := []string{ + openai.GPT5, + openai.GPT5Mini, + openai.GPT5Nano, + openai.GPT5ChatLatest, + } + + for _, model := range models { + t.Run(model, func(t *testing.T) { + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: model, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err) + } + }) + } +} diff --git a/reasoning_validator.go b/reasoning_validator.go index 2910b1395..1d26ca047 100644 --- a/reasoning_validator.go +++ b/reasoning_validator.go @@ -28,21 +28,22 @@ var ( ErrReasoningModelLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll ) -// ReasoningValidator handles validation for o-series model requests. +// ReasoningValidator handles validation for reasoning model requests. type ReasoningValidator struct{} -// NewReasoningValidator creates a new validator for o-series models. +// NewReasoningValidator creates a new validator for reasoning models. func NewReasoningValidator() *ReasoningValidator { return &ReasoningValidator{} } -// Validate performs all validation checks for o-series models. +// Validate performs all validation checks for reasoning models. func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error { o1Series := strings.HasPrefix(request.Model, "o1") o3Series := strings.HasPrefix(request.Model, "o3") o4Series := strings.HasPrefix(request.Model, "o4") + gpt5Series := strings.HasPrefix(request.Model, "gpt-5") - if !o1Series && !o3Series && !o4Series { + if !o1Series && !o3Series && !o4Series && !gpt5Series { return nil } From f71d1a622abab7fa75159b02318cbcf6e2dcb0a0 Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Tue, 12 Aug 2025 18:03:57 +0800 Subject: [PATCH 39/41] feat: add safety_identifier params (#1066) --- chat.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/chat.go b/chat.go index 0bb2e98ee..9719f6b92 100644 --- a/chat.go +++ b/chat.go @@ -320,6 +320,11 @@ type ChatCompletionRequest struct { ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"` // Specifies the latency tier to use for processing the request. ServiceTier ServiceTier `json:"service_tier,omitempty"` + // A stable identifier used to help detect users of your application that may be violating OpenAI's usage policies. + // The IDs should be a string that uniquely identifies each user. + // We recommend hashing their username or email address, in order to avoid sending us any identifying information. + // https://platform.openai.com/docs/api-reference/chat/create#chat_create-safety_identifier + SafetyIdentifier string `json:"safety_identifier,omitempty"` // Embedded struct for non-OpenAI extensions ChatCompletionRequestExtensions } From 8e5611cc5efdc2533b80e5667b69741c3fad875c Mon Sep 17 00:00:00 2001 From: Amady Azdaev Date: Fri, 29 Aug 2025 20:29:03 +0300 Subject: [PATCH 40/41] Add Verbosity parameter to Chat Completion Request (#1064) * add verbosity param to ChatCompletionRequest * edit comment about verbosity --- chat.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/chat.go b/chat.go index 9719f6b92..0aa018715 100644 --- a/chat.go +++ b/chat.go @@ -320,6 +320,12 @@ type ChatCompletionRequest struct { ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"` // Specifies the latency tier to use for processing the request. ServiceTier ServiceTier `json:"service_tier,omitempty"` + // Verbosity determines how many output tokens are generated. Lowering the number of + // tokens reduces overall latency. It can be set to "low", "medium", or "high". + // Note: This field is only confirmed to work with gpt-5, gpt-5-mini and gpt-5-nano. + // Also, it is not in the API reference of chat completion at the time of writing, + // though it is supported by the API. + Verbosity string `json:"verbosity,omitempty"` // A stable identifier used to help detect users of your application that may be violating OpenAI's usage policies. // The IDs should be a string that uniquely identifies each user. // We recommend hashing their username or email address, in order to avoid sending us any identifying information. From 5d7a276f4c0e48f97354fe555cb793b52e350e62 Mon Sep 17 00:00:00 2001 From: Christopher Petito <47751006+krissetto@users.noreply.github.com> Date: Tue, 21 Oct 2025 21:27:33 +0200 Subject: [PATCH 41/41] Stop stripping dots in azure model mapper for models that aren't 3.5 based (#1079) fixes #978 Signed-off-by: Christopher Petito --- config.go | 7 ++++++- config_test.go | 4 ++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/config.go b/config.go index 4788ba62a..4b8cfb6fb 100644 --- a/config.go +++ b/config.go @@ -3,6 +3,7 @@ package openai import ( "net/http" "regexp" + "strings" ) const ( @@ -70,7 +71,11 @@ func DefaultAzureConfig(apiKey, baseURL string) ClientConfig { APIType: APITypeAzure, APIVersion: "2023-05-15", AzureModelMapperFunc: func(model string) string { - return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "") + // only 3.5 models have the "." stripped in their names + if strings.Contains(model, "3.5") { + return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "") + } + return strings.ReplaceAll(model, ":", "") }, HTTPClient: &http.Client{}, diff --git a/config_test.go b/config_test.go index 960230804..f44b80825 100644 --- a/config_test.go +++ b/config_test.go @@ -20,6 +20,10 @@ func TestGetAzureDeploymentByModel(t *testing.T) { Model: "gpt-3.5-turbo-0301", Expect: "gpt-35-turbo-0301", }, + { + Model: "gpt-4.1", + Expect: "gpt-4.1", + }, { Model: "text-embedding-ada-002", Expect: "text-embedding-ada-002",