From e0d0801ac73cdc87d1b56ced0a0eb71e574546c3 Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Thu, 11 Apr 2024 16:39:10 +0800 Subject: [PATCH 001/129] feat: add GPT4Turbo and GPT4Turbo20240409 (#703) --- completion.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/completion.go b/completion.go index ab1dbd6c5..00f43ff1c 100644 --- a/completion.go +++ b/completion.go @@ -22,6 +22,8 @@ const ( GPT432K = "gpt-4-32k" GPT40613 = "gpt-4-0613" GPT40314 = "gpt-4-0314" + GPT4Turbo = "gpt-4-turbo" + GPT4Turbo20240409 = "gpt-4-turbo-2024-04-09" GPT4Turbo0125 = "gpt-4-0125-preview" GPT4Turbo1106 = "gpt-4-1106-preview" GPT4TurboPreview = "gpt-4-turbo-preview" @@ -84,6 +86,8 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT4VisionPreview: true, GPT4Turbo1106: true, GPT4Turbo0125: true, + GPT4Turbo: true, + GPT4Turbo20240409: true, GPT40314: true, GPT40613: true, GPT432K: true, From ea551f422e5f38a0afc7d938eea5cff1f69494c5 Mon Sep 17 00:00:00 2001 From: Andreas Deininger Date: Sat, 13 Apr 2024 13:32:38 +0200 Subject: [PATCH 002/129] Fixing typos (#706) --- README.md | 2 +- assistant.go | 4 ++-- client_test.go | 2 +- error.go | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 9a479c0a0..7946f4d9b 100644 --- a/README.md +++ b/README.md @@ -636,7 +636,7 @@ FunctionDefinition{ }, "unit": { Type: jsonschema.String, - Enum: []string{"celcius", "fahrenheit"}, + Enum: []string{"celsius", "fahrenheit"}, }, }, Required: []string{"location"}, diff --git a/assistant.go b/assistant.go index 4ca2dda62..9415325f8 100644 --- a/assistant.go +++ b/assistant.go @@ -181,7 +181,7 @@ func (c *Client) ListAssistants( order *string, after *string, before *string, -) (reponse AssistantsList, err error) { +) (response AssistantsList, err error) { urlValues := url.Values{} if limit != nil { urlValues.Add("limit", fmt.Sprintf("%d", *limit)) @@ -208,7 +208,7 @@ func (c *Client) ListAssistants( return } - err = c.sendRequest(req, &reponse) + err = c.sendRequest(req, &response) return } diff --git a/client_test.go b/client_test.go index bc5133edc..a08d10f21 100644 --- a/client_test.go +++ b/client_test.go @@ -406,7 +406,7 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { } } -func TestClientReturnsRequestBuilderErrorsAddtion(t *testing.T) { +func TestClientReturnsRequestBuilderErrorsAddition(t *testing.T) { config := DefaultConfig(test.GetTestToken()) client := NewClientWithConfig(config) client.requestBuilder = &failingRequestBuilder{} diff --git a/error.go b/error.go index b2d01e22e..37959a272 100644 --- a/error.go +++ b/error.go @@ -23,7 +23,7 @@ type InnerError struct { ContentFilterResults ContentFilterResults `json:"content_filter_result,omitempty"` } -// RequestError provides informations about generic request errors. +// RequestError provides information about generic request errors. type RequestError struct { HTTPStatusCode int Err error From 2446f08f94b2750287c40bb9593377f349f5578e Mon Sep 17 00:00:00 2001 From: Andreas Deininger Date: Sat, 13 Apr 2024 13:34:23 +0200 Subject: [PATCH 003/129] Bump GitHub workflow actions to latest versions (#707) --- .github/workflows/close-inactive-issues.yml | 2 +- .github/workflows/pr.yml | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/close-inactive-issues.yml b/.github/workflows/close-inactive-issues.yml index bfe9b5c96..32723c4e9 100644 --- a/.github/workflows/close-inactive-issues.yml +++ b/.github/workflows/close-inactive-issues.yml @@ -10,7 +10,7 @@ jobs: issues: write pull-requests: write steps: - - uses: actions/stale@v5 + - uses: actions/stale@v9 with: days-before-issue-stale: 30 days-before-issue-close: 14 diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 8df721f0f..a41fff92f 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -9,19 +9,19 @@ jobs: name: Sanity check runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Setup Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v5 with: - go-version: '1.19' + go-version: '1.21' - name: Run vet run: | go vet . - name: Run golangci-lint - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v4 with: version: latest - name: Run tests run: go test -race -covermode=atomic -coverprofile=coverage.out -v . - name: Upload coverage reports to Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 From a42f51967f5c2f8462f8d8dfd25f7d6a8d7a46fc Mon Sep 17 00:00:00 2001 From: Quest Henkart Date: Wed, 17 Apr 2024 03:26:14 +0800 Subject: [PATCH 004/129] [New_Features] Adds recently added Assistant cost saving parameters (#710) * add cost saving parameters * add periods at the end of comments * shorten commnet * further lower comment length * fix type --- run.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/run.go b/run.go index 1f3cb7eb7..7c14779c5 100644 --- a/run.go +++ b/run.go @@ -28,6 +28,16 @@ type Run struct { Metadata map[string]any `json:"metadata"` Usage Usage `json:"usage,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` + // The maximum number of prompt tokens that may be used over the course of the run. + // If the run exceeds the number of prompt tokens specified, the run will end with status 'complete'. + MaxPromptTokens int `json:"max_prompt_tokens,omitempty"` + // The maximum number of completion tokens that may be used over the course of the run. + // If the run exceeds the number of completion tokens specified, the run will end with status 'complete'. + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` + // ThreadTruncationStrategy defines the truncation strategy to use for the thread. + TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` + httpHeader } @@ -78,8 +88,42 @@ type RunRequest struct { AdditionalInstructions string `json:"additional_instructions,omitempty"` Tools []Tool `json:"tools,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` + + // Sampling temperature between 0 and 2. Higher values like 0.8 are more random. + // lower values are more focused and deterministic. + Temperature *float32 `json:"temperature,omitempty"` + + // The maximum number of prompt tokens that may be used over the course of the run. + // If the run exceeds the number of prompt tokens specified, the run will end with status 'complete'. + MaxPromptTokens int `json:"max_prompt_tokens,omitempty"` + + // The maximum number of completion tokens that may be used over the course of the run. + // If the run exceeds the number of completion tokens specified, the run will end with status 'complete'. + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` + + // ThreadTruncationStrategy defines the truncation strategy to use for the thread. + TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` } +// ThreadTruncationStrategy defines the truncation strategy to use for the thread. +// https://platform.openai.com/docs/assistants/how-it-works/truncation-strategy. +type ThreadTruncationStrategy struct { + // default 'auto'. + Type TruncationStrategy `json:"type,omitempty"` + // this field should be set if the truncation strategy is set to LastMessages. + LastMessages *int `json:"last_messages,omitempty"` +} + +// TruncationStrategy defines the existing truncation strategies existing for thread management in an assistant. +type TruncationStrategy string + +const ( + // TruncationStrategyAuto messages in the middle of the thread will be dropped to fit the context length of the model. + TruncationStrategyAuto = TruncationStrategy("auto") + // TruncationStrategyLastMessages the thread will be truncated to the n most recent messages in the thread. + TruncationStrategyLastMessages = TruncationStrategy("last_messages") +) + type RunModifyRequest struct { Metadata map[string]any `json:"metadata,omitempty"` } From c6a63ed19aeb0e91facc5409c5a08612db550fb2 Mon Sep 17 00:00:00 2001 From: Mike Chaykowsky Date: Tue, 16 Apr 2024 12:28:06 -0700 Subject: [PATCH 005/129] Add PromptFilterResult (#702) --- chat_stream.go | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/chat_stream.go b/chat_stream.go index 57cfa789f..6ff7078e2 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -19,13 +19,19 @@ type ChatCompletionStreamChoice struct { ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` } +type PromptFilterResult struct { + Index int `json:"index"` + ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` +} + type ChatCompletionStreamResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []ChatCompletionStreamChoice `json:"choices"` - PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionStreamChoice `json:"choices"` + PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` + PromptFilterResults []PromptFilterResult `json:"prompt_filter_results,omitempty"` } // ChatCompletionStream From 8d15a377ec4fa3aaf2e706cd1e2ad986dd6b8242 Mon Sep 17 00:00:00 2001 From: Danai Antoniou <32068609+danai-antoniou@users.noreply.github.com> Date: Wed, 24 Apr 2024 12:59:50 +0100 Subject: [PATCH 006/129] Remove hardcoded assistants version (#719) --- assistant.go | 19 +++++++++---------- client.go | 4 ++-- config.go | 14 +++++++++----- messages.go | 17 +++++++++++------ run.go | 27 +++++++++------------------ thread.go | 8 ++++---- 6 files changed, 44 insertions(+), 45 deletions(-) diff --git a/assistant.go b/assistant.go index 9415325f8..661681e83 100644 --- a/assistant.go +++ b/assistant.go @@ -11,7 +11,6 @@ import ( const ( assistantsSuffix = "/assistants" assistantsFilesSuffix = "/files" - openaiAssistantsV1 = "assistants=v1" ) type Assistant struct { @@ -116,7 +115,7 @@ type AssistantFilesList struct { // CreateAssistant creates a new assistant. func (c *Client) CreateAssistant(ctx context.Context, request AssistantRequest) (response Assistant, err error) { req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(assistantsSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -132,7 +131,7 @@ func (c *Client) RetrieveAssistant( ) (response Assistant, err error) { urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -149,7 +148,7 @@ func (c *Client) ModifyAssistant( ) (response Assistant, err error) { urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -165,7 +164,7 @@ func (c *Client) DeleteAssistant( ) (response AssistantDeleteResponse, err error) { urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -203,7 +202,7 @@ func (c *Client) ListAssistants( urlSuffix := fmt.Sprintf("%s%s", assistantsSuffix, encodedValues) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -221,7 +220,7 @@ func (c *Client) CreateAssistantFile( urlSuffix := fmt.Sprintf("%s/%s%s", assistantsSuffix, assistantID, assistantsFilesSuffix) req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -238,7 +237,7 @@ func (c *Client) RetrieveAssistantFile( ) (response AssistantFile, err error) { urlSuffix := fmt.Sprintf("%s/%s%s/%s", assistantsSuffix, assistantID, assistantsFilesSuffix, fileID) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -255,7 +254,7 @@ func (c *Client) DeleteAssistantFile( ) (err error) { urlSuffix := fmt.Sprintf("%s/%s%s/%s", assistantsSuffix, assistantID, assistantsFilesSuffix, fileID) req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -294,7 +293,7 @@ func (c *Client) ListAssistantFiles( urlSuffix := fmt.Sprintf("%s/%s%s%s", assistantsSuffix, assistantID, assistantsFilesSuffix, encodedValues) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } diff --git a/client.go b/client.go index 9a1c8958d..77d693226 100644 --- a/client.go +++ b/client.go @@ -89,9 +89,9 @@ func withContentType(contentType string) requestOption { } } -func withBetaAssistantV1() requestOption { +func withBetaAssistantVersion(version string) requestOption { return func(args *requestOptions) { - args.header.Set("OpenAI-Beta", "assistants=v1") + args.header.Set("OpenAI-Beta", fmt.Sprintf("assistants=%s", version)) } } diff --git a/config.go b/config.go index c58b71ec6..599fa89c0 100644 --- a/config.go +++ b/config.go @@ -23,6 +23,8 @@ const ( const AzureAPIKeyHeader = "api-key" +const defaultAssistantVersion = "v1" // This will be deprecated by the end of 2024. + // ClientConfig is a configuration of a client. type ClientConfig struct { authToken string @@ -30,7 +32,8 @@ type ClientConfig struct { BaseURL string OrgID string APIType APIType - APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD + APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD + AssistantVersion string AzureModelMapperFunc func(model string) string // replace model to azure deployment name func HTTPClient *http.Client @@ -39,10 +42,11 @@ type ClientConfig struct { func DefaultConfig(authToken string) ClientConfig { return ClientConfig{ - authToken: authToken, - BaseURL: openaiAPIURLv1, - APIType: APITypeOpenAI, - OrgID: "", + authToken: authToken, + BaseURL: openaiAPIURLv1, + APIType: APITypeOpenAI, + AssistantVersion: defaultAssistantVersion, + OrgID: "", HTTPClient: &http.Client{}, diff --git a/messages.go b/messages.go index 6fd0adbc9..6af118445 100644 --- a/messages.go +++ b/messages.go @@ -76,7 +76,8 @@ type MessageFilesList struct { // CreateMessage creates a new message. func (c *Client) CreateMessage(ctx context.Context, threadID string, request MessageRequest) (msg Message, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s", threadID, messagesSuffix) - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -111,7 +112,8 @@ func (c *Client) ListMessage(ctx context.Context, threadID string, } urlSuffix := fmt.Sprintf("/threads/%s/%s%s", threadID, messagesSuffix, encodedValues) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -126,7 +128,8 @@ func (c *Client) RetrieveMessage( threadID, messageID string, ) (msg Message, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -143,7 +146,7 @@ func (c *Client) ModifyMessage( ) (msg Message, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), - withBody(map[string]any{"metadata": metadata}), withBetaAssistantV1()) + withBody(map[string]any{"metadata": metadata}), withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -158,7 +161,8 @@ func (c *Client) RetrieveMessageFile( threadID, messageID, fileID string, ) (file MessageFile, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s/%s/files/%s", threadID, messagesSuffix, messageID, fileID) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -173,7 +177,8 @@ func (c *Client) ListMessageFiles( threadID, messageID string, ) (files MessageFilesList, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s/%s/files", threadID, messagesSuffix, messageID) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } diff --git a/run.go b/run.go index 7c14779c5..094b0a4db 100644 --- a/run.go +++ b/run.go @@ -226,8 +226,7 @@ func (c *Client) CreateRun( http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -247,8 +246,7 @@ func (c *Client) RetrieveRun( ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -270,8 +268,7 @@ func (c *Client) ModifyRun( http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -310,8 +307,7 @@ func (c *Client) ListRuns( ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -332,8 +328,7 @@ func (c *Client) SubmitToolOutputs( http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -352,8 +347,7 @@ func (c *Client) CancelRun( ctx, http.MethodPost, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -372,8 +366,7 @@ func (c *Client) CreateThreadAndRun( http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -394,8 +387,7 @@ func (c *Client) RetrieveRunStep( ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -435,8 +427,7 @@ func (c *Client) ListRunSteps( ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } diff --git a/thread.go b/thread.go index 291f3dcab..900e3f2ea 100644 --- a/thread.go +++ b/thread.go @@ -51,7 +51,7 @@ type ThreadDeleteResponse struct { // CreateThread creates a new thread. func (c *Client) CreateThread(ctx context.Context, request ThreadRequest) (response Thread, err error) { req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(threadsSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -64,7 +64,7 @@ func (c *Client) CreateThread(ctx context.Context, request ThreadRequest) (respo func (c *Client) RetrieveThread(ctx context.Context, threadID string) (response Thread, err error) { urlSuffix := threadsSuffix + "/" + threadID req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -81,7 +81,7 @@ func (c *Client) ModifyThread( ) (response Thread, err error) { urlSuffix := threadsSuffix + "/" + threadID req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -97,7 +97,7 @@ func (c *Client) DeleteThread( ) (response ThreadDeleteResponse, err error) { urlSuffix := threadsSuffix + "/" + threadID req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } From 2d58f8f4b87be26dc0b7ba2b1f0c9496ecf1dfa5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=80=E6=97=A5=E3=80=82?= Date: Wed, 24 Apr 2024 20:02:03 +0800 Subject: [PATCH 007/129] chore: add SystemFingerprint for chat completion stream response (#716) * chore: add SystemFingerprint for stream response * chore: add test * lint: format for test --- chat_stream.go | 1 + chat_stream_test.go | 22 ++++++++++++---------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/chat_stream.go b/chat_stream.go index 6ff7078e2..159f9f472 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -30,6 +30,7 @@ type ChatCompletionStreamResponse struct { Created int64 `json:"created"` Model string `json:"model"` Choices []ChatCompletionStreamChoice `json:"choices"` + SystemFingerprint string `json:"system_fingerprint"` PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` PromptFilterResults []PromptFilterResult `json:"prompt_filter_results,omitempty"` } diff --git a/chat_stream_test.go b/chat_stream_test.go index bd571cb48..bd1c737dd 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -46,12 +46,12 @@ func TestCreateChatCompletionStream(t *testing.T) { dataBytes := []byte{} dataBytes = append(dataBytes, []byte("event: message\n")...) //nolint:lll - data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}` + data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}` dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) dataBytes = append(dataBytes, []byte("event: message\n")...) //nolint:lll - data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}` + data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}` dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) dataBytes = append(dataBytes, []byte("event: done\n")...) @@ -77,10 +77,11 @@ func TestCreateChatCompletionStream(t *testing.T) { expectedResponses := []openai.ChatCompletionStreamResponse{ { - ID: "1", - Object: "completion", - Created: 1598069254, - Model: openai.GPT3Dot5Turbo, + ID: "1", + Object: "completion", + Created: 1598069254, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", Choices: []openai.ChatCompletionStreamChoice{ { Delta: openai.ChatCompletionStreamChoiceDelta{ @@ -91,10 +92,11 @@ func TestCreateChatCompletionStream(t *testing.T) { }, }, { - ID: "2", - Object: "completion", - Created: 1598069255, - Model: openai.GPT3Dot5Turbo, + ID: "2", + Object: "completion", + Created: 1598069255, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", Choices: []openai.ChatCompletionStreamChoice{ { Delta: openai.ChatCompletionStreamChoiceDelta{ From c84ab5f6ae8da3a78826ed2c8dc4c5cf93e30589 Mon Sep 17 00:00:00 2001 From: wurui <1009479218@qq.com> Date: Wed, 24 Apr 2024 20:08:58 +0800 Subject: [PATCH 008/129] feat: support cloudflare AI Gateway flavored azure openai (#715) * feat: support cloudflare AI Gateway flavored azure openai Signed-off-by: STRRL * test: add test for cloudflare azure fullURL --------- Signed-off-by: STRRL Co-authored-by: STRRL --- api_internal_test.go | 36 ++++++++++++++++++++++++++++++++++++ client.go | 10 ++++++++-- config.go | 7 ++++--- 3 files changed, 48 insertions(+), 5 deletions(-) diff --git a/api_internal_test.go b/api_internal_test.go index 0fb0f8993..a590ec9ab 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -148,3 +148,39 @@ func TestAzureFullURL(t *testing.T) { }) } } + +func TestCloudflareAzureFullURL(t *testing.T) { + cases := []struct { + Name string + BaseURL string + Expect string + }{ + { + "CloudflareAzureBaseURLWithSlashAutoStrip", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + + "chat/completions?api-version=2023-05-15", + }, + { + "CloudflareAzureBaseURLWithoutSlashOK", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + + "chat/completions?api-version=2023-05-15", + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + az := DefaultAzureConfig("dummy", c.BaseURL) + az.APIType = APITypeCloudflareAzure + + cli := NewClientWithConfig(az) + + actual := cli.fullURL("/chat/completions") + if actual != c.Expect { + t.Errorf("Expected %s, got %s", c.Expect, actual) + } + t.Logf("Full URL: %s", actual) + }) + } +} diff --git a/client.go b/client.go index 77d693226..c57ba17c7 100644 --- a/client.go +++ b/client.go @@ -182,7 +182,7 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream func (c *Client) setCommonHeaders(req *http.Request) { // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication // Azure API Key authentication - if c.config.APIType == APITypeAzure { + if c.config.APIType == APITypeAzure || c.config.APIType == APITypeCloudflareAzure { req.Header.Set(AzureAPIKeyHeader, c.config.authToken) } else if c.config.authToken != "" { // OpenAI or Azure AD authentication @@ -246,7 +246,13 @@ func (c *Client) fullURL(suffix string, args ...any) string { ) } - // c.config.APIType == APITypeOpenAI || c.config.APIType == "" + // https://developers.cloudflare.com/ai-gateway/providers/azureopenai/ + if c.config.APIType == APITypeCloudflareAzure { + baseURL := c.config.BaseURL + baseURL = strings.TrimRight(baseURL, "/") + return fmt.Sprintf("%s%s?api-version=%s", baseURL, suffix, c.config.APIVersion) + } + return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) } diff --git a/config.go b/config.go index 599fa89c0..bb437c97f 100644 --- a/config.go +++ b/config.go @@ -16,9 +16,10 @@ const ( type APIType string const ( - APITypeOpenAI APIType = "OPEN_AI" - APITypeAzure APIType = "AZURE" - APITypeAzureAD APIType = "AZURE_AD" + APITypeOpenAI APIType = "OPEN_AI" + APITypeAzure APIType = "AZURE" + APITypeAzureAD APIType = "AZURE_AD" + APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE" ) const AzureAPIKeyHeader = "api-key" From c9953a7b051bd661254fb071029553e61c78f8bd Mon Sep 17 00:00:00 2001 From: Alireza Ghasemi Date: Sat, 27 Apr 2024 12:55:49 +0330 Subject: [PATCH 009/129] Fixup minor copy-pasta comment typo (#728) imagess -> images --- image_api_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/image_api_test.go b/image_api_test.go index 2eb46f2b4..48416b1e2 100644 --- a/image_api_test.go +++ b/image_api_test.go @@ -36,7 +36,7 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { var err error var resBytes []byte - // imagess only accepts POST requests + // images only accepts POST requests if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } @@ -146,7 +146,7 @@ func TestImageEditWithoutMask(t *testing.T) { func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) { var resBytes []byte - // imagess only accepts POST requests + // images only accepts POST requests if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } @@ -202,7 +202,7 @@ func TestImageVariation(t *testing.T) { func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) { var resBytes []byte - // imagess only accepts POST requests + // images only accepts POST requests if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } From 3334a9c78a9d594934e33af184e4e6313c4a942b Mon Sep 17 00:00:00 2001 From: Alireza Ghasemi Date: Tue, 7 May 2024 16:10:07 +0330 Subject: [PATCH 010/129] Add support for word-level audio transcription timestamp granularity (#733) * Add support for audio transcription timestamp_granularities word * Fixup multiple timestamp granularities --- audio.go | 31 ++++++++++++++++++++++++++----- audio_api_test.go | 4 ++++ audio_test.go | 6 +++++- 3 files changed, 35 insertions(+), 6 deletions(-) diff --git a/audio.go b/audio.go index 4cbe4fe64..dbc26d154 100644 --- a/audio.go +++ b/audio.go @@ -27,8 +27,14 @@ const ( AudioResponseFormatVTT AudioResponseFormat = "vtt" ) +type TranscriptionTimestampGranularity string + +const ( + TranscriptionTimestampGranularityWord TranscriptionTimestampGranularity = "word" + TranscriptionTimestampGranularitySegment TranscriptionTimestampGranularity = "segment" +) + // AudioRequest represents a request structure for audio API. -// ResponseFormat is not supported for now. We only return JSON text, which may be sufficient. type AudioRequest struct { Model string @@ -38,10 +44,11 @@ type AudioRequest struct { // Reader is an optional io.Reader when you do not want to use an existing file. Reader io.Reader - Prompt string // For translation, it should be in English - Temperature float32 - Language string // For translation, just do not use it. It seems "en" works, not confirmed... - Format AudioResponseFormat + Prompt string + Temperature float32 + Language string // Only for transcription. + Format AudioResponseFormat + TimestampGranularities []TranscriptionTimestampGranularity // Only for transcription. } // AudioResponse represents a response structure for audio API. @@ -62,6 +69,11 @@ type AudioResponse struct { NoSpeechProb float64 `json:"no_speech_prob"` Transient bool `json:"transient"` } `json:"segments"` + Words []struct { + Word string `json:"word"` + Start float64 `json:"start"` + End float64 `json:"end"` + } `json:"words"` Text string `json:"text"` httpHeader @@ -179,6 +191,15 @@ func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error { } } + if len(request.TimestampGranularities) > 0 { + for _, tg := range request.TimestampGranularities { + err = b.WriteField("timestamp_granularities[]", string(tg)) + if err != nil { + return fmt.Errorf("writing timestamp_granularities[]: %w", err) + } + } + } + // Close the multipart writer return b.Close() } diff --git a/audio_api_test.go b/audio_api_test.go index a0efc7921..c24598443 100644 --- a/audio_api_test.go +++ b/audio_api_test.go @@ -105,6 +105,10 @@ func TestAudioWithOptionalArgs(t *testing.T) { Temperature: 0.5, Language: "zh", Format: openai.AudioResponseFormatSRT, + TimestampGranularities: []openai.TranscriptionTimestampGranularity{ + openai.TranscriptionTimestampGranularitySegment, + openai.TranscriptionTimestampGranularityWord, + }, } _, err := tc.createFn(ctx, req) checks.NoError(t, err, "audio API error") diff --git a/audio_test.go b/audio_test.go index 5346244c8..235931f36 100644 --- a/audio_test.go +++ b/audio_test.go @@ -24,6 +24,10 @@ func TestAudioWithFailingFormBuilder(t *testing.T) { Temperature: 0.5, Language: "en", Format: AudioResponseFormatSRT, + TimestampGranularities: []TranscriptionTimestampGranularity{ + TranscriptionTimestampGranularitySegment, + TranscriptionTimestampGranularityWord, + }, } mockFailedErr := fmt.Errorf("mock form builder fail") @@ -47,7 +51,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) { return nil } - failOn := []string{"model", "prompt", "temperature", "language", "response_format"} + failOn := []string{"model", "prompt", "temperature", "language", "response_format", "timestamp_granularities[]"} for _, failingField := range failOn { failForField = failingField mockFailedErr = fmt.Errorf("mock form builder fail on field %s", failingField) From 6af32202d1ce469674050600efa07c90ec286d03 Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Tue, 7 May 2024 20:42:24 +0800 Subject: [PATCH 011/129] feat: support stream_options (#736) * feat: support stream_options * fix lint * fix lint --- chat.go | 10 ++++ chat_stream.go | 4 ++ chat_stream_test.go | 123 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 137 insertions(+) diff --git a/chat.go b/chat.go index efb14fd4c..a1eb11720 100644 --- a/chat.go +++ b/chat.go @@ -216,6 +216,16 @@ type ChatCompletionRequest struct { Tools []Tool `json:"tools,omitempty"` // This can be either a string or an ToolChoice object. ToolChoice any `json:"tool_choice,omitempty"` + // Options for streaming response. Only set this when you set stream: true. + StreamOptions *StreamOptions `json:"stream_options,omitempty"` +} + +type StreamOptions struct { + // If set, an additional chunk will be streamed before the data: [DONE] message. + // The usage field on this chunk shows the token usage statistics for the entire request, + // and the choices field will always be an empty array. + // All other chunks will also include a usage field, but with a null value. + IncludeUsage bool `json:"include_usage,omitempty"` } type ToolType string diff --git a/chat_stream.go b/chat_stream.go index 159f9f472..ffd512ff6 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -33,6 +33,10 @@ type ChatCompletionStreamResponse struct { SystemFingerprint string `json:"system_fingerprint"` PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` PromptFilterResults []PromptFilterResult `json:"prompt_filter_results,omitempty"` + // An optional field that will only be present when you set stream_options: {"include_usage": true} in your request. + // When present, it contains a null value except for the last chunk which contains the token usage statistics + // for the entire request. + Usage *Usage `json:"usage,omitempty"` } // ChatCompletionStream diff --git a/chat_stream_test.go b/chat_stream_test.go index bd1c737dd..63e45ee23 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -388,6 +388,120 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { } } +func TestCreateChatCompletionStreamStreamOptions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + var dataBytes []byte + //nolint:lll + data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}],"usage":null}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + //nolint:lll + data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}],"usage":null}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + //nolint:lll + data = `{"id":"3","object":"completion","created":1598069256,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + StreamOptions: &openai.StreamOptions{ + IncludeUsage: true, + }, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "1", + Object: "completion", + Created: 1598069254, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "response1", + }, + FinishReason: "max_tokens", + }, + }, + }, + { + ID: "2", + Object: "completion", + Created: 1598069255, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "response2", + }, + FinishReason: "max_tokens", + }, + }, + }, + { + ID: "3", + Object: "completion", + Created: 1598069256, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{}, + Usage: &openai.Usage{ + PromptTokens: 1, + CompletionTokens: 1, + TotalTokens: 2, + }, + }, + } + + for ix, expectedResponse := range expectedResponses { + b, _ := json.Marshal(expectedResponse) + t.Logf("%d: %s", ix, string(b)) + + receivedResponse, streamErr := stream.Recv() + checks.NoError(t, streamErr, "stream.Recv() failed") + if !compareChatResponses(expectedResponse, receivedResponse) { + t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) + } + } + + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } + + _, streamErr = stream.Recv() + + checks.ErrorIs(t, streamErr, io.EOF, "stream.Recv() did not return EOF when the stream is finished") + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr) + } +} + // Helper funcs. func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { @@ -401,6 +515,15 @@ func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool { return false } } + if r1.Usage != nil || r2.Usage != nil { + if r1.Usage == nil || r2.Usage == nil { + return false + } + if r1.Usage.PromptTokens != r2.Usage.PromptTokens || r1.Usage.CompletionTokens != r2.Usage.CompletionTokens || + r1.Usage.TotalTokens != r2.Usage.TotalTokens { + return false + } + } return true } From 3b25e09da90715681fe4049955d7c7ce645e218c Mon Sep 17 00:00:00 2001 From: Kevin Mesiab Date: Mon, 13 May 2024 11:48:14 -0700 Subject: [PATCH 012/129] enhancement: Add new GPT4-o and alias to completion enums (#744) --- completion.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/completion.go b/completion.go index 00f43ff1c..3b4f8952a 100644 --- a/completion.go +++ b/completion.go @@ -22,6 +22,8 @@ const ( GPT432K = "gpt-4-32k" GPT40613 = "gpt-4-0613" GPT40314 = "gpt-4-0314" + GPT4o = "gpt-4o" + GPT4o20240513 = "gpt-4o-2024-05-13" GPT4Turbo = "gpt-4-turbo" GPT4Turbo20240409 = "gpt-4-turbo-2024-04-09" GPT4Turbo0125 = "gpt-4-0125-preview" From 9f19d1c93bf986f2a8925be62f35aa5c413a706a Mon Sep 17 00:00:00 2001 From: nullswan Date: Mon, 13 May 2024 21:07:07 +0200 Subject: [PATCH 013/129] Add gpt4o (#742) * Add gpt4o * disabled model for endpoint seen in https://github.com/sashabaranov/go-openai/commit/e0d0801ac73cdc87d1b56ced0a0eb71e574546c3 * Update completion.go --------- Co-authored-by: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> --- completion.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/completion.go b/completion.go index 3b4f8952a..ced8e0606 100644 --- a/completion.go +++ b/completion.go @@ -84,6 +84,8 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT3Dot5Turbo16K: true, GPT3Dot5Turbo16K0613: true, GPT4: true, + GPT4o: true, + GPT4o20240513: true, GPT4TurboPreview: true, GPT4VisionPreview: true, GPT4Turbo1106: true, From 4f4a85687be31607536997e924b27693f5e5211a Mon Sep 17 00:00:00 2001 From: Kshirodra Meher Date: Tue, 14 May 2024 00:38:14 +0530 Subject: [PATCH 014/129] Added DALL.E 3 to readme.md (#741) * Added DALL.E 3 to readme.md Added DALL.E 3 to readme.md as its supported now as per issue https://github.com/sashabaranov/go-openai/issues/494 * Update README.md --------- Co-authored-by: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7946f4d9b..799dc602b 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ This library provides unofficial Go clients for [OpenAI API](https://platform.op * ChatGPT * GPT-3, GPT-4 -* DALL·E 2 +* DALL·E 2, DALL·E 3 * Whisper ## Installation From 211cb49fc22766f4174fef15301c4d39aef609d3 Mon Sep 17 00:00:00 2001 From: ando-masaki Date: Fri, 24 May 2024 16:18:47 +0900 Subject: [PATCH 015/129] Update client.go to get response header whether there is an error or not. (#751) Update client.go to get response header whether there is an error or not. Because 429 Too Many Requests error response has "Retry-After" header. --- client.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index c57ba17c7..7bc28e984 100644 --- a/client.go +++ b/client.go @@ -129,14 +129,14 @@ func (c *Client) sendRequest(req *http.Request, v Response) error { defer res.Body.Close() - if isFailureStatusCode(res) { - return c.handleErrorResp(res) - } - if v != nil { v.SetHeader(res.Header) } + if isFailureStatusCode(res) { + return c.handleErrorResp(res) + } + return decodeResponse(res.Body, v) } From 30cf7b879cff5eb56f06fda19c51c9e92fce8b13 Mon Sep 17 00:00:00 2001 From: Adam Smith <62568604+TheAdamSmith@users.noreply.github.com> Date: Mon, 3 Jun 2024 09:50:22 -0700 Subject: [PATCH 016/129] feat: add params to RunRequest (#754) --- run.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/run.go b/run.go index 094b0a4db..6bd3933b1 100644 --- a/run.go +++ b/run.go @@ -92,6 +92,7 @@ type RunRequest struct { // Sampling temperature between 0 and 2. Higher values like 0.8 are more random. // lower values are more focused and deterministic. Temperature *float32 `json:"temperature,omitempty"` + TopP *float32 `json:"top_p,omitempty"` // The maximum number of prompt tokens that may be used over the course of the run. // If the run exceeds the number of prompt tokens specified, the run will end with status 'complete'. @@ -103,6 +104,11 @@ type RunRequest struct { // ThreadTruncationStrategy defines the truncation strategy to use for the thread. TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` + + // This can be either a string or a ToolChoice object. + ToolChoice any `json:"tool_choice,omitempty"` + // This can be either a string or a ResponseFormat object. + ResponseFormat any `json:"response_format,omitempty"` } // ThreadTruncationStrategy defines the truncation strategy to use for the thread. @@ -124,6 +130,13 @@ const ( TruncationStrategyLastMessages = TruncationStrategy("last_messages") ) +// ReponseFormat specifies the format the model must output. +// https://platform.openai.com/docs/api-reference/runs/createRun#runs-createrun-response_format. +// Type can either be text or json_object. +type ReponseFormat struct { + Type string `json:"type"` +} + type RunModifyRequest struct { Metadata map[string]any `json:"metadata,omitempty"` } From 8618492b98bb91edbb43f8080b3a68275e183663 Mon Sep 17 00:00:00 2001 From: shosato0306 <38198918+shosato0306@users.noreply.github.com> Date: Wed, 5 Jun 2024 20:03:57 +0900 Subject: [PATCH 017/129] feat: add incomplete run status (#763) --- run.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/run.go b/run.go index 6bd3933b1..5598f1dfb 100644 --- a/run.go +++ b/run.go @@ -30,10 +30,10 @@ type Run struct { Temperature *float32 `json:"temperature,omitempty"` // The maximum number of prompt tokens that may be used over the course of the run. - // If the run exceeds the number of prompt tokens specified, the run will end with status 'complete'. + // If the run exceeds the number of prompt tokens specified, the run will end with status 'incomplete'. MaxPromptTokens int `json:"max_prompt_tokens,omitempty"` // The maximum number of completion tokens that may be used over the course of the run. - // If the run exceeds the number of completion tokens specified, the run will end with status 'complete'. + // If the run exceeds the number of completion tokens specified, the run will end with status 'incomplete'. MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // ThreadTruncationStrategy defines the truncation strategy to use for the thread. TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` @@ -50,6 +50,7 @@ const ( RunStatusCancelling RunStatus = "cancelling" RunStatusFailed RunStatus = "failed" RunStatusCompleted RunStatus = "completed" + RunStatusIncomplete RunStatus = "incomplete" RunStatusExpired RunStatus = "expired" RunStatusCancelled RunStatus = "cancelled" ) @@ -95,11 +96,11 @@ type RunRequest struct { TopP *float32 `json:"top_p,omitempty"` // The maximum number of prompt tokens that may be used over the course of the run. - // If the run exceeds the number of prompt tokens specified, the run will end with status 'complete'. + // If the run exceeds the number of prompt tokens specified, the run will end with status 'incomplete'. MaxPromptTokens int `json:"max_prompt_tokens,omitempty"` // The maximum number of completion tokens that may be used over the course of the run. - // If the run exceeds the number of completion tokens specified, the run will end with status 'complete'. + // If the run exceeds the number of completion tokens specified, the run will end with status 'incomplete'. MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // ThreadTruncationStrategy defines the truncation strategy to use for the thread. From fd41f7a5f49e6723d97642c186e5e090abaebfe2 Mon Sep 17 00:00:00 2001 From: Adam Smith <62568604+TheAdamSmith@users.noreply.github.com> Date: Thu, 13 Jun 2024 06:23:07 -0700 Subject: [PATCH 018/129] Fix integration test (#762) * added TestCompletionStream test moved completion stream testing to seperate function added NoErrorF fixes nil pointer reference on stream object * update integration test models --- api_integration_test.go | 64 ++++++++++++++++++++-------------- completion.go | 31 ++++++++-------- embeddings.go | 2 +- internal/test/checks/checks.go | 7 ++++ 4 files changed, 62 insertions(+), 42 deletions(-) diff --git a/api_integration_test.go b/api_integration_test.go index 736040c50..f34685188 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -26,7 +26,7 @@ func TestAPI(t *testing.T) { _, err = c.ListEngines(ctx) checks.NoError(t, err, "ListEngines error") - _, err = c.GetEngine(ctx, "davinci") + _, err = c.GetEngine(ctx, openai.GPT3Davinci002) checks.NoError(t, err, "GetEngine error") fileRes, err := c.ListFiles(ctx) @@ -42,7 +42,7 @@ func TestAPI(t *testing.T) { "The food was delicious and the waiter", "Other examples of embedding request", }, - Model: openai.AdaSearchQuery, + Model: openai.AdaEmbeddingV2, } _, err = c.CreateEmbeddings(ctx, embeddingReq) checks.NoError(t, err, "Embedding error") @@ -77,31 +77,6 @@ func TestAPI(t *testing.T) { ) checks.NoError(t, err, "CreateChatCompletion (with name) returned error") - stream, err := c.CreateCompletionStream(ctx, openai.CompletionRequest{ - Prompt: "Ex falso quodlibet", - Model: openai.GPT3Ada, - MaxTokens: 5, - Stream: true, - }) - checks.NoError(t, err, "CreateCompletionStream returned error") - defer stream.Close() - - counter := 0 - for { - _, err = stream.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - break - } - t.Errorf("Stream error: %v", err) - } else { - counter++ - } - } - if counter == 0 { - t.Error("Stream did not return any responses") - } - _, err = c.CreateChatCompletion( context.Background(), openai.ChatCompletionRequest{ @@ -134,6 +109,41 @@ func TestAPI(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion (with functions) returned error") } +func TestCompletionStream(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + c := openai.NewClient(apiToken) + ctx := context.Background() + + stream, err := c.CreateCompletionStream(ctx, openai.CompletionRequest{ + Prompt: "Ex falso quodlibet", + Model: openai.GPT3Babbage002, + MaxTokens: 5, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + counter := 0 + for { + _, err = stream.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + t.Errorf("Stream error: %v", err) + } else { + counter++ + } + } + if counter == 0 { + t.Error("Stream did not return any responses") + } +} + func TestAPIError(t *testing.T) { apiToken := os.Getenv("OPENAI_TOKEN") if apiToken == "" { diff --git a/completion.go b/completion.go index ced8e0606..024f09b14 100644 --- a/completion.go +++ b/completion.go @@ -39,30 +39,33 @@ const ( GPT3Dot5Turbo16K0613 = "gpt-3.5-turbo-16k-0613" GPT3Dot5Turbo = "gpt-3.5-turbo" GPT3Dot5TurboInstruct = "gpt-3.5-turbo-instruct" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextDavinci003 = "text-davinci-003" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextDavinci002 = "text-davinci-002" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextCurie001 = "text-curie-001" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextBabbage001 = "text-babbage-001" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextAda001 = "text-ada-001" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextDavinci001 = "text-davinci-001" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3DavinciInstructBeta = "davinci-instruct-beta" - GPT3Davinci = "davinci" - GPT3Davinci002 = "davinci-002" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use davinci-002 instead. + GPT3Davinci = "davinci" + GPT3Davinci002 = "davinci-002" + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3CurieInstructBeta = "curie-instruct-beta" GPT3Curie = "curie" GPT3Curie002 = "curie-002" - GPT3Ada = "ada" - GPT3Ada002 = "ada-002" - GPT3Babbage = "babbage" - GPT3Babbage002 = "babbage-002" + // Deprecated: Model is shutdown. Use babbage-002 instead. + GPT3Ada = "ada" + GPT3Ada002 = "ada-002" + // Deprecated: Model is shutdown. Use babbage-002 instead. + GPT3Babbage = "babbage" + GPT3Babbage002 = "babbage-002" ) // Codex Defines the models provided by OpenAI. diff --git a/embeddings.go b/embeddings.go index c5633a313..b513ba6a7 100644 --- a/embeddings.go +++ b/embeddings.go @@ -16,7 +16,7 @@ var ErrVectorLengthMismatch = errors.New("vector length mismatch") type EmbeddingModel string const ( - // Deprecated: The following block will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. + // Deprecated: The following block is shut down. Use text-embedding-ada-002 instead. AdaSimilarity EmbeddingModel = "text-similarity-ada-001" BabbageSimilarity EmbeddingModel = "text-similarity-babbage-001" CurieSimilarity EmbeddingModel = "text-similarity-curie-001" diff --git a/internal/test/checks/checks.go b/internal/test/checks/checks.go index 713369157..6bd0964c6 100644 --- a/internal/test/checks/checks.go +++ b/internal/test/checks/checks.go @@ -12,6 +12,13 @@ func NoError(t *testing.T, err error, message ...string) { } } +func NoErrorF(t *testing.T, err error, message ...string) { + t.Helper() + if err != nil { + t.Fatal(err, message) + } +} + func HasError(t *testing.T, err error, message ...string) { t.Helper() if err == nil { From 7e96c712cbdad50b9cf67324b1ca5ef6541b6235 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Thu, 13 Jun 2024 19:15:27 +0400 Subject: [PATCH 019/129] run integration tests (#769) --- .github/workflows/integration-tests.yml | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 .github/workflows/integration-tests.yml diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml new file mode 100644 index 000000000..19f158e40 --- /dev/null +++ b/.github/workflows/integration-tests.yml @@ -0,0 +1,19 @@ +name: Integration tests + +on: + push: + branches: + - master + +jobs: + integration_tests: + name: Run integration tests + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: '1.21' + - name: Run integration tests + run: go test -v -tags=integration ./api_integration_test.go From c69c3bb1d259375d5de801f890aca40c0b2a8867 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Thu, 13 Jun 2024 19:21:25 +0400 Subject: [PATCH 020/129] integration tests: pass openai secret (#770) * pass openai secret * only run in master branch --- .github/workflows/integration-tests.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 19f158e40..7260b00b4 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -16,4 +16,6 @@ jobs: with: go-version: '1.21' - name: Run integration tests + env: + OPENAI_TOKEN: ${{ secrets.OPENAI_TOKEN }} run: go test -v -tags=integration ./api_integration_test.go From 99cc170b5414bd21fc1c55bccba1d6c1bad04516 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Thu, 13 Jun 2024 23:24:37 +0800 Subject: [PATCH 021/129] feat: support batches api (#746) * feat: support batches api * update batch_test.go * fix golangci-lint check * fix golangci-lint check * fix tests coverage * fix tests coverage * fix tests coverage * fix tests coverage * fix tests coverage * fix tests coverage * fix tests coverage * fix: create batch api * update batch_test.go * feat: add `CreateBatchWithUploadFile` * feat: add `UploadBatchFile` * optimize variable and type naming * expose `BatchLineItem` interface * update batches const --- batch.go | 275 ++++++++++++++++++++++++++++++++++++ batch_test.go | 368 +++++++++++++++++++++++++++++++++++++++++++++++++ client_test.go | 11 ++ files.go | 1 + 4 files changed, 655 insertions(+) create mode 100644 batch.go create mode 100644 batch_test.go diff --git a/batch.go b/batch.go new file mode 100644 index 000000000..4aba966bc --- /dev/null +++ b/batch.go @@ -0,0 +1,275 @@ +package openai + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" +) + +const batchesSuffix = "/batches" + +type BatchEndpoint string + +const ( + BatchEndpointChatCompletions BatchEndpoint = "/v1/chat/completions" + BatchEndpointCompletions BatchEndpoint = "/v1/completions" + BatchEndpointEmbeddings BatchEndpoint = "/v1/embeddings" +) + +type BatchLineItem interface { + MarshalBatchLineItem() []byte +} + +type BatchChatCompletionRequest struct { + CustomID string `json:"custom_id"` + Body ChatCompletionRequest `json:"body"` + Method string `json:"method"` + URL BatchEndpoint `json:"url"` +} + +func (r BatchChatCompletionRequest) MarshalBatchLineItem() []byte { + marshal, _ := json.Marshal(r) + return marshal +} + +type BatchCompletionRequest struct { + CustomID string `json:"custom_id"` + Body CompletionRequest `json:"body"` + Method string `json:"method"` + URL BatchEndpoint `json:"url"` +} + +func (r BatchCompletionRequest) MarshalBatchLineItem() []byte { + marshal, _ := json.Marshal(r) + return marshal +} + +type BatchEmbeddingRequest struct { + CustomID string `json:"custom_id"` + Body EmbeddingRequest `json:"body"` + Method string `json:"method"` + URL BatchEndpoint `json:"url"` +} + +func (r BatchEmbeddingRequest) MarshalBatchLineItem() []byte { + marshal, _ := json.Marshal(r) + return marshal +} + +type Batch struct { + ID string `json:"id"` + Object string `json:"object"` + Endpoint BatchEndpoint `json:"endpoint"` + Errors *struct { + Object string `json:"object,omitempty"` + Data struct { + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Param *string `json:"param,omitempty"` + Line *int `json:"line,omitempty"` + } `json:"data"` + } `json:"errors"` + InputFileID string `json:"input_file_id"` + CompletionWindow string `json:"completion_window"` + Status string `json:"status"` + OutputFileID *string `json:"output_file_id"` + ErrorFileID *string `json:"error_file_id"` + CreatedAt int `json:"created_at"` + InProgressAt *int `json:"in_progress_at"` + ExpiresAt *int `json:"expires_at"` + FinalizingAt *int `json:"finalizing_at"` + CompletedAt *int `json:"completed_at"` + FailedAt *int `json:"failed_at"` + ExpiredAt *int `json:"expired_at"` + CancellingAt *int `json:"cancelling_at"` + CancelledAt *int `json:"cancelled_at"` + RequestCounts BatchRequestCounts `json:"request_counts"` + Metadata map[string]any `json:"metadata"` +} + +type BatchRequestCounts struct { + Total int `json:"total"` + Completed int `json:"completed"` + Failed int `json:"failed"` +} + +type CreateBatchRequest struct { + InputFileID string `json:"input_file_id"` + Endpoint BatchEndpoint `json:"endpoint"` + CompletionWindow string `json:"completion_window"` + Metadata map[string]any `json:"metadata"` +} + +type BatchResponse struct { + httpHeader + Batch +} + +var ErrUploadBatchFileFailed = errors.New("upload batch file failed") + +// CreateBatch — API call to Create batch. +func (c *Client) CreateBatch( + ctx context.Context, + request CreateBatchRequest, +) (response BatchResponse, err error) { + if request.CompletionWindow == "" { + request.CompletionWindow = "24h" + } + + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(batchesSuffix), withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +type UploadBatchFileRequest struct { + FileName string + Lines []BatchLineItem +} + +func (r *UploadBatchFileRequest) MarshalJSONL() []byte { + buff := bytes.Buffer{} + for i, line := range r.Lines { + if i != 0 { + buff.Write([]byte("\n")) + } + buff.Write(line.MarshalBatchLineItem()) + } + return buff.Bytes() +} + +func (r *UploadBatchFileRequest) AddChatCompletion(customerID string, body ChatCompletionRequest) { + r.Lines = append(r.Lines, BatchChatCompletionRequest{ + CustomID: customerID, + Body: body, + Method: "POST", + URL: BatchEndpointChatCompletions, + }) +} + +func (r *UploadBatchFileRequest) AddCompletion(customerID string, body CompletionRequest) { + r.Lines = append(r.Lines, BatchCompletionRequest{ + CustomID: customerID, + Body: body, + Method: "POST", + URL: BatchEndpointCompletions, + }) +} + +func (r *UploadBatchFileRequest) AddEmbedding(customerID string, body EmbeddingRequest) { + r.Lines = append(r.Lines, BatchEmbeddingRequest{ + CustomID: customerID, + Body: body, + Method: "POST", + URL: BatchEndpointEmbeddings, + }) +} + +// UploadBatchFile — upload batch file. +func (c *Client) UploadBatchFile(ctx context.Context, request UploadBatchFileRequest) (File, error) { + if request.FileName == "" { + request.FileName = "@batchinput.jsonl" + } + return c.CreateFileBytes(ctx, FileBytesRequest{ + Name: request.FileName, + Bytes: request.MarshalJSONL(), + Purpose: PurposeBatch, + }) +} + +type CreateBatchWithUploadFileRequest struct { + Endpoint BatchEndpoint `json:"endpoint"` + CompletionWindow string `json:"completion_window"` + Metadata map[string]any `json:"metadata"` + UploadBatchFileRequest +} + +// CreateBatchWithUploadFile — API call to Create batch with upload file. +func (c *Client) CreateBatchWithUploadFile( + ctx context.Context, + request CreateBatchWithUploadFileRequest, +) (response BatchResponse, err error) { + var file File + file, err = c.UploadBatchFile(ctx, UploadBatchFileRequest{ + FileName: request.FileName, + Lines: request.Lines, + }) + if err != nil { + err = errors.Join(ErrUploadBatchFileFailed, err) + return + } + return c.CreateBatch(ctx, CreateBatchRequest{ + InputFileID: file.ID, + Endpoint: request.Endpoint, + CompletionWindow: request.CompletionWindow, + Metadata: request.Metadata, + }) +} + +// RetrieveBatch — API call to Retrieve batch. +func (c *Client) RetrieveBatch( + ctx context.Context, + batchID string, +) (response BatchResponse, err error) { + urlSuffix := fmt.Sprintf("%s/%s", batchesSuffix, batchID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + err = c.sendRequest(req, &response) + return +} + +// CancelBatch — API call to Cancel batch. +func (c *Client) CancelBatch( + ctx context.Context, + batchID string, +) (response BatchResponse, err error) { + urlSuffix := fmt.Sprintf("%s/%s/cancel", batchesSuffix, batchID) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix)) + if err != nil { + return + } + err = c.sendRequest(req, &response) + return +} + +type ListBatchResponse struct { + httpHeader + Object string `json:"object"` + Data []Batch `json:"data"` + FirstID string `json:"first_id"` + LastID string `json:"last_id"` + HasMore bool `json:"has_more"` +} + +// ListBatch API call to List batch. +func (c *Client) ListBatch(ctx context.Context, after *string, limit *int) (response ListBatchResponse, err error) { + urlValues := url.Values{} + if limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *limit)) + } + if after != nil { + urlValues.Add("after", *after) + } + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s%s", batchesSuffix, encodedValues) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/batch_test.go b/batch_test.go new file mode 100644 index 000000000..4b2261e0e --- /dev/null +++ b/batch_test.go @@ -0,0 +1,368 @@ +package openai_test + +import ( + "context" + "fmt" + "net/http" + "reflect" + "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestUploadBatchFile(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/files", handleCreateFile) + req := openai.UploadBatchFileRequest{} + req.AddChatCompletion("req-1", openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + _, err := client.UploadBatchFile(context.Background(), req) + checks.NoError(t, err, "UploadBatchFile error") +} + +func TestCreateBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/batches", handleBatchEndpoint) + _, err := client.CreateBatch(context.Background(), openai.CreateBatchRequest{ + InputFileID: "file-abc", + Endpoint: openai.BatchEndpointChatCompletions, + CompletionWindow: "24h", + }) + checks.NoError(t, err, "CreateBatch error") +} + +func TestCreateBatchWithUploadFile(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files", handleCreateFile) + server.RegisterHandler("/v1/batches", handleBatchEndpoint) + req := openai.CreateBatchWithUploadFileRequest{ + Endpoint: openai.BatchEndpointChatCompletions, + } + req.AddChatCompletion("req-1", openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + _, err := client.CreateBatchWithUploadFile(context.Background(), req) + checks.NoError(t, err, "CreateBatchWithUploadFile error") +} + +func TestRetrieveBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/batches/file-id-1", handleRetrieveBatchEndpoint) + _, err := client.RetrieveBatch(context.Background(), "file-id-1") + checks.NoError(t, err, "RetrieveBatch error") +} + +func TestCancelBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/batches/file-id-1/cancel", handleCancelBatchEndpoint) + _, err := client.CancelBatch(context.Background(), "file-id-1") + checks.NoError(t, err, "RetrieveBatch error") +} + +func TestListBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/batches", handleBatchEndpoint) + after := "batch_abc123" + limit := 10 + _, err := client.ListBatch(context.Background(), &after, &limit) + checks.NoError(t, err, "RetrieveBatch error") +} + +func TestUploadBatchFileRequest_AddChatCompletion(t *testing.T) { + type args struct { + customerID string + body openai.ChatCompletionRequest + } + tests := []struct { + name string + args []args + want []byte + }{ + {"", []args{ + { + customerID: "req-1", + body: openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }, + }, + { + customerID: "req-2", + body: openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }, + }, + }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"messages\":[{\"role\":\"user\",\"content\":\"Hello!\"}],\"max_tokens\":5},\"method\":\"POST\",\"url\":\"/v1/chat/completions\"}\n{\"custom_id\":\"req-2\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"messages\":[{\"role\":\"user\",\"content\":\"Hello!\"}],\"max_tokens\":5},\"method\":\"POST\",\"url\":\"/v1/chat/completions\"}")}, //nolint:lll + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &openai.UploadBatchFileRequest{} + for _, arg := range tt.args { + r.AddChatCompletion(arg.customerID, arg.body) + } + got := r.MarshalJSONL() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Marshal() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUploadBatchFileRequest_AddCompletion(t *testing.T) { + type args struct { + customerID string + body openai.CompletionRequest + } + tests := []struct { + name string + args []args + want []byte + }{ + {"", []args{ + { + customerID: "req-1", + body: openai.CompletionRequest{ + Model: openai.GPT3Dot5Turbo, + User: "Hello", + }, + }, + { + customerID: "req-2", + body: openai.CompletionRequest{ + Model: openai.GPT3Dot5Turbo, + User: "Hello", + }, + }, + }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"user\":\"Hello\"},\"method\":\"POST\",\"url\":\"/v1/completions\"}\n{\"custom_id\":\"req-2\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"user\":\"Hello\"},\"method\":\"POST\",\"url\":\"/v1/completions\"}")}, //nolint:lll + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &openai.UploadBatchFileRequest{} + for _, arg := range tt.args { + r.AddCompletion(arg.customerID, arg.body) + } + got := r.MarshalJSONL() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Marshal() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUploadBatchFileRequest_AddEmbedding(t *testing.T) { + type args struct { + customerID string + body openai.EmbeddingRequest + } + tests := []struct { + name string + args []args + want []byte + }{ + {"", []args{ + { + customerID: "req-1", + body: openai.EmbeddingRequest{ + Model: openai.GPT3Dot5Turbo, + Input: []string{"Hello", "World"}, + }, + }, + { + customerID: "req-2", + body: openai.EmbeddingRequest{ + Model: openai.AdaEmbeddingV2, + Input: []string{"Hello", "World"}, + }, + }, + }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"input\":[\"Hello\",\"World\"],\"model\":\"gpt-3.5-turbo\",\"user\":\"\"},\"method\":\"POST\",\"url\":\"/v1/embeddings\"}\n{\"custom_id\":\"req-2\",\"body\":{\"input\":[\"Hello\",\"World\"],\"model\":\"text-embedding-ada-002\",\"user\":\"\"},\"method\":\"POST\",\"url\":\"/v1/embeddings\"}")}, //nolint:lll + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &openai.UploadBatchFileRequest{} + for _, arg := range tt.args { + r.AddEmbedding(arg.customerID, arg.body) + } + got := r.MarshalJSONL() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Marshal() got = %v, want %v", got, tt.want) + } + }) + } +} + +func handleBatchEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + _, _ = fmt.Fprintln(w, `{ + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "completed", + "output_file_id": "file-cvaTdG", + "error_file_id": "file-HOWS94", + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": 1711493133, + "completed_at": 1711493163, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 95, + "failed": 5 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job" + } + }`) + } else if r.Method == http.MethodGet { + _, _ = fmt.Fprintln(w, `{ + "object": "list", + "data": [ + { + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/chat/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "completed", + "output_file_id": "file-cvaTdG", + "error_file_id": "file-HOWS94", + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": 1711493133, + "completed_at": 1711493163, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 95, + "failed": 5 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly job" + } + } + ], + "first_id": "batch_abc123", + "last_id": "batch_abc456", + "has_more": true + }`) + } +} + +func handleRetrieveBatchEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + _, _ = fmt.Fprintln(w, `{ + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "completed", + "output_file_id": "file-cvaTdG", + "error_file_id": "file-HOWS94", + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": 1711493133, + "completed_at": 1711493163, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 95, + "failed": 5 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job" + } + }`) + } +} + +func handleCancelBatchEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + _, _ = fmt.Fprintln(w, `{ + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/chat/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "cancelling", + "output_file_id": null, + "error_file_id": null, + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": null, + "completed_at": null, + "failed_at": null, + "expired_at": null, + "cancelling_at": 1711475133, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 23, + "failed": 1 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job" + } + }`) + } +} diff --git a/client_test.go b/client_test.go index a08d10f21..e49da9b3d 100644 --- a/client_test.go +++ b/client_test.go @@ -396,6 +396,17 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"CreateSpeech", func() (any, error) { return client.CreateSpeech(ctx, CreateSpeechRequest{Model: TTSModel1, Voice: VoiceAlloy}) }}, + {"CreateBatch", func() (any, error) { + return client.CreateBatch(ctx, CreateBatchRequest{}) + }}, + {"CreateBatchWithUploadFile", func() (any, error) { + return client.CreateBatchWithUploadFile(ctx, CreateBatchWithUploadFileRequest{}) + }}, + {"RetrieveBatch", func() (any, error) { + return client.RetrieveBatch(ctx, "") + }}, + {"CancelBatch", func() (any, error) { return client.CancelBatch(ctx, "") }}, + {"ListBatch", func() (any, error) { return client.ListBatch(ctx, nil, nil) }}, } for _, testCase := range testCases { diff --git a/files.go b/files.go index b40a44f15..26ad6bd70 100644 --- a/files.go +++ b/files.go @@ -22,6 +22,7 @@ const ( PurposeFineTuneResults PurposeType = "fine-tune-results" PurposeAssistants PurposeType = "assistants" PurposeAssistantsOutput PurposeType = "assistants_output" + PurposeBatch PurposeType = "batch" ) // FileBytesRequest represents a file upload request. From 68acf22a43903c1b460006e7c4b883ce73e35857 Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Thu, 13 Jun 2024 17:26:37 +0200 Subject: [PATCH 022/129] Support Tool Resources properties for Threads (#760) * Support Tool Resources properties for Threads * Add Chunking Strategy for Threads vector stores --- thread.go | 67 +++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 60 insertions(+), 7 deletions(-) diff --git a/thread.go b/thread.go index 900e3f2ea..6f7521454 100644 --- a/thread.go +++ b/thread.go @@ -10,21 +10,74 @@ const ( ) type Thread struct { - ID string `json:"id"` - Object string `json:"object"` - CreatedAt int64 `json:"created_at"` - Metadata map[string]any `json:"metadata"` + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Metadata map[string]any `json:"metadata"` + ToolResources ToolResources `json:"tool_resources,omitempty"` httpHeader } type ThreadRequest struct { - Messages []ThreadMessage `json:"messages,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Messages []ThreadMessage `json:"messages,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + ToolResources *ToolResourcesRequest `json:"tool_resources,omitempty"` } +type ToolResources struct { + CodeInterpreter *CodeInterpreterToolResources `json:"code_interpreter,omitempty"` + FileSearch *FileSearchToolResources `json:"file_search,omitempty"` +} + +type CodeInterpreterToolResources struct { + FileIDs []string `json:"file_ids,omitempty"` +} + +type FileSearchToolResources struct { + VectorStoreIDs []string `json:"vector_store_ids,omitempty"` +} + +type ToolResourcesRequest struct { + CodeInterpreter *CodeInterpreterToolResourcesRequest `json:"code_interpreter,omitempty"` + FileSearch *FileSearchToolResourcesRequest `json:"file_search,omitempty"` +} + +type CodeInterpreterToolResourcesRequest struct { + FileIDs []string `json:"file_ids,omitempty"` +} + +type FileSearchToolResourcesRequest struct { + VectorStoreIDs []string `json:"vector_store_ids,omitempty"` + VectorStores []VectorStoreToolResources `json:"vector_stores,omitempty"` +} + +type VectorStoreToolResources struct { + FileIDs []string `json:"file_ids,omitempty"` + ChunkingStrategy *ChunkingStrategy `json:"chunking_strategy,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type ChunkingStrategy struct { + Type ChunkingStrategyType `json:"type"` + Static *StaticChunkingStrategy `json:"static,omitempty"` +} + +type StaticChunkingStrategy struct { + MaxChunkSizeTokens int `json:"max_chunk_size_tokens"` + ChunkOverlapTokens int `json:"chunk_overlap_tokens"` +} + +type ChunkingStrategyType string + +const ( + ChunkingStrategyTypeAuto ChunkingStrategyType = "auto" + ChunkingStrategyTypeStatic ChunkingStrategyType = "static" +) + type ModifyThreadRequest struct { - Metadata map[string]any `json:"metadata"` + Metadata map[string]any `json:"metadata"` + ToolResources *ToolResources `json:"tool_resources,omitempty"` } type ThreadMessageRole string From 0a421308993425afed7796da8f8e0e1abafd4582 Mon Sep 17 00:00:00 2001 From: Peng Guan-Cheng Date: Wed, 19 Jun 2024 16:37:21 +0800 Subject: [PATCH 023/129] feat: provide vector store (#772) * implement vectore store feature * fix after integration testing * fix golint error * improve test to increare code coverage * fix golint anc code coverage problem * add tool_resource in assistant response * chore: code style * feat: use pagination param * feat: use pagination param * test: use pagination param * test: rm unused code --------- Co-authored-by: Denny Depok <61371551+kodernubie@users.noreply.github.com> Co-authored-by: eric.p --- assistant.go | 50 ++++--- config.go | 2 +- vector_store.go | 345 ++++++++++++++++++++++++++++++++++++++++++ vector_store_test.go | 349 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 728 insertions(+), 18 deletions(-) create mode 100644 vector_store.go create mode 100644 vector_store_test.go diff --git a/assistant.go b/assistant.go index 661681e83..cc13a3020 100644 --- a/assistant.go +++ b/assistant.go @@ -14,16 +14,17 @@ const ( ) type Assistant struct { - ID string `json:"id"` - Object string `json:"object"` - CreatedAt int64 `json:"created_at"` - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` - Model string `json:"model"` - Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"tools"` - FileIDs []string `json:"file_ids,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Model string `json:"model"` + Instructions *string `json:"instructions,omitempty"` + Tools []AssistantTool `json:"tools"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + ToolResources *AssistantToolResource `json:"tool_resources,omitempty"` httpHeader } @@ -34,6 +35,7 @@ const ( AssistantToolTypeCodeInterpreter AssistantToolType = "code_interpreter" AssistantToolTypeRetrieval AssistantToolType = "retrieval" AssistantToolTypeFunction AssistantToolType = "function" + AssistantToolTypeFileSearch AssistantToolType = "file_search" ) type AssistantTool struct { @@ -41,19 +43,33 @@ type AssistantTool struct { Function *FunctionDefinition `json:"function,omitempty"` } +type AssistantToolFileSearch struct { + VectorStoreIDs []string `json:"vector_store_ids"` +} + +type AssistantToolCodeInterpreter struct { + FileIDs []string `json:"file_ids"` +} + +type AssistantToolResource struct { + FileSearch *AssistantToolFileSearch `json:"file_search,omitempty"` + CodeInterpreter *AssistantToolCodeInterpreter `json:"code_interpreter,omitempty"` +} + // AssistantRequest provides the assistant request parameters. // When modifying the tools the API functions as the following: // If Tools is undefined, no changes are made to the Assistant's tools. // If Tools is empty slice it will effectively delete all of the Assistant's tools. // If Tools is populated, it will replace all of the existing Assistant's tools with the provided tools. type AssistantRequest struct { - Model string `json:"model"` - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` - Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"-"` - FileIDs []string `json:"file_ids,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Model string `json:"model"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Instructions *string `json:"instructions,omitempty"` + Tools []AssistantTool `json:"-"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + ToolResources *AssistantToolResource `json:"tool_resources,omitempty"` } // MarshalJSON provides a custom marshaller for the assistant request to handle the API use cases diff --git a/config.go b/config.go index bb437c97f..1347567d7 100644 --- a/config.go +++ b/config.go @@ -24,7 +24,7 @@ const ( const AzureAPIKeyHeader = "api-key" -const defaultAssistantVersion = "v1" // This will be deprecated by the end of 2024. +const defaultAssistantVersion = "v2" // upgrade to v2 to support vector store // ClientConfig is a configuration of a client. type ClientConfig struct { diff --git a/vector_store.go b/vector_store.go new file mode 100644 index 000000000..5c364362a --- /dev/null +++ b/vector_store.go @@ -0,0 +1,345 @@ +package openai + +import ( + "context" + "fmt" + "net/http" + "net/url" +) + +const ( + vectorStoresSuffix = "/vector_stores" + vectorStoresFilesSuffix = "/files" + vectorStoresFileBatchesSuffix = "/file_batches" +) + +type VectorStoreFileCount struct { + InProgress int `json:"in_progress"` + Completed int `json:"completed"` + Failed int `json:"failed"` + Cancelled int `json:"cancelled"` + Total int `json:"total"` +} + +type VectorStore struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Name string `json:"name"` + UsageBytes int `json:"usage_bytes"` + FileCounts VectorStoreFileCount `json:"file_counts"` + Status string `json:"status"` + ExpiresAfter *VectorStoreExpires `json:"expires_after"` + ExpiresAt *int `json:"expires_at"` + Metadata map[string]any `json:"metadata"` + + httpHeader +} + +type VectorStoreExpires struct { + Anchor string `json:"anchor"` + Days int `json:"days"` +} + +// VectorStoreRequest provides the vector store request parameters. +type VectorStoreRequest struct { + Name string `json:"name,omitempty"` + FileIDs []string `json:"file_ids,omitempty"` + ExpiresAfter *VectorStoreExpires `json:"expires_after,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// VectorStoresList is a list of vector store. +type VectorStoresList struct { + VectorStores []VectorStore `json:"data"` + LastID *string `json:"last_id"` + FirstID *string `json:"first_id"` + HasMore bool `json:"has_more"` + httpHeader +} + +type VectorStoreDeleteResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` + + httpHeader +} + +type VectorStoreFile struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + VectorStoreID string `json:"vector_store_id"` + UsageBytes int `json:"usage_bytes"` + Status string `json:"status"` + + httpHeader +} + +type VectorStoreFileRequest struct { + FileID string `json:"file_id"` +} + +type VectorStoreFilesList struct { + VectorStoreFiles []VectorStoreFile `json:"data"` + + httpHeader +} + +type VectorStoreFileBatch struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + VectorStoreID string `json:"vector_store_id"` + Status string `json:"status"` + FileCounts VectorStoreFileCount `json:"file_counts"` + + httpHeader +} + +type VectorStoreFileBatchRequest struct { + FileIDs []string `json:"file_ids"` +} + +// CreateVectorStore creates a new vector store. +func (c *Client) CreateVectorStore(ctx context.Context, request VectorStoreRequest) (response VectorStore, err error) { + req, _ := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(vectorStoresSuffix), + withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion), + ) + + err = c.sendRequest(req, &response) + return +} + +// RetrieveVectorStore retrieves an vector store. +func (c *Client) RetrieveVectorStore( + ctx context.Context, + vectorStoreID string, +) (response VectorStore, err error) { + urlSuffix := fmt.Sprintf("%s/%s", vectorStoresSuffix, vectorStoreID) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// ModifyVectorStore modifies a vector store. +func (c *Client) ModifyVectorStore( + ctx context.Context, + vectorStoreID string, + request VectorStoreRequest, +) (response VectorStore, err error) { + urlSuffix := fmt.Sprintf("%s/%s", vectorStoresSuffix, vectorStoreID) + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// DeleteVectorStore deletes an vector store. +func (c *Client) DeleteVectorStore( + ctx context.Context, + vectorStoreID string, +) (response VectorStoreDeleteResponse, err error) { + urlSuffix := fmt.Sprintf("%s/%s", vectorStoresSuffix, vectorStoreID) + req, _ := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// ListVectorStores Lists the currently available vector store. +func (c *Client) ListVectorStores( + ctx context.Context, + pagination Pagination, +) (response VectorStoresList, err error) { + urlValues := url.Values{} + + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s%s", vectorStoresSuffix, encodedValues) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// CreateVectorStoreFile creates a new vector store file. +func (c *Client) CreateVectorStoreFile( + ctx context.Context, + vectorStoreID string, + request VectorStoreFileRequest, +) (response VectorStoreFile, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix) + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// RetrieveVectorStoreFile retrieves a vector store file. +func (c *Client) RetrieveVectorStoreFile( + ctx context.Context, + vectorStoreID string, + fileID string, +) (response VectorStoreFile, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix, fileID) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// DeleteVectorStoreFile deletes an existing file. +func (c *Client) DeleteVectorStoreFile( + ctx context.Context, + vectorStoreID string, + fileID string, +) (err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix, fileID) + req, _ := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, nil) + return +} + +// ListVectorStoreFiles Lists the currently available files for a vector store. +func (c *Client) ListVectorStoreFiles( + ctx context.Context, + vectorStoreID string, + pagination Pagination, +) (response VectorStoreFilesList, err error) { + urlValues := url.Values{} + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s/%s%s%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix, encodedValues) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// CreateVectorStoreFileBatch creates a new vector store file batch. +func (c *Client) CreateVectorStoreFileBatch( + ctx context.Context, + vectorStoreID string, + request VectorStoreFileBatchRequest, +) (response VectorStoreFileBatch, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s", vectorStoresSuffix, vectorStoreID, vectorStoresFileBatchesSuffix) + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// RetrieveVectorStoreFileBatch retrieves a vector store file batch. +func (c *Client) RetrieveVectorStoreFileBatch( + ctx context.Context, + vectorStoreID string, + batchID string, +) (response VectorStoreFileBatch, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", vectorStoresSuffix, vectorStoreID, vectorStoresFileBatchesSuffix, batchID) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// CancelVectorStoreFileBatch cancel a new vector store file batch. +func (c *Client) CancelVectorStoreFileBatch( + ctx context.Context, + vectorStoreID string, + batchID string, +) (response VectorStoreFileBatch, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s%s", vectorStoresSuffix, + vectorStoreID, vectorStoresFileBatchesSuffix, batchID, "/cancel") + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// ListVectorStoreFiles Lists the currently available files for a vector store. +func (c *Client) ListVectorStoreFilesInBatch( + ctx context.Context, + vectorStoreID string, + batchID string, + pagination Pagination, +) (response VectorStoreFilesList, err error) { + urlValues := url.Values{} + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s/%s%s/%s%s%s", vectorStoresSuffix, + vectorStoreID, vectorStoresFileBatchesSuffix, batchID, "/files", encodedValues) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} diff --git a/vector_store_test.go b/vector_store_test.go new file mode 100644 index 000000000..58b9a857e --- /dev/null +++ b/vector_store_test.go @@ -0,0 +1,349 @@ +package openai_test + +import ( + "context" + + openai "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" + + "encoding/json" + "fmt" + "net/http" + "testing" +) + +// TestVectorStore Tests the vector store endpoint of the API using the mocked server. +func TestVectorStore(t *testing.T) { + vectorStoreID := "vs_abc123" + vectorStoreName := "TestStore" + vectorStoreFileID := "file-wB6RM6wHdA49HfS2DJ9fEyrH" + vectorStoreFileBatchID := "vsfb_abc123" + limit := 20 + order := "desc" + after := "vs_abc122" + before := "vs_abc123" + + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/files/"+vectorStoreFileID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoreFile{ + ID: vectorStoreFileID, + Object: "vector_store.file", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "completed", + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodDelete { + fmt.Fprintln(w, `{ + id: "file-wB6RM6wHdA49HfS2DJ9fEyrH", + object: "vector_store.file.deleted", + deleted: true + }`) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/files", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoreFilesList{ + VectorStoreFiles: []openai.VectorStoreFile{ + { + ID: vectorStoreFileID, + Object: "vector_store.file", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodPost { + var request openai.VectorStoreFileRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.VectorStoreFile{ + ID: request.FileID, + Object: "vector_store.file", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/file_batches/"+vectorStoreFileBatchID+"/files", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoreFilesList{ + VectorStoreFiles: []openai.VectorStoreFile{ + { + ID: vectorStoreFileID, + Object: "vector_store.file", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/file_batches/"+vectorStoreFileBatchID+"/cancel", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + resBytes, _ := json.Marshal(openai.VectorStoreFileBatch{ + ID: vectorStoreFileBatchID, + Object: "vector_store.file_batch", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "cancelling", + FileCounts: openai.VectorStoreFileCount{ + InProgress: 0, + Completed: 1, + Failed: 0, + Cancelled: 0, + Total: 0, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/file_batches/"+vectorStoreFileBatchID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoreFileBatch{ + ID: vectorStoreFileBatchID, + Object: "vector_store.file_batch", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "completed", + FileCounts: openai.VectorStoreFileCount{ + Completed: 1, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodPost { + resBytes, _ := json.Marshal(openai.VectorStoreFileBatch{ + ID: vectorStoreFileBatchID, + Object: "vector_store.file_batch", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "cancelling", + FileCounts: openai.VectorStoreFileCount{ + Completed: 1, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/file_batches", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.VectorStoreFileBatchRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.VectorStoreFileBatch{ + ID: vectorStoreFileBatchID, + Object: "vector_store.file_batch", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "completed", + FileCounts: openai.VectorStoreFileCount{ + InProgress: 0, + Completed: len(request.FileIDs), + Failed: 0, + Cancelled: 0, + Total: 0, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal(openai.VectorStore{ + ID: vectorStoreID, + Object: "vector_store", + CreatedAt: 1234567890, + Name: vectorStoreName, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodPost: + var request openai.VectorStore + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.VectorStore{ + ID: vectorStoreID, + Object: "vector_store", + CreatedAt: 1234567890, + Name: request.Name, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + fmt.Fprintln(w, `{ + "id": "vectorstore_abc123", + "object": "vector_store.deleted", + "deleted": true + }`) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.VectorStoreRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.VectorStore{ + ID: vectorStoreID, + Object: "vector_store", + CreatedAt: 1234567890, + Name: request.Name, + FileCounts: openai.VectorStoreFileCount{ + InProgress: 0, + Completed: 0, + Failed: 0, + Cancelled: 0, + Total: 0, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoresList{ + LastID: &vectorStoreID, + FirstID: &vectorStoreID, + VectorStores: []openai.VectorStore{ + { + ID: vectorStoreID, + Object: "vector_store", + CreatedAt: 1234567890, + Name: vectorStoreName, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + ctx := context.Background() + + t.Run("create_vector_store", func(t *testing.T) { + _, err := client.CreateVectorStore(ctx, openai.VectorStoreRequest{ + Name: vectorStoreName, + }) + checks.NoError(t, err, "CreateVectorStore error") + }) + + t.Run("retrieve_vector_store", func(t *testing.T) { + _, err := client.RetrieveVectorStore(ctx, vectorStoreID) + checks.NoError(t, err, "RetrieveVectorStore error") + }) + + t.Run("delete_vector_store", func(t *testing.T) { + _, err := client.DeleteVectorStore(ctx, vectorStoreID) + checks.NoError(t, err, "DeleteVectorStore error") + }) + + t.Run("list_vector_store", func(t *testing.T) { + _, err := client.ListVectorStores(context.TODO(), openai.Pagination{ + Limit: &limit, + Order: &order, + After: &after, + Before: &before, + }) + checks.NoError(t, err, "ListVectorStores error") + }) + + t.Run("create_vector_store_file", func(t *testing.T) { + _, err := client.CreateVectorStoreFile(context.TODO(), vectorStoreID, openai.VectorStoreFileRequest{ + FileID: vectorStoreFileID, + }) + checks.NoError(t, err, "CreateVectorStoreFile error") + }) + + t.Run("list_vector_store_files", func(t *testing.T) { + _, err := client.ListVectorStoreFiles(ctx, vectorStoreID, openai.Pagination{ + Limit: &limit, + Order: &order, + After: &after, + Before: &before, + }) + checks.NoError(t, err, "ListVectorStoreFiles error") + }) + + t.Run("retrieve_vector_store_file", func(t *testing.T) { + _, err := client.RetrieveVectorStoreFile(ctx, vectorStoreID, vectorStoreFileID) + checks.NoError(t, err, "RetrieveVectorStoreFile error") + }) + + t.Run("delete_vector_store_file", func(t *testing.T) { + err := client.DeleteVectorStoreFile(ctx, vectorStoreID, vectorStoreFileID) + checks.NoError(t, err, "DeleteVectorStoreFile error") + }) + + t.Run("modify_vector_store", func(t *testing.T) { + _, err := client.ModifyVectorStore(ctx, vectorStoreID, openai.VectorStoreRequest{ + Name: vectorStoreName, + }) + checks.NoError(t, err, "ModifyVectorStore error") + }) + + t.Run("create_vector_store_file_batch", func(t *testing.T) { + _, err := client.CreateVectorStoreFileBatch(ctx, vectorStoreID, openai.VectorStoreFileBatchRequest{ + FileIDs: []string{vectorStoreFileID}, + }) + checks.NoError(t, err, "CreateVectorStoreFileBatch error") + }) + + t.Run("retrieve_vector_store_file_batch", func(t *testing.T) { + _, err := client.RetrieveVectorStoreFileBatch(ctx, vectorStoreID, vectorStoreFileBatchID) + checks.NoError(t, err, "RetrieveVectorStoreFileBatch error") + }) + + t.Run("list_vector_store_files_in_batch", func(t *testing.T) { + _, err := client.ListVectorStoreFilesInBatch( + ctx, + vectorStoreID, + vectorStoreFileBatchID, + openai.Pagination{ + Limit: &limit, + Order: &order, + After: &after, + Before: &before, + }) + checks.NoError(t, err, "ListVectorStoreFilesInBatch error") + }) + + t.Run("cancel_vector_store_file_batch", func(t *testing.T) { + _, err := client.CancelVectorStoreFileBatch(ctx, vectorStoreID, vectorStoreFileBatchID) + checks.NoError(t, err, "CancelVectorStoreFileBatch error") + }) +} From e31185974c45949cc58c24a6cbf5ca969fb0f622 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Wed, 26 Jun 2024 14:06:52 +0100 Subject: [PATCH 024/129] remove errors.Join (#778) --- batch.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/batch.go b/batch.go index 4aba966bc..a43d401ab 100644 --- a/batch.go +++ b/batch.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "encoding/json" - "errors" "fmt" "net/http" "net/url" @@ -109,8 +108,6 @@ type BatchResponse struct { Batch } -var ErrUploadBatchFileFailed = errors.New("upload batch file failed") - // CreateBatch — API call to Create batch. func (c *Client) CreateBatch( ctx context.Context, @@ -202,7 +199,6 @@ func (c *Client) CreateBatchWithUploadFile( Lines: request.Lines, }) if err != nil { - err = errors.Join(ErrUploadBatchFileFailed, err) return } return c.CreateBatch(ctx, CreateBatchRequest{ From 03851d20327b7df5358ff9fb0ac96f476be1875a Mon Sep 17 00:00:00 2001 From: Adrian Liechti Date: Sun, 30 Jun 2024 17:20:10 +0200 Subject: [PATCH 025/129] allow custom voice and speech models (#691) --- speech.go | 31 ------------------------------- speech_test.go | 17 ----------------- 2 files changed, 48 deletions(-) diff --git a/speech.go b/speech.go index 7e22e755c..19b21bdf1 100644 --- a/speech.go +++ b/speech.go @@ -2,7 +2,6 @@ package openai import ( "context" - "errors" "net/http" ) @@ -36,11 +35,6 @@ const ( SpeechResponseFormatPcm SpeechResponseFormat = "pcm" ) -var ( - ErrInvalidSpeechModel = errors.New("invalid speech model") - ErrInvalidVoice = errors.New("invalid voice") -) - type CreateSpeechRequest struct { Model SpeechModel `json:"model"` Input string `json:"input"` @@ -49,32 +43,7 @@ type CreateSpeechRequest struct { Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0 } -func contains[T comparable](s []T, e T) bool { - for _, v := range s { - if v == e { - return true - } - } - return false -} - -func isValidSpeechModel(model SpeechModel) bool { - return contains([]SpeechModel{TTSModel1, TTSModel1HD, TTSModelCanary}, model) -} - -func isValidVoice(voice SpeechVoice) bool { - return contains([]SpeechVoice{VoiceAlloy, VoiceEcho, VoiceFable, VoiceOnyx, VoiceNova, VoiceShimmer}, voice) -} - func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response RawResponse, err error) { - if !isValidSpeechModel(request.Model) { - err = ErrInvalidSpeechModel - return - } - if !isValidVoice(request.Voice) { - err = ErrInvalidVoice - return - } req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", string(request.Model)), withBody(request), withContentType("application/json"), diff --git a/speech_test.go b/speech_test.go index d9ba58b13..f1e405c39 100644 --- a/speech_test.go +++ b/speech_test.go @@ -95,21 +95,4 @@ func TestSpeechIntegration(t *testing.T) { err = os.WriteFile("test.mp3", buf, 0644) checks.NoError(t, err, "Create error") }) - t.Run("invalid model", func(t *testing.T) { - _, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ - Model: "invalid_model", - Input: "Hello!", - Voice: openai.VoiceAlloy, - }) - checks.ErrorIs(t, err, openai.ErrInvalidSpeechModel, "CreateSpeech error") - }) - - t.Run("invalid voice", func(t *testing.T) { - _, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ - Model: openai.TTSModel1, - Input: "Hello!", - Voice: "invalid_voice", - }) - checks.ErrorIs(t, err, openai.ErrInvalidVoice, "CreateSpeech error") - }) } From 727944c47886924800128d1c33df706b4159eb23 Mon Sep 17 00:00:00 2001 From: Luca Giannini <68999840+LGXerxes@users.noreply.github.com> Date: Fri, 12 Jul 2024 12:31:11 +0200 Subject: [PATCH 026/129] feat: ParallelToolCalls to ChatCompletionRequest with helper functions (#787) * added ParallelToolCalls to ChatCompletionRequest with helper functions * added tests for coverage * changed ParallelToolCalls to any --- chat.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/chat.go b/chat.go index a1eb11720..eb494f41f 100644 --- a/chat.go +++ b/chat.go @@ -218,6 +218,8 @@ type ChatCompletionRequest struct { ToolChoice any `json:"tool_choice,omitempty"` // Options for streaming response. Only set this when you set stream: true. StreamOptions *StreamOptions `json:"stream_options,omitempty"` + // Disable the default behavior of parallel tool calls by setting it: false. + ParallelToolCalls any `json:"parallel_tool_calls,omitempty"` } type StreamOptions struct { From 3e47e6fef4ac861dd5e07f73a8fb240374e8cad3 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Fri, 19 Jul 2024 22:06:27 +0800 Subject: [PATCH 027/129] fix: #790 (#798) --- files.go | 1 + 1 file changed, 1 insertion(+) diff --git a/files.go b/files.go index 26ad6bd70..edc9f2a20 100644 --- a/files.go +++ b/files.go @@ -102,6 +102,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File if err != nil { return } + defer fileData.Close() err = builder.CreateFormFile("file", fileData) if err != nil { From 27c1c56f0b50a84740425f7534c46825e227b437 Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Fri, 19 Jul 2024 07:06:51 -0700 Subject: [PATCH 028/129] feat: Add GPT-4o Mini model support (#796) --- completion.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/completion.go b/completion.go index 024f09b14..4ff1123c4 100644 --- a/completion.go +++ b/completion.go @@ -24,6 +24,8 @@ const ( GPT40314 = "gpt-4-0314" GPT4o = "gpt-4o" GPT4o20240513 = "gpt-4o-2024-05-13" + GPT4oMini = "gpt-4o-mini" + GPT4oMini20240718 = "gpt-4o-mini-2024-07-18" GPT4Turbo = "gpt-4-turbo" GPT4Turbo20240409 = "gpt-4-turbo-2024-04-09" GPT4Turbo0125 = "gpt-4-0125-preview" @@ -89,6 +91,8 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT4: true, GPT4o: true, GPT4o20240513: true, + GPT4oMini: true, + GPT4oMini20240718: true, GPT4TurboPreview: true, GPT4VisionPreview: true, GPT4Turbo1106: true, From 92f483055f666847f7954e148b7f46771c5581b8 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Fri, 19 Jul 2024 22:10:17 +0800 Subject: [PATCH 029/129] fix: #794 (#797) --- client.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index 7bc28e984..d5d555c3d 100644 --- a/client.go +++ b/client.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "net/url" "strings" utils "github.com/sashabaranov/go-openai/internal" @@ -228,10 +229,13 @@ func (c *Client) fullURL(suffix string, args ...any) string { if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { baseURL := c.config.BaseURL baseURL = strings.TrimRight(baseURL, "/") + parseURL, _ := url.Parse(baseURL) + query := parseURL.Query() + query.Add("api-version", c.config.APIVersion) // if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01 // https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP if containsSubstr([]string{"/models", "/assistants", "/threads", "/files"}, suffix) { - return fmt.Sprintf("%s/%s%s?api-version=%s", baseURL, azureAPIPrefix, suffix, c.config.APIVersion) + return fmt.Sprintf("%s/%s%s?%s", baseURL, azureAPIPrefix, suffix, query.Encode()) } azureDeploymentName := "UNKNOWN" if len(args) > 0 { @@ -240,9 +244,9 @@ func (c *Client) fullURL(suffix string, args ...any) string { azureDeploymentName = c.config.GetAzureDeploymentByModel(model) } } - return fmt.Sprintf("%s/%s/%s/%s%s?api-version=%s", + return fmt.Sprintf("%s/%s/%s/%s%s?%s", baseURL, azureAPIPrefix, azureDeploymentsPrefix, - azureDeploymentName, suffix, c.config.APIVersion, + azureDeploymentName, suffix, query.Encode(), ) } From ae903d7465c4b48654fac6103472767ee4d95e41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edin=20=C4=86orali=C4=87?= <73831203+ecoralic@users.noreply.github.com> Date: Fri, 19 Jul 2024 17:12:20 +0300 Subject: [PATCH 030/129] fix: Updated ThreadMessage struct with latest fields based on OpenAI docs (#792) * fix: Updated ThreadMessage struct with latest fields based on OpenAI docs * fix: Reverted FileIDs for backward compatibility of v1 --- thread.go | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/thread.go b/thread.go index 6f7521454..bc08e2bcb 100644 --- a/thread.go +++ b/thread.go @@ -83,14 +83,25 @@ type ModifyThreadRequest struct { type ThreadMessageRole string const ( - ThreadMessageRoleUser ThreadMessageRole = "user" + ThreadMessageRoleAssistant ThreadMessageRole = "assistant" + ThreadMessageRoleUser ThreadMessageRole = "user" ) type ThreadMessage struct { - Role ThreadMessageRole `json:"role"` - Content string `json:"content"` - FileIDs []string `json:"file_ids,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Role ThreadMessageRole `json:"role"` + Content string `json:"content"` + FileIDs []string `json:"file_ids,omitempty"` + Attachments []ThreadAttachment `json:"attachments,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type ThreadAttachment struct { + FileID string `json:"file_id"` + Tools []ThreadAttachmentTool `json:"tools"` +} + +type ThreadAttachmentTool struct { + Type string `json:"type"` } type ThreadDeleteResponse struct { From a7e9f0e3880d1487fe8e06a43820f42046b5b622 Mon Sep 17 00:00:00 2001 From: Janusch Jacoby Date: Fri, 19 Jul 2024 16:13:02 +0200 Subject: [PATCH 031/129] add hyperparams (#793) --- fine_tuning_job.go | 4 +++- fine_tuning_job_test.go | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/fine_tuning_job.go b/fine_tuning_job.go index 9dcb49de1..5a9f54a92 100644 --- a/fine_tuning_job.go +++ b/fine_tuning_job.go @@ -26,7 +26,9 @@ type FineTuningJob struct { } type Hyperparameters struct { - Epochs any `json:"n_epochs,omitempty"` + Epochs any `json:"n_epochs,omitempty"` + LearningRateMultiplier any `json:"learning_rate_multiplier,omitempty"` + BatchSize any `json:"batch_size,omitempty"` } type FineTuningJobRequest struct { diff --git a/fine_tuning_job_test.go b/fine_tuning_job_test.go index d2fbcd4c7..5f63ef24c 100644 --- a/fine_tuning_job_test.go +++ b/fine_tuning_job_test.go @@ -33,7 +33,9 @@ func TestFineTuningJob(t *testing.T) { ValidationFile: "", TrainingFile: "file-abc123", Hyperparameters: openai.Hyperparameters{ - Epochs: "auto", + Epochs: "auto", + LearningRateMultiplier: "auto", + BatchSize: "auto", }, TrainedTokens: 5768, }) From 966ee682b11ca580c2c2c3ac067c27b51bd6d749 Mon Sep 17 00:00:00 2001 From: VanessaMae23 <60029664+Vanessamae23@users.noreply.github.com> Date: Fri, 19 Jul 2024 22:18:16 +0800 Subject: [PATCH 032/129] Add New Optional Parameters to `AssistantRequest` Struct (#795) * Add more parameters to support Assistant v2 * Add goimports --- assistant.go | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/assistant.go b/assistant.go index cc13a3020..4c89c1b2f 100644 --- a/assistant.go +++ b/assistant.go @@ -62,14 +62,17 @@ type AssistantToolResource struct { // If Tools is empty slice it will effectively delete all of the Assistant's tools. // If Tools is populated, it will replace all of the existing Assistant's tools with the provided tools. type AssistantRequest struct { - Model string `json:"model"` - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` - Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"-"` - FileIDs []string `json:"file_ids,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` - ToolResources *AssistantToolResource `json:"tool_resources,omitempty"` + Model string `json:"model"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Instructions *string `json:"instructions,omitempty"` + Tools []AssistantTool `json:"-"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + ToolResources *AssistantToolResource `json:"tool_resources,omitempty"` + ResponseFormat any `json:"response_format,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` + TopP *float32 `json:"top_p,omitempty"` } // MarshalJSON provides a custom marshaller for the assistant request to handle the API use cases From 581da2f12d52617368bdfe2625f5b0ef1dd32758 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Mon, 29 Jul 2024 01:43:45 +0800 Subject: [PATCH 033/129] fix: #804 (#807) --- batch.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/batch.go b/batch.go index a43d401ab..3c1a9d0d7 100644 --- a/batch.go +++ b/batch.go @@ -65,7 +65,7 @@ type Batch struct { Endpoint BatchEndpoint `json:"endpoint"` Errors *struct { Object string `json:"object,omitempty"` - Data struct { + Data []struct { Code string `json:"code,omitempty"` Message string `json:"message,omitempty"` Param *string `json:"param,omitempty"` From dbe726c59f6df65965a4ee25e37706c33e391dc4 Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Wed, 7 Aug 2024 20:21:38 +1000 Subject: [PATCH 034/129] Add support for `gpt-4o-2024-08-06` (#812) * feat: Add GPT-4o Mini model support * feat: Add GPT-4o-2024-08-06 model support --- completion.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/completion.go b/completion.go index 4ff1123c4..d435eb382 100644 --- a/completion.go +++ b/completion.go @@ -24,6 +24,7 @@ const ( GPT40314 = "gpt-4-0314" GPT4o = "gpt-4o" GPT4o20240513 = "gpt-4o-2024-05-13" + GPT4o20240806 = "gpt-4o-2024-08-06" GPT4oMini = "gpt-4o-mini" GPT4oMini20240718 = "gpt-4o-mini-2024-07-18" GPT4Turbo = "gpt-4-turbo" @@ -91,6 +92,7 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT4: true, GPT4o: true, GPT4o20240513: true, + GPT4o20240806: true, GPT4oMini: true, GPT4oMini20240718: true, GPT4TurboPreview: true, From 623074c14a110b97d9a7aac7896bbdccf335257f Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Wed, 7 Aug 2024 21:47:48 +0800 Subject: [PATCH 035/129] feat: Support Structured Outputs (#813) * feat: Support Structured Outputs * feat: Support Structured Outputs * update imports * add integration test * update JSON schema comments --- api_integration_test.go | 61 +++++++++++++++++++++++++++++++++++++++++ chat.go | 13 ++++++++- jsonschema/json.go | 8 +++++- 3 files changed, 80 insertions(+), 2 deletions(-) diff --git a/api_integration_test.go b/api_integration_test.go index f34685188..a487f588a 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -4,6 +4,7 @@ package openai_test import ( "context" + "encoding/json" "errors" "io" "os" @@ -178,3 +179,63 @@ func TestAPIError(t *testing.T) { t.Fatal("Empty error message occurred") } } + +func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + var err error + c := openai.NewClient(apiToken) + ctx := context.Background() + + resp, err := c.CreateChatCompletion( + ctx, + openai.ChatCompletionRequest{ + Model: openai.GPT4oMini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: "Please enter a string, and we will convert it into the following naming conventions:" + + "1. PascalCase: Each word starts with an uppercase letter, with no spaces or separators." + + "2. CamelCase: The first word starts with a lowercase letter, " + + "and subsequent words start with an uppercase letter, with no spaces or separators." + + "3. KebabCase: All letters are lowercase, with words separated by hyphens `-`." + + "4. SnakeCase: All letters are lowercase, with words separated by underscores `_`.", + }, + { + Role: openai.ChatMessageRoleUser, + Content: "Hello World", + }, + }, + ResponseFormat: &openai.ChatCompletionResponseFormat{ + Type: openai.ChatCompletionResponseFormatTypeJSONSchema, + JSONSchema: openai.ChatCompletionResponseFormatJSONSchema{ + Name: "cases", + Schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "PascalCase": jsonschema.Definition{Type: jsonschema.String}, + "CamelCase": jsonschema.Definition{Type: jsonschema.String}, + "KebabCase": jsonschema.Definition{Type: jsonschema.String}, + "SnakeCase": jsonschema.Definition{Type: jsonschema.String}, + }, + Required: []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"}, + AdditionalProperties: false, + }, + Strict: true, + }, + }, + }, + ) + checks.NoError(t, err, "CreateChatCompletion (use json_schema response) returned error") + var result = make(map[string]string) + err = json.Unmarshal([]byte(resp.Choices[0].Message.Content), &result) + checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error") + for _, key := range []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"} { + if _, ok := result[key]; !ok { + t.Errorf("key:%s does not exist.", key) + } + } +} diff --git a/chat.go b/chat.go index eb494f41f..8bfe558b5 100644 --- a/chat.go +++ b/chat.go @@ -5,6 +5,8 @@ import ( "encoding/json" "errors" "net/http" + + "github.com/sashabaranov/go-openai/jsonschema" ) // Chat message role defined by the OpenAI API. @@ -175,11 +177,20 @@ type ChatCompletionResponseFormatType string const ( ChatCompletionResponseFormatTypeJSONObject ChatCompletionResponseFormatType = "json_object" + ChatCompletionResponseFormatTypeJSONSchema ChatCompletionResponseFormatType = "json_schema" ChatCompletionResponseFormatTypeText ChatCompletionResponseFormatType = "text" ) type ChatCompletionResponseFormat struct { - Type ChatCompletionResponseFormatType `json:"type,omitempty"` + Type ChatCompletionResponseFormatType `json:"type,omitempty"` + JSONSchema ChatCompletionResponseFormatJSONSchema `json:"json_schema,omitempty"` +} + +type ChatCompletionResponseFormatJSONSchema struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Schema jsonschema.Definition `json:"schema"` + Strict bool `json:"strict"` } // ChatCompletionRequest represents a request structure for chat completion API. diff --git a/jsonschema/json.go b/jsonschema/json.go index cb941eb75..7fd1e11bf 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -29,11 +29,17 @@ type Definition struct { // one element, where each element is unique. You will probably only use this with strings. Enum []string `json:"enum,omitempty"` // Properties describes the properties of an object, if the schema type is Object. - Properties map[string]Definition `json:"properties"` + Properties map[string]Definition `json:"properties,omitempty"` // Required specifies which properties are required, if the schema type is Object. Required []string `json:"required,omitempty"` // Items specifies which data type an array contains, if the schema type is Array. Items *Definition `json:"items,omitempty"` + // AdditionalProperties is used to control the handling of properties in an object + // that are not explicitly defined in the properties section of the schema. example: + // additionalProperties: true + // additionalProperties: false + // additionalProperties: jsonschema.Definition{Type: jsonschema.String} + AdditionalProperties any `json:"additionalProperties,omitempty"` } func (d Definition) MarshalJSON() ([]byte, error) { From 6439e1fcc93fc5175accf5d51358e45fa5ea9099 Mon Sep 17 00:00:00 2001 From: Tyler Gannon Date: Wed, 7 Aug 2024 12:40:45 -0700 Subject: [PATCH 036/129] Make reponse format JSONSchema optional (#820) --- chat.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chat.go b/chat.go index 8bfe558b5..31fa887d6 100644 --- a/chat.go +++ b/chat.go @@ -182,8 +182,8 @@ const ( ) type ChatCompletionResponseFormat struct { - Type ChatCompletionResponseFormatType `json:"type,omitempty"` - JSONSchema ChatCompletionResponseFormatJSONSchema `json:"json_schema,omitempty"` + Type ChatCompletionResponseFormatType `json:"type,omitempty"` + JSONSchema *ChatCompletionResponseFormatJSONSchema `json:"json_schema,omitempty"` } type ChatCompletionResponseFormatJSONSchema struct { From 18803333812ea21c409e84d426141606b9a6e692 Mon Sep 17 00:00:00 2001 From: Harald Nordgren Date: Fri, 9 Aug 2024 18:30:32 +0200 Subject: [PATCH 037/129] Run integration tests for PRs (#823) * Unbreak integration tests * Update integration-tests.yml --- api_integration_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api_integration_test.go b/api_integration_test.go index a487f588a..3084268e6 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -211,7 +211,7 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { }, ResponseFormat: &openai.ChatCompletionResponseFormat{ Type: openai.ChatCompletionResponseFormatTypeJSONSchema, - JSONSchema: openai.ChatCompletionResponseFormatJSONSchema{ + JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{ Name: "cases", Schema: jsonschema.Definition{ Type: jsonschema.Object, From 2c6889e0818b93c4fd724d9528b610896f5e9421 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Sun, 11 Aug 2024 05:05:06 +0800 Subject: [PATCH 038/129] fix: #788 (#800) --- completion.go | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/completion.go b/completion.go index d435eb382..bc2a63795 100644 --- a/completion.go +++ b/completion.go @@ -138,25 +138,26 @@ func checkPromptType(prompt any) bool { // CompletionRequest represents a request structure for completion API. type CompletionRequest struct { - Model string `json:"model"` - Prompt any `json:"prompt,omitempty"` - Suffix string `json:"suffix,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Stream bool `json:"stream,omitempty"` - LogProbs int `json:"logprobs,omitempty"` - Echo bool `json:"echo,omitempty"` - Stop []string `json:"stop,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` - BestOf int `json:"best_of,omitempty"` + Model string `json:"model"` + Prompt any `json:"prompt,omitempty"` + BestOf int `json:"best_of,omitempty"` + Echo bool `json:"echo,omitempty"` + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` // refs: https://platform.openai.com/docs/api-reference/completions/create#completions/create-logit_bias - LogitBias map[string]int `json:"logit_bias,omitempty"` - User string `json:"user,omitempty"` + LogitBias map[string]int `json:"logit_bias,omitempty"` + LogProbs int `json:"logprobs,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + Seed *int `json:"seed,omitempty"` + Stop []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + Suffix string `json:"suffix,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + User string `json:"user,omitempty"` } // CompletionChoice represents one of possible completions. From dd7f5824f9a4c3860cccfaf8350d5d09e864038f Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Sat, 17 Aug 2024 01:11:38 +0800 Subject: [PATCH 039/129] fix: fullURL endpoint generation (#817) --- api_internal_test.go | 24 ++++++++--- audio.go | 9 ++++- chat.go | 7 +++- chat_stream.go | 7 +++- client.go | 84 ++++++++++++++++++++++++-------------- client_test.go | 96 ++++++++++++++++++++++++++++++++++++++++++++ completion.go | 7 +++- edits.go | 7 +++- embeddings.go | 7 +++- example_test.go | 2 +- image.go | 25 +++++++++--- moderation.go | 7 +++- speech.go | 5 ++- stream.go | 8 +++- 14 files changed, 244 insertions(+), 51 deletions(-) diff --git a/api_internal_test.go b/api_internal_test.go index a590ec9ab..09677968a 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -112,6 +112,7 @@ func TestAzureFullURL(t *testing.T) { Name string BaseURL string AzureModelMapper map[string]string + Suffix string Model string Expect string }{ @@ -119,6 +120,7 @@ func TestAzureFullURL(t *testing.T) { "AzureBaseURLWithSlashAutoStrip", "https://httpbin.org/", nil, + "/chat/completions", "chatgpt-demo", "https://httpbin.org/" + "openai/deployments/chatgpt-demo" + @@ -128,11 +130,20 @@ func TestAzureFullURL(t *testing.T) { "AzureBaseURLWithoutSlashOK", "https://httpbin.org/", nil, + "/chat/completions", "chatgpt-demo", "https://httpbin.org/" + "openai/deployments/chatgpt-demo" + "/chat/completions?api-version=2023-05-15", }, + { + "", + "https://httpbin.org/", + nil, + "/assistants?limit=10", + "chatgpt-demo", + "https://httpbin.org/openai/assistants?api-version=2023-05-15&limit=10", + }, } for _, c := range cases { @@ -140,7 +151,7 @@ func TestAzureFullURL(t *testing.T) { az := DefaultAzureConfig("dummy", c.BaseURL) cli := NewClientWithConfig(az) // /openai/deployments/{engine}/chat/completions?api-version={api_version} - actual := cli.fullURL("/chat/completions", c.Model) + actual := cli.fullURL(c.Suffix, withModel(c.Model)) if actual != c.Expect { t.Errorf("Expected %s, got %s", c.Expect, actual) } @@ -153,19 +164,22 @@ func TestCloudflareAzureFullURL(t *testing.T) { cases := []struct { Name string BaseURL string + Suffix string Expect string }{ { "CloudflareAzureBaseURLWithSlashAutoStrip", "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/", + "/chat/completions", "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + "chat/completions?api-version=2023-05-15", }, { - "CloudflareAzureBaseURLWithoutSlashOK", + "", "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo", - "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + - "chat/completions?api-version=2023-05-15", + "/assistants?limit=10", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo" + + "/assistants?api-version=2023-05-15&limit=10", }, } @@ -176,7 +190,7 @@ func TestCloudflareAzureFullURL(t *testing.T) { cli := NewClientWithConfig(az) - actual := cli.fullURL("/chat/completions") + actual := cli.fullURL(c.Suffix) if actual != c.Expect { t.Errorf("Expected %s, got %s", c.Expect, actual) } diff --git a/audio.go b/audio.go index dbc26d154..f321f93d6 100644 --- a/audio.go +++ b/audio.go @@ -122,8 +122,13 @@ func (c *Client) callAudioAPI( } urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix) - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), - withBody(&formBody), withContentType(builder.FormDataContentType())) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(&formBody), + withContentType(builder.FormDataContentType()), + ) if err != nil { return AudioResponse{}, err } diff --git a/chat.go b/chat.go index 31fa887d6..826fd3bd5 100644 --- a/chat.go +++ b/chat.go @@ -358,7 +358,12 @@ func (c *Client) CreateChatCompletion( return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return } diff --git a/chat_stream.go b/chat_stream.go index ffd512ff6..3f90bc019 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -60,7 +60,12 @@ func (c *Client) CreateChatCompletionStream( } request.Stream = true - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return nil, err } diff --git a/client.go b/client.go index d5d555c3d..9f547e7cb 100644 --- a/client.go +++ b/client.go @@ -222,42 +222,66 @@ func decodeString(body io.Reader, output *string) error { return nil } +type fullURLOptions struct { + model string +} + +type fullURLOption func(*fullURLOptions) + +func withModel(model string) fullURLOption { + return func(args *fullURLOptions) { + args.model = model + } +} + +var azureDeploymentsEndpoints = []string{ + "/completions", + "/embeddings", + "/chat/completions", + "/audio/transcriptions", + "/audio/translations", + "/audio/speech", + "/images/generations", +} + // fullURL returns full URL for request. -// args[0] is model name, if API type is Azure, model name is required to get deployment name. -func (c *Client) fullURL(suffix string, args ...any) string { - // /openai/deployments/{model}/chat/completions?api-version={api_version} +func (c *Client) fullURL(suffix string, setters ...fullURLOption) string { + baseURL := strings.TrimRight(c.config.BaseURL, "/") + args := fullURLOptions{} + for _, setter := range setters { + setter(&args) + } + if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { - baseURL := c.config.BaseURL - baseURL = strings.TrimRight(baseURL, "/") - parseURL, _ := url.Parse(baseURL) - query := parseURL.Query() - query.Add("api-version", c.config.APIVersion) - // if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01 - // https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP - if containsSubstr([]string{"/models", "/assistants", "/threads", "/files"}, suffix) { - return fmt.Sprintf("%s/%s%s?%s", baseURL, azureAPIPrefix, suffix, query.Encode()) - } - azureDeploymentName := "UNKNOWN" - if len(args) > 0 { - model, ok := args[0].(string) - if ok { - azureDeploymentName = c.config.GetAzureDeploymentByModel(model) - } - } - return fmt.Sprintf("%s/%s/%s/%s%s?%s", - baseURL, azureAPIPrefix, azureDeploymentsPrefix, - azureDeploymentName, suffix, query.Encode(), - ) + baseURL = c.baseURLWithAzureDeployment(baseURL, suffix, args.model) + } + + if c.config.APIVersion != "" { + suffix = c.suffixWithAPIVersion(suffix) } + return fmt.Sprintf("%s%s", baseURL, suffix) +} - // https://developers.cloudflare.com/ai-gateway/providers/azureopenai/ - if c.config.APIType == APITypeCloudflareAzure { - baseURL := c.config.BaseURL - baseURL = strings.TrimRight(baseURL, "/") - return fmt.Sprintf("%s%s?api-version=%s", baseURL, suffix, c.config.APIVersion) +func (c *Client) suffixWithAPIVersion(suffix string) string { + parsedSuffix, err := url.Parse(suffix) + if err != nil { + panic("failed to parse url suffix") } + query := parsedSuffix.Query() + query.Add("api-version", c.config.APIVersion) + return fmt.Sprintf("%s?%s", parsedSuffix.Path, query.Encode()) +} - return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) +func (c *Client) baseURLWithAzureDeployment(baseURL, suffix, model string) (newBaseURL string) { + baseURL = fmt.Sprintf("%s/%s", strings.TrimRight(baseURL, "/"), azureAPIPrefix) + if containsSubstr(azureDeploymentsEndpoints, suffix) { + azureDeploymentName := c.config.GetAzureDeploymentByModel(model) + if azureDeploymentName == "" { + azureDeploymentName = "UNKNOWN" + } + baseURL = fmt.Sprintf("%s/%s/%s", baseURL, azureDeploymentsPrefix, azureDeploymentName) + } + return baseURL } func (c *Client) handleErrorResp(resp *http.Response) error { diff --git a/client_test.go b/client_test.go index e49da9b3d..a0d3bb390 100644 --- a/client_test.go +++ b/client_test.go @@ -431,3 +431,99 @@ func TestClientReturnsRequestBuilderErrorsAddition(t *testing.T) { t.Fatalf("Did not return error when request builder failed: %v", err) } } + +func TestClient_suffixWithAPIVersion(t *testing.T) { + type fields struct { + apiVersion string + } + type args struct { + suffix string + } + tests := []struct { + name string + fields fields + args args + want string + wantPanic string + }{ + { + "", + fields{apiVersion: "2023-05"}, + args{suffix: "/assistants"}, + "/assistants?api-version=2023-05", + "", + }, + { + "", + fields{apiVersion: "2023-05"}, + args{suffix: "/assistants?limit=5"}, + "/assistants?api-version=2023-05&limit=5", + "", + }, + { + "", + fields{apiVersion: "2023-05"}, + args{suffix: "123:assistants?limit=5"}, + "/assistants?api-version=2023-05&limit=5", + "failed to parse url suffix", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Client{ + config: ClientConfig{APIVersion: tt.fields.apiVersion}, + } + defer func() { + if r := recover(); r != nil { + if r.(string) != tt.wantPanic { + t.Errorf("suffixWithAPIVersion() = %v, want %v", r, tt.wantPanic) + } + } + }() + if got := c.suffixWithAPIVersion(tt.args.suffix); got != tt.want { + t.Errorf("suffixWithAPIVersion() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestClient_baseURLWithAzureDeployment(t *testing.T) { + type args struct { + baseURL string + suffix string + model string + } + tests := []struct { + name string + args args + wantNewBaseURL string + }{ + { + "", + args{baseURL: "https://test.openai.azure.com/", suffix: assistantsSuffix, model: GPT4oMini}, + "https://test.openai.azure.com/openai", + }, + { + "", + args{baseURL: "https://test.openai.azure.com/", suffix: chatCompletionsSuffix, model: GPT4oMini}, + "https://test.openai.azure.com/openai/deployments/gpt-4o-mini", + }, + { + "", + args{baseURL: "https://test.openai.azure.com/", suffix: chatCompletionsSuffix, model: ""}, + "https://test.openai.azure.com/openai/deployments/UNKNOWN", + }, + } + client := NewClient("") + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotNewBaseURL := client.baseURLWithAzureDeployment( + tt.args.baseURL, + tt.args.suffix, + tt.args.model, + ); gotNewBaseURL != tt.wantNewBaseURL { + t.Errorf("baseURLWithAzureDeployment() = %v, want %v", gotNewBaseURL, tt.wantNewBaseURL) + } + }) + } +} diff --git a/completion.go b/completion.go index bc2a63795..e8e9242c9 100644 --- a/completion.go +++ b/completion.go @@ -213,7 +213,12 @@ func (c *Client) CreateCompletion( return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return } diff --git a/edits.go b/edits.go index 97d026029..fe8ecd0c1 100644 --- a/edits.go +++ b/edits.go @@ -38,7 +38,12 @@ will need to migrate to GPT-3.5 Turbo by January 4, 2024. You can use CreateChatCompletion or CreateChatCompletionStream instead. */ func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) { - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/edits", withModel(fmt.Sprint(request.Model))), + withBody(request), + ) if err != nil { return } diff --git a/embeddings.go b/embeddings.go index b513ba6a7..74eb8aa57 100644 --- a/embeddings.go +++ b/embeddings.go @@ -241,7 +241,12 @@ func (c *Client) CreateEmbeddings( conv EmbeddingRequestConverter, ) (res EmbeddingResponse, err error) { baseReq := conv.Convert() - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", string(baseReq.Model)), withBody(baseReq)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/embeddings", withModel(string(baseReq.Model))), + withBody(baseReq), + ) if err != nil { return } diff --git a/example_test.go b/example_test.go index de67c57cd..1bdb8496e 100644 --- a/example_test.go +++ b/example_test.go @@ -73,7 +73,7 @@ func ExampleClient_CreateChatCompletionStream() { return } - fmt.Printf(response.Choices[0].Delta.Content) + fmt.Println(response.Choices[0].Delta.Content) } } diff --git a/image.go b/image.go index 665de1a74..577d7db95 100644 --- a/image.go +++ b/image.go @@ -68,7 +68,12 @@ type ImageResponseDataInner struct { // CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) { urlSuffix := "/images/generations" - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return } @@ -132,8 +137,13 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/edits", request.Model), - withBody(body), withContentType(builder.FormDataContentType())) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/images/edits", withModel(request.Model)), + withBody(body), + withContentType(builder.FormDataContentType()), + ) if err != nil { return } @@ -183,8 +193,13 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/variations", request.Model), - withBody(body), withContentType(builder.FormDataContentType())) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/images/variations", withModel(request.Model)), + withBody(body), + withContentType(builder.FormDataContentType()), + ) if err != nil { return } diff --git a/moderation.go b/moderation.go index ae285ef83..c8652efc8 100644 --- a/moderation.go +++ b/moderation.go @@ -88,7 +88,12 @@ func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (re err = ErrModerationInvalidModel return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), withBody(&request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/moderations", withModel(request.Model)), + withBody(&request), + ) if err != nil { return } diff --git a/speech.go b/speech.go index 19b21bdf1..20b52e334 100644 --- a/speech.go +++ b/speech.go @@ -44,7 +44,10 @@ type CreateSpeechRequest struct { } func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response RawResponse, err error) { - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", string(request.Model)), + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/audio/speech", withModel(string(request.Model))), withBody(request), withContentType("application/json"), ) diff --git a/stream.go b/stream.go index b277f3c29..a61c7c970 100644 --- a/stream.go +++ b/stream.go @@ -3,6 +3,7 @@ package openai import ( "context" "errors" + "net/http" ) var ( @@ -33,7 +34,12 @@ func (c *Client) CreateCompletionStream( } request.Stream = true - req, err := c.newRequest(ctx, "POST", c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return nil, err } From d86425a5cfd09bb76fe2f9239a03a9dbcdca8a9c Mon Sep 17 00:00:00 2001 From: Grey Baker Date: Fri, 16 Aug 2024 13:41:39 -0400 Subject: [PATCH 040/129] Allow structured outputs via function calling (#828) --- api_integration_test.go | 76 +++++++++++++++++++++++++++++++++++++++++ chat.go | 1 + chat_test.go | 26 ++++++++++++++ 3 files changed, 103 insertions(+) diff --git a/api_integration_test.go b/api_integration_test.go index 3084268e6..57f7c40fb 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -239,3 +239,79 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { } } } + +func TestChatCompletionStructuredOutputsFunctionCalling(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + var err error + c := openai.NewClient(apiToken) + ctx := context.Background() + + resp, err := c.CreateChatCompletion( + ctx, + openai.ChatCompletionRequest{ + Model: openai.GPT4oMini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: "Please enter a string, and we will convert it into the following naming conventions:" + + "1. PascalCase: Each word starts with an uppercase letter, with no spaces or separators." + + "2. CamelCase: The first word starts with a lowercase letter, " + + "and subsequent words start with an uppercase letter, with no spaces or separators." + + "3. KebabCase: All letters are lowercase, with words separated by hyphens `-`." + + "4. SnakeCase: All letters are lowercase, with words separated by underscores `_`.", + }, + { + Role: openai.ChatMessageRoleUser, + Content: "Hello World", + }, + }, + Tools: []openai.Tool{ + { + Type: openai.ToolTypeFunction, + Function: &openai.FunctionDefinition{ + Name: "display_cases", + Strict: true, + Parameters: &jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "PascalCase": { + Type: jsonschema.String, + }, + "CamelCase": { + Type: jsonschema.String, + }, + "KebabCase": { + Type: jsonschema.String, + }, + "SnakeCase": { + Type: jsonschema.String, + }, + }, + Required: []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"}, + AdditionalProperties: false, + }, + }, + }, + }, + ToolChoice: openai.ToolChoice{ + Type: openai.ToolTypeFunction, + Function: openai.ToolFunction{ + Name: "display_cases", + }, + }, + }, + ) + checks.NoError(t, err, "CreateChatCompletion (use structured outputs response) returned error") + var result = make(map[string]string) + err = json.Unmarshal([]byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments), &result) + checks.NoError(t, err, "CreateChatCompletion (use structured outputs response) unmarshal error") + for _, key := range []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"} { + if _, ok := result[key]; !ok { + t.Errorf("key:%s does not exist.", key) + } + } +} diff --git a/chat.go b/chat.go index 826fd3bd5..97c89a497 100644 --- a/chat.go +++ b/chat.go @@ -264,6 +264,7 @@ type ToolFunction struct { type FunctionDefinition struct { Name string `json:"name"` Description string `json:"description,omitempty"` + Strict bool `json:"strict,omitempty"` // Parameters is an object describing the function. // You can pass json.RawMessage to describe the schema, // or you can pass in a struct which serializes to the proper JSON schema. diff --git a/chat_test.go b/chat_test.go index 520bf5ca4..37dc09d4d 100644 --- a/chat_test.go +++ b/chat_test.go @@ -277,6 +277,32 @@ func TestChatCompletionsFunctions(t *testing.T) { }) checks.NoError(t, err, "CreateChatCompletion with functions error") }) + t.Run("StructuredOutputs", func(t *testing.T) { + type testMessage struct { + Count int `json:"count"` + Words []string `json:"words"` + } + msg := testMessage{ + Count: 2, + Words: []string{"hello", "world"}, + } + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo0613, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Functions: []openai.FunctionDefinition{{ + Name: "test", + Strict: true, + Parameters: &msg, + }}, + }) + checks.NoError(t, err, "CreateChatCompletion with functions error") + }) } func TestAzureChatCompletions(t *testing.T) { From 6d021190f05410a44d9401984815c55f4736b755 Mon Sep 17 00:00:00 2001 From: Yamagami ken-ichi Date: Thu, 22 Aug 2024 23:27:44 +0900 Subject: [PATCH 041/129] feat: Support Delete Message API (#799) * feat: Add DeleteMessage function to API client * fix: linter nolint : Deprecated method split function: cognitive complexity 21 * rename func name for unit-test --- client_test.go | 3 +++ fine_tunes.go | 2 +- messages.go | 24 ++++++++++++++++++++++++ messages_test.go | 36 +++++++++++++++++++++++++++++++----- 4 files changed, 59 insertions(+), 6 deletions(-) diff --git a/client_test.go b/client_test.go index a0d3bb390..7119d8a7e 100644 --- a/client_test.go +++ b/client_test.go @@ -348,6 +348,9 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"ModifyMessage", func() (any, error) { return client.ModifyMessage(ctx, "", "", nil) }}, + {"DeleteMessage", func() (any, error) { + return client.DeleteMessage(ctx, "", "") + }}, {"RetrieveMessageFile", func() (any, error) { return client.RetrieveMessageFile(ctx, "", "", "") }}, diff --git a/fine_tunes.go b/fine_tunes.go index ca840781c..74b47bf3f 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -115,7 +115,7 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r // This API will be officially deprecated on January 4th, 2024. // OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel")) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel")) //nolint:lll //this method is deprecated if err != nil { return } diff --git a/messages.go b/messages.go index 6af118445..1fddd6314 100644 --- a/messages.go +++ b/messages.go @@ -73,6 +73,14 @@ type MessageFilesList struct { httpHeader } +type MessageDeletionStatus struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` + + httpHeader +} + // CreateMessage creates a new message. func (c *Client) CreateMessage(ctx context.Context, threadID string, request MessageRequest) (msg Message, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s", threadID, messagesSuffix) @@ -186,3 +194,19 @@ func (c *Client) ListMessageFiles( err = c.sendRequest(req, &files) return } + +// DeleteMessage deletes a message.. +func (c *Client) DeleteMessage( + ctx context.Context, + threadID, messageID string, +) (status MessageDeletionStatus, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &status) + return +} diff --git a/messages_test.go b/messages_test.go index a18be20bd..71ceb4d3a 100644 --- a/messages_test.go +++ b/messages_test.go @@ -8,20 +8,17 @@ import ( "testing" "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" ) var emptyStr = "" -// TestMessages Tests the messages endpoint of the API using the mocked server. -func TestMessages(t *testing.T) { +func setupServerForTestMessage(t *testing.T, server *test.ServerTest) { threadID := "thread_abc123" messageID := "msg_abc123" fileID := "file_abc123" - client, server, teardown := setupOpenAITestServer() - defer teardown() - server.RegisterHandler( "/v1/threads/"+threadID+"/messages/"+messageID+"/files/"+fileID, func(w http.ResponseWriter, r *http.Request) { @@ -115,6 +112,13 @@ func TestMessages(t *testing.T) { Metadata: nil, }) fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + resBytes, _ := json.Marshal(openai.MessageDeletionStatus{ + ID: messageID, + Object: "thread.message.deleted", + Deleted: true, + }) + fmt.Fprintln(w, string(resBytes)) default: t.Fatalf("unsupported messages http method: %s", r.Method) } @@ -176,7 +180,18 @@ func TestMessages(t *testing.T) { } }, ) +} +// TestMessages Tests the messages endpoint of the API using the mocked server. +func TestMessages(t *testing.T) { + threadID := "thread_abc123" + messageID := "msg_abc123" + fileID := "file_abc123" + + client, server, teardown := setupOpenAITestServer() + defer teardown() + + setupServerForTestMessage(t, server) ctx := context.Background() // static assertion of return type @@ -225,6 +240,17 @@ func TestMessages(t *testing.T) { t.Fatalf("expected message metadata to get modified") } + msgDel, err := client.DeleteMessage(ctx, threadID, messageID) + checks.NoError(t, err, "DeleteMessage error") + if msgDel.ID != messageID { + t.Fatalf("unexpected message id: '%s'", msg.ID) + } + if !msgDel.Deleted { + t.Fatalf("expected deleted is true") + } + _, err = client.DeleteMessage(ctx, threadID, "not_exist_id") + checks.HasError(t, err, "DeleteMessage error") + // message files var msgFile openai.MessageFile msgFile, err = client.RetrieveMessageFile(ctx, threadID, messageID, fileID) From 5162adbbf90cef77b8462c1f33c81f7d258a1447 Mon Sep 17 00:00:00 2001 From: Alexey Michurin Date: Fri, 23 Aug 2024 13:47:11 +0300 Subject: [PATCH 042/129] Support http client middlewareing (#830) --- config.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/config.go b/config.go index 1347567d7..8a9183558 100644 --- a/config.go +++ b/config.go @@ -26,6 +26,10 @@ const AzureAPIKeyHeader = "api-key" const defaultAssistantVersion = "v2" // upgrade to v2 to support vector store +type HTTPDoer interface { + Do(req *http.Request) (*http.Response, error) +} + // ClientConfig is a configuration of a client. type ClientConfig struct { authToken string @@ -36,7 +40,7 @@ type ClientConfig struct { APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD AssistantVersion string AzureModelMapperFunc func(model string) string // replace model to azure deployment name func - HTTPClient *http.Client + HTTPClient HTTPDoer EmptyMessagesLimit uint } From a3bd2569ac51f1c54d704ec80dcbb91ab9f46acf Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Sun, 25 Aug 2024 01:06:08 +0800 Subject: [PATCH 043/129] Improve handling of JSON Schema in OpenAI API Response Context (#819) * feat: add jsonschema.Validate and jsonschema.Unmarshal * fix Sanity check * remove slices.Contains * fix Sanity check * add SchemaWrapper * update api_integration_test.go * update method 'reflectSchema' to support 'omitempty' in JSON tag * add GenerateSchemaForType * update json_test.go * update `Warp` to `Wrap` * fix Sanity check * fix Sanity check * update api_internal_test.go * update README.md * update README.md * remove jsonschema.SchemaWrapper * remove jsonschema.SchemaWrapper * fix Sanity check * optimize code formatting --- README.md | 64 +++++++++++++++++ api_integration_test.go | 36 +++++----- chat.go | 10 ++- example_test.go | 2 +- jsonschema/json.go | 105 +++++++++++++++++++++++++++- jsonschema/validate.go | 89 +++++++++++++++++++++++ jsonschema/validate_test.go | 136 ++++++++++++++++++++++++++++++++++++ 7 files changed, 412 insertions(+), 30 deletions(-) create mode 100644 jsonschema/validate.go create mode 100644 jsonschema/validate_test.go diff --git a/README.md b/README.md index 799dc602b..0d6aafa40 100644 --- a/README.md +++ b/README.md @@ -743,6 +743,70 @@ func main() { } ``` + +
+Structured Outputs + +```go +package main + +import ( + "context" + "fmt" + "log" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/jsonschema" +) + +func main() { + client := openai.NewClient("your token") + ctx := context.Background() + + type Result struct { + Steps []struct { + Explanation string `json:"explanation"` + Output string `json:"output"` + } `json:"steps"` + FinalAnswer string `json:"final_answer"` + } + var result Result + schema, err := jsonschema.GenerateSchemaForType(result) + if err != nil { + log.Fatalf("GenerateSchemaForType error: %v", err) + } + resp, err := client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ + Model: openai.GPT4oMini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: "You are a helpful math tutor. Guide the user through the solution step by step.", + }, + { + Role: openai.ChatMessageRoleUser, + Content: "how can I solve 8x + 7 = -23", + }, + }, + ResponseFormat: &openai.ChatCompletionResponseFormat{ + Type: openai.ChatCompletionResponseFormatTypeJSONSchema, + JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{ + Name: "math_reasoning", + Schema: schema, + Strict: true, + }, + }, + }) + if err != nil { + log.Fatalf("CreateChatCompletion error: %v", err) + } + err = schema.Unmarshal(resp.Choices[0].Message.Content, &result) + if err != nil { + log.Fatalf("Unmarshal schema error: %v", err) + } + fmt.Println(result) +} +``` +
See the `examples/` folder for more. ## Frequently Asked Questions diff --git a/api_integration_test.go b/api_integration_test.go index 57f7c40fb..8c9f3384f 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -4,7 +4,6 @@ package openai_test import ( "context" - "encoding/json" "errors" "io" "os" @@ -190,6 +189,17 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { c := openai.NewClient(apiToken) ctx := context.Background() + type MyStructuredResponse struct { + PascalCase string `json:"pascal_case" required:"true" description:"PascalCase"` + CamelCase string `json:"camel_case" required:"true" description:"CamelCase"` + KebabCase string `json:"kebab_case" required:"true" description:"KebabCase"` + SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"` + } + var result MyStructuredResponse + schema, err := jsonschema.GenerateSchemaForType(result) + if err != nil { + t.Fatal("CreateChatCompletion (use json_schema response) GenerateSchemaForType error") + } resp, err := c.CreateChatCompletion( ctx, openai.ChatCompletionRequest{ @@ -212,31 +222,17 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { ResponseFormat: &openai.ChatCompletionResponseFormat{ Type: openai.ChatCompletionResponseFormatTypeJSONSchema, JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{ - Name: "cases", - Schema: jsonschema.Definition{ - Type: jsonschema.Object, - Properties: map[string]jsonschema.Definition{ - "PascalCase": jsonschema.Definition{Type: jsonschema.String}, - "CamelCase": jsonschema.Definition{Type: jsonschema.String}, - "KebabCase": jsonschema.Definition{Type: jsonschema.String}, - "SnakeCase": jsonschema.Definition{Type: jsonschema.String}, - }, - Required: []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"}, - AdditionalProperties: false, - }, + Name: "cases", + Schema: schema, Strict: true, }, }, }, ) checks.NoError(t, err, "CreateChatCompletion (use json_schema response) returned error") - var result = make(map[string]string) - err = json.Unmarshal([]byte(resp.Choices[0].Message.Content), &result) - checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error") - for _, key := range []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"} { - if _, ok := result[key]; !ok { - t.Errorf("key:%s does not exist.", key) - } + if err == nil { + err = schema.Unmarshal(resp.Choices[0].Message.Content, &result) + checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error") } } diff --git a/chat.go b/chat.go index 97c89a497..56e99a78b 100644 --- a/chat.go +++ b/chat.go @@ -5,8 +5,6 @@ import ( "encoding/json" "errors" "net/http" - - "github.com/sashabaranov/go-openai/jsonschema" ) // Chat message role defined by the OpenAI API. @@ -187,10 +185,10 @@ type ChatCompletionResponseFormat struct { } type ChatCompletionResponseFormatJSONSchema struct { - Name string `json:"name"` - Description string `json:"description,omitempty"` - Schema jsonschema.Definition `json:"schema"` - Strict bool `json:"strict"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + Schema json.Marshaler `json:"schema"` + Strict bool `json:"strict"` } // ChatCompletionRequest represents a request structure for chat completion API. diff --git a/example_test.go b/example_test.go index 1bdb8496e..e5dbf44bf 100644 --- a/example_test.go +++ b/example_test.go @@ -59,7 +59,7 @@ func ExampleClient_CreateChatCompletionStream() { } defer stream.Close() - fmt.Printf("Stream response: ") + fmt.Print("Stream response: ") for { var response openai.ChatCompletionStreamResponse response, err = stream.Recv() diff --git a/jsonschema/json.go b/jsonschema/json.go index 7fd1e11bf..bcb253fae 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -4,7 +4,13 @@ // and/or pass in the schema in []byte format. package jsonschema -import "encoding/json" +import ( + "encoding/json" + "fmt" + "reflect" + "strconv" + "strings" +) type DataType string @@ -42,7 +48,7 @@ type Definition struct { AdditionalProperties any `json:"additionalProperties,omitempty"` } -func (d Definition) MarshalJSON() ([]byte, error) { +func (d *Definition) MarshalJSON() ([]byte, error) { if d.Properties == nil { d.Properties = make(map[string]Definition) } @@ -50,6 +56,99 @@ func (d Definition) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Alias }{ - Alias: (Alias)(d), + Alias: (Alias)(*d), }) } + +func (d *Definition) Unmarshal(content string, v any) error { + return VerifySchemaAndUnmarshal(*d, []byte(content), v) +} + +func GenerateSchemaForType(v any) (*Definition, error) { + return reflectSchema(reflect.TypeOf(v)) +} + +func reflectSchema(t reflect.Type) (*Definition, error) { + var d Definition + switch t.Kind() { + case reflect.String: + d.Type = String + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + d.Type = Integer + case reflect.Float32, reflect.Float64: + d.Type = Number + case reflect.Bool: + d.Type = Boolean + case reflect.Slice, reflect.Array: + d.Type = Array + items, err := reflectSchema(t.Elem()) + if err != nil { + return nil, err + } + d.Items = items + case reflect.Struct: + d.Type = Object + d.AdditionalProperties = false + object, err := reflectSchemaObject(t) + if err != nil { + return nil, err + } + d = *object + case reflect.Ptr: + definition, err := reflectSchema(t.Elem()) + if err != nil { + return nil, err + } + d = *definition + case reflect.Invalid, reflect.Uintptr, reflect.Complex64, reflect.Complex128, + reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, + reflect.UnsafePointer: + return nil, fmt.Errorf("unsupported type: %s", t.Kind().String()) + default: + } + return &d, nil +} + +func reflectSchemaObject(t reflect.Type) (*Definition, error) { + var d = Definition{ + Type: Object, + AdditionalProperties: false, + } + properties := make(map[string]Definition) + var requiredFields []string + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if !field.IsExported() { + continue + } + jsonTag := field.Tag.Get("json") + var required = true + if jsonTag == "" { + jsonTag = field.Name + } else if strings.HasSuffix(jsonTag, ",omitempty") { + jsonTag = strings.TrimSuffix(jsonTag, ",omitempty") + required = false + } + + item, err := reflectSchema(field.Type) + if err != nil { + return nil, err + } + description := field.Tag.Get("description") + if description != "" { + item.Description = description + } + properties[jsonTag] = *item + + if s := field.Tag.Get("required"); s != "" { + required, _ = strconv.ParseBool(s) + } + if required { + requiredFields = append(requiredFields, jsonTag) + } + } + d.Required = requiredFields + d.Properties = properties + return &d, nil +} diff --git a/jsonschema/validate.go b/jsonschema/validate.go new file mode 100644 index 000000000..f14ffd4c4 --- /dev/null +++ b/jsonschema/validate.go @@ -0,0 +1,89 @@ +package jsonschema + +import ( + "encoding/json" + "errors" +) + +func VerifySchemaAndUnmarshal(schema Definition, content []byte, v any) error { + var data any + err := json.Unmarshal(content, &data) + if err != nil { + return err + } + if !Validate(schema, data) { + return errors.New("data validation failed against the provided schema") + } + return json.Unmarshal(content, &v) +} + +func Validate(schema Definition, data any) bool { + switch schema.Type { + case Object: + return validateObject(schema, data) + case Array: + return validateArray(schema, data) + case String: + _, ok := data.(string) + return ok + case Number: // float64 and int + _, ok := data.(float64) + if !ok { + _, ok = data.(int) + } + return ok + case Boolean: + _, ok := data.(bool) + return ok + case Integer: + _, ok := data.(int) + return ok + case Null: + return data == nil + default: + return false + } +} + +func validateObject(schema Definition, data any) bool { + dataMap, ok := data.(map[string]any) + if !ok { + return false + } + for _, field := range schema.Required { + if _, exists := dataMap[field]; !exists { + return false + } + } + for key, valueSchema := range schema.Properties { + value, exists := dataMap[key] + if exists && !Validate(valueSchema, value) { + return false + } else if !exists && contains(schema.Required, key) { + return false + } + } + return true +} + +func validateArray(schema Definition, data any) bool { + dataArray, ok := data.([]any) + if !ok { + return false + } + for _, item := range dataArray { + if !Validate(*schema.Items, item) { + return false + } + } + return true +} + +func contains[S ~[]E, E comparable](s S, v E) bool { + for i := range s { + if v == s[i] { + return true + } + } + return false +} diff --git a/jsonschema/validate_test.go b/jsonschema/validate_test.go new file mode 100644 index 000000000..c2c47a2ce --- /dev/null +++ b/jsonschema/validate_test.go @@ -0,0 +1,136 @@ +package jsonschema_test + +import ( + "testing" + + "github.com/sashabaranov/go-openai/jsonschema" +) + +func Test_Validate(t *testing.T) { + type args struct { + data any + schema jsonschema.Definition + } + tests := []struct { + name string + args args + want bool + }{ + // string integer number boolean + {"", args{data: "ABC", schema: jsonschema.Definition{Type: jsonschema.String}}, true}, + {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.String}}, false}, + {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.Integer}}, true}, + {"", args{data: 123.4, schema: jsonschema.Definition{Type: jsonschema.Integer}}, false}, + {"", args{data: "ABC", schema: jsonschema.Definition{Type: jsonschema.Number}}, false}, + {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.Number}}, true}, + {"", args{data: false, schema: jsonschema.Definition{Type: jsonschema.Boolean}}, true}, + {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.Boolean}}, false}, + {"", args{data: nil, schema: jsonschema.Definition{Type: jsonschema.Null}}, true}, + {"", args{data: 0, schema: jsonschema.Definition{Type: jsonschema.Null}}, false}, + // array + {"", args{data: []any{"a", "b", "c"}, schema: jsonschema.Definition{ + Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.String}}, + }, true}, + {"", args{data: []any{1, 2, 3}, schema: jsonschema.Definition{ + Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.String}}, + }, false}, + {"", args{data: []any{1, 2, 3}, schema: jsonschema.Definition{ + Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Integer}}, + }, true}, + {"", args{data: []any{1, 2, 3.4}, schema: jsonschema.Definition{ + Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Integer}}, + }, false}, + // object + {"", args{data: map[string]any{ + "string": "abc", + "integer": 123, + "number": 123.4, + "boolean": false, + "array": []any{1, 2, 3}, + }, schema: jsonschema.Definition{Type: jsonschema.Object, Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "integer": {Type: jsonschema.Integer}, + "number": {Type: jsonschema.Number}, + "boolean": {Type: jsonschema.Boolean}, + "array": {Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Number}}, + }, + Required: []string{"string"}, + }}, true}, + {"", args{data: map[string]any{ + "integer": 123, + "number": 123.4, + "boolean": false, + "array": []any{1, 2, 3}, + }, schema: jsonschema.Definition{Type: jsonschema.Object, Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "integer": {Type: jsonschema.Integer}, + "number": {Type: jsonschema.Number}, + "boolean": {Type: jsonschema.Boolean}, + "array": {Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Number}}, + }, + Required: []string{"string"}, + }}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := jsonschema.Validate(tt.args.schema, tt.args.data); got != tt.want { + t.Errorf("Validate() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUnmarshal(t *testing.T) { + type args struct { + schema jsonschema.Definition + content []byte + v any + } + var result1 struct { + String string `json:"string"` + Number float64 `json:"number"` + } + var result2 struct { + String string `json:"string"` + Number float64 `json:"number"` + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"", args{ + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "number": {Type: jsonschema.Number}, + }, + }, + content: []byte(`{"string":"abc","number":123.4}`), + v: &result1, + }, false}, + {"", args{ + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "number": {Type: jsonschema.Number}, + }, + Required: []string{"string", "number"}, + }, + content: []byte(`{"string":"abc"}`), + v: result2, + }, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := jsonschema.VerifySchemaAndUnmarshal(tt.args.schema, tt.args.content, tt.args.v) + if (err != nil) != tt.wantErr { + t.Errorf("Unmarshal() error = %v, wantErr %v", err, tt.wantErr) + } else if err == nil { + t.Logf("Unmarshal() v = %+v\n", tt.args.v) + } + }) + } +} From 030b7cb7ed60fc4a8b2fd608f538c470b65b1131 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Sat, 24 Aug 2024 18:11:27 +0100 Subject: [PATCH 044/129] fix integration tests (#834) --- api_integration_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/api_integration_test.go b/api_integration_test.go index 8c9f3384f..7828d9451 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -4,6 +4,7 @@ package openai_test import ( "context" + "encoding/json" "errors" "io" "os" From c37cf9ab5b887fe0195d3cc6240780e9b1928a04 Mon Sep 17 00:00:00 2001 From: Tommy Mathisen Date: Sun, 1 Sep 2024 18:30:29 +0300 Subject: [PATCH 045/129] Dynamic model (#838) --- completion.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/completion.go b/completion.go index e8e9242c9..12ce4b558 100644 --- a/completion.go +++ b/completion.go @@ -25,6 +25,7 @@ const ( GPT4o = "gpt-4o" GPT4o20240513 = "gpt-4o-2024-05-13" GPT4o20240806 = "gpt-4o-2024-08-06" + GPT4oLatest = "chatgpt-4o-latest" GPT4oMini = "gpt-4o-mini" GPT4oMini20240718 = "gpt-4o-mini-2024-07-18" GPT4Turbo = "gpt-4-turbo" @@ -93,6 +94,7 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT4o: true, GPT4o20240513: true, GPT4o20240806: true, + GPT4oLatest: true, GPT4oMini: true, GPT4oMini20240718: true, GPT4TurboPreview: true, From 643da8d650b1f7db4706076a53b9d0acddccbd17 Mon Sep 17 00:00:00 2001 From: Arun Das <89579096+Arundas666@users.noreply.github.com> Date: Wed, 4 Sep 2024 17:19:57 +0530 Subject: [PATCH 046/129] depricated model GPT3Ada changed to GPT3Babbage002 (#843) * depricated model GPT3Ada changed to GPT3Babbage002 * Delete test.mp3 --- README.md | 4 ++-- example_test.go | 4 ++-- examples/completion/main.go | 2 +- stream_test.go | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 0d6aafa40..b3ebc1471 100644 --- a/README.md +++ b/README.md @@ -141,7 +141,7 @@ func main() { ctx := context.Background() req := openai.CompletionRequest{ - Model: openai.GPT3Ada, + Model: openai.GPT3Babbage002, MaxTokens: 5, Prompt: "Lorem ipsum", } @@ -174,7 +174,7 @@ func main() { ctx := context.Background() req := openai.CompletionRequest{ - Model: openai.GPT3Ada, + Model: openai.GPT3Babbage002, MaxTokens: 5, Prompt: "Lorem ipsum", Stream: true, diff --git a/example_test.go b/example_test.go index e5dbf44bf..5910ffb84 100644 --- a/example_test.go +++ b/example_test.go @@ -82,7 +82,7 @@ func ExampleClient_CreateCompletion() { resp, err := client.CreateCompletion( context.Background(), openai.CompletionRequest{ - Model: openai.GPT3Ada, + Model: openai.GPT3Babbage002, MaxTokens: 5, Prompt: "Lorem ipsum", }, @@ -99,7 +99,7 @@ func ExampleClient_CreateCompletionStream() { stream, err := client.CreateCompletionStream( context.Background(), openai.CompletionRequest{ - Model: openai.GPT3Ada, + Model: openai.GPT3Babbage002, MaxTokens: 5, Prompt: "Lorem ipsum", Stream: true, diff --git a/examples/completion/main.go b/examples/completion/main.go index 22af1fd82..8c5cbd5ca 100644 --- a/examples/completion/main.go +++ b/examples/completion/main.go @@ -13,7 +13,7 @@ func main() { resp, err := client.CreateCompletion( context.Background(), openai.CompletionRequest{ - Model: openai.GPT3Ada, + Model: openai.GPT3Babbage002, MaxTokens: 5, Prompt: "Lorem ipsum", }, diff --git a/stream_test.go b/stream_test.go index 2822a3535..9dd95bb5f 100644 --- a/stream_test.go +++ b/stream_test.go @@ -169,7 +169,7 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) { var apiErr *openai.APIError _, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ MaxTokens: 5, - Model: openai.GPT3Ada, + Model: openai.GPT3Babbage002, Prompt: "Hello!", Stream: true, }) From 194a03e763f0d71333a6088bf613a35f65c50447 Mon Sep 17 00:00:00 2001 From: Quest Henkart Date: Wed, 11 Sep 2024 22:24:49 +0200 Subject: [PATCH 047/129] Add refusal (#844) * add custom marshaller, documentation and isolate tests * fix linter * add missing field --- chat.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/chat.go b/chat.go index 56e99a78b..dc60f35b9 100644 --- a/chat.go +++ b/chat.go @@ -82,6 +82,7 @@ type ChatMessagePart struct { type ChatCompletionMessage struct { Role string `json:"role"` Content string `json:"content"` + Refusal string `json:"refusal,omitempty"` MultiContent []ChatMessagePart // This property isn't in the official documentation, but it's in @@ -107,6 +108,7 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { msg := struct { Role string `json:"role"` Content string `json:"-"` + Refusal string `json:"refusal,omitempty"` MultiContent []ChatMessagePart `json:"content,omitempty"` Name string `json:"name,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"` @@ -115,9 +117,11 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { }(m) return json.Marshal(msg) } + msg := struct { Role string `json:"role"` Content string `json:"content"` + Refusal string `json:"refusal,omitempty"` MultiContent []ChatMessagePart `json:"-"` Name string `json:"name,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"` @@ -131,12 +135,14 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { msg := struct { Role string `json:"role"` Content string `json:"content"` + Refusal string `json:"refusal,omitempty"` MultiContent []ChatMessagePart Name string `json:"name,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCallID string `json:"tool_call_id,omitempty"` }{} + if err := json.Unmarshal(bs, &msg); err == nil { *m = ChatCompletionMessage(msg) return nil @@ -144,6 +150,7 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { multiMsg := struct { Role string `json:"role"` Content string + Refusal string `json:"refusal,omitempty"` MultiContent []ChatMessagePart `json:"content"` Name string `json:"name,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"` From a5fb55321b43aa6b31bb3ff57d43cb5a8f2e17ef Mon Sep 17 00:00:00 2001 From: Aaron Batilo Date: Tue, 17 Sep 2024 14:19:47 -0600 Subject: [PATCH 048/129] Support OpenAI reasoning models (#850) These model strings are now available for use. More info: https://openai.com/index/introducing-openai-o1-preview/ https://platform.openai.com/docs/guides/reasoning --- completion.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/completion.go b/completion.go index 12ce4b558..e1e065a8b 100644 --- a/completion.go +++ b/completion.go @@ -17,6 +17,10 @@ var ( // GPT3 Models are designed for text-based tasks. For code-specific // tasks, please refer to the Codex series of models. const ( + O1Mini = "o1-mini" + O1Mini20240912 = "o1-mini-2024-09-12" + O1Preview = "o1-preview" + O1Preview20240912 = "o1-preview-2024-09-12" GPT432K0613 = "gpt-4-32k-0613" GPT432K0314 = "gpt-4-32k-0314" GPT432K = "gpt-4-32k" @@ -83,6 +87,10 @@ const ( var disabledModelsForEndpoints = map[string]map[string]bool{ "/completions": { + O1Mini: true, + O1Mini20240912: true, + O1Preview: true, + O1Preview20240912: true, GPT3Dot5Turbo: true, GPT3Dot5Turbo0301: true, GPT3Dot5Turbo0613: true, From 1ec8c24ea7ae0e31d5e8332f8a0349d2ecd5b913 Mon Sep 17 00:00:00 2001 From: Wei-An Yen Date: Sat, 21 Sep 2024 02:22:01 +0800 Subject: [PATCH 049/129] fix: jsonschema integer validation (#852) --- jsonschema/validate.go | 4 ++++ jsonschema/validate_test.go | 48 +++++++++++++++++++++++++++++-------- 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/jsonschema/validate.go b/jsonschema/validate.go index f14ffd4c4..49f9b8859 100644 --- a/jsonschema/validate.go +++ b/jsonschema/validate.go @@ -36,6 +36,10 @@ func Validate(schema Definition, data any) bool { _, ok := data.(bool) return ok case Integer: + // Golang unmarshals all numbers as float64, so we need to check if the float64 is an integer + if num, ok := data.(float64); ok { + return num == float64(int64(num)) + } _, ok := data.(int) return ok case Null: diff --git a/jsonschema/validate_test.go b/jsonschema/validate_test.go index c2c47a2ce..6fa30ab0c 100644 --- a/jsonschema/validate_test.go +++ b/jsonschema/validate_test.go @@ -86,14 +86,6 @@ func TestUnmarshal(t *testing.T) { content []byte v any } - var result1 struct { - String string `json:"string"` - Number float64 `json:"number"` - } - var result2 struct { - String string `json:"string"` - Number float64 `json:"number"` - } tests := []struct { name string args args @@ -108,7 +100,10 @@ func TestUnmarshal(t *testing.T) { }, }, content: []byte(`{"string":"abc","number":123.4}`), - v: &result1, + v: &struct { + String string `json:"string"` + Number float64 `json:"number"` + }{}, }, false}, {"", args{ schema: jsonschema.Definition{ @@ -120,7 +115,40 @@ func TestUnmarshal(t *testing.T) { Required: []string{"string", "number"}, }, content: []byte(`{"string":"abc"}`), - v: result2, + v: struct { + String string `json:"string"` + Number float64 `json:"number"` + }{}, + }, true}, + {"validate integer", args{ + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "integer": {Type: jsonschema.Integer}, + }, + Required: []string{"string", "integer"}, + }, + content: []byte(`{"string":"abc","integer":123}`), + v: &struct { + String string `json:"string"` + Integer int `json:"integer"` + }{}, + }, false}, + {"validate integer failed", args{ + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "integer": {Type: jsonschema.Integer}, + }, + Required: []string{"string", "integer"}, + }, + content: []byte(`{"string":"abc","integer":123.4}`), + v: &struct { + String string `json:"string"` + Integer int `json:"integer"` + }{}, }, true}, } for _, tt := range tests { From 9add1c348607c14e8fde9966713c97f9a2351919 Mon Sep 17 00:00:00 2001 From: Ivan Timofeev Date: Fri, 20 Sep 2024 23:40:24 +0300 Subject: [PATCH 050/129] add max_completions_tokens for o1 series models (#857) * add max_completions_tokens for o1 series models * add validation for o1 series models validataion + beta limitations --- chat.go | 35 +++++--- chat_stream.go | 4 + chat_stream_test.go | 21 +++++ chat_test.go | 211 ++++++++++++++++++++++++++++++++++++++++++++ completion.go | 82 +++++++++++++++++ 5 files changed, 341 insertions(+), 12 deletions(-) diff --git a/chat.go b/chat.go index dc60f35b9..d47c95e4f 100644 --- a/chat.go +++ b/chat.go @@ -200,18 +200,25 @@ type ChatCompletionResponseFormatJSONSchema struct { // ChatCompletionRequest represents a request structure for chat completion API. type ChatCompletionRequest struct { - Model string `json:"model"` - Messages []ChatCompletionMessage `json:"messages"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Stream bool `json:"stream,omitempty"` - Stop []string `json:"stop,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` - Seed *int `json:"seed,omitempty"` - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + Model string `json:"model"` + Messages []ChatCompletionMessage `json:"messages"` + // MaxTokens The maximum number of tokens that can be generated in the chat completion. + // This value can be used to control costs for text generated via API. + // This value is now deprecated in favor of max_completion_tokens, and is not compatible with o1 series models. + // refs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_tokens + MaxTokens int `json:"max_tokens,omitempty"` + // MaxCompletionsTokens An upper bound for the number of tokens that can be generated for a completion, + // including visible output tokens and reasoning tokens https://platform.openai.com/docs/guides/reasoning + MaxCompletionsTokens int `json:"max_completions_tokens,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Stream bool `json:"stream,omitempty"` + Stop []string `json:"stop,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` + Seed *int `json:"seed,omitempty"` + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` // refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias @@ -364,6 +371,10 @@ func (c *Client) CreateChatCompletion( return } + if err = validateRequestForO1Models(request); err != nil { + return + } + req, err := c.newRequest( ctx, http.MethodPost, diff --git a/chat_stream.go b/chat_stream.go index 3f90bc019..f43d01834 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -60,6 +60,10 @@ func (c *Client) CreateChatCompletionStream( } request.Stream = true + if err = validateRequestForO1Models(request); err != nil { + return + } + req, err := c.newRequest( ctx, http.MethodPost, diff --git a/chat_stream_test.go b/chat_stream_test.go index 63e45ee23..2e7c99b45 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -36,6 +36,27 @@ func TestChatCompletionsStreamWrongModel(t *testing.T) { } } +func TestChatCompletionsStreamWithO1BetaLimitations(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1/chat/completions" + client := openai.NewClientWithConfig(config) + ctx := context.Background() + + req := openai.ChatCompletionRequest{ + Model: openai.O1Preview, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + } + _, err := client.CreateChatCompletionStream(ctx, req) + if !errors.Is(err, openai.ErrO1BetaLimitationsStreaming) { + t.Fatalf("CreateChatCompletion should return ErrO1BetaLimitationsStreaming, but returned: %v", err) + } +} + func TestCreateChatCompletionStream(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() diff --git a/chat_test.go b/chat_test.go index 37dc09d4d..a54dd35e0 100644 --- a/chat_test.go +++ b/chat_test.go @@ -52,6 +52,199 @@ func TestChatCompletionsWrongModel(t *testing.T) { checks.ErrorIs(t, err, openai.ErrChatCompletionInvalidModel, msg) } +func TestO1ModelsChatCompletionsDeprecatedFields(t *testing.T) { + tests := []struct { + name string + in openai.ChatCompletionRequest + expectedError error + }{ + { + name: "o1-preview_MaxTokens_deprecated", + in: openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.O1Preview, + }, + expectedError: openai.ErrO1MaxTokensDeprecated, + }, + { + name: "o1-mini_MaxTokens_deprecated", + in: openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.O1Mini, + }, + expectedError: openai.ErrO1MaxTokensDeprecated, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + ctx := context.Background() + + _, err := client.CreateChatCompletion(ctx, tt.in) + checks.HasError(t, err) + msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err) + checks.ErrorIs(t, err, tt.expectedError, msg) + }) + } +} + +func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { + tests := []struct { + name string + in openai.ChatCompletionRequest + expectedError error + }{ + { + name: "log_probs_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionsTokens: 1000, + LogProbs: true, + Model: openai.O1Preview, + }, + expectedError: openai.ErrO1BetaLimitationsLogprobs, + }, + { + name: "message_type_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionsTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + }, + }, + }, + expectedError: openai.ErrO1BetaLimitationsMessageTypes, + }, + { + name: "tool_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionsTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Tools: []openai.Tool{ + { + Type: openai.ToolTypeFunction, + }, + }, + }, + expectedError: openai.ErrO1BetaLimitationsTools, + }, + { + name: "set_temperature_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionsTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(2), + }, + expectedError: openai.ErrO1BetaLimitationsOther, + }, + { + name: "set_top_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionsTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(0.1), + }, + expectedError: openai.ErrO1BetaLimitationsOther, + }, + { + name: "set_n_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionsTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(1), + N: 2, + }, + expectedError: openai.ErrO1BetaLimitationsOther, + }, + { + name: "set_presence_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionsTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + PresencePenalty: float32(1), + }, + expectedError: openai.ErrO1BetaLimitationsOther, + }, + { + name: "set_frequency_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionsTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + FrequencyPenalty: float32(0.1), + }, + expectedError: openai.ErrO1BetaLimitationsOther, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + ctx := context.Background() + + _, err := client.CreateChatCompletion(ctx, tt.in) + checks.HasError(t, err) + msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err) + checks.ErrorIs(t, err, tt.expectedError, msg) + }) + } +} + func TestChatRequestOmitEmpty(t *testing.T) { data, err := json.Marshal(openai.ChatCompletionRequest{ // We set model b/c it's required, so omitempty doesn't make sense @@ -97,6 +290,24 @@ func TestChatCompletions(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion error") } +// TestCompletions Tests the completions endpoint of the API using the mocked server. +func TestO1ModelChatCompletions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + Model: openai.O1Preview, + MaxCompletionsTokens: 1000, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") +} + // TestCompletions Tests the completions endpoint of the API using the mocked server. func TestChatCompletionsWithHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() diff --git a/completion.go b/completion.go index e1e065a8b..8e3172ace 100644 --- a/completion.go +++ b/completion.go @@ -7,11 +7,20 @@ import ( ) var ( + ErrO1MaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionsTokens") //nolint:lll ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Prompt only supports string and []string") //nolint:lll ) +var ( + ErrO1BetaLimitationsMessageTypes = errors.New("this model has beta-limitations, user and assistant messages only, system messages are not supported") //nolint:lll + ErrO1BetaLimitationsStreaming = errors.New("this model has beta-limitations, streaming not supported") //nolint:lll + ErrO1BetaLimitationsTools = errors.New("this model has beta-limitations, tools, function calling, and response format parameters are not supported") //nolint:lll + ErrO1BetaLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll + ErrO1BetaLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll +) + // GPT3 Defines the models provided by OpenAI to use when generating // completions from OpenAI. // GPT3 Models are designed for text-based tasks. For code-specific @@ -85,6 +94,15 @@ const ( CodexCodeDavinci001 = "code-davinci-001" ) +// O1SeriesModels List of new Series of OpenAI models. +// Some old api attributes not supported. +var O1SeriesModels = map[string]struct{}{ + O1Mini: {}, + O1Mini20240912: {}, + O1Preview: {}, + O1Preview20240912: {}, +} + var disabledModelsForEndpoints = map[string]map[string]bool{ "/completions": { O1Mini: true, @@ -146,6 +164,70 @@ func checkPromptType(prompt any) bool { return isString || isStringSlice } +var unsupportedToolsForO1Models = map[ToolType]struct{}{ + ToolTypeFunction: {}, +} + +var availableMessageRoleForO1Models = map[string]struct{}{ + ChatMessageRoleUser: {}, + ChatMessageRoleAssistant: {}, +} + +// validateRequestForO1Models checks for deprecated fields of OpenAI models. +func validateRequestForO1Models(request ChatCompletionRequest) error { + if _, found := O1SeriesModels[request.Model]; !found { + return nil + } + + if request.MaxTokens > 0 { + return ErrO1MaxTokensDeprecated + } + + // Beta Limitations + // refs:https://platform.openai.com/docs/guides/reasoning/beta-limitations + // Streaming: not supported + if request.Stream { + return ErrO1BetaLimitationsStreaming + } + // Logprobs: not supported. + if request.LogProbs { + return ErrO1BetaLimitationsLogprobs + } + + // Message types: user and assistant messages only, system messages are not supported. + for _, m := range request.Messages { + if _, found := availableMessageRoleForO1Models[m.Role]; !found { + return ErrO1BetaLimitationsMessageTypes + } + } + + // Tools: tools, function calling, and response format parameters are not supported + for _, t := range request.Tools { + if _, found := unsupportedToolsForO1Models[t.Type]; found { + return ErrO1BetaLimitationsTools + } + } + + // Other: temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0. + if request.Temperature > 0 && request.Temperature != 1 { + return ErrO1BetaLimitationsOther + } + if request.TopP > 0 && request.TopP != 1 { + return ErrO1BetaLimitationsOther + } + if request.N > 0 && request.N != 1 { + return ErrO1BetaLimitationsOther + } + if request.PresencePenalty > 0 { + return ErrO1BetaLimitationsOther + } + if request.FrequencyPenalty > 0 { + return ErrO1BetaLimitationsOther + } + + return nil +} + // CompletionRequest represents a request structure for completion API. type CompletionRequest struct { Model string `json:"model"` From 9a4f3a7dbf8f29408848c94cf933d1530ae64526 Mon Sep 17 00:00:00 2001 From: Jialin Tian Date: Sat, 21 Sep 2024 04:49:28 +0800 Subject: [PATCH 051/129] feat: add ParallelToolCalls to RunRequest (#847) --- run.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/run.go b/run.go index 5598f1dfb..0cdec2bdc 100644 --- a/run.go +++ b/run.go @@ -37,6 +37,8 @@ type Run struct { MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // ThreadTruncationStrategy defines the truncation strategy to use for the thread. TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` + // Disable the default behavior of parallel tool calls by setting it: false. + ParallelToolCalls any `json:"parallel_tool_calls,omitempty"` httpHeader } From e095df5325a39ed94940dbe3882d2aa14eb64ad0 Mon Sep 17 00:00:00 2001 From: floodwm Date: Fri, 20 Sep 2024 23:54:25 +0300 Subject: [PATCH 052/129] run_id string Optional (#855) Filter messages by the run ID that generated them. Co-authored-by: wappi --- .zshrc | 0 client_test.go | 2 +- messages.go | 5 +++++ messages_test.go | 5 +++-- 4 files changed, 9 insertions(+), 3 deletions(-) create mode 100644 .zshrc diff --git a/.zshrc b/.zshrc new file mode 100644 index 000000000..e69de29bb diff --git a/client_test.go b/client_test.go index 7119d8a7e..3f27b9dd7 100644 --- a/client_test.go +++ b/client_test.go @@ -340,7 +340,7 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { return client.CreateMessage(ctx, "", MessageRequest{}) }}, {"ListMessage", func() (any, error) { - return client.ListMessage(ctx, "", nil, nil, nil, nil) + return client.ListMessage(ctx, "", nil, nil, nil, nil, nil) }}, {"RetrieveMessage", func() (any, error) { return client.RetrieveMessage(ctx, "", "") diff --git a/messages.go b/messages.go index 1fddd6314..eefc29a36 100644 --- a/messages.go +++ b/messages.go @@ -100,6 +100,7 @@ func (c *Client) ListMessage(ctx context.Context, threadID string, order *string, after *string, before *string, + runID *string, ) (messages MessagesList, err error) { urlValues := url.Values{} if limit != nil { @@ -114,6 +115,10 @@ func (c *Client) ListMessage(ctx context.Context, threadID string, if before != nil { urlValues.Add("before", *before) } + if runID != nil { + urlValues.Add("run_id", *runID) + } + encodedValues := "" if len(urlValues) > 0 { encodedValues = "?" + urlValues.Encode() diff --git a/messages_test.go b/messages_test.go index 71ceb4d3a..b25755f98 100644 --- a/messages_test.go +++ b/messages_test.go @@ -208,7 +208,7 @@ func TestMessages(t *testing.T) { } var msgs openai.MessagesList - msgs, err = client.ListMessage(ctx, threadID, nil, nil, nil, nil) + msgs, err = client.ListMessage(ctx, threadID, nil, nil, nil, nil, nil) checks.NoError(t, err, "ListMessages error") if len(msgs.Messages) != 1 { t.Fatalf("unexpected length of fetched messages") @@ -219,7 +219,8 @@ func TestMessages(t *testing.T) { order := "desc" after := "obj_foo" before := "obj_bar" - msgs, err = client.ListMessage(ctx, threadID, &limit, &order, &after, &before) + runID := "run_abc123" + msgs, err = client.ListMessage(ctx, threadID, &limit, &order, &after, &before, &runID) checks.NoError(t, err, "ListMessages error") if len(msgs.Messages) != 1 { t.Fatalf("unexpected length of fetched messages") From 38bdc812df391bcec3d7defda2a456ea00bb54e5 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Thu, 26 Sep 2024 18:25:56 +0800 Subject: [PATCH 053/129] Optimize Client Error Return (#856) * update client error return * update client_test.go * update client_test.go * update file_api_test.go * update client_test.go * update client_test.go --- client.go | 9 ++++++ client_test.go | 76 +++++++++++++++++++++++++++++++++-------------- error.go | 6 ++-- files_api_test.go | 1 + 4 files changed, 67 insertions(+), 25 deletions(-) diff --git a/client.go b/client.go index 9f547e7cb..583244fe1 100644 --- a/client.go +++ b/client.go @@ -285,10 +285,18 @@ func (c *Client) baseURLWithAzureDeployment(baseURL, suffix, model string) (newB } func (c *Client) handleErrorResp(resp *http.Response) error { + if !strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") { + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("error, reading response body: %w", err) + } + return fmt.Errorf("error, status code: %d, status: %s, body: %s", resp.StatusCode, resp.Status, body) + } var errRes ErrorResponse err := json.NewDecoder(resp.Body).Decode(&errRes) if err != nil || errRes.Error == nil { reqErr := &RequestError{ + HTTPStatus: resp.Status, HTTPStatusCode: resp.StatusCode, Err: err, } @@ -298,6 +306,7 @@ func (c *Client) handleErrorResp(resp *http.Response) error { return reqErr } + errRes.Error.HTTPStatus = resp.Status errRes.Error.HTTPStatusCode = resp.StatusCode return errRes.Error } diff --git a/client_test.go b/client_test.go index 3f27b9dd7..18da787a0 100644 --- a/client_test.go +++ b/client_test.go @@ -134,14 +134,17 @@ func TestHandleErrorResp(t *testing.T) { client := NewClient(mockToken) testCases := []struct { - name string - httpCode int - body io.Reader - expected string + name string + httpCode int + httpStatus string + contentType string + body io.Reader + expected string }{ { - name: "401 Invalid Authentication", - httpCode: http.StatusUnauthorized, + name: "401 Invalid Authentication", + httpCode: http.StatusUnauthorized, + contentType: "application/json", body: bytes.NewReader([]byte( `{ "error":{ @@ -152,11 +155,12 @@ func TestHandleErrorResp(t *testing.T) { } }`, )), - expected: "error, status code: 401, message: You didn't provide an API key. ....", + expected: "error, status code: 401, status: , message: You didn't provide an API key. ....", }, { - name: "401 Azure Access Denied", - httpCode: http.StatusUnauthorized, + name: "401 Azure Access Denied", + httpCode: http.StatusUnauthorized, + contentType: "application/json", body: bytes.NewReader([]byte( `{ "error":{ @@ -165,11 +169,12 @@ func TestHandleErrorResp(t *testing.T) { } }`, )), - expected: "error, status code: 401, message: Access denied due to Virtual Network/Firewall rules.", + expected: "error, status code: 401, status: , message: Access denied due to Virtual Network/Firewall rules.", }, { - name: "503 Model Overloaded", - httpCode: http.StatusServiceUnavailable, + name: "503 Model Overloaded", + httpCode: http.StatusServiceUnavailable, + contentType: "application/json", body: bytes.NewReader([]byte(` { "error":{ @@ -179,22 +184,53 @@ func TestHandleErrorResp(t *testing.T) { "code":null } }`)), - expected: "error, status code: 503, message: That model...", + expected: "error, status code: 503, status: , message: That model...", }, { - name: "503 no message (Unknown response)", - httpCode: http.StatusServiceUnavailable, + name: "503 no message (Unknown response)", + httpCode: http.StatusServiceUnavailable, + contentType: "application/json", body: bytes.NewReader([]byte(` { "error":{} }`)), - expected: "error, status code: 503, message: ", + expected: "error, status code: 503, status: , message: ", + }, + { + name: "413 Request Entity Too Large", + httpCode: http.StatusRequestEntityTooLarge, + contentType: "text/html", + body: bytes.NewReader([]byte(` +413 Request Entity Too Large + +

413 Request Entity Too Large

+
nginx
+ +`)), + expected: `error, status code: 413, status: , body: +413 Request Entity Too Large + +

413 Request Entity Too Large

+
nginx
+ +`, + }, + { + name: "errorReader", + httpCode: http.StatusRequestEntityTooLarge, + contentType: "text/html", + body: &errorReader{err: errors.New("errorReader")}, + expected: "error, reading response body: errorReader", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - testCase := &http.Response{} + testCase := &http.Response{ + Header: map[string][]string{ + "Content-Type": {tc.contentType}, + }, + } testCase.StatusCode = tc.httpCode testCase.Body = io.NopCloser(tc.body) err := client.handleErrorResp(testCase) @@ -203,12 +239,6 @@ func TestHandleErrorResp(t *testing.T) { t.Errorf("Unexpected error: %v , expected: %s", err, tc.expected) t.Fail() } - - e := &APIError{} - if !errors.As(err, &e) { - t.Errorf("(%s) Expected error to be of type APIError", tc.name) - t.Fail() - } }) } } diff --git a/error.go b/error.go index 37959a272..1f6a8971d 100644 --- a/error.go +++ b/error.go @@ -13,6 +13,7 @@ type APIError struct { Message string `json:"message"` Param *string `json:"param,omitempty"` Type string `json:"type"` + HTTPStatus string `json:"-"` HTTPStatusCode int `json:"-"` InnerError *InnerError `json:"innererror,omitempty"` } @@ -25,6 +26,7 @@ type InnerError struct { // RequestError provides information about generic request errors. type RequestError struct { + HTTPStatus string HTTPStatusCode int Err error } @@ -35,7 +37,7 @@ type ErrorResponse struct { func (e *APIError) Error() string { if e.HTTPStatusCode > 0 { - return fmt.Sprintf("error, status code: %d, message: %s", e.HTTPStatusCode, e.Message) + return fmt.Sprintf("error, status code: %d, status: %s, message: %s", e.HTTPStatusCode, e.HTTPStatus, e.Message) } return e.Message @@ -101,7 +103,7 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) { } func (e *RequestError) Error() string { - return fmt.Sprintf("error, status code: %d, message: %s", e.HTTPStatusCode, e.Err) + return fmt.Sprintf("error, status code: %d, status: %s, message: %s", e.HTTPStatusCode, e.HTTPStatus, e.Err) } func (e *RequestError) Unwrap() error { diff --git a/files_api_test.go b/files_api_test.go index c92162a84..aa4fda458 100644 --- a/files_api_test.go +++ b/files_api_test.go @@ -152,6 +152,7 @@ func TestGetFileContentReturnError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) fmt.Fprint(w, wantErrorResp) }) From 7f80303cc393edf2f6806ca37668346f8fa6247e Mon Sep 17 00:00:00 2001 From: Alex Philipp Date: Thu, 26 Sep 2024 05:26:22 -0500 Subject: [PATCH 054/129] Fix max_completion_tokens (#860) The json tag is incorrect, and results in an error from the API when using the o1 model. I didn't modify the struct field name to maintain compatibility if anyone else had started using it, but it wouldn't work for them either. --- chat.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chat.go b/chat.go index d47c95e4f..dd99c530e 100644 --- a/chat.go +++ b/chat.go @@ -209,7 +209,7 @@ type ChatCompletionRequest struct { MaxTokens int `json:"max_tokens,omitempty"` // MaxCompletionsTokens An upper bound for the number of tokens that can be generated for a completion, // including visible output tokens and reasoning tokens https://platform.openai.com/docs/guides/reasoning - MaxCompletionsTokens int `json:"max_completions_tokens,omitempty"` + MaxCompletionsTokens int `json:"max_completion_tokens,omitempty"` Temperature float32 `json:"temperature,omitempty"` TopP float32 `json:"top_p,omitempty"` N int `json:"n,omitempty"` From e9d8485e90092b8adcce82fdd0dcd7cf10327e8d Mon Sep 17 00:00:00 2001 From: Jialin Tian Date: Thu, 26 Sep 2024 18:26:54 +0800 Subject: [PATCH 055/129] fix: ParallelToolCalls should be added to RunRequest (#861) --- run.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/run.go b/run.go index 0cdec2bdc..d3e755f05 100644 --- a/run.go +++ b/run.go @@ -37,8 +37,6 @@ type Run struct { MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // ThreadTruncationStrategy defines the truncation strategy to use for the thread. TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` - // Disable the default behavior of parallel tool calls by setting it: false. - ParallelToolCalls any `json:"parallel_tool_calls,omitempty"` httpHeader } @@ -112,6 +110,8 @@ type RunRequest struct { ToolChoice any `json:"tool_choice,omitempty"` // This can be either a string or a ResponseFormat object. ResponseFormat any `json:"response_format,omitempty"` + // Disable the default behavior of parallel tool calls by setting it: false. + ParallelToolCalls any `json:"parallel_tool_calls,omitempty"` } // ThreadTruncationStrategy defines the truncation strategy to use for the thread. From fdd59d93413154cd07b2e46a428b15eda40b26e2 Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Thu, 26 Sep 2024 18:30:56 +0800 Subject: [PATCH 056/129] feat: usage struct add CompletionTokensDetails (#863) --- common.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/common.go b/common.go index cbfda4e3c..cde14154a 100644 --- a/common.go +++ b/common.go @@ -4,7 +4,13 @@ package openai // Usage Represents the total token usage per request to OpenAI. type Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details"` +} + +// CompletionTokensDetails Breakdown of tokens used in a completion. +type CompletionTokensDetails struct { + ReasoningTokens int `json:"reasoning_tokens"` } From bac7d5936108965a9666a65d0d4d55bd0fe78808 Mon Sep 17 00:00:00 2001 From: Winston Liu Date: Thu, 3 Oct 2024 12:17:16 -0700 Subject: [PATCH 057/129] fix MaxCompletionTokens typo (#862) * fix spelling error * fix lint * Update chat.go * Update chat.go --- chat.go | 22 +++++++++++----------- chat_test.go | 38 +++++++++++++++++++------------------- completion.go | 2 +- 3 files changed, 31 insertions(+), 31 deletions(-) diff --git a/chat.go b/chat.go index dd99c530e..9adf2808d 100644 --- a/chat.go +++ b/chat.go @@ -207,18 +207,18 @@ type ChatCompletionRequest struct { // This value is now deprecated in favor of max_completion_tokens, and is not compatible with o1 series models. // refs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_tokens MaxTokens int `json:"max_tokens,omitempty"` - // MaxCompletionsTokens An upper bound for the number of tokens that can be generated for a completion, + // MaxCompletionTokens An upper bound for the number of tokens that can be generated for a completion, // including visible output tokens and reasoning tokens https://platform.openai.com/docs/guides/reasoning - MaxCompletionsTokens int `json:"max_completion_tokens,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Stream bool `json:"stream,omitempty"` - Stop []string `json:"stop,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` - Seed *int `json:"seed,omitempty"` - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Stream bool `json:"stream,omitempty"` + Stop []string `json:"stop,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` + Seed *int `json:"seed,omitempty"` + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` // refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias diff --git a/chat_test.go b/chat_test.go index a54dd35e0..134026cdb 100644 --- a/chat_test.go +++ b/chat_test.go @@ -100,17 +100,17 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { { name: "log_probs_unsupported", in: openai.ChatCompletionRequest{ - MaxCompletionsTokens: 1000, - LogProbs: true, - Model: openai.O1Preview, + MaxCompletionTokens: 1000, + LogProbs: true, + Model: openai.O1Preview, }, expectedError: openai.ErrO1BetaLimitationsLogprobs, }, { name: "message_type_unsupported", in: openai.ChatCompletionRequest{ - MaxCompletionsTokens: 1000, - Model: openai.O1Mini, + MaxCompletionTokens: 1000, + Model: openai.O1Mini, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleSystem, @@ -122,8 +122,8 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { { name: "tool_unsupported", in: openai.ChatCompletionRequest{ - MaxCompletionsTokens: 1000, - Model: openai.O1Mini, + MaxCompletionTokens: 1000, + Model: openai.O1Mini, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, @@ -143,8 +143,8 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { { name: "set_temperature_unsupported", in: openai.ChatCompletionRequest{ - MaxCompletionsTokens: 1000, - Model: openai.O1Mini, + MaxCompletionTokens: 1000, + Model: openai.O1Mini, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, @@ -160,8 +160,8 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { { name: "set_top_unsupported", in: openai.ChatCompletionRequest{ - MaxCompletionsTokens: 1000, - Model: openai.O1Mini, + MaxCompletionTokens: 1000, + Model: openai.O1Mini, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, @@ -178,8 +178,8 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { { name: "set_n_unsupported", in: openai.ChatCompletionRequest{ - MaxCompletionsTokens: 1000, - Model: openai.O1Mini, + MaxCompletionTokens: 1000, + Model: openai.O1Mini, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, @@ -197,8 +197,8 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { { name: "set_presence_penalty_unsupported", in: openai.ChatCompletionRequest{ - MaxCompletionsTokens: 1000, - Model: openai.O1Mini, + MaxCompletionTokens: 1000, + Model: openai.O1Mini, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, @@ -214,8 +214,8 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { { name: "set_frequency_penalty_unsupported", in: openai.ChatCompletionRequest{ - MaxCompletionsTokens: 1000, - Model: openai.O1Mini, + MaxCompletionTokens: 1000, + Model: openai.O1Mini, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, @@ -296,8 +296,8 @@ func TestO1ModelChatCompletions(t *testing.T) { defer teardown() server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ - Model: openai.O1Preview, - MaxCompletionsTokens: 1000, + Model: openai.O1Preview, + MaxCompletionTokens: 1000, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, diff --git a/completion.go b/completion.go index 8e3172ace..80c4d39ae 100644 --- a/completion.go +++ b/completion.go @@ -7,7 +7,7 @@ import ( ) var ( - ErrO1MaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionsTokens") //nolint:lll + ErrO1MaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens") //nolint:lll ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Prompt only supports string and []string") //nolint:lll From 7c145ebb4be68610bc3bb5377b754944307d44fd Mon Sep 17 00:00:00 2001 From: Julio Martins <89476495+juliomartinsdev@users.noreply.github.com> Date: Thu, 3 Oct 2024 16:19:48 -0300 Subject: [PATCH 058/129] add jailbreak filter result, add ContentFilterResults on output (#864) * add jailbreak filter result * add content filter results on completion output * add profanity content filter --- chat.go | 40 +++++++++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/chat.go b/chat.go index 9adf2808d..a7dee8e03 100644 --- a/chat.go +++ b/chat.go @@ -41,11 +41,23 @@ type Violence struct { Severity string `json:"severity,omitempty"` } +type JailBreak struct { + Filtered bool `json:"filtered"` + Detected bool `json:"detected"` +} + +type Profanity struct { + Filtered bool `json:"filtered"` + Detected bool `json:"detected"` +} + type ContentFilterResults struct { - Hate Hate `json:"hate,omitempty"` - SelfHarm SelfHarm `json:"self_harm,omitempty"` - Sexual Sexual `json:"sexual,omitempty"` - Violence Violence `json:"violence,omitempty"` + Hate Hate `json:"hate,omitempty"` + SelfHarm SelfHarm `json:"self_harm,omitempty"` + Sexual Sexual `json:"sexual,omitempty"` + Violence Violence `json:"violence,omitempty"` + JailBreak JailBreak `json:"jailbreak,omitempty"` + Profanity Profanity `json:"profanity,omitempty"` } type PromptAnnotation struct { @@ -338,19 +350,21 @@ type ChatCompletionChoice struct { // function_call: The model decided to call a function // content_filter: Omitted content due to a flag from our content filters // null: API response still in progress or incomplete - FinishReason FinishReason `json:"finish_reason"` - LogProbs *LogProbs `json:"logprobs,omitempty"` + FinishReason FinishReason `json:"finish_reason"` + LogProbs *LogProbs `json:"logprobs,omitempty"` + ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` } // ChatCompletionResponse represents a response structure for chat completion API. type ChatCompletionResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []ChatCompletionChoice `json:"choices"` - Usage Usage `json:"usage"` - SystemFingerprint string `json:"system_fingerprint"` + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionChoice `json:"choices"` + Usage Usage `json:"usage"` + SystemFingerprint string `json:"system_fingerprint"` + PromptFilterResults []PromptFilterResult `json:"prompt_filter_results,omitempty"` httpHeader } From 991326480f84981b6e89032b9f9710a3a83a6f0f Mon Sep 17 00:00:00 2001 From: Isaac Seymour Date: Wed, 9 Oct 2024 10:50:27 +0100 Subject: [PATCH 059/129] Completion API: add new params (#870) * Completion API: add 'store' param This param allows you to opt a completion request in to being stored, for use in distillations and evals. * Add cached and audio tokens to usage structs These have been added to the completions API recently: https://platform.openai.com/docs/api-reference/chat/object#chat/object-usage --- common.go | 8 ++++++++ completion.go | 27 +++++++++++++++------------ 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/common.go b/common.go index cde14154a..8cc7289c0 100644 --- a/common.go +++ b/common.go @@ -7,10 +7,18 @@ type Usage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` + PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details"` CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details"` } // CompletionTokensDetails Breakdown of tokens used in a completion. type CompletionTokensDetails struct { + AudioTokens int `json:"audio_tokens"` ReasoningTokens int `json:"reasoning_tokens"` } + +// PromptTokensDetails Breakdown of tokens used in the prompt. +type PromptTokensDetails struct { + AudioTokens int `json:"audio_tokens"` + CachedTokens int `json:"cached_tokens"` +} diff --git a/completion.go b/completion.go index 80c4d39ae..afcf84671 100644 --- a/completion.go +++ b/completion.go @@ -238,18 +238,21 @@ type CompletionRequest struct { // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` // refs: https://platform.openai.com/docs/api-reference/completions/create#completions/create-logit_bias - LogitBias map[string]int `json:"logit_bias,omitempty"` - LogProbs int `json:"logprobs,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - N int `json:"n,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - Seed *int `json:"seed,omitempty"` - Stop []string `json:"stop,omitempty"` - Stream bool `json:"stream,omitempty"` - Suffix string `json:"suffix,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - User string `json:"user,omitempty"` + LogitBias map[string]int `json:"logit_bias,omitempty"` + // Store can be set to true to store the output of this completion request for use in distillations and evals. + // https://platform.openai.com/docs/api-reference/chat/create#chat-create-store + Store bool `json:"store,omitempty"` + LogProbs int `json:"logprobs,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + Seed *int `json:"seed,omitempty"` + Stop []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + Suffix string `json:"suffix,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + User string `json:"user,omitempty"` } // CompletionChoice represents one of possible completions. From cfe15ffd00bb908c32cf0d9e277786a14afdd2c7 Mon Sep 17 00:00:00 2001 From: Ayush Sawant Date: Mon, 14 Oct 2024 18:50:39 +0530 Subject: [PATCH 060/129] return response body as byte slice for RequestError type (#873) --- client.go | 11 ++++++----- error.go | 1 + 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/client.go b/client.go index 583244fe1..1e228a097 100644 --- a/client.go +++ b/client.go @@ -285,20 +285,21 @@ func (c *Client) baseURLWithAzureDeployment(baseURL, suffix, model string) (newB } func (c *Client) handleErrorResp(resp *http.Response) error { + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("error, reading response body: %w", err) + } if !strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") { - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("error, reading response body: %w", err) - } return fmt.Errorf("error, status code: %d, status: %s, body: %s", resp.StatusCode, resp.Status, body) } var errRes ErrorResponse - err := json.NewDecoder(resp.Body).Decode(&errRes) + err = json.Unmarshal(body, &errRes) if err != nil || errRes.Error == nil { reqErr := &RequestError{ HTTPStatus: resp.Status, HTTPStatusCode: resp.StatusCode, Err: err, + Body: body, } if errRes.Error != nil { reqErr.Err = errRes.Error diff --git a/error.go b/error.go index 1f6a8971d..fc9e7cdb9 100644 --- a/error.go +++ b/error.go @@ -29,6 +29,7 @@ type RequestError struct { HTTPStatus string HTTPStatusCode int Err error + Body []byte } type ErrorResponse struct { From 21f713457449b1ab386529b9495cbf1f27c0db5a Mon Sep 17 00:00:00 2001 From: Matt Jacobs Date: Mon, 14 Oct 2024 09:21:39 -0400 Subject: [PATCH 061/129] Adding new moderation model constants (#875) --- moderation.go | 12 ++++++++---- moderation_test.go | 2 ++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/moderation.go b/moderation.go index c8652efc8..a0e09c0ee 100644 --- a/moderation.go +++ b/moderation.go @@ -14,8 +14,10 @@ import ( // If you use text-moderation-stable, we will provide advanced notice before updating the model. // Accuracy of text-moderation-stable may be slightly lower than for text-moderation-latest. const ( - ModerationTextStable = "text-moderation-stable" - ModerationTextLatest = "text-moderation-latest" + ModerationOmniLatest = "omni-moderation-latest" + ModerationOmni20240926 = "omni-moderation-2024-09-26" + ModerationTextStable = "text-moderation-stable" + ModerationTextLatest = "text-moderation-latest" // Deprecated: use ModerationTextStable and ModerationTextLatest instead. ModerationText001 = "text-moderation-001" ) @@ -25,8 +27,10 @@ var ( ) var validModerationModel = map[string]struct{}{ - ModerationTextStable: {}, - ModerationTextLatest: {}, + ModerationOmniLatest: {}, + ModerationOmni20240926: {}, + ModerationTextStable: {}, + ModerationTextLatest: {}, } // ModerationRequest represents a request structure for moderation API. diff --git a/moderation_test.go b/moderation_test.go index 61171c384..a97f25bc6 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -37,6 +37,8 @@ func TestModerationsWithDifferentModelOptions(t *testing.T) { getModerationModelTestOption(openai.GPT3Dot5Turbo, openai.ErrModerationInvalidModel), getModerationModelTestOption(openai.ModerationTextStable, nil), getModerationModelTestOption(openai.ModerationTextLatest, nil), + getModerationModelTestOption(openai.ModerationOmni20240926, nil), + getModerationModelTestOption(openai.ModerationOmniLatest, nil), getModerationModelTestOption("", nil), ) client, server, teardown := setupOpenAITestServer() From b162541513db0cf3d4d48da03be22b05861269cb Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Tue, 15 Oct 2024 20:09:34 +0100 Subject: [PATCH 062/129] Cleanup (#879) * remove obsolete files * update readme --- .zshrc | 0 Makefile | 35 ----------------------------------- README.md | 2 +- 3 files changed, 1 insertion(+), 36 deletions(-) delete mode 100644 .zshrc delete mode 100644 Makefile diff --git a/.zshrc b/.zshrc deleted file mode 100644 index e69de29bb..000000000 diff --git a/Makefile b/Makefile deleted file mode 100644 index 2e608aa0c..000000000 --- a/Makefile +++ /dev/null @@ -1,35 +0,0 @@ -##@ General - -# The help target prints out all targets with their descriptions organized -# beneath their categories. The categories are represented by '##@' and the -# target descriptions by '##'. The awk commands is responsible for reading the -# entire set of makefiles included in this invocation, looking for lines of the -# file as xyz: ## something, and then pretty-format the target and help. Then, -# if there's a line with ##@ something, that gets pretty-printed as a category. -# More info on the usage of ANSI control characters for terminal formatting: -# https://en.wikipedia.org/wiki/ANSI_escape_code#SGR_parameters -# More info on the awk command: -# http://linuxcommand.org/lc3_adv_awk.php - -.PHONY: help -help: ## Display this help. - @awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n"} /^[a-zA-Z_0-9-]+:.*?##/ { printf " \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST) - - -##@ Development - -.PHONY: test -TEST_ARGS ?= -v -TEST_TARGETS ?= ./... -test: ## Test the Go modules within this package. - @ echo ▶️ go test $(TEST_ARGS) $(TEST_TARGETS) - go test $(TEST_ARGS) $(TEST_TARGETS) - @ echo ✅ success! - - -.PHONY: lint -LINT_TARGETS ?= ./... -lint: ## Lint Go code with the installed golangci-lint - @ echo "▶️ golangci-lint run" - golangci-lint run $(LINT_TARGETS) - @ echo "✅ golangci-lint run" diff --git a/README.md b/README.md index b3ebc1471..57d1d35bf 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ This library provides unofficial Go clients for [OpenAI API](https://platform.openai.com/). We support: -* ChatGPT +* ChatGPT 4o, o1 * GPT-3, GPT-4 * DALL·E 2, DALL·E 3 * Whisper From 9fe2c6ce1f5b756cd172ae9a7786beea69b2956f Mon Sep 17 00:00:00 2001 From: Sander Mack-Crane <71154168+smackcrane@users.noreply.github.com> Date: Tue, 15 Oct 2024 14:16:57 -0600 Subject: [PATCH 063/129] Completion API: add Store and Metadata parameters (#878) --- chat.go | 5 +++++ completion.go | 26 ++++++++++++++------------ 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/chat.go b/chat.go index a7dee8e03..2b13f8dd7 100644 --- a/chat.go +++ b/chat.go @@ -255,6 +255,11 @@ type ChatCompletionRequest struct { StreamOptions *StreamOptions `json:"stream_options,omitempty"` // Disable the default behavior of parallel tool calls by setting it: false. ParallelToolCalls any `json:"parallel_tool_calls,omitempty"` + // Store can be set to true to store the output of this completion request for use in distillations and evals. + // https://platform.openai.com/docs/api-reference/chat/create#chat-create-store + Store bool `json:"store,omitempty"` + // Metadata to store with the completion. + Metadata map[string]string `json:"metadata,omitempty"` } type StreamOptions struct { diff --git a/completion.go b/completion.go index afcf84671..84ef2ad26 100644 --- a/completion.go +++ b/completion.go @@ -241,18 +241,20 @@ type CompletionRequest struct { LogitBias map[string]int `json:"logit_bias,omitempty"` // Store can be set to true to store the output of this completion request for use in distillations and evals. // https://platform.openai.com/docs/api-reference/chat/create#chat-create-store - Store bool `json:"store,omitempty"` - LogProbs int `json:"logprobs,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - N int `json:"n,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - Seed *int `json:"seed,omitempty"` - Stop []string `json:"stop,omitempty"` - Stream bool `json:"stream,omitempty"` - Suffix string `json:"suffix,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - User string `json:"user,omitempty"` + Store bool `json:"store,omitempty"` + // Metadata to store with the completion. + Metadata map[string]string `json:"metadata,omitempty"` + LogProbs int `json:"logprobs,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + Seed *int `json:"seed,omitempty"` + Stop []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + Suffix string `json:"suffix,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + User string `json:"user,omitempty"` } // CompletionChoice represents one of possible completions. From fb15ff9dcd861e601fc2c54078aac2bbd3c06ce8 Mon Sep 17 00:00:00 2001 From: Ayush Sawant Date: Tue, 22 Oct 2024 02:19:34 +0530 Subject: [PATCH 064/129] Handling for non-json response (#881) * removed handling for non-json response * added response body in RequestError.Error() and updated tests * done linting --- client.go | 3 --- client_test.go | 35 ++++++++++++++++++++--------------- error.go | 5 ++++- 3 files changed, 24 insertions(+), 19 deletions(-) diff --git a/client.go b/client.go index 1e228a097..ed8595e0b 100644 --- a/client.go +++ b/client.go @@ -289,9 +289,6 @@ func (c *Client) handleErrorResp(resp *http.Response) error { if err != nil { return fmt.Errorf("error, reading response body: %w", err) } - if !strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") { - return fmt.Errorf("error, status code: %d, status: %s, body: %s", resp.StatusCode, resp.Status, body) - } var errRes ErrorResponse err = json.Unmarshal(body, &errRes) if err != nil || errRes.Error == nil { diff --git a/client_test.go b/client_test.go index 18da787a0..354a6b3f5 100644 --- a/client_test.go +++ b/client_test.go @@ -194,26 +194,31 @@ func TestHandleErrorResp(t *testing.T) { { "error":{} }`)), - expected: "error, status code: 503, status: , message: ", + expected: `error, status code: 503, status: , message: , body: + { + "error":{} + }`, }, { name: "413 Request Entity Too Large", httpCode: http.StatusRequestEntityTooLarge, contentType: "text/html", - body: bytes.NewReader([]byte(` -413 Request Entity Too Large - -

413 Request Entity Too Large

-
nginx
- -`)), - expected: `error, status code: 413, status: , body: -413 Request Entity Too Large - -

413 Request Entity Too Large

-
nginx
- -`, + body: bytes.NewReader([]byte(` + + 413 Request Entity Too Large + +

413 Request Entity Too Large

+
nginx
+ + `)), + expected: `error, status code: 413, status: , message: invalid character '<' looking for beginning of value, body: + + 413 Request Entity Too Large + +

413 Request Entity Too Large

+
nginx
+ + `, }, { name: "errorReader", diff --git a/error.go b/error.go index fc9e7cdb9..8a74bd52c 100644 --- a/error.go +++ b/error.go @@ -104,7 +104,10 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) { } func (e *RequestError) Error() string { - return fmt.Sprintf("error, status code: %d, status: %s, message: %s", e.HTTPStatusCode, e.HTTPStatus, e.Err) + return fmt.Sprintf( + "error, status code: %d, status: %s, message: %s, body: %s", + e.HTTPStatusCode, e.HTTPStatus, e.Err, e.Body, + ) } func (e *RequestError) Unwrap() error { From 3672c0dec601f89037d8d54e7df653d7df1f0c83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edin=20=C4=86orali=C4=87?= <73831203+ecoralic@users.noreply.github.com> Date: Mon, 21 Oct 2024 22:57:02 +0200 Subject: [PATCH 065/129] fix: Updated Assistent struct with latest fields based on OpenAI docs (#883) --- assistant.go | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/assistant.go b/assistant.go index 4c89c1b2f..8aab5bcf0 100644 --- a/assistant.go +++ b/assistant.go @@ -14,17 +14,20 @@ const ( ) type Assistant struct { - ID string `json:"id"` - Object string `json:"object"` - CreatedAt int64 `json:"created_at"` - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` - Model string `json:"model"` - Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"tools"` - FileIDs []string `json:"file_ids,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` - ToolResources *AssistantToolResource `json:"tool_resources,omitempty"` + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Model string `json:"model"` + Instructions *string `json:"instructions,omitempty"` + Tools []AssistantTool `json:"tools"` + ToolResources *AssistantToolResource `json:"tool_resources,omitempty"` + FileIDs []string `json:"file_ids,omitempty"` // Deprecated in v2 + Metadata map[string]any `json:"metadata,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` + TopP *float32 `json:"top_p,omitempty"` + ResponseFormat any `json:"response_format,omitempty"` httpHeader } From 6e087322b77693e6e9227d9950a0c8d8a10a8d1a Mon Sep 17 00:00:00 2001 From: Ayush Sawant Date: Fri, 25 Oct 2024 19:11:45 +0530 Subject: [PATCH 066/129] Updated checkPromptType function to handle prompt list in completions (#885) * updated checkPromptType function to handle prompt list in completions * removed generated test file * added corresponding unit testcases * Updated to use less nesting with early returns --- completion.go | 18 ++++++++++- completion_test.go | 78 ++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 85 insertions(+), 11 deletions(-) diff --git a/completion.go b/completion.go index 84ef2ad26..77ea8c3ab 100644 --- a/completion.go +++ b/completion.go @@ -161,7 +161,23 @@ func checkEndpointSupportsModel(endpoint, model string) bool { func checkPromptType(prompt any) bool { _, isString := prompt.(string) _, isStringSlice := prompt.([]string) - return isString || isStringSlice + if isString || isStringSlice { + return true + } + + // check if it is prompt is []string hidden under []any + slice, isSlice := prompt.([]any) + if !isSlice { + return false + } + + for _, item := range slice { + _, itemIsString := item.(string) + if !itemIsString { + return false + } + } + return true // all items in the slice are string, so it is []string } var unsupportedToolsForO1Models = map[ToolType]struct{}{ diff --git a/completion_test.go b/completion_test.go index 89950bf94..935bbe864 100644 --- a/completion_test.go +++ b/completion_test.go @@ -59,6 +59,38 @@ func TestCompletions(t *testing.T) { checks.NoError(t, err, "CreateCompletion error") } +// TestMultiplePromptsCompletionsWrong Tests the completions endpoint of the API using the mocked server +// where the completions requests has a list of prompts with wrong type. +func TestMultiplePromptsCompletionsWrong(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", handleCompletionEndpoint) + req := openai.CompletionRequest{ + MaxTokens: 5, + Model: "ada", + Prompt: []interface{}{"Lorem ipsum", 9}, + } + _, err := client.CreateCompletion(context.Background(), req) + if !errors.Is(err, openai.ErrCompletionRequestPromptTypeNotSupported) { + t.Fatalf("CreateCompletion should return ErrCompletionRequestPromptTypeNotSupported, but returned: %v", err) + } +} + +// TestMultiplePromptsCompletions Tests the completions endpoint of the API using the mocked server +// where the completions requests has a list of prompts. +func TestMultiplePromptsCompletions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", handleCompletionEndpoint) + req := openai.CompletionRequest{ + MaxTokens: 5, + Model: "ada", + Prompt: []interface{}{"Lorem ipsum", "Lorem ipsum"}, + } + _, err := client.CreateCompletion(context.Background(), req) + checks.NoError(t, err, "CreateCompletion error") +} + // handleCompletionEndpoint Handles the completion endpoint by the test server. func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { var err error @@ -87,24 +119,50 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { if n == 0 { n = 1 } + // Handle different types of prompts: single string or list of strings + prompts := []string{} + switch v := completionReq.Prompt.(type) { + case string: + prompts = append(prompts, v) + case []interface{}: + for _, item := range v { + if str, ok := item.(string); ok { + prompts = append(prompts, str) + } + } + default: + http.Error(w, "Invalid prompt type", http.StatusBadRequest) + return + } + for i := 0; i < n; i++ { - // generate a random string of length completionReq.Length - completionStr := strings.Repeat("a", completionReq.MaxTokens) - if completionReq.Echo { - completionStr = completionReq.Prompt.(string) + completionStr + for _, prompt := range prompts { + // Generate a random string of length completionReq.MaxTokens + completionStr := strings.Repeat("a", completionReq.MaxTokens) + if completionReq.Echo { + completionStr = prompt + completionStr + } + + res.Choices = append(res.Choices, openai.CompletionChoice{ + Text: completionStr, + Index: len(res.Choices), + }) } - res.Choices = append(res.Choices, openai.CompletionChoice{ - Text: completionStr, - Index: i, - }) } - inputTokens := numTokens(completionReq.Prompt.(string)) * n - completionTokens := completionReq.MaxTokens * n + + inputTokens := 0 + for _, prompt := range prompts { + inputTokens += numTokens(prompt) + } + inputTokens *= n + completionTokens := completionReq.MaxTokens * len(prompts) * n res.Usage = openai.Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, } + + // Serialize the response and send it back resBytes, _ = json.Marshal(res) fmt.Fprintln(w, string(resBytes)) } From d10f1b81995ddce1aacacfa671d79f2784a68ef4 Mon Sep 17 00:00:00 2001 From: genglixia <62233468+Yu0u@users.noreply.github.com> Date: Tue, 29 Oct 2024 15:22:52 +0800 Subject: [PATCH 067/129] add chatcompletion stream delta refusal and logprobs (#882) * add chatcompletion stream refusal and logprobs * fix slice to struct * add integration test * fix lint * fix lint * fix: the object should be pointer --------- Co-authored-by: genglixia --- chat_stream.go | 28 ++++- chat_stream_test.go | 265 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 289 insertions(+), 4 deletions(-) diff --git a/chat_stream.go b/chat_stream.go index f43d01834..58b2651c0 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -10,13 +10,33 @@ type ChatCompletionStreamChoiceDelta struct { Role string `json:"role,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"` + Refusal string `json:"refusal,omitempty"` +} + +type ChatCompletionStreamChoiceLogprobs struct { + Content []ChatCompletionTokenLogprob `json:"content,omitempty"` + Refusal []ChatCompletionTokenLogprob `json:"refusal,omitempty"` +} + +type ChatCompletionTokenLogprob struct { + Token string `json:"token"` + Bytes []int64 `json:"bytes,omitempty"` + Logprob float64 `json:"logprob,omitempty"` + TopLogprobs []ChatCompletionTokenLogprobTopLogprob `json:"top_logprobs"` +} + +type ChatCompletionTokenLogprobTopLogprob struct { + Token string `json:"token"` + Bytes []int64 `json:"bytes"` + Logprob float64 `json:"logprob"` } type ChatCompletionStreamChoice struct { - Index int `json:"index"` - Delta ChatCompletionStreamChoiceDelta `json:"delta"` - FinishReason FinishReason `json:"finish_reason"` - ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` + Index int `json:"index"` + Delta ChatCompletionStreamChoiceDelta `json:"delta"` + Logprobs *ChatCompletionStreamChoiceLogprobs `json:"logprobs,omitempty"` + FinishReason FinishReason `json:"finish_reason"` + ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` } type PromptFilterResult struct { diff --git a/chat_stream_test.go b/chat_stream_test.go index 2e7c99b45..14684146c 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -358,6 +358,271 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { t.Logf("%+v\n", apiErr) } +func TestCreateChatCompletionStreamWithRefusal(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + dataBytes := []byte{} + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"1","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"2","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"refusal":"Hello"},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"3","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"refusal":" World"},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"4","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 2000, + Model: openai.GPT4oMini20240718, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "1", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + }, + }, + }, + { + ID: "2", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Refusal: "Hello", + }, + }, + }, + }, + { + ID: "3", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Refusal: " World", + }, + }, + }, + }, + { + ID: "4", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + FinishReason: "stop", + }, + }, + }, + } + + for ix, expectedResponse := range expectedResponses { + b, _ := json.Marshal(expectedResponse) + t.Logf("%d: %s", ix, string(b)) + + receivedResponse, streamErr := stream.Recv() + checks.NoError(t, streamErr, "stream.Recv() failed") + if !compareChatResponses(expectedResponse, receivedResponse) { + t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) + } + } + + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } +} + +func TestCreateChatCompletionStreamWithLogprobs(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + dataBytes := []byte{} + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"1","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":{"content":[],"refusal":null},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"2","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":{"content":[{"token":"Hello","logprob":-0.000020458236,"bytes":[72,101,108,108,111],"top_logprobs":[]}],"refusal":null},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"3","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":" World"},"logprobs":{"content":[{"token":" World","logprob":-0.00055303273,"bytes":[32,87,111,114,108,100],"top_logprobs":[]}],"refusal":null},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"4","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 2000, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "1", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + Logprobs: &openai.ChatCompletionStreamChoiceLogprobs{ + Content: []openai.ChatCompletionTokenLogprob{}, + }, + }, + }, + }, + { + ID: "2", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "Hello", + }, + Logprobs: &openai.ChatCompletionStreamChoiceLogprobs{ + Content: []openai.ChatCompletionTokenLogprob{ + { + Token: "Hello", + Logprob: -0.000020458236, + Bytes: []int64{72, 101, 108, 108, 111}, + TopLogprobs: []openai.ChatCompletionTokenLogprobTopLogprob{}, + }, + }, + }, + }, + }, + }, + { + ID: "3", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: " World", + }, + Logprobs: &openai.ChatCompletionStreamChoiceLogprobs{ + Content: []openai.ChatCompletionTokenLogprob{ + { + Token: " World", + Logprob: -0.00055303273, + Bytes: []int64{32, 87, 111, 114, 108, 100}, + TopLogprobs: []openai.ChatCompletionTokenLogprobTopLogprob{}, + }, + }, + }, + }, + }, + }, + { + ID: "4", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + FinishReason: "stop", + }, + }, + }, + } + + for ix, expectedResponse := range expectedResponses { + b, _ := json.Marshal(expectedResponse) + t.Logf("%d: %s", ix, string(b)) + + receivedResponse, streamErr := stream.Recv() + checks.NoError(t, streamErr, "stream.Recv() failed") + if !compareChatResponses(expectedResponse, receivedResponse) { + t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) + } + } + + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } +} + func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { wantCode := "429" wantMessage := "Requests to the Creates a completion for the chat message Operation under Azure OpenAI API " + From f5e6e0e4fed1284bafa4805f6487e5b5f8a4ccd1 Mon Sep 17 00:00:00 2001 From: Matt Davis Date: Fri, 8 Nov 2024 08:53:02 -0500 Subject: [PATCH 068/129] Added Vector Store File List properties that allow for pagination (#891) --- vector_store.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vector_store.go b/vector_store.go index 5c364362a..682bb1cf9 100644 --- a/vector_store.go +++ b/vector_store.go @@ -83,6 +83,9 @@ type VectorStoreFileRequest struct { type VectorStoreFilesList struct { VectorStoreFiles []VectorStoreFile `json:"data"` + FirstID *string `json:"first_id"` + LastID *string `json:"last_id"` + HasMore bool `json:"has_more"` httpHeader } From 6d066bb12dfbaa3cefa83f204c431fb0d0ef02fa Mon Sep 17 00:00:00 2001 From: Denny Depok <61371551+kodernubie@users.noreply.github.com> Date: Fri, 8 Nov 2024 20:54:27 +0700 Subject: [PATCH 069/129] Support Attachments in MessageRequest (#890) * add attachments in MessageRequest * Move tools const to message * remove const, just use assistanttool const --- messages.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/messages.go b/messages.go index eefc29a36..902363938 100644 --- a/messages.go +++ b/messages.go @@ -52,10 +52,11 @@ type ImageFile struct { } type MessageRequest struct { - Role string `json:"role"` - Content string `json:"content"` - FileIds []string `json:"file_ids,omitempty"` //nolint:revive // backwards-compatibility - Metadata map[string]any `json:"metadata,omitempty"` + Role string `json:"role"` + Content string `json:"content"` + FileIds []string `json:"file_ids,omitempty"` //nolint:revive // backwards-compatibility + Metadata map[string]any `json:"metadata,omitempty"` + Attachments []ThreadAttachment `json:"attachments,omitempty"` } type MessageFile struct { From b3ece4d32e9416105bc2427b735448e82abd448b Mon Sep 17 00:00:00 2001 From: Ayush Sawant Date: Wed, 20 Nov 2024 02:07:10 +0530 Subject: [PATCH 070/129] Updated client_test to solve lint error (#900) * updated client_test to solve lint error * modified golangci yml to solve linter issues * minor change --- .golangci.yml | 6 +++--- client_test.go | 10 ++++++++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 58fab4a20..724cb7375 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -57,7 +57,7 @@ linters-settings: # Default: true skipRecvDeref: false - gomnd: + mnd: # List of function patterns to exclude from analysis. # Values always ignored: `time.Date` # Default: [] @@ -167,7 +167,7 @@ linters: - durationcheck # check for two durations multiplied together - errname # Checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error. - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. - - execinquery # execinquery is a linter about query string checker in Query function which reads your Go src files and warning it finds + # Removed execinquery (deprecated). execinquery is a linter about query string checker in Query function which reads your Go src files and warning it finds - exhaustive # check exhaustiveness of enum switch statements - exportloopref # checks for pointers to enclosing loop variables - forbidigo # Forbids identifiers @@ -180,7 +180,6 @@ linters: - gocyclo # Computes and checks the cyclomatic complexity of functions - godot # Check if comments end in a period - goimports # In addition to fixing imports, goimports also formats your code in the same style as gofmt. - - gomnd # An analyzer to detect magic numbers. - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - goprintffuncname # Checks that printf-like functions are named with f at the end @@ -188,6 +187,7 @@ linters: - lll # Reports long lines - makezero # Finds slice declarations with non-zero initial length # - nakedret # Finds naked returns in functions greater than a specified function length + - mnd # An analyzer to detect magic numbers. - nestif # Reports deeply nested if statements - nilerr # Finds the code that returns nil even if it checks that the error is not nil. - nilnil # Checks that there is no simultaneous return of nil error and an invalid value. diff --git a/client_test.go b/client_test.go index 354a6b3f5..2ed82f13c 100644 --- a/client_test.go +++ b/client_test.go @@ -513,8 +513,14 @@ func TestClient_suffixWithAPIVersion(t *testing.T) { } defer func() { if r := recover(); r != nil { - if r.(string) != tt.wantPanic { - t.Errorf("suffixWithAPIVersion() = %v, want %v", r, tt.wantPanic) + // Check if the panic message matches the expected panic message + if rStr, ok := r.(string); ok { + if rStr != tt.wantPanic { + t.Errorf("suffixWithAPIVersion() = %v, want %v", rStr, tt.wantPanic) + } + } else { + // If the panic is not a string, log it + t.Errorf("suffixWithAPIVersion() panicked with non-string value: %v", r) } } }() From 168761616567a1cf2645c98f6f19329877f0beaa Mon Sep 17 00:00:00 2001 From: LinYushen Date: Thu, 21 Nov 2024 04:26:10 +0800 Subject: [PATCH 071/129] o1 model support stream (#904) --- chat_stream_test.go | 21 --------------------- completion.go | 7 ------- 2 files changed, 28 deletions(-) diff --git a/chat_stream_test.go b/chat_stream_test.go index 14684146c..28a9acf67 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -36,27 +36,6 @@ func TestChatCompletionsStreamWrongModel(t *testing.T) { } } -func TestChatCompletionsStreamWithO1BetaLimitations(t *testing.T) { - config := openai.DefaultConfig("whatever") - config.BaseURL = "http://localhost/v1/chat/completions" - client := openai.NewClientWithConfig(config) - ctx := context.Background() - - req := openai.ChatCompletionRequest{ - Model: openai.O1Preview, - Messages: []openai.ChatCompletionMessage{ - { - Role: openai.ChatMessageRoleUser, - Content: "Hello!", - }, - }, - } - _, err := client.CreateChatCompletionStream(ctx, req) - if !errors.Is(err, openai.ErrO1BetaLimitationsStreaming) { - t.Fatalf("CreateChatCompletion should return ErrO1BetaLimitationsStreaming, but returned: %v", err) - } -} - func TestCreateChatCompletionStream(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() diff --git a/completion.go b/completion.go index 77ea8c3ab..9e3073694 100644 --- a/completion.go +++ b/completion.go @@ -15,7 +15,6 @@ var ( var ( ErrO1BetaLimitationsMessageTypes = errors.New("this model has beta-limitations, user and assistant messages only, system messages are not supported") //nolint:lll - ErrO1BetaLimitationsStreaming = errors.New("this model has beta-limitations, streaming not supported") //nolint:lll ErrO1BetaLimitationsTools = errors.New("this model has beta-limitations, tools, function calling, and response format parameters are not supported") //nolint:lll ErrO1BetaLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll ErrO1BetaLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll @@ -199,12 +198,6 @@ func validateRequestForO1Models(request ChatCompletionRequest) error { return ErrO1MaxTokensDeprecated } - // Beta Limitations - // refs:https://platform.openai.com/docs/guides/reasoning/beta-limitations - // Streaming: not supported - if request.Stream { - return ErrO1BetaLimitationsStreaming - } // Logprobs: not supported. if request.LogProbs { return ErrO1BetaLimitationsLogprobs From 74ed75f291f8f55d1104a541090d46c021169115 Mon Sep 17 00:00:00 2001 From: nagar-ajay Date: Thu, 21 Nov 2024 02:09:44 +0530 Subject: [PATCH 072/129] Make user field optional in embedding request (#899) * make user optional in embedding request * fix unit test --- batch_test.go | 2 +- embeddings.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/batch_test.go b/batch_test.go index 4b2261e0e..f4714f4eb 100644 --- a/batch_test.go +++ b/batch_test.go @@ -211,7 +211,7 @@ func TestUploadBatchFileRequest_AddEmbedding(t *testing.T) { Input: []string{"Hello", "World"}, }, }, - }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"input\":[\"Hello\",\"World\"],\"model\":\"gpt-3.5-turbo\",\"user\":\"\"},\"method\":\"POST\",\"url\":\"/v1/embeddings\"}\n{\"custom_id\":\"req-2\",\"body\":{\"input\":[\"Hello\",\"World\"],\"model\":\"text-embedding-ada-002\",\"user\":\"\"},\"method\":\"POST\",\"url\":\"/v1/embeddings\"}")}, //nolint:lll + }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"input\":[\"Hello\",\"World\"],\"model\":\"gpt-3.5-turbo\"},\"method\":\"POST\",\"url\":\"/v1/embeddings\"}\n{\"custom_id\":\"req-2\",\"body\":{\"input\":[\"Hello\",\"World\"],\"model\":\"text-embedding-ada-002\"},\"method\":\"POST\",\"url\":\"/v1/embeddings\"}")}, //nolint:lll } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/embeddings.go b/embeddings.go index 74eb8aa57..4a0e682da 100644 --- a/embeddings.go +++ b/embeddings.go @@ -155,7 +155,7 @@ const ( type EmbeddingRequest struct { Input any `json:"input"` Model EmbeddingModel `json:"model"` - User string `json:"user"` + User string `json:"user,omitempty"` EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` // Dimensions The number of dimensions the resulting output embeddings should have. // Only supported in text-embedding-3 and later models. From 21fa42c18dbafef43977ab73c403eef6d694b14a Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Sat, 30 Nov 2024 17:39:47 +0800 Subject: [PATCH 073/129] feat: add gpt-4o-2024-11-20 model (#905) --- completion.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/completion.go b/completion.go index 9e3073694..f11566081 100644 --- a/completion.go +++ b/completion.go @@ -37,6 +37,7 @@ const ( GPT4o = "gpt-4o" GPT4o20240513 = "gpt-4o-2024-05-13" GPT4o20240806 = "gpt-4o-2024-08-06" + GPT4o20241120 = "gpt-4o-2024-11-20" GPT4oLatest = "chatgpt-4o-latest" GPT4oMini = "gpt-4o-mini" GPT4oMini20240718 = "gpt-4o-mini-2024-07-18" @@ -119,6 +120,7 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT4o: true, GPT4o20240513: true, GPT4o20240806: true, + GPT4o20241120: true, GPT4oLatest: true, GPT4oMini: true, GPT4oMini20240718: true, From c203ca001fecd40210cfcf9923ab69235c92e321 Mon Sep 17 00:00:00 2001 From: Qiying Wang <781345688@qq.com> Date: Sat, 30 Nov 2024 18:29:05 +0800 Subject: [PATCH 074/129] feat: add RecvRaw (#896) --- stream_reader.go | 39 ++++++++++++++++++++++----------------- stream_reader_test.go | 13 +++++++++++++ 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/stream_reader.go b/stream_reader.go index 4210a1948..ecfa26807 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -32,17 +32,28 @@ type streamReader[T streamable] struct { } func (stream *streamReader[T]) Recv() (response T, err error) { - if stream.isFinished { - err = io.EOF + rawLine, err := stream.RecvRaw() + if err != nil { return } - response, err = stream.processLines() - return + err = stream.unmarshaler.Unmarshal(rawLine, &response) + if err != nil { + return + } + return response, nil +} + +func (stream *streamReader[T]) RecvRaw() ([]byte, error) { + if stream.isFinished { + return nil, io.EOF + } + + return stream.processLines() } //nolint:gocognit -func (stream *streamReader[T]) processLines() (T, error) { +func (stream *streamReader[T]) processLines() ([]byte, error) { var ( emptyMessagesCount uint hasErrorPrefix bool @@ -53,9 +64,9 @@ func (stream *streamReader[T]) processLines() (T, error) { if readErr != nil || hasErrorPrefix { respErr := stream.unmarshalError() if respErr != nil { - return *new(T), fmt.Errorf("error, %w", respErr.Error) + return nil, fmt.Errorf("error, %w", respErr.Error) } - return *new(T), readErr + return nil, readErr } noSpaceLine := bytes.TrimSpace(rawLine) @@ -68,11 +79,11 @@ func (stream *streamReader[T]) processLines() (T, error) { } writeErr := stream.errAccumulator.Write(noSpaceLine) if writeErr != nil { - return *new(T), writeErr + return nil, writeErr } emptyMessagesCount++ if emptyMessagesCount > stream.emptyMessagesLimit { - return *new(T), ErrTooManyEmptyStreamMessages + return nil, ErrTooManyEmptyStreamMessages } continue @@ -81,16 +92,10 @@ func (stream *streamReader[T]) processLines() (T, error) { noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData) if string(noPrefixLine) == "[DONE]" { stream.isFinished = true - return *new(T), io.EOF - } - - var response T - unmarshalErr := stream.unmarshaler.Unmarshal(noPrefixLine, &response) - if unmarshalErr != nil { - return *new(T), unmarshalErr + return nil, io.EOF } - return response, nil + return noPrefixLine, nil } } diff --git a/stream_reader_test.go b/stream_reader_test.go index cd6e46eff..449a14b43 100644 --- a/stream_reader_test.go +++ b/stream_reader_test.go @@ -63,3 +63,16 @@ func TestStreamReaderReturnsErrTestErrorAccumulatorWriteFailed(t *testing.T) { _, err := stream.Recv() checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error()) } + +func TestStreamReaderRecvRaw(t *testing.T) { + stream := &streamReader[ChatCompletionStreamResponse]{ + reader: bufio.NewReader(bytes.NewReader([]byte("data: {\"key\": \"value\"}\n"))), + } + rawLine, err := stream.RecvRaw() + if err != nil { + t.Fatalf("Did not return raw line: %v", err) + } + if !bytes.Equal(rawLine, []byte("{\"key\": \"value\"}")) { + t.Fatalf("Did not return raw line: %v", string(rawLine)) + } +} From af5355f5b1a7701f891109e8a17b7b245ac5363b Mon Sep 17 00:00:00 2001 From: Tim Misiak Date: Sun, 8 Dec 2024 05:12:05 -0800 Subject: [PATCH 075/129] Fix ID field to be optional (#911) The ID field is not always present for streaming responses. Without omitempty, the entire ToolCall struct will be missing. --- chat.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chat.go b/chat.go index 2b13f8dd7..fcaf79cf7 100644 --- a/chat.go +++ b/chat.go @@ -179,7 +179,7 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { type ToolCall struct { // Index is not nil only in chat completion chunk object Index *int `json:"index,omitempty"` - ID string `json:"id"` + ID string `json:"id,omitempty"` Type ToolType `json:"type"` Function FunctionCall `json:"function"` } From 56a9acf86fc3ce0e9030feafa346d64bade94027 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Sun, 8 Dec 2024 13:16:48 +0000 Subject: [PATCH 076/129] Ignore test.mp3 (#913) --- .gitignore | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 99b40bf17..b0ac1605c 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,7 @@ # Auth token for tests .openai-token -.idea \ No newline at end of file +.idea + +# Generated by tests +test.mp3 \ No newline at end of file From 2a0ff5ac63e460cbe44cccd0d4199d51bf8682a4 Mon Sep 17 00:00:00 2001 From: Sabuhi Gurbani <51547928+sabuhigr@users.noreply.github.com> Date: Fri, 27 Dec 2024 14:01:16 +0400 Subject: [PATCH 077/129] Added additional_messages (#914) --- run.go | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/run.go b/run.go index d3e755f05..9c51aaf8d 100644 --- a/run.go +++ b/run.go @@ -83,12 +83,13 @@ const ( ) type RunRequest struct { - AssistantID string `json:"assistant_id"` - Model string `json:"model,omitempty"` - Instructions string `json:"instructions,omitempty"` - AdditionalInstructions string `json:"additional_instructions,omitempty"` - Tools []Tool `json:"tools,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + AssistantID string `json:"assistant_id"` + Model string `json:"model,omitempty"` + Instructions string `json:"instructions,omitempty"` + AdditionalInstructions string `json:"additional_instructions,omitempty"` + AdditionalMessages []ThreadMessage `json:"additional_messages,omitempty"` + Tools []Tool `json:"tools,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` // Sampling temperature between 0 and 2. Higher values like 0.8 are more random. // lower values are more focused and deterministic. From 7a2915a37dae714f40a4b5575fbf98430fe1d6aa Mon Sep 17 00:00:00 2001 From: Oleksandr Redko Date: Fri, 31 Jan 2025 20:55:41 +0200 Subject: [PATCH 078/129] Simplify tests with T.TempDir (#929) --- .golangci.yml | 1 + audio_api_test.go | 10 ++------- audio_test.go | 8 ++----- image_api_test.go | 42 +++++++++++------------------------ internal/form_builder_test.go | 17 ++++---------- internal/test/helpers.go | 10 --------- openai_test.go | 2 +- speech_test.go | 4 +--- 8 files changed, 24 insertions(+), 70 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 724cb7375..9d22d9bd3 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -206,6 +206,7 @@ linters: - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - unconvert # Remove unnecessary type conversions - unparam # Reports unused function parameters + - usetesting # Reports uses of functions with replacement inside the testing package - wastedassign # wastedassign finds wasted assignment statements. - whitespace # Tool for detection of leading and trailing whitespace ## you may want to enable diff --git a/audio_api_test.go b/audio_api_test.go index c24598443..6c6a35643 100644 --- a/audio_api_test.go +++ b/audio_api_test.go @@ -40,12 +40,9 @@ func TestAudio(t *testing.T) { ctx := context.Background() - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - path := filepath.Join(dir, "fake.mp3") + path := filepath.Join(t.TempDir(), "fake.mp3") test.CreateTestFile(t, path) req := openai.AudioRequest{ @@ -90,12 +87,9 @@ func TestAudioWithOptionalArgs(t *testing.T) { ctx := context.Background() - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - path := filepath.Join(dir, "fake.mp3") + path := filepath.Join(t.TempDir(), "fake.mp3") test.CreateTestFile(t, path) req := openai.AudioRequest{ diff --git a/audio_test.go b/audio_test.go index 235931f36..9f32d5468 100644 --- a/audio_test.go +++ b/audio_test.go @@ -13,9 +13,7 @@ import ( ) func TestAudioWithFailingFormBuilder(t *testing.T) { - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - path := filepath.Join(dir, "fake.mp3") + path := filepath.Join(t.TempDir(), "fake.mp3") test.CreateTestFile(t, path) req := AudioRequest{ @@ -63,9 +61,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) { func TestCreateFileField(t *testing.T) { t.Run("createFileField failing file", func(t *testing.T) { - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - path := filepath.Join(dir, "fake.mp3") + path := filepath.Join(t.TempDir(), "fake.mp3") test.CreateTestFile(t, path) req := AudioRequest{ diff --git a/image_api_test.go b/image_api_test.go index 48416b1e2..f6057b77d 100644 --- a/image_api_test.go +++ b/image_api_test.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "os" + "path/filepath" "testing" "time" @@ -86,24 +87,17 @@ func TestImageEdit(t *testing.T) { defer teardown() server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) - origin, err := os.Create("image.png") + origin, err := os.Create(filepath.Join(t.TempDir(), "image.png")) if err != nil { - t.Error("open origin file error") - return + t.Fatalf("open origin file error: %v", err) } + defer origin.Close() - mask, err := os.Create("mask.png") + mask, err := os.Create(filepath.Join(t.TempDir(), "mask.png")) if err != nil { - t.Error("open mask file error") - return + t.Fatalf("open mask file error: %v", err) } - - defer func() { - mask.Close() - origin.Close() - os.Remove("mask.png") - os.Remove("image.png") - }() + defer mask.Close() _, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{ Image: origin, @@ -121,16 +115,11 @@ func TestImageEditWithoutMask(t *testing.T) { defer teardown() server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) - origin, err := os.Create("image.png") + origin, err := os.Create(filepath.Join(t.TempDir(), "image.png")) if err != nil { - t.Error("open origin file error") - return + t.Fatalf("open origin file error: %v", err) } - - defer func() { - origin.Close() - os.Remove("image.png") - }() + defer origin.Close() _, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{ Image: origin, @@ -178,16 +167,11 @@ func TestImageVariation(t *testing.T) { defer teardown() server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint) - origin, err := os.Create("image.png") + origin, err := os.Create(filepath.Join(t.TempDir(), "image.png")) if err != nil { - t.Error("open origin file error") - return + t.Fatalf("open origin file error: %v", err) } - - defer func() { - origin.Close() - os.Remove("image.png") - }() + defer origin.Close() _, err = client.CreateVariImage(context.Background(), openai.ImageVariRequest{ Image: origin, diff --git a/internal/form_builder_test.go b/internal/form_builder_test.go index d3faf9982..8df989e3b 100644 --- a/internal/form_builder_test.go +++ b/internal/form_builder_test.go @@ -1,7 +1,6 @@ package openai //nolint:testpackage // testing private field import ( - "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" "bytes" @@ -20,15 +19,11 @@ func (*failingWriter) Write([]byte) (int, error) { } func TestFormBuilderWithFailingWriter(t *testing.T) { - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - - file, err := os.CreateTemp(dir, "") + file, err := os.CreateTemp(t.TempDir(), "") if err != nil { - t.Errorf("Error creating tmp file: %v", err) + t.Fatalf("Error creating tmp file: %v", err) } defer file.Close() - defer os.Remove(file.Name()) builder := NewFormBuilder(&failingWriter{}) err = builder.CreateFormFile("file", file) @@ -36,15 +31,11 @@ func TestFormBuilderWithFailingWriter(t *testing.T) { } func TestFormBuilderWithClosedFile(t *testing.T) { - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - - file, err := os.CreateTemp(dir, "") + file, err := os.CreateTemp(t.TempDir(), "") if err != nil { - t.Errorf("Error creating tmp file: %v", err) + t.Fatalf("Error creating tmp file: %v", err) } file.Close() - defer os.Remove(file.Name()) body := &bytes.Buffer{} builder := NewFormBuilder(body) diff --git a/internal/test/helpers.go b/internal/test/helpers.go index 0e63ae82f..dc5fa6646 100644 --- a/internal/test/helpers.go +++ b/internal/test/helpers.go @@ -19,16 +19,6 @@ func CreateTestFile(t *testing.T, path string) { file.Close() } -// CreateTestDirectory creates a temporary folder which will be deleted when cleanup is called. -func CreateTestDirectory(t *testing.T) (path string, cleanup func()) { - t.Helper() - - path, err := os.MkdirTemp(os.TempDir(), "") - checks.NoError(t, err) - - return path, func() { os.RemoveAll(path) } -} - // TokenRoundTripper is a struct that implements the RoundTripper // interface, specifically to handle the authentication token by adding a token // to the request header. We need this because the API requires that each diff --git a/openai_test.go b/openai_test.go index 729d8880c..48a00b9fc 100644 --- a/openai_test.go +++ b/openai_test.go @@ -31,7 +31,7 @@ func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, tea // This function approximates based on the rule of thumb stated by OpenAI: // https://beta.openai.com/tokenizer // -// TODO: implement an actual tokenizer for GPT-3 and Codex (once available) +// TODO: implement an actual tokenizer for GPT-3 and Codex (once available). func numTokens(s string) int { return int(float32(len(s)) / 4) } diff --git a/speech_test.go b/speech_test.go index f1e405c39..67a3feabc 100644 --- a/speech_test.go +++ b/speech_test.go @@ -21,10 +21,8 @@ func TestSpeechIntegration(t *testing.T) { defer teardown() server.RegisterHandler("/v1/audio/speech", func(w http.ResponseWriter, r *http.Request) { - dir, cleanup := test.CreateTestDirectory(t) - path := filepath.Join(dir, "fake.mp3") + path := filepath.Join(t.TempDir(), "fake.mp3") test.CreateTestFile(t, path) - defer cleanup() // audio endpoints only accept POST requests if r.Method != "POST" { From 9823a8bbbdc00871c1d569ed2b90111af94a4fb2 Mon Sep 17 00:00:00 2001 From: Trevor Creech Date: Fri, 31 Jan 2025 10:57:57 -0800 Subject: [PATCH 079/129] Chat Completion API: add ReasoningEffort and new o1 models (#928) * add reasoning_effort param * add o1 model * fix lint --- chat.go | 2 ++ completion.go | 2 ++ 2 files changed, 4 insertions(+) diff --git a/chat.go b/chat.go index fcaf79cf7..7a44fd831 100644 --- a/chat.go +++ b/chat.go @@ -258,6 +258,8 @@ type ChatCompletionRequest struct { // Store can be set to true to store the output of this completion request for use in distillations and evals. // https://platform.openai.com/docs/api-reference/chat/create#chat-create-store Store bool `json:"store,omitempty"` + // Controls effort on reasoning for reasoning models. It can be set to "low", "medium", or "high". + ReasoningEffort string `json:"reasoning_effort,omitempty"` // Metadata to store with the completion. Metadata map[string]string `json:"metadata,omitempty"` } diff --git a/completion.go b/completion.go index f11566081..62724688a 100644 --- a/completion.go +++ b/completion.go @@ -29,6 +29,8 @@ const ( O1Mini20240912 = "o1-mini-2024-09-12" O1Preview = "o1-preview" O1Preview20240912 = "o1-preview-2024-09-12" + O1 = "o1" + O120241217 = "o1-2024-12-17" GPT432K0613 = "gpt-4-32k-0613" GPT432K0314 = "gpt-4-32k-0314" GPT432K = "gpt-4-32k" From 45aa99607be0b4c225af57c36fb5cff7328957de Mon Sep 17 00:00:00 2001 From: saileshd1402 Date: Sat, 1 Feb 2025 00:35:29 +0530 Subject: [PATCH 080/129] Make "Content" field in "ChatCompletionMessage" omitempty (#926) --- chat.go | 6 +++--- chat_test.go | 2 +- openai_test.go | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/chat.go b/chat.go index 7a44fd831..8ea7238fe 100644 --- a/chat.go +++ b/chat.go @@ -93,7 +93,7 @@ type ChatMessagePart struct { type ChatCompletionMessage struct { Role string `json:"role"` - Content string `json:"content"` + Content string `json:"content,omitempty"` Refusal string `json:"refusal,omitempty"` MultiContent []ChatMessagePart @@ -132,7 +132,7 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { msg := struct { Role string `json:"role"` - Content string `json:"content"` + Content string `json:"content,omitempty"` Refusal string `json:"refusal,omitempty"` MultiContent []ChatMessagePart `json:"-"` Name string `json:"name,omitempty"` @@ -146,7 +146,7 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { msg := struct { Role string `json:"role"` - Content string `json:"content"` + Content string `json:"content,omitempty"` Refusal string `json:"refusal,omitempty"` MultiContent []ChatMessagePart Name string `json:"name,omitempty"` diff --git a/chat_test.go b/chat_test.go index 134026cdb..cea549cbd 100644 --- a/chat_test.go +++ b/chat_test.go @@ -631,7 +631,7 @@ func TestMultipartChatMessageSerialization(t *testing.T) { t.Fatalf("Unexpected error") } res = strings.ReplaceAll(string(s), " ", "") - if res != `{"role":"user","content":""}` { + if res != `{"role":"user"}` { t.Fatalf("invalid message: %s", string(s)) } } diff --git a/openai_test.go b/openai_test.go index 48a00b9fc..6c26eebd1 100644 --- a/openai_test.go +++ b/openai_test.go @@ -29,7 +29,7 @@ func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, tea // numTokens Returns the number of GPT-3 encoded tokens in the given text. // This function approximates based on the rule of thumb stated by OpenAI: -// https://beta.openai.com/tokenizer +// https://beta.openai.com/tokenizer/ // // TODO: implement an actual tokenizer for GPT-3 and Codex (once available). func numTokens(s string) int { From 2054db016c335136eba471aebf49cc78981dd502 Mon Sep 17 00:00:00 2001 From: rory malcolm Date: Thu, 6 Feb 2025 14:53:19 +0000 Subject: [PATCH 081/129] Add support for O3-mini (#930) * Add support for O3-mini - Add support for the o3 mini set of models, including tests that match the constraints in OpenAI's API docs (https://platform.openai.com/docs/models#o3-mini). * Deprecate and refactor - Deprecate `ErrO1BetaLimitationsLogprobs` and `ErrO1BetaLimitationsOther` - Implement `validationRequestForReasoningModels`, which works on both o1 & o3, and has per-model-type restrictions on functionality (eg, o3 class are allowed function calls and system messages, o1 isn't) * Move reasoning validation to `reasoning_validator.go` - Add a `NewReasoningValidator` which exposes a `Validate()` method for a given request - Also adds a test for chat streams * Final nits --- chat.go | 3 +- chat_stream.go | 3 +- chat_stream_test.go | 167 +++++++++++++++++++++++++++++++++++++++++ chat_test.go | 153 +++++++++++++++++++++++++++++++++++-- completion.go | 86 +-------------------- reasoning_validator.go | 111 +++++++++++++++++++++++++++ 6 files changed, 431 insertions(+), 92 deletions(-) create mode 100644 reasoning_validator.go diff --git a/chat.go b/chat.go index 8ea7238fe..ce24fa34a 100644 --- a/chat.go +++ b/chat.go @@ -392,7 +392,8 @@ func (c *Client) CreateChatCompletion( return } - if err = validateRequestForO1Models(request); err != nil { + reasoningValidator := NewReasoningValidator() + if err = reasoningValidator.Validate(request); err != nil { return } diff --git a/chat_stream.go b/chat_stream.go index 58b2651c0..525b4457a 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -80,7 +80,8 @@ func (c *Client) CreateChatCompletionStream( } request.Stream = true - if err = validateRequestForO1Models(request); err != nil { + reasoningValidator := NewReasoningValidator() + if err = reasoningValidator.Validate(request); err != nil { return } diff --git a/chat_stream_test.go b/chat_stream_test.go index 28a9acf67..4d992e4d1 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -792,6 +792,173 @@ func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool { return true } +func TestCreateChatCompletionStreamWithReasoningModel(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + dataBytes := []byte{} + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"1","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"2","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"3","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":" from"},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"4","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":" O3Mini"},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"5","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxCompletionTokens: 2000, + Model: openai.O3Mini20250131, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "1", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.O3Mini20250131, + SystemFingerprint: "fp_mini", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Role: "assistant", + }, + }, + }, + }, + { + ID: "2", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.O3Mini20250131, + SystemFingerprint: "fp_mini", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "Hello", + }, + }, + }, + }, + { + ID: "3", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.O3Mini20250131, + SystemFingerprint: "fp_mini", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: " from", + }, + }, + }, + }, + { + ID: "4", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.O3Mini20250131, + SystemFingerprint: "fp_mini", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: " O3Mini", + }, + }, + }, + }, + { + ID: "5", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.O3Mini20250131, + SystemFingerprint: "fp_mini", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + FinishReason: "stop", + }, + }, + }, + } + + for ix, expectedResponse := range expectedResponses { + b, _ := json.Marshal(expectedResponse) + t.Logf("%d: %s", ix, string(b)) + + receivedResponse, streamErr := stream.Recv() + checks.NoError(t, streamErr, "stream.Recv() failed") + if !compareChatResponses(expectedResponse, receivedResponse) { + t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) + } + } + + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } +} + +func TestCreateChatCompletionStreamReasoningValidatorFails(t *testing.T) { + client, _, _ := setupOpenAITestServer() + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 100, // This will trigger the validator to fail + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + + if stream != nil { + t.Error("Expected nil stream when validation fails") + stream.Close() + } + + if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) { + t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated, got: %v", err) + } +} + func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool { if c1.Index != c2.Index { return false diff --git a/chat_test.go b/chat_test.go index cea549cbd..fc6c4a936 100644 --- a/chat_test.go +++ b/chat_test.go @@ -64,7 +64,7 @@ func TestO1ModelsChatCompletionsDeprecatedFields(t *testing.T) { MaxTokens: 5, Model: openai.O1Preview, }, - expectedError: openai.ErrO1MaxTokensDeprecated, + expectedError: openai.ErrReasoningModelMaxTokensDeprecated, }, { name: "o1-mini_MaxTokens_deprecated", @@ -72,7 +72,7 @@ func TestO1ModelsChatCompletionsDeprecatedFields(t *testing.T) { MaxTokens: 5, Model: openai.O1Mini, }, - expectedError: openai.ErrO1MaxTokensDeprecated, + expectedError: openai.ErrReasoningModelMaxTokensDeprecated, }, } @@ -104,7 +104,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { LogProbs: true, Model: openai.O1Preview, }, - expectedError: openai.ErrO1BetaLimitationsLogprobs, + expectedError: openai.ErrReasoningModelLimitationsLogprobs, }, { name: "message_type_unsupported", @@ -155,7 +155,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { }, Temperature: float32(2), }, - expectedError: openai.ErrO1BetaLimitationsOther, + expectedError: openai.ErrReasoningModelLimitationsOther, }, { name: "set_top_unsupported", @@ -173,7 +173,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { Temperature: float32(1), TopP: float32(0.1), }, - expectedError: openai.ErrO1BetaLimitationsOther, + expectedError: openai.ErrReasoningModelLimitationsOther, }, { name: "set_n_unsupported", @@ -192,7 +192,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { TopP: float32(1), N: 2, }, - expectedError: openai.ErrO1BetaLimitationsOther, + expectedError: openai.ErrReasoningModelLimitationsOther, }, { name: "set_presence_penalty_unsupported", @@ -209,7 +209,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { }, PresencePenalty: float32(1), }, - expectedError: openai.ErrO1BetaLimitationsOther, + expectedError: openai.ErrReasoningModelLimitationsOther, }, { name: "set_frequency_penalty_unsupported", @@ -226,7 +226,127 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { }, FrequencyPenalty: float32(0.1), }, - expectedError: openai.ErrO1BetaLimitationsOther, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + ctx := context.Background() + + _, err := client.CreateChatCompletion(ctx, tt.in) + checks.HasError(t, err) + msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err) + checks.ErrorIs(t, err, tt.expectedError, msg) + }) + } +} + +func TestO3ModelsChatCompletionsBetaLimitations(t *testing.T) { + tests := []struct { + name string + in openai.ChatCompletionRequest + expectedError error + }{ + { + name: "log_probs_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + LogProbs: true, + Model: openai.O3Mini, + }, + expectedError: openai.ErrReasoningModelLimitationsLogprobs, + }, + { + name: "set_temperature_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(2), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_top_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(0.1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_n_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(1), + N: 2, + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_presence_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + PresencePenalty: float32(1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_frequency_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + FrequencyPenalty: float32(0.1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, }, } @@ -308,6 +428,23 @@ func TestO1ModelChatCompletions(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion error") } +func TestO3ModelChatCompletions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + Model: openai.O3Mini, + MaxCompletionTokens: 1000, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") +} + // TestCompletions Tests the completions endpoint of the API using the mocked server. func TestChatCompletionsWithHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() diff --git a/completion.go b/completion.go index 62724688a..1985293f8 100644 --- a/completion.go +++ b/completion.go @@ -2,24 +2,9 @@ package openai import ( "context" - "errors" "net/http" ) -var ( - ErrO1MaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens") //nolint:lll - ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll - ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll - ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Prompt only supports string and []string") //nolint:lll -) - -var ( - ErrO1BetaLimitationsMessageTypes = errors.New("this model has beta-limitations, user and assistant messages only, system messages are not supported") //nolint:lll - ErrO1BetaLimitationsTools = errors.New("this model has beta-limitations, tools, function calling, and response format parameters are not supported") //nolint:lll - ErrO1BetaLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll - ErrO1BetaLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll -) - // GPT3 Defines the models provided by OpenAI to use when generating // completions from OpenAI. // GPT3 Models are designed for text-based tasks. For code-specific @@ -31,6 +16,8 @@ const ( O1Preview20240912 = "o1-preview-2024-09-12" O1 = "o1" O120241217 = "o1-2024-12-17" + O3Mini = "o3-mini" + O3Mini20250131 = "o3-mini-2025-01-31" GPT432K0613 = "gpt-4-32k-0613" GPT432K0314 = "gpt-4-32k-0314" GPT432K = "gpt-4-32k" @@ -96,21 +83,14 @@ const ( CodexCodeDavinci001 = "code-davinci-001" ) -// O1SeriesModels List of new Series of OpenAI models. -// Some old api attributes not supported. -var O1SeriesModels = map[string]struct{}{ - O1Mini: {}, - O1Mini20240912: {}, - O1Preview: {}, - O1Preview20240912: {}, -} - var disabledModelsForEndpoints = map[string]map[string]bool{ "/completions": { O1Mini: true, O1Mini20240912: true, O1Preview: true, O1Preview20240912: true, + O3Mini: true, + O3Mini20250131: true, GPT3Dot5Turbo: true, GPT3Dot5Turbo0301: true, GPT3Dot5Turbo0613: true, @@ -183,64 +163,6 @@ func checkPromptType(prompt any) bool { return true // all items in the slice are string, so it is []string } -var unsupportedToolsForO1Models = map[ToolType]struct{}{ - ToolTypeFunction: {}, -} - -var availableMessageRoleForO1Models = map[string]struct{}{ - ChatMessageRoleUser: {}, - ChatMessageRoleAssistant: {}, -} - -// validateRequestForO1Models checks for deprecated fields of OpenAI models. -func validateRequestForO1Models(request ChatCompletionRequest) error { - if _, found := O1SeriesModels[request.Model]; !found { - return nil - } - - if request.MaxTokens > 0 { - return ErrO1MaxTokensDeprecated - } - - // Logprobs: not supported. - if request.LogProbs { - return ErrO1BetaLimitationsLogprobs - } - - // Message types: user and assistant messages only, system messages are not supported. - for _, m := range request.Messages { - if _, found := availableMessageRoleForO1Models[m.Role]; !found { - return ErrO1BetaLimitationsMessageTypes - } - } - - // Tools: tools, function calling, and response format parameters are not supported - for _, t := range request.Tools { - if _, found := unsupportedToolsForO1Models[t.Type]; found { - return ErrO1BetaLimitationsTools - } - } - - // Other: temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0. - if request.Temperature > 0 && request.Temperature != 1 { - return ErrO1BetaLimitationsOther - } - if request.TopP > 0 && request.TopP != 1 { - return ErrO1BetaLimitationsOther - } - if request.N > 0 && request.N != 1 { - return ErrO1BetaLimitationsOther - } - if request.PresencePenalty > 0 { - return ErrO1BetaLimitationsOther - } - if request.FrequencyPenalty > 0 { - return ErrO1BetaLimitationsOther - } - - return nil -} - // CompletionRequest represents a request structure for completion API. type CompletionRequest struct { Model string `json:"model"` diff --git a/reasoning_validator.go b/reasoning_validator.go new file mode 100644 index 000000000..42a9fbd2e --- /dev/null +++ b/reasoning_validator.go @@ -0,0 +1,111 @@ +package openai + +import ( + "errors" + "strings" +) + +var ( + // Deprecated: use ErrReasoningModelMaxTokensDeprecated instead. + ErrO1MaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens") //nolint:lll + ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll + ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll + ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Prompt only supports string and []string") //nolint:lll +) + +var ( + ErrO1BetaLimitationsMessageTypes = errors.New("this model has beta-limitations, user and assistant messages only, system messages are not supported") //nolint:lll + ErrO1BetaLimitationsTools = errors.New("this model has beta-limitations, tools, function calling, and response format parameters are not supported") //nolint:lll + // Deprecated: use ErrReasoningModelLimitations* instead. + ErrO1BetaLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll + ErrO1BetaLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll +) + +var ( + //nolint:lll + ErrReasoningModelMaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens") + ErrReasoningModelLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll + ErrReasoningModelLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll +) + +var unsupportedToolsForO1Models = map[ToolType]struct{}{ + ToolTypeFunction: {}, +} + +var availableMessageRoleForO1Models = map[string]struct{}{ + ChatMessageRoleUser: {}, + ChatMessageRoleAssistant: {}, +} + +// ReasoningValidator handles validation for o-series model requests. +type ReasoningValidator struct{} + +// NewReasoningValidator creates a new validator for o-series models. +func NewReasoningValidator() *ReasoningValidator { + return &ReasoningValidator{} +} + +// Validate performs all validation checks for o-series models. +func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error { + o1Series := strings.HasPrefix(request.Model, "o1") + o3Series := strings.HasPrefix(request.Model, "o3") + + if !o1Series && !o3Series { + return nil + } + + if err := v.validateReasoningModelParams(request); err != nil { + return err + } + + if o1Series { + if err := v.validateO1Specific(request); err != nil { + return err + } + } + + return nil +} + +// validateReasoningModelParams checks reasoning model parameters. +func (v *ReasoningValidator) validateReasoningModelParams(request ChatCompletionRequest) error { + if request.MaxTokens > 0 { + return ErrReasoningModelMaxTokensDeprecated + } + if request.LogProbs { + return ErrReasoningModelLimitationsLogprobs + } + if request.Temperature > 0 && request.Temperature != 1 { + return ErrReasoningModelLimitationsOther + } + if request.TopP > 0 && request.TopP != 1 { + return ErrReasoningModelLimitationsOther + } + if request.N > 0 && request.N != 1 { + return ErrReasoningModelLimitationsOther + } + if request.PresencePenalty > 0 { + return ErrReasoningModelLimitationsOther + } + if request.FrequencyPenalty > 0 { + return ErrReasoningModelLimitationsOther + } + + return nil +} + +// validateO1Specific checks O1-specific limitations. +func (v *ReasoningValidator) validateO1Specific(request ChatCompletionRequest) error { + for _, m := range request.Messages { + if _, found := availableMessageRoleForO1Models[m.Role]; !found { + return ErrO1BetaLimitationsMessageTypes + } + } + + for _, t := range request.Tools { + if _, found := unsupportedToolsForO1Models[t.Type]; found { + return ErrO1BetaLimitationsTools + } + } + return nil +} From a62919e8c66e35db125c129e8a9d2566a73e1e1f Mon Sep 17 00:00:00 2001 From: Mazyar Yousefiniyae shad Date: Sun, 9 Feb 2025 22:06:44 +0330 Subject: [PATCH 082/129] ref: add image url support to messages (#933) * ref: add image url support to messages * fix linter error * fix linter error --- messages.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/messages.go b/messages.go index 902363938..3852d2e37 100644 --- a/messages.go +++ b/messages.go @@ -41,6 +41,7 @@ type MessageContent struct { Type string `json:"type"` Text *MessageText `json:"text,omitempty"` ImageFile *ImageFile `json:"image_file,omitempty"` + ImageURL *ImageURL `json:"image_url,omitempty"` } type MessageText struct { Value string `json:"value"` @@ -51,6 +52,11 @@ type ImageFile struct { FileID string `json:"file_id"` } +type ImageURL struct { + URL string `json:"url"` + Detail string `json:"detail"` +} + type MessageRequest struct { Role string `json:"role"` Content string `json:"content"` From c0a9a75fe01dbefb16f87d69bab042516009184f Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Wed, 12 Feb 2025 23:05:44 +0800 Subject: [PATCH 083/129] feat: add developer role (#936) --- chat.go | 1 + reasoning_validator.go | 1 + 2 files changed, 2 insertions(+) diff --git a/chat.go b/chat.go index ce24fa34a..995860c40 100644 --- a/chat.go +++ b/chat.go @@ -14,6 +14,7 @@ const ( ChatMessageRoleAssistant = "assistant" ChatMessageRoleFunction = "function" ChatMessageRoleTool = "tool" + ChatMessageRoleDeveloper = "developer" ) const chatCompletionsSuffix = "/chat/completions" diff --git a/reasoning_validator.go b/reasoning_validator.go index 42a9fbd2e..4d4671b17 100644 --- a/reasoning_validator.go +++ b/reasoning_validator.go @@ -35,6 +35,7 @@ var unsupportedToolsForO1Models = map[ToolType]struct{}{ var availableMessageRoleForO1Models = map[string]struct{}{ ChatMessageRoleUser: {}, ChatMessageRoleAssistant: {}, + ChatMessageRoleDeveloper: {}, } // ReasoningValidator handles validation for o-series model requests. From 85f578b865a6ea12ab24307f3bc68c97f85b6580 Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Mon, 17 Feb 2025 19:29:18 +0800 Subject: [PATCH 084/129] fix: remove validateO1Specific (#939) * fix: remove validateO1Specific * update golangci-lint-action version * fix actions * fix actions * fix actions * fix actions * remove some o1 test --- .github/workflows/pr.yml | 4 ++-- chat_test.go | 34 ---------------------------------- reasoning_validator.go | 32 -------------------------------- 3 files changed, 2 insertions(+), 68 deletions(-) diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index a41fff92f..ea0c327f1 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -18,9 +18,9 @@ jobs: run: | go vet . - name: Run golangci-lint - uses: golangci/golangci-lint-action@v4 + uses: golangci/golangci-lint-action@v6 with: - version: latest + version: v1.63.4 - name: Run tests run: go test -race -covermode=atomic -coverprofile=coverage.out -v . - name: Upload coverage reports to Codecov diff --git a/chat_test.go b/chat_test.go index fc6c4a936..e90142da6 100644 --- a/chat_test.go +++ b/chat_test.go @@ -106,40 +106,6 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { }, expectedError: openai.ErrReasoningModelLimitationsLogprobs, }, - { - name: "message_type_unsupported", - in: openai.ChatCompletionRequest{ - MaxCompletionTokens: 1000, - Model: openai.O1Mini, - Messages: []openai.ChatCompletionMessage{ - { - Role: openai.ChatMessageRoleSystem, - }, - }, - }, - expectedError: openai.ErrO1BetaLimitationsMessageTypes, - }, - { - name: "tool_unsupported", - in: openai.ChatCompletionRequest{ - MaxCompletionTokens: 1000, - Model: openai.O1Mini, - Messages: []openai.ChatCompletionMessage{ - { - Role: openai.ChatMessageRoleUser, - }, - { - Role: openai.ChatMessageRoleAssistant, - }, - }, - Tools: []openai.Tool{ - { - Type: openai.ToolTypeFunction, - }, - }, - }, - expectedError: openai.ErrO1BetaLimitationsTools, - }, { name: "set_temperature_unsupported", in: openai.ChatCompletionRequest{ diff --git a/reasoning_validator.go b/reasoning_validator.go index 4d4671b17..040d6b495 100644 --- a/reasoning_validator.go +++ b/reasoning_validator.go @@ -28,16 +28,6 @@ var ( ErrReasoningModelLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll ) -var unsupportedToolsForO1Models = map[ToolType]struct{}{ - ToolTypeFunction: {}, -} - -var availableMessageRoleForO1Models = map[string]struct{}{ - ChatMessageRoleUser: {}, - ChatMessageRoleAssistant: {}, - ChatMessageRoleDeveloper: {}, -} - // ReasoningValidator handles validation for o-series model requests. type ReasoningValidator struct{} @@ -59,12 +49,6 @@ func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error { return err } - if o1Series { - if err := v.validateO1Specific(request); err != nil { - return err - } - } - return nil } @@ -94,19 +78,3 @@ func (v *ReasoningValidator) validateReasoningModelParams(request ChatCompletion return nil } - -// validateO1Specific checks O1-specific limitations. -func (v *ReasoningValidator) validateO1Specific(request ChatCompletionRequest) error { - for _, m := range request.Messages { - if _, found := availableMessageRoleForO1Models[m.Role]; !found { - return ErrO1BetaLimitationsMessageTypes - } - } - - for _, t := range request.Tools { - if _, found := unsupportedToolsForO1Models[t.Type]; found { - return ErrO1BetaLimitationsTools - } - } - return nil -} From be2e2387d4dcb15593ae5d0094e6f7b023ab3f53 Mon Sep 17 00:00:00 2001 From: Dan Ackerson Date: Tue, 25 Feb 2025 12:03:38 +0100 Subject: [PATCH 085/129] feat: add Anthropic API support with custom version header (#934) * feat: add Anthropic API support with custom version header * refactor: use switch statement for API type header handling * refactor: add OpenAI & AzureAD types to be exhaustive * Update client.go need explicit fallthrough in empty case statements * constant for APIVersion; addtl tests --- client.go | 18 +++++++++++++----- client_test.go | 15 +++++++++++++++ config.go | 22 +++++++++++++++++++++- config_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 89 insertions(+), 6 deletions(-) diff --git a/client.go b/client.go index ed8595e0b..cef375348 100644 --- a/client.go +++ b/client.go @@ -182,13 +182,21 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream func (c *Client) setCommonHeaders(req *http.Request) { // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication - // Azure API Key authentication - if c.config.APIType == APITypeAzure || c.config.APIType == APITypeCloudflareAzure { + switch c.config.APIType { + case APITypeAzure, APITypeCloudflareAzure: + // Azure API Key authentication req.Header.Set(AzureAPIKeyHeader, c.config.authToken) - } else if c.config.authToken != "" { - // OpenAI or Azure AD authentication - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) + case APITypeAnthropic: + // https://docs.anthropic.com/en/api/versioning + req.Header.Set("anthropic-version", c.config.APIVersion) + case APITypeOpenAI, APITypeAzureAD: + fallthrough + default: + if c.config.authToken != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) + } } + if c.config.OrgID != "" { req.Header.Set("OpenAI-Organization", c.config.OrgID) } diff --git a/client_test.go b/client_test.go index 2ed82f13c..321971445 100644 --- a/client_test.go +++ b/client_test.go @@ -39,6 +39,21 @@ func TestClient(t *testing.T) { } } +func TestSetCommonHeadersAnthropic(t *testing.T) { + config := DefaultAnthropicConfig("mock-token", "") + client := NewClientWithConfig(config) + req, err := http.NewRequest("GET", "http://example.com/", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + client.setCommonHeaders(req) + + if got := req.Header.Get("anthropic-version"); got != AnthropicAPIVersion { + t.Errorf("Expected anthropic-version header to be %q, got %q", AnthropicAPIVersion, got) + } +} + func TestDecodeResponse(t *testing.T) { stringInput := "" diff --git a/config.go b/config.go index 8a9183558..4788ba62a 100644 --- a/config.go +++ b/config.go @@ -11,6 +11,8 @@ const ( azureAPIPrefix = "openai" azureDeploymentsPrefix = "deployments" + + AnthropicAPIVersion = "2023-06-01" ) type APIType string @@ -20,6 +22,7 @@ const ( APITypeAzure APIType = "AZURE" APITypeAzureAD APIType = "AZURE_AD" APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE" + APITypeAnthropic APIType = "ANTHROPIC" ) const AzureAPIKeyHeader = "api-key" @@ -37,7 +40,7 @@ type ClientConfig struct { BaseURL string OrgID string APIType APIType - APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD + APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD or APITypeAnthropic AssistantVersion string AzureModelMapperFunc func(model string) string // replace model to azure deployment name func HTTPClient HTTPDoer @@ -76,6 +79,23 @@ func DefaultAzureConfig(apiKey, baseURL string) ClientConfig { } } +func DefaultAnthropicConfig(apiKey, baseURL string) ClientConfig { + if baseURL == "" { + baseURL = "https://api.anthropic.com/v1" + } + return ClientConfig{ + authToken: apiKey, + BaseURL: baseURL, + OrgID: "", + APIType: APITypeAnthropic, + APIVersion: AnthropicAPIVersion, + + HTTPClient: &http.Client{}, + + EmptyMessagesLimit: defaultEmptyMessagesLimit, + } +} + func (ClientConfig) String() string { return "" } diff --git a/config_test.go b/config_test.go index 3e528c3e9..145c26066 100644 --- a/config_test.go +++ b/config_test.go @@ -60,3 +60,43 @@ func TestGetAzureDeploymentByModel(t *testing.T) { }) } } + +func TestDefaultAnthropicConfig(t *testing.T) { + apiKey := "test-key" + baseURL := "https://api.anthropic.com/v1" + + config := openai.DefaultAnthropicConfig(apiKey, baseURL) + + if config.APIType != openai.APITypeAnthropic { + t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType) + } + + if config.APIVersion != openai.AnthropicAPIVersion { + t.Errorf("Expected APIVersion to be 2023-06-01, got %v", config.APIVersion) + } + + if config.BaseURL != baseURL { + t.Errorf("Expected BaseURL to be %v, got %v", baseURL, config.BaseURL) + } + + if config.EmptyMessagesLimit != 300 { + t.Errorf("Expected EmptyMessagesLimit to be 300, got %v", config.EmptyMessagesLimit) + } +} + +func TestDefaultAnthropicConfigWithEmptyValues(t *testing.T) { + config := openai.DefaultAnthropicConfig("", "") + + if config.APIType != openai.APITypeAnthropic { + t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType) + } + + if config.APIVersion != openai.AnthropicAPIVersion { + t.Errorf("Expected APIVersion to be %s, got %v", openai.AnthropicAPIVersion, config.APIVersion) + } + + expectedBaseURL := "https://api.anthropic.com/v1" + if config.BaseURL != expectedBaseURL { + t.Errorf("Expected BaseURL to be %v, got %v", expectedBaseURL, config.BaseURL) + } +} From 261721bfdbeb2edc495f24189b75f2c151f186a7 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Tue, 25 Feb 2025 16:56:35 +0000 Subject: [PATCH 086/129] Fix linter (#943) * fix lint * remove linters --- .github/workflows/pr.yml | 4 ++-- .golangci.yml | 14 -------------- 2 files changed, 2 insertions(+), 16 deletions(-) diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index ea0c327f1..818a8842b 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -13,14 +13,14 @@ jobs: - name: Setup Go uses: actions/setup-go@v5 with: - go-version: '1.21' + go-version: '1.24' - name: Run vet run: | go vet . - name: Run golangci-lint uses: golangci/golangci-lint-action@v6 with: - version: v1.63.4 + version: v1.64.5 - name: Run tests run: go test -race -covermode=atomic -coverprofile=coverage.out -v . - name: Upload coverage reports to Codecov diff --git a/.golangci.yml b/.golangci.yml index 9d22d9bd3..9f2ba52e0 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -139,11 +139,6 @@ linters-settings: # Default: false all: true - varcheck: - # Check usage of exported fields and variables. - # Default: false - exported-fields: false # default false # TODO: enable after fixing false positives - linters: disable-all: true @@ -167,9 +162,7 @@ linters: - durationcheck # check for two durations multiplied together - errname # Checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error. - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. - # Removed execinquery (deprecated). execinquery is a linter about query string checker in Query function which reads your Go src files and warning it finds - exhaustive # check exhaustiveness of enum switch statements - - exportloopref # checks for pointers to enclosing loop variables - forbidigo # Forbids identifiers - funlen # Tool for detection of long functions # - gochecknoglobals # check that no global variables exist @@ -201,7 +194,6 @@ linters: - rowserrcheck # checks whether Err of rows is checked successfully - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - stylecheck # Stylecheck is a replacement for golint - - tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17 - testpackage # linter that makes you use a separate _test package - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - unconvert # Remove unnecessary type conversions @@ -239,12 +231,6 @@ linters: #- tagliatelle # Checks the struct tags. #- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers #- wsl # [too strict and mostly code is not more readable] Whitespace Linter - Forces you to use empty lines! - ## deprecated - #- exhaustivestruct # [deprecated, replaced by exhaustruct] Checks if all struct's fields are initialized - #- golint # [deprecated, replaced by revive] Golint differs from gofmt. Gofmt reformats Go source code, whereas golint prints out style mistakes - #- interfacer # [deprecated] Linter that suggests narrower interface types - #- maligned # [deprecated, replaced by govet fieldalignment] Tool to detect Go structs that would take less memory if their fields were sorted - #- scopelint # [deprecated, replaced by exportloopref] Scopelint checks for unpinned variables in go programs issues: From 74d6449f22dd8bf668ebaeb181263b675b9a668b Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Tue, 4 Mar 2025 16:26:59 +0800 Subject: [PATCH 087/129] feat: add gpt-4.5-preview models (#947) --- completion.go | 138 ++++++++++++++++++++++++++------------------------ 1 file changed, 71 insertions(+), 67 deletions(-) diff --git a/completion.go b/completion.go index 1985293f8..015fa2a9f 100644 --- a/completion.go +++ b/completion.go @@ -10,41 +10,43 @@ import ( // GPT3 Models are designed for text-based tasks. For code-specific // tasks, please refer to the Codex series of models. const ( - O1Mini = "o1-mini" - O1Mini20240912 = "o1-mini-2024-09-12" - O1Preview = "o1-preview" - O1Preview20240912 = "o1-preview-2024-09-12" - O1 = "o1" - O120241217 = "o1-2024-12-17" - O3Mini = "o3-mini" - O3Mini20250131 = "o3-mini-2025-01-31" - GPT432K0613 = "gpt-4-32k-0613" - GPT432K0314 = "gpt-4-32k-0314" - GPT432K = "gpt-4-32k" - GPT40613 = "gpt-4-0613" - GPT40314 = "gpt-4-0314" - GPT4o = "gpt-4o" - GPT4o20240513 = "gpt-4o-2024-05-13" - GPT4o20240806 = "gpt-4o-2024-08-06" - GPT4o20241120 = "gpt-4o-2024-11-20" - GPT4oLatest = "chatgpt-4o-latest" - GPT4oMini = "gpt-4o-mini" - GPT4oMini20240718 = "gpt-4o-mini-2024-07-18" - GPT4Turbo = "gpt-4-turbo" - GPT4Turbo20240409 = "gpt-4-turbo-2024-04-09" - GPT4Turbo0125 = "gpt-4-0125-preview" - GPT4Turbo1106 = "gpt-4-1106-preview" - GPT4TurboPreview = "gpt-4-turbo-preview" - GPT4VisionPreview = "gpt-4-vision-preview" - GPT4 = "gpt-4" - GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125" - GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106" - GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613" - GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301" - GPT3Dot5Turbo16K = "gpt-3.5-turbo-16k" - GPT3Dot5Turbo16K0613 = "gpt-3.5-turbo-16k-0613" - GPT3Dot5Turbo = "gpt-3.5-turbo" - GPT3Dot5TurboInstruct = "gpt-3.5-turbo-instruct" + O1Mini = "o1-mini" + O1Mini20240912 = "o1-mini-2024-09-12" + O1Preview = "o1-preview" + O1Preview20240912 = "o1-preview-2024-09-12" + O1 = "o1" + O120241217 = "o1-2024-12-17" + O3Mini = "o3-mini" + O3Mini20250131 = "o3-mini-2025-01-31" + GPT432K0613 = "gpt-4-32k-0613" + GPT432K0314 = "gpt-4-32k-0314" + GPT432K = "gpt-4-32k" + GPT40613 = "gpt-4-0613" + GPT40314 = "gpt-4-0314" + GPT4o = "gpt-4o" + GPT4o20240513 = "gpt-4o-2024-05-13" + GPT4o20240806 = "gpt-4o-2024-08-06" + GPT4o20241120 = "gpt-4o-2024-11-20" + GPT4oLatest = "chatgpt-4o-latest" + GPT4oMini = "gpt-4o-mini" + GPT4oMini20240718 = "gpt-4o-mini-2024-07-18" + GPT4Turbo = "gpt-4-turbo" + GPT4Turbo20240409 = "gpt-4-turbo-2024-04-09" + GPT4Turbo0125 = "gpt-4-0125-preview" + GPT4Turbo1106 = "gpt-4-1106-preview" + GPT4TurboPreview = "gpt-4-turbo-preview" + GPT4VisionPreview = "gpt-4-vision-preview" + GPT4 = "gpt-4" + GPT4Dot5Preview = "gpt-4.5-preview" + GPT4Dot5Preview20250227 = "gpt-4.5-preview-2025-02-27" + GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125" + GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106" + GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613" + GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301" + GPT3Dot5Turbo16K = "gpt-3.5-turbo-16k" + GPT3Dot5Turbo16K0613 = "gpt-3.5-turbo-16k-0613" + GPT3Dot5Turbo = "gpt-3.5-turbo" + GPT3Dot5TurboInstruct = "gpt-3.5-turbo-instruct" // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextDavinci003 = "text-davinci-003" // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. @@ -85,38 +87,40 @@ const ( var disabledModelsForEndpoints = map[string]map[string]bool{ "/completions": { - O1Mini: true, - O1Mini20240912: true, - O1Preview: true, - O1Preview20240912: true, - O3Mini: true, - O3Mini20250131: true, - GPT3Dot5Turbo: true, - GPT3Dot5Turbo0301: true, - GPT3Dot5Turbo0613: true, - GPT3Dot5Turbo1106: true, - GPT3Dot5Turbo0125: true, - GPT3Dot5Turbo16K: true, - GPT3Dot5Turbo16K0613: true, - GPT4: true, - GPT4o: true, - GPT4o20240513: true, - GPT4o20240806: true, - GPT4o20241120: true, - GPT4oLatest: true, - GPT4oMini: true, - GPT4oMini20240718: true, - GPT4TurboPreview: true, - GPT4VisionPreview: true, - GPT4Turbo1106: true, - GPT4Turbo0125: true, - GPT4Turbo: true, - GPT4Turbo20240409: true, - GPT40314: true, - GPT40613: true, - GPT432K: true, - GPT432K0314: true, - GPT432K0613: true, + O1Mini: true, + O1Mini20240912: true, + O1Preview: true, + O1Preview20240912: true, + O3Mini: true, + O3Mini20250131: true, + GPT3Dot5Turbo: true, + GPT3Dot5Turbo0301: true, + GPT3Dot5Turbo0613: true, + GPT3Dot5Turbo1106: true, + GPT3Dot5Turbo0125: true, + GPT3Dot5Turbo16K: true, + GPT3Dot5Turbo16K0613: true, + GPT4: true, + GPT4Dot5Preview: true, + GPT4Dot5Preview20250227: true, + GPT4o: true, + GPT4o20240513: true, + GPT4o20240806: true, + GPT4o20241120: true, + GPT4oLatest: true, + GPT4oMini: true, + GPT4oMini20240718: true, + GPT4TurboPreview: true, + GPT4VisionPreview: true, + GPT4Turbo1106: true, + GPT4Turbo0125: true, + GPT4Turbo: true, + GPT4Turbo20240409: true, + GPT40314: true, + GPT40613: true, + GPT432K: true, + GPT432K0314: true, + GPT432K0613: true, }, chatCompletionsSuffix: { CodexCodeDavinci002: true, From e99eb54c9d81cc102683921f4952a6d0c1964cbf Mon Sep 17 00:00:00 2001 From: "JT A." Date: Sun, 13 Apr 2025 12:00:48 -0600 Subject: [PATCH 088/129] add enum tag to jsonschema (#962) * fix jsonschema tests * ensure all run during PR Github Action * add test for struct to schema * add support for enum tag * support nullable tag --- .github/workflows/pr.yml | 2 +- jsonschema/json.go | 12 ++ jsonschema/json_test.go | 310 ++++++++++++++++++++++++++++++--------- 3 files changed, 252 insertions(+), 72 deletions(-) diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 818a8842b..f4cbe7c8b 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -22,6 +22,6 @@ jobs: with: version: v1.64.5 - name: Run tests - run: go test -race -covermode=atomic -coverprofile=coverage.out -v . + run: go test -race -covermode=atomic -coverprofile=coverage.out -v ./... - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v4 diff --git a/jsonschema/json.go b/jsonschema/json.go index bcb253fae..d458418f3 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -46,6 +46,8 @@ type Definition struct { // additionalProperties: false // additionalProperties: jsonschema.Definition{Type: jsonschema.String} AdditionalProperties any `json:"additionalProperties,omitempty"` + // Whether the schema is nullable or not. + Nullable bool `json:"nullable,omitempty"` } func (d *Definition) MarshalJSON() ([]byte, error) { @@ -139,6 +141,16 @@ func reflectSchemaObject(t reflect.Type) (*Definition, error) { if description != "" { item.Description = description } + enum := field.Tag.Get("enum") + if enum != "" { + item.Enum = strings.Split(enum, ",") + } + + if n := field.Tag.Get("nullable"); n != "" { + nullable, _ := strconv.ParseBool(n) + item.Nullable = nullable + } + properties[jsonTag] = *item if s := field.Tag.Get("required"); s != "" { diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go index 744706082..17f0aba8a 100644 --- a/jsonschema/json_test.go +++ b/jsonschema/json_test.go @@ -17,7 +17,7 @@ func TestDefinition_MarshalJSON(t *testing.T) { { name: "Test with empty Definition", def: jsonschema.Definition{}, - want: `{"properties":{}}`, + want: `{}`, }, { name: "Test with Definition properties set", @@ -31,15 +31,14 @@ func TestDefinition_MarshalJSON(t *testing.T) { }, }, want: `{ - "type":"string", - "description":"A string type", - "properties":{ - "name":{ - "type":"string", - "properties":{} - } - } -}`, + "type":"string", + "description":"A string type", + "properties":{ + "name":{ + "type":"string" + } + } + }`, }, { name: "Test with nested Definition properties", @@ -60,23 +59,21 @@ func TestDefinition_MarshalJSON(t *testing.T) { }, }, want: `{ - "type":"object", - "properties":{ - "user":{ - "type":"object", - "properties":{ - "name":{ - "type":"string", - "properties":{} - }, - "age":{ - "type":"integer", - "properties":{} - } - } - } - } -}`, + "type":"object", + "properties":{ + "user":{ + "type":"object", + "properties":{ + "name":{ + "type":"string" + }, + "age":{ + "type":"integer" + } + } + } + } + }`, }, { name: "Test with complex nested Definition", @@ -108,36 +105,32 @@ func TestDefinition_MarshalJSON(t *testing.T) { }, }, want: `{ - "type":"object", - "properties":{ - "user":{ - "type":"object", - "properties":{ - "name":{ - "type":"string", - "properties":{} - }, - "age":{ - "type":"integer", - "properties":{} - }, - "address":{ - "type":"object", - "properties":{ - "city":{ - "type":"string", - "properties":{} - }, - "country":{ - "type":"string", - "properties":{} - } - } - } - } - } - } -}`, + "type":"object", + "properties":{ + "user":{ + "type":"object", + "properties":{ + "name":{ + "type":"string" + }, + "age":{ + "type":"integer" + }, + "address":{ + "type":"object", + "properties":{ + "city":{ + "type":"string" + }, + "country":{ + "type":"string" + } + } + } + } + } + } + }`, }, { name: "Test with Array type Definition", @@ -153,20 +146,16 @@ func TestDefinition_MarshalJSON(t *testing.T) { }, }, want: `{ - "type":"array", - "items":{ - "type":"string", - "properties":{ - - } - }, - "properties":{ - "name":{ - "type":"string", - "properties":{} - } - } -}`, + "type":"array", + "items":{ + "type":"string" + }, + "properties":{ + "name":{ + "type":"string" + } + } + }`, }, } @@ -193,6 +182,185 @@ func TestDefinition_MarshalJSON(t *testing.T) { } } +func TestStructToSchema(t *testing.T) { + tests := []struct { + name string + in any + want string + }{ + { + name: "Test with empty struct", + in: struct{}{}, + want: `{ + "type":"object", + "additionalProperties":false + }`, + }, + { + name: "Test with struct containing many fields", + in: struct { + Name string `json:"name"` + Age int `json:"age"` + Active bool `json:"active"` + Height float64 `json:"height"` + Cities []struct { + Name string `json:"name"` + State string `json:"state"` + } `json:"cities"` + }{ + Name: "John Doe", + Age: 30, + Cities: []struct { + Name string `json:"name"` + State string `json:"state"` + }{ + {Name: "New York", State: "NY"}, + {Name: "Los Angeles", State: "CA"}, + }, + }, + want: `{ + "type":"object", + "properties":{ + "name":{ + "type":"string" + }, + "age":{ + "type":"integer" + }, + "active":{ + "type":"boolean" + }, + "height":{ + "type":"number" + }, + "cities":{ + "type":"array", + "items":{ + "additionalProperties":false, + "type":"object", + "properties":{ + "name":{ + "type":"string" + }, + "state":{ + "type":"string" + } + }, + "required":["name","state"] + } + } + }, + "required":["name","age","active","height","cities"], + "additionalProperties":false + }`, + }, + { + name: "Test with description tag", + in: struct { + Name string `json:"name" description:"The name of the person"` + }{ + Name: "John Doe", + }, + want: `{ + "type":"object", + "properties":{ + "name":{ + "type":"string", + "description":"The name of the person" + } + }, + "required":["name"], + "additionalProperties":false + }`, + }, + { + name: "Test with required tag", + in: struct { + Name string `json:"name" required:"false"` + }{ + Name: "John Doe", + }, + want: `{ + "type":"object", + "properties":{ + "name":{ + "type":"string" + } + }, + "additionalProperties":false + }`, + }, + { + name: "Test with enum tag", + in: struct { + Color string `json:"color" enum:"red,green,blue"` + }{ + Color: "red", + }, + want: `{ + "type":"object", + "properties":{ + "color":{ + "type":"string", + "enum":["red","green","blue"] + } + }, + "required":["color"], + "additionalProperties":false + }`, + }, + { + name: "Test with nullable tag", + in: struct { + Name *string `json:"name" nullable:"true"` + }{ + Name: nil, + }, + want: `{ + + "type":"object", + "properties":{ + "name":{ + "type":"string", + "nullable":true + } + }, + "required":["name"], + "additionalProperties":false + }`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + wantBytes := []byte(tt.want) + + schema, err := jsonschema.GenerateSchemaForType(tt.in) + if err != nil { + t.Errorf("Failed to generate schema: error = %v", err) + return + } + + var want map[string]interface{} + err = json.Unmarshal(wantBytes, &want) + if err != nil { + t.Errorf("Failed to Unmarshal JSON: error = %v", err) + return + } + + got := structToMap(t, schema) + gotPtr := structToMap(t, &schema) + + if !reflect.DeepEqual(got, want) { + t.Errorf("MarshalJSON() got = %v, want %v", got, want) + } + if !reflect.DeepEqual(gotPtr, want) { + t.Errorf("MarshalJSON() gotPtr = %v, want %v", gotPtr, want) + } + }) + } +} + func structToMap(t *testing.T, v any) map[string]any { t.Helper() gotBytes, err := json.Marshal(v) From d68a6838156049ada8c25d3f4b8fa3befb3b4d1b Mon Sep 17 00:00:00 2001 From: Takahiro Ikeuchi Date: Thu, 24 Apr 2025 06:50:47 +0900 Subject: [PATCH 089/129] feat: add new GPT-4.1 model variants to completion.go (#966) * feat: add new GPT-4.1 model variants to completion.go * feat: add tests for unsupported models in completion endpoint * fix: add missing periods to test function comments in completion_test.go --- completion.go | 13 ++++++++ completion_test.go | 83 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+) diff --git a/completion.go b/completion.go index 015fa2a9f..0d0c1a8f4 100644 --- a/completion.go +++ b/completion.go @@ -37,6 +37,12 @@ const ( GPT4TurboPreview = "gpt-4-turbo-preview" GPT4VisionPreview = "gpt-4-vision-preview" GPT4 = "gpt-4" + GPT4Dot1 = "gpt-4.1" + GPT4Dot120250414 = "gpt-4.1-2025-04-14" + GPT4Dot1Mini = "gpt-4.1-mini" + GPT4Dot1Mini20250414 = "gpt-4.1-mini-2025-04-14" + GPT4Dot1Nano = "gpt-4.1-nano" + GPT4Dot1Nano20250414 = "gpt-4.1-nano-2025-04-14" GPT4Dot5Preview = "gpt-4.5-preview" GPT4Dot5Preview20250227 = "gpt-4.5-preview-2025-02-27" GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125" @@ -121,6 +127,13 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT432K: true, GPT432K0314: true, GPT432K0613: true, + O1: true, + GPT4Dot1: true, + GPT4Dot120250414: true, + GPT4Dot1Mini: true, + GPT4Dot1Mini20250414: true, + GPT4Dot1Nano: true, + GPT4Dot1Nano20250414: true, }, chatCompletionsSuffix: { CodexCodeDavinci002: true, diff --git a/completion_test.go b/completion_test.go index 935bbe864..83bd899a1 100644 --- a/completion_test.go +++ b/completion_test.go @@ -181,3 +181,86 @@ func getCompletionBody(r *http.Request) (openai.CompletionRequest, error) { } return completion, nil } + +// TestCompletionWithO1Model Tests that O1 model is not supported for completion endpoint. +func TestCompletionWithO1Model(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: openai.O1, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O1 model, but returned: %v", err) + } +} + +// TestCompletionWithGPT4DotModels Tests that newer GPT4 models are not supported for completion endpoint. +func TestCompletionWithGPT4DotModels(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + models := []string{ + openai.GPT4Dot1, + openai.GPT4Dot120250414, + openai.GPT4Dot1Mini, + openai.GPT4Dot1Mini20250414, + openai.GPT4Dot1Nano, + openai.GPT4Dot1Nano20250414, + openai.GPT4Dot5Preview, + openai.GPT4Dot5Preview20250227, + } + + for _, model := range models { + t.Run(model, func(t *testing.T) { + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: model, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err) + } + }) + } +} + +// TestCompletionWithGPT4oModels Tests that GPT4o models are not supported for completion endpoint. +func TestCompletionWithGPT4oModels(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + models := []string{ + openai.GPT4o, + openai.GPT4o20240513, + openai.GPT4o20240806, + openai.GPT4o20241120, + openai.GPT4oLatest, + openai.GPT4oMini, + openai.GPT4oMini20240718, + } + + for _, model := range models { + t.Run(model, func(t *testing.T) { + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: model, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err) + } + }) + } +} From 658beda2ba8be4d155bc62208224a5766e0640c0 Mon Sep 17 00:00:00 2001 From: netr Date: Sat, 26 Apr 2025 03:13:43 -0700 Subject: [PATCH 090/129] feat: Add missing TTS models and voices (#958) * feat: Add missing TTS models and voices * feat: Add new instruction field to create speech request - From docs: Control the voice of your generated audio with additional instructions. Does not work with tts-1 or tts-1-hd. * fix: add canary-tts back to SpeechModel --- speech.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/speech.go b/speech.go index 20b52e334..60e7694fd 100644 --- a/speech.go +++ b/speech.go @@ -8,20 +8,25 @@ import ( type SpeechModel string const ( - TTSModel1 SpeechModel = "tts-1" - TTSModel1HD SpeechModel = "tts-1-hd" - TTSModelCanary SpeechModel = "canary-tts" + TTSModel1 SpeechModel = "tts-1" + TTSModel1HD SpeechModel = "tts-1-hd" + TTSModelCanary SpeechModel = "canary-tts" + TTSModelGPT4oMini SpeechModel = "gpt-4o-mini-tts" ) type SpeechVoice string const ( VoiceAlloy SpeechVoice = "alloy" + VoiceAsh SpeechVoice = "ash" + VoiceBallad SpeechVoice = "ballad" + VoiceCoral SpeechVoice = "coral" VoiceEcho SpeechVoice = "echo" VoiceFable SpeechVoice = "fable" VoiceOnyx SpeechVoice = "onyx" VoiceNova SpeechVoice = "nova" VoiceShimmer SpeechVoice = "shimmer" + VoiceVerse SpeechVoice = "verse" ) type SpeechResponseFormat string @@ -39,6 +44,7 @@ type CreateSpeechRequest struct { Model SpeechModel `json:"model"` Input string `json:"input"` Voice SpeechVoice `json:"voice"` + Instructions string `json:"instructions,omitempty"` // Optional, Doesnt work with tts-1 or tts-1-hd. ResponseFormat SpeechResponseFormat `json:"response_format,omitempty"` // Optional, default to mp3 Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0 } From 306fbbbe6f09ff7bd718e11cd322e88b442f4496 Mon Sep 17 00:00:00 2001 From: goodenough Date: Tue, 29 Apr 2025 21:24:45 +0800 Subject: [PATCH 091/129] Add support for reasoning_content field in chat completion messages for DeepSeek R1 (#925) * support deepseek field "reasoning_content" * support deepseek field "reasoning_content" * Comment ends in a period (godot) * add comment on field reasoning_content * fix go lint error * chore: trigger CI * make field "content" in MarshalJSON function omitempty * remove reasoning_content in TestO1ModelChatCompletions func * feat: Add test and handler for deepseek-reasoner chat model completions, including support for reasoning content in responses. * feat: Add test and handler for deepseek-reasoner chat model completions, including support for reasoning content in responses. * feat: Add test and handler for deepseek-reasoner chat model completions, including support for reasoning content in responses. --- chat.go | 74 ++++++++++++++++++++++++++-------------------- chat_stream.go | 6 ++++ chat_test.go | 79 ++++++++++++++++++++++++++++++++++++++++++++++++++ openai_test.go | 2 +- 4 files changed, 128 insertions(+), 33 deletions(-) diff --git a/chat.go b/chat.go index 995860c40..7112dc7b7 100644 --- a/chat.go +++ b/chat.go @@ -104,6 +104,12 @@ type ChatCompletionMessage struct { // - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb Name string `json:"name,omitempty"` + // This property is used for the "reasoning" feature supported by deepseek-reasoner + // which is not in the official documentation. + // the doc from deepseek: + // - https://api-docs.deepseek.com/api/create-chat-completion#responses + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` // For Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls. @@ -119,41 +125,44 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { } if len(m.MultiContent) > 0 { msg := struct { - Role string `json:"role"` - Content string `json:"-"` - Refusal string `json:"refusal,omitempty"` - MultiContent []ChatMessagePart `json:"content,omitempty"` - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"-"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"content,omitempty"` + Name string `json:"name,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }(m) return json.Marshal(msg) } msg := struct { - Role string `json:"role"` - Content string `json:"content,omitempty"` - Refusal string `json:"refusal,omitempty"` - MultiContent []ChatMessagePart `json:"-"` - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"content,omitempty"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"-"` + Name string `json:"name,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }(m) return json.Marshal(msg) } func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { msg := struct { - Role string `json:"role"` - Content string `json:"content,omitempty"` - Refusal string `json:"refusal,omitempty"` - MultiContent []ChatMessagePart - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"content"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart + Name string `json:"name,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }{} if err := json.Unmarshal(bs, &msg); err == nil { @@ -161,14 +170,15 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { return nil } multiMsg := struct { - Role string `json:"role"` - Content string - Refusal string `json:"refusal,omitempty"` - MultiContent []ChatMessagePart `json:"content"` - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"content"` + Name string `json:"name,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }{} if err := json.Unmarshal(bs, &multiMsg); err != nil { return err diff --git a/chat_stream.go b/chat_stream.go index 525b4457a..80d16cc63 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -11,6 +11,12 @@ type ChatCompletionStreamChoiceDelta struct { FunctionCall *FunctionCall `json:"function_call,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"` Refusal string `json:"refusal,omitempty"` + + // This property is used for the "reasoning" feature supported by deepseek-reasoner + // which is not in the official documentation. + // the doc from deepseek: + // - https://api-docs.deepseek.com/api/create-chat-completion#responses + ReasoningContent string `json:"reasoning_content,omitempty"` } type ChatCompletionStreamChoiceLogprobs struct { diff --git a/chat_test.go b/chat_test.go index e90142da6..514706c96 100644 --- a/chat_test.go +++ b/chat_test.go @@ -411,6 +411,23 @@ func TestO3ModelChatCompletions(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion error") } +func TestDeepseekR1ModelChatCompletions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleDeepseekR1ChatCompletionEndpoint) + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + Model: "deepseek-reasoner", + MaxCompletionTokens: 100, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") +} + // TestCompletions Tests the completions endpoint of the API using the mocked server. func TestChatCompletionsWithHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() @@ -822,6 +839,68 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, string(resBytes)) } +func handleDeepseekR1ChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { + var err error + var resBytes []byte + + // completions only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + var completionReq openai.ChatCompletionRequest + if completionReq, err = getChatCompletionBody(r); err != nil { + http.Error(w, "could not read request", http.StatusInternalServerError) + return + } + res := openai.ChatCompletionResponse{ + ID: strconv.Itoa(int(time.Now().Unix())), + Object: "test-object", + Created: time.Now().Unix(), + // would be nice to validate Model during testing, but + // this may not be possible with how much upkeep + // would be required / wouldn't make much sense + Model: completionReq.Model, + } + // create completions + n := completionReq.N + if n == 0 { + n = 1 + } + if completionReq.MaxCompletionTokens == 0 { + completionReq.MaxCompletionTokens = 1000 + } + for i := 0; i < n; i++ { + reasoningContent := "User says hello! And I need to reply" + completionStr := strings.Repeat("a", completionReq.MaxCompletionTokens-numTokens(reasoningContent)) + res.Choices = append(res.Choices, openai.ChatCompletionChoice{ + Message: openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, + ReasoningContent: reasoningContent, + Content: completionStr, + }, + Index: i, + }) + } + inputTokens := numTokens(completionReq.Messages[0].Content) * n + completionTokens := completionReq.MaxTokens * n + res.Usage = openai.Usage{ + PromptTokens: inputTokens, + CompletionTokens: completionTokens, + TotalTokens: inputTokens + completionTokens, + } + resBytes, _ = json.Marshal(res) + w.Header().Set(xCustomHeader, xCustomHeaderValue) + for k, v := range rateLimitHeaders { + switch val := v.(type) { + case int: + w.Header().Set(k, strconv.Itoa(val)) + default: + w.Header().Set(k, fmt.Sprintf("%s", v)) + } + } + fmt.Fprintln(w, string(resBytes)) +} + // getChatCompletionBody Returns the body of the request to create a completion. func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error) { completion := openai.ChatCompletionRequest{} diff --git a/openai_test.go b/openai_test.go index 6c26eebd1..a55f3a858 100644 --- a/openai_test.go +++ b/openai_test.go @@ -29,7 +29,7 @@ func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, tea // numTokens Returns the number of GPT-3 encoded tokens in the given text. // This function approximates based on the rule of thumb stated by OpenAI: -// https://beta.openai.com/tokenizer/ +// https://beta.openai.com/tokenizer. // // TODO: implement an actual tokenizer for GPT-3 and Codex (once available). func numTokens(s string) int { From 4cccc6c93455d2ff2cf52660c916cfa5907ddbd3 Mon Sep 17 00:00:00 2001 From: Zhongxian Pan Date: Tue, 29 Apr 2025 21:29:15 +0800 Subject: [PATCH 092/129] Adapt different stream data prefix, with or without space (#945) --- stream_reader.go | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/stream_reader.go b/stream_reader.go index ecfa26807..6faefe0a7 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -6,13 +6,14 @@ import ( "fmt" "io" "net/http" + "regexp" utils "github.com/sashabaranov/go-openai/internal" ) var ( - headerData = []byte("data: ") - errorPrefix = []byte(`data: {"error":`) + headerData = regexp.MustCompile(`^data:\s*`) + errorPrefix = regexp.MustCompile(`^data:\s*{"error":`) ) type streamable interface { @@ -70,12 +71,12 @@ func (stream *streamReader[T]) processLines() ([]byte, error) { } noSpaceLine := bytes.TrimSpace(rawLine) - if bytes.HasPrefix(noSpaceLine, errorPrefix) { + if errorPrefix.Match(noSpaceLine) { hasErrorPrefix = true } - if !bytes.HasPrefix(noSpaceLine, headerData) || hasErrorPrefix { + if !headerData.Match(noSpaceLine) || hasErrorPrefix { if hasErrorPrefix { - noSpaceLine = bytes.TrimPrefix(noSpaceLine, headerData) + noSpaceLine = headerData.ReplaceAll(noSpaceLine, nil) } writeErr := stream.errAccumulator.Write(noSpaceLine) if writeErr != nil { @@ -89,7 +90,7 @@ func (stream *streamReader[T]) processLines() ([]byte, error) { continue } - noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData) + noPrefixLine := headerData.ReplaceAll(noSpaceLine, nil) if string(noPrefixLine) == "[DONE]" { stream.isFinished = true return nil, io.EOF From bb5bc275678767d410d25d307959e1c45bc89c90 Mon Sep 17 00:00:00 2001 From: rory malcolm Date: Tue, 29 Apr 2025 14:34:33 +0100 Subject: [PATCH 093/129] Add support for `4o-mini` and `3o` (#968) - This adds supports, and tests, for the 3o and 4o-mini class of models --- chat_stream_test.go | 50 ++++++++++++++++++++++++++++++++++++++++++ completion.go | 8 +++++++ completion_test.go | 36 ++++++++++++++++++++++++++++++ models_test.go | 18 +++++++++++++++ reasoning_validator.go | 3 ++- 5 files changed, 114 insertions(+), 1 deletion(-) diff --git a/chat_stream_test.go b/chat_stream_test.go index 4d992e4d1..eabb0f3a2 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -959,6 +959,56 @@ func TestCreateChatCompletionStreamReasoningValidatorFails(t *testing.T) { } } +func TestCreateChatCompletionStreamO3ReasoningValidatorFails(t *testing.T) { + client, _, _ := setupOpenAITestServer() + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 100, // This will trigger the validator to fail + Model: openai.O3, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + + if stream != nil { + t.Error("Expected nil stream when validation fails") + stream.Close() + } + + if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) { + t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated for O3, got: %v", err) + } +} + +func TestCreateChatCompletionStreamO4MiniReasoningValidatorFails(t *testing.T) { + client, _, _ := setupOpenAITestServer() + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 100, // This will trigger the validator to fail + Model: openai.O4Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + + if stream != nil { + t.Error("Expected nil stream when validation fails") + stream.Close() + } + + if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) { + t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated for O4Mini, got: %v", err) + } +} + func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool { if c1.Index != c2.Index { return false diff --git a/completion.go b/completion.go index 0d0c1a8f4..7a6de3033 100644 --- a/completion.go +++ b/completion.go @@ -16,8 +16,12 @@ const ( O1Preview20240912 = "o1-preview-2024-09-12" O1 = "o1" O120241217 = "o1-2024-12-17" + O3 = "o3" + O320250416 = "o3-2025-04-16" O3Mini = "o3-mini" O3Mini20250131 = "o3-mini-2025-01-31" + O4Mini = "o4-mini" + O4Mini2020416 = "o4-mini-2025-04-16" GPT432K0613 = "gpt-4-32k-0613" GPT432K0314 = "gpt-4-32k-0314" GPT432K = "gpt-4-32k" @@ -99,6 +103,10 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ O1Preview20240912: true, O3Mini: true, O3Mini20250131: true, + O4Mini: true, + O4Mini2020416: true, + O3: true, + O320250416: true, GPT3Dot5Turbo: true, GPT3Dot5Turbo0301: true, GPT3Dot5Turbo0613: true, diff --git a/completion_test.go b/completion_test.go index 83bd899a1..27e2d150e 100644 --- a/completion_test.go +++ b/completion_test.go @@ -33,6 +33,42 @@ func TestCompletionsWrongModel(t *testing.T) { } } +// TestCompletionsWrongModelO3 Tests the completions endpoint with O3 model which is not supported. +func TestCompletionsWrongModelO3(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: openai.O3, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O3, but returned: %v", err) + } +} + +// TestCompletionsWrongModelO4Mini Tests the completions endpoint with O4Mini model which is not supported. +func TestCompletionsWrongModelO4Mini(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: openai.O4Mini, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O4Mini, but returned: %v", err) + } +} + func TestCompletionWithStream(t *testing.T) { config := openai.DefaultConfig("whatever") client := openai.NewClientWithConfig(config) diff --git a/models_test.go b/models_test.go index 24a28ed23..7fd010c34 100644 --- a/models_test.go +++ b/models_test.go @@ -47,6 +47,24 @@ func TestGetModel(t *testing.T) { checks.NoError(t, err, "GetModel error") } +// TestGetModelO3 Tests the retrieve O3 model endpoint of the API using the mocked server. +func TestGetModelO3(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/models/o3", handleGetModelEndpoint) + _, err := client.GetModel(context.Background(), "o3") + checks.NoError(t, err, "GetModel error for O3") +} + +// TestGetModelO4Mini Tests the retrieve O4Mini model endpoint of the API using the mocked server. +func TestGetModelO4Mini(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/models/o4-mini", handleGetModelEndpoint) + _, err := client.GetModel(context.Background(), "o4-mini") + checks.NoError(t, err, "GetModel error for O4Mini") +} + func TestAzureGetModel(t *testing.T) { client, server, teardown := setupAzureTestServer() defer teardown() diff --git a/reasoning_validator.go b/reasoning_validator.go index 040d6b495..2910b1395 100644 --- a/reasoning_validator.go +++ b/reasoning_validator.go @@ -40,8 +40,9 @@ func NewReasoningValidator() *ReasoningValidator { func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error { o1Series := strings.HasPrefix(request.Model, "o1") o3Series := strings.HasPrefix(request.Model, "o3") + o4Series := strings.HasPrefix(request.Model, "o4") - if !o1Series && !o3Series { + if !o1Series && !o3Series && !o4Series { return nil } From da5f9bc9bc40537a0c2b451fffa9364efa94dbe1 Mon Sep 17 00:00:00 2001 From: Sean McGinnis Date: Tue, 29 Apr 2025 08:35:26 -0500 Subject: [PATCH 094/129] Add CompletionRequest.StreamOptions (#959) The legacy completion API supports a `stream_options` object when `stream` is set to true [0]. This adds a StreamOptions property to the CompletionRequest struct to support this setting. [0] https://platform.openai.com/docs/api-reference/completions/create#completions-create-stream_options Signed-off-by: Sean McGinnis --- completion.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/completion.go b/completion.go index 7a6de3033..9c3a64dd5 100644 --- a/completion.go +++ b/completion.go @@ -215,6 +215,8 @@ type CompletionRequest struct { Temperature float32 `json:"temperature,omitempty"` TopP float32 `json:"top_p,omitempty"` User string `json:"user,omitempty"` + // Options for streaming response. Only set this when you set stream: true. + StreamOptions *StreamOptions `json:"stream_options,omitempty"` } // CompletionChoice represents one of possible completions. From 6836cf6a6fd0027ea21f8d31bff5d023040d9db4 Mon Sep 17 00:00:00 2001 From: Oleksandr Redko Date: Tue, 29 Apr 2025 16:36:38 +0300 Subject: [PATCH 095/129] Remove redundant typecheck linter (#955) --- .golangci.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.golangci.yml b/.golangci.yml index 9f2ba52e0..a5988825b 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -149,7 +149,6 @@ linters: - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string - ineffassign # Detects when assignments to existing variables are not used - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks - - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code - unused # Checks Go code for unused constants, variables, functions and types ## disabled by default # - asasalint # Check for pass []any as any in variadic func(...any) From 93a611cf4f1d227963a4f28a2c4a3422a0d37bfd Mon Sep 17 00:00:00 2001 From: Daniel Peng Date: Tue, 29 Apr 2025 06:38:27 -0700 Subject: [PATCH 096/129] Add Prediction field (#970) * Add Prediction field to ChatCompletionRequest * Include prediction tokens in response --- chat.go | 7 +++++++ common.go | 6 ++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/chat.go b/chat.go index 7112dc7b7..0f91d481c 100644 --- a/chat.go +++ b/chat.go @@ -273,6 +273,8 @@ type ChatCompletionRequest struct { ReasoningEffort string `json:"reasoning_effort,omitempty"` // Metadata to store with the completion. Metadata map[string]string `json:"metadata,omitempty"` + // Configuration for a predicted output. + Prediction *Prediction `json:"prediction,omitempty"` } type StreamOptions struct { @@ -340,6 +342,11 @@ type LogProbs struct { Content []LogProb `json:"content"` } +type Prediction struct { + Content string `json:"content"` + Type string `json:"type"` +} + type FinishReason string const ( diff --git a/common.go b/common.go index 8cc7289c0..d1936d656 100644 --- a/common.go +++ b/common.go @@ -13,8 +13,10 @@ type Usage struct { // CompletionTokensDetails Breakdown of tokens used in a completion. type CompletionTokensDetails struct { - AudioTokens int `json:"audio_tokens"` - ReasoningTokens int `json:"reasoning_tokens"` + AudioTokens int `json:"audio_tokens"` + ReasoningTokens int `json:"reasoning_tokens"` + AcceptedPredictionTokens int `json:"accepted_prediction_tokens"` + RejectedPredictionTokens int `json:"rejected_prediction_tokens"` } // PromptTokensDetails Breakdown of tokens used in the prompt. From d65f0cb54e8c9c91f4340fad14243beeb38f5a08 Mon Sep 17 00:00:00 2001 From: Ben Katz Date: Sun, 4 May 2025 03:44:48 +0700 Subject: [PATCH 097/129] Fix: Corrected typo in O4Mini20250416 model name and endpoint map. (#981) --- completion.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/completion.go b/completion.go index 9c3a64dd5..21d4897c4 100644 --- a/completion.go +++ b/completion.go @@ -21,7 +21,7 @@ const ( O3Mini = "o3-mini" O3Mini20250131 = "o3-mini-2025-01-31" O4Mini = "o4-mini" - O4Mini2020416 = "o4-mini-2025-04-16" + O4Mini20250416 = "o4-mini-2025-04-16" GPT432K0613 = "gpt-4-32k-0613" GPT432K0314 = "gpt-4-32k-0314" GPT432K = "gpt-4-32k" @@ -104,7 +104,7 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ O3Mini: true, O3Mini20250131: true, O4Mini: true, - O4Mini2020416: true, + O4Mini20250416: true, O3: true, O320250416: true, GPT3Dot5Turbo: true, From 5ea214a188a7751bddb9f6c899632509d388f643 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Sat, 3 May 2025 22:25:14 +0100 Subject: [PATCH 098/129] Improve unit test coverage (#984) * add tests for config * add audio tests * lint * lint * lint --- audio_test.go | 132 +++++++++++++++++++++++++++++++++++++++++++++++++ config_test.go | 21 ++++++++ 2 files changed, 153 insertions(+) diff --git a/audio_test.go b/audio_test.go index 9f32d5468..51b3f465d 100644 --- a/audio_test.go +++ b/audio_test.go @@ -2,12 +2,16 @@ package openai //nolint:testpackage // testing private field import ( "bytes" + "context" + "errors" "fmt" "io" + "net/http" "os" "path/filepath" "testing" + utils "github.com/sashabaranov/go-openai/internal" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" ) @@ -107,3 +111,131 @@ func TestCreateFileField(t *testing.T) { checks.HasError(t, err, "createFileField using file should return error when open file fails") }) } + +// failingFormBuilder always returns an error when creating form files. +type failingFormBuilder struct{ err error } + +func (f *failingFormBuilder) CreateFormFile(_ string, _ *os.File) error { + return f.err +} + +func (f *failingFormBuilder) CreateFormFileReader(_ string, _ io.Reader, _ string) error { + return f.err +} + +func (f *failingFormBuilder) WriteField(_, _ string) error { + return nil +} + +func (f *failingFormBuilder) Close() error { + return nil +} + +func (f *failingFormBuilder) FormDataContentType() string { + return "multipart/form-data" +} + +// failingAudioRequestBuilder simulates an error during HTTP request construction. +type failingAudioRequestBuilder struct{ err error } + +func (f *failingAudioRequestBuilder) Build( + _ context.Context, + _, _ string, + _ any, + _ http.Header, +) (*http.Request, error) { + return nil, f.err +} + +// errorHTTPClient always returns an error when making HTTP calls. +type errorHTTPClient struct{ err error } + +func (e *errorHTTPClient) Do(_ *http.Request) (*http.Response, error) { + return nil, e.err +} + +func TestCallAudioAPIMultipartFormError(t *testing.T) { + client := NewClient("test-token") + errForm := errors.New("mock create form file failure") + // Override form builder to force an error during multipart form creation. + client.createFormBuilder = func(_ io.Writer) utils.FormBuilder { + return &failingFormBuilder{err: errForm} + } + + // Provide a reader so createFileField uses the reader path (no file open). + req := AudioRequest{FilePath: "fake.mp3", Reader: bytes.NewBuffer([]byte("dummy")), Model: Whisper1} + _, err := client.callAudioAPI(context.Background(), req, "transcriptions") + if err == nil { + t.Fatal("expected error but got none") + } + if !errors.Is(err, errForm) { + t.Errorf("expected error %v, got %v", errForm, err) + } +} + +func TestCallAudioAPINewRequestError(t *testing.T) { + client := NewClient("test-token") + // Create a real temp file so multipart form succeeds. + tmp := t.TempDir() + path := filepath.Join(tmp, "file.mp3") + if err := os.WriteFile(path, []byte("content"), 0644); err != nil { + t.Fatalf("failed to write temp file: %v", err) + } + + errBuild := errors.New("mock build failure") + client.requestBuilder = &failingAudioRequestBuilder{err: errBuild} + + req := AudioRequest{FilePath: path, Model: Whisper1} + _, err := client.callAudioAPI(context.Background(), req, "translations") + if err == nil { + t.Fatal("expected error but got none") + } + if !errors.Is(err, errBuild) { + t.Errorf("expected error %v, got %v", errBuild, err) + } +} + +func TestCallAudioAPISendRequestErrorJSON(t *testing.T) { + client := NewClient("test-token") + // Create a real temp file so multipart form succeeds. + tmp := t.TempDir() + path := filepath.Join(tmp, "file.mp3") + if err := os.WriteFile(path, []byte("content"), 0644); err != nil { + t.Fatalf("failed to write temp file: %v", err) + } + + errHTTP := errors.New("mock HTTPClient failure") + // Override HTTP client to simulate a network error. + client.config.HTTPClient = &errorHTTPClient{err: errHTTP} + + req := AudioRequest{FilePath: path, Model: Whisper1} + _, err := client.callAudioAPI(context.Background(), req, "transcriptions") + if err == nil { + t.Fatal("expected error but got none") + } + if !errors.Is(err, errHTTP) { + t.Errorf("expected error %v, got %v", errHTTP, err) + } +} + +func TestCallAudioAPISendRequestErrorText(t *testing.T) { + client := NewClient("test-token") + tmp := t.TempDir() + path := filepath.Join(tmp, "file.mp3") + if err := os.WriteFile(path, []byte("content"), 0644); err != nil { + t.Fatalf("failed to write temp file: %v", err) + } + + errHTTP := errors.New("mock HTTPClient failure") + client.config.HTTPClient = &errorHTTPClient{err: errHTTP} + + // Use a non-JSON response format to exercise the text path. + req := AudioRequest{FilePath: path, Model: Whisper1, Format: AudioResponseFormatText} + _, err := client.callAudioAPI(context.Background(), req, "translations") + if err == nil { + t.Fatal("expected error but got none") + } + if !errors.Is(err, errHTTP) { + t.Errorf("expected error %v, got %v", errHTTP, err) + } +} diff --git a/config_test.go b/config_test.go index 145c26066..960230804 100644 --- a/config_test.go +++ b/config_test.go @@ -100,3 +100,24 @@ func TestDefaultAnthropicConfigWithEmptyValues(t *testing.T) { t.Errorf("Expected BaseURL to be %v, got %v", expectedBaseURL, config.BaseURL) } } + +func TestClientConfigString(t *testing.T) { + // String() should always return the constant value + conf := openai.DefaultConfig("dummy-token") + expected := "" + got := conf.String() + if got != expected { + t.Errorf("ClientConfig.String() = %q; want %q", got, expected) + } +} + +func TestGetAzureDeploymentByModel_NoMapper(t *testing.T) { + // On a zero-value or DefaultConfig, AzureModelMapperFunc is nil, + // so GetAzureDeploymentByModel should just return the input model. + conf := openai.DefaultConfig("dummy-token") + model := "some-model" + got := conf.GetAzureDeploymentByModel(model) + if got != model { + t.Errorf("GetAzureDeploymentByModel(%q) = %q; want %q", model, got, model) + } +} From 77ccac8d342f7704b020e63b23cec2d6f009bff9 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Sat, 3 May 2025 22:39:47 +0100 Subject: [PATCH 099/129] Upgrade golangci-lint to 2.1.5 (#986) * Upgrade golangci-lint to 2.1.5 * update action --- .github/workflows/pr.yml | 4 +- .golangci.yml | 418 +++++++++++++++------------------------ 2 files changed, 166 insertions(+), 256 deletions(-) diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index f4cbe7c8b..268e3259b 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -18,9 +18,9 @@ jobs: run: | go vet . - name: Run golangci-lint - uses: golangci/golangci-lint-action@v6 + uses: golangci/golangci-lint-action@v7 with: - version: v1.64.5 + version: v2.1.5 - name: Run tests run: go test -race -covermode=atomic -coverprofile=coverage.out -v ./... - name: Upload coverage reports to Codecov diff --git a/.golangci.yml b/.golangci.yml index a5988825b..6391ad76f 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,258 +1,168 @@ -## Golden config for golangci-lint v1.47.3 -# -# This is the best config for golangci-lint based on my experience and opinion. -# It is very strict, but not extremely strict. -# Feel free to adopt and change it for your needs. - -run: - # Timeout for analysis, e.g. 30s, 5m. - # Default: 1m - timeout: 3m - - -# This file contains only configs which differ from defaults. -# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml -linters-settings: - cyclop: - # The maximal code complexity to report. - # Default: 10 - max-complexity: 30 - # The maximal average package complexity. - # If it's higher than 0.0 (float) the check is enabled - # Default: 0.0 - package-average: 10.0 - - errcheck: - # Report about not checking of errors in type assertions: `a := b.(MyStruct)`. - # Such cases aren't reported by default. - # Default: false - check-type-assertions: true - - funlen: - # Checks the number of lines in a function. - # If lower than 0, disable the check. - # Default: 60 - lines: 100 - # Checks the number of statements in a function. - # If lower than 0, disable the check. - # Default: 40 - statements: 50 - - gocognit: - # Minimal code complexity to report - # Default: 30 (but we recommend 10-20) - min-complexity: 20 - - gocritic: - # Settings passed to gocritic. - # The settings key is the name of a supported gocritic checker. - # The list of supported checkers can be find in https://go-critic.github.io/overview. - settings: - captLocal: - # Whether to restrict checker to params only. - # Default: true - paramsOnly: false - underef: - # Whether to skip (*x).method() calls where x is a pointer receiver. - # Default: true - skipRecvDeref: false - - mnd: - # List of function patterns to exclude from analysis. - # Values always ignored: `time.Date` - # Default: [] - ignored-functions: - - os.Chmod - - os.Mkdir - - os.MkdirAll - - os.OpenFile - - os.WriteFile - - prometheus.ExponentialBuckets - - prometheus.ExponentialBucketsRange - - prometheus.LinearBuckets - - strconv.FormatFloat - - strconv.FormatInt - - strconv.FormatUint - - strconv.ParseFloat - - strconv.ParseInt - - strconv.ParseUint - - gomodguard: - blocked: - # List of blocked modules. - # Default: [] - modules: - - github.com/golang/protobuf: - recommendations: - - google.golang.org/protobuf - reason: "see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules" - - github.com/satori/go.uuid: - recommendations: - - github.com/google/uuid - reason: "satori's package is not maintained" - - github.com/gofrs/uuid: - recommendations: - - github.com/google/uuid - reason: "see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw" - - govet: - # Enable all analyzers. - # Default: false - enable-all: true - # Disable analyzers by name. - # Run `go tool vet help` to see all analyzers. - # Default: [] - disable: - - fieldalignment # too strict - # Settings per analyzer. - settings: - shadow: - # Whether to be strict about shadowing; can be noisy. - # Default: false - strict: true - - nakedret: - # Make an issue if func has more lines of code than this setting, and it has naked returns. - # Default: 30 - max-func-lines: 0 - - nolintlint: - # Exclude following linters from requiring an explanation. - # Default: [] - allow-no-explanation: [ funlen, gocognit, lll ] - # Enable to require an explanation of nonzero length after each nolint directive. - # Default: false - require-explanation: true - # Enable to require nolint directives to mention the specific linter being suppressed. - # Default: false - require-specific: true - - rowserrcheck: - # database/sql is always checked - # Default: [] - packages: - - github.com/jmoiron/sqlx - - tenv: - # The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures. - # Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked. - # Default: false - all: true - - +version: "2" linters: - disable-all: true + default: none enable: - ## enabled by default - - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases - - gosimple # Linter for Go source code that specializes in simplifying a code - - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string - - ineffassign # Detects when assignments to existing variables are not used - - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks - - unused # Checks Go code for unused constants, variables, functions and types - ## disabled by default - # - asasalint # Check for pass []any as any in variadic func(...any) - - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers - - bidichk # Checks for dangerous unicode character sequences - - bodyclose # checks whether HTTP response body is closed successfully - - contextcheck # check the function whether use a non-inherited context - - cyclop # checks function and package cyclomatic complexity - - dupl # Tool for code clone detection - - durationcheck # check for two durations multiplied together - - errname # Checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error. - - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. - - exhaustive # check exhaustiveness of enum switch statements - - forbidigo # Forbids identifiers - - funlen # Tool for detection of long functions - # - gochecknoglobals # check that no global variables exist - - gochecknoinits # Checks that no init functions are present in Go code - - gocognit # Computes and checks the cognitive complexity of functions - - goconst # Finds repeated strings that could be replaced by a constant - - gocritic # Provides diagnostics that check for bugs, performance and style issues. - - gocyclo # Computes and checks the cyclomatic complexity of functions - - godot # Check if comments end in a period - - goimports # In addition to fixing imports, goimports also formats your code in the same style as gofmt. - - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - - goprintffuncname # Checks that printf-like functions are named with f at the end - - gosec # Inspects source code for security problems - - lll # Reports long lines - - makezero # Finds slice declarations with non-zero initial length - # - nakedret # Finds naked returns in functions greater than a specified function length - - mnd # An analyzer to detect magic numbers. - - nestif # Reports deeply nested if statements - - nilerr # Finds the code that returns nil even if it checks that the error is not nil. - - nilnil # Checks that there is no simultaneous return of nil error and an invalid value. - # - noctx # noctx finds sending http request without context.Context - - nolintlint # Reports ill-formed or insufficient nolint directives - # - nonamedreturns # Reports all named returns - - nosprintfhostport # Checks for misuse of Sprintf to construct a host with port in a URL. - - predeclared # find code that shadows one of Go's predeclared identifiers - - promlinter # Check Prometheus metrics naming via promlint - - revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint. - - rowserrcheck # checks whether Err of rows is checked successfully - - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - - stylecheck # Stylecheck is a replacement for golint - - testpackage # linter that makes you use a separate _test package - - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - - unconvert # Remove unnecessary type conversions - - unparam # Reports unused function parameters - - usetesting # Reports uses of functions with replacement inside the testing package - - wastedassign # wastedassign finds wasted assignment statements. - - whitespace # Tool for detection of leading and trailing whitespace - ## you may want to enable - #- decorder # check declaration order and count of types, constants, variables and functions - #- exhaustruct # Checks if all structure fields are initialized - #- goheader # Checks is file header matches to pattern - #- ireturn # Accept Interfaces, Return Concrete Types - #- prealloc # [premature optimization, but can be used in some cases] Finds slice declarations that could potentially be preallocated - #- varnamelen # [great idea, but too many false positives] checks that the length of a variable's name matches its scope - #- wrapcheck # Checks that errors returned from external packages are wrapped - ## disabled - #- containedctx # containedctx is a linter that detects struct contained context.Context field - #- depguard # [replaced by gomodguard] Go linter that checks if package imports are in a list of acceptable packages - #- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) - #- errchkjson # [don't see profit + I'm against of omitting errors like in the first example https://github.com/breml/errchkjson] Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occasions, where the check for the returned error can be omitted. - #- forcetypeassert # [replaced by errcheck] finds forced type assertions - #- gci # Gci controls golang package import order and makes it always deterministic. - #- godox # Tool for detection of FIXME, TODO and other comment keywords - #- goerr113 # [too strict] Golang linter to check the errors handling expressions - #- gofmt # [replaced by goimports] Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification - #- gofumpt # [replaced by goimports, gofumports is not available yet] Gofumpt checks whether code was gofumpt-ed. - #- grouper # An analyzer to analyze expression groups. - #- ifshort # Checks that your code uses short syntax for if-statements whenever possible - #- importas # Enforces consistent import aliases - #- maintidx # maintidx measures the maintainability index of each function. - #- misspell # [useless] Finds commonly misspelled English words in comments - #- nlreturn # [too strict and mostly code is not more readable] nlreturn checks for a new line before return and branch statements to increase code clarity - #- nosnakecase # Detects snake case of variable naming and function name. # TODO: maybe enable after https://github.com/sivchari/nosnakecase/issues/14 - #- paralleltest # [too many false positives] paralleltest detects missing usage of t.Parallel() method in your Go test - #- tagliatelle # Checks the struct tags. - #- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - #- wsl # [too strict and mostly code is not more readable] Whitespace Linter - Forces you to use empty lines! - - + - asciicheck + - bidichk + - bodyclose + - contextcheck + - cyclop + - dupl + - durationcheck + - errcheck + - errname + - errorlint + - exhaustive + - forbidigo + - funlen + - gochecknoinits + - gocognit + - goconst + - gocritic + - gocyclo + - godot + - gomoddirectives + - gomodguard + - goprintffuncname + - gosec + - govet + - ineffassign + - lll + - makezero + - mnd + - nestif + - nilerr + - nilnil + - nolintlint + - nosprintfhostport + - predeclared + - promlinter + - revive + - rowserrcheck + - sqlclosecheck + - staticcheck + - testpackage + - tparallel + - unconvert + - unparam + - unused + - usetesting + - wastedassign + - whitespace + settings: + cyclop: + max-complexity: 30 + package-average: 10 + errcheck: + check-type-assertions: true + funlen: + lines: 100 + statements: 50 + gocognit: + min-complexity: 20 + gocritic: + settings: + captLocal: + paramsOnly: false + underef: + skipRecvDeref: false + gomodguard: + blocked: + modules: + - github.com/golang/protobuf: + recommendations: + - google.golang.org/protobuf + reason: see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules + - github.com/satori/go.uuid: + recommendations: + - github.com/google/uuid + reason: satori's package is not maintained + - github.com/gofrs/uuid: + recommendations: + - github.com/google/uuid + reason: 'see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw' + govet: + disable: + - fieldalignment + enable-all: true + settings: + shadow: + strict: true + mnd: + ignored-functions: + - os.Chmod + - os.Mkdir + - os.MkdirAll + - os.OpenFile + - os.WriteFile + - prometheus.ExponentialBuckets + - prometheus.ExponentialBucketsRange + - prometheus.LinearBuckets + - strconv.FormatFloat + - strconv.FormatInt + - strconv.FormatUint + - strconv.ParseFloat + - strconv.ParseInt + - strconv.ParseUint + nakedret: + max-func-lines: 0 + nolintlint: + require-explanation: true + require-specific: true + allow-no-explanation: + - funlen + - gocognit + - lll + rowserrcheck: + packages: + - github.com/jmoiron/sqlx + exclusions: + generated: lax + presets: + - comments + - common-false-positives + - legacy + - std-error-handling + rules: + - linters: + - forbidigo + - mnd + - revive + path : ^examples/.*\.go$ + - linters: + - lll + source: ^//\s*go:generate\s + - linters: + - godot + source: (noinspection|TODO) + - linters: + - gocritic + source: //noinspection + - linters: + - errorlint + source: ^\s+if _, ok := err\.\([^.]+\.InternalError\); ok { + - linters: + - bodyclose + - dupl + - funlen + - goconst + - gosec + - noctx + - wrapcheck + - staticcheck + path: _test\.go + paths: + - third_party$ + - builtin$ + - examples$ issues: - # Maximum count of issues with the same text. - # Set to 0 to disable. - # Default: 3 max-same-issues: 50 - - exclude-rules: - - source: "^//\\s*go:generate\\s" - linters: [ lll ] - - source: "(noinspection|TODO)" - linters: [ godot ] - - source: "//noinspection" - linters: [ gocritic ] - - source: "^\\s+if _, ok := err\\.\\([^.]+\\.InternalError\\); ok {" - linters: [ errorlint ] - - path: "_test\\.go" - linters: - - bodyclose - - dupl - - funlen - - goconst - - gosec - - noctx - - wrapcheck +formatters: + enable: + - goimports + exclusions: + generated: lax + paths: + - third_party$ + - builtin$ + - examples$ From 6181facea7e6e5525b6b8da42205d7cce822c249 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Sun, 4 May 2025 15:45:40 +0100 Subject: [PATCH 100/129] update codecov action, pass token (#987) --- .github/workflows/pr.yml | 4 +- .golangci.bck.yml | 258 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 261 insertions(+), 1 deletion(-) create mode 100644 .golangci.bck.yml diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 268e3259b..18c720f03 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -24,4 +24,6 @@ jobs: - name: Run tests run: go test -race -covermode=atomic -coverprofile=coverage.out -v ./... - name: Upload coverage reports to Codecov - uses: codecov/codecov-action@v4 + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.golangci.bck.yml b/.golangci.bck.yml new file mode 100644 index 000000000..a5988825b --- /dev/null +++ b/.golangci.bck.yml @@ -0,0 +1,258 @@ +## Golden config for golangci-lint v1.47.3 +# +# This is the best config for golangci-lint based on my experience and opinion. +# It is very strict, but not extremely strict. +# Feel free to adopt and change it for your needs. + +run: + # Timeout for analysis, e.g. 30s, 5m. + # Default: 1m + timeout: 3m + + +# This file contains only configs which differ from defaults. +# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml +linters-settings: + cyclop: + # The maximal code complexity to report. + # Default: 10 + max-complexity: 30 + # The maximal average package complexity. + # If it's higher than 0.0 (float) the check is enabled + # Default: 0.0 + package-average: 10.0 + + errcheck: + # Report about not checking of errors in type assertions: `a := b.(MyStruct)`. + # Such cases aren't reported by default. + # Default: false + check-type-assertions: true + + funlen: + # Checks the number of lines in a function. + # If lower than 0, disable the check. + # Default: 60 + lines: 100 + # Checks the number of statements in a function. + # If lower than 0, disable the check. + # Default: 40 + statements: 50 + + gocognit: + # Minimal code complexity to report + # Default: 30 (but we recommend 10-20) + min-complexity: 20 + + gocritic: + # Settings passed to gocritic. + # The settings key is the name of a supported gocritic checker. + # The list of supported checkers can be find in https://go-critic.github.io/overview. + settings: + captLocal: + # Whether to restrict checker to params only. + # Default: true + paramsOnly: false + underef: + # Whether to skip (*x).method() calls where x is a pointer receiver. + # Default: true + skipRecvDeref: false + + mnd: + # List of function patterns to exclude from analysis. + # Values always ignored: `time.Date` + # Default: [] + ignored-functions: + - os.Chmod + - os.Mkdir + - os.MkdirAll + - os.OpenFile + - os.WriteFile + - prometheus.ExponentialBuckets + - prometheus.ExponentialBucketsRange + - prometheus.LinearBuckets + - strconv.FormatFloat + - strconv.FormatInt + - strconv.FormatUint + - strconv.ParseFloat + - strconv.ParseInt + - strconv.ParseUint + + gomodguard: + blocked: + # List of blocked modules. + # Default: [] + modules: + - github.com/golang/protobuf: + recommendations: + - google.golang.org/protobuf + reason: "see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules" + - github.com/satori/go.uuid: + recommendations: + - github.com/google/uuid + reason: "satori's package is not maintained" + - github.com/gofrs/uuid: + recommendations: + - github.com/google/uuid + reason: "see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw" + + govet: + # Enable all analyzers. + # Default: false + enable-all: true + # Disable analyzers by name. + # Run `go tool vet help` to see all analyzers. + # Default: [] + disable: + - fieldalignment # too strict + # Settings per analyzer. + settings: + shadow: + # Whether to be strict about shadowing; can be noisy. + # Default: false + strict: true + + nakedret: + # Make an issue if func has more lines of code than this setting, and it has naked returns. + # Default: 30 + max-func-lines: 0 + + nolintlint: + # Exclude following linters from requiring an explanation. + # Default: [] + allow-no-explanation: [ funlen, gocognit, lll ] + # Enable to require an explanation of nonzero length after each nolint directive. + # Default: false + require-explanation: true + # Enable to require nolint directives to mention the specific linter being suppressed. + # Default: false + require-specific: true + + rowserrcheck: + # database/sql is always checked + # Default: [] + packages: + - github.com/jmoiron/sqlx + + tenv: + # The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures. + # Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked. + # Default: false + all: true + + +linters: + disable-all: true + enable: + ## enabled by default + - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases + - gosimple # Linter for Go source code that specializes in simplifying a code + - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string + - ineffassign # Detects when assignments to existing variables are not used + - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks + - unused # Checks Go code for unused constants, variables, functions and types + ## disabled by default + # - asasalint # Check for pass []any as any in variadic func(...any) + - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers + - bidichk # Checks for dangerous unicode character sequences + - bodyclose # checks whether HTTP response body is closed successfully + - contextcheck # check the function whether use a non-inherited context + - cyclop # checks function and package cyclomatic complexity + - dupl # Tool for code clone detection + - durationcheck # check for two durations multiplied together + - errname # Checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error. + - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. + - exhaustive # check exhaustiveness of enum switch statements + - forbidigo # Forbids identifiers + - funlen # Tool for detection of long functions + # - gochecknoglobals # check that no global variables exist + - gochecknoinits # Checks that no init functions are present in Go code + - gocognit # Computes and checks the cognitive complexity of functions + - goconst # Finds repeated strings that could be replaced by a constant + - gocritic # Provides diagnostics that check for bugs, performance and style issues. + - gocyclo # Computes and checks the cyclomatic complexity of functions + - godot # Check if comments end in a period + - goimports # In addition to fixing imports, goimports also formats your code in the same style as gofmt. + - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. + - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. + - goprintffuncname # Checks that printf-like functions are named with f at the end + - gosec # Inspects source code for security problems + - lll # Reports long lines + - makezero # Finds slice declarations with non-zero initial length + # - nakedret # Finds naked returns in functions greater than a specified function length + - mnd # An analyzer to detect magic numbers. + - nestif # Reports deeply nested if statements + - nilerr # Finds the code that returns nil even if it checks that the error is not nil. + - nilnil # Checks that there is no simultaneous return of nil error and an invalid value. + # - noctx # noctx finds sending http request without context.Context + - nolintlint # Reports ill-formed or insufficient nolint directives + # - nonamedreturns # Reports all named returns + - nosprintfhostport # Checks for misuse of Sprintf to construct a host with port in a URL. + - predeclared # find code that shadows one of Go's predeclared identifiers + - promlinter # Check Prometheus metrics naming via promlint + - revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint. + - rowserrcheck # checks whether Err of rows is checked successfully + - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. + - stylecheck # Stylecheck is a replacement for golint + - testpackage # linter that makes you use a separate _test package + - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes + - unconvert # Remove unnecessary type conversions + - unparam # Reports unused function parameters + - usetesting # Reports uses of functions with replacement inside the testing package + - wastedassign # wastedassign finds wasted assignment statements. + - whitespace # Tool for detection of leading and trailing whitespace + ## you may want to enable + #- decorder # check declaration order and count of types, constants, variables and functions + #- exhaustruct # Checks if all structure fields are initialized + #- goheader # Checks is file header matches to pattern + #- ireturn # Accept Interfaces, Return Concrete Types + #- prealloc # [premature optimization, but can be used in some cases] Finds slice declarations that could potentially be preallocated + #- varnamelen # [great idea, but too many false positives] checks that the length of a variable's name matches its scope + #- wrapcheck # Checks that errors returned from external packages are wrapped + ## disabled + #- containedctx # containedctx is a linter that detects struct contained context.Context field + #- depguard # [replaced by gomodguard] Go linter that checks if package imports are in a list of acceptable packages + #- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) + #- errchkjson # [don't see profit + I'm against of omitting errors like in the first example https://github.com/breml/errchkjson] Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occasions, where the check for the returned error can be omitted. + #- forcetypeassert # [replaced by errcheck] finds forced type assertions + #- gci # Gci controls golang package import order and makes it always deterministic. + #- godox # Tool for detection of FIXME, TODO and other comment keywords + #- goerr113 # [too strict] Golang linter to check the errors handling expressions + #- gofmt # [replaced by goimports] Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification + #- gofumpt # [replaced by goimports, gofumports is not available yet] Gofumpt checks whether code was gofumpt-ed. + #- grouper # An analyzer to analyze expression groups. + #- ifshort # Checks that your code uses short syntax for if-statements whenever possible + #- importas # Enforces consistent import aliases + #- maintidx # maintidx measures the maintainability index of each function. + #- misspell # [useless] Finds commonly misspelled English words in comments + #- nlreturn # [too strict and mostly code is not more readable] nlreturn checks for a new line before return and branch statements to increase code clarity + #- nosnakecase # Detects snake case of variable naming and function name. # TODO: maybe enable after https://github.com/sivchari/nosnakecase/issues/14 + #- paralleltest # [too many false positives] paralleltest detects missing usage of t.Parallel() method in your Go test + #- tagliatelle # Checks the struct tags. + #- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers + #- wsl # [too strict and mostly code is not more readable] Whitespace Linter - Forces you to use empty lines! + + +issues: + # Maximum count of issues with the same text. + # Set to 0 to disable. + # Default: 3 + max-same-issues: 50 + + exclude-rules: + - source: "^//\\s*go:generate\\s" + linters: [ lll ] + - source: "(noinspection|TODO)" + linters: [ godot ] + - source: "//noinspection" + linters: [ gocritic ] + - source: "^\\s+if _, ok := err\\.\\([^.]+\\.InternalError\\); ok {" + linters: [ errorlint ] + - path: "_test\\.go" + linters: + - bodyclose + - dupl + - funlen + - goconst + - gosec + - noctx + - wrapcheck From 8ba38f6ba16264760d3fd88892d02b25ef742c24 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Tue, 13 May 2025 12:44:16 +0100 Subject: [PATCH 101/129] remove backup file (#996) --- .golangci.bck.yml | 258 ---------------------------------------------- 1 file changed, 258 deletions(-) delete mode 100644 .golangci.bck.yml diff --git a/.golangci.bck.yml b/.golangci.bck.yml deleted file mode 100644 index a5988825b..000000000 --- a/.golangci.bck.yml +++ /dev/null @@ -1,258 +0,0 @@ -## Golden config for golangci-lint v1.47.3 -# -# This is the best config for golangci-lint based on my experience and opinion. -# It is very strict, but not extremely strict. -# Feel free to adopt and change it for your needs. - -run: - # Timeout for analysis, e.g. 30s, 5m. - # Default: 1m - timeout: 3m - - -# This file contains only configs which differ from defaults. -# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml -linters-settings: - cyclop: - # The maximal code complexity to report. - # Default: 10 - max-complexity: 30 - # The maximal average package complexity. - # If it's higher than 0.0 (float) the check is enabled - # Default: 0.0 - package-average: 10.0 - - errcheck: - # Report about not checking of errors in type assertions: `a := b.(MyStruct)`. - # Such cases aren't reported by default. - # Default: false - check-type-assertions: true - - funlen: - # Checks the number of lines in a function. - # If lower than 0, disable the check. - # Default: 60 - lines: 100 - # Checks the number of statements in a function. - # If lower than 0, disable the check. - # Default: 40 - statements: 50 - - gocognit: - # Minimal code complexity to report - # Default: 30 (but we recommend 10-20) - min-complexity: 20 - - gocritic: - # Settings passed to gocritic. - # The settings key is the name of a supported gocritic checker. - # The list of supported checkers can be find in https://go-critic.github.io/overview. - settings: - captLocal: - # Whether to restrict checker to params only. - # Default: true - paramsOnly: false - underef: - # Whether to skip (*x).method() calls where x is a pointer receiver. - # Default: true - skipRecvDeref: false - - mnd: - # List of function patterns to exclude from analysis. - # Values always ignored: `time.Date` - # Default: [] - ignored-functions: - - os.Chmod - - os.Mkdir - - os.MkdirAll - - os.OpenFile - - os.WriteFile - - prometheus.ExponentialBuckets - - prometheus.ExponentialBucketsRange - - prometheus.LinearBuckets - - strconv.FormatFloat - - strconv.FormatInt - - strconv.FormatUint - - strconv.ParseFloat - - strconv.ParseInt - - strconv.ParseUint - - gomodguard: - blocked: - # List of blocked modules. - # Default: [] - modules: - - github.com/golang/protobuf: - recommendations: - - google.golang.org/protobuf - reason: "see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules" - - github.com/satori/go.uuid: - recommendations: - - github.com/google/uuid - reason: "satori's package is not maintained" - - github.com/gofrs/uuid: - recommendations: - - github.com/google/uuid - reason: "see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw" - - govet: - # Enable all analyzers. - # Default: false - enable-all: true - # Disable analyzers by name. - # Run `go tool vet help` to see all analyzers. - # Default: [] - disable: - - fieldalignment # too strict - # Settings per analyzer. - settings: - shadow: - # Whether to be strict about shadowing; can be noisy. - # Default: false - strict: true - - nakedret: - # Make an issue if func has more lines of code than this setting, and it has naked returns. - # Default: 30 - max-func-lines: 0 - - nolintlint: - # Exclude following linters from requiring an explanation. - # Default: [] - allow-no-explanation: [ funlen, gocognit, lll ] - # Enable to require an explanation of nonzero length after each nolint directive. - # Default: false - require-explanation: true - # Enable to require nolint directives to mention the specific linter being suppressed. - # Default: false - require-specific: true - - rowserrcheck: - # database/sql is always checked - # Default: [] - packages: - - github.com/jmoiron/sqlx - - tenv: - # The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures. - # Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked. - # Default: false - all: true - - -linters: - disable-all: true - enable: - ## enabled by default - - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases - - gosimple # Linter for Go source code that specializes in simplifying a code - - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string - - ineffassign # Detects when assignments to existing variables are not used - - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks - - unused # Checks Go code for unused constants, variables, functions and types - ## disabled by default - # - asasalint # Check for pass []any as any in variadic func(...any) - - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers - - bidichk # Checks for dangerous unicode character sequences - - bodyclose # checks whether HTTP response body is closed successfully - - contextcheck # check the function whether use a non-inherited context - - cyclop # checks function and package cyclomatic complexity - - dupl # Tool for code clone detection - - durationcheck # check for two durations multiplied together - - errname # Checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error. - - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. - - exhaustive # check exhaustiveness of enum switch statements - - forbidigo # Forbids identifiers - - funlen # Tool for detection of long functions - # - gochecknoglobals # check that no global variables exist - - gochecknoinits # Checks that no init functions are present in Go code - - gocognit # Computes and checks the cognitive complexity of functions - - goconst # Finds repeated strings that could be replaced by a constant - - gocritic # Provides diagnostics that check for bugs, performance and style issues. - - gocyclo # Computes and checks the cyclomatic complexity of functions - - godot # Check if comments end in a period - - goimports # In addition to fixing imports, goimports also formats your code in the same style as gofmt. - - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - - goprintffuncname # Checks that printf-like functions are named with f at the end - - gosec # Inspects source code for security problems - - lll # Reports long lines - - makezero # Finds slice declarations with non-zero initial length - # - nakedret # Finds naked returns in functions greater than a specified function length - - mnd # An analyzer to detect magic numbers. - - nestif # Reports deeply nested if statements - - nilerr # Finds the code that returns nil even if it checks that the error is not nil. - - nilnil # Checks that there is no simultaneous return of nil error and an invalid value. - # - noctx # noctx finds sending http request without context.Context - - nolintlint # Reports ill-formed or insufficient nolint directives - # - nonamedreturns # Reports all named returns - - nosprintfhostport # Checks for misuse of Sprintf to construct a host with port in a URL. - - predeclared # find code that shadows one of Go's predeclared identifiers - - promlinter # Check Prometheus metrics naming via promlint - - revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint. - - rowserrcheck # checks whether Err of rows is checked successfully - - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - - stylecheck # Stylecheck is a replacement for golint - - testpackage # linter that makes you use a separate _test package - - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - - unconvert # Remove unnecessary type conversions - - unparam # Reports unused function parameters - - usetesting # Reports uses of functions with replacement inside the testing package - - wastedassign # wastedassign finds wasted assignment statements. - - whitespace # Tool for detection of leading and trailing whitespace - ## you may want to enable - #- decorder # check declaration order and count of types, constants, variables and functions - #- exhaustruct # Checks if all structure fields are initialized - #- goheader # Checks is file header matches to pattern - #- ireturn # Accept Interfaces, Return Concrete Types - #- prealloc # [premature optimization, but can be used in some cases] Finds slice declarations that could potentially be preallocated - #- varnamelen # [great idea, but too many false positives] checks that the length of a variable's name matches its scope - #- wrapcheck # Checks that errors returned from external packages are wrapped - ## disabled - #- containedctx # containedctx is a linter that detects struct contained context.Context field - #- depguard # [replaced by gomodguard] Go linter that checks if package imports are in a list of acceptable packages - #- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) - #- errchkjson # [don't see profit + I'm against of omitting errors like in the first example https://github.com/breml/errchkjson] Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occasions, where the check for the returned error can be omitted. - #- forcetypeassert # [replaced by errcheck] finds forced type assertions - #- gci # Gci controls golang package import order and makes it always deterministic. - #- godox # Tool for detection of FIXME, TODO and other comment keywords - #- goerr113 # [too strict] Golang linter to check the errors handling expressions - #- gofmt # [replaced by goimports] Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification - #- gofumpt # [replaced by goimports, gofumports is not available yet] Gofumpt checks whether code was gofumpt-ed. - #- grouper # An analyzer to analyze expression groups. - #- ifshort # Checks that your code uses short syntax for if-statements whenever possible - #- importas # Enforces consistent import aliases - #- maintidx # maintidx measures the maintainability index of each function. - #- misspell # [useless] Finds commonly misspelled English words in comments - #- nlreturn # [too strict and mostly code is not more readable] nlreturn checks for a new line before return and branch statements to increase code clarity - #- nosnakecase # Detects snake case of variable naming and function name. # TODO: maybe enable after https://github.com/sivchari/nosnakecase/issues/14 - #- paralleltest # [too many false positives] paralleltest detects missing usage of t.Parallel() method in your Go test - #- tagliatelle # Checks the struct tags. - #- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - #- wsl # [too strict and mostly code is not more readable] Whitespace Linter - Forces you to use empty lines! - - -issues: - # Maximum count of issues with the same text. - # Set to 0 to disable. - # Default: 3 - max-same-issues: 50 - - exclude-rules: - - source: "^//\\s*go:generate\\s" - linters: [ lll ] - - source: "(noinspection|TODO)" - linters: [ godot ] - - source: "//noinspection" - linters: [ gocritic ] - - source: "^\\s+if _, ok := err\\.\\([^.]+\\.InternalError\\); ok {" - linters: [ errorlint ] - - path: "_test\\.go" - linters: - - bodyclose - - dupl - - funlen - - goconst - - gosec - - noctx - - wrapcheck From 0116f2994de0960ab65598f60e55d98b43e5fc2d Mon Sep 17 00:00:00 2001 From: Pedro Chaparro <94259578+PChaparro@users.noreply.github.com> Date: Tue, 13 May 2025 06:51:08 -0500 Subject: [PATCH 102/129] feat: add support for image generation using `gpt-image-1` (#971) * feat: add gpt-image-1 support * feat: add example to generate image using gpt-image-1 model * style: missing period in comments * feat: add missing fields to example * docs: add GPT Image 1 to README * revert: keep `examples/images/main.go` unchanged * docs: remove unnecessary newline from example in README file --- README.md | 62 +++++++++++++++++++++++++++++++++- examples/images/main.go | 2 +- image.go | 75 +++++++++++++++++++++++++++++++++++------ 3 files changed, 126 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 57d1d35bf..77b85e519 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ This library provides unofficial Go clients for [OpenAI API](https://platform.op * ChatGPT 4o, o1 * GPT-3, GPT-4 -* DALL·E 2, DALL·E 3 +* DALL·E 2, DALL·E 3, GPT Image 1 * Whisper ## Installation @@ -357,6 +357,66 @@ func main() { ``` +
+GPT Image 1 image generation + +```go +package main + +import ( + "context" + "encoding/base64" + "fmt" + "os" + + openai "github.com/sashabaranov/go-openai" +) + +func main() { + c := openai.NewClient("your token") + ctx := context.Background() + + req := openai.ImageRequest{ + Prompt: "Parrot on a skateboard performing a trick. Large bold text \"SKATE MASTER\" banner at the bottom of the image. Cartoon style, natural light, high detail, 1:1 aspect ratio.", + Background: openai.CreateImageBackgroundOpaque, + Model: openai.CreateImageModelGptImage1, + Size: openai.CreateImageSize1024x1024, + N: 1, + Quality: openai.CreateImageQualityLow, + OutputCompression: 100, + OutputFormat: openai.CreateImageOutputFormatJPEG, + // Moderation: openai.CreateImageModerationLow, + // User: "", + } + + resp, err := c.CreateImage(ctx, req) + if err != nil { + fmt.Printf("Image creation Image generation with GPT Image 1error: %v\n", err) + return + } + + fmt.Println("Image Base64:", resp.Data[0].B64JSON) + + // Decode the base64 data + imgBytes, err := base64.StdEncoding.DecodeString(resp.Data[0].B64JSON) + if err != nil { + fmt.Printf("Base64 decode error: %v\n", err) + return + } + + // Write image to file + outputPath := "generated_image.jpg" + err = os.WriteFile(outputPath, imgBytes, 0644) + if err != nil { + fmt.Printf("Failed to write image file: %v\n", err) + return + } + + fmt.Printf("The image was saved as %s\n", outputPath) +} +``` +
+
Configuring proxy diff --git a/examples/images/main.go b/examples/images/main.go index 5ee649d22..2bfeb7973 100644 --- a/examples/images/main.go +++ b/examples/images/main.go @@ -25,4 +25,4 @@ func main() { return } fmt.Println(respUrl.Data[0].URL) -} +} \ No newline at end of file diff --git a/image.go b/image.go index 577d7db95..d62622a35 100644 --- a/image.go +++ b/image.go @@ -13,51 +13,101 @@ const ( CreateImageSize256x256 = "256x256" CreateImageSize512x512 = "512x512" CreateImageSize1024x1024 = "1024x1024" + // dall-e-3 supported only. CreateImageSize1792x1024 = "1792x1024" CreateImageSize1024x1792 = "1024x1792" + + // gpt-image-1 supported only. + CreateImageSize1536x1024 = "1536x1024" // Landscape + CreateImageSize1024x1536 = "1024x1536" // Portrait ) const ( - CreateImageResponseFormatURL = "url" + // dall-e-2 and dall-e-3 only. CreateImageResponseFormatB64JSON = "b64_json" + CreateImageResponseFormatURL = "url" ) const ( - CreateImageModelDallE2 = "dall-e-2" - CreateImageModelDallE3 = "dall-e-3" + CreateImageModelDallE2 = "dall-e-2" + CreateImageModelDallE3 = "dall-e-3" + CreateImageModelGptImage1 = "gpt-image-1" ) const ( CreateImageQualityHD = "hd" CreateImageQualityStandard = "standard" + + // gpt-image-1 only. + CreateImageQualityHigh = "high" + CreateImageQualityMedium = "medium" + CreateImageQualityLow = "low" ) const ( + // dall-e-3 only. CreateImageStyleVivid = "vivid" CreateImageStyleNatural = "natural" ) +const ( + // gpt-image-1 only. + CreateImageBackgroundTransparent = "transparent" + CreateImageBackgroundOpaque = "opaque" +) + +const ( + // gpt-image-1 only. + CreateImageModerationLow = "low" +) + +const ( + // gpt-image-1 only. + CreateImageOutputFormatPNG = "png" + CreateImageOutputFormatJPEG = "jpeg" + CreateImageOutputFormatWEBP = "webp" +) + // ImageRequest represents the request structure for the image API. type ImageRequest struct { - Prompt string `json:"prompt,omitempty"` - Model string `json:"model,omitempty"` - N int `json:"n,omitempty"` - Quality string `json:"quality,omitempty"` - Size string `json:"size,omitempty"` - Style string `json:"style,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` - User string `json:"user,omitempty"` + Prompt string `json:"prompt,omitempty"` + Model string `json:"model,omitempty"` + N int `json:"n,omitempty"` + Quality string `json:"quality,omitempty"` + Size string `json:"size,omitempty"` + Style string `json:"style,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + User string `json:"user,omitempty"` + Background string `json:"background,omitempty"` + Moderation string `json:"moderation,omitempty"` + OutputCompression int `json:"output_compression,omitempty"` + OutputFormat string `json:"output_format,omitempty"` } // ImageResponse represents a response structure for image API. type ImageResponse struct { Created int64 `json:"created,omitempty"` Data []ImageResponseDataInner `json:"data,omitempty"` + Usage ImageResponseUsage `json:"usage,omitempty"` httpHeader } +// ImageResponseInputTokensDetails represents the token breakdown for input tokens. +type ImageResponseInputTokensDetails struct { + TextTokens int `json:"text_tokens,omitempty"` + ImageTokens int `json:"image_tokens,omitempty"` +} + +// ImageResponseUsage represents the token usage information for image API. +type ImageResponseUsage struct { + TotalTokens int `json:"total_tokens,omitempty"` + InputTokens int `json:"input_tokens,omitempty"` + OutputTokens int `json:"output_tokens,omitempty"` + InputTokensDetails ImageResponseInputTokensDetails `json:"input_tokens_details,omitempty"` +} + // ImageResponseDataInner represents a response data structure for image API. type ImageResponseDataInner struct { URL string `json:"url,omitempty"` @@ -91,6 +141,8 @@ type ImageEditRequest struct { N int `json:"n,omitempty"` Size string `json:"size,omitempty"` ResponseFormat string `json:"response_format,omitempty"` + Quality string `json:"quality,omitempty"` + User string `json:"user,omitempty"` } // CreateEditImage - API call to create an image. This is the main endpoint of the DALL-E API. @@ -159,6 +211,7 @@ type ImageVariRequest struct { N int `json:"n,omitempty"` Size string `json:"size,omitempty"` ResponseFormat string `json:"response_format,omitempty"` + User string `json:"user,omitempty"` } // CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API. From 6aaa7322960741a84da11ac360516e4ec813dfff Mon Sep 17 00:00:00 2001 From: Justa Date: Tue, 13 May 2025 19:52:44 +0800 Subject: [PATCH 103/129] add ChatTemplateKwargs to ChatCompletionRequest (#980) Co-authored-by: Justa --- chat.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/chat.go b/chat.go index 0f91d481c..c8a3e81b3 100644 --- a/chat.go +++ b/chat.go @@ -275,6 +275,11 @@ type ChatCompletionRequest struct { Metadata map[string]string `json:"metadata,omitempty"` // Configuration for a predicted output. Prediction *Prediction `json:"prediction,omitempty"` + // ChatTemplateKwargs provides a way to add non-standard parameters to the request body. + // Additional kwargs to pass to the template renderer. Will be accessible by the chat template. + // Such as think mode for qwen3. "chat_template_kwargs": {"enable_thinking": false} + // https://qwen.readthedocs.io/en/latest/deployment/vllm.html#thinking-non-thinking-modes + ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"` } type StreamOptions struct { From 4d2e7ab29d7bf853c740e9e63187fed960e6b0b4 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Tue, 13 May 2025 12:59:06 +0100 Subject: [PATCH 104/129] fix lint (#998) --- examples/images/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/images/main.go b/examples/images/main.go index 2bfeb7973..5ee649d22 100644 --- a/examples/images/main.go +++ b/examples/images/main.go @@ -25,4 +25,4 @@ func main() { return } fmt.Println(respUrl.Data[0].URL) -} \ No newline at end of file +} From 8c65b35c57ad4e9ba408def9bf9ff97817aab932 Mon Sep 17 00:00:00 2001 From: Axb12 <67110563+Axb12@users.noreply.github.com> Date: Tue, 20 May 2025 21:45:40 +0800 Subject: [PATCH 105/129] update image api *os.File to io.Reader (#994) * update image api *os.File to io.Reader * update code style * add reader test * supplementary reader test * update the reader in the form builder test * add commnet * update comment * update code style --- image.go | 43 ++++++++++++++++++----------------- image_test.go | 8 +++---- internal/form_builder.go | 35 ++++++++++++++++++++++++++-- internal/form_builder_test.go | 29 +++++++++++++++++++++++ 4 files changed, 88 insertions(+), 27 deletions(-) diff --git a/image.go b/image.go index d62622a35..72077ce41 100644 --- a/image.go +++ b/image.go @@ -3,8 +3,8 @@ package openai import ( "bytes" "context" + "io" "net/http" - "os" "strconv" ) @@ -134,15 +134,15 @@ func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (respons // ImageEditRequest represents the request structure for the image API. type ImageEditRequest struct { - Image *os.File `json:"image,omitempty"` - Mask *os.File `json:"mask,omitempty"` - Prompt string `json:"prompt,omitempty"` - Model string `json:"model,omitempty"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` - Quality string `json:"quality,omitempty"` - User string `json:"user,omitempty"` + Image io.Reader `json:"image,omitempty"` + Mask io.Reader `json:"mask,omitempty"` + Prompt string `json:"prompt,omitempty"` + Model string `json:"model,omitempty"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + Quality string `json:"quality,omitempty"` + User string `json:"user,omitempty"` } // CreateEditImage - API call to create an image. This is the main endpoint of the DALL-E API. @@ -150,15 +150,16 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) body := &bytes.Buffer{} builder := c.createFormBuilder(body) - // image - err = builder.CreateFormFile("image", request.Image) + // image, filename is not required + err = builder.CreateFormFileReader("image", request.Image, "") if err != nil { return } // mask, it is optional if request.Mask != nil { - err = builder.CreateFormFile("mask", request.Mask) + // mask, filename is not required + err = builder.CreateFormFileReader("mask", request.Mask, "") if err != nil { return } @@ -206,12 +207,12 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) // ImageVariRequest represents the request structure for the image API. type ImageVariRequest struct { - Image *os.File `json:"image,omitempty"` - Model string `json:"model,omitempty"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` - User string `json:"user,omitempty"` + Image io.Reader `json:"image,omitempty"` + Model string `json:"model,omitempty"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + User string `json:"user,omitempty"` } // CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API. @@ -220,8 +221,8 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) body := &bytes.Buffer{} builder := c.createFormBuilder(body) - // image - err = builder.CreateFormFile("image", request.Image) + // image, filename is not required + err = builder.CreateFormFileReader("image", request.Image, "") if err != nil { return } diff --git a/image_test.go b/image_test.go index 9332dd5cd..644005515 100644 --- a/image_test.go +++ b/image_test.go @@ -54,13 +54,13 @@ func TestImageFormBuilderFailures(t *testing.T) { } mockFailedErr := fmt.Errorf("mock form builder fail") - mockBuilder.mockCreateFormFile = func(string, *os.File) error { + mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { return mockFailedErr } _, err := client.CreateEditImage(ctx, req) checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - mockBuilder.mockCreateFormFile = func(name string, _ *os.File) error { + mockBuilder.mockCreateFormFileReader = func(name string, _ io.Reader, _ string) error { if name == "mask" { return mockFailedErr } @@ -119,13 +119,13 @@ func TestVariImageFormBuilderFailures(t *testing.T) { req := ImageVariRequest{} mockFailedErr := fmt.Errorf("mock form builder fail") - mockBuilder.mockCreateFormFile = func(string, *os.File) error { + mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { return mockFailedErr } _, err := client.CreateVariImage(ctx, req) checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") - mockBuilder.mockCreateFormFile = func(string, *os.File) error { + mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } diff --git a/internal/form_builder.go b/internal/form_builder.go index 2224fad45..1c2513dd9 100644 --- a/internal/form_builder.go +++ b/internal/form_builder.go @@ -4,8 +4,10 @@ import ( "fmt" "io" "mime/multipart" + "net/textproto" "os" - "path" + "path/filepath" + "strings" ) type FormBuilder interface { @@ -30,8 +32,37 @@ func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) er return fb.createFormFile(fieldname, file, file.Name()) } +var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") + +func escapeQuotes(s string) string { + return quoteEscaper.Replace(s) +} + +// CreateFormFileReader creates a form field with a file reader. +// The filename in parameters can be an empty string. +// The filename in Content-Disposition is required, But it can be an empty string. func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error { - return fb.createFormFile(fieldname, r, path.Base(filename)) + h := make(textproto.MIMEHeader) + h.Set( + "Content-Disposition", + fmt.Sprintf( + `form-data; name="%s"; filename="%s"`, + escapeQuotes(fieldname), + escapeQuotes(filepath.Base(filename)), + ), + ) + + fieldWriter, err := fb.writer.CreatePart(h) + if err != nil { + return err + } + + _, err = io.Copy(fieldWriter, r) + if err != nil { + return err + } + + return nil } func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error { diff --git a/internal/form_builder_test.go b/internal/form_builder_test.go index 8df989e3b..76922c1ba 100644 --- a/internal/form_builder_test.go +++ b/internal/form_builder_test.go @@ -43,3 +43,32 @@ func TestFormBuilderWithClosedFile(t *testing.T) { checks.HasError(t, err, "formbuilder should return error if file is closed") checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed") } + +type failingReader struct { +} + +var errMockFailingReaderError = errors.New("mock reader failed") + +func (*failingReader) Read([]byte) (int, error) { + return 0, errMockFailingReaderError +} + +func TestFormBuilderWithReader(t *testing.T) { + file, err := os.CreateTemp(t.TempDir(), "") + if err != nil { + t.Fatalf("Error creating tmp file: %v", err) + } + defer file.Close() + builder := NewFormBuilder(&failingWriter{}) + err = builder.CreateFormFileReader("file", file, file.Name()) + checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails") + + builder = NewFormBuilder(&bytes.Buffer{}) + reader := &failingReader{} + err = builder.CreateFormFileReader("file", reader, "") + checks.ErrorIs(t, err, errMockFailingReaderError, "formbuilder should return error if copy reader fails") + + successReader := &bytes.Buffer{} + err = builder.CreateFormFileReader("file", successReader, "") + checks.NoError(t, err, "formbuilder should not return error") +} From ff9d83a4854790ecbd16e6328415a32c7497efaf Mon Sep 17 00:00:00 2001 From: "JT A." Date: Thu, 29 May 2025 04:31:35 -0600 Subject: [PATCH 106/129] skip json field (#1009) * skip json field * backfill some coverage and tests --- jsonschema/json.go | 7 ++++-- jsonschema/json_test.go | 47 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/jsonschema/json.go b/jsonschema/json.go index d458418f3..03bb68891 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -126,9 +126,12 @@ func reflectSchemaObject(t reflect.Type) (*Definition, error) { } jsonTag := field.Tag.Get("json") var required = true - if jsonTag == "" { + switch { + case jsonTag == "-": + continue + case jsonTag == "": jsonTag = field.Name - } else if strings.HasSuffix(jsonTag, ",omitempty") { + case strings.HasSuffix(jsonTag, ",omitempty"): jsonTag = strings.TrimSuffix(jsonTag, ",omitempty") required = false } diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go index 17f0aba8a..84f25fa85 100644 --- a/jsonschema/json_test.go +++ b/jsonschema/json_test.go @@ -329,6 +329,53 @@ func TestStructToSchema(t *testing.T) { "additionalProperties":false }`, }, + { + name: "Test with exclude mark", + in: struct { + Name string `json:"-"` + }{ + Name: "Name", + }, + want: `{ + "type":"object", + "additionalProperties":false + }`, + }, + { + name: "Test with no json tag", + in: struct { + Name string + }{ + Name: "", + }, + want: `{ + "type":"object", + "properties":{ + "Name":{ + "type":"string" + } + }, + "required":["Name"], + "additionalProperties":false + }`, + }, + { + name: "Test with omitempty tag", + in: struct { + Name string `json:"name,omitempty"` + }{ + Name: "", + }, + want: `{ + "type":"object", + "properties":{ + "name":{ + "type":"string" + } + }, + "additionalProperties":false + }`, + }, } for _, tt := range tests { From d7dca83beda99528c974392a8295630775c1c197 Mon Sep 17 00:00:00 2001 From: Axb12 <67110563+Axb12@users.noreply.github.com> Date: Tue, 17 Jun 2025 03:08:14 +0800 Subject: [PATCH 107/129] fix image api missing filename bug (#1017) * fix image api missing filename bug * add test * add test * update test --- image.go | 32 +++++++++++++++++++++++++++++--- internal/form_builder.go | 17 +++++++++++++++-- internal/form_builder_test.go | 18 ++++++++++++++++++ 3 files changed, 62 insertions(+), 5 deletions(-) diff --git a/image.go b/image.go index 72077ce41..84b9daf02 100644 --- a/image.go +++ b/image.go @@ -132,7 +132,32 @@ func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (respons return } +// WrapReader wraps an io.Reader with filename and Content-type. +func WrapReader(rdr io.Reader, filename string, contentType string) io.Reader { + return file{rdr, filename, contentType} +} + +type file struct { + io.Reader + name string + contentType string +} + +func (f file) Name() string { + if f.name != "" { + return f.name + } else if named, ok := f.Reader.(interface{ Name() string }); ok { + return named.Name() + } + return "" +} + +func (f file) ContentType() string { + return f.contentType +} + // ImageEditRequest represents the request structure for the image API. +// Use WrapReader to wrap an io.Reader with filename and Content-type. type ImageEditRequest struct { Image io.Reader `json:"image,omitempty"` Mask io.Reader `json:"mask,omitempty"` @@ -150,7 +175,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) body := &bytes.Buffer{} builder := c.createFormBuilder(body) - // image, filename is not required + // image, filename verification can be postponed err = builder.CreateFormFileReader("image", request.Image, "") if err != nil { return @@ -158,7 +183,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) // mask, it is optional if request.Mask != nil { - // mask, filename is not required + // filename verification can be postponed err = builder.CreateFormFileReader("mask", request.Mask, "") if err != nil { return @@ -206,6 +231,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) } // ImageVariRequest represents the request structure for the image API. +// Use WrapReader to wrap an io.Reader with filename and Content-type. type ImageVariRequest struct { Image io.Reader `json:"image,omitempty"` Model string `json:"model,omitempty"` @@ -221,7 +247,7 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) body := &bytes.Buffer{} builder := c.createFormBuilder(body) - // image, filename is not required + // image, filename verification can be postponed err = builder.CreateFormFileReader("image", request.Image, "") if err != nil { return diff --git a/internal/form_builder.go b/internal/form_builder.go index 1c2513dd9..5b382df20 100644 --- a/internal/form_builder.go +++ b/internal/form_builder.go @@ -39,9 +39,18 @@ func escapeQuotes(s string) string { } // CreateFormFileReader creates a form field with a file reader. -// The filename in parameters can be an empty string. -// The filename in Content-Disposition is required, But it can be an empty string. +// The filename in Content-Disposition is required. func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error { + if filename == "" { + if f, ok := r.(interface{ Name() string }); ok { + filename = f.Name() + } + } + var contentType string + if f, ok := r.(interface{ ContentType() string }); ok { + contentType = f.ContentType() + } + h := make(textproto.MIMEHeader) h.Set( "Content-Disposition", @@ -51,6 +60,10 @@ func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader escapeQuotes(filepath.Base(filename)), ), ) + // content type is optional, but it can be set + if contentType != "" { + h.Set("Content-Type", contentType) + } fieldWriter, err := fb.writer.CreatePart(h) if err != nil { diff --git a/internal/form_builder_test.go b/internal/form_builder_test.go index 76922c1ba..f4958ad5e 100644 --- a/internal/form_builder_test.go +++ b/internal/form_builder_test.go @@ -1,6 +1,8 @@ package openai //nolint:testpackage // testing private field import ( + "io" + "github.com/sashabaranov/go-openai/internal/test/checks" "bytes" @@ -53,6 +55,18 @@ func (*failingReader) Read([]byte) (int, error) { return 0, errMockFailingReaderError } +type readerWithNameAndContentType struct { + io.Reader +} + +func (*readerWithNameAndContentType) Name() string { + return "" +} + +func (*readerWithNameAndContentType) ContentType() string { + return "image/png" +} + func TestFormBuilderWithReader(t *testing.T) { file, err := os.CreateTemp(t.TempDir(), "") if err != nil { @@ -71,4 +85,8 @@ func TestFormBuilderWithReader(t *testing.T) { successReader := &bytes.Buffer{} err = builder.CreateFormFileReader("file", successReader, "") checks.NoError(t, err, "formbuilder should not return error") + + rnc := &readerWithNameAndContentType{Reader: &bytes.Buffer{}} + err = builder.CreateFormFileReader("file", rnc, "") + checks.NoError(t, err, "formbuilder should not return error") } From a931bf7e85af9d39af414de21d907b6835281519 Mon Sep 17 00:00:00 2001 From: Whitea <145814986+Whitea029@users.noreply.github.com> Date: Tue, 17 Jun 2025 18:00:15 +0800 Subject: [PATCH 108/129] test: enhance error accumulator and form builder tests, add marshaller tests (#999) * test: enhance error accumulator and form builder tests, add marshaller tests * test: fix some issue form golangci-lint * test: gofmt form builder test * fix * fix * fix lint --- internal/error_accumulator_test.go | 43 +++++++----------- internal/form_builder.go | 3 ++ internal/form_builder_test.go | 73 +++++++++++++++++++++++++++++- internal/marshaller_test.go | 34 ++++++++++++++ internal/unmarshaler_test.go | 37 +++++++++++++++ 5 files changed, 163 insertions(+), 27 deletions(-) create mode 100644 internal/marshaller_test.go create mode 100644 internal/unmarshaler_test.go diff --git a/internal/error_accumulator_test.go b/internal/error_accumulator_test.go index d48f28177..3fa9d7714 100644 --- a/internal/error_accumulator_test.go +++ b/internal/error_accumulator_test.go @@ -1,41 +1,32 @@ package openai_test import ( - "bytes" - "errors" "testing" - utils "github.com/sashabaranov/go-openai/internal" - "github.com/sashabaranov/go-openai/internal/test" + openai "github.com/sashabaranov/go-openai/internal" + "github.com/sashabaranov/go-openai/internal/test/checks" ) -func TestErrorAccumulatorBytes(t *testing.T) { - accumulator := &utils.DefaultErrorAccumulator{ - Buffer: &bytes.Buffer{}, +func TestDefaultErrorAccumulator_WriteMultiple(t *testing.T) { + ea, ok := openai.NewErrorAccumulator().(*openai.DefaultErrorAccumulator) + if !ok { + t.Fatal("type assertion to *DefaultErrorAccumulator failed") } + checks.NoError(t, ea.Write([]byte("{\"error\": \"test1\"}"))) + checks.NoError(t, ea.Write([]byte("{\"error\": \"test2\"}"))) - errBytes := accumulator.Bytes() - if len(errBytes) != 0 { - t.Fatalf("Did not return nil with empty bytes: %s", string(errBytes)) - } - - err := accumulator.Write([]byte("{}")) - if err != nil { - t.Fatalf("%+v", err) - } - - errBytes = accumulator.Bytes() - if len(errBytes) == 0 { - t.Fatalf("Did not return error bytes when has error: %s", string(errBytes)) + expected := "{\"error\": \"test1\"}{\"error\": \"test2\"}" + if string(ea.Bytes()) != expected { + t.Fatalf("Expected %q, got %q", expected, ea.Bytes()) } } -func TestErrorByteWriteErrors(t *testing.T) { - accumulator := &utils.DefaultErrorAccumulator{ - Buffer: &test.FailingErrorBuffer{}, +func TestDefaultErrorAccumulator_EmptyBuffer(t *testing.T) { + ea, ok := openai.NewErrorAccumulator().(*openai.DefaultErrorAccumulator) + if !ok { + t.Fatal("type assertion to *DefaultErrorAccumulator failed") } - err := accumulator.Write([]byte("{")) - if !errors.Is(err, test.ErrTestErrorAccumulatorWriteFailed) { - t.Fatalf("Did not return error when write failed: %v", err) + if len(ea.Bytes()) != 0 { + t.Fatal("Buffer should be empty initially") } } diff --git a/internal/form_builder.go b/internal/form_builder.go index 5b382df20..a17e820c0 100644 --- a/internal/form_builder.go +++ b/internal/form_builder.go @@ -97,6 +97,9 @@ func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, file } func (fb *DefaultFormBuilder) WriteField(fieldname, value string) error { + if fieldname == "" { + return fmt.Errorf("fieldname cannot be empty") + } return fb.writer.WriteField(fieldname, value) } diff --git a/internal/form_builder_test.go b/internal/form_builder_test.go index f4958ad5e..ddd6b8448 100644 --- a/internal/form_builder_test.go +++ b/internal/form_builder_test.go @@ -1,16 +1,57 @@ package openai //nolint:testpackage // testing private field import ( + "errors" "io" "github.com/sashabaranov/go-openai/internal/test/checks" "bytes" - "errors" "os" "testing" ) +type mockFormBuilder struct { + mockCreateFormFile func(string, *os.File) error + mockWriteField func(string, string) error + mockClose func() error +} + +func (m *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error { + return m.mockCreateFormFile(fieldname, file) +} + +func (m *mockFormBuilder) WriteField(fieldname, value string) error { + return m.mockWriteField(fieldname, value) +} + +func (m *mockFormBuilder) Close() error { + return m.mockClose() +} + +func (m *mockFormBuilder) FormDataContentType() string { + return "" +} + +func TestCloseMethod(t *testing.T) { + t.Run("NormalClose", func(t *testing.T) { + body := &bytes.Buffer{} + builder := NewFormBuilder(body) + checks.NoError(t, builder.Close(), "正常关闭应成功") + }) + + t.Run("ErrorPropagation", func(t *testing.T) { + errorMock := errors.New("mock close error") + mockBuilder := &mockFormBuilder{ + mockClose: func() error { + return errorMock + }, + } + err := mockBuilder.Close() + checks.ErrorIs(t, err, errorMock, "应传递关闭错误") + }) +} + type failingWriter struct { } @@ -90,3 +131,33 @@ func TestFormBuilderWithReader(t *testing.T) { err = builder.CreateFormFileReader("file", rnc, "") checks.NoError(t, err, "formbuilder should not return error") } + +func TestFormDataContentType(t *testing.T) { + t.Run("ReturnsUnderlyingWriterContentType", func(t *testing.T) { + buf := &bytes.Buffer{} + builder := NewFormBuilder(buf) + + contentType := builder.FormDataContentType() + if contentType == "" { + t.Errorf("expected non-empty content type, got empty string") + } + }) +} + +func TestWriteField(t *testing.T) { + t.Run("EmptyFieldNameShouldReturnError", func(t *testing.T) { + buf := &bytes.Buffer{} + builder := NewFormBuilder(buf) + + err := builder.WriteField("", "some value") + checks.HasError(t, err, "fieldname is required") + }) + + t.Run("ValidFieldNameShouldSucceed", func(t *testing.T) { + buf := &bytes.Buffer{} + builder := NewFormBuilder(buf) + + err := builder.WriteField("key", "value") + checks.NoError(t, err, "should write field without error") + }) +} diff --git a/internal/marshaller_test.go b/internal/marshaller_test.go new file mode 100644 index 000000000..70694faed --- /dev/null +++ b/internal/marshaller_test.go @@ -0,0 +1,34 @@ +package openai_test + +import ( + "testing" + + openai "github.com/sashabaranov/go-openai/internal" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestJSONMarshaller_Normal(t *testing.T) { + jm := &openai.JSONMarshaller{} + data := map[string]string{"key": "value"} + + b, err := jm.Marshal(data) + checks.NoError(t, err) + if len(b) == 0 { + t.Fatal("should return non-empty bytes") + } +} + +func TestJSONMarshaller_InvalidInput(t *testing.T) { + jm := &openai.JSONMarshaller{} + _, err := jm.Marshal(make(chan int)) + checks.HasError(t, err, "should return error for unsupported type") +} + +func TestJSONMarshaller_EmptyValue(t *testing.T) { + jm := &openai.JSONMarshaller{} + b, err := jm.Marshal(nil) + checks.NoError(t, err) + if string(b) != "null" { + t.Fatalf("unexpected marshaled value: %s", string(b)) + } +} diff --git a/internal/unmarshaler_test.go b/internal/unmarshaler_test.go new file mode 100644 index 000000000..d63eac779 --- /dev/null +++ b/internal/unmarshaler_test.go @@ -0,0 +1,37 @@ +package openai_test + +import ( + "testing" + + openai "github.com/sashabaranov/go-openai/internal" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestJSONUnmarshaler_Normal(t *testing.T) { + jm := &openai.JSONUnmarshaler{} + data := []byte(`{"key":"value"}`) + var v map[string]string + + err := jm.Unmarshal(data, &v) + checks.NoError(t, err) + if v["key"] != "value" { + t.Fatal("unmarshal result mismatch") + } +} + +func TestJSONUnmarshaler_InvalidJSON(t *testing.T) { + jm := &openai.JSONUnmarshaler{} + data := []byte(`{invalid}`) + var v map[string]interface{} + + err := jm.Unmarshal(data, &v) + checks.HasError(t, err, "should return error for invalid JSON") +} + +func TestJSONUnmarshaler_EmptyInput(t *testing.T) { + jm := &openai.JSONUnmarshaler{} + var v interface{} + + err := jm.Unmarshal(nil, &v) + checks.HasError(t, err, "should return error for nil input") +} From c125ae2ad7b239355161c7f4260a3743ec0c182b Mon Sep 17 00:00:00 2001 From: Hritik Raj Date: Wed, 25 Jun 2025 15:34:45 +0530 Subject: [PATCH 109/129] Fix for removing usage in every stream chunk response. (#1022) * Fix for https://github.com/sashabaranov/go-openai/issues/1021: 1. Make Usage field in completions Response to pointer. * Fix for https://github.com/sashabaranov/go-openai/issues/1021: 1. Make Usage field in completions Response to pointer. 2. Add omitempty to json tag Signed-off-by: Hritik003 --------- Signed-off-by: Hritik003 --- completion.go | 2 +- completion_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/completion.go b/completion.go index 21d4897c4..02ce7b016 100644 --- a/completion.go +++ b/completion.go @@ -242,7 +242,7 @@ type CompletionResponse struct { Created int64 `json:"created"` Model string `json:"model"` Choices []CompletionChoice `json:"choices"` - Usage Usage `json:"usage"` + Usage *Usage `json:"usage,omitempty"` httpHeader } diff --git a/completion_test.go b/completion_test.go index 27e2d150e..f0ead0d63 100644 --- a/completion_test.go +++ b/completion_test.go @@ -192,7 +192,7 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { } inputTokens *= n completionTokens := completionReq.MaxTokens * len(prompts) * n - res.Usage = openai.Usage{ + res.Usage = &openai.Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, From 8e9b2ac83ab2d15a34982d689d97ec5db34cbb06 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Tue, 8 Jul 2025 19:19:49 +0800 Subject: [PATCH 110/129] fix: properly unmarshal JSON schema in ChatCompletionResponseFormatJSONSchema.schema (#1028) * feat: #1027 * add tests * feat: #1027 * feat: #1027 * feat: #1027 * update chat_test.go * feat: #1027 * add test cases --- chat.go | 27 ++++++++++ chat_test.go | 139 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 166 insertions(+) diff --git a/chat.go b/chat.go index c8a3e81b3..b4a0ad90f 100644 --- a/chat.go +++ b/chat.go @@ -5,6 +5,8 @@ import ( "encoding/json" "errors" "net/http" + + "github.com/sashabaranov/go-openai/jsonschema" ) // Chat message role defined by the OpenAI API. @@ -221,6 +223,31 @@ type ChatCompletionResponseFormatJSONSchema struct { Strict bool `json:"strict"` } +func (r *ChatCompletionResponseFormatJSONSchema) UnmarshalJSON(data []byte) error { + type rawJSONSchema struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Schema json.RawMessage `json:"schema"` + Strict bool `json:"strict"` + } + var raw rawJSONSchema + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + r.Name = raw.Name + r.Description = raw.Description + r.Strict = raw.Strict + if len(raw.Schema) > 0 && string(raw.Schema) != "null" { + var d jsonschema.Definition + err := json.Unmarshal(raw.Schema, &d) + if err != nil { + return err + } + r.Schema = &d + } + return nil +} + // ChatCompletionRequest represents a request structure for chat completion API. type ChatCompletionRequest struct { Model string `json:"model"` diff --git a/chat_test.go b/chat_test.go index 514706c96..172ce0740 100644 --- a/chat_test.go +++ b/chat_test.go @@ -946,3 +946,142 @@ func TestFinishReason(t *testing.T) { } } } + +func TestChatCompletionResponseFormatJSONSchema_UnmarshalJSON(t *testing.T) { + type args struct { + data []byte + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + "", + args{ + data: []byte(`{ + "name": "math_response", + "strict": true, + "schema": { + "type": "object", + "properties": { + "steps": { + "type": "array", + "items": { + "type": "object", + "properties": { + "explanation": { "type": "string" }, + "output": { "type": "string" } + }, + "required": ["explanation","output"], + "additionalProperties": false + } + }, + "final_answer": { "type": "string" } + }, + "required": ["steps","final_answer"], + "additionalProperties": false + } + }`), + }, + false, + }, + { + "", + args{ + data: []byte(`{ + "name": "math_response", + "strict": true, + "schema": null + }`), + }, + false, + }, + { + "", + args{ + data: []byte(`[123,456]`), + }, + true, + }, + { + "", + args{ + data: []byte(`{ + "name": "math_response", + "strict": true, + "schema": 123456 + }`), + }, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var r openai.ChatCompletionResponseFormatJSONSchema + err := r.UnmarshalJSON(tt.args.data) + if (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestChatCompletionRequest_UnmarshalJSON(t *testing.T) { + type args struct { + bs []byte + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + "", + args{bs: []byte(`{ + "model": "llama3-1b", + "messages": [ + { "role": "system", "content": "You are a helpful math tutor." }, + { "role": "user", "content": "solve 8x + 31 = 2" } + ], + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "math_response", + "strict": true, + "schema": { + "type": "object", + "properties": { + "steps": { + "type": "array", + "items": { + "type": "object", + "properties": { + "explanation": { "type": "string" }, + "output": { "type": "string" } + }, + "required": ["explanation","output"], + "additionalProperties": false + } + }, + "final_answer": { "type": "string" } + }, + "required": ["steps","final_answer"], + "additionalProperties": false + } + } + } +}`)}, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var m openai.ChatCompletionRequest + err := json.Unmarshal(tt.args.bs, &m) + if err != nil { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} From bd612cebceb5d84c423bece9e2f2766c7567f8ed Mon Sep 17 00:00:00 2001 From: Matt Tinsley <68241446+mathewtinsley@users.noreply.github.com> Date: Tue, 8 Jul 2025 04:23:41 -0700 Subject: [PATCH 111/129] Add support for Chat Completion Service Tier (#1023) * Add support for Chat Completion Service Tier * Add priority service tier --- chat.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/chat.go b/chat.go index b4a0ad90f..0f0c5b5d5 100644 --- a/chat.go +++ b/chat.go @@ -307,6 +307,8 @@ type ChatCompletionRequest struct { // Such as think mode for qwen3. "chat_template_kwargs": {"enable_thinking": false} // https://qwen.readthedocs.io/en/latest/deployment/vllm.html#thinking-non-thinking-modes ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"` + // Specifies the latency tier to use for processing the request. + ServiceTier ServiceTier `json:"service_tier,omitempty"` } type StreamOptions struct { @@ -390,6 +392,15 @@ const ( FinishReasonNull FinishReason = "null" ) +type ServiceTier string + +const ( + ServiceTierAuto ServiceTier = "auto" + ServiceTierDefault ServiceTier = "default" + ServiceTierFlex ServiceTier = "flex" + ServiceTierPriority ServiceTier = "priority" +) + func (r FinishReason) MarshalJSON() ([]byte, error) { if r == FinishReasonNull || r == "" { return []byte("null"), nil @@ -422,6 +433,7 @@ type ChatCompletionResponse struct { Usage Usage `json:"usage"` SystemFingerprint string `json:"system_fingerprint"` PromptFilterResults []PromptFilterResult `json:"prompt_filter_results,omitempty"` + ServiceTier ServiceTier `json:"service_tier,omitempty"` httpHeader } From c650976e492731e3fa758a2dbff2f743ae0e1c98 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Thu, 10 Jul 2025 18:35:16 +0800 Subject: [PATCH 112/129] Support $ref and $defs in JSON Schema (#1030) * support $ref and $defs in JSON Schema * update --- jsonschema/json.go | 36 +++++-- jsonschema/json_test.go | 70 ++++++++++++++ jsonschema/validate.go | 65 +++++++++++-- jsonschema/validate_test.go | 187 +++++++++++++++++++++++++++++++++++- 4 files changed, 340 insertions(+), 18 deletions(-) diff --git a/jsonschema/json.go b/jsonschema/json.go index 03bb68891..29d15b409 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -48,6 +48,11 @@ type Definition struct { AdditionalProperties any `json:"additionalProperties,omitempty"` // Whether the schema is nullable or not. Nullable bool `json:"nullable,omitempty"` + + // Ref Reference to a definition in $defs or external schema. + Ref string `json:"$ref,omitempty"` + // Defs A map of reusable schema definitions. + Defs map[string]Definition `json:"$defs,omitempty"` } func (d *Definition) MarshalJSON() ([]byte, error) { @@ -67,10 +72,16 @@ func (d *Definition) Unmarshal(content string, v any) error { } func GenerateSchemaForType(v any) (*Definition, error) { - return reflectSchema(reflect.TypeOf(v)) + var defs = make(map[string]Definition) + def, err := reflectSchema(reflect.TypeOf(v), defs) + if err != nil { + return nil, err + } + def.Defs = defs + return def, nil } -func reflectSchema(t reflect.Type) (*Definition, error) { +func reflectSchema(t reflect.Type, defs map[string]Definition) (*Definition, error) { var d Definition switch t.Kind() { case reflect.String: @@ -84,21 +95,32 @@ func reflectSchema(t reflect.Type) (*Definition, error) { d.Type = Boolean case reflect.Slice, reflect.Array: d.Type = Array - items, err := reflectSchema(t.Elem()) + items, err := reflectSchema(t.Elem(), defs) if err != nil { return nil, err } d.Items = items case reflect.Struct: + if t.Name() != "" { + if _, ok := defs[t.Name()]; !ok { + defs[t.Name()] = Definition{} + object, err := reflectSchemaObject(t, defs) + if err != nil { + return nil, err + } + defs[t.Name()] = *object + } + return &Definition{Ref: "#/$defs/" + t.Name()}, nil + } d.Type = Object d.AdditionalProperties = false - object, err := reflectSchemaObject(t) + object, err := reflectSchemaObject(t, defs) if err != nil { return nil, err } d = *object case reflect.Ptr: - definition, err := reflectSchema(t.Elem()) + definition, err := reflectSchema(t.Elem(), defs) if err != nil { return nil, err } @@ -112,7 +134,7 @@ func reflectSchema(t reflect.Type) (*Definition, error) { return &d, nil } -func reflectSchemaObject(t reflect.Type) (*Definition, error) { +func reflectSchemaObject(t reflect.Type, defs map[string]Definition) (*Definition, error) { var d = Definition{ Type: Object, AdditionalProperties: false, @@ -136,7 +158,7 @@ func reflectSchemaObject(t reflect.Type) (*Definition, error) { required = false } - item, err := reflectSchema(field.Type) + item, err := reflectSchema(field.Type, defs) if err != nil { return nil, err } diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go index 84f25fa85..31b54ed1a 100644 --- a/jsonschema/json_test.go +++ b/jsonschema/json_test.go @@ -183,6 +183,17 @@ func TestDefinition_MarshalJSON(t *testing.T) { } func TestStructToSchema(t *testing.T) { + type Tweet struct { + Text string `json:"text"` + } + + type Person struct { + Name string `json:"name,omitempty"` + Age int `json:"age,omitempty"` + Friends []Person `json:"friends,omitempty"` + Tweets []Tweet `json:"tweets,omitempty"` + } + tests := []struct { name string in any @@ -376,6 +387,65 @@ func TestStructToSchema(t *testing.T) { "additionalProperties":false }`, }, + { + name: "Test with $ref and $defs", + in: struct { + Person Person `json:"person"` + Tweets []Tweet `json:"tweets"` + }{}, + want: `{ + "type" : "object", + "properties" : { + "person" : { + "$ref" : "#/$defs/Person" + }, + "tweets" : { + "type" : "array", + "items" : { + "$ref" : "#/$defs/Tweet" + } + } + }, + "required" : [ "person", "tweets" ], + "additionalProperties" : false, + "$defs" : { + "Person" : { + "type" : "object", + "properties" : { + "age" : { + "type" : "integer" + }, + "friends" : { + "type" : "array", + "items" : { + "$ref" : "#/$defs/Person" + } + }, + "name" : { + "type" : "string" + }, + "tweets" : { + "type" : "array", + "items" : { + "$ref" : "#/$defs/Tweet" + } + } + }, + "additionalProperties" : false + }, + "Tweet" : { + "type" : "object", + "properties" : { + "text" : { + "type" : "string" + } + }, + "required" : [ "text" ], + "additionalProperties" : false + } + } +}`, + }, } for _, tt := range tests { diff --git a/jsonschema/validate.go b/jsonschema/validate.go index 49f9b8859..1bd2f809c 100644 --- a/jsonschema/validate.go +++ b/jsonschema/validate.go @@ -5,26 +5,68 @@ import ( "errors" ) +func CollectDefs(def Definition) map[string]Definition { + result := make(map[string]Definition) + collectDefsRecursive(def, result, "#") + return result +} + +func collectDefsRecursive(def Definition, result map[string]Definition, prefix string) { + for k, v := range def.Defs { + path := prefix + "/$defs/" + k + result[path] = v + collectDefsRecursive(v, result, path) + } + for k, sub := range def.Properties { + collectDefsRecursive(sub, result, prefix+"/properties/"+k) + } + if def.Items != nil { + collectDefsRecursive(*def.Items, result, prefix) + } +} + func VerifySchemaAndUnmarshal(schema Definition, content []byte, v any) error { var data any err := json.Unmarshal(content, &data) if err != nil { return err } - if !Validate(schema, data) { + if !Validate(schema, data, WithDefs(CollectDefs(schema))) { return errors.New("data validation failed against the provided schema") } return json.Unmarshal(content, &v) } -func Validate(schema Definition, data any) bool { +type validateArgs struct { + Defs map[string]Definition +} + +type ValidateOption func(*validateArgs) + +func WithDefs(defs map[string]Definition) ValidateOption { + return func(option *validateArgs) { + option.Defs = defs + } +} + +func Validate(schema Definition, data any, opts ...ValidateOption) bool { + args := validateArgs{} + for _, opt := range opts { + opt(&args) + } + if len(opts) == 0 { + args.Defs = CollectDefs(schema) + } switch schema.Type { case Object: - return validateObject(schema, data) + return validateObject(schema, data, args.Defs) case Array: - return validateArray(schema, data) + return validateArray(schema, data, args.Defs) case String: - _, ok := data.(string) + v, ok := data.(string) + if ok && len(schema.Enum) > 0 { + return contains(schema.Enum, v) + } return ok case Number: // float64 and int _, ok := data.(float64) @@ -45,11 +87,16 @@ func Validate(schema Definition, data any) bool { case Null: return data == nil default: + if schema.Ref != "" && args.Defs != nil { + if v, ok := args.Defs[schema.Ref]; ok { + return Validate(v, data, WithDefs(args.Defs)) + } + } return false } } -func validateObject(schema Definition, data any) bool { +func validateObject(schema Definition, data any, defs map[string]Definition) bool { dataMap, ok := data.(map[string]any) if !ok { return false @@ -61,7 +108,7 @@ func validateObject(schema Definition, data any) bool { } for key, valueSchema := range schema.Properties { value, exists := dataMap[key] - if exists && !Validate(valueSchema, value) { + if exists && !Validate(valueSchema, value, WithDefs(defs)) { return false } else if !exists && contains(schema.Required, key) { return false @@ -70,13 +117,13 @@ func validateObject(schema Definition, data any) bool { return true } -func validateArray(schema Definition, data any) bool { +func validateArray(schema Definition, data any, defs map[string]Definition) bool { dataArray, ok := data.([]any) if !ok { return false } for _, item := range dataArray { - if !Validate(*schema.Items, item) { + if !Validate(*schema.Items, item, WithDefs(defs)) { return false } } diff --git a/jsonschema/validate_test.go b/jsonschema/validate_test.go index 6fa30ab0c..aefdf4069 100644 --- a/jsonschema/validate_test.go +++ b/jsonschema/validate_test.go @@ -1,6 +1,7 @@ package jsonschema_test import ( + "reflect" "testing" "github.com/sashabaranov/go-openai/jsonschema" @@ -70,6 +71,96 @@ func Test_Validate(t *testing.T) { }, Required: []string{"string"}, }}, false}, + { + "test schema with ref and defs", args{data: map[string]any{ + "person": map[string]any{ + "name": "John", + "gender": "male", + "age": 28, + "profile": map[string]any{ + "full_name": "John Doe", + }, + }, + }, schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "person": {Ref: "#/$defs/Person"}, + }, + Required: []string{"person"}, + Defs: map[string]jsonschema.Definition{ + "Person": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": {Type: jsonschema.String}, + "gender": {Type: jsonschema.String, Enum: []string{"male", "female", "unknown"}}, + "age": {Type: jsonschema.Integer}, + "profile": {Ref: "#/$defs/Person/$defs/Profile"}, + "tweets": {Type: jsonschema.Array, Items: &jsonschema.Definition{Ref: "#/$defs/Tweet"}}, + }, + Required: []string{"name", "gender", "age", "profile"}, + Defs: map[string]jsonschema.Definition{ + "Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + }, + }, + "Tweet": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "text": {Type: jsonschema.String}, + "person": {Ref: "#/$defs/Person"}, + }, + }, + }, + }}, true}, + { + "test enum invalid value", args{data: map[string]any{ + "person": map[string]any{ + "name": "John", + "gender": "other", + "age": 28, + "profile": map[string]any{ + "full_name": "John Doe", + }, + }, + }, schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "person": {Ref: "#/$defs/Person"}, + }, + Required: []string{"person"}, + Defs: map[string]jsonschema.Definition{ + "Person": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": {Type: jsonschema.String}, + "gender": {Type: jsonschema.String, Enum: []string{"male", "female", "unknown"}}, + "age": {Type: jsonschema.Integer}, + "profile": {Ref: "#/$defs/Person/$defs/Profile"}, + "tweets": {Type: jsonschema.Array, Items: &jsonschema.Definition{Ref: "#/$defs/Tweet"}}, + }, + Required: []string{"name", "gender", "age", "profile"}, + Defs: map[string]jsonschema.Definition{ + "Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + }, + }, + "Tweet": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "text": {Type: jsonschema.String}, + "person": {Ref: "#/$defs/Person"}, + }, + }, + }, + }}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -156,8 +247,100 @@ func TestUnmarshal(t *testing.T) { err := jsonschema.VerifySchemaAndUnmarshal(tt.args.schema, tt.args.content, tt.args.v) if (err != nil) != tt.wantErr { t.Errorf("Unmarshal() error = %v, wantErr %v", err, tt.wantErr) - } else if err == nil { - t.Logf("Unmarshal() v = %+v\n", tt.args.v) + } + }) + } +} + +func TestCollectDefs(t *testing.T) { + type args struct { + schema jsonschema.Definition + } + tests := []struct { + name string + args args + want map[string]jsonschema.Definition + }{ + { + "test collect defs", + args{ + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "person": {Ref: "#/$defs/Person"}, + }, + Required: []string{"person"}, + Defs: map[string]jsonschema.Definition{ + "Person": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": {Type: jsonschema.String}, + "gender": {Type: jsonschema.String, Enum: []string{"male", "female", "unknown"}}, + "age": {Type: jsonschema.Integer}, + "profile": {Ref: "#/$defs/Person/$defs/Profile"}, + "tweets": {Type: jsonschema.Array, Items: &jsonschema.Definition{Ref: "#/$defs/Tweet"}}, + }, + Required: []string{"name", "gender", "age", "profile"}, + Defs: map[string]jsonschema.Definition{ + "Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + }, + }, + "Tweet": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "text": {Type: jsonschema.String}, + "person": {Ref: "#/$defs/Person"}, + }, + }, + }, + }, + }, + map[string]jsonschema.Definition{ + "#/$defs/Person": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": {Type: jsonschema.String}, + "gender": {Type: jsonschema.String, Enum: []string{"male", "female", "unknown"}}, + "age": {Type: jsonschema.Integer}, + "profile": {Ref: "#/$defs/Person/$defs/Profile"}, + "tweets": {Type: jsonschema.Array, Items: &jsonschema.Definition{Ref: "#/$defs/Tweet"}}, + }, + Required: []string{"name", "gender", "age", "profile"}, + Defs: map[string]jsonschema.Definition{ + "Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + }, + }, + "#/$defs/Person/$defs/Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + "#/$defs/Tweet": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "text": {Type: jsonschema.String}, + "person": {Ref: "#/$defs/Person"}, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := jsonschema.CollectDefs(tt.args.schema) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("CollectDefs() = %v, want %v", got, tt.want) } }) } From 1bf77f6fd6f10b9cd80e1dfdedd616f3f1712e20 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Thu, 10 Jul 2025 21:49:20 +0100 Subject: [PATCH 113/129] Improve unit test coverage (#1032) * Add unit tests to improve coverage * Fix type assertion checks in tests --- image_test.go | 50 ++++++++++++++++++++++++++++++ internal/error_accumulator_test.go | 7 +++++ internal/form_builder_test.go | 14 +++++++++ internal/request_builder_test.go | 27 ++++++++++++++++ 4 files changed, 98 insertions(+) diff --git a/image_test.go b/image_test.go index 644005515..bb9a086fd 100644 --- a/image_test.go +++ b/image_test.go @@ -4,6 +4,7 @@ import ( utils "github.com/sashabaranov/go-openai/internal" "github.com/sashabaranov/go-openai/internal/test/checks" + "bytes" "context" "fmt" "io" @@ -156,3 +157,52 @@ func TestVariImageFormBuilderFailures(t *testing.T) { _, err = client.CreateVariImage(ctx, req) checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") } + +type testNamedReader struct{ io.Reader } + +func (testNamedReader) Name() string { return "named.txt" } + +func TestWrapReader(t *testing.T) { + r := bytes.NewBufferString("data") + wrapped := WrapReader(r, "file.png", "image/png") + f, ok := wrapped.(interface { + Name() string + ContentType() string + }) + if !ok { + t.Fatal("wrapped reader missing Name or ContentType") + } + if f.Name() != "file.png" { + t.Fatalf("expected name file.png, got %s", f.Name()) + } + if f.ContentType() != "image/png" { + t.Fatalf("expected content type image/png, got %s", f.ContentType()) + } + + // test name from underlying reader + nr := testNamedReader{Reader: bytes.NewBufferString("d")} + wrapped = WrapReader(nr, "", "text/plain") + f, ok = wrapped.(interface { + Name() string + ContentType() string + }) + if !ok { + t.Fatal("wrapped named reader missing Name or ContentType") + } + if f.Name() != "named.txt" { + t.Fatalf("expected name named.txt, got %s", f.Name()) + } + if f.ContentType() != "text/plain" { + t.Fatalf("expected content type text/plain, got %s", f.ContentType()) + } + + // no name provided + wrapped = WrapReader(bytes.NewBuffer(nil), "", "") + f2, ok := wrapped.(interface{ Name() string }) + if !ok { + t.Fatal("wrapped anonymous reader missing Name") + } + if f2.Name() != "" { + t.Fatalf("expected empty name, got %s", f2.Name()) + } +} diff --git a/internal/error_accumulator_test.go b/internal/error_accumulator_test.go index 3fa9d7714..f6c226c5e 100644 --- a/internal/error_accumulator_test.go +++ b/internal/error_accumulator_test.go @@ -4,6 +4,7 @@ import ( "testing" openai "github.com/sashabaranov/go-openai/internal" + "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" ) @@ -30,3 +31,9 @@ func TestDefaultErrorAccumulator_EmptyBuffer(t *testing.T) { t.Fatal("Buffer should be empty initially") } } + +func TestDefaultErrorAccumulator_WriteError(t *testing.T) { + ea := &openai.DefaultErrorAccumulator{Buffer: &test.FailingErrorBuffer{}} + err := ea.Write([]byte("fail")) + checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Write should propagate buffer errors") +} diff --git a/internal/form_builder_test.go b/internal/form_builder_test.go index ddd6b8448..1cc82ab8a 100644 --- a/internal/form_builder_test.go +++ b/internal/form_builder_test.go @@ -161,3 +161,17 @@ func TestWriteField(t *testing.T) { checks.NoError(t, err, "should write field without error") }) } + +func TestCreateFormFile(t *testing.T) { + buf := &bytes.Buffer{} + builder := NewFormBuilder(buf) + + err := builder.createFormFile("file", bytes.NewBufferString("data"), "") + if err == nil { + t.Fatal("expected error for empty filename") + } + + builder = NewFormBuilder(&failingWriter{}) + err = builder.createFormFile("file", bytes.NewBufferString("data"), "name") + checks.ErrorIs(t, err, errMockFailingWriterError, "should propagate writer error") +} diff --git a/internal/request_builder_test.go b/internal/request_builder_test.go index e26022a6b..1561b87fb 100644 --- a/internal/request_builder_test.go +++ b/internal/request_builder_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "errors" + "io" "net/http" "reflect" "testing" @@ -59,3 +60,29 @@ func TestRequestBuilderReturnsRequestWhenRequestOfArgsIsNil(t *testing.T) { t.Errorf("Build() got = %v, want %v", got, want) } } + +func TestRequestBuilderWithReaderBodyAndHeader(t *testing.T) { + b := NewRequestBuilder() + ctx := context.Background() + method := http.MethodPost + url := "/reader" + bodyContent := "hello" + body := bytes.NewBufferString(bodyContent) + header := http.Header{"X-Test": []string{"val"}} + + req, err := b.Build(ctx, method, url, body, header) + if err != nil { + t.Fatalf("Build returned error: %v", err) + } + + gotBody, err := io.ReadAll(req.Body) + if err != nil { + t.Fatalf("cannot read body: %v", err) + } + if string(gotBody) != bodyContent { + t.Fatalf("expected body %q, got %q", bodyContent, string(gotBody)) + } + if req.Header.Get("X-Test") != "val" { + t.Fatalf("expected header set to val, got %q", req.Header.Get("X-Test")) + } +} From a0c185f3628cd616fb9d1648700fc51c40c42dda Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Fri, 11 Jul 2025 19:17:33 +0800 Subject: [PATCH 114/129] Removed root $ref from GenerateSchemaForType (#1033) * support $ref and $defs in JSON Schema * update * removed root $ref from JSON Schema * Update json.go * Update json_test.go * Update jsonschema/json.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update jsonschema/json.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- jsonschema/json.go | 44 ++++++++++ jsonschema/json_test.go | 174 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 218 insertions(+) diff --git a/jsonschema/json.go b/jsonschema/json.go index 29d15b409..75e3b5173 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -77,6 +77,27 @@ func GenerateSchemaForType(v any) (*Definition, error) { if err != nil { return nil, err } + // If the schema has a root $ref, resolve it by: + // 1. Extracting the key from the $ref. + // 2. Detaching the referenced definition from $defs. + // 3. Checking for self-references in the detached definition. + // - If a self-reference is found, restore the original $defs structure. + // 4. Flattening the referenced definition into the root schema. + // 5. Clearing the $ref field in the root schema. + if def.Ref != "" { + origRef := def.Ref + key := strings.TrimPrefix(origRef, "#/$defs/") + if root, ok := defs[key]; ok { + delete(defs, key) + root.Defs = defs + if containsRef(root, origRef) { + root.Defs = nil + defs[key] = root + } + *def = root + } + def.Ref = "" + } def.Defs = defs return def, nil } @@ -189,3 +210,26 @@ func reflectSchemaObject(t reflect.Type, defs map[string]Definition) (*Definitio d.Properties = properties return &d, nil } + +func containsRef(def Definition, targetRef string) bool { + if def.Ref == targetRef { + return true + } + + for _, d := range def.Defs { + if containsRef(d, targetRef) { + return true + } + } + + for _, prop := range def.Properties { + if containsRef(prop, targetRef) { + return true + } + } + + if def.Items != nil && containsRef(*def.Items, targetRef) { + return true + } + return false +} diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go index 31b54ed1a..34f5d88eb 100644 --- a/jsonschema/json_test.go +++ b/jsonschema/json_test.go @@ -182,6 +182,18 @@ func TestDefinition_MarshalJSON(t *testing.T) { } } +type User struct { + ID int `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Orders []*Order `json:"orders,omitempty"` +} + +type Order struct { + ID int `json:"id,omitempty"` + Amount float64 `json:"amount,omitempty"` + Buyer *User `json:"buyer,omitempty"` +} + func TestStructToSchema(t *testing.T) { type Tweet struct { Text string `json:"text"` @@ -194,6 +206,13 @@ func TestStructToSchema(t *testing.T) { Tweets []Tweet `json:"tweets,omitempty"` } + type MyStructuredResponse struct { + PascalCase string `json:"pascal_case" required:"true" description:"PascalCase"` + CamelCase string `json:"camel_case" required:"true" description:"CamelCase"` + KebabCase string `json:"kebab_case" required:"true" description:"KebabCase"` + SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"` + } + tests := []struct { name string in any @@ -444,6 +463,161 @@ func TestStructToSchema(t *testing.T) { "additionalProperties" : false } } +}`, + }, + { + name: "Test Person", + in: Person{}, + want: `{ + "type": "object", + "properties": { + "age": { + "type": "integer" + }, + "friends": { + "type": "array", + "items": { + "$ref": "#/$defs/Person" + } + }, + "name": { + "type": "string" + }, + "tweets": { + "type": "array", + "items": { + "$ref": "#/$defs/Tweet" + } + } + }, + "additionalProperties": false, + "$defs": { + "Person": { + "type": "object", + "properties": { + "age": { + "type": "integer" + }, + "friends": { + "type": "array", + "items": { + "$ref": "#/$defs/Person" + } + }, + "name": { + "type": "string" + }, + "tweets": { + "type": "array", + "items": { + "$ref": "#/$defs/Tweet" + } + } + }, + "additionalProperties": false + }, + "Tweet": { + "type": "object", + "properties": { + "text": { + "type": "string" + } + }, + "required": [ + "text" + ], + "additionalProperties": false + } + } +}`, + }, + { + name: "Test MyStructuredResponse", + in: MyStructuredResponse{}, + want: `{ + "type": "object", + "properties": { + "camel_case": { + "type": "string", + "description": "CamelCase" + }, + "kebab_case": { + "type": "string", + "description": "KebabCase" + }, + "pascal_case": { + "type": "string", + "description": "PascalCase" + }, + "snake_case": { + "type": "string", + "description": "SnakeCase" + } + }, + "required": [ + "pascal_case", + "camel_case", + "kebab_case", + "snake_case" + ], + "additionalProperties": false +}`, + }, + { + name: "Test User", + in: User{}, + want: `{ + "type": "object", + "properties": { + "id": { + "type": "integer" + }, + "name": { + "type": "string" + }, + "orders": { + "type": "array", + "items": { + "$ref": "#/$defs/Order" + } + } + }, + "additionalProperties": false, + "$defs": { + "Order": { + "type": "object", + "properties": { + "amount": { + "type": "number" + }, + "buyer": { + "$ref": "#/$defs/User" + }, + "id": { + "type": "integer" + } + }, + "additionalProperties": false + }, + "User": { + "type": "object", + "properties": { + "id": { + "type": "integer" + }, + "name": { + "type": "string" + }, + "orders": { + "type": "array", + "items": { + "$ref": "#/$defs/Order" + } + } + }, + "additionalProperties": false + } + } }`, }, } From 575aff439644c3b6460f2617b81b119044bfc35c Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 11 Jul 2025 13:15:15 +0100 Subject: [PATCH 115/129] Add tests for form and request builders (#1036) --- internal/form_builder_test.go | 13 +++++++++++++ internal/request_builder_test.go | 8 ++++++++ 2 files changed, 21 insertions(+) diff --git a/internal/form_builder_test.go b/internal/form_builder_test.go index 1cc82ab8a..53ef11d23 100644 --- a/internal/form_builder_test.go +++ b/internal/form_builder_test.go @@ -8,6 +8,7 @@ import ( "bytes" "os" + "strings" "testing" ) @@ -175,3 +176,15 @@ func TestCreateFormFile(t *testing.T) { err = builder.createFormFile("file", bytes.NewBufferString("data"), "name") checks.ErrorIs(t, err, errMockFailingWriterError, "should propagate writer error") } + +func TestCreateFormFileSuccess(t *testing.T) { + buf := &bytes.Buffer{} + builder := NewFormBuilder(buf) + + err := builder.createFormFile("file", bytes.NewBufferString("data"), "foo.txt") + checks.NoError(t, err, "createFormFile should succeed") + + if !strings.Contains(buf.String(), "filename=\"foo.txt\"") { + t.Fatalf("expected filename header, got %q", buf.String()) + } +} diff --git a/internal/request_builder_test.go b/internal/request_builder_test.go index 1561b87fb..adccb158e 100644 --- a/internal/request_builder_test.go +++ b/internal/request_builder_test.go @@ -86,3 +86,11 @@ func TestRequestBuilderWithReaderBodyAndHeader(t *testing.T) { t.Fatalf("expected header set to val, got %q", req.Header.Get("X-Test")) } } + +func TestRequestBuilderInvalidURL(t *testing.T) { + b := NewRequestBuilder() + _, err := b.Build(context.Background(), http.MethodGet, ":", nil, nil) + if err == nil { + t.Fatal("expected error for invalid URL") + } +} From 88eb1df90bb6af6a1f4d471a08d9f8a13dd5a8e5 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 11 Jul 2025 15:58:52 +0100 Subject: [PATCH 116/129] Ignore codecov coverage for examples and internal/test (#1038) --- .codecov.yml | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 .codecov.yml diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 000000000..81773666c --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,4 @@ +coverage: + ignore: + - "examples/**" + - "internal/test/**" From 8d681e7f9a8f172168199f29e8e1f16701d6817a Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 11 Jul 2025 16:18:23 +0100 Subject: [PATCH 117/129] Increase image.go test coverage to 100% (#1039) --- image_test.go | 303 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 209 insertions(+), 94 deletions(-) diff --git a/image_test.go b/image_test.go index bb9a086fd..c2c8f42dc 100644 --- a/image_test.go +++ b/image_test.go @@ -40,122 +40,237 @@ func (fb *mockFormBuilder) FormDataContentType() string { } func TestImageFormBuilderFailures(t *testing.T) { - config := DefaultConfig("") - config.BaseURL = "" - client := NewClientWithConfig(config) - - mockBuilder := &mockFormBuilder{} - client.createFormBuilder = func(io.Writer) utils.FormBuilder { - return mockBuilder - } ctx := context.Background() - - req := ImageEditRequest{ - Mask: &os.File{}, - } - mockFailedErr := fmt.Errorf("mock form builder fail") - mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { - return mockFailedErr - } - _, err := client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - mockBuilder.mockCreateFormFileReader = func(name string, _ io.Reader, _ string) error { - if name == "mask" { - return mockFailedErr - } - return nil + newClient := func(fb *mockFormBuilder) *Client { + cfg := DefaultConfig("") + cfg.BaseURL = "" + c := NewClientWithConfig(cfg) + c.createFormBuilder = func(io.Writer) utils.FormBuilder { return fb } + return c } - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - mockBuilder.mockCreateFormFile = func(string, *os.File) error { - return nil + tests := []struct { + name string + setup func(*mockFormBuilder) + req ImageEditRequest + }{ + { + name: "image", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return mockFailedErr } + fb.mockWriteField = func(string, string) error { return nil } + fb.mockClose = func() error { return nil } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "mask", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(name string, _ io.Reader, _ string) error { + if name == "mask" { + return mockFailedErr + } + return nil + } + fb.mockWriteField = func(string, string) error { return nil } + fb.mockClose = func() error { return nil } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "prompt", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field, _ string) error { + if field == "prompt" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "n", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field, _ string) error { + if field == "n" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "size", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field, _ string) error { + if field == "size" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "response_format", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field, _ string) error { + if field == "response_format" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "close", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(string, string) error { return nil } + fb.mockClose = func() error { return mockFailedErr } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, } - var failForField string - mockBuilder.mockWriteField = func(fieldname, _ string) error { - if fieldname == failForField { - return mockFailedErr - } - return nil + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + fb := &mockFormBuilder{} + tc.setup(fb) + client := newClient(fb) + _, err := client.CreateEditImage(ctx, tc.req) + checks.ErrorIs(t, err, mockFailedErr, "CreateEditImage should return error if form builder fails") + }) } - failForField = "prompt" - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - - failForField = "n" - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - - failForField = "size" - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - - failForField = "response_format" - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") + t.Run("new request", func(t *testing.T) { + fb := &mockFormBuilder{ + mockCreateFormFileReader: func(string, io.Reader, string) error { return nil }, + mockWriteField: func(string, string) error { return nil }, + mockClose: func() error { return nil }, + } + client := newClient(fb) + client.requestBuilder = &failingRequestBuilder{} - failForField = "" - mockBuilder.mockClose = func() error { - return mockFailedErr - } - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") + _, err := client.CreateEditImage(ctx, ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}) + checks.ErrorIs(t, err, errTestRequestBuilderFailed, "CreateEditImage should return error if request builder fails") + }) } func TestVariImageFormBuilderFailures(t *testing.T) { - config := DefaultConfig("") - config.BaseURL = "" - client := NewClientWithConfig(config) - - mockBuilder := &mockFormBuilder{} - client.createFormBuilder = func(io.Writer) utils.FormBuilder { - return mockBuilder - } ctx := context.Background() - - req := ImageVariRequest{} - mockFailedErr := fmt.Errorf("mock form builder fail") - mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { - return mockFailedErr - } - _, err := client.CreateVariImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") - mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { - return nil + newClient := func(fb *mockFormBuilder) *Client { + cfg := DefaultConfig("") + cfg.BaseURL = "" + c := NewClientWithConfig(cfg) + c.createFormBuilder = func(io.Writer) utils.FormBuilder { return fb } + return c } - var failForField string - mockBuilder.mockWriteField = func(fieldname, _ string) error { - if fieldname == failForField { - return mockFailedErr - } - return nil + tests := []struct { + name string + setup func(*mockFormBuilder) + req ImageVariRequest + }{ + { + name: "image", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return mockFailedErr } + fb.mockWriteField = func(string, string) error { return nil } + fb.mockClose = func() error { return nil } + }, + req: ImageVariRequest{Image: bytes.NewBuffer(nil)}, + }, + { + name: "n", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field string, _ string) error { + if field == "n" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageVariRequest{Image: bytes.NewBuffer(nil)}, + }, + { + name: "size", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field string, _ string) error { + if field == "size" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageVariRequest{Image: bytes.NewBuffer(nil)}, + }, + { + name: "response_format", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field string, _ string) error { + if field == "response_format" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageVariRequest{Image: bytes.NewBuffer(nil)}, + }, + { + name: "close", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(string, string) error { return nil } + fb.mockClose = func() error { return mockFailedErr } + }, + req: ImageVariRequest{Image: bytes.NewBuffer(nil)}, + }, } - failForField = "n" - _, err = client.CreateVariImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") - - failForField = "size" - _, err = client.CreateVariImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + fb := &mockFormBuilder{} + tc.setup(fb) + client := newClient(fb) + _, err := client.CreateVariImage(ctx, tc.req) + checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") + }) + } - failForField = "response_format" - _, err = client.CreateVariImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") + t.Run("new request", func(t *testing.T) { + fb := &mockFormBuilder{ + mockCreateFormFileReader: func(string, io.Reader, string) error { return nil }, + mockWriteField: func(string, string) error { return nil }, + mockClose: func() error { return nil }, + } + client := newClient(fb) + client.requestBuilder = &failingRequestBuilder{} - failForField = "" - mockBuilder.mockClose = func() error { - return mockFailedErr - } - _, err = client.CreateVariImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") + _, err := client.CreateVariImage(ctx, ImageVariRequest{Image: bytes.NewBuffer(nil)}) + checks.ErrorIs(t, err, errTestRequestBuilderFailed, "CreateVariImage should return error if request builder fails") + }) } type testNamedReader struct{ io.Reader } From 8665ad7264b64cd2046c2a2d4addf8617c6fffea Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 11 Jul 2025 16:31:09 +0100 Subject: [PATCH 118/129] Increase jsonschema test coverage (#1040) --- jsonschema/json_additional_test.go | 73 ++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 jsonschema/json_additional_test.go diff --git a/jsonschema/json_additional_test.go b/jsonschema/json_additional_test.go new file mode 100644 index 000000000..70cf37490 --- /dev/null +++ b/jsonschema/json_additional_test.go @@ -0,0 +1,73 @@ +package jsonschema_test + +import ( + "testing" + + "github.com/sashabaranov/go-openai/jsonschema" +) + +// Test Definition.Unmarshal, including success path, validation error, +// JSON syntax error and type mismatch during unmarshalling. +func TestDefinitionUnmarshal(t *testing.T) { + schema := jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": {Type: jsonschema.String}, + }, + } + + var dst struct { + Name string `json:"name"` + } + if err := schema.Unmarshal(`{"name":"foo"}`, &dst); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if dst.Name != "foo" { + t.Errorf("expected name to be foo, got %q", dst.Name) + } + + if err := schema.Unmarshal(`{`, &dst); err == nil { + t.Error("expected error for malformed json") + } + + if err := schema.Unmarshal(`{"name":1}`, &dst); err == nil { + t.Error("expected validation error") + } + + numSchema := jsonschema.Definition{Type: jsonschema.Number} + var s string + if err := numSchema.Unmarshal(`123`, &s); err == nil { + t.Error("expected unmarshal type error") + } +} + +// Ensure GenerateSchemaForType returns an error when encountering unsupported types. +func TestGenerateSchemaForTypeUnsupported(t *testing.T) { + type Bad struct { + Ch chan int `json:"ch"` + } + if _, err := jsonschema.GenerateSchemaForType(Bad{}); err == nil { + t.Fatal("expected error for unsupported type") + } +} + +// Validate should fail when provided data does not match the expected container types. +func TestValidateInvalidContainers(t *testing.T) { + objSchema := jsonschema.Definition{Type: jsonschema.Object} + if jsonschema.Validate(objSchema, 1) { + t.Error("expected object validation to fail for non-map input") + } + + arrSchema := jsonschema.Definition{Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.String}} + if jsonschema.Validate(arrSchema, 1) { + t.Error("expected array validation to fail for non-slice input") + } +} + +// Validate should return false when $ref cannot be resolved. +func TestValidateRefNotFound(t *testing.T) { + refSchema := jsonschema.Definition{Ref: "#/$defs/Missing"} + if jsonschema.Validate(refSchema, "data", jsonschema.WithDefs(map[string]jsonschema.Definition{})) { + t.Error("expected validation to fail when reference is missing") + } +} From 1e912177886f37ae0c4159982a962d9ef8ff921b Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 11 Jul 2025 16:48:49 +0100 Subject: [PATCH 119/129] test: cover CreateFile request builder failure (#1041) --- files_test.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/files_test.go b/files_test.go index 3c1b99fb4..486ef892e 100644 --- a/files_test.go +++ b/files_test.go @@ -121,3 +121,20 @@ func TestFileUploadWithNonExistentPath(t *testing.T) { _, err := client.CreateFile(ctx, req) checks.ErrorIs(t, err, os.ErrNotExist, "CreateFile should return error if file does not exist") } +func TestCreateFileRequestBuilderFailure(t *testing.T) { + config := DefaultConfig("") + config.BaseURL = "" + client := NewClientWithConfig(config) + client.requestBuilder = &failingRequestBuilder{} + + client.createFormBuilder = func(io.Writer) utils.FormBuilder { + return &mockFormBuilder{ + mockWriteField: func(string, string) error { return nil }, + mockCreateFormFile: func(string, *os.File) error { return nil }, + mockClose: func() error { return nil }, + } + } + + _, err := client.CreateFile(context.Background(), FileRequest{FilePath: "client.go"}) + checks.ErrorIs(t, err, errTestRequestBuilderFailed, "CreateFile should return error if request builder fails") +} From 3bb1014fa7e8d09fa6ff31bc6bce6172a3bf59ae Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 11 Jul 2025 17:11:07 +0100 Subject: [PATCH 120/129] ci: enable version compatibility vet (#1042) --- .github/workflows/pr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 18c720f03..2c9730656 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -16,7 +16,7 @@ jobs: go-version: '1.24' - name: Run vet run: | - go vet . + go vet -stdversion ./... - name: Run golangci-lint uses: golangci/golangci-lint-action@v7 with: From bd36c45dc505d3592ddfc0d53c06561fe8a3dacb Mon Sep 17 00:00:00 2001 From: Ayush Sawant Date: Fri, 11 Jul 2025 21:45:53 +0530 Subject: [PATCH 121/129] Support for extra_body parameter for embeddings API (#906) * support for extra_body parameter for embeddings API * done linting * added unit tests * improved code coverage and removed unnecessary checks * test cleanup * updated body map creation code * code coverage * minor change * updated testcase comment --- client.go | 14 ++++++++++++++ embeddings.go | 32 +++++++++++++++++++++++++++++++- embeddings_test.go | 46 +++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 90 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index cef375348..413b8db03 100644 --- a/client.go +++ b/client.go @@ -84,6 +84,20 @@ func withBody(body any) requestOption { } } +func withExtraBody(extraBody map[string]any) requestOption { + return func(args *requestOptions) { + // Assert that args.body is a map[string]any. + bodyMap, ok := args.body.(map[string]any) + if ok { + // If it's a map[string]any then only add extraBody + // fields to args.body otherwise keep only fields in request struct. + for key, value := range extraBody { + bodyMap[key] = value + } + } + } +} + func withContentType(contentType string) requestOption { return func(args *requestOptions) { args.header.Set("Content-Type", contentType) diff --git a/embeddings.go b/embeddings.go index 4a0e682da..8593f8b5b 100644 --- a/embeddings.go +++ b/embeddings.go @@ -4,6 +4,7 @@ import ( "context" "encoding/base64" "encoding/binary" + "encoding/json" "errors" "math" "net/http" @@ -160,6 +161,9 @@ type EmbeddingRequest struct { // Dimensions The number of dimensions the resulting output embeddings should have. // Only supported in text-embedding-3 and later models. Dimensions int `json:"dimensions,omitempty"` + // The ExtraBody field allows for the inclusion of arbitrary key-value pairs + // in the request body that may not be explicitly defined in this struct. + ExtraBody map[string]any `json:"extra_body,omitempty"` } func (r EmbeddingRequest) Convert() EmbeddingRequest { @@ -187,6 +191,9 @@ type EmbeddingRequestStrings struct { // Dimensions The number of dimensions the resulting output embeddings should have. // Only supported in text-embedding-3 and later models. Dimensions int `json:"dimensions,omitempty"` + // The ExtraBody field allows for the inclusion of arbitrary key-value pairs + // in the request body that may not be explicitly defined in this struct. + ExtraBody map[string]any `json:"extra_body,omitempty"` } func (r EmbeddingRequestStrings) Convert() EmbeddingRequest { @@ -196,6 +203,7 @@ func (r EmbeddingRequestStrings) Convert() EmbeddingRequest { User: r.User, EncodingFormat: r.EncodingFormat, Dimensions: r.Dimensions, + ExtraBody: r.ExtraBody, } } @@ -219,6 +227,9 @@ type EmbeddingRequestTokens struct { // Dimensions The number of dimensions the resulting output embeddings should have. // Only supported in text-embedding-3 and later models. Dimensions int `json:"dimensions,omitempty"` + // The ExtraBody field allows for the inclusion of arbitrary key-value pairs + // in the request body that may not be explicitly defined in this struct. + ExtraBody map[string]any `json:"extra_body,omitempty"` } func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { @@ -228,6 +239,7 @@ func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { User: r.User, EncodingFormat: r.EncodingFormat, Dimensions: r.Dimensions, + ExtraBody: r.ExtraBody, } } @@ -241,11 +253,29 @@ func (c *Client) CreateEmbeddings( conv EmbeddingRequestConverter, ) (res EmbeddingResponse, err error) { baseReq := conv.Convert() + + // The body map is used to dynamically construct the request payload for the embedding API. + // Instead of relying on a fixed struct, the body map allows for flexible inclusion of fields + // based on their presence, avoiding unnecessary or empty fields in the request. + extraBody := baseReq.ExtraBody + baseReq.ExtraBody = nil + + // Serialize baseReq to JSON + jsonData, err := json.Marshal(baseReq) + if err != nil { + return + } + + // Deserialize JSON to map[string]any + var body map[string]any + _ = json.Unmarshal(jsonData, &body) + req, err := c.newRequest( ctx, http.MethodPost, c.fullURL("/embeddings", withModel(string(baseReq.Model))), - withBody(baseReq), + withBody(body), // Main request body. + withExtraBody(extraBody), // Merge ExtraBody fields. ) if err != nil { return diff --git a/embeddings_test.go b/embeddings_test.go index 438978169..07f1262cb 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -51,6 +51,24 @@ func TestEmbedding(t *testing.T) { t.Fatalf("Expected embedding request to contain model field") } + // test embedding request with strings and extra_body param + embeddingReqWithExtraBody := openai.EmbeddingRequest{ + Input: []string{ + "The food was delicious and the waiter", + "Other examples of embedding request", + }, + Model: model, + ExtraBody: map[string]any{ + "input_type": "query", + "truncate": "NONE", + }, + } + marshaled, err = json.Marshal(embeddingReqWithExtraBody) + checks.NoError(t, err, "Could not marshal embedding request") + if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) { + t.Fatalf("Expected embedding request to contain model field") + } + // test embedding request with strings embeddingReqStrings := openai.EmbeddingRequestStrings{ Input: []string{ @@ -124,7 +142,33 @@ func TestEmbeddingEndpoint(t *testing.T) { t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) } - // test create embeddings with strings (simple embedding request) + // test create embeddings with strings (ExtraBody in request) + res, err = client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + ExtraBody: map[string]any{ + "input_type": "query", + "truncate": "NONE", + }, + Dimensions: 1, + }, + ) + checks.NoError(t, err, "CreateEmbeddings error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } + + // test create embeddings with strings (ExtraBody in request and ) + _, err = client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + Input: make(chan int), // Channels are not serializable + Model: "example_model", + }, + ) + checks.HasError(t, err, "CreateEmbeddings error") + + // test failed (Serialize JSON error) res, err = client.CreateEmbeddings( context.Background(), openai.EmbeddingRequest{ From e6c1d3e5bde0bae5966070ab4edee874b6c8c73f Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 11 Jul 2025 17:39:39 +0100 Subject: [PATCH 122/129] Increase jsonschema test coverage (#1043) * test: expand jsonschema coverage * test: fix package name for containsref tests --- jsonschema/containsref_test.go | 48 ++++++++++++++++++++++++++++++++++ jsonschema/json_errors_test.go | 27 +++++++++++++++++++ 2 files changed, 75 insertions(+) create mode 100644 jsonschema/containsref_test.go create mode 100644 jsonschema/json_errors_test.go diff --git a/jsonschema/containsref_test.go b/jsonschema/containsref_test.go new file mode 100644 index 000000000..dc1842775 --- /dev/null +++ b/jsonschema/containsref_test.go @@ -0,0 +1,48 @@ +package jsonschema_test + +import ( + "testing" + + "github.com/sashabaranov/go-openai/jsonschema" +) + +// SelfRef struct used to produce a self-referential schema. +type SelfRef struct { + Friends []SelfRef `json:"friends"` +} + +// Address struct referenced by Person without self-reference. +type Address struct { + Street string `json:"street"` +} + +type Person struct { + Address Address `json:"address"` +} + +// TestGenerateSchemaForType_SelfRef ensures that self-referential types are not +// flattened during schema generation. +func TestGenerateSchemaForType_SelfRef(t *testing.T) { + schema, err := jsonschema.GenerateSchemaForType(SelfRef{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := schema.Defs["SelfRef"]; !ok { + t.Fatal("expected defs to contain SelfRef for self reference") + } +} + +// TestGenerateSchemaForType_NoSelfRef ensures that non-self-referential types +// are flattened and do not reappear in $defs. +func TestGenerateSchemaForType_NoSelfRef(t *testing.T) { + schema, err := jsonschema.GenerateSchemaForType(Person{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := schema.Defs["Person"]; ok { + t.Fatal("unexpected Person definition in defs") + } + if _, ok := schema.Defs["Address"]; !ok { + t.Fatal("expected Address definition in defs") + } +} diff --git a/jsonschema/json_errors_test.go b/jsonschema/json_errors_test.go new file mode 100644 index 000000000..3b864fc21 --- /dev/null +++ b/jsonschema/json_errors_test.go @@ -0,0 +1,27 @@ +package jsonschema_test + +import ( + "testing" + + "github.com/sashabaranov/go-openai/jsonschema" +) + +// TestGenerateSchemaForType_ErrorPaths verifies error handling for unsupported types. +func TestGenerateSchemaForType_ErrorPaths(t *testing.T) { + type anon struct{ Ch chan int } + tests := []struct { + name string + v any + }{ + {"slice", []chan int{}}, + {"anon struct", anon{}}, + {"pointer", (*chan int)(nil)}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if _, err := jsonschema.GenerateSchemaForType(tt.v); err == nil { + t.Errorf("expected error for %s", tt.name) + } + }) + } +} From 181c0e8fd7358b1f37f902262b246d513acd1e29 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 11 Jul 2025 18:29:23 +0100 Subject: [PATCH 123/129] Add tests for internal utilities (#1044) * Add unit tests for internal test utilities * Fix lint issues in internal tests --- internal/test/checks/checks_test.go | 19 +++++++ internal/test/failer_test.go | 24 +++++++++ internal/test/helpers_test.go | 54 +++++++++++++++++++ internal/test/server.go | 12 +++++ internal/test/server_test.go | 80 +++++++++++++++++++++++++++++ 5 files changed, 189 insertions(+) create mode 100644 internal/test/checks/checks_test.go create mode 100644 internal/test/failer_test.go create mode 100644 internal/test/helpers_test.go create mode 100644 internal/test/server_test.go diff --git a/internal/test/checks/checks_test.go b/internal/test/checks/checks_test.go new file mode 100644 index 000000000..0677054df --- /dev/null +++ b/internal/test/checks/checks_test.go @@ -0,0 +1,19 @@ +package checks_test + +import ( + "errors" + "testing" + + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestChecksSuccessPaths(t *testing.T) { + checks.NoError(t, nil) + checks.NoErrorF(t, nil) + checks.HasError(t, errors.New("err")) + target := errors.New("x") + checks.ErrorIs(t, target, target) + checks.ErrorIsF(t, target, target, "msg") + checks.ErrorIsNot(t, errors.New("y"), target) + checks.ErrorIsNotf(t, errors.New("y"), target, "msg") +} diff --git a/internal/test/failer_test.go b/internal/test/failer_test.go new file mode 100644 index 000000000..fb1f4bf06 --- /dev/null +++ b/internal/test/failer_test.go @@ -0,0 +1,24 @@ +//nolint:testpackage // need access to unexported fields and types for testing +package test + +import ( + "errors" + "testing" +) + +func TestFailingErrorBuffer(t *testing.T) { + buf := &FailingErrorBuffer{} + n, err := buf.Write([]byte("test")) + if !errors.Is(err, ErrTestErrorAccumulatorWriteFailed) { + t.Fatalf("expected %v, got %v", ErrTestErrorAccumulatorWriteFailed, err) + } + if n != 0 { + t.Fatalf("expected n=0, got %d", n) + } + if buf.Len() != 0 { + t.Fatalf("expected Len 0, got %d", buf.Len()) + } + if len(buf.Bytes()) != 0 { + t.Fatalf("expected empty bytes") + } +} diff --git a/internal/test/helpers_test.go b/internal/test/helpers_test.go new file mode 100644 index 000000000..aa177679b --- /dev/null +++ b/internal/test/helpers_test.go @@ -0,0 +1,54 @@ +package test_test + +import ( + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + internaltest "github.com/sashabaranov/go-openai/internal/test" +) + +func TestCreateTestFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "file.txt") + internaltest.CreateTestFile(t, path) + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("failed to read created file: %v", err) + } + if string(data) != "hello" { + t.Fatalf("unexpected file contents: %q", string(data)) + } +} + +func TestTokenRoundTripperAddsHeader(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "Bearer "+internaltest.GetTestToken() { + t.Fatalf("authorization header not set") + } + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + client := srv.Client() + client.Transport = &internaltest.TokenRoundTripper{Token: internaltest.GetTestToken(), Fallback: client.Transport} + + req, err := http.NewRequest(http.MethodGet, srv.URL, nil) + if err != nil { + t.Fatalf("request error: %v", err) + } + resp, err := client.Do(req) + if err != nil { + t.Fatalf("client request error: %v", err) + } + if _, err = io.Copy(io.Discard, resp.Body); err != nil { + t.Fatalf("read body: %v", err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } +} diff --git a/internal/test/server.go b/internal/test/server.go index 127d4c16f..d32c3e4cb 100644 --- a/internal/test/server.go +++ b/internal/test/server.go @@ -23,6 +23,18 @@ func NewTestServer() *ServerTest { return &ServerTest{handlers: make(map[string]handler)} } +// HandlerCount returns the number of registered handlers. +func (ts *ServerTest) HandlerCount() int { + return len(ts.handlers) +} + +// HasHandler checks if a handler was registered for the given path. +func (ts *ServerTest) HasHandler(path string) bool { + path = strings.ReplaceAll(path, "*", ".*") + _, ok := ts.handlers[path] + return ok +} + func (ts *ServerTest) RegisterHandler(path string, handler handler) { // to make the registered paths friendlier to a regex match in the route handler // in OpenAITestServer diff --git a/internal/test/server_test.go b/internal/test/server_test.go new file mode 100644 index 000000000..f8ce731d1 --- /dev/null +++ b/internal/test/server_test.go @@ -0,0 +1,80 @@ +package test_test + +import ( + "io" + "net/http" + "testing" + + internaltest "github.com/sashabaranov/go-openai/internal/test" +) + +func TestGetTestToken(t *testing.T) { + if internaltest.GetTestToken() != "this-is-my-secure-token-do-not-steal!!" { + t.Fatalf("unexpected token") + } +} + +func TestNewTestServer(t *testing.T) { + ts := internaltest.NewTestServer() + if ts == nil { + t.Fatalf("server not properly initialized") + } + if ts.HandlerCount() != 0 { + t.Fatalf("expected no handlers initially") + } +} + +func TestRegisterHandlerTransformsPath(t *testing.T) { + ts := internaltest.NewTestServer() + h := func(_ http.ResponseWriter, _ *http.Request) {} + ts.RegisterHandler("/foo/*", h) + if !ts.HasHandler("/foo/*") { + t.Fatalf("handler not registered with transformed path") + } +} + +func TestOpenAITestServer(t *testing.T) { + ts := internaltest.NewTestServer() + ts.RegisterHandler("/v1/test/*", func(w http.ResponseWriter, _ *http.Request) { + if _, err := io.WriteString(w, "ok"); err != nil { + t.Fatalf("write: %v", err) + } + }) + srv := ts.OpenAITestServer() + srv.Start() + defer srv.Close() + + base := srv.Client().Transport + client := &http.Client{Transport: &internaltest.TokenRoundTripper{Token: internaltest.GetTestToken(), Fallback: base}} + resp, err := client.Get(srv.URL + "/v1/test/123") + if err != nil { + t.Fatalf("request error: %v", err) + } + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + t.Fatalf("read response body: %v", err) + } + if resp.StatusCode != http.StatusOK || string(body) != "ok" { + t.Fatalf("unexpected response: %d %q", resp.StatusCode, string(body)) + } + + // unregistered path + resp, err = client.Get(srv.URL + "/unknown") + if err != nil { + t.Fatalf("request error: %v", err) + } + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("expected 404, got %d", resp.StatusCode) + } + + // missing token should return unauthorized + clientNoToken := srv.Client() + resp, err = clientNoToken.Get(srv.URL + "/v1/test/123") + if err != nil { + t.Fatalf("request error: %v", err) + } + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", resp.StatusCode) + } +} From 4f87294cebdd14457cc2e1013cdf5acf5db3d27d Mon Sep 17 00:00:00 2001 From: Ayush Sawant Date: Sun, 20 Jul 2025 01:29:02 +0530 Subject: [PATCH 124/129] Add GuidedChoice to ChatCompletionRequest (#1034) * Add GuidedChoice to ChatCompletionRequest * made separate NonOpenAIExtensions * fixed lint issue * renamed struct and removed inline json tag * Update chat.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update chat.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- chat.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/chat.go b/chat.go index 0f0c5b5d5..e14acd9d9 100644 --- a/chat.go +++ b/chat.go @@ -248,6 +248,16 @@ func (r *ChatCompletionResponseFormatJSONSchema) UnmarshalJSON(data []byte) erro return nil } +// ChatCompletionRequestExtensions contains third-party OpenAI API extensions (e.g., vendor-specific implementations like vLLM). +type ChatCompletionRequestExtensions struct { + // GuidedChoice is a vLLM-specific extension that restricts the model's output + // to one of the predefined string choices provided in this field. This feature + // is used to constrain the model's responses to a controlled set of options, + // ensuring predictable and consistent outputs in scenarios where specific + // choices are required. + GuidedChoice []string `json:"guided_choice,omitempty"` +} + // ChatCompletionRequest represents a request structure for chat completion API. type ChatCompletionRequest struct { Model string `json:"model"` @@ -309,6 +319,8 @@ type ChatCompletionRequest struct { ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"` // Specifies the latency tier to use for processing the request. ServiceTier ServiceTier `json:"service_tier,omitempty"` + // Embedded struct for non-OpenAI extensions + ChatCompletionRequestExtensions } type StreamOptions struct { From c4273cb5f46031ee478601f4edd82d2fad401e77 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Sat, 19 Jul 2025 21:07:34 +0100 Subject: [PATCH 125/129] fix(chat): shorten comment to pass linter (#1050) --- chat.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/chat.go b/chat.go index e14acd9d9..0bb2e98ee 100644 --- a/chat.go +++ b/chat.go @@ -248,7 +248,8 @@ func (r *ChatCompletionResponseFormatJSONSchema) UnmarshalJSON(data []byte) erro return nil } -// ChatCompletionRequestExtensions contains third-party OpenAI API extensions (e.g., vendor-specific implementations like vLLM). +// ChatCompletionRequestExtensions contains third-party OpenAI API extensions +// (e.g., vendor-specific implementations like vLLM). type ChatCompletionRequestExtensions struct { // GuidedChoice is a vLLM-specific extension that restricts the model's output // to one of the predefined string choices provided in this field. This feature @@ -264,7 +265,7 @@ type ChatCompletionRequest struct { Messages []ChatCompletionMessage `json:"messages"` // MaxTokens The maximum number of tokens that can be generated in the chat completion. // This value can be used to control costs for text generated via API. - // This value is now deprecated in favor of max_completion_tokens, and is not compatible with o1 series models. + // Deprecated: use MaxCompletionTokens. Not compatible with o1-series models. // refs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_tokens MaxTokens int `json:"max_tokens,omitempty"` // MaxCompletionTokens An upper bound for the number of tokens that can be generated for a completion, From f7d6ece81065bde1e93576c89f7506d1b35cb205 Mon Sep 17 00:00:00 2001 From: Behzad Soltanpour Date: Mon, 11 Aug 2025 12:45:50 +0330 Subject: [PATCH 126/129] add GPT-5 model constants and reasoning validation (#1062) --- chat_test.go | 120 +++++++++++++++++++++++++++++++++++++++++ completion.go | 8 +++ completion_test.go | 29 ++++++++++ reasoning_validator.go | 9 ++-- 4 files changed, 162 insertions(+), 4 deletions(-) diff --git a/chat_test.go b/chat_test.go index 172ce0740..236cff736 100644 --- a/chat_test.go +++ b/chat_test.go @@ -331,6 +331,126 @@ func TestO3ModelsChatCompletionsBetaLimitations(t *testing.T) { } } +func TestGPT5ModelsChatCompletionsBetaLimitations(t *testing.T) { + tests := []struct { + name string + in openai.ChatCompletionRequest + expectedError error + }{ + { + name: "log_probs_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + LogProbs: true, + Model: openai.GPT5, + }, + expectedError: openai.ErrReasoningModelLimitationsLogprobs, + }, + { + name: "set_temperature_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.GPT5Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(2), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_top_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.GPT5Nano, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(0.1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_n_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.GPT5ChatLatest, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(1), + N: 2, + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_presence_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.GPT5, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + PresencePenalty: float32(0.1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_frequency_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.GPT5Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + FrequencyPenalty: float32(0.1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + ctx := context.Background() + + _, err := client.CreateChatCompletion(ctx, tt.in) + checks.HasError(t, err) + msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err) + checks.ErrorIs(t, err, tt.expectedError, msg) + }) + } +} + func TestChatRequestOmitEmpty(t *testing.T) { data, err := json.Marshal(openai.ChatCompletionRequest{ // We set model b/c it's required, so omitempty doesn't make sense diff --git a/completion.go b/completion.go index 02ce7b016..27d69f587 100644 --- a/completion.go +++ b/completion.go @@ -49,6 +49,10 @@ const ( GPT4Dot1Nano20250414 = "gpt-4.1-nano-2025-04-14" GPT4Dot5Preview = "gpt-4.5-preview" GPT4Dot5Preview20250227 = "gpt-4.5-preview-2025-02-27" + GPT5 = "gpt-5" + GPT5Mini = "gpt-5-mini" + GPT5Nano = "gpt-5-nano" + GPT5ChatLatest = "gpt-5-chat-latest" GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125" GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106" GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613" @@ -142,6 +146,10 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT4Dot1Mini20250414: true, GPT4Dot1Nano: true, GPT4Dot1Nano20250414: true, + GPT5: true, + GPT5Mini: true, + GPT5Nano: true, + GPT5ChatLatest: true, }, chatCompletionsSuffix: { CodexCodeDavinci002: true, diff --git a/completion_test.go b/completion_test.go index f0ead0d63..abfc3007e 100644 --- a/completion_test.go +++ b/completion_test.go @@ -300,3 +300,32 @@ func TestCompletionWithGPT4oModels(t *testing.T) { }) } } + +// TestCompletionWithGPT5Models Tests that GPT5 models are not supported for completion endpoint. +func TestCompletionWithGPT5Models(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + models := []string{ + openai.GPT5, + openai.GPT5Mini, + openai.GPT5Nano, + openai.GPT5ChatLatest, + } + + for _, model := range models { + t.Run(model, func(t *testing.T) { + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: model, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err) + } + }) + } +} diff --git a/reasoning_validator.go b/reasoning_validator.go index 2910b1395..1d26ca047 100644 --- a/reasoning_validator.go +++ b/reasoning_validator.go @@ -28,21 +28,22 @@ var ( ErrReasoningModelLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll ) -// ReasoningValidator handles validation for o-series model requests. +// ReasoningValidator handles validation for reasoning model requests. type ReasoningValidator struct{} -// NewReasoningValidator creates a new validator for o-series models. +// NewReasoningValidator creates a new validator for reasoning models. func NewReasoningValidator() *ReasoningValidator { return &ReasoningValidator{} } -// Validate performs all validation checks for o-series models. +// Validate performs all validation checks for reasoning models. func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error { o1Series := strings.HasPrefix(request.Model, "o1") o3Series := strings.HasPrefix(request.Model, "o3") o4Series := strings.HasPrefix(request.Model, "o4") + gpt5Series := strings.HasPrefix(request.Model, "gpt-5") - if !o1Series && !o3Series && !o4Series { + if !o1Series && !o3Series && !o4Series && !gpt5Series { return nil } From f71d1a622abab7fa75159b02318cbcf6e2dcb0a0 Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Tue, 12 Aug 2025 18:03:57 +0800 Subject: [PATCH 127/129] feat: add safety_identifier params (#1066) --- chat.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/chat.go b/chat.go index 0bb2e98ee..9719f6b92 100644 --- a/chat.go +++ b/chat.go @@ -320,6 +320,11 @@ type ChatCompletionRequest struct { ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"` // Specifies the latency tier to use for processing the request. ServiceTier ServiceTier `json:"service_tier,omitempty"` + // A stable identifier used to help detect users of your application that may be violating OpenAI's usage policies. + // The IDs should be a string that uniquely identifies each user. + // We recommend hashing their username or email address, in order to avoid sending us any identifying information. + // https://platform.openai.com/docs/api-reference/chat/create#chat_create-safety_identifier + SafetyIdentifier string `json:"safety_identifier,omitempty"` // Embedded struct for non-OpenAI extensions ChatCompletionRequestExtensions } From 8e5611cc5efdc2533b80e5667b69741c3fad875c Mon Sep 17 00:00:00 2001 From: Amady Azdaev Date: Fri, 29 Aug 2025 20:29:03 +0300 Subject: [PATCH 128/129] Add Verbosity parameter to Chat Completion Request (#1064) * add verbosity param to ChatCompletionRequest * edit comment about verbosity --- chat.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/chat.go b/chat.go index 9719f6b92..0aa018715 100644 --- a/chat.go +++ b/chat.go @@ -320,6 +320,12 @@ type ChatCompletionRequest struct { ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"` // Specifies the latency tier to use for processing the request. ServiceTier ServiceTier `json:"service_tier,omitempty"` + // Verbosity determines how many output tokens are generated. Lowering the number of + // tokens reduces overall latency. It can be set to "low", "medium", or "high". + // Note: This field is only confirmed to work with gpt-5, gpt-5-mini and gpt-5-nano. + // Also, it is not in the API reference of chat completion at the time of writing, + // though it is supported by the API. + Verbosity string `json:"verbosity,omitempty"` // A stable identifier used to help detect users of your application that may be violating OpenAI's usage policies. // The IDs should be a string that uniquely identifies each user. // We recommend hashing their username or email address, in order to avoid sending us any identifying information. From 5d7a276f4c0e48f97354fe555cb793b52e350e62 Mon Sep 17 00:00:00 2001 From: Christopher Petito <47751006+krissetto@users.noreply.github.com> Date: Tue, 21 Oct 2025 21:27:33 +0200 Subject: [PATCH 129/129] Stop stripping dots in azure model mapper for models that aren't 3.5 based (#1079) fixes #978 Signed-off-by: Christopher Petito --- config.go | 7 ++++++- config_test.go | 4 ++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/config.go b/config.go index 4788ba62a..4b8cfb6fb 100644 --- a/config.go +++ b/config.go @@ -3,6 +3,7 @@ package openai import ( "net/http" "regexp" + "strings" ) const ( @@ -70,7 +71,11 @@ func DefaultAzureConfig(apiKey, baseURL string) ClientConfig { APIType: APITypeAzure, APIVersion: "2023-05-15", AzureModelMapperFunc: func(model string) string { - return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "") + // only 3.5 models have the "." stripped in their names + if strings.Contains(model, "3.5") { + return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "") + } + return strings.ReplaceAll(model, ":", "") }, HTTPClient: &http.Client{}, diff --git a/config_test.go b/config_test.go index 960230804..f44b80825 100644 --- a/config_test.go +++ b/config_test.go @@ -20,6 +20,10 @@ func TestGetAzureDeploymentByModel(t *testing.T) { Model: "gpt-3.5-turbo-0301", Expect: "gpt-35-turbo-0301", }, + { + Model: "gpt-4.1", + Expect: "gpt-4.1", + }, { Model: "text-embedding-ada-002", Expect: "text-embedding-ada-002",